spirv-fuzz: Take care of OpPhi instructions when inlining (#3939)
Fixes #3938.
diff --git a/source/fuzz/transformation_inline_function.cpp b/source/fuzz/transformation_inline_function.cpp
index b2a6f09..f58b123 100644
--- a/source/fuzz/transformation_inline_function.cpp
+++ b/source/fuzz/transformation_inline_function.cpp
@@ -158,12 +158,28 @@
AdaptInlinedInstruction(result_id_map, ir_context, inlined_instruction);
}
+ // If the function call's successor block contains OpPhi instructions that
+ // refer to the block containing the call then these will need to be rewritten
+ // to instead refer to the block associated with "returning" from the inlined
+ // function, as this block will be the predecessor of what used to be the
+ // function call's successor block. We look out for this block.
+ uint32_t new_return_block_id = 0;
+
// Inline the |called_function| non-entry blocks.
for (auto& block : *called_function) {
if (&block == &*called_function->entry()) {
continue;
}
+ // Check whether this is the function's return block. Take note if it is,
+ // so that OpPhi instructions in the successor of the original function call
+ // block can be re-written.
+ if (block.terminator()->IsReturn()) {
+ assert(new_return_block_id == 0 &&
+ "There should be only one return block.");
+ new_return_block_id = result_id_map.at(block.id());
+ }
+
auto* cloned_block = block.Clone(ir_context);
cloned_block = caller_function->InsertBasicBlockBefore(
std::unique_ptr<opt::BasicBlock>(cloned_block), successor_block);
@@ -176,10 +192,31 @@
}
}
+ opt::BasicBlock* block_containing_function_call =
+ ir_context->get_instr_block(function_call_instruction);
+
+ assert(((new_return_block_id == 0) ==
+ called_function->entry()->terminator()->IsReturn()) &&
+ "We should have found a return block unless the function being "
+ "inlined returns in its first block.");
+ if (new_return_block_id != 0) {
+ // Rewrite any OpPhi instructions in the successor block so that they refer
+ // to the new return block instead of the block that originally contained
+ // the function call.
+ ir_context->get_def_use_mgr()->ForEachUse(
+ block_containing_function_call->id(),
+ [ir_context, new_return_block_id, successor_block](
+ opt::Instruction* use_instruction, uint32_t operand_index) {
+ if (use_instruction->opcode() == SpvOpPhi &&
+ ir_context->get_instr_block(use_instruction) == successor_block) {
+ use_instruction->SetOperand(operand_index, {new_return_block_id});
+ }
+ });
+ }
+
// Removes the function call instruction and its block termination instruction
// from |caller_function|.
- ir_context->KillInst(
- ir_context->get_instr_block(function_call_instruction)->terminator());
+ ir_context->KillInst(block_containing_function_call->terminator());
ir_context->KillInst(function_call_instruction);
// Since the SPIR-V module has changed, no analyses must be validated.
diff --git a/test/fuzz/transformation_inline_function_test.cpp b/test/fuzz/transformation_inline_function_test.cpp
index 092043d..4cd465f 100644
--- a/test/fuzz/transformation_inline_function_test.cpp
+++ b/test/fuzz/transformation_inline_function_test.cpp
@@ -1021,6 +1021,96 @@
ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
}
+TEST(TransformationInlineFunctionTest, OpPhiInBlockFollowingCall) {
+ // This test checks that if the block after the inlined function call has an
+ // OpPhi instruction and the called function contains multiple blocks then the
+ // OpPhi instruction gets updated to refer to the return block of the inlined
+ // function, since the block containing the call will no longer be a
+ // predecessor.
+
+ std::string reference_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %13 = OpTypeBool
+ %14 = OpConstantTrue %13
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ %8 = OpFunctionCall %2 %6
+ OpBranch %11
+ %11 = OpLabel
+ %12 = OpPhi %13 %14 %10
+ OpReturn
+ OpFunctionEnd
+ %6 = OpFunction %2 None %3
+ %7 = OpLabel
+ OpBranch %20
+ %20 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_5;
+ const auto consumer = nullptr;
+ const auto context =
+ BuildModule(env, consumer, reference_shader, kFuzzAssembleOption);
+ spvtools::ValidatorOptions validator_options;
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+ kConsoleMessageConsumer));
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+ auto transformation = TransformationInlineFunction(8, {{7, 100}, {20, 101}});
+
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+
+ ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+ kConsoleMessageConsumer));
+
+ std::string variant_shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %13 = OpTypeBool
+ %14 = OpConstantTrue %13
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpBranch %101
+ %101 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %12 = OpPhi %13 %14 %101
+ OpReturn
+ OpFunctionEnd
+ %6 = OpFunction %2 None %3
+ %7 = OpLabel
+ OpBranch %20
+ %20 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+ kConsoleMessageConsumer));
+ ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
+}
+
} // namespace
} // namespace fuzz
} // namespace spvtools