spirv-opt: Fix OpCompositeExtract relaxation with struct operands (#5536)
diff --git a/source/opt/convert_to_half_pass.cpp b/source/opt/convert_to_half_pass.cpp
index cb0065d..e243bed 100644
--- a/source/opt/convert_to_half_pass.cpp
+++ b/source/opt/convert_to_half_pass.cpp
@@ -171,6 +171,19 @@
bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
bool modified = false;
+ // If this is a OpCompositeExtract instruction and has a struct operand, we
+ // should not relax this instruction. Doing so could cause a mismatch between
+ // the result type and the struct member type.
+ bool hasStructOperand = false;
+ if (inst->opcode() == spv::Op::OpCompositeExtract) {
+ inst->ForEachInId([&hasStructOperand, this](uint32_t* idp) {
+ Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
+ if (IsStruct(op_inst)) hasStructOperand = true;
+ });
+ if (hasStructOperand) {
+ return false;
+ }
+ }
// Convert all float32 based operands to float16 equivalent and change
// instruction type to float16 equivalent.
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
@@ -303,12 +316,19 @@
if (closure_ops_.count(inst->opcode()) == 0) return false;
// Can relax if all float operands are relaxed
bool relax = true;
- inst->ForEachInId([&relax, this](uint32_t* idp) {
+ bool hasStructOperand = false;
+ inst->ForEachInId([&relax, &hasStructOperand, this](uint32_t* idp) {
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
- if (IsStruct(op_inst)) relax = false;
+ if (IsStruct(op_inst)) hasStructOperand = true;
if (!IsFloat(op_inst, 32)) return;
if (!IsRelaxed(*idp)) relax = false;
});
+ // If the instruction has a struct operand, we should not relax it, even if
+ // all its uses are relaxed. Doing so could cause a mismatch between the
+ // result type and the struct member type.
+ if (hasStructOperand) {
+ return false;
+ }
if (relax) {
AddRelaxed(inst->result_id());
return true;
diff --git a/test/opt/convert_relaxed_to_half_test.cpp b/test/opt/convert_relaxed_to_half_test.cpp
index 62b9ae4..c577404 100644
--- a/test/opt/convert_relaxed_to_half_test.cpp
+++ b/test/opt/convert_relaxed_to_half_test.cpp
@@ -1713,6 +1713,75 @@
SinglePassRunAndMatch<ConvertToHalfPass>(test, true);
}
+TEST_F(ConvertToHalfTest, DontRelaxDecoratedOpCompositeExtract) {
+ // This test checks that a OpCompositeExtract with a Struct operand won't be
+ // relaxed, even if it is explicitly decorated with RelaxedPrecision.
+ const std::string test =
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %9 RelaxedPrecision
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string expected =
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<ConvertToHalfPass>(test, expected, true);
+}
+
+TEST_F(ConvertToHalfTest, DontRelaxOpCompositeExtract) {
+ // This test checks that a OpCompositeExtract with a Struct operand won't be
+ // relaxed, even if its result has no uses.
+ const std::string test =
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<ConvertToHalfPass>(test, test, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools