Rework of compute encoder abstraction

The current plan is to more or less follow the wgpu/wgpu-hal approach. In the mux/backend layer (which corresponds fairly strongly to wgpu-hal), there isn't explicit construction of a compute encoder, but there are new methods for beginning and ending a compute pass. At the hub layer (which corresponds to wgpu) there will be a ComputeEncoder object.

That said, there will be some differences. The WebGPU "end" method on a compute encoder is implemented in wgpu as Drop, and that is not ideal. Also, the wgpu-hal approach to timer queries (still based on write_timestamp) is not up to the task of Metal timer queries, where the query offsets have to be specified at compute encoder creation. That's why there are different projects :)

WIP: current state is that stage-style queries work on Apple Silicon, but non-Metal backends are broken, and piet-gpu is not yet updated to use new API.
diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs
index dae5b31..7aff938 100644
--- a/piet-gpu-hal/examples/collatz.rs
+++ b/piet-gpu-hal/examples/collatz.rs
@@ -1,4 +1,4 @@
-use piet_gpu_hal::{include_shader, BindType};
+use piet_gpu_hal::{include_shader, BindType, ComputePassDescriptor};
 use piet_gpu_hal::{BufferUsage, Instance, InstanceFlags, Session};
 
 fn main() {
@@ -20,9 +20,9 @@
         let mut cmd_buf = session.cmd_buf().unwrap();
         cmd_buf.begin();
         cmd_buf.reset_query_pool(&query_pool);
-        cmd_buf.write_timestamp(&query_pool, 0);
-        cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
-        cmd_buf.write_timestamp(&query_pool, 1);
+        let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 0, 1));
+        pass.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
+        pass.end();
         cmd_buf.finish_timestamps(&query_pool);
         cmd_buf.host_barrier();
         cmd_buf.finish();
diff --git a/piet-gpu-hal/src/backend.rs b/piet-gpu-hal/src/backend.rs
index c1b2132..f2c67a1 100644
--- a/piet-gpu-hal/src/backend.rs
+++ b/piet-gpu-hal/src/backend.rs
@@ -17,7 +17,8 @@
 //! The generic trait for backends to implement.
 
 use crate::{
-    BindType, BufferUsage, Error, GpuInfo, ImageFormat, ImageLayout, MapMode, SamplerParams,
+    BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout,
+    MapMode, SamplerParams,
 };
 
 pub trait Device: Sized {
@@ -159,16 +160,32 @@
     unsafe fn create_sampler(&self, params: SamplerParams) -> Result<Self::Sampler, Error>;
 }
 
+/// The trait implemented by backend command buffer implementations.
+///
+/// Valid encoding is represented by a state machine (currently not validated
+/// but it is easy to imagine there might be at least debug validation). Most
+/// methods are only valid in a particular state, and some move it to another
+/// state.
 pub trait CmdBuf<D: Device> {
-    type ComputeEncoder;
-
+    /// Begin encoding.
+    ///
+    /// State: init -> ready
     unsafe fn begin(&mut self);
 
+    /// State: ready -> finished
     unsafe fn finish(&mut self);
 
     /// Return true if the command buffer is suitable for reuse.
     unsafe fn reset(&mut self) -> bool;
 
+    /// Begin a compute pass.
+    ///
+    /// State: ready -> in_compute_pass
+    unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor);
+
+    /// Dispatch
+    ///
+    /// State: in_compute_pass
     unsafe fn dispatch(
         &mut self,
         pipeline: &D::Pipeline,
@@ -177,6 +194,9 @@
         workgroup_size: (u32, u32, u32),
     );
 
+    /// State: in_compute_pass -> ready
+    unsafe fn end_compute_pass(&mut self);
+
     /// Insert an execution and memory barrier.
     ///
     /// Compute kernels (and other actions) after this barrier may read from buffers
@@ -229,12 +249,10 @@
     unsafe fn finish_timestamps(&mut self, _pool: &D::QueryPool) {}
 
     /// Begin a labeled section for debugging and profiling purposes.
-    unsafe fn begin_debug_label(&mut self, label: &str) {}
+    unsafe fn begin_debug_label(&mut self, _label: &str) {}
 
     /// End a section opened by `begin_debug_label`.
     unsafe fn end_debug_label(&mut self) {}
-
-    unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder;
 }
 
 /// A builder for descriptor sets with more complex layouts.
@@ -256,16 +274,3 @@
     fn add_textures(&mut self, images: &[&D::Image]);
     unsafe fn build(self, device: &D, pipeline: &D::Pipeline) -> Result<D::DescriptorSet, Error>;
 }
-
-pub trait ComputeEncoder<D: Device> {
-    unsafe fn dispatch(
-        &mut self,
-        pipeline: &D::Pipeline,
-        descriptor_set: &D::DescriptorSet,
-        workgroup_count: (u32, u32, u32),
-        workgroup_size: (u32, u32, u32),
-    );
-
-    // Question: should be self?
-    unsafe fn finish(&mut self);
-}
diff --git a/piet-gpu-hal/src/hub.rs b/piet-gpu-hal/src/hub.rs
index cc09832..37c59df 100644
--- a/piet-gpu-hal/src/hub.rs
+++ b/piet-gpu-hal/src/hub.rs
@@ -13,7 +13,7 @@
 use bytemuck::Pod;
 use smallvec::SmallVec;
 
-use crate::{mux, BackendType, BufWrite, ImageFormat, MapMode};
+use crate::{mux, BackendType, BufWrite, ComputePassDescriptor, ImageFormat, MapMode};
 
 use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, SamplerParams};
 
@@ -135,6 +135,11 @@
     size: u64,
 }
 
+/// A sub-object of a command buffer for a sequence of compute dispatches.
+pub struct ComputePass<'a> {
+    cmd_buf: &'a mut CmdBuf,
+}
+
 impl Session {
     /// Create a new session, choosing the best backend.
     pub fn new(device: mux::Device) -> Session {
@@ -471,6 +476,12 @@
         self.cmd_buf().finish();
     }
 
+    /// Begin a compute pass.
+    pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) -> ComputePass {
+        self.cmd_buf().begin_compute_pass(desc);
+        ComputePass { cmd_buf: self }
+    }
+
     /// Dispatch a compute shader.
     ///
     /// Request a compute shader to be run, using the pipeline to specify the
@@ -479,6 +490,11 @@
     /// Both the workgroup count (number of workgroups) and the workgroup size
     /// (number of threads in a workgroup) must be specified here, though not
     /// all back-ends require the latter info.
+    ///
+    /// This version is deprecated because (a) you do not get timer queries and
+    /// (b) it doesn't aggregate multiple dispatches into a single compute
+    /// pass, which is a performance concern.
+    #[deprecated(note = "moving to ComputePass")]
     pub unsafe fn dispatch(
         &mut self,
         pipeline: &Pipeline,
@@ -486,8 +502,9 @@
         workgroup_count: (u32, u32, u32),
         workgroup_size: (u32, u32, u32),
     ) {
-        self.cmd_buf()
-            .dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
+        let mut pass = self.begin_compute_pass(&Default::default());
+        pass.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
+        pass.end();
     }
 
     /// Insert an execution and memory barrier.
@@ -692,6 +709,32 @@
     }
 }
 
