/***********************************************************
	huf.c -- static Huffman
***********************************************************/
#include <stdlib.h>
#include "ar.h"

#define NP (MAXDICBIT + 1)
#define NT (CODE_BIT + 3)
#define PBIT 4                  /* smallest integer such that (1U << PBIT) > NP */
#define TBIT 5                  /* smallest integer such that (1U << TBIT) > NT */
#if NT > NP
#define NPT NT
#else
#define NPT NP
#endif

ushort left[2 * NC - 1], right[2 * NC - 1];
static uchar *buf, c_len[NC], pt_len[NPT];
static uint bufsiz = 0, blocksize;
static ushort c_freq[2 * NC - 1], c_table[4096], c_code[NC],
    p_freq[2 * NP - 1], pt_table[256], pt_code[NPT], t_freq[2 * NT - 1];

static int np;
static int pbit;

static void
init_parameter(struct lha_method *m)
{
    np   = m->dicbit + 1;
    pbit = m->pbit;
}

/***** encoding *****/

static void
count_t_freq(void)
{
    int i, k, n, count;

    for (i = 0; i < NT; i++)
        t_freq[i] = 0;
    n = NC;
    while (n > 0 && c_len[n - 1] == 0)
        n--;
    i = 0;
    while (i < n) {
        k = c_len[i++];
        if (k == 0) {
            count = 1;
            while (i < n && c_len[i] == 0) {
                i++;
                count++;
            }
            if (count <= 2)
                t_freq[0] += count;
            else if (count <= 18)
                t_freq[1]++;
            else if (count == 19) {
                t_freq[0]++;
                t_freq[1]++;
            }
            else
                t_freq[2]++;
        }
        else
            t_freq[k + 2]++;
    }
}

static void
write_pt_len(struct lzh_ostream *wp, int n, int nbit, int i_special)
{
    int i, k;

    while (n > 0 && pt_len[n - 1] == 0)
        n--;
    putbits(wp, nbit, n);
    i = 0;
    while (i < n) {
        k = pt_len[i++];
        if (k <= 6)
            putbits(wp, 3, k);
        else
            putbits(wp, k - 3, (1U << (k - 3)) - 2);
        if (i == i_special) {
            while (i < 6 && pt_len[i] == 0)
                i++;
            putbits(wp, 2, (i - 3) & 3);
        }
    }
}

static void
write_c_len(struct lzh_ostream *wp)
{
    int i, k, n, count;

    n = NC;
    while (n > 0 && c_len[n - 1] == 0)
        n--;
    putbits(wp, CBIT, n);
    i = 0;
    while (i < n) {
        k = c_len[i++];
        if (k == 0) {
            count = 1;
            while (i < n && c_len[i] == 0) {
                i++;
                count++;
            }
            if (count <= 2) {
                for (k = 0; k < count; k++)
                    putbits(wp, pt_len[0], pt_code[0]);
            }
            else if (count <= 18) {
                putbits(wp, pt_len[1], pt_code[1]);
                putbits(wp, 4, count - 3);
            }
            else if (count == 19) {
                putbits(wp, pt_len[0], pt_code[0]);
                putbits(wp, pt_len[1], pt_code[1]);
                putbits(wp, 4, 15);
            }
            else {
                putbits(wp, pt_len[2], pt_code[2]);
                putbits(wp, CBIT, count - 20);
            }
        }
        else
            putbits(wp, pt_len[k + 2], pt_code[k + 2]);
    }
}

static void
encode_c(struct lzh_ostream *wp, int c)
{
    putbits(wp, c_len[c], c_code[c]);
}

static void
encode_p(struct lzh_ostream *wp, uint p)
{
    uint c, q;

    c = 0;
    q = p;
    while (q) {
        q >>= 1;
        c++;
    }
    putbits(wp, pt_len[c], pt_code[c]);
    if (c > 1)
        putbits(wp, c - 1, p & (0xFFFFU >> (17 - c)));
}

static void
send_block(struct lzh_ostream *wp)
{
    uint i, k, flags, root, pos, size;

    root = make_tree(NC, c_freq, c_len, c_code);
    size = c_freq[root];
    putbits(wp, 16, size);
    if (root >= NC) {
        count_t_freq();
        root = make_tree(NT, t_freq, pt_len, pt_code);
        if (root >= NT) {
            write_pt_len(wp, NT, TBIT, 3);
        }
        else {
            putbits(wp, TBIT, 0);
            putbits(wp, TBIT, root);
        }
        write_c_len(wp);
    }
    else {
        putbits(wp, TBIT, 0);
        putbits(wp, TBIT, 0);
        putbits(wp, CBIT, 0);
        putbits(wp, CBIT, root);
    }
    root = make_tree(np, p_freq, pt_len, pt_code);
    if (root >= np) {
        write_pt_len(wp, np, pbit, -1);
    }
    else {
        putbits(wp, pbit, 0);
        putbits(wp, pbit, root);
    }
    pos = 0;
    for (i = 0; i < size; i++) {
        if (i % CHAR_BIT == 0)
            flags = buf[pos++];
        else
            flags <<= 1;
        if (flags & (1U << (CHAR_BIT - 1))) {
            encode_c(wp, buf[pos++] + (1U << CHAR_BIT));
            k = buf[pos++] << CHAR_BIT;
            k += buf[pos++];
            encode_p(wp, k);
        }
        else
            encode_c(wp, buf[pos++]);
        if (wp->unpackable)
            return;
    }
    for (i = 0; i < NC; i++)
        c_freq[i] = 0;
    for (i = 0; i < np; i++)
        p_freq[i] = 0;
}

