blob: c883a708ecfbdf9ce1c91ab11e7c6abc2d03a777 [file] [log] [blame]
// Copyright 2024 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//
// SPDX-License-Identifier: Apache-2.0 OR MIT
pri func decoder.apply_transform_predictor!(pix: slice base.u8, tile_data: roslice base.u8) {
var w4 : base.u64[..= 0x1_0000]
var prev_row : roslice base.u8
var curr_row : slice base.u8
var tile_size_log2 : base.u32[..= 9]
var tiles_per_row : base.u32[..= 16895]
var mask : base.u32
var y : base.u32[..= 0x4000]
var x : base.u32[..= 0x4000]
var t : base.u64
var tile_data : roslice base.u8
var mode : base.u8[..= 0x0F]
var l0 : base.u32[..= 0xFF]
var l1 : base.u32[..= 0xFF]
var l2 : base.u32[..= 0xFF]
var l3 : base.u32[..= 0xFF]
var c0 : base.u32[..= 0xFF]
var c1 : base.u32[..= 0xFF]
var c2 : base.u32[..= 0xFF]
var c3 : base.u32[..= 0xFF]
var t0 : base.u32[..= 0xFF]
var t1 : base.u32[..= 0xFF]
var t2 : base.u32[..= 0xFF]
var t3 : base.u32[..= 0xFF]
var sum_l : base.u32
var sum_t : base.u32
if (this.width <= 0) or (this.height <= 0) {
return nothing
}
w4 = (this.width * 4) as base.u64
curr_row = this.util.empty_slice_u8()
if w4 <= args.pix.length() {
curr_row = args.pix[.. w4]
}
// The first pixel's predictor is mode 0 (opaque black).
if curr_row.length() >= 4 {
curr_row[3] ~mod+= 0xFF
}
// The rest of the first row's predictor is mode 1 (L).
while curr_row.length() >= 8 {
curr_row[4] ~mod+= curr_row[0]
curr_row[5] ~mod+= curr_row[1]
curr_row[6] ~mod+= curr_row[2]
curr_row[7] ~mod+= curr_row[3]
curr_row = curr_row[4 ..]
}
tile_size_log2 = this.transform_tile_size_log2[0] as base.u32
tiles_per_row = (this.width + (((1 as base.u32) << tile_size_log2) - 1)) >> tile_size_log2
mask = ((1 as base.u32) << tile_size_log2) - 1
y = 1
while y < this.height {
assert y < 0x4000 via "a < b: a < c; c <= b"(c: this.height)
t = (4 * (y >> tile_size_log2) * tiles_per_row) as base.u64
tile_data = this.util.empty_slice_u8()
if t <= args.tile_data.length() {
tile_data = args.tile_data[t ..]
if tile_data.length() >= 4 {
mode = tile_data[1] & 0x0F
tile_data = tile_data[4 ..]
}
}
if w4 <= args.pix.length() {
prev_row = args.pix
args.pix = args.pix[w4 ..]
curr_row = args.pix
}
// The first column's predictor is mode 2 (T).
if (prev_row.length() >= 4) and (curr_row.length() >= 4) {
curr_row[0] ~mod+= prev_row[0]
curr_row[1] ~mod+= prev_row[1]
curr_row[2] ~mod+= prev_row[2]
curr_row[3] ~mod+= prev_row[3]
}
x = 1
while x < this.width,
inv y < 0x4000,
{
assert x < 0x4000 via "a < b: a < c; c <= b"(c: this.width)
if ((x & mask) == 0) and (tile_data.length() >= 4) {
mode = tile_data[1] & 0x0F
tile_data = tile_data[4 ..]
}
if (prev_row.length() < 12) or (curr_row.length() < 8) {
break
}
if mode == 0 { // Opaque black.
curr_row[7] ~mod+= 0xFF
} else if mode == 1 { // L
curr_row[4] ~mod+= curr_row[0]
curr_row[5] ~mod+= curr_row[1]
curr_row[6] ~mod+= curr_row[2]
curr_row[7] ~mod+= curr_row[3]
} else if mode == 2 { // T
curr_row[4] ~mod+= prev_row[4]
curr_row[5] ~mod+= prev_row[5]
curr_row[6] ~mod+= prev_row[6]
curr_row[7] ~mod+= prev_row[7]
} else if mode == 3 { // TR
curr_row[4] ~mod+= prev_row[8]
curr_row[5] ~mod+= prev_row[9]
curr_row[6] ~mod+= prev_row[10]
curr_row[7] ~mod+= prev_row[11]
} else if mode == 4 { // TL
curr_row[4] ~mod+= prev_row[0]
curr_row[5] ~mod+= prev_row[1]
curr_row[6] ~mod+= prev_row[2]
curr_row[7] ~mod+= prev_row[3]
} else if mode == 5 { // Average2(Average2(L, TR), T).
l0 = ((curr_row[0] as base.u32) + (prev_row[8] as base.u32)) / 2
l1 = ((curr_row[1] as base.u32) + (prev_row[9] as base.u32)) / 2
l2 = ((curr_row[2] as base.u32) + (prev_row[10] as base.u32)) / 2
l3 = ((curr_row[3] as base.u32) + (prev_row[11] as base.u32)) / 2
curr_row[4] ~mod+= ((l0 + (prev_row[4] as base.u32)) / 2) as base.u8
curr_row[5] ~mod+= ((l1 + (prev_row[5] as base.u32)) / 2) as base.u8
curr_row[6] ~mod+= ((l2 + (prev_row[6] as base.u32)) / 2) as base.u8
curr_row[7] ~mod+= ((l3 + (prev_row[7] as base.u32)) / 2) as base.u8
} else if mode == 6 { // Average2(L, TL).
curr_row[4] ~mod+= (((curr_row[0] as base.u32) + (prev_row[0] as base.u32)) / 2) as base.u8
curr_row[5] ~mod+= (((curr_row[1] as base.u32) + (prev_row[1] as base.u32)) / 2) as base.u8
curr_row[6] ~mod+= (((curr_row[2] as base.u32) + (prev_row[2] as base.u32)) / 2) as base.u8
curr_row[7] ~mod+= (((curr_row[3] as base.u32) + (prev_row[3] as base.u32)) / 2) as base.u8
} else if mode == 7 { // Average2(L, T).
curr_row[4] ~mod+= (((curr_row[0] as base.u32) + (prev_row[4] as base.u32)) / 2) as base.u8
curr_row[5] ~mod+= (((curr_row[1] as base.u32) + (prev_row[5] as base.u32)) / 2) as base.u8
curr_row[6] ~mod+= (((curr_row[2] as base.u32) + (prev_row[6] as base.u32)) / 2) as base.u8
curr_row[7] ~mod+= (((curr_row[3] as base.u32) + (prev_row[7] as base.u32)) / 2) as base.u8
} else if mode == 8 { // Average2(TL, T).
curr_row[4] ~mod+= (((prev_row[0] as base.u32) + (prev_row[4] as base.u32)) / 2) as base.u8
curr_row[5] ~mod+= (((prev_row[1] as base.u32) + (prev_row[5] as base.u32)) / 2) as base.u8
curr_row[6] ~mod+= (((prev_row[2] as base.u32) + (prev_row[6] as base.u32)) / 2) as base.u8
curr_row[7] ~mod+= (((prev_row[3] as base.u32) + (prev_row[7] as base.u32)) / 2) as base.u8
} else if mode == 9 { // Average2(T, TR).
curr_row[4] ~mod+= (((prev_row[4] as base.u32) + (prev_row[8] as base.u32)) / 2) as base.u8
curr_row[5] ~mod+= (((prev_row[5] as base.u32) + (prev_row[9] as base.u32)) / 2) as base.u8
curr_row[6] ~mod+= (((prev_row[6] as base.u32) + (prev_row[10] as base.u32)) / 2) as base.u8
curr_row[7] ~mod+= (((prev_row[7] as base.u32) + (prev_row[11] as base.u32)) / 2) as base.u8
} else if mode == 10 { // Average2(Average2(L, TL), Average2(T, TR)).
l0 = ((curr_row[0] as base.u32) + (prev_row[0] as base.u32)) / 2
l1 = ((curr_row[1] as base.u32) + (prev_row[1] as base.u32)) / 2
l2 = ((curr_row[2] as base.u32) + (prev_row[2] as base.u32)) / 2
l3 = ((curr_row[3] as base.u32) + (prev_row[3] as base.u32)) / 2
t0 = ((prev_row[4] as base.u32) + (prev_row[8] as base.u32)) / 2
t1 = ((prev_row[5] as base.u32) + (prev_row[9] as base.u32)) / 2
t2 = ((prev_row[6] as base.u32) + (prev_row[10] as base.u32)) / 2
t3 = ((prev_row[7] as base.u32) + (prev_row[11] as base.u32)) / 2
curr_row[4] ~mod+= ((l0 + t0) / 2) as base.u8
curr_row[5] ~mod+= ((l1 + t1) / 2) as base.u8
curr_row[6] ~mod+= ((l2 + t2) / 2) as base.u8
curr_row[7] ~mod+= ((l3 + t3) / 2) as base.u8
} else if mode == 11 { // Select(L, T, TL).
l0 = curr_row[0] as base.u32
l1 = curr_row[1] as base.u32
l2 = curr_row[2] as base.u32
l3 = curr_row[3] as base.u32
c0 = prev_row[0] as base.u32
c1 = prev_row[1] as base.u32
c2 = prev_row[2] as base.u32
c3 = prev_row[3] as base.u32
t0 = prev_row[4] as base.u32
t1 = prev_row[5] as base.u32
t2 = prev_row[6] as base.u32
t3 = prev_row[7] as base.u32
sum_l = this.absolute_difference(a: c0, b: t0) +
this.absolute_difference(a: c1, b: t1) +
this.absolute_difference(a: c2, b: t2) +
this.absolute_difference(a: c3, b: t3)
sum_t = this.absolute_difference(a: c0, b: l0) +
this.absolute_difference(a: c1, b: l1) +
this.absolute_difference(a: c2, b: l2) +
this.absolute_difference(a: c3, b: l3)
if sum_l < sum_t {
curr_row[4] ~mod+= l0 as base.u8
curr_row[5] ~mod+= l1 as base.u8
curr_row[6] ~mod+= l2 as base.u8
curr_row[7] ~mod+= l3 as base.u8
} else {
curr_row[4] ~mod+= t0 as base.u8
curr_row[5] ~mod+= t1 as base.u8
curr_row[6] ~mod+= t2 as base.u8
curr_row[7] ~mod+= t3 as base.u8
}
} else if mode == 12 { // ClampAddSubtractFull(L, T, TL).
curr_row[4] ~mod+= this.mode12(l: curr_row[0], t: prev_row[4], tl: prev_row[0])
curr_row[5] ~mod+= this.mode12(l: curr_row[1], t: prev_row[5], tl: prev_row[1])
curr_row[6] ~mod+= this.mode12(l: curr_row[2], t: prev_row[6], tl: prev_row[2])
curr_row[7] ~mod+= this.mode12(l: curr_row[3], t: prev_row[7], tl: prev_row[3])
} else if mode == 13 { // ClampAddSubtractHalf(Average2(L, T), TL).
curr_row[4] ~mod+= this.mode13(l: curr_row[0], t: prev_row[4], tl: prev_row[0])
curr_row[5] ~mod+= this.mode13(l: curr_row[1], t: prev_row[5], tl: prev_row[1])
curr_row[6] ~mod+= this.mode13(l: curr_row[2], t: prev_row[6], tl: prev_row[2])
curr_row[7] ~mod+= this.mode13(l: curr_row[3], t: prev_row[7], tl: prev_row[3])
}
curr_row = curr_row[4 ..]
prev_row = prev_row[4 ..]
x += 1
}
y += 1
}
}
pri func decoder.absolute_difference(a: base.u32[..= 0xFF], b: base.u32[..= 0xFF]) base.u32[..= 0xFF] {
if args.a < args.b {
assert args.b > args.a via "a > b: b < a"()
return args.b - args.a
}
return args.a - args.b
}
pri func decoder.mode12(l: base.u8, t: base.u8, tl: base.u8) base.u8 {
var v : base.u32
v = ((args.l as base.u32) + (args.t as base.u32)) ~mod- (args.tl as base.u32)
// Return clamp(v), where v is a signed i32.
if v < 256 {
return v as base.u8
} else if v < 512 {
return 255
}
return 0
}
pri func decoder.mode13(l: base.u8, t: base.u8, tl: base.u8) base.u8 {
var x : base.u32[..= 0xFF]
var y : base.u32[..= 0xFF]
var z : base.u32
var v : base.u32
x = ((args.l as base.u32) + (args.t as base.u32)) / 2
y = args.tl as base.u32
// Calculate "v = x + ((x - y) / 2)" but using signed (i32) arithmetic
// instead of unsigned (u32) arithmetic.
z = x ~mod- y
v = x ~mod+ this.util.sign_extend_rshift_u32(a: z ~mod+ (z >> 31), n: 1)
// Return clamp(v), where v is a signed i32.
if v < 256 {
return v as base.u8
} else if v < 512 {
return 255
}
return 0
}
pri func decoder.apply_transform_cross_color!(pix: slice base.u8, tile_data: roslice base.u8) {
var tile_size_log2 : base.u32[..= 9]
var tiles_per_row : base.u32[..= 16895]
var mask : base.u32
var y : base.u32[..= 0x4000]
var x : base.u32[..= 0x4000]
var t : base.u64
var tile_data : roslice base.u8
var g2r : base.u32
var g2b : base.u32
var r2b : base.u32
var b : base.u8
var g : base.u8
var r : base.u8
tile_size_log2 = this.transform_tile_size_log2[1] as base.u32
tiles_per_row = (this.width + (((1 as base.u32) << tile_size_log2) - 1)) >> tile_size_log2
mask = ((1 as base.u32) << tile_size_log2) - 1
y = 0
while y < this.height {
assert y < 0x4000 via "a < b: a < c; c <= b"(c: this.height)
t = (4 * (y >> tile_size_log2) * tiles_per_row) as base.u64
tile_data = this.util.empty_slice_u8()
if t <= args.tile_data.length() {
tile_data = args.tile_data[t ..]
}
x = 0
while x < this.width,
inv y < 0x4000,
{
assert x < 0x4000 via "a < b: a < c; c <= b"(c: this.width)
if ((x & mask) == 0) and (tile_data.length() >= 4) {
g2r = this.util.sign_extend_convert_u8_u32(a: tile_data[0])
g2b = this.util.sign_extend_convert_u8_u32(a: tile_data[1])
r2b = this.util.sign_extend_convert_u8_u32(a: tile_data[2])
tile_data = tile_data[4 ..]
}
if args.pix.length() >= 4 {
b = args.pix[0]
g = args.pix[1]
r = args.pix[2]
r ~mod+= (((this.util.sign_extend_convert_u8_u32(a: g) ~mod* g2r) >> 5) & 0xFF) as base.u8
b ~mod+= (((this.util.sign_extend_convert_u8_u32(a: g) ~mod* g2b) >> 5) & 0xFF) as base.u8
b ~mod+= (((this.util.sign_extend_convert_u8_u32(a: r) ~mod* r2b) >> 5) & 0xFF) as base.u8
args.pix[0] = b
args.pix[2] = r
args.pix = args.pix[4 ..]
}
x += 1
}
y += 1
}
}
pri func decoder.apply_transform_subtract_green!(pix: slice base.u8) {
var p : slice base.u8
var g : base.u8
iterate (p = args.pix)(length: 4, advance: 4, unroll: 1) {
g = p[1]
p[0] ~mod+= g
p[2] ~mod+= g
}
}
pri func decoder.apply_transform_color_indexing!(pix: slice base.u8) {
var tile_size_log2 : base.u32[..= 9]
var bits_per_pixel : base.u32[..= 8]
var x_mask : base.u32
var s_mask : base.u32[..= 0xFF]
var src_index : base.u64
var y : base.u32[..= 0x4000]
var di : base.u64
var dj : base.u64
var dst : slice base.u8
var x : base.u32
var s : base.u32[..= 0xFF]
var p : base.u32[..= 0x3FC]
var p0 : base.u8
var p1 : base.u8
var p2 : base.u8
var p3 : base.u8
tile_size_log2 = this.transform_tile_size_log2[3] as base.u32
if tile_size_log2 == 0 {
iterate (dst = args.pix)(length: 4, advance: 4, unroll: 1) {
p = (dst[1] as base.u32) * 4
p0 = this.palette[p + 0]
p1 = this.palette[p + 1]
p2 = this.palette[p + 2]
p3 = this.palette[p + 3]
dst[0] = p0
dst[1] = p1
dst[2] = p2
dst[3] = p3
}
return nothing
}
bits_per_pixel = (8 as base.u32) >> tile_size_log2
x_mask = ((1 as base.u32) << tile_size_log2) - 1
s_mask = ((1 as base.u32) << bits_per_pixel) - 1
// The "+ 1" selects the green pixel of the BGRA 4-byte group.
src_index = (this.workbuf_offset_for_color_indexing + 1) as base.u64
y = 0
while y < this.height {
assert y < 0x4000 via "a < b: a < c; c <= b"(c: this.height)
di = (4 * (y + 0) * this.width) as base.u64
dj = (4 * (y + 1) * this.width) as base.u64
if (di > dj) or (dj > args.pix.length()) {
break
}
dst = args.pix[di .. dj]
x = 0
while dst.length() >= 4,
inv y < 0x4000,
{
if ((x & x_mask) == 0) and (src_index < args.pix.length()) {
s = args.pix[src_index] as base.u32
src_index ~mod+= 4
}
p = (s & s_mask) * 4
s >>= bits_per_pixel
p0 = this.palette[p + 0]
p1 = this.palette[p + 1]
p2 = this.palette[p + 2]
p3 = this.palette[p + 3]
dst[0] = p0
dst[1] = p1
dst[2] = p2
dst[3] = p3
dst = dst[4 ..]
x ~mod+= 1
}
y += 1
}
}