+impl<'a> ComputePass<'a> {
+    /// Dispatch a compute shader.
+    ///
+    /// Request a compute shader to be run, using the pipeline to specify the
+    /// code, and the descriptor set to address the resources read and written.
+    ///
+    /// Both the workgroup count (number of workgroups) and the workgroup size
+    /// (number of threads in a workgroup) must be specified here, though not
+    /// all back-ends require the latter info.
+    pub unsafe fn dispatch(
+        &mut self,
+        pipeline: &Pipeline,
+        descriptor_set: &DescriptorSet,
+        workgroup_count: (u32, u32, u32),
+        workgroup_size: (u32, u32, u32),
+    ) {
+        self.cmd_buf
+            .cmd_buf()
+            .dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
+    }
+
+    pub unsafe fn end(&mut self) {
+        self.cmd_buf.cmd_buf().end_compute_pass();
+    }
+}
+
 impl Drop for BufferInner {
     fn drop(&mut self) {
         if let Some(session) = Weak::upgrade(&self.session) {
diff --git a/piet-gpu-hal/src/lib.rs b/piet-gpu-hal/src/lib.rs
index fab7d65..241cdfd 100644
--- a/piet-gpu-hal/src/lib.rs
+++ b/piet-gpu-hal/src/lib.rs
@@ -189,3 +189,17 @@
     /// dimension.
     pub max_invocations: u32,
 }
+
+#[derive(Default)]
+pub struct ComputePassDescriptor<'a> {
+    // Maybe label should go here? It does in wgpu and wgpu_hal.
+    timer_queries: Option<(&'a QueryPool, u32, u32)>,
+}
+
+impl<'a> ComputePassDescriptor<'a> {
+    pub fn timer(pool: &'a QueryPool, start_query: u32, end_query: u32) -> ComputePassDescriptor {
+        ComputePassDescriptor {
+            timer_queries: Some((pool, start_query, end_query)),
+        }
+    }
+}
diff --git a/piet-gpu-hal/src/metal.rs b/piet-gpu-hal/src/metal.rs
index 23cc256..c907d77 100644
--- a/piet-gpu-hal/src/metal.rs
+++ b/piet-gpu-hal/src/metal.rs
@@ -33,11 +33,13 @@
 
 use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
 
-use crate::{BufferUsage, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits};
+use crate::{
+    BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits,
+};
 
 use util::*;
 
-use self::timer::{CounterSampleBuffer, CounterSet};
+use self::timer::{CounterSampleBuffer, CounterSet, TimeCalibration};
 
 pub struct MtlInstance;
 
@@ -110,15 +112,11 @@
 }
 
 #[derive(Default)]
-struct TimeCalibration {
-    cpu_start_ts: u64,
-    gpu_start_ts: u64,
-    cpu_end_ts: u64,
-    gpu_end_ts: u64,
+pub struct QueryPool {
+    counter_sample_buf: Option<CounterSampleBuffer>,
+    calibration: Arc<Mutex<Option<Arc<Mutex<TimeCalibration>>>>>,
 }
 
-pub struct QueryPool(Option<CounterSampleBuffer>);
-
 pub struct Pipeline(metal::ComputePipelineState);
 
 #[derive(Default)]
@@ -134,10 +132,6 @@
     clear_pipeline: metal::ComputePipelineState,
 }
 
-pub struct ComputeEncoder {
-    raw: metal::ComputeCommandEncoder,
-}
-
 impl MtlInstance {
     pub fn new(
         window_handle: Option<&dyn HasRawWindowHandle>,
@@ -263,7 +257,7 @@
             helpers,
             timer_set,
             counter_style,
-    }
+        }
     }
 
     pub fn cmd_buf_from_raw_mtl(&self, raw_cmd_buf: metal::CommandBuffer) -> CmdBuf {
@@ -409,16 +403,28 @@
         if let Some(timer_set) = &self.timer_set {
             let pool = CounterSampleBuffer::new(&self.device, n_queries as u64, timer_set)
                 .ok_or("error creating timer query pool")?;
-            return Ok(QueryPool(Some(pool)));
+            return Ok(QueryPool {
+                counter_sample_buf: Some(pool),
+                calibration: Default::default(),
+            });
         }
-        Ok(QueryPool(None))
+        Ok(QueryPool::default())
     }
 
     unsafe fn fetch_query_pool(&self, pool: &Self::QueryPool) -> Result<Vec<f64>, Error> {
-        if let Some(raw) = &pool.0 {
+        if let Some(raw) = &pool.counter_sample_buf {
             let resolved = raw.resolve();
-            println!("resolved = {:?}", resolved);
+            let calibration = pool.calibration.lock().unwrap();
+            if let Some(calibration) = &*calibration {
+                let calibration = calibration.lock().unwrap();
+                let result = resolved
+                    .iter()
+                    .map(|time_ns| calibration.correlate(*time_ns))
+                    .collect();
+                return Ok(result);
+            }
         }
+        // Maybe should return None indicating it wasn't successful? But that might break.
         Ok(Vec::new())
     }
 
@@ -444,10 +450,6 @@
                 let gpu_ts_ptr = &mut time_calibration.gpu_start_ts as *mut _;
                 // TODO: only do this if supported.
                 let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
-                println!(
-                    "scheduled, {}, {}",
-                    time_calibration.cpu_start_ts, time_calibration.gpu_start_ts
-                );
             })
             .copy();
             add_scheduled_handler(&cmd_buf.cmd_buf, &start_block);
@@ -461,10 +463,6 @@
                     // TODO: only do this if supported.
                     let () =
                         msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
-                    println!(
-                        "completed, {}, {}",
-                        time_calibration.cpu_end_ts, time_calibration.gpu_end_ts
-                    );
                 })
                 .copy();
             cmd_buf.cmd_buf.add_completed_handler(&completed_block);
@@ -546,8 +544,6 @@
 }
 
 impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
-    type ComputeEncoder = ComputeEncoder;
-
     unsafe fn begin(&mut self) {}
 
     unsafe fn finish(&mut self) {
@@ -558,6 +554,35 @@
         false
     }
 
