//
// mathematical functions
//

#include "Expr.h"
#include "Module_math.h"
#include "Algorithm.h"

AScript_BeginModule(math)

//-----------------------------------------------------------------------------
// AScipr module functions
//-----------------------------------------------------------------------------
// num = real(num):map
AScript_DeclareFunction(real)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(real)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetNumber(value.GetComplex().real());
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = imag(num):map
AScript_DeclareFunction(imag)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(imag)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(0.);
	} else if (value.IsComplex()) {
		result.SetNumber(value.GetComplex().imag());
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = arg(num):map
AScript_DeclareFunction(arg)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(arg)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(0.);
	} else if (value.IsComplex()) {
		result.SetNumber(std::arg(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = norm(num):map
AScript_DeclareFunction(norm)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(norm)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber() * value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetNumber(std::norm(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = conj(num):map
AScript_DeclareFunction(conj)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(conj)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetComplex(std::conj(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = acos(num):map
AScript_DeclareFunctionBegin(acos)
AScript_DeclareFunctionEnd(acos)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(acos)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::acos(value.GetNumber()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = asin(num):map
AScript_DeclareFunctionBegin(asin)
AScript_DeclareFunctionEnd(asin)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(asin)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::asin(value.GetNumber()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = atan(num):map
AScript_DeclareFunctionBegin(atan)
AScript_DeclareFunctionEnd(atan)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(atan)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::atan(value.GetNumber()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = atan2(num, num):map
AScript_DeclareFunction(atan2)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value1", VTYPE_Any);
	DeclareArg(env, "value2", VTYPE_Any);
}

AScript_ImplementFunction(atan2)
{
	const Value &value1 = context.GetValue(0);
	const Value &value2 = context.GetValue(1);
	Value result;
	if (value1.IsNumber() && value2.IsNumber()) {
		result.SetNumber(::atan2(value1.GetNumber(), value2.GetNumber()));
	} else if (value1.IsValid() && value2.IsValid()) {
		SetError_InvalidValType(sig, value1, value2);
	}
	return result;
}

// num = ceil(num):map
AScript_DeclareFunction(ceil)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(ceil)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::ceil(value.GetNumber()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = cos(num):map
AScript_DeclareFunctionBegin(cos)
AScript_DeclareFunctionEnd(cos)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(cos)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::cos(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::cos(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = cosh(num):map
AScript_DeclareFunction(cosh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(cosh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::cosh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::cosh(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = exp(num):map
AScript_DeclareFunctionBegin(exp)
AScript_DeclareFunctionEnd(exp)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(exp)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::exp(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::exp(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = abs(num):map
AScript_DeclareFunction(abs)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(abs)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::fabs(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::abs(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = floor(num):map
AScript_DeclareFunction(floor)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(floor)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::floor(value.GetNumber()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = log(num):map
AScript_DeclareFunctionBegin(log)
AScript_DeclareFunctionEnd(log)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(log)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() > 0.) {
			result.SetNumber(::log(value.GetNumber()));
		} else {
			SetError_InvalidValue(sig, value);
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::log(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = log10(num):map
AScript_DeclareFunctionBegin(log10)
AScript_DeclareFunctionEnd(log10)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(log10)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() > 0.) {
			result.SetNumber(::log10(value.GetNumber()));
		} else {
			SetError_InvalidValue(sig, value);
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::log10(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sin(num):map
AScript_DeclareFunctionBegin(sin)
AScript_DeclareFunctionEnd(sin)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(sin)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::sin(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::sin(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sinh(num):map
AScript_DeclareFunction(sinh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(sinh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::sinh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::sinh(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sqrt(num):map
AScript_DeclareFunctionBegin(sqrt)
AScript_DeclareFunctionEnd(sqrt)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(sqrt)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() >= 0) {
			result.SetNumber(::sqrt(value.GetNumber()));
		} else {
			result.SetComplex(Complex(0, ::sqrt(-value.GetNumber())));
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::sqrt(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = tan(num):map
AScript_DeclareFunctionBegin(tan)
AScript_DeclareFunctionEnd(tan)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(tan)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::tan(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::tan(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = tanh(num):map
AScript_DeclareFunction(tanh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_Any);
}

AScript_ImplementFunction(tanh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::tanh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::tanh(value.GetComplex()));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// expr = least_square(x:iterator, y:iterator, n:number = 1, var:symbol = `x)
AScript_DeclareFunction(least_square)
{
	SetMode(RSLTMODE_Normal, MAP_Off, FLAT_Off);
	DeclareArg(env, "x", VTYPE_Iterator);
	DeclareArg(env, "y", VTYPE_Iterator);
	DeclareArg(env, "n", VTYPE_Number, OCCUR_Once, false, new Expr_Value(1));
	DeclareArg(env, "var", VTYPE_Symbol, OCCUR_Once, false,
											new Expr_Symbol(AScript_Symbol(x)));
}

AScript_ImplementFunction(least_square)
{
	size_t nDim = context.GetSizeT(2);
	if (nDim == 0) {
		sig.SetError(ERR_ValueError, "invalid dimension");
		return Value::Null;
	}
	size_t nCols = nDim + 1;
	size_t nRows = nCols;
	Iterator *pIterX = context.GetIterator(0);
	Iterator *pIterY = context.GetIterator(1);
	const Symbol *pSymbolVar = context.GetSymbol(3);
	NumberList sumListXX(nCols * 2, 0), sumListXY(nCols, 0);
	bool flagX = false, flagY = false;
	for (;;) {
		Value valueX, valueY;
		flagX = pIterX->Next(sig, valueX);
		if (sig.IsSignalled()) return Value::Null;
		flagY = pIterY->Next(sig, valueY);
		if (sig.IsSignalled()) return Value::Null;
		if (!(flagX && flagY)) break;
		if (!valueX.IsNumber()) {
			sig.SetError(ERR_ValueError, "cannot calculate non-number value");
			return Value::Null;
		}
		if (!valueY.IsNumber()) {
			sig.SetError(ERR_ValueError, "cannot calculate non-number value");
			return Value::Null;
		}
		Number numX = valueX.GetNumber(), numY = valueY.GetNumber();
		Number productX = 1;
		NumberList::iterator pSumXX = sumListXX.begin();
		NumberList::iterator pSumXY = sumListXY.begin();
		for ( ; pSumXY != sumListXY.end(); pSumXX++, pSumXY++) {
			*pSumXX += productX;
			*pSumXY += productX * numY;
			productX *= numX;
		}
		for ( ; pSumXX != sumListXX.end(); pSumXX++) {
			*pSumXX += productX;
			productX *= numX;
		}
	}
	if (flagX || flagY) {
		sig.SetError(ERR_ValueError, "number of x and y must be the same");
		return Value::Null;
	}
	NumberList mat;
	mat.reserve(nCols * nRows * 2);
	NumberList::iterator pSumXXBase = sumListXX.begin();
	for (size_t iRow = 0; iRow < nRows; iRow++, pSumXXBase++) {
		NumberList::iterator pSumXX = pSumXXBase;
		for (size_t iCol = 0; iCol < nCols; iCol++, pSumXX++) {
			mat.push_back(*pSumXX);
		}
		for (size_t iCol = 0; iCol < nCols; iCol++) {
			mat.push_back((iCol == iRow)? 1. : 0.);
		}
	}
	Number det;
	if (!AScript::InvertMatrix(mat, nCols, det)) {
		sig.SetError(ERR_ValueError, "failed to calculate inverse matrix");
		return Value::Null;
	}
	NumberList alphaList;
	alphaList.reserve(nCols);
	NumberList::iterator pMat = mat.begin() + nCols;
	for (size_t iRow = 0; iRow < nRows; iRow++) {
		Number alpha = 0;
		NumberList::iterator pSumXY = sumListXY.begin();
		for (size_t iCol = 0; iCol < nCols; iCol++, pMat++, pSumXY++) {
			alpha += *pMat * *pSumXY;
		}
		alphaList.push_back(alpha);
		pMat += nCols;
	}
	Value result;
	do {
		NumberList::iterator pAlpha = alphaList.begin();
		Expr *pExpr = new Expr_Value(*pAlpha);
		pAlpha++;
		Expr *pExprLeft = new Expr_BinaryOp(env.GetFunc_Multiply(),
			new Expr_Value(*pAlpha),
			new Expr_Symbol(pSymbolVar));
		pAlpha++;
		pExpr = new Expr_BinaryOp(env.GetFunc_Plus(), pExpr, pExprLeft);
		for ( ; pAlpha != alphaList.end(); pAlpha++) {
			size_t n = pAlpha - alphaList.begin();
			pExprLeft = new Expr_BinaryOp(env.GetFunc_Multiply(),
				new Expr_Value(*pAlpha),
				new Expr_BinaryOp(env.GetFunc_Power(),
					new Expr_Symbol(pSymbolVar),
					new Expr_Value(n)));
			pExpr = new Expr_BinaryOp(env.GetFunc_Plus(), pExpr, pExprLeft);
		}
		Function *pFunc = new FunctionCustom(env,
							AScript_Symbol(_anonymous_), pExpr, FUNCTYPE_Function);
		pFunc->SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
		pFunc->DeclareArg(env, pSymbolVar, VTYPE_Number);
		result.InitAsFunction(env, pFunc);
	} while (0);
	return result;
}

// fft(seq[])
AScript_DeclareFunction(fft)
{
	SetMode(RSLTMODE_Normal, MAP_Off, FLAT_Off);
	DeclareArg(env, "seq", VTYPE_Any, OCCUR_Once, true);
}

AScript_ImplementFunction(fft)
{
	
	return Value::Null;
}

// list = integral(func, sequence)
AScript_DeclareFunction(integral)
{
	SetMode(RSLTMODE_Normal, MAP_Off, FLAT_Off);
}

AScript_ImplementFunction(integral)
{
	return Value::Null;
}

// Module entry
AScript_ModuleEntry()
{
	// value assignment
	AScript_AssignValue(AScript_Symbol(e),	Value(2.71828182845904523536));
	AScript_AssignValue(AScript_Symbol(pi),	Value(3.14159265358979323846));
	// function assignment
	Environment::Global *pGlobal = env.GetGlobal();
	AScript_AssignFunction(real);
	AScript_AssignFunction(imag);
	AScript_AssignFunction(arg);
	AScript_AssignFunction(norm);
	AScript_AssignFunction(conj);
	pGlobal->_pFunc_acos = AScript_AssignFunction(acos)->IncRef();
	pGlobal->_pFunc_asin = AScript_AssignFunction(asin)->IncRef();
	pGlobal->_pFunc_atan = AScript_AssignFunction(atan)->IncRef();
	pGlobal->_pFunc_atan2 = AScript_AssignFunction(atan2)->IncRef();
	pGlobal->_pFunc_ceil = AScript_AssignFunction(ceil)->IncRef();
	pGlobal->_pFunc_cos = AScript_AssignFunction(cos)->IncRef();
	pGlobal->_pFunc_cosh = AScript_AssignFunction(cosh)->IncRef();
	pGlobal->_pFunc_exp = AScript_AssignFunction(exp)->IncRef();
	pGlobal->_pFunc_abs = AScript_AssignFunction(abs)->IncRef();
	pGlobal->_pFunc_floor = AScript_AssignFunction(floor)->IncRef();
	pGlobal->_pFunc_log = AScript_AssignFunction(log)->IncRef();
	pGlobal->_pFunc_log10 = AScript_AssignFunction(log10)->IncRef();
	pGlobal->_pFunc_sin = AScript_AssignFunction(sin)->IncRef();
	pGlobal->_pFunc_sinh = AScript_AssignFunction(sinh)->IncRef();
	pGlobal->_pFunc_sqrt = AScript_AssignFunction(sqrt)->IncRef();
	pGlobal->_pFunc_tan = AScript_AssignFunction(tan)->IncRef();
	pGlobal->_pFunc_tanh = AScript_AssignFunction(tanh)->IncRef();
	AScript_AssignFunction(least_square);
	AScript_AssignFunction(fft);
	AScript_AssignFunction(integral);
}

AScript_ModuleTerminate()
{
}


AScript_EndModule(math)

AScript_RegisterModule(math)
