Merge branch 'master' into prefix
diff --git a/piet-gpu-hal/examples/prefix.rs b/piet-gpu-hal/examples/prefix.rs
new file mode 100644
index 0000000..2f80a20
--- /dev/null
+++ b/piet-gpu-hal/examples/prefix.rs
@@ -0,0 +1,54 @@
+use piet_gpu_hal::vulkan::VkInstance;
+use piet_gpu_hal::{CmdBuf, Device, MemFlags};
+
+const BLOCKSIZE: usize = 16384;
+
+fn main() {
+    let n = 64 * 1024 * 1024;
+    let n_tiles = n / BLOCKSIZE;
+    let instance = VkInstance::new().unwrap();
+    unsafe {
+        let device = instance.device().unwrap();
+        let mem_flags = MemFlags::host_coherent();
+        let device_local = MemFlags::device_local();
+        let src = (0..n).map(|x| (x & 3) as u32).collect::<Vec<u32>>();
+        let buffer = device
+            .create_buffer(std::mem::size_of_val(&src[..]) as u64, mem_flags)
+            .unwrap();
+        let buffer_dev = device
+            .create_buffer(std::mem::size_of_val(&src[..]) as u64, device_local)
+            .unwrap();
+        let dst_buffer = device
+            .create_buffer(std::mem::size_of_val(&src[..]) as u64, device_local)
+            .unwrap();
+        let work_buffer = device
+            .create_buffer((n_tiles * 16) as u64, device_local)
+            .unwrap();
+        device.write_buffer(&buffer, &src).unwrap();
+        let code = include_bytes!("./shader/prefix.spv");
+        let pipeline = device.create_simple_compute_pipeline(code, 3).unwrap();
+        let bufs = [&buffer_dev, &dst_buffer, &work_buffer];
+        let descriptor_set = device.create_descriptor_set(&pipeline, &bufs).unwrap();
+        let query_pool = device.create_query_pool(2).unwrap();
+        let mut cmd_buf = device.create_cmd_buf().unwrap();
+        cmd_buf.begin();
+        cmd_buf.clear_buffer(&work_buffer);
+        cmd_buf.copy_buffer(&buffer, &buffer_dev);
+        cmd_buf.memory_barrier();
+        cmd_buf.reset_query_pool(&query_pool);
+        cmd_buf.write_timestamp(&query_pool, 0);
+        cmd_buf.dispatch(&pipeline, &descriptor_set, (n_tiles as u32, 1, 1));
+        cmd_buf.write_timestamp(&query_pool, 1);
+        cmd_buf.memory_barrier();
+        cmd_buf.copy_buffer(&dst_buffer, &buffer);
+        cmd_buf.finish();
+        device.run_cmd_buf(&cmd_buf).unwrap();
+        let timestamps = device.reap_query_pool(query_pool).unwrap();
+        let mut dst: Vec<u32> = Default::default();
+        device.read_buffer(&buffer, &mut dst).unwrap();
+        for (i, val) in dst.iter().enumerate().take(16) {
+            println!("{}: {}", i, val);
+        }
+        println!("{:?}ms", timestamps[0] * 1e3);
+    }
+}
diff --git a/piet-gpu-hal/examples/shader/build.ninja b/piet-gpu-hal/examples/shader/build.ninja
index 848637a..0aa0c40 100644
--- a/piet-gpu-hal/examples/shader/build.ninja
+++ b/piet-gpu-hal/examples/shader/build.ninja
@@ -5,6 +5,8 @@
 glslang_validator = glslangValidator
 
 rule glsl
-  command = $glslang_validator -V -o $out $in
+  command = $glslang_validator -V -o $out $in --target-env vulkan1.1
 
 build collatz.spv: glsl collatz.comp
+
+build prefix.spv: glsl prefix.comp
diff --git a/piet-gpu-hal/examples/shader/collatz.spv b/piet-gpu-hal/examples/shader/collatz.spv
index 21e4e92..45cc988 100644
--- a/piet-gpu-hal/examples/shader/collatz.spv
+++ b/piet-gpu-hal/examples/shader/collatz.spv
Binary files differ
diff --git a/piet-gpu-hal/examples/shader/prefix.comp b/piet-gpu-hal/examples/shader/prefix.comp
new file mode 100644
index 0000000..67fec78
--- /dev/null
+++ b/piet-gpu-hal/examples/shader/prefix.comp
@@ -0,0 +1,124 @@
+// See https://research.nvidia.com/sites/default/files/pubs/2016-03_Single-pass-Parallel-Prefix/nvr-2016-002.pdf
+
+#version 450
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_memory_scope_semantics : enable
+
+layout(local_size_x = 1024) in;
+
+// One workgroup processes workgroup size * N_ROWS elements.
+#define N_ROWS 16
+
+layout(set = 0, binding = 0) readonly buffer InBuf {
+    uint[] in_buf;
+};
+
+layout(set = 0, binding = 1) buffer OutBuf {
+    uint[] out_buf;
+};
+
+// work_buf[0] is the tile id
+// work_buf[i * 4 + 1] is the flag for tile i
+// work_buf[i * 4 + 2] is the aggregate for tile i
+// work_buf[i * 4 + 3] is the prefix for tile i
+layout(set = 0, binding = 2) buffer WorkBuf {
+    uint[] work_buf;
+};
+
+// These correspond to X, A, P respectively in the paper.
+#define FLAG_NOT_READY 0
+#define FLAG_AGGREGATE_READY 1
+#define FLAG_PREFIX_READY 2
+
+shared uint shared_tile;
+shared uint shared_prefix;
+// Note: the subgroup size and other dimensions are hard-coded.
+// TODO: make it more adaptive.
+shared uint chunks[32];
+
+void main() {
+    uint local_ix = gl_LocalInvocationID.x;
+    // Determine tile to process by atomic counter (implement idea from
+    // section 4.4 in the paper).
+    if (local_ix == 0) {
+        shared_tile = atomicAdd(work_buf[0], 1);
+    }
+    barrier();
+    uint my_tile = shared_tile;
+    uint mem_base = my_tile * 16384;
+    uint aggregates[N_ROWS];
+
+    // Interleave reading of data, computing row prefix sums, and aggregate
+    // (step 3 of paper).
+    uint total = 0;
+    for (uint i = 0; i < N_ROWS; i++) {
+        uint ix = (local_ix & 0x3e0) * N_ROWS + i * 32 + (local_ix & 0x1f);
+        uint data = in_buf[mem_base + ix];
+        uint row = subgroupInclusiveAdd(data);
+        total += row;
+        aggregates[i] = row;
+    }
+    if (gl_SubgroupInvocationID == 31) {
+        chunks[local_ix >> 5] = total;
+    }
+
+    barrier();
+    if (local_ix < 32) {
+        uint chunk = chunks[gl_SubgroupInvocationID];
+        total = subgroupInclusiveAdd(chunk);
+        chunks[gl_SubgroupInvocationID] = total;
+    }
+
+    uint exclusive_prefix = 0;
+    if (local_ix == 31) {
+        atomicStore(work_buf[my_tile * 4 + 2], total, gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelaxed);
+        uint flag = FLAG_AGGREGATE_READY;
+        if (my_tile == 0) {
+            atomicStore(work_buf[my_tile * 4 + 3], total, gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelaxed);
+            flag = FLAG_PREFIX_READY;
+        }
+        atomicStore(work_buf[my_tile * 4 + 1], flag, gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelease);
+        if (my_tile != 0) {
+            // step 4: decoupled lookback
+            uint look_back_ix = my_tile - 1;
+            while (true) {
+                flag = atomicLoad(work_buf[look_back_ix * 4 + 1], gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsAcquire);
+                if (flag == FLAG_PREFIX_READY) {
+                    uint their_prefix = atomicLoad(work_buf[look_back_ix * 4 + 3], gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelaxed);
+                    exclusive_prefix = their_prefix + exclusive_prefix;
+                    break;
+                } else if (flag == FLAG_AGGREGATE_READY) {
+                    uint their_agg = atomicLoad(work_buf[look_back_ix * 4 + 2], gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelaxed);
+                    exclusive_prefix = their_agg + exclusive_prefix;
+                    look_back_ix--;
+                }
+                // else spin
+            }
+
+            // step 5: compute inclusive prefix
+            uint inclusive_prefix = exclusive_prefix + total;
+            shared_prefix = exclusive_prefix;
+            atomicStore(work_buf[my_tile * 4 + 3], inclusive_prefix, gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelaxed);
+            flag = FLAG_PREFIX_READY;
+            atomicStore(work_buf[my_tile * 4 + 1], flag, gl_ScopeDevice, gl_StorageSemanticsBuffer, gl_SemanticsRelease);
+        }
+    }
+    uint prefix = 0;
+    barrier();
+    if (my_tile != 0) {
+        prefix = shared_prefix;
+    }
+
+    // step 6: perform partition-wide scan
+    if (local_ix >> 5 > 0) {
+        prefix += chunks[(local_ix >> 5) - 1];
+    }
+    for (uint i = 0; i < N_ROWS; i++) {
+        uint ix = (local_ix & 0x3e0) * N_ROWS + i * 32 + (local_ix & 0x1f);
+        uint agg = aggregates[i];
+        out_buf[mem_base + ix] = prefix + agg;
+        prefix += subgroupBroadcast(agg, 31);
+    }
+}
diff --git a/piet-gpu-hal/examples/shader/prefix.spv b/piet-gpu-hal/examples/shader/prefix.spv
new file mode 100644
index 0000000..311df27
--- /dev/null
+++ b/piet-gpu-hal/examples/shader/prefix.spv
Binary files differ
diff --git a/piet-gpu-hal/src/vulkan.rs b/piet-gpu-hal/src/vulkan.rs
index 2a5abd9..35cf68f 100644
--- a/piet-gpu-hal/src/vulkan.rs
+++ b/piet-gpu-hal/src/vulkan.rs
@@ -326,6 +326,10 @@
                             .stage(vk::ShaderStageFlags::COMPUTE)
                             .module(compute_shader_module)
                             .name(&entry_name)
+                            .push_next(
+                                &mut vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT::builder()
+                                    .required_subgroup_size(32)
+                            )
                             .build(),
                     )
                     .layout(pipeline_layout)