// ComBasic.h
// (c) 2002-2005 exeal

#ifndef _COM_BASIC_H_
#define _COM_BASIC_H_

#include <objbase.h>
#include <objsafe.h>
#include <cassert>
#include <stdexcept>


namespace Armaiti {

// }NƂ
/////////////////////////////////////////////////////////////////////////////

#define RETURN_IF_FAILED(hr)	\
	if(FAILED(hr))				\
		return (hr)

#define VERIFY_POINTER(p)	\
	if((p) == 0)			\
		return E_POINTER

#define SAFE_BSTR(bstr)	\
	((bstr != 0) ? (bstr) : OLESTR(""))

#ifndef toBoolean
#define toBoolean(b)	((b) != 0)
#endif

#define toVariantBoolean(b)	\
	(((b) != 0) ? VARIANT_TRUE : VARIANT_FALSE)

#define isEmptyBSTR(bstr)	\
	(bstr == 0 || *bstr == 0)



/// CComPtr::operator-> Ԃ AddRef ARelease Ăяo֎~vLV
template<class T>
class CComPtrProxy : public T {
	// \bh
private:
	STDMETHOD_(ULONG, AddRef)() = 0;
	STDMETHOD_(ULONG, Release)() = 0;
//
//	// Zq
//private:
//	void	operator delete(void*);
};

/// Öق̌^ϊF߂|V[
struct AllowConversion {
};

/// Öق̌^ϊF߂Ȃ|V[
struct DisallowConversion {
};

/// AhXZqĂяoƂɕێĂ|C^ Release |V[
struct ReleasePreviousPointer {
	enum {bRelease = true};
};

/// AhXZqĂяoƂɊɃ|C^ĂȂ assert |V[
struct AssertPointerIsNull {
	enum {bRelease = false};
};

/**
 *	COM X}[g|C^
 *	@param T				C^[tFCX^
 *	@param AddressingPolicy	łɃ|C^ێĂƂ̃AhXZq̋B
 *							AhXZq̐Q
 *	@param ConversionPolicy	Öق̌^ϊs
 */
template<class T,
	class AddressingPolicy = ReleasePreviousPointer,
	class ConversionPolicy = AllowConversion>
class CComPtr {
	// RXgN^
public:
	///	RXgN^
	explicit CComPtr(T* p = 0) : m_pointee(p) {
		if(m_pointee != 0)
			m_pointee->AddRef();
	}
	///	Rs[RXgN^
	CComPtr(const CComPtr<T>& rhs) : m_pointee(rhs.m_pointee) {
		if(m_pointee != 0)
			m_pointee->AddRef();
	}
	///	fXgN^
	virtual ~CComPtr() {
		if(m_pointee != 0)
			m_pointee->Release();
	}

	// Zq
public:
	/**
	 *	@brief	AhXZq
	 *	̉Zq̌ʂ͗Ⴆ
	 *	CreateObject(&p);
	 *	Ȃǂ̂悤ɏo͈ƂĂ̂ݎgp邱ƂOƂĂB
	 *	ZqĂяoɃIuWFNg null łȂꍇ̋
	 *	AddressingPolicy ɂ
	 */
	T** operator &() {
		if(m_pointee != 0) {
			if(AddressingPolicy::bRelease)
				m_pointee->Release();
			else
				assert(false);
		}
		return &m_pointee;
	}

	///	oANZXZq
	CComPtrProxy<T>* operator ->() const {
		assert(m_pointee != 0);
		return static_cast<CComPtrProxy<T>*>(m_pointee);
	}
//	T& operator *() const {
//		assert(m_pointee != 0);
//		return *m_pointee;
//	}

	/**
	 *	@brief	Zq
	 *	<var>p</var>  null ł悢
	 */
	CComPtr<T>& operator =(T* p) {
		if(m_pointee != p) {
			if(m_pointee != 0)
				m_pointee->Release();
			m_pointee = p;
			if(m_pointee != 0)
				m_pointee->AddRef();
		}
		return *this;
	}

	/// Zq
	template<class I>
	CComPtr<T>& operator =(const CComPtr<I>& rhs) {
		if(this != &rhs) {
			if(m_pointee != 0)
				m_pointee->Release();
			m_pointee = rhs.m_pointee;
			m_pointee->AddRef();
		}
		return *this;
	}

	/// Zq
	bool operator ==(const T* p) const {
		return m_pointee == p;
	}

	/// sZq
	bool operator !=(const T* p) const {
		return m_pointee != p;
	}

	/// ̃|C^^ւ̃LXg
	operator T*() const;

	// \bh
public:
	/// ::CoCreateInstance ɂIuWFNg
	HRESULT CreateInstance(REFCLSID rclsid,
			IUnknown* pUnkOuter = 0, DWORD dwClsContext = CLSCTX_ALL, REFIID riid = __uuidof(T)) {
		assert(m_pointee == 0);
		return ::CoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, reinterpret_cast<void**>(&m_pointee));
	}

	/// <var>p</var> ƓIuWFNgǂԂ
	bool IsEqualObject(IUnknown* p) const {
		if(m_pointee == 0 && p == 0)
			return true;
		else if(m_pointee == 0 || p == 0)
			return false;

		IUnknown*	p1 = 0;
		IUnknown*	p2 = 0;
		bool		b;

		m_pointee->QueryInterface(IID_IUnknown, reinterpret_cast<void**>(p1));
		p->QueryInterface(IID_IUnknown, reinterpret_cast<void**>(p2));
		b = p1 == p2;
		p1->Release();
		p2->Release();

		return b;
	}

	/**
	 *	@see		IUnknown::QueryInterface
	 *	@param pp	[out] LXg
	 */
	template<class I>
	HRESULT QueryInterface(I** pp) const {
		assert(m_pointee != 0);
		if(pp == 0)
			return E_POINTER;
		return m_pointee->QueryInterface(__uuidof(I), reinterpret_cast<void**>(pp));
	}

	/**
	 *	̃|C^ւ̖IȃANZX
	 *	@param p	X}[g|C^
	 *	@return		̃|C^
	 */
	static T* GetRowPointer(const CComPtr<T>& p) {
		return p.m_pointee;
	}

	// f[^o
private:
	T*	m_pointee;
};

template<class T, class AddressingPolicy, typename AllowConversion>
inline CComPtr<T, AddressingPolicy, AllowConversion>::operator T*() const {
	return m_pointee;
}


/**
 *	@brief	IErrorInfo  C++ OƂĈ߂̃bpNX
 *	o: Essential COM (Don Box)
 */
class CComException {
	// RXgN^
public:
	/**
	 *	RXgN^
	 *	@param hResult			SCODE
	 *	@param riid				IID
	 *	@param pwszSource		̗O𓊂NX
	 *	@param pwszDescription	O̐Bnull ̏ꍇ <var>hResult</var> 擾
	 *	@param pwszHelpFile		wvt@C̃pX
	 *	@param dwHelpContext	wvgsbN̔ԍ
	 */
	CComException(
			HRESULT hResult,
			REFIID riid,
			const OLECHAR* pwszSource,			// class threw this exception
			const OLECHAR* pwszDescription = 0,	// description of this error (set null to obtain from hResult)
			const OLECHAR* pwszHelpFile = 0,
			DWORD dwHelpContext = 0) {
		HRESULT				hr;
		ICreateErrorInfo*	pcei = 0;

		assert(FAILED(hResult));

		hr = ::CreateErrorInfo(&pcei);
		assert(SUCCEEDED(hr));

		hr = pcei->SetGUID(riid);
		assert(SUCCEEDED(hr));
		if(pwszSource != 0) {
			hr = pcei->SetSource(const_cast<OLECHAR*>(pwszSource));
			assert(SUCCEEDED(hr));
		}
		if(pwszDescription != 0) {
			hr = pcei->SetDescription(const_cast<OLECHAR*>(pwszDescription));
			assert(SUCCEEDED(hr));
		} else {
			BSTR	bstrDescription = 0;
			CComException::GetDescriptionOfSCode(hResult, bstrDescription);
			hr = pcei->SetDescription(bstrDescription);
			::SysFreeString(bstrDescription);
			assert(SUCCEEDED(hr));
		}
		if(pwszHelpFile != 0) {
			hr = pcei->SetHelpFile(const_cast<OLECHAR*>(pwszHelpFile));
			assert(SUCCEEDED(hr));
		}
		hr = pcei->SetHelpContext(dwHelpContext);
		assert(SUCCEEDED(hr));

		m_hResult = hResult;
		hr = pcei->QueryInterface(IID_IErrorInfo, reinterpret_cast<void**>(&m_pErrorInfo));
		assert(SUCCEEDED(hr));
		pcei->Release();
	}
	/// fXgN^
	virtual ~CComException() {
		if(m_pErrorInfo != 0)
			m_pErrorInfo->Release();
	}

	// \bh
public:
	/// G[ IErrorInfo Ԃ
	void GetErrorInfo(IErrorInfo** pErrorInfo) const {
		assert(pErrorInfo != 0);
		*pErrorInfo = m_pErrorInfo;
		(*pErrorInfo)->AddRef();
	}

	/// G[ HRESULT Ԃ
	HRESULT GetSCode() const {
		return m_hResult;
	}

	/// OIuWFNg_XbhOƂē
	void ThrowLogicalThreadError() {
		::SetErrorInfo(0, m_pErrorInfo);
	}

	/**
	 *	HRESULT ɑΉG[bZ[WԂ
	 *	@param hResult			[in] HRESULT
	 *	@param bstrDescription	[out] G[bZ[W
	 *	@param dwLanguage		[in]  ID
	 */
	static void GetDescriptionOfSCode(HRESULT hResult,
			BSTR& bstrDescription, DWORD dwLanguageId = MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT)) {
		void*	pwszDescription = 0;

		FormatMessageW(
			FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
			0, hResult, dwLanguageId, reinterpret_cast<wchar_t*>(&pwszDescription), 0, 0);
		bstrDescription = ::SysAllocString(reinterpret_cast<OLECHAR*>(pwszDescription));
		::LocalFree(pwszDescription);
	}

	// f[^o
private:
	HRESULT		m_hResult;
	IErrorInfo*	m_pErrorInfo;

};


/**
 *	NeBJZNVœsBATL ƂقƂǓ
 *	@param bAuto	CX^X̐AjɎIɃNeBJZNVAj󂷂邩
 */
template<bool bAuto = true>
class CComCriticalSection {
public:
	/// RXgN^
	CComCriticalSection() throw(std::runtime_error) {
		if(bAuto)
			if(FAILED(_Init()))
				throw std::runtime_error("Failed to initialize critical section!");
	}
	/// fXgN^
	~CComCriticalSection() {
		if(bAuto)
			_Term();
	}
	/// bN
	void Lock() {
		::EnterCriticalSection(&m_cs);
	}
	/// bN
	void Unlock() {
		::LeaveCriticalSection(&m_cs);
	}
	/// CX^X
	HRESULT Init();
	/// CX^X̌㏈
	void Term();

private:
	CComCriticalSection(CComCriticalSection&);
	operator =(CComCriticalSection&);
	HRESULT _Init() {
		__try {
			::InitializeCriticalSection(&m_cs);
		} __except(EXCEPTION_EXECUTE_HANDLER) {
			return (STATUS_NO_MEMORY == ::GetExceptionCode()) ? E_OUTOFMEMORY : E_FAIL;
		}
		return S_OK;
	}
	void _Term() {
		::DeleteCriticalSection(&m_cs);
	}
private:
	CRITICAL_SECTION	m_cs;
};

template<>
inline HRESULT CComCriticalSection<false>::Init() {
	return _Init();
}

template<>
inline void CComCriticalSection<false>::Term() {
	_Term();
}


/// ISupportErrorInfo ̕WIȎ
template<const IID* piid>
class ISupportErrorInfoImpl : virtual public ISupportErrorInfo {
	STDMETHODIMP InterfaceSupportsErrorInfo(REFIID riid) {
		return (riid == *piid) ? S_OK : S_FALSE;
	}
};


/**
 *	IObjectSafety ̒PȎ
 *
 *	̎͒P̃C^[tFCXT|[gȂ
 *	@param dwSupportedSafety	T|[gIvV
 */
template<DWORD dwSupportedSafety>
class IObjectSafetyImpl : virtual public IObjectSafety {
public:
	IObjectSafetyImpl(DWORD dwEnabledOptions = 0) : m_dwEnabledSafety(dwEnabledOptions & dwSupportedSafety) {
	}
	virtual ~IObjectSafetyImpl() {
	}
public:
	STDMETHODIMP GetInterfaceSafetyOptions(REFIID riid, DWORD* pdwSupportedOptions, DWORD* pdwEnabledOptions) {
		VERIFY_POINTER(pdwSupportedOptions);
		VERIFY_POINTER(pdwEnabledOptions);

		IUnknown*	p;
		if(SUCCEEDED(QueryInterface(riid, reinterpret_cast<void**>(&p)))) {
			p->Release();
			*pdwSupportedOptions = dwSupportedSafety;
			*pdwEnabledOptions = m_dwEnabledSafety;
			return S_OK;
		} else {
			*pdwSupportedOptions = *pdwEnabledOptions = 0;
			return E_NOINTERFACE;
		}
	}
	STDMETHODIMP SetInterfaceSafetyOptions(REFIID riid, DWORD dwOptionSetMask, DWORD dwEnabledOptions) {
		IUnknown*	p;

		if(FAILED(QueryInterface(riid, reinterpret_cast<void**>(&p))))
			return E_NOINTERFACE;
		p->Release();
		if(toBoolean(dwOptionSetMask & ~dwSupportedSafety))
			return E_FAIL;
		m_dwEnabledSafety = (m_dwEnabledSafety & ~dwOptionSetMask) | (dwOptionSetMask & dwEnabledOptions);
		return S_OK;
	}
protected:
	DWORD	m_dwEnabledSafety;
};

} // namespace Armaiti

#endif /* _COM_BASIC_H_ */

/* [EOF] */