/**********************************************************************
 * record.c                                                 August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/slab.h>

#include "app_data.h"
#include "conn.h"
#include "conn_list.h"
#include "record.h"
#include "record_list.h"
#include "message.h"
#include "socket.h"
#include "alert.h"
#include "log.h"

#include "types/change_cipher_spec_t.h"
#include "types/handshake_t.h"
#include "types/alert_t.h"

#include "crypto/ssl3mac.h"

#include "pk.h"
#include "kssl_alloc.h"
#include "daemon.h"

#ifndef MIN
#define MIN(a,b) ((a) < (b) ? (a) : (b))
#endif


kssl_record_t *
__kssl_record_create(kssl_conn_t *conn, u32 conn_flag,
		const char *function, const char *file, size_t line)
{
	kssl_record_t *cr;

	KSSL_DEBUG(12, "%s(%s:%u): kssl_record_create\n", 
			function, file, line);

	cr = (kssl_record_t *)kssl_kmalloc(sizeof(kssl_record_t), 
			GFP_KERNEL);
	if (!cr) {
		return NULL;
	}
	kssl_record_clear(cr);

	if (conn) {
		cr->conn_flag = conn_flag;
		cr->conn = conn;
		kssl_conn_get(conn);
	}

	return cr;
}


static void 
kssl_record_destroy_data(kssl_record_t *cr)
{
	int i;

	KSSL_DEBUG(12, "kssl_record_destroy_data enter %p\n", cr);

	/* Free data in the record */
	for (i = 0; i < cr->iov_len; i++) {
		if (cr->skbv && *(cr->skbv+i))
			kfree_skb(*(cr->skbv+i));
		else if (cr->iov && (cr->iov+i)->iov_base)
			kssl_kfree((cr->iov+i)->iov_base);
	}

	if (cr->skbv) {
		kssl_kfree(cr->skbv);
		cr->skbv = NULL;
	}
	if (cr->iov) {
		kssl_kfree(cr->iov);
		cr->iov = NULL;
	}
}


void 
__kssl_record_destroy(kssl_record_t *cr, 
		const char *function, const char *file, size_t line)
{

	KSSL_DEBUG(12, "%s(%s:%u): kssl_record_destroy enter: cr=%p\n",
			function, file, line, cr);

	kssl_record_destroy_data(cr);
	list_del(&(cr->list));
	if (cr->conn)
		kssl_conn_put(cr->conn);
	if (cr->msg)
		kssl_message_destroy(cr->msg);
	kssl_kfree(cr);
}


static inline void 
kssl_record_show_head(kssl_record_t *cr, const char *tag)
{
	if (cr->record.head.type == ct_ssl2) {
		KSSL_DEBUG(1, "%s: ssl2 format len=%u pad=%u is_escape=%u\n",
				tag, cr->ssl2_head.len, cr->ssl2_head.pad, 
				cr->ssl2_head.is_escape);
	}
	else {
		KSSL_DEBUG(1, "%s: version=%u.%u type=%u len=%u\n",
				tag, cr->record.head.version.major, 
				cr->record.head.version.minor,
				cr->record.head.type, cr->record.head.len);
	}
}


int 
kssl_record_add_skb(kssl_record_t *cr,
		struct sk_buff *skb, size_t offset, size_t len)
{
	int status;
	struct iovec *new_iov;
	struct sk_buff **new_skbv;
	struct sk_buff *cloned_skb;

	/* XXX: Should use skb_clone and deal with non-linear skbs */
	cloned_skb = skb_clone(skb, GFP_ATOMIC);
	if (!cloned_skb) {
		return -ENOMEM;
	}

	status = skb_linearize(cloned_skb, GFP_ATOMIC);
	if(status < 0) {
		kfree_skb(cloned_skb);
		return status;
	}

	if (cr->iov_len == cr->iov_alloc) {
		cr->iov_alloc += KSSL_CONN_RECORD_IOV_BLOCK;
		new_iov = (struct iovec *)kssl_kmalloc(
				cr->iov_alloc * sizeof(struct iovec),
				GFP_KERNEL);
		if (!new_iov) {
			kfree_skb(cloned_skb);
			return -ENOMEM;
		}

		new_skbv = kssl_kmalloc(cr->iov_alloc * 
				sizeof(struct sk_buff *), GFP_KERNEL);
		if (!new_skbv) {
			kfree_skb(cloned_skb);
			kssl_kfree(new_iov);
			return -ENOMEM;
		}

		memcpy(new_iov, cr->iov, cr->iov_len * sizeof(struct iovec));
		if (cr->iov)
			kssl_kfree(cr->iov);
		cr->iov = new_iov;

		memcpy(new_skbv, cr->skbv, 
				cr->iov_len * sizeof(struct sk_buff *));
		if (cr->skbv)
			kssl_kfree(cr->skbv);
		cr->skbv = new_skbv;
	}

	KSSL_DEBUG(9, "kssl_record_add_skb: skb=%p skb->data+offset=%p"
			"len=%u\n", skb, skb->data + offset, len);

	new_iov = cr->iov+cr->iov_len;
	new_iov->iov_base = (u8 *)cloned_skb->data + offset;
	new_iov->iov_len = len;

	*(cr->skbv+cr->iov_len) = cloned_skb;
	/* skb_get(cloned_skb); No, skb_clone does this */

	cr->iov_len++;
	cr->total_len += len;
	cr->content_len += len;
 
	KSSL_DEBUG(9, "kssl_record_add_skb: cr->skb=%p *(cr->skb)=%p\n",
	                                      cr->skbv, *(cr->skbv));
	/*
	{
	int i;
	printk(KERN_DEBUG "kssl_record_add_skb total_len=%u\n",
			cr->total_len);
	for(i = 0; i < cr->iov_len; i++) {
		printk(KERN_DEBUG " iov no %x: iov_len=%x iov_base=%p\n", i,
					((cr->iov+i)->iov_len),
					((cr->iov+i)->iov_base));
	}
	}
	*/

	return len;
}


static int 
kssl_record_start(kssl_record_t *cr, alert_t *alert)
{
	u8 buf[TLS_HEAD_NLEN];

	if (cr->total_len < 1)
		return 0;

	kssl_record_vec_cpy(cr, buf, TLS_HEAD_NLEN, 0);

	switch (*buf) {
		case ct_change_cipher_spec:
		case ct_alert:
		case ct_handshake:
		case ct_application_data:
			if (cr->total_len < TLS_HEAD_NLEN)
				goto leave;
			tls_head_from_buf(&(cr->record.head), buf);
			break;
		case ct_last:
		default:
			if (cr->total_len < SSL2_HEAD_NLEN_BUF(buf))
				goto leave;
			kssl_record_vec_cpy(cr, buf, 
					SSL2_HEAD_NLEN_BUF(buf), 0);
			ssl2_head_from_buf(&(cr->ssl2_head), buf);
			cr->record.head.len = cr->ssl2_head.len;
			cr->record.head.type = ct_ssl2;
			break;
	}

	if (cr->record.head.len > TLS_MAX_DATA) {
		KSSL_DEBUG(3, "SSL record data is too large\n");
		goto leave;
	}

	cr->state = KSSL_CONN_RECORD_RCV_STARTED;

	/* kssl_record_show_head(cr, "record start"); */

	return 1;

leave:
	alert->level = al_fatal;
	alert->description = ad_illegal_parameter;
	return -EINVAL;
}


static inline int 
kssl_record_complete(kssl_record_t *cr, alert_t *alert)
{
	if (cr->total_len < KSSL_CONN_RECORD_NLEN(cr))
		return 0;
	
	KSSL_DEBUG(12, "record complete: total_len=%u record_nlen=%u\n",
			cr->total_len, KSSL_CONN_RECORD_NLEN(cr));

	cr->state = KSSL_CONN_RECORD_RCV_COMPLETE;

	return 1;
}


static inline void 
kssl_iovec_split(struct iovec *next_head, 
		struct iovec *prev_tail, size_t offset) 
{
	if (prev_tail) {
		prev_tail->iov_len = offset;
	}
	if (next_head) {
		next_head->iov_base = (u8 *)next_head->iov_base + offset;
		next_head->iov_len -= offset;
	}
}


int 
kssl_record_vec_split(kssl_record_t *head,
		kssl_record_t **tail, size_t record_len)
{
	size_t vecno;
	size_t vecoffset;

	KSSL_DEBUG(12, "kssl_record_vec_split: enter %d %d %d\n",
			head->content_len, head->offset, record_len);

	/* If these lengths match then there is exactly one record present
	 * Otherwise to split up the data into the first record (head)
	 * and the remainder (tail) */
	if (head->content_len - head->offset == record_len) {
		KSSL_DEBUG(12, "kssl_record_vec_split: no split\n");
		*tail = NULL;
		return 0;
	}

	/*
	printk(KERN_DEBUG "kssl_record_vec_split: disabled\n");
	*tail = NULL;
	return -EINVAL;
	*/

	*tail = kssl_record_create(head->conn, head->conn_flag);
	if (!*tail)
		return -ENOMEM;

	kssl_record_vec_seek(head, head->offset+record_len, 
			&vecno, &vecoffset);

	/* Only set (*tail)->iov_alloc here so if there is
	 * an error (in kssl_kmalloc) head will be unchanged on
	 * return */
	if (!vecoffset)
		(*tail)->iov_alloc = (*tail)->iov_len = head->iov_len - vecno;
	else
		(*tail)->iov_alloc = (*tail)->iov_len = head->iov_len - vecno;

	(*tail)->iov = kssl_kmalloc((*tail)->iov_alloc * sizeof(struct iovec),
			GFP_KERNEL);
	if (!(*tail)->iov) {
		kssl_kfree(*tail);
		*tail = NULL;
		return -ENOMEM;
	}

	(*tail)->skbv = kssl_kmalloc((*tail)->iov_alloc * 
				     sizeof(struct sk_buff *), GFP_KERNEL);
	if (!(*tail)->skbv) {
		kssl_kfree((*tail)->iov);
		kssl_kfree(*tail);
		*tail = NULL;
		return -ENOMEM;
	}

	/* Now set the rest of the lengths */
	(*tail)->content_len = (*tail)->total_len = 
		head->total_len - head->offset - record_len;
	head->content_len = head->total_len = record_len + head->offset;
	if (!vecoffset)
		head->iov_len = vecno;
	else
		head->iov_len = vecno + 1;

	/*
	KSSL_DEBUG(12, "kssl_record_vec_split: "
			" vecno=%u vecoffset=%u record_len=%u\n"
			" (*tail)->iov_alloc=%u  (*tail)->iov_len=%u\n"
			" head->iov_alloc=%u  head->iov_len=%u\n",
			vecno, vecoffset, record_len,
			(*tail)->iov_alloc, (*tail)->iov_len,
			head->iov_alloc, head->iov_len);
			*/

	memcpy((*tail)->skbv, head->skbv+vecno, 
			(*tail)->iov_len * sizeof(struct sk_buff *));
	memcpy((*tail)->iov, head->iov+vecno, 
			(*tail)->iov_len * sizeof(struct iovec));

	if (vecoffset) {
		kssl_iovec_split((*tail)->iov, head->iov+head->iov_len-1,
				vecoffset);
		/* this skb will is in both conn_records, 
		 * so increment its reference count */
		skb_get(*((*tail)->skbv));
	}

	return 0;
}

/* head is set on error so that it can be used to send
 * an alert in the calling function */

