// Copyright 2017 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub status "#bad code"
pub status "#truncated input"

pri status "#internal error: inconsistent I/O"

// TODO: move bulk data buffers like decoder.suffixes or decoder.output into
// the workbuf? The first attempt at this was a performance regression for
// decoding all but the smallest GIFs. See these git commits for numbers:
//  - 49627b4 Flatten the lzw.decoder.suffixes array
//  - f877fb2 Use the workbuf instead of lzw.decoder.suffixes
//  - 85be5b9 Delete the obsolete lzw.decoder.suffixes array
// and the roll back has combined numbers:
//  - 3056a84 Roll back 3 recent lzw.decoder.suffixes commits
pub const DECODER_WORKBUF_LEN_MAX_INCL_WORST_CASE : base.u64 = 0

pub struct decoder? implements base.io_transformer(
        // set_literal_width_arg is 1 plus the saved argument passed to
        // set_literal_width. This is assigned to the literal_width field at the
        // start of transform_io. During that method, calling set_literal_width
        // will change set_literal_width_arg but not literal_width.
        set_literal_width_arg : base.u32[..= 9],

        // read_from state that does not change during a decode call.
        literal_width : base.u32[..= 8],
        clear_code    : base.u32[..= 256],
        end_code      : base.u32[..= 257],

        // read_from state that does change during a decode call.
        save_code : base.u32[..= 4096],
        prev_code : base.u32[..= 4095],
        width     : base.u32[..= 12],
        bits      : base.u32,
        n_bits    : base.u32[..= 31],
        output_ri : base.u32[..= 8191],
        output_wi : base.u32[..= 8191],

        // read_from return value. The read_from method effectively returns a
        // base.u32 to show how decode should continue after calling write_to. That
        // value needs to be saved across write_to's possible suspension, so we
        // might as well save it explicitly as a decoder field.
        read_from_return_value : base.u32,

        // read_from per-code state.
        prefixes : array[4096] base.u16[..= 4095],

        util : base.utility,
) + (
        // read_from per-code state.
        suffixes : array[4096] array[8] base.u8,
        // lm1s is the "length minus 1"s of the values for the implicit key-value
        // table in this decoder. See std/lzw/README.md for more detail.
        lm1s : array[4096] base.u16,

        // output[output_ri:output_wi] is the buffered output, connecting read_from
        // with write_to and flush.
        output : array[8192 + 7] base.u8,
)

pub func decoder.set_quirk_enabled!(quirk: base.u32, enabled: base.bool) {
}

pub func decoder.set_literal_width!(lw: base.u32[..= 8]) {
    this.set_literal_width_arg = args.lw + 1
}

pub func decoder.workbuf_len() base.range_ii_u64 {
    return this.util.make_range_ii_u64(min_incl: 0, max_incl: 0)
}

pub func decoder.transform_io?(dst: base.io_writer, src: base.io_reader, workbuf: slice base.u8) {
    var i : base.u32[..= 8191]

    // Initialize read_from state.
    this.literal_width = 8
    if this.set_literal_width_arg > 0 {
        this.literal_width = this.set_literal_width_arg - 1
    }
    this.clear_code = (1 as base.u32) << this.literal_width
    this.end_code = this.clear_code + 1
    this.save_code = this.end_code
    this.prev_code = this.end_code
    this.width = this.literal_width + 1
    this.bits = 0
    this.n_bits = 0
    this.output_ri = 0
    this.output_wi = 0
    i = 0
    while i < this.clear_code {
        assert i < 256 via "a < b: a < c; c <= b"(c: this.clear_code)
        this.lm1s[i] = 0
        this.suffixes[i][0] = i as base.u8
        i += 1
    } endwhile

    while true {
        this.read_from!(src: args.src)

        if this.output_wi > 0 {
            this.write_to?(dst: args.dst)
        }

        if this.read_from_return_value == 0 {
            break
        } else if this.read_from_return_value == 1 {
            continue
        } else if this.read_from_return_value == 2 {
            yield? base."$short read"
        } else if this.read_from_return_value == 3 {
            return "#truncated input"
        } else if this.read_from_return_value == 4 {
            return "#bad code"
        } else {
            return "#internal error: inconsistent I/O"
        }
    } endwhile
}

pri func decoder.read_from!(src: base.io_reader) {
    var clear_code : base.u32[..= 256]
    var end_code   : base.u32[..= 257]

    var save_code : base.u32[..= 4096]
    var prev_code : base.u32[..= 4095]
    var width     : base.u32[..= 12]
    var bits      : base.u32
    var n_bits    : base.u32[..= 31]
    var output_wi : base.u32[..= 8191]

    var code       : base.u32[..= 4095]
    var c          : base.u32[..= 4095]
    var o          : base.u32[..= 8191]
    var steps      : base.u32
    var first_byte : base.u8
    var lm1_b      : base.u16[..= 4095]
    var lm1_a      : base.u16[..= 4095]

    clear_code = this.clear_code
    end_code = this.end_code

    save_code = this.save_code
    prev_code = this.prev_code
    width = this.width
    bits = this.bits
    n_bits = this.n_bits
    output_wi = this.output_wi

    while true {
        if n_bits < width {
            assert n_bits < 12 via "a < b: a < c; c <= b"(c: width)
            if args.src.length() >= 4 {
                // Read 4 bytes, using the "Variant 4" technique of
                // https://fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/
                bits |= args.src.peek_u32le() ~mod<< n_bits
                args.src.skip_u32_fast!(actual: (31 - n_bits) >> 3, worst_case: 3)
                n_bits |= 24
                assert width <= n_bits via "a <= b: a <= c; c <= b"(c: 12)
                assert n_bits >= width via "a >= b: b <= a"()
            } else if args.src.length() <= 0 {
                if args.src.is_closed() {
                    this.read_from_return_value = 3
                } else {
                    this.read_from_return_value = 2
                }
                break
            } else {
                bits |= args.src.peek_u8_as_u32() << n_bits
                args.src.skip_u32_fast!(actual: 1, worst_case: 1)
                n_bits += 8
                if n_bits >= width {
                    // No-op.
                } else if args.src.length() <= 0 {
                    if args.src.is_closed() {
                        this.read_from_return_value = 3
                    } else {
                        this.read_from_return_value = 2
                    }
                    break
                } else {
                    bits |= args.src.peek_u8_as_u32() << n_bits
                    args.src.skip_u32_fast!(actual: 1, worst_case: 1)
                    n_bits += 8
                    assert width <= n_bits via "a <= b: a <= c; c <= b"(c: 12)
                    assert n_bits >= width via "a >= b: b <= a"()

                    // This if condition is always false, but for some unknown
                    // reason, removing it worsens the benchmarks slightly.
                    if n_bits < width {
                        this.read_from_return_value = 5
                        break
                    }
                }
            }
        }

        code = bits.low_bits(n: width)
        bits >>= width
        n_bits -= width

        if code < clear_code {
            assert code < 256 via "a < b: a < c; c <= b"(c: clear_code)
            this.output[output_wi] = code as base.u8
            output_wi = (output_wi + 1) & 8191
            if save_code <= 4095 {
                lm1_a = (this.lm1s[prev_code] ~mod+ 1) & 4095
                this.lm1s[save_code] = lm1_a

                if (lm1_a % 8) <> 0 {
                    this.prefixes[save_code] = this.prefixes[prev_code]
                    this.suffixes[save_code] = this.suffixes[prev_code]
                    this.suffixes[save_code][lm1_a % 8] = code as base.u8
                } else {
                    this.prefixes[save_code] = prev_code as base.u16
                    this.suffixes[save_code][0] = code as base.u8
                }

                save_code += 1
                if width < 12 {
                    width += 1 & (save_code >> width)
                }
                prev_code = code
            }

        } else if code <= end_code {
            if code == end_code {
                this.read_from_return_value = 0
                break
            }
            save_code = end_code
            prev_code = end_code
            width = this.literal_width + 1

        } else if code <= save_code {
            c = code
            if code == save_code {
                c = prev_code
            }

            // Letting old_wi and new_wi denote the values of output_wi before
            // and after these two lines of code, the decoded bytes will be
            // written to output[old_wi:new_wi]. They will be written
            // back-to-front, 8 bytes at a time, starting by writing
            // output[o:o + 8], which will contain output[new_wi - 1].
            //
            // In the special case that code == save_code, the decoded bytes
            // contain an extra copy (at the end) of the first byte, and will
            // be written to output[old_wi:new_wi + 1].
            o = (output_wi + ((this.lm1s[c] as base.u32) & 0xFFFF_FFF8)) & 8191
            output_wi = (output_wi + 1 + (this.lm1s[c] as base.u32)) & 8191

            steps = (this.lm1s[c] as base.u32) >> 3
            while true {
                assert o <= (o + 8) via "a <= (a + b): 0 <= b"(b: 8)

                // The final "8" is redundant semantically, but helps the
                // wuffs-c code generator recognize that both slices have the
                // same constant length, and hence produce efficient C code.
                this.output[o .. o + 8].copy_from_slice!(s: this.suffixes[c][.. 8])

                if steps <= 0 {
                    break
                }
                steps -= 1

                // This line is essentially "o -= 8". The "& 8191" is a no-op
                // in practice, but is necessary for the overflow checker.
                o = (o ~mod- 8) & 8191
                c = this.prefixes[c] as base.u32
            } endwhile
            first_byte = this.suffixes[c][0]

            if code == save_code {
                this.output[output_wi] = first_byte
                output_wi = (output_wi + 1) & 8191
            }

            if save_code <= 4095 {
                lm1_b = (this.lm1s[prev_code] ~mod+ 1) & 4095
                this.lm1s[save_code] = lm1_b

                if (lm1_b % 8) <> 0 {
                    this.prefixes[save_code] = this.prefixes[prev_code]
                    this.suffixes[save_code] = this.suffixes[prev_code]
                    this.suffixes[save_code][lm1_b % 8] = first_byte
                } else {
                    this.prefixes[save_code] = prev_code as base.u16
                    this.suffixes[save_code][0] = first_byte as base.u8
                }

                save_code += 1
                if width < 12 {
                    width += 1 & (save_code >> width)
                }
                prev_code = code
            }

        } else {
            this.read_from_return_value = 4
            break
        }

        // Flush the output if it could be too full to contain the entire
        // decoding of the next code. The longest possible decoding is slightly
        // less than 4096 and output's length is 8192, so a conservative
        // threshold is ensuring that output_wi <= 4095.
        if output_wi > 4095 {
            this.read_from_return_value = 1
            break
        }
    } endwhile

    // Rewind args.src, if we're not in "$short read" and we've read too many
    // bits.
    if this.read_from_return_value <> 2 {
        while n_bits >= 8 {
            n_bits -= 8
            if args.src.can_undo_byte() {
                args.src.undo_byte!()
            } else {
                this.read_from_return_value = 5
                break
            }
        } endwhile
    }

    this.save_code = save_code
    this.prev_code = prev_code
    this.width = width
    this.bits = bits
    this.n_bits = n_bits
    this.output_wi = output_wi
}

pri func decoder.write_to?(dst: base.io_writer) {
    var s : slice base.u8
    var n : base.u64

    while this.output_wi > 0 {
        if this.output_ri > this.output_wi {
            return "#internal error: inconsistent I/O"
        }
        s = this.output[this.output_ri .. this.output_wi]
        n = args.dst.copy_from_slice!(s: s)
        if n == s.length() {
            this.output_ri = 0
            this.output_wi = 0
            return ok
        }
        this.output_ri = (this.output_ri ~mod+ ((n & 0xFFFF_FFFF) as base.u32)) & 8191
        yield? base."$short write"
    } endwhile
}

pub func decoder.flush!() slice base.u8 {
    var s : slice base.u8

    if this.output_ri <= this.output_wi {
        s = this.output[this.output_ri .. this.output_wi]
    }
    this.output_ri = 0
    this.output_wi = 0
    return s
}
