/***********************************************************
	huf.c -- static Huffman
***********************************************************/
#include <stdlib.h>
#include "ar.h"

/***** encoding *****/

static void
count_t_freq(struct lzh_ostream *wp, uint8_t *c_len, uint16_t *t_freq)
{
    int i, k, n, count;

    for (i = 0; i < NT; i++)
        t_freq[i] = 0;
    n = NC;
    while (n > 0 && c_len[n - 1] == 0)
        n--;
    i = 0;
    while (i < n) {
        k = c_len[i++];
        if (k == 0) {
            count = 1;
            while (i < n && c_len[i] == 0) {
                i++;
                count++;
            }
            if (count <= 2)
                t_freq[0] += count;
            else if (count <= 18)
                t_freq[1]++;
            else if (count == 19) {
                t_freq[0]++;
                t_freq[1]++;
            }
            else
                t_freq[2]++;
        }
        else
            t_freq[k + 2]++;
    }
}

static void
write_pt_len(struct lzh_ostream *wp, int n, int nbit, int i_special,
             uint8_t *pt_len)
{
    int i, k;

    while (n > 0 && pt_len[n - 1] == 0)
        n--;
    putbits(wp, nbit, n);
    i = 0;
    while (i < n) {
        k = pt_len[i++];
        if (k <= 6)
            putbits(wp, 3, k);
        else
            putbits(wp, k - 3, (1U << (k - 3)) - 2);
        if (i == i_special) {
            while (i < 6 && pt_len[i] == 0)
                i++;
            putbits(wp, 2, (i - 3) & 3);
        }
    }
}

static void
write_c_len(struct lzh_ostream *wp, uint8_t *c_len, uint8_t *t_len, uint16_t *t_code)
{
    int i, k, n, count;

    n = NC;
    while (n > 0 && c_len[n - 1] == 0)
        n--;
    putbits(wp, CBIT, n);
    i = 0;
    while (i < n) {
        k = c_len[i++];
        if (k == 0) {
            count = 1;
            while (i < n && c_len[i] == 0) {
                i++;
                count++;
            }
            if (count <= 2) {
                for (k = 0; k < count; k++)
                    putbits(wp, t_len[0], t_code[0]);
            }
            else if (count <= 18) {
                putbits(wp, t_len[1], t_code[1]);
                putbits(wp, 4, count - 3);
            }
            else if (count == 19) {
                putbits(wp, t_len[0], t_code[0]);
                putbits(wp, t_len[1], t_code[1]);
                putbits(wp, 4, 15);
            }
            else {
                putbits(wp, t_len[2], t_code[2]);
                putbits(wp, CBIT, count - 20);
            }
        }
        else
            putbits(wp, t_len[k + 2], t_code[k + 2]);
    }
}

static void
encode_c(struct lzh_ostream *wp, int c, uint8_t *c_len, uint16_t *c_code)
{
    putbits(wp, c_len[c], c_code[c]);
}

static void
encode_p(struct lzh_ostream *wp, uint32_t p, uint8_t *p_len, uint16_t *p_code)
{
    uint16_t c;
    uint32_t q;

    c = 0;
    q = p;
    while (q) {
        q >>= 1;
        c++;
    }
    putbits(wp, p_len[c], p_code[c]);
    if (c > 1)
        putbits(wp, c - 1, p & (0xFFFFU >> (17 - c)));
}

