// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense

// Processing of the path stream, after the tag scan.

#version 450
#extension GL_GOOGLE_include_directive : enable

#include "mem.h"
#include "setup.h"
#include "pathtag.h"

#define N_SEQ 4
#define LG_WG_SIZE (7 + LG_WG_FACTOR)
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_SEQ)

layout(local_size_x = WG_SIZE, local_size_y = 1) in;

layout(binding = 1) readonly buffer ConfigBuf {
    Config conf;
};

layout(binding = 2) readonly buffer SceneBuf {
    uint[] scene;
};

#include "tile.h"
#include "pathseg.h"

layout(binding = 3) readonly buffer ParentBuf {
    TagMonoid[] parent;
};

struct Monoid {
    vec4 bbox;
    uint flags;
};

#define FLAG_RESET_BBOX 1
#define FLAG_SET_BBOX 2

Monoid combine_monoid(Monoid a, Monoid b) {
    Monoid c;
    c.bbox = b.bbox;
    // TODO: I think this should be gated on b & SET_BBOX == false also.
    if ((a.flags & FLAG_RESET_BBOX) == 0 && b.bbox.z <= b.bbox.x && b.bbox.w <= b.bbox.y) {
        c.bbox = a.bbox;
    } else if ((a.flags & FLAG_RESET_BBOX) == 0 && (b.flags & FLAG_SET_BBOX) == 0 &&
               (a.bbox.z > a.bbox.x || a.bbox.w > a.bbox.y)) {
        c.bbox.xy = min(a.bbox.xy, c.bbox.xy);
        c.bbox.zw = max(a.bbox.zw, c.bbox.zw);
    }
    c.flags = (a.flags & FLAG_SET_BBOX) | b.flags;
    c.flags |= ((a.flags & FLAG_RESET_BBOX) << 1);
    return c;
}

Monoid monoid_identity() {
    return Monoid(vec4(0.0, 0.0, 0.0, 0.0), 0);
}

// These are not both live at the same time. A very smart shader compiler
// would be able to figure that out, but I suspect many won't.
shared TagMonoid sh_tag[WG_SIZE];
shared Monoid sh_scratch[WG_SIZE];

vec2 read_f32_point(uint ix) {
    float x = uintBitsToFloat(scene[ix]);
    float y = uintBitsToFloat(scene[ix + 1]);
    return vec2(x, y);
}

vec2 read_i16_point(uint ix) {
    uint raw = scene[ix];
    float x = float(int(raw << 16) >> 16);
    float y = float(int(raw) >> 16);
    return vec2(x, y);
}

// Note: these are 16 bit, which is adequate, but we could use 32 bits.

// Round down and saturate to minimum integer; add bias
uint round_down(float x) {
    return uint(max(0.0, floor(x) + 32768.0));
}

// Round up and saturate to maximum integer; add bias
uint round_up(float x) {
    return uint(min(65535.0, ceil(x) + 32768.0));
}

void main() {
    Monoid local[N_SEQ];
    float linewidth[N_SEQ];
    uint save_trans_ix[N_SEQ];

    uint ix = gl_GlobalInvocationID.x * N_SEQ;

    uint tag_word = scene[(conf.pathtag_offset >> 2) + (ix >> 2)];

    // Scan the tag monoid
    TagMonoid local_tm = reduce_tag(tag_word);
    sh_tag[gl_LocalInvocationID.x] = local_tm;
    for (uint i = 0; i < LG_WG_SIZE; i++) {
        barrier();
        if (gl_LocalInvocationID.x >= (1u << i)) {
            TagMonoid other = sh_tag[gl_LocalInvocationID.x - (1u << i)];
            local_tm = combine_tag_monoid(other, local_tm);
        }
        barrier();
        sh_tag[gl_LocalInvocationID.x] = local_tm;
    }
    barrier();
    // sh_tag is now the partition-wide inclusive scan of the tag monoid.
    TagMonoid tm = tag_monoid_identity();
    if (gl_WorkGroupID.x > 0) {
        tm = parent[gl_WorkGroupID.x - 1];
    }
    if (gl_LocalInvocationID.x > 0) {
        tm = combine_tag_monoid(tm, sh_tag[gl_LocalInvocationID.x - 1]);
    }
    // tm is now the full exclusive scan of the tag monoid.

    // Indices to scene buffer in u32 units.
    uint ps_ix = (conf.pathseg_offset >> 2) + tm.pathseg_offset;
    uint lw_ix = (conf.linewidth_offset >> 2) + tm.linewidth_ix;
    uint save_path_ix = tm.path_ix;
    uint trans_ix = tm.trans_ix;
    TransformSegRef trans_ref = TransformSegRef(conf.trans_alloc.offset + trans_ix * TransformSeg_size);
    PathSegRef ps_ref = PathSegRef(conf.pathseg_alloc.offset + tm.pathseg_ix * PathSeg_size);
    for (uint i = 0; i < N_SEQ; i++) {
        linewidth[i] = uintBitsToFloat(scene[lw_ix]);
        save_trans_ix[i] = trans_ix;
        // if N_SEQ > 4, need to load tag_word from local if N_SEQ % 4 == 0
        uint tag_byte = tag_word >> (i * 8);
        uint seg_type = tag_byte & 3;
        if (seg_type != 0) {
            // 1 = line, 2 = quad, 3 = cubic
            // Unpack path segment from input
            vec2 p0;
            vec2 p1;
            vec2 p2;
            vec2 p3;
            if ((tag_byte & 8) != 0) {
                // 32 bit encoding
                p0 = read_f32_point(ps_ix);
                p1 = read_f32_point(ps_ix + 2);
                if (seg_type >= 2) {
                    p2 = read_f32_point(ps_ix + 4);
                    if (seg_type == 3) {
                        p3 = read_f32_point(ps_ix + 6);
                    }
                }
            } else {
                // 16 bit encoding
                p0 = read_i16_point(ps_ix);
                p1 = read_i16_point(ps_ix + 1);
                if (seg_type >= 2) {
                    p2 = read_i16_point(ps_ix + 2);
                    if (seg_type == 3) {
                        p3 = read_i16_point(ps_ix + 3);
                    }
                }
            }
            TransformSeg transform = TransformSeg_read(conf.trans_alloc, trans_ref);
            p0 = transform.mat.xy * p0.x + transform.mat.zw * p0.y + transform.translate;
            p1 = transform.mat.xy * p1.x + transform.mat.zw * p1.y + transform.translate;
            vec4 bbox = vec4(min(p0, p1), max(p0, p1));
            // Degree-raise and compute bbox
            if (seg_type >= 2) {
                p2 = transform.mat.xy * p2.x + transform.mat.zw * p2.y + transform.translate;
                bbox.xy = min(bbox.xy, p2);
                bbox.zw = max(bbox.zw, p2);
                if (seg_type == 3) {
                    p3 = transform.mat.xy * p3.x + transform.mat.zw * p3.y + transform.translate;
                    bbox.xy = min(bbox.xy, p3);
                    bbox.zw = max(bbox.zw, p3);
                } else {
                    p3 = p2;
                    p2 = mix(p1, p2, 1.0 / 3.0);
                    p1 = mix(p1, p0, 1.0 / 3.0);
                }
            } else {
                p3 = p1;
                p2 = mix(p3, p0, 1.0 / 3.0);
                p1 = mix(p0, p3, 1.0 / 3.0);
            }
            vec2 stroke = vec2(0.0, 0.0);
            if (linewidth[i] >= 0.0) {
                // See https://www.iquilezles.org/www/articles/ellipses/ellipses.htm
                stroke = 0.5 * linewidth[i] * vec2(length(transform.mat.xz), length(transform.mat.yw));
                bbox += vec4(-stroke, stroke);
            }
            local[i].bbox = bbox;
            local[i].flags = 0;

            // Write path segment to output
            PathCubic cubic;
            cubic.p0 = p0;
            cubic.p1 = p1;
            cubic.p2 = p2;
            cubic.p3 = p3;
            cubic.path_ix = tm.path_ix;
            // Not needed, TODO remove from struct
            cubic.trans_ix = gl_GlobalInvocationID.x * 4 + i;
            cubic.stroke = stroke;
            uint fill_mode = uint(linewidth[i] >= 0.0);
            PathSeg_Cubic_write(conf.pathseg_alloc, ps_ref, fill_mode, cubic);

            ps_ref.offset += PathSeg_size;
            uint n_points = (tag_byte & 3) + ((tag_byte >> 2) & 1);
            uint n_words = n_points + (n_points & (((tag_byte >> 3) & 1) * 15));
            ps_ix += n_words;
        } else {
            local[i].bbox = vec4(0.0, 0.0, 0.0, 0.0);
            // These shifts need to be kept in sync with setup.h
            uint is_path = (tag_byte >> 4) & 1;
            // Relies on the fact that RESET_BBOX == 1
            local[i].flags = is_path;
            tm.path_ix += is_path;
            trans_ix += (tag_byte >> 5) & 1;
            trans_ref.offset += ((tag_byte >> 5) & 1) * TransformSeg_size;
            lw_ix += (tag_byte >> 6) & 1;
        }
    }

    // Partition-wide monoid scan for bbox monoid
    Monoid agg = local[0];
    for (uint i = 1; i < N_SEQ; i++) {
        // Note: this could be fused with the map above, but probably
        // a thin performance gain not worth the complexity.
        agg = combine_monoid(agg, local[i]);
        local[i] = agg;
    }
    // local is N_SEQ sub-partition inclusive scan of bbox monoid.
    sh_scratch[gl_LocalInvocationID.x] = agg;
    for (uint i = 0; i < LG_WG_SIZE; i++) {
        barrier();
        if (gl_LocalInvocationID.x >= (1u << i)) {
            Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
            agg = combine_monoid(other, agg);
        }
        barrier();
        sh_scratch[gl_LocalInvocationID.x] = agg;
    }
    // sh_scratch is the partition-wide inclusive scan of the bbox monoid,
    // sampled at the end of the N_SEQ sub-partition.

    barrier();
    uint path_ix = save_path_ix;
    uint bbox_out_ix = (conf.path_bbox_alloc.offset >> 2) + path_ix * 6;
    // Write bboxes to paths; do atomic min/max if partial
    Monoid row = monoid_identity();
    if (gl_LocalInvocationID.x > 0) {
        row = sh_scratch[gl_LocalInvocationID.x - 1];
    }
    for (uint i = 0; i < N_SEQ; i++) {
        Monoid m = combine_monoid(row, local[i]);
        // m is partition-wide inclusive scan of bbox monoid.
        bool do_atomic = false;
        if (i == N_SEQ - 1 && gl_LocalInvocationID.x == WG_SIZE - 1) {
            // last element
            do_atomic = true;
        }
        if ((m.flags & FLAG_RESET_BBOX) != 0) {
            memory[bbox_out_ix + 4] = floatBitsToUint(linewidth[i]);
            memory[bbox_out_ix + 5] = save_trans_ix[i];
            if ((m.flags & FLAG_SET_BBOX) == 0) {
                do_atomic = true;
            } else {
                memory[bbox_out_ix] = round_down(m.bbox.x);
                memory[bbox_out_ix + 1] = round_down(m.bbox.y);
                memory[bbox_out_ix + 2] = round_up(m.bbox.z);
                memory[bbox_out_ix + 3] = round_up(m.bbox.w);
                bbox_out_ix += 6;
                do_atomic = false;
            }
        }
        if (do_atomic) {
            if (m.bbox.z > m.bbox.x || m.bbox.w > m.bbox.y) {
                // atomic min/max
                atomicMin(memory[bbox_out_ix], round_down(m.bbox.x));
                atomicMin(memory[bbox_out_ix + 1], round_down(m.bbox.y));
                atomicMax(memory[bbox_out_ix + 2], round_up(m.bbox.z));
                atomicMax(memory[bbox_out_ix + 3], round_up(m.bbox.w));
            }
            bbox_out_ix += 6;
        }
    }
}
