spirv-opt: Fix OpCompositeExtract relaxation with struct operands (#5536)

diff --git a/source/opt/convert_to_half_pass.cpp b/source/opt/convert_to_half_pass.cpp
index cb0065d..e243bed 100644
--- a/source/opt/convert_to_half_pass.cpp
+++ b/source/opt/convert_to_half_pass.cpp
@@ -171,6 +171,19 @@
 
 bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
   bool modified = false;
+  // If this is a OpCompositeExtract instruction and has a struct operand, we
+  // should not relax this instruction. Doing so could cause a mismatch between
+  // the result type and the struct member type.
+  bool hasStructOperand = false;
+  if (inst->opcode() == spv::Op::OpCompositeExtract) {
+    inst->ForEachInId([&hasStructOperand, this](uint32_t* idp) {
+      Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
+      if (IsStruct(op_inst)) hasStructOperand = true;
+    });
+    if (hasStructOperand) {
+      return false;
+    }
+  }
   // Convert all float32 based operands to float16 equivalent and change
   // instruction type to float16 equivalent.
   inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
@@ -303,12 +316,19 @@
   if (closure_ops_.count(inst->opcode()) == 0) return false;
   // Can relax if all float operands are relaxed
   bool relax = true;
-  inst->ForEachInId([&relax, this](uint32_t* idp) {
+  bool hasStructOperand = false;
+  inst->ForEachInId([&relax, &hasStructOperand, this](uint32_t* idp) {
     Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
-    if (IsStruct(op_inst)) relax = false;
+    if (IsStruct(op_inst)) hasStructOperand = true;
     if (!IsFloat(op_inst, 32)) return;
     if (!IsRelaxed(*idp)) relax = false;
   });
+  // If the instruction has a struct operand, we should not relax it, even if
+  // all its uses are relaxed. Doing so could cause a mismatch between the
+  // result type and the struct member type.
+  if (hasStructOperand) {
+    return false;
+  }
   if (relax) {
     AddRelaxed(inst->result_id());
     return true;
diff --git a/test/opt/convert_relaxed_to_half_test.cpp b/test/opt/convert_relaxed_to_half_test.cpp
index 62b9ae4..c577404 100644
--- a/test/opt/convert_relaxed_to_half_test.cpp
+++ b/test/opt/convert_relaxed_to_half_test.cpp
@@ -1713,6 +1713,75 @@
   SinglePassRunAndMatch<ConvertToHalfPass>(test, true);
 }
 
+TEST_F(ConvertToHalfTest, DontRelaxDecoratedOpCompositeExtract) {
+  // This test checks that a OpCompositeExtract with a Struct operand won't be
+  // relaxed, even if it is explicitly decorated with RelaxedPrecision.
+  const std::string test =
+      R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %9 RelaxedPrecision
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+  const std::string expected =
+      R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<ConvertToHalfPass>(test, expected, true);
+}
+
+TEST_F(ConvertToHalfTest, DontRelaxOpCompositeExtract) {
+  // This test checks that a OpCompositeExtract with a Struct operand won't be
+  // relaxed, even if its result has no uses.
+  const std::string test =
+      R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_6 = OpTypeStruct %v4float
+%7 = OpUndef %_struct_6
+%1 = OpFunction %void None %3
+%8 = OpLabel
+%9 = OpCompositeExtract %float %7 0 3
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<ConvertToHalfPass>(test, test, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools