blob: 3ab0c2a899dab5821c8228cfdbe4108b7ffa52ce [file] [log] [blame]
// Copyright 2017 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// TODO: describe how the xxx_fastxx version differs from the xxx_slow one, the
// assumptions that xxx_fastxx makes, and how that makes it fast.
pri func decoder.decode_huffman_fast32!(dst: base.io_writer, src: base.io_reader) base.status {
// When editing this function, consider making the equivalent change to the
// decode_huffman_slow function. Keep the diff between the two
// decode_huffman_*.wuffs files as small as possible, while retaining both
// correctness and performance.
var bits : base.u32
var n_bits : base.u32
var table_entry : base.u32
var table_entry_n_bits : base.u32[..= 15]
var lmask : base.u32[..= 511]
var dmask : base.u32[..= 511]
var redir_top : base.u32[..= 0xFFFF]
var redir_mask : base.u32[..= 0x7FFF]
var length : base.u32[..= 258]
var dist_minus_1 : base.u32[..= 0x7FFF]
var hlen : base.u32[..= 0x7FFF]
var hdist : base.u32
var hdist_adjustment : base.u32
if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) <> 0) {
return "#internal error: inconsistent n_bits"
}
bits = this.bits
n_bits = this.n_bits
lmask = ((1 as base.u32) << this.n_huffs_bits[0]) - 1
dmask = ((1 as base.u32) << this.n_huffs_bits[1]) - 1
if this.transformed_history_count < args.dst.history_position() {
return base."#bad I/O position"
}
hdist_adjustment = ((this.transformed_history_count - args.dst.history_position()) & 0xFFFF_FFFF) as base.u32
// Check up front, on each iteration, that we have enough buffer space to
// both read (12 bytes) and write (266 bytes) as much as we need to. Doing
// this check once (per iteration), up front, removes the need to check
// multiple times inside the loop body, so it's faster overall.
//
// For writing, a literal code obviously corresponds to writing 1 byte, and
// 258 is the maximum length in a length-distance pair, as specified in the
// RFC section 3.2.5. Compressed blocks (length and distance codes).
//
// We adjust 258 up to (258 + 8) = 266 since it can be faster to overshoot
// a little and make multiple-of-8-byte copies even when the length in the
// length-distance pair isn't an exact multiple-of-8. Strictly speaking,
// 264 (an exact multiple-of-8) is the tight bound but (258 + 8) is easier
// for the Wuffs proof system, as length's type is refined to [..= 258],
//
// For reading, strictly speaking, we only need 6 bytes (48 bits) of
// available input, because the H-L Literal/Length code is up to 15 bits
// plus up to 5 extra bits, the H-D Distance code is up to 15 bits plus up
// to 13 extra bits and 15 + 5 + 15 + 13 == 48. However,
// args.src.length() assertions are made in terms of bytes, not bits.
//
// In theory, we could write assertions that mixed bytes and bits, such as:
//
// assert ((args.src.length() >= 6) and (n_bits >= 5)) or
// ((args.src.length() >= 4) and (n_bits >= 21))
//
// but that looks clumsy and makes the proofs more complicated. Instead, we
// are conservative and work in terms of bytes only. In bytes, each code
// requires either one (direct) or two (redirect) lookups, up to 2 bytes
// for each. There are also up to 13 extra bits, or 2 bytes. Thus, each
// code requires up to 2 + 2 + 2 == 6 bytes, and there are two codes (one
// from H-L and one from H-D). Therefore, we check for at least 12 bytes.
while.loop(args.dst.length() >= 266) and (args.src.length() >= 12) {
// Ensure that we have at least 15 bits of input.
if n_bits < 15 {
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert n_bits >= 15
} else {
assert args.src.length() >= 10
}
// These assertions are redundant, but are listed explicitly for
// clarity. They must hold regardless of which branch was taken for the
// if statement above.
assert args.src.length() >= 10
assert n_bits >= 15
// Decode an lcode symbol from H-L.
table_entry = this.huffs[0][bits & lmask]
table_entry_n_bits = table_entry & 0x0F
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
if (table_entry >> 31) <> 0 {
// Literal.
args.dst.write_u8_fast!(a: ((table_entry >> 8) & 0xFF) as base.u8)
continue.loop
} else if (table_entry >> 30) <> 0 {
// No-op; code continues past the if-else chain.
assert args.src.length() >= 8
} else if (table_entry >> 29) <> 0 {
// End of block.
this.end_of_block = true
break.loop
} else if (table_entry >> 28) <> 0 {
// Redirect.
// Ensure that we have at least 15 bits of input.
if n_bits < 15 {
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert n_bits >= 15
} else {
assert args.src.length() >= 8
}
// Once again, redundant but explicit assertions.
assert args.src.length() >= 8
assert n_bits >= 15
redir_top = (table_entry >> 8) & 0xFFFF
redir_mask = ((1 as base.u32) << ((table_entry >> 4) & 0x0F)) - 1
table_entry = this.huffs[0][(redir_top + (bits & redir_mask)) & HUFFS_TABLE_MASK]
table_entry_n_bits = table_entry & 0x0F
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
if (table_entry >> 31) <> 0 {
// Literal.
args.dst.write_u8_fast!(a: ((table_entry >> 8) & 0xFF) as base.u8)
continue.loop
} else if (table_entry >> 30) <> 0 {
// No-op; code continues past the if-else chain.
} else if (table_entry >> 29) <> 0 {
// End of block.
this.end_of_block = true
break.loop
} else if (table_entry >> 28) <> 0 {
return "#internal error: inconsistent Huffman decoder state"
} else if (table_entry >> 27) <> 0 {
return "#bad Huffman code"
} else {
return "#internal error: inconsistent Huffman decoder state"
}
// Once again, redundant but explicit assertions.
assert args.src.length() >= 8
} else if (table_entry >> 27) <> 0 {
return "#bad Huffman code"
} else {
return "#internal error: inconsistent Huffman decoder state"
}
// length = base_number_minus_3 + 3 + extra_bits.
//
// The -3 is from the bias in script/print-deflate-magic-numbers.go.
// That bias makes the "& 0xFF" 1 and 15-ish lines below correct.
length = ((table_entry >> 8) & 0xFF) + 3
table_entry_n_bits = (table_entry >> 4) & 0x0F
if table_entry_n_bits > 0 {
// Ensure that we have at least 15 bits of input.
if n_bits < 15 {
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert n_bits >= 15
} else {
assert args.src.length() >= 6
}
// Once again, redundant but explicit assertions.
assert args.src.length() >= 6
assert n_bits >= 15
// The "+ 253" is the same as "- 3", after the "& 0xFF", but the
// plus form won't require an underflow check.
length = ((length + 253 + bits.low_bits(n: table_entry_n_bits)) & 0xFF) + 3
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
} else {
assert args.src.length() >= 6
}
// Ensure that we have at least 15 bits of input.
if n_bits < 15 {
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert n_bits >= 15
} else {
assert args.src.length() >= 4
}
// Once again, redundant but explicit assertions.
assert args.src.length() >= 4
assert n_bits >= 15
// Decode a dcode symbol from H-D.
table_entry = this.huffs[1][bits & dmask]
table_entry_n_bits = table_entry & 15
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
// Check for a redirect.
if (table_entry >> 28) == 1 {
// Ensure that we have at least 15 bits of input.
if n_bits < 15 {
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert n_bits >= 15
} else {
assert args.src.length() >= 2
}
// Once again, redundant but explicit assertions.
assert args.src.length() >= 2
assert n_bits >= 15
redir_top = (table_entry >> 8) & 0xFFFF
redir_mask = ((1 as base.u32) << ((table_entry >> 4) & 0x0F)) - 1
table_entry = this.huffs[1][(redir_top + (bits & redir_mask)) & HUFFS_TABLE_MASK]
table_entry_n_bits = table_entry & 0x0F
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
} else {
assert args.src.length() >= 2
}
// For H-D, all symbols should be base_number + extra_bits.
if (table_entry >> 24) <> 0x40 {
if (table_entry >> 24) == 0x08 {
return "#bad Huffman code"
}
return "#internal error: inconsistent Huffman decoder state"
}
// dist_minus_1 = base_number_minus_1 + extra_bits.
// distance = dist_minus_1 + 1.
//
// The -1 is from the bias in script/print-deflate-magic-numbers.go.
// That bias makes the "& 0x7FFF" 2 and 15-ish lines below correct and
// undoing that bias makes proving (dist_minus_1 + 1) > 0 trivial.
dist_minus_1 = (table_entry >> 8) & 0x7FFF
table_entry_n_bits = (table_entry >> 4) & 0x0F
// Ensure that we have at least table_entry_n_bits bits of input.
if n_bits < table_entry_n_bits {
assert n_bits < 15 via "a < b: a < c; c <= b"(c: table_entry_n_bits)
assert n_bits >= 0 // TODO: this shouldn't be necessary.
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
bits |= args.src.peek_u8_as_u32() << n_bits
args.src.skip_u32_fast!(actual: 1, worst_case: 1)
n_bits += 8
assert table_entry_n_bits <= n_bits via "a <= b: a <= c; c <= b"(c: 16)
assert n_bits >= table_entry_n_bits via "a >= b: b <= a"()
}
// Once again, redundant but explicit assertions.
assert n_bits >= table_entry_n_bits
dist_minus_1 = (dist_minus_1 + bits.low_bits(n: table_entry_n_bits)) & 0x7FFF
bits >>= table_entry_n_bits
n_bits -= table_entry_n_bits
// The "while true { etc; break }" is a redundant version of "etc", but
// its presence minimizes the diff between decode_huffman_fastxx and
// decode_huffman_slow.
while true,
pre args.dst.length() >= 266,
{
// We can therefore prove:
assert (length as base.u64) <= args.dst.length() via "a <= b: a <= c; c <= b"(c: 266)
assert ((length + 8) as base.u64) <= args.dst.length() via "a <= b: a <= c; c <= b"(c: 266)
// Copy from this.history.
if ((dist_minus_1 + 1) as base.u64) > args.dst.history_length() {
// Set (hlen, hdist) to be the length-distance pair to copy
// from this.history, and (length, distance) to be the
// remaining length-distance pair to copy from args.dst.
hlen = 0
hdist = (((dist_minus_1 + 1) as base.u64) - args.dst.history_length()) as base.u32
if length > hdist {
assert hdist < length via "a < b: b > a"()
assert hdist < 0x8000 via "a < b: a < c; c <= b"(c: length)
length -= hdist
hlen = hdist
} else {
hlen = length
length = 0
}
hdist ~mod+= hdist_adjustment
if this.history_index < hdist {
return "#bad distance"
}
// Copy from this.history[(this.history_index - hdist) ..].
//
// This copying is simpler than the decode_huffman_slow version
// because it cannot yield. We have already checked that
// args.dst.length() is large enough.
args.dst.limited_copy_u32_from_slice!(
up_to: hlen, s: this.history[(this.history_index - hdist) & 0x7FFF ..])
if length == 0 {
// No need to copy from args.dst.
continue.loop
}
if (((dist_minus_1 + 1) as base.u64) > args.dst.history_length()) or
((length as base.u64) > args.dst.length()) or
(((length + 8) as base.u64) > args.dst.length()) {
return "#internal error: inconsistent distance"
}
}
// Once again, redundant but explicit assertions.
assert (dist_minus_1 + 1) >= 1
assert ((dist_minus_1 + 1) as base.u64) <= args.dst.history_length()
assert (length as base.u64) <= args.dst.length()
assert ((length + 8) as base.u64) <= args.dst.length()
// Copy from args.dst.
//
// For short distances, less than 8 bytes, copying atomic 8-byte
// chunks can result in incorrect output, so we fall back to a
// slower 1-byte-at-a-time copy. Deflate's copy-from-history can
// pick up freshly written bytes. A length = 5, distance = 2 copy
// starting with "abc" should give "abcbcbcb" but the 8-byte chunk
// technique would give "abcbc???", the exact output depending on
// what was previously in the writer buffer.
if (dist_minus_1 + 1) >= 8 {
args.dst.limited_copy_u32_from_history_8_byte_chunks_fast!(
up_to: length, distance: (dist_minus_1 + 1))
} else {
args.dst.limited_copy_u32_from_history_fast!(
up_to: length, distance: (dist_minus_1 + 1))
}
break
} endwhile
} endwhile.loop
// Ensure n_bits < 8 by rewindng args.src, if we loaded too many of its
// bytes into the bits variable.
//
// Note that we can unconditionally call undo_read (without resulting in an
// "invalid I/O operation" error code) only because this whole function can
// never suspend, as all of its I/O operations were checked beforehand for
// sufficient buffer space. Otherwise, resuming from the suspension could
// mean that the (possibly different) args.src is no longer rewindable,
// even if conceptually, this function was responsible for reading the
// bytes we want to rewind.
while n_bits >= 8,
post n_bits < 8,
{
n_bits -= 8
if args.src.can_undo_byte() {
args.src.undo_byte!()
} else {
return "#internal error: inconsistent I/O"
}
} endwhile
this.bits = bits & (((1 as base.u32) << n_bits) - 1)
this.n_bits = n_bits
if (this.n_bits >= 8) or ((this.bits >> this.n_bits) <> 0) {
return "#internal error: inconsistent n_bits"
}
}