#include "sw_cuda.h"

// For Score Matrix (BLOSUM50 etc)
texture<MATRIX_ELEMENT, 2, cudaReadModeElementType> texScoreMatrix;

texture<char, 1, cudaReadModeElementType> texSubjectSeqs;
texture<char , 1, cudaReadModeElementType> texQuerySeq;

__device__ __constant__ int Gap_init;
__device__ __constant__ int  Gap_extend;

#include "sw_kernel.cu"

template<class T>
T getAlignment(T v, int alignment, int offset=0){
  return ((T)((v + alignment - 1)/alignment) + offset) * alignment;
}

SWResult *getSWResult(char *d_seq_data,
		      int *d_seq_data_index,
		      int n_seq,
		      char *d_query_seq,
		      int query_len,
		      int block_per_kernel_call,
		      int thread_per_block,
		      int warp_size, 
		      int *h_seq_id_index,
		      int num_answer){

  // $B%+!<%M%k%3!<%kKh$N=hM}$9$k%7!<%1%s%9?t(B
  int seq_per_block = (thread_per_block/warp_size);
  int seq_per_kernel_call = block_per_kernel_call * seq_per_block;
    
  int temp_buf_size
    = sizeof(TempScore) * query_len * seq_per_kernel_call;

  TempScore *d_temp_HF = NULL;
  cutilSafeCall( cudaMalloc((void**)&d_temp_HF, temp_buf_size) );
  
  int *d_result[3] = {NULL, NULL, NULL};
  for(int i=0; i<3; i++){
    cutilSafeCall( cudaMalloc((void**)&d_result[i], sizeof(int) * seq_per_kernel_call) );
    cutilSafeCall( cudaMemset(d_result[i], 0, sizeof(int)* seq_per_kernel_call) );
  }
  
  int *h_H_result[2] = {NULL, NULL};
    for(int i=0; i<2; i++){
    cutilSafeCall( cudaHostAlloc((void**)&h_H_result[i], sizeof(int) * seq_per_kernel_call  * thread_per_block, 
				 cudaHostAllocDefault) );
  }
    
  // $B%[%9%H$N%"%i%$%a%s%H%9%3%"NN0h(B
  SWResult *result = NULL;
  cutilSafeCall( cudaHostAlloc((void**)&result,
			       sizeof(SWResult) * (num_answer + seq_per_kernel_call),
			       cudaHostAllocDefault) );
  memset(result, 0, sizeof(SWResult) * (num_answer + seq_per_kernel_call));
  
  //////////////////////////////////////
  cudaStream_t stream[3];
  for(int i=0; i<3; i++){
    cutilSafeCall( cudaStreamCreate(stream + i) );
  }  
  
  //////////////////////////////////////

  int n_iteration = getAlignment(n_seq, seq_per_kernel_call) / seq_per_kernel_call;

  // These process are pipelined
  // 1. kernel call (calc the score)
  // 2. memory copy (device to host)
  // 3. sort (reduce the result to number of "num_answer")
  
  // $B%V%m%C%/?t(B
  int cuda_block_num;
  for(int i=0; i<n_iteration; i++){
    // $B$3$N%$%?%l!<%7%g%s$N3+;O%7!<%1%s%9HV9f(B
    int head_seq = i * seq_per_kernel_call;

    cuda_block_num = block_per_kernel_call;
        
    /*
    if(head_seq + seq_per_kernel_call > n_seq){
      cuda_block_num = getAlignment(n_seq % seq_per_block, seq_per_block)/seq_per_block;
    }
    */

    dim3 grid(cuda_block_num);
    dim3 block(warp_size, seq_per_block);
    
#ifdef DEBUG
    printf("%d th iteration of %d : process %d block\n", i, n_iteration, cuda_block_num);
#endif

    int stream_id_kernel = i%3;
    int stream_id_memcpy = (i-1)%3;
    int stream_id_sort = (i-2)%3;
    
    // always launch kernel
    kernelSmithWaterman<WARP_SIZE, SEQ_PER_BLOCK>
      <<<grid, block, 0, stream[stream_id_kernel]>>>(d_seq_data,
						     d_seq_data_index + head_seq,
						     d_query_seq,
						     query_len,
						     d_temp_HF,
						     d_result[stream_id_kernel]);

    // copy the temp_H from device to host
    if(i >= 1) {
      cutilSafeCall( cudaStreamSynchronize(stream[stream_id_memcpy]) );
      cutilSafeCall( cudaMemcpyAsync(h_H_result[i%2], 
				     d_result[stream_id_memcpy],
				     sizeof(int) * seq_per_kernel_call, 
				     cudaMemcpyDeviceToHost,
				     stream[stream_id_memcpy]) );
    }

    // gather the result and sort it
    if(i >= 2){
      cutilSafeCall( cudaStreamSynchronize(stream[stream_id_sort]) );
      for(int j=0; j<seq_per_kernel_call; j++){
	int index = (i-2)*seq_per_kernel_call+j;
	int score = h_H_result[(i-1)%2][j];
	result[num_answer + j].score = score;
	result[num_answer + j].index = h_seq_id_index[index];
      }
      qsort(result, num_answer + seq_per_kernel_call, sizeof(SWResult), cmpSWResult);
    }

  }

  cudaThreadSynchronize();

  if(n_iteration > 0){
    if(n_seq % seq_per_kernel_call == 0){
      cutilSafeCall( cudaMemcpy(h_H_result[n_iteration%2], d_result[(n_iteration-1)%3],
				sizeof(int) * seq_per_kernel_call, 
				cudaMemcpyDeviceToHost) );
      
    }else{
      cutilSafeCall( cudaMemcpy(h_H_result[n_iteration%2], d_result[(n_iteration-1)%3],
				sizeof(int) * (n_seq % seq_per_kernel_call), 
				cudaMemcpyDeviceToHost) );
    }
  }
  
  if(n_iteration > 1){
    for(int j=0; j<seq_per_kernel_call; j++){
      int index = (n_iteration-2)*block_per_kernel_call*8+j;
      int score = h_H_result[(n_iteration-1)%2][j];
      result[num_answer + j].score = score;
      result[num_answer + j].index = h_seq_id_index[index];
    }
    
    qsort(result, num_answer + seq_per_kernel_call,
	  sizeof(SWResult), cmpSWResult);
  }
  
  if(n_iteration > 0){
    if(n_seq % seq_per_kernel_call == 0){
      for(int j=0; j<seq_per_kernel_call; j++){
	int index = (n_iteration-1)*block_per_kernel_call*8+j;
	int score = h_H_result[(n_iteration)%2][j];
	result[num_answer + j].score = score;
	result[num_answer + j].index = h_seq_id_index[index];
      }
      
      qsort(result, num_answer + seq_per_kernel_call,
	    sizeof(SWResult), cmpSWResult);
    }else{
       for(int j=0; j<n_seq % seq_per_kernel_call; j++){
	int index = (n_iteration-1)*block_per_kernel_call*8+j;
	int score = h_H_result[(n_iteration)%2][j];
	result[num_answer + j].score = score;
	result[num_answer + j].index = h_seq_id_index[index];
       }
       qsort(result, num_answer + (n_seq % seq_per_kernel_call),
	     sizeof(SWResult), cmpSWResult);
    }
    
  }

  //////////////////////////////////////
    
  for(int i=0; i<2; i++){
    cutilSafeCall( cudaFreeHost(h_H_result[i]) );
  }
  
  cutilSafeCall( cudaFree(d_temp_HF) );

  for(int i=0; i<3; i++){
    cutilSafeCall( cudaFree(d_result[i]) );
  }

  for(int i=0; i<3; i++){
    cutilSafeCall( cudaStreamDestroy(stream[i]) );
  }
  
  return result;
}

int cmpProcSeqArray(const void *a, const void *b){
  ProcSeqArray *s_a = (ProcSeqArray*)a;
  ProcSeqArray *s_b = (ProcSeqArray*)b;

  if(s_a->total_len > s_b->total_len) return 1;
  if(s_a->total_len < s_b->total_len) return -1;
  return 0;
}


int cmpSWResult(const void *a, const void *b){
  SWResult *sw_a = (SWResult*)a;
  SWResult *sw_b = (SWResult*)b;

  if(sw_a->score > sw_b->score) return -1;
  if(sw_a->score < sw_b->score) return 1;
  return 0;
}

__global__
void kernelScoreMatrixTextureTest(MATRIX_ELEMENT *result){
  int tid = blockDim.x * blockIdx.x + threadIdx.x;
  int x = tid % 32;
  int y = tid / 32;

  result[tid] = tex2D(texScoreMatrix, x, y);
}

void setSubjectSeqsToTexture(char *d_score_index_block, int size){
  cudaChannelFormatDesc channelDesc
    = cudaCreateChannelDesc(sizeof(char)*8, 0, 0, 0, cudaChannelFormatKindSigned);
  
  // Set texture parameters
  texSubjectSeqs.addressMode[0] = cudaAddressModeClamp;
  texSubjectSeqs.addressMode[1] = cudaAddressModeClamp;
  texSubjectSeqs.filterMode = cudaFilterModePoint;
  texSubjectSeqs.normalized = false;
  
  cutilSafeCall( cudaBindTexture(0, texSubjectSeqs, d_score_index_block, channelDesc, size) );
  
}

void setScoreMatrixToTexture(AminoMatrix *h_matrix){
  cudaArray *d_score_matrix = NULL;
  int matrix_size = sizeof(MATRIX_ELEMENT) * h_matrix->size * h_matrix->size;
  
  
  cudaChannelFormatDesc channelDesc
    = cudaCreateChannelDesc(sizeof(MATRIX_ELEMENT)*8, 0, 0, 0, cudaChannelFormatKindSigned);

  cutilSafeCall( cudaMallocArray(&d_score_matrix, &channelDesc, h_matrix->size, h_matrix->size) );
  cutilSafeCall( cudaMemcpyToArray(d_score_matrix, 0, 0, h_matrix->data, matrix_size,
				   cudaMemcpyHostToDevice) );
  
  // Set texture parameters
  texScoreMatrix.addressMode[0] = cudaAddressModeClamp;
  texScoreMatrix.addressMode[1] = cudaAddressModeClamp;
  texScoreMatrix.filterMode = cudaFilterModePoint;
  texScoreMatrix.normalized = false;
  
  cutilSafeCall( cudaBindTextureToArray(texScoreMatrix, d_score_matrix, channelDesc) );

  // check the texture copy to texutre memory correctry.
#ifdef DEBUG
  textureCheck(h_matrix);
#endif
  
}

void textureCheck(AminoMatrix *h_matrix){

  int matrix_size = sizeof(MATRIX_ELEMENT) * h_matrix->size * h_matrix->size;
  MATRIX_ELEMENT *d_result;
  MATRIX_ELEMENT *h_result;
  cutilSafeCall( cudaMalloc((void**)&d_result, matrix_size) );
  
  // There is MAGIC NUMBER.
  kernelScoreMatrixTextureTest<<<32*32/256, 256>>>(d_result);
  
  cutilSafeCall( cudaHostAlloc((void**)&h_result, matrix_size, cudaHostAllocDefault) );
  cutilSafeCall( cudaMemcpy(h_result, d_result, matrix_size, cudaMemcpyDeviceToHost) );

  int flag = 0;
  for(int i=0; i<(h_matrix->size)*(h_matrix->size); i++){
    if(h_result[i] != h_matrix->data[i]){
      fprintf(stderr, "score matrix: unmutch %d[%d:%d]\n", i, h_result[i], h_matrix->data[i]);
      flag = 1;
    }
  }
  if(flag){
    fprintf(stderr, "score matrix unmutch\n");
  }else{
    fprintf(stderr, "score matrix mutch\n");
  }
 
}
  

int getIndexedSeq(char *h_dst, char *h_src, int seq_len, int *h_index, char padding){  
  if(h_src == NULL) return -1;
  if(h_dst == NULL) return -1;

  for(int i=0; i<seq_len; i++){
    char c = h_src[i];
    int idx;
    
    if(c == '*'){
      idx = 26;
    }else if(c == padding){
      idx = 27;
    }else{
      idx = h_index[c - 'A'];
    }
    h_dst[i] = idx;
  }

  return 0;
}

