#include "stdafx.h"

#include "Som.h"


// wKilj
const double Som::S_ALPHA = 0.5;
// wǨi̊wK񐔂ŊwKقڔj
const double Som::S_A_UDW = 8.0;
// 𑁂߂
const double Som::S_B_UDW = 0.2;


Som::Som(const int input_size,
         const double &size_rate)
         : M_input_size(input_size),
           M_train_count(0)
{
    M_cell_size = (int)(input_size * size_rate);
    // d݂CߖT̃TCY(l)
    M_neighborhood_size = M_cell_size / 2;
    // set max loop size
    //M_max_train = 90 + 10 * (M_input_size / 30); 
    M_max_train = 6000;
    M_c_udw = S_A_UDW / (S_A_UDW + M_max_train - 1);

    reset(0.0, 0.0, 30.0);
}



void Som::reset(const double &center_x,
                const double &center_y,
                const double &radius)
{
    double d_rad = (2.0 * 3.1415926535) / (double)M_cell_size;
    int i;

    M_cells.clear();

    double rad = 0.0;
    for ( i = 0; i < M_cell_size; i++ )
    {
        M_cells.push_back(SomCell(radius * cos(rad) + center_x,
                                  radius * sin(rad) + center_y));
        rad += d_rad;
    }
    M_train_count = 0;
}


#if 0
void Som::update(const double &x,
                 const double &y)
{
    if ( M_train_count >= M_max_train
         || M_cells.empty() )
    {
        return;
    }

    int min_id = 0;
    double min_dist2 = 100000;

    // tf[^ɍł߂҃j[߂
    int i = 0;
    std::vector<SomCell>::const_iterator cell_end(M_cells.end());
    for ( std::vector<SomCell>::const_iterator it = M_cells.begin();
          it != cell_end;
          ++it, ++i )
    {
        double d2 = it->dist2(x, y);
        if ( d2 < min_dist2 )
        {
            min_dist2 = d2;
            min_id = i;
        }
    }

    // wKƋߖT̃TCY𒲐錓p̃p[^
    const double eta
        = S_A_UDW / (S_A_UDW + (double)M_train_count)
        - M_c_udw;

    const double alpha
        = eta * S_ALPHA + S_B_UDW;

    // ҃j[̏d݂̏C
    double x_diff = x - M_cells[min_id].x();
    double y_diff = y - M_cells[min_id].y();
    M_cells[min_id].add(alpha * x_diff,
                        alpha * y_diff);

    // ߖTj[̏d݂̏C
    int left = min_id;
    int right = min_id;
        
    const int neighborhood
        = (int)(eta * M_neighborhood_size) + 2;//1;
    for ( int dist = 1; dist < neighborhood; dist++ )
    {
        double h = alpha * (1.0 - dist / (double)neighborhood);
        if ( --left < 0 )
        {
            left += M_cell_size;
        }
        x_diff = x - M_cells[left].x();
        y_diff = y - M_cells[left].y();
        M_cells[left].add(h * x_diff, h * y_diff);
        if ( ++right >= M_cell_size )
        {
            right -= M_cell_size;
        }
        x_diff = x - M_cells[right].x();
        y_diff = y - M_cells[right].y();
        M_cells[right].add(h * x_diff, h * y_diff);
    }
}
#else





void Som::update(const double &x,
                 const double &y)
{
    if ( M_train_count >= M_max_train
         || M_cells.empty() )
    {
        return;
    }

    const double rate = 0.01;
    const double sigma_base = 120.0;

    const double sigma
        = sigma_base
        * exp(-2.0
              * sigma_base
              * ((double)M_train_count / (double)M_max_train));
TRACE2("train %d.  sigma = %f\n", M_train_count, sigma);
    double min_dist2 = 100000;
    SomCell min_cell = M_cells.front();

    // tf[^ɍł߂҃j[߂
    const std::vector<SomCell>::const_iterator cell_end(M_cells.end());
    for ( std::vector<SomCell>::const_iterator it = M_cells.begin();
          it != cell_end;
          ++it )
    {
        double d2 = it->dist2(x, y);
        if ( d2 < min_dist2 )
        {
            min_dist2 = d2;
            min_cell = *it;
        }
    }
TRACE2("champion cell (%.1f %.1f)\n", min_cell.x(), min_cell.y());
    // ߖTj[̏d݂̏C
    const std::vector<SomCell>::iterator cell_normal_end(M_cells.end());
    for ( std::vector<SomCell>::iterator nit = M_cells.begin();
          nit != cell_normal_end;
          ++nit )
    {
        const double neighborhood
            = exp(50.0 * nit->dist2(min_cell) / (-2.0 * sigma * sigma));
TRACE2("dist to champ %.2f,   neighborhood = %f\n",
       nit->dist(min_cell), neighborhood);
        double x_diff = x - nit->x();
        double y_diff = y - nit->y();
        nit->add(rate * neighborhood * x_diff,
                 rate * neighborhood * y_diff);
    }
}



#endif

