| #pragma clang diagnostic ignored "-Wmissing-prototypes" |
| #pragma clang diagnostic ignored "-Wmissing-braces" |
| #pragma clang diagnostic ignored "-Wunused-variable" |
| |
| #include <metal_stdlib> |
| #include <simd/simd.h> |
| #include <metal_atomic> |
| |
| using namespace metal; |
| |
| template<typename T, size_t Num> |
| struct spvUnsafeArray |
| { |
| T elements[Num ? Num : 1]; |
| |
| thread T& operator [] (size_t pos) thread |
| { |
| return elements[pos]; |
| } |
| constexpr const thread T& operator [] (size_t pos) const thread |
| { |
| return elements[pos]; |
| } |
| |
| device T& operator [] (size_t pos) device |
| { |
| return elements[pos]; |
| } |
| constexpr const device T& operator [] (size_t pos) const device |
| { |
| return elements[pos]; |
| } |
| |
| constexpr const constant T& operator [] (size_t pos) const constant |
| { |
| return elements[pos]; |
| } |
| |
| threadgroup T& operator [] (size_t pos) threadgroup |
| { |
| return elements[pos]; |
| } |
| constexpr const threadgroup T& operator [] (size_t pos) const threadgroup |
| { |
| return elements[pos]; |
| } |
| }; |
| |
| struct Monoid |
| { |
| uint element; |
| }; |
| |
| struct Monoid_1 |
| { |
| uint element; |
| }; |
| |
| struct State |
| { |
| uint flag; |
| Monoid_1 aggregate; |
| Monoid_1 prefix; |
| }; |
| |
| struct StateBuf |
| { |
| uint part_counter; |
| State state[1]; |
| }; |
| |
| struct InBuf |
| { |
| Monoid_1 inbuf[1]; |
| }; |
| |
| struct OutBuf |
| { |
| Monoid_1 outbuf[1]; |
| }; |
| |
| constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); |
| |
| static inline __attribute__((always_inline)) |
| Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) |
| { |
| return Monoid{ a.element + b.element }; |
| } |
| |
| kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _367 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) |
| { |
| threadgroup uint sh_part_ix; |
| threadgroup Monoid sh_scratch[512]; |
| threadgroup uint sh_flag; |
| threadgroup Monoid sh_prefix; |
| if (gl_LocalInvocationID.x == 0u) |
| { |
| uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed); |
| sh_part_ix = _47; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| uint part_ix = sh_part_ix; |
| uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); |
| spvUnsafeArray<Monoid, 16> local; |
| local[0].element = _67.inbuf[ix].element; |
| Monoid param_1; |
| for (uint i = 1u; i < 16u; i++) |
| { |
| Monoid param = local[i - 1u]; |
| param_1.element = _67.inbuf[ix + i].element; |
| local[i] = combine_monoid(param, param_1); |
| } |
| Monoid agg = local[15]; |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| for (uint i_1 = 0u; i_1 < 9u; i_1++) |
| { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| if (gl_LocalInvocationID.x >= (1u << i_1)) |
| { |
| Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; |
| Monoid param_2 = other; |
| Monoid param_3 = agg; |
| agg = combine_monoid(param_2, param_3); |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| } |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].aggregate.element, agg.element, memory_order_relaxed); |
| if (part_ix == 0u) |
| { |
| atomic_store_explicit((volatile device atomic_uint*)&_43.state[0].prefix.element, agg.element, memory_order_relaxed); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_device); |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| uint flag = 1u; |
| if (part_ix == 0u) |
| { |
| flag = 2u; |
| } |
| atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, flag, memory_order_relaxed); |
| } |
| Monoid exclusive = Monoid{ 0u }; |
| if (part_ix != 0u) |
| { |
| uint look_back_ix = part_ix - 1u; |
| uint their_ix = 0u; |
| Monoid their_agg; |
| Monoid m; |
| while (true) |
| { |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| uint _206 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].flag, memory_order_relaxed); |
| sh_flag = _206; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| threadgroup_barrier(mem_flags::mem_device); |
| uint flag_1 = sh_flag; |
| if (flag_1 == 2u) |
| { |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| uint _221 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].prefix.element, memory_order_relaxed); |
| Monoid their_prefix = Monoid{ _221 }; |
| Monoid param_4 = their_prefix; |
| Monoid param_5 = exclusive; |
| exclusive = combine_monoid(param_4, param_5); |
| } |
| break; |
| } |
| else |
| { |
| if (flag_1 == 1u) |
| { |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| uint _242 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].aggregate.element, memory_order_relaxed); |
| their_agg.element = _242; |
| Monoid param_6 = their_agg; |
| Monoid param_7 = exclusive; |
| exclusive = combine_monoid(param_6, param_7); |
| } |
| look_back_ix--; |
| their_ix = 0u; |
| continue; |
| } |
| } |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element; |
| if (their_ix == 0u) |
| { |
| their_agg = m; |
| } |
| else |
| { |
| Monoid param_8 = their_agg; |
| Monoid param_9 = m; |
| their_agg = combine_monoid(param_8, param_9); |
| } |
| their_ix++; |
| if (their_ix == 8192u) |
| { |
| Monoid param_10 = their_agg; |
| Monoid param_11 = exclusive; |
| exclusive = combine_monoid(param_10, param_11); |
| if (look_back_ix == 0u) |
| { |
| sh_flag = 2u; |
| } |
| else |
| { |
| look_back_ix--; |
| their_ix = 0u; |
| } |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| flag_1 = sh_flag; |
| if (flag_1 == 2u) |
| { |
| break; |
| } |
| } |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| Monoid param_12 = exclusive; |
| Monoid param_13 = agg; |
| Monoid inclusive_prefix = combine_monoid(param_12, param_13); |
| sh_prefix = exclusive; |
| atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].prefix.element, inclusive_prefix.element, memory_order_relaxed); |
| } |
| threadgroup_barrier(mem_flags::mem_device); |
| if (gl_LocalInvocationID.x == 511u) |
| { |
| atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, 2u, memory_order_relaxed); |
| } |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| if (part_ix != 0u) |
| { |
| exclusive = sh_prefix; |
| } |
| Monoid row = exclusive; |
| if (gl_LocalInvocationID.x > 0u) |
| { |
| Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; |
| Monoid param_14 = row; |
| Monoid param_15 = other_1; |
| row = combine_monoid(param_14, param_15); |
| } |
| for (uint i_2 = 0u; i_2 < 16u; i_2++) |
| { |
| Monoid param_16 = row; |
| Monoid param_17 = local[i_2]; |
| Monoid m_1 = combine_monoid(param_16, param_17); |
| _367.outbuf[ix + i_2].element = m_1.element; |
| } |
| } |
| |