int makePaddedQueryIndexSeq(char **d_seq,
			    char *h_index_qseq, int qseq_len,char padding_index,
			    int head_padding_size, int tail_padding_size, int alignment){
  int padded_seq_length = (int)((qseq_len + alignment -1) / alignment + head_padding_size + tail_padding_size) * alignment;
  
  if(*d_seq != NULL){
    cutilSafeCall( cudaFree(*d_seq) );
    *d_seq = NULL;
  }
  
  cutilSafeCall( cudaMalloc((void**)d_seq, padded_seq_length) );
  
  char *d_padded_index_seq = *d_seq;

  cutilSafeCall( cudaMemcpy(d_padded_index_seq + head_padding_size * alignment,
			    h_index_qseq, qseq_len, cudaMemcpyHostToDevice) );

  if(head_padding_size > 0){
    cutilSafeCall( cudaMemset(d_padded_index_seq, padding_index, head_padding_size * alignment) );
  }

  int seq_length_without_tail_padding = head_padding_size * alignment + qseq_len;
  int tail_padding_length = padded_seq_length - seq_length_without_tail_padding;
  
  if(tail_padding_size > 0){
    cutilSafeCall( cudaMemset(d_padded_index_seq + seq_length_without_tail_padding,
			      padding_index, tail_padding_length) );
  }

#ifdef DEBUG

  char *h_test = NULL;
  cutilSafeCall( cudaHostAlloc((void**)&h_test, padded_seq_length, cudaHostAllocDefault) );
  cutilSafeCall( cudaMemcpy(h_test, d_padded_index_seq, padded_seq_length, cudaMemcpyDeviceToHost) );
  for(int i=0; i<padded_seq_length; i++){
    fprintf(stdout, "%d,", h_test[i]);
  }
  fprintf(stdout, "\n");
 
#endif
  
  return padded_seq_length;
}

char *makeScoreIndexBlock(SeqBlock *h_block, AminoMatrix *h_matrix, char padding){
  char *d_index_block = NULL;
  char *d_seq_block = NULL;
  int *d_amino_index = NULL;
  
  int n_threads = 256;
  int n_block_total = h_block->size / n_threads;
  
  cutilSafeCall( cudaMalloc( (void**)&d_index_block, h_block->size) );
  cutilSafeCall( cudaMalloc( (void**)&d_seq_block, h_block->size) );
  cutilSafeCall( cudaMalloc( (void**)&d_amino_index, sizeof(int) * h_matrix->size) );
  
  cutilSafeCall( cudaMemcpy(d_seq_block, h_block->seq_block, h_block->size,
			    cudaMemcpyHostToDevice) );
  cutilSafeCall( cudaMemcpy(d_amino_index, h_matrix->row_index, sizeof(int) * h_matrix->size,
			    cudaMemcpyHostToDevice) );
  
  kernelScoreIndex<<<1, n_threads>>>(d_seq_block, d_index_block,
						 d_amino_index, padding,
						 n_block_total);

//   char *hoge;
//   cutilSafeCall( cudaHostAlloc((void**) &hoge, h_block->size, cudaHostAllocDefault) );
//   cutilSafeCall( cudaMemcpy(hoge, d_index_block, h_block->size, cudaMemcpyDeviceToHost) );
  
//   for(int i=256; i<512; i++){
//     printf("%c[%d], ", h_block->seq_block[i], (int)(hoge[i]));
//   }
//   printf("\n");
  
  cutilSafeCall( cudaFree(d_seq_block) );
  //cutilSafeCall( cudaFree(d_amino_index) );
  
  return d_index_block;

}

