﻿module mmgenerator;
private import std.cstream;
private import std.file;
private import std.math;
private import coneneko.math;
private import skeleton;
private import mqo;
private import mkm;
private import coneneko.sjis2utf8;
private import std.stream;
private import std.random;

//void write(Stream outStream, Matrix m);
//Vector mul3(Vector a, Matrix b); // TODO vec3.w問題
//uint[] getKeyFrameList(Mkm.Motion motion);
//uint getEndFrame(Mkm.Motion motion);
//Vector linear(Vector previous, Vector next, float t); // 線形補間
interface _BoneMotion
{
	//BoneMotion.this(Mqo mqo, Mkm mkm)
	struct BoneMatrix { Matrix lookAt, rotation, translation; }
	BoneMatrix boneMatrix(uint t, uint boneIndex);
	Matrix boneMotionMatrix(uint t, uint boneIndex); // そのまま使える方
}
interface _KeyFrameIndexer
{
	//KeyFrameIndexer.this(int[] keyFrameList, int endFrame)
	int endFrame();
	int getPreviousFrameIndex(int t);
	int getNextFrameIndex(int t);
}
interface _MotionMatrix : _KeyFrameIndexer
{
	//LocationMotion.this(Mkm mkm)
	static QuaternionMotionMatrix[] createList(Mkm mkm);
	Matrix at(int t); // 補間
}

MotionMatrixFormat[] createMotionMatrixFormatList(Skeleton skeleton, Mkm mkm)
{
	MotionMatrixFormat[] result;
	foreach (Mkm.Motion a; mkm.motion)
	{
		BoneMotion bm = new BoneMotion(skeleton, a);
		result ~= createMotionMatrix(skeleton, a);
	}
	return result;
}

MotionMatrixFormat createMotionMatrix(Skeleton skeleton, Mkm.Motion motion)
{
	BoneMotion bm = new BoneMotion(skeleton, motion);
	
	MotionMatrixFormat result = MotionMatrixFormat.createWriter(bm.endFrame, bm.boneList.length);
	for (int i = 0; i < result.frameLength; i++)
	{
		for (int j = 0; j < result.boneLength; j++)
		{
			BoneMotion.BoneMatrix m = bm.boneMatrix(i, j);
			result.setLookAt(i, j, m.lookAt);
			result.setRotation(i, j, m.rotation);
			result.setTranslation(i, j, m.translation);
		}
	}
	return result;
}

int[] getKeyFrameList(Mkm.Motion motion)
{
	int[] result;
	foreach (Mkm.IntAndFloat4 a; motion.quaternion[0].keyFrame)
	{
		result ~= a.keyFrameNumber;
	}
	return result;
}

uint getEndFrame(Mkm.Motion motion)
{
	return motion.endframe;
}

Vector linear(Vector previous, Vector next, float t) // 線形補間
in
{
	assert(0.0 <= t && t <= 1.0);
}
body
{
	return previous * (1 - t) + next * t;
}


// とりあえずmkmの最初のmotionのみを扱う
class BoneMotion : _BoneMotion
{
	public final Skeleton skeleton;
	public final Bone[] boneList;
	public final BoneRelation[] boneRelationList;
	public final uint endFrame;
	public final int[] keyFrameList;
	public final LocationMotionMatrix locationMotion;
	public final QuaternionMotionMatrix[] quaternionMotionList; // boneListのインデックスと対応する
	
	this(Mqo mqo, Mkm.Motion motion)
	{
		this(new Skeleton(mqo), motion);
	}
	
	this(Skeleton skeleton, Mkm.Motion motion)
	{
		this.skeleton = skeleton;
		this.boneList = skeleton.boneList;
		this.boneRelationList = skeleton.boneRelationList;

		this.endFrame = getEndFrame(motion);
		this.keyFrameList = getKeyFrameList(motion);
		this.locationMotion = new LocationMotionMatrix(motion);
		this.quaternionMotionList = QuaternionMotionMatrix.createList(motion);
		assert(boneList.length == quaternionMotionList.length);
		initializeBoneMatrix();
	}
	
	Matrix boneMotionMatrix(uint t, uint boneIndex)
	{
		return getMatrix(boneList[boneIndex], t);
	}
	
	Matrix getMatrix(Bone bone, int t)
	{
		assertEquals(getMatrix2(bone, t), getMatrix3(bone, t), 0.0001);
		return getMatrix3(bone, t);
	}
	
	private static void assertEquals(Matrix a, Matrix b, float f)
	{
		for (int i = 0; i < 4; i++)
		{
			for (int j = 0; j < 4; j++)
			{
				assert(fabs(a.m[i][j] - b.m[i][j]) < f);
			}
		}
	}
	
	// boneのインデックスに対応、index=childからparentへの平行移動
	public Matrix translation(int boneIndex)
	{
		Bone parent = skeleton.getParent(boneList[boneIndex]);
		Vector moveVector = mul3(boneList[boneIndex].startPosition, lookAtMatrix(parent));
		return Matrix.translation(moveVector.x, moveVector.y, moveVector.z);
	}
	
	private final BoneMatrix[][] _boneMatrix; // tickLength * boneLength
	BoneMatrix boneMatrix(uint t, uint boneIndex) { return _boneMatrix[t][boneIndex]; }
	void initializeBoneMatrix()
	out
	{
		assert(endFrame == _boneMatrix.length);
		foreach (BoneMatrix[] a; _boneMatrix) assert(boneList.length == a.length);
	}
	body
	{
		_boneMatrix.length = endFrame;
		foreach (inout BoneMatrix[] a; _boneMatrix) a.length = boneList.length;
		
		for (int t = 0; t < endFrame; t++)
		{
			for (int b = 0; b < boneList.length; b++)
			{
				_boneMatrix[t][b].lookAt = lookAtMatrix(boneList[b]);
				_boneMatrix[t][b].rotation = quaternionMotionList[b].at(t);
				_boneMatrix[t][b].translation = isRoot(b) ? locationMotion.at(t) : translation(b);
			}
		}
	}
	
	bit isRoot(uint boneIndex)
	{
		return (!skeleton.getParent(boneList[boneIndex])) ? true : false;
	}
	
	uint parentIndex(uint a)
	{
		return skeleton.indexOf(skeleton.getParent(boneList[a]));
	}
	
	// lookAt * rotation * translation * rotation * translation * ... * translation
	// 全てのボーンはlookAt, rotation, translationを持つ
	// lookAtはボーンごとに固定
	// rotationはボーンと時間ごとに変化する
	// rootのtranslationはボーンと時間ごとに変化する、それ以外のtranslationはボーンごとに固定
	Matrix getMatrix3(Bone bone, int t) // 再帰を使わない
	in
	{
		assert(0 <= t && t <= endFrame);
	}
	body
	{
		uint currentIndex = skeleton.indexOf(bone);
		Matrix result = _boneMatrix[t][currentIndex].lookAt;
		while (true)
		{
			result *= _boneMatrix[t][currentIndex].rotation;
			result *= _boneMatrix[t][currentIndex].translation;
			if (isRoot(currentIndex)) return result;
			currentIndex = parentIndex(currentIndex);
		}
		throw new Error("BoneMotion.getMatrix3");
	}
	
	Matrix getMatrix2(Bone bone, int t) // 再帰を使う
	in
	{
		assert(0 <= t && t <= endFrame);
	}
	body
	{
		Matrix getMatrixR(Bone bone, int t) // 再帰
		{
			Matrix result = Matrix.identity;

			result *= quaternionMotionList[skeleton.indexOf(bone)].at(t);

			Bone parentBone = skeleton.getParent(bone);
			if (parentBone is null)
			{
				result *= locationMotion.at(t);
			}
			else
			{
				Vector moveVector = mul3(bone.startPosition, lookAtMatrix(parentBone));
				result *= Matrix.translation(moveVector.x, moveVector.y, moveVector.z);
				result *= getMatrixR(parentBone, t);
			}

			return result;
		}
		
		return lookAtMatrix(bone) * getMatrixR(bone, t);
	}
	
	private static Matrix lookAtMatrix(Bone bone) // ローカルへの変換
	{
		Vector up = normalize(bone.hPosition - bone.startPosition);
		return Matrix.lookAtLH(
			bone.startPosition.x, bone.startPosition.y, bone.startPosition.z,
			bone.endPosition.x, bone.endPosition.y, bone.endPosition.z,
			up.x, up.y, up.z
		);
	}
	
