/************************************************
 *						*
 *	Back Propergation Simurator(BPS)	*
 *	      subroutine package		*
 *	  	Version 4.0			*
 *	  coded		in May.17 1989		*
 *	  coded by 	Y.Okamura		*
 *	  last modified	in Nov.15 1990		*
 *         modified by   K.Kuroda		*
 *						*
 *************************************************
 *						*
 *	filename setlmain.c 			*
 *	    BP learning controler		*
 *						*
 ************************************************/
#include "BPS.h"
#include "learn.h"

#define	CHECKEND  ((SumOfErr>MinError) && (learn_cnt!=MaxLearnCount))

int     learn_cnt, wgt_stor_cnt;
int     err_stor_cnt, err_datapoint, display_cnt;
int     buff_no, buff_write_flag;


/************************************************
  set learning
  input:
  lrn_mode : learning mode
  ************************************************/
double
set_learn(lrn_mode)
     int lrn_mode;
{
  int     unit, ptrn;
  double  error;

  error = 0.0;
  for (unit = 0; unit < NumOfCell[NumOfLayer-1]; OutCellErr[unit++] = 0.0);

  workspace_initialize();
  for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
    error += forward_learn(InputData[ptrn], TeachData[ptrn]);
    if (lrn_mode)
      backward_learn();
  }
  return (error);
}


/*************************************************
  store weight
  *************************************************/
void
store_weight()
{
  Header   head;
  int      check;
  ilin_t  *link_pt;

  /* StoreWeight2(WgtHistoryFile, wgt_stor_cnt ); */
  /* T.Hayasaka 1994 4/26 */
  StoreWeight2(WgtHistoryFile, wgt_stor_cnt - 1);
  if (WgtStorMode == APPEND)
    wgt_stor_cnt++;

  check = LoadHeader(WgtHistoryFile, &head);
  if (check == -1) exit(102);

  link_pt = inter_ilin(BPNet[0][1].CellNode, BPNet[1][1].CellNode);
  sprintf(head.comment, "%s%f", "learning rate = ", link_pt->CoefLearn);

  StoreHeader(WgtHistoryFile, &head);
}


/*************************************************
  store error
  *************************************************/
void
store_error()
{
  if (ErrStorDirection == RECORD) {
    StoreErrRecord(ErrHistoryFile, err_stor_cnt);
  } else
    WriteErrDataPoint(ErrHistoryFile, err_stor_cnt, err_datapoint++, SumOfErr);
  if (ErrStorMode == APPEND)
    err_stor_cnt++;
}


/*************************************************
  display iteration, error, comment
  *************************************************/
void
display()
{
  int    index[10];

  printf("\n");
  printf("### ITERATION        ### = %d\n",   learn_cnt);
  printf("### SQUARE'S ERROR   ### = %14e\n", SumOfErr);
  printf("### DIFFERENCE       ### = %14e\n",
	 SumOfErr - buff_err[display_cnt-2]);
/*	 SumOfErr - buff_err[learn_cnt-2]); */
  printf("### COMMENT          ### = %s\n",   Comment);

  if (buff_write_flag) {
    index[0] = display_cnt;
    if(WriteBuffer(buff_no, 1, index, buff_err) == -1)
      exit(3);
  }
}


/************************************************
  steep method
  ************************************************/
void
Steep_method()
{
  int     ptrn, unit;
  double  ptrn_err = 0.0;

  while (CHECKEND) {
    learn_cnt++;

    if (LearnMode == SET_LEARN)
      steep();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
	for (unit = 0; unit < NumOfCell[NumOfLayer-1];
	     OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(InputData[ptrn], TeachData[ptrn]);
	backward_learn();
	steep();
      }
    }

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = (LearnMode == SET_LEARN) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}

/************************************************
  momentum method ( set_learn )
  ************************************************/
