Fix bug in merge return (#2734)

* Fix bug in merge return

The merge return pass seems to assume that the only new edges in the cfg
are from return block to merge blocks.  However, it is possible that a
merge block branches to a merge block when it did not before.

This change add a new variable to track all of the new edges.  It also
renames some other variables and cleans us the code to make it a bit
easier to read.

Fixes #2702.
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index 0446331..6e7284b 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -224,6 +224,7 @@
   return_inst->SetOpcode(SpvOpBranch);
   return_inst->ReplaceOperands({{SPV_OPERAND_TYPE_ID, {target}}});
   context()->get_def_use_mgr()->AnalyzeInstDefUse(return_inst);
+  new_edges_[target_block].insert(block->id());
   cfg()->AddEdge(block->id(), target);
 }
 
@@ -236,28 +237,18 @@
     context()->UpdateDefUse(inst);
   });
 
-  const auto& target_pred = cfg()->preds(target->id());
-  if (target_pred.size() == 1) {
-    MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0]));
-  } else {
-    // If the loop contained a break and a return, OpPhi instructions may be
-    // required starting from the dominator of the loop merge.
-    DominatorAnalysis* dom_tree =
-        context()->GetDominatorAnalysis(target->GetParent());
-    auto idom = dom_tree->ImmediateDominator(target);
-    if (idom) {
-      MarkForNewPhiNodes(target, idom);
-    }
-  }
+  // Store the immediate dominator for this block in case new phi nodes will be
+  // needed later.
+  RecordImmediateDominator(target);
 }
 
 void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
                                             Instruction& inst) {
   DominatorAnalysis* dom_tree =
       context()->GetDominatorAnalysis(merge_block->GetParent());
-  BasicBlock* inst_bb = context()->get_instr_block(&inst);
 
   if (inst.result_id() != 0) {
+    BasicBlock* inst_bb = context()->get_instr_block(&inst);
     std::vector<Instruction*> users_to_update;
     context()->get_def_use_mgr()->ForEachUser(
         &inst,
@@ -295,12 +286,13 @@
         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
     uint32_t undef_id = Type2Undef(inst.type_id());
     std::vector<uint32_t> phi_operands;
+    const std::set<uint32_t>& new_edges = new_edges_[merge_block];
 
     // Add the OpPhi operands. If the predecessor is a return block use undef,
     // otherwise use |inst|'s id.
     std::vector<uint32_t> preds = cfg()->preds(merge_block->id());
     for (uint32_t pred_id : preds) {
-      if (return_blocks_.count(pred_id)) {
+      if (new_edges.count(pred_id)) {
         phi_operands.push_back(undef_id);
       } else {
         phi_operands.push_back(inst.result_id());
@@ -417,6 +409,8 @@
   auto old_body_id = TakeNextId();
   BasicBlock* old_body = block->SplitBasicBlock(context(), old_body_id, iter);
   predicated->insert(old_body);
+  cfg()->AddEdges(old_body);
+
   // If a return block is being split, mark the new body block also as a return
   // block.
   if (return_blocks_.count(block->id())) {
@@ -456,14 +450,15 @@
   builder.AddConditionalBranch(load_id, merge_block->id(), old_body->id(),
                                old_body->id());
 
-  // 3. Update OpPhi instructions in |merge_block|.
-  BasicBlock* merge_original_pred = MarkedSinglePred(merge_block);
-  if (merge_original_pred == nullptr) {
-    UpdatePhiNodes(block, merge_block);
-  } else if (merge_original_pred == block) {
-    MarkForNewPhiNodes(merge_block, old_body);
+  if (!new_edges_[merge_block].insert(block->id()).second) {
+    // It is possible that we already inserted a new edge to the merge block.
+    // If so, that edge now goes from |old_body| to |merge_block|.
+    new_edges_[merge_block].insert(old_body->id());
   }
 
+  // 3. Update OpPhi instructions in |merge_block|.
+  UpdatePhiNodes(block, merge_block);
+
   // 4. Update the CFG.  We do this after updating the OpPhi instructions
   // because |UpdatePhiNodes| assumes the edge from |block| has not been added
   // to the CFG yet.
@@ -659,26 +654,37 @@
 }
 
 void MergeReturnPass::AddNewPhiNodes() {
-  DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_);
   std::list<BasicBlock*> order;
   cfg()->ComputeStructuredOrder(function_, &*function_->begin(), &order);
 
   for (BasicBlock* bb : order) {
-    BasicBlock* dominator = dom_tree->ImmediateDominator(bb);
-    if (dominator) {
-      AddNewPhiNodes(bb, new_merge_nodes_[bb], dominator->id());
-    }
+    AddNewPhiNodes(bb);
   }
 }
 
-void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred,
-                                     uint32_t header_id) {
-  DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_);
-  // Insert as a stopping point.  We do not have to add anything in the block
-  // or above because the header dominates |bb|.
+void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb) {
+  // New phi nodes are needed for any id whose definition used to dominate |bb|,
+  // but no longer dominates |bb|.  These are found by walking the dominator
+  // tree starting at the original immediate dominator of |bb| and ending at its
+  // current dominator.
 
-  BasicBlock* current_bb = pred;
-  while (current_bb != nullptr && current_bb->id() != header_id) {
+  // Because we are walking the updated dominator tree it is important that the
+  // new phi nodes for the original dominators of |bb| have already been added.
+  // Otherwise some ids might be missed.  Consider the case where bb1 dominates
+  // bb2, and bb2 dominates bb3.  Suppose there are changes such that bb1 no
+  // longer dominates bb2 and the same for bb2 and bb3.  This algorithm will not
+  // look at the ids defined in bb1.  However, calling |AddNewPhiNodes(bb2)|
+  // first will add a phi node in bb2 for that value.  Then a call to
+  // |AddNewPhiNodes(bb3)| will process that value by processing the phi in bb2.
+  DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_);
+
+  BasicBlock* dominator = dom_tree->ImmediateDominator(bb);
+  if (dominator == nullptr) {
+    return;
+  }
+
+  BasicBlock* current_bb = new_merge_nodes_[bb];
+  while (current_bb != nullptr && current_bb != dominator) {
     for (Instruction& inst : *current_bb) {
       CreatePhiNodesForInst(bb, inst);
     }
@@ -686,9 +692,11 @@
   }
 }
 
-void MergeReturnPass::MarkForNewPhiNodes(BasicBlock* block,
-                                         BasicBlock* single_original_pred) {
-  new_merge_nodes_[block] = single_original_pred;
+void MergeReturnPass::RecordImmediateDominator(BasicBlock* block) {
+  DominatorAnalysis* dom_tree =
+      context()->GetDominatorAnalysis(block->GetParent());
+  auto idom = dom_tree->ImmediateDominator(block);
+  new_merge_nodes_[block] = idom;
 }
 
 void MergeReturnPass::InsertAfterElement(BasicBlock* element,
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index 63094b7..d9eae39 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -251,31 +251,19 @@
   // there are no unreachable blocks in the control flow graph.
   void AddNewPhiNodes();
 
-  // Creates any new phi nodes that are needed in |bb| now that |pred| is no
-  // longer the only block that preceedes |bb|.  |header_id| is the id of the
-  // basic block for the loop or selection construct that merges at |bb|.
-  void AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred, uint32_t header_id);
+  // Creates any new phi nodes that are needed in |bb|.  |AddNewPhiNodes| must
+  // have already been called on the original dominators of |bb|.
+  void AddNewPhiNodes(BasicBlock* bb);
 
   // Saves |block| to a list of basic block that will require OpPhi nodes to be
   // added by calling |AddNewPhiNodes|.  It is assumed that |block| used to have
   // a single predecessor, |single_original_pred|, but now has more.
-  void MarkForNewPhiNodes(BasicBlock* block, BasicBlock* single_original_pred);
-
-  // Return the original single predcessor of |block| if it was flagged as
-  // having a single predecessor.  |nullptr| is returned otherwise.
-  BasicBlock* MarkedSinglePred(BasicBlock* block) {
-    auto it = new_merge_nodes_.find(block);
-    if (it != new_merge_nodes_.end()) {
-      return it->second;
-    } else {
-      return nullptr;
-    }
-  }
+  void RecordImmediateDominator(BasicBlock* block);
 
   // Modifies existing OpPhi instruction in |target| block to account for the
   // new edge from |new_source|.  The value for that edge will be an Undef. If
   // |target| only had a single predecessor, then it is marked as needing new
-  // phi nodes.  See |MarkForNewPhiNodes|.
+  // phi nodes.  See |RecordImmediateDominator|.
   //
   // The CFG must not include the edge from |new_source| to |target| yet.
   void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target);
@@ -301,6 +289,11 @@
   // |merge_target| as the merge node.
   void CreateDummyLoop(BasicBlock* merge_target);
 
+  // Returns true if |function| has an unreachable block that is not a continue
+  // target that simply branches back to the header, or a merge block containing
+  // 1 instruction which is OpUnreachable.
+  bool HasNontrivialUnreachableBlocks(Function* function);
+
   // A stack used to keep track of the innermost contain loop and selection
   // constructs.
   std::vector<StructuredControlState> state_;
@@ -324,12 +317,13 @@
   // after processing the current function.
   BasicBlock* final_return_block_;
 
-  // This map contains the set of nodes that use to have a single predcessor,
-  // but now have more.  They will need new OpPhi nodes.  For each of the nodes,
-  // it is mapped to it original single predcessor.  It is assumed there are no
-  // values that will need a phi on the new edges.
+  // This is a map from a node to its original immediate dominator.  This is
+  // used to determine which values will require a new phi node.
   std::unordered_map<BasicBlock*, BasicBlock*> new_merge_nodes_;
-  bool HasNontrivialUnreachableBlocks(Function* function);
+
+  // A map from a basic block, bb, to the set of basic blocks which represent
+  // the new edges that reach |bb|.
+  std::unordered_map<BasicBlock*, std::set<uint32_t>> new_edges_;
 
   // Contains all return blocks that are merged. This is set is populated while
   // processing structured blocks and used to properly construct OpPhi
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index 423e81d..f6b30de 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1724,6 +1724,62 @@
   SinglePassRunAndMatch<MergeReturnPass>(predefs + caller + callee, true);
 }
 
+TEST_F(MergeReturnPassTest, MergeToMergeBranch) {
+  const std::string text =
+      R"(
+; CHECK: [[new_undef:%\w+]] = OpUndef %uint
+; CHECK: OpLoopMerge
+; CHECK: OpLoopMerge [[merge1:%\w+]]
+; CHECK: OpLoopMerge [[merge2:%\w+]]
+; CHECK: [[merge1]] = OpLabel
+; CHECK-NEXT: OpPhi %uint [[new_undef]] [[merge2]]
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %2 "main"
+               OpExecutionMode %2 LocalSize 100 1 1
+               OpSource ESSL 310
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+     %uint_1 = OpConstant %uint 1
+       %bool = OpTypeBool
+      %false = OpConstantFalse %bool
+     %uint_0 = OpConstant %uint 0
+        %int = OpTypeInt 32 1
+      %int_0 = OpConstant %int 0
+      %int_1 = OpConstant %int 1
+         %13 = OpUndef %bool
+          %2 = OpFunction %void None %4
+         %14 = OpLabel
+               OpBranch %15
+         %15 = OpLabel
+               OpLoopMerge %16 %17 None
+               OpBranch %18
+         %18 = OpLabel
+               OpLoopMerge %19 %20 None
+               OpBranchConditional %13 %21 %19
+         %21 = OpLabel
+               OpReturn
+         %20 = OpLabel
+               OpBranch %18
+         %19 = OpLabel
+         %22 = OpUndef %uint
+               OpBranch %23
+         %23 = OpLabel
+               OpBranch %16
+         %17 = OpLabel
+               OpBranch %15
+         %16 = OpLabel
+         %24 = OpCopyObject %uint %22
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(text, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools