| #pragma clang diagnostic ignored "-Wmissing-prototypes" |
| #pragma clang diagnostic ignored "-Wmissing-braces" |
| |
| #include <metal_stdlib> |
| #include <simd/simd.h> |
| |
| using namespace metal; |
| |
| template<typename T, size_t Num> |
| struct spvUnsafeArray |
| { |
| T elements[Num ? Num : 1]; |
| |
| thread T& operator [] (size_t pos) thread |
| { |
| return elements[pos]; |
| } |
| constexpr const thread T& operator [] (size_t pos) const thread |
| { |
| return elements[pos]; |
| } |
| |
| device T& operator [] (size_t pos) device |
| { |
| return elements[pos]; |
| } |
| constexpr const device T& operator [] (size_t pos) const device |
| { |
| return elements[pos]; |
| } |
| |
| constexpr const constant T& operator [] (size_t pos) const constant |
| { |
| return elements[pos]; |
| } |
| |
| threadgroup T& operator [] (size_t pos) threadgroup |
| { |
| return elements[pos]; |
| } |
| constexpr const threadgroup T& operator [] (size_t pos) const threadgroup |
| { |
| return elements[pos]; |
| } |
| }; |
| |
| struct Monoid |
| { |
| uint element; |
| }; |
| |
| struct Monoid_1 |
| { |
| uint element; |
| }; |
| |
| struct DataBuf |
| { |
| Monoid_1 data[1]; |
| }; |
| |
| struct ParentBuf |
| { |
| Monoid_1 parent[1]; |
| }; |
| |
| constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); |
| |
| static inline __attribute__((always_inline)) |
| Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) |
| { |
| return Monoid{ a.element + b.element }; |
| } |
| |
| kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _141 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) |
| { |
| threadgroup Monoid sh_scratch[512]; |
| uint ix = gl_GlobalInvocationID.x * 8u; |
| spvUnsafeArray<Monoid, 8> local; |
| local[0].element = _42.data[ix].element; |
| Monoid param_1; |
| for (uint i = 1u; i < 8u; i++) |
| { |
| Monoid param = local[i - 1u]; |
| param_1.element = _42.data[ix + i].element; |
| local[i] = combine_monoid(param, param_1); |
| } |
| Monoid agg = local[7]; |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| for (uint i_1 = 0u; i_1 < 9u; i_1++) |
| { |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| if (gl_LocalInvocationID.x >= (1u << i_1)) |
| { |
| Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; |
| Monoid param_2 = other; |
| Monoid param_3 = agg; |
| agg = combine_monoid(param_2, param_3); |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| sh_scratch[gl_LocalInvocationID.x] = agg; |
| } |
| threadgroup_barrier(mem_flags::mem_threadgroup); |
| Monoid row = Monoid{ 0u }; |
| if (gl_WorkGroupID.x > 0u) |
| { |
| row.element = _141.parent[gl_WorkGroupID.x - 1u].element; |
| } |
| if (gl_LocalInvocationID.x > 0u) |
| { |
| Monoid param_4 = row; |
| Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; |
| row = combine_monoid(param_4, param_5); |
| } |
| for (uint i_2 = 0u; i_2 < 8u; i_2++) |
| { |
| Monoid param_6 = row; |
| Monoid param_7 = local[i_2]; |
| Monoid m = combine_monoid(param_6, param_7); |
| _42.data[ix + i_2].element = m.element; |
| } |
| } |
| |