spirv-fuzz: Refactor boilerplate in TransformationAddParameter (#3625)

Part of #3534. I forgot to implement this functionality in the original PR.
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index b2d5681..15d1057 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -826,6 +826,12 @@
   operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
                      parameter_type_ids.end());
 
+  // A trivial case - we change nothing.
+  if (FindFunctionType(ir_context, operand_ids) ==
+      old_function_type->result_id()) {
+    return old_function_type->result_id();
+  }
+
   if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
       FindFunctionType(ir_context, operand_ids) == 0) {
     // We can change |old_function_type| only if it's used once in the module
@@ -849,17 +855,16 @@
     // existing one or create a new one.
     auto type_id = FindOrCreateFunctionType(
         ir_context, new_function_type_result_id, operand_ids);
+    assert(type_id != old_function_type->result_id() &&
+           "We should've handled this case above");
 
-    if (type_id != old_function_type->result_id()) {
-      function->DefInst().SetInOperand(1, {type_id});
+    function->DefInst().SetInOperand(1, {type_id});
 
-      // DefUseManager hasn't been updated yet, so if the following condition is
-      // true, then |old_function_type| will have no users when this function
-      // returns. We might as well remove it.
-      if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
-        old_function_type->RemoveFromList();
-        delete old_function_type;
-      }
+    // DefUseManager hasn't been updated yet, so if the following condition is
+    // true, then |old_function_type| will have no users when this function
+    // returns. We might as well remove it.
+    if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
+      ir_context->KillInst(old_function_type);
     }
 
     return type_id;
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index 343c94e..6058022 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -316,6 +316,10 @@
 // more users, it is removed from the module. Returns the result id of the
 // OpTypeFunction instruction that is used as a type of the function with
 // |function_id|.
+//
+// CAUTION: When the old type of the function is removed from the module, its
+//          memory is deallocated. Be sure not to use any pointers to the old
+//          type when this function returns.
 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
                             uint32_t new_function_type_result_id,
                             uint32_t return_type_id,
diff --git a/source/fuzz/transformation_add_parameter.cpp b/source/fuzz/transformation_add_parameter.cpp
index cc32362..6c0ab28 100644
--- a/source/fuzz/transformation_add_parameter.cpp
+++ b/source/fuzz/transformation_add_parameter.cpp
@@ -14,8 +14,6 @@
 
 #include "source/fuzz/transformation_add_parameter.h"
 
-#include <source/spirv_constant.h>
-
 #include "source/fuzz/fuzzer_util.h"
 
 namespace spvtools {
@@ -72,13 +70,13 @@
   auto* function = fuzzerutil::FindFunction(ir_context, message_.function_id());
   assert(function && "Can't find the function");
 
-  auto parameter_type_id =
+  const auto new_parameter_type_id =
       fuzzerutil::GetTypeId(ir_context, message_.initializer_id());
-  assert(parameter_type_id != 0 && "Initializer has invalid type");
+  assert(new_parameter_type_id != 0 && "Initializer has invalid type");
 
   // Add new parameters to the function.
   function->AddParameter(MakeUnique<opt::Instruction>(
-      ir_context, SpvOpFunctionParameter, parameter_type_id,
+      ir_context, SpvOpFunctionParameter, new_parameter_type_id,
       message_.parameter_fresh_id(), opt::Instruction::OperandList()));
 
   fuzzerutil::UpdateModuleIdBound(ir_context, message_.parameter_fresh_id());
@@ -92,43 +90,30 @@
       message_.parameter_fresh_id());
 
   // Fix all OpFunctionCall instructions.
-  ir_context->get_def_use_mgr()->ForEachUser(
-      &function->DefInst(), [this](opt::Instruction* call) {
-        if (call->opcode() != SpvOpFunctionCall ||
-            call->GetSingleWordInOperand(0) != message_.function_id()) {
-          return;
-        }
+  for (auto* inst : fuzzerutil::GetCallers(ir_context, function->result_id())) {
+    inst->AddOperand({SPV_OPERAND_TYPE_ID, {message_.initializer_id()}});
+  }
 
-        call->AddOperand({SPV_OPERAND_TYPE_ID, {message_.initializer_id()}});
-      });
+  // Update function's type.
+  {
+    // We use a separate scope here since |old_function_type| might become a
+    // dangling pointer after the call to the fuzzerutil::UpdateFunctionType.
 
-  auto* old_function_type = fuzzerutil::GetFunctionType(ir_context, function);
-  assert(old_function_type && "Function must have a valid type");
+    const auto* old_function_type =
+        fuzzerutil::GetFunctionType(ir_context, function);
+    assert(old_function_type && "Function must have a valid type");
 
-  if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
-    // Adjust existing function type if it is used only by this function.
-    old_function_type->AddOperand({SPV_OPERAND_TYPE_ID, {parameter_type_id}});
-
-    // We must make sure that all dependencies of |old_function_type| are
-    // evaluated before |old_function_type| (i.e. the domination rules are not
-    // broken). Thus, we move |old_function_type| to the end of the list of all
-    // types in the module.
-    old_function_type->RemoveFromList();
-    ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
-  } else {
-    // Otherwise, either create a new type or use an existing one.
-    std::vector<uint32_t> type_ids;
-    type_ids.reserve(old_function_type->NumInOperands() + 1);
-
-    for (uint32_t i = 0, n = old_function_type->NumInOperands(); i < n; ++i) {
-      type_ids.push_back(old_function_type->GetSingleWordInOperand(i));
+    std::vector<uint32_t> parameter_type_ids;
+    for (uint32_t i = 1; i < old_function_type->NumInOperands(); ++i) {
+      parameter_type_ids.push_back(
+          old_function_type->GetSingleWordInOperand(i));
     }
 
-    type_ids.push_back(parameter_type_id);
+    parameter_type_ids.push_back(new_parameter_type_id);
 
-    function->DefInst().SetInOperand(
-        1, {fuzzerutil::FindOrCreateFunctionType(
-               ir_context, message_.function_type_fresh_id(), type_ids)});
+    fuzzerutil::UpdateFunctionType(
+        ir_context, function->result_id(), message_.function_type_fresh_id(),
+        old_function_type->GetSingleWordInOperand(0), parameter_type_ids);
   }
 
   // Make sure our changes are analyzed.
diff --git a/test/fuzz/transformation_add_parameter_test.cpp b/test/fuzz/transformation_add_parameter_test.cpp
index 6593d00..c57c738 100644
--- a/test/fuzz/transformation_add_parameter_test.cpp
+++ b/test/fuzz/transformation_add_parameter_test.cpp
@@ -32,20 +32,69 @@
           %2 = OpTypeVoid
           %7 = OpTypeBool
          %11 = OpTypeInt 32 1
+         %16 = OpTypeFloat 32
           %3 = OpTypeFunction %2
           %6 = OpTypeFunction %7 %7
           %8 = OpConstant %11 23
          %12 = OpConstantTrue %7
+         %15 = OpTypeFunction %2 %16
+         %24 = OpTypeFunction %2 %16 %7
+         %31 = OpTypeStruct %7 %11
+         %32 = OpConstant %16 23
+         %33 = OpConstantComposite %31 %12 %8
+         %41 = OpTypeStruct %11 %16
+         %42 = OpConstantComposite %41 %8 %32
+         %43 = OpTypeFunction %2 %41
+         %44 = OpTypeFunction %2 %41 %7
           %4 = OpFunction %2 None %3
           %5 = OpLabel
          %13 = OpFunctionCall %7 %9 %12
                OpReturn
                OpFunctionEnd
+
+          ; adjust type of the function in-place
           %9 = OpFunction %7 None %6
          %14 = OpFunctionParameter %7
          %10 = OpLabel
                OpReturnValue %12
                OpFunctionEnd
+
+         ; reuse an existing function type
+         %17 = OpFunction %2 None %15
+         %18 = OpFunctionParameter %16
+         %19 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %20 = OpFunction %2 None %15
+         %21 = OpFunctionParameter %16
+         %22 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %25 = OpFunction %2 None %24
+         %26 = OpFunctionParameter %16
+         %27 = OpFunctionParameter %7
+         %28 = OpLabel
+               OpReturn
+               OpFunctionEnd
+
+         ; create a new function type
+         %29 = OpFunction %2 None %3
+         %30 = OpLabel
+               OpReturn
+               OpFunctionEnd
+
+         ; don't adjust the type of the function if it creates a duplicate
+         %34 = OpFunction %2 None %43
+         %35 = OpFunctionParameter %41
+         %36 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %37 = OpFunction %2 None %44
+         %38 = OpFunctionParameter %41
+         %39 = OpFunctionParameter %7
+         %40 = OpLabel
+               OpReturn
+               OpFunctionEnd
   )";
 
   const auto env = SPV_ENV_UNIVERSAL_1_3;
@@ -59,38 +108,58 @@
                                                validator_options);
 
   // Can't modify entry point function.
-  ASSERT_FALSE(TransformationAddParameter(4, 15, 12, 16)
+  ASSERT_FALSE(TransformationAddParameter(4, 60, 12, 61)
                    .IsApplicable(context.get(), transformation_context));
 
   // There is no function with result id 29.
-  ASSERT_FALSE(TransformationAddParameter(29, 15, 8, 16)
+  ASSERT_FALSE(TransformationAddParameter(60, 60, 8, 61)
                    .IsApplicable(context.get(), transformation_context));
 
   // Parameter id is not fresh.
-  ASSERT_FALSE(TransformationAddParameter(9, 14, 8, 16)
+  ASSERT_FALSE(TransformationAddParameter(9, 14, 8, 61)
                    .IsApplicable(context.get(), transformation_context));
 
   // Function type id is not fresh.
-  ASSERT_FALSE(TransformationAddParameter(9, 15, 8, 14)
+  ASSERT_FALSE(TransformationAddParameter(9, 60, 8, 14)
                    .IsApplicable(context.get(), transformation_context));
 
   // Function type id and parameter type id are equal.
-  ASSERT_FALSE(TransformationAddParameter(9, 15, 8, 15)
+  ASSERT_FALSE(TransformationAddParameter(9, 60, 8, 60)
                    .IsApplicable(context.get(), transformation_context));
 
   // Parameter's initializer doesn't exist.
-  ASSERT_FALSE(TransformationAddParameter(9, 15, 15, 16)
+  ASSERT_FALSE(TransformationAddParameter(9, 60, 60, 61)
                    .IsApplicable(context.get(), transformation_context));
 
-  // Correct transformation.
-  TransformationAddParameter correct(9, 15, 8, 16);
-  ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
-  correct.Apply(context.get(), &transformation_context);
-
-  // The module remains valid.
-  ASSERT_TRUE(IsValid(env, context.get()));
-
-  ASSERT_TRUE(fact_manager.IdIsIrrelevant(15));
+  // Correct transformations.
+  {
+    TransformationAddParameter correct(9, 60, 8, 61);
+    ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+    correct.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+    ASSERT_TRUE(fact_manager.IdIsIrrelevant(60));
+  }
+  {
+    TransformationAddParameter correct(17, 62, 12, 63);
+    ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+    correct.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+    ASSERT_TRUE(fact_manager.IdIsIrrelevant(62));
+  }
+  {
+    TransformationAddParameter correct(29, 64, 33, 65);
+    ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+    correct.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+    ASSERT_TRUE(fact_manager.IdIsIrrelevant(64));
+  }
+  {
+    TransformationAddParameter correct(34, 66, 12, 67);
+    ASSERT_TRUE(correct.IsApplicable(context.get(), transformation_context));
+    correct.Apply(context.get(), &transformation_context);
+    ASSERT_TRUE(IsValid(env, context.get()));
+    ASSERT_TRUE(fact_manager.IdIsIrrelevant(66));
+  }
 
   std::string expected_shader = R"(
                OpCapability Shader
@@ -103,21 +172,73 @@
           %2 = OpTypeVoid
           %7 = OpTypeBool
          %11 = OpTypeInt 32 1
+         %16 = OpTypeFloat 32
           %3 = OpTypeFunction %2
           %8 = OpConstant %11 23
          %12 = OpConstantTrue %7
+         %15 = OpTypeFunction %2 %16
+         %24 = OpTypeFunction %2 %16 %7
+         %31 = OpTypeStruct %7 %11
+         %32 = OpConstant %16 23
+         %33 = OpConstantComposite %31 %12 %8
+         %41 = OpTypeStruct %11 %16
+         %42 = OpConstantComposite %41 %8 %32
+         %44 = OpTypeFunction %2 %41 %7
           %6 = OpTypeFunction %7 %7 %11
+         %65 = OpTypeFunction %2 %31
           %4 = OpFunction %2 None %3
           %5 = OpLabel
          %13 = OpFunctionCall %7 %9 %12 %8
                OpReturn
                OpFunctionEnd
+
+          ; adjust type of the function in-place
           %9 = OpFunction %7 None %6
          %14 = OpFunctionParameter %7
-         %15 = OpFunctionParameter %11
+         %60 = OpFunctionParameter %11
          %10 = OpLabel
                OpReturnValue %12
                OpFunctionEnd
+
+         ; reuse an existing function type
+         %17 = OpFunction %2 None %24
+         %18 = OpFunctionParameter %16
+         %62 = OpFunctionParameter %7
+         %19 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %20 = OpFunction %2 None %15
+         %21 = OpFunctionParameter %16
+         %22 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %25 = OpFunction %2 None %24
+         %26 = OpFunctionParameter %16
+         %27 = OpFunctionParameter %7
+         %28 = OpLabel
+               OpReturn
+               OpFunctionEnd
+
+         ; create a new function type
+         %29 = OpFunction %2 None %65
+         %64 = OpFunctionParameter %31
+         %30 = OpLabel
+               OpReturn
+               OpFunctionEnd
+
+         ; don't adjust the type of the function if it creates a duplicate
+         %34 = OpFunction %2 None %44
+         %35 = OpFunctionParameter %41
+         %66 = OpFunctionParameter %7
+         %36 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %37 = OpFunction %2 None %44
+         %38 = OpFunctionParameter %41
+         %39 = OpFunctionParameter %7
+         %40 = OpLabel
+               OpReturn
+               OpFunctionEnd
   )";
 
   ASSERT_TRUE(IsEqual(env, expected_shader, context.get()));