blob: 23bf47ab089d4945f16dfc95888ea5ab30bbe5ce [file] [log] [blame]
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// The main phase for the stack monoid
#version 450
#define LG_N_SEQ 0
#define N_SEQ (1 << LG_N_SEQ)
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PART_SIZE (WG_SIZE * N_SEQ)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
// The bicyclic monoid
struct Bic {
uint a;
uint b;
};
Bic bic_combine(Bic x, Bic y) {
uint m = min(x.b, y.a);
return Bic(x.a + y.a - m, x.b + y.b - m);
}
layout(binding = 0) readonly buffer InBuf {
uint[] inbuf;
};
layout(binding = 1) readonly buffer BicBuf {
Bic[] bicbuf;
};
layout(binding = 2) readonly buffer StackBuf {
uint[] stack;
};
layout(binding = 3) buffer OutBuf {
uint[] outbuf;
};
shared Bic sh_bic[WG_SIZE * 2 - 2];
shared uint sh_stack[PART_SIZE];
#if N_SEQ > 1
shared uint sh_next[WG_SIZE];
shared uint sh_link[WG_SIZE];
shared uint sh_bitmaps[WG_SIZE];
#endif
// search for the ix'th set bit in bitmask.
// More formally, the largest i such that bitmask & ((1 << i) - 1) has ix bits set.
uint search_bit_set(uint bitmask, uint ix) {
uint result = 0;
// TODO: instead of 5, use LG_N_SEQ
for (uint j = 0; j < 5; j++) {
uint step = 1u << (4 - j);
if (bitCount(bitmask & ((1u << (result + step)) - 1)) <= ix) {
result += step;
}
}
return result;
}
// Search for predecessor node.
// For N_SEQ == 1:
// On return, ix is the smallest value such that reduce(ix..th).b == 0
// bic contains the reduction (and may start with a non-default value)
//
// For N_SEQ > 1:
// If return value is nonnegative, then it is an index to an element
// within this workgroup's partition. If it is negative, then it is an
// index into the stack.
//
// Note: inout is not needed for N_SEQ > 1
uint search_link(inout Bic bic) {
uint ix = gl_LocalInvocationID.x;
uint j = 0;
while (j < LG_WG_SIZE) {
uint base = 2 * WG_SIZE - (2u << (LG_WG_SIZE - j));
if (((ix >> j) & 1) != 0) {
Bic test = bic_combine(sh_bic[base + (ix >> j) - 1], bic);
if (test.b > 0) {
break;
}
bic = test;
ix -= 1u << j;
}
j++;
}
if (ix > 0) {
while (j > 0) {
j--;
uint base = 2 * WG_SIZE - (2u << (LG_WG_SIZE - j));
Bic test = bic_combine(sh_bic[base + (ix >> j) - 1], bic);
if (test.b == 0) {
bic = test;
ix -= 1u << j;
}
}
}
// ix is the smallest value such that reduce(ix..th).b == 0
#if N_SEQ > 1
if (ix > 0) {
ix--;
Bic test = bic_combine(sh_bic[ix], bic);
uint ix_in_chunk = search_bit_set(sh_bitmaps[ix], test.b - 1);
return ix * N_SEQ + ix_in_chunk;
} else {
return ~0u - bic.a;
}
#else
return ix;
#endif
}
void main() {
uint th = gl_LocalInvocationID.x;
// materialize stack up to start of this partition
// start with reverse scan of bicyclic semigroup
Bic bic = Bic(0, 0);
if (th * N_SEQ < gl_WorkGroupID.x) {
bic = bicbuf[th * N_SEQ];
}
for (uint i = 1; i < N_SEQ; i++) {
if (th * N_SEQ + i < gl_WorkGroupID.x) {
Bic other = bicbuf[th * N_SEQ + i];
bic = bic_combine(bic, other);
}
}
sh_bic[th] = bic;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (th + (1u << i) < WG_SIZE) {
Bic other = sh_bic[th + (1u << i)];
bic = bic_combine(bic, other);
}
barrier();
sh_bic[th] = bic;
}
barrier();
#if N_SEQ > 1
if (th == WG_SIZE - 1) {
bic = Bic(0, 0);
} else {
bic = sh_bic[th + 1];
}
uint last_b = bic.b;
uint bitmap = 0;
for (uint i = 0; i < N_SEQ; i++) {
uint this_ix = th * N_SEQ + N_SEQ - 1 - i;
if (this_ix < gl_WorkGroupID.x) {
bic = bic_combine(bicbuf[this_ix], bic);
}
sh_stack[this_ix] = bic.b;
if (bic.b > last_b) {
bitmap |= 1u << (N_SEQ - 1 - i);
}
last_b = bic.b;
}
sh_bitmaps[th] = bitmap;
// forward scan links to nonempty bitmaps
uint link = 0;
if (bitmap != 0) {
link = th * N_SEQ + findMSB(bitmap);
}
sh_link[th] = link;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (th >= (1u << i)) {
link = max(link, sh_link[th - (1u << i)]);
}
barrier();
sh_link[th] = link;
}
barrier();
// binary search in stack
uint sp = PART_SIZE - N_SEQ - th * N_SEQ;
uint ix = 0;
for (uint i = 0; i < LG_WG_SIZE + LG_N_SEQ; i++) {
uint probe = ix + (uint(PART_SIZE / 2) >> i);
if (sp < sh_stack[probe]) {
ix = probe;
}
}
// ix is the largest value such that sp < sh_stack[ix] (if any)
uint b = sh_stack[ix];
uint local_stack[N_SEQ];
for (uint i = 0; i < N_SEQ; i++) {
// Probably not really necessary, but avoid UB
local_stack[i] = 0;
}
// Copy stack values sequentially
uint i = 0;
while (sp + i < b) {
local_stack[N_SEQ - 1 - i] = stack[ix * PART_SIZE + b - (sp + i) - 1];
i++;
if (i == N_SEQ) {
break;
}
if (sp + i == b) {
// find previous nonempty slice
uint bits = sh_bitmaps[ix / N_SEQ] & ((1u << (ix % N_SEQ)) - 1);
if (bits == 0) {
ix = sh_link[max(ix / N_SEQ, 1) - 1];
} else {
ix = (ix & -N_SEQ) + findMSB(bits);
}
b = sh_stack[ix];
}
}
barrier();
for (uint i = 0; i < N_SEQ; i++) {
sh_stack[th * N_SEQ + i] = local_stack[i];
}
#else
// binary search in stack
uint sp = PART_SIZE - 1 - th;
uint ix = 0;
for (uint i = 0; i < LG_WG_SIZE; i++) {
uint probe = ix + (uint(PART_SIZE / 2) >> i);
if (sp < sh_bic[probe].b) {
ix = probe;
}
}
// ix is the largest value such that sp < sh_bic[ix].b (if any)
uint b = sh_bic[ix].b;
if (sp < b) {
sh_stack[th] = stack[ix * PART_SIZE + b - sp - 1];
}
#endif
barrier();
// Do tree reduction of bicyclic semigroups (up-sweep)
uint inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + N_SEQ - 1];
bic = Bic(1 - inp, inp);
#if N_SEQ > 1
bitmap = inp << (N_SEQ - 1);
for (uint i = N_SEQ - 1; i > 0; i--) {
inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + i - 1];
if (inp == 1 && bic.a == 0) {
bitmap |= 1u << (i - 1);
}
Bic other = Bic(1 - inp, inp);
bic = bic_combine(other, bic);
}
sh_bitmaps[th] = bitmap;
#endif
sh_bic[th] = bic;
uint inbase = 0;
for (uint i = 0; i < LG_WG_SIZE - 1; i++) {
uint outbase = 2 * WG_SIZE - (1u << (LG_WG_SIZE - i));
barrier();
if (th < (1u << (LG_WG_SIZE - 1 - i))) {
sh_bic[outbase + th] = bic_combine(sh_bic[inbase + th * 2], sh_bic[inbase + th * 2 + 1]);
}
inbase = outbase;
}
barrier();
// Search for predecessor node.
#if N_SEQ > 1
bic.b = 0;
// search for predecessor of first unmatched open paren in this block
sh_link[th] = search_link(bic);
#endif
// search for predecessor of first character in this block
bic = Bic(0, 0);
ix = search_link(bic);
// Generate output
uint outp;
#if N_SEQ > 1
uint loc_stack[N_SEQ];
uint loc_sp = 0;
for (uint i = 0; i < N_SEQ; i++) {
if (loc_sp > 0) {
outp = loc_stack[loc_sp - 1];
} else if (int(ix) >= 0) {
outp = gl_WorkGroupID.x * PART_SIZE + ix;
} else {
outp = sh_stack[PART_SIZE + ix];
}
outbuf[gl_GlobalInvocationID.x * N_SEQ + i] = outp;
// store in local memory instead?
inp = inbuf[gl_GlobalInvocationID.x * N_SEQ + i];
if (inp == 1) {
loc_stack[loc_sp] = gl_GlobalInvocationID.x * N_SEQ + i;
loc_sp++;
} else if (inp == 0) {
if (loc_sp > 0) {
loc_sp--;
} else {
// backlink logic
if (int(ix) >= 0) {
uint bits = sh_bitmaps[ix / N_SEQ] & ((1u << (ix % N_SEQ)) - 1);
if (bits == 0) {
ix = sh_link[ix / N_SEQ];
} else {
ix = (ix & -N_SEQ) + findMSB(bits);
}
} else {
ix--;
}
}
}
}
#else
if (ix > 0) {
outp = gl_WorkGroupID.x * PART_SIZE + ix - 1;
} else {
outp = sh_stack[PART_SIZE - 1 - bic.a];
}
outbuf[gl_GlobalInvocationID.x] = outp;
#endif
}