// -*-c++-*-

/*!
  \file mark_analyzer.cpp
  \brief mark target analyzer Source File
*/

/*
 *Copyright:

 Copyright (C) Hidehisa AKIYAMA

 This code is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 3, or (at your option)
 any later version.

 This code is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.

 You should have received a copy of the GNU General Public License
 along with this code; see the file COPYING.  If not, write to
 the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.

 *EndCopyright:
 */

/////////////////////////////////////////////////////////////////////

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "mark_analyzer.h"

#include "strategy.h"

#include <rcsc/player/player_agent.h>
#include <rcsc/player/player_predicate.h>
#include <rcsc/player/world_model.h>
#include <rcsc/common/logger.h>
#include <rcsc/timer.h>

#include <algorithm>
#include <iostream>
#include <sstream>

using namespace rcsc;

// #define DEBUG_PRINT
// #define DEBUG_PROFILE

// #define DEBUG_PRINT_LEVEL_1
// #define DEBUG_PRINT_LEVEL_2
// #define DEBUG_PRINT_LEVEL_3
// #define DEBUG_PRINT_COMBINATION
// #define DEBUG_EVAL

/*-------------------------------------------------------------------*/
/*!
  \brief compare the distance from our goal (with scaled x)
*/
struct TargetPositionCmp {
private:
    static const double goal_x;
    static const double x_scale;
public:
    bool operator()( const MarkAnalyzer::Target * lhs,
                     const MarkAnalyzer::Target * rhs ) const
      {
          //return lhs->pos().x < rhs->pos().x;
          return ( std::pow( lhs->player_->pos().x - goal_x, 2.0 ) * x_scale
                   + std::pow( lhs->player_->pos().y, 2.0 ) )
              < ( std::pow( rhs->player_->pos().x - goal_x, 2.0 ) * x_scale
                  + std::pow( rhs->player_->pos().y, 2.0 ) );
      }
};

const double TargetPositionCmp::goal_x = -52.0;
const double TargetPositionCmp::x_scale = 1.2 * 1.2;

/*-------------------------------------------------------------------*/
/*!

 */
struct OpponentDistCmp {
private:
    const MarkAnalyzer::Marker & marker_;
public:
    OpponentDistCmp( const MarkAnalyzer::Marker & m )
        : marker_( m )
      { }

    bool operator()( const AbstractPlayerObject * lhs,
                     const AbstractPlayerObject * rhs ) const
      {
          return ( lhs->pos().dist2( marker_.pos_ )
                   < rhs->pos().dist2( marker_.pos_ ) );
      }
};

/*-------------------------------------------------------------------*/
/*!
  \brief compare the distance from our goal (with scaled x)
*/
struct MarkerPositionCmp {
private:
    static const double goal_x;
    static const double x_scale;
public:
    bool operator()( const MarkAnalyzer::Marker & lhs,
                     const MarkAnalyzer::Marker & rhs ) const
      {
          return ( std::pow( lhs.player_->pos().x - goal_x, 2.0 ) * x_scale
                   + std::pow( lhs.player_->pos().y, 2.0 ) )
              < ( std::pow( rhs.player_->pos().x - goal_x, 2.0 ) * x_scale
                  + std::pow( rhs.player_->pos().y, 2.0 ) );
      }
};

const double MarkerPositionCmp::goal_x = -47.0;
const double MarkerPositionCmp::x_scale = 1.5 * 1.5;

/*-------------------------------------------------------------------*/
/*!

 */
struct MarkerMatch {
    const AbstractPlayerObject * marker_;
    MarkerMatch( const AbstractPlayerObject * p )
        : marker_( p )
      { }

    bool operator()( const MarkAnalyzer::Marker * val ) const
      {
          return val->player_ == marker_;
      }
};

/*-------------------------------------------------------------------*/
/*!

 */
class CombinationCmp {
public:
    bool operator()( const MarkAnalyzer::Combination & lhs,
                     const MarkAnalyzer::Combination & rhs ) const
      {
          return lhs.score_ < rhs.score_;
      }
};

/*-------------------------------------------------------------------*/
/*!

 */


/*-------------------------------------------------------------------*/
/*!

 */
MarkAnalyzer::MarkAnalyzer()
{

}

/*-------------------------------------------------------------------*/
/*!

 */
MarkAnalyzer &
MarkAnalyzer::instance()
{
    static MarkAnalyzer s_instance;
    return s_instance;
}

/*-------------------------------------------------------------------*/
/*!

*/
bool
MarkAnalyzer::isMarkerType( const int unum ) const
{
    int number = Strategy::i().roleNumber( unum );

    if ( number == 2
         || number == 3
         || number == 4
         || number == 5 )
    {
        return true;
    }

    return false;
}

/*-------------------------------------------------------------------*/
/*!

 */
const
rcsc::AbstractPlayerObject *
MarkAnalyzer::getTargetOf( const int marker_unum ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->marker_->unum() == marker_unum )
        {
            return it->target_;
        }
    }

    return static_cast< rcsc::AbstractPlayerObject * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
rcsc::AbstractPlayerObject *
MarkAnalyzer::getTargetOf( const rcsc::AbstractPlayerObject * marker ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->marker_ == marker )
        {
            return it->target_;;
        }
    }

    return static_cast< rcsc::AbstractPlayerObject * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
rcsc::AbstractPlayerObject *
MarkAnalyzer::getMarkerOf( const int target_unum ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->target_->unum() == target_unum )
        {
            return it->marker_;
        }
    }

    return static_cast< rcsc::AbstractPlayerObject * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
rcsc::AbstractPlayerObject *
MarkAnalyzer::getMarkerOf( const rcsc::AbstractPlayerObject * target ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->target_ == target )
        {
            return it->marker_;
        }
    }

    return static_cast< rcsc::AbstractPlayerObject * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
MarkAnalyzer::Pair *
MarkAnalyzer::getPairOfMarker( const int unum ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->marker_->unum() == unum )
        {
            return &(*it);
        }
    }

    return static_cast< const Pair * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
MarkAnalyzer::Pair *
MarkAnalyzer::getPairOfTarget( const int unum ) const
{
    const std::vector< Pair >::const_iterator end = M_pairs.end();
    for ( std::vector< Pair >::const_iterator it = M_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->target_->unum() == unum )
        {
            return &(*it);
        }
    }

    return static_cast< const Pair * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
MarkAnalyzer::UnumPair *
MarkAnalyzer::getUnumPairOfMarker( const int unum ) const
{
    const std::vector< UnumPair >::const_iterator end = M_unum_pairs.end();
    for ( std::vector< UnumPair >::const_iterator it = M_unum_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->marker_ == unum )
        {
            return &(*it);
        }
    }

    return static_cast< const UnumPair * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
const
MarkAnalyzer::UnumPair *
MarkAnalyzer::getUnumPairOfTarget( const int unum ) const
{
    const std::vector< UnumPair >::const_iterator end = M_unum_pairs.end();
    for ( std::vector< UnumPair >::const_iterator it = M_unum_pairs.begin();
          it != end;
          ++it )
    {
        if ( it->target_ == unum )
        {
            return &(*it);
        }
    }

    return static_cast< const UnumPair * >( 0 );
}

/*-------------------------------------------------------------------*/
/*!

 */
