spirv-fuzz: Use overflow ids when inlining functions (#3880)

Fixes #3751.
diff --git a/source/fuzz/transformation_inline_function.cpp b/source/fuzz/transformation_inline_function.cpp
index 31f6fb3..cdcee1b 100644
--- a/source/fuzz/transformation_inline_function.cpp
+++ b/source/fuzz/transformation_inline_function.cpp
@@ -33,7 +33,8 @@
 }
 
 bool TransformationInlineFunction::IsApplicable(
-    opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
+    opt::IRContext* ir_context,
+    const TransformationContext& transformation_context) const {
   // The values in the |message_.result_id_map| must be all fresh and all
   // distinct.
   const auto result_id_map =
@@ -71,7 +72,8 @@
     // Since the entry block label will not be inlined, only the remaining
     // labels must have a corresponding value in the map.
     if (&block != &*called_function->entry() &&
-        !result_id_map.count(block.GetLabel()->result_id())) {
+        !result_id_map.count(block.id()) &&
+        !transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
       return false;
     }
 
@@ -81,7 +83,8 @@
       // If |instruction| has result id, then it must have a mapped id in
       // |result_id_map|.
       if (instruction.HasResultId() &&
-          !result_id_map.count(instruction.result_id())) {
+          !result_id_map.count(instruction.result_id()) &&
+          !transformation_context.GetOverflowIdSource()->HasOverflowIds()) {
         return false;
       }
     }
@@ -100,15 +103,37 @@
 }
 
 void TransformationInlineFunction::Apply(
-    opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
+    opt::IRContext* ir_context,
+    TransformationContext* transformation_context) const {
   auto* function_call_instruction =
       ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
   auto* caller_function =
       ir_context->get_instr_block(function_call_instruction)->GetParent();
   auto* called_function = fuzzerutil::FindFunction(
       ir_context, function_call_instruction->GetSingleWordInOperand(0));
-  const auto result_id_map =
+  std::map<uint32_t, uint32_t> result_id_map =
       fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
+
+  // If there are gaps in the result id map, fill them using overflow ids.
+  for (auto& block : *called_function) {
+    if (&block != &*called_function->entry() &&
+        !result_id_map.count(block.id())) {
+      result_id_map.insert(
+          {block.id(),
+           transformation_context->GetOverflowIdSource()->GetNextOverflowId()});
+    }
+    for (auto& instruction : block) {
+      // If |instruction| has result id, then it must have a mapped id in
+      // |result_id_map|.
+      if (instruction.HasResultId() &&
+          !result_id_map.count(instruction.result_id())) {
+        result_id_map.insert({instruction.result_id(),
+                              transformation_context->GetOverflowIdSource()
+                                  ->GetNextOverflowId()});
+      }
+    }
+  }
+
   auto* successor_block = ir_context->cfg()->block(
       ir_context->get_instr_block(function_call_instruction)
           ->terminator()
@@ -128,7 +153,7 @@
           MakeUnique<opt::Instruction>(entry_block_instruction));
     }
 
-    AdaptInlinedInstruction(ir_context, inlined_instruction);
+    AdaptInlinedInstruction(result_id_map, ir_context, inlined_instruction);
   }
 
   // Inline the |called_function| non-entry blocks.
@@ -141,13 +166,11 @@
     cloned_block = caller_function->InsertBasicBlockBefore(
         std::unique_ptr<opt::BasicBlock>(cloned_block), successor_block);
     cloned_block->SetParent(caller_function);
-    cloned_block->GetLabel()->SetResultId(
-        result_id_map.at(cloned_block->GetLabel()->result_id()));
-    fuzzerutil::UpdateModuleIdBound(ir_context,
-                                    cloned_block->GetLabel()->result_id());
+    cloned_block->GetLabel()->SetResultId(result_id_map.at(cloned_block->id()));
+    fuzzerutil::UpdateModuleIdBound(ir_context, cloned_block->id());
 
     for (auto& inlined_instruction : *cloned_block) {
-      AdaptInlinedInstruction(ir_context, &inlined_instruction);
+      AdaptInlinedInstruction(result_id_map, ir_context, &inlined_instruction);
     }
   }
 
@@ -202,14 +225,13 @@
 }
 
 void TransformationInlineFunction::AdaptInlinedInstruction(
+    const std::map<uint32_t, uint32_t>& result_id_map,
     opt::IRContext* ir_context,
     opt::Instruction* instruction_to_be_inlined) const {
   auto* function_call_instruction =
       ir_context->get_def_use_mgr()->GetDef(message_.function_call_id());
   auto* called_function = fuzzerutil::FindFunction(
       ir_context, function_call_instruction->GetSingleWordInOperand(0));
-  const auto result_id_map =
-      fuzzerutil::RepeatedUInt32PairToMap(message_.result_id_map());
 
   const auto* function_call_block =
       ir_context->get_instr_block(function_call_instruction);
diff --git a/source/fuzz/transformation_inline_function.h b/source/fuzz/transformation_inline_function.h
index 272024a..8105d92 100644
--- a/source/fuzz/transformation_inline_function.h
+++ b/source/fuzz/transformation_inline_function.h
@@ -33,7 +33,7 @@
       const std::map<uint32_t, uint32_t>& result_id_map);
 
   // - |message_.result_id_map| must map the instructions of the called function
-  //   to fresh ids.
+  //   to fresh ids, unless overflow ids are available.
   // - |message_.function_call_id| must be an OpFunctionCall instruction.
   //   It must not have an early return and must not use OpUnreachable or
   //   OpKill. This is to guard against making the module invalid when the
@@ -67,8 +67,9 @@
 
   // Inline |instruction_to_be_inlined| by setting its ids to the corresponding
   // ids in |result_id_map|.
-  void AdaptInlinedInstruction(opt::IRContext* ir_context,
-                               opt::Instruction* instruction) const;
+  void AdaptInlinedInstruction(
+      const std::map<uint32_t, uint32_t>& result_id_map,
+      opt::IRContext* ir_context, opt::Instruction* instruction) const;
 };
 
 }  // namespace fuzz
diff --git a/test/fuzz/transformation_inline_function_test.cpp b/test/fuzz/transformation_inline_function_test.cpp
index 89887be..09cf936 100644
--- a/test/fuzz/transformation_inline_function_test.cpp
+++ b/test/fuzz/transformation_inline_function_test.cpp
@@ -14,6 +14,7 @@
 
 #include "source/fuzz/transformation_inline_function.h"
 
+#include "source/fuzz/counter_overflow_id_source.h"
 #include "source/fuzz/instruction_descriptor.h"
 #include "test/fuzz/fuzz_test_util.h"
 
@@ -533,6 +534,7 @@
   ASSERT_FALSE(
       transformation.IsApplicable(context.get(), transformation_context));
 
+#ifndef NDEBUG
   // Tests the id of the returned value not included in the id map.
   transformation = TransformationInlineFunction(25, {{56, 69},
                                                      {57, 70},
@@ -544,8 +546,10 @@
                                                      {64, 76},
                                                      {65, 77},
                                                      {66, 78}});
-  ASSERT_FALSE(
-      transformation.IsApplicable(context.get(), transformation_context));
+  ASSERT_DEATH(
+      transformation.IsApplicable(context.get(), transformation_context),
+      "Bad attempt to query whether overflow ids are available.");
+#endif
 
   transformation = TransformationInlineFunction(25, {{57, 69},
                                                      {58, 70},
@@ -819,6 +823,198 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationInlineFunctionTest, OverflowIds) {
+  std::string reference_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %39 "main"
+
+; Types
+          %2 = OpTypeFloat 32
+          %3 = OpTypeVector %2 4
+          %4 = OpTypePointer Function %3
+          %5 = OpTypeVoid
+          %6 = OpTypeFunction %5
+          %7 = OpTypeFunction %2 %4 %4
+
+; Constant scalars
+          %8 = OpConstant %2 1
+          %9 = OpConstant %2 2
+         %10 = OpConstant %2 3
+         %11 = OpConstant %2 4
+         %12 = OpConstant %2 5
+         %13 = OpConstant %2 6
+         %14 = OpConstant %2 7
+         %15 = OpConstant %2 8
+
+; Constant vectors
+         %16 = OpConstantComposite %3 %8 %9 %10 %11
+         %17 = OpConstantComposite %3 %12 %13 %14 %15
+
+; dot product function
+         %18 = OpFunction %2 None %7
+         %19 = OpFunctionParameter %4
+         %20 = OpFunctionParameter %4
+         %21 = OpLabel
+         %22 = OpLoad %3 %19
+         %23 = OpLoad %3 %20
+         %24 = OpCompositeExtract %2 %22 0
+         %25 = OpCompositeExtract %2 %23 0
+         %26 = OpFMul %2 %24 %25
+         %27 = OpCompositeExtract %2 %22 1
+         %28 = OpCompositeExtract %2 %23 1
+         %29 = OpFMul %2 %27 %28
+               OpBranch %100
+        %100 = OpLabel
+         %30 = OpCompositeExtract %2 %22 2
+         %31 = OpCompositeExtract %2 %23 2
+         %32 = OpFMul %2 %30 %31
+         %33 = OpCompositeExtract %2 %22 3
+         %34 = OpCompositeExtract %2 %23 3
+         %35 = OpFMul %2 %33 %34
+         %36 = OpFAdd %2 %26 %29
+         %37 = OpFAdd %2 %32 %36
+         %38 = OpFAdd %2 %35 %37
+               OpReturnValue %38
+               OpFunctionEnd
+
+; main function
+         %39 = OpFunction %5 None %6
+         %40 = OpLabel
+         %41 = OpVariable %4 Function
+         %42 = OpVariable %4 Function
+               OpStore %41 %16
+               OpStore %42 %17
+         %43 = OpFunctionCall %2 %18 %41 %42 ; dot product function call
+               OpBranch %44
+         %44 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context =
+      BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  auto overflow_ids_unique_ptr = MakeUnique<CounterOverflowIdSource>(1000);
+  auto overflow_ids_ptr = overflow_ids_unique_ptr.get();
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options,
+      std::move(overflow_ids_unique_ptr));
+  auto transformation = TransformationInlineFunction(43, {{22, 45},
+                                                          {23, 46},
+                                                          {24, 47},
+                                                          {25, 48},
+                                                          {26, 49},
+                                                          {27, 50},
+                                                          {28, 51},
+                                                          {29, 52}});
+
+  // The following ids are left un-mapped; overflow ids will be required for
+  // them: 30, 31, 32, 33, 34, 35, 36, 37, 38, 100
+
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context,
+                        overflow_ids_ptr->GetIssuedOverflowIds());
+
+  std::string variant_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %39 "main"
+
+; Types
+          %2 = OpTypeFloat 32
+          %3 = OpTypeVector %2 4
+          %4 = OpTypePointer Function %3
+          %5 = OpTypeVoid
+          %6 = OpTypeFunction %5
+          %7 = OpTypeFunction %2 %4 %4
+
+; Constant scalars
+          %8 = OpConstant %2 1
+          %9 = OpConstant %2 2
+         %10 = OpConstant %2 3
+         %11 = OpConstant %2 4
+         %12 = OpConstant %2 5
+         %13 = OpConstant %2 6
+         %14 = OpConstant %2 7
+         %15 = OpConstant %2 8
+
+; Constant vectors
+         %16 = OpConstantComposite %3 %8 %9 %10 %11
+         %17 = OpConstantComposite %3 %12 %13 %14 %15
+
+; dot product function
+         %18 = OpFunction %2 None %7
+         %19 = OpFunctionParameter %4
+         %20 = OpFunctionParameter %4
+         %21 = OpLabel
+         %22 = OpLoad %3 %19
+         %23 = OpLoad %3 %20
+         %24 = OpCompositeExtract %2 %22 0
+         %25 = OpCompositeExtract %2 %23 0
+         %26 = OpFMul %2 %24 %25
+         %27 = OpCompositeExtract %2 %22 1
+         %28 = OpCompositeExtract %2 %23 1
+         %29 = OpFMul %2 %27 %28
+               OpBranch %100
+        %100 = OpLabel
+         %30 = OpCompositeExtract %2 %22 2
+         %31 = OpCompositeExtract %2 %23 2
+         %32 = OpFMul %2 %30 %31
+         %33 = OpCompositeExtract %2 %22 3
+         %34 = OpCompositeExtract %2 %23 3
+         %35 = OpFMul %2 %33 %34
+         %36 = OpFAdd %2 %26 %29
+         %37 = OpFAdd %2 %32 %36
+         %38 = OpFAdd %2 %35 %37
+               OpReturnValue %38
+               OpFunctionEnd
+
+; main function
+         %39 = OpFunction %5 None %6
+         %40 = OpLabel
+         %41 = OpVariable %4 Function
+         %42 = OpVariable %4 Function
+               OpStore %41 %16
+               OpStore %42 %17
+         %45 = OpLoad %3 %41
+         %46 = OpLoad %3 %42
+         %47 = OpCompositeExtract %2 %45 0
+         %48 = OpCompositeExtract %2 %46 0
+         %49 = OpFMul %2 %47 %48
+         %50 = OpCompositeExtract %2 %45 1
+         %51 = OpCompositeExtract %2 %46 1
+         %52 = OpFMul %2 %50 %51
+               OpBranch %1000
+       %1000 = OpLabel
+       %1001 = OpCompositeExtract %2 %45 2
+       %1002 = OpCompositeExtract %2 %46 2
+       %1003 = OpFMul %2 %1001 %1002
+       %1004 = OpCompositeExtract %2 %45 3
+       %1005 = OpCompositeExtract %2 %46 3
+       %1006 = OpFMul %2 %1004 %1005
+       %1007 = OpFAdd %2 %49 %52
+       %1008 = OpFAdd %2 %1003 %1007
+       %1009 = OpFAdd %2 %1006 %1008
+         %43 = OpCopyObject %2 %1009
+               OpBranch %44
+         %44 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(IsValid(env, context.get()));
+  ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools