//	Roast+ License

#ifndef __SFJP_ROAST_EX__graphics__directx__dx10__shader_HPP__
#define __SFJP_ROAST_EX__graphics__directx__dx10__shader_HPP__

#include <string>
#include "roast/graphics/directx/dx10/device.hpp"
#include "roast/graphics/directx/dx10/buffer.hpp"

namespace roast
{
	namespace directx
	{
		namespace dx10
		{
			namespace graphics
			{
				///////////////////////////////////////////////////////////////////////////////////////////////////////////////

				class pass
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::pass,
						get_desc__GetDesc_Failed
					};
				protected:
					d3dpass* m_if;
					mutable D3D10_PASS_DESC m_desc;
					mutable bool m_desc_got;
				public:
					pass(d3dpass* pass_in) : m_if(pass_in){
						//::ZeroMemory(&m_desc,sizeof(m_desc));
						m_desc_got = false;
					}
					
					//	Get Desc
					const D3D10_PASS_DESC& get_desc() const
					{
						if ( !m_desc_got )
						{
							HRESULT hr = m_if->GetDesc(&m_desc);
							if ( hr != D3D_OK ){
								throw api_error(get_desc__GetDesc_Failed,
									"ID3D10EffectPass::GetDesc() Failed.", hr);
							}
							m_desc_got = true;
						}
						return m_desc;
					}

					d3dpass* get_ipass_ptr(){ return m_if; }
					d3dpass* get_ipass_ptr() const { return m_if; }
				};

				///////////////////////////////////////////////////////////////////////////////////////////////////////////////

				class technique
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::technique,
						get_pass__Unknown_pass
					};
				protected:
					d3dtechnique* m_if;
				public:
					technique(d3dtechnique* tech) : m_if(tech){}
					/////

					//	Get Pass by Name
					pass get_pass(const char* name){
						d3dpass* ps = m_if->GetPassByName(name);
						if ( ps == NULL ){
							throw exception(get_pass__Unknown_pass,
								::std::string("ID3D10EffectTechnique::GetPassByName() unknown pass ")+name);
						}
						return pass(ps);
					}
					const pass get_pass(const char* name) const { return get_pass(name); }
					pass operator[](const char* name){ return get_pass(name); }
					const pass operator[](const char* name) const { return get_pass(name); }

					//	Get Pass by Index
					pass get_pass(int index){
						d3dpass* ps = m_if->GetPassByIndex(index);
						if ( ps == NULL ){
							char work[16];
#pragma warning ( disable : 4996 )
							sprintf(work, "%d", index);
							throw exception(get_pass__Unknown_pass,
								::std::string("ID3D10EffectTechnique::GetPassByIndex() unknown pass index ")+work);
#pragma warning ( default : 4996 )
						}
						return pass(ps);
					}
					const pass get_pass(int index) const { return get_pass(index); }
					pass operator[](int index){ return get_pass(index); }
					const pass operator[](int index) const { return get_pass(index); }

					d3dtechnique* get_itechnique_ptr(){ return m_if; }
					d3dtechnique* get_ieffect_ptr() const { return m_if; }
				};

				///////////////////////////////////////////////////////////////////////////////////////////////////////////////
				
				class effect : protected iunknown_<d3deffect>
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::effect,

						effect_D3DX10CreateEffectFromFile_Failed,
						get_technique__Unknown_technique
					};

					/// Constructor Param //////////////////////////////////////
					struct construct_by_file
					{
						::std::string filename;
						construct_by_file(const char* filename){ this->filename = filename; }
						construct_by_file(const ::std::string& filename){ this->filename = filename; }
					};
					struct construct_by_file_lpcstr
					{
						LPCSTR filename;
						construct_by_file_lpcstr(LPCSTR filename){ this->filename = filename; }
					};
					struct construct_by_file_lpctstr
					{
						LPCTSTR filename;
						construct_by_file_lpctstr(LPCTSTR filename){ this->filename = filename; }
					};
				protected:
					device& m_dev;

				private:
					void _cunstruct_a(
						LPCSTR filename,
						CONST D3D10_SHADER_MACRO *defines,
						ID3D10Include *include,
						LPCSTR profile,
						UINT hlsl_flags,
						UINT fx_flags,
						/*ID3D10Device *pDevice,*/
						ID3D10EffectPool *effect_pool,
						ID3DX10ThreadPump *pump,
						/*ID3D10Effect **ppEffect,
						ID3D10Blob **ppErrors,*/
						HRESULT *p_hresult)
					{
						d3dblob *error_out = NULL;
						::HRESULT hr = ::D3DX10CreateEffectFromFileA(
							filename, defines, include, profile, hlsl_flags, fx_flags, m_dev.get_d3ddevice_ptr(),
							effect_pool, pump, &m_if, &error_out, p_hresult);

						if ( hr != D3D_OK ){
							::std::string msg = "D3DX10CreateEffectFromFileA() Failed.";
							if ( error_out != NULL )
							{
								const char* error_out_msg = (const char*)error_out->GetBufferPointer();
								msg += error_out_msg;
								error_out->Release();
							}
							throw api_error(effect_D3DX10CreateEffectFromFile_Failed, msg, hr);
						}
					}

					void _cunstruct_aw(
						LPCTSTR filename,
						CONST D3D10_SHADER_MACRO *defines,
						ID3D10Include *include,
						LPCSTR profile,
						UINT hlsl_flags,
						UINT fx_flags,
						/*ID3D10Device *pDevice,*/
						ID3D10EffectPool *effect_pool,
						ID3DX10ThreadPump *pump,
						/*ID3D10Effect **ppEffect,
						ID3D10Blob **ppErrors,*/
						HRESULT *p_hresult)
					{
						d3dblob *error_out = NULL;
						::HRESULT hr = ::D3DX10CreateEffectFromFile(
							filename, defines, include, profile, hlsl_flags, fx_flags, m_dev.get_d3ddevice_ptr(),
							effect_pool, pump, &m_if, &error_out, p_hresult);

						if ( hr != D3D_OK ){
							const char* error_out_msg = (const char*)error_out->GetBufferPointer();
							::std::string msg = "D3DX10CreateEffectFromFileA() Failed.";
							msg += error_out_msg;
							error_out->Release();
							throw api_error(effect_D3DX10CreateEffectFromFile_Failed, msg, hr);
						}
					}

				public:
					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						CONST D3D10_SHADER_MACRO *defines=NULL,
						ID3D10Include *include=NULL,
						ID3D10EffectPool *effect_pool=NULL,
						UINT fx_flags=0,
						UINT hlsl_flags=0,
						ID3DX10ThreadPump *pump=NULL,
						HRESULT *p_hresult=NULL) : m_dev(d)
					{
						_cunstruct_a(
							f.filename.c_str(), defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}

					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file_lpcstr &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						CONST D3D10_SHADER_MACRO *defines=NULL,
						ID3D10Include *include=NULL,
						ID3D10EffectPool *effect_pool=NULL,
						UINT fx_flags=0,
						UINT hlsl_flags=0,
						ID3DX10ThreadPump *pump=NULL,
						HRESULT *p_hresult=NULL) : m_dev(d)
					{
						_cunstruct_a(
							f.filename, defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}
					template <shader_model_type_e SHADER_MODEL_TYPE, unsigned char SHADER_MAJOR, unsigned char SHADER_MINOR>
					effect(device &d, const construct_by_file_lpctstr &f,
						const shader_model<SHADER_MODEL_TYPE, SHADER_MAJOR, SHADER_MINOR> &model,
						CONST D3D10_SHADER_MACRO *defines=NULL,
						ID3D10Include *include=NULL,
						ID3D10EffectPool *effect_pool=NULL,
						UINT fx_flags=0,
						UINT hlsl_flags=0,
						ID3DX10ThreadPump *pump=NULL,
						HRESULT *p_hresult=NULL) : m_dev(d)
					{
						_cunstruct_aw(
							f.filename, defines, include, model.to_string(), hlsl_flags, fx_flags/*, m_dev.get_d3ddevice_ptr()*/,
							effect_pool, pump, /*m_if, error_out,*/ p_hresult);
					}

					/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

					technique get_technique(const char* name)
					{
						d3dtechnique* d3dtech = m_if->GetTechniqueByName(name);
						if ( d3dtech == NULL ){
							throw exception(get_technique__Unknown_technique,
								::std::string("GetTechniqueByName() unknown technique ")+name);
						}
						return technique(d3dtech);
					}
					const technique get_technique(const char* name) const { return get_technique(name); }
					technique operator[](const char* name){ return get_technique(name); }
					const technique operator[](const char* name) const { return get_technique(name); }

					d3deffect* get_ieffect_ptr(){ return m_if; }
					d3deffect* get_ieffect_ptr() const { return m_if; }
				};
				

				///////////////////////////////////////////////////////////////////////////////////

				class render_device
				{
				public:
					enum exception_codes
					{
						exception_codes__head = exception_code_root::shader,

						buffer__CreateBuffer_Failed,
						setup_input_layout__CreateInputLayout_Failed,
						apply_pass__ID3D10EffectPass_Apply_Failed,
						get_pass_Failed,
						get_technique_Failed
					};
				protected:
					device& m_dev;

				public:
					render_device(device &d) : m_dev(d)
					{
					}

					///////////////////////////////////////////////////////

					template <typename VERTEX_TYPE>
					void set_vertex_buffer(const input_vertex_buffer<VERTEX_TYPE> &vb)// const OK...?
					{
						typename const input_vertex_buffer<VERTEX_TYPE>::_IfType *buffer_ptr = vb.get_buffer_ptr();
						UINT stride = sizeof(VERTEX_TYPE);
						UINT offset = 0;
						m_dev.get_d3ddevice_ptr()->IASetVertexBuffers( 0, 1, (ID3D10Buffer *const *)&buffer_ptr, &stride, &offset );
					}

					///////

					template <typename VERTEX_TYPE>
					void setup_input_layout(const ::D3D10_INPUT_ELEMENT_DESC *input_desc, int num_elements, const pass &pss)
					{
						::ID3D10InputLayout* input_layout;

						//	Create Input Layout
						HRESULT hr = m_dev.get_d3ddevice_ptr()->CreateInputLayout(input_desc, num_elements,
							pss.get_desc().pIAInputSignature, pss.get_desc().IAInputSignatureSize, &input_layout );
						if ( hr != D3D_OK ){
							throw api_error(setup_input_layout__CreateInputLayout_Failed,
								"ID3D10Device::CreateInputLayout() Failed.", hr);
						}

						//	Bind Input Layout
						m_dev.get_d3ddevice_ptr()->IASetInputLayout(input_layout);
					}

					///////

					void apply_pass(const pass& pss)
					{
						HRESULT hr = pss.get_ipass_ptr()->Apply(0);
						if ( hr != D3D_OK ){
							throw api_error(apply_pass__ID3D10EffectPass_Apply_Failed,
								"ID3D10EffectPass::Apply() Failed.", hr);
						}
					}

					void set_primitive_topology(::D3D10_PRIMITIVE_TOPOLOGY topology)
					{
						m_dev.get_d3ddevice_ptr()->IASetPrimitiveTopology(topology);
					}

					void draw(UINT vertex_count, UINT vertex_offset=0)
					{
						m_dev.get_d3ddevice_ptr()->Draw(vertex_count, vertex_offset);
					}

					void draw_primitive(::D3D10_PRIMITIVE_TOPOLOGY topology, UINT vertex_count, UINT vertex_offset=0)
					{
						set_primitive_topology(topology);
						draw(vertex_count, vertex_offset);
					}
				};
				typedef render_device renderer;

				///////////////////////////////////////////////////////////////////////////
			}
		}
	}
}

#endif//__SFJP_ROAST_EX__graphics__directx__dx10__shader_HPP__
