#include "win32_transcoder.h"

#ifdef _WIN32

#include <assert.h>
#include "../exception.h"
#include "encoding_name.h"
#include "pivot_transcoder.h"


#include <algorithm>

namespace aka2 {
  extern win32_encodings g_win32_encodings_;
}


using namespace aka2;

namespace {
  struct encoding_less {
    bool operator()(const win32_encoding &lhs, const win32_encoding &rhs) const {
      return lhs.key_ < rhs.key_;
    }
    bool operator()(const std::string &lhs, const win32_encoding &rhs) const {
      return lhs < rhs.key_;
    }
    bool operator()(const win32_encoding &lhs, const std::string &rhs) const {
      return lhs.key_ < rhs;
    }
  };
}

win32_encodings::encodings win32_encodings::encodings_;
win32_encodings::encodings win32_encodings::aliases_;
IClassFactory *win32_encodings::factory_ = NULL;
HMODULE win32_encodings::mlang_dll_ = NULL;


void win32_encodings::scan_aliases() {
  HKEY charsetKey;
  if (RegOpenKeyExA(HKEY_CLASSES_ROOT, 
                    "MIME\\Database\\Charset",
                    0,
                    KEY_READ,
                    &charsetKey)) {
    // Registry does not have MIME\Database\charset.
    // Nothing can be done.
    return;
  }

  DWORD subkeyIndex;
  for (subkeyIndex = 0; ; ++subkeyIndex) {
    win32_encoding encoding;
    
    char encoding_name[1024];
    DWORD name_size = 1024;
    int res = RegEnumKeyExA(charsetKey,
                            subkeyIndex, encoding_name, &name_size, 
                            0, 0, 0, 0);
    if (res == ERROR_NO_MORE_ITEMS) // enumeration finished.
      break;

    assert(res == ERROR_SUCCESS);

    if (encoding_name[0] == '_')
      continue;

    // Get encoding name subkey.
    encoding.name_ = encoding_name;
    encoding.key_ = create_encoding_key(encoding_name);

    HKEY encodingKey;
    // Must succeed.
    res = RegOpenKeyExA(charsetKey, encoding_name, 0, KEY_READ, &encodingKey);
    assert(res == ERROR_SUCCESS);

    // Check is this key is an alias.
    DWORD value_type = 0;

    UCHAR value_buffer[1024];
    DWORD value_buffer_length = 1024;
    res = RegQueryValueExA(encodingKey, "aliasForCharset", 0, &value_type, value_buffer, &value_buffer_length);
    if (res == ERROR_SUCCESS) {
      // This encoding has aliasForCharset subkey.
      // This is an alias encoding.  Get the real encoding name.
      encoding.realname_ = reinterpret_cast<char*>(value_buffer);
      aliases_.push_back(encoding);
    }
    RegCloseKey(encodingKey);
  }

  RegCloseKey(charsetKey);

}

void win32_encodings::scan_codepages() {
  HKEY charsetKey;
  if (RegOpenKeyExA(HKEY_CLASSES_ROOT, 
                    "MIME\\Database\\Charset",
                    0,
                    KEY_READ,
                    &charsetKey)) {
    // Registry does not have MIME\Database\charset.
    // Nothing can be done.
    return;
  }

  DWORD subkeyIndex;
  for (subkeyIndex = 0; ; ++subkeyIndex) {
    win32_encoding encoding;
    
    char encoding_name[1024];
    DWORD name_size = 1024;
    int res = RegEnumKeyExA(charsetKey,
                            subkeyIndex, encoding_name, &name_size, 
                            0, 0, 0, 0);
    if (res == ERROR_NO_MORE_ITEMS) // enumeration finished.
      break;

    assert(res == ERROR_SUCCESS);

    if (encoding_name[0] == '_')
      continue;

    // Get encoding name subkey.
    encoding.name_ = encoding_name;
    encoding.key_ = create_encoding_key(encoding_name);

    HKEY encodingKey;
    // Must succeed.
    res = RegOpenKeyExA(charsetKey, encoding_name, 0, KEY_READ, &encodingKey);
    assert(res == ERROR_SUCCESS);

    // Check is this key is an alias.
    DWORD value_type = 0;

    UCHAR value_buffer[1024];
    DWORD value_buffer_length = 1024;
    value_type = 0;
    value_buffer_length = sizeof(DWORD);
    res = RegQueryValueExA(encodingKey, "InternetEncoding", 0, &value_type, value_buffer, &value_buffer_length);
    if (res == ERROR_SUCCESS) {
      // This encoding has InternetEncoding subkey.
      // This is an alias encoding.  Get the real encoding name.
      encoding.cp_= *reinterpret_cast<UINT*>(value_buffer);
    }
    else {
      //Internet encoding not found, therefore use encoding.
      value_buffer_length = sizeof(DWORD);
      res = RegQueryValueExA(encodingKey, "Codepage", 0, &value_type, value_buffer, &value_buffer_length);
      if (res == ERROR_SUCCESS) {
        // This encoding has Codepage subkey.
        // This is an alias encoding.  Get the real encoding name.
        encoding.cp_= *reinterpret_cast<UINT*>(value_buffer);
      }
    }
    if (res == ERROR_SUCCESS) {
      encodings_.push_back(encoding);
    }
    RegCloseKey(encodingKey);
  }

  RegCloseKey(charsetKey);
}


void win32_encodings::initialize() {
  encodings_.clear();
  aliases_.clear();
  factory_ = 0;

  scan_aliases();
  if (!mlang_initialize()) {
    assert(!is_mlang_available());
    // mlang not available, use Win32 transcoding functions.
    scan_codepages();
  }
  std::sort(encodings_.begin(), encodings_.end(), encoding_less());
  std::sort(aliases_.begin(), aliases_.end(), encoding_less());
}


extern "C" {
  typedef HRESULT (STDAPICALLTYPE *DllGetClassObjectFunc)(const CLSID &,  const IID &, void **); 
}


int win32_encodings::mlang_initialize() {
  assert(mlang_dll_ == NULL);
  mlang_dll_ = LoadLibrary("mlang.dll");
  if (mlang_dll_ == NULL) {
    return 0; // mlang.dll is not available.
  }

  DllGetClassObjectFunc get_class_object = NULL;
  get_class_object = (DllGetClassObjectFunc)GetProcAddress(mlang_dll_, "DllGetClassObject");
  if (get_class_object == NULL) {
    throw aka2::error("Failed to get DllGetClassObject for mlang.dll", __FILE__, __LINE__);
  }
  HRESULT res = get_class_object(CLSID_CMultiLanguage, IID_IClassFactory, (void**)&factory_);
  if (res != S_OK) {
    throw aka2::error("Failed to create IClassFactory for CMultiLanguage.", __FILE__, __LINE__);
  }
  mlang_scan_codepages();
  return 1;
}

void win32_encodings::uninitialize() {
  encodings_.clear();
  aliases_.clear();
  if (factory_ != NULL)
    factory_->Release();
  if (mlang_dll_ != NULL)
    FreeLibrary(mlang_dll_);
}

std::string win32_encodings::get_default_encoding() {
  UINT acp = GetACP();
  for (encodings::const_iterator it = encodings_.begin();
    it != encodings_.end(); ++it) {
      if (!it->realname_.empty())
        continue; // ignore alias name.
      if (it->cp_ == acp)
        return it->name_;
    }
  throw error("win32_transcoder: Cannot find default encoding.", __FILE__, __LINE__);
  return std::string();
}

UINT win32_encodings::get_cp(const std::string &encoding_name) {
  std::string key = create_encoding_key(encoding_name);
  encodings::const_iterator lit = std::lower_bound(encodings_.begin(), encodings_.end(), key, encoding_less());
  encodings::const_iterator uit = std::upper_bound(encodings_.begin(), encodings_.end(), key, encoding_less());

  if (lit != uit)
    return lit->cp_;

  lit = std::lower_bound(aliases_.begin(), aliases_.end(), key, encoding_less());
  uit = std::upper_bound(aliases_.begin(), aliases_.end(), key, encoding_less());
  if (lit != uit) {
    key = create_encoding_key(lit->realname_);
    lit = std::lower_bound(encodings_.begin(), encodings_.end(), key, encoding_less());
    uit = std::upper_bound(encodings_.begin(), encodings_.end(), key, encoding_less());
    assert(lit != uit);
    return lit->cp_;
  }
  return 0;
}