static int 
kssl_record_split(kssl_record_t **pool, kssl_record_t **head)
{
	int status;
	kssl_record_t *tail;

	*head = *pool;
	tail = NULL;

	if ((*head)->state != KSSL_CONN_RECORD_RCV_COMPLETE)
		return 0;

	status = kssl_record_vec_split((*head), &tail, 
			KSSL_CONN_RECORD_NLEN(*head));
	if (status < 0)
		return status;
	*pool = tail;

	/*
	kssl_record_show_head(*head, "record split");
	printk(KERN_DEBUG "head total_len=%u\n", (*head)->total_len);
	{
	int i;
	for(i = 0; i < (*head)->iov_len; i++) {
		printk(KERN_DEBUG " iov no %u: iov_len=%u iov_base=%p\n", i,
					(((*head)->iov+i)->iov_len),
					(((*head)->iov+i)->iov_base));
	}
	if (tail) {
	printk(KERN_DEBUG "tail total_len=%u\n", tail->total_len);
	for(i = 0; i < tail->iov_len; i++) {
		printk(KERN_DEBUG " iov no %u: iov_len=%u iov_base=%p\n", i,
					((tail->iov+i)->iov_len),
					((tail->iov+i)->iov_base));
	}
	}
	}
	*/

	/* Queue the record for processing */
        list_move_tail(&((*head)->list), &kssl_record_in_list);
        (*head)->state = KSSL_CONN_RECORD_IN_STARTED;
	return 1;
}


static inline int 
kssl_record_update_ssl(kssl_record_t **cr)
{
	int status = -EINVAL;
	kssl_record_t *head;
	alert_t alert = { al_warning, ad_close };

	head = *cr;
	if (!(*cr)->state) {
		status = kssl_record_start(head, &alert);
		if (status < 0)
			goto alert;
		if (!status)
			return 0;
	}

	status = kssl_record_complete(head, &alert);
	if (status < 0)
		goto alert;
	if (!status)
		return 0;

	status = kssl_record_split(cr, &head);
	if (status < 0)
		goto alert;
	if (!status)
		return 0;

	return 1;

alert:
	if (*cr == head)
		*cr = NULL;

	if (kssl_alert_send(head, &alert, 1) < 0) {
		KSSL_DEBUG(6, "kssl_record_update: kssl_alert_send failed");
		kssl_conn_close(head->conn, KSSL_CONN_SSL_CLOSE);
		kssl_record_destroy(head);
	}
	/* Implied by kssl_alert_send(..., 1) */
	/* kssl_record_destroy(head); */
	return status;
}


static inline int 
kssl_record_update_pt(kssl_record_t **cr)
{
        list_move_tail(&((*cr)->list), &kssl_record_in_list);
        (*cr)->state = KSSL_CONN_RECORD_IN_STARTED;
	*cr = NULL;
        return 1;
}


int 
kssl_record_update(kssl_record_t **cr)
{
	if ((*cr)->conn_flag == KSSL_CONN_SSL)
		return kssl_record_update_ssl(cr);
	return kssl_record_update_pt(cr);
}


static inline int 
kssl_record_uncompress(kssl_record_t *cr, alert_t *alert) 
{
	/* The only compression method defined in
	 * the specification is NULL. That is the
	 * only method this implementation supports.
	 * Thus, there is nothing to do */
	cr->state = KSSL_CONN_RECORD_IN_UNCMP;
	return 0;

	/*
fail:
	alert->level = al_fatal;
	alert->description = ad_decompression_failire;
	*/
}


static inline int 
kssl_record_compress(kssl_record_t *cr) 
{
	/* The only compression method defined in
	 * the specification is NULL. That is the
	 * only method this implementation supports.
	 * Thus, there is nothing to do */
	return 0;
}


static inline int 
kssl_record_compress_flags(kssl_record_t *cr) 
{
	/* The only compression method defined in
	 * the specification is NULL. That is the
	 * only method this implementation supports.
	 * Thus, there is nothing to do */
	cr->state |= KSSL_CONN_RECORD_IN_UNCMP;
	return 0;
}

static int 
__kssl_record_digest_generate_tls(kssl_record_t *cr, 
		u16 content_len, opaque_t *secret, uint64_t seq, 
		security_parameters_t *sec_param, u8 *output)
{
	struct scatterlist sg1[4];
	struct scatterlist *sg2 = NULL;
	int status = -ENOMEM;
	void *p;
	u32 len;
	u32 nvec;

	KSSL_DEBUG(12, "__kssl_record_digest_generate_tls: enter\n");

	memset(output, 0, sec_param->hash_size);

	if (!cr->conn->conn_state.mac) {
		cr->conn->conn_state.mac = crypto_alloc_tfm(
				(sec_param->mac_algorithm == ma_md5) ?  
				"md5" : "sha1", 0);
		if (!cr->conn->conn_state.mac) {
			KSSL_DEBUG(6, "failed to load transform\n");
			goto leave;
		}
	}

	len = sec_param->hash_size;
	crypto_hmac_init(cr->conn->conn_state.mac, secret, &len);
	/* asym_print_char("sec", secret, len); */

	seq = __cpu_to_be64(seq);
	p = &(seq);
	sg1[0].page = virt_to_page(p);
	sg1[0].offset = ((long) p & ~PAGE_MASK);
	sg1[0].length = sizeof(uint64_t);
	/* asym_print_char("seq", p, sizeof(uint64_t)); */

	p = &(cr->record.head.type);
	sg1[1].page = virt_to_page(p);
	sg1[1].offset = ((long) p & ~PAGE_MASK);
	sg1[1].length = sizeof(u8);
	/* asym_print_char("type", p, sizeof(u8)); */

	p = &(cr->record.head.version);
	sg1[2].page = virt_to_page(p);
	sg1[2].offset = ((long) p & ~PAGE_MASK);
	sg1[2].length = sizeof(u16);
	/* asym_print_char("version", p, sizeof(u16)); */

	content_len = htons(content_len);
	p = &content_len;
	sg1[3].page = virt_to_page(p);
	sg1[3].offset = ((long) p & ~PAGE_MASK);
	sg1[3].length = sizeof(u16);
	/* asym_print_char("version", p, sizeof(u16)); */

	crypto_hmac_update(cr->conn->conn_state.mac, sg1, 4);

	if (kssl_record_to_sg(cr, TLS_HEAD_NLEN, 
			ntohs(content_len), &sg2, &nvec, &len) < 0) {
		KSSL_DEBUG(6, "__kssl_record_digest_generate_tls: "
				"kssl_record_to_sg\n");
		goto leave;
	}

	crypto_hmac_update(cr->conn->conn_state.mac, sg2, nvec);
	len = sec_param->hash_size;
	crypto_hmac_final(cr->conn->conn_state.mac, secret, &len, output);
	/* asym_print_char("mac", output, len); */

	status = 0;
leave:
	if (sg2)
		kssl_kfree(sg2);
	return status;
}