	override char[] toString()
	{
		MemoryStream result = new MemoryStream();

		foreach (Bone bone; boneList) result.writeLine(bone.toString());
		result.writeLine("");

		foreach (BoneRelation boneRelation; boneRelationList)
		{
			result.printf("%d %d\r\n",
				skeleton.indexOf(boneRelation.first), skeleton.indexOf(boneRelation.second));
		}
		result.writeLine("");
		
		result.printf("%d\r\n", endFrame);
		
		foreach (int i, int keyFrame; keyFrameList)
		{
			result.printf("%d", keyFrame);
			if (i != keyFrameList.length - 1) result.writeString(" ");
		}
		result.writeLine("");
		result.writeLine("");

		result.writeLine(locationMotion.toString());
		result.writeLine("");
		
		foreach (QuaternionMotionMatrix quaternionMotion; quaternionMotionList)
		{
			result.writeLine(quaternionMotion.toString());
		}

		return result.toString();
	}
}

class KeyFrameIndexer : _KeyFrameIndexer
{
	unittest
	{
		const int[] keyFrameList1 = [ 0, 10, 20 ];
		KeyFrameIndexer mm1 = new KeyFrameIndexer(keyFrameList1, 20);
		assert(0 == mm1.getPreviousFrameIndex(0));
		assert(0 == mm1.getPreviousFrameIndex(5));
		assert(1 == mm1.getPreviousFrameIndex(10));
		assert(1 == mm1.getPreviousFrameIndex(15));
		assert(2 == mm1.getPreviousFrameIndex(20));
		assert(0 == mm1.getNextFrameIndex(0));
		assert(1 == mm1.getNextFrameIndex(5));
		assert(1 == mm1.getNextFrameIndex(10));
		assert(2 == mm1.getNextFrameIndex(15));
		assert(2 == mm1.getNextFrameIndex(20));

		const int[] keyFrameList2 = [ 1, 10, 19 ];
		KeyFrameIndexer mm2 = new KeyFrameIndexer(keyFrameList2, 20);
		assert(0 == mm2.getPreviousFrameIndex(0));
		assert(0 == mm2.getPreviousFrameIndex(1));
		assert(0 == mm2.getPreviousFrameIndex(5));
		assert(1 == mm2.getPreviousFrameIndex(10));
		assert(1 == mm2.getPreviousFrameIndex(15));
		assert(2 == mm2.getPreviousFrameIndex(19));
		assert(2 == mm2.getPreviousFrameIndex(20));
		assert(0 == mm2.getNextFrameIndex(0));
		assert(0 == mm2.getNextFrameIndex(1));
		assert(1 == mm2.getNextFrameIndex(5));
		assert(1 == mm2.getNextFrameIndex(10));
		assert(2 == mm2.getNextFrameIndex(15));
		assert(2 == mm2.getNextFrameIndex(19));
		assert(2 == mm2.getNextFrameIndex(20));
	}

	private final int[] keyFrameList;
	private final int _endFrame;

	this(int[] keyFrameList, int endFrame)
	{
		this.keyFrameList = keyFrameList;
		this._endFrame = endFrame;
	}
	
	int endFrame() { return _endFrame; }

	int getPreviousFrameIndex(int t)
	in
	{
		assert(0 <= t && t <= _endFrame);
	}
	out (result)
	{
		assert(0 <= result && result < keyFrameList.length);
	}
	body
	{
		int result = -1;
		foreach (int a; keyFrameList) if (a <= t) result++;
		if (result < 0) result = 0;
		return result;
	}

	int getNextFrameIndex(int t)
	in
	{
		assert(0 <= t && t <= _endFrame);
	}
	out (result)
	{
		assert(0 <= result && result < keyFrameList.length);
	}
	body
	{
		int result = keyFrameList.length;
		foreach (int a; keyFrameList) if (a >= t) result--;
		if (result == keyFrameList.length) result--;
		return result;
	}
}

class LocationMotionMatrix : KeyFrameIndexer, _MotionMatrix
{
	private final Vector[] locationList; // keyFrameListのインデックスと対応
	
	this(Mkm.Motion motion)
	{
		foreach (Mkm.IntAndFloat3 a; motion.vector.keyFrame)
		{
			locationList ~= Vector.create(a.position.x, a.position.y, a.position.z);
		}
		super(getKeyFrameList(motion), getEndFrame(motion));
	}
	
