#include "state_machine.hpp"

bool StateMachine::isOverflowAddtion(Byte a, Byte b, Byte result)
{
	bool sa = bit8::signBit(a);
	bool sb = bit8::signBit(b);
	bool sr = bit8::signBit(result);
	return ((~sa & ~sb) & sr) | ((sa & sb) & (~sr));
}
bool StateMachine::isOverflowSubtraction(Byte a, Byte b, Byte result)
{
	bool sa = bit8::signBit(a);
	bool sb = bit8::signBit(b);
	bool sr = bit8::signBit(result);
	return ((~sa & sb) & sr) | ((sa & ~sb) & (~sr));
}
void StateMachine::pushWord(Word word)
{
	Byte high = word>>8;
	Byte low = word&0xFF;
	push(high); // higher is upper.
	push(low);
}
Word StateMachine::pullWord()
{
	Byte low = pull();
	Byte high = pull();
	return bit8::lowHigh(low, high);
}

Word StateMachine::rts()
{
	if (m_call_stack.size() > 2) {
		m_call_stack.pop_back();
		m_call_stack.pop_back();
	}
	return pullWord();
}
void StateMachine::call(Word dest_addr, Word ret_addr)
{
	m_call_stack.push_back(dest_addr);
	m_call_stack.push_back(ret_addr);
	pushWord(ret_addr);
	m_return_addr = ret_addr;
	m_pc = dest_addr;
		
}
void StateMachine::compare(Byte left, Byte right)
{
	setPS(PS_N, left < right);
	setPS(PS_Z, left == right);
	setPS(PS_C, left >= right);
}

void StateMachine::execute(PInstruction instr)
{
	m_pc = instr->next();
	Byte& m = target(instr->mode(), instr->operand());
		
	Byte& a = m_reg[R_A];
	Byte& x = m_reg[R_X];
	Byte& y = m_reg[R_Y];
	
	bool temp_flag=false;
	Byte result;
		
	switch (instr->type()) {
	case opcode::LDA:
		a = m;
		cond(a);
		return;
	case opcode::LDX: x = m; cond(x); return;
	case opcode::LDY: y = m; cond(y); return;
	case opcode::ADC:
		temp_flag = ps(PS_C);
		result = a + m + (temp_flag?1:0);
		setPS(PS_C, (result < a));
		setPS(PS_V, isOverflowAddtion(a, m, result));
		a = result;
		cond(a);
		return; 
	case opcode::SBC:
		temp_flag = ps(PS_C);
		result = a - m - (temp_flag ? 0:1);
		setPS(PS_C, (result > a));
		setPS(PS_V, isOverflowSubtraction(a, m, result));
		a = result;
		a -= m;
		cond(a);
		return;
	case opcode::AND: a &= m; cond(a); return;
	case opcode::EOR: a ^= m; cond(a); return;
	case opcode::ORA: a |= m; cond(a); return;
	case opcode::BIT:
		setPS(PS_N, bit8::signBit(m));
		setPS(PS_Z, a == m);
		setPS(PS_V, bit8::byteBit(m,6));
		return;
	case opcode::CMP: compare(a,m); return;
	case opcode::CPX: compare(x,m); return;
	case opcode::CPY: compare(y,m); return;
		
	case opcode::STA: m = a; return;
	case opcode::STX: m = x; return;
	case opcode::STY: m = y; return;
	case opcode::INC: m++; cond(m); return;
	case opcode::INX: x++; cond(x); return;
	case opcode::INY: y++; cond(y); return;
	case opcode::DEC: m--; cond(m); return;
	case opcode::DEX: x--; cond(x); return;
	case opcode::DEY: y--; cond(y); return;
	case opcode::ASL:
		setPS(PS_C, bit8::byteBit(m, HIGHEST_BIT)); // carry bit
		m <<= 1; cond(m);
		return;
	case opcode::LSR:
		setPS(PS_C,bit8::byteBit(m, 0));
		m >>= 1; cond(m);
		return;
	case opcode::ROL:
		temp_flag = ps(PS_C);
		setPS(PS_C,bit8::byteBit(m, HIGHEST_BIT)); // carry bit
		m = m << 1 | (temp_flag?1:0);
		cond(m);
		return;
	case opcode::ROR:
		temp_flag = ps(PS_C);
		setPS(PS_C,bit8::byteBit(m, 0));
		m = m >> 1 | (temp_flag?(1<<7):0);
		cond(m);
		return;
		
	case opcode::PHA: push(a); return;
	case opcode::PHP: push(m_reg[R_PS]); return;
	case opcode::PLA: a = pull(); cond(a); return;
	case opcode::PLP: m_reg[R_PS] = pull();return;
			
	case opcode::CLC: setPS(PS_C, false); return;
	case opcode::CLD: setPS(PS_D, false); return;
	case opcode::CLI: setPS(PS_I, false); return;
	case opcode::CLV: setPS(PS_V, false); return;
	case opcode::SEC: setPS(PS_C, true); return;
	case opcode::SED: setPS(PS_D, true); return;
	case opcode::SEI: setPS(PS_I, true); return;
			
	case opcode::TAX: x = a; cond(x); return;
	case opcode::TAY: y = a; cond(y); return;
	case opcode::TSX: x = m_reg[R_SP]; cond(x); return;
	case opcode::TXA: a = x; cond(x); return;
	case opcode::TXS: m_reg[R_SP] = x ; cond(x); return;
	case opcode::TYA: a = y; cond(y); return;
			
	case opcode::JMP:
		if (instr->mode()==opcode::IND) {
			MemoryID id = m_mapper.id(instr->operand());
			m_pc = m_mapper.wordData(id);
		} else {
			m_pc = instr->operand();
		}
		return;
			
	case opcode::JSR:
		if (instr->mode()==opcode::IND) {
			MemoryID id = m_mapper.id(instr->operand());
			call(m_mapper.wordData(id), instr->next()-1);
		} else {
			call(instr->operand(), instr->next()-1);
		}
		return;
			
	case opcode::RTS: m_pc = rts()+1; return;
			
	case opcode::BRK:
		call(m_mapper.irq(), instr->next()+1);//increment (not byte length)
		push(m_reg[R_PS]);
		return;
	case opcode::RTI:
		m_reg[R_PS] = pull();
		m_pc = rts() - 1;
		return;
			
	case opcode::BEQ: m_cond_flag = ps(PS_Z); return;
	case opcode::BNE: m_cond_flag = !ps(PS_Z); return;
	case opcode::BCS: m_cond_flag = ps(PS_C); return;
	case opcode::BCC: m_cond_flag = !ps(PS_C); return;
	case opcode::BVC: m_cond_flag = ps(PS_V); return;
	case opcode::BVS: m_cond_flag = !ps(PS_V); return;
	case opcode::BMI: m_cond_flag = ps(PS_N); return;
	case opcode::BPL: m_cond_flag = !ps(PS_N); return;
			
	case opcode::NOP:
		return;//do nothing
	default:
		return;
	}
}

