blob: a5e4003030fdcec4d2f370500c95a1a3dc442bd1 [file] [log] [blame]
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A scan for a tree reduction prefix scan that outputs the final result.
// Output is written into memory at trans_alloc.
#version 450
#extension GL_GOOGLE_include_directive : enable
#include "mem.h"
#include "setup.h"
#define N_ROWS 8
#define LG_WG_SIZE (7 + LG_WG_FACTOR)
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
layout(binding = 1) readonly buffer ConfigBuf {
Config conf;
};
layout(binding = 2) readonly buffer SceneBuf {
uint[] scene;
};
#include "scene.h"
#include "tile.h"
#define Monoid Transform
layout(set = 0, binding = 3) readonly buffer ParentBuf {
Monoid[] parent;
};
Monoid monoid_identity() {
return Monoid(vec4(1.0, 0.0, 0.0, 1.0), vec2(0.0, 0.0));
}
Monoid combine_monoid(Monoid a, Monoid b) {
Monoid c;
c.mat = a.mat.xyxy * b.mat.xxzz + a.mat.zwzw * b.mat.yyww;
c.translate = a.mat.xy * b.translate.x + a.mat.zw * b.translate.y + a.translate;
return c;
}
shared Monoid sh_scratch[WG_SIZE];
void main() {
Monoid local[N_ROWS];
uint ix = gl_GlobalInvocationID.x * N_ROWS;
TransformRef ref = TransformRef(conf.trans_offset + ix * Transform_size);
Monoid agg = Transform_read(ref);
local[0] = agg;
for (uint i = 1; i < N_ROWS; i++) {
agg = combine_monoid(agg, Transform_read(Transform_index(ref, i)));
local[i] = agg;
}
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1u << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
agg = combine_monoid(other, agg);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
barrier();
Monoid row = monoid_identity();
if (gl_WorkGroupID.x > 0) {
row = parent[gl_WorkGroupID.x - 1];
}
if (gl_LocalInvocationID.x > 0) {
row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]);
}
for (uint i = 0; i < N_ROWS; i++) {
Monoid m = combine_monoid(row, local[i]);
TransformSeg transform = TransformSeg(m.mat, m.translate);
TransformSegRef trans_ref = TransformSegRef(conf.trans_alloc.offset + (ix + i) * TransformSeg_size);
TransformSeg_write(conf.trans_alloc, trans_ref, transform);
}
}