void
MarkAnalyzer::update( const WorldModel & wm )
{
    static GameTime update_time( 0, 0 );

    static std::vector< Target > target_opponents;
    static std::vector< Combination > combinations;
    static std::vector< const Marker * > combination_stack;

    if ( update_time == wm.time() )
    {
        return;
    }

    update_time = wm.time();

    //
    // clear old variables
    //
    M_pairs.clear();

    if ( wm.gameMode().type() == GameMode::BeforeKickOff
         || wm.gameMode().type() == GameMode::AfterGoal_
         || wm.gameMode().isPenaltyKickMode() )
    {
        M_unum_pairs.clear();
        return;
    }

#ifdef DEBUG_PROFILE
    MSecTimer timer;
#endif

    //
    // create combinations and eavluate them
    //

    createMarkTargets( wm, target_opponents );

#ifdef DEBUG_PROFILE
    dlog.addText( Logger::MARK,
                  "MarkAnalyzer: create target elapsed %.3f [ms]",
                  timer.elapsedReal() );
#endif

    if ( target_opponents.empty() )
    {
        dlog.addText( Logger::MARK,
                      "MarkAnalyzer: no target opponent" );
        M_unum_pairs.clear();
        return;
    }

    //
    // create the pointer container
    //
    std::vector< Target * > ptr_target_opponents;
    ptr_target_opponents.reserve( target_opponents.size() );
    for ( std::vector< Target >::iterator it = target_opponents.begin();
          it != target_opponents.end();
          ++it )
    {
        ptr_target_opponents.push_back( &(*it) );
    }

    // sort by distance from our goal
    std::sort( ptr_target_opponents.begin(),
               ptr_target_opponents.end(),
               TargetPositionCmp() );

#ifdef DEBUG_PRINT
    for ( std::vector< Target * >::const_iterator o = ptr_target_opponents.begin();
          o != ptr_target_opponents.end();
          ++o )
    {
        for ( std::vector< Marker >::const_iterator m = (*o)->markers_.begin();
              m != (*o)->markers_.end();
              ++m )
        {
            dlog.addText( Logger::MARK,
                          "_ opp %d (%.1f %.1f) : marker=%d (%.1f %.1f) dist=%.2f",
                          (*o)->player_->unum(),
                          (*o)->player_->pos().x, (*o)->player_->pos().y,
                          m->player_->unum(),
                          m->player_->pos().x, m->player_->pos().y,
                          std::sqrt( m->dist2_ ) );
        }
    }
#endif

    //
    // create combinations
    //
#ifdef DEBUG_PROFILE
    Timer comb_timer;
#endif

    createCombination( ptr_target_opponents.begin(), ptr_target_opponents.end(),
                       combination_stack,
                       combinations );
#ifdef DEBUG_PROFILE
    dlog.addText( Logger::MARK,
                  "MarkAnalyzer: create combination elapsed %.3f [ms]",
                  comb_timer.elapsedReal() );
#endif

    //
    // evaluate all combinations
    //
#ifdef DEBUG_PROFILE
    Timer eval_timer;
#endif

    evaluate( combinations );

#ifdef DEBUG_PROFILE
    dlog.addText( Logger::MARK,
                  "MarkAnalyzer: evaluate elapsed %.3f [ms]",
                  eval_timer.elapsedReal() );
#endif

    std::vector< Combination >::iterator it
        = std::max_element( combinations.begin(),
                            combinations.end(),
                            CombinationCmp() );

    //
    // clear old variables
    //
    M_unum_pairs.clear();

    if ( it != combinations.end() )
    {
        dlog.addText( Logger::MARK,
                      "MarkAnalyzer: total combination= %d, best assignment score= %f",
                      (int)combinations.size(),
                      it->score_ );

        const std::vector< const Marker * >::iterator end = it->markers_.end();
        for ( std::vector< const Marker * >::iterator m = it->markers_.begin();
              m != end;
              ++m )
        {
            M_pairs.push_back( Pair( (*m)->player_, (*m)->target_ ) );
            if ( (*m)->player_->unum() != Unum_Unknown
                 && (*m)->target_->unum() != Unum_Unknown )
            {
                M_unum_pairs.push_back( UnumPair( (*m)->player_->unum(),
                                                  (*m)->target_->unum() ) );
            }

            dlog.addText( Logger::MARK,
                          "<<< pair (%d, %d) marker(%.1f %.1f)(%.1f %.1f) - target(%.1f %.1f)",
                          (*m)->player_->unum(),
                          (*m)->target_->unum(),
                          (*m)->player_->pos().x, (*m)->player_->pos().y,
                          (*m)->pos_.x, (*m)->pos_.y,
                          (*m)->target_->pos().x, (*m)->target_->pos().y );
        }
    }

    target_opponents.clear();
    combinations.clear();
    combination_stack.clear();

#ifdef DEBUG_PROFILE
    dlog.addText( Logger::MARK,
                  "MarkAnalyzer: elapsed %.3f [ms]",
                  timer.elapsedReal() );
#endif
}

/*-------------------------------------------------------------------*/
/*!

 */
