| // SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense |
| |
| // A scan for a tree reduction prefix scan (either root or not, by ifdef). |
| |
| #version 450 |
| #extension GL_GOOGLE_include_directive : enable |
| |
| #include "setup.h" |
| |
| #define N_ROWS 8 |
| #define LG_WG_SIZE (7 + LG_WG_FACTOR) |
| #define WG_SIZE (1 << LG_WG_SIZE) |
| #define PARTITION_SIZE (WG_SIZE * N_ROWS) |
| |
| layout(local_size_x = WG_SIZE, local_size_y = 1) in; |
| |
| // This is copy-pasted from scene.h. It might be better for DRY |
| // to include it, but that pulls in more stuff we don't need. |
| struct Transform { |
| vec4 mat; |
| vec2 translate; |
| }; |
| |
| #define Monoid Transform |
| |
| layout(binding = 0) buffer DataBuf { |
| Monoid[] data; |
| }; |
| |
| #ifndef ROOT |
| layout(binding = 1) readonly buffer ParentBuf { |
| Monoid[] parent; |
| }; |
| #endif |
| |
| Monoid monoid_identity() { |
| return Monoid(vec4(1.0, 0.0, 0.0, 1.0), vec2(0.0, 0.0)); |
| } |
| |
| Monoid combine_monoid(Monoid a, Monoid b) { |
| Monoid c; |
| c.mat = a.mat.xyxy * b.mat.xxzz + a.mat.zwzw * b.mat.yyww; |
| c.translate = a.mat.xy * b.translate.x + a.mat.zw * b.translate.y + a.translate; |
| return c; |
| } |
| |
| shared Monoid sh_scratch[WG_SIZE]; |
| |
| void main() { |
| Monoid local[N_ROWS]; |
| |
| uint ix = gl_GlobalInvocationID.x * N_ROWS; |
| |
| local[0] = data[ix]; |
| for (uint i = 1; i < N_ROWS; i++) { |
| local[i] = combine_monoid(local[i - 1], data[ix + i]); |
| } |
| Monoid agg = local[N_ROWS - 1]; |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| for (uint i = 0; i < LG_WG_SIZE; i++) { |
| barrier(); |
| if (gl_LocalInvocationID.x >= (1u << i)) { |
| Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)]; |
| agg = combine_monoid(other, agg); |
| } |
| barrier(); |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| } |
| |
| barrier(); |
| // This could be a semigroup instead of a monoid if we reworked the |
| // conditional logic, but that might impact performance. |
| Monoid row = monoid_identity(); |
| #ifdef ROOT |
| if (gl_LocalInvocationID.x > 0) { |
| row = sh_scratch[gl_LocalInvocationID.x - 1]; |
| } |
| #else |
| if (gl_WorkGroupID.x > 0) { |
| row = parent[gl_WorkGroupID.x - 1]; |
| } |
| if (gl_LocalInvocationID.x > 0) { |
| row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]); |
| } |
| #endif |
| for (uint i = 0; i < N_ROWS; i++) { |
| Monoid m = combine_monoid(row, local[i]); |
| // TODO: gate buffer write |
| data[ix + i] = m; |
| } |
| } |