opt: add Int16 and Float16 to capability trim pass (#5519)
Add support for Int16 and Float16 trim.
Signed-off-by: Nathan Gauër <brioche@google.com>
diff --git a/source/opt/trim_capabilities_pass.cpp b/source/opt/trim_capabilities_pass.cpp
index 0549947..19f8569 100644
--- a/source/opt/trim_capabilities_pass.cpp
+++ b/source/opt/trim_capabilities_pass.cpp
@@ -137,6 +137,16 @@
// Handler names follow the following convention:
// Handler_<Opcode>_<Capability>()
+static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeFloat &&
+ "This handler only support OpTypeFloat opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
+ return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
+}
+
static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeFloat &&
@@ -274,6 +284,16 @@
: std::nullopt;
}
+static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
+ const Instruction* instruction) {
+ assert(instruction->opcode() == spv::Op::OpTypeInt &&
+ "This handler only support OpTypeInt opcodes.");
+
+ const uint32_t size =
+ instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
+ return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
+}
+
static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeInt &&
@@ -341,12 +361,14 @@
}
// Opcode of interest to determine capabilities requirements.
-constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 10> kOpcodeHandlers{{
+constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
// clang-format off
{spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
{spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
+ {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
{spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
+ {spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
{spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
diff --git a/source/opt/trim_capabilities_pass.h b/source/opt/trim_capabilities_pass.h
index 73d5dc8..9f23732 100644
--- a/source/opt/trim_capabilities_pass.h
+++ b/source/opt/trim_capabilities_pass.h
@@ -76,12 +76,14 @@
// clang-format off
spv::Capability::ComputeDerivativeGroupLinearNV,
spv::Capability::ComputeDerivativeGroupQuadsNV,
+ spv::Capability::Float16,
spv::Capability::Float64,
spv::Capability::FragmentShaderPixelInterlockEXT,
spv::Capability::FragmentShaderSampleInterlockEXT,
spv::Capability::FragmentShaderShadingRateInterlockEXT,
spv::Capability::Groups,
spv::Capability::ImageMSArray,
+ spv::Capability::Int16,
spv::Capability::Int64,
spv::Capability::Linkage,
spv::Capability::MinLod,
diff --git a/test/opt/trim_capabilities_pass_test.cpp b/test/opt/trim_capabilities_pass_test.cpp
index 14a8aa3..c90afb4 100644
--- a/test/opt/trim_capabilities_pass_test.cpp
+++ b/test/opt/trim_capabilities_pass_test.cpp
@@ -2486,6 +2486,104 @@
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}
+TEST_F(TrimCapabilitiesPassTest, Float16_RemovedWhenUnused) {
+ const std::string kTest = R"(
+ OpCapability Float16
+; CHECK-NOT: OpCapability Float16
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd;
+ )";
+ const auto result =
+ SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+ EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Float16_RemainsWhenUsed) {
+ const std::string kTest = R"(
+ OpCapability Float16
+; CHECK: OpCapability Float16
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ %void = OpTypeVoid
+ %float = OpTypeFloat 16
+ %3 = OpTypeFunction %void
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd;
+ )";
+ const auto result =
+ SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+ EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Int16_RemovedWhenUnused) {
+ const std::string kTest = R"(
+ OpCapability Int16
+; CHECK-NOT: OpCapability Int16
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd;
+ )";
+ const auto result =
+ SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+ EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, Int16_RemainsWhenUsed) {
+ const std::string kTest = R"(
+ OpCapability Int16
+; CHECK: OpCapability Int16
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ %void = OpTypeVoid
+ %int = OpTypeInt 16 1
+ %3 = OpTypeFunction %void
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd;
+ )";
+ const auto result =
+ SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+ EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
+TEST_F(TrimCapabilitiesPassTest, UInt16_RemainsWhenUsed) {
+ const std::string kTest = R"(
+ OpCapability Int16
+; CHECK: OpCapability Int16
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %1 "main"
+ %void = OpTypeVoid
+ %uint = OpTypeInt 16 0
+ %3 = OpTypeFunction %void
+ %1 = OpFunction %void None %3
+ %6 = OpLabel
+ OpReturn
+ OpFunctionEnd;
+ )";
+ const auto result =
+ SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
+ EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools