///     make_me_params
//
//ʸ᥯饹ܳΨꡢ
//ʸ᥯饹ܤȤ
//票ȥԡǥΥѥ᡼ޤ
/*
Copyright (C) 2006  hanaoka

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/

#include <cstdio>
#include <iostream>
#include <fstream>
#include <vector>
#include <cmath>
#include <iomanip>
// #define DEBUG
#define EPS 1e-10
#define DELTA 1e-1
//#define SEGSIZE 39

using namespace std;

enum MODE{
  PRINT_PD,
  PRINT_PARAMS,
  WRITE_PARAMS
};
MODE mode = PRINT_PARAMS;

int fn; // = SEGSIZE*SEGSIZE; //ο
vector<vector<double> > epd; //ؽǡγΨʬ(иŪΨʬ)
vector<vector<vector<int> > > features; //
vector<double> epd_expectations; //иŪΨʬۤˤ
vector<double> marginal_prob; //ؽǡμճΨʬ
int SEGSIZE; //ʸ᥯饹ο

void input(){
#ifdef DEBUG
  cout<<"input"<<endl;
#endif
  //ʸ᥯饹ο
  scanf("%d", &SEGSIZE);
  fn = SEGSIZE*SEGSIZE;
  
  //иŪΨʬepd
  //epd[a][b] ˤ P(a, b)б

  epd = vector<vector<double> >(SEGSIZE, vector<double>(SEGSIZE));
  marginal_prob = vector<double>(SEGSIZE);
  for(int i=0; i<SEGSIZE; ++i){
    double d=0;
    for(int j=0; j<SEGSIZE; ++j){
      scanf("%lf", &epd[i][j]);
      d+=epd[i][j];
    }
    marginal_prob[i]=d;
  }
  
  //ؿκ
  //fi(x,y) 01
  features = vector<vector<vector<int> > >(fn, vector<vector<int> >(SEGSIZE, vector<int>(SEGSIZE, 0)));
  int cnt = 0;
  for(int i=0; i<SEGSIZE; ++i)
    for(int j=0; j<SEGSIZE; ++j){
      features[cnt][i][j]=1;
      ++cnt;
    }
}

void update_pd(vector<vector<double> > &pd, vector<double> &params, vector<double> &zs){
  //ߤΥѥ᡼ФΨʬۤ򹹿
  for(int i=0; i<SEGSIZE; ++i){
    if(abs(zs[i])>=EPS){
      for(int j=0; j<SEGSIZE; ++j){
	double d=0;
	for(int k=0; k<fn; ++k)
	  d+=params[k] * features[k][i][j];
	pd[i][j]=exp(d)/zs[i];
      }
    }
    else{
      for(int j=0; j<SEGSIZE; ++j)
	pd[i][j]=0;
    }
  }
}

void make_epd_expectations(){
  //ؽǡФ
  epd_expectations = vector<double>(fn);
  for(int i=0; i<fn; ++i){
    double d=0;
    for(int j=0; j<SEGSIZE; ++j)
      for(int k=0; k<SEGSIZE; ++k)
	d+=epd[j][k]*features[i][j][k];
    epd_expectations[i]=d;
  }
}

double calc_pd_expectation(int k, vector<vector<double> > &pd){
  //kФ븽ߤγΨʬpdδͤ׻
  double ret = 0;
  for(int i=0; i<SEGSIZE; ++i){
    if(abs(marginal_prob[i]) < EPS)
      continue;
    else{
      double d=0;
      for(int j=0; j<SEGSIZE; ++j)
	d+=pd[i][j]*features[k][i][j];
      ret += marginal_prob[i]*d;
    }
  }
  return ret;
}

void calc_delta(vector<double> &ds, vector<vector<double> > &pd){
  //ѥ᡼Ѳ̤׻
#ifdef DEBUG
  cout<<"calc_delta"<<endl;
#endif    
  for(int i=0; i<fn; ++i){
    double e_p = calc_pd_expectation(i, pd);
    double e_ep = epd_expectations[i];
    /*
    double a=(abs(e_p) < EPS)?0:log(e_p);
    double b=(abs(e_ep) < EPS)?0:log(e_ep);
    ds[i]=max(0.0,a-b);
    */
    if(abs(e_ep) < EPS){
      ds[i]=-1;
    }else{
      if(abs(e_p) < EPS)
	ds[i]=0;
      else{
	ds[i]=log(e_ep)-log(e_p);
      }
    }
  }
}

void print_params(vector<double> &params){
  for(int i=0;i<fn;++i)
    cout<<params[i]<<" ";
  cout<<endl;
}

void update_zs(vector<double> &zs, vector<double> &params){
  for(int i=0; i<SEGSIZE; ++i){
    zs[i] = 0;
    for(int j=0; j<SEGSIZE; ++j){
      double d=0;
      for(int k=0; k<fn; ++k)
	d += params[k] * features[k][i][j];
      zs[i] += exp(d);
    }
  }
}

void output(vector<vector<double> > &pd){
  for(int i=0; i<SEGSIZE; ++i){
    for(int j=0; j<SEGSIZE; ++j){
      printf("%.5e", pd[i][j]);
      if(!(i==SEGSIZE-1 && j==SEGSIZE-1))
	printf(", ");
    }
    printf("\n");
  }
}

void output(vector<double> params){
  for(int i=0; i<fn; ++i)
    cout<<params[i]<<" ";
  cout<<endl;
}

void output_params(vector<double> &params, vector<double> &zs){
  for(int i=0; i<fn; ++i)
    printf("%.5e ", params[i]);
  printf("\n\n");
  
  for(int i=0; i<SEGSIZE; ++i)
    printf("%.5e ", zs[i]);
  printf("\n");
}

void make_params(vector<double> &params, vector<double> &zs, vector<vector<double> > &pd){
  params = vector<double>(fn, 10.0);
  zs = vector<double>(SEGSIZE);
  pd = vector<vector<double> >(SEGSIZE, vector<double>(SEGSIZE));
  
  update_zs(zs, params);
  update_pd(pd, params, zs);
  
  //  output(pd);
  vector<double> ds(fn);
  while(true){
    double d = 0;
    calc_delta(ds, pd);
    for(int i=0; i<fn; ++i){
      params[i]+=ds[i];
      d=max(d, abs(ds[i]/params[i])); 
    }
    update_zs(zs, params);
    update_pd(pd, params, zs);
#ifdef DEBUG
    cout<<d<<endl;
    //    output(params);
#endif
    if(d < DELTA) //Ѳ̤κ礬delta꾮м«Ȥߤʤ
      break;
  }
}

void print_usage(){
  printf("make_me_params\n");
  printf(" $ ./make_me_params -l < [input-file]\n");
  printf("//(default) ѥ᡼ɸ\n");
  printf("\n");
  printf(" $ ./make_me_params -p < [input-file]\n");
  printf("//ѥ᡼ˤΨʬۤɸ\n");
  printf("\n");
  printf(" $ ./make_me_params -f [lambdas] [zs] < [input-file]\n");
  printf("//ѥ᡼եؽ\n");
  printf("//ܤ˻ꤷեνŤߥꥹȡ\n");
  printf("//ܤꥹȤ񤭹\n");
  exit(0);
}

void parse_args(int argc, char **argv){
  for(int i=1; i<argc; ++i){
    char *arg = argv[i];
    if(*arg=='-'){
      ++arg;
      while(*arg != '\0'){
	switch(*arg){
	case 'p':
	  mode = PRINT_PD;
	  break;
	case 'l':
	  mode = PRINT_PARAMS;
	  break;
	case 'f':
	  if(i==1 && argc==4){
	    mode = WRITE_PARAMS;
	    break;
	  }
	default :
	  print_usage();
	  break;
	}
	++arg;
      }
    }else if(mode != WRITE_PARAMS)
      print_usage();
  }
}

void write_params(vector<double> params, vector<double> zs, char**argv){
  ofstream st1(argv[2]);
  st1<<setiosflags(ios::scientific)<<setprecision(5);
  for(int i=0; i<fn; ++i){
    if(i)st1<<" , ";
    st1<<params[i];
  }

  ofstream st2(argv[3]);
  st2<<setiosflags(ios::scientific)<<setprecision(5);
  for(int i=0; i<SEGSIZE; ++i){
    if(i)st2<<" , ";
    st2<<params[i];
  }
}

int main(int argc, char **argv)
{
  cout<<setiosflags(ios::scientific)<<setprecision(5);
  parse_args(argc, argv);
  
  input();
  
  make_epd_expectations();

  vector<double> params; //ФŤ
  vector<double> zs; //ʸ᥯饹Ф
  vector<vector<double> > pd; //ߤΥѥ᡼ФΨʬ
  
  make_params(params, zs, pd);

  if(mode==PRINT_PD)
    output(pd);
  else if(mode==PRINT_PARAMS){
    output_params(params, zs);
  }else{
    write_params(params, zs, argv);
  }
}