static int __kssl_record_digest_generate_ssl3(kssl_record_t *cr, 
		u16 content_len, opaque_t *secret, uint64_t seq, 
		security_parameters_t *sec_param, u8 *output)
{
	struct scatterlist sg1[3];
	struct scatterlist *sg2 = NULL;
	int status = -ENOMEM;
	void *p;
	u32 len;
	u32 nvec;
	u16 clen;
	
	KSSL_DEBUG(12, "__kssl_record_digest_generate_ssl3: enter\n");

	memset(output, 0, sec_param->hash_size);

	if (!cr->conn->conn_state.mac) {
		cr->conn->conn_state.mac = crypto_alloc_tfm(
				(sec_param->mac_algorithm == ma_md5) ?  
				"md5" : "sha1", 0);
		if (!cr->conn->conn_state.mac) {
			KSSL_DEBUG(6, "failed to load transform\n");
			goto leave;
		}
	}

	len = sec_param->hash_size;
	crypto_ssl3mac_init(cr->conn->conn_state.mac, secret, &len);

	seq = __cpu_to_be64(seq);
	p = &(seq);
	sg1[0].page = virt_to_page(p);
	sg1[0].offset = ((long) p & ~PAGE_MASK);
	sg1[0].length = sizeof(uint64_t);

	p = &(cr->record.head.type);
	sg1[1].page = virt_to_page(p);
	sg1[1].offset = ((long) p & ~PAGE_MASK);
	sg1[1].length = sizeof(u8);

	clen = htons(content_len);
	p = &clen;
	sg1[2].page = virt_to_page(p);
	sg1[2].offset = ((long) p & ~PAGE_MASK);
	sg1[2].length = sizeof(u16);

	crypto_hmac_update(cr->conn->conn_state.mac, sg1, 3);

	if (kssl_record_to_sg(cr, TLS_HEAD_NLEN, 
				content_len, &sg2, &nvec, &len) < 0) {
		KSSL_DEBUG(6, "__kssl_record_digest_generate_ssl3: "
				"kssl_record_to_sg\n");
		goto leave;
	}

	crypto_ssl3mac_update(cr->conn->conn_state.mac, sg2, nvec);
	len = sec_param->hash_size;
	crypto_ssl3mac_final(cr->conn->conn_state.mac, secret, &len, output);

	status = 0;
leave:
	if (sg2)
		kssl_kfree(sg2);
	return status;
}


static int 
kssl_record_digest_generate(kssl_record_t *cr,
		u16 content_len, opaque_t *secret, uint64_t seq,
	     	security_parameters_t *sec_param, u8 *output)
{
	/* Major version 3 and minor version of 0 or 1 
	 * should already have been checked */
	if (cr->conn->conn_state.version.minor) {
		return __kssl_record_digest_generate_tls(cr, content_len, 
					secret, seq, sec_param, output);
	}
	return __kssl_record_digest_generate_ssl3(cr, content_len, 
				secret, seq, sec_param, output);
}




static int 
kssl_record_digest_verify(kssl_record_t *cr, u32 record_len) 
{
	int content_len;
	u8 pad_len = 0;
	opaque_t digest_out[20]; /* MD5 or SHA1 digest will fit */

	KSSL_DEBUG(12, "kssl_record_digest_verify: enter\n");

	if (cr->conn->sec_param_in_act->bulk_cipher_algorithm != bca_null) {
		kssl_record_vec_cpy(cr, &pad_len, 1, 
				TLS_HEAD_NLEN + record_len - 1);
		content_len = record_len - pad_len - 1 - 
			cr->conn->sec_param_in_act->hash_size;
	}
	else {
		content_len = record_len -
			cr->conn->sec_param_in_act->hash_size;
	}


	if (kssl_record_digest_generate(cr, content_len, 
				cr->conn->conn_state.client_mac_secret, 
				cr->conn->conn_state.in_seq, 
				cr->conn->sec_param_in_act,
				digest_out) < 0) {
		KSSL_DEBUG(6, "kssl_record_digest_verify: "
				"kssl_record_digest_generate\n");
		return content_len;
	}

	return kssl_record_vec_cmp(cr, digest_out, 
				cr->conn->sec_param_in_act->hash_size,
				TLS_HEAD_NLEN + content_len) ? 
			-EINVAL : content_len;
}


static int 
kssl_record_decrypt(kssl_record_t *cr, alert_t *alert) 
{
	int status = 0;
	struct scatterlist *sg = NULL;
	u32 len;
	u32 nvec;

	KSSL_DEBUG(12, "kssl_record_decrypt: enter: %p\n",
			cr->conn->sec_param_in_act);

	if (!cr->conn->sec_param_in_act)
		goto leave;

	KSSL_DEBUG(12, "kssl_record_decrypt: bulk_cipher_algorithm: %d\n",
			cr->conn->sec_param_in_act->bulk_cipher_algorithm);

	if (cr->conn->sec_param_in_act->bulk_cipher_algorithm != bca_null) {
		status = kssl_record_to_sg(cr, TLS_HEAD_NLEN, 
				cr->record.head.len, &sg, &nvec, &len);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_decrypt: "
					"kssl_record_to_sg\n");
			goto leave;
		}

		status = crypto_cipher_decrypt(cr->conn->conn_state.dec, 
				sg, sg, len);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_decrypt: "
					"crypto_cipher_decrypt\n");
			goto leave;
		}
	}
	else {
		len = cr->record.head.len;
	}

	status = kssl_record_digest_verify(cr, len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_record_decrypt digest mismatch\n");
	}
	else {
		KSSL_DEBUG(9, "kssl_record_decrypt digest verified\n");
		cr->content_len = status + TLS_HEAD_NLEN;
		status = 0;
	}

