// Copyright 2021 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// --------

// Filter 1: Sub.

// This (filter = 1, distance = 3) implementation doesn't actually bench faster
// than the non-SIMD one.
//
// pri func decoder.filter_1_distance_3_sse42!(curr: slice base.u8),
//     choose cpu_arch >= x86_sse42,
// {
//     var c    : slice base.u8
//     var x128 : base.x86_m128i
//     var a128 : base.x86_m128i
//
//     iterate (c = args.curr)(length: 4, advance: 3, unroll: 2) {
//         x128.load_u32!(a: c.peek_u32le())
//         x128 = x128._mm_add_epi8(b: a128)
//         a128 = x128
//         c.poke_u24le!(a: x128.truncate_u32())
//     } else (length: 3, advance: 3, unroll: 1) {
//         x128.load_u32!(a: c.peek_u24le_as_u32())
//         x128 = x128._mm_add_epi8(b: a128)
//         c.poke_u24le!(a: x128.truncate_u32())
//     }
// }
//
// Note that "more SIMD" doesn't always mean faster compute. See
// https://github.com/google/wuffs/commit/1660f9268621ed4415b3b363f0a0e1026d4aa83d
// "Have std/png filter_1_distance_? use more SIMD" for a pessimizing example.

pri func decoder.filter_1_distance_4_sse42!(curr: slice base.u8),
	choose cpu_arch >= x86_sse42,
{
	var c    : slice base.u8
	var x128 : base.x86_m128i
	var a128 : base.x86_m128i

	iterate (c = args.curr)(length: 4, advance: 4, unroll: 2) {
		x128.load_u32!(a: c.peek_u32le())
		x128 = x128._mm_add_epi8(b: a128)
		a128 = x128
		c.poke_u32le!(a: x128.truncate_u32())
	}
}

// --------

// Filter 3: Average.

// Similar to filter_1_distance_3_sse42, the SIMD implementation for (filter =
// 3, distance = 3) doesn't actually bench faster than the non-SIMD one.
//
// pri func decoder.filter_3_distance_3_sse42!(curr: slice base.u8, prev: slice base.u8),
//     choose cpu_arch >= x86_sse42,
// {
//     etc
// }

pri func decoder.filter_3_distance_4_sse42!(curr: slice base.u8, prev: slice base.u8),
	choose cpu_arch >= x86_sse42,
{
	var c    : slice base.u8
	var p    : slice base.u8
	var x128 : base.x86_m128i
	var a128 : base.x86_m128i
	var b128 : base.x86_m128i
	var p128 : base.x86_m128i
	var k128 : base.x86_m128i

	if args.prev.length() == 0 {
		k128 = k128.create_mm_set1_epi8(a: 0xFE)
		iterate (c = args.curr)(length: 4, advance: 4, unroll: 2) {
			// The predictor, p128, is just half (rounded down) of the previous
			// pixel, a128. In this branch, b128 stays zero so the average of
			// a128 and b128 is just half of a128. _mm_avg_epu8 rounds up, but
			// (a128 & 0xFE_repeated) takes out the low bits of a128's bytes.
			p128 = a128._mm_and_si128(b: k128)._mm_avg_epu8(b: b128)

			// Add the predictor to the residual and, for the next iteration,
			// set its previous pixel, a128, to this one, x128.
			x128.load_u32!(a: c.peek_u32le())
			x128 = x128._mm_add_epi8(b: p128)
			a128 = x128
			c.poke_u32le!(a: x128.truncate_u32())
		}

	} else {
		k128 = k128.create_mm_set1_epi8(a: 0x01)
		iterate (c = args.curr, p = args.prev)(length: 4, advance: 4, unroll: 2) {
			// Load the pixel from the row above.
			b128.load_u32!(a: p.peek_u32le())

			// The predictor, p128, is the average (rounded down) of the
			// previous pixel, a128, and the pixel above, b128.
			p128 = a128._mm_avg_epu8(b: b128)

			// Subtract a correction term because _mm_avg_epu8 rounds up but
			// the PNG filter rounds down. The correction term is the low bit
			// of each byte of (a128 ^ b128).
			p128 = p128._mm_sub_epi8(b: k128._mm_and_si128(b: a128._mm_xor_si128(b: b128)))

			// Add the predictor to the residual and, for the next iteration,
			// set its previous pixel, a128, to this one, x128.
			x128.load_u32!(a: c.peek_u32le())
			x128 = x128._mm_add_epi8(b: p128)
			a128 = x128
			c.poke_u32le!(a: x128.truncate_u32())
		}
	}
}

// --------

// Filter 4: Paeth.

pri func decoder.filter_4_distance_3_sse42!(curr: slice base.u8, prev: slice base.u8),
	choose cpu_arch >= x86_sse42,
{
	// See the comments in filter_4_distance_4_sse42 for an explanation of how
	// this works. That function's single loop is copied twice here, once with
	// "length: 4" and once with "length: 3". It's generally faster to load 4
	// bytes at a time instead of 3.
	//
	// Differences between that function and this one are marked with a §.

	var c           : slice base.u8
	var p           : slice base.u8
	var x128        : base.x86_m128i
	var a128        : base.x86_m128i
	var b128        : base.x86_m128i
	var c128        : base.x86_m128i
	var p128        : base.x86_m128i
	var pa128       : base.x86_m128i
	var pb128       : base.x86_m128i
	var pc128       : base.x86_m128i
	var smallest128 : base.x86_m128i
	var z128        : base.x86_m128i

	// § The advance is 3, not 4.
	iterate (c = args.curr, p = args.prev)(length: 4, advance: 3, unroll: 2) {
		b128.load_u32!(a: p.peek_u32le())
		b128 = b128._mm_unpacklo_epi8(b: z128)
		pa128 = b128._mm_sub_epi16(b: c128)
		pb128 = a128._mm_sub_epi16(b: c128)
		pc128 = pa128._mm_add_epi16(b: pb128)
		pa128 = pa128._mm_abs_epi16()
		pb128 = pb128._mm_abs_epi16()
		pc128 = pc128._mm_abs_epi16()
		smallest128 = pc128._mm_min_epi16(b: pb128._mm_min_epi16(b: pa128))
		p128 = c128._mm_blendv_epi8(
			b: b128,
			mask: smallest128._mm_cmpeq_epi16(b: pb128))._mm_blendv_epi8(
			b: a128,
			mask: smallest128._mm_cmpeq_epi16(b: pa128))
		x128.load_u32!(a: c.peek_u32le())
		x128 = x128._mm_unpacklo_epi8(b: z128)
		x128 = x128._mm_add_epi8(b: p128)
		a128 = x128
		c128 = b128
		x128 = x128._mm_packus_epi16(b: x128)
		// § poke_u24le replaces poke_u32le.
		c.poke_u24le!(a: x128.truncate_u32())

		// § The length and advance are both 3, not 4.
	} else (length: 3, advance: 3, unroll: 1) {
		// § peek_u24le_as_u32 replaces peek_u32le.
		b128.load_u32!(a: p.peek_u24le_as_u32())
		b128 = b128._mm_unpacklo_epi8(b: z128)
		pa128 = b128._mm_sub_epi16(b: c128)
		pb128 = a128._mm_sub_epi16(b: c128)
		pc128 = pa128._mm_add_epi16(b: pb128)
		pa128 = pa128._mm_abs_epi16()
		pb128 = pb128._mm_abs_epi16()
		pc128 = pc128._mm_abs_epi16()
		smallest128 = pc128._mm_min_epi16(b: pb128._mm_min_epi16(b: pa128))
		p128 = c128._mm_blendv_epi8(
			b: b128,
			mask: smallest128._mm_cmpeq_epi16(b: pb128))._mm_blendv_epi8(
			b: a128,
			mask: smallest128._mm_cmpeq_epi16(b: pa128))
		// § peek_u24le_as_u32 replaces peek_u32le.
		x128.load_u32!(a: c.peek_u24le_as_u32())
		x128 = x128._mm_unpacklo_epi8(b: z128)
		x128 = x128._mm_add_epi8(b: p128)
		// § These assignments are unnecessary; this is the last iteration.
		// a128 = x128
		// c128 = b128
		x128 = x128._mm_packus_epi16(b: x128)
		// § poke_u24le replaces poke_u32le.
		c.poke_u24le!(a: x128.truncate_u32())
	}
}

pri func decoder.filter_4_distance_4_sse42!(curr: slice base.u8, prev: slice base.u8),
	choose cpu_arch >= x86_sse42,
{
	var c           : slice base.u8
	var p           : slice base.u8
	var x128        : base.x86_m128i
	var a128        : base.x86_m128i
	var b128        : base.x86_m128i
	var c128        : base.x86_m128i
	var p128        : base.x86_m128i
	var pa128       : base.x86_m128i
	var pb128       : base.x86_m128i
	var pc128       : base.x86_m128i
	var smallest128 : base.x86_m128i
	var z128        : base.x86_m128i

	iterate (c = args.curr, p = args.prev)(length: 4, advance: 4, unroll: 2) {
		// Load the pixel from the row above.
		b128.load_u32!(a: p.peek_u32le())

		// Convert from u8 to i16 by unpacking it with zeroes.
		b128 = b128._mm_unpacklo_epi8(b: z128)

		// Compute:
		//  - pa128 = (m128 - a128)
		//  - pb128 = (m128 - b128)
		//  - pc128 = (m128 - c128)
		// where m128 = (a128 + b128 - c128).
		pa128 = b128._mm_sub_epi16(b: c128)
		pb128 = a128._mm_sub_epi16(b: c128)
		pc128 = pa128._mm_add_epi16(b: pb128)

		// Compute the smallest absolute value of pa128, pb128 and pc128.
		pa128 = pa128._mm_abs_epi16()
		pb128 = pb128._mm_abs_epi16()
		pc128 = pc128._mm_abs_epi16()
		smallest128 = pc128._mm_min_epi16(b: pb128._mm_min_epi16(b: pa128))

		// The predictor, p128, is whichever of a128, b128 or c128 such that
		// pa128, pb128 or pc128 matches this smallest absolute value. Ties are
		// broken in favor of a128 then b128 then c128.
		//
		// The a._mm_blendv_epi8(b, mask) method picks b when mask is true.
		p128 = c128._mm_blendv_epi8(
			b: b128,
			mask: smallest128._mm_cmpeq_epi16(b: pb128))._mm_blendv_epi8(
			b: a128,
			mask: smallest128._mm_cmpeq_epi16(b: pa128))

		// Add the predictor to the residual and, for the next iteration, set
		// its previous pixels, a128 and c128, to x128 and b128.
		//
		// Again, convert between u8 and i16 by unpacking and re-packing.
		x128.load_u32!(a: c.peek_u32le())
		x128 = x128._mm_unpacklo_epi8(b: z128)
		x128 = x128._mm_add_epi8(b: p128)
		a128 = x128
		c128 = b128
		x128 = x128._mm_packus_epi16(b: x128)
		c.poke_u32le!(a: x128.truncate_u32())
	}
}
