spirv-fuzz: Take care of OpPhi instructions when inlining (#3939)

Fixes #3938.
diff --git a/source/fuzz/transformation_inline_function.cpp b/source/fuzz/transformation_inline_function.cpp
index b2a6f09..f58b123 100644
--- a/source/fuzz/transformation_inline_function.cpp
+++ b/source/fuzz/transformation_inline_function.cpp
@@ -158,12 +158,28 @@
     AdaptInlinedInstruction(result_id_map, ir_context, inlined_instruction);
   }
 
+  // If the function call's successor block contains OpPhi instructions that
+  // refer to the block containing the call then these will need to be rewritten
+  // to instead refer to the block associated with "returning" from the inlined
+  // function, as this block will be the predecessor of what used to be the
+  // function call's successor block.  We look out for this block.
+  uint32_t new_return_block_id = 0;
+
   // Inline the |called_function| non-entry blocks.
   for (auto& block : *called_function) {
     if (&block == &*called_function->entry()) {
       continue;
     }
 
+    // Check whether this is the function's return block.  Take note if it is,
+    // so that OpPhi instructions in the successor of the original function call
+    // block can be re-written.
+    if (block.terminator()->IsReturn()) {
+      assert(new_return_block_id == 0 &&
+             "There should be only one return block.");
+      new_return_block_id = result_id_map.at(block.id());
+    }
+
     auto* cloned_block = block.Clone(ir_context);
     cloned_block = caller_function->InsertBasicBlockBefore(
         std::unique_ptr<opt::BasicBlock>(cloned_block), successor_block);
@@ -176,10 +192,31 @@
     }
   }
 
+  opt::BasicBlock* block_containing_function_call =
+      ir_context->get_instr_block(function_call_instruction);
+
+  assert(((new_return_block_id == 0) ==
+          called_function->entry()->terminator()->IsReturn()) &&
+         "We should have found a return block unless the function being "
+         "inlined returns in its first block.");
+  if (new_return_block_id != 0) {
+    // Rewrite any OpPhi instructions in the successor block so that they refer
+    // to the new return block instead of the block that originally contained
+    // the function call.
+    ir_context->get_def_use_mgr()->ForEachUse(
+        block_containing_function_call->id(),
+        [ir_context, new_return_block_id, successor_block](
+            opt::Instruction* use_instruction, uint32_t operand_index) {
+          if (use_instruction->opcode() == SpvOpPhi &&
+              ir_context->get_instr_block(use_instruction) == successor_block) {
+            use_instruction->SetOperand(operand_index, {new_return_block_id});
+          }
+        });
+  }
+
   // Removes the function call instruction and its block termination instruction
   // from |caller_function|.
-  ir_context->KillInst(
-      ir_context->get_instr_block(function_call_instruction)->terminator());
+  ir_context->KillInst(block_containing_function_call->terminator());
   ir_context->KillInst(function_call_instruction);
 
   // Since the SPIR-V module has changed, no analyses must be validated.
diff --git a/test/fuzz/transformation_inline_function_test.cpp b/test/fuzz/transformation_inline_function_test.cpp
index 092043d..4cd465f 100644
--- a/test/fuzz/transformation_inline_function_test.cpp
+++ b/test/fuzz/transformation_inline_function_test.cpp
@@ -1021,6 +1021,96 @@
   ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
 }
 
+TEST(TransformationInlineFunctionTest, OpPhiInBlockFollowingCall) {
+  // This test checks that if the block after the inlined function call has an
+  // OpPhi instruction and the called function contains multiple blocks then the
+  // OpPhi instruction gets updated to refer to the return block of the inlined
+  // function, since the block containing the call will no longer be a
+  // predecessor.
+
+  std::string reference_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+         %13 = OpTypeBool
+         %14 = OpConstantTrue %13
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+          %8 = OpFunctionCall %2 %6
+               OpBranch %11
+         %11 = OpLabel
+         %12 = OpPhi %13 %14 %10
+               OpReturn
+               OpFunctionEnd
+          %6 = OpFunction %2 None %3
+          %7 = OpLabel
+               OpBranch %20
+         %20 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context =
+      BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+  spvtools::ValidatorOptions validator_options;
+  ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+                                               kConsoleMessageConsumer));
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+  auto transformation = TransformationInlineFunction(8, {{7, 100}, {20, 101}});
+
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+
+  ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+                                               kConsoleMessageConsumer));
+
+  std::string variant_shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+         %13 = OpTypeBool
+         %14 = OpConstantTrue %13
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+               OpBranch %101
+        %101 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+         %12 = OpPhi %13 %14 %101
+               OpReturn
+               OpFunctionEnd
+          %6 = OpFunction %2 None %3
+          %7 = OpLabel
+               OpBranch %20
+         %20 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+                                               kConsoleMessageConsumer));
+  ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools