/**********************************************************************
 * message_mask.c                                           August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/version.h>

#include "record.h"
#include "message_mask.h"
#include "log.h"

#include "types/change_cipher_spec_t.h"
#include "types/handshake_t.h"
#include "types/alert_t.h"


static u32 kssl_message_mask_check_hanshake(const kssl_record_t *cr)
{
	/* Is this message expected ? */
	switch(cr->msg->data.handshake.msg_type) {
		case ht_hello_request:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_HELLO_REQUEST);
		case ht_client_hello:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_CLIENT_HELLO);
		case ht_server_hello:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_SERVER_HELLO);
		case ht_certificate:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_CERTIFICATE);
		case ht_server_key_exchange:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_SERVER_KEY_EXCHANGE);
		case ht_certificate_request:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_CERTIFICATE_REQUEST);
		case ht_server_hello_done:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_SERVER_HELLO_DONE);
		case ht_certificate_verify:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_CERTIFICATE_VERIFY);
		case ht_client_key_exchange:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_CLIENT_KEY_EXCHANGE);
		case ht_finished:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_H_FINISHED);
		default:
			KSSL_DEBUG(3, "kssl_message_mask_check: "
					"Invalid handshake.msg_type\n");
			return(0);
	}

	/* Not reached */
	return(0);
}


u32 kssl_message_mask_check(const kssl_record_t *cr)
{
	KSSL_DEBUG(12, "kssl_message_mask_check: "
			"head.type=%d msg_mask=0x%x\n", 
			cr->record.head.type, cr->conn->msg_mask);

	/* Is this message expected ? */
	switch(cr->record.head.type) {
		case ct_change_cipher_spec:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_CHANGE_CIPHER_SPEC);
		case ct_alert:
			return(cr->conn->msg_mask & KSSL_CONN_M_ALERT);
		case ct_handshake:
		case ct_ssl2:
			return((cr->conn->msg_mask & KSSL_CONN_M_HANDSHAKE) &&
					kssl_message_mask_check_hanshake(cr));
		case ct_application_data:
			return(cr->conn->msg_mask & 
					KSSL_CONN_M_APPLICATION_DATA);
		default:
			KSSL_DEBUG(3, "kssl_message_mask_check: "
					"Invalid head.type\n");
			return(0);
	}

	/* Not reached */
	return(0);
}


static void kssl_message_mask_update_hanshake_out(kssl_record_t *cr)
{
	/* what messages are now expected */
	switch(cr->msg->data.handshake.msg_type) {
		case ht_server_hello:
		case ht_certificate:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT;
			break;
		case ht_server_hello_done:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_H_CLIENT_KEY_EXCHANGE;
			break;
		case ht_finished:
			if (cr->conn->sec_param_in_act == 
					cr->conn->sec_param_out_act) {
				cr->conn->msg_mask = KSSL_CONN_M_ALERT |
					KSSL_CONN_M_APPLICATION_DATA;
			}
			else {
				cr->conn->msg_mask = KSSL_CONN_M_ALERT |
					KSSL_CONN_M_CHANGE_CIPHER_SPEC;
			}
			break;
		case ht_hello_request:
		case ht_client_hello:
		case ht_certificate_request:
		case ht_certificate_verify:
		case ht_server_key_exchange:
		case ht_client_key_exchange:
		default:
			KSSL_DEBUG(3, "kssl_message_mask_update_hanshake_out"
					": Invalid handshake.msg_type 0x%x\n",
					cr->msg->data.handshake.msg_type);
			cr->conn->msg_mask = 0;
			break;
	}

	return;
}


static void kssl_message_mask_update_alert(kssl_record_t *cr)
{
	/* what messages are now expected */
	if(cr->msg->data.alert.level != al_warning ||
			cr->msg->data.alert.description == ad_close) {
		cr->conn->msg_mask = 0;
	}

	/* Not reached */
	return;
}


void kssl_message_mask_update_out(kssl_record_t *cr)
{
	/* what messages are now expected */
	switch(cr->record.head.type) {
		case ct_change_cipher_spec:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_H_FINISHED;
			break;
		case ct_alert:
			kssl_message_mask_update_alert(cr);
			break;
		case ct_handshake:
		case ct_ssl2:
			kssl_message_mask_update_hanshake_out(cr);
			break;
		case ct_application_data:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_APPLICATION_DATA;
			break;
		default:
			KSSL_DEBUG(3, "kssl_message_mask_update_out: "
					"Invalid head.type: %d\n",
					cr->record.head.type);
			cr->conn->msg_mask = 0;
			break;
	}

	KSSL_DEBUG(12, "kssl_message_mask_update_out: "
			"head.type=%d msg_mask=0x%x\n", 
			cr->record.head.type, cr->conn->msg_mask);
	return;
}


static void kssl_message_mask_update_hanshake_in(kssl_record_t *cr)
{
	/* what messages are now expected */
	switch(cr->msg->data.handshake.msg_type) {
		case ht_client_hello:
		case ht_finished:
			if (cr->conn->sec_param_in_act == 
					cr->conn->sec_param_out_act) {
				cr->conn->msg_mask = KSSL_CONN_M_ALERT |
					KSSL_CONN_M_APPLICATION_DATA;
			}
			else {
				cr->conn->msg_mask = KSSL_CONN_M_ALERT;
			}
			break;
		case ht_client_key_exchange:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_CHANGE_CIPHER_SPEC;
			break;
		case ht_hello_request:
		case ht_server_hello:
		case ht_certificate:
		case ht_certificate_request:
		case ht_server_hello_done:
		case ht_certificate_verify:
		case ht_server_key_exchange:
		default:
			KSSL_DEBUG(3,
					"kssl_message_mask_update_hanshake_in: "
					"Invalid handshake.msg_type %d\n",
					cr->msg->data.handshake.msg_type);
			cr->conn->msg_mask = 0;
			break;
	}

	return;
}


void kssl_message_update_mask_in(kssl_record_t *cr)
{
	/* what messages are now expected */
	switch(cr->record.head.type) {
		case ct_change_cipher_spec:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_H_FINISHED;
			break;
		case ct_alert:
			kssl_message_mask_update_alert(cr);
			break;
		case ct_handshake:
		case ct_ssl2:
			kssl_message_mask_update_hanshake_in(cr);
			break;
		case ct_application_data:
			cr->conn->msg_mask = KSSL_CONN_M_ALERT |
				KSSL_CONN_M_APPLICATION_DATA;
			break;
		default:
			KSSL_DEBUG(3, "kssl_message_update_mask_in: "
					"Invalid head.type\n");
			cr->conn->msg_mask = 0;
			break;
	}

	return;
}

