/* Copyright 2013 Akira Ohta (akohta001@gmail.com)
    This file is part of ntch.

    The ntch is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    The ntch is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with ntch.  If not, see <http://www.gnu.org/licenses/>.
    
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <openssl/hmac.h> 
#include <openssl/sha.h>
#include <openssl/aes.h>

#include "utils/nt_std_t.h"
#include "utils/base64.h"
#include "utils/crypt.h"

#define NT_CRYPT_CHK_SUM (1578428)

typedef struct tag_nt_crypt_t *nt_crypt_tp;
typedef struct tag_nt_crypt_t{
	nt_crypt_handle_t handle;
	int ref_count;
	unsigned char key[256/8];
	unsigned char iv[16];
	
}nt_crypt_t;


static nt_crypt_handle g_crypt_handle;

static BOOL validate_rfc2898param(const char *salt,
		int iteration, const char **error_msg);
static nt_crypt_handle nt_crypt_alloc(const char *rfc2898_salt,
		int rfc2898_iteration, const char *aes256pass,
		const char **error_msg);
static void pkcs5_padding(unsigned char *data, size_t num_pad);
static int pkcs5_length(unsigned char *input, size_t in_len);

BOOL nt_crypt_lib_init(const char *rfc2898_salt,
		int rfc2898_iteration, const char *aes256pass,
		const char **error_msg)
{
	if(!aes256pass || 0 == strlen(aes256pass)){
		g_crypt_handle = NULL;
		return FALSE;
	}
	g_crypt_handle = 
			nt_crypt_alloc(rfc2898_salt, rfc2898_iteration, 
					aes256pass, error_msg);
	if(!g_crypt_handle)
		return FALSE;
	return TRUE;
}

void nt_crypt_lib_finish()
{
	if(g_crypt_handle)
		nt_crypt_release_ref(g_crypt_handle);
}

nt_crypt_handle nt_crypt_get_handle()
{
	if(!g_crypt_handle)
		return NULL;
	nt_crypt_add_ref(g_crypt_handle);
	return g_crypt_handle;
}

int nt_crypt_encrypt(nt_crypt_handle handle,
		const unsigned char *in_buf, size_t in_buf_len,
		unsigned char *out_buf, size_t out_buf_len)
{
	nt_crypt_tp cryptp;
	unsigned char iv[AES_BLOCK_SIZE];
	size_t length, num_pad;
	unsigned char *data;
	AES_KEY aes_key;
	
	assert(handle);
	assert(handle->chk_sum == NT_CRYPT_CHK_SUM);
	cryptp = (nt_crypt_tp)handle;
	assert(cryptp->ref_count > 0);
	
	memcpy(iv, cryptp->iv, sizeof(iv));
	
	length = ((in_buf_len / AES_BLOCK_SIZE + 1) * AES_BLOCK_SIZE);
	if(length > out_buf_len)
		return -1;
	
	num_pad = length - in_buf_len;
	data = malloc(length);
	if(!data)
		return -1;
	
	memcpy(data, in_buf, in_buf_len);
	
	pkcs5_padding(data+in_buf_len, num_pad);
	
	AES_set_encrypt_key(cryptp->key, 256, &aes_key);
	
	AES_cbc_encrypt(data, out_buf, length, &aes_key, iv, AES_ENCRYPT);
	
	free(data);
	
	return (int)length;
}

int nt_crypt_decrypt(nt_crypt_handle handle,
		const unsigned char *in_buf, size_t in_buf_len,
		unsigned char *out_buf, size_t out_buf_len)
{
	nt_crypt_tp cryptp;
	unsigned char iv[AES_BLOCK_SIZE];
	AES_KEY aes_key;
	
	assert(handle);
	assert(handle->chk_sum == NT_CRYPT_CHK_SUM);
	cryptp = (nt_crypt_tp)handle;
	assert(cryptp->ref_count > 0);
	
	memcpy(iv, cryptp->iv, sizeof(iv));
	
	if(in_buf_len > out_buf_len)
		return -1;
	if(in_buf_len %  AES_BLOCK_SIZE)
		return -1;
	
	AES_set_decrypt_key(cryptp->key, 256, &aes_key);
	
	AES_cbc_encrypt(in_buf, out_buf, in_buf_len, &aes_key, iv, AES_DECRYPT);
	
	return pkcs5_length(out_buf, in_buf_len);
}


static nt_crypt_handle nt_crypt_alloc(const char *rfc2898_salt,
		int rfc2898_iteration, const char *aes256pass,
		const char **error_msg)
{
	nt_crypt_tp cryptp;
	unsigned char result[256];
	
	cryptp = malloc(sizeof(nt_crypt_t));
	if(!cryptp)
		return FALSE;
	
	if(!validate_rfc2898param(
		rfc2898_salt, rfc2898_iteration, error_msg)){
		free(cryptp);
		return FALSE;
	}
	if(!PKCS5_PBKDF2_HMAC_SHA1(aes256pass, strlen(aes256pass),
			(const unsigned char*)rfc2898_salt, strlen(rfc2898_salt),
			rfc2898_iteration, 32+16, result)){
		*error_msg = "PKCS5_PBKDF2_HMAC_SHA1 呼び出しに失敗しました";
		free(cryptp);
		return FALSE;
	}
	
	memcpy(cryptp->key, result, 32);
	memcpy(cryptp->iv, result+32, 16);
	cryptp->handle.chk_sum = NT_CRYPT_CHK_SUM;
	cryptp->ref_count = 1;
	return &cryptp->handle;
}

int nt_crypt_add_ref(nt_crypt_handle handle)
{
	nt_crypt_tp cryptp;
	assert(handle);
	assert(handle->chk_sum == NT_CRYPT_CHK_SUM);
	cryptp = (nt_crypt_tp)handle;
	assert(cryptp->ref_count > 0);
	return ++cryptp->ref_count;
}

int nt_crypt_release_ref(nt_crypt_handle handle)
{
	nt_crypt_tp cryptp;
	assert(handle);
	assert(handle->chk_sum == NT_CRYPT_CHK_SUM);
	cryptp = (nt_crypt_tp)handle;
	assert(cryptp->ref_count > 0);
	if(0 != --cryptp->ref_count)
		return cryptp->ref_count;
	free(cryptp);
	return 0;
}

static BOOL validate_rfc2898param(
		const char *salt, int iteration,
		const char **error_msg)
{
	if(!salt || 8 > strlen(salt)){
		*error_msg = "SALT値が不正です";
		return FALSE;
	}
	if(iteration < 1000){
		*error_msg = "PBKDF2 繰り返し回数は1000以上を指定して下さい";
		return FALSE;
	}
	
	return TRUE;
}

static void pkcs5_padding(unsigned char *data, size_t num_pad)
{
	int i;
	unsigned char pad;
	assert(num_pad > 0 && num_pad <= AES_BLOCK_SIZE);
	if(num_pad == AES_BLOCK_SIZE)
		pad = 0;
	else
		pad = num_pad;
	for(i = 0; i < num_pad; i++){
		data[i] = pad;
	}
}

static int pkcs5_length(unsigned char *input, size_t in_len)
{
	size_t pad_num;
	if(in_len % 16)
		return -1;
	
	pad_num = input[in_len-1];
	if(pad_num > 15)
		return -1;
	if(pad_num == 0)
		pad_num = 16;
	return in_len - pad_num;
}

BOOL nt_make_sha1_string(const char *org_name, 
			char *outbuf, size_t outbuf_len)
{
	assert(outbuf != NULL);
	assert(outbuf_len > SHA_DIGEST_LENGTH*2+1);
	assert(org_name != NULL);

	SHA_CTX ctx;
	unsigned char md[SHA_DIGEST_LENGTH]; 
	char *outptr;
	int i , len;
	
	SHA1_Init(&ctx);

	len =  strlen(org_name);
	
	SHA1_Update(&ctx, org_name, len+1);

	SHA1_Final(&(md[0]), &ctx);
 
 	outptr = outbuf;
	for(i = 0; i < SHA_DIGEST_LENGTH; i++){
		sprintf(outptr, "%02x", md[i]);
		outptr += 2;
	}
	while(outptr < (outbuf+outbuf_len)){
		*outptr = '\0';
		outptr++;
	}
	return TRUE;
}
