﻿/**
 *	process class.
 *
 *	Version:
 *		$Revision$
 *	Date:
 *		$Date$
 *	License:
 *		MIT/X Consortium License
 *	History:
 *		$Log$
 */

module os.i386.process;

import std.stdint;
import std.c.string : memcpy;

import drt.list;

import os.i386.cache;
import os.i386.errorcode;
import os.i386.page;
import os.i386.segment;
import os.i386.sync;
import os.i386.systemcall;
import os.i386.tss;

private {

	struct Context {
		uint32_t esp;
		uint32_t eip;
		uint16_t fs;
		uint16_t gs;
	}

	extern(C) void switchToNextThread(Context* prev, Context* next);
	extern(C) void threadEntryPoint();
	
	Process processes_;
	Thread threads_;
}

alias void function(void*) ThreadEntryPoint;

/// A memory region attached to processes.
class MemoryRegion {
	
	enum Type {
		CODE,
		RO_DATA,
		DATA,
		STACK,
	}
	
	struct MappedProcess {
		Process process;
		size_t offset;
		
		mixin ListMixture;
	}
	
	this(Type type, size_t len) {
		type_ = type;
		length_ = len;
	}
	
	~this() {
		// free all list elements.
		for(auto p = mappedProcesses_; p !is null;) {
			auto tmp = p;
			p = p.next;
			deallocateStructFromCache(tmp);
		}
	}
	
	size_t length() {return length_;}
	
	/// extend region
	bool extend(size_t len) {
		// check already mapped
		foreach(mp; this) {
			auto begin = mp.offset + length();
			auto end = begin + len;
			if(mp.process.findRegion(begin, end) !is null) {
				// already mapped
				return false;
			}
		}
		
		length_ += len;
		
		return true;
	}
	
	/// map to process memory space
	bool map(Process p, size_t off) {
		if(p is null) {
			return false;
		}
		
		auto mp = allocateStructFromCache!(MappedProcess)();
		if(mp is null) {
			return false;
		}
		
		mp.process = p;
		mp.offset = off;
		if(mappedProcesses_ is null) {
			mappedProcesses_ = mp;
		} else {
			mappedProcesses_.back.insertNext(mp);
		}
		
		return true;
	}
	
	/// for each processes used this memory region
	int opApply(int delegate(inout MappedProcess) dg) {
		if(mappedProcesses_ !is null) {
			for(auto mp = mappedProcesses_; mp !is null; mp = mp.next) {
				if(auto result = dg(*mp)) {
					return result;
				}
			}
		}
		return 0;
	}
	
private:
	
	MappedProcess* mappedProcesses_;
	Thread thread_;
	size_t length_;
	Type type_;
}

/// hold code, static data and heap.
class Process : KernelObject {
	
	enum State {
		CREATED,
		RUNNING,
		TERMINATED,
	}
	
	struct MappedRegion {
		MemoryRegion region;
		size_t offset;
		
		mixin ListMixture;
	}
	
	this() {
		if(processes_ is null) {
			processes_ = this;
		} else {
			processes_.insertBefore(this);
		}
	}
	
	~this() {
		if(state_ != State.TERMINATED) {
			terminate();
		}
		erase();
	}
	
	PageDirectory* pageDirectory() {
		initializePageDirectory();
		return dir_;
	}
	
	State state() {return state_;}
	void state(State s) {state_ = s;}
	
	void terminate() {
		// free all list elements.
		for(auto p = mappedRegions_; p !is null;) {
			auto tmp = p;
			p = p.next;
			deallocateStructFromCache(tmp);
		}
		
		if(dir_ !is null) {
			// free page tables
			foreach(e; *dir_) {
				if(e.present) {
					deallocatePages(e.baseAddress, 1);
				}
			}
			
			// free page directory
			deallocatePages(dir_, 1);
		}
		
		state_ = State.TERMINATED;
	}
	
	bool addRegion(MemoryRegion region, size_t off) {
		if(region is null) {
			return false;
		}
		
		if(findRegion(off, off + region.length)) {
			// already mapped region
			return false;
		}
		
		auto mr = allocateStructFromCache!(MappedRegion)();
		if(mr is null) {
			// memory not enough
			return false;
		}
		
		mr.region = region;
		mr.offset = off;
		if(mappedRegions_ is null) {
			mappedRegions_ = mr;
		} else {
			mappedRegions_.back.insertNext(mr);
		}
		
		return true;
	}
	
	void removeRegion(MemoryRegion region) {
		// find region
		MappedRegion* found = null;
		for(auto mr = mappedRegions_; mr !is null; mr = mr.next) {
			if(mr.region is region) {
				found = mr;
				break;
			}
		}
		
		// deallocate region
		if(found !is null) {
			found.erase();
			deallocateStructFromCache(found);
		}
	}
	
