Have PredicateBlocks jump the existing merge blocks. (#1849)

* Refactor PredicateBlocks

Refactor PredicateBlocks so that we know which constructs a return
is contained in.  Will be used later.

* Have PredicateBlocks jump the existing merge blocks.

In PredicateBlocks, we currently skip instructions with side effects,
but it still follows the same control flow (sort-of).  This causes a
problem, when we are trying to predicate code in a loop.  We skip all
of the code with side effects (IV increment), but still follow the
same control flow (jump back the start of the loop).  This creates an
infinite loop because the code will keep jumping back to the start of
the loop without changing the values that effect the exit condition.

This is a large change to merge-return.  When predicating a block that
is in a loop or merge construct, it will jump to the merge block of the
construct.  Once out of all constructs we will generate code as we did
before.
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index 4ad7b36..28b6461 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -87,8 +87,33 @@
     }
   }
 
-  // Predicate successors of the original return blocks as necessary.
-  PredicateBlocks(return_blocks);
+  state_.clear();
+  state_.emplace_back(nullptr, nullptr);
+  std::unordered_set<BasicBlock*> predicated;
+  for (auto block : order) {
+    if (cfg()->IsPseudoEntryBlock(block) || cfg()->IsPseudoExitBlock(block)) {
+      continue;
+    }
+
+    auto blockId = block->GetLabelInst()->result_id();
+    if (blockId == CurrentState().CurrentMergeId()) {
+      // Pop the current state as we've hit the merge
+      state_.pop_back();
+    }
+
+    // Predicate successors of the original return blocks as necessary.
+    if (std::find(return_blocks.begin(), return_blocks.end(), block) !=
+        return_blocks.end()) {
+      PredicateBlocks(block, &predicated);
+    }
+
+    // Generate state for next block
+    if (Instruction* mergeInst = block->GetMergeInst()) {
+      Instruction* loopMergeInst = block->GetLoopMergeInst();
+      if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst();
+      state_.emplace_back(loopMergeInst, mergeInst);
+    }
+  }
 
   // We have not kept the dominator tree up-to-date.
   // Invalidate it at this point to make sure it will be rebuilt.
@@ -165,24 +190,8 @@
     RecordReturned(block);
     RecordReturnValue(block);
   }
-
-  // Fix up existing phi nodes.
-  //
-  // A new edge is being added from |block| to |target|, so go through
-  // |target|'s phi nodes add an undef incoming value for |block|.
   BasicBlock* target_block = context()->get_instr_block(target);
-  target_block->ForEachPhiInst([this, block](Instruction* inst) {
-    uint32_t undefId = Type2Undef(inst->type_id());
-    inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}});
-    inst->AddOperand({SPV_OPERAND_TYPE_ID, {block->id()}});
-    context()->UpdateDefUse(inst);
-  });
-
-  const auto& target_pred = cfg()->preds(target);
-  if (target_pred.size() == 1) {
-    MarkForNewPhiNodes(target_block,
-                       context()->get_instr_block(target_pred[0]));
-  }
+  UpdatePhiNodes(block, target_block);
 
   Instruction* return_inst = block->terminator();
   return_inst->SetOpcode(SpvOpBranch);
@@ -191,6 +200,23 @@
   cfg()->AddEdge(block->id(), target);
 }
 
