#include <stdio.h>
#include <string.h>
#include <ctype.h>
#include <stdlib.h>

#include "matrix.h"

int getLabelIndex(char c){
  int index;
  
  if(isalpha(c)) c = toupper(c);
  if(isalpha(c) || c == '*' || c == '_'){
    if(isalpha(c)){
      index = c - 'A';
    }else if(c == '*'){
      index = 26;
    }else{
      index = 27;
    }
  }else{
    index = -1;
  }
   
  return index;
}  

int *makeColumnLabelIndex(char *buf, int index_size,
			  void *(* mem_allocate_func)(size_t size)){
  char delim[] = " \t\n";
  char *token;
  int cnt;
  int *column_index;
  int index;
  char acid;

  int i;
  
  // memory allocate and initialize
  column_index = (int*) mem_allocate_func (sizeof(int) * index_size); 
  if(column_index == NULL) return NULL;
  for(i=0; i<index_size; i++){
    column_index[i] = -1;
  }

  cnt = 0;
  token = strtok(buf, delim);
  while(token != NULL){
    acid = token[0];
    
    index = getLabelIndex(acid);
    
    if(index < 0){
      fprintf(stderr, "Unknown amino acid %s in Matrix file.", token);
      return NULL;
    }

    if(column_index[index] > 0){
      fprintf(stderr, "There are multiple %c in Matrix file.", acid);
      return NULL;
    }else{
      column_index[index] = cnt;
    }
    
    token = strtok(NULL, delim);
    cnt++;
  }

  column_index[getLabelIndex('_')] = cnt;

  return column_index;
}

void makeRowLabelIndexAndMatrix(char *buf, int index_size, MATRIX_ELEMENT *matrix, int *row_index, int row){
  char delim[] = " \t\n";
  char *token;
  char acid;
  int index;
  int cnt;

  token = strtok(buf, delim);
   
  acid = token[0];
  index = getLabelIndex(acid);
  if(index < 0){
    fprintf(stderr, "Unknown amino acid %s in Matrix file.", token);
    return;
  }
    
  if(row_index[index] > 0){
    fprintf(stderr, "There are multiple %c in Matrix file.", acid);
    return;
  }else{
    row_index[index] = row;
  }
   
  token = strtok(NULL, delim);  
  cnt = 0;
  while(token != NULL){
    matrix[cnt] = atoi(token);
    token = strtok(NULL, delim);  
    cnt++;
  }
}

AminoMatrix *readMatrixFromFile(char *file_path, void *(* mem_allocate_func)(size_t size)){
  FILE *fp = NULL;
  char buf[256];
  char buf2[256];
  
  int *column_index = NULL;
  int *row_index = NULL;
  MATRIX_ELEMENT *matrix_data = NULL;
  int matrix_size = 32;

  int i;
  int column_index_flag = 1;
  int row_cnt = 0;

  AminoMatrix *matrix = NULL;
  
  if((fp = fopen(file_path, "r")) == NULL){
    fprintf(stderr, "Matrix file open error:%s\n", file_path);
    return NULL;
  }

  matrix_data = (MATRIX_ELEMENT*) mem_allocate_func (sizeof(MATRIX_ELEMENT) * matrix_size * matrix_size);
  if(matrix_data == NULL) return NULL;
  memset(matrix_data, 0, sizeof(MATRIX_ELEMENT) * matrix_size * matrix_size);

  row_index = (int*) mem_allocate_func (sizeof(int) * matrix_size);
  if(row_index == NULL) return NULL;
  for(i=0; i<matrix_size; i++){
    row_index[i] = -1;
  }
  
  column_index_flag = 1;
  while( fgets(buf, sizeof(buf), fp) != NULL){
    if(sscanf(buf, "%[#]", buf2) > 0) continue;
    
    if(column_index_flag){
      column_index = makeColumnLabelIndex(buf, matrix_size, mem_allocate_func);
      if(column_index == NULL) break;
      column_index_flag = 0; 
      row_cnt = 0;
    }else{
      makeRowLabelIndexAndMatrix(buf, matrix_size,
				 matrix_data + row_cnt * matrix_size,
				 row_index, row_cnt);
      row_cnt++;
    }
  }
  row_index[getLabelIndex('_')] = row_cnt;

  matrix = (AminoMatrix*) mem_allocate_func (sizeof(AminoMatrix));
  
  matrix->data = matrix_data;
  matrix->size = matrix_size;
  matrix->column_index = column_index;
  matrix->row_index = row_index;  
  
  return matrix;
}

int getMatrixValue(AminoMatrix *matrix, char c1, char c2){
  int c_idx = getLabelIndex(c1);
  int r_idx = getLabelIndex(c2);

  int size = matrix->size;

  if(c_idx < 0){
    fprintf(stderr, "c1 is illegal amino acid: %c\n", c1);
    return 0;
  }
  
  if(r_idx < 0){
    fprintf(stderr, "c2 is illegal amino acid: %c\n", c2);
    return 0;
  }

  c_idx = matrix->column_index[c_idx];
  r_idx = matrix->row_index[r_idx];
  
  return matrix->data[r_idx * size + c_idx];
}
