blob: 13ce9faac1706de3c6d92984f8f3e50b135e3f26 [file] [log] [blame]
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Bic
{
uint a;
uint b;
};
struct InBuf
{
uint inbuf[1];
};
struct Bic_1
{
uint a;
uint b;
};
struct OutBuf
{
Bic_1 outbuf[1];
};
struct StackBuf
{
uint stack[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Bic bic_combine(thread const Bic& x, thread const Bic& y)
{
uint m = min(x.b, y.a);
return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
}
kernel void main0(const device InBuf& _48 [[buffer(0)]], device OutBuf& _159 [[buffer(1)]], device StackBuf& _221 [[buffer(2)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
threadgroup Bic sh_bic[512];
spvUnsafeArray<uint, 1> inp;
inp[0] = _48.inbuf[gl_GlobalInvocationID.x * 1u];
Bic bic = Bic{ 1u - inp[0], inp[0] };
for (uint i = 1u; i < 1u; i++)
{
inp[i] = _48.inbuf[(gl_GlobalInvocationID.x * 1u) + i];
Bic other = Bic{ 1u - inp[i], inp[i] };
Bic param = bic;
Bic param_1 = other;
bic = bic_combine(param, param_1);
}
sh_bic[gl_LocalInvocationID.x] = bic;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((gl_LocalInvocationID.x + (1u << i_1)) < 512u)
{
Bic other_1 = sh_bic[gl_LocalInvocationID.x + (1u << i_1)];
Bic param_2 = bic;
Bic param_3 = other_1;
bic = bic_combine(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_bic[gl_LocalInvocationID.x] = bic;
}
if (gl_LocalInvocationID.x == 0u)
{
_159.outbuf[gl_WorkGroupID.x].a = bic.a;
_159.outbuf[gl_WorkGroupID.x].b = bic.b;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint size = sh_bic[0].b;
bic = Bic{ 0u, 0u };
if ((gl_LocalInvocationID.x + 1u) < 512u)
{
bic = sh_bic[gl_LocalInvocationID.x + 1u];
}
uint out_ix = ((gl_WorkGroupID.x * 512u) + size) - bic.b;
for (uint i_2 = 1u; i_2 > 0u; i_2--)
{
bool _207 = inp[i_2 - 1u] == 1u;
bool _213;
if (_207)
{
_213 = bic.a == 0u;
}
else
{
_213 = _207;
}
if (_213)
{
out_ix--;
_221.stack[out_ix] = ((gl_GlobalInvocationID.x * 1u) + i_2) - 1u;
}
Bic other_2 = Bic{ 1u - inp[i_2 - 1u], inp[i_2 - 1u] };
Bic param_4 = other_2;
Bic param_5 = bic;
bic = bic_combine(param_4, param_5);
}
}