Compile shaders on Windows computer
diff --git a/tests/shader/gen/intersection_leaf.dxil b/tests/shader/gen/intersection_leaf.dxil
new file mode 100644
index 0000000..ca4e6ec
--- /dev/null
+++ b/tests/shader/gen/intersection_leaf.dxil
Binary files differ
diff --git a/tests/shader/gen/intersection_leaf.hlsl b/tests/shader/gen/intersection_leaf.hlsl
new file mode 100644
index 0000000..2b48f28
--- /dev/null
+++ b/tests/shader/gen/intersection_leaf.hlsl
@@ -0,0 +1,234 @@
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
+
+static const Bic _76 = { 0u, 0u };
+
+ByteAddressBuffer _89 : register(t1);
+ByteAddressBuffer _167 : register(t2);
+ByteAddressBuffer _285 : register(t0);
+RWByteAddressBuffer _520 : register(u3);
+
+static uint3 gl_WorkGroupID;
+static uint3 gl_LocalInvocationID;
+static uint3 gl_GlobalInvocationID;
+struct SPIRV_Cross_Input
+{
+ uint3 gl_WorkGroupID : SV_GroupID;
+ uint3 gl_LocalInvocationID : SV_GroupThreadID;
+ uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
+};
+
+groupshared Bic sh_bic[1022];
+groupshared float4 sh_bbox[512];
+groupshared float4 sh_stack[512];
+groupshared uint sh_link[512];
+
+Bic bic_combine(Bic x, Bic y)
+{
+ uint m = min(x.b, y.a);
+ Bic _46 = { (x.a + y.a) - m, (x.b + y.b) - m };
+ return _46;
+}
+
+float4 bbox_intersect(float4 a, float4 b)
+{
+ return float4(max(a.xy, b.xy), min(a.zw, b.zw));
+}
+
+void comp_main()
+{
+ uint th = gl_LocalInvocationID.x;
+ Bic bic = _76;
+ if (th < gl_WorkGroupID.x)
+ {
+ Bic _93;
+ _93.a = _89.Load(th * 8 + 0);
+ _93.b = _89.Load(th * 8 + 4);
+ bic.a = _93.a;
+ bic.b = _93.b;
+ }
+ sh_bic[th] = bic;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ GroupMemoryBarrierWithGroupSync();
+ uint other_ix = th + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ sh_bic[th] = bic;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ uint size = sh_bic[0].b;
+ uint bic_next_b = 0u;
+ if ((th + 1u) < 512u)
+ {
+ bic_next_b = sh_bic[th + 1u].b;
+ }
+ float4 bbox = float4(-1000000000.0f, -1000000000.0f, 1000000000.0f, 1000000000.0f);
+ if (bic.b > bic_next_b)
+ {
+ bbox = asfloat(_167.Load4(((((th * 512u) + bic.b) - bic_next_b) - 1u) * 16 + 0));
+ }
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ sh_bbox[th] = bbox;
+ GroupMemoryBarrierWithGroupSync();
+ if (th >= (1u << i_1))
+ {
+ float4 param_2 = sh_bbox[th - (1u << i_1)];
+ float4 param_3 = bbox;
+ bbox = bbox_intersect(param_2, param_3);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ sh_bbox[th] = bbox;
+ GroupMemoryBarrierWithGroupSync();
+ uint sp = 511u - th;
+ uint ix = 0u;
+ for (uint i_2 = 0u; i_2 < 9u; i_2++)
+ {
+ uint probe = ix + (256u >> i_2);
+ if (sp < sh_bic[probe].b)
+ {
+ ix = probe;
+ }
+ }
+ uint b = sh_bic[ix].b;
+ if (sp < b)
+ {
+ bbox = asfloat(_167.Load4(((((ix * 512u) + b) - sp) - 1u) * 16 + 0));
+ if (ix > 0u)
+ {
+ float4 param_4 = sh_bbox[ix - 1u];
+ float4 param_5 = bbox;
+ bbox = bbox_intersect(param_4, param_5);
+ }
+ sh_stack[th] = bbox;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ Node _291;
+ _291.node_type = _285.Load(gl_GlobalInvocationID.x * 32 + 0);
+ _291.pad1 = _285.Load(gl_GlobalInvocationID.x * 32 + 4);
+ _291.pad2 = _285.Load(gl_GlobalInvocationID.x * 32 + 8);
+ _291.pad3 = _285.Load(gl_GlobalInvocationID.x * 32 + 12);
+ _291.bbox = asfloat(_285.Load4(gl_GlobalInvocationID.x * 32 + 16));
+ Node inp;
+ inp.node_type = _291.node_type;
+ inp.pad1 = _291.pad1;
+ inp.pad2 = _291.pad2;
+ inp.pad3 = _291.pad3;
+ inp.bbox = _291.bbox;
+ uint node_type = inp.node_type;
+ Bic _314 = { uint(node_type == 1u), uint(node_type == 0u) };
+ bic = _314;
+ sh_bic[th] = bic;
+ uint inbase = 0u;
+ for (uint i_3 = 0u; i_3 < 8u; i_3++)
+ {
+ uint outbase = 1024u - (1u << (9u - i_3));
+ GroupMemoryBarrierWithGroupSync();
+ if (th < (1u << (8u - i_3)))
+ {
+ Bic param_6 = sh_bic[inbase + (th * 2u)];
+ Bic param_7 = sh_bic[(inbase + (th * 2u)) + 1u];
+ sh_bic[outbase + th] = bic_combine(param_6, param_7);
+ }
+ inbase = outbase;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ ix = th;
+ bic = _76;
+ uint j = 0u;
+ while (j < 9u)
+ {
+ uint base = 1024u - (2u << (9u - j));
+ if (((ix >> j) & 1u) != 0u)
+ {
+ Bic param_8 = sh_bic[(base + (ix >> j)) - 1u];
+ Bic param_9 = bic;
+ Bic test = bic_combine(param_8, param_9);
+ if (test.b > 0u)
+ {
+ break;
+ }
+ bic = test;
+ ix -= (1u << j);
+ }
+ j++;
+ }
+ if (ix > 0u)
+ {
+ while (j > 0u)
+ {
+ j--;
+ uint base_1 = 1024u - (2u << (9u - j));
+ Bic param_10 = sh_bic[(base_1 + (ix >> j)) - 1u];
+ Bic param_11 = bic;
+ Bic test_1 = bic_combine(param_10, param_11);
+ if (test_1.b == 0u)
+ {
+ bic = test_1;
+ ix -= (1u << j);
+ }
+ }
+ }
+ uint _455;
+ if (ix > 0u)
+ {
+ _455 = ix - 1u;
+ }
+ else
+ {
+ _455 = 4294967295u - bic.a;
+ }
+ uint link = _455;
+ bbox = inp.bbox;
+ for (uint i_4 = 0u; i_4 < 9u; i_4++)
+ {
+ sh_link[th] = link;
+ sh_bbox[th] = bbox;
+ GroupMemoryBarrierWithGroupSync();
+ if (int(link) >= 0)
+ {
+ float4 param_12 = sh_bbox[link];
+ float4 param_13 = bbox;
+ bbox = bbox_intersect(param_12, param_13);
+ link = sh_link[link];
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ if (int(link + size) >= 0)
+ {
+ float4 param_14 = sh_stack[512u + link];
+ float4 param_15 = bbox;
+ bbox = bbox_intersect(param_14, param_15);
+ }
+ _520.Store4(gl_GlobalInvocationID.x * 16 + 0, asuint(bbox));
+}
+
+[numthreads(512, 1, 1)]
+void main(SPIRV_Cross_Input stage_input)
+{
+ gl_WorkGroupID = stage_input.gl_WorkGroupID;
+ gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
+ gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
+ comp_main();
+}
diff --git a/tests/shader/gen/intersection_leaf.msl b/tests/shader/gen/intersection_leaf.msl
new file mode 100644
index 0000000..0bf83f1
--- /dev/null
+++ b/tests/shader/gen/intersection_leaf.msl
@@ -0,0 +1,241 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Bic_1
+{
+ uint a;
+ uint b;
+};
+
+struct BicBuf
+{
+ Bic_1 bicbuf[1];
+};
+
+struct StackBuf
+{
+ float4 stack[1];
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Node_1
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct InBuf
+{
+ Node_1 inbuf[1];
+};
+
+struct OutBuf
+{
+ float4 outbuf[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+Bic bic_combine(thread const Bic& x, thread const Bic& y)
+{
+ uint m = min(x.b, y.a);
+ return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
+}
+
+static inline __attribute__((always_inline))
+float4 bbox_intersect(thread const float4& a, thread const float4& b)
+{
+ return float4(fast::max(a.xy, b.xy), fast::min(a.zw, b.zw));
+}
+
+kernel void main0(const device InBuf& _285 [[buffer(0)]], const device BicBuf& _89 [[buffer(1)]], const device StackBuf& _167 [[buffer(2)]], device OutBuf& _520 [[buffer(3)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ threadgroup Bic sh_bic[1022];
+ threadgroup float4 sh_bbox[512];
+ threadgroup float4 sh_stack[512];
+ threadgroup uint sh_link[512];
+ uint th = gl_LocalInvocationID.x;
+ Bic bic = Bic{ 0u, 0u };
+ if (th < gl_WorkGroupID.x)
+ {
+ bic.a = _89.bicbuf[th].a;
+ bic.b = _89.bicbuf[th].b;
+ }
+ sh_bic[th] = bic;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint other_ix = th + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ sh_bic[th] = bic;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint size = sh_bic[0].b;
+ uint bic_next_b = 0u;
+ if ((th + 1u) < 512u)
+ {
+ bic_next_b = sh_bic[th + 1u].b;
+ }
+ float4 bbox = float4(-1000000000.0, -1000000000.0, 1000000000.0, 1000000000.0);
+ if (bic.b > bic_next_b)
+ {
+ bbox = _167.stack[(((th * 512u) + bic.b) - bic_next_b) - 1u];
+ }
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ sh_bbox[th] = bbox;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (th >= (1u << i_1))
+ {
+ float4 param_2 = sh_bbox[th - (1u << i_1)];
+ float4 param_3 = bbox;
+ bbox = bbox_intersect(param_2, param_3);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ sh_bbox[th] = bbox;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint sp = 511u - th;
+ uint ix = 0u;
+ for (uint i_2 = 0u; i_2 < 9u; i_2++)
+ {
+ uint probe = ix + (256u >> i_2);
+ if (sp < sh_bic[probe].b)
+ {
+ ix = probe;
+ }
+ }
+ uint b = sh_bic[ix].b;
+ if (sp < b)
+ {
+ bbox = _167.stack[(((ix * 512u) + b) - sp) - 1u];
+ if (ix > 0u)
+ {
+ float4 param_4 = sh_bbox[ix - 1u];
+ float4 param_5 = bbox;
+ bbox = bbox_intersect(param_4, param_5);
+ }
+ sh_stack[th] = bbox;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ Node inp;
+ inp.node_type = _285.inbuf[gl_GlobalInvocationID.x].node_type;
+ inp.pad1 = _285.inbuf[gl_GlobalInvocationID.x].pad1;
+ inp.pad2 = _285.inbuf[gl_GlobalInvocationID.x].pad2;
+ inp.pad3 = _285.inbuf[gl_GlobalInvocationID.x].pad3;
+ inp.bbox = _285.inbuf[gl_GlobalInvocationID.x].bbox;
+ uint node_type = inp.node_type;
+ bic = Bic{ uint(node_type == 1u), uint(node_type == 0u) };
+ sh_bic[th] = bic;
+ uint inbase = 0u;
+ for (uint i_3 = 0u; i_3 < 8u; i_3++)
+ {
+ uint outbase = 1024u - (1u << (9u - i_3));
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (th < (1u << (8u - i_3)))
+ {
+ Bic param_6 = sh_bic[inbase + (th * 2u)];
+ Bic param_7 = sh_bic[(inbase + (th * 2u)) + 1u];
+ sh_bic[outbase + th] = bic_combine(param_6, param_7);
+ }
+ inbase = outbase;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ ix = th;
+ bic = Bic{ 0u, 0u };
+ uint j = 0u;
+ while (j < 9u)
+ {
+ uint base = 1024u - (2u << (9u - j));
+ if (((ix >> j) & 1u) != 0u)
+ {
+ Bic param_8 = sh_bic[(base + (ix >> j)) - 1u];
+ Bic param_9 = bic;
+ Bic test = bic_combine(param_8, param_9);
+ if (test.b > 0u)
+ {
+ break;
+ }
+ bic = test;
+ ix -= (1u << j);
+ }
+ j++;
+ }
+ if (ix > 0u)
+ {
+ while (j > 0u)
+ {
+ j--;
+ uint base_1 = 1024u - (2u << (9u - j));
+ Bic param_10 = sh_bic[(base_1 + (ix >> j)) - 1u];
+ Bic param_11 = bic;
+ Bic test_1 = bic_combine(param_10, param_11);
+ if (test_1.b == 0u)
+ {
+ bic = test_1;
+ ix -= (1u << j);
+ }
+ }
+ }
+ uint _455;
+ if (ix > 0u)
+ {
+ _455 = ix - 1u;
+ }
+ else
+ {
+ _455 = 4294967295u - bic.a;
+ }
+ uint link = _455;
+ bbox = inp.bbox;
+ for (uint i_4 = 0u; i_4 < 9u; i_4++)
+ {
+ sh_link[th] = link;
+ sh_bbox[th] = bbox;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (int(link) >= 0)
+ {
+ float4 param_12 = sh_bbox[link];
+ float4 param_13 = bbox;
+ bbox = bbox_intersect(param_12, param_13);
+ link = sh_link[link];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ if (int(link + size) >= 0)
+ {
+ float4 param_14 = sh_stack[512u + link];
+ float4 param_15 = bbox;
+ bbox = bbox_intersect(param_14, param_15);
+ }
+ _520.outbuf[gl_GlobalInvocationID.x] = bbox;
+}
+
diff --git a/tests/shader/gen/intersection_leaf.spv b/tests/shader/gen/intersection_leaf.spv
new file mode 100644
index 0000000..001a7d7
--- /dev/null
+++ b/tests/shader/gen/intersection_leaf.spv
Binary files differ
diff --git a/tests/shader/gen/intersection_reduce.dxil b/tests/shader/gen/intersection_reduce.dxil
new file mode 100644
index 0000000..e8737e8
--- /dev/null
+++ b/tests/shader/gen/intersection_reduce.dxil
Binary files differ
diff --git a/tests/shader/gen/intersection_reduce.hlsl b/tests/shader/gen/intersection_reduce.hlsl
new file mode 100644
index 0000000..90f9540
--- /dev/null
+++ b/tests/shader/gen/intersection_reduce.hlsl
@@ -0,0 +1,152 @@
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
+
+static const Bic _181 = { 0u, 0u };
+
+ByteAddressBuffer _82 : register(t0);
+RWByteAddressBuffer _165 : register(u1);
+RWByteAddressBuffer _273 : register(u2);
+
+static uint3 gl_WorkGroupID;
+static uint3 gl_LocalInvocationID;
+static uint3 gl_GlobalInvocationID;
+struct SPIRV_Cross_Input
+{
+ uint3 gl_WorkGroupID : SV_GroupID;
+ uint3 gl_LocalInvocationID : SV_GroupThreadID;
+ uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
+};
+
+groupshared Bic sh_bic[512];
+groupshared float4 sh_bbox[512];
+
+Bic bic_combine(Bic x, Bic y)
+{
+ uint m = min(x.b, y.a);
+ Bic _46 = { (x.a + y.a) - m, (x.b + y.b) - m };
+ return _46;
+}
+
+float4 bbox_intersect(float4 a, float4 b)
+{
+ return float4(max(a.xy, b.xy), min(a.zw, b.zw));
+}
+
+void comp_main()
+{
+ uint th = gl_LocalInvocationID.x;
+ Node _88;
+ _88.node_type = _82.Load(gl_GlobalInvocationID.x * 32 + 0);
+ _88.pad1 = _82.Load(gl_GlobalInvocationID.x * 32 + 4);
+ _88.pad2 = _82.Load(gl_GlobalInvocationID.x * 32 + 8);
+ _88.pad3 = _82.Load(gl_GlobalInvocationID.x * 32 + 12);
+ _88.bbox = asfloat(_82.Load4(gl_GlobalInvocationID.x * 32 + 16));
+ Node inp;
+ inp.node_type = _88.node_type;
+ inp.pad1 = _88.pad1;
+ inp.pad2 = _88.pad2;
+ inp.pad3 = _88.pad3;
+ inp.bbox = _88.bbox;
+ uint node_type = inp.node_type;
+ Bic _114 = { uint(node_type == 1u), uint(node_type == 0u) };
+ Bic bic = _114;
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ GroupMemoryBarrierWithGroupSync();
+ uint other_ix = gl_LocalInvocationID.x + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ sh_bic[th] = bic;
+ }
+ if (th == 0u)
+ {
+ _165.Store(gl_WorkGroupID.x * 8 + 0, bic.a);
+ _165.Store(gl_WorkGroupID.x * 8 + 4, bic.b);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ uint size = sh_bic[0].b;
+ bic = _181;
+ if ((th + 1u) < 512u)
+ {
+ bic = sh_bic[th + 1u];
+ }
+ bool _193 = inp.node_type == 0u;
+ bool _199;
+ if (_193)
+ {
+ _199 = bic.a == 0u;
+ }
+ else
+ {
+ _199 = _193;
+ }
+ if (_199)
+ {
+ uint out_ix = (size - bic.b) - 1u;
+ sh_bbox[out_ix] = inp.bbox;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ float4 bbox;
+ if (th < size)
+ {
+ bbox = sh_bbox[th];
+ }
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ bool _235 = th < size;
+ bool _242;
+ if (_235)
+ {
+ _242 = th >= (1u << i_1);
+ }
+ else
+ {
+ _242 = _235;
+ }
+ if (_242)
+ {
+ float4 param_2 = sh_bbox[th - (1u << i_1)];
+ float4 param_3 = bbox;
+ bbox = bbox_intersect(param_2, param_3);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ if (th < size)
+ {
+ sh_bbox[th] = bbox;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ }
+ if (th < size)
+ {
+ _273.Store4(gl_GlobalInvocationID.x * 16 + 0, asuint(bbox));
+ }
+}
+
+[numthreads(512, 1, 1)]
+void main(SPIRV_Cross_Input stage_input)
+{
+ gl_WorkGroupID = stage_input.gl_WorkGroupID;
+ gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
+ gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
+ comp_main();
+}
diff --git a/tests/shader/gen/intersection_reduce.msl b/tests/shader/gen/intersection_reduce.msl
new file mode 100644
index 0000000..357a525
--- /dev/null
+++ b/tests/shader/gen/intersection_reduce.msl
@@ -0,0 +1,158 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Node_1
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct InBuf
+{
+ Node_1 inbuf[1];
+};
+
+struct Bic_1
+{
+ uint a;
+ uint b;
+};
+
+struct BicBuf
+{
+ Bic_1 bicbuf[1];
+};
+
+struct StackBuf
+{
+ float4 stack[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+Bic bic_combine(thread const Bic& x, thread const Bic& y)
+{
+ uint m = min(x.b, y.a);
+ return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
+}
+
+static inline __attribute__((always_inline))
+float4 bbox_intersect(thread const float4& a, thread const float4& b)
+{
+ return float4(fast::max(a.xy, b.xy), fast::min(a.zw, b.zw));
+}
+
+kernel void main0(const device InBuf& _82 [[buffer(0)]], device BicBuf& _165 [[buffer(1)]], device StackBuf& _273 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+ threadgroup Bic sh_bic[512];
+ threadgroup float4 sh_bbox[512];
+ uint th = gl_LocalInvocationID.x;
+ Node inp;
+ inp.node_type = _82.inbuf[gl_GlobalInvocationID.x].node_type;
+ inp.pad1 = _82.inbuf[gl_GlobalInvocationID.x].pad1;
+ inp.pad2 = _82.inbuf[gl_GlobalInvocationID.x].pad2;
+ inp.pad3 = _82.inbuf[gl_GlobalInvocationID.x].pad3;
+ inp.bbox = _82.inbuf[gl_GlobalInvocationID.x].bbox;
+ uint node_type = inp.node_type;
+ Bic bic = Bic{ uint(node_type == 1u), uint(node_type == 0u) };
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint other_ix = gl_LocalInvocationID.x + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ sh_bic[th] = bic;
+ }
+ if (th == 0u)
+ {
+ _165.bicbuf[gl_WorkGroupID.x].a = bic.a;
+ _165.bicbuf[gl_WorkGroupID.x].b = bic.b;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint size = sh_bic[0].b;
+ bic = Bic{ 0u, 0u };
+ if ((th + 1u) < 512u)
+ {
+ bic = sh_bic[th + 1u];
+ }
+ bool _193 = inp.node_type == 0u;
+ bool _199;
+ if (_193)
+ {
+ _199 = bic.a == 0u;
+ }
+ else
+ {
+ _199 = _193;
+ }
+ if (_199)
+ {
+ uint out_ix = (size - bic.b) - 1u;
+ sh_bbox[out_ix] = inp.bbox;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ float4 bbox;
+ if (th < size)
+ {
+ bbox = sh_bbox[th];
+ }
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ bool _235 = th < size;
+ bool _242;
+ if (_235)
+ {
+ _242 = th >= (1u << i_1);
+ }
+ else
+ {
+ _242 = _235;
+ }
+ if (_242)
+ {
+ float4 param_2 = sh_bbox[th - (1u << i_1)];
+ float4 param_3 = bbox;
+ bbox = bbox_intersect(param_2, param_3);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (th < size)
+ {
+ sh_bbox[th] = bbox;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ if (th < size)
+ {
+ _273.stack[gl_GlobalInvocationID.x] = bbox;
+ }
+}
+
diff --git a/tests/shader/gen/intersection_reduce.spv b/tests/shader/gen/intersection_reduce.spv
new file mode 100644
index 0000000..587476d
--- /dev/null
+++ b/tests/shader/gen/intersection_reduce.spv
Binary files differ
diff --git a/tests/shader/gen/stack_leaf.dxil b/tests/shader/gen/stack_leaf.dxil
index 14658c3..16867ad 100644
--- a/tests/shader/gen/stack_leaf.dxil
+++ b/tests/shader/gen/stack_leaf.dxil
Binary files differ
diff --git a/tests/shader/gen/stack_reduce.dxil b/tests/shader/gen/stack_reduce.dxil
index 50ceb82..440b017 100644
--- a/tests/shader/gen/stack_reduce.dxil
+++ b/tests/shader/gen/stack_reduce.dxil
Binary files differ
diff --git a/tests/shader/gen/union_leaf.dxil b/tests/shader/gen/union_leaf.dxil
new file mode 100644
index 0000000..03d8011
--- /dev/null
+++ b/tests/shader/gen/union_leaf.dxil
Binary files differ
diff --git a/tests/shader/gen/union_leaf.hlsl b/tests/shader/gen/union_leaf.hlsl
new file mode 100644
index 0000000..16ee5b1
--- /dev/null
+++ b/tests/shader/gen/union_leaf.hlsl
@@ -0,0 +1,223 @@
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct BicBbox
+{
+ Bic bic;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
+
+static const Bic _76 = { 0u, 0u };
+
+ByteAddressBuffer _94 : register(t1);
+ByteAddressBuffer _213 : register(t2);
+ByteAddressBuffer _249 : register(t0);
+RWByteAddressBuffer _492 : register(u3);
+
+static uint3 gl_WorkGroupID;
+static uint3 gl_LocalInvocationID;
+static uint3 gl_GlobalInvocationID;
+struct SPIRV_Cross_Input
+{
+ uint3 gl_WorkGroupID : SV_GroupID;
+ uint3 gl_LocalInvocationID : SV_GroupThreadID;
+ uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
+};
+
+groupshared Bic sh_bic[1022];
+groupshared float4 sh_bbox[1022];
+groupshared float4 sh_stack[512];
+
+Bic bic_combine(Bic x, Bic y)
+{
+ uint m = min(x.b, y.a);
+ Bic _46 = { (x.a + y.a) - m, (x.b + y.b) - m };
+ return _46;
+}
+
+float4 bbox_union(float4 a, float4 b)
+{
+ return float4(min(a.xy, b.xy), max(a.zw, b.zw));
+}
+
+void comp_main()
+{
+ uint th = gl_LocalInvocationID.x;
+ Bic bic = _76;
+ float4 bbox = float4(1000000000.0f, 1000000000.0f, -1000000000.0f, -1000000000.0f);
+ if (th < gl_WorkGroupID.x)
+ {
+ Bic _98;
+ _98.a = _94.Load(th * 32 + 0);
+ _98.b = _94.Load(th * 32 + 4);
+ bic.a = _98.a;
+ bic.b = _98.b;
+ bbox = asfloat(_94.Load4(th * 32 + 16));
+ }
+ sh_bic[th] = bic;
+ sh_bbox[th] = bbox;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ GroupMemoryBarrierWithGroupSync();
+ uint other_ix = th + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ float4 param_2 = bbox;
+ float4 param_3 = sh_bbox[other_ix];
+ bbox = bbox_union(param_2, param_3);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ sh_bic[th] = bic;
+ sh_bbox[th] = bbox;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ uint size = sh_bic[0].b;
+ uint sp = 511u - th;
+ uint ix = 0u;
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ uint probe = ix + (256u >> i_1);
+ if (sp < sh_bic[probe].b)
+ {
+ ix = probe;
+ }
+ }
+ uint b = sh_bic[ix].b;
+ if (sp < b)
+ {
+ float4 bbox_1 = asfloat(_213.Load4(((((ix * 512u) + b) - sp) - 1u) * 16 + 0));
+ if ((ix + 1u) < 512u)
+ {
+ float4 param_4 = bbox_1;
+ float4 param_5 = sh_bbox[ix + 1u];
+ bbox_1 = bbox_union(param_4, param_5);
+ }
+ sh_stack[th] = bbox_1;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ Node _255;
+ _255.node_type = _249.Load(gl_GlobalInvocationID.x * 32 + 0);
+ _255.pad1 = _249.Load(gl_GlobalInvocationID.x * 32 + 4);
+ _255.pad2 = _249.Load(gl_GlobalInvocationID.x * 32 + 8);
+ _255.pad3 = _249.Load(gl_GlobalInvocationID.x * 32 + 12);
+ _255.bbox = asfloat(_249.Load4(gl_GlobalInvocationID.x * 32 + 16));
+ Node inp;
+ inp.node_type = _255.node_type;
+ inp.pad1 = _255.pad1;
+ inp.pad2 = _255.pad2;
+ inp.pad3 = _255.pad3;
+ inp.bbox = _255.bbox;
+ uint node_type = inp.node_type;
+ Bic _277 = { uint(node_type == 1u), uint(node_type == 0u) };
+ bic = _277;
+ sh_bic[th] = bic;
+ sh_bbox[th] = inp.bbox;
+ uint inbase = 0u;
+ for (uint i_2 = 0u; i_2 < 8u; i_2++)
+ {
+ uint outbase = 1024u - (1u << (9u - i_2));
+ GroupMemoryBarrierWithGroupSync();
+ if (th < (1u << (8u - i_2)))
+ {
+ Bic param_6 = sh_bic[inbase + (th * 2u)];
+ Bic param_7 = sh_bic[(inbase + (th * 2u)) + 1u];
+ sh_bic[outbase + th] = bic_combine(param_6, param_7);
+ float4 param_8 = sh_bbox[inbase + (th * 2u)];
+ float4 param_9 = sh_bbox[(inbase + (th * 2u)) + 1u];
+ sh_bbox[outbase + th] = bbox_union(param_8, param_9);
+ }
+ inbase = outbase;
+ }
+ GroupMemoryBarrierWithGroupSync();
+ ix = th;
+ bbox = inp.bbox;
+ bic = _76;
+ if (node_type == 1u)
+ {
+ uint j = 0u;
+ while (j < 9u)
+ {
+ uint base = 1024u - (2u << (9u - j));
+ if (((ix >> j) & 1u) != 0u)
+ {
+ Bic param_10 = sh_bic[(base + (ix >> j)) - 1u];
+ Bic param_11 = bic;
+ Bic test = bic_combine(param_10, param_11);
+ if (test.b > 0u)
+ {
+ break;
+ }
+ bic = test;
+ float4 param_12 = sh_bbox[(base + (ix >> j)) - 1u];
+ float4 param_13 = bbox;
+ bbox = bbox_union(param_12, param_13);
+ ix -= (1u << j);
+ }
+ j++;
+ }
+ if (ix > 0u)
+ {
+ while (j > 0u)
+ {
+ j--;
+ uint base_1 = 1024u - (2u << (9u - j));
+ Bic param_14 = sh_bic[(base_1 + (ix >> j)) - 1u];
+ Bic param_15 = bic;
+ Bic test_1 = bic_combine(param_14, param_15);
+ if (test_1.b == 0u)
+ {
+ bic = test_1;
+ float4 param_16 = sh_bbox[(base_1 + (ix >> j)) - 1u];
+ float4 param_17 = bbox;
+ bbox = bbox_union(param_16, param_17);
+ ix -= (1u << j);
+ }
+ }
+ }
+ bool _470 = ix == 0u;
+ bool _477;
+ if (_470)
+ {
+ _477 = bic.a < size;
+ }
+ else
+ {
+ _477 = _470;
+ }
+ if (_477)
+ {
+ float4 param_18 = sh_stack[511u - bic.a];
+ float4 param_19 = bbox;
+ bbox = bbox_union(param_18, param_19);
+ }
+ }
+ _492.Store4(gl_GlobalInvocationID.x * 16 + 0, asuint(bbox));
+}
+
+[numthreads(512, 1, 1)]
+void main(SPIRV_Cross_Input stage_input)
+{
+ gl_WorkGroupID = stage_input.gl_WorkGroupID;
+ gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
+ gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
+ comp_main();
+}
diff --git a/tests/shader/gen/union_leaf.msl b/tests/shader/gen/union_leaf.msl
new file mode 100644
index 0000000..d42b1eb
--- /dev/null
+++ b/tests/shader/gen/union_leaf.msl
@@ -0,0 +1,230 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Bic_1
+{
+ uint a;
+ uint b;
+};
+
+struct BicBbox
+{
+ Bic_1 bic;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct BicBuf
+{
+ BicBbox bicbuf[1];
+};
+
+struct StackBuf
+{
+ float4 stack[1];
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Node_1
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct InBuf
+{
+ Node_1 inbuf[1];
+};
+
+struct OutBuf
+{
+ float4 outbuf[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+Bic bic_combine(thread const Bic& x, thread const Bic& y)
+{
+ uint m = min(x.b, y.a);
+ return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
+}
+
+static inline __attribute__((always_inline))
+float4 bbox_union(thread const float4& a, thread const float4& b)
+{
+ return float4(fast::min(a.xy, b.xy), fast::max(a.zw, b.zw));
+}
+
+kernel void main0(const device InBuf& _249 [[buffer(0)]], const device BicBuf& _94 [[buffer(1)]], const device StackBuf& _213 [[buffer(2)]], device OutBuf& _492 [[buffer(3)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
+{
+ threadgroup Bic sh_bic[1022];
+ threadgroup float4 sh_bbox[1022];
+ threadgroup float4 sh_stack[512];
+ uint th = gl_LocalInvocationID.x;
+ Bic bic = Bic{ 0u, 0u };
+ float4 bbox = float4(1000000000.0, 1000000000.0, -1000000000.0, -1000000000.0);
+ if (th < gl_WorkGroupID.x)
+ {
+ bic.a = _94.bicbuf[th].bic.a;
+ bic.b = _94.bicbuf[th].bic.b;
+ bbox = _94.bicbuf[th].bbox;
+ }
+ sh_bic[th] = bic;
+ sh_bbox[th] = bbox;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint other_ix = th + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ float4 param_2 = bbox;
+ float4 param_3 = sh_bbox[other_ix];
+ bbox = bbox_union(param_2, param_3);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ sh_bic[th] = bic;
+ sh_bbox[th] = bbox;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint size = sh_bic[0].b;
+ uint sp = 511u - th;
+ uint ix = 0u;
+ for (uint i_1 = 0u; i_1 < 9u; i_1++)
+ {
+ uint probe = ix + (256u >> i_1);
+ if (sp < sh_bic[probe].b)
+ {
+ ix = probe;
+ }
+ }
+ uint b = sh_bic[ix].b;
+ if (sp < b)
+ {
+ float4 bbox_1 = _213.stack[(((ix * 512u) + b) - sp) - 1u];
+ if ((ix + 1u) < 512u)
+ {
+ float4 param_4 = bbox_1;
+ float4 param_5 = sh_bbox[ix + 1u];
+ bbox_1 = bbox_union(param_4, param_5);
+ }
+ sh_stack[th] = bbox_1;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ Node inp;
+ inp.node_type = _249.inbuf[gl_GlobalInvocationID.x].node_type;
+ inp.pad1 = _249.inbuf[gl_GlobalInvocationID.x].pad1;
+ inp.pad2 = _249.inbuf[gl_GlobalInvocationID.x].pad2;
+ inp.pad3 = _249.inbuf[gl_GlobalInvocationID.x].pad3;
+ inp.bbox = _249.inbuf[gl_GlobalInvocationID.x].bbox;
+ uint node_type = inp.node_type;
+ bic = Bic{ uint(node_type == 1u), uint(node_type == 0u) };
+ sh_bic[th] = bic;
+ sh_bbox[th] = inp.bbox;
+ uint inbase = 0u;
+ for (uint i_2 = 0u; i_2 < 8u; i_2++)
+ {
+ uint outbase = 1024u - (1u << (9u - i_2));
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (th < (1u << (8u - i_2)))
+ {
+ Bic param_6 = sh_bic[inbase + (th * 2u)];
+ Bic param_7 = sh_bic[(inbase + (th * 2u)) + 1u];
+ sh_bic[outbase + th] = bic_combine(param_6, param_7);
+ float4 param_8 = sh_bbox[inbase + (th * 2u)];
+ float4 param_9 = sh_bbox[(inbase + (th * 2u)) + 1u];
+ sh_bbox[outbase + th] = bbox_union(param_8, param_9);
+ }
+ inbase = outbase;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ ix = th;
+ bbox = inp.bbox;
+ bic = Bic{ 0u, 0u };
+ if (node_type == 1u)
+ {
+ uint j = 0u;
+ while (j < 9u)
+ {
+ uint base = 1024u - (2u << (9u - j));
+ if (((ix >> j) & 1u) != 0u)
+ {
+ Bic param_10 = sh_bic[(base + (ix >> j)) - 1u];
+ Bic param_11 = bic;
+ Bic test = bic_combine(param_10, param_11);
+ if (test.b > 0u)
+ {
+ break;
+ }
+ bic = test;
+ float4 param_12 = sh_bbox[(base + (ix >> j)) - 1u];
+ float4 param_13 = bbox;
+ bbox = bbox_union(param_12, param_13);
+ ix -= (1u << j);
+ }
+ j++;
+ }
+ if (ix > 0u)
+ {
+ while (j > 0u)
+ {
+ j--;
+ uint base_1 = 1024u - (2u << (9u - j));
+ Bic param_14 = sh_bic[(base_1 + (ix >> j)) - 1u];
+ Bic param_15 = bic;
+ Bic test_1 = bic_combine(param_14, param_15);
+ if (test_1.b == 0u)
+ {
+ bic = test_1;
+ float4 param_16 = sh_bbox[(base_1 + (ix >> j)) - 1u];
+ float4 param_17 = bbox;
+ bbox = bbox_union(param_16, param_17);
+ ix -= (1u << j);
+ }
+ }
+ }
+ bool _470 = ix == 0u;
+ bool _477;
+ if (_470)
+ {
+ _477 = bic.a < size;
+ }
+ else
+ {
+ _477 = _470;
+ }
+ if (_477)
+ {
+ float4 param_18 = sh_stack[511u - bic.a];
+ float4 param_19 = bbox;
+ bbox = bbox_union(param_18, param_19);
+ }
+ }
+ _492.outbuf[gl_GlobalInvocationID.x] = bbox;
+}
+
diff --git a/tests/shader/gen/union_leaf.spv b/tests/shader/gen/union_leaf.spv
new file mode 100644
index 0000000..e81edb2
--- /dev/null
+++ b/tests/shader/gen/union_leaf.spv
Binary files differ
diff --git a/tests/shader/gen/union_reduce.dxil b/tests/shader/gen/union_reduce.dxil
new file mode 100644
index 0000000..22fbb6a
--- /dev/null
+++ b/tests/shader/gen/union_reduce.dxil
Binary files differ
diff --git a/tests/shader/gen/union_reduce.hlsl b/tests/shader/gen/union_reduce.hlsl
new file mode 100644
index 0000000..696d217
--- /dev/null
+++ b/tests/shader/gen/union_reduce.hlsl
@@ -0,0 +1,135 @@
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct BicBbox
+{
+ Bic bic;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
+
+static const Bic _219 = { 0u, 0u };
+
+ByteAddressBuffer _74 : register(t0);
+RWByteAddressBuffer _193 : register(u1);
+RWByteAddressBuffer _255 : register(u2);
+
+static uint3 gl_WorkGroupID;
+static uint3 gl_LocalInvocationID;
+static uint3 gl_GlobalInvocationID;
+struct SPIRV_Cross_Input
+{
+ uint3 gl_WorkGroupID : SV_GroupID;
+ uint3 gl_LocalInvocationID : SV_GroupThreadID;
+ uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
+};
+
+groupshared Bic sh_bic[512];
+groupshared float4 sh_bbox[512];
+
+Bic bic_combine(Bic x, Bic y)
+{
+ uint m = min(x.b, y.a);
+ Bic _46 = { (x.a + y.a) - m, (x.b + y.b) - m };
+ return _46;
+}
+
+float4 bbox_union(float4 a, float4 b)
+{
+ return float4(min(a.xy, b.xy), max(a.zw, b.zw));
+}
+
+void comp_main()
+{
+ Node _84;
+ _84.node_type = _74.Load(gl_GlobalInvocationID.x * 32 + 0);
+ _84.pad1 = _74.Load(gl_GlobalInvocationID.x * 32 + 4);
+ _84.pad2 = _74.Load(gl_GlobalInvocationID.x * 32 + 8);
+ _84.pad3 = _74.Load(gl_GlobalInvocationID.x * 32 + 12);
+ _84.bbox = asfloat(_74.Load4(gl_GlobalInvocationID.x * 32 + 16));
+ Node inp;
+ inp.node_type = _84.node_type;
+ inp.pad1 = _84.pad1;
+ inp.pad2 = _84.pad2;
+ inp.pad3 = _84.pad3;
+ inp.bbox = _84.bbox;
+ uint node_type = inp.node_type;
+ float4 bbox = inp.bbox;
+ Bic _113 = { uint(node_type == 1u), uint(node_type == 0u) };
+ Bic bic = _113;
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ sh_bbox[gl_LocalInvocationID.x] = bbox;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ GroupMemoryBarrierWithGroupSync();
+ uint other_ix = gl_LocalInvocationID.x + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ float4 param_2 = bbox;
+ float4 param_3 = sh_bbox[other_ix];
+ bbox = bbox_union(param_2, param_3);
+ }
+ GroupMemoryBarrierWithGroupSync();
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ sh_bbox[gl_LocalInvocationID.x] = bbox;
+ }
+ if (gl_LocalInvocationID.x == 0u)
+ {
+ BicBbox _187 = { bic, 0u, 0u, bbox };
+ BicBbox bic_bbox = _187;
+ _193.Store(gl_WorkGroupID.x * 32 + 0, bic_bbox.bic.a);
+ _193.Store(gl_WorkGroupID.x * 32 + 4, bic_bbox.bic.b);
+ _193.Store(gl_WorkGroupID.x * 32 + 8, bic_bbox.pad2);
+ _193.Store(gl_WorkGroupID.x * 32 + 12, bic_bbox.pad3);
+ _193.Store4(gl_WorkGroupID.x * 32 + 16, asuint(bic_bbox.bbox));
+ }
+ GroupMemoryBarrierWithGroupSync();
+ uint size = sh_bic[0].b;
+ bic = _219;
+ if ((gl_LocalInvocationID.x + 1u) < 512u)
+ {
+ bic = sh_bic[gl_LocalInvocationID.x + 1u];
+ }
+ bool _233 = inp.node_type == 0u;
+ bool _239;
+ if (_233)
+ {
+ _239 = bic.a == 0u;
+ }
+ else
+ {
+ _239 = _233;
+ }
+ if (_239)
+ {
+ uint out_ix = (((gl_WorkGroupID.x * 512u) + size) - bic.b) - 1u;
+ _255.Store4(out_ix * 16 + 0, asuint(sh_bbox[gl_LocalInvocationID.x]));
+ }
+}
+
+[numthreads(512, 1, 1)]
+void main(SPIRV_Cross_Input stage_input)
+{
+ gl_WorkGroupID = stage_input.gl_WorkGroupID;
+ gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
+ gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
+ comp_main();
+}
diff --git a/tests/shader/gen/union_reduce.msl b/tests/shader/gen/union_reduce.msl
new file mode 100644
index 0000000..bc87386
--- /dev/null
+++ b/tests/shader/gen/union_reduce.msl
@@ -0,0 +1,148 @@
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+
+#include <metal_stdlib>
+#include <simd/simd.h>
+
+using namespace metal;
+
+struct Bic
+{
+ uint a;
+ uint b;
+};
+
+struct Node
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Node_1
+{
+ uint node_type;
+ uint pad1;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct InBuf
+{
+ Node_1 inbuf[1];
+};
+
+struct BicBbox
+{
+ Bic bic;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct Bic_1
+{
+ uint a;
+ uint b;
+};
+
+struct BicBbox_1
+{
+ Bic_1 bic;
+ uint pad2;
+ uint pad3;
+ float4 bbox;
+};
+
+struct BicBuf
+{
+ BicBbox_1 bicbuf[1];
+};
+
+struct StackBuf
+{
+ float4 stack[1];
+};
+
+constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
+
+static inline __attribute__((always_inline))
+Bic bic_combine(thread const Bic& x, thread const Bic& y)
+{
+ uint m = min(x.b, y.a);
+ return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
+}
+
+static inline __attribute__((always_inline))
+float4 bbox_union(thread const float4& a, thread const float4& b)
+{
+ return float4(fast::min(a.xy, b.xy), fast::max(a.zw, b.zw));
+}
+
+kernel void main0(const device InBuf& _74 [[buffer(0)]], device BicBuf& _193 [[buffer(1)]], device StackBuf& _255 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
+{
+ threadgroup Bic sh_bic[512];
+ threadgroup float4 sh_bbox[512];
+ Node inp;
+ inp.node_type = _74.inbuf[gl_GlobalInvocationID.x].node_type;
+ inp.pad1 = _74.inbuf[gl_GlobalInvocationID.x].pad1;
+ inp.pad2 = _74.inbuf[gl_GlobalInvocationID.x].pad2;
+ inp.pad3 = _74.inbuf[gl_GlobalInvocationID.x].pad3;
+ inp.bbox = _74.inbuf[gl_GlobalInvocationID.x].bbox;
+ uint node_type = inp.node_type;
+ float4 bbox = inp.bbox;
+ Bic bic = Bic{ uint(node_type == 1u), uint(node_type == 0u) };
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ sh_bbox[gl_LocalInvocationID.x] = bbox;
+ for (uint i = 0u; i < 9u; i++)
+ {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint other_ix = gl_LocalInvocationID.x + (1u << i);
+ if (other_ix < 512u)
+ {
+ Bic param = bic;
+ Bic param_1 = sh_bic[other_ix];
+ bic = bic_combine(param, param_1);
+ float4 param_2 = bbox;
+ float4 param_3 = sh_bbox[other_ix];
+ bbox = bbox_union(param_2, param_3);
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ sh_bic[gl_LocalInvocationID.x] = bic;
+ sh_bbox[gl_LocalInvocationID.x] = bbox;
+ }
+ if (gl_LocalInvocationID.x == 0u)
+ {
+ BicBbox bic_bbox = BicBbox{ bic, 0u, 0u, bbox };
+ _193.bicbuf[gl_WorkGroupID.x].bic.a = bic_bbox.bic.a;
+ _193.bicbuf[gl_WorkGroupID.x].bic.b = bic_bbox.bic.b;
+ _193.bicbuf[gl_WorkGroupID.x].pad2 = bic_bbox.pad2;
+ _193.bicbuf[gl_WorkGroupID.x].pad3 = bic_bbox.pad3;
+ _193.bicbuf[gl_WorkGroupID.x].bbox = bic_bbox.bbox;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ uint size = sh_bic[0].b;
+ bic = Bic{ 0u, 0u };
+ if ((gl_LocalInvocationID.x + 1u) < 512u)
+ {
+ bic = sh_bic[gl_LocalInvocationID.x + 1u];
+ }
+ bool _233 = inp.node_type == 0u;
+ bool _239;
+ if (_233)
+ {
+ _239 = bic.a == 0u;
+ }
+ else
+ {
+ _239 = _233;
+ }
+ if (_239)
+ {
+ uint out_ix = (((gl_WorkGroupID.x * 512u) + size) - bic.b) - 1u;
+ _255.stack[out_ix] = sh_bbox[gl_LocalInvocationID.x];
+ }
+}
+
diff --git a/tests/shader/gen/union_reduce.spv b/tests/shader/gen/union_reduce.spv
new file mode 100644
index 0000000..a049391
--- /dev/null
+++ b/tests/shader/gen/union_reduce.spv
Binary files differ