+void MergeReturnPass::UpdatePhiNodes(BasicBlock* new_source,
+                                     BasicBlock* target) {
+  // A new edge is being added from |new_source| to |target|, so go through
+  // |target|'s phi nodes add an undef incoming value for |new_source|.
+  target->ForEachPhiInst([this, new_source](Instruction* inst) {
+    uint32_t undefId = Type2Undef(inst->type_id());
+    inst->AddOperand({SPV_OPERAND_TYPE_ID, {undefId}});
+    inst->AddOperand({SPV_OPERAND_TYPE_ID, {new_source->id()}});
+    context()->UpdateDefUse(inst);
+  });
+
+  const auto& target_pred = cfg()->preds(target->id());
+  if (target_pred.size() == 1) {
+    MarkForNewPhiNodes(target, context()->get_instr_block(target_pred[0]));
+  }
+}
+
 void MergeReturnPass::CreatePhiNodesForInst(BasicBlock* merge_block,
                                             uint32_t predecessor,
                                             Instruction& inst) {
@@ -251,42 +277,72 @@
 }
 
 void MergeReturnPass::PredicateBlocks(
-    const std::vector<BasicBlock*>& return_blocks) {
+    BasicBlock* return_block, std::unordered_set<BasicBlock*>* predicated) {
   // The CFG is being modified as the function proceeds so avoid caching
   // successors.
-  std::vector<BasicBlock*> stack;
-  auto add_successors = [this, &stack](BasicBlock* block) {
-    const BasicBlock* const_block = const_cast<const BasicBlock*>(block);
-    const_block->ForEachSuccessorLabel([this, &stack](const uint32_t idx) {
-      stack.push_back(context()->get_instr_block(idx));
-    });
-  };
 
+  if (predicated->count(return_block)) {
+    return;
+  }
+
+  BasicBlock* block = nullptr;
+  const BasicBlock* const_block = const_cast<const BasicBlock*>(return_block);
+  const_block->ForEachSuccessorLabel([this, &block](const uint32_t idx) {
+    BasicBlock* succ_block = context()->get_instr_block(idx);
+    assert(block == nullptr);
+    block = succ_block;
+  });
+  assert(block &&
+         "Return blocks should have returns already replaced by a single "
+         "unconditional branch.");
+
+  auto state = state_.rbegin();
   std::unordered_set<BasicBlock*> seen;
-  std::unordered_set<BasicBlock*> predicated;
-  for (auto b : return_blocks) {
-    seen.clear();
-    add_successors(b);
-
-    while (!stack.empty()) {
-      BasicBlock* block = stack.back();
-      assert(block);
-      stack.pop_back();
-
-      if (block == b) continue;
-      if (block == final_return_block_) continue;
-      if (!seen.insert(block).second) continue;
-      if (!predicated.insert(block).second) continue;
-
-      // Skip structured subgraphs.
-      BasicBlock* next = block;
-      while (next->GetMergeInst()) {
-        next = context()->get_instr_block(next->MergeBlockIdIfAny());
-      }
-      add_successors(next);
-      PredicateBlock(block, next, &predicated);
+  if (block->id() == state->CurrentMergeId()) {
+    state++;
+  } else if (block->id() == state->LoopMergeId()) {
+    while (state->LoopMergeId() == block->id()) {
+      state++;
     }
   }
+
+  while (block != nullptr && block != final_return_block_) {
+    if (!predicated->insert(block).second) break;
+
+    // Skip structured subgraphs.
+    BasicBlock* next = nullptr;
+    if (state->InLoop()) {
+      next = context()->get_instr_block(state->LoopMergeId());
+      while (state->LoopMergeId() == next->id()) {
+        state++;
+      }
+      BreakFromConstruct(block, next, predicated);
+    } else if (state->InStructuredFlow()) {
+      next = context()->get_instr_block(state->CurrentMergeId());
+      state++;
+      BreakFromConstruct(block, next, predicated);
+    } else {
+      BasicBlock* tail = block;
+      while (tail->GetMergeInst()) {
+        tail = context()->get_instr_block(tail->MergeBlockIdIfAny());
+      }
+
+      // Must find |next| (the successor of |tail|) before predicating the
+      // block because, if |block| == |tail|, then |tail| will have multiple
+      // successors.
+      next = nullptr;
+      tail->ForEachSuccessorLabel([this, &next](const uint32_t idx) {
+        BasicBlock* succ_block = context()->get_instr_block(idx);
+        assert(
+            next == nullptr &&
+            "Found block with multiple successors and no merge instruction.");
+        next = succ_block;
+      });
+
+      PredicateBlock(block, tail, predicated);
+    }
+    block = next;
+  }
 }
 
 bool MergeReturnPass::RequiresPredication(const BasicBlock* block,
@@ -312,18 +368,18 @@
     return;
   }
 
-  // Make sure the cfg is build here.  If we don't then it becomes very hard to
-  // know which new blocks need to be updated.
+  // Make sure the cfg is build here.  If we don't then it becomes very hard
+  // to know which new blocks need to be updated.
   context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG);
 
-  // When predicating, be aware of whether this block is a header block, a merge
-  // block or both.
+  // When predicating, be aware of whether this block is a header block, a
+  // merge block or both.
   //
   // If this block is a merge block, ensure the appropriate header stays
   // up-to-date with any changes (i.e. points to the pre-header).
   //
-  // If this block is a header block, predicate the entire structured subgraph.
-  // This can act recursively.
+  // If this block is a header block, predicate the entire structured
+  // subgraph. This can act recursively.
 
   // If |block| is a loop header, then the back edge must jump to the original
   // code, not the new header.
@@ -385,8 +441,8 @@
   get_def_use_mgr()->AnalyzeInstUse(new_merge->terminator());
   context()->set_instr_block(new_merge->terminator(), new_merge);
 
-  // Add a branch to the new merge. If we jumped multiple blocks, the branch is
-  // added to tail_block, otherwise the branch belongs in old_body.
+  // Add a branch to the new merge. If we jumped multiple blocks, the branch
+  // is added to tail_block, otherwise the branch belongs in old_body.
   tail_block->AddInstruction(
       MakeUnique<Instruction>(context(), SpvOpBranch, 0, 0,
                               std::initializer_list<Operand>{
@@ -443,6 +499,88 @@
   MarkForNewPhiNodes(new_merge, tail_block);
 }
 
+void MergeReturnPass::BreakFromConstruct(
+    BasicBlock* block, BasicBlock* merge_block,
+    std::unordered_set<BasicBlock*>* predicated) {
+  // Make sure the cfg is build here.  If we don't then it becomes very hard
+  // to know which new blocks need to be updated.
+  context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG);
+
+  // When predicating, be aware of whether this block is a header block, a
+  // merge block or both.
+  //
+  // If this block is a merge block, ensure the appropriate header stays
+  // up-to-date with any changes (i.e. points to the pre-header).
+  //
+  // If this block is a header block, predicate the entire structured
+  // subgraph. This can act recursively.
+
+  // If |block| is a loop header, then the back edge must jump to the original
+  // code, not the new header.
+  if (block->GetLoopMergeInst()) {
+    cfg()->SplitLoopHeader(block);
+  }
+
+  // Leave the phi instructions behind.
+  auto iter = block->begin();
+  while (iter->opcode() == SpvOpPhi) {
+    ++iter;
+  }
+
+  // Forget about the edges leaving block.  They will be removed.
+  cfg()->RemoveSuccessorEdges(block);
+
+  std::unique_ptr<BasicBlock> new_block(
+      block->SplitBasicBlock(context(), TakeNextId(), iter));
+  BasicBlock* old_body =
+      function_->InsertBasicBlockAfter(std::move(new_block), block);
+  predicated->insert(old_body);
+
+  // Within the new header we need the following:
+  // 1. Load of the return status flag
+  // 2. Branch to new merge (true) or old body (false)
+  // 3. Update OpPhi instructions in |merge_block|.
+  //
+  // Sine we are branching to the merge block of the current construct, there is
+  // no need for an OpSelectionMerge.
+
+  // 1. Load of the return status flag
+  analysis::Bool bool_type;
+  uint32_t bool_id = context()->get_type_mgr()->GetId(&bool_type);
+  assert(bool_id != 0);
+  uint32_t load_id = TakeNextId();
+  block->AddInstruction(MakeUnique<Instruction>(
+      context(), SpvOpLoad, bool_id, load_id,
+      std::initializer_list<Operand>{
+          {SPV_OPERAND_TYPE_ID, {return_flag_->result_id()}}}));
+  get_def_use_mgr()->AnalyzeInstDefUse(block->terminator());
+  context()->set_instr_block(block->terminator(), block);
+
+  // 2. Branch to |merge_block| (true) or |old_body| (false)
+  block->AddInstruction(MakeUnique<Instruction>(
+      context(), SpvOpBranchConditional, 0, 0,
+      std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {load_id}},
+                                     {SPV_OPERAND_TYPE_ID, {merge_block->id()}},
+                                     {SPV_OPERAND_TYPE_ID, {old_body->id()}}}));
+  get_def_use_mgr()->AnalyzeInstUse(block->terminator());
+  context()->set_instr_block(block->terminator(), block);
+
+  // Update the cfg
+  cfg()->AddEdges(block);
+  cfg()->RegisterBlock(old_body);
+
+  // 3. Update OpPhi instructions in |merge_block|.
+  BasicBlock* merge_original_pred = MarkedSinglePred(merge_block);
+  if (merge_original_pred == nullptr) {
+    UpdatePhiNodes(block, merge_block);
+  } else if (merge_original_pred == block) {
+    MarkForNewPhiNodes(merge_block, old_body);
+  }
+
+  assert(old_body->begin() != old_body->end());
+  assert(block->begin() != block->end());
+}
+
 void MergeReturnPass::RecordReturned(BasicBlock* block) {
   if (block->tail()->opcode() != SpvOpReturn &&
       block->tail()->opcode() != SpvOpReturnValue)
@@ -637,8 +775,8 @@
 void MergeReturnPass::AddNewPhiNodes(BasicBlock* bb, BasicBlock* pred,
                                      uint32_t header_id) {
   DominatorAnalysis* dom_tree = context()->GetDominatorAnalysis(function_);
-  // Insert as a stopping point.  We do not have to add anything in the block or
-  // above because the header dominates |bb|.
+  // Insert as a stopping point.  We do not have to add anything in the block
+  // or above because the header dominates |bb|.
 
   BasicBlock* current_bb = pred;
   while (current_bb != nullptr && current_bb->id() != header_id) {
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index 0c7a3b4..0a77b1b 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -220,7 +220,8 @@
   // not be executed because the original code would have already returned. This
   // involves adding new selections constructs to jump around these
   // instructions.
-  void PredicateBlocks(const std::vector<BasicBlock*>& return_blocks);
+  void PredicateBlocks(BasicBlock* return_block,
+                       std::unordered_set<BasicBlock*>* pSet);
 
   // Add the predication code (see |PredicateBlocks|) to |tail_block| if it
   // requires predication.  |tail_block| and any new blocks that are known to
@@ -298,6 +299,9 @@
   // it is mapped to it original single predcessor.  It is assumed there are no
   // values that will need a phi on the new edges.
   std::unordered_map<BasicBlock*, BasicBlock*> new_merge_nodes_;
+  void BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block,
+                          std::unordered_set<BasicBlock*>* predicated);
+  void UpdatePhiNodes(BasicBlock* new_source, BasicBlock* target);
 };
 
 }  // namespace opt
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index 5b2f6c6..fc1f112 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -296,6 +296,7 @@
 OpFunctionEnd
 )";
 
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   SinglePassRunAndMatch<MergeReturnPass>(before, false);
 }
 