	/// find region from an address
	MemoryRegion findRegion(uintptr_t p) {return findRegion(p, p);}
	
	/// find region from address range
	MemoryRegion findRegion(uintptr_t low, uintptr_t high) {
		if(mappedRegions_ is null) {
			return null;
		}
		
		for(auto mr = mappedRegions_; mr !is null; mr = mr.next) {
			auto begin = mr.offset;
			auto end = mr.offset + mr.region.length;
			if((begin <= low && low < end) || (begin <= high && high < end)) {
				// found
				return mr.region;
			}
		}
		
		// not found
		return null;
	}
	
	bool hasThread() {return threads_ > 0;}
	
	mixin ListMixture;
	
private:
	
	void addThread(Thread t) {++threads_;}
	void removeThread(Thread t) {--threads_;}
	
	/// for lazy initialize a page directory
	bool initializePageDirectory() {
		if(dir_ is null) {
			dir_ = cast(PageDirectory*) allocatePages(1, Page.Purpose.KERNEL);
			if(dir_ is null) {
				return false;
			}
			memcpy(dir_, getKernelPageDirectory(), PageDirectory.sizeof);
		}
		return true;
	}
	
	/// for page fault
	bool enableDirectoryEntry(size_t i) {
		if(!initializePageDirectory()) {
			return false;
		}
		
		if(!(*dir_)[i].present) {
			// allocate a new page table. table is empty.
			auto table = cast(PageTable*) allocatePages(1, Page.Purpose.KERNEL);
			if(table is null) {
				return false;
			}
			
			// update directory entry
			(*dir_)[i] = PageEntry(table, true, true, true);
		}
		
		return true;
	}
	
	PageDirectory* dir_;
	MappedRegion* mappedRegions_;
	State state_ = State.CREATED;
	size_t threads_;
}

/// execute context.
class Thread : KernelObject {
	
	enum State {
		CREATED,
		RUNNING,
		BLOCKED,
		TERMINATED,
	}
	
	const void* USER_STACK_BOTTOM = cast(void*) 0xFFFF_FFFF;
	const size_t DEFAULT_STACK_SIZE = 2 * 1024 * 1024;
	
	this(Process process, size_t stackSize = DEFAULT_STACK_SIZE) {
		process_ = process;
		stackSize_ = stackSize;
		process.addThread(this);
		
		if(threads_ is null) {
			threads_ = this;
		} else {
			threads_.back.insertNext(this);
		}
	}
	
	~this() {
		if(state_ != State.TERMINATED) {
			terminate();
		}
		erase();
		process_.removeThread(this);
	}
	
	/// system call
	static SysRes onForkThread(SysArg n, SysArg a1, SysArg a2, SysArg a3, SysArg a4, SysArg a5) {
		auto cur = getCurrentThread();
		if(cur is null || cur.process is null) {
			return Errno.NOT_INITIALIZED_SYSTEM;
		}
		
		auto thread = new Thread(cur.process);
		if(!thread.initializeContext(cast(ThreadEntryPoint) a1, cast(void*) a2)) {
			delete thread;
			return Errno.NOT_ENOUGH_MEMORY;
		}
		
		thread.state_ = State.RUNNING;
		
		return cast(SysRes) thread;
	}
	
	bool firstBegin(ThreadEntryPoint fn) {		
		if(!initializeContext(fn, null)) {
			return false;
		}
		
		setupNextThreadState(this);
		
		auto stack = kernelStackPointer() + uint32_t.sizeof * 3;
		asm {
			mov EAX, fn;
			mov ESP, stack;
			jmp EAX;
		}
		
		return true;
	}
	
	Process process() {return process_;}
	
	State state() {return state_;}
	
	static bool switchTo(Thread prev, Thread next) {
		if(next.state != State.RUNNING) {
			return false;
		}
		
		switchToNextThread(&prev.context_, &next.context_);
			
		return true;
	}
	
	void terminate() {
		for(auto te = stackTables_; te !is null;) {
			auto tmp = te;
			te = te.next;
			
			// free stack pages
			foreach(e; *tmp.table) {
				if(e.present) {
					deallocatePages(e.baseAddress, 1);
				}
			}
			
			// free a page table
			deallocatePages(tmp.table, 1);
			
			// free a stack page table entry
			deallocateStructFromCache(tmp);
		}
		
		stackTables_ = null;
		
		if(kernelStack_ !is null) {
			deallocatePages(kernelStack_, 1);
			kernelStack_ = null;
		}
		
		state_ = State.TERMINATED;
	}
	
	size_t userStackSize() {return stackSize_;}
	
	void* kernelStack() {return kernelStack_;}
	void* kernelStackBottom() {return kernelStack_ + PAGE_LENGTH;}
	void* kernelStackPointer() {return cast(void*) context_.esp;}
	
	bool extendStack() {
		auto p = cast(void*)(USER_STACK_BOTTOM - stackPages_ * PAGE_LENGTH);
		
		if(process_ is null) {
			return false;
		}
		
		// check stack size
		if(p < (USER_STACK_BOTTOM - stackSize_ + 1)) {
			return false;
		}
		
		auto dir = process_.pageDirectory;
		if(dir is null) {
			return false;
		}
		
		// get page table
		auto i = getPageDirectoryIndex(p);
		if(!(*dir)[i].present) {
			// allocate a page table
			auto te = extendStackPageTable();
			if(te is null) {
				return false;
			}
			(*dir)[i] = PageEntry(te.table);
		}
		auto table = cast(PageTable*) (*dir)[i].baseAddress;
		
		// Is a stack page present?
		auto j = getPageTableIndex(p);
		if(!(*table)[j].present) {
			// allocate a new stack page
			auto page = allocatePages(1, Page.Purpose.USER);
			if(page is null) {
				return false;
			}
			
			++stackPages_;
			
			// user mode r/w page
			(*table)[j] = PageEntry(page, true, true, true);
		}
		
		return true;
	}
	
	mixin ListMixture;
	
private:
	
	struct StackPageTableEntry {
		StackPageTableEntry* next;
		PageTable* table;
	}
	
	bool switchPageTable() {
		// get page directory
		auto dir = process_.pageDirectory();
		if(dir is null) {
			return false;
		}
		
		// set stack page tables
		if(auto p = stackTables_) {
			for(auto sp = USER_STACK_BOTTOM; p.next !is null; p = p.next, sp += PAGE_LENGTH) {
				(*dir)[getPageDirectoryIndex(sp)] = PageEntry(p.table, true, true);
			}
		}
		
		return true;
	}
	
	StackPageTableEntry* extendStackPageTable() {
		// allocate entry
		auto entry = allocateStructFromCache!(StackPageTableEntry)();
		if(entry is null) {
			return null;
		}
		
		// allocate page table
		entry.table = cast(PageTable*) allocatePages(1, Page.Purpose.KERNEL);
		if(entry.table is null) {
			deallocateStructFromCache(entry);
			return null;
		}
		
		// insert to last
		if(auto p = stackTables_) {
			for(; p.next !is null; p = p.next) {}
			p.next = entry;
		} else {
			stackTables_ = entry;
		}
		
		return entry;
	}
	
	bool initializeContext(ThreadEntryPoint fn, void* arg) {
		if(kernelStack_ is null) {
			// allocate stack page
			kernelStack_ = allocatePages(1, Page.Purpose.KERNEL);
			if(kernelStack_ is null) {
				return false;
			}
			
			// save thread information
			*(cast(Thread*)kernelStack_) = this;
			
			// setup stack
			
			// fn argument;
			auto fnArgument = cast(uint32_t*)(kernelStackBottom() - uint32_t.sizeof);
			*fnArgument = cast(uint32_t) arg;
			
			// fn return address
			auto fnReturn = fnArgument - 1;
			*fnReturn = 0;
			
			// eflags
			auto eflags = fnReturn - 1;
			*eflags = getTaskStateSegment().eflags | (1U<<9);
			
			// code segment
			auto cs = eflags - 1;
			*cs = GdtSelector.KERNEL_CODE;
			
			// return address
			auto eip = cs - 1;
			*eip = cast(uint32_t) fn;
			
			// stack top
			context_.esp = cast(uint32_t) eip;
			
			// entry point
			context_.eip = cast(uint32_t) &threadEntryPoint;
		}
		return true;
	}
	
	Context context_;
	State state_ = State.CREATED;
	Process process_;
	StackPageTableEntry* stackTables_;
	void* kernelStack_;
	size_t stackSize_;
	size_t stackPages_;
}

Thread getCurrentThread() {
	Thread result;
	asm {
		mov ECX, ESP;
		and ECX, 0xffff_f000;
		mov EAX, [ECX];
		mov result, EAX;
	}
	return result;
}

bool schedule() {
	if(auto cur = getCurrentThread()) {
		return Thread.switchTo(cur, (cur.next !is null) ? cur.next : threads_);
	}
	return false;
}

void initializeProcessSubsystem() {
	setSystemCall(SystemCallNumber.FORK_THREAD, &Thread.onForkThread);
}

Thread forkThread(void function(void*) fn, void* arg = null) {
	return cast(Thread) callSystemCall(SystemCallNumber.FORK_THREAD, cast(SysArg) fn, cast(SysArg) arg);
}

extern(C) void setupNextThreadState(Thread next) {
	next.switchPageTable();
	auto nextDir = next.process_.pageDirectory();
	loadPageDirectory(nextDir);
	getTaskStateSegment().esp0 = cast(uint32_t) next.kernelStackBottom();
	getTaskStateSegment().cr3 = cast(uint32_t) nextDir;
	next.state_ = Thread.State.RUNNING;
}
