blob: 95fdf582bdeb7b28878c84855e4cb935f880d723 [file] [log] [blame]
// Implementation of the parallel prefix sum algorithm
const int SIZE = 512;
layout(set=0, binding=0) in float[] in_data;
layout(set=0, binding=1) out float[] out_data;
threadgroup float[SIZE * 2] shared_data;
// Use a separate function here simply to ensure that we're passing the threadgroups struct
// correctly
noinline void store(uint i, float value) {
shared_data[i] = value;
}
void main() {
uint id = sk_ThreadPosition.x;
uint rd_id;
uint wr_id;
uint mask;
// Each thread is responsible for two elements of the output array
shared_data[id * 2] = in_data[id * 2];
shared_data[id * 2 + 1] = in_data[id * 2 + 1];
threadgroupBarrier();
const uint steps = uint(log2(float(SIZE))) + 1;
for (uint step = 0; step < steps; step++) {
// Calculate the read and write index in the shared array
mask = (1 << step) - 1;
rd_id = ((id >> step) << (step + 1)) + mask;
wr_id = rd_id + 1 + (id & mask);
// Accumulate the read data into our element
store(wr_id, shared_data[wr_id] + shared_data[rd_id]);
threadgroupBarrier();
}
// Write the final result out
out_data[id * 2] = shared_data[id * 2];
out_data[id * 2 + 1] = shared_data[id * 2 + 1];
}