blob: d285020407e88ba1abf6f0845e9106af6ad2f0ba [file] [log] [blame]
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A scan pass for draw tag scan implemented as a tree reduction.
#version 450
#extension GL_GOOGLE_include_directive : enable
#include "setup.h"
#include "drawtag.h"
#define N_ROWS 8
#define LG_WG_SIZE (7 + LG_WG_FACTOR)
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
#define Monoid DrawMonoid
#define combine_monoid combine_draw_monoid
#define monoid_identity draw_monoid_identity
layout(binding = 0) buffer DataBuf {
Monoid[] data;
};
#ifndef ROOT
layout(binding = 1) readonly buffer ParentBuf {
Monoid[] parent;
};
#endif
shared Monoid sh_scratch[WG_SIZE];
void main() {
Monoid local[N_ROWS];
uint ix = gl_GlobalInvocationID.x * N_ROWS;
local[0] = data[ix];
for (uint i = 1; i < N_ROWS; i++) {
local[i] = combine_monoid(local[i - 1], data[ix + i]);
}
Monoid agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1u << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
agg = combine_monoid(other, agg);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
barrier();
// This could be a semigroup instead of a monoid if we reworked the
// conditional logic, but that might impact performance.
Monoid row = monoid_identity();
#ifdef ROOT
if (gl_LocalInvocationID.x > 0) {
row = sh_scratch[gl_LocalInvocationID.x - 1];
}
#else
if (gl_WorkGroupID.x > 0) {
row = parent[gl_WorkGroupID.x - 1];
}
if (gl_LocalInvocationID.x > 0) {
row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]);
}
#endif
for (uint i = 0; i < N_ROWS; i++) {
Monoid m = combine_monoid(row, local[i]);
data[ix + i] = m;
}
}