	override char[] toString()
	{
		MemoryStream result = new MemoryStream();
		foreach (Vector location; locationList)
		{
			result.printf("%f %f %f\r\n",
				cast(double)location.x, cast(double)location.y, cast(double)location.z);
		}
		return result.toString();
	}
	
	Matrix at(int t) // 補間
	{
		int previousIndex = getPreviousFrameIndex(t);
		int nextIndex = getNextFrameIndex(t);
		if (previousIndex == nextIndex) return createTranslationMatrix(locationList[previousIndex]);

		return createTranslationMatrix(
			linear(locationList[previousIndex], locationList[nextIndex],
				cast(float)(t - keyFrameList[previousIndex]) / 
				cast(float)(keyFrameList[nextIndex] - keyFrameList[previousIndex])
				)
			);
	}

	private Matrix createTranslationMatrix(Vector v)
	{
		return Matrix.translation(v.x, v.y, v.z);
	}
}

class QuaternionMotionMatrix : KeyFrameIndexer, _MotionMatrix
{
	private final Vector[] quaternionList; // keyFrameListのインデックスと対応
	
	static QuaternionMotionMatrix[] createList(Mkm.Motion motion)
	{
		QuaternionMotionMatrix[] result;
		foreach (Mkm.Quaternion a; motion.quaternion)
		{
			Vector[] quaternionList;
			foreach (Mkm.IntAndFloat4 b; a.keyFrame)
			{
				Mkm.float4 r = b.rotation;
				quaternionList ~= Vector.create(r.a, r.b, r.c, r.d);
			}
			result ~= new QuaternionMotionMatrix(
				quaternionList, getKeyFrameList(motion), getEndFrame(motion));
		}
		return result;
	}
	
	this(Vector[] quaternionList, int[] keyFrameList, int endFrame)
	{
		super(keyFrameList, endFrame);
		this.quaternionList = quaternionList;
	}
	
	override char[] toString()
	{
		MemoryStream result = new MemoryStream();
		foreach (Vector quaternion; quaternionList)
		{
			result.printf("%f %f %f %f\r\n",
				cast(double)quaternion.x, cast(double)quaternion.y, cast(double)quaternion.z,
				cast(double)quaternion.w);
		}
		return result.toString();
	}
	
	Matrix at(int t) // 補間
	{
		int previousIndex = getPreviousFrameIndex(t);
		int nextIndex = getNextFrameIndex(t);
		if (previousIndex == nextIndex) return createQuaternionMatrix(quaternionList[previousIndex]);
		
		/*
		return createQuaternionMatrix(
			linear(quaternionList[previousIndex], quaternionList[nextIndex],
				cast(float)(t - keyFrameList[previousIndex]) / 
				cast(float)(keyFrameList[nextIndex] - keyFrameList[previousIndex])
				)
			);
		*/
		if (nextIndex >= quaternionList.length) nextIndex = 0; // 応急処置

		assert(0 <= previousIndex);
		assert(0 <= nextIndex);
		assert(previousIndex < quaternionList.length);
		assert(nextIndex < quaternionList.length);
		assert(previousIndex < keyFrameList.length);
		assert(nextIndex < keyFrameList.length);
		return createQuaternionMatrix(
			slerp(quaternionList[previousIndex], quaternionList[nextIndex],
				cast(float)(t - keyFrameList[previousIndex]) / 
				cast(float)(keyFrameList[nextIndex] - keyFrameList[previousIndex])
				)
			);
	}
	
	private static Vector slerp(Vector previous, Vector next, float t) // 球面線形補間
	{
		return createSlerpQuaternion(previous, next, t);
	}

	private Matrix createQuaternionMatrix(Vector quaternion)
	{
		return createMatrixFromQuaternion(quaternion);
	}
}


private interface _MotionMatrixFormat
{
	static MotionMatrixFormat createWriter(uint frameLength, uint boneLength);
	static MotionMatrixFormat createReader(void[] data);
	uint frameLength();
	uint boneLength();
	Matrix getLookAt(uint t, uint boneIndex);
	Matrix getRotation(uint t, uint boneIndex);
	Matrix getTranslation(uint t, uint boneIndex);
	void setLookAt(uint t, uint boneIndex, Matrix lookAt);
	void setRotation(uint t, uint boneIndex, Matrix rotation);
	void setTranslation(uint t, uint boneIndex, Matrix translation);
	void[] toArray();
}

