/* Copyright (c) 1991-2002 Doshita Lab. Speech Group, Kyoto University */
/* Copyright (c) 2000-2002 Speech and Acoustics Processing Lab., NAIST */
/*   All rights reserved   */

/* gms_gprune.c --- compute GMS HMM with Gaussian pruning */

/* $Id: gms_gprune.c,v 1.2 2002/09/11 22:01:50 ri Exp $ */

#include <sent/stddefs.h>
#include <sent/htk_hmm.h>
#include <sent/htk_param.h>
#include <sent/hmm.h>
#include <sent/gprune.h>
#include "globalvars.h"

/* activate experimental methods */
#define GS_MAX_PROB		/* compute only max for GS states */
#define LAST_BEST		/* compute last best Gaussian first */
#undef BEAM			/* enable beam pruning */
#define BEAM_OFFSET 10.0	/* envelope offset */

#ifdef BEAM
#define LAST_BEST
#endif
#ifdef LAST_BEST
#define GS_MAX_PROB
#endif

/* local cache */
static int my_gsset_num;	/* num of gsset states (local copy) */
static int *last_max_id;	/* maximum mixture id of last call for each states */
#ifdef BEAM
static VECT *dimthres;	/* threshold for each dimension (reversed,base=1) */
static int dimthres_num;
#endif

/************************************************************************/
/* initialization */
void
gms_gprune_init(HTK_HMM_INFO *hmminfo, int gsset_num)
{
  my_gsset_num = gsset_num;
  last_max_id = (int *)mybmalloc(sizeof(int) * gsset_num);
#ifdef BEAM
  dimthres_num = hmminfo->opt.vec_size;
  dimthres = (LOGPROB *)mybmalloc(sizeof(LOGPROB) * dimthres_num);
#endif
}

void
gms_gprune_prepare()
{
  int i;
  for(i=0;i<my_gsset_num;i++) {
    last_max_id[i] = -1;
  }
}

/**********************************************************************/
/* compute only max by (safe|beam) pruning */
/* LAST_BEST ... compute the maximum component in last frame first */
/* BEAM ... use beam pruning */
static LOGPROB
calc_contprob_with_safe_pruning(HTK_HMM_Dens *binfo, LOGPROB thres)
{
  LOGPROB tmp, x;
  VECT *mean;
  VECT *var;
  LOGPROB fthres = thres * (-2.0);
  VECT *vec = OP_vec;
  short veclen = OP_veclen;

  if (binfo == NULL) return(LOG_ZERO);
  mean = binfo->mean;
  var = binfo->var->vec;

  tmp = binfo->gconst;
  for (; veclen > 0; veclen--) {
    x = *(vec++) - *(mean++);
    tmp += x * x / *(var++);
    if ( tmp > fthres) {
      return LOG_ZERO;
    }
  }
  return(tmp / -2.0);
}

#ifdef BEAM

static LOGPROB
calc_contprob_with_beam_pruning_pre(HTK_HMM_Dens *binfo)
{
  LOGPROB tmp, x;
  VECT *mean;
  VECT *var;
  VECT *th = dimthres;
  VECT *vec = OP_vec;
  short veclen = OP_veclen;

  if (binfo == NULL) return(LOG_ZERO);
  mean = binfo->mean;
  var = binfo->var->vec;

  tmp = 0.0;
  for (; veclen > 0; veclen--) {
    x = *(vec++) - *(mean++);
    tmp += x * x / *(var++);
    if ( *th < tmp) *th = tmp;
    th++;
  }
  return((tmp + binfo->gconst) / -2.0);
}
static LOGPROB
calc_contprob_with_beam_pruning_post(HTK_HMM_Dens *binfo)
{
  LOGPROB tmp, x;
  LOGPROB *mean;
  LOGPROB *var;
  LOGPROB *th = dimthres;
  VECT *vec = OP_vec;
  short veclen = OP_veclen;

  if (binfo == NULL) return(LOG_ZERO);
  mean = binfo->mean;
  var = binfo->var->vec;

  tmp = 0.0;
  for (; veclen > 0; veclen--) {
    x = *(vec++) - *(mean++);
    tmp += x * x / *(var++);
    if ( tmp > *(th++)) {
      return LOG_ZERO;
    }
  }
  return((tmp + binfo->gconst) / -2.0);
}

#endif /* BEAM */

#ifdef LAST_BEST
static LOGPROB
compute_g_max(HTK_HMM_State *stateinfo, int last_maxi, int *maxi_ret)
{
  int i, maxi;
  LOGPROB prob;
  LOGPROB maxprob = LOG_ZERO;

  if (last_maxi != -1) {
    maxi = last_maxi;
#ifdef BEAM
    /* clear dimthres */
    for(i=0;i<dimthres_num;i++) dimthres[i] = 0.0;
    /* calculate and set thres for each dimension */
    maxprob = calc_contprob_with_beam_pruning_pre(stateinfo->b[maxi]);
    /* set final beam */
    for(i=0;i<dimthres_num;i++) dimthres[i] += BEAM_OFFSET;
#else  /* ~BEAM */
    maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi], LOG_ZERO);
#endif
    for (i = stateinfo->mix_num - 1; i >= 0; i--) {
      if (i == last_maxi) continue;
#ifdef BEAM
      prob = calc_contprob_with_beam_pruning_post(stateinfo->b[i]);
#else  /* ~BEAM */
      prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
#endif
      if (prob > maxprob) {
	maxprob = prob;
	maxi = i;
      }
    }
    *maxi_ret = maxi;
  } else {
    maxi = stateinfo->mix_num - 1;
    maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi],  LOG_ZERO);
    i = maxi - 1;
    for (; i >= 0; i--) {
      prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
      if (prob > maxprob) {
	maxprob = prob;
	maxi = i;
      }
    }
    *maxi_ret = maxi;
  }

  return((maxprob + stateinfo->bweight[maxi]) / LOGTEN);
}
  
#else  /* ~LAST_BEST */
  
static LOGPROB
compute_g_max(HTK_HMM_State *stateinfo)
{
  int i, maxi;
  LOGPROB prob;
  LOGPROB maxprob = LOG_ZERO;

  i = maxi = stateinfo->mix_num - 1;
  for (; i >= 0; i--) {
    prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
    if (prob > maxprob) {
      maxprob = prob;
      maxi = i;
    }
  }
  return((maxprob + stateinfo->bweight[maxi]) / LOGTEN);
}
#endif

/**********************************************************************/
/* main function: compute all gshmm scores */
/* *** assume to be called for sequencial frame (using last result) */
void
compute_gs_scores(GS_SET *gsset, int gsset_num, LOGPROB *scores_ret)
{
  int i;
#ifdef LAST_BEST
  int max_id;
#endif

  for (i=0;i<gsset_num;i++) {
#ifdef GS_MAX_PROB
#ifdef LAST_BEST
    /* compute only the maximum with pruning (last best first) */
    scores_ret[i] = compute_g_max(gsset[i].state, last_max_id[i], &max_id);
    last_max_id[i] = max_id;
#else
    scores_ret[i] = compute_g_max(gsset[i].state);
#endif /* LAST_BEST */
#else
    /* compute all mixture */
    scores_ret[i] = compute_g_base(gsset[i].state);
#endif
    /*printf("%d:%s:%f\n",i,gsset[i].book->name,scores_ret[i]);*/
  }

}
