#pragma once

/////////////////////////////////////////////////////////////////////////////////
//
// The MIT License
// 
// Copyright (c) 2010 Yoshida Shoichi
// (yoshida.sho1@gmail.com)
// 
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// 
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// 
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
//
/////////////////////////////////////////////////////////////////////////////////

#include <afx.h>
#include <afxwin.h>
#include <afxtempl.h>	// For CArray
#include <afxmt.h>		// For CCriticalSection

#include <WinSock2.h>
#include <MSTcpIP.h>	// For tcp_keepalive, promiscuous
#include <memory>		// For shard_ptr

#pragma comment(lib, "ws2_32")

#define SM_EVENT				(WM_APP + 100)
#define SM_GETHOST				(WM_APP + 101)

#define SX_KA_TIME				30000	//  30 seconds
#define SX_KA_INTERVAL			1000	//   1 seconds
#define SX_SEND_TIMEOUT			20000	//  20 seconds
#define SX_RECV_TIMEOUT			20000	//  20 seconds
#define SX_TCP_MAXBUFSIZE		262144	// 256 KByte
#define SX_TCP_MINBUFSIZE		8192	//   8 KByte
#define SX_UDP_BUFSIZE			4096	//   4 KByte

#define SX_EVENT_SEND			1000
#define SX_EVENT_RECEIVE		2000
#define SX_EVENT_CONNECT		3000
#define SX_EVENT_ACCEPT			4000
#define SX_EVENT_CLOSE			5000

#define FD_SERVER				(FD_ACCEPT | FD_WRITE | FD_READ | FD_CLOSE)
#define FD_CLIENT				(FD_WRITE | FD_READ | FD_CLOSE)
#define FD_PEER					(FD_WRITE | FD_READ | FD_CLOSE)

//////////////////////////////////////////////////////////////////////////

BOOL SocketXInitialize();

class ISocketMessage;
class IBaseSocket;
class CServerSocket;
class CClientSocket;
class CPeerSocket;
class CBcastSocket;
class CMcastSocket;
class CRawSocket;

typedef std::tr1::shared_ptr<CServerSocket>	CServerSocketPtr;
typedef std::tr1::shared_ptr<CClientSocket>	CClientSocketPtr;
typedef std::tr1::shared_ptr<CPeerSocket>	CPeerSocketPtr;
typedef std::tr1::shared_ptr<CBcastSocket>	CBcastSocketPtr;
typedef std::tr1::shared_ptr<CMcastSocket>	CMcastSocketPtr;
typedef std::tr1::shared_ptr<CRawSocket>	CRawSocketPtr;

typedef struct _NETADDR
{
	DWORD dwAddress;
	WORD  wPort;
	UINT  nAF;
}
NETADDR, *LPNETADDR;

typedef struct _LOCALNET
{
	DWORD dwAddress;
	DWORD dwMask;
	DWORD dwNet;
}
LOCALNET, *LPLOCALNET;

typedef struct _IPv4
{
	DWORD dwIndex;
	DWORD dwAddress;
	DWORD dwNetMask;
	DWORD dwBcastAddress;
}
IPv4, *LPIPv4;

typedef struct _SOCKMSG
{
	ISocketMessage *pTgtObj;
	DWORD		   dwSocketID;
}
SOCKMSG, *LPSOCKMSG;

//////////////////////////////////////////////////////////////////////////

class ISocketMessage
{
public:
	ISocketMessage() {}
	~ISocketMessage() {}

	// Message function
	virtual void OnSocketSendMessage(DWORD dwSocketID) {}
	virtual void OnSocketReceiveMessage(DWORD dwSocketID, NETADDR NetAddr, const CByteArray &data) {}
	virtual void OnSocketReceiveFromMessage(DWORD dwSocketID, NETADDR NetAddr, const CByteArray &data) {}
	virtual void OnSocketAcceptMessage(DWORD dwSocketID, NETADDR NetAddr) {}
	virtual void OnSocketConnectMessage(DWORD dwSocketID, NETADDR NetAddr) {}
	virtual void OnSocketCloseMessage(DWORD dwSocketID, NETADDR NetAddr) {}
	virtual void OnGetHostByNameMessage(const LPHOSTENT lpHostEnt) {}
};

//////////////////////////////////////////////////////////////////////////

class CNetwork
{
public:
	CNetwork();
	virtual ~CNetwork();

	BOOL  GetIPv4List(DWORD &dwSize, DWORD dwMax, LPIPv4 lpList);
	BOOL  IsAvailable();
	DWORD MakeBcastAddress(DWORD dwAddress, DWORD dwNetMask);
};

//////////////////////////////////////////////////////////////////////////

class CSocketHelper
{
public:
	CSocketHelper();
	virtual ~CSocketHelper();

	void StringToByteArray(CByteArray &dst, const CString &src);
	void ByteArrayToString(CString &dst, const CByteArray &src);
	void Serialize(CByteArray &dst, const LPVOID src, DWORD dwSize);
};
//////////////////////////////////////////////////////////////////////////

// Interface class
class IBaseSocket : public CWnd
{
	DECLARE_DYNAMIC(IBaseSocket)

public:
	IBaseSocket();
	virtual ~IBaseSocket();

	virtual void OnAccept(SOCKET hSocket, int nErrorCode) {}
	virtual void OnConnect(SOCKET hSocket, int nErrorCode) {}
	virtual void OnClose(SOCKET hSocket, int nErrorCode) {}
	virtual void OnSend(SOCKET hSocket, int nErrorCode) {}
	virtual void OnReceive(SOCKET hSocket, int nErrorCode) {}
	void OnGetHostByName(HANDLE hTask, int nErrorCode);

	void SetTargetWnd(SOCKMSG SockMsg) { m_SockMsg = SockMsg; }
	void ConvertSockAddrToNetAddr(const LPSOCKADDR lpSockAddr, NETADDR &NetAddr);
	BOOL CloseSocket();
	BOOL Send(const CByteArray &data);
	BOOL Receive(CByteArray &data);
	BOOL SendTo(const CByteArray &data, const LPSOCKADDR lpSockAddr, int nSockAddrLen);
	BOOL ReceiveFrom(CByteArray &data, LPSOCKADDR lpSockAddr, int *lpSockAddrLen);
	BOOL Shutdown(int nFlag);
	BOOL AsyncSelect(LONG lEvent);
	BOOL AsyncGetHostByName(const LPCSTR szHostName);
	BOOL SetBlocking();
	BOOL Bind(WORD wPort, DWORD dwAddress = INADDR_ANY);
	BOOL GetSockName(LPSOCKADDR lpSockAddr);
	BOOL GetPeerName(LPSOCKADDR lpSockAddr);
	BOOL SetSockOpt(int nOptionName, LPCSTR lpOptionValue, int nOptionLen, int nLevel = SOL_SOCKET);
	BOOL GetSockOpt(int nOptionName, LPSTR lpOptionValue, int *lpOptionLen, int nLevel = SOL_SOCKET);
	BOOL SetSendTimeout(DWORD dwTimeout);
	BOOL SetRecieveTimeout(DWORD dwTimeout);
	BOOL SetSendBufferSize(DWORD dwBufSize);
	BOOL SetReceiveBufferSize(DWORD dwBufSize);
	BOOL SetKeepAlive(DWORD dwTime = SX_KA_TIME, DWORD dwInterval = SX_KA_INTERVAL);
	BOOL SetReuseAddr(BOOL flag);
	BOOL IOCtlSocket(LONG lCommand, DWORD *lpArgument);
	BOOL WSAIOCtl(DWORD dwIoCtlCode, LPVOID lpInBuf, DWORD dwInBuf, LPVOID lpOutBuf, 
				  DWORD dwOutBuf, LPDWORD dwByteReturned, LPWSAOVERLAPPED lpOverlapped, 
				  LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);

	void  GetLocalAddress(NETADDR &NetAddr);
	void  GetPeerAddress(NETADDR &NetAddr);
	void  SetLocalNet(LOCALNET LocalNet) { m_LocalNet = LocalNet; }
	void  GetLocalNet(LOCALNET &LocalNet) { LocalNet = m_LocalNet; }
	BOOL  IsBlocking() { return m_fBlocking; }
	LONG  GetEvent() { return m_lEvent; }
	DWORD GetError() { return m_dwError; }
	DWORD GetSendTimeout() { return m_dwSendTimeout; }
	DWORD GetReceiveTimeout() { return m_dwReceiveTimeout; }

protected:
	DECLARE_MESSAGE_MAP()
	afx_msg LRESULT OnSocketMessage(WPARAM wParam, LPARAM lParam);
	afx_msg LRESULT OnGetHostMessage(WPARAM wParam, LPARAM lParam);

	virtual BOOL CreateSocket(SOCKMSG SockMsg) = 0;

	BOOL CreateSocketWindow();
	BOOL IsLANConnection(const LPSOCKADDR lpSockAddr);
	void MakeSockAddrIN(LPSOCKADDR_IN lpSockAddrIN, DWORD dwAddress, WORD wPort, UINT nAF = AF_INET);

