blob: 47d7eabb6b0b5285aefd03b90561f8e045eae94e [file] [log] [blame]
// Copyright 2021 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pri func hasher.up_arm_neon!(x: slice base.u8),
choose cpu_arch >= arm_neon,
{
// These variables are the same as the non-SIMD version.
var s1 : base.u32
var s2 : base.u32
var remaining : slice base.u8
var p : slice base.u8
// The remaining variables are specific to the SIMD version.
var util : base.arm_neon_utility
var p__left : base.arm_neon_u8x16
var p_right : base.arm_neon_u8x16
var v1 : base.arm_neon_u32x4
var v2 : base.arm_neon_u32x4
var col0 : base.arm_neon_u16x8
var col1 : base.arm_neon_u16x8
var col2 : base.arm_neon_u16x8
var col3 : base.arm_neon_u16x8
var sum1 : base.arm_neon_u32x2
var sum2 : base.arm_neon_u32x2
var sum12 : base.arm_neon_u32x2
var num_iterate_bytes : base.u32
var tail_index : base.u64
// Decompose this.state.
s1 = this.state.low_bits(n: 16)
s2 = this.state.high_bits(n: 16)
// Align to a 16-byte boundary.
while (args.x.length() > 0) and ((15 & args.x.uintptr_low_12_bits()) <> 0) {
s1 ~mod+= args.x[0] as base.u32
s2 ~mod+= s1
args.x = args.x[1 ..]
} endwhile
s1 %= 65521
s2 %= 65521
// Just like the non-SIMD version, loop over args.x up to almost-5552 bytes
// at a time. The slightly smaller 5536 is the largest multiple of 32 less
// than non-SIMD's 5552.
while args.x.length() > 0 {
remaining = args.x[.. 0]
if args.x.length() > 5536 {
remaining = args.x[5536 ..]
args.x = args.x[.. 5536]
}
// The s1 state is the sum of the input bytes and the s2 state is the
// sum of the s1 state at each 1-byte step. Inside the iterate loop
// below, but starting fresh at each outer while loop iteration, s1
// consists of three parts (called s1i, s1j and s1k):
// - s1i: the initial value, before any 32-byte iterations.
// - s1j: the total contribution from previous 32-byte iterations.
// - s1k: the contribution due to the current 32-byte iteration.
//
// The upcoming iterate loop (at 32 bytes per iteration) encompasses
// num_iterate_bytes 1-byte steps. We hoist the total s1i contribution,
// (s1i * num_iterate_bytes) out here.
num_iterate_bytes = (args.x.length() & 0xFFFF_FFE0) as base.u32
s2 ~mod+= (s1 ~mod* num_iterate_bytes)
// Zero-initialize some u32×4 vectors associated with the two state
// variables s1 and s2. The iterate loop accumulates four parallel u32
// sums in each vector. A post-iterate step merges the four u32 sums
// into a single u32 sum.
v1 = util.make_u32x4_repeat(a: 0)
v2 = util.make_u32x4_repeat(a: 0)
col0 = util.make_u16x8_repeat(a: 0)
col1 = util.make_u16x8_repeat(a: 0)
col2 = util.make_u16x8_repeat(a: 0)
col3 = util.make_u16x8_repeat(a: 0)
// The inner loop.
iterate (p = args.x)(length: 32, advance: 32, unroll: 1) {
// Split the 32-byte p into left and right halves. NEON works with
// 16-byte registers.
//
// Let p__left = [u8×16: p00, p01, p02, ..., p15]
// Let p_right = [u8×16: p16, p17, p18, ..., p31]
p__left = util.make_u8x16_slice128(a: p[.. 16])
p_right = util.make_u8x16_slice128(a: p[16 .. 32])
// For v2j, we need to calculate the sums of the s1j terms for each
// of p's 32 elements. This is simply 32 times the same number,
// that number being the sum of v1's four u32 accumulators. We add
// v1 now and multiply by 32 later, outside the inner loop.
v2 = v2.vaddq_u32(b: v1)
// For v1, we need to add the elements of p.
//
// p__left.vpaddlq_u8() is:
// [u16×8: p00 + p01, p02 + p03, ..., p14 + p15]
//
// Combining (vpadalq_u8) that with p_right gives:
// [u16×8: p00 + p01 + p16 + p17,
// p02 + p03 + p18 + p19,
// ...
// p14 + p15 + p30 + p31]
//
// Pair-wise summing (and widening) again sets v1 to:
// [u32×4: Σ (p00 + p01 + p02 + p03 + p16 + p17 + p18 + p19),
// Σ (p04 + p05 + p06 + p07 + p20 + p21 + p22 + p23),
// ...
// Σ (p12 + p13 + p14 + p15 + p28 + p29 + p30 + p31)]
v1 = v1.vpadalq_u16(b: p__left.vpaddlq_u8().vpadalq_u8(b: p_right))
// Accumulate column sums:
// col0 = [u16×8: Σ p00, Σ p01, ..., Σ p07]
// col1 = [u16×8: Σ p08, Σ p09, ..., Σ p15]
// col2 = [u16×8: Σ p16, Σ p17, ..., Σ p23]
// col3 = [u16×8: Σ p24, Σ p25, ..., Σ p31]
col0 = col0.vaddw_u8(b: p__left.vget_low_u8())
col1 = col1.vaddw_u8(b: p__left.vget_high_u8())
col2 = col2.vaddw_u8(b: p_right.vget_low_u8())
col3 = col3.vaddw_u8(b: p_right.vget_high_u8())
}
// Multiply v2j's four u32 elements each by 32, alluded to earlier.
v2 = v2.vshlq_n_u32(b: 5)
// Add the v2k contributions in eight u32×4 vectors:
// [u32×4: 0x20 * (Σ p00), 0x1F * (Σ p01), ..., 0x1D * (Σ p03)]
// [u32×4: 0x1C * (Σ p04), 0x1B * (Σ p05), ..., 0x19 * (Σ p07)]
// ...
// [u32×4: 0x04 * (Σ p28), 0x03 * (Σ p29), ..., 0x01 * (Σ p31)]
v2 = v2.vmlal_u16(
b: col0.vget_low_u16(),
c: util.make_u16x4_multiple(a00: 0x20, a01: 0x1F, a02: 0x1E, a03: 0x1D))
v2 = v2.vmlal_u16(
b: col0.vget_high_u16(),
c: util.make_u16x4_multiple(a00: 0x1C, a01: 0x1B, a02: 0x1A, a03: 0x19))
v2 = v2.vmlal_u16(
b: col1.vget_low_u16(),
c: util.make_u16x4_multiple(a00: 0x18, a01: 0x17, a02: 0x16, a03: 0x15))
v2 = v2.vmlal_u16(
b: col1.vget_high_u16(),
c: util.make_u16x4_multiple(a00: 0x14, a01: 0x13, a02: 0x12, a03: 0x11))
v2 = v2.vmlal_u16(
b: col2.vget_low_u16(),
c: util.make_u16x4_multiple(a00: 0x10, a01: 0x0F, a02: 0x0E, a03: 0x0D))
v2 = v2.vmlal_u16(
b: col2.vget_high_u16(),
c: util.make_u16x4_multiple(a00: 0x0C, a01: 0x0B, a02: 0x0A, a03: 0x09))
v2 = v2.vmlal_u16(
b: col3.vget_low_u16(),
c: util.make_u16x4_multiple(a00: 0x08, a01: 0x07, a02: 0x06, a03: 0x05))
v2 = v2.vmlal_u16(
b: col3.vget_high_u16(),
c: util.make_u16x4_multiple(a00: 0x04, a01: 0x03, a02: 0x02, a03: 0x01))
// Merge the four parallel u32 sums (v1) into the single u32 sum (s1)
// and ditto for v2. Starting with [u32×4: xx_0, xx_1, xx_2, xx_3]:
// sum1 = [u32×2: v1_0 + v1_1, v1_2 + v1_3]
// sum2 = [u32×2: v2_0 + v2_1, v2_2 + v2_3]
// sum12 = [u32×2: v1_0 + v1_1 + v1_2 + v1_3,
// v2_0 + v2_1 + v2_2 + v2_3]
sum1 = v1.vget_low_u32().vpadd_u32(b: v1.vget_high_u32())
sum2 = v2.vget_low_u32().vpadd_u32(b: v2.vget_high_u32())
sum12 = sum1.vpadd_u32(b: sum2)
s1 ~mod+= sum12.vget_lane_u32(b: 0)
s2 ~mod+= sum12.vget_lane_u32(b: 1)
// Handle the tail of args.x that wasn't a complete 32-byte chunk.
tail_index = args.x.length() & 0xFFFF_FFFF_FFFF_FFE0 // And-not 32.
if tail_index < args.x.length() {
iterate (p = args.x[tail_index ..])(length: 1, advance: 1, unroll: 1) {
s1 ~mod+= p[0] as base.u32
s2 ~mod+= s1
}
}
// The rest of this function is the same as the non-SIMD version.
s1 %= 65521
s2 %= 65521
args.x = remaining
} endwhile
this.state = ((s2 & 0xFFFF) << 16) | (s1 & 0xFFFF)
}