leave:
	if (sg)
		kssl_kfree(sg);
	cr->state = KSSL_CONN_RECORD_IN_UNENC;
	return status;
}


static inline void 
kssl_record_encrypt_flags(kssl_record_t *cr) 
{
	if (!cr->conn->sec_param_out_act)
		return;

	cr->state |= KSSL_CONN_RECORD_OUT_DIGEST;

	if (cr->conn->sec_param_out_act->bulk_cipher_algorithm == bca_null)
		return;

	cr->state |= KSSL_CONN_RECORD_OUT_ENC;
}



static int 
kssl_record_encrypt(kssl_record_t *cr) 
{
	struct scatterlist *sg = NULL;
	int status = 0;
	size_t vecno;
	size_t vecoffset;
	u32 len;
	u32 nvec;
	u32 chunk;
	u8 pad = 0;

	KSSL_DEBUG(12, "kssl_record_encrypt: enter: %p "
			"total_len=%u content_len=%u\n", 
			cr, cr->total_len, cr->content_len);

	if (! (cr->state & KSSL_CONN_RECORD_OUT_DIGEST))
		return 0;

	kssl_record_vec_seek(cr, cr->content_len, &vecno, &vecoffset);

	status = kssl_record_digest_generate(cr, 
			 cr->content_len - TLS_HEAD_NLEN, 
			 cr->conn->conn_state.server_mac_secret, cr->seq,
			 cr->conn->sec_param_out_act,
			 (cr->iov + vecno)->iov_base + vecoffset);
	if (status < 0)  {
		KSSL_DEBUG(6, "kssl_record_encrypt: "
				"kssl_record_digest_generate\n");
		return status;
	}

	cr->record.head.len = cr->total_len - TLS_HEAD_NLEN;
	tls_head_to_buf(&(cr->record.head), cr->iov->iov_base);

	cr->state ^= KSSL_CONN_RECORD_OUT_DIGEST;

	if (! (cr->state & KSSL_CONN_RECORD_OUT_ENC))
		return 0;

	pad = cr->total_len - cr->content_len - 
		cr->conn->sec_param_out_act->hash_size - 1;
	memset((cr->iov + vecno)->iov_base + vecoffset +
			cr->conn->sec_param_out_act->hash_size, pad, pad+1);

	for(chunk = TLS_HEAD_NLEN; chunk < cr->total_len;
			chunk += PAGE_SIZE) {
		len = cr->total_len - chunk;
		if (len > PAGE_SIZE)
			len = PAGE_SIZE;

		status = kssl_record_to_sg(cr, chunk, len, &sg, &nvec, &len);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_encrypt: "
					"kssl_record_to_sg\n");
			return status;
		}

		status = crypto_cipher_encrypt(cr->conn->conn_state.enc, 
				sg, sg, len);

		kssl_kfree(sg);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_encrypt: "
					"crypto_cipher_encrypt: %d\n", status);
			return status;
		}
	}

	cr->state ^= KSSL_CONN_RECORD_OUT_ENC;

	return 0;
}

/* Verify Version
 * At this stage we only implement SSLv3  and TLSv1
 * i.e. Major = 3, Minor = 0 or 1. */

int kssl_version_verify(protocol_version_t *version) 
{
	KSSL_DEBUG(12, "kssl_version_verify: enter: %d.%d\n",
			version->major, version->minor);
	if (version->major != 3 ||
			(version->minor != 0 && version->minor != 1)) {
		KSSL_DEBUG(3, "Unsported version %u.%u\n",
				version->major, version->minor);
		KSSL_DEBUG(6, "(0.2==SSLv2, 3.0==SSLv1, 3.1==TLSv1. "
				"Other values are unknown. SSLv2 is not "
				"supported)\n");
		return -EINVAL;
	}
	return 0;
}


int kssl_record_verify_head(kssl_record_t *cr, alert_t *alert) 
{
	KSSL_DEBUG(12, "kssl_version_head: enter: %d %d.%d %d\n",
			cr->record.head.type,
			cr->record.head.version.major, 
			cr->record.head.version.minor,
			cr->record.head.len);

	/* The length should already have been verified */

	/* Verify Content Type */
	switch(cr->record.head.type) {
		case ct_ssl2:
			/* Skip version check, it isn't in the 
			 * record header */
			return 0;
		case ct_change_cipher_spec:
		case ct_alert:
		case ct_handshake:
		case ct_application_data:
			break;
		default:
			goto fail;
	}
	
	/* Verify Version */
	if (kssl_version_verify(&(cr->record.head.version)) < 0) {
		goto fail;
	}

	return 0;

fail:
	alert->level = al_fatal;
	alert->description = ad_illegal_parameter;
	return -EINVAL;
}


int kssl_record_vec_seek(kssl_record_t *cr, size_t offset,
		size_t *vecno, size_t *vecoffset)
{
	size_t vno = 0;
	size_t voffset = 0;
	size_t len = 0;

	for(vno = 0; vno < cr->iov_len; vno++) {
		voffset = MIN((cr->iov+vno)->iov_len, offset-len);
		len += voffset;
		if (len >= offset) {
			break;
		}
	}
	if (voffset == (cr->iov+vno)->iov_len) {
		vno++;
		voffset = 0;
	}

	*vecno = vno;
	*vecoffset = voffset;

	return len;
}


int kssl_record_vec_cpy(kssl_record_t *cr,
		u8 *buf, size_t buf_len, size_t offset) 
{
	size_t vno = 0;
	size_t voffset = 0;
	size_t vecno = 0;
	size_t vecoffset = 0;
	size_t len = 0;

	KSSL_DEBUG(12, "kssl_record_vec_cpy: buf_len=%u offest=%u\n",
			buf_len, offset);

	if (offset) {
		kssl_record_vec_seek(cr, offset, &vecno, &vecoffset);
	}


	for(vno = vecno; vno < cr->iov_len && len < buf_len; vno++) {
		voffset = MIN((cr->iov+vno)->iov_len - vecoffset, buf_len-len);
		memcpy(buf+len,  (cr->iov+vno)->iov_base + vecoffset, voffset);
		len += voffset;
		vecoffset = 0;
	}

	return len;
}


int kssl_record_vec_cpy_ptr(kssl_record_t *cr,
		struct iovec *iov, struct sk_buff **skbv, 
		size_t len, size_t offset) 
{
	size_t vno = 0;
	size_t voffset = 0;
	size_t vecno = 0;
	size_t vecoffset = 0;
	size_t total_len = 0;

	KSSL_DEBUG(12, "kssl_record_vec_cpy_ptr len=%u offest=%u\n",
			len, offset);

	if (offset) {
		kssl_record_vec_seek(cr, offset, &vecno, &vecoffset);
	}


	for(vno = vecno; vno < cr->iov_len && total_len < len; vno++) {
		voffset = MIN((cr->iov+vno)->iov_len - vecoffset, 
				len-total_len);
		(iov + vno - vecno)->iov_base = (cr->iov + vno)->iov_base +
			vecoffset;
		(iov + vno - vecno)->iov_len = voffset;
		*(skbv + vno - vecno) = *(cr->skbv + vno);
		total_len += voffset;
		vecoffset = 0;
	}

	return vno - vecno;
}


int kssl_record_vec_cmp(kssl_record_t *cr,
		u8 *buf, size_t buf_len, size_t offset) 
{
	size_t vno = 0;
	size_t voffset = 0;
	size_t vecno = 0;
	size_t vecoffset = 0;
	size_t len = 0;
	int status;

	if (offset) {
		kssl_record_vec_seek(cr, offset, &vecno, &vecoffset);
	}


	for(vno = vecno; vno < cr->iov_len && len < buf_len; vno++) {
		voffset = MIN((cr->iov+vno)->iov_len - vecoffset, buf_len-len);
		status = memcmp(buf+len,  (cr->iov+vno)->iov_base + vecoffset, 
				voffset);
		if (status)
			return status;
		len += voffset;
		vecoffset = 0;
	}

	return 0;
}


int 
kssl_record_to_sg(kssl_record_t *cr, u32 offset, u32 size,
		struct scatterlist **sg_r, u32 *nvec_r, u32 *len_r)
{
	struct scatterlist *sg;
	u32 vecno;
	u32 vecoffset;
	u32 total_len;
	u32 nvec;
	u32 npage;

	KSSL_DEBUG(12, "kssl_record_to_sg: enter\n");

	kssl_record_vec_seek(cr, offset, &vecno, &vecoffset);

	sg = (struct scatterlist *)kssl_kmalloc((cr->iov_len  - vecno) * 
			sizeof(struct scatterlist), GFP_KERNEL);
	if (!sg) {
		KSSL_DEBUG(6, "kssl_record_to_sg: kssl_kmalloc\n");
		return -ENOMEM;
	}

	total_len = 0;
	nvec = vecno;
	npage = 0;
	while (nvec < cr->iov_len && total_len < size) {
		(sg+npage)->page = virt_to_page((cr->iov+nvec)->iov_base
					    + vecoffset);
		(sg+npage)->offset = ((long) ((cr->iov+nvec)->iov_base
					  + vecoffset) & ~PAGE_MASK);
		(sg+npage)->length = (cr->iov+nvec)->iov_len - vecoffset;
		if ((sg+npage)->length > size - total_len) {
			(sg+npage)->length = size - total_len;
		}
		vecoffset = 0;
		total_len += (sg+npage)->length;
		nvec++;
		npage++;
	}
	nvec -= vecno;

	*sg_r = sg;
	*nvec_r = nvec;
	*len_r = total_len;

	return 0;
}


static int 
kssl_record_to_msg(kssl_record_t *cr, u32 offset, u32 size,
		struct msghdr **msg_r, u32 *nvec_r, u32 *len_r)
{
	struct msghdr *msg;
	u32 vecno;
	u32 vecoffset;
	u32 total_len;
	u32 nvec;

	kssl_record_vec_seek(cr, offset, &vecno, &vecoffset);

	msg = (struct msghdr *)kssl_kmalloc(sizeof(struct msghdr), GFP_KERNEL);
	if (!msg) {
		KSSL_DEBUG(6, "kssl_record_to_msg: kssl_kmalloc msg\n");
		return -ENOMEM;
	}
	memset(msg, 0, sizeof(struct msghdr));

	msg->msg_iovlen = cr->iov_len  - vecno;
	msg->msg_iov = (struct iovec *)kssl_kmalloc((cr->iov_len  - vecno) * 
			sizeof(struct iovec), GFP_KERNEL);
	if (!msg->msg_iov) {
		kssl_kfree(msg);
		KSSL_DEBUG(6, "kssl_record_to_msg: kssl_kmalloc msg.msg_iov\n");
		return -ENOMEM;
	}

	total_len = 0;
	nvec = 0;
	do {
		(msg->msg_iov+nvec)->iov_base = (cr->iov+nvec+vecno)->iov_base
			+ vecoffset;
		(msg->msg_iov+nvec)->iov_len = (cr->iov+nvec+vecno)->iov_len
			- vecoffset;
		if ((msg->msg_iov+nvec)->iov_len > size - total_len) {
			(msg->msg_iov+nvec)->iov_len = size - total_len;
		}
		vecoffset = 0;
		total_len += (msg->msg_iov+nvec)->iov_len;
		nvec++;
	} while(nvec + vecno < cr->iov_len && total_len);

	*msg_r = msg;
	*nvec_r = nvec;
	*len_r = total_len;

	return 0;
}


static int 
kssl_record_process_in_ssl(kssl_record_t *cr) 
{
	int status = 0;
	alert_t alert = { al_warning, ad_close };

	KSSL_DEBUG(12, "kssl_record_process_in_ssl: enter\n");

	if (cr->conn->flag & KSSL_CONN_SSL_CLOSE) {
		kssl_record_destroy(cr);
		return 0;
	}

	if (cr->state == KSSL_CONN_RECORD_IN_STARTED) {
		status = kssl_record_verify_head(cr, &alert);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_process_in_ssl: "
					"kssl_record_verify_head\n");
			goto alert;
		}

		status = kssl_record_decrypt(cr, &alert);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_process_in_ssl: "
					"kssl_record_decrypt\n");
			goto alert;
		}
	}

	if (cr->state == KSSL_CONN_RECORD_IN_UNCMP) {
		status = kssl_record_uncompress(cr, &alert);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_process_in_ssl: "
					"kssl_record_uncompress\n");
			goto alert;
		}
	}

	if (cr->state == KSSL_CONN_RECORD_IN_UNENC) {
		/* Process the messages inside the record */
		status = kssl_message_in(&cr, &alert);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_process_in_ssl: "
					"kssl_message_in\n");
			goto alert;
		}
	}

	return 0;

alert:
	if (kssl_alert_send(cr, &alert, 1) < 0) {
		KSSL_DEBUG(6, "kssl_record_process_in_ssl: "
				"kssl_alert_send failed\n");
		kssl_conn_close(cr->conn, KSSL_CONN_SSL_CLOSE);
		kssl_record_destroy(cr);
	}
	/* Implied by kssl_alert_send(..., 1) */
	/* kssl_record_destroy(cr); */
	return status;
}


static int 
kssl_record_process_in_pt(kssl_record_t *cr) 
{
	int status = 0;
	alert_t alert = { al_warning, ad_close };

	status = kssl_application_data_send_from_pt(cr, 
			cr->total_len - cr->offset, cr->offset, 1);
	if (status < 0) {
		if (kssl_alert_send(cr, &alert, 1) < 0) {
			KSSL_DEBUG(6, "kssl_record_process_in: "
					"kssl_alert_send failed\n");
			kssl_conn_close(cr->conn, KSSL_CONN_PT_CLOSE);
			kssl_record_destroy(cr);
		}
		/* Implied by kssl_alert_send(..., 1) */
		/* kssl_record_destroy(cr); */
	}

	KSSL_NOTICE(3, "ISSL014: Send(ssl): send encrypted data to client\n");

	/* Implied by kssl_application_data_send_from_pt(..., 1) */
	/* kssl_record_destroy(cr); */
	return status;
}


int 
kssl_record_process_in(kssl_record_t *cr) 
{
	KSSL_DEBUG(12, "kssl_record_process_in: total_len=%u offset=%u\n", 
			cr->total_len, cr->offset);

	if (cr->conn_flag == KSSL_CONN_SSL)
		return kssl_record_process_in_ssl(cr);
	return kssl_record_process_in_pt(cr);
}


static inline int 
kssl_conn_sendmsg(kssl_conn_t *conn, struct msghdr *msg, int size, u32 flag)
{
	int status;
	struct socket *sock;

	msg->msg_flags |= MSG_DONTWAIT;
	sock = flag == KSSL_CONN_PT ? conn->pt_sock : conn->ssl_sock;

	status = sock_sendmsg(sock, msg, size);

	if (status > 0) {
		if (flag == KSSL_CONN_PT)
			conn->pt_out_bytes += status;
		else
			conn->ssl_out_bytes += status;
	}

	KSSL_DEBUG(4, "Send (%s):  %p: %d %d\n", 
			(flag == KSSL_CONN_PT) ? "pt" : "ssl", conn, size,
			status);

	return status;
}


static int 
kssl_record_actually_send(kssl_record_t *cr) 
{
	int status;
	struct msghdr *msg = NULL;
	u32 nvec; 
	u32 len;

	KSSL_DEBUG(12, "kssl_record_actually_send: %p offset=%d len=%d\n",
			cr, cr->offset, cr->total_len - cr->offset);

	status = kssl_record_compress(cr);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_record_actually_send: "
				"kssl_record_compress\n");
		goto leave;
	}

	status = kssl_record_encrypt(cr);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_record_actually_send: "
				"kssl_record_encrypt\n");
		goto leave;
	}

	status = kssl_record_to_msg(cr, cr->offset, cr->total_len - cr->offset,
			&msg, &nvec, &len);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_record_actually_send: "
				"kssl_record_to_msg: %d\n", status);
		goto leave;
	}

	status = kssl_conn_sendmsg(cr->conn, msg, len, cr->conn_flag);
	if (status < 0 && status != -EAGAIN && status != -EWOULDBLOCK && 
			status == -EINTR) {
		KSSL_DEBUG(6, "kssl_record_actually_send: "
				"kssl_conn_sendmsg\n");
		goto leave;
	}
	else if (status != len) {
		KSSL_DEBUG(9, "kssl_record_actually_send: "
				"kssl_conn_sendmsg: %d != %d\n", status, len);
		goto leave;
	}

leave:
	if (msg) {
		kssl_kfree(msg->msg_iov);
		kssl_kfree(msg);
	}
	return status;
}


int 
kssl_record_send(kssl_record_t *cr)
{                       
	KSSL_DEBUG(12, "kssl_record_send enter: %p "
			"total_len=%u content_len=%u\n",
			cr, cr->total_len, cr->content_len);

	cr->seq = -1;

	if (cr->conn_flag == KSSL_CONN_SSL) {
		cr->state = 0;
		kssl_record_compress_flags(cr);
		kssl_record_encrypt_flags(cr);
		if (cr->state & KSSL_CONN_RECORD_OUT_DIGEST) {
			cr->seq = cr->conn->conn_state.out_seq;
			cr->conn->conn_state.out_seq++;
		}
		kssl_message_mask_update_out(cr);
	}

	list_move_tail(&(cr->list), &kssl_record_out_list);

	return 0;
}       