// read the value of operand.
Byte& StateMachine::target(AddressingMode mode, Word operand)
{
	MemoryID id = m_mapper.id(operand);
	switch (mode) {
	case opcode::ZERO: case opcode::ABS:
		return m_mapper.data(id);
	case opcode::ZERO_X: case opcode::ABS_X:
		return m_mapper.indexedData(id, m_reg[R_X]);
	case opcode::ZERO_Y: case opcode::ABS_Y: 
		return m_mapper.indexedData(id, m_reg[R_Y]);
	case opcode::PRE_IND: // ex. lda ($50, x)
		return m_mapper.preindirectData(id, m_reg[R_X]);
	case opcode::POST_IND:// ex. lda ($50), y
		return m_mapper.postindirectData(id, m_reg[R_Y]);
	case opcode::IMM:
		m_temp = operand;
		return m_temp;
	case opcode::REG_A:
		return m_reg[R_A];
	case opcode::IND: // this returns Word but
	default:
		return m_temp;
	}
}
MemoryID StateMachine::targetID(AddressingMode mode,
                              Word operand)
{
	MemoryID id = m_mapper.id(operand);
	switch (mode) {
	case opcode::ZERO: case opcode::ABS:
		return id;
	case opcode::ZERO_X: case opcode::ABS_X:
		return m_mapper.index(id, m_reg[R_X]);
	case opcode::ZERO_Y: case opcode::ABS_Y: 
		return m_mapper.index(id, m_reg[R_Y]);
	case opcode::PRE_IND: // ex. lda ($50, x)
		return m_mapper.preindirectID(id, m_reg[R_X]);
	case opcode::POST_IND:// ex. lda ($50), y
		return m_mapper.postindirectID(id, m_reg[R_Y]);
	case opcode::IND: // this returns Word but
		return m_mapper.preindirectID(id, 0);
	case opcode::IMM:
	case opcode::REG_A:
	default:
		return id;
	}
}

