spirv-fuzz: Refactoring and type-related fixes (#3144)

This change refactors some code for walking access chain indexes to
make it mirror the structure of other code (to improve readability in
the first instance and potentially enable a future refactoring to
extract common code), and fixes a problem related to module donation
and function types.
diff --git a/source/fuzz/fuzzer_pass_donate_modules.cpp b/source/fuzz/fuzzer_pass_donate_modules.cpp
index 33813d2..0587a50 100644
--- a/source/fuzz/fuzzer_pass_donate_modules.cpp
+++ b/source/fuzz/fuzzer_pass_donate_modules.cpp
@@ -299,43 +299,41 @@
       } break;
       case SpvOpTypeFunction: {
         // It is not OK to have multiple function types that use identical ids
-        // for their return and parameter types.  We thus first look for a
-        // matching function type in the recipient module and use the id of this
-        // type if a match is found.  Otherwise we add a remapped version of the
-        // function type.
+        // for their return and parameter types.  We thus go through all
+        // existing function types to look for a match.  We do not use the
+        // type manager here because we want to regard two function types that
+        // are structurally identical but that differ with respect to the
+        // actual ids used for pointer types as different.
+        //
+        // Example:
+        //
+        // %1 = OpTypeVoid
+        // %2 = OpTypeInt 32 0
+        // %3 = OpTypePointer Function %2
+        // %4 = OpTypePointer Function %2
+        // %5 = OpTypeFunction %1 %3
+        // %6 = OpTypeFunction %1 %4
+        //
+        // We regard %5 and %6 as distinct function types here, even though
+        // they both have the form "uint32* -> void"
 
-        // Build a sequence of types used as parameters for the function type.
-        std::vector<const opt::analysis::Type*> parameter_types;
-        // We start iterating at 1 because 0 is the function's return type.
-        for (uint32_t index = 1; index < type_or_value.NumInOperands();
-             index++) {
-          parameter_types.push_back(GetIRContext()->get_type_mgr()->GetType(
-              original_id_to_donated_id->at(
-                  type_or_value.GetSingleWordInOperand(index))));
+        std::vector<uint32_t> return_and_parameter_types;
+        for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
+          return_and_parameter_types.push_back(original_id_to_donated_id->at(
+              type_or_value.GetSingleWordInOperand(i)));
         }
-        // Make a type object corresponding to the function type.
-        opt::analysis::Function function_type(
-            GetIRContext()->get_type_mgr()->GetType(
-                original_id_to_donated_id->at(
-                    type_or_value.GetSingleWordInOperand(0))),
-            parameter_types);
-
-        // Check whether a function type corresponding to this this type object
-        // is already declared by the module.
-        auto function_type_id =
-            GetIRContext()->get_type_mgr()->GetId(&function_type);
-        if (function_type_id) {
-          // A suitable existing function was found - use its id.
-          new_result_id = function_type_id;
+        uint32_t existing_function_id = fuzzerutil::FindFunctionType(
+            GetIRContext(), return_and_parameter_types);
+        if (existing_function_id) {
+          new_result_id = existing_function_id;
         } else {
           // No match was found, so add a remapped version of the function type
           // to the module, with a fresh id.
           new_result_id = GetFuzzerContext()->GetFreshId();
           std::vector<uint32_t> argument_type_ids;
-          for (uint32_t index = 1; index < type_or_value.NumInOperands();
-               index++) {
+          for (uint32_t i = 1; i < type_or_value.NumInOperands(); i++) {
             argument_type_ids.push_back(original_id_to_donated_id->at(
-                type_or_value.GetSingleWordInOperand(index)));
+                type_or_value.GetSingleWordInOperand(i)));
           }
           ApplyTransformation(TransformationAddTypeFunction(
               new_result_id,
diff --git a/source/fuzz/fuzzer_util.cpp b/source/fuzz/fuzzer_util.cpp
index 82d761c..085246e 100644
--- a/source/fuzz/fuzzer_util.cpp
+++ b/source/fuzz/fuzzer_util.cpp
@@ -258,33 +258,36 @@
     auto should_be_composite_type =
         context->get_def_use_mgr()->GetDef(sub_object_type_id);
     assert(should_be_composite_type && "The type should exist.");
-    if (SpvOpTypeArray == should_be_composite_type->opcode()) {
-      auto array_length = GetArraySize(*should_be_composite_type, context);
-      if (array_length == 0 || index >= array_length) {
-        return 0;
+    switch (should_be_composite_type->opcode()) {
+      case SpvOpTypeArray: {
+        auto array_length = GetArraySize(*should_be_composite_type, context);
+        if (array_length == 0 || index >= array_length) {
+          return 0;
+        }
+        sub_object_type_id =
+            should_be_composite_type->GetSingleWordInOperand(0);
+        break;
       }
-      sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
-    } else if (SpvOpTypeMatrix == should_be_composite_type->opcode()) {
-      auto matrix_column_count =
-          should_be_composite_type->GetSingleWordInOperand(1);
-      if (index >= matrix_column_count) {
-        return 0;
+      case SpvOpTypeMatrix:
+      case SpvOpTypeVector: {
+        auto count = should_be_composite_type->GetSingleWordInOperand(1);
+        if (index >= count) {
+          return 0;
+        }
+        sub_object_type_id =
+            should_be_composite_type->GetSingleWordInOperand(0);
+        break;
       }
-      sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
-    } else if (SpvOpTypeStruct == should_be_composite_type->opcode()) {
-      if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
-        return 0;
+      case SpvOpTypeStruct: {
+        if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
+          return 0;
+        }
+        sub_object_type_id =
+            should_be_composite_type->GetSingleWordInOperand(index);
+        break;
       }
-      sub_object_type_id =
-          should_be_composite_type->GetSingleWordInOperand(index);
-    } else if (SpvOpTypeVector == should_be_composite_type->opcode()) {
-      auto vector_length = should_be_composite_type->GetSingleWordInOperand(1);
-      if (index >= vector_length) {
+      default:
         return 0;
-      }
-      sub_object_type_id = should_be_composite_type->GetSingleWordInOperand(0);
-    } else {
-      return 0;
     }
   }
   return sub_object_type_id;
@@ -347,6 +350,35 @@
   return result;
 }
 
+uint32_t FindFunctionType(opt::IRContext* ir_context,
+                          const std::vector<uint32_t>& type_ids) {
+  // Look through the existing types for a match.
+  for (auto& type_or_value : ir_context->types_values()) {
+    if (type_or_value.opcode() != SpvOpTypeFunction) {
+      // We are only interested in function types.
+      continue;
+    }
+    if (type_or_value.NumInOperands() != type_ids.size()) {
+      // Not a match: different numbers of arguments.
+      continue;
+    }
+    // Check whether the return type and argument types match.
+    bool input_operands_match = true;
+    for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
+      if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
+        input_operands_match = false;
+        break;
+      }
+    }
+    if (input_operands_match) {
+      // Everything matches.
+      return type_or_value.result_id();
+    }
+  }
+  // No match was found.
+  return 0;
+}
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/fuzzer_util.h b/source/fuzz/fuzzer_util.h
index cbd81cd..f0a2953 100644
--- a/source/fuzz/fuzzer_util.h
+++ b/source/fuzz/fuzzer_util.h
@@ -131,6 +131,12 @@
 // Returns true if and only if |block_id| is a merge block or continue target
 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id);
 
+// Returns the result id of an instruction of the form:
+//  %id = OpTypeFunction |type_ids|
+// or 0 if no such instruction exists.
+uint32_t FindFunctionType(opt::IRContext* ir_context,
+                          const std::vector<uint32_t>& type_ids);
+
 }  // namespace fuzzerutil
 
 }  // namespace fuzz
