#include "StdAfx.h"
#include "SSLProtocolSocket.h"
#include "openssl\err.h"

//!	BIObp
BIO *BIO_new_ovrs(COverlappedSocket *ovrs);

CSSLProtocolSocket::CSSLProtocolSocket(void)
{
	m_sslInit = 0;
	m_ssl = NULL;
}

CSSLProtocolSocket::~CSSLProtocolSocket(void)
{
	Close();
}


//////////////////////////////////////////////////////////////////////////////////
//	Ȃ
//////////////////////////////////////////////////////////////////////////////////
/*!
	
*/
void CSSLProtocolSocket::Close()
{
	//	eNX
	CProtocolSocket::Close();

	//	
	if(m_sslInit)
	{
		SSL_shutdown(m_ssl);
		SSL_free(m_ssl);
	}
	m_sslInit = 0;

	//	̃Xbh̃G[j
	ERR_remove_state(0);
}

//////////////////////////////////////////////////////////////////////////////////
//	ʐM
//////////////////////////////////////////////////////////////////////////////////
/*!
	SSLʐMJn
*/
int CSSLProtocolSocket::SSLAccept(SSL_CTX *sslContext)
{
	//	ďH
	ASSERT(m_sslInit == 0);
	m_sslInit = 1;

	//	\PbgZbgAbv
	m_bio = BIO_new_ovrs(&m_socket);
	if(m_bio == NULL)
	{
		CloseError(CSPS_ERROR_CREATE_SSL_CONNECTION, "SSLAccept");
		return(-1);
	}

	//	SSL
	m_ssl=SSL_new(sslContext);
	if(m_ssl == NULL)
	{
		CloseError(CSPS_ERROR_CREATE_SSL_CONNECTION, "SSLAccept");
		return(-1);
	}

	//	BIO蓖
	SSL_set_bio(m_ssl,m_bio,m_bio);

	//	lSVG[V
	switch(SSL_get_error(m_ssl, SSL_accept(m_ssl)))
	{
	//	H
	case SSL_ERROR_NONE:
		return(0);

	//	G[
	case SSL_ERROR_WANT_READ:
	case SSL_ERROR_WANT_WRITE:
	case SSL_ERROR_ZERO_RETURN:
	default:
		CloseError(CSPS_ERROR_SSL_ACCEPT, "SSLAccept");
		return(-1);
	}
}

//////////////////////////////////////////////////////////////////////////////////
//	ʐM
//////////////////////////////////////////////////////////////////////////////////
/*!
	M
*/
int CSSLProtocolSocket::BlockingSend(int timeOut)
{
	//	`FbN
	if(!m_sslInit)
	{
		CloseError(CBS_ERROR_CLOSE, "TrySend");
		return(-1);
	}

	//	f[^?
	if(m_sendBuf.GetInBuf() == 0)
		return(0);

	//	^CAEgw
	m_socket.SetTimeOut(timeOut);

	//	Mpobt@擾
	int		bufLen;
	char	*sendBuf = m_sendBuf.GetReadBuffer(&bufLen);

	//	M
	int sended = SSL_write(m_ssl, sendBuf, bufLen);
	if(SSL_get_error(m_ssl, sended) != SSL_ERROR_NONE)
	{
		switch(m_socket.GetLastError())
		{
		case COverlappedSocket::OVRS_ERROR_TIMEOUT:
			CloseError(CBS_ERROR_TIMEOUT,"BlockingSend");
			break;

		case COverlappedSocket::OVRS_ERROR_CLOSE:
			CloseError(CBS_ERROR_CLOSE,"BlockingSend");
			break;

		case COverlappedSocket::OVRS_ERROR_BREAK:
			CloseError(CBS_ERROR_BREAK,"BlockingSend");
			break;

		case COverlappedSocket::OVRS_ERROR_OTHER:
		default:
			CloseError(CBS_ERROR_OTHER,"BlockingSend");
			break;
		}
		return(-1);
	}

	//	Mf[^ʂݒ
	m_sendBuf.SetReadedLen(sended);
	return(m_sendBuf.GetInBuf());
}


/*!
	M
*/
int CSSLProtocolSocket::BlockingRecv(int timeOut)
{
	//	`FbN
	if(!m_sslInit)
	{
		CloseError(CBS_ERROR_CLOSE, "TryRecv");
		return(-1);
	}

	//	obt@邩H
	if(m_recvBuf.GetFreeBuf() == 0)
		return(m_recvBuf.GetInBuf());	//	f[^Mς݂ȂI

	//	^CAEgw
	m_socket.SetTimeOut(timeOut);

	//	Mpobt@擾
	int		bufLen;
	char	*recvBuf = m_recvBuf.GetWriteBuffer(&bufLen);

	//	M
	int recved =SSL_read(m_ssl, recvBuf, bufLen);
	if(SSL_get_error(m_ssl, recved) != SSL_ERROR_NONE)
	{
		switch(m_socket.GetLastError())
		{
		case COverlappedSocket::OVRS_ERROR_TIMEOUT:
			CloseError(CBS_ERROR_TIMEOUT,"BlockingRecv");
			break;

		case COverlappedSocket::OVRS_ERROR_CLOSE:
			CloseError(CBS_ERROR_CLOSE,"BlockingSend");
			break;

		case COverlappedSocket::OVRS_ERROR_BREAK:
			CloseError(CBS_ERROR_BREAK,"BlockingRecv");
			break;

		case COverlappedSocket::OVRS_ERROR_OTHER:
		default:
			CloseError(CBS_ERROR_OTHER,"BlockingRecv");
			break;
		}
		return(-1);
	}

	//	Mf[^ʂݒ
	m_recvBuf.SetWritedLen(recved);
	return(m_recvBuf.GetInBuf());
}


//////////////////////////////////////////////////////////////////////////////////
//	G[
//////////////////////////////////////////////////////////////////////////////////
/*!
	G[R[h當擾
*/
CString CSSLProtocolSocket::GetErrorString(int errorCode)
{
	switch(errorCode)
	{
	case CSPS_ERROR_CREATE_SSL_CONNECTION:
		return(_T("SSLڑ̏Ɏs܂"));

	case CSPS_ERROR_SSL_ACCEPT:
		return(_T("SSLڑ̃lSVG[VɎs܂"));

	default:
		return(CProtocolSocket::GetErrorString(errorCode));
	}
}
