/**********************************************************************
 * conn.c                                                   August 2005
 *
 * TCPS: TCP Splicing Module
 * This code is based on ip_vs_conn.c of ipvs-1.0.9.
 *
 * IPVS         An implementation of the IP virtual server support for the
 *              LINUX operating system.  IPVS is now implemented as a module
 *              over the Netfilter framework. IPVS can be used to build a
 *              high-performance and highly available server based on a
 *              cluster of servers.
 *
 * Authors:     Wensong Zhang <wensong@linuxvirtualserver.org>
 *              Peter Kese <peter.kese@ijs.si>
 *              Julian Anastasov <ja@ssi.bg>
 *
 * The IPVS code for kernel 2.2 was done by Wensong Zhang and Peter Kese,
 * with changes/fixes from Julian Anastasov, Lars Marowsky-Bree, Horms
 * and others. Many code here is taken from IP MASQ code of kernel 2.2.
 *
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/config.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/kernel.h>
#include <linux/errno.h>
#include <linux/vmalloc.h>
#include <linux/init.h>
#include <linux/fs.h>
#include <linux/sysctl.h>
#include <linux/proc_fs.h>
#include <linux/timer.h>
#include <linux/swap.h>
#include <linux/proc_fs.h>
#include <linux/file.h>
#include <linux/skbuff.h>               /* for struct sk_buff */
#include <linux/ip.h>                   /* for struct iphdr */
#include <net/tcp.h>                    /* for csum_tcpudp_magic */
#include <net/udp.h>
#include <net/icmp.h>                   /* for icmp_send */

#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>

#include <net/ip.h>
#include <net/sock.h>

#include <asm/uaccess.h>

#ifdef TCPS_INTEGRATED_SOURCE
#include <net/ipv4/tcps/tcps.h>
#else
#include "tcps.h"
#endif

/*
 *  Fine locking granularity for big connection hash table
 */
#define CT_LOCKARRAY_BITS  4
#define CT_LOCKARRAY_SIZE  (1<<CT_LOCKARRAY_BITS)
#define CT_LOCKARRAY_MASK  (CT_LOCKARRAY_SIZE-1)

/*
 *  Connection hash table: for input and output packets lookups of tcps
 */
static struct list_head *tcps_conn_tab;

/*  SLAB cache for tcps connections */
static kmem_cache_t *tcps_conn_cachep;

/* connection entry counter */
static atomic_t tcps_conn_count;

struct tcps_aligned_lock {
	rwlock_t	l;
} __attribute__((__aligned__(SMP_CACHE_BYTES)));

/* lock array for conn table */
struct tcps_aligned_lock
__tcps_conntbl_lock_array[CT_LOCKARRAY_SIZE] __cacheline_aligned;

static inline void ct_read_lock(unsigned key)
{
	read_lock(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_read_unlock(unsigned key)
{
	read_unlock(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_write_lock(unsigned key)
{
	write_lock(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_write_unlock(unsigned key)
{
	write_unlock(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_read_lock_bh(unsigned key)
{
	read_lock_bh(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_read_unlock_bh(unsigned key)
{
	read_unlock_bh(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_write_lock_bh(unsigned key)
{
	write_lock_bh(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}

static inline void ct_write_unlock_bh(unsigned key)
{
	write_unlock_bh(&__tcps_conntbl_lock_array[key&CT_LOCKARRAY_MASK].l);
}


/*
 *	Returns hash value for IPVS connection entry
 */
static inline unsigned
tcps_conn_hashkey(__u32 addr, __u16 port)
{
	unsigned addrh = ntohl(addr);

	return (addrh^(addrh>>TCPS_CONN_TAB_BITS)^ntohs(port))
		& TCPS_CONN_TAB_MASK;
}

/*
 * tcps_state_name: convert state to string representation.
 * XXX maybe need more work.
 */
static const char *
tcps_state_name(int state)
{
	if (state & (TCPS_CONN_S_RS_FIN | TCPS_CONN_S_CL_FIN)) {
		return "TIME_WAIT";
	} else if (state & (TCPS_CONN_S_RS_SENT | TCPS_CONN_S_CL_SENT)) {
		return "ESTABLISHED";
	} else if (state & TCPS_CONN_S_SPLICED) {
		return "ESTABLISHED";
	} else {
		return "SYN_SENT";
	}
}

static int
tcps_conn_getinfo(char *buffer, char **start, off_t offset, int length)
{
	off_t pos = 0;
	int idx, len = 0;
	char temp[70];
	struct tcps_conn *tc;
	struct list_head *l, *e;

	pos = 128;
	if (pos > offset) {
		len += sprintf(buffer+len, "%-127s\n",
			       "Pro FromIP   FPrt ToIP     TPrt DestIP   DPrt State       Expires");
	}

	for(idx = 0; idx < TCPS_CONN_TAB_SIZE; idx++) {
		/*
		 *	Lock is actually only need in next loop
		 *	we are called from uspace: must stop bh.
		 */
		ct_read_lock_bh(idx);

		l = &tcps_conn_tab[idx];
		for (e = l->next; e != l; e = e->next) {
			tc = list_entry(e, struct tcps_conn, c_list);
			pos += 128;
			if (pos <= offset)
				continue;
			sprintf(temp,
				"TCP %08X %04X %08X %04X %08X %04X %-11s %7lu",
				ntohl(tc->caddr), ntohs(tc->cport),
				ntohl(tc->vaddr), ntohs(tc->vport),
				ntohl(tc->raddr), ntohs(tc->rport),
				tcps_state_name(tc->state),
				(tc->timer.expires-jiffies)/HZ);
			len += sprintf(buffer + len, "%-127s\n", temp);
			if (pos >= offset + length) {
				ct_read_unlock_bh(idx);
				goto done;
			}
		}
		ct_read_unlock_bh(idx);
	}

  done:
	*start = buffer + len - (pos - offset);  /* Start of wanted data */
	len = pos - offset;
	if (len > length)
		len = length;
	if (len < 0)
		len = 0;
	return len;
}


int
tcps_conn_hash(struct tcps_conn *tc)
{
	unsigned hash;

	if (tc->flags & TCPS_CONN_F_HASHED) {
		TCPS_ERR("tcps_conn_hash(): request for already hashed, "
			 "called from %p\n", __builtin_return_address(0));
		return 0;
	}

	/* Hash by protocol, client address and port */
	hash = tcps_conn_hashkey(tc->caddr, tc->cport);

	ct_write_lock_bh(hash);

	list_add(&tc->c_list, &tcps_conn_tab[hash]);
	tc->flags |= TCPS_CONN_F_HASHED;
	atomic_inc(&tc->refcnt);

	ct_write_unlock_bh(hash);

	return 1;
}

int
tcps_conn_unhash(struct tcps_conn *tc)
{
	unsigned hash;

	if (!(tc->flags & TCPS_CONN_F_HASHED)) {
		TCPS_ERR("tcps_conn_unhash(): request for unhash flagged, "
			 "called from %p\n", __builtin_return_address(0));
		return 0;
	}

	/* unhash it and decrease its reference counter */
	hash = tcps_conn_hashkey(tc->caddr, tc->cport);
	ct_write_lock_bh(hash);

	list_del(&tc->c_list);
	tc->flags &= ~TCPS_CONN_F_HASHED;
	atomic_dec(&tc->refcnt);

	ct_write_unlock_bh(hash);

	return 1;
}

static inline struct tcps_conn *
__tcps_conn_get_c_v(__u32 caddr, __u16 cport, __u32 vaddr, __u16 vport)
{
	struct tcps_conn *tc, *found;
	unsigned int hash;
	struct list_head *l, *e;

	hash = tcps_conn_hashkey(caddr, cport);
	l = &tcps_conn_tab[hash];
	found = NULL;

	ct_read_lock_bh(hash);
	for (e = l->next; e != l; e = e->next) {
		tc = list_entry(e, struct tcps_conn, c_list);
		if (caddr == tc->caddr && cport == tc->cport &&
		    vaddr == tc->vaddr && vport == tc->vport) {
			if (tc->state & TCPS_CONN_S_EXPIRED) {
				continue;
			}
			atomic_inc(&tc->refcnt);
			found = tc;
			break;
		}
	}
	ct_read_unlock_bh(hash);
	return found;
}

struct tcps_conn *
tcps_conn_in_get(__u32 caddr, __u16 cport, __u32 vaddr, __u16 vport)
{
	return __tcps_conn_get_c_v(caddr, cport, vaddr, vport);
}

struct tcps_conn *
tcps_conn_out_rs_get(__u32 caddr, __u16 cport, __u32 laddr, __u16 lport)
{
	struct tcps_conn *tc, *found;
	unsigned int hash;
	struct list_head *l, *e;

	hash = tcps_conn_hashkey(caddr, cport);
	l = &tcps_conn_tab[hash];
	found = NULL;

	ct_read_lock_bh(hash);
	for (e = l->next; e != l; e = e->next) {
		tc = list_entry(e, struct tcps_conn, c_list);
		if (caddr == tc->caddr && cport == tc->cport &&
		    laddr == tc->laddr && lport == tc->lport) {
			if (tc->state & TCPS_CONN_S_EXPIRED) {
				continue;
			}
			atomic_inc(&tc->refcnt);
			found = tc;
			break;
		}
	}
	ct_read_unlock_bh(hash);
	return found;
}

struct tcps_conn *
tcps_conn_out_cl_get(__u32 caddr, __u16 cport, __u32 vaddr, __u16 vport)
{
	return __tcps_conn_get_c_v(caddr, cport, vaddr, vport);
}


struct tcps_conn *
tcps_conn_preroute_get(__u32 caddr, __u16 cport, __u32 raddr, __u16 rport)
{
	struct tcps_conn *tc, *found;
	unsigned int hash;
	struct list_head *l, *e;

	hash = tcps_conn_hashkey(caddr, cport);
	l = &tcps_conn_tab[hash];
	found = NULL;

	ct_read_lock_bh(hash);
	for (e = l->next; e != l; e = e->next) {
		tc = list_entry(e, struct tcps_conn, c_list);
		if (caddr == tc->caddr && cport == tc->cport &&
		    raddr == tc->raddr && rport == tc->rport) {
			if (tc->state & TCPS_CONN_S_EXPIRED) {
				continue;
			}
			atomic_inc(&tc->refcnt);
			found = tc;
			break;
		}
	}
	ct_read_unlock_bh(hash);
	return found;
}

void
tcps_conn_set_state(struct tcps_conn *tc, u32 state)
{
	tc->state = state;
	if (state & (TCPS_CONN_S_RS_FIN | TCPS_CONN_S_CL_FIN)) {
		tc->timeout = TCPS_CONN_TIMEOUT_FIN;
	} else {
		tc->timeout = TCPS_CONN_TIMEOUT;
	}
}

void
tcps_conn_put(struct tcps_conn *tc)
{
	mod_timer(&tc->timer, jiffies + tc->timeout);
	atomic_dec(&tc->refcnt);
}

static void
tcps_conn_expire(unsigned long data)
{
	struct tcps_conn *tc = (struct tcps_conn *)data;

	tc->timeout = TCPS_CONN_TIMEOUT_FIN;
	/*
	 *	hey, I'm using it
	 */
	atomic_inc(&tc->refcnt);

	/*
	 *	unhash it if it is hashed in the conn table
	 */
	tcps_conn_unhash(tc);

	tc->state |= TCPS_CONN_S_EXPIRED;

	/*
	 *	refcnt==1 implies I'm the only one referrer
	 */
	if (atomic_read(&tc->refcnt) == 1) {
		/* make sure that there is no timer on it now */
		if (timer_pending(&tc->timer))
			del_timer(&tc->timer);

		TCPS_DBG("tcps_conn_expire: expire %p\n", tc);
		tcps_conn_free(tc);
		return;
	}

	/* hash it back to the table */
	tcps_conn_hash(tc);

	TCPS_DBG("tcps_conn_expire: delayed %p: refcnt-1=%d\n",
		 tc, atomic_read(&tc->refcnt)-1);

	tcps_conn_put(tc);
}

void
tcps_conn_expire_now(struct tcps_conn *tc)
{
	tc->timeout = 0;
	tc->state |= TCPS_CONN_S_EXPIRED;
	mod_timer(&tc->timer, jiffies);
	atomic_dec(&tc->refcnt);
}

/*
 *  Create a new connection entry and hash it into the tcps_conn_tab.
 */
struct tcps_conn *
tcps_conn_new(__u32 laddr, __u16 lport, __u32 raddr, __u16 rport,
	      __u32 caddr, __u16 cport)
{
	struct tcps_conn *tc;

	tc = kmem_cache_alloc(tcps_conn_cachep, GFP_ATOMIC);
	if (tc == NULL) {
		printk(KERN_INFO "No memory available");
		return NULL;
	}

	memset(tc, 0, sizeof(*tc));
	INIT_LIST_HEAD(&tc->c_list);

	init_timer(&tc->timer);
	tc->timer.data     = (unsigned long)tc;
	tc->timer.function = tcps_conn_expire;

	tc->caddr	   = caddr;
	tc->cport	   = cport;
	tc->laddr	   = laddr;
	tc->lport	   = lport;
	tc->raddr          = raddr;
	tc->rport          = rport;
	tc->state          = TCPS_CONN_S_NEW;
	tc->lock           = SPIN_LOCK_UNLOCKED;

	atomic_inc(&tcps_conn_count);

	/* Set its state and timeout */
	tcps_conn_set_state(tc, TCPS_CONN_S_NEW);

	/*
	 * Set the entry is referenced by the current thread before hashing
	 * it in the table.
	 */
	atomic_set(&tc->refcnt, 1);

	tcps_conn_hash(tc);

	return tc;
}

void
tcps_conn_free(struct tcps_conn *tc)
{
	spin_lock_bh(&tc->lock);
	if (tc->csk != NULL) {
		tcps_reset_sock(tc->csk);
		sock_put(tc->csk);
		tc->csk = NULL;
	}
	if (tc->rsk != NULL) {
		tcps_reset_sock(tc->rsk);
		sock_put(tc->rsk);
		tc->rsk = NULL;
	}
	spin_unlock_bh(&tc->lock);

	kmem_cache_free(tcps_conn_cachep, tc);
	atomic_dec(&tcps_conn_count);
}

int
tcps_conn_init(void)
{
	int idx;

	/*
	 * Allocate the connection hash table and initialize its list heads
	 */
	tcps_conn_tab = vmalloc(TCPS_CONN_TAB_SIZE*sizeof(struct list_head));
	if (!tcps_conn_tab)
		return -ENOMEM;

	/* Allocate tcps_conn slab cache */
	tcps_conn_cachep = kmem_cache_create("tcps_conn",
					      sizeof(struct tcps_conn), 0,
					      SLAB_HWCACHE_ALIGN, NULL, NULL);
	if (!tcps_conn_cachep) {
		vfree(tcps_conn_tab);
		return -ENOMEM;
	}

	TCPS_INFO("Connection hash table configured "
		  "(size=%d, memory=%ldKbytes)\n",
		  TCPS_CONN_TAB_SIZE,
		  (long)(TCPS_CONN_TAB_SIZE*sizeof(struct list_head))/1024);
	TCPS_DBG("Each connection entry needs %d bytes at least\n",
		 sizeof(struct tcps_conn));

	for (idx = 0; idx < TCPS_CONN_TAB_SIZE; idx++) {
		INIT_LIST_HEAD(&tcps_conn_tab[idx]);
	}

	for (idx = 0; idx < CT_LOCKARRAY_SIZE; idx++)  {
		__tcps_conntbl_lock_array[idx].l = RW_LOCK_UNLOCKED;
	}

	proc_net_create("tcps_conn", 0, tcps_conn_getinfo);
	atomic_set(&tcps_conn_count, 0);

	return 0;
}

void
tcps_conn_fini(void)
{
	proc_net_remove("tcps_conn");
	kmem_cache_destroy(tcps_conn_cachep);
	vfree(tcps_conn_tab);
}
