#include "stdafx.h"
#include "SocketX.h"

//////////////////////////////////////////////////////////////////////////

void SocketXFinalize()
{
	WSACleanup();
}

BOOL SocketXInitialize()
{
	WSADATA wsaData;
	WORD	wVersion = MAKEWORD(2, 2);

	if (WSAStartup(wVersion, &wsaData) != 0)
	{
		TRACE0("Failed to WSAStartup.\n");
		return FALSE;
	}

	if (wVersion != wsaData.wVersion)
	{
		TRACE0("Failed to WinhSocket version.\n");
		return FALSE;
	}

	atexit(SocketXFinalize);

	return TRUE;
}

//////////////////////////////////////////////////////////////////////////

IMPLEMENT_DYNAMIC(IBaseSocket, CWnd)

BEGIN_MESSAGE_MAP(IBaseSocket, CWnd)
	ON_MESSAGE(SM_EVENT, &IBaseSocket::OnSocketMessage)
	ON_MESSAGE(SM_GETHOST, &IBaseSocket::OnGetHostMessage)
END_MESSAGE_MAP()

IBaseSocket::IBaseSocket() :
	m_hSocket(INVALID_SOCKET),
	m_dwError(0),
	m_dwSendTimeout(SX_SEND_TIMEOUT),
	m_dwReceiveTimeout(SX_RECV_TIMEOUT),
	m_lEvent(0),
	m_wBindPort(0),
	m_fBlocking(TRUE),
	m_fWnd(FALSE),
	m_fInit(FALSE)
{
	ZeroMemory(&m_SockMsg, sizeof(m_SockMsg));
	ZeroMemory(&m_HostName.buf, sizeof(m_HostName.buf));
	ZeroMemory(&m_LocalNet, sizeof(m_LocalNet));
	ZeroMemory(&m_LocalAddr, sizeof(m_LocalAddr));
	ZeroMemory(&m_PeerAddr, sizeof(m_PeerAddr));
}

IBaseSocket::~IBaseSocket()
{
	doCloseSocket(m_hSocket);

	if (m_fWnd)
	{
		DestroyWindow();
	}
}

LRESULT IBaseSocket::OnSocketMessage(WPARAM wParam, LPARAM lParam)
{
	SOCKET hSocket = (SOCKET)wParam;
	int nEventCode = WSAGETSELECTEVENT(lParam);
	int nErrorCode = WSAGETSELECTERROR(lParam);

	switch (nEventCode)
	{
	case FD_CONNECT:
		OnConnect(hSocket, nErrorCode);
		break;

	case FD_ACCEPT:
		OnAccept(hSocket, nErrorCode);
		break;

	case FD_WRITE:
		OnSend(hSocket, nErrorCode);
		break;

	case FD_READ:
		OnReceive(hSocket, nErrorCode);
		break;

	case FD_CLOSE:
		OnClose(hSocket, nErrorCode);
		break;
	}

	return 0;
}

LRESULT IBaseSocket::OnGetHostMessage(WPARAM wParam, LPARAM lParam)
{
	int nErrorCode = WSAGETSELECTERROR(lParam);
	HANDLE hTask   = (HANDLE)wParam;

	OnGetHostByName(hTask, nErrorCode);

	return 0;
}

void IBaseSocket::OnGetHostByName(HANDLE hTask, int nErrorCode)
{
	LPHOSTENT lpHostEnt = reinterpret_cast<LPHOSTENT>(m_HostName.buf);

	if (m_HostName.hTask != hTask)
	{
		TRACE0("***** ERROR: GetHostByName(Handle) *****\n");
		return;
	}

	m_SockMsg.pTgtObj->OnGetHostByNameMessage(lpHostEnt);
}

BOOL IBaseSocket::CreateSocketWindow()
{
	CString strClass;
	CRect	rect;

	if (m_fWnd)
	{
		// Window has already been created.
		DestroyWindow();
		m_fWnd = FALSE;
	}

	try
	{
		strClass = AfxRegisterWndClass(0);
	}
	catch (CResourceException *e) 
	{
		TRACE0("***** ERROR: AfxRegisterWndClass *****\n");
		m_dwError = GetLastError();
		e->Delete();
		return FALSE;
	}

	if (!strClass.IsEmpty())
	{
		rect.SetRect(0, 0, 100, 100);

		if (CreateEx(0, strClass, NULL, WS_OVERLAPPEDWINDOW, rect, NULL, 0))
		{
			m_fWnd = TRUE;
			return TRUE;
		}
		else
		{
			TRACE0("***** ERROR: CreateEx *****\n");
			m_dwError = GetLastError();
			return FALSE;
		}
	}

	TRACE0("***** ERROR: CreateSocketWindow *****\n");

	return FALSE;
}

BOOL IBaseSocket::CloseSocket()
{
	m_fInit = !doCloseSocket(m_hSocket);
	return m_fInit;
}

BOOL IBaseSocket::IsLANConnection(const LPSOCKADDR lpSockAddr)
{
	SOCKADDR_IN siAddr;
	CopyMemory(&siAddr, lpSockAddr, sizeof(siAddr));
	return (m_LocalNet.dwNet == (m_LocalNet.dwMask & siAddr.sin_addr.s_addr));
}

void IBaseSocket::MakeSockAddrIN(LPSOCKADDR_IN lpSockAddrIN, DWORD dwAddress, 
								 WORD wPort, UINT nAF)
{
	ZeroMemory(lpSockAddrIN, sizeof(LPSOCKADDR_IN));
	lpSockAddrIN->sin_port		   = htons(wPort);
	lpSockAddrIN->sin_family	   = nAF;
	lpSockAddrIN->sin_addr.s_addr  = dwAddress;
}

void IBaseSocket::ConvertSockAddrToNetAddr(const LPSOCKADDR lpSockAddr, NETADDR &NetAddr)
{
	SOCKADDR_IN siAddr;
	CopyMemory(&siAddr, lpSockAddr, sizeof(siAddr));
	NetAddr.wPort	  = ntohs(siAddr.sin_port);
	NetAddr.nAF		  = siAddr.sin_family;
	NetAddr.dwAddress = siAddr.sin_addr.s_addr;
}

BOOL IBaseSocket::Shutdown(int nFlag)
{
	return doShutdown(m_hSocket, nFlag);
}

BOOL IBaseSocket::AsyncSelect(LONG lEvent)
{
	if (doAsyncSelect(m_hSocket, SM_EVENT, lEvent))
	{
		m_lEvent	= lEvent;
		m_fBlocking = FALSE;
		return TRUE;
	}

	return FALSE;
}

BOOL IBaseSocket::AsyncGetHostByName(const LPCSTR szHostName)
{
	return doAsyncGetHostByName(SM_GETHOST, szHostName, m_HostName.buf, sizeof(m_HostName.buf));
}

BOOL IBaseSocket::SetBlocking()
{
	BOOL fRet = TRUE;

	if (m_fBlocking)
	{
		fRet = doSetBlocking(m_hSocket);

		if (fRet)
		{
			m_fBlocking = TRUE;
			m_lEvent	= 0;
		}
	}

	return FALSE;
}

BOOL IBaseSocket::Bind(WORD wPort, DWORD dwAddress)
{
	return doBind(m_hSocket, wPort, dwAddress);
}

BOOL IBaseSocket::GetSockName(LPSOCKADDR lpSockAddr)
{
	return doGetSockName(m_hSocket, lpSockAddr);
}

BOOL IBaseSocket::GetPeerName(LPSOCKADDR lpSockAddr)
{
	return doGetPeerName(m_hSocket, lpSockAddr);
}

BOOL IBaseSocket::SetSockOpt(int nOptionName, LPCSTR lpOptionValue, 
							 int nOptionLen, int nLevel)
{
	return doSetSockOpt(m_hSocket, nLevel, nOptionName, lpOptionValue, nOptionLen);
}

BOOL IBaseSocket::GetSockOpt(int nOptionName, LPSTR lpOptionValue, 
							 int *lpOptionLen, int nLevel)
{
	return doGetSockOpt(m_hSocket, nLevel, nOptionName, lpOptionValue, lpOptionLen);
}

BOOL IBaseSocket::SetSendTimeout(DWORD dwTimeout)
{
	return doSetSendTimeout(m_hSocket, dwTimeout);
}

BOOL IBaseSocket::SetRecieveTimeout(DWORD dwTimeout)
{
	return doSetReceiveTimeout(m_hSocket, dwTimeout);
}

BOOL IBaseSocket::SetSendBufferSize(DWORD dwBufSize)
{
	return doSetSendBufferSize(m_hSocket, dwBufSize);
}

BOOL IBaseSocket::SetReceiveBufferSize(DWORD dwBufSize)
{
	return doSetReceiveBufferSize(m_hSocket, dwBufSize);
}

BOOL IBaseSocket::SetKeepAlive(DWORD dwTime, DWORD dwInterval)
{
	return doSetKeepAlive(m_hSocket, dwTime, dwInterval);
}

