package jp.sfjp.armadillo.compression.lzhuf;

import java.io.*;
import jp.sfjp.armadillo.io.*;

/**
 * Huffman decoder for LH4, LH5, LH6, LH7.
 */
public final class LzhHuffmanDecoder implements LzssDecoderReadable {

    /** Work table bit length (16 bits and a guard) */
    private static final int WORK_TABLE_BITLENGTH = 16 + 1;

    private BitInputStream bin;
    private int blockSize;
    private int symbolMaxBitLength;
    private short[] symbolLengthTable;
    private short[] symbolCodeTable;
    private int offsetMaxBitLength;
    private short[] offsetLengthTable;
    private short[] offsetCodeTable;

    public LzhHuffmanDecoder(InputStream in) {
        this(in, -1);
    }

    /**
     * @param in InputStream
     * @param limit when positive value: bytes to decode
     *              when negative value: infinite (until EOF)
     */
    public LzhHuffmanDecoder(final InputStream in, final long limit) {
        this.bin = (limit < 0) ? new BitInputStream(in) : new BitInputStream(new InputStream() {
            private long remaining = limit;

            @Override
            public int read() throws IOException {
                if (remaining <= 0)
                    return -1;
                final int read = in.read();
                if (read != -1)
                    --remaining;
                return read;
            }

        });
        this.blockSize = 0;
    }

    @Override
    public int read() throws IOException {
        try {
            assert blockSize >= 0;
            if (blockSize == 0) {
                if (bin.prefetch() == -1)
                    return -1;
                final int n = bin.readBits(16);
                if (n == -1)
                    return -1;
                this.blockSize = n;
                assert blockSize > 0 : "block size = " + blockSize;
                createSymbolTables();
                createOffsetTables();
            }
            --blockSize;
            final int b = bin.prefetchBits(symbolMaxBitLength);
            assert b != -1;
            final int code = symbolCodeTable[b];
            bin.readBits(symbolLengthTable[code]);
            assert code >= 0 && code < 511;
            return code;
        }
        catch (final RuntimeException ex) {
            throw new LzhufException("decode error", ex);
        }
    }

    @Override
    public int readOffset() throws IOException {
        try {
            final int b = bin.prefetchBits(offsetMaxBitLength);
            assert b != -1;
            final int code = offsetCodeTable[b];
            final int codeLength = offsetLengthTable[code];
            if (codeLength > 0)
                bin.readBits(codeLength);
            assert code >= 0 && codeLength >= 0;
            int offset;
            if (code > 1)
                offset = (1 << (code - 1)) | bin.readBits(code - 1);
            else
                offset = code;
            assert offset >= 0;
            return offset;
        }
        catch (final RuntimeException ex) {
            throw new LzhufException("decode error", ex);
        }
    }

    private void createSymbolTables() throws IOException {
        // initialize symbolMaxBitLength, symbolLengthTable and symbolCodeTable
        final short[] lengthList = readCodeLengthList(5, 3);
        final int blength = getMaxBitSize(lengthList);
        final short[] table = createCodeTable(lengthList, blength);
        final int n = bin.readBits(9);
        if (n < 1)
            throw new LzhufException("invalid compressed data: number of code lengths=" + n);
        final short[] codeLengthList = new short[n];
        for (int i = 0; i < codeLengthList.length;) {
            final int code = bin.prefetchBits(blength);
            if (code == -1)
                throw new LzhufException("EOF appeared while reading symbol length list");
            final int length = table[code];
            final int bitLength = lengthList[length];
            bin.readBits(bitLength);
            switch (length) {
                case 0:
                    ++i;
                    break;
                case 1:
                    i += bin.readBits(4) + 3;
                    break;
                case 2:
                    i += bin.readBits(9) + 20;
                    break;
                default:
                    codeLengthList[i++] = (short)(length - 2);
            }
        }
        final int maxBitLength = getMaxBitSize(codeLengthList);
        this.symbolMaxBitLength = maxBitLength;
        this.symbolLengthTable = codeLengthList;
        this.symbolCodeTable = createCodeTable(codeLengthList, maxBitLength);
    }

    private void createOffsetTables() throws IOException {
        // initialize offsetMaxBitLength, offsetLengthTable and offsetCodeTable
        short[] codeLengthList = readCodeLengthList(4, -1);
        if (codeLengthList.length == 0) {
            final int offset = bin.readBits(4);
            codeLengthList = new short[offset + 1];
            final short[] codeTable = new short[]{(short)offset, (short)offset};
            this.offsetMaxBitLength = 1;
            this.offsetLengthTable = codeLengthList;
            this.offsetCodeTable = codeTable;
        }
        else {
            final int maxBitLength = getMaxBitSize(codeLengthList);
            this.offsetMaxBitLength = maxBitLength;
            this.offsetLengthTable = codeLengthList;
            this.offsetCodeTable = createCodeTable(codeLengthList, maxBitLength);
        }
    }

    private short[] readCodeLengthList(int nBits, int special) throws IOException {
        final int n = bin.readBits(nBits);
        final short[] list = new short[n];
        for (int i = 0; i < n; i++) {
            if (i == special)
                i += bin.readBits(2);
            int length = bin.readBits(3);
            if (length == 7)
                while (bin.readBit() == 1)
                    ++length;
            list[i] = (short)length;
        }
        return list;
    }

    private static int getMaxBitSize(short[] bitLengthList) {
        int max = 0;
        for (int i = 0; i < bitLengthList.length; i++)
            if (bitLengthList[i] > max)
                max = bitLengthList[i];
        return max;
    }

    private static short[] createCodeTable(short[] lengthList, int maxBitLength) {
        final int[] codeList = createCodeList(lengthList);
        final int tableSize = (1 << maxBitLength);
        final short[] table = new short[tableSize];
        for (int i = 0; i < lengthList.length; i++)
            if (lengthList[i] > 0) {
                final int rangeBits = maxBitLength - lengthList[i];
                final int start = codeList[i] << rangeBits;
                final int next = start + (1 << rangeBits);
                for (int index = start; index < next; index++)
                    table[index] = (short)i;
            }
        return table;
    }

    private static int[] createCodeList(short[] codeLengthList) {
        assert codeLengthList.length > 0;
        if (codeLengthList.length == 1)
            return new int[1];
        final int[] counts = new int[WORK_TABLE_BITLENGTH];
        for (int i = 0; i < codeLengthList.length; i++)
            ++counts[codeLengthList[i]];
        final int[] baseCodes = new int[WORK_TABLE_BITLENGTH];
        // i = bit length - 1
        for (int i = 0; i < WORK_TABLE_BITLENGTH - 1; i++)
            baseCodes[i + 1] = baseCodes[i] + counts[i + 1] << 1;
        final int[] codeList = new int[codeLengthList.length];
        for (int i = 0; i < codeList.length; i++) {
            final int codeLength = codeLengthList[i];
            if (codeLength > 0)
                codeList[i] = baseCodes[codeLength - 1]++;
        }
        return codeList;
    }

    @Override
    public void close() throws IOException {
        try {
            bin.close();
        }
        finally {
            bin = null;
            symbolLengthTable = null;
            symbolCodeTable = null;
            offsetLengthTable = null;
            offsetCodeTable = null;
        }
    }

}
