package jp.sfjp.armadillo.compression.lzhuf;

import java.io.*;
import java.util.*;
import jp.sfjp.armadillo.archive.lzh.*;
import jp.sfjp.armadillo.io.*;

/**
 * Huffman encoder for LH4, LH5, LH6, LH7.
 */
public final class LzhHuffmanEncoder implements LzssEncoderWritable {

    private static final int BUFFER_SIZE = 4096;

    private final BitOutputStream out;
    private final int threshold;
    private int index;
    private final int[] symbolBuffer;
    private final int[] offsetBuffer;
    private final int[] frequencyTable;

    /*
     * C: code table
     * L: code length table
     * T1: code huffman table
     * T2: length huffman table of T1
     * T3: huffman table bit length of offset
     */
    private int[] ct1, lt1, ct2, lt2, ct3, lt3;
    private int t1size;

    public LzhHuffmanEncoder(OutputStream out, int threshold) {
        this.out = new BitOutputStream(new FilterOutputStream(out) {
            @Override
            public void close() throws IOException {
                //
            }
        });
        this.threshold = threshold;
        this.index = 0;
        this.symbolBuffer = new int[BUFFER_SIZE];
        this.offsetBuffer = new int[BUFFER_SIZE];
        this.frequencyTable = new int[512];
    }

    @Override
    public void write(int symbol) throws IOException {
        ++frequencyTable[symbol];
        symbolBuffer[index++] = symbol;
        if (index >= BUFFER_SIZE)
            encode();
    }

    @Override
    public void writeMatched(int offset, int length) throws IOException {
        assert length >= threshold;
        final int symbol = length - threshold + 0x100;
        ++frequencyTable[symbol];
        symbolBuffer[index] = symbol;
        offsetBuffer[index++] = offset - 1;
        if (index >= BUFFER_SIZE)
            encode();
    }

    @Override
    public void flush() throws IOException {
        encode();
    }

    private void encode() throws IOException {
        if (index == 0)
            return;
        try {
            createTables();
            outputCodes();
        }
        catch (final LzhQuit ex) {
            throw ex;
        }
        catch (final RuntimeException ex) {
            final IOException exception = new LzhufException("huffman encoding error");
            exception.initCause(ex);
            throw exception;
        }
        finally {
            index = 0;
            ct1 = null;
            lt1 = null;
            ct2 = null;
            lt2 = null;
            ct3 = null;
            lt3 = null;
            Arrays.fill(frequencyTable, 0);
        }
    }

    private void createTables() {
        // create T1
        final LzhHuffmanTable table1 = LzhHuffmanTable.build(frequencyTable);
        ct1 = table1.codeTable;
        lt1 = table1.codeLengthTable;
        t1size = getTrimmedSize(lt1);
        final int size = getTrimmedSize(lt1);
        final int[] ft = new int[512];
        for (int i1 = 0; i1 < size;) {
            final int length = lt1[i1++];
            if (length == 0) {
                int count = 1;
                while (i1 < size && lt1[i1] == 0) {
                    ++i1;
                    ++count;
                }
                if (count <= 2)
                    ft[0] += count;
                else if (count <= 18)
                    ++ft[1];
                else if (count == 19) {
                    ++ft[0];
                    ++ft[1];
                }
                else
                    ++ft[2];
            }
            else {
                final int p = length + 2;
                ++ft[p];
            }
        }
        // create T2
        final LzhHuffmanTable table2 = LzhHuffmanTable.build(ft);
        ct2 = table2.codeTable;
        lt2 = table2.codeLengthTable;
        // create T3
        final int[] ft3 = new int[17];
        for (int i = 0; i < index; i++) {
            if (symbolBuffer[i] < 0x100)
                continue;
            final int offset = offsetBuffer[i];
            int bitLength = 0;
            while (true) {
                if (offset < 1 << bitLength)
                    break;
                ++bitLength;
            }
            ++ft3[bitLength];
        }
        final LzhHuffmanTable table3 = LzhHuffmanTable.build(ft3);
        ct3 = table3.codeTable;
        lt3 = table3.codeLengthTable;
    }

    private void outputCodes() throws IOException {
        // a size of LZSS elements
        final int t2size = getTrimmedSize(lt2);
        out.writeBits(index, 16);
        // T2
        out.writeBits(t2size, 5);
        for (int i = 0; i < t2size;) {
            final int length = lt2[i++];
            if (length <= 6)
                out.writeBits(length, 3);
            else {
                out.writeBits(0xFFFFFFFF, length - 4);
                out.writeBits(0, 1);
            }
            if (i == 3) {
                while (i < 6 && lt2[i] == 0)
                    ++i;
                out.writeBits(i - 3, 2);
            }
        }
        // T1 table size
        out.writeBits(t1size, 9);
        // T1
        for (int i = 0; i < t1size;) {
            final int length = lt1[i++];
            if (length == 0) {
                int count = 1;
                while (i < t1size && lt1[i] == 0) {
                    ++i;
                    ++count;
                }
                if (count <= 2)
                    for (int j = 0; j < count; j++)
                        out.writeBits(ct2[0], lt2[0]);
                else if (count <= 18) {
                    out.writeBits(ct2[1], lt2[1]);
                    out.writeBits(count - 3, 4);
                }
                else if (count == 19) {
                    out.writeBits(ct2[0], lt2[0]);
                    out.writeBits(ct2[1], lt2[1]);
                    out.writeBits(15, 4);
                }
                else {
                    out.writeBits(ct2[2], lt2[2]);
                    out.writeBits(count - 20, 9);
                }
            }
            else
                out.writeBits(ct2[length + 2], lt2[length + 2]);
        }
        // T3
        final int t3size = getTrimmedSize(lt3);
        out.writeBits(t3size, 4);
        if (t3size == 0)
            out.writeBits(0, 4);
        else
            for (int i = 0; i < t3size;) {
                final int length = lt3[i++];
                if (length <= 6)
                    out.writeBits(length, 3);
                else {
                    out.writeBits(0xFFFFFFFF, length - 4);
                    out.writeBits(0, 1);
                }
            }
        // LZSS
        for (int i = 0; i < index; i++) {
            final int symbol = symbolBuffer[i];
            out.writeBits(ct1[symbol], lt1[symbol]);
            if (symbol >= 0x100) {
                final int offset = offsetBuffer[i];
                int offsetBitLength = 1;
                while (true) {
                    if (offset < (1 << offsetBitLength))
                        break;
                    ++offsetBitLength;
                }
                if (offset == 0)
                    offsetBitLength = 0;
                out.writeBits(ct3[offsetBitLength], lt3[offsetBitLength]);
                if (offsetBitLength > 1)
                    out.writeBits(offset, offsetBitLength - 1);
            }
        }
        out.flush();
    }

    private static int getTrimmedSize(int[] a) {
        int i = a.length;
        while (i > 0 && a[i - 1] == 0)
            --i;
        return i;
    }

    @Override
    public void close() throws IOException {
        try {
            flush();
        }
        finally {
            out.close();
        }
    }

}