BOOL IBaseSocket::SetReuseAddr(BOOL flag)
{
	return doSetReuseAddr(m_hSocket, flag);
}

BOOL IBaseSocket::IOCtlSocket(LONG lCommand, DWORD *lpArgument)
{
	return doIOCtlSocket(m_hSocket, lCommand, lpArgument);
}

BOOL IBaseSocket::WSAIOCtl(DWORD dwIoCtlCode, LPVOID lpInBuf, DWORD dwInBuf, 
						   LPVOID lpOutBuf, DWORD dwOutBuf, LPDWORD dwByteReturned, 
						   LPWSAOVERLAPPED lpOverlapped, 
						   LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine)
{
	return doWSAIOCtl(m_hSocket, dwIoCtlCode, lpInBuf, dwInBuf, lpOutBuf, dwOutBuf,
					  dwByteReturned, lpOverlapped, lpCompletionRoutine);
}

void IBaseSocket::GetLocalAddress(NETADDR &NetAddr)
{
	ConvertSockAddrToNetAddr(&m_LocalAddr, NetAddr);
}

void IBaseSocket::GetPeerAddress(NETADDR &NetAddr)
{
	ConvertSockAddrToNetAddr(&m_PeerAddr, NetAddr);
}

BOOL IBaseSocket::Send(const CByteArray &data)
{
	return doSend(m_hSocket, data);
}

BOOL IBaseSocket::Receive(CByteArray &data)
{
	return doReceive(m_hSocket, data);
}

BOOL IBaseSocket::SendTo(const CByteArray &data, const LPSOCKADDR lpSockAddr, int nSockAddrLen)
{
	return doSendTo(m_hSocket, data, lpSockAddr, nSockAddrLen);
}

BOOL IBaseSocket::ReceiveFrom(CByteArray &data, LPSOCKADDR lpSockAddr, int *lpSockAddrLen)
{
	return doReceiveFrom(m_hSocket, data, lpSockAddr, lpSockAddrLen);
}

//////////////////////////////////////////////////////////////////////////

