blob: 14ed1c45b5682e4d56001680671bbe8b6c043a89 [file] [log] [blame]
// Implementation of the parallel prefix sum algorithm
const int SIZE = 512;
layout(set=0, binding=0) readonly buffer inputs {
float[] in_data;
};
layout(set=0, binding=1) writeonly buffer outputs {
float[] out_data;
};
workgroup float[SIZE * 2] shared_data;
// Test that workgroup-shared variables are passed to user-defined functions
// correctly.
noinline void store(uint i, float value) {
shared_data[i] = value;
}
void main() {
uint id = sk_GlobalInvocationID.x;
uint rd_id;
uint wr_id;
uint mask;
// Each thread is responsible for two elements of the output array
shared_data[id * 2] = in_data[id * 2];
shared_data[id * 2 + 1] = in_data[id * 2 + 1];
workgroupBarrier();
const uint steps = uint(log2(float(SIZE))) + 1;
for (uint step = 0; step < steps; step++) {
// Calculate the read and write index in the shared array
mask = (1 << step) - 1;
rd_id = ((id >> step) << (step + 1)) + mask;
wr_id = rd_id + 1 + (id & mask);
// Accumulate the read data into our element
store(wr_id, shared_data[wr_id] + shared_data[rd_id]);
workgroupBarrier();
}
// Write the final result out
out_data[id * 2] = shared_data[id * 2];
out_data[id * 2 + 1] = shared_data[id * 2 + 1];
}