/**********************************************************************
 
	Copyright (C) 2003-2005
	Hirohisa MORI <joshua@nichibun.ac.jp>
	Tomohito Nakajima <nakajima@zeta.co.jp>
 
	This program is free software; you can redistribute it 
	and/or modify it under the terms of the GLOBALBASE 
	Library General Public License (G-LGPL) as published by 

	http://www.globalbase.org/
 
	This program is distributed in the hope that it will be 
	useful, but WITHOUT ANY WARRANTY; without even the 
	implied warranty of MERCHANTABILITY or FITNESS FOR A 
	PARTICULAR PURPOSE.

**********************************************************************/

#include <string.h>
#include <stdio.h>
#include <stdlib.h>

#include "openssl/rsa.h"

#include "ossl.h"


#ifdef GBSSL_TEST
#define d_alloc malloc
#define d_f_ree free
#define gbrsa_test_main main
#else
#include "memory_debug.h"
#endif

#define RSA_PKCS1_PADDING_SIZE 11

/*
extern struct rsa_st;
typedef struct rsa_st RSA;

extern RSA *d2i_RSAPrivateKey(RSA **a, const unsigned char **in, long len);
extern int i2d_RSAPrivateKey(const RSA *a, unsigned char **out);

extern RSA *d2i_RSAPublicKey(RSA **a, const unsigned char **in, long len);
extern int i2d_RSAPublicKey(const RSA *a, unsigned char **out);

extern int RSA_private_decrypt(int flen, const unsigned char *from, unsigned char *to,
	     RSA *rsa, int padding);
extern int RSA_public_encrypt(int flen, const unsigned char *from, unsigned char *to,
	     RSA *rsa, int padding);
extern int RSA_size(const RSA *r);
extern void RSA_free(RSA *r);
#define RSA_PKCS1_PADDING	1
*/

oSSL_data *
oSSL_RSA_decode(
oSSL_object * private_key,
oSSL_data * org)
{
	RSA *rsa;
	unsigned char *in_ptr;
	unsigned char *out_ptr;

	int max_rsa_decode_size;
	int decode_size;
	int decoded_size;
	int left_data_size;
	int max_decoded_len;
	unsigned char *total_decoded_data;
	int total_decoded_len;

	/* load public key RSA from private_key->data */
	/*
	ptr = (unsigned char*)private_key->data;
	rsa = d2i_RSAPrivateKey(NULL, &ptr, (long)private_key->len);
	*/
	rsa = private_key->obj.rsa_key.rsa;
	if(rsa == NULL){
		goto err;
	}
	max_rsa_decode_size = RSA_size(rsa);
	
	in_ptr = (unsigned char*)org->data;
	left_data_size = org->len;
	
	max_decoded_len = max_rsa_decode_size*((org->len+max_rsa_decode_size-1)/max_rsa_decode_size);
	total_decoded_len = 0;
	total_decoded_data = (unsigned char *)d_alloc(max_decoded_len);
	out_ptr = total_decoded_data;
	while(left_data_size>0){
		if(max_rsa_decode_size > left_data_size){
			decode_size = left_data_size;
		}
		else{
			decode_size = max_rsa_decode_size;
		}
		decoded_size = RSA_private_decrypt(decode_size, in_ptr, (unsigned char*)&total_decoded_data[total_decoded_len], rsa, RSA_PKCS1_PADDING);
		if(decoded_size == -1){
			goto err;
		}
		total_decoded_len += decoded_size;
		in_ptr += decode_size;
		left_data_size -= decode_size;
	}
	
	return oSSL_data_new(OSSL_RAW, total_decoded_data, total_decoded_len, TRUE, max_decoded_len);
err:
	d_f_ree(total_decoded_data);
/*	RSA_free(rsa);
*/	return 0;
}

oSSL_data *
oSSL_RSA_encode(oSSL_object *public_key, oSSL_data * org)
{
	RSA *rsa;
	unsigned char *in_ptr;
	unsigned char *out_ptr;
	
	int max_rsa_encode_size;
	int encode_size;
	int encoded_size;
	int left_data_size;
	
	int total_encode_len;
	unsigned char* total_encode_data;

	total_encode_data = 0;
	total_encode_len = 0;
	
	/* load public key RSA from public_key->data */
/*	ptr = (unsigned char*)public_key->data;
	rsa = d2i_RSAPublicKey(NULL, &ptr, (long)public_key->len);
	if(rsa == NULL){
		goto err;
	}
	*/
	rsa = public_key->obj.rsa_key.rsa;

	max_rsa_encode_size = RSA_size(rsa)-RSA_PKCS1_PADDING_SIZE;
	
	in_ptr = (unsigned char*)org->data;
	left_data_size = org->len;
	
	total_encode_len = RSA_size(rsa)*((org->len+max_rsa_encode_size-1)/max_rsa_encode_size);
	total_encode_data = (unsigned char *)d_alloc(total_encode_len);
	out_ptr = (unsigned char*)total_encode_data;
	while(left_data_size>0){
		if(max_rsa_encode_size > left_data_size){
			encode_size = left_data_size;
		}
		else{
			encode_size = max_rsa_encode_size;
		}
		encoded_size = RSA_public_encrypt(encode_size, in_ptr, out_ptr, rsa, RSA_PKCS1_PADDING);
		if(encoded_size == -1){
			goto err;
		}
		in_ptr += encode_size;
		out_ptr += encoded_size;
		left_data_size -= encode_size;
	}
	return oSSL_data_new(OSSL_RAW, total_encode_data, total_encode_len, TRUE, total_encode_len);
err:
	d_f_ree(total_encode_data);
	return 0;
}

