blob: cbf72a3056194385485460ef4fa2328bdefc1c8e [file] [log] [blame]
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
// Implementation of the unsigned GLSL findMSB() function
template<typename T>
inline T spvFindUMSB(T x)
{
return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));
}
struct Bic
{
uint a;
uint b;
};
struct Bic_1
{
uint a;
uint b;
};
struct BicBuf
{
Bic_1 bicbuf[1];
};
struct StackBuf
{
uint stack[1];
};
struct InBuf
{
uint inbuf[1];
};
struct OutBuf
{
uint outbuf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(64u, 1u, 1u);
static inline __attribute__((always_inline))
Bic bic_combine(thread const Bic& x, thread const Bic& y)
{
uint m = min(x.b, y.a);
return Bic{ (x.a + y.a) - m, (x.b + y.b) - m };
}
static inline __attribute__((always_inline))
uint search_bit_set(thread const uint& bitmask, thread const uint& ix)
{
uint result = 0u;
for (uint j = 0u; j < 5u; j++)
{
uint _step = 1u << (4u - j);
if (uint(int(popcount(bitmask & ((1u << (result + _step)) - 1u)))) <= ix)
{
result += _step;
}
}
return result;
}
static inline __attribute__((always_inline))
uint search_link(thread Bic& bic, thread uint3& gl_LocalInvocationID, threadgroup Bic (&sh_bic)[126], threadgroup uint (&sh_bitmaps)[64])
{
uint ix = gl_LocalInvocationID.x;
uint j = 0u;
while (j < 6u)
{
uint base = 128u - (2u << (6u - j));
if (((ix >> j) & 1u) != 0u)
{
Bic param = sh_bic[(base + (ix >> j)) - 1u];
Bic param_1 = bic;
Bic test = bic_combine(param, param_1);
if (test.b > 0u)
{
break;
}
bic = test;
ix -= (1u << j);
}
j++;
}
if (ix > 0u)
{
while (j > 0u)
{
j--;
uint base_1 = 128u - (2u << (6u - j));
Bic param_2 = sh_bic[(base_1 + (ix >> j)) - 1u];
Bic param_3 = bic;
Bic test_1 = bic_combine(param_2, param_3);
if (test_1.b == 0u)
{
bic = test_1;
ix -= (1u << j);
}
}
}
if (ix > 0u)
{
ix--;
Bic param_4 = sh_bic[ix];
Bic param_5 = bic;
Bic test_2 = bic_combine(param_4, param_5);
uint param_6 = sh_bitmaps[ix];
uint param_7 = test_2.b - 1u;
uint ix_in_chunk = search_bit_set(param_6, param_7);
return (ix * 8u) + ix_in_chunk;
}
else
{
return 4294967295u - bic.a;
}
}
kernel void main0(const device InBuf& _593 [[buffer(0)]], const device BicBuf& _250 [[buffer(1)]], const device StackBuf& _512 [[buffer(2)]], device OutBuf& _751 [[buffer(3)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
threadgroup Bic sh_bic[126];
threadgroup uint sh_bitmaps[64];
threadgroup uint sh_stack[512];
threadgroup uint sh_link[64];
threadgroup uint sh_next[64];
uint th = gl_LocalInvocationID.x;
Bic bic = Bic{ 0u, 0u };
if ((th * 8u) < gl_WorkGroupID.x)
{
uint _252 = th * 8u;
bic.a = _250.bicbuf[_252].a;
bic.b = _250.bicbuf[_252].b;
}
Bic other;
for (uint i = 1u; i < 8u; i++)
{
if (((th * 8u) + i) < gl_WorkGroupID.x)
{
uint _281 = (th * 8u) + i;
other.a = _250.bicbuf[_281].a;
other.b = _250.bicbuf[_281].b;
Bic param = bic;
Bic param_1 = other;
bic = bic_combine(param, param_1);
}
}
sh_bic[th] = bic;
for (uint i_1 = 0u; i_1 < 6u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((th + (1u << i_1)) < 64u)
{
Bic other_1 = sh_bic[th + (1u << i_1)];
Bic param_2 = bic;
Bic param_3 = other_1;
bic = bic_combine(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_bic[th] = bic;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (th == 63u)
{
bic = Bic{ 0u, 0u };
}
else
{
bic = sh_bic[th + 1u];
}
uint last_b = bic.b;
uint bitmap = 0u;
Bic param_4;
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
uint this_ix = (((th * 8u) + 8u) - 1u) - i_2;
if (this_ix < gl_WorkGroupID.x)
{
param_4.a = _250.bicbuf[this_ix].a;
param_4.b = _250.bicbuf[this_ix].b;
Bic param_5 = bic;
bic = bic_combine(param_4, param_5);
}
sh_stack[this_ix] = bic.b;
if (bic.b > last_b)
{
bitmap |= (1u << (7u - i_2));
}
last_b = bic.b;
}
sh_bitmaps[th] = bitmap;
uint link = 0u;
if (bitmap != 0u)
{
link = (th * 8u) + uint(int(spvFindUMSB(bitmap)));
}
sh_link[th] = link;
for (uint i_3 = 0u; i_3 < 6u; i_3++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (th >= (1u << i_3))
{
link = max(link, sh_link[th - (1u << i_3)]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_link[th] = link;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint sp = 504u - (th * 8u);
uint ix = 0u;
for (uint i_4 = 0u; i_4 < 9u; i_4++)
{
uint probe = ix + (256u >> i_4);
if (sp < sh_stack[probe])
{
ix = probe;
}
}
uint b = sh_stack[ix];
spvUnsafeArray<uint, 8> local_stack;
for (uint i_5 = 0u; i_5 < 8u; i_5++)
{
local_stack[i_5] = 0u;
}
uint i_6 = 0u;
while ((sp + i_6) < b)
{
local_stack[7u - i_6] = _512.stack[(((ix * 512u) + b) - (sp + i_6)) - 1u];
i_6++;
if (i_6 == 8u)
{
break;
}
if ((sp + i_6) == b)
{
uint bits = sh_bitmaps[ix / 8u] & ((1u << (ix % 8u)) - 1u);
if (bits == 0u)
{
ix = sh_link[max((ix / 8u), 1u) - 1u];
}
else
{
ix = (ix & 4294967288u) + uint(int(spvFindUMSB(bits)));
}
b = sh_stack[ix];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i_7 = 0u; i_7 < 8u; i_7++)
{
sh_stack[(th * 8u) + i_7] = local_stack[i_7];
}
uint inp = _593.inbuf[((gl_GlobalInvocationID.x * 8u) + 8u) - 1u];
bic = Bic{ 1u - inp, inp };
bitmap = inp << uint(7);
for (uint i_8 = 7u; i_8 > 0u; i_8--)
{
inp = _593.inbuf[((gl_GlobalInvocationID.x * 8u) + i_8) - 1u];
bool _626 = inp == 1u;
bool _632;
if (_626)
{
_632 = bic.a == 0u;
}
else
{
_632 = _626;
}
if (_632)
{
bitmap |= (1u << (i_8 - 1u));
}
Bic other_2 = Bic{ 1u - inp, inp };
Bic param_6 = other_2;
Bic param_7 = bic;
bic = bic_combine(param_6, param_7);
}
sh_bitmaps[th] = bitmap;
sh_bic[th] = bic;
uint inbase = 0u;
for (uint i_9 = 0u; i_9 < 5u; i_9++)
{
uint outbase = 128u - (1u << (6u - i_9));
threadgroup_barrier(mem_flags::mem_threadgroup);
if (th < (1u << (5u - i_9)))
{
Bic param_8 = sh_bic[inbase + (th * 2u)];
Bic param_9 = sh_bic[(inbase + (th * 2u)) + 1u];
sh_bic[outbase + th] = bic_combine(param_8, param_9);
}
inbase = outbase;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
bic.b = 0u;
Bic param_10 = bic;
uint _706 = search_link(param_10, gl_LocalInvocationID, sh_bic, sh_bitmaps);
bic = param_10;
sh_link[th] = _706;
bic = Bic{ 0u, 0u };
Bic param_11 = bic;
uint _711 = search_link(param_11, gl_LocalInvocationID, sh_bic, sh_bitmaps);
bic = param_11;
ix = _711;
uint loc_sp = 0u;
uint outp;
spvUnsafeArray<uint, 8> loc_stack;
for (uint i_10 = 0u; i_10 < 8u; i_10++)
{
if (loc_sp > 0u)
{
outp = loc_stack[loc_sp - 1u];
}
else
{
if (int(ix) >= 0)
{
outp = (gl_WorkGroupID.x * 512u) + ix;
}
else
{
outp = sh_stack[512u + ix];
}
}
_751.outbuf[(gl_GlobalInvocationID.x * 8u) + i_10] = outp;
inp = _593.inbuf[(gl_GlobalInvocationID.x * 8u) + i_10];
if (inp == 1u)
{
loc_stack[loc_sp] = (gl_GlobalInvocationID.x * 8u) + i_10;
loc_sp++;
}
else
{
if (inp == 0u)
{
if (loc_sp > 0u)
{
loc_sp--;
}
else
{
if (int(ix) >= 0)
{
uint bits_1 = sh_bitmaps[ix / 8u] & ((1u << (ix % 8u)) - 1u);
if (bits_1 == 0u)
{
ix = sh_link[ix / 8u];
}
else
{
ix = (ix & 4294967288u) + uint(int(spvFindUMSB(bits_1)));
}
}
else
{
ix--;
}
}
}
}
}
}