blob: bc8738625faf752a452574bee60eda6cfbdd5647 [file] [log] [blame]
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Bic
{
uint a;
uint b;
};
struct Node
{
uint node_type;
uint pad1;
uint pad2;
uint pad3;
float4 bbox;
};
struct Node_1
{
uint node_type;
uint pad1;
uint pad2;
uint pad3;
float4 bbox;
};
struct InBuf
{
Node_1 inbuf[1];
};
struct BicBbox
{
Bic bic;
uint pad2;
uint pad3;
float4 bbox;
};
struct Bic_1
{
uint a;
uint b;
};
struct BicBbox_1
{
Bic_1 bic;
uint pad2;
uint pad3;
float4 bbox;
};
struct BicBuf
{
BicBbox_1 bicbuf[1];
};
struct StackBuf
{
float4 stack[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Bic bic_combine(thread const Bic& x, thread const Bic& y)
{
uint m = min(x.b, y.a);
return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
}
static inline __attribute__((always_inline))
float4 bbox_union(thread const float4& a, thread const float4& b)
{
return float4(fast::min(a.xy, b.xy), fast::max(a.zw, b.zw));
}
kernel void main0(const device InBuf& _74 [[buffer(0)]], device BicBuf& _193 [[buffer(1)]], device StackBuf& _255 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
threadgroup Bic sh_bic[512];
threadgroup float4 sh_bbox[512];
Node inp;
inp.node_type = _74.inbuf[gl_GlobalInvocationID.x].node_type;
inp.pad1 = _74.inbuf[gl_GlobalInvocationID.x].pad1;
inp.pad2 = _74.inbuf[gl_GlobalInvocationID.x].pad2;
inp.pad3 = _74.inbuf[gl_GlobalInvocationID.x].pad3;
inp.bbox = _74.inbuf[gl_GlobalInvocationID.x].bbox;
uint node_type = inp.node_type;
float4 bbox = inp.bbox;
Bic bic = Bic{ uint(node_type == 1u), uint(node_type == 0u) };
sh_bic[gl_LocalInvocationID.x] = bic;
sh_bbox[gl_LocalInvocationID.x] = bbox;
for (uint i = 0u; i < 9u; i++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
uint other_ix = gl_LocalInvocationID.x + (1u << i);
if (other_ix < 512u)
{
Bic param = bic;
Bic param_1 = sh_bic[other_ix];
bic = bic_combine(param, param_1);
float4 param_2 = bbox;
float4 param_3 = sh_bbox[other_ix];
bbox = bbox_union(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_bic[gl_LocalInvocationID.x] = bic;
sh_bbox[gl_LocalInvocationID.x] = bbox;
}
if (gl_LocalInvocationID.x == 0u)
{
BicBbox bic_bbox = BicBbox{ bic, 0u, 0u, bbox };
_193.bicbuf[gl_WorkGroupID.x].bic.a = bic_bbox.bic.a;
_193.bicbuf[gl_WorkGroupID.x].bic.b = bic_bbox.bic.b;
_193.bicbuf[gl_WorkGroupID.x].pad2 = bic_bbox.pad2;
_193.bicbuf[gl_WorkGroupID.x].pad3 = bic_bbox.pad3;
_193.bicbuf[gl_WorkGroupID.x].bbox = bic_bbox.bbox;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint size = sh_bic[0].b;
bic = Bic{ 0u, 0u };
if ((gl_LocalInvocationID.x + 1u) < 512u)
{
bic = sh_bic[gl_LocalInvocationID.x + 1u];
}
bool _233 = inp.node_type == 0u;
bool _239;
if (_233)
{
_239 = bic.a == 0u;
}
else
{
_239 = _233;
}
if (_239)
{
uint out_ix = (((gl_WorkGroupID.x * 512u) + size) - bic.b) - 1u;
_255.stack[out_ix] = sh_bbox[gl_LocalInvocationID.x];
}
}