// Copyright 2020 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// --------

// Filter 1: Sub.

pri func decoder.filter_1!(curr: slice base.u8),
        choosy,
{
    var filter_distance : base.u64[..= 8]
    var fa              : base.u8
    var i_start         : base.u64
    var i               : base.u64

    filter_distance = this.filter_distance as base.u64
    i_start = 0
    while i_start < filter_distance {
        assert i_start < 8 via "a < b: a < c; c <= b"(c: filter_distance)
        fa = 0
        i = i_start
        while i < args.curr.length(),
                inv i_start < 8,
        {
            args.curr[i] = args.curr[i] ~mod+ fa
            fa = args.curr[i]
            i ~mod+= filter_distance
        } endwhile
        i_start += 1
    } endwhile
}

pri func decoder.filter_1_distance_3_fallback!(curr: slice base.u8) {
    var curr : slice base.u8

    var fa0 : base.u8
    var fa1 : base.u8
    var fa2 : base.u8

    iterate (curr = args.curr)(length: 3, advance: 3, unroll: 2) {
        fa0 = fa0 ~mod+ curr[0]
        curr[0] = fa0
        fa1 = fa1 ~mod+ curr[1]
        curr[1] = fa1
        fa2 = fa2 ~mod+ curr[2]
        curr[2] = fa2
    }
}

pri func decoder.filter_1_distance_4_fallback!(curr: slice base.u8) {
    var curr : slice base.u8

    var fa0 : base.u8
    var fa1 : base.u8
    var fa2 : base.u8
    var fa3 : base.u8

    iterate (curr = args.curr)(length: 4, advance: 4, unroll: 1) {
        fa0 = fa0 ~mod+ curr[0]
        curr[0] = fa0
        fa1 = fa1 ~mod+ curr[1]
        curr[1] = fa1
        fa2 = fa2 ~mod+ curr[2]
        curr[2] = fa2
        fa3 = fa3 ~mod+ curr[3]
        curr[3] = fa3
    }
}

// Filter 2: Up.

pri func decoder.filter_2!(curr: slice base.u8, prev: slice base.u8) {
    var n : base.u64
    var i : base.u64

    n = args.curr.length().min(no_more_than: args.prev.length())
    i = 0
    while i < n,
            inv n <= args.curr.length(),
            inv n <= args.prev.length(),
    {
        assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: n)
        assert i < args.curr.length() via "a < b: a < c; c <= b"(c: n)
        assert i < args.prev.length() via "a < b: a < c; c <= b"(c: n)
        args.curr[i] = args.curr[i] ~mod+ (args.prev[i])
        i += 1
    } endwhile
}

// Filter 3: Average.

pri func decoder.filter_3!(curr: slice base.u8, prev: slice base.u8),
        choosy,
{
    var filter_distance : base.u64[..= 8]
    var n               : base.u64
    var i               : base.u64

    filter_distance = this.filter_distance as base.u64
    if args.prev.length() == 0 {
        i = filter_distance
        assert i >= filter_distance via "a >= b: a == b"()
        while i < args.curr.length(),
                inv i >= filter_distance,
        {
            assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: args.curr.length())
            assert (i - filter_distance) < args.curr.length() via "(a - b) < c: a < c; 0 <= b"()
            args.curr[i] = args.curr[i] ~mod+ (args.curr[i - filter_distance] / 2)
            i += 1
            assert i >= filter_distance via "a >= b: a >= (b + c); 0 <= c"(c: 1)
        } endwhile

    } else {
        n = args.curr.length().min(no_more_than: args.prev.length())
        i = 0
        while (i < n) and (i < filter_distance),
                inv n <= args.curr.length(),
                inv n <= args.prev.length(),
        {
            assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: n)
            assert i < args.curr.length() via "a < b: a < c; c <= b"(c: n)
            assert i < args.prev.length() via "a < b: a < c; c <= b"(c: n)
            args.curr[i] = args.curr[i] ~mod+ (args.prev[i] / 2)
            i += 1
        } endwhile

        i = filter_distance
        assert i >= filter_distance via "a >= b: a == b"()
        while i < n,
                inv i >= filter_distance,
                inv n <= args.curr.length(),
                inv n <= args.prev.length(),
        {
            assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: n)
            assert i < args.curr.length() via "a < b: a < c; c <= b"(c: n)
            assert i < args.prev.length() via "a < b: a < c; c <= b"(c: n)
            assert (i - filter_distance) < args.curr.length() via "(a - b) < c: a < c; 0 <= b"()
            args.curr[i] = args.curr[i] ~mod+ (((
                    (args.curr[i - filter_distance] as base.u32) +
                    (args.prev[i] as base.u32)) / 2) as base.u8)
            i += 1
            assert i >= filter_distance via "a >= b: a >= (b + c); 0 <= c"(c: 1)
        } endwhile
    }
}

pri func decoder.filter_3_distance_3_fallback!(curr: slice base.u8, prev: slice base.u8) {
    var curr : slice base.u8
    var prev : slice base.u8

    var fa0 : base.u8
    var fa1 : base.u8
    var fa2 : base.u8

    if args.prev.length() == 0 {
        iterate (curr = args.curr)(length: 3, advance: 3, unroll: 2) {
            fa0 = (fa0 / 2) ~mod+ curr[0]
            curr[0] = fa0
            fa1 = (fa1 / 2) ~mod+ curr[1]
            curr[1] = fa1
            fa2 = (fa2 / 2) ~mod+ curr[2]
            curr[2] = fa2
        }

    } else {
        iterate (curr = args.curr, prev = args.prev)(length: 3, advance: 3, unroll: 2) {
            fa0 = ((((fa0 as base.u32) + (prev[0] as base.u32)) / 2) as base.u8) ~mod+ curr[0]
            curr[0] = fa0
            fa1 = ((((fa1 as base.u32) + (prev[1] as base.u32)) / 2) as base.u8) ~mod+ curr[1]
            curr[1] = fa1
            fa2 = ((((fa2 as base.u32) + (prev[2] as base.u32)) / 2) as base.u8) ~mod+ curr[2]
            curr[2] = fa2
        }
    }
}

pri func decoder.filter_3_distance_4_fallback!(curr: slice base.u8, prev: slice base.u8) {
    var curr : slice base.u8
    var prev : slice base.u8

    var fa0 : base.u8
    var fa1 : base.u8
    var fa2 : base.u8
    var fa3 : base.u8

    if args.prev.length() == 0 {
        iterate (curr = args.curr)(length: 4, advance: 4, unroll: 1) {
            fa0 = (fa0 / 2) ~mod+ curr[0]
            curr[0] = fa0
            fa1 = (fa1 / 2) ~mod+ curr[1]
            curr[1] = fa1
            fa2 = (fa2 / 2) ~mod+ curr[2]
            curr[2] = fa2
            fa3 = (fa3 / 2) ~mod+ curr[3]
            curr[3] = fa3
        }

    } else {
        iterate (curr = args.curr, prev = args.prev)(length: 4, advance: 4, unroll: 1) {
            fa0 = ((((fa0 as base.u32) + (prev[0] as base.u32)) / 2) as base.u8) ~mod+ curr[0]
            curr[0] = fa0
            fa1 = ((((fa1 as base.u32) + (prev[1] as base.u32)) / 2) as base.u8) ~mod+ curr[1]
            curr[1] = fa1
            fa2 = ((((fa2 as base.u32) + (prev[2] as base.u32)) / 2) as base.u8) ~mod+ curr[2]
            curr[2] = fa2
            fa3 = ((((fa3 as base.u32) + (prev[3] as base.u32)) / 2) as base.u8) ~mod+ curr[3]
            curr[3] = fa3
        }
    }
}

// Filter 4: Paeth.

pri func decoder.filter_4!(curr: slice base.u8, prev: slice base.u8),
        choosy,
{
    var filter_distance : base.u64[..= 8]
    var n               : base.u64
    var i               : base.u64

    var fa : base.u32
    var fb : base.u32
    var fc : base.u32
    var pp : base.u32
    var pa : base.u32
    var pb : base.u32
    var pc : base.u32

    filter_distance = this.filter_distance as base.u64
    n = args.curr.length().min(no_more_than: args.prev.length())
    i = 0
    while (i < n) and (i < filter_distance),
            inv n <= args.curr.length(),
            inv n <= args.prev.length(),
    {
        assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: n)
        assert i < args.curr.length() via "a < b: a < c; c <= b"(c: n)
        assert i < args.prev.length() via "a < b: a < c; c <= b"(c: n)
        args.curr[i] = args.curr[i] ~mod+ args.prev[i]
        i += 1
    } endwhile

    i = filter_distance
    assert i >= filter_distance via "a >= b: a == b"()
    while i < n,
            inv i >= filter_distance,
            inv n <= args.curr.length(),
            inv n <= args.prev.length(),
    {
        assert i < 0xFFFF_FFFF_FFFF_FFFF via "a < b: a < c; c <= b"(c: n)
        assert i < args.curr.length() via "a < b: a < c; c <= b"(c: n)
        assert i < args.prev.length() via "a < b: a < c; c <= b"(c: n)
        assert (i - filter_distance) < args.curr.length() via "(a - b) < c: a < c; 0 <= b"()
        assert (i - filter_distance) < args.prev.length() via "(a - b) < c: a < c; 0 <= b"()
        fa = args.curr[i - filter_distance] as base.u32
        fb = args.prev[i] as base.u32
        fc = args.prev[i - filter_distance] as base.u32
        pp = (fa ~mod+ fb) ~mod- fc
        pa = pp ~mod- fa
        if pa >= 0x8000_0000 {
            pa = 0 ~mod- pa
        }
        pb = pp ~mod- fb
        if pb >= 0x8000_0000 {
            pb = 0 ~mod- pb
        }
        pc = pp ~mod- fc
        if pc >= 0x8000_0000 {
            pc = 0 ~mod- pc
        }
        if (pa <= pb) and (pa <= pc) {
            // No-op.
        } else if pb <= pc {
            fa = fb
        } else {
            fa = fc
        }
        args.curr[i] = args.curr[i] ~mod+ ((fa & 0xFF) as base.u8)
        i += 1
        assert i >= filter_distance via "a >= b: a >= (b + c); 0 <= c"(c: 1)
    } endwhile
}

pri func decoder.filter_4_distance_3_fallback!(curr: slice base.u8, prev: slice base.u8) {
    var curr : slice base.u8
    var prev : slice base.u8

    var fa0 : base.u32
    var fa1 : base.u32
    var fa2 : base.u32
    var fb0 : base.u32
    var fb1 : base.u32
    var fb2 : base.u32
    var fc0 : base.u32
    var fc1 : base.u32
    var fc2 : base.u32
    var pp0 : base.u32
    var pp1 : base.u32
    var pp2 : base.u32
    var pa0 : base.u32
    var pa1 : base.u32
    var pa2 : base.u32
    var pb0 : base.u32
    var pb1 : base.u32
    var pb2 : base.u32
    var pc0 : base.u32
    var pc1 : base.u32
    var pc2 : base.u32

    iterate (curr = args.curr, prev = args.prev)(length: 3, advance: 3, unroll: 1) {
        fb0 = prev[0] as base.u32
        pp0 = (fa0 ~mod+ fb0) ~mod- fc0
        pa0 = pp0 ~mod- fa0
        if pa0 >= 0x8000_0000 {
            pa0 = 0 ~mod- pa0
        }
        pb0 = pp0 ~mod- fb0
        if pb0 >= 0x8000_0000 {
            pb0 = 0 ~mod- pb0
        }
        pc0 = pp0 ~mod- fc0
        if pc0 >= 0x8000_0000 {
            pc0 = 0 ~mod- pc0
        }
        if (pa0 <= pb0) and (pa0 <= pc0) {
            // No-op.
        } else if pb0 <= pc0 {
            fa0 = fb0
        } else {
            fa0 = fc0
        }
        curr[0] = curr[0] ~mod+ ((fa0 & 0xFF) as base.u8)
        fa0 = curr[0] as base.u32
        fc0 = fb0

        fb1 = prev[1] as base.u32
        pp1 = (fa1 ~mod+ fb1) ~mod- fc1
        pa1 = pp1 ~mod- fa1
        if pa1 >= 0x8000_0000 {
            pa1 = 0 ~mod- pa1
        }
        pb1 = pp1 ~mod- fb1
        if pb1 >= 0x8000_0000 {
            pb1 = 0 ~mod- pb1
        }
        pc1 = pp1 ~mod- fc1
        if pc1 >= 0x8000_0000 {
            pc1 = 0 ~mod- pc1
        }
        if (pa1 <= pb1) and (pa1 <= pc1) {
            // No-op.
        } else if pb1 <= pc1 {
            fa1 = fb1
        } else {
            fa1 = fc1
        }
        curr[1] = curr[1] ~mod+ ((fa1 & 0xFF) as base.u8)
        fa1 = curr[1] as base.u32
        fc1 = fb1

        fb2 = prev[2] as base.u32
        pp2 = (fa2 ~mod+ fb2) ~mod- fc2
        pa2 = pp2 ~mod- fa2
        if pa2 >= 0x8000_0000 {
            pa2 = 0 ~mod- pa2
        }
        pb2 = pp2 ~mod- fb2
        if pb2 >= 0x8000_0000 {
            pb2 = 0 ~mod- pb2
        }
        pc2 = pp2 ~mod- fc2
        if pc2 >= 0x8000_0000 {
            pc2 = 0 ~mod- pc2
        }
        if (pa2 <= pb2) and (pa2 <= pc2) {
            // No-op.
        } else if pb2 <= pc2 {
            fa2 = fb2
        } else {
            fa2 = fc2
        }
        curr[2] = curr[2] ~mod+ ((fa2 & 0xFF) as base.u8)
        fa2 = curr[2] as base.u32
        fc2 = fb2
    }
}

pri func decoder.filter_4_distance_4_fallback!(curr: slice base.u8, prev: slice base.u8) {
    var curr : slice base.u8
    var prev : slice base.u8

    var fa0 : base.u32
    var fa1 : base.u32
    var fa2 : base.u32
    var fa3 : base.u32
    var fb0 : base.u32
    var fb1 : base.u32
    var fb2 : base.u32
    var fb3 : base.u32
    var fc0 : base.u32
    var fc1 : base.u32
    var fc2 : base.u32
    var fc3 : base.u32
    var pp0 : base.u32
    var pp1 : base.u32
    var pp2 : base.u32
    var pp3 : base.u32
    var pa0 : base.u32
    var pa1 : base.u32
    var pa2 : base.u32
    var pa3 : base.u32
    var pb0 : base.u32
    var pb1 : base.u32
    var pb2 : base.u32
    var pb3 : base.u32
    var pc0 : base.u32
    var pc1 : base.u32
    var pc2 : base.u32
    var pc3 : base.u32

    iterate (curr = args.curr, prev = args.prev)(length: 4, advance: 4, unroll: 1) {
        fb0 = prev[0] as base.u32
        pp0 = (fa0 ~mod+ fb0) ~mod- fc0
        pa0 = pp0 ~mod- fa0
        if pa0 >= 0x8000_0000 {
            pa0 = 0 ~mod- pa0
        }
        pb0 = pp0 ~mod- fb0
        if pb0 >= 0x8000_0000 {
            pb0 = 0 ~mod- pb0
        }
        pc0 = pp0 ~mod- fc0
        if pc0 >= 0x8000_0000 {
            pc0 = 0 ~mod- pc0
        }
        if (pa0 <= pb0) and (pa0 <= pc0) {
            // No-op.
        } else if pb0 <= pc0 {
            fa0 = fb0
        } else {
            fa0 = fc0
        }
        curr[0] = curr[0] ~mod+ ((fa0 & 0xFF) as base.u8)
        fa0 = curr[0] as base.u32
        fc0 = fb0

        fb1 = prev[1] as base.u32
        pp1 = (fa1 ~mod+ fb1) ~mod- fc1
        pa1 = pp1 ~mod- fa1
        if pa1 >= 0x8000_0000 {
            pa1 = 0 ~mod- pa1
        }
        pb1 = pp1 ~mod- fb1
        if pb1 >= 0x8000_0000 {
            pb1 = 0 ~mod- pb1
        }
        pc1 = pp1 ~mod- fc1
        if pc1 >= 0x8000_0000 {
            pc1 = 0 ~mod- pc1
        }
        if (pa1 <= pb1) and (pa1 <= pc1) {
            // No-op.
        } else if pb1 <= pc1 {
            fa1 = fb1
        } else {
            fa1 = fc1
        }
        curr[1] = curr[1] ~mod+ ((fa1 & 0xFF) as base.u8)
        fa1 = curr[1] as base.u32
        fc1 = fb1

        fb2 = prev[2] as base.u32
        pp2 = (fa2 ~mod+ fb2) ~mod- fc2
        pa2 = pp2 ~mod- fa2
        if pa2 >= 0x8000_0000 {
            pa2 = 0 ~mod- pa2
        }
        pb2 = pp2 ~mod- fb2
        if pb2 >= 0x8000_0000 {
            pb2 = 0 ~mod- pb2
        }
        pc2 = pp2 ~mod- fc2
        if pc2 >= 0x8000_0000 {
            pc2 = 0 ~mod- pc2
        }
        if (pa2 <= pb2) and (pa2 <= pc2) {
            // No-op.
        } else if pb2 <= pc2 {
            fa2 = fb2
        } else {
            fa2 = fc2
        }
        curr[2] = curr[2] ~mod+ ((fa2 & 0xFF) as base.u8)
        fa2 = curr[2] as base.u32
        fc2 = fb2

        fb3 = prev[3] as base.u32
        pp3 = (fa3 ~mod+ fb3) ~mod- fc3
        pa3 = pp3 ~mod- fa3
        if pa3 >= 0x8000_0000 {
            pa3 = 0 ~mod- pa3
        }
        pb3 = pp3 ~mod- fb3
        if pb3 >= 0x8000_0000 {
            pb3 = 0 ~mod- pb3
        }
        pc3 = pp3 ~mod- fc3
        if pc3 >= 0x8000_0000 {
            pc3 = 0 ~mod- pc3
        }
        if (pa3 <= pb3) and (pa3 <= pc3) {
            // No-op.
        } else if pb3 <= pc3 {
            fa3 = fb3
        } else {
            fa3 = fc3
        }
        curr[3] = curr[3] ~mod+ ((fa3 & 0xFF) as base.u8)
        fa3 = curr[3] as base.u32
        fc3 = fb3
    }
}
