blob: 13afb0a8cf74c472f3bb600d435d7bf7a8f3b75b [file] [log] [blame]
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Bic
{
uint a;
uint b;
};
struct Bic_1
{
uint a;
uint b;
};
struct BicBuf
{
Bic_1 bicbuf[1];
};
struct StackBuf
{
uint stack[1];
};
struct InBuf
{
uint inbuf[1];
};
struct OutBuf
{
uint outbuf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Bic bic_combine(thread const Bic& x, thread const Bic& y)
{
uint m = min(x.b, y.a);
return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
}
static inline __attribute__((always_inline))
uint search_link(thread Bic& bic, thread uint3& gl_LocalInvocationID, threadgroup Bic (&sh_bic)[1022])
{
uint ix = gl_LocalInvocationID.x;
uint j = 0u;
while (j < 9u)
{
uint base = 1024u - (2u << (9u - j));
if (((ix >> j) & 1u) != 0u)
{
Bic param = sh_bic[(base + (ix >> j)) - 1u];
Bic param_1 = bic;
Bic test = bic_combine(param, param_1);
if (test.b > 0u)
{
break;
}
bic = test;
ix -= (1u << j);
}
j++;
}
if (ix > 0u)
{
while (j > 0u)
{
j--;
uint base_1 = 1024u - (2u << (9u - j));
Bic param_2 = sh_bic[(base_1 + (ix >> j)) - 1u];
Bic param_3 = bic;
Bic test_1 = bic_combine(param_2, param_3);
if (test_1.b == 0u)
{
bic = test_1;
ix -= (1u << j);
}
}
}
return ix;
}
kernel void main0(const device InBuf& _314 [[buffer(0)]], const device BicBuf& _170 [[buffer(1)]], const device StackBuf& _298 [[buffer(2)]], device OutBuf& _399 [[buffer(3)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
threadgroup Bic sh_bic[1022];
threadgroup uint sh_stack[512];
uint th = gl_LocalInvocationID.x;
Bic bic = Bic{ 0u, 0u };
if ((th * 1u) < gl_WorkGroupID.x)
{
uint _172 = th * 1u;
bic.a = _170.bicbuf[_172].a;
bic.b = _170.bicbuf[_172].b;
}
Bic other;
for (uint i = 1u; i < 1u; i++)
{
if (((th * 1u) + i) < gl_WorkGroupID.x)
{
uint _201 = (th * 1u) + i;
other.a = _170.bicbuf[_201].a;
other.b = _170.bicbuf[_201].b;
Bic param = bic;
Bic param_1 = other;
bic = bic_combine(param, param_1);
}
}
sh_bic[th] = bic;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((th + (1u << i_1)) < 512u)
{
Bic other_1 = sh_bic[th + (1u << i_1)];
Bic param_2 = bic;
Bic param_3 = other_1;
bic = bic_combine(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_bic[th] = bic;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint sp = 511u - th;
uint ix = 0u;
for (uint i_2 = 0u; i_2 < 9u; i_2++)
{
uint probe = ix + (256u >> i_2);
if (sp < sh_bic[probe].b)
{
ix = probe;
}
}
uint b = sh_bic[ix].b;
if (sp < b)
{
sh_stack[th] = _298.stack[(((ix * 512u) + b) - sp) - 1u];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint inp = _314.inbuf[((gl_GlobalInvocationID.x * 1u) + 1u) - 1u];
bic = Bic{ 1u - inp, inp };
sh_bic[th] = bic;
uint inbase = 0u;
for (uint i_3 = 0u; i_3 < 8u; i_3++)
{
uint outbase = 1024u - (1u << (9u - i_3));
threadgroup_barrier(mem_flags::mem_threadgroup);
if (th < (1u << (8u - i_3)))
{
Bic param_4 = sh_bic[inbase + (th * 2u)];
Bic param_5 = sh_bic[(inbase + (th * 2u)) + 1u];
sh_bic[outbase + th] = bic_combine(param_4, param_5);
}
inbase = outbase;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
bic = Bic{ 0u, 0u };
Bic param_6 = bic;
uint _377 = search_link(param_6, gl_LocalInvocationID, sh_bic);
bic = param_6;
ix = _377;
uint outp;
if (ix > 0u)
{
outp = ((gl_WorkGroupID.x * 512u) + ix) - 1u;
}
else
{
outp = sh_stack[511u - bic.a];
}
_399.outbuf[gl_GlobalInvocationID.x] = outp;
}