| // SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense |
| |
| // The reduction phase for the stack monoid |
| |
| #version 450 |
| |
| #define N_SEQ 1 |
| #define LG_WG_SIZE 9 |
| #define WG_SIZE (1 << LG_WG_SIZE) |
| #define PART_SIZE (WG_SIZE * N_SEQ) |
| |
| layout(local_size_x = WG_SIZE, local_size_y = 1) in; |
| |
| // The bicyclic monoid |
| struct Bic { |
| uint a; |
| uint b; |
| }; |
| |
| Bic bic_combine(Bic x, Bic y) { |
| uint m = min(x.b, y.a); |
| return Bic(x.a + y.a - m, x.b + y.b - m); |
| } |
| |
| layout(binding = 0) readonly buffer InBuf { |
| uint[] inbuf; |
| }; |
| |
| layout(binding = 1) buffer OutBuf { |
| Bic[] outbuf; |
| }; |
| |
| layout(binding = 2) buffer StackBuf { |
| uint[] stack; |
| }; |
| |
| shared Bic sh_bic[WG_SIZE]; |
| |
| void main() { |
| uint inp[N_SEQ]; |
| inp[0] = inbuf[gl_GlobalInvocationID.x * N_SEQ]; |
| // reverse scan of bicyclic semigroup |
| Bic bic = Bic(1 - inp[0], inp[0]); |
| for (uint i = 1; i < N_SEQ; i++) { |
| inp[i] = inbuf[gl_GlobalInvocationID.x * N_SEQ + i]; |
| Bic other = Bic(1 - inp[i], inp[i]); |
| bic = bic_combine(bic, other); |
| } |
| sh_bic[gl_LocalInvocationID.x] = bic; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| barrier(); |
| if (gl_LocalInvocationID.x + (1u << i) < WG_SIZE) { |
| Bic other = sh_bic[gl_LocalInvocationID.x + (1u << i)]; |
| bic = bic_combine(bic, other); |
| } |
| barrier(); |
| sh_bic[gl_LocalInvocationID.x] = bic; |
| } |
| if (gl_LocalInvocationID.x == 0) { |
| outbuf[gl_WorkGroupID.x] = bic; |
| } |
| barrier(); |
| uint size = sh_bic[0].b; |
| bic = Bic(0, 0); |
| if (gl_LocalInvocationID.x + 1 < WG_SIZE) { |
| bic = sh_bic[gl_LocalInvocationID.x + 1]; |
| } |
| // stream compaction based on exclusive scan |
| uint out_ix = gl_WorkGroupID.x * PART_SIZE + size - bic.b; |
| for (uint i = N_SEQ; i > 0; i--) { |
| if (inp[i - 1] == 1 && bic.a == 0) { |
| out_ix--; |
| stack[out_ix] = gl_GlobalInvocationID.x * N_SEQ + i - 1; |
| } |
| Bic other = Bic(1 - inp[i - 1], inp[i - 1]); |
| bic = bic_combine(other, bic); |
| } |
| } |