//
// mathematical functions
//

#include "Expr.h"
#include "Module_math.h"

AScript_BeginModule(math)

// num = real(num):map
AScript_DeclareFunction(real)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(real)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetNumber(value.GetComplex().real());
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = imag(num):map
AScript_DeclareFunction(imag)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(imag)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(0.);
	} else if (value.IsComplex()) {
		result.SetNumber(value.GetComplex().imag());
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = arg(num):map
AScript_DeclareFunction(arg)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(arg)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(0.);
	} else if (value.IsComplex()) {
		result.SetNumber(std::arg(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = norm(num):map
AScript_DeclareFunction(norm)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(norm)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber() * value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetNumber(std::norm(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = conj(num):map
AScript_DeclareFunction(conj)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(conj)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(value.GetNumber());
	} else if (value.IsComplex()) {
		result.SetComplex(std::conj(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = acos(num):map
AScript_DeclareFunctionBegin(acos)
AScript_DeclareFunctionEnd(acos)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(acos)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::acos(value.GetNumber()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = asin(num):map
AScript_DeclareFunctionBegin(asin)
AScript_DeclareFunctionEnd(asin)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(asin)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::asin(value.GetNumber()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = atan(num):map
AScript_DeclareFunctionBegin(atan)
AScript_DeclareFunctionEnd(atan)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(atan)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::atan(value.GetNumber()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = atan2(num, num):map
AScript_DeclareFunction(atan2)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value1", VTYPE_AnyType);
	DeclareArg(env, "value2", VTYPE_AnyType);
}

AScript_ImplementFunction(atan2)
{
	const Value &value1 = context.GetValue(0);
	const Value &value2 = context.GetValue(1);
	Value result;
	if (value1.IsNumber() && value2.IsNumber()) {
		result.SetNumber(::atan2(value1.GetNumber(), value2.GetNumber()));
	} else if (value1.IsExprOrSymbol() || value2.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this,
						new Expr_Value(value1), new Expr_Value(value2)));
	} else if (value1.IsValid() && value2.IsValid()) {
		SetError_InvalidValType(sig, value1, value2);
	}
	return result;
}

// num = ceil(num):map
AScript_DeclareFunction(ceil)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(ceil)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::ceil(value.GetNumber()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = cos(num):map
AScript_DeclareFunctionBegin(cos)
AScript_DeclareFunctionEnd(cos)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(cos)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::cos(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::cos(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = cosh(num):map
AScript_DeclareFunction(cosh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(cosh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::cosh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::cosh(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = exp(num):map
AScript_DeclareFunctionBegin(exp)
AScript_DeclareFunctionEnd(exp)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(exp)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::exp(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::exp(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = abs(num):map
AScript_DeclareFunction(abs)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(abs)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::fabs(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::abs(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = floor(num):map
AScript_DeclareFunction(floor)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(floor)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::floor(value.GetNumber()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = log(num):map
AScript_DeclareFunctionBegin(log)
AScript_DeclareFunctionEnd(log)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(log)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() > 0.) {
			result.SetNumber(::log(value.GetNumber()));
		} else {
			SetError_InvalidValue(sig, value);
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::log(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = log10(num):map
AScript_DeclareFunctionBegin(log10)
AScript_DeclareFunctionEnd(log10)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(log10)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() > 0.) {
			result.SetNumber(::log10(value.GetNumber()));
		} else {
			SetError_InvalidValue(sig, value);
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::log10(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sin(num):map
AScript_DeclareFunctionBegin(sin)
AScript_DeclareFunctionEnd(sin)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(sin)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::sin(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::sin(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sinh(num):map
AScript_DeclareFunction(sinh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(sinh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::sinh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::sinh(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = sqrt(num):map
AScript_DeclareFunctionBegin(sqrt)
AScript_DeclareFunctionEnd(sqrt)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(sqrt)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		if (value.GetNumber() >= 0) {
			result.SetNumber(::sqrt(value.GetNumber()));
		} else {
			result.SetComplex(Complex(0, ::sqrt(-value.GetNumber())));
		}
	} else if (value.IsComplex()) {
		result.SetComplex(std::sqrt(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = tan(num):map
AScript_DeclareFunctionBegin(tan)
AScript_DeclareFunctionEnd(tan)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(tan)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::tan(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::tan(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// num = tanh(num):map
AScript_DeclareFunction(tanh)
{
	SetMode(RSLTMODE_Normal, MAP_On, FLAT_Off);
	DeclareArg(env, "value", VTYPE_AnyType);
}

AScript_ImplementFunction(tanh)
{
	const Value &value = context.GetValue(0);
	Value result;
	if (value.IsNumber()) {
		result.SetNumber(::tanh(value.GetNumber()));
	} else if (value.IsComplex()) {
		result.SetComplex(std::tanh(value.GetComplex()));
	} else if (value.IsExprOrSymbol()) {
		result.InitAsExpr(env, new Expr_Caller(this, new Expr_Value(value)));
	} else if (value.IsValid()) {
		SetError_InvalidValType(sig, value);
	}
	return result;
}

// expr = least_square(symbol:symbol, x[]:number, y[]:number, n?:number)
AScript_DeclareFunction(least_square)
{
	SetMode(RSLTMODE_Normal, MAP_Off, FLAT_Off);
	DeclareArg(env, "symbol", VTYPE_Symbol);
	DeclareArg(env, "x", VTYPE_Number, true);
	DeclareArg(env, "y", VTYPE_Number, true);
	DeclareArg(env, "n", VTYPE_Number, false, OCCUR_ZeroOrOnce);
}

AScript_ImplementFunction(least_square)
{
#if 0
	const Symbol *pSymbol = context.GetSymbol(0);
	const ValueList &valList1 = context.GetList(1);
	const ValueList &valList2 = context.GetList(2);
	if (valList1.size() != valList2.size()) {
		sig.SetError(ERR_ValueError, "lists have different length");
		return Value::Null;
	}
	if (valList1.empty()) {
		sig.SetError(ERR_ValueError, "lists are empty");
		return Value::Null;
	}
	Number n = static_cast<Number>(valList1.size());
	Number xSum = 0., ySum = 0., xxSum = 0., xySum = 0.;
	ValueList::const_iterator pValue1 = valList1.begin();
	ValueList::const_iterator pValue2 = valList2.begin();
	for ( ; pValue1 != valList1.end(); pValue1++, pValue2++) {
		if (!(pValue1->IsNumber() && pValue2->IsNumber())) {
			sig.SetError(ERR_ValueError, "non-number value is contained in the lists");
			return Value::Null;
		}
		xSum += pValue1->GetNumber();
		ySum += pValue2->GetNumber();
		xxSum += pValue1->GetNumber() * pValue1->GetNumber();
		xySum += pValue1->GetNumber() * pValue2->GetNumber();
	}
	Number a = (n * xySum - xSum * ySum) / (n * xxSum - xSum * xSum);
	Number b = (xxSum * ySum - xySum * xSum) / (n * xxSum - xSum * xSum);
	Expr *pExpr = env.GetFunc_Plus().Optimize(env, sig, ExprList(
		env.GetFunc_Multiply().Optimize(env, sig, ExprList(
			new Expr_Value(a), new Expr_Symbol(pSymbol))),
		new Expr_Value(b)));
	Value value;
	value.InitAsExpr(env, pExpr);
	return value;
#endif
	return Value::Null;
}

// list = integral(func, sequence)
AScript_DeclareFunction(integral)
{
	SetMode(RSLTMODE_Normal, MAP_Off, FLAT_Off);
}

AScript_ImplementFunction(integral)
{
	return Value::Null;
}

// Module entry
AScript_ModuleEntry()
{
	AScript_AssignValue(AScript_Symbol(e),	Value(2.71828182845904523536));
	AScript_AssignValue(AScript_Symbol(pi),	Value(3.14159265358979323846));
	AScript_AssignFunction(real);
	AScript_AssignFunction(imag);
	AScript_AssignFunction(arg);
	AScript_AssignFunction(norm);
	AScript_AssignFunction(conj);
	AScript_AssignFunction(acos);
	AScript_AssignFunction(asin);
	AScript_AssignFunction(atan);
	AScript_AssignFunction(atan2);
	AScript_AssignFunction(ceil);
	AScript_AssignFunction(cos);
	AScript_AssignFunction(cosh);
	AScript_AssignFunction(exp);
	AScript_AssignFunction(abs);
	AScript_AssignFunction(floor);
	AScript_AssignFunction(log);
	AScript_AssignFunction(log10);
	AScript_AssignFunction(sin);
	AScript_AssignFunction(sinh);
	AScript_AssignFunction(sqrt);
	AScript_AssignFunction(tan);
	AScript_AssignFunction(tanh);
	AScript_AssignFunction(integral);
	AScript_AssignFunction(least_square);
}

AScript_ModuleTerminate()
{
}

AScript_EndModule(math)
