//======================================================================
//-----------------------------------------------------------------------
/**
 * @file		WXProcHook.cpp
 * @brief		֐tbN t@C
 *
 * @author		t.sirayanagi
 * @version		1.0
 *
 * @par			copyright
 * Copyright (C) 2010-2012 Takazumi Shirayanagi\n
 * The new BSD License is applied to this software.
 * see iris_LICENSE.txt
*/
//-----------------------------------------------------------------------
//======================================================================
#define INCG_IRIS_WXProcHook_CPP_

//======================================================================
// include
#include "WXProcHook.h"
#include "WXImageDirectoryEntry.h"
#include "WXDebugHelp.h"

namespace iris {
namespace wx {
namespace dbg
{

//======================================================================
// class
/**********************************************************************//**
 *
 * RXgN^
 *
*//***********************************************************************/
CProcHook::CProcHook(void)
: m_pfnOrigin(nullptr)
, m_pfnCurrent(nullptr)
, m_ppAddress(nullptr)
, m_pfnPrev(nullptr)
{
}

/**********************************************************************//**
 *
 * fXgN^
 *
*//***********************************************************************/
CProcHook::~CProcHook(void)
{
	Close();
}

/**********************************************************************//**
 *
 * ^[Qbg̃I[v
 *
 -----------------------------------------------------------------------
 * @param [in]	pszModuleName	= ^[Qbg֐W[
 * @param [in]	pfnTarget		= ^[Qbg֐|C^
 * @param [in]	hModule			= ^[QbgW[
 * @return	
*//***********************************************************************/
template<>
bool CProcHook::Open<CHAR>(LPCSTR  pszModuleName, PROC pfnTarget, HMODULE hModule)
{
	Close();

	CImageDirectoryEntry<IMAGE_DIRECTORY_ENTRY_IMPORT> pidesc((PVOID)hModule, TRUE);
	if( pidesc == nullptr ) return false;

	while( pidesc->Name )
	{
		LPCSTR lpName = pointer_cast<LPCSTR>((PBYTE)hModule + pidesc->Name);
		if( lstrcmpiA(pszModuleName, lpName) == 0 )
			break;
		++pidesc;
	}
	if( pidesc->Name == 0 ) return false;

	PIMAGE_THUNK_DATA pitd = pointer_cast<PIMAGE_THUNK_DATA>((PBYTE)hModule + pidesc->FirstThunk);

	while( pitd->u1.Function )
	{
		PROC* ppfn = pointer_cast<PROC*>(&pitd->u1.Function);
		if( *ppfn == pfnTarget )
		{
			m_ppAddress = ppfn;
			m_pfnOrigin = *ppfn;
			return true;
		}
		++pitd;
	}
	return false;
}
/// CProcHook::Open Q
template<>
bool CProcHook::Open<WCHAR>(LPCWSTR pszModuleName, PROC pfnTarget, HMODULE hModule)
{
	Close();
	// TODO : Ή
	IRIS_UNUSED_VAR(pszModuleName);
	IRIS_UNUSED_VAR(pfnTarget);
	IRIS_UNUSED_VAR(hModule);
	return false;
}

/**********************************************************************//**
 *
 * ^[Qbg̃I[v
 *
 * @note	W[S
 *
 -----------------------------------------------------------------------
 * @param [in]	pfnTarget		= ^[Qbg֐|C^
 * @param [in]	hModule			= ^[QbgW[
 * @return	
*//***********************************************************************/
bool CProcHook::Open(PROC pfnTarget, HMODULE hModule)
{
	Close();

	// Import Entry T
	{
		CImageDirectoryEntry<IMAGE_DIRECTORY_ENTRY_IMPORT> pidesc((PVOID)hModule, TRUE);
		PROC* ppfn = pidesc.FindProcAddress(pfnTarget, hModule);
		if( ppfn != nullptr )
		{
			m_ppAddress = ppfn;
			m_pfnOrigin = *ppfn;
			return true;
		}
	}

	// IAT T
	{
		CImageDirectoryEntry<IMAGE_DIRECTORY_ENTRY_IAT> pidesc((PVOID)hModule, TRUE);
		PROC* ppfn = pidesc.FindProcAddress(pfnTarget, hModule);
		if( ppfn != nullptr )
		{
			m_ppAddress = ppfn;
			m_pfnOrigin = *ppfn;
			return true;
		}
	}
	return false;
}

/**********************************************************************//**
 *
 * I[vς݂ǂ
 *
 -----------------------------------------------------------------------
 * @return	^Ul
*//***********************************************************************/
bool CProcHook::IsOpen(void)
{
	return m_ppAddress != nullptr;
}

/**********************************************************************//**
 *
 * 
 *
*//***********************************************************************/
void CProcHook::Close(void)
{
	Reset();
	m_ppAddress = nullptr;
	m_pfnOrigin  = nullptr;
	m_pfnCurrent = nullptr;
}

/**********************************************************************//**
 *
 * u
 *
 -----------------------------------------------------------------------
 * @param [in]	pfnNew			= V֐|C^
 * @return	
*//***********************************************************************/
bool CProcHook::Replace(PROC pfnNew)
{
	if( !IsOpen() ) return false;
	DWORD dwOldProtect;
	PROC* ppfn = m_ppAddress;
	m_pfnPrev = *ppfn;
	VirtualProtect(ppfn, sizeof(ppfn), PAGE_EXECUTE_READWRITE, &dwOldProtect);
	WriteProcessMemory(GetCurrentProcess(), ppfn, &pfnNew, sizeof(pfnNew), nullptr);
	VirtualProtect(ppfn, sizeof(ppfn), dwOldProtect, &dwOldProtect);
	m_pfnCurrent = pfnNew;
	return true;
}

/**********************************************************************//**
 *
 * ɖ߂
 *
 -----------------------------------------------------------------------
 * @return	
*//***********************************************************************/
bool CProcHook::Revert(void)
{
	if( !IsOpen() ) return true;
	if( m_pfnPrev != nullptr )
	{
		if( !Replace(m_pfnPrev) ) return false;
	}
	return true;
}

/**********************************************************************//**
 *
 * ɖ߂
 *
 -----------------------------------------------------------------------
 * @return	
*//***********************************************************************/
bool CProcHook::Reset(void)
{
	if( !IsOpen() ) return true;
	if( m_pfnOrigin != nullptr )
	{
		if( !Replace(m_pfnOrigin) ) return false;
	}
	return true;
}


}	// end of namespace dbg
}	// end of namespace wx
}	// end of namespace iris