void
MarkAnalyzer::createMarkTargets( const rcsc::WorldModel & wm,
                              std::vector< Target > & target_opponents )
{
    const double dist_thr2 = std::pow( 20.0, 2 );
    const double x_thr = 12.0;
    const double y_thr = 10.0;
    //const double opponent_max_x = 25.0;
    const double opponent_max_x = ( std::fabs( wm.offsideLineX() - wm.defenseLineX() ) > 40.0
                                    ? wm.offsideLineX() * 0.5 + wm.defenseLineX() * 0.5
                                    : std::min( wm.offsideLineX(), wm.defenseLineX() + 20.0 ) );
    const double current_position_weight = 0.5;
    const size_t max_partial_size = 3;

    dlog.addText( Logger::MARK,
                  "MarkAnalyzer::createMarkTargets() opponent_max_x = %.2f",
                  opponent_max_x );

    const Strategy & strategy = Strategy::i();

    //
    // create marker teammates
    //
    std::vector< Marker > markers;
    markers.reserve( wm.allTeammates().size() );
    {
        const AbstractPlayerCont::const_iterator end = wm.allTeammates().end();
        for ( AbstractPlayerCont::const_iterator t = wm.allTeammates().begin();
              t != end;
              ++t )
        {
            if ( (*t)->goalie() ) continue;
            if ( (*t)->pos().x < wm.defenseLineX() - 3.0 ) continue;
            if ( ! isMarkerType( (*t)->unum() ) ) continue;

            Vector2D pos = (*t)->unum() != Unum_Unknown
                ? strategy.getPosition( (*t)->unum() )
                : (*t)->pos();
            pos = ( pos * current_position_weight )
                + ( (*t)->pos() * ( 1.0 - current_position_weight ) );

            markers.push_back( Marker( *t, pos ) );
        }

        if ( markers.empty() )
        {
            dlog.addText( Logger::MARK,
                          "MarkAnalyzer::createMarkTargets() no marker" );
            return;
        }
    }
    // sort by distance from our goal
    std::sort( markers.begin(), markers.end(), MarkerPositionCmp() );

    const std::vector< Marker >::iterator m_end = markers.end();

    //
    // create mark target opponents
    //
    AbstractPlayerCont wm_opponents;
    wm_opponents.reserve( 10 );
    wm.getPlayerCont( wm_opponents,
                      new AndPlayerPredicate
                      ( new OpponentOrUnknownPlayerPredicate( wm ),
                        //new FieldPlayerPredicate(),
                        //new NoGhostPlayerPredicate( 20 ),
                        new XCoordinateBackwardPlayerPredicate( opponent_max_x ) ) );

    if ( wm_opponents.empty() )
    {
        dlog.addText( Logger::MARK,
                      "MarkAnalyzer::createMarkTargets() no opponent" );
        return;
    }

    target_opponents.reserve( wm_opponents.size() );
    {
        const AbstractPlayerCont::const_iterator end = wm_opponents.end();
        for ( AbstractPlayerCont::const_iterator o = wm_opponents.begin();
              o != end;
              ++o )
        {
            target_opponents.push_back( Target( *o ) );
        }
    }

    const std::vector< Target >::iterator o_end = target_opponents.end();

    if ( target_opponents.empty() )
    {
        dlog.addText( Logger::MARK,
                      "MarkAnalyzer::createMarkTargets() marker_size=%d opponent_size=%d",
                      markers.size(), wm_opponents.size() );
        return;
    }

    //
    // create mark candidates for each marker teammate
    //
    const size_t partial_size = std::min( max_partial_size, wm_opponents.size() );
    double d2 = 0.0;
    for ( std::vector< Marker >::iterator m = markers.begin();
          m != m_end;
          ++m )
    {
        std::partial_sort( wm_opponents.begin(),
                           wm_opponents.begin() + partial_size,
                           wm_opponents.end(),
                           OpponentDistCmp( *m ) );
#ifdef DEBUG_PRINT_LEVEL_2
        dlog.addText( Logger::MARK,
                      "__  marker %d real_pos=(%.1f %.1f) pos=(%.2f %.2f)",
                      m->player_->unum(),
                      m->player_->pos().x, m->player_->pos().y,
                      m->pos_.x, m->pos_.y );
#endif
        for ( size_t i = 0; i < partial_size; ++i )
        {
#ifdef DEBUG_PRINT_LEVEL_3
            dlog.addText( Logger::MARK,
                          "____  target %d (%.1f %.1f) address=%d",
                          wm_opponents[i]->unum(),
                          wm_opponents[i]->pos().x,
                          wm_opponents[i]->pos().y,
                          wm_opponents[i] );
#endif
            for ( std::vector< Target >::iterator o = target_opponents.begin();
                  o != o_end;
                  ++o )
            {
                if ( o->player_ == wm_opponents[i] )
                {
                    if ( o->player_->pos().x < m->pos_.x + x_thr
                         && std::fabs( o->player_->pos().y - m->pos_.y ) < y_thr
                         && ( d2 = o->player_->pos().dist2( m->pos_ ) ) < dist_thr2 )
                    {
#ifdef DEBUG_PRINT_LEVEL_3
                        dlog.addText( Logger::MARK,
                                      "______  mark candidate x_diff=%.3f y_diff=%.3f dist=%.3f",
                                      o->player_->pos().x - m->pos_.x,
                                      o->player_->pos().y - m->pos_.y,
                                      std::sqrt( d2 ) );
#endif
                        m->target_ = o->player_;
                        m->dist2_ = d2;
                        o->markers_.push_back( *m );
                    }
#ifdef DEBUG_PRINT_LEVEL_3
                    else
                    {
                        dlog.addText( Logger::MARK,
                                      "______  no mark candidate x_diff=%.3f y_diff=%.3f dist=%.3f",
                                      o->player_->pos().x - m->pos_.x,
                                      o->player_->pos().y - m->pos_.y,
                                      o->player_->pos().dist( m->pos_ ) );
                    }
#endif
                    break;
                }
            }
        }
    }
}

/*-------------------------------------------------------------------*/
/*!

 */
void
MarkAnalyzer::createCombination( std::vector< Target *>::const_iterator first,
                              std::vector< Target * >::const_iterator last,
                              std::vector< const Marker * > & combination_stack,
                              std::vector< Combination > & combinations )
{
    if ( first == last )
    {
        combinations.push_back( Combination() );
        combinations.back().markers_ = combination_stack;

#ifdef DEBUG_PRINT_COMBINATION
        std::ostringstream os;
        for ( std::vector< const Marker * >::const_iterator p = combination_stack.begin();
              p != combination_stack.end();
              ++p )
        {
            os << (*p)->player_->unum() << ' ';
        }
        dlog.addText( Logger::MARK,
                      "-> add combination: %s",
                      os.str().c_str() );
#endif
        return;
    }

    std::size_t prev_size = combinations.size();

    const std::vector< Marker >::const_iterator m_end = (*first)->markers_.end();
    for ( std::vector< Marker >::const_iterator m = (*first)->markers_.begin();
          m != m_end;
          ++m )
    {
        if ( std::find_if( combination_stack.begin(), combination_stack.end(), MarkerMatch( m->player_ ) )
             == combination_stack.end() )
        {
            combination_stack.push_back( &(*m) );
            createCombination( first + 1, last, combination_stack, combinations );
            combination_stack.pop_back();
        }
#ifdef DEBUG_PRINT_COMBINATION
        else
        {
            std::ostringstream os;
            for ( std::vector< const Marker * >::const_iterator p = combination_stack.begin();
                  p != combination_stack.end();
                  ++p )
            {
                os << (*p)->player_->unum() << ' ';
            }
            dlog.addText( Logger::MARK,
                          "xxx cancel: %s (%d)",
                          os.str().c_str(),
                          m->player_->unum() );
        }
#endif
    }

    if ( prev_size == combinations.size()
         && ! combination_stack.empty() )
    {
#ifdef DEBUG_PRINT_COMBINATION
        std::ostringstream os;
        for ( std::vector< const Marker * >::const_iterator p = combination_stack.begin();
              p != combination_stack.end();
              ++p )
        {
            os << (*p)->player_->unum() << ' ';
        }
        dlog.addText( Logger::MARK,
                      "-> add sub combination: %s",
                      os.str().c_str() );
#endif
        combinations.push_back( Combination() );
        combinations.back().markers_ = combination_stack;
    }
}

/*-------------------------------------------------------------------*/
/*!

 */
void
MarkAnalyzer::evaluate( std::vector< Combination > & combinations )
{
    const double decay = 0.99;

    const std::vector< Combination >::iterator end = combinations.end();
    for ( std::vector< Combination >::iterator it = combinations.begin();
          it != end;
          ++it )
    {
        double k = 1.0;
        double score = 0.0;
        const std::vector< const Marker * >::const_iterator m_end = it->markers_.end();
        for ( std::vector< const Marker * >::const_iterator m = it->markers_.begin();
              m != m_end;
              ++m )
        {
            score += k / (*m)->dist2_;
            k *= decay;
        }

        it->score_ = score;
#ifdef DEBUG_EVAL
        std::ostringstream os;
        for ( std::vector< const Marker * >::const_iterator m = it->markers_.begin();
              m != m_end;
              ++m )
        {
            os << '(' << (*m)->player_->unum() << ',' << (*m)->target_->unum() << ')';
        }
        dlog.addText( Logger::MARK,
                      "** eval: %s : %lf",
                      os.str().c_str(),
                      it->score_ );
#endif
    }
}
