Fix merge return in the face of breaks (#2466)
Fixes #2453
* Enable addition of OpPhi instructions when the loop has multiple
predecessors of the merge due to a break
* This can result in some values no longer dominating their uses
* Track return blocks in structured flow to produce OpPhis that have
multiple undef and non-undef arguments
* New tests to catch the bug
* When a block is predicated, mark the new body as a return if the old
block as already a return
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index afbb900..a12b2ca 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -197,6 +197,7 @@
tail_opcode == SpvOpUnreachable) {
assert(CurrentState().InLoop() && "Should be in the dummy loop.");
BranchToBlock(block, CurrentState().LoopMergeId());
+ return_blocks_.insert(block->id());
}
}
@@ -232,11 +233,19 @@
const auto& target_pred = cfg()->preds(target->id());
if (target_pred.size() == 1) {
MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0]));
+ } else {
+ // If the loop contained a break and a return, OpPhi instructions may be
+ // required starting from the dominator of the loop merge.
+ DominatorAnalysis* dom_tree =
+ context()->GetDominatorAnalysis(target->GetParent());
+ auto idom = dom_tree->ImmediateDominator(target);
+ if (idom) {
+ MarkForNewPhiNodes(target, idom);
+ }
}
}
void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
- uint32_t predecessor,
Instruction& inst) {
DominatorAnalysis* dom_tree =
context()->GetDominatorAnalysis(merge_block->GetParent());
@@ -281,17 +290,16 @@
uint32_t undef_id = Type2Undef(inst.type_id());
std::vector<uint32_t> phi_operands;
- // Add the operands for the defining instructions.
- phi_operands.push_back(inst.result_id());
- phi_operands.push_back(predecessor);
-
- // Add undef from all other blocks.
+ // Add the OpPhi operands. If the predecessor is a return block use undef,
+ // otherwise use |inst|'s id.
std::vector<uint32_t> preds = cfg()->preds(merge_block->id());
for (uint32_t pred_id : preds) {
- if (pred_id != predecessor) {
+ if (return_blocks_.count(pred_id)) {
phi_operands.push_back(undef_id);
- phi_operands.push_back(pred_id);
+ } else {
+ phi_operands.push_back(inst.result_id());
}
+ phi_operands.push_back(pred_id);
}
Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands);
@@ -400,8 +408,14 @@
// Forget about the edges leaving block. They will be removed.
cfg()->RemoveSuccessorEdges(block);
- BasicBlock* old_body = block->SplitBasicBlock(context(), TakeNextId(), iter);
+ auto old_body_id = TakeNextId();
+ BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
predicated->insert(old_body);
+ // If a return block is being split, mark the new body block also as a return
+ // block.
+ if (return_blocks_.count(block->id())) {
+ return_blocks_.insert(old_body_id);
+ }
// If |block| was a continue target for a loop |old_body| is now the correct
// continue target.
@@ -660,7 +674,7 @@
BasicBlock* current_bb = pred;
while (current_bb != nullptr && current_bb->id() != header_id) {
for (Instruction& inst : *current_bb) {
- CreatePhiNodesForInst(bb, pred->id(), inst);
+ CreatePhiNodesForInst(bb, inst);
}
current_bb = dom_tree->ImmediateDominator(current_bb);
}
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index d7e18f0..63094b7 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -240,12 +240,11 @@
// return block at the end of the pass.
void CreateReturnBlock();
- // Creates a Phi node in |merge_block| for the result of |inst| coming from
- // |predecessor|. Any uses of the result of |inst| that are no longer
+ // Creates a Phi node in |merge_block| for the result of |inst|.
+ // Any uses of the result of |inst| that are no longer
// dominated by |inst|, are replaced with the result of the new |OpPhi|
// instruction.
- void CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor,
- Instruction& inst);
+ void CreatePhiNodesForInst(BasicBlock* merge_block, Instruction& inst);
// Traverse the nodes in |new_merge_nodes_|, and adds the OpPhi instructions
// that are needed to make the code correct. It is assumed that at this point
@@ -331,6 +330,11 @@
// values that will need a phi on the new edges.
std::unordered_map<BasicBlock*, BasicBlock*> new_merge_nodes_;
bool HasNontrivialUnreachableBlocks(Function* function);
+
+ // Contains all return blocks that are merged. This is set is populated while
+ // processing structured blocks and used to properly construct OpPhi
+ // instructions.
+ std::unordered_set<uint32_t> return_blocks_;
};
} // namespace opt
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index f985c89..2f2e74a 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1206,7 +1206,7 @@
; CHECK: [[bb:%\w+]] = OpLabel
; CHECK-NEXT: [[val:%\w+]] = OpUndef %bool
; CHECK: [[merge]] = OpLabel
-; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool [[val]] [[bb]] {{%\w+}} [[old_ret_block]]
+; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool {{%\w+}} [[old_ret_block]] [[val]] [[bb]]
; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb2:%\w+]]
; CHECK: [[bb2]] = OpLabel
; CHECK: OpBranch [[header2:%\w+]]
@@ -1263,7 +1263,7 @@
; CHECK: [[continue]] = OpLabel
; CHECK-NEXT: [[undef:%\w+]] = OpUndef
; CHECK: [[merge]] = OpLabel
- ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool [[undef]] [[continue]] {{%\w+}} {{%\w+}}
+ ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool {{%\w+}} {{%\w+}} [[undef]] [[continue]]
; CHECK: OpCopyObject %bool [[phi]]
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
@@ -1328,7 +1328,7 @@
; CHECK: OpLoopMerge [[merge:%\w+]]
; CHECK: [[def:%\w+]] = OpFOrdLessThan
; CHECK: [[merge]] = OpLabel
-; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[def]]
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[def]]
; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
; CHECK: [[cont]] = OpLabel
; CHECK-NEXT: OpBranchConditional [[phi]]
@@ -1480,6 +1480,198 @@
SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
SinglePassRunAndMatch<MergeReturnPass>(before, false);
}
+
+TEST_F(MergeReturnPassTest, BreakFromLoopUseNoLongerDominated) {
+ const std::string spirv = R"(
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: OpLoopMerge
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
+; CHECK-NEXT: OpBranch [[body:%\w+]]
+; CHECK: [[body]] = OpLabel
+; CHECK-NEXT: OpSelectionMerge [[non_ret:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret:%\w+]] [[non_ret]]
+; CHECK: [[ret]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[non_ret]] = OpLabel
+; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]]
+; CHECK: [[break]] = OpLabel
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[cont]] = OpLabel
+; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret]] [[def]] [[break]] [[def]] [[cont]]
+; CHECK: OpLogicalNot {{%\w+}} [[phi]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %func "func"
+OpExecutionMode %func LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%func = OpFunction %void None %void_fn
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %8 %7 None
+OpBranch %3
+%3 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %true %4 %5
+%4 = OpLabel
+OpReturn
+%5 = OpLabel
+%def = OpLogicalNot %bool %true
+OpBranchConditional %true %6 %7
+%6 = OpLabel
+OpBranch %8
+%7 = OpLabel
+OpBranchConditional %true %2 %8
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+%use = OpLogicalNot %bool %def
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
+TEST_F(MergeReturnPassTest, TwoBreaksFromLoopUsesNoLongerDominated) {
+ const std::string spirv = R"(
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: OpLoopMerge
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
+; CHECK-NEXT: OpBranch [[body:%\w+]]
+; CHECK: [[body]] = OpLabel
+; CHECK-NEXT: OpSelectionMerge [[body2:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret1:%\w+]] [[body2]]
+; CHECK: [[ret1]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[body2]] = OpLabel
+; CHECK-NEXT: [[def1:%\w+]] = OpLogicalNot
+; CHECK-NEXT: OpSelectionMerge [[body3:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret2:%\w+]] [[body3:%\w+]]
+; CHECK: [[ret2]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[body3]] = OpLabel
+; CHECK-NEXT: [[def2:%\w+]] = OpLogicalAnd
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]]
+; CHECK: [[break]] = OpLabel
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[cont]] = OpLabel
+; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi1:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def1]] [[break]] [[def1]] [[cont]]
+; CHECK-NEXT: [[phi2:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def2]] [[break]] [[def2]] [[cont]]
+; CHECK: OpLogicalNot {{%\w+}} [[phi1]]
+; CHECK: OpLogicalAnd {{%\w+}} [[phi2]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %func "func"
+OpExecutionMode %func LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%func = OpFunction %void None %void_fn
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %10 %9 None
+OpBranch %3
+%3 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %true %4 %5
+%4 = OpLabel
+OpReturn
+%5 = OpLabel
+%def1 = OpLogicalNot %bool %true
+OpSelectionMerge %7 None
+OpBranchConditional %true %6 %7
+%6 = OpLabel
+OpReturn
+%7 = OpLabel
+%def2 = OpLogicalAnd %bool %true %true
+OpBranchConditional %true %8 %9
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranchConditional %true %2 %10
+%10 = OpLabel
+OpBranch %11
+%11 = OpLabel
+%use1 = OpLogicalNot %bool %def1
+%use2 = OpLogicalAnd %bool %def2 %true
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
+TEST_F(MergeReturnPassTest, PredicateBreakBlock) {
+ const std::string spirv = R"(
+; IDs are being preserved so we can rely on basic block labels.
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: %13 = OpLabel
+; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot
+; CHECK: %8 = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] {{%\w+}} [[undef]] {{%\w+}} [[def]] %13 [[undef]] {{%\w+}}
+; CHECK: OpLogicalAnd {{%\w+}} [[phi]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %1 "func"
+OpExecutionMode %1 LocalSize 1 1 1
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpUndef %bool
+%1 = OpFunction %void None %3
+%6 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpLoopMerge %8 %9 None
+OpBranch %10
+%10 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %true %12 %13
+%12 = OpLabel
+OpLoopMerge %14 %15 None
+OpBranch %16
+%16 = OpLabel
+OpReturn
+%15 = OpLabel
+OpBranch %12
+%14 = OpLabel
+OpUnreachable
+%13 = OpLabel
+%17 = OpLogicalNot %bool %true
+OpBranch %8
+%11 = OpLabel
+OpUnreachable
+%9 = OpLabel
+OpBranch %7
+%8 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpLogicalAnd %bool %17 %true
+OpReturn
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools