//	Roast+ License

#ifndef __SFJP_ROAST_EX__graphics__directx__dx9__shader_HPP__
#define __SFJP_ROAST_EX__graphics__directx__dx9__shader_HPP__

#include <string>
#include "roast/graphics/directx/dx9/device.hpp"
#include "roast/graphics/directx/dx9/buffer.hpp"

namespace roast
{
	namespace directx
	{
		namespace dx9
		{
			namespace graphics
			{
				///////////////////////////////////////////////////////////////////////////////////////////////////////////////

				class pass
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::pass,
						get_desc__GetDesc_Failed
					};
				protected:
					d3deffect *m_ieffect;
					::D3DXHANDLE m_handle;
				public:
					pass(d3deffect* ieffect, ::D3DXHANDLE hpass) : m_ieffect(ieffect), m_handle(m_handle){}

					void begin() const {}
					void end() const {}
					::D3DXHANDLE get_hpass(){ return m_handle; }
					const ::D3DXHANDLE get_hpass() const { return m_handle; }
				};

				///////////////////////////////////////////////////////////////////////////////////////////////////////////////

				class technique
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::technique,
						get_pass__Unknown_pass
					};
				protected:
					d3deffect *m_ieffect;
					::D3DXHANDLE m_handle;
				public:
					technique(d3deffect* ieffect, ::D3DXHANDLE htechnique) : m_ieffect(ieffect), m_handle(htechnique){}
					/////
					pass get_pass(const char* name){
						::D3DXHANDLE ps = m_ieffect->GetPassByName(m_handle, name);
						if ( ps == NULL ){
							throw exception(get_pass__Unknown_pass,
								::std::string("ID3DXEffect::GetPassByName() unknown pass ")+name);
						}
						return pass(m_ieffect,ps);
					}
					const pass get_pass(const char* name) const { return get_pass(name); }
					pass operator[](const char* name){ return get_pass(name); }
					const pass operator[](const char* name) const { return get_pass(name); }

					//	Get Pass by Index
					pass get_pass(int index){
						::D3DXHANDLE ps = m_ieffect->GetPass(m_handle, index);
						if ( ps == NULL ){
							char work[16];
#pragma warning ( disable : 4996 )
							sprintf(work, "%d", index);
							throw exception(get_pass__Unknown_pass,
								::std::string("ID3DXEffect::GetPassByIndex() unknown pass index ")+work);
#pragma warning ( default : 4996 )
						}
						return pass(m_ieffect,ps);
					}
					const pass get_pass(int index) const { return get_pass(index); }
					pass operator[](int index){ return get_pass(index); }
					const pass operator[](int index) const { return get_pass(index); }

