blob: 450f76f69f4f1a1455dc6ae64709522497fe00bf [file] [log] [blame]
layout(local_size_x = 256) in;
const uint WORKGROUP_SIZE = 256;
struct GlobalCounts {
atomicUint firstHalfCount;
atomicUint secondHalfCount;
};
layout(metal, binding = 0) buffer ssbo {
GlobalCounts globalCounts;
};
workgroup atomicUint localCounts[2];
void main() {
// Initialize the local counts.
if (sk_LocalInvocationID.x == 0) {
atomicStore(localCounts[0], 0);
atomicStore(localCounts[1], 0);
}
// Synchronize the threads in the workgroup so they all see the initial value.
workgroupBarrier();
// Each thread increments one of the local counters based on its invocation index.
uint idx = sk_LocalInvocationID.x < (WORKGROUP_SIZE / 2) ? 0 : 1;
atomicAdd(localCounts[idx], 1);
// Synchronize the threads again to ensure they have all executed the increments
// and the following load reads the same value across all threads in the
// workgroup.
workgroupBarrier();
// Add the workgroup-only tally to the global counter.
if (sk_LocalInvocationID.x == 0) {
atomicAdd(globalCounts.firstHalfCount, atomicLoad(localCounts[0]));
atomicAdd(globalCounts.secondHalfCount, atomicLoad(localCounts[1]));
}
}