spirv-fuzz: Remove OpFunctionCall operands in correct order (#3630)

Fixes #3629.
diff --git a/source/fuzz/transformation_replace_params_with_struct.cpp b/source/fuzz/transformation_replace_params_with_struct.cpp
index 7e76415..8c683a3 100644
--- a/source/fuzz/transformation_replace_params_with_struct.cpp
+++ b/source/fuzz/transformation_replace_params_with_struct.cpp
@@ -101,8 +101,9 @@
     return false;
   }
 
-  auto caller_id_to_fresh_composite_id = fuzzerutil::RepeatedUInt32PairToMap(
-      message_.caller_id_to_fresh_composite_id());
+  const auto caller_id_to_fresh_composite_id =
+      fuzzerutil::RepeatedUInt32PairToMap(
+          message_.caller_id_to_fresh_composite_id());
 
   // Check that |callee_id_to_fresh_composite_id| is valid.
   for (const auto* inst :
@@ -151,24 +152,12 @@
   // Compute indices of replaced parameters. This will be used to adjust
   // OpFunctionCall instructions and create OpCompositeConstruct instructions at
   // every call site.
-  std::vector<uint32_t> indices_of_replaced_params;
-  {
-    // We want to destroy |params| after the loop because it will contain
-    // dangling pointers when we remove parameters from the function.
-    auto params = fuzzerutil::GetParameters(ir_context, function->result_id());
-    for (auto id : message_.parameter_id()) {
-      auto it = std::find_if(params.begin(), params.end(),
-                             [id](const opt::Instruction* param) {
-                               return param->result_id() == id;
-                             });
-      assert(it != params.end() && "Parameter's id is invalid");
-      indices_of_replaced_params.push_back(
-          static_cast<uint32_t>(it - params.begin()));
-    }
-  }
+  const auto indices_of_replaced_params =
+      ComputeIndicesOfReplacedParameters(ir_context);
 
-  auto caller_id_to_fresh_composite_id = fuzzerutil::RepeatedUInt32PairToMap(
-      message_.caller_id_to_fresh_composite_id());
+  const auto caller_id_to_fresh_composite_id =
+      fuzzerutil::RepeatedUInt32PairToMap(
+          message_.caller_id_to_fresh_composite_id());
 
   // Update all function calls.
   for (auto* inst : fuzzerutil::GetCallers(ir_context, function->result_id())) {
@@ -182,12 +171,13 @@
     }
 
     // Remove arguments from the function call. We do it in a separate loop
-    // and in reverse order to make sure we have removed correct operands.
-    for (auto it = indices_of_replaced_params.rbegin();
-         it != indices_of_replaced_params.rend(); ++it) {
+    // and in decreasing order to make sure we have removed correct operands.
+    for (auto index : std::set<uint32_t, std::greater<uint32_t>>(
+             indices_of_replaced_params.begin(),
+             indices_of_replaced_params.end())) {
       // +1 since the first in operand to OpFunctionCall is the result id of
       // the function.
-      inst->RemoveInOperand(*it + 1);
+      inst->RemoveInOperand(index + 1);
     }
 
     // Insert OpCompositeConstruct before the function call.
@@ -305,5 +295,30 @@
   return fuzzerutil::MaybeGetStructType(ir_context, component_type_ids);
 }
 
+std::vector<uint32_t>
+TransformationReplaceParamsWithStruct::ComputeIndicesOfReplacedParameters(
+    opt::IRContext* ir_context) const {
+  assert(!message_.parameter_id().empty() &&
+         "There must be at least one parameter to replace");
+
+  const auto* function = fuzzerutil::GetFunctionFromParameterId(
+      ir_context, message_.parameter_id(0));
+  assert(function && "|parameter_id|s are invalid");
+
+  std::vector<uint32_t> result;
+
+  auto params = fuzzerutil::GetParameters(ir_context, function->result_id());
+  for (auto id : message_.parameter_id()) {
+    auto it = std::find_if(params.begin(), params.end(),
+                           [id](const opt::Instruction* param) {
+                             return param->result_id() == id;
+                           });
+    assert(it != params.end() && "Parameter's id is invalid");
+    result.push_back(static_cast<uint32_t>(it - params.begin()));
+  }
+
+  return result;
+}
+
 }  // namespace fuzz
 }  // namespace spvtools
diff --git a/source/fuzz/transformation_replace_params_with_struct.h b/source/fuzz/transformation_replace_params_with_struct.h
index d2ce204..7e40de8 100644
--- a/source/fuzz/transformation_replace_params_with_struct.h
+++ b/source/fuzz/transformation_replace_params_with_struct.h
@@ -73,6 +73,12 @@
   // transformation (see docs on the IsApplicable method to learn more).
   uint32_t MaybeGetRequiredStructType(opt::IRContext* ir_context) const;
 