@@ -344,6 +345,7 @@
 OpFunctionEnd
 )";
 
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   SinglePassRunAndMatch<MergeReturnPass>(before, false);
 }
 
@@ -593,16 +595,16 @@
                OpReturn
          %11 = OpLabel
                OpSelectionMerge %12 None
-               OpBranchConditional %false %14 %15
-         %14 = OpLabel
-         %16 = OpIAdd %uint %uint_0 %uint_0
+               OpBranchConditional %false %13 %14
+         %13 = OpLabel
+         %15 = OpIAdd %uint %uint_0 %uint_0
                OpBranch %12
-         %15 = OpLabel
+         %14 = OpLabel
                OpReturn
          %12 = OpLabel
                OpBranch %9
           %9 = OpLabel
-         %17 = OpIAdd %uint %16 %16
+         %16 = OpIAdd %uint %15 %15
                OpReturn
                OpFunctionEnd
 )";
@@ -621,7 +623,7 @@
 %7 = OpTypeFunction %void
 %_ptr_Function_bool = OpTypePointer Function %bool
 %true = OpConstantTrue %bool
-%24 = OpUndef %uint
+%26 = OpUndef %uint
 %1 = OpFunction %void None %7
 %8 = OpLabel
 %19 = OpVariable %_ptr_Function_bool Function %false
@@ -640,24 +642,28 @@
 OpStore %19 %true
 OpBranch %12
 %12 = OpLabel
-%25 = OpPhi %uint %15 %13 %24 %14
+%27 = OpPhi %uint %15 %13 %26 %14
+%22 = OpLoad %bool %19
+OpBranchConditional %22 %9 %21
+%21 = OpLabel
 OpBranch %9
 %9 = OpLabel
-%26 = OpPhi %uint %25 %12 %24 %10
-%23 = OpLoad %bool %19
-OpSelectionMerge %22 None
-OpBranchConditional %23 %22 %21
-%21 = OpLabel
-%16 = OpIAdd %uint %26 %26
+%28 = OpPhi %uint %27 %21 %26 %10 %26 %12
+%25 = OpLoad %bool %19
+OpSelectionMerge %24 None
+OpBranchConditional %25 %24 %23
+%23 = OpLabel
+%16 = OpIAdd %uint %28 %28
 OpStore %19 %true
-OpBranch %22
-%22 = OpLabel
+OpBranch %24
+%24 = OpLabel
 OpBranch %17
 %17 = OpLabel
 OpReturn
 OpFunctionEnd
 )";
 
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   SinglePassRunAndCheck<MergeReturnPass>(before, after, false, true);
 }
 
@@ -666,8 +672,7 @@
 // work even if the order of the traversals change.
 TEST_F(MergeReturnPassTest, NestedSelectionMerge2) {
   const std::string before =
-      R"(
-               OpCapability Addresses
+      R"(      OpCapability Addresses
                OpCapability Shader
                OpCapability Linkage
                OpMemoryModel Logical GLSL450
@@ -686,16 +691,16 @@
                OpReturn
          %10 = OpLabel
                OpSelectionMerge %12 None
-               OpBranchConditional %false %14 %15
-         %14 = OpLabel
-         %16 = OpIAdd %uint %uint_0 %uint_0
+               OpBranchConditional %false %13 %14
+         %13 = OpLabel
+         %15 = OpIAdd %uint %uint_0 %uint_0
                OpBranch %12
-         %15 = OpLabel
+         %14 = OpLabel
                OpReturn
          %12 = OpLabel
                OpBranch %9
           %9 = OpLabel
-         %17 = OpIAdd %uint %16 %16
+         %16 = OpIAdd %uint %15 %15
                OpReturn
                OpFunctionEnd
 )";
@@ -714,7 +719,7 @@
 %7 = OpTypeFunction %void
 %_ptr_Function_bool = OpTypePointer Function %bool
 %true = OpConstantTrue %bool
-%24 = OpUndef %uint
+%26 = OpUndef %uint
 %1 = OpFunction %void None %7
 %8 = OpLabel
 %19 = OpVariable %_ptr_Function_bool Function %false
@@ -733,15 +738,18 @@
 OpStore %19 %true
 OpBranch %12
 %12 = OpLabel