diff --git a/source/fuzz/transformation.cpp b/source/fuzz/transformation.cpp
index c7aae58..8037af1 100644
--- a/source/fuzz/transformation.cpp
+++ b/source/fuzz/transformation.cpp
@@ -16,6 +16,7 @@
 
 #include <cassert>
 
+#include "source/fuzz/fuzzer_util.h"
 #include "source/fuzz/transformation_add_constant_boolean.h"
 #include "source/fuzz/transformation_add_constant_composite.h"
 #include "source/fuzz/transformation_add_constant_scalar.h"
@@ -159,5 +160,18 @@
   return nullptr;
 }
 
+bool Transformation::CheckIdIsFreshAndNotUsedByThisTransformation(
+    uint32_t id, opt::IRContext* context,
+    std::set<uint32_t>* ids_used_by_this_transformation) {
+  if (!fuzzerutil::IsFreshId(context, id)) {
+    return false;
+  }
+  if (ids_used_by_this_transformation->count(id) != 0) {
+    return false;
+  }
+  ids_used_by_this_transformation->insert(id);
+  return true;
+}
+
 }  // namespace fuzz
 }  // namespace spvtools
diff --git a/source/fuzz/transformation.h b/source/fuzz/transformation.h
index c6b852f..dbe803f 100644
--- a/source/fuzz/transformation.h
+++ b/source/fuzz/transformation.h
@@ -83,6 +83,15 @@
   // representation of a transformation given by |message|.
   static std::unique_ptr<Transformation> FromMessage(
       const protobufs::Transformation& message);
+
+  // Helper that returns true if and only if (a) |id| is a fresh id for the
+  // module, and (b) |id| is not in |ids_used_by_this_transformation|, a set of
+  // ids already known to be in use by a transformation.  This is useful when
+  // checking id freshness for a transformation that uses many ids, all of which
+  // must be distinct.
+  static bool CheckIdIsFreshAndNotUsedByThisTransformation(
+      uint32_t id, opt::IRContext* context,
+      std::set<uint32_t>* ids_used_by_this_transformation);
 };
 
 }  // namespace fuzz
diff --git a/source/fuzz/transformation_outline_function.cpp b/source/fuzz/transformation_outline_function.cpp
index b50b9c5..1b308c4 100644
--- a/source/fuzz/transformation_outline_function.cpp
+++ b/source/fuzz/transformation_outline_function.cpp
@@ -368,20 +368,6 @@
   return result;
 }
 
-bool TransformationOutlineFunction::
-    CheckIdIsFreshAndNotUsedByThisTransformation(
-        uint32_t id, opt::IRContext* context,
-        std::set<uint32_t>* ids_used_by_this_transformation) const {
-  if (!fuzzerutil::IsFreshId(context, id)) {
-    return false;
-  }
-  if (ids_used_by_this_transformation->count(id) != 0) {
-    return false;
-  }
-  ids_used_by_this_transformation->insert(id);
-  return true;
-}
-
 std::vector<uint32_t> TransformationOutlineFunction::GetRegionInputIds(
     opt::IRContext* context, const std::set<opt::BasicBlock*>& region_set,
     opt::BasicBlock* region_exit_block) {
@@ -540,15 +526,16 @@
   // not exist there cannot already be a function type with this struct as its
   // return type.
   if (region_output_ids.empty()) {
+    std::vector<uint32_t> return_and_parameter_types;
     opt::analysis::Void void_type;
     return_type_id = context->get_type_mgr()->GetId(&void_type);
-    std::vector<const opt::analysis::Type*> argument_types;
+    return_and_parameter_types.push_back(return_type_id);
     for (auto id : region_input_ids) {
-      argument_types.push_back(context->get_type_mgr()->GetType(
-          context->get_def_use_mgr()->GetDef(id)->type_id()));
+      return_and_parameter_types.push_back(
+          context->get_def_use_mgr()->GetDef(id)->type_id());
     }
-    opt::analysis::Function function_type(&void_type, argument_types);
-    function_type_id = context->get_type_mgr()->GetId(&function_type);
+    function_type_id =
+        fuzzerutil::FindFunctionType(context, return_and_parameter_types);
   }
 
   // If no existing function type was found, we need to create one.
diff --git a/source/fuzz/transformation_outline_function.h b/source/fuzz/transformation_outline_function.h
index b59e660..43bdf3b 100644
--- a/source/fuzz/transformation_outline_function.h
+++ b/source/fuzz/transformation_outline_function.h
@@ -128,15 +128,6 @@
       opt::BasicBlock* region_exit_block);
 
  private:
-  // A helper method for the applicability check.  Returns true if and only if
-  // |id| is (a) a fresh id for the module, and (b) an id that has not
-  // previously been subject to this check.  We use this to check whether the
-  // ids given for the transformation are not only fresh but also different from
-  // one another.
-  bool CheckIdIsFreshAndNotUsedByThisTransformation(
-      uint32_t id, opt::IRContext* context,
-      std::set<uint32_t>* ids_used_by_this_transformation) const;
-
   // Ensures that the module's id bound is at least the maximum of any fresh id
   // associated with the transformation.
   void UpdateModuleIdBoundForFreshIds(
diff --git a/test/fuzz/fuzzer_pass_donate_modules_test.cpp b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
index 988b675..7342dd4 100644
--- a/test/fuzz/fuzzer_pass_donate_modules_test.cpp
+++ b/test/fuzz/fuzzer_pass_donate_modules_test.cpp
@@ -438,6 +438,58 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, recipient_context.get()));
 }
 
+TEST(FuzzerPassDonateModulesTest, DonateFunctionTypeWithDifferentPointers) {
+  std::string recipient_and_donor_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeInt 32 0
+          %7 = OpTypePointer Function %6
+          %8 = OpTypeFunction %2 %7
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+          %9 = OpVariable %7 Function
+         %10 = OpFunctionCall %2 %11 %9
+               OpReturn
+               OpFunctionEnd
+         %11 = OpFunction %2 None %8
+         %12 = OpFunctionParameter %7
+         %13 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto recipient_context = BuildModule(
+      env, consumer, recipient_and_donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+
+  const auto donor_context = BuildModule(
+      env, consumer, recipient_and_donor_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, donor_context.get()));
+
+  FactManager fact_manager;
+
+  FuzzerContext fuzzer_context(MakeUnique<PseudoRandomGenerator>(0).get(), 100);
+  protobufs::TransformationSequence transformation_sequence;
+
+  FuzzerPassDonateModules fuzzer_pass(recipient_context.get(), &fact_manager,
+                                      &fuzzer_context, &transformation_sequence,
+                                      {});
+
+  fuzzer_pass.DonateSingleModule(donor_context.get());
+
+  // We just check that the result is valid.  Checking to what it should be
+  // exactly equal to would be very fragile.
+  ASSERT_TRUE(IsValid(env, recipient_context.get()));
+}
+
 TEST(FuzzerPassDonateModulesTest, Miscellaneous1) {
   std::string recipient_shader = R"(
                OpCapability Shader
diff --git a/test/fuzz/transformation_outline_function_test.cpp b/test/fuzz/transformation_outline_function_test.cpp
index 4f828b6..5cd1437 100644
--- a/test/fuzz/transformation_outline_function_test.cpp
+++ b/test/fuzz/transformation_outline_function_test.cpp
@@ -2040,6 +2040,91 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationOutlineFunctionTest, Miscellaneous4) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %6 "main"
+               OpExecutionMode %6 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+         %21 = OpTypeBool
+        %100 = OpTypeInt 32 0
+        %101 = OpTypePointer Function %100
+        %102 = OpTypePointer Function %100
+        %103 = OpTypeFunction %2 %101
+          %6 = OpFunction %2 None %3
+          %7 = OpLabel
+        %104 = OpVariable %102 Function
+               OpBranch %80
+         %80 = OpLabel
+        %105 = OpLoad %100 %104
+               OpBranch %106
+        %106 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  FactManager fact_manager;
+
+  TransformationOutlineFunction transformation(
+      /*entry_block*/ 80,
+      /*exit_block*/ 106,
+      /*new_function_struct_return_type_id*/ 300,
+      /*new_function_type_id*/ 301,
+      /*new_function_id*/ 302,
+      /*new_function_region_entry_block*/ 304,
+      /*new_caller_result_id*/ 305,
+      /*new_callee_result_id*/ 306,
+      /*input_id_to_fresh_id*/ {{104, 307}},
+      /*output_id_to_fresh_id*/ {});
+
+  ASSERT_TRUE(transformation.IsApplicable(context.get(), fact_manager));
+  transformation.Apply(context.get(), &fact_manager);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %6 "main"
+               OpExecutionMode %6 OriginUpperLeft
+               OpSource ESSL 310
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+         %21 = OpTypeBool
+        %100 = OpTypeInt 32 0
+        %101 = OpTypePointer Function %100
+        %102 = OpTypePointer Function %100
+        %103 = OpTypeFunction %2 %101
+        %301 = OpTypeFunction %2 %102
+          %6 = OpFunction %2 None %3
+          %7 = OpLabel
+        %104 = OpVariable %102 Function
+               OpBranch %80
+         %80 = OpLabel
+        %305 = OpFunctionCall %2 %302 %104
+               OpReturn
+               OpFunctionEnd
+        %302 = OpFunction %2 None %301
+        %307 = OpFunctionParameter %102
+        %304 = OpLabel
+        %105 = OpLoad %100 %307
+               OpBranch %106
+        %106 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools