blob: a8209f2cc58801d3deb3f322c931879715a285db [file] [log] [blame]
// Copyright 2021 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pri func hasher.up_x86_avx2!(x: slice base.u8),
choose cpu_arch >= x86_avx2,
{
// These variables are the same as the non-SIMD version.
var s1 : base.u32
var s2 : base.u32
var remaining : slice base.u8
var p : slice base.u8
// The remaining variables are specific to the SIMD version.
var util : base.x86_avx2_utility
var zeroes : base.x86_m256i
var ones : base.x86_m256i
var weights : base.x86_m256i
var q : base.x86_m256i
var v1 : base.x86_m256i
var v2 : base.x86_m256i
var v2j : base.x86_m256i
var v2k : base.x86_m256i
var h1 : base.x86_m128i
var h2 : base.x86_m128i
var num_iterate_bytes : base.u32
var tail_index : base.u64
// zeroes and ones are uniform u16×8 vectors.
zeroes = util.make_m256i_repeat_u16(a: 0)
ones = util.make_m256i_repeat_u16(a: 1)
// weights form the sequence 32, 31, 30, ..., 1.
weights = util.make_m256i_multiple_u8(
a00: 0x20, a01: 0x1F, a02: 0x1E, a03: 0x1D,
a04: 0x1C, a05: 0x1B, a06: 0x1A, a07: 0x19,
a08: 0x18, a09: 0x17, a10: 0x16, a11: 0x15,
a12: 0x14, a13: 0x13, a14: 0x12, a15: 0x11,
a16: 0x10, a17: 0x0F, a18: 0x0E, a19: 0x0D,
a20: 0x0C, a21: 0x0B, a22: 0x0A, a23: 0x09,
a24: 0x08, a25: 0x07, a26: 0x06, a27: 0x05,
a28: 0x04, a29: 0x03, a30: 0x02, a31: 0x01)
// Decompose this.state.
s1 = this.state.low_bits(n: 16)
s2 = this.state.high_bits(n: 16)
// Just like the non-SIMD version, loop over args.x up to almost-5552 bytes
// at a time. The slightly smaller 5536 is the largest multiple of 32 less
// than non-SIMD's 5552.
while args.x.length() > 0 {
remaining = args.x[.. 0]
if args.x.length() > 5536 {
remaining = args.x[5536 ..]
args.x = args.x[.. 5536]
}
// The s1 state is the sum of the input bytes and the s2 state is the
// sum of the s1 state at each 1-byte step. Inside the iterate loop
// below, but starting fresh at each outer while loop iteration, s1
// consists of three parts (called s1i, s1j and s1k):
// - s1i: the initial value, before any 32-byte iterations.
// - s1j: the total contribution from previous 32-byte iterations.
// - s1k: the contribution due to the current 32-byte iteration.
//
// The upcoming iterate loop (at 32 bytes per iteration) encompasses
// num_iterate_bytes 1-byte steps. We hoist the total s1i contribution,
// (s1i * num_iterate_bytes) out here.
num_iterate_bytes = (args.x.length() & 0xFFFF_FFE0) as base.u32
s2 ~mod+= (s1 ~mod* num_iterate_bytes)
// Zero-initialize some u32×8 vectors associated with the two state
// variables s1 and s2. The iterate loop accumulates eight parallel u32
// sums in each vector. A post-iterate step merges the eight u32 sums
// into a single u32 sum.
v1 = util.make_m256i_zeroes()
v2j = util.make_m256i_zeroes()
v2k = util.make_m256i_zeroes()
// The inner loop.
iterate (p = args.x)(length: 32, advance: 32, unroll: 1) {
// AVX2 works with 32-byte registers.
//
// Let q = [u8×32: p00, p01, p02, ..., p31]
q = util.make_m256i_slice256(a: p)
// For v2j, we need to calculate the sums of the s1j terms for each
// of p's 32 elements. This is simply 32 times the same number,
// that number being the sum of v1's eight u32 accumulators. We add
// v1 now and multiply by 32 later, outside the inner loop.
v2j = v2j._mm256_add_epi32(b: v1)
// For v1, we need to add the elements of p. Computing the sum of
// absolute differences (_mm256_sad_epu8) with zero just sums the
// elements. q._mm256_sad_epu8(b: zeroes) equals
// [u64×4: p00 + p01 + ... + p07, p08 + p09 + ... + p15,
// p16 + p17 + ... + p23, p24 + p25 + ... + p31]
// This is equivalent (little-endian) to:
// [u32×8: p00 + p01 + ... + p07, 0, p08 + p09 + ... + p15, 0,
// p16 + p17 + ... + p23, 0, p24 + p25 + ... + p31, 0]
// We accumulate those "sum of q's elements" in v1.
v1 = v1._mm256_add_epi32(b: q._mm256_sad_epu8(b: zeroes))
// For v2k, we need to calculate a weighted sum: ((32 * p00) + (31
// * p01) + (30 * p02) + ... + (1 * p31)).
//
// The _mm256_maddubs_epi16 call (vertically multiply u8 columns
// and then horizontally sum u16 pairs) produces:
// [u16×16: ((32*p00)+(31*p01)),
// ((30*p02)+(29*p03)),
// ...
// (( 2*p30)+( 1*p31))]
//
// The ones._mm256_madd_epi16(b: etc) call is a multiply-add (note
// that it's "madd" not "add"). Multiplying by 1 is a no-op, so
// this sums u16 pairs to produce u32 values:
// [u32×8: (((32*p00)+(31*p01)+(30*p02)+(29*p03)),
// (((28*p04)+(27*p05)+(26*p06)+(25*p07)),
// ...
// ((( 4*p28)+( 3*p29)+( 2*p30)+( 1*p31))]
v2k = v2k._mm256_add_epi32(b: ones._mm256_madd_epi16(
b: q._mm256_maddubs_epi16(b: weights)))
}
// Merge the eight parallel u32 sums (v1) into the single u32 sum (s1).
// First, merge the 256-bit (u32×8) v1 into the 128-bit (u32×4) h1.
h1 = v1._mm256_extracti128_si256(imm8: 0)._mm_add_epi32(
b: v1._mm256_extracti128_si256(imm8: 1))
// Starting with a u32×4 vector [x0, x1, x2, x3]:
// - shuffling with 0b1011_0001 gives [x1, x0, x3, x2].
// - adding gives [x0+x1, x0+x1, x2+x3, x2+x3].
// - shuffling with 0b0100_1110 gives [x2+x3, x2+x3, x0+x1, x0+x1].
// - adding gives [x0+x1+x2+x3, ditto, ditto, ditto].
// The truncate_u32 call extracts the first u32: x0+x1+x2+x3.
h1 = h1._mm_add_epi32(b: h1._mm_shuffle_epi32(imm8: 0b1011_0001))
h1 = h1._mm_add_epi32(b: h1._mm_shuffle_epi32(imm8: 0b0100_1110))
s1 ~mod+= h1.truncate_u32()
// Combine v2j and v2k. The slli (shift logical left immediate) by 5
// multiplies v2j's eight u32 elements each by 32, alluded to earlier.
v2 = v2k._mm256_add_epi32(b: v2j._mm256_slli_epi32(imm8: 5))
// Similarly merge v2 (a u32×8 vector) into s2 (a u32 scalar).
h2 = v2._mm256_extracti128_si256(imm8: 0)._mm_add_epi32(
b: v2._mm256_extracti128_si256(imm8: 1))
h2 = h2._mm_add_epi32(b: h2._mm_shuffle_epi32(imm8: 0b1011_0001))
h2 = h2._mm_add_epi32(b: h2._mm_shuffle_epi32(imm8: 0b0100_1110))
s2 ~mod+= h2.truncate_u32()
// Handle the tail of args.x that wasn't a complete 32-byte chunk.
tail_index = args.x.length() & 0xFFFF_FFFF_FFFF_FFE0 // And-not 32.
if tail_index < args.x.length() {
iterate (p = args.x[tail_index ..])(length: 1, advance: 1, unroll: 1) {
s1 ~mod+= p[0] as base.u32
s2 ~mod+= s1
}
}
// The rest of this function is the same as the non-SIMD version.
s1 %= 65521
s2 %= 65521
args.x = remaining
} endwhile
this.state = ((s2 & 0xFFFF) << 16) | (s1 & 0xFFFF)
}