// Copyright 2017 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//
// SPDX-License-Identifier: Apache-2.0 OR MIT

// TODO: describe how the xxx_fastxx version differs from the xxx_slow one, the
// assumptions that xxx_fastxx makes, and how that makes it fast.
pri func decoder.decode_huffman_fast64!(dst: base.io_writer, src: base.io_reader) base.status,
        choosy,
{
    // When editing this function, consider making the equivalent change to the
    // decode_huffman_slow function. Keep the diff between the two
    // decode_huffman_*.wuffs files as small as possible, while retaining both
    // correctness and performance.

    var bits               : base.u64
    var n_bits             : base.u32
    var table_entry        : base.u32
    var table_entry_n_bits : base.u32[..= 15]
    var lmask              : base.u64[..= 511]
    var dmask              : base.u64[..= 511]
    var redir_top          : base.u32[..= 0xFFFF]
    var redir_mask         : base.u32[..= 0x7FFF]
    var length             : base.u32[..= 258]
    var dist_minus_1       : base.u32[..= 0x7FFF]
    var hlen               : base.u32[..= 0x7FFF]
    var hdist              : base.u32
    var hdist_adjustment   : base.u32

    if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) <> 0) {
        return "#internal error: inconsistent n_bits"
    }

    bits = this.bits as base.u64
    n_bits = this.n_bits

    lmask = ((1 as base.u64) << this.n_huffs_bits[0]) - 1
    dmask = ((1 as base.u64) << this.n_huffs_bits[1]) - 1

    if this.transformed_history_count < args.dst.history_position() {
        return base."#bad I/O position"
    }
    hdist_adjustment = ((this.transformed_history_count - args.dst.history_position()) & 0xFFFF_FFFF) as base.u32

    // Check up front, on each iteration, that we have enough buffer space to
    // both read (8 bytes) and write (266 bytes) as much as we need to. Doing
    // this check once (per iteration), up front, removes the need to check
    // multiple times inside the loop body, so it's faster overall.
    //
    // For writing, a literal code obviously corresponds to writing 1 byte, and
    // 258 is the maximum length in a length-distance pair, as specified in the
    // RFC section 3.2.5. Compressed blocks (length and distance codes).
    //
    // We adjust 258 up to (258 + 8) = 266 since it can be faster to overshoot
    // a little and make multiple-of-8-byte copies even when the length in the
    // length-distance pair isn't an exact multiple-of-8. Strictly speaking,
    // 264 (an exact multiple-of-8) is the tight bound but (258 + 8) is easier
    // for the Wuffs proof system, as length's type is refined to [..= 258],
    //
    // For reading, strictly speaking, we only need 6 bytes (48 bits) of
    // available input, because the H-L Literal/Length code is up to 15 bits
    // plus up to 5 extra bits, the H-D Distance code is up to 15 bits plus up
    // to 13 extra bits and 15 + 5 + 15 + 13 == 48. However, it's faster to
    // read 64 bits than 48 bits or (48 - n_bits) bits.
    while.loop(args.dst.length() >= 266) and (args.src.length() >= 8) {
        // Ensure that we have at least 56 bits of input.
        //
        // This is "Variant 4" of
        // https://fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/
        //
        // 56, the number of bits in 7 bytes, is a property of that "Variant 4"
        // bit-reading technique, and not related to the DEFLATE format per se.
        //
        // The "& 63" is a no-op, and not part of the original "Variant 4"
        // technique, but satisfies Wuffs' overflow/underflow checks.
        bits |= args.src.peek_u64le() ~mod<< (n_bits & 63)
        args.src.skip_u32_fast!(actual: (63 - (n_bits & 63)) >> 3, worst_case: 8)
        n_bits |= 56

        // Decode an lcode symbol from H-L.
        table_entry = this.huffs[0][bits & lmask]
        table_entry_n_bits = table_entry & 0x0F
        bits >>= table_entry_n_bits
        n_bits ~mod-= table_entry_n_bits

        if (table_entry >> 31) <> 0 {
            // Literal.
            args.dst.write_u8_fast!(a: ((table_entry >> 8) & 0xFF) as base.u8)
            continue.loop
        } else if (table_entry >> 30) <> 0 {
            // No-op; code continues past the if-else chain.
        } else if (table_entry >> 29) <> 0 {
            // End of block.
            this.end_of_block = true
            break.loop
        } else if (table_entry >> 28) <> 0 {
            // Redirect.
            redir_top = (table_entry >> 8) & 0xFFFF
            redir_mask = ((1 as base.u32) << ((table_entry >> 4) & 0x0F)) - 1
            table_entry = this.huffs[0][
                    (redir_top + (((bits & 0xFFFF_FFFF) as base.u32) & redir_mask)) & HUFFS_TABLE_MASK]
            table_entry_n_bits = table_entry & 0x0F
            bits >>= table_entry_n_bits
            n_bits ~mod-= table_entry_n_bits

            if (table_entry >> 31) <> 0 {
                // Literal.
                args.dst.write_u8_fast!(a: ((table_entry >> 8) & 0xFF) as base.u8)
                continue.loop
            } else if (table_entry >> 30) <> 0 {
                // No-op; code continues past the if-else chain.
            } else if (table_entry >> 29) <> 0 {
                // End of block.
                this.end_of_block = true
                break.loop
            } else if (table_entry >> 28) <> 0 {
                return "#internal error: inconsistent Huffman decoder state"
            } else if (table_entry >> 27) <> 0 {
                return "#bad Huffman code"
            } else {
                return "#internal error: inconsistent Huffman decoder state"
            }

        } else if (table_entry >> 27) <> 0 {
            return "#bad Huffman code"
        } else {
            return "#internal error: inconsistent Huffman decoder state"
        }

        // length = base_number_minus_3 + 3 + extra_bits.
        //
        // The -3 is from the bias in script/print-deflate-magic-numbers.go.
        // That bias makes the "& 0xFF" 1 and 15-ish lines below correct.
        length = ((table_entry >> 8) & 0xFF) + 3
        table_entry_n_bits = (table_entry >> 4) & 0x0F
        if table_entry_n_bits > 0 {
            // The "+ 253" is the same as "- 3", after the "& 0xFF", but the
            // plus form won't require an underflow check.
            length = ((length + 253 + (bits.low_bits(n: table_entry_n_bits) as base.u32)) & 0xFF) + 3
            bits >>= table_entry_n_bits
            n_bits ~mod-= table_entry_n_bits
        }
        assert length >= 1

        // Decode a dcode symbol from H-D.
        table_entry = this.huffs[1][bits & dmask]
        table_entry_n_bits = table_entry & 15
        bits >>= table_entry_n_bits
        n_bits ~mod-= table_entry_n_bits

        // Check for a redirect.
        if (table_entry >> 28) == 1 {
            redir_top = (table_entry >> 8) & 0xFFFF
            redir_mask = ((1 as base.u32) << ((table_entry >> 4) & 0x0F)) - 1
            table_entry = this.huffs[1][
                    (redir_top + (((bits & 0xFFFF_FFFF) as base.u32) & redir_mask)) & HUFFS_TABLE_MASK]
            table_entry_n_bits = table_entry & 0x0F
            bits >>= table_entry_n_bits
            n_bits ~mod-= table_entry_n_bits
        }

        // For H-D, all symbols should be base_number + extra_bits.
        if (table_entry >> 24) <> 0x40 {
            if (table_entry >> 24) == 0x08 {
                return "#bad Huffman code"
            }
            return "#internal error: inconsistent Huffman decoder state"
        }

        // dist_minus_1 = base_number_minus_1 + extra_bits.
        // distance     = dist_minus_1 + 1.
        //
        // The -1 is from the bias in script/print-deflate-magic-numbers.go.
        // That bias makes the "& 0x7FFF" 2 and 15-ish lines below correct and
        // undoing that bias makes proving (dist_minus_1 + 1) > 0 trivial.
        dist_minus_1 = (table_entry >> 8) & 0x7FFF
        table_entry_n_bits = (table_entry >> 4) & 0x0F

        dist_minus_1 = (dist_minus_1 + (bits.low_bits(n: table_entry_n_bits) as base.u32)) & 0x7FFF
        bits >>= table_entry_n_bits
        n_bits ~mod-= table_entry_n_bits

        // The "while true { etc; break }" is a redundant version of "etc", but
        // its presence minimizes the diff between decode_huffman_fastxx and
        // decode_huffman_slow.
        while true,
                pre args.dst.length() >= 266,
                pre length >= 1,
        {
            // We can therefore prove:
            assert (length as base.u64) <= args.dst.length() via "a <= b: a <= c; c <= b"(c: 266)
            assert ((length + 8) as base.u64) <= args.dst.length() via "a <= b: a <= c; c <= b"(c: 266)

            // Copy from this.history.
            if ((dist_minus_1 + 1) as base.u64) > args.dst.history_length() {
                // Set (hlen, hdist) to be the length-distance pair to copy
                // from this.history, and (length, distance) to be the
                // remaining length-distance pair to copy from args.dst.
                hlen = 0
                hdist = (((dist_minus_1 + 1) as base.u64) - args.dst.history_length()) as base.u32
                if length > hdist {
                    assert hdist < length via "a < b: b > a"()
                    assert hdist < 0x8000 via "a < b: a < c; c <= b"(c: length)
                    length -= hdist
                    hlen = hdist
                } else {
                    hlen = length
                    length = 0
                }
                hdist ~mod+= hdist_adjustment
                if this.history_index < hdist {
                    return "#bad distance"
                }

                // Copy from this.history[(this.history_index - hdist) ..].
                //
                // This copying is simpler than the decode_huffman_slow version
                // because it cannot yield. We have already checked that
                // args.dst.length() is large enough.
                args.dst.limited_copy_u32_from_slice!(
                        up_to: hlen, s: this.history[(this.history_index - hdist) & 0x7FFF ..])
                if length == 0 {
                    // No need to copy from args.dst.
                    continue.loop
                }
                assert length >= 1

                if (((dist_minus_1 + 1) as base.u64) > args.dst.history_length()) or
                        ((length as base.u64) > args.dst.length()) or
                        (((length + 8) as base.u64) > args.dst.length()) {
                    return "#internal error: inconsistent distance"
                }
            }
            // Once again, redundant but explicit assertions.
            assert (dist_minus_1 + 1) >= 1
            assert ((dist_minus_1 + 1) as base.u64) <= args.dst.history_length()
            assert length >= 1
            assert (length as base.u64) <= args.dst.length()
            assert ((length + 8) as base.u64) <= args.dst.length()

            // Copy from args.dst.
            //
            // For short distances, less than 8 bytes, copying atomic 8-byte
            // chunks can result in incorrect output, so we fall back to a
            // slower 1-byte-at-a-time copy. Deflate's copy-from-history can
            // pick up freshly written bytes. A length = 5, distance = 2 copy
            // starting with "abc" should give "abcbcbcb" but the 8-byte chunk
            // technique would give "abcbc???", the exact output depending on
            // what was previously in the writer buffer.
            if (dist_minus_1 + 1) >= 8 {
                args.dst.limited_copy_u32_from_history_8_byte_chunks_fast!(
                        up_to: length, distance: (dist_minus_1 + 1))
            } else if (dist_minus_1 + 1) == 1 {
                // (distance == 1) is essentially RLE (Run Length Encoding). It
                // happens often enough that it's worth special-casing.
                args.dst.limited_copy_u32_from_history_8_byte_chunks_distance_1_fast!(
                        up_to: length, distance: (dist_minus_1 + 1))
            } else {
                args.dst.limited_copy_u32_from_history_fast!(
                        up_to: length, distance: (dist_minus_1 + 1))
            }
            break
        } endwhile
    } endwhile.loop

    // Ensure n_bits < 8 by rewindng args.src, if we loaded too many of its
    // bytes into the bits variable.
    //
    // Note that we can unconditionally call undo_read (without resulting in an
    // "invalid I/O operation" error code) only because this whole function can
    // never suspend, as all of its I/O operations were checked beforehand for
    // sufficient buffer space. Otherwise, resuming from the suspension could
    // mean that the (possibly different) args.src is no longer rewindable,
    // even if conceptually, this function was responsible for reading the
    // bytes we want to rewind.
    if n_bits > 63 {
        return "#internal error: inconsistent n_bits"
    }
    while n_bits >= 8,
            post n_bits < 8,
    {
        n_bits -= 8
        if args.src.can_undo_byte() {
            args.src.undo_byte!()
        } else {
            return "#internal error: inconsistent I/O"
        }
    } endwhile

    this.bits = (bits & (((1 as base.u64) << n_bits) - 1)) as base.u32
    this.n_bits = n_bits

    if (this.n_bits >= 8) or ((this.bits >> this.n_bits) <> 0) {
        return "#internal error: inconsistent n_bits"
    }
}
