| #pragma clang diagnostic ignored "-Wmissing-prototypes" |
| #pragma clang diagnostic ignored "-Wmissing-braces" |
| |
| #include <metal_stdlib> |
| #include <simd/simd.h> |
| |
| using namespace metal; |
| |
| template<typename T, size_t Num> |
| struct spvUnsafeArray |
| { |
| T elements[Num ? Num : 1]; |
| |
| thread T& operator [] (size_t pos) thread |
| { |
| return elements[pos]; |
| } |
| constexpr const thread T& operator [] (size_t pos) const thread |
| { |
| return elements[pos]; |
| } |
| |
| device T& operator [] (size_t pos) device |
| { |
| return elements[pos]; |
| } |
| constexpr const device T& operator [] (size_t pos) const device |
| { |
| return elements[pos]; |
| } |
| |
| constexpr const constant T& operator [] (size_t pos) const constant |
| { |
| return elements[pos]; |
| } |
| |
| threadgroup T& operator [] (size_t pos) threadgroup |
| { |
| return elements[pos]; |
| } |
| constexpr const threadgroup T& operator [] (size_t pos) const threadgroup |
| { |
| return elements[pos]; |
| } |
| }; |
| |
| struct Transform |
| { |
| float4 mat; |
| float2 translate; |
| }; |
| |
| struct Transform_1 |
| { |
| float4 mat; |
| float2 translate; |
| char _m0_final_padding[8]; |
| }; |
| |
| struct DataBuf |
| { |
| Transform_1 data[1]; |
| }; |
| |
| constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); |
| |
| static inline __attribute__((always_inline)) |
| Transform combine_monoid(thread const Transform& a, thread const Transform& b) |
| { |
| Transform c; |
| c.mat = (a.mat.xyxy * b.mat.xxzz) + (a.mat.zwzw * b.mat.yyww); |
| c.translate = ((a.mat.xy * b.translate.x) + (a.mat.zw * b.translate.y)) + a.translate; |
| return c; |
| } |
| |
| static inline __attribute__((always_inline)) |
| Transform monoid_identity() |
| { |
| return Transform{ float4(1.0, 0.0, 0.0, 1.0), float2(0.0) }; |
| } |
| |
| kernel void main0(device DataBuf& _89 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) |
| { |
| threadgroup Transform sh_scratch[512]; |
| uint ix = gl_GlobalInvocationID.x * 8u; |
| spvUnsafeArray<Transform, 8> local; |
| local[0].mat = _89.data[ix].mat; |
| local[0].translate = _89.data[ix].translate; |
| Transform param_1; |
| for (uint i = 1u; i < 8u; i++) |
| { |
| uint _113 = ix + i; |
| Transform param = local[i - 1u]; |
| param_1.mat = _89.data[_113].mat; |
| param_1.translate = _89.data[_113].translate; |
| local[i] = combine_monoid(param, param_1); |
| } |
| Transform agg = local[7]; |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| for (uint i_1 = 0u; i_1 < 9u; i_1++) |
| { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| if (gl_LocalInvocationID.x >= (1u << i_1)) |
| { |
| Transform other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; |
| Transform param_2 = other; |
| Transform param_3 = agg; |
| agg = combine_monoid(param_2, param_3); |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| Transform row = monoid_identity(); |
| if (gl_LocalInvocationID.x > 0u) |
| { |
| row = sh_scratch[gl_LocalInvocationID.x - 1u]; |
| } |
| for (uint i_2 = 0u; i_2 < 8u; i_2++) |
| { |
| Transform param_4 = row; |
| Transform param_5 = local[i_2]; |
| Transform m = combine_monoid(param_4, param_5); |
| uint _209 = ix + i_2; |
| _89.data[_209].mat = m.mat; |
| _89.data[_209].translate = m.translate; |
| } |
| } |
| |