// Copyright 2020 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub status "#bad header"
pub status "#unsupported BMP file"

pri status "#internal error: inconsistent swizzle count"

pub const decoder_workbuf_len_max_incl_worst_case base.u64 = 0

pub struct decoder? implements base.image_decoder(
	width  : base.u32[..= 0x7FFF_FFFF],
	height : base.u32[..= 0x7FFF_FFFF],

	call_sequence : base.u8,

	top_down       : base.bool,
	bits_per_pixel : base.u8[..= 32],
	pad_per_row    : base.u32[..= 3],
	bytes_per_row  : base.u64[..= 0x0000_0001_FFFF_FFFC],  // 4 * 0x7FFF_FFFF
	bytes_total    : base.u64[..= 0xFFFF_FFFC_0000_0004],  // 4 * 0x7FFF_FFFF * 0x7FFF_FFFF
	pixfmt         : base.pixel_format,

	padding : base.u32,

	mask_r : base.u32,
	mask_g : base.u32,
	mask_b : base.u32,
	mask_a : base.u32,

	frame_config_io_position : base.u64,

	dst_x     : base.u32,
	dst_y     : base.u32,
	dst_y_end : base.u32,
	dst_y_inc : base.u32,

	stash       : array[4] base.u8,
	num_stashed : base.u8[..= 4],
	pending_pad : base.u8[..= 3],

	swizzler : base.pixel_swizzler,
	util     : base.utility,
)

pub func decoder.decode_image_config?(dst: nptr base.image_config, src: base.io_reader) {
	var magic           : base.u32
	var bitmap_info_len : base.u32
	var width           : base.u32
	var height          : base.u32
	var planes          : base.u32
	var bits_per_pixel  : base.u32
	var compression     : base.u32

	if this.call_sequence <> 0 {
		return base."#bad call sequence"
	}

	// Read the BITMAPFILEHEADER (14 bytes).

	magic = args.src.read_u16le_as_u32?()
	if magic <> 'BM'le {
		return "#bad header"
	}

	args.src.skip32?(n: 8)

	this.padding = args.src.read_u32le?()
	if this.padding < 14 {
		return "#bad header"
	}
	this.padding -= 14

	// Read the BITMAPINFOHEADER (version 3 / 4 / 5 is 40 / 108 / 124 bytes).

	bitmap_info_len = args.src.read_u32le?()
	if (bitmap_info_len <> 40) and (bitmap_info_len <> 108) and (bitmap_info_len <> 124) {
		return "#unsupported BMP file"
	}
	if this.padding < bitmap_info_len {
		return "#bad header"
	}
	this.padding -= bitmap_info_len

	width = args.src.read_u32le?()
	if width >= 0x8000_0000 {
		return "#bad header"
	}
	this.width = width

	height = args.src.read_u32le?()
	if height == 0x8000_0000 {
		return "#bad header"
	} else if height >= 0x8000_0000 {
		// The &0x7FFF_FFFF is redundant, but proves to the compiler that the
		// result is within this.height's refined bounds.
		this.height = (0 ~mod- height) & 0x7FFF_FFFF
		this.top_down = true
	} else {
		this.height = height
	}

	planes = args.src.read_u16le_as_u32?()
	if planes <> 1 {
		return "#unsupported BMP file"
	}

	bits_per_pixel = args.src.read_u16le_as_u32?()
	if bits_per_pixel == 24 {
		this.bits_per_pixel = 24
		// 3 bytes per pixel, but row lengths are rounded up to multiples of 4.
		// The "((x + 3) >> 2) << 2" dance rounds x up.
		this.bytes_per_row = ((((this.width as base.u64) * 3) + 3) >> 2) << 2
		this.pad_per_row = this.width & 3
		// TODO: a Wuffs (not just C) name for the
		// WUFFS_BASE__PIXEL_FORMAT__BGR magic pixfmt constant.
		this.pixfmt = this.util.make_pixel_format(repr: 0x4000_0888)
	} else if bits_per_pixel == 32 {
		this.bits_per_pixel = 32
		this.bytes_per_row = (this.width as base.u64) * 4
		this.pad_per_row = 0
		// TODO: a Wuffs (not just C) name for the
		// WUFFS_BASE__PIXEL_FORMAT__BGRA_NONPREMUL magic pixfmt constant.
		this.pixfmt = this.util.make_pixel_format(repr: 0x4500_8888)
	} else {
		// TODO: support other bits_per_pixel's.
		return "#unsupported BMP file"
	}
	this.bytes_total = this.bytes_per_row * (this.height as base.u64)

	compression = args.src.read_u32le?()

	// We've already read 20 bytes from the BITMAPINFOHEADER: size (4), width
	// (4), height (4), planes (2), bpp (2), compression (4). Skip the rest of
	// the version 3 BITMAPINFOHEADER (whose total size is 40).
	args.src.skip32?(n: 40 - 20)

	if bitmap_info_len >= 108 {
		this.mask_r = args.src.read_u32le?()
		this.mask_g = args.src.read_u32le?()
		this.mask_b = args.src.read_u32le?()
		this.mask_a = args.src.read_u32le?()

		// If compression is 3 (BITFIELDS) but the explicit masks are what the
		// implicit masks are for no compression, treat it as no compression.
		if (compression == 3) and
			(this.mask_r == 0x00FF_0000) and
			(this.mask_g == 0x0000_FF00) and
			(this.mask_b == 0x0000_00FF) and
			(this.mask_a == 0xFF00_0000) {
			compression = 0
		}

		// Skip the rest of the BITMAPINFOHEADER. We've already read (40 + (4 *
		// 4)) bytes.
		args.src.skip32?(n: bitmap_info_len - 56)
	}

	if compression <> 0 {
		// TODO: support compression.
		return "#unsupported BMP file"
	}

	this.frame_config_io_position = args.src.position()

	if args.dst <> nullptr {
		// TODO: a Wuffs (not just C) name for the
		// WUFFS_BASE__PIXEL_FORMAT__BGRA_NONPREMUL magic pixfmt constant.
		args.dst.set!(
			pixfmt: 0x4500_8888,  // TODO: this.pixfmt instead?
			pixsub: 0,
			width: this.width,
			height: this.height,
			first_frame_io_position: this.frame_config_io_position,
			first_frame_is_opaque: true)
	}

	this.call_sequence = 1
}

pub func decoder.decode_frame_config?(dst: nptr base.frame_config, src: base.io_reader) {
	if this.call_sequence < 1 {
		this.decode_image_config?(dst: nullptr, src: args.src)
	} else if this.call_sequence == 1 {
		if this.frame_config_io_position <> args.src.position() {
			return base."#bad restart"
		}
	} else if this.call_sequence == 2 {
		this.skip_frame?(src: args.src)
		return base."@end of data"
	} else {
		return base."@end of data"
	}

	if args.dst <> nullptr {
		args.dst.update!(bounds: this.util.make_rect_ie_u32(
			min_incl_x: 0,
			min_incl_y: 0,
			max_excl_x: this.width,
			max_excl_y: this.height),
			duration: 0,
			index: 0,
			io_position: this.frame_config_io_position,
			disposal: 0,
			opaque_within_bounds: true,
			overwrite_instead_of_blend: false,
			background_color: 0xFF00_0000)
	}

	this.call_sequence = 2
}

pub func decoder.decode_frame?(dst: ptr base.pixel_buffer, src: base.io_reader, blend: base.pixel_blend, workbuf: slice base.u8, opts: nptr base.decode_frame_options) {
	var status          : base.status
	var bytes_remaining : base.u64
	var n               : base.u64
	var src             : slice base.u8

	if this.call_sequence < 2 {
		this.decode_frame_config?(dst: nullptr, src: args.src)
	} else if this.call_sequence == 2 {
		// No-op.
	} else {
		return base."@end of data"
	}

	args.src.skip32?(n: this.padding)

	if (this.width > 0) and (this.height > 0) {
		this.dst_x = 0
		if this.top_down {
			this.dst_y = 0
			this.dst_y_end = this.height
			this.dst_y_inc = 1
		} else {
			this.dst_y = this.height ~mod- 1
			this.dst_y_end = 0xFFFF_FFFF  // -1 as a base.u32.
			this.dst_y_inc = 0xFFFF_FFFF  // -1 as a base.u32.
		}

		status = this.swizzler.prepare!(
			dst_pixfmt: args.dst.pixel_format(),
			dst_palette: args.dst.palette(),
			src_pixfmt: this.pixfmt,
			src_palette: this.util.empty_slice_u8(),
			blend: args.blend)
		if not status.is_ok() {
			return status
		}

		bytes_remaining = this.bytes_total
		while true {
			n = args.src.available()
			if bytes_remaining >= n {
				bytes_remaining -= n
			} else {
				n = bytes_remaining
				bytes_remaining = 0
			}
			src = args.src.take!(n: n)
			status = this.swizzle!(dst: args.dst, src: src)
			if status.is_ok() {
				break
			} else if status.is_suspension() {
				yield? status
			} else {
				return status
			}
		} endwhile
	}

	this.call_sequence = 3
}