+  // Returns a vector of indices of parameters to replace. Concretely, i'th
+  // element is the index of the parameter with result id |parameter_id[i]| in
+  // its function.
+  std::vector<uint32_t> ComputeIndicesOfReplacedParameters(
+      opt::IRContext* ir_context) const;
+
   protobufs::TransformationReplaceParamsWithStruct message_;
 };
 
diff --git a/test/fuzz/transformation_replace_params_with_struct_test.cpp b/test/fuzz/transformation_replace_params_with_struct_test.cpp
index e59f6ea..a198b42 100644
--- a/test/fuzz/transformation_replace_params_with_struct_test.cpp
+++ b/test/fuzz/transformation_replace_params_with_struct_test.cpp
@@ -333,6 +333,147 @@
   ASSERT_TRUE(IsEqual(env, expected_shader, context.get()));
 }
 
+TEST(TransformationReplaceParamsWithStructTest, ParametersRemainValid) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %6 = OpTypeInt 32 1
+          %3 = OpTypeFunction %2
+          %8 = OpTypeFloat 32
+         %10 = OpTypeVector %8 2
+         %12 = OpTypeBool
+         %40 = OpTypePointer Function %12
+         %13 = OpTypeStruct %6 %8
+         %45 = OpTypeStruct %6 %8 %13
+         %47 = OpTypeStruct %45 %12 %10
+         %15 = OpTypeFunction %2 %6 %8 %10 %13 %40 %12
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %20 = OpFunction %2 None %15
+         %16 = OpFunctionParameter %6
+         %17 = OpFunctionParameter %8
+         %18 = OpFunctionParameter %10
+         %19 = OpFunctionParameter %13
+         %42 = OpFunctionParameter %40
+         %43 = OpFunctionParameter %12
+         %21 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(&fact_manager,
+                                               validator_options);
+
+  {
+    // Try to replace parameters in "increasing" order of their declaration.
+    TransformationReplaceParamsWithStruct transformation({16, 17, 19}, 70, 71,
+                                                         {{}});
+    ASSERT_TRUE(
+        transformation.IsApplicable(context.get(), transformation_context));
+    transformation.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+  }
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %6 = OpTypeInt 32 1
+          %3 = OpTypeFunction %2
+          %8 = OpTypeFloat 32
+         %10 = OpTypeVector %8 2
+         %12 = OpTypeBool
+         %40 = OpTypePointer Function %12
+         %13 = OpTypeStruct %6 %8
+         %45 = OpTypeStruct %6 %8 %13
+         %47 = OpTypeStruct %45 %12 %10
+         %15 = OpTypeFunction %2 %10 %40 %12 %45
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %20 = OpFunction %2 None %15
+         %18 = OpFunctionParameter %10
+         %42 = OpFunctionParameter %40
+         %43 = OpFunctionParameter %12
+         %71 = OpFunctionParameter %45
+         %21 = OpLabel
+         %19 = OpCompositeExtract %13 %71 2
+         %17 = OpCompositeExtract %8 %71 1
+         %16 = OpCompositeExtract %6 %71 0
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+
+  {
+    // Try to replace parameters in "decreasing" order of their declaration.
+    TransformationReplaceParamsWithStruct transformation({71, 43, 18}, 72, 73,
+                                                         {{}});
+    ASSERT_TRUE(
+        transformation.IsApplicable(context.get(), transformation_context));
+    transformation.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+  }
+
+  after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %6 = OpTypeInt 32 1
+          %3 = OpTypeFunction %2
+          %8 = OpTypeFloat 32
+         %10 = OpTypeVector %8 2
+         %12 = OpTypeBool
+         %40 = OpTypePointer Function %12
+         %13 = OpTypeStruct %6 %8
+         %45 = OpTypeStruct %6 %8 %13
+         %47 = OpTypeStruct %45 %12 %10
+         %15 = OpTypeFunction %2 %40 %47
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %20 = OpFunction %2 None %15
+         %42 = OpFunctionParameter %40
+         %73 = OpFunctionParameter %47
+         %21 = OpLabel
+         %18 = OpCompositeExtract %10 %73 2
+         %43 = OpCompositeExtract %12 %73 1
+         %71 = OpCompositeExtract %45 %73 0
+         %19 = OpCompositeExtract %13 %71 2
+         %17 = OpCompositeExtract %8 %71 1
+         %16 = OpCompositeExtract %6 %71 0
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools