/* Copyright 2013,2014 Akira Ohta (akohta001@gmail.com)
    This file is part of ntch.

    The ntch is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    The ntch is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with ntch.  If not, see <http://www.gnu.org/licenses/>.
    
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <wchar.h>
#include <assert.h>

#include "error.h"
#include "utils/nt_std_t.h"
#include "utils/nt_pthread.h"

#define NT_PTHREAD_CHK_SUM (1278428)

struct _semaphore_for_thread_pool{
	sem_t sem;
	pthread_mutex_t mutex;
	nt_queue_handle h_que;
	nt_queue_handle h_que_result;
	int thread_size;
	int queue_size;
	pthread_t *pthreads;
}g_sem;

static pthread_mutex_t g_counter_lock_mutex = PTHREAD_MUTEX_INITIALIZER;

typedef struct tag_nt_pthread_t *nt_pthread_tp;
typedef struct tag_nt_pthread_t{
	nt_pthread_handle_t handle;
	int ref_count;
	nt_pthread_fn func;
	void *param;
	nt_pthread_result_fn result_func;
	nt_pthread_result_t result;
}nt_pthread_t;


static void* thread_pool_func(void *data);
static BOOL nt_pthread_exec(nt_pthread_handle h_pthread);
static nt_pthread_result_t nt_pthread_call_result_func(nt_pthread_handle h_pthread);
static BOOL nt_pthread_has_result(nt_pthread_handle h_pthread);

int nt_pthread_increment_int(int *val)
{
	int count;
	pthread_mutex_lock(&g_counter_lock_mutex);
	(*val)++;
	count = *val;
	pthread_mutex_unlock(&g_counter_lock_mutex);
	return count;
}

int nt_pthread_decrement_int(int *val)
{
	int count;
	pthread_mutex_lock(&g_counter_lock_mutex);
	(*val)--;
	count = *val;
	pthread_mutex_unlock(&g_counter_lock_mutex);
	return count;
}

BOOL nt_pthread_lib_init(int thread_pool_size, int queue_size,
		const char **error_msg)
{
	pthread_t *pthreads;
	int i;
	
	assert(thread_pool_size > 0);
	assert(queue_size >= 0);
	
	if(0 != sem_init(&g_sem.sem, 0, 0)){
		*error_msg = "Semaphore initialization failed";
		return FALSE;
	}
	g_sem.queue_size = queue_size;
	g_sem.thread_size = thread_pool_size;
	
	pthread_mutex_init(&(g_sem.mutex), NULL);
	
	g_sem.h_que = nt_queue_alloc();
	g_sem.h_que_result = nt_queue_alloc();
	if(!g_sem.h_que || !g_sem.h_que_result){
		sem_destroy(&g_sem.sem);
		*error_msg = "Memory allocation failed";
		return FALSE;
	}
	pthreads = calloc(thread_pool_size, sizeof(pthread_t));
	if(!pthreads){
		sem_destroy(&g_sem.sem);
		*error_msg = "Memory allocation failed";
		return FALSE;
	}
	for(i = 0; i < thread_pool_size; i++){
		if(0 != pthread_create(&pthreads[i], NULL, 
				thread_pool_func, NULL)){
			free(pthreads);
			*error_msg = "Thread creation failed";
			return FALSE;
		}
	}
	g_sem.pthreads = pthreads;
	return TRUE;
}


void nt_pthread_lib_finish()
{
	int i;
	void *retval;
	for(i = 0; i < g_sem.thread_size; i++){
		sem_post(&g_sem.sem);
	}
	for(i = 0; i < g_sem.thread_size; i++){
		pthread_join(g_sem.pthreads[i], &retval);
	}
	free(g_sem.pthreads);
	pthread_mutex_destroy(&(g_sem.mutex));
	sem_destroy(&g_sem.sem);
	
	nt_queue_free(g_sem.h_que, NULL);
	nt_queue_free(g_sem.h_que_result, NULL);
}

BOOL nt_pthread_put_que(nt_pthread_handle h_pthread)
{
	int err, sem_count;
	
	err = pthread_mutex_lock(&(g_sem.mutex));
	if(err != 0)
		return FALSE;
	
	err = sem_getvalue(&g_sem.sem, &sem_count);
	if(err != 0)
		goto ERROR_TRAP;
	if(sem_count >= g_sem.queue_size)
		goto ERROR_TRAP;
	
	if(!nt_queue_push(g_sem.h_que, h_pthread))
		goto ERROR_TRAP;
	
	nt_pthread_add_ref(h_pthread);
	
	err = sem_post(&g_sem.sem);
	if(err != 0)
		goto ERROR_TRAP;

	err = pthread_mutex_unlock(&(g_sem.mutex));
	if(err != 0)
		return FALSE;
	return TRUE;
ERROR_TRAP:
	pthread_mutex_unlock(&(g_sem.mutex));
	return FALSE;
}

static void* thread_pool_func(void *data)
{
	nt_pthread_handle h_pthread;
	int err;
	
	while(1){
		if(0 != sem_wait(&g_sem.sem)){
			return NULL;
		}
		err = pthread_mutex_lock(&(g_sem.mutex));
		if(err != 0)
			return NULL;
			
		h_pthread = (nt_pthread_handle)nt_queue_shift(g_sem.h_que);
		
		err = pthread_mutex_unlock(&(g_sem.mutex));
		if(err != 0)
			return NULL;
			
		if(!h_pthread)
			return NULL;
		
		if(!nt_pthread_exec(h_pthread))
			return NULL;
		
		if(!nt_pthread_has_result(h_pthread)){
			nt_pthread_release_ref(h_pthread);
			continue;
		}
		
		err = pthread_mutex_lock(&(g_sem.mutex));
		if(err != 0)
			return NULL;
			
		if(!nt_queue_push(g_sem.h_que_result, h_pthread))
			return NULL;
		
		err = pthread_mutex_unlock(&(g_sem.mutex));
		if(err != 0)
			return NULL;
			
	}
	return NULL;
}

nt_pthread_result_t nt_pthread_get_result_from_que()
{
	nt_pthread_handle h_pthread;
	nt_pthread_result_t result;
	int err;
	
	result.code = NT_PTHREAD_RESULT_NONE;
	result.data = NULL;
	
	err = pthread_mutex_lock(&(g_sem.mutex));
	if(err != 0)
		return result;
	if(g_sem.h_que_result)
		h_pthread = (nt_pthread_handle)nt_queue_shift(g_sem.h_que_result);
	else
		h_pthread = NULL;
	err = pthread_mutex_unlock(&(g_sem.mutex));
	if(err != 0)
		return result;
	
	if(!h_pthread)
		return result;
	
	result = nt_pthread_call_result_func(h_pthread);
	
	nt_pthread_release_ref(h_pthread);
	
	return result;
	
}

static BOOL nt_pthread_has_result(nt_pthread_handle h_pthread)
{
	nt_pthread_tp pthreadp;
	assert(h_pthread);
	assert(h_pthread->chk_sum == NT_PTHREAD_CHK_SUM);
	pthreadp = (nt_pthread_tp)h_pthread;
	assert(pthreadp->ref_count > 0);
	return (pthreadp->result_func) ? TRUE : FALSE;
}

static BOOL nt_pthread_exec(nt_pthread_handle h_pthread)
{
	nt_pthread_tp pthreadp;
	assert(h_pthread);
	assert(h_pthread->chk_sum == NT_PTHREAD_CHK_SUM);
	pthreadp = (nt_pthread_tp)h_pthread;
	assert(pthreadp->ref_count > 0);
	pthreadp->result = (pthreadp->func)(pthreadp->param);
	return TRUE;
}
static nt_pthread_result_t nt_pthread_call_result_func(
			nt_pthread_handle h_pthread)
{
	nt_pthread_tp pthreadp;
	assert(h_pthread);
	assert(h_pthread->chk_sum == NT_PTHREAD_CHK_SUM);
	pthreadp = (nt_pthread_tp)h_pthread;
	assert(pthreadp->ref_count > 0);
	assert(pthreadp->result_func);
	return (pthreadp->result_func)(pthreadp->result);
}

nt_pthread_handle nt_pthread_alloc(
			nt_pthread_fn func, void *param, 
			nt_pthread_result_fn result_func)
{
	nt_pthread_tp pthreadp;
	
	pthreadp = malloc(sizeof(nt_pthread_t));
	if(!pthreadp)
		return NULL;
	pthreadp->func = func;
	pthreadp->param = param;
	pthreadp->result_func = result_func;
	pthreadp->result.code = NT_PTHREAD_RESULT_NONE;
	pthreadp->result.data = NULL;
	pthreadp->handle.chk_sum = NT_PTHREAD_CHK_SUM;
	pthreadp->ref_count = 1;
	return &pthreadp->handle;
}

int nt_pthread_add_ref(nt_pthread_handle h_pthread)
{
	nt_pthread_tp pthreadp;
	assert(h_pthread);
	assert(h_pthread->chk_sum == NT_PTHREAD_CHK_SUM);
	pthreadp = (nt_pthread_tp)h_pthread;
	assert(pthreadp->ref_count > 0);
	return nt_pthread_increment_int(&pthreadp->ref_count);
}

int nt_pthread_release_ref(nt_pthread_handle h_pthread)
{
	nt_pthread_tp pthreadp;
	int c;
	assert(h_pthread);
	assert(h_pthread->chk_sum == NT_PTHREAD_CHK_SUM);
	pthreadp = (nt_pthread_tp)h_pthread;
	assert(pthreadp->ref_count > 0);
	c = nt_pthread_decrement_int(&pthreadp->ref_count);
	if(0 != c)
		return c;
	free(h_pthread);
	return 0;
}