-%25 = OpPhi %uint %15 %13 %24 %14
+%27 = OpPhi %uint %15 %13 %26 %14
+%25 = OpLoad %bool %19
+OpBranchConditional %25 %9 %24
+%24 = OpLabel
 OpBranch %9
 %9 = OpLabel
-%26 = OpPhi %uint %25 %12 %24 %11
+%28 = OpPhi %uint %27 %24 %26 %11 %26 %12
 %23 = OpLoad %bool %19
 OpSelectionMerge %22 None
 OpBranchConditional %23 %22 %21
 %21 = OpLabel
-%16 = OpIAdd %uint %26 %26
+%16 = OpIAdd %uint %28 %28
 OpStore %19 %true
 OpBranch %22
 %22 = OpLabel
@@ -756,8 +764,7 @@
 
 TEST_F(MergeReturnPassTest, NestedSelectionMerge3) {
   const std::string before =
-      R"(
-               OpCapability Addresses
+      R"(      OpCapability Addresses
                OpCapability Shader
                OpCapability Linkage
                OpMemoryModel Logical GLSL450
@@ -775,17 +782,17 @@
          %11 = OpLabel
                OpReturn
          %10 = OpLabel
-         %16 = OpIAdd %uint %uint_0 %uint_0
-               OpSelectionMerge %12 None
+         %12 = OpIAdd %uint %uint_0 %uint_0
+               OpSelectionMerge %13 None
                OpBranchConditional %false %14 %15
          %14 = OpLabel
-               OpBranch %12
+               OpBranch %13
          %15 = OpLabel
                OpReturn
-         %12 = OpLabel
+         %13 = OpLabel
                OpBranch %9
           %9 = OpLabel
-         %17 = OpIAdd %uint %16 %16
+         %16 = OpIAdd %uint %12 %12
                OpReturn
                OpFunctionEnd
 )";
@@ -804,7 +811,7 @@
 %7 = OpTypeFunction %void
 %_ptr_Function_bool = OpTypePointer Function %bool
 %true = OpConstantTrue %bool
-%24 = OpUndef %uint
+%26 = OpUndef %uint
 %1 = OpFunction %void None %7
 %8 = OpLabel
 %19 = OpVariable %_ptr_Function_bool Function %false
@@ -823,14 +830,17 @@
 OpStore %19 %true
 OpBranch %13
 %13 = OpLabel
+%25 = OpLoad %bool %19
+OpBranchConditional %25 %9 %24
+%24 = OpLabel
 OpBranch %9
 %9 = OpLabel
-%25 = OpPhi %uint %12 %13 %24 %11
+%27 = OpPhi %uint %12 %24 %26 %11 %26 %13
 %23 = OpLoad %bool %19
 OpSelectionMerge %22 None
 OpBranchConditional %23 %22 %21
 %21 = OpLabel
-%16 = OpIAdd %uint %25 %25
+%16 = OpIAdd %uint %27 %27
 OpStore %19 %true
 OpBranch %22
 %22 = OpLabel
@@ -843,6 +853,158 @@
   SinglePassRunAndCheck<MergeReturnPass>(before, after, false, true);
 }
 
