| // SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense |
| |
| // The main phase for the stack monoid |
| |
| #version 450 |
| |
| #define LG_N_SEQ 0 |
| #define N_SEQ (1 << LG_N_SEQ) |
| #define LG_WG_SIZE 9 |
| #define WG_SIZE (1 << LG_WG_SIZE) |
| #define PART_SIZE (WG_SIZE * N_SEQ) |
| |
| layout(local_size_x = WG_SIZE, local_size_y = 1) in; |
| |
| // The bicyclic monoid |
| struct Bic { |
| uint a; |
| uint b; |
| }; |
| |
| Bic bic_combine(Bic x, Bic y) { |
| uint m = min(x.b, y.a); |
| return Bic(x.a + y.a - m, x.b + y.b - m); |
| } |
| |
| layout(binding = 0) readonly buffer InBuf { |
| uint[] inbuf; |
| }; |
| |
| layout(binding = 1) readonly buffer BicBuf { |
| Bic[] bicbuf; |
| }; |
| |
| layout(binding = 2) readonly buffer StackBuf { |
| uint[] stack; |
| }; |
| |
| layout(binding = 3) buffer OutBuf { |
| uint[] outbuf; |
| }; |
| |
| shared Bic sh_bic[WG_SIZE * 2 - 2]; |
| shared uint sh_stack[PART_SIZE]; |
| #if N_SEQ > 1 |
| shared uint sh_next[WG_SIZE]; |
| shared uint sh_link[WG_SIZE]; |
| shared uint sh_bitmaps[WG_SIZE]; |
| #endif |
| |
| // search for the ix'th set bit in bitmask. |
| // More formally, the largest i such that bitmask & ((1 << i) - 1) has ix bits set. |
| uint search_bit_set(uint bitmask, uint ix) { |
| uint result = 0; |
| // TODO: instead of 5, use LG_N_SEQ |
| for (uint j = 0; j < 5; j++) { |
| uint step = 1u << (4 - j); |
| if (bitCount(bitmask & ((1u << (result + step)) - 1)) <= ix) { |
| result += step; |
| } |
| } |
| return result; |
| } |
| |
| // Search for predecessor node. |
| // For N_SEQ == 1: |
| // On return, ix is the smallest value such that reduce(ix..th).b == 0 |
| // bic contains the reduction (and may start with a non-default value) |
| // |
| // For N_SEQ > 1: |
| // If return value is nonnegative, then it is an index to an element |
| // within this workgroup's partition. If it is negative, then it is an |
| // index into the stack. |
| // |
| // Note: inout is not needed for N_SEQ > 1 |
| uint search_link(inout Bic bic) { |
| uint ix = gl_LocalInvocationID.x; |
| uint j = 0; |
| while (j < LG_WG_SIZE) { |
| uint base = 2 * WG_SIZE - (2u << (LG_WG_SIZE - j)); |
| if (((ix >> j) & 1) != 0) { |
| Bic test = bic_combine(sh_bic[base + (ix >> j) - 1], bic); |
| if (test.b > 0) { |
| break; |
| } |
| bic = test; |
| ix -= 1u << j; |
| } |
| j++; |
| } |
| if (ix > 0) { |
| while (j > 0) { |
| j--; |
| uint base = 2 * WG_SIZE - (2u << (LG_WG_SIZE - j)); |
| Bic test = bic_combine(sh_bic[base + (ix >> j) - 1], bic); |
| if (test.b == 0) { |
| bic = test; |
| ix -= 1u << j; |
| } |
| } |
| } |
| // ix is the smallest value such that reduce(ix..th).b == 0 |
| |
| #if N_SEQ > 1 |
| if (ix > 0) { |
| ix--; |
| Bic test = bic_combine(sh_bic[ix], bic); |
| uint ix_in_chunk = search_bit_set(sh_bitmaps[ix], test.b - 1); |
| return ix * N_SEQ + ix_in_chunk; |
| } else { |
| return ~0u - bic.a; |
| } |
| #else |
| return ix; |
| #endif |
| } |
| |
| void main() { |
| uint th = gl_LocalInvocationID.x; |
| // materialize stack up to start of this partition |
| // start with reverse scan of bicyclic semigroup |
| Bic bic = Bic(0, 0); |
| if (th * N_SEQ < gl_WorkGroupID.x) { |
| bic = bicbuf[th * N_SEQ]; |
| } |
| for (uint i = 1; i < N_SEQ; i++) { |
| if (th * N_SEQ + i < gl_WorkGroupID.x) { |
| Bic other = bicbuf[th * N_SEQ + i]; |
| bic = bic_combine(bic, other); |
| } |
| } |
| sh_bic[th] = bic; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| barrier(); |
| if (th + (1u << i) < WG_SIZE) { |
| Bic other = sh_bic[th + (1u << i)]; |
| bic = bic_combine(bic, other); |
| } |
| barrier(); |
| sh_bic[th] = bic; |
| } |
| barrier(); |
| |
| #if N_SEQ > 1 |
| if (th == WG_SIZE - 1) { |
| bic = Bic(0, 0); |
| } else { |
| bic = sh_bic[th + 1]; |
| } |
| uint last_b = bic.b; |
| uint bitmap = 0; |
| for (uint i = 0; i < N_SEQ; i++) { |
| uint this_ix = th * N_SEQ + N_SEQ - 1 - i; |
| if (this_ix < gl_WorkGroupID.x) { |
| bic = bic_combine(bicbuf[this_ix], bic); |
| } |
| sh_stack[this_ix] = bic.b; |
| if (bic.b > last_b) { |
| bitmap |= 1u << (N_SEQ - 1 - i); |
| } |
| last_b = bic.b; |
| } |
| sh_bitmaps[th] = bitmap; |
| |
| // forward scan links to nonempty bitmaps |
| uint link = 0; |
| if (bitmap != 0) { |
| link = th * N_SEQ + findMSB(bitmap); |
| } |
| sh_link[th] = link; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| barrier(); |
| if (th >= (1u << i)) { |
| link = max(link, sh_link[th - (1u << i)]); |
| } |
| barrier(); |
| sh_link[th] = link; |
| } |
| barrier(); |
| |
| // binary search in stack |
| uint sp = PART_SIZE - N_SEQ - th * N_SEQ; |
| uint ix = 0; |
| for (uint i = 0; i < LG_WG_SIZE + LG_N_SEQ; i++) { |
| uint probe = ix + (uint(PART_SIZE / 2) >> i); |
| if (sp < sh_stack[probe]) { |
| ix = probe; |
| } |
| } |
| // ix is the largest value such that sp < sh_stack[ix] (if any) |
| uint b = sh_stack[ix]; |
| uint local_stack[N_SEQ]; |
| for (uint i = 0; i < N_SEQ; i++) { |
| // Probably not really necessary, but avoid UB |
| local_stack[i] = 0; |
| } |
| // Copy stack values sequentially |
| uint i = 0; |
| while (sp + i < b) { |
| local_stack[N_SEQ - 1 - i] = stack[ix * PART_SIZE + b - (sp + i) - 1]; |
| i++; |
| if (i == N_SEQ) { |
| break; |
| } |
| if (sp + i == b) { |
| // find previous nonempty slice |
| uint bits = sh_bitmaps[ix / N_SEQ] & ((1u << (ix % N_SEQ)) - 1); |
| if (bits == 0) { |
| ix = sh_link[max(ix / N_SEQ, 1) - 1]; |
| } else { |
| ix = (ix & -N_SEQ) + findMSB(bits); |
| } |
| b = sh_stack[ix]; |
| } |
| } |
| barrier(); |
| for (uint i = 0; i < N_SEQ; i++) { |
| sh_stack[th * N_SEQ + i] = local_stack[i]; |
| } |
| #else |
| // binary search in stack |
| uint sp = PART_SIZE - 1 - th; |
| uint ix = 0; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| uint probe = ix + (uint(PART_SIZE / 2) >> i); |
| if (sp < sh_bic[probe].b) { |
| ix = probe; |
| } |
| } |
| // ix is the largest value such that sp < sh_bic[ix].b (if any) |
| uint b = sh_bic[ix].b; |
| if (sp < b) { |
| sh_stack[th] = stack[ix * PART_SIZE + b - sp - 1]; |
| } |
| #endif |
| barrier(); |
| |
| // Do tree reduction of bicyclic semigroups (up-sweep) |
| uint inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + N_SEQ - 1]; |
| bic = Bic(1 - inp, inp); |
| #if N_SEQ > 1 |
| bitmap = inp << (N_SEQ - 1); |
| for (uint i = N_SEQ - 1; i > 0; i--) { |
| inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + i - 1]; |
| if (inp == 1 && bic.a == 0) { |
| bitmap |= 1u << (i - 1); |
| } |
| Bic other = Bic(1 - inp, inp); |
| bic = bic_combine(other, bic); |
| } |
| sh_bitmaps[th] = bitmap; |
| #endif |
| sh_bic[th] = bic; |
| uint inbase = 0; |
| for (uint i = 0; i < LG_WG_SIZE - 1; i++) { |
| uint outbase = 2 * WG_SIZE - (1u << (LG_WG_SIZE - i)); |
| barrier(); |
| if (th < (1u << (LG_WG_SIZE - 1 - i))) { |
| sh_bic[outbase + th] = bic_combine(sh_bic[inbase + th * 2], sh_bic[inbase + th * 2 + 1]); |
| } |
| inbase = outbase; |
| } |
| barrier(); |
| |
| // Search for predecessor node. |
| #if N_SEQ > 1 |
| bic.b = 0; |
| // search for predecessor of first unmatched open paren in this block |
| sh_link[th] = search_link(bic); |
| #endif |
| // search for predecessor of first character in this block |
| bic = Bic(0, 0); |
| ix = search_link(bic); |
| |
| // Generate output |
| uint outp; |
| #if N_SEQ > 1 |
| uint loc_stack[N_SEQ]; |
| uint loc_sp = 0; |
| for (uint i = 0; i < N_SEQ; i++) { |
| if (loc_sp > 0) { |
| outp = loc_stack[loc_sp - 1]; |
| } else if (int(ix) >= 0) { |
| outp = gl_WorkGroupID.x * PART_SIZE + ix; |
| } else { |
| outp = sh_stack[PART_SIZE + ix]; |
| } |
| outbuf[gl_GlobalInvocationID.x * N_SEQ + i] = outp; |
| // store in local memory instead? |
| inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + i]; |
| if (inp == 1) { |
| loc_stack[loc_sp] = gl_GlobalInvocationID.x * N_SEQ + i; |
| loc_sp++; |
| } else if (inp == 0) { |
| if (loc_sp > 0) { |
| loc_sp--; |
| } else { |
| // backlink logic |
| if (int(ix) >= 0) { |
| uint bits = sh_bitmaps[ix / N_SEQ] & ((1u << (ix % N_SEQ)) - 1); |
| if (bits == 0) { |
| ix = sh_link[ix / N_SEQ]; |
| } else { |
| ix = (ix & -N_SEQ) + findMSB(bits); |
| } |
| } else { |
| ix--; |
| } |
| } |
| } |
| } |
| #else |
| if (ix > 0) { |
| outp = gl_WorkGroupID.x * PART_SIZE + ix - 1; |
| } else { |
| outp = sh_stack[PART_SIZE - 1 - bic.a]; |
| } |
| outbuf[gl_GlobalInvocationID.x] = outp; |
| #endif |
| } |