BOOL IBaseSocket::doCloseSocket(SOCKET hSocket)
{
	if (closesocket(hSocket) != 0)
	{
		TRACE1("***** ERROR: closesocket(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	hSocket  = INVALID_SOCKET;
	m_lEvent = 0;

	return TRUE;
}

BOOL IBaseSocket::doShutdown(SOCKET hSocket, int nFlag)
{
	if (shutdown(hSocket, nFlag) == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: shutdown(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doAsyncSelect(SOCKET hSocket, UINT uMsg, LONG lEvent)
{
	if (WSAAsyncSelect(hSocket, m_hWnd, SM_EVENT, lEvent) != 0)
	{
		TRACE1("***** ERROR: WSAAsyncSelect(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doAsyncGetHostByName(UINT uMsg, const LPCSTR szHostName, LPSTR szHostEnt, int nBufLen)
{
	ZeroMemory(szHostEnt, nBufLen);

	m_HostName.hTask = WSAAsyncGetHostByName(m_hWnd, uMsg, szHostName, szHostEnt, nBufLen);

	if (m_HostName.hTask != 0)
	{
		TRACE1("***** ERROR: WSAAsyncGetHostByName(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doSetBlocking(SOCKET hSocket)
{
	DWORD dwVal = 0;
	return (doAsyncSelect(hSocket, 0, 0) & doIOCtlSocket(hSocket, FIONBIO, &dwVal));
}

BOOL IBaseSocket::doBind(SOCKET hSocket, WORD wPort, DWORD dwAddress)
{
	SOCKADDR_IN siAddr;

	MakeSockAddrIN(&siAddr, dwAddress, wPort);

	if (bind(hSocket, (LPSOCKADDR)&siAddr, sizeof(siAddr)) == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: bind(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	m_wBindPort = wPort;

	return TRUE;
}

BOOL IBaseSocket::doGetSockName(SOCKET hSocket, LPSOCKADDR lpSockAddr)
{
	int nLenSA = sizeof(SOCKADDR);

	if (getsockname(hSocket, lpSockAddr, &nLenSA) != 0)
	{
		TRACE1("***** ERROR: getsockname(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doGetPeerName(SOCKET hSocket, LPSOCKADDR lpSockAddr)
{
	int nLenSA = sizeof(SOCKADDR);

	if (getpeername(hSocket, lpSockAddr, &nLenSA) != 0)
	{
		TRACE1("***** ERROR: getpeername(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doSetSockOpt(SOCKET hSocket, int nLevel, int nOptionName,
							   LPCSTR lpOptionValue, int nOptionLen)
{
	if (setsockopt(hSocket, nLevel, nOptionName, lpOptionValue, nOptionLen) == SOCKET_ERROR)
	{
		TRACE2("***** ERROR: setsockopt(%d) option(%d) *****\n", GetLastError(), nOptionName);
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doGetSockOpt(SOCKET hSocket, int nLevel, int nOptionName, 
							   LPSTR lpOptionValue, int *lpOptionLen)
{
	if (getsockopt(hSocket, nLevel, nOptionName, lpOptionValue, lpOptionLen) == SOCKET_ERROR)
	{
		TRACE2("***** ERROR: getsockopt(%d) option(%d) *****\n", GetLastError(), nOptionName);
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doSetSendTimeout(SOCKET hSocket, DWORD dwTimeout)
{
	return doSetSockOpt(hSocket, SOL_SOCKET, SO_SNDTIMEO, (LPSTR)&dwTimeout, sizeof(dwTimeout));
}

BOOL IBaseSocket::doSetReceiveTimeout(SOCKET hSocket, DWORD dwTimeout)
{
	return doSetSockOpt(hSocket, SOL_SOCKET, SO_RCVTIMEO, (LPSTR)&dwTimeout, sizeof(dwTimeout));
}

BOOL IBaseSocket::doSetSendBufferSize(SOCKET hSocket, DWORD dwBufSize)
{
	return doSetSockOpt(hSocket, SOL_SOCKET, SO_SNDBUF, (LPSTR)&dwBufSize, sizeof(dwBufSize));
}

BOOL IBaseSocket::doSetReceiveBufferSize(SOCKET hSocket, DWORD dwBufSize)
{
	return doSetSockOpt(hSocket, SOL_SOCKET, SO_RCVBUF, (LPSTR)&dwBufSize, sizeof(dwBufSize));
}

BOOL IBaseSocket::doSetReuseAddr(SOCKET hSocket, BOOL flag)
{
	return doSetSockOpt(hSocket, SOL_SOCKET, SO_REUSEADDR, (LPSTR)&flag, sizeof(flag));
}

BOOL IBaseSocket::doSetKeepAlive(SOCKET hSocket, DWORD dwKaTime, DWORD dwInterval)
{
	struct tcp_keepalive tkaSet, tkaRet;
	DWORD dwBytes;
	BOOL  flag = TRUE;

	if (!doSetSockOpt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (LPSTR)&flag, sizeof(flag)))
	{
		return FALSE;
	}

	ZeroMemory(&tkaSet, sizeof(tkaSet));
	ZeroMemory(&tkaRet, sizeof(tkaRet));
	tkaSet.onoff			 = 1;
	tkaSet.keepalivetime	 = dwKaTime;
	tkaSet.keepaliveinterval = dwInterval;

	if (!doWSAIOCtl(hSocket, SIO_KEEPALIVE_VALS, &tkaSet, sizeof(tkaSet),
		&tkaRet, sizeof(tkaRet), &dwBytes, NULL, NULL))
	{
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doIOCtlSocket(SOCKET hSocket, LONG lCommand, DWORD *lpArgument)
{
	if (ioctlsocket(hSocket, lCommand, lpArgument) == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: ioctlsocket(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doWSAIOCtl(SOCKET hSocket, DWORD dwIoCtlCode, LPVOID lpInBuf, DWORD dwInBuf, LPVOID lpOutBuf, 
							 DWORD dwOutBuf, LPDWORD dwByteReturned, LPWSAOVERLAPPED lpOverlapped, 
							 LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine)
{
	if (WSAIoctl(hSocket, dwIoCtlCode, lpInBuf, dwInBuf, lpOutBuf, 
		dwOutBuf, dwByteReturned, lpOverlapped, lpCompletionRoutine) == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: WSAIoctl(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doSend(SOCKET hSocket, const CByteArray &data)
{
	CByteArray snd;
	size_t dataSize;
	size_t sendSize;
	size_t bufPtr;
	LONG   lEvent;
	BOOL   fResult;
	int    nRet;

	lEvent = (hSocket == m_hSocket) ? m_lEvent : (FD_READ | FD_WRITE | FD_CLOSE);

	if ((lEvent != 0) && (lEvent & FD_WRITE))
	{
		// FD_WRITE cancel
		doAsyncSelect(hSocket, SM_EVENT, (lEvent ^ FD_WRITE));
	}

	dataSize = data.GetSize();
	sendSize = dataSize + sizeof(HEADER);
	bufPtr   = 0;
	fResult  = TRUE;

	snd.SetSize(sendSize);
	((LPHEADER)snd.GetData())->size = dataSize;
	CopyMemory((snd.GetData() + sizeof(HEADER)), data.GetData(), dataSize);
	snd.FreeExtra();

	do
	{
		nRet = send(hSocket, reinterpret_cast<const LPSTR>(snd.GetData() + bufPtr), sendSize, 0);

		if (nRet == SOCKET_ERROR)
		{
			m_dwError = GetLastError();

			if (m_dwError == WSAEWOULDBLOCK)
			{
				continue;
			}
			else
			{
				TRACE1("***** ERROR: send(%d) *****\n", m_dwError);
				fResult = FALSE;
				break;
			}
		}
		else
		{
			bufPtr   += nRet;
			sendSize -= nRet;
		}
	}
	while (sendSize > 0);

	if ((lEvent != 0) && (lEvent & FD_WRITE))
	{
		// FD_WRITE setting
		doAsyncSelect(hSocket, SM_EVENT, lEvent);
	}

	return fResult;
}

BOOL IBaseSocket::doReceive(SOCKET hSocket, CByteArray &data)
{
	HEADER header;
	size_t recvSize;
	size_t dataSize;
	size_t bufPtr;
	BOOL   fResult;
	LONG   lEvent;
	int	   nRet;

	lEvent = (hSocket == m_hSocket) ? m_lEvent : (FD_READ | FD_WRITE | FD_CLOSE);

	if ((lEvent != 0) && (lEvent & FD_READ))
	{
		// FD_READ cancel
		doAsyncSelect(hSocket, SM_EVENT, (lEvent ^ FD_READ));
	}

	// Receive header
	nRet = recv(hSocket, reinterpret_cast<LPSTR>(&header.size), sizeof(HEADER), MSG_PEEK);

	if (nRet == sizeof(HEADER))
	{
		dataSize = header.size + sizeof(HEADER);
		recvSize = dataSize;
		bufPtr   = 0;

		data.RemoveAll();
		data.SetSize(dataSize);
		fResult = TRUE;

		do
		{
			// Receive data
			nRet = recv(hSocket, reinterpret_cast<LPSTR>(data.GetData() + bufPtr), recvSize, 0);

			if (nRet == SOCKET_ERROR)
			{
				m_dwError = GetLastError();

				if (m_dwError == WSAEWOULDBLOCK)
				{
					continue;
				}
				else
				{
					TRACE1("***** ERROR: recv - Data(%d) *****\n", m_dwError);
					fResult = FALSE;
					break;
				}
			}
			else
			{
				bufPtr   += nRet;
				recvSize -= nRet;
			}
		}
		while (recvSize > 0);

		// Header remove
		if (data.GetSize() >= sizeof(HEADER))
		{
			data.RemoveAt(0, sizeof(HEADER));
			data.FreeExtra();
		}
	}
	else
	{
		TRACE1("***** ERROR: recv - Header(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		fResult = FALSE;
	}

	if ((lEvent != 0) && (lEvent & FD_READ))
	{
		// FD_READ setting
		doAsyncSelect(hSocket, SM_EVENT, lEvent);
	}

	return fResult;
}

BOOL IBaseSocket::doSendTo(SOCKET hSocket, const CByteArray &data, const LPSOCKADDR lpSockAddr, int nSockAddrLen)
{
	int	nRet;

	nRet = sendto(hSocket, reinterpret_cast<LPSTR>(const_cast<BYTE *>(data.GetData())), 
				  SX_UDP_BUFSIZE, 0, lpSockAddr, nSockAddrLen);

	if (nRet == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: sendto(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL IBaseSocket::doReceiveFrom(SOCKET hSocket, CByteArray &data, LPSOCKADDR lpSockAddr, int *lpSockAddrLen)
{
	int	nRet;
	int	nLenSA = sizeof(SOCKADDR);

	if (data.GetSize() != SX_UDP_BUFSIZE)
	{
		data.RemoveAll();
		data.SetSize(SX_UDP_BUFSIZE);
	}

	nRet = recvfrom(hSocket, reinterpret_cast<LPSTR>(data.GetData()),
					SX_UDP_BUFSIZE, 0, lpSockAddr, lpSockAddrLen);

	if (nRet == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: recvfrom(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	data.FreeExtra();

	return TRUE;
}

//////////////////////////////////////////////////////////////////////////

