blob: 71da8b68a3063267ca9e4edbb790c92793ab979a [file] [log] [blame]
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// The reduction phase for the stack monoid
#version 450
#define N_SEQ 1
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PART_SIZE (WG_SIZE * N_SEQ)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
// The bicyclic monoid
struct Bic {
uint a;
uint b;
};
Bic bic_combine(Bic x, Bic y) {
uint m = min(x.b, y.a);
return Bic(x.a + y.a - m, x.b + y.b - m);
}
layout(binding = 0) readonly buffer InBuf {
uint[] inbuf;
};
layout(binding = 1) buffer OutBuf {
Bic[] outbuf;
};
layout(binding = 2) buffer StackBuf {
uint[] stack;
};
shared Bic sh_bic[WG_SIZE];
void main() {
uint inp[N_SEQ];
inp[0] = inbuf[gl_GlobalInvocationID.x * N_SEQ];
// reverse scan of bicyclic semigroup
Bic bic = Bic(1 - inp[0], inp[0]);
for (uint i = 1; i < N_SEQ; i++) {
inp[i] = inbuf[gl_GlobalInvocationID.x * N_SEQ + i];
Bic other = Bic(1 - inp[i], inp[i]);
bic = bic_combine(bic, other);
}
sh_bic[gl_LocalInvocationID.x] = bic;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x + (1u << i) < WG_SIZE) {
Bic other = sh_bic[gl_LocalInvocationID.x + (1u << i)];
bic = bic_combine(bic, other);
}
barrier();
sh_bic[gl_LocalInvocationID.x] = bic;
}
if (gl_LocalInvocationID.x == 0) {
outbuf[gl_WorkGroupID.x] = bic;
}
barrier();
uint size = sh_bic[0].b;
bic = Bic(0, 0);
if (gl_LocalInvocationID.x + 1 < WG_SIZE) {
bic = sh_bic[gl_LocalInvocationID.x + 1];
}
// stream compaction based on exclusive scan
uint out_ix = gl_WorkGroupID.x * PART_SIZE + size - bic.b;
for (uint i = N_SEQ; i > 0; i--) {
if (inp[i - 1] == 1 && bic.a == 0) {
out_ix--;
stack[out_ix] = gl_GlobalInvocationID.x * N_SEQ + i - 1;
}
Bic other = Bic(1 - inp[i - 1], inp[i - 1]);
bic = bic_combine(other, bic);
}
}