/*
    size        frequency       bitlength       Huffman coding
   -----------------------------------------------------------
     NC         c_freq          c_len           c_code
     NT         t_freq          t_len           t_code
     np         p_freq          p_len           p_code

  output format for a block.

    +-----------+
    | blocksize |
    +-----------+
     16bit

    +-----+--------------------+
    | len |       t_len        | Huffman tree for c_len[]
    +-----+--------------------+
      5bit        ?? bit
      TBIT

    +-------+------------------+
    |  len  |     c_len        | Huffman tree for characters and length
    +-------+------------------+
      9bit        ?? bit
      CBIT

    +---------+--------------------+
    |   len   |   p_len            | Huffman tree for offset
    +---------+--------------------+
     pbit         ?? bit
                               (pbit=4bit(lh4,5) or 5bit(lh6,7))

    +---------------------+
    |  encoding text      |
    +---------------------+


  In special case, only one kind characters in a block.

                  TBIT: 5 bits
                  CBIT: 9 bits

    +-----------+
    | blocksize |
    +-----------+
     16bit

    +-----+-----+
    |  0  |  0  | Huffman tree for c_len[]
    +-----+-----+
      TBIT  TBIT

    +-------+-------+
    |  0    |  0    | Huffman tree for characters and length
    +-------+-------+
      CBIT    CBIT

    +---------+--------------+
    |  0      | offset value | Huffman tree for offset
    +---------+--------------+
     pbit         pbit
                               (pbit=4bit(lh4,5) or 5bit(lh6,7))

    +---------------------+
    |  encoding text      |
    +---------------------+
 */
static void
send_block(struct lzh_ostream *wp)
{
    unsigned int i, k, flags, root, pos, size; /* XXX: variable type is vague */

    uint8_t c_len[NC];
    uint16_t c_code[NC];

    uint8_t t_len[NT];
    uint16_t t_code[NT];

    uint8_t p_len[NP];
    uint16_t p_code[NP];

    /* make Huffman tree for characters and length */
    root = make_tree(NC, wp->c_freq, c_len, c_code);
    size = wp->c_freq[root];
    putbits(wp, 16, size);
    if (root >= NC) {
        uint16_t t_freq[2 * NT - 1];

        /* make Huffman tree for c_len */
        count_t_freq(wp, c_len, t_freq);
        root = make_tree(NT, t_freq, t_len, t_code);
        if (root >= NT) {
            write_pt_len(wp, NT, TBIT, 3, t_len);
        }
        else {
            /* only one kind */
            putbits(wp, TBIT, 0);
            putbits(wp, TBIT, root);
        }
        write_c_len(wp, c_len, t_len, t_code);
    }
    else {
        /* only one kind */
        putbits(wp, TBIT, 0);
        putbits(wp, TBIT, 0);
        putbits(wp, CBIT, 0);
        putbits(wp, CBIT, root);
    }

    /* make Huffman tree for offset */
    root = make_tree(wp->np, wp->p_freq, p_len, p_code);
    if (root >= wp->np) {
        write_pt_len(wp, wp->np, wp->pbit, -1, p_len);
    }
    else {
        putbits(wp, wp->pbit, 0);
        putbits(wp, wp->pbit, root);
    }

    /* write Huffman encoding */
    pos = 0;
    for (i = 0; i < size; i++) {
        if (i % CHAR_BIT == 0)
            flags = wp->buf[pos++];
        else
            flags <<= 1;
        if (flags & (1U << (CHAR_BIT - 1))) {
            /* write length */
            encode_c(wp, wp->buf[pos++] + (1U << CHAR_BIT), c_len, c_code);
            /* write offset */
            k = wp->buf[pos++] << CHAR_BIT;
            k += wp->buf[pos++];
            encode_p(wp, k, p_len, p_code);
        }
        else {
            /* write character */
            encode_c(wp, wp->buf[pos++], c_len, c_code);
        }
        if (wp->unpackable)
            return;
    }

    /* clear frequency table */
    for (i = 0; i < NC; i++)
        wp->c_freq[i] = 0;
    for (i = 0; i < wp->np; i++)
        wp->p_freq[i] = 0;
}