#if	(defined(_IRIS_SUPPORT_GOOGLETEST) || defined(_IRIS_UNITTEST) || defined(_IRIS_MULTI_UNITTEST))

//======================================================================
// include
#include "unit/gt/gt_inchead.h"
#include "iris_using.h"

static int WINAPI _User_MessageBoxA(HWND hWnd, LPCSTR lpText, LPCSTR lpCaption, UINT uType)
{
	IRIS_UNUSED_VARIABLE(hWnd);
	IRIS_UNUSED_VARIABLE(lpText);
	IRIS_UNUSED_VARIABLE(lpCaption);
	IRIS_UNUSED_VARIABLE(uType);
	return -1;
}

TEST(ProcHookTest, MessageBox)
{
	{
		::iris::wx::dbg::CProcHook phook;
		ASSERT_TRUE( phook.Open("user32.dll", "MessageBoxA", ::GetModuleHandle(nullptr)) );
		ASSERT_TRUE( phook.Replace((PROC)_User_MessageBoxA) );

		int ret = MessageBoxA(nullptr, "ProcHookTest", "test", 0);
		ASSERT_TRUE( ret == -1 );
	}

	{
		::iris::wx::dbg::CProcHook phook;
		ASSERT_TRUE( phook.Open("user32.dll", "MessageBoxA", ::GetModuleHandle(nullptr)) );
	}
}

TEST(ProcHookTest, SearchAll)
{
	{
		::iris::wx::dbg::CProcHook phook;
		ASSERT_TRUE( phook.Open((PROC)MessageBoxA, ::GetModuleHandle(nullptr)) );
		ASSERT_TRUE( phook.Replace((PROC)_User_MessageBoxA) );

		int ret = MessageBoxA(nullptr, "ProcHookTest", "test", 0);
		ASSERT_TRUE( ret == -1 );
	}

	{
		::iris::wx::dbg::CProcHook phook;
		ASSERT_TRUE( phook.Open("user32.dll", "MessageBoxA", ::GetModuleHandle(nullptr)) );
	}
}

#endif