					::D3DXHANDLE get_htechnique(){ return m_handle; }
				};

				///////////////////////////////////////////////////////////////////////////////////////////////////////////////
				
				class effect_shared_memory
				{
				};

				class effect : protected iunknown_<d3deffect>, public effect_base
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::effect,

						effect_D3DXCreateEffectFromFileA_Failed,
						effect_D3DXCreateEffectFromFile_Failed,
						get_technique__Unknown_technique
					};

					/// Constructor Param //////////////////////////////////////
					struct construct_by_file
					{
						::std::string filename;
						construct_by_file(const char* filename){ this->filename = filename; }
						construct_by_file(const ::std::string& filename){ this->filename = filename; }
					};
					struct construct_by_file_lpcstr
					{
						LPCSTR filename;
						construct_by_file_lpcstr(LPCSTR filename){ this->filename = filename; }
					};
					struct construct_by_file_lpctstr
					{
						LPCTSTR filename;
						construct_by_file_lpctstr(LPCTSTR filename){ this->filename = filename; }
					};
				protected:
					device& m_dev;

				private:
					template <typename _Str>
					::HRESULT _D3DXCreateEffectFromFile_(
						LPDIRECT3DDEVICE9               pDevice,
						_Str                            pSrcFile,
						CONST D3DXMACRO*                pDefines,
						LPD3DXINCLUDE                   pInclude,
						DWORD                           Flags,
						LPD3DXEFFECTPOOL                pPool,
						LPD3DXEFFECT*                   ppEffect,
						LPD3DXBUFFER*                   ppCompilationErrors)
					{ return ::D3DXCreateEffectFromFileA(pDevice,pSrcFile,pDefines,pInclude,Flags,pPool,ppEffect,ppCompilationErrors); }
					
					template <>
					::HRESULT _D3DXCreateEffectFromFile_<LPCWSTR>(
						LPDIRECT3DDEVICE9               pDevice,
						LPCWSTR                         pSrcFile,
						CONST D3DXMACRO*                pDefines,
						LPD3DXINCLUDE                   pInclude,
						DWORD                           Flags,
						LPD3DXEFFECTPOOL                pPool,
						LPD3DXEFFECT*                   ppEffect,
						LPD3DXBUFFER*                   ppCompilationErrors)
					{ return ::D3DXCreateEffectFromFileW(pDevice,pSrcFile,pDefines,pInclude,Flags,pPool,ppEffect,ppCompilationErrors); }
					
					template <>
					::HRESULT _D3DXCreateEffectFromFile_<LPWSTR>(
						LPDIRECT3DDEVICE9               pDevice,
						LPWSTR                          pSrcFile,
						CONST D3DXMACRO*                pDefines,
						LPD3DXINCLUDE                   pInclude,
						DWORD                           Flags,
						LPD3DXEFFECTPOOL                pPool,
						LPD3DXEFFECT*                   ppEffect,
						LPD3DXBUFFER*                   ppCompilationErrors)
					{ return ::D3DXCreateEffectFromFileW(pDevice,pSrcFile,pDefines,pInclude,Flags,pPool,ppEffect,ppCompilationErrors); }

					///////////////////

					void _cunstruct_a(
						LPCSTR filename,
						CONST ::D3DXMACRO *defines,
						::LPD3DXINCLUDE include,
						DWORD flags,
						LPD3DXEFFECTPOOL effect_pool)
					{
						LPD3DXBUFFER error_out = NULL;
						// DX8
						//::HRESULT hr = ::D3DXCreateEffectFromFileA(m_dev.get_d3ddevice_ptr(), filename, &m_if, &error_out);
						::HRESULT hr = ::D3DXCreateEffectFromFileA(m_dev.get_d3ddevice_ptr(),
							filename, defines, include, flags, effect_pool, &m_if, &error_out);

						if ( hr != D3D_OK ){
							::std::string msg = "D3DXCreateEffectFromFileA() Failed.";
							if ( error_out != NULL )
							{
								const char* error_out_msg = (const char*)error_out->GetBufferPointer();
								msg += error_out_msg;
								error_out->Release();
							}
							throw api_error(effect_D3DXCreateEffectFromFileA_Failed, msg, hr);
						}
					}
					void _cunstruct_aw(
						LPCTSTR filename,
						CONST ::D3DXMACRO *defines,
						::LPD3DXINCLUDE include,
						DWORD flags,
						LPD3DXEFFECTPOOL effect_pool)
					{
						LPD3DXBUFFER error_out = NULL;
						::HRESULT hr = ::D3DXCreateEffectFromFile(m_dev.get_d3ddevice_ptr(),
							filename, defines, include, flags, effect_pool, &m_if, &error_out);

						if ( hr != D3D_OK ){
							const char* error_out_msg = (const char*)error_out->GetBufferPointer();
							::std::string msg = "D3DXCreateEffectFromFile() Failed.";
							msg += error_out_msg;
							error_out->Release();
							throw api_error(effect_D3DXCreateEffectFromFile_Failed, msg, hr);
						}
					}

				public:
					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						const defines_t defines=defines_t(),
						DWORD flags=0,
						effect_shared_memory effect_pool=effect_shared_memory()) : m_dev(d)
					{
						_cunstruct_a(
							f.filename.c_str(), defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}

					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file_lpcstr &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						const defines_t defines=defines_t(),
						DWORD flags=0,
						effect_shared_memory effect_pool=effect_shared_memory()) : m_dev(d)
					{
						_cunstruct_a(
							f.filename, defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}
					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file_lpctstr &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						const defines_t defines=defines_t(),
						DWORD flags=0,
						effect_shared_memory effect_pool=effect_shared_memory()) : m_dev(d)
					{
						_cunstruct_aw(
							f.filename, defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}

					/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

					technique get_technique(const char* name)
					{
						::D3DXHANDLE d3dtech = m_if->GetTechniqueByName(name);
						if ( d3dtech == NULL ){
							throw exception(get_technique__Unknown_technique,
								::std::string("ID3DXEffect::GetTechniqueByName() unknown technique ")+name);
						}
						return technique(m_if, d3dtech);
					}
					const technique get_technique(const char* name) const { return get_technique(name); }
					technique operator[](const char* name){ return get_technique(name); }
					const technique operator[](const char* name) const { return get_technique(name); }

					d3deffect* get_ieffect_ptr(){ return m_if; }
					d3deffect* get_ieffect_ptr() const { return m_if; }
				};
				

				///////////////////////////////////////////////////////////////////////////////////

				class render_device
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::shader,

						buffer__CreateBuffer_Failed,
						setup_input_layout__CreateInputLayout_Failed,
						apply_pass__ID3D9EffectPass_Apply_Failed,
						get_pass_Failed,
						get_technique_Failed
					};
				protected:
					device& m_dev;

				public:
					render_device(device &d) : m_dev(d)
					{
					}

					///////////////////////////////////////////////////////

					template <typename VERTEX_TYPE>
					void set_vertex_buffer(const input_vertex_buffer<VERTEX_TYPE> &vb)// const OK...?
					{
						typename const input_vertex_buffer<VERTEX_TYPE>::_IfType *buffer_ptr = vb.get_buffer_ptr();
						UINT stride = sizeof(VERTEX_TYPE);
						UINT offset = 0;
						m_dev.get_d3ddevice_ptr()->IASetVertexBuffers( 0, 1, (ID3D9Buffer *const *)&buffer_ptr, &stride, &offset );
					}

					///////

					template <typename VERTEX_TYPE>
					void setup_input_layout(/*const ::D3D9_INPUT_ELEMENT_DESC *input_desc, int num_elements, */const pass &pss)
					{
						::ID3D9InputLayout* input_layout;

						//	Create Input Layout
						HRESULT hr = m_dev.get_d3ddevice_ptr()->CreateInputLayout(input_desc, num_elements,
							pss.get_desc().pIAInputSignature, pss.get_desc().IAInputSignatureSize, &input_layout );
						if ( hr != D3D_OK ){
							throw api_error(setup_input_layout__CreateInputLayout_Failed,
								"ID3D9Device::CreateInputLayout() Failed.", hr);
						}

						//	Bind Input Layout
						m_dev.get_d3ddevice_ptr()->IASetInputLayout(input_layout);
					}

					///////

					void apply_pass(const pass& pss)
					{
						pss.end();
						pss.begin();
					}

					/*
					void set_primitive_topology(::D3D9_PRIMITIVE_TOPOLOGY topology)
					{
						m_dev.get_d3ddevice_ptr()->IASetPrimitiveTopology(topology);
					}

					void draw(UINT vertex_count, UINT vertex_offset=0)
					{
						m_dev.get_d3ddevice_ptr()->Draw(vertex_count, vertex_offset);
					}

					void draw_primitive(::D3D9_PRIMITIVE_TOPOLOGY topology, UINT vertex_count, UINT vertex_offset=0)
					{
						set_primitive_topology(topology);
						draw(vertex_count, vertex_offset);
					}
					*/
				};
				typedef render_device renderer;

				///////////////////////////////////////////////////////////////////////////
			}
		}
	}
}

#endif//__SFJP_ROAST_EX__graphics__directx__dx9__shader_HPP__
