#include <stdio.h>
#include <malloc.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <unistd.h>
#include <getopt.h>

#include "seq_entry_array.h"
#include "fasta_util.h"

int blosum_index[] = {0,20,4,3,6,13,7,8,9,-1,11,10,12,2,-1,14,5,1,15,16,-1,19,17,22,18,21};

SeqEntryArray **devideSeqEntryArray(SeqEntryArray *src, int n_set, int n_block);

void writeDatabase(SeqEntryArray *src, FILE *fp, long original_length, long padded_length);

int cmp(const void *entry1, const void *entry2);

int main(int argc, char *argv[]){
  
  // number of warp per thread block
  int set_size = 8;

  // # of GPUs
  int gpu_num = 1;
  
  if(argc < 3){
    printf("usage : %s [-g #gpu] src_fasta_file dst_file_prefix \n", argv[0]);
    return 0;
  }

  
  int option;
  while((option = getopt(argc, argv, "w:g:")) != -1){
    switch(option){
    case 'w': set_size = atoi(optarg); break;
    case 'g': gpu_num = atoi(optarg); break;
      
    }
    optarg = NULL;
  }
  
  argc -= optind;
  argv += optind;
  
  // stupid argment analysis
  char *src_fasta_file = argv[0];
  char *dst_fasta_file_prefix = argv[1];
  
  // read FASTA format file from file
  SeqEntryArray *s_array = fasta_read( src_fasta_file );
  if(s_array == NULL){
    fprintf(stderr, "fasta file read error\n");
    return 0;
  }

  qsort(s_array->array, s_array->n, sizeof(SeqEntry*), cmp);
  
  int i;
  for(i=0; i<s_array->n; i++){
    SeqEntry *seq = getSeqEntryArray(s_array, i);
    seq->id_num = i;
  }
  
  long padded_length = 0;
  long original_length = 0;

  char *id_file;
  id_file = (char*)malloc(sizeof(char)* (strlen(dst_fasta_file_prefix) + 10));
  sprintf(id_file, "%s.id", dst_fasta_file_prefix);
  
  FILE *fp;
  if((fp = fopen(id_file, "w")) == NULL){
    fprintf(stderr, "file open error!!\n");
    exit(1);
  }

  fprintf(fp, "%d\n", s_array->n);
  
  for(i=0; i<s_array->n; i++){
    SeqEntry *seq = getSeqEntryArray(s_array, i);

    //int refer_seq_idx = (int)((i + set_size) / set_size) * set_size - 1;
    int refer_seq_idx = (int)(i/set_size)*set_size;
    
    if(refer_seq_idx > s_array->n){
      refer_seq_idx = s_array->n - 1;
    }
    
    SeqEntry *seq_refer = getSeqEntryArray(s_array, refer_seq_idx);
    
    original_length += seq->length;

    paddingSeqEntry(seq, 0, 32, 0, '_');
    while(seq_refer->length > seq->length){
      paddingSeqEntry(seq, 0, 32, 32, '_');
    }
    
    padded_length += seq->length;

  }

  fprintf(fp, "%ld\n", original_length);
  
  for(i=0; i<s_array->n; i++){
    SeqEntry *seq = getSeqEntryArray(s_array, i);
    fprintf(fp, "%d,%s\n", seq->id_num, seq->id);
  }

  fclose(fp);

  printf("ID file output finith\n");
  fflush(stdout);
  
  // devide to some SeqEntryArray (Database division)
  SeqEntryArray **div_array = devideSeqEntryArray(s_array, set_size, gpu_num);
  if(div_array == NULL){
    fprintf(stderr, "SeqEntryArray devide fault\n");
    return 0;
  }

  printf("Database divide finith\n");
  fflush(stdout);
  
  for(i=0; i<gpu_num; i++){
    char *buf;
    buf = (char*)malloc(sizeof(char) * (strlen(dst_fasta_file_prefix) + 10));
    sprintf(buf, "%s.kfasta%d", dst_fasta_file_prefix, i);

    if((fp = fopen(buf, "w")) == NULL){
      fprintf(stderr, "file open error!!\n");
      exit(1);
    }

    padded_length = 0;
    int j;
    for(j=0; j<div_array[i]->n; j++){
      SeqEntry *seq = getSeqEntryArray(div_array[i], j);
      padded_length += seq->length;
    }
    
    writeDatabase(div_array[i], fp, -1, padded_length);
  }
  
  return 0;
}

SeqEntryArray **devideSeqEntryArray(SeqEntryArray *src, int n_set, int n_block){
  int i, j;

  int *length;
  
  length = (int*)malloc(sizeof(int)*n_block);
  memset(length, 0, sizeof(int)*n_block);

  SeqEntryArray **dst = (SeqEntryArray**) malloc(sizeof(SeqEntryArray*) * n_block);
  if(dst == NULL){
    fprintf(stderr, "devideSeqEntryArray(): memory allocate error\n");
    return NULL;
  }
  
  for(i=0; i<n_block; i++){
    *(dst + i) = makeSeqEntryArray(src->n / n_block + 1);
  }

  int min, min_idx;
  
  for(i=0; i<(src->n + n_set - 1)/n_set; i++){
    min = length[0]; min_idx = 0;
    for(j=1; j<n_block; j++){
      if(length[j] < min){
	min = length[j];
	min_idx = j;
      }
    }

    SeqEntryArray *array = *(dst + min_idx);

    for(j=0; j<n_set; j++){
      if(i*n_set + j >= src->n) break;

      SeqEntry *seq = getSeqEntryArray(src, i*n_set + j);
      setSeqEntryArray(array, seq);
      length[min_idx] += seq->length;
    }
  }

  return dst;
}

void writeDatabase(SeqEntryArray *src, FILE *fp, long original_length, long padded_length){
  
  int index = 0;

  fprintf(fp, "%d\n", src->n);
  fprintf(fp, "%ld\n", padded_length);
  fprintf(fp, "%ld\n", original_length);

  int i, j;
  for(i=0; i<src->n; i++){
    SeqEntry *seq = getSeqEntryArray(src, i);
    fprintf(fp, "%d,%d\n", index, seq->id_num);
    index += seq->length;
  }

  for(i=0; i<src->n; i++){
    SeqEntry *seq = getSeqEntryArray(src, i);
    char *buf = seq->seq;
    for(j=0; j<seq->length; j++){
      if(buf[j] >= 'A' && buf[j] <= 'Z'){
	if(buf[j] == 'O') buf[j] = 'K';
	if(buf[j] == 'U') buf[j] = 'C';
	fprintf(fp, "%c", blosum_index[buf[j] - 'A']);
      }else if(buf[j] == '*'){
	fprintf(fp, "%c", 'Z' - 'A' + 1);
      }else if(buf[j] == '_'){
	fprintf(fp, "%c", 'Z' - 'A' + 2);
      }else{
	fprintf(stderr, "undefined char %c\n", buf[j]);
	exit(1);
      }
    }
  }

  fprintf(fp, "\n");
  
}

int cmp(const void *entry1, const void *entry2){
  SeqEntry *e1, *e2;
  e1 = *((SeqEntry**)entry1);
  e2 = *((SeqEntry**)entry2);

  if(e1->length < e2->length) return 1;
  if(e1->length > e2->length) return -1;
  return 0;
}