+    unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
+        debug_assert!(matches!(self.cur_encoder, Encoder::None));
+        let encoder = if let Some(queries) = &desc.timer_queries {
+            let descriptor: id = msg_send![class!(MTLComputePassDescriptor), computePassDescriptor];
+            let attachments: id = msg_send![descriptor, sampleBufferAttachments];
+            let index: NSUInteger = 0;
+            let attachment: id = msg_send![attachments, objectAtIndexedSubscript: index];
+            // Here we break the hub/mux separation a bit, for expedience
+            #[allow(irrefutable_let_patterns)]
+            if let crate::hub::QueryPool::Mtl(query_pool) = queries.0 {
+                if let Some(sample_buf) = &query_pool.counter_sample_buf {
+                    let () = msg_send![attachment, setSampleBuffer: sample_buf.id()];
+                }
+            }
+            let start_index = queries.1 as NSUInteger;
+            let end_index = queries.2 as NSInteger;
+            let () = msg_send![attachment, setStartOfEncoderSampleIndex: start_index];
+            let () = msg_send![attachment, setEndOfEncoderSampleIndex: end_index];
+            let encoder = msg_send![
+                self.cmd_buf,
+                computeCommandEncoderWithDescriptor: descriptor
+            ];
+            encoder
+        } else {
+            self.cmd_buf.new_compute_command_encoder()
+        };
+        self.cur_encoder = Encoder::Compute(encoder.to_owned());
+    }
+
     unsafe fn dispatch(
         &mut self,
         pipeline: &Pipeline,
@@ -590,6 +615,11 @@
         encoder.dispatch_thread_groups(workgroup_count, workgroup_size);
     }
 
+    unsafe fn end_compute_pass(&mut self) {
+        // TODO: might validate that we are in a compute encoder state
+        self.flush_encoder();
+    }
+
     unsafe fn memory_barrier(&mut self) {
         // We'll probably move to explicit barriers, but for now rely on
         // Metal's own tracking.
@@ -690,10 +720,13 @@
         );
     }
 
-    unsafe fn reset_query_pool(&mut self, _pool: &QueryPool) {}
+    unsafe fn reset_query_pool(&mut self, pool: &QueryPool) {
+        let mut calibration = pool.calibration.lock().unwrap();
+        *calibration = Some(self.time_calibration.clone());
+    }
 
     unsafe fn write_timestamp(&mut self, pool: &QueryPool, query: u32) {
-        if let Some(buf) = &pool.0 {
+        if let Some(buf) = &pool.counter_sample_buf {
             if matches!(self.cur_encoder, Encoder::None) {
                 self.cur_encoder =
                     Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned());
@@ -709,21 +742,14 @@
                 }
             } else if self.counter_style == CounterStyle::Stage {
                 match &self.cur_encoder {
-                    Encoder::Compute(e) => {
-                        println!("here we are");
+                    Encoder::Compute(_e) => {
+                        println!("write_timestamp is not supported for stage-style encoders");
                     }
                     _ => (),
                 }
             }
         }
     }
-
-    unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder {
-        let raw = self.cmd_buf.new_compute_command_encoder().to_owned();
-        ComputeEncoder {
-            raw
-        }
-    }
 }
 
 impl CmdBuf {
@@ -761,43 +787,6 @@
     }
 }
 
-impl crate::backend::ComputeEncoder<MtlDevice> for ComputeEncoder {
-    unsafe fn dispatch(
-        &mut self,
-        pipeline: &Pipeline,
-        descriptor_set: &DescriptorSet,
-        workgroup_count: (u32, u32, u32),
-        workgroup_size: (u32, u32, u32),
-    ) {
-        self.raw.set_compute_pipeline_state(&pipeline.0);
-        let mut buf_ix = 0;
-        for buffer in &descriptor_set.buffers {
-            self.raw.set_buffer(buf_ix, Some(&buffer.buffer), 0);
-            buf_ix += 1;
-        }
-        let mut img_ix = buf_ix;
-        for image in &descriptor_set.images {
-            self.raw.set_texture(img_ix, Some(&image.texture));
-            img_ix += 1;
-        }
-        let workgroup_count = metal::MTLSize {
-            width: workgroup_count.0 as u64,
-            height: workgroup_count.1 as u64,
-            depth: workgroup_count.2 as u64,
-        };
-        let workgroup_size = metal::MTLSize {
-            width: workgroup_size.0 as u64,
-            height: workgroup_size.1 as u64,
-            depth: workgroup_size.2 as u64,
-        };
-        self.raw.dispatch_thread_groups(workgroup_count, workgroup_size);
-    }
-
-    unsafe fn finish(&mut self) {
-        self.raw.end_encoding();
-    }
-}
-
 impl crate::backend::DescriptorSetBuilder<MtlDevice> for DescriptorSetBuilder {
     fn add_buffers(&mut self, buffers: &[&Buffer]) {
         self.0.buffers.extend(buffers.iter().copied().cloned());
diff --git a/piet-gpu-hal/src/metal/timer.rs b/piet-gpu-hal/src/metal/timer.rs
index a51bc6d..a8b80d6 100644
--- a/piet-gpu-hal/src/metal/timer.rs
+++ b/piet-gpu-hal/src/metal/timer.rs
@@ -36,6 +36,14 @@
     id: id,
 }
 
+#[derive(Default)]
+pub struct TimeCalibration {
+    pub cpu_start_ts: u64,
+    pub gpu_start_ts: u64,
+    pub cpu_end_ts: u64,
+    pub gpu_end_ts: u64,
+}
+
 impl Drop for CounterSampleBuffer {
     fn drop(&mut self) {
         unsafe { msg_send![self.id, release] }
@@ -87,7 +95,6 @@
         unsafe {
             let desc_cls = class!(MTLCounterSampleBufferDescriptor);
             let descriptor: id = msg_send![desc_cls, alloc];
-            println!("descriptor = {:?}", descriptor);
             let _: id = msg_send![descriptor, init];
             let count = count as NSUInteger;
             let () = msg_send![descriptor, setSampleCount: count];
@@ -121,3 +128,21 @@
         }
     }
 }
+
+impl TimeCalibration {
+    /// Convert GPU timestamp into CPU time base.
+    ///
+    /// See https://developer.apple.com/documentation/metal/performance_tuning/correlating_cpu_and_gpu_timestamps
+    pub fn correlate(&self, raw_ts: u64) -> f64 {
+        let delta_cpu = self.cpu_end_ts - self.cpu_start_ts;
+        let delta_gpu = self.gpu_end_ts - self.gpu_start_ts;
+        let adj_ts = if delta_gpu > 0 {
+            let scale = delta_cpu as f64 / delta_gpu as f64;
+            self.cpu_start_ts as f64 + (raw_ts - self.gpu_start_ts) as f64 * scale
+        } else {
+            // Default is ns on Apple Silicon; on other hardware this will be wrong
+            raw_ts as f64
+        };
+        adj_ts * 1e-9
+    }
+}
diff --git a/piet-gpu-hal/src/mux.rs b/piet-gpu-hal/src/mux.rs
index 7853c2b..9795193 100644
--- a/piet-gpu-hal/src/mux.rs
+++ b/piet-gpu-hal/src/mux.rs
@@ -35,6 +35,7 @@
 use crate::backend::Device as DeviceTrait;
 use crate::BackendType;
 use crate::BindType;
+use crate::ComputePassDescriptor;
 use crate::ImageFormat;
 use crate::MapMode;
 use crate::{BufferUsage, Error, GpuInfo, ImageLayout, InstanceFlags};
@@ -100,14 +101,6 @@
 QueryPool }
 mux_device_enum! { Sampler }
 
-mux_enum! {
-    pub enum ComputeEncoder {
-        Vk(<crate::vulkan::CmdBuf as crate::backend::CmdBuf<vulkan::VkDevice>>::ComputeEncoder),
-        Dx12(<crate::dx12::Dx12Device as crate::backend::CmdBuf<dx12::Dx12Device>>::ComputeEncoder),
-        Mtl(<crate::metal::CmdBuf as crate::backend::CmdBuf<metal::MtlDevice>>::ComputeEncoder),
-    }
-}
-
 /// The code for a shader, either as source or intermediate representation.
 pub enum ShaderCode<'a> {
     /// SPIR-V (binary intermediate representation)
@@ -666,6 +659,14 @@
         }
     }
 
+    pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
+        mux_match! { self;
+            CmdBuf::Vk(c) => c.begin_compute_pass(desc),
+            CmdBuf::Dx12(c) => c.begin_compute_pass(desc),
+            CmdBuf::Mtl(c) => c.begin_compute_pass(desc),
+        }
+    }
+
     /// Dispatch a compute shader.
     ///
     /// Note that both the number of workgroups (`workgroup_count`) and the number of
@@ -688,6 +689,14 @@
         }
     }
 
+    pub unsafe fn end_compute_pass(&mut self) {
+        mux_match! { self;
+            CmdBuf::Vk(c) => c.end_compute_pass(),
+            CmdBuf::Dx12(c) => c.end_compute_pass(),
+            CmdBuf::Mtl(c) => c.end_compute_pass(),
+        }
+    }
+
     pub unsafe fn memory_barrier(&mut self) {
         mux_match! { self;
             CmdBuf::Vk(c) => c.memory_barrier(),