static uint output_pos, output_mask;

void
output(struct lzh_ostream *wp, uint c, uint p)
{
    static uint cpos;

    if ((output_mask >>= 1) == 0) {
        output_mask = 1U << (CHAR_BIT - 1);
        if (output_pos >= bufsiz - 3 * CHAR_BIT) {
            send_block(wp);
            if (wp->unpackable)
                return;
            output_pos = 0;
        }
        cpos = output_pos++;
        buf[cpos] = 0;
    }
    buf[output_pos++] = (uchar) c;
    c_freq[c]++;
    if (c >= (1U << CHAR_BIT)) {
        buf[cpos] |= output_mask;
        buf[output_pos++] = (uchar) (p >> CHAR_BIT);
        buf[output_pos++] = (uchar) p;
        c = 0;
        while (p) {
            p >>= 1;
            c++;
        }
        p_freq[c]++;
    }
}

void
huf_encode_start(struct lzh_ostream *wp, struct lha_method *m)
{
    int i;

    init_parameter(m);

    if (bufsiz == 0) {
        bufsiz = 16 * 1024U;
        while ((buf = malloc(bufsiz)) == NULL) {
            bufsiz = (bufsiz / 10U) * 9U;
            if (bufsiz < 4 * 1024U)
                error("Out of memory.");
        }
    }
    buf[0] = 0;
    for (i = 0; i < NC; i++)
        c_freq[i] = 0;
    for (i = 0; i < np; i++)
        p_freq[i] = 0;
    output_pos = output_mask = 0;
    init_putbits(wp);
}

void
huf_encode_end(struct lzh_ostream *wp)
{
    if (!wp->unpackable) {
        send_block(wp);
        putbits(wp, CHAR_BIT - 1, 0);       /* flush remaining bits */
    }
}

/***** decoding *****/

static void
read_pt_len(struct lzh_istream *rp, int nn, int nbit, int i_special)
{
    int i, c, n;
    uint mask;

    n = getbits(rp, nbit);
    if (n == 0) {
        c = getbits(rp, nbit);
        for (i = 0; i < nn; i++)
            pt_len[i] = 0;
        for (i = 0; i < 256; i++)
            pt_table[i] = c;
    }
    else {
        i = 0;
        while (i < n) {
            c = rp->bitbuf >> (BITBUFSIZ - 3);
            if (c == 7) {
                mask = 1U << (BITBUFSIZ - 1 - 3);
                while (mask & rp->bitbuf) {
                    mask >>= 1;
                    c++;
                }
            }
            fillbuf(rp, (c < 7) ? 3 : c - 3);
            pt_len[i++] = c;
            if (i == i_special) {
                c = getbits(rp, 2);
                while (--c >= 0)
                    pt_len[i++] = 0;
            }
        }
        while (i < nn)
            pt_len[i++] = 0;
        make_table(nn, pt_len, 8, pt_table);
    }
}

static void
read_c_len(struct lzh_istream *rp)
{
    int i, c, n;
    uint mask;

    n = getbits(rp, CBIT);
    if (n == 0) {
        c = getbits(rp, CBIT);
        for (i = 0; i < NC; i++)
            c_len[i] = 0;
        for (i = 0; i < 4096; i++)
            c_table[i] = c;
    }
    else {
        i = 0;
        while (i < n) {
            c = pt_table[rp->bitbuf >> (BITBUFSIZ - 8)];
            if (c >= NT) {
                mask = 1U << (BITBUFSIZ - 1 - 8);
                do {
                    if (rp->bitbuf & mask)
                        c = right[c];
                    else
                        c = left[c];
                    mask >>= 1;
                } while (c >= NT);
            }
            fillbuf(rp, pt_len[c]);
            if (c <= 2) {
                if (c == 0)
                    c = 1;
                else if (c == 1)
                    c = getbits(rp, 4) + 3;
                else
                    c = getbits(rp, CBIT) + 20;
                while (--c >= 0)
                    c_len[i++] = 0;
            }
            else
                c_len[i++] = c - 2;
        }
        while (i < NC)
            c_len[i++] = 0;
        make_table(NC, c_len, 12, c_table);
    }
}

uint
decode_c(struct lzh_istream *rp)
{
    uint j, mask;

    if (blocksize == 0) {
        blocksize = getbits(rp, 16);
        read_pt_len(rp, NT, TBIT, 3);
        read_c_len(rp);
        read_pt_len(rp, np, pbit, -1);
    }
    blocksize--;
    j = c_table[rp->bitbuf >> (BITBUFSIZ - 12)];
    if (j >= NC) {
        mask = 1U << (BITBUFSIZ - 1 - 12);
        do {
            if (rp->bitbuf & mask)
                j = right[j];
            else
                j = left[j];
            mask >>= 1;
        } while (j >= NC);
    }
    fillbuf(rp, c_len[j]);
    return j;
}

uint
decode_p(struct lzh_istream *rp)
{
    uint j, mask;

    j = pt_table[rp->bitbuf >> (BITBUFSIZ - 8)];
    if (j >= np) {
        mask = 1U << (BITBUFSIZ - 1 - 8);
        do {
            if (rp->bitbuf & mask)
                j = right[j];
            else
                j = left[j];
            mask >>= 1;
        } while (j >= np);
    }
    fillbuf(rp, pt_len[j]);
    if (j != 0)
        j = (1U << (j - 1)) + getbits(rp, j - 1);
    return j;
}

void
huf_decode_start(struct lzh_istream *rp, struct lha_method *m)
{
    init_parameter(m);
    init_getbits(rp);
    blocksize = 0;
}