int *makeSequenceIndex(SeqBlock *h_block){
  int *d_seq_index = NULL;
  int size = sizeof(int) * (h_block->n_block + 1);

#ifdef DEBUG
  printf("# of index : %d\n", h_block->n_block + 1);
  

  for(int i=0; i<h_block->n_block; i++){
    printf("%d : %d\n", i, h_block->block_idx[i+1] - h_block->block_idx[i]);
  }
#endif
  
  cutilSafeCall( cudaMalloc((void**)&d_seq_index, size) );
  cutilSafeCall( cudaMemcpy(d_seq_index, h_block->block_idx, size, cudaMemcpyHostToDevice) );
  
  return d_seq_index;
}

void *cudaHostAllocPinned(size_t size){
  void *p;

  cutilSafeCall( cudaHostAlloc((void **)&p, size, cudaHostAllocDefault) );

  return p;
}

__global__
void kernelScoreIndex(char *seq_block, char *index_block, 
		      int *amino_index, char padding,
		      int n_block_total){

  int n_iteration = n_block_total;
  
  int tid = threadIdx.x;

  for(int i = 0; i < n_iteration; i++){
    char c = seq_block[i*blockDim.x + tid];

    // not clever
    // there is some magic number
    if(c == '*'){
      c = amino_index[26];
    }else if(c == padding){
      c = amino_index[27];
    }else{
    c = amino_index[c - 'A'];
    }
    
    index_block[i*blockDim.x + tid] = c;
  }
}

void setGap(int *open, int *extend){
  int init = *open + *extend;
  cudaMemcpyToSymbol(Gap_init, &init, sizeof(int));
  cudaMemcpyToSymbol(Gap_extend, extend, sizeof(int));
}

TaskPlan *makeTaskPlan(Processors *proc,
		       char *seq_file_prefix,
		       AminoMatrix *amino_mat,
		       int gap_open,
		       int gap_extend){
  
  TaskPlan *tp;
  tp = (TaskPlan*) malloc (sizeof(TaskPlan) * proc->n_proc);

  omp_set_num_threads(proc->n_proc);
#pragma omp parallel
  {
    unsigned int id = omp_get_thread_num();
    cudaSetDevice(id);

    TaskPlan *task = tp + id;

    task->proc = proc;

    task->a_matrix = amino_mat;

    char *seq_file = (char*)malloc(sizeof(char) * (strlen(seq_file_prefix) + 10));

    sprintf(seq_file, "%s%d", seq_file_prefix, id);
    
    readKFasta(seq_file,
	       &(task->h_seq_data),
	       &(task->h_seq_data_index),
	       &(task->h_seq_id_index),
	       &(task->n_seq),
	       &(task->n_cell),
	       &(task->n_padded_cell)
	       );

    task->d_seq_data = NULL;
    task->d_seq_data_index = NULL;

    task->gap_open = gap_open;
    task->gap_extend = gap_extend;
    // set gap penalty
    setGap(&(task->gap_open), &(task->gap_extend));
    
    // Set the scoring matrix to texture named "texScoreMatrix".
    // The "texScoreMatrix" is global variable. 
    setScoreMatrixToTexture(task->a_matrix);

    // $B%G!<%?%Y!<%9%7!<%1%s%9$rE>Aw(B
    cutilSafeCall( cudaMalloc((void**)&(task->d_seq_data), sizeof(char) * (task->n_padded_cell)) );
    cutilSafeCall( cudaMemset(task->d_seq_data, 0, sizeof(char) * 100) );
    cutilSafeCall( cudaMemcpy(task->d_seq_data,
 			      task->h_seq_data,
 			      sizeof(char) * (task->n_padded_cell),
 			      cudaMemcpyHostToDevice) );
    
    // $B%G!<%?%Y!<%9%7!<%1%s%9$N%$%s%G%C%/%9$rE>Aw(B
    cutilSafeCall( cudaMalloc((void**)&(task->d_seq_data_index), sizeof(int) * getAlignment(task->n_seq, BLOCK_PER_KERNEL_CALL*SEQ_PER_BLOCK, 1)) );
    cutilSafeCall( cudaMemcpy(task->d_seq_data_index,
			      task->h_seq_data_index,
			      sizeof(int) * getAlignment(task->n_seq, BLOCK_PER_KERNEL_CALL*SEQ_PER_BLOCK, 1),
			      cudaMemcpyHostToDevice) );
  }

  return tp;
}

Processors *getProcessors(int n_gpu_limit){
  //TaskPlan *tp = (TaskPlan*) malloc (sizeof(TaskPlan));
  Processors *proc = (Processors*)malloc(sizeof(Processors));

  cudaGetDeviceCount(&(proc->n_gpu));
  if((n_gpu_limit >= 0) &&(proc->n_gpu > n_gpu_limit)){
    proc->n_gpu = n_gpu_limit;
  }

  proc->n_proc = proc->n_gpu;

  return proc;
}

SWResult *getSWResultWithMultiThreads(TaskPlan *tp, 
				      Processors *proc,
				      char *h_indexed_qseq,
				      int indexed_qseq_len,
				      int block_per_kernel_call,
				      int thread_per_block,
				      int warp_size,
				      int num_answer){
  SWResult *result = NULL;
  result = (SWResult*) malloc (sizeof(SWResult) * (num_answer * proc->n_proc + 1) );
  
  omp_set_num_threads(proc->n_proc);
#pragma omp parallel
  {
    unsigned int id = omp_get_thread_num();
    TaskPlan *task = tp + id;
    cudaSetDevice(id);

    if(id < (proc->n_gpu)){

      char *d_padded_query_seq = NULL;
      int padded_query_len = makePaddedQueryIndexSeq(&d_padded_query_seq,
						     h_indexed_qseq, indexed_qseq_len,
						     PADDING_INDEX, 1, 4, WARP_SIZE);
      padded_query_len -= WARP_SIZE * 3;
      
      SWResult *each_result;
      each_result = getSWResult(task->d_seq_data,
				task->d_seq_data_index,
				task->n_seq, 
				d_padded_query_seq, padded_query_len,
				BLOCK_PER_KERNEL_CALL, THREAD_PER_BLOCK, WARP_SIZE,
				task->h_seq_id_index,
				num_answer);
      
      for(int i=0; i<num_answer; i++){
	//printf("%d : %d: %d\n", id, each_result[i].score, each_result[i].index);
	result[num_answer * id + i] = each_result[i];
      }
    }

    

  }

  if(proc->n_proc > 1){
    qsort(result, num_answer * proc->n_proc, sizeof(SWResult), cmpSWResult);
  }
  
  return result;
}

int readKFasta(char *path,
	       char **seq_data,
	       int **seq_data_index,
	       int **seq_id_index,
	       int *n_seq,
	       int *n_cell,
	       int *n_padded_cell){
  FILE *fp;
  if((fp = fopen(path, "r")) == NULL){
    fprintf(stderr, "file open error! %s\n", path);
    return 1;
  }

  char *buf;
  buf = (char*)malloc(sizeof(char) * (128+1));

  fgets(buf, 128, fp);
  *n_seq = strtol(buf, NULL, 10);

  int n_data = getAlignment(*n_seq, BLOCK_PER_KERNEL_CALL*SEQ_PER_BLOCK, 1);

  *seq_data_index = (int*) malloc (sizeof(int) * n_data);
  *seq_id_index = (int*) malloc (sizeof(int) * n_data);

  fgets(buf, 128, fp);
  *n_padded_cell = strtol(buf, NULL, 10);

  fgets(buf, 128, fp);
  *n_cell = strtol(buf, NULL, 10);
  
  for(int i=0; i<(*n_seq); i++){
    char *buf2;
    fgets(buf, 128, fp);
    (*seq_data_index)[i] = strtol(buf, &buf2, 10);
    (*seq_id_index)[i] = strtol(buf2+1, NULL, 10);
  }

  for(int i=(*n_seq); i<n_data; i++){
    (*seq_data_index)[i] = *n_padded_cell;
    (*seq_id_index)[i] = -1;
  }

  *seq_data = (char*)malloc(sizeof(char)*(*n_padded_cell));

  if(fread(*seq_data, sizeof(char), *n_padded_cell, fp) != *n_padded_cell){
    fprintf(stderr, "file read error! %s\n", path);
    exit(1);
  }

  return 0;

}
