package traffic;

import java.io.*;
import java.net.*;
import java.util.*;
import traffic.object.*;

public class IO implements Constants {
  private InetAddress m_address;
  private int m_port;
  private short m_ludpId = 0;
  private DatagramSocket m_socket;
  private final byte[] m_buf = new byte[PACKET_SIZE];

  public IO(InetAddress kernelAddress, int kernelPort) {
    try{
      m_address = kernelAddress;
      m_port    = kernelPort;
      m_socket  = new DatagramSocket();
    } catch (Exception e) { e.printStackTrace();  System.exit(1); }
  }

  /** CAUTION: The function receive() assumes that all UDP packets are recevied from the kernel surely */
  public int[] receive() {
    HashMap idLudpMap = new HashMap();
    while (true) {
      try {
	DatagramPacket pkt = new DatagramPacket(m_buf, m_buf.length);
	m_socket.receive(pkt);

	// [2..3]: LongUDP packet ID
	// [4..5]: order in this LongUDP packet
	// [6..7]: total UDP packet number
	// [8.. ]: body of LongUDP packet data
	int ludpID = (short) ((((int) m_buf[2] & 0xff) << 8) | ((int) m_buf[3] & 0xff));
	int nth    = (short) ((((int) m_buf[4] & 0xff) << 8) | ((int) m_buf[5] & 0xff));
	int total  = (short) ((((int) m_buf[6] & 0xff) << 8) | ((int) m_buf[7] & 0xff));
	int[] body = getIntArray(pkt.getData(), 8, pkt.getLength());

	int[][] data = (int[][]) idLudpMap.get(new Integer(ludpID));
	if (data == null) {
	  data = new int[total][];
	  idLudpMap.put(new Integer(ludpID), data);
	}
	data[nth] = body;

	Iterator it = idLudpMap.entrySet().iterator();
      loop:
	while (it.hasNext()) {
	  int[][] udps = (int[][]) ((Map.Entry) it.next()).getValue();
	  for (int i = 0;  i < udps.length;  i ++)
	    if (udps[i] == null)
	      continue loop;
	  it.remove();
	  int[] result = udpsToLudp(udps);
	  if (result[0] == KS_CONNECT_OK) {
	    m_address = pkt.getAddress();
	    m_port    = pkt.getPort();
	  }
	  return result;
	}
      } catch (Exception e) { e.printStackTrace();  System.exit(1); }
    }
  }

  /** @return int[] casted from src[from..to] */
  private int[] getIntArray(byte[] src, int from, int to) {
    int[] result = new int[(to - from) / 4];
    for (int i = 0, b = from;  i < result.length;   i ++, b += 4)
      result[i] = (int) ((src[b  ] & 0xff) << (8 * 3))
                + (int) ((src[b+1] & 0xff) << (8 * 2))
                + (int) ((src[b+2] & 0xff) <<  8     )
                + (int)  (src[b+3] & 0xff);
    return result;
  }

  private int[] udpsToLudp(int[][] udps) {
    int size = 0;
    for (int i = 0;  i < udps.length;  i ++)
      size += udps[i].length;

    int[] result = new int[size];
    for (int i = 0, pos = 0;  i < udps.length;  i++) {
      int len = udps[i].length;
      System.arraycopy(udps[i], 0, result, pos, len);
      pos += len;
    }

    return result;
  }

  private void send(int header, byte[] body) {
    final int MAX_BODY_SIZE = PACKET_SIZE - 8;  // 8: LongUDP header

    byte[] ludpBody = null;
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(baos);
    try {
      dos.writeInt(header);
      dos.writeInt(body.length);
      dos.write(body, 0, body.length);
      dos.writeInt(HEADER_NULL);
      dos.close();
      ludpBody = baos.toByteArray();
      baos.close();
    } catch (Exception e) { e.printStackTrace();  System.exit(1); }

    final short LUDPID = m_ludpId ++;
    final short NUM = (short) ((ludpBody.length + MAX_BODY_SIZE - 1) / MAX_BODY_SIZE);
    int offset = 0;
    for (short i = 0;  i < NUM;  i ++) {
      int size = Math.min(ludpBody.length - offset, MAX_BODY_SIZE);
      baos = new ByteArrayOutputStream();
      dos = new DataOutputStream(baos);
      try {
	// header for assembring LongUDP packets
	dos.writeShort(0x0008);       // magic number 0x0008
	dos.writeShort(LUDPID);  // LongUDP packet ID
	dos.writeShort(i);    // order of this LongUDP packete
	dos.writeShort(NUM);  // total UDP packet number

	dos.write(ludpBody, offset, size);
	dos.close();

 	byte[] buf = baos.toByteArray();
	baos.close();
	m_socket.send(new DatagramPacket(buf, buf.length, m_address, m_port));
      } catch (Exception e) { e.printStackTrace();  System.exit(1); }
      offset += size;
    }
    if(ASSERT)Util.myassert(offset == ludpBody.length);
  }

  public void sendConnect() {
    System.out.println("connecting");
    send(SK_CONNECT, new byte[] {0, 0, 0, 0});  // version must be 0
  }

  public void receiveConnectOk() {
    int[] data = receive();
    if(ASSERT)Util.myassert(data[0] == KS_CONNECT_OK, "received Long UDP packet isn't KS_CONNECT_OK");
    System.out.println("initializing");
    WORLD.update(data, 2, INITIALIZING_TIME);  // data[2..]: body
    WORLD.initialize();
  }

  public void sendAcknowledge() {
    send(SK_ACKNOWLEDGE, new byte[0]);
  }

  public void receiveCommands() {
    int[] data = receive();
    if(ASSERT)Util.myassert(data[0] == KS_COMMANDS, "received Long UDP packet isn't KS_COMMANDS");
    if(ASSERT)Util.myassert(data[2] == WORLD.time() + 1  ||  WORLD.time() == INITIALIZING_TIME, "wrong simulation time; lost UDP packet of KS_COMMANDS");
    WORLD.parseCommands(data);
  }

  public void sendUpdate() {
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(baos);

    for (int i = WORLD.movingObjectArray().length - 1;  i >= 0;  i --)
      WORLD.movingObjectArray()[i].output(dos);
    try {
      dos.writeInt(TYPE_NULL);
      dos.close();
      send(SK_UPDATE, baos.toByteArray());
      baos.close();
    } catch (IOException e) { e.printStackTrace();  System.exit(1); }
  }

  public void receiveUpdate() {
    int[] data = receive();
    if(ASSERT)Util.myassert(data[0] == KS_UPDATE, "received Long UDP packet isn't KS_UPDATE");
    if(ASSERT)Util.myassert(data[2] == WORLD.time(), "wrong simulation time; lost UDP packet of KS_UPDATE");
    WORLD.update(data, 3, data[2]);  // data[2]: time, data[3..]: body
  }
}