void
setMomentum()
{
  while (CHECKEND) {
    learn_cnt++;

    momentum1();

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = set_learn(1);

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}


/************************************************
  momentum method ( pattern_learn )
  ************************************************/
void
patternMomentum()
{
  int  ptrn, unit;

  while (CHECKEND) {
    learn_cnt++;

    SumOfErr = 0.0;
    for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
      for (unit = 0; unit < NumOfCell[NumOfLayer-1];
	   OutCellErr[unit++] = 0.0);

      workspace_initialize();
      SumOfErr += forward_learn(InputData[ptrn], TeachData[ptrn]);
      backward_learn();
      momentum1();
    }

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}


/************************************************
  Vogl method
  ************************************************/
void
Vogl_method()
{
  int      ptrn, unit, lay;
  double   ptrn_err = 0.0;
  ilin_t  *link_pt;


  for (lay = 0; lay < NumOfLayer; lay++) {
    for (unit = 1; unit <= NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
	link_pt->CoefLearn = LearnRate;
	link_pt = Getinfwdlist(link_pt);
      }
    }
  }

  while (CHECKEND) {
    learn_cnt++;

    if (LearnMode == SET_LEARN)
      vogl();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
	for (unit = 0; unit < NumOfCell[NumOfLayer-1];
	     OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(InputData[ptrn], TeachData[ptrn]);
	backward_learn();
	vogl();
      }
    }

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = (LearnMode == SET_LEARN) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}


void
Jacob_method()
/************************************************
  Jacob's method
  ************************************************/
{
  int             ptrn, unit;
  double          ptrn_err = 0.0;

  while (CHECKEND) {
    learn_cnt++;

    if (LearnMode == SET_LEARN)
      jacobs();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
	for (unit = 0; unit < NumOfCell[NumOfLayer-1];
	     OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(InputData[ptrn], TeachData[ptrn]);
	backward_learn();
	jacobs();
      }
    }

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = (LearnMode == SET_LEARN) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}


void
Vogl_coef_method()
/************************************************
  momentum Vogl's coefficient method
  ************************************************/
{
  int     ptrn, unit;
  double  ptrn_err = 0.0;

  vgl2_coe();

  while (CHECKEND) {
    learn_cnt++;

    if (LearnMode == SET_LEARN)
      momentum2();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < NumOfPtrn; ptrn++) {
	for (unit=0; unit < NumOfCell[NumOfLayer-1]; OutCellErr[unit++]=0.0);

	workspace_initialize();
	ptrn_err += forward_learn(InputData[ptrn], TeachData[ptrn]);
	backward_learn();
	momentum2();
      }
    }

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = (LearnMode == SET_LEARN) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;
    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}


/************************************************
  An acceralated learning method to reduce
  the oscillation of weight for neural networks
  ************************************************/
void
Ochi_method()
{
  int      unit, lay;
  ilin_t  *link_pt;

  for (lay = 0; lay < NumOfLayer; lay++)
    for (unit = 1; unit <= NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
	link_pt->CoefLearn = LearnRate;
	link_pt->wgtworkold = -link_pt->WgtWork;
	link_pt = Getinfwdlist(link_pt);
      }
    }

  while (CHECKEND) {
    learn_cnt++;
    Ochi();

    if ((learn_cnt % WgtStorInterval) == 0)
      store_weight();

    SumOfErr = set_learn(1);

    if ((learn_cnt % ErrStorInterval) == 0)
      store_error();

    buff_err[display_cnt++] = (double)SumOfErr;

    if ((learn_cnt % DisplayInterval) == 0)
      display();
  }
}

/************************************************
  learning main routine
  ************************************************/