class MotionMatrixFormat : _MotionMatrixFormat
{
	unittest
	{
		Matrix mrand()
		{
			Matrix result;
			for (int i = 0; i < 4; i++)
			{
				for (int j = 0; j < 4; j++) result[i, j] = cast(float)rand();
			}
			return result;
		}
		MotionMatrixFormat writer = MotionMatrixFormat.createWriter(30, 20);
		assert(30 == writer.frameLength);
		assert(20 == writer.boneLength);
		rand_seed(0, 0);
		for (int i = 0; i < writer.frameLength; i++)
		{
			for (int j = 0; j < writer.boneLength; j++)
			{
				writer.setLookAt(i, j, mrand());
				writer.setRotation(i, j, mrand());
				writer.setTranslation(i, j, mrand());
			}
		}
		MotionMatrixFormat reader = MotionMatrixFormat.createReader(writer.toArray());
		assert(30 == reader.frameLength);
		assert(20 == reader.boneLength);
		rand_seed(0, 0);
		for (int i = 0; i < reader.frameLength; i++)
		{
			for (int j = 0; j < reader.boneLength; j++)
			{
				assert(mrand() == reader.getLookAt(i, j));
				assert(mrand() == reader.getRotation(i, j));
				assert(mrand() == reader.getTranslation(i, j));
			}
		}
	}
	
	private Matrix[][] _lookAt, _rotation, _translation;
	uint frameLength() { return _lookAt.length; }
	uint boneLength() { return _lookAt[0].length; }
	const char[] HEADER = "MotionMatrixFormat";
	
	private this(uint frameLength, uint boneLength)
	{
		_lookAt.length = frameLength;
		foreach (inout Matrix[] a; _lookAt) a.length = boneLength;
		_rotation.length = frameLength;
		foreach (inout Matrix[] a; _rotation) a.length = boneLength;
		_translation.length = frameLength;
		foreach (inout Matrix[] a; _translation) a.length = boneLength;
	}
	
	static MotionMatrixFormat createWriter(uint frameLength, uint boneLength)
	{
		return new MotionMatrixFormat(frameLength, boneLength);
	}
	
	static MotionMatrixFormat createReader(void[] data)
	{
		MemoryStream ds = new MemoryStream(cast(ubyte[])data);
		char[] header = new char[HEADER.length];
		ds.read(header);
		if (header != HEADER) throw new Error("MotionMatrixFormat.createReader");
		uint frameLength, boneLength;
		ds.read(frameLength);
		ds.read(boneLength);
		MotionMatrixFormat result = new MotionMatrixFormat(frameLength, boneLength);
		for (int i = 0; i < frameLength; i++)
		{
			for (int j = 0; j < boneLength; j++)
			{
				result.setLookAt(i, j, readMatrix(ds));
				result.setRotation(i, j, readMatrix(ds));
				result.setTranslation(i, j, readMatrix(ds));
			}
		}
		return result;
	}
	
	void[] toArray()
	{
		MemoryStream result = new MemoryStream();
		result.write(HEADER);
		result.write(frameLength);
		result.write(boneLength);
		for (int i = 0; i < frameLength; i++)
		{
			for (int j = 0; j < boneLength; j++)
			{
				writeMatrix(result, getLookAt(i, j));
				writeMatrix(result, getRotation(i, j));
				writeMatrix(result, getTranslation(i, j));
			}
		}
		return result.data;
	}
	
	private static void writeMatrix(Stream s, Matrix m)
	{
		s.writeExact(&m, Matrix.sizeof * 1);
	}
	
	private static Matrix readMatrix(Stream s)
	{
		Matrix result;
		s.readExact(&result, Matrix.sizeof * 1);
		return result;
	}
	
	Matrix getLookAt(uint t, uint boneIndex) { return _lookAt[t][boneIndex]; }
	Matrix getRotation(uint t, uint boneIndex) { return _rotation[t][boneIndex]; }
	Matrix getTranslation(uint t, uint boneIndex) { return _translation[t][boneIndex]; }
	void setLookAt(uint t, uint boneIndex, Matrix lookAt) { _lookAt[t][boneIndex] = lookAt; }
	void setRotation(uint t, uint boneIndex, Matrix rotation) { _rotation[t][boneIndex] = rotation; }
	void setTranslation(uint t, uint boneIndex, Matrix translation) { _translation[t][boneIndex] = translation; }
}