	BOOL doCloseSocket(SOCKET hSocket);
	BOOL doSend(SOCKET hSocket, const CByteArray &data);
	BOOL doSendTo(SOCKET hSocket, const CByteArray &data, const LPSOCKADDR lpSockAddr, int nSockAddrLen);
	BOOL doReceive(SOCKET hSocket, CByteArray &data);
	BOOL doReceiveFrom(SOCKET hSocket, CByteArray &data, LPSOCKADDR lpSockAddr, int *lpSockAddrLen);
	BOOL doSetSendTimeout(SOCKET hSocket, DWORD dwTimeout);
	BOOL doSetReceiveTimeout(SOCKET hSocket, DWORD dwTimeout);
	BOOL doSetSendBufferSize(SOCKET hSocket, DWORD dwBufSize);
	BOOL doSetReceiveBufferSize(SOCKET hSocket, DWORD dwBufSize);
	BOOL doShutdown(SOCKET hSocket, int nFlag);
	BOOL doAsyncSelect(SOCKET hSocket, UINT uMsg, LONG lEvent);
	BOOL doAsyncGetHostByName(UINT uMsg, const LPCSTR szHostName, LPSTR szHostEnt, int nBufLen);
	BOOL doSetSockOpt(SOCKET hSocket, int nLevel, int nOptionName, LPCSTR lpOptionValue, int nOptionLen);
	BOOL doGetSockOpt(SOCKET hSocket, int nLevel, int nOptionName, LPSTR lpOptionValue, int *lpOptionLen);
	BOOL doSetBlocking(SOCKET hSocket);
	BOOL doSetKeepAlive(SOCKET hSocket, DWORD dwKaTime, DWORD dwInterval);
	BOOL doBind(SOCKET hSocket, WORD wPort, DWORD dwAddress);
	BOOL doGetSockName(SOCKET hSocket, LPSOCKADDR lpSockAddr);
	BOOL doGetPeerName(SOCKET hSocket, LPSOCKADDR lpSockAddr);
	BOOL doSetReuseAddr(SOCKET hSocket, BOOL flag);
	BOOL doIOCtlSocket(SOCKET hSocket, LONG lCommand, DWORD *lpArgument);
	BOOL doWSAIOCtl(SOCKET hSocket, DWORD dwIoCtlCode, LPVOID lpInBuf, DWORD dwInBuf, LPVOID lpOutBuf, 
					DWORD dwOutBuf, LPDWORD dwByteReturned, LPWSAOVERLAPPED lpOverlapped, 
					LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine);

#ifdef _DEBUG
	CString DwToIPAddress(DWORD dwAddress)
	{
		CString cs;
		dwAddress = htonl(dwAddress);
		cs.Format(_T("%d.%d.%d.%d"),
			(dwAddress >> 24) & 0xff,
			(dwAddress >> 16) & 0xff,
			(dwAddress >> 8)  & 0xff,
			dwAddress & 0xff);
		return cs;	
	}
#endif

protected:
	typedef struct _HEADER
	{
		size_t size;
	}
	HEADER, *LPHEADER;

	typedef struct _SOCKETDATA
	{
		SOCKET  hSocket;
		NETADDR NetAddr;
	}
	SOCKETDATA, *LPSOCKETDATA;

	typedef struct _HOSTNAME
	{
		HANDLE hTask;
		CHAR   buf[MAXGETHOSTSTRUCT];
	}
	HOSTNAME, *LPHOSTNAME;

protected:
	LOCALNET m_LocalNet;
	HOSTNAME m_HostName;
	SOCKADDR m_PeerAddr;
	SOCKADDR m_LocalAddr;
	SOCKMSG  m_SockMsg;
	SOCKET   m_hSocket;
	DWORD    m_dwError;
	DWORD    m_dwSendTimeout;
	DWORD    m_dwReceiveTimeout;
	LONG     m_lEvent;
	WORD     m_wBindPort;
	BOOL     m_fBlocking;
	BOOL     m_fWnd;
	BOOL     m_fInit;
};

//////////////////////////////////////////////////////////////////////////

class CServerSocket : public IBaseSocket
{
	DECLARE_DYNAMIC(CServerSocket)

public:
	CServerSocket();
	virtual ~CServerSocket();

	virtual void OnAccept(SOCKET hSocket, int nErrorCode);
	virtual void OnClose(SOCKET hSocket, int nErrorCode);
	virtual void OnSend(SOCKET hSocket, int nErrorCode);
	virtual void OnReceive(SOCKET hSocket, int nErrorCode);

	BOOL Initialize(SOCKMSG SockMsg, WORD wPort, DWORD dwBindAddress = INADDR_ANY, int nBacklog = 5);
	BOOL Accept(LPSOCKADDR lpSockAddr, int *lpSockAddrLen);
	BOOL Listen(int nBacklog = 5);
	BOOL Broadcast(const CByteArray &data);
	BOOL SendToClient(const CByteArray &data, DWORD dwAddress, WORD wPort);