bool win32_transcoder::from_inbuf() {

  for (const char *inbuf_current = inbuf_.get_ptr(); inbuf_current < inbuf_.get_end_ptr(); ) {

    int char_len = IsDBCSLeadByteEx(cp_from_, *inbuf_current) ? 2 : 1;
    const char *next_char = inbuf_current + char_len;
    if (inbuf_.get_end_ptr() < next_char) {
      // incomplete input.
      inbuf_.erase_by(inbuf_current);
      return true;
    }

    aka2::uchar_t *pivot_current = ubuffer_.get_end_ptr(1);
    int pivot_len = ubuffer_.get_remain_length();

    int length = MultiByteToWideChar(cp_from_,
                                     MB_ERR_INVALID_CHARS,
                                     inbuf_current, char_len,
                                     pivot_current, 1);
    if (length == 0) {
      assert(GetLastError() == ERROR_NO_UNICODE_TRANSLATION);
      return false; // ill sequence.
    }
    assert(length == 1);
    ubuffer_.commit_additional_length(length);
    inbuf_current += char_len;
  }
  inbuf_.clear();
  return true;
}
  

bool win32_transcoder::to_outbuf(const aka2::uchar_t *pivot_current) {
  // 'from UTF' part.
  aka2::uchar_t *pivot_end = ubuffer_.get_end_ptr();
  int pivot_length = pivot_end - pivot_current;
  assert(pivot_length != 0);

  for (; pivot_current < pivot_end; ++pivot_current) {
    BOOL error_found = 0;
    char *outbuf = outbuf_.get_end_ptr(6);
    int outbuf_length = outbuf_.get_remain_length();
    int length = WideCharToMultiByte(cp_to_,
                                     0,
                                     pivot_current, 1,
                                     outbuf, outbuf_length,
                                     0, &error_found);
    if ((length == 0) || error_found) {
      ucs2_escape(*pivot_current);
    }
    else {
      outbuf_.commit_additional_length(length);
    }
  }
  ubuffer_.clear();
  return true;
}

void win32_transcoder::reset() {
  ubuffer_.clear();
}

transcoder *win32_transcoder::create(const std::string &tocode, const std::string &fromcode) {
  UINT cp_to = 0;
  UINT cp_from = 0;

  if (is_ucs2(tocode)) {
    cp_to = 0;
  }
  else if (is_utf8(tocode)) {
    cp_to = 0;
  }
  else {
    cp_to = g_win32_encodings_.get_cp(tocode);
    if (cp_to != GetACP()) {
      throw error("win32_transcoder: Cannot handle encoding other than LCP.", __FILE__, __LINE__);
    }
  }
  if (is_ucs2(fromcode)) {
    cp_from = 0;
  }
  else if (is_utf8(fromcode)) {
    cp_from = 0;
  }
  else {
    cp_from = g_win32_encodings_.get_cp(fromcode);
    if (cp_from != GetACP())
      throw error("win32_transcoder: Cannot handle encoding other than LCP.", __FILE__, __LINE__);
  }
  return new win32_transcoder(cp_to, cp_from);
}




#include <iostream>


void win32_encodings::mlang_scan_codepages() {
  IMultiLanguage* mlang;
  get_mlang(&mlang);
  IEnumCodePage* enumcp = 0;
  HRESULT hr = mlang->EnumCodePages(0, &enumcp);
  mlang->Release();
  assert(hr == S_OK);

  while (true) {
    MIMECPINFO cpinfo;
    ULONG num_fetched = 0;
    hr = enumcp->Next(1, &cpinfo, &num_fetched);
    if (hr == S_FALSE)
      break;
    assert(hr == S_OK);
    assert(num_fetched == 1);
    win32_encoding encoding;
    encoding.cp_ = cpinfo.uiCodePage;
    encoding.name_ = uniconv::ucs2_to_utf8(cpinfo.wszWebCharset);
    encoding.key_ = create_encoding_key(encoding.name_);      
    //std::cerr << encoding.name_ << " " << encoding.cp_ << std::endl;
    encodings_.push_back(encoding);
  }
  enumcp->Release();
}


void win32_encodings::get_mlang(IMultiLanguage **mlang) {
  IUnknown *pUnk = NULL;
  HRESULT hr = factory_->CreateInstance(0, IID_IMultiLanguage, (void**)mlang);
  if (hr != S_OK) {
    throw error("Cannot create MultiLanguage object.", __FILE__, __LINE__);
  }
}

transcoder *win32_encodings::create(const std::string &tocode, const std::string &fromcode) {
  if (win32_encodings::is_mlang_available())
    return mlang_transcoder::create(tocode, fromcode);
  return win32_transcoder::create(tocode, fromcode);
}


mlang_transcoder::~mlang_transcoder() {
  assert(mlang_ != 0);
  mlang_->Release();
  mlang_ = 0;
}

bool mlang_transcoder::from_inbuf() {

  for (const char *inbuf_current = inbuf_.get_ptr(); inbuf_current < inbuf_.get_end_ptr(); ) {

    UINT char_len = IsDBCSLeadByteEx(cp_from_, *inbuf_current) ? 2 : 1;
    const char *next_char = inbuf_current + char_len;
    if (inbuf_.get_end_ptr() < next_char) {
      // incomplete input.
      inbuf_.erase_by(inbuf_current);
      return true;
    }

    aka2::uchar_t *pivot_current = ubuffer_.get_end_ptr(1);
    UINT pivot_len = ubuffer_.get_remain_length();


    DWORD context_new = from_context_;
    HRESULT res = mlang_->ConvertStringToUnicode(&context_new, cp_from_, const_cast<char*>(inbuf_current), &char_len, pivot_current, &pivot_len);
    if (res == S_FALSE) {
      //assert(GetLastError() == ERROR_NO_UNICODE_TRANSLATION);
      return false; // ill sequence.
    }
    assert(res == S_OK);
    assert(pivot_len == 1);
    from_context_ = context_new;
    ubuffer_.commit_additional_length(pivot_len);
    inbuf_current += char_len;
  }
  inbuf_.clear();
  return true;
}
  

bool mlang_transcoder::to_outbuf(const aka2::uchar_t *pivot_current) {
  // 'from UTF' part.
  aka2::uchar_t *pivot_end = ubuffer_.get_end_ptr();
  UINT pivot_length = pivot_end - pivot_current;
  assert(pivot_length != 0);

  for (; pivot_current < pivot_end; ++pivot_current) {
    BOOL error_found = 0;
    char *outbuf = outbuf_.get_end_ptr(6);
    UINT outbuf_length = outbuf_.get_remain_length();
    pivot_length = 1;
    DWORD context_new = to_context_;
    HRESULT res = mlang_->
      ConvertStringFromUnicode(&context_new, cp_to_, 
			       const_cast<aka2::uchar_t*>(pivot_current), 
			       &pivot_length, outbuf, &outbuf_length);
    if (res != S_OK) {
      ucs2_escape(*pivot_current);
    }
    else {
      to_context_ = context_new;
      outbuf_.commit_additional_length(outbuf_length);
    }
  }
  ubuffer_.clear();
  return true;
}

void mlang_transcoder::reset() {
  ubuffer_.clear();
  to_context_ = 0;
  from_context_ = 0;
}


transcoder *mlang_transcoder::create(const std::string &tocode, const std::string &fromcode) {
  UINT cp_to = 0;
  UINT cp_from = 0;

  if (is_ucs2(tocode)) {
    cp_to = 0;
  }
  else if (is_utf8(tocode)) {
    cp_to = 0;
  }
  else {
    cp_to = g_win32_encodings_.get_cp(tocode);
    if (cp_to == 0)
      throw error("win32_transcoder: Cannot find encoding.", __FILE__, __LINE__);
  }

  if (is_ucs2(fromcode)) {
    cp_from = 0;
  }
  else if (is_utf8(fromcode)) {
    cp_from = 0;
  }
  else {
    cp_from = g_win32_encodings_.get_cp(fromcode);
    if (cp_from == 0)
      throw error("mlang_transcoder: Cannot find encoding.", __FILE__, __LINE__);
  }
  IMultiLanguage *mlang = 0;
  g_win32_encodings_.get_mlang(&mlang);

  return new mlang_transcoder(cp_to, cp_from, mlang);
}


#endif