/*
  call with output(wp, c, 0)        c <  256
         or output(wp, len, off)  len >= 256


one block:

output_mask
             128  64   32             16   8              4    2    1
        +----+----+----+----+----+----+----+----+----+----+----+----+----+
buf     | 40 | c1 | c2 |len |   off   | c4 |len |   off   | c6 | c7 | c8 |
        +----+----+----+----+----+----+----+----+----+----+----+----+----+
        cpos                                                            /
                                                                       /
                                                               output_pos
   buf[cpos] = 32 + 8 (it points <len, off>)

c_freq[] has frequency of cN and len.
p_freq[] has frequency of length of off bits.
*/
void
output(struct lzh_ostream *wp, uint16_t c, uint32_t p)
{
    wp->output_mask >>= 1;
    if (wp->output_mask == 0) {
        wp->output_mask = 128;
        /* max byte size of one block: 3 * CHAR_BIT + 1 */
        if (wp->output_pos > wp->bufsiz - (3 * CHAR_BIT + 1)) {
            send_block(wp);
            if (wp->unpackable)
                return;
            wp->output_pos = 0;
        }
        wp->cpos = wp->output_pos++;
        wp->buf[wp->cpos] = 0;
    }
    wp->buf[wp->output_pos++] = (uint8_t) c; /* char or length */
    wp->c_freq[c]++;

    if (c >= (1U << CHAR_BIT)) {
        /* c is length, p is offset */
        wp->buf[wp->cpos] |= wp->output_mask;
        wp->buf[wp->output_pos++] = (uint8_t) (p >> CHAR_BIT);
        wp->buf[wp->output_pos++] = (uint8_t) p;

        {
            /* count frequency of p's bit length */
            int n = 0;          /* max of n is np-1 */
            while (p) {
                p >>= 1;
                n++;
            }
            wp->p_freq[n]++;
        }
    }
}

static void
init_enc_parameter(struct lzh_ostream *wp, struct lha_method *m)
{
    wp->np   = m->dicbit + 1;
    wp->pbit = m->pbit;
}

void
huf_encode_start(struct lzh_ostream *wp, struct lha_method *m)
{
    int i;

    init_enc_parameter(wp, m);
    wp->output_pos = wp->output_mask = wp->cpos = 0;

    if (wp->buf == 0) {
        wp->bufsiz = 16 * 1024U;
        while ((wp->buf = malloc(wp->bufsiz)) == NULL) {
            wp->bufsiz = (wp->bufsiz / 10U) * 9U;
            if (wp->bufsiz < 4 * 1024U)
                error("Out of memory.");
        }
    }
    wp->buf[0] = 0;
    for (i = 0; i < NC; i++)
        wp->c_freq[i] = 0;
    for (i = 0; i < wp->np; i++)
        wp->p_freq[i] = 0;

    init_putbits(wp);
}

void
huf_encode_end(struct lzh_ostream *wp)
{
    if (!wp->unpackable) {
        send_block(wp);
        putbits(wp, CHAR_BIT - 1, 0);       /* flush remaining bits */
    }

    if (wp->buf) {
        free(wp->buf);
        wp->buf = NULL;
    }
}

/***** decoding *****/

static void
read_pt_len(struct lzh_istream *rp, int nn, int nbit,
            int i_special, uint8_t *pt_len, uint16_t *pt_table,
            uint16_t *left, uint16_t *right)
{
    int i, c, n;
    unsigned int mask;

    n = getbits(rp, nbit);
    if (n == 0) {
        c = getbits(rp, nbit);
        for (i = 0; i < nn; i++)
            pt_len[i] = 0;
        for (i = 0; i < 256; i++)
            pt_table[i] = c;
    }
    else {
        i = 0;
        while (i < n) {
            c = rp->bitbuf >> (BITBUFSIZ - 3);
            if (c == 7) {
                mask = 1U << (BITBUFSIZ - 1 - 3);
                while (mask & rp->bitbuf) {
                    mask >>= 1;
                    c++;
                }
            }
            fillbuf(rp, (c < 7) ? 3 : c - 3);
            pt_len[i++] = c;
            if (i == i_special) {
                c = getbits(rp, 2);
                while (--c >= 0)
                    pt_len[i++] = 0;
            }
        }
        while (i < nn)
            pt_len[i++] = 0;
        make_table(nn, pt_len, 8, pt_table, left, right);
    }
}

static void
read_c_len(struct lzh_istream *rp, uint8_t *t_len, uint16_t *t_table,
           uint16_t *t_left, uint16_t *t_right)
{
    int i, c, n;
    unsigned int mask;

    n = getbits(rp, CBIT);
    if (n == 0) {
        c = getbits(rp, CBIT);
        for (i = 0; i < NC; i++)
            rp->c_len[i] = 0;
        for (i = 0; i < 4096; i++)
            rp->c_table[i] = c;
    }
    else {
        i = 0;
        while (i < n) {
            c = t_table[rp->bitbuf >> (BITBUFSIZ - 8)];
            if (c >= NT) {
                mask = 1U << (BITBUFSIZ - 1 - 8);
                do {
                    if (rp->bitbuf & mask)
                        c = t_right[c];
                    else
                        c = t_left[c];
                    mask >>= 1;
                } while (c >= NT);
            }
            fillbuf(rp, t_len[c]);
            if (c <= 2) {
                if (c == 0)
                    c = 1;
                else if (c == 1)
                    c = getbits(rp, 4) + 3;
                else
                    c = getbits(rp, CBIT) + 20;
                while (--c >= 0)
                    rp->c_len[i++] = 0;
            }
            else
                rp->c_len[i++] = c - 2;
        }
        while (i < NC)
            rp->c_len[i++] = 0;
        make_table(NC, rp->c_len, 12, rp->c_table, rp->c_left, rp->c_right);
    }
}

uint16_t
decode_c(struct lzh_istream *rp)
{
    uint16_t j, mask;

    if (rp->blocksize == 0) {
        uint8_t t_len[NT];
        uint16_t t_table[256];
        uint16_t t_left[2*NT-1], t_right[2*NT-1];

        rp->blocksize = getbits(rp, 16);
        read_pt_len(rp, NT, TBIT, 3, t_len, t_table, t_left, t_right);
        read_c_len(rp, t_len, t_table, t_left, t_right);
        read_pt_len(rp, rp->np, rp->pbit, -1, rp->p_len, rp->p_table,
                    rp->p_left, rp->p_right);
    }
    rp->blocksize--;
    j = rp->c_table[rp->bitbuf >> (BITBUFSIZ - 12)];
    if (j >= NC) {
        mask = 1U << (BITBUFSIZ - 1 - 12);
        do {
            if (rp->bitbuf & mask)
                j = rp->c_right[j];
            else
                j = rp->c_left[j];
            mask >>= 1;
        } while (j >= NC);
    }
    fillbuf(rp, rp->c_len[j]);
    return j;
}

uint16_t
decode_p(struct lzh_istream *rp)
{
    uint16_t j, mask;

    j = rp->p_table[rp->bitbuf >> (BITBUFSIZ - 8)];
    if (j >= rp->np) {
        mask = 1U << (BITBUFSIZ - 1 - 8);
        do {
            if (rp->bitbuf & mask)
                j = rp->p_right[j];
            else
                j = rp->p_left[j];
            mask >>= 1;
        } while (j >= rp->np);
    }
    fillbuf(rp, rp->p_len[j]);
    if (j != 0)
        j = (1U << (j - 1)) + getbits(rp, j - 1);
    return j;
}

static void
init_dec_parameter(struct lzh_istream *rp, struct lha_method *m)
{
    rp->np   = m->dicbit + 1;
    rp->pbit = m->pbit;
}

void
huf_decode_start(struct lzh_istream *rp, struct lha_method *m)
{
    init_dec_parameter(rp, m);
    init_getbits(rp);
    rp->blocksize = 0;
}
