blob: 8b4b2a1032e5b2ce959b8a208fb433ab4a6c034e [file] [log] [blame]
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Transform
{
float4 mat;
float2 translate;
};
struct Transform_1
{
float4 mat;
float2 translate;
char _m0_final_padding[8];
};
struct DataBuf
{
Transform_1 data[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(256u, 1u, 1u);
static inline __attribute__((always_inline))
Transform combine_monoid(thread const Transform& a, thread const Transform& b)
{
Transform c;
c.mat = (a.mat.xyxy * b.mat.xxzz) + (a.mat.zwzw * b.mat.yyww);
c.translate = ((a.mat.xy * b.translate.x) + (a.mat.zw * b.translate.y)) + a.translate;
return c;
}
static inline __attribute__((always_inline))
Transform monoid_identity()
{
return Transform{ float4(1.0, 0.0, 0.0, 1.0), float2(0.0) };
}
kernel void main0(device DataBuf& _89 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
{
threadgroup Transform sh_scratch[256];
uint ix = gl_GlobalInvocationID.x * 8u;
spvUnsafeArray<Transform, 8> local;
local[0].mat = _89.data[ix].mat;
local[0].translate = _89.data[ix].translate;
Transform param_1;
for (uint i = 1u; i < 8u; i++)
{
uint _113 = ix + i;
Transform param = local[i - 1u];
param_1.mat = _89.data[_113].mat;
param_1.translate = _89.data[_113].translate;
local[i] = combine_monoid(param, param_1);
}
Transform agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 8u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= (1u << i_1))
{
Transform other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Transform param_2 = other;
Transform param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
Transform row = monoid_identity();
if (gl_LocalInvocationID.x > 0u)
{
row = sh_scratch[gl_LocalInvocationID.x - 1u];
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Transform param_4 = row;
Transform param_5 = local[i_2];
Transform m = combine_monoid(param_4, param_5);
uint _208 = ix + i_2;
_89.data[_208].mat = m.mat;
_89.data[_208].translate = m.translate;
}
}