	void SetAccept(BOOL fAccept);
	BOOL GetAccept() { return m_fAccept; }
	void SetConnectionMode(BOOL fInternet) { m_fInternet = fInternet; }
	BOOL GetConnectionMode() { return m_fInternet; }

protected:
	virtual BOOL CreateSocket(SOCKMSG SockMsg);

	int  SearchSocketConnection(SOCKET hSocket);
	void AddConnection(SOCKET hSocket, const LPSOCKADDR lpSockAddr, int nSockAddrLen);
	void CloseAllSocketConnection();
	BOOL ServerReceive(SOCKET hSocket);
	BOOL CloseSocketConnection(SOCKET hSocket, LPNETADDR lpNetAddr);
	SOCKET GetSocketHandle(DWORD dwAddress, WORD wPort);

protected:
	typedef CArray<SOCKETDATA> SocketDataArray;

	typedef struct _AtmSDArray
	{
		SocketDataArray  list;
		CCriticalSection cs;
	}
	AtmSDArray, *LPAtmSDArray;

	AtmSDArray m_sdList;
	BOOL	   m_fAccept;
	BOOL	   m_fInternet;
};

//////////////////////////////////////////////////////////////////////////

class CClientSocket : public IBaseSocket
{
	DECLARE_DYNAMIC(CClientSocket)

public:
	CClientSocket();
	virtual ~CClientSocket();

	virtual void OnConnect(SOCKET hSocket, int nErrorCode);
	virtual void OnClose(SOCKET hSocket, int nErrorCode);
	virtual void OnSend(SOCKET hSocket, int nErrorCode);
	virtual void OnReceive(SOCKET hSocket, int nErrorCode);

	BOOL Initialize(SOCKMSG SockMsg);
	BOOL AsyncConnect(const LPSOCKADDR lpDstAddr, int nLength);
	BOOL AsyncConnect(const LPCSTR lpDstAddr, WORD wPort, int nLength);
	BOOL SyncConnect(const LPSOCKADDR lpDstAddr, int nLength, DWORD dwTimeout);
	BOOL SyncConnect(const LPCSTR lpDstAddr, WORD wPort, int nLength, DWORD dwTimeout);

protected:
	virtual BOOL CreateSocket(SOCKMSG SockMsg);
};

//////////////////////////////////////////////////////////////////////////

class CPeerSocket : public IBaseSocket
{
	DECLARE_DYNAMIC(CPeerSocket)

public:
	CPeerSocket();
	virtual ~CPeerSocket();

	virtual void OnClose(SOCKET hSocket, int nErrorCode);
	virtual void OnSend(SOCKET hSocket, int nErrorCode);
	virtual void OnReceive(SOCKET hSocket, int nErrorCode);
	virtual BOOL Initialize(SOCKMSG SockMsg, DWORD dwLocalAddress, WORD wPort);

protected:
	virtual BOOL CreateSocket(SOCKMSG SockMsg);
};

//////////////////////////////////////////////////////////////////////////

class CBcastSocket : public CPeerSocket
{
	DECLARE_DYNAMIC(CBcastSocket)

public:
	CBcastSocket();
	virtual ~CBcastSocket();

	BOOL SetBroadcast(DWORD dwLocalAddress, DWORD dwMaskAddress, WORD wPort);
	BOOL SetBroadcast(DWORD dwBcastAddress, WORD wPort);
	BOOL ResetBroadcast();
	BOOL Broadcast(const CByteArray &data);

protected:
	SOCKADDR m_BcastAddr;
	BOOL	 m_fBcast;
};

//////////////////////////////////////////////////////////////////////////

class CMcastSocket : public CPeerSocket
{
	DECLARE_DYNAMIC(CMcastSocket)

public:
	CMcastSocket();
	virtual ~CMcastSocket();

	BOOL SetReceiveMulticast(DWORD dwMcastAddress);
	BOOL JoinMulticastGroup(DWORD dwMcastAddress, WORD wPort);
	BOOL LeaveMulticastGroup();
	BOOL SetTTL(DWORD dwTTL);
	BOOL Multicast(const CByteArray &data);

protected:
	SOCKADDR m_McastAddr;
	BOOL	 m_fMcast;
};

//////////////////////////////////////////////////////////////////////////

class CRawSocket : public CPeerSocket
{
	DECLARE_DYNAMIC(CRawSocket)

public:
	CRawSocket();
	virtual ~CRawSocket();

	virtual BOOL Initialize(SOCKMSG SockMsg, DWORD dwLocalAddress, WORD wPort);

	BOOL SetPromiscuous(BOOL flag);

protected:
	virtual BOOL CreateSocket(SOCKMSG SockMsg);

protected:
	BOOL m_fPromiscuous;
};

//////////////////////////////////////////////////////////////////////////]