+TEST_F(MergeReturnPassTest, NestedLoopMerge) {
+  const std::string before =
+      R"(               OpCapability SampledBuffer
+               OpCapability StorageImageExtendedFormats
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %2 "CS"
+               OpExecutionMode %2 LocalSize 8 8 1
+               OpSource HLSL 600
+               OpName %function "function"
+       %uint = OpTypeInt 32 0
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+     %v3uint = OpTypeVector %uint 3
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %_struct_13 = OpTypeStruct %v3uint %v3uint %v3uint %uint %uint %uint %uint %uint %uint
+          %2 = OpFunction %void None %6
+         %14 = OpLabel
+         %15 = OpFunctionCall %void %function
+               OpReturn
+               OpFunctionEnd
+   %function = OpFunction %void None %6
+         %16 = OpLabel
+         %17 = OpVariable %_ptr_Function_uint Function
+         %18 = OpVariable %_ptr_Function_uint Function
+               OpStore %17 %uint_0
+               OpBranch %19
+         %19 = OpLabel
+         %20 = OpLoad %uint %17
+         %21 = OpULessThan %bool %20 %uint_1
+               OpLoopMerge %22 %23 DontUnroll
+               OpBranchConditional %21 %24 %22
+         %24 = OpLabel
+               OpStore %18 %uint_1
+               OpBranch %25
+         %25 = OpLabel
+         %26 = OpLoad %uint %18
+         %27 = OpINotEqual %bool %26 %uint_0
+               OpLoopMerge %28 %29 DontUnroll
+               OpBranchConditional %27 %30 %28
+         %30 = OpLabel
+               OpSelectionMerge %31 None
+               OpBranchConditional %true %32 %31
+         %32 = OpLabel
+               OpReturn
+         %31 = OpLabel
+               OpStore %18 %uint_1
+               OpBranch %29
+         %29 = OpLabel
+               OpBranch %25
+         %28 = OpLabel
+               OpBranch %23
+         %23 = OpLabel
+         %33 = OpLoad %uint %17
+         %34 = OpIAdd %uint %33 %uint_1
+               OpStore %17 %34
+               OpBranch %19
+         %22 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability SampledBuffer
+OpCapability StorageImageExtendedFormats
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %2 "CS"
+OpExecutionMode %2 LocalSize 8 8 1
+OpSource HLSL 600
+OpName %function "function"
+%uint = OpTypeInt 32 0
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+%uint_0 = OpConstant %uint 0
+%uint_1 = OpConstant %uint 1
+%v3uint = OpTypeVector %uint 3
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_struct_13 = OpTypeStruct %v3uint %v3uint %v3uint %uint %uint %uint %uint %uint %uint
+%false = OpConstantFalse %bool
+%_ptr_Function_bool = OpTypePointer Function %bool
+%2 = OpFunction %void None %6
+%14 = OpLabel
+%15 = OpFunctionCall %void %function
+OpReturn
+OpFunctionEnd
+%function = OpFunction %void None %6
+%16 = OpLabel
+%38 = OpVariable %_ptr_Function_bool Function %false
+%17 = OpVariable %_ptr_Function_uint Function
+%18 = OpVariable %_ptr_Function_uint Function
+OpStore %17 %uint_0
+OpBranch %19
+%19 = OpLabel
+%20 = OpLoad %uint %17
+%21 = OpULessThan %bool %20 %uint_1
+OpLoopMerge %22 %23 DontUnroll
+OpBranchConditional %21 %24 %22
+%24 = OpLabel
+OpStore %18 %uint_1
+OpBranch %25
+%25 = OpLabel
+%26 = OpLoad %uint %18
+%27 = OpINotEqual %bool %26 %uint_0
+OpLoopMerge %28 %29 DontUnroll
+OpBranchConditional %27 %30 %28
+%30 = OpLabel
+OpSelectionMerge %31 None
+OpBranchConditional %true %32 %31
+%32 = OpLabel
+OpStore %38 %true
+OpBranch %28
+%31 = OpLabel
+OpStore %18 %uint_1
+OpBranch %29
+%29 = OpLabel
+OpBranch %25
+%28 = OpLabel
+%40 = OpLoad %bool %38
+OpBranchConditional %40 %22 %39
+%39 = OpLabel
+OpBranch %23
+%23 = OpLabel
+%33 = OpLoad %uint %17
+%34 = OpIAdd %uint %33 %uint_1
+OpStore %17 %34
+OpBranch %19
+%22 = OpLabel
+%43 = OpLoad %bool %38
+OpSelectionMerge %42 None
+OpBranchConditional %43 %42 %41
+%41 = OpLabel
+OpStore %38 %true
+OpBranch %42
+%42 = OpLabel
+OpBranch %35
+%35 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndCheck<MergeReturnPass>(before, after, false, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools