opt: add Int16 and Float16 to capability trim pass (#5519)

Add support for Int16 and Float16 trim.

Signed-off-by: Nathan Gauër <brioche@google.com>
diff --git a/source/opt/trim_capabilities_pass.cpp b/source/opt/trim_capabilities_pass.cpp
index 0549947..19f8569 100644
--- a/source/opt/trim_capabilities_pass.cpp
+++ b/source/opt/trim_capabilities_pass.cpp
@@ -137,6 +137,16 @@
 // Handler names follow the following convention:
 //  Handler_<Opcode>_<Capability>()
 
+static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
+    const Instruction* instruction) {
+  assert(instruction->opcode() == spv::Op::OpTypeFloat &&
+         "This handler only support OpTypeFloat opcodes.");
+
+  const uint32_t size =
+      instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
+  return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
+}
+
 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
     const Instruction* instruction) {
   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
@@ -274,6 +284,16 @@
                         : std::nullopt;
 }
 
+static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
+    const Instruction* instruction) {
+  assert(instruction->opcode() == spv::Op::OpTypeInt &&
+         "This handler only support OpTypeInt opcodes.");
+
+  const uint32_t size =
+      instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
+  return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
+}
+
 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
     const Instruction* instruction) {
   assert(instruction->opcode() == spv::Op::OpTypeInt &&
@@ -341,12 +361,14 @@
 }
 
 // Opcode of interest to determine capabilities requirements.
-constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 10> kOpcodeHandlers{{
+constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
     // clang-format off
     {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
     {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
+    {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float16 },
     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
     {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
+    {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int16 },
     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
diff --git a/source/opt/trim_capabilities_pass.h b/source/opt/trim_capabilities_pass.h
index 73d5dc8..9f23732 100644
--- a/source/opt/trim_capabilities_pass.h
+++ b/source/opt/trim_capabilities_pass.h
@@ -76,12 +76,14 @@
       // clang-format off
       spv::Capability::ComputeDerivativeGroupLinearNV,
       spv::Capability::ComputeDerivativeGroupQuadsNV,
+      spv::Capability::Float16,
       spv::Capability::Float64,
       spv::Capability::FragmentShaderPixelInterlockEXT,
       spv::Capability::FragmentShaderSampleInterlockEXT,
       spv::Capability::FragmentShaderShadingRateInterlockEXT,
       spv::Capability::Groups,
       spv::Capability::ImageMSArray,
+      spv::Capability::Int16,
       spv::Capability::Int64,
       spv::Capability::Linkage,
       spv::Capability::MinLod,
diff --git a/test/opt/trim_capabilities_pass_test.cpp b/test/opt/trim_capabilities_pass_test.cpp
index 14a8aa3..c90afb4 100644
--- a/test/opt/trim_capabilities_pass_test.cpp
+++ b/test/opt/trim_capabilities_pass_test.cpp
@@ -2486,6 +2486,104 @@
   EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
 }
 
+TEST_F(TrimCapabilitiesPassTest, Float16_RemovedWhenUnused) {
+  const std::string kTest = R"(
+               OpCapability Float16
+; CHECK-NOT:   OpCapability Float16
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd;
+  )";
+  const auto result =
+      SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+  EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Float16_RemainsWhenUsed) {
+  const std::string kTest = R"(
+               OpCapability Float16
+; CHECK:       OpCapability Float16
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+       %void = OpTypeVoid
+      %float = OpTypeFloat 16
+          %3 = OpTypeFunction %void
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd;
+  )";
+  const auto result =
+      SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+  EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Int16_RemovedWhenUnused) {
+  const std::string kTest = R"(
+               OpCapability Int16
+; CHECK-NOT:   OpCapability Int16
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd;
+  )";
+  const auto result =
+      SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+  EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Int16_RemainsWhenUsed) {
+  const std::string kTest = R"(
+               OpCapability Int16
+; CHECK:       OpCapability Int16
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+       %void = OpTypeVoid
+        %int = OpTypeInt 16 1
+          %3 = OpTypeFunction %void
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd;
+  )";
+  const auto result =
+      SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+  EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, UInt16_RemainsWhenUsed) {
+  const std::string kTest = R"(
+               OpCapability Int16
+; CHECK:       OpCapability Int16
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+       %void = OpTypeVoid
+       %uint = OpTypeInt 16 0
+          %3 = OpTypeFunction %void
+          %1 = OpFunction %void None %3
+          %6 = OpLabel
+               OpReturn
+               OpFunctionEnd;
+  )";
+  const auto result =
+      SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+  EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools