/*
 * MVKDXTnCodec.def
 *
 * Copyright (c) 2018-2022 Chip Davis for CodeWeavers
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


#ifndef MVK_DECOMPRESS_CODE
#error MVK_DECOMPRESS_CODE must be defined before including this file
#endif

MVK_DECOMPRESS_CODE(
	static bool isBC1Format(VkFormat format) {
		return format == VK_FORMAT_BC1_RGB_UNORM_BLOCK || format == VK_FORMAT_BC1_RGB_SRGB_BLOCK ||
			format == VK_FORMAT_BC1_RGBA_UNORM_BLOCK || format == VK_FORMAT_BC1_RGBA_SRGB_BLOCK;
	}

	static bool isBC2Format(VkFormat format) {
		return format == VK_FORMAT_BC2_UNORM_BLOCK || format == VK_FORMAT_BC2_SRGB_BLOCK;
	}

	static bool isBC3Format(VkFormat format) {
		return format == VK_FORMAT_BC3_UNORM_BLOCK || format == VK_FORMAT_BC3_SRGB_BLOCK;
	}

	static bool isSRGBFormat(VkFormat format) {
		return format == VK_FORMAT_BC1_RGB_SRGB_BLOCK || format == VK_FORMAT_BC1_RGBA_SRGB_BLOCK ||
			format == VK_FORMAT_BC2_SRGB_BLOCK || format == VK_FORMAT_BC3_SRGB_BLOCK;
	}

	static void buildDXTnColourTable(uint16_t colour0, uint16_t colour1, thread float3* pColourTable, VkFormat format) {
		pColourTable[0] = unpack_unorm565_to_float(colour0);
		pColourTable[1] = unpack_unorm565_to_float(colour1);

		if (isBC1Format(format) && colour0 <= colour1) {
			pColourTable[2] = (pColourTable[0] + pColourTable[1]) / 2;
			pColourTable[3] = float3(0);
		} else {
			for (uint32_t i = 0; i < 2; ++i) {
				pColourTable[i + 2] = (2 * pColourTable[i] + pColourTable[1 - i]) / 3;
			}
		}
	}

	static void buildDXT5AlphaTable(uint8_t alpha0, uint8_t alpha1, thread float* pAlphaTable) {
		pAlphaTable[0] = alpha0 / 255.0f;
		pAlphaTable[1] = alpha1 / 255.0f;

		if (alpha0 > alpha1) {
			for (uint32_t i = 0; i < 6; ++i) {
				pAlphaTable[2 + i] = ((6 - i) * pAlphaTable[0] + (i + 1) * pAlphaTable[1]) / 7;
			}
		} else {
			for (uint32_t i = 0; i < 4; ++i) {
				pAlphaTable[2 + i] = ((4 - i) * pAlphaTable[0] + (i + 1) * pAlphaTable[1]) / 5;
			}
			pAlphaTable[6] = 0;
			pAlphaTable[7] = 1;
		}
	}

	static float3 sRGBCorrect(float3 colour) {
		return select(pow((colour + 0.055)/1.055, float3(2.4)), colour/12.92, colour <= 0.04045);
	}

	static void decompressDXTnBlock(const device void* pSrc, thread void* pDest, VkExtent2D extent, VkDeviceSize destRowPitch, VkFormat format) {
		const device uint32_t* pSrcBlock = (const device uint32_t *)pSrc;
		bool isBC1Alpha = false;
		float3 colourTable[4];
		float alphaTable[8];
		size_t alphaBits;
		uint32_t colourBits;

		if (isBC1Format(format)) {
			alphaBits = 0;

			uint16_t colour0 = pSrcBlock[0] & 0xffff;
			uint16_t colour1 = pSrcBlock[0] >> 16;
			colourBits = pSrcBlock[1];
			buildDXTnColourTable(colour0, colour1, colourTable, format);
			if (colour0 <= colour1) { isBC1Alpha = true; }
		} else {
			alphaBits = pSrcBlock[0] | ((size_t)pSrcBlock[1] << 32);
			if (isBC3Format(format)) {
				buildDXT5AlphaTable(alphaBits & 0xff, (alphaBits >> 8) & 0xff, alphaTable);
				alphaBits >>= 16;
			}

			colourBits = pSrcBlock[3];
			buildDXTnColourTable(pSrcBlock[2] & 0xffff, pSrcBlock[2] >> 16, colourTable, format);
		}

		for (uint32_t y = 0; y < extent.height; ++y) {
			thread uint32_t* pDestRow = (thread uint32_t *)((thread uint8_t *)pDest + y * destRowPitch);
			for (uint32_t x = 0; x < extent.width; ++x) {
				uint8_t colourIndex = (colourBits >> (y * 8 + x * 2)) & 0x3;
				float alpha;
				if (isBC1Format(format)) {
					alpha = (!isBC1Alpha || colourIndex != 3) ? 1.0f : 0.0f;
				} else if (isBC2Format(format)) {
					alpha = ((alphaBits >> (y * 16 + x * 4)) & 0xf) / 15.0f;
				} else {	// Must be a BC3 format
					alpha = alphaTable[(alphaBits >> (y * 12 + x * 3)) & 0x7];
				}
				float4 colour;
				colour.rgb = colourTable[colourIndex];
				if (isSRGBFormat(format)) {
					// Convert sRGB back to linear.
					colour.rgb = sRGBCorrect(colour.rgb);
				}
				colour.a = alpha;
				pDestRow[x] = pack_float_to_unorm4x8(colour);
			}
		}
	}
)
