Fix merge return in the face of breaks (#2466)

Fixes #2453

* Enable addition of OpPhi instructions when the loop has multiple
predecessors of the merge due to a break
 * This can result in some values no longer dominating their uses
* Track return blocks in structured flow to produce OpPhis that have
multiple undef and non-undef arguments
* New tests to catch the bug
* When a block is predicated, mark the new body as a return if the old
block as already a return
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index afbb900..a12b2ca 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -197,6 +197,7 @@
       tail_opcode == SpvOpUnreachable) {
     assert(CurrentState().InLoop() && "Should be in the dummy loop.");
     BranchToBlock(block, CurrentState().LoopMergeId());
+    return_blocks_.insert(block->id());
   }
 }
 
@@ -232,11 +233,19 @@
   const auto& target_pred = cfg()->preds(target->id());
   if (target_pred.size() == 1) {
     MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0]));
+  } else {
+    // If the loop contained a break and a return, OpPhi instructions may be
+    // required starting from the dominator of the loop merge.
+    DominatorAnalysis* dom_tree =
+        context()->GetDominatorAnalysis(target->GetParent());
+    auto idom = dom_tree->ImmediateDominator(target);
+    if (idom) {
+      MarkForNewPhiNodes(target, idom);
+    }
   }
 }
 
 void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
-                                            uint32_t predecessor,
                                             Instruction& inst) {
   DominatorAnalysis* dom_tree =
       context()->GetDominatorAnalysis(merge_block->GetParent());
@@ -281,17 +290,16 @@
     uint32_t undef_id = Type2Undef(inst.type_id());
     std::vector<uint32_t> phi_operands;
 
-    // Add the operands for the defining instructions.
-    phi_operands.push_back(inst.result_id());
-    phi_operands.push_back(predecessor);
-
-    // Add undef from all other blocks.
+    // Add the OpPhi operands. If the predecessor is a return block use undef,
+    // otherwise use |inst|'s id.
     std::vector<uint32_t> preds = cfg()->preds(merge_block->id());
     for (uint32_t pred_id : preds) {
-      if (pred_id != predecessor) {
+      if (return_blocks_.count(pred_id)) {
         phi_operands.push_back(undef_id);
-        phi_operands.push_back(pred_id);
+      } else {
+        phi_operands.push_back(inst.result_id());
       }
+      phi_operands.push_back(pred_id);
     }
 
     Instruction* new_phi = builder.AddPhi(inst.type_id(), phi_operands);
@@ -400,8 +408,14 @@
   // Forget about the edges leaving block.  They will be removed.
   cfg()->RemoveSuccessorEdges(block);
 
-  BasicBlock* old_body = block->SplitBasicBlock(context(), TakeNextId(), iter);
+  auto old_body_id = TakeNextId();
+  BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
   predicated->insert(old_body);
+  // If a return block is being split, mark the new body block also as a return
+  // block.
+  if (return_blocks_.count(block->id())) {
+    return_blocks_.insert(old_body_id);
+  }
 
   // If |block| was a continue target for a loop |old_body| is now the correct
   // continue target.
@@ -660,7 +674,7 @@
   BasicBlock* current_bb = pred;
   while (current_bb != nullptr && current_bb->id() != header_id) {
     for (Instruction& inst : *current_bb) {
-      CreatePhiNodesForInst(bb, pred->id(), inst);
+      CreatePhiNodesForInst(bb, inst);
     }
     current_bb = dom_tree->ImmediateDominator(current_bb);
   }
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index d7e18f0..63094b7 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -240,12 +240,11 @@
   // return block at the end of the pass.
   void CreateReturnBlock();
 
-  // Creates a Phi node in |merge_block| for the result of |inst| coming from
-  // |predecessor|.  Any uses of the result of |inst| that are no longer
+  // Creates a Phi node in |merge_block| for the result of |inst|.
+  // Any uses of the result of |inst| that are no longer
   // dominated by |inst|, are replaced with the result of the new |OpPhi|
   // instruction.
-  void CreatePhiNodesForInst(BasicBlock* merge_block, uint32_t predecessor,
-                             Instruction& inst);
+  void CreatePhiNodesForInst(BasicBlock* merge_block, Instruction& inst);
 
   // Traverse the nodes in |new_merge_nodes_|, and adds the OpPhi instructions
   // that are needed to make the code correct.  It is assumed that at this point
@@ -331,6 +330,11 @@
   // values that will need a phi on the new edges.
   std::unordered_map<BasicBlock*, BasicBlock*> new_merge_nodes_;
   bool HasNontrivialUnreachableBlocks(Function* function);
+
+  // Contains all return blocks that are merged. This is set is populated while
+  // processing structured blocks and used to properly construct OpPhi
+  // instructions.
+  std::unordered_set<uint32_t> return_blocks_;
 };
 
 }  // namespace opt
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index f985c89..2f2e74a 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1206,7 +1206,7 @@
 ; CHECK: [[bb:%\w+]] = OpLabel
 ; CHECK-NEXT: [[val:%\w+]] = OpUndef %bool
 ; CHECK: [[merge]] = OpLabel
-; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool [[val]] [[bb]] {{%\w+}} [[old_ret_block]]
+; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool {{%\w+}} [[old_ret_block]] [[val]] [[bb]]
 ; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb2:%\w+]]
 ; CHECK: [[bb2]] = OpLabel
 ; CHECK: OpBranch [[header2:%\w+]]
@@ -1263,7 +1263,7 @@
       ; CHECK: [[continue]] = OpLabel
       ; CHECK-NEXT: [[undef:%\w+]] = OpUndef
       ; CHECK: [[merge]] = OpLabel
-      ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool [[undef]] [[continue]] {{%\w+}} {{%\w+}}
+      ; CHECK-NEXT: [[phi:%\w+]] = OpPhi %bool {{%\w+}} {{%\w+}} [[undef]] [[continue]]
       ; CHECK: OpCopyObject %bool [[phi]]
                OpCapability Shader
           %1 = OpExtInstImport "GLSL.std.450"
@@ -1328,7 +1328,7 @@
 ; CHECK: OpLoopMerge [[merge:%\w+]]
 ; CHECK: [[def:%\w+]] = OpFOrdLessThan
 ; CHECK: [[merge]] = OpLabel
-; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[def]]
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} {{%\w+}} {{%\w+}} [[def]]
 ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
 ; CHECK: [[cont]] = OpLabel
 ; CHECK-NEXT: OpBranchConditional [[phi]]
@@ -1480,6 +1480,198 @@
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   SinglePassRunAndMatch<MergeReturnPass>(before, false);
 }
+
+TEST_F(MergeReturnPassTest, BreakFromLoopUseNoLongerDominated) {
+  const std::string spirv = R"(
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: OpLoopMerge
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
+; CHECK-NEXT: OpBranch [[body:%\w+]]
+; CHECK: [[body]] = OpLabel
+; CHECK-NEXT: OpSelectionMerge [[non_ret:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret:%\w+]] [[non_ret]]
+; CHECK: [[ret]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[non_ret]] = OpLabel
+; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]]
+; CHECK: [[break]] = OpLabel
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[cont]] = OpLabel
+; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret]] [[def]] [[break]] [[def]] [[cont]]
+; CHECK: OpLogicalNot {{%\w+}} [[phi]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %func "func"
+OpExecutionMode %func LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%func = OpFunction %void None %void_fn
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %8 %7 None
+OpBranch %3
+%3 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %true %4 %5
+%4 = OpLabel
+OpReturn
+%5 = OpLabel
+%def = OpLogicalNot %bool %true
+OpBranchConditional %true %6 %7
+%6 = OpLabel
+OpBranch %8
+%7 = OpLabel
+OpBranchConditional %true %2 %8
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+%use = OpLogicalNot %bool %def
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
+TEST_F(MergeReturnPassTest, TwoBreaksFromLoopUsesNoLongerDominated) {
+  const std::string spirv = R"(
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: OpLoopMerge
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
+; CHECK-NEXT: OpBranch [[body:%\w+]]
+; CHECK: [[body]] = OpLabel
+; CHECK-NEXT: OpSelectionMerge [[body2:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret1:%\w+]] [[body2]]
+; CHECK: [[ret1]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[body2]] = OpLabel
+; CHECK-NEXT: [[def1:%\w+]] = OpLogicalNot
+; CHECK-NEXT: OpSelectionMerge [[body3:%\w+]]
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[ret2:%\w+]] [[body3:%\w+]]
+; CHECK: [[ret2]] = OpLabel
+; CHECK-NEXT: OpStore
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[body3]] = OpLabel
+; CHECK-NEXT: [[def2:%\w+]] = OpLogicalAnd
+; CHECK-NEXT: OpBranchConditional {{%\w+}} [[break:%\w+]] [[cont]]
+; CHECK: [[break]] = OpLabel
+; CHECK-NEXT: OpBranch [[merge]]
+; CHECK: [[cont]] = OpLabel
+; CHECK-NEXT: OpBranchConditional {{%\w+}} {{%\w+}} [[merge]]
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi1:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def1]] [[break]] [[def1]] [[cont]]
+; CHECK-NEXT: [[phi2:%\w+]] = OpPhi {{%\w+}} [[undef]] [[ret1]] [[undef]] [[ret2]] [[def2]] [[break]] [[def2]] [[cont]]
+; CHECK: OpLogicalNot {{%\w+}} [[phi1]]
+; CHECK: OpLogicalAnd {{%\w+}} [[phi2]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %func "func"
+OpExecutionMode %func LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%func = OpFunction %void None %void_fn
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %10 %9 None
+OpBranch %3
+%3 = OpLabel
+OpSelectionMerge %5 None
+OpBranchConditional %true %4 %5
+%4 = OpLabel
+OpReturn
+%5 = OpLabel
+%def1 = OpLogicalNot %bool %true
+OpSelectionMerge %7 None
+OpBranchConditional %true %6 %7
+%6 = OpLabel
+OpReturn
+%7 = OpLabel
+%def2 = OpLogicalAnd %bool %true %true
+OpBranchConditional %true %8 %9
+%8 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpBranchConditional %true %2 %10
+%10 = OpLabel
+OpBranch %11
+%11 = OpLabel
+%use1 = OpLogicalNot %bool %def1
+%use2 = OpLogicalAnd %bool %def2 %true
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
+TEST_F(MergeReturnPassTest, PredicateBreakBlock) {
+  const std::string spirv = R"(
+; IDs are being preserved so we can rely on basic block labels.
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: [[undef:%\w+]] = OpUndef
+; CHECK: %13 = OpLabel
+; CHECK-NEXT: [[def:%\w+]] = OpLogicalNot
+; CHECK: %8 = OpLabel
+; CHECK-NEXT: [[phi:%\w+]] = OpPhi {{%\w+}} [[undef]] {{%\w+}} [[undef]] {{%\w+}} [[def]] %13 [[undef]] {{%\w+}}
+; CHECK: OpLogicalAnd {{%\w+}} [[phi]]
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %1 "func"
+OpExecutionMode %1 LocalSize 1 1 1
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpUndef %bool
+%1 = OpFunction %void None %3
+%6 = OpLabel
+OpBranch %7
+%7 = OpLabel
+OpLoopMerge %8 %9 None
+OpBranch %10
+%10 = OpLabel
+OpSelectionMerge %11 None
+OpBranchConditional %true %12 %13
+%12 = OpLabel
+OpLoopMerge %14 %15 None
+OpBranch %16
+%16 = OpLabel
+OpReturn
+%15 = OpLabel
+OpBranch %12
+%14 = OpLabel
+OpUnreachable
+%13 = OpLabel
+%17 = OpLogicalNot %bool %true
+OpBranch %8
+%11 = OpLabel
+OpUnreachable
+%9 = OpLabel
+OpBranch %7
+%8 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%19 = OpLogicalAnd %bool %17 %true
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(spirv, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools