Alternative strategies for elements barrier

This patch provides an #ifdef METAL for switching between two different
strategies for the barriers in the elements shader. With it enabled, the
barriers are in workgroup uniform control flow, which is compatible with
translation to Metal, but fails to compile in FXC, as FXC's uniformity
analysis fails. With it disabled, the lookback logic (including
barriers) runs only in a single thread, which compiles in FXC (though
not tested), but creates problems in Metal.

In testing on Android, the METAL version seems slightly faster. On AMD
5700 XT, there is no measurable difference.

I'm inclined not to commit this, but it's potentially useful if we want
to explore cs_5_0 compatibility.
diff --git a/piet-gpu/shader/elements.comp b/piet-gpu/shader/elements.comp
index e4bbfec..c410da3 100644
--- a/piet-gpu/shader/elements.comp
+++ b/piet-gpu/shader/elements.comp
@@ -175,7 +175,15 @@
 
 shared uint sh_part_ix;
 shared State sh_prefix;
+#define METAL
+#ifdef METAL
 shared uint sh_flag;
+#define LAST_THREAD_INNER gl_LocalInvocationID.x == WG_SIZE - 1
+#define LAST_THREAD_OUTER true
+#else
+#define LAST_THREAD_INNER true
+#define LAST_THREAD_OUTER gl_LocalInvocationID.x == WG_SIZE - 1
+#endif
 
 void main() {
     State th_state[N_ROWS];
@@ -234,78 +242,90 @@
         }
         state[state_flag_index(part_ix)] = flag;
     }
-    if (part_ix != 0) {
-        // step 4 of paper: decoupled lookback
-        uint look_back_ix = part_ix - 1;
+    if (LAST_THREAD_OUTER) {
+        if (part_ix != 0) {
+            // step 4 of paper: decoupled lookback
+            uint look_back_ix = part_ix - 1;
 
-        State their_agg;
-        uint their_ix = 0;
-        while (true) {
-            // Read flag with acquire semantics.
-            if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                sh_flag = state[state_flag_index(look_back_ix)];
-            }
-            // The flag load is done only in the last thread. However, because the
-            // translation of memoryBarrierBuffer to Metal requires uniform control
-            // flow, we broadcast it to all threads.
-            barrier();
-            memoryBarrierBuffer();
-            uint flag = sh_flag;
+            State their_agg;
+            uint their_ix = 0;
+            while (true) {
+                // Read flag with acquire semantics.
+#ifdef METAL
+                if (LAST_THREAD_INNER) {
+                    sh_flag = state[state_flag_index(look_back_ix)];
+                }
+                // The flag load is done only in the last thread. However, because the
+                // translation of memoryBarrierBuffer to Metal requires uniform control
+                // flow, we broadcast it to all threads.
+                barrier();
+                memoryBarrierBuffer();
+                uint flag = sh_flag;
+#else
+                uint flag = state[state_flag_index(look_back_ix)];
+#endif
 
-            if (flag == FLAG_PREFIX_READY) {
-                if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                    State their_prefix = State_read(state_prefix_ref(look_back_ix));
-                    exclusive = combine_state(their_prefix, exclusive);
+                if (flag == FLAG_PREFIX_READY) {
+                    if (LAST_THREAD_INNER) {
+                        State their_prefix = State_read(state_prefix_ref(look_back_ix));
+                        exclusive = combine_state(their_prefix, exclusive);
+                    }
+                    break;
+                } else if (flag == FLAG_AGGREGATE_READY) {
+                    if (LAST_THREAD_INNER) {
+                        their_agg = State_read(state_aggregate_ref(look_back_ix));
+                        exclusive = combine_state(their_agg, exclusive);
+                    }
+                    look_back_ix--;
+                    their_ix = 0;
+                    continue;
                 }
-                break;
-            } else if (flag == FLAG_AGGREGATE_READY) {
-                if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                    their_agg = State_read(state_aggregate_ref(look_back_ix));
-                    exclusive = combine_state(their_agg, exclusive);
-                }
-                look_back_ix--;
-                their_ix = 0;
-                continue;
-            }
-            // else spin
+                // else spin
 
-            if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                // Unfortunately there's no guarantee of forward progress of other
-                // workgroups, so compute a bit of the aggregate before trying again.
-                // In the worst case, spinning stops when the aggregate is complete.
-                ElementRef ref = ElementRef((look_back_ix * PARTITION_SIZE + their_ix) * Element_size);
-                State s = map_element(ref);
-                if (their_ix == 0) {
-                    their_agg = s;
-                } else {
-                    their_agg = combine_state(their_agg, s);
-                }
-                their_ix++;
-                if (their_ix == PARTITION_SIZE) {
-                    exclusive = combine_state(their_agg, exclusive);
-                    if (look_back_ix == 0) {
-                        sh_flag = FLAG_PREFIX_READY;
+                if (LAST_THREAD_INNER) {
+                    // Unfortunately there's no guarantee of forward progress of other
+                    // workgroups, so compute a bit of the aggregate before trying again.
+                    // In the worst case, spinning stops when the aggregate is complete.
+                    ElementRef ref = ElementRef((look_back_ix * PARTITION_SIZE + their_ix) * Element_size);
+                    State s = map_element(ref);
+                    if (their_ix == 0) {
+                        their_agg = s;
                     } else {
-                        look_back_ix--;
-                        their_ix = 0;
+                        their_agg = combine_state(their_agg, s);
+                    }
+                    their_ix++;
+                    if (their_ix == PARTITION_SIZE) {
+                        exclusive = combine_state(their_agg, exclusive);
+                        if (look_back_ix == 0) {
+#ifdef METAL
+                            sh_flag = FLAG_PREFIX_READY;
+#else
+                            break;
+#endif
+                        } else {
+                            look_back_ix--;
+                            their_ix = 0;
+                        }
                     }
                 }
+#ifdef METAL
+                barrier();
+                flag = sh_flag;
+                if (flag == FLAG_PREFIX_READY) {
+                    break;
+                }
+#endif
             }
-            barrier();
-            flag = sh_flag;
-            if (flag == FLAG_PREFIX_READY) {
-                break;
+            // step 5 of paper: compute inclusive prefix
+            if (LAST_THREAD_INNER) {
+                State inclusive_prefix = combine_state(exclusive, agg);
+                sh_prefix = exclusive;
+                State_write(state_prefix_ref(part_ix), inclusive_prefix);
             }
-        }
-        // step 5 of paper: compute inclusive prefix
-        if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-            State inclusive_prefix = combine_state(exclusive, agg);
-            sh_prefix = exclusive;
-            State_write(state_prefix_ref(part_ix), inclusive_prefix);
-        }
-        memoryBarrierBuffer();
-        if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-            state[state_flag_index(part_ix)] = FLAG_PREFIX_READY;
+            memoryBarrierBuffer();
+            if (LAST_THREAD_INNER) {
+                state[state_flag_index(part_ix)] = FLAG_PREFIX_READY;
+            }
         }
     }
     barrier();
diff --git a/piet-gpu/shader/elements.spv b/piet-gpu/shader/elements.spv
index 60517b0..183c892 100644
--- a/piet-gpu/shader/elements.spv
+++ b/piet-gpu/shader/elements.spv
Binary files differ