int
main()
{
  int   idx[10];

  read_syscom();
  rebps();

  /* SYSTEM INITIALIZE */
  GetStructureParameters();
  PrintStructureParameters();
  GetLearningParameters();
  PrintLearningParameters();

  buff_no = (int)GetScalar(0);
  if (buff_no < 0) exit(17);

  buff_write_flag = (buff_no != 0) ? 1 : 0;
  learn_cnt       = 0;
  display_cnt     = 0;

  system_initialize(); /*  ErrBuffer Υ꤬ݤ */
  printf("system_initialize is OK\n");

  /* MAKE NETWORK */
  MakeNetwork();

  /*    ReadWeight2(LrnInitWgtFile, LAST);  */
  ReadWeight2(LrnInitWgtFile, LAST + 1);

  /*
   * CreateFile2( WgtHistoryFile, "weight history", WgtStorMode );
   * CreateFile2( ErrHistoryFile, "error history", ErrStorMode );
   */
  CreateFile3(WgtHistoryFile, "weight history", WgtStorMode);
  CreateFile3(ErrHistoryFile, "error history",  ErrStorMode);

  wgt_stor_cnt = NextWgtHistory(WgtHistoryFile);

  /****** Modified by dora 1995/5/30 ******/
  printf("wgt_stor_cnt   = %d\n", wgt_stor_cnt);
  if (wgt_stor_cnt == 0)
    wgt_stor_cnt = 1;
  /****** Modified by dora 1995/5/30 ******/

#if 0
  if (ErrStorDirection == RECORD)
    err_stor_cnt = NextErrHistory(ErrHistoryFile);
  else {
    err_stor_cnt = GetNumOfRecord(ErrHistoryFile);

    printf("ErrHistoryFile = %s\n", ErrHistoryFile);
    printf("err_stor_cnt   = %d\n", err_stor_cnt);

    if (err_stor_cnt == 0) {
      err_stor_cnt  = 1;
      err_datapoint = 0;
    } else {
      err_datapoint = StoreOldError(ErrHistoryFile, err_stor_cnt);
      if (err_datapoint == 0)
	err_stor_cnt++;
    }
  }
#else
  /* modified by take */
  if ( ErrStorMode == APPEND ) {
    if (ErrStorDirection == RECORD)
      err_stor_cnt  = NextErrHistory(ErrHistoryFile);
    else {
      err_stor_cnt  = GetNumOfRecord(ErrHistoryFile) +1;
      err_datapoint = StoreOldError(ErrHistoryFile, err_stor_cnt);
    }
    if (err_datapoint == 0)
      err_stor_cnt = 1;
  } else {
    err_datapoint = 0;
    err_stor_cnt  = 1;
  }
  printf("ErrHistoryFile = %s\n", ErrHistoryFile);
  printf("err_stor_cnt   = %d\n", err_stor_cnt);
  /* modified by take */
#endif

  write_syscom();

  /* LEARNING */

  SumOfErr = set_learn(1);
  if ((err_stor_cnt  == 1) &&
      (err_datapoint == 0))
    store_error();

  buff_err[display_cnt++] = (double)SumOfErr;

  printf("\n");
  printf("### ITERATION        ### = %d\n",   learn_cnt);
  printf("### SQUARE'S ERROR   ### = %14e\n", SumOfErr);
  printf("### DIFFERENCE       ### = %14e\n", SumOfErr);
  printf("### COMMENT          ### = %s\n",   Comment);

  if (buff_write_flag){
    idx[0] = display_cnt;
    if(WriteBuffer(buff_no, 1, idx, buff_err) == -1)
      exit(3);
  }

  switch (LearnAlgo) {
  case STEEP:      Steep_method();     break;
  case MOMENTUM:   
    if (LearnMode == SET_LEARN)
      setMomentum();
    else
      patternMomentum();
    break;
  case VOGL:       Vogl_method();      break;
  case JACOB:      Jacob_method();     break;
  case MOMENTUM2:  Vogl_coef_method(); break;
  case OCHI:       Ochi_method();      break;
  }

  if ((learn_cnt % ErrStorInterval) != 0) store_error();
  if ((learn_cnt % WgtStorInterval) != 0) store_weight();
  if ((learn_cnt % DisplayInterval) != 0) display();

  printf("\n\t*** Learning is done ! ***\n");

  BreakNetwork();
  system_end();

  wrbps();
  write_syscom();
  exit(0);
}