void oSSL_RSA_free(oSSL_object *obj){
	--(obj->obj.rsa_key.refcnt);
	if(obj->obj.rsa_key.refcnt == 0){
		RSA_free(obj->obj.rsa_key.rsa);
		d_f_ree(obj);
	}
}

oSSL_data *oSSL_RSA_object2data(oSSL_object * obj, int convert_type)
{
	unsigned char *d,*d2;
	int len;

	if(convert_type != OSSL_DATA_FORMAT_DER){
		/* not supported format */
		return NULL;
	}
	if(obj->type == OSSL_RSA_PUBLICKEY){
		len = i2d_RSAPublicKey(obj->obj.rsa_key.rsa, NULL);
		d2 = d = d_alloc(len);
		i2d_RSAPublicKey(obj->obj.rsa_key.rsa, &d2);
	}
	else if(obj->type == OSSL_RSA_PRIVATEKEY){
		len = i2d_RSAPrivateKey(obj->obj.rsa_key.rsa, NULL);
		d2 = d = d_alloc(len);
		i2d_RSAPrivateKey(obj->obj.rsa_key.rsa, &d2);
	}
	return oSSL_data_new(obj->type, d, len, TRUE, len);
}

extern oSSL_method oSSL_RSA_publickey_method;
extern oSSL_method oSSL_RSA_privatekey_method;

oSSL_object *oSSL_RSA_data2object(oSSL_data * data, int convert_type)
{
	oSSL_object *obj;
	RSA *rsa;
	
	if(convert_type != OSSL_DATA_FORMAT_DER){
		/* not supported format */
		return NULL;
	}

	if(data->type == OSSL_RSA_PUBLICKEY){
		rsa = d2i_RSAPublicKey(NULL, (const unsigned char**)&data->data, (long)data->len);
		obj = oSSL_object_new(data->type, &oSSL_RSA_publickey_method);
	}
	else if(data->type == OSSL_RSA_PRIVATEKEY){
		rsa = d2i_RSAPrivateKey(NULL, (const unsigned char**)&data->data, (long)data->len);
		obj = oSSL_object_new(data->type, &oSSL_RSA_privatekey_method);
	}
	else{
		/* type miss much */
		return 0;
	}
	obj->obj.rsa_key.rsa = rsa;
	obj->obj.rsa_key.refcnt = 1;
	return obj;
}

oSSL_method oSSL_RSA_publickey_method = {
	OSSL_RSA_PUBLICKEY,
	oSSL_RSA_free,
	oSSL_RSA_object2data,
	oSSL_RSA_data2object
};

oSSL_method oSSL_RSA_privatekey_method = {
	OSSL_RSA_PRIVATEKEY,
	oSSL_RSA_free,
	oSSL_RSA_object2data,
	oSSL_RSA_data2object
};

void oSSL_RSA_init(){
	oSSL_types[OSSL_RSA_PUBLICKEY] = &oSSL_RSA_publickey_method;
	oSSL_types[OSSL_RSA_PRIVATEKEY] = &oSSL_RSA_privatekey_method;
}

/*
int oSSL_RSA_decode(oSSL_object * private_key, oSSL_data * decode, oSSL_data * org);
int oSSL_RSA_encode(oSSL_object * public_key, oSSL_data * encode, oSSL_data * org);
int oSSL_RSA_genkey(oSSL_object ** public_key, oSSL_object ** private_key, int size);
*/
int
oSSL_RSA_genkey(oSSL_object ** public_key, oSSL_object ** private_key, int size)
{
	RSA *rsa;
	
	rsa = RSA_generate_key(size/* key bit size */, RSA_F4, NULL, NULL);
	if(rsa == NULL){
		return -1;
	}
	
	*public_key = oSSL_object_new(OSSL_RSA_PUBLICKEY, &oSSL_RSA_publickey_method);
	(*public_key)->obj.rsa_key.rsa = rsa;
	(*public_key)->obj.rsa_key.refcnt = 2;
	
	*private_key = oSSL_object_new(OSSL_RSA_PRIVATEKEY, &oSSL_RSA_privatekey_method);
	(*private_key)->obj.rsa_key.rsa = rsa;
	(*private_key)->obj.rsa_key.refcnt = 2;
	
	/*
	public_key_length = i2d_RSAPublicKey(rsa, NULL);
	public_key->len = public_key_length;
	public_key->data = (char*)d_alloc(public_key_length);
	out_ptr = (unsigned char*)public_key->data;
	i2d_RSAPublicKey(rsa, &out_ptr);
	
	private_key_length = i2d_RSAPrivateKey(rsa, NULL);
	private_key->len = private_key_length;
	private_key->data =  (char*)d_alloc(private_key_length);
	out_ptr = (unsigned char*)private_key->data;
	i2d_RSAPrivateKey(rsa, &out_ptr);
	
	RSA_free(rsa);
	*/

	return 0;
}