pri func decoder.swizzle!(dst: ptr base.pixel_buffer, src: slice base.u8) base.status {
	var dst_pixfmt          : base.pixel_format
	var dst_bits_per_pixel  : base.u32[..= 256]
	var dst_bytes_per_pixel : base.u64[..= 32]
	var dst_bytes_per_row   : base.u64
	var src_bytes_per_pixel : base.u8[..= 4]
	var tab                 : table base.u8
	var dst                 : slice base.u8
	var i                   : base.u64
	var n                   : base.u64

	// TODO: the dst_pixfmt variable shouldn't be necessary. We should be able
	// to chain the two calls: "args.dst.pixel_format().bits_per_pixel()".
	dst_pixfmt = args.dst.pixel_format()
	dst_bits_per_pixel = dst_pixfmt.bits_per_pixel()
	if (dst_bits_per_pixel & 7) <> 0 {
		return base."#unsupported option"
	}
	dst_bytes_per_pixel = (dst_bits_per_pixel / 8) as base.u64
	dst_bytes_per_row = (this.width as base.u64) * dst_bytes_per_pixel
	src_bytes_per_pixel = (this.bits_per_pixel / 8) as base.u8
	tab = args.dst.plane(p: 0)

	// Handle the case where the I/O suspension occurred in the middle of the
	// end-of-row padding.
	while this.pending_pad > 0 {
		if args.src.length() <= 0 {
			return base."$short read"
		}
		this.pending_pad -= 1
		args.src = args.src[1 ..]
	} endwhile

	// Handle the case where the I/O suspension occurred in the middle of a
	// source pixel.
	if this.num_stashed <> 0 {
		while this.num_stashed < src_bytes_per_pixel {
			assert this.num_stashed < 4 via "a < b: a < c; c <= b"(c: src_bytes_per_pixel)
			if args.src.length() <= 0 {
				return base."$short read"
			}
			this.stash[this.num_stashed] = args.src[0]
			this.num_stashed += 1
			args.src = args.src[1 ..]
		} endwhile

		// Write the single pixel.
		dst = tab.row(y: this.dst_y)
		if dst_bytes_per_row < dst.length() {
			dst = dst[.. dst_bytes_per_row]
		}
		i = (this.dst_x as base.u64) * dst_bytes_per_pixel
		if i < dst.length() {
			this.swizzler.swizzle_interleaved!(
				dst: dst[i ..],
				dst_palette: this.util.empty_slice_u8(),
				src: this.stash[.. this.num_stashed])
			this.dst_x ~sat+= 1
		}

		this.num_stashed = 0
	}

	while true {
		if this.dst_x == this.width {
			this.dst_x = 0
			this.dst_y ~mod+= this.dst_y_inc

			if this.pad_per_row <> 0 {
				n = this.pad_per_row as base.u64
				if n <= args.src.length() {
					args.src = args.src[n ..]
				} else {
					this.pending_pad = ((n - args.src.length()) & 3) as base.u8
					return base."$short read"
				}
			}
		}

		if this.dst_y == this.dst_y_end {
			break
		}

		dst = tab.row(y: this.dst_y)
		if dst_bytes_per_row < dst.length() {
			dst = dst[.. dst_bytes_per_row]
		}
		i = (this.dst_x as base.u64) * dst_bytes_per_pixel
		if i < dst.length() {
			n = this.swizzler.swizzle_interleaved!(
				dst: dst[i ..],
				dst_palette: this.util.empty_slice_u8(),
				src: args.src)
			this.dst_x ~sat+= (n & 0xFFFF_FFFF) as base.u32
			n = (n & 0xFFFF_FFFF) * (src_bytes_per_pixel as base.u64)
			if n <= args.src.length() {
				args.src = args.src[n ..]
			} else {
				return "#internal error: inconsistent swizzle count"
			}
		}

		// Suspend if we didn't complete the row, potentially in the middle of
		// a source pixel.
		if this.dst_x < this.width {
			while (this.num_stashed < src_bytes_per_pixel) and (args.src.length() > 0) {
				assert this.num_stashed < 4 via "a < b: a < c; c <= b"(c: src_bytes_per_pixel)
				this.stash[this.num_stashed] = args.src[0]
				this.num_stashed += 1
				args.src = args.src[1 ..]
			} endwhile
			return base."$short read"
		}
	} endwhile

	return ok
}

pri func decoder.skip_frame?(src: base.io_reader) {
	args.src.skip32?(n: this.padding)
	args.src.skip?(n: this.bytes_total)

	this.call_sequence = 3
}

pub func decoder.ack_metadata_chunk?(src: base.io_reader) {
	// TODO: this final line shouldn't be necessary, here and in some other
	// std/*/*.wuffs functions.
	return ok
}

pub func decoder.frame_dirty_rect() base.rect_ie_u32 {
	return this.util.make_rect_ie_u32(
		min_incl_x: 0,
		min_incl_y: 0,
		max_excl_x: this.width,
		max_excl_y: this.height)
}

pub func decoder.metadata_chunk_length() base.u64 {
	return 0
}

pub func decoder.metadata_fourcc() base.u32 {
	return 0
}

pub func decoder.num_animation_loops() base.u32 {
	return 0
}

pub func decoder.num_decoded_frame_configs() base.u64 {
	if this.call_sequence > 1 {
		return 1
	}
	return 0
}

pub func decoder.num_decoded_frames() base.u64 {
	if this.call_sequence > 2 {
		return 1
	}
	return 0
}

pub func decoder.restart_frame!(index: base.u64, io_position: base.u64) base.status {
	if this.call_sequence == 0 {
		return base."#bad call sequence"
	}
	if args.index <> 0 {
		return base."#bad argument"
	}
	this.call_sequence = 1
	this.frame_config_io_position = args.io_position
	return ok
}

pub func decoder.set_report_metadata!(fourcc: base.u32, report: base.bool) {
	// No-op. BMP doesn't support metadata.
}

pub func decoder.workbuf_len() base.range_ii_u64 {
	return this.util.make_range_ii_u64(min_incl: 0, max_incl: 0)
}