int 
kssl_record_process_out(kssl_record_t *cr) 
{
	int status = 0;

	if (cr->conn_flag == KSSL_CONN_PT && !cr->conn->pt_sock) {
		status = kssl_conn_open_pt(cr->conn);
		if (status < 0) {
			KSSL_DEBUG(6, "kssl_record_process_out: "
					"kssl_conn_open_pt\n");
			goto leave;
		}
	}

	status = kssl_record_actually_send(cr);
	if (status == -EAGAIN || status == -EWOULDBLOCK || status == -EINTR) {
		/* Transient error, retry */
		return status;
	}

	if (status > 0 && status < cr->total_len - cr->offset) {
		cr->offset += status;
		return -EAGAIN;
	}
 
	if ( cr->conn->flag & KSSL_CONN_ALL_CLOSE  && cr->conn->users < 3) 
		kssl_conn_close(cr->conn, KSSL_CONN_ALL_CLOSE);

leave:
	if (status < 0) {
		if (cr->conn_flag == KSSL_CONN_PT)
			kssl_conn_close(cr->conn, KSSL_CONN_PT_CLOSE);
		else
			kssl_conn_close(cr->conn, KSSL_CONN_SSL_CLOSE);
	}
	kssl_record_destroy(cr);
	return status; 
}


void 
kssl_record_head_set(kssl_record_t *cr, content_type_t type,
		size_t len)
{
	cr->record.head.version.major = cr->conn->conn_state.version.major;
	cr->record.head.version.minor = cr->conn->conn_state.version.minor;
	cr->record.head.type = type;
	cr->record.head.len = len;

	KSSL_DEBUG(12, "kssl_record_head_set: %02x %02x %02x %04x\n",
			cr->record.head.type,
			cr->record.head.version.major,
			cr->record.head.version.minor,
			cr->record.head.len);

	tls_head_to_buf(&(cr->record.head), cr->iov->iov_base);
}


static inline int 
__kssl_record_build_send(kssl_record_t *cr,
		record_write_func_t write_func, void *data)
{
	int status;

	status = write_func(cr, data);
	if (status < 0) {
		KSSL_DEBUG(6, "__kssl_record_build_send: "
				"write_func\n");
		return status;
	}

	status = kssl_record_send(cr);
	if (status < 0) {
		KSSL_DEBUG(6, "__kssl_record_build_send: "
				"kssl_record_send\n");
		return status;
	}
	
	return cr->total_len;
}


static int 
kssl_record_build_send_new(kssl_record_t *cr,
		record_write_func_t write_func,
		size_t content_len, size_t total_len, void *data)
{
	kssl_record_t *new_cr;

	KSSL_DEBUG(12, "kssl_record_build_send_new: %u\n", total_len);

	new_cr = kssl_record_create(cr->conn, cr->conn_flag);
	if (!new_cr) {
		KSSL_DEBUG(6, "kssl_record_build_send_new: "
				"kssl_record_create\n");
		return -ENOMEM;
	}

	memcpy(&(new_cr->record.head), &(cr->record.head), sizeof(tls_head_t));
	memcpy(&(new_cr->ssl2_head), &(cr->ssl2_head), sizeof(ssl2_head_t));

	new_cr->iov_len = new_cr->iov_alloc = 1; 
	new_cr->iov = kssl_kmalloc(sizeof(struct iovec), GFP_KERNEL);
	if (!new_cr->iov) {
		KSSL_DEBUG(6, "kssl_record_build_send_new: "
				"kssl_kmalloc 1\n");
		kssl_record_destroy(new_cr);
		return -ENOMEM;
	}
	
	new_cr->total_len = new_cr->iov->iov_len = total_len;
	new_cr->content_len = content_len;
	new_cr->iov->iov_base = kssl_kmalloc(new_cr->total_len, GFP_KERNEL);
	if (!new_cr->iov->iov_base) {
		KSSL_DEBUG(6, "kssl_record_build_send_new: "
				"kssl_kmalloc 2\n");
		kssl_kfree(new_cr->iov);
		kssl_record_destroy(new_cr);
		return -ENOMEM;
	}

	return __kssl_record_build_send(new_cr, write_func, data);
}


size_t 
kssl_record_tail_len(kssl_record_t *cr, size_t content_len)
{
	int pad;
	int total_len;

	if (!cr->conn->sec_param_out_act)
		return 0;

	if (cr->conn->sec_param_out_act->bulk_cipher_algorithm == bca_null)
		return cr->conn->sec_param_out_act->hash_size;

	
	total_len = content_len - TLS_HEAD_NLEN + 1 + 
		cr->conn->sec_param_out_act->hash_size;
	pad = cr->conn->sec_param_out_act->block_size - 
		(total_len % cr->conn->sec_param_out_act->block_size);
	if (pad == cr->conn->sec_param_out_act->block_size)
		pad = 0;

	return pad +  cr->conn->sec_param_out_act->hash_size + 1;
}


int 
kssl_record_build_send(kssl_record_t *cr, record_write_func_t write_func, 
		size_t content_len, void *data, int reuse)
{
	int status;
	int total_len;

	total_len = content_len + kssl_record_tail_len(cr, content_len);

	if (reuse && cr && cr->iov_len > 0 && cr->iov->iov_len >=  total_len) {
		cr->iov->iov_len = cr->total_len = total_len;
		cr->content_len = content_len;
		cr->offset = 0;
		return __kssl_record_build_send(cr, write_func, data);
	}

	status = kssl_record_build_send_new(cr, write_func, content_len,
			total_len, data);
	if (reuse && status >= 0)
		kssl_record_destroy(cr);
	return status;
}


int 
kssl_record_pt_send(kssl_record_t *cr, size_t len, size_t offset)
{
	int status;
	size_t old_offset;
	size_t old_total_len;
	u32 old_conn_flag;

	KSSL_DEBUG(12, "kssl_record_pt_send enter: len=%d offset=%d\n",
			len, cr->offset);

	old_offset = cr->offset;
	old_total_len = cr->total_len;
	old_conn_flag = cr->conn_flag;

	cr->offset += offset;
	if (len + cr->offset < cr->total_len)
		cr->total_len = len + cr->offset;
	cr->conn_flag = KSSL_CONN_PT;

	status = kssl_record_send(cr);
	if (status < 0) {
		cr->offset = old_offset;
		cr->total_len = old_total_len;
		cr->conn_flag = old_conn_flag;
	}

	return status;
}

