Use dummy switch instead of dummy loop in MergeReturn pass. (#3151)

Fixes #3127
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index 18c49f5..bbac4bb 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -69,6 +69,32 @@
   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
 }
 
+void MergeReturnPass::GenerateState(BasicBlock* block) {
+  if (Instruction* mergeInst = block->GetMergeInst()) {
+    if (mergeInst->opcode() == SpvOpLoopMerge) {
+      // If new loop, break to this loop merge block
+      state_.emplace_back(mergeInst, mergeInst);
+    } else {
+      auto branchInst = mergeInst->NextNode();
+      if (branchInst->opcode() == SpvOpSwitch) {
+        // If switch inside of loop, break to innermost loop merge block.
+        // Otherwise need to break to this switch merge block.
+        auto lastMergeInst = state_.back().BreakMergeInst();
+        if (lastMergeInst && lastMergeInst->opcode() == SpvOpLoopMerge)
+          state_.emplace_back(lastMergeInst, mergeInst);
+        else
+          state_.emplace_back(mergeInst, mergeInst);
+      } else {
+        // If branch conditional inside loop, always break to innermost
+        // loop merge block. If branch conditional inside switch, break to
+        // innermost switch merge block.
+        auto lastMergeInst = state_.back().BreakMergeInst();
+        state_.emplace_back(lastMergeInst, mergeInst);
+      }
+    }
+  }
+}
+
 bool MergeReturnPass::ProcessStructured(
     Function* function, const std::vector<BasicBlock*>& return_blocks) {
   if (HasNontrivialUnreachableBlocks(function)) {
@@ -82,7 +108,7 @@
   }
 
   RecordImmediateDominators(function);
-  AddDummyLoopAroundFunction();
+  AddDummySwitchAroundFunction();
 
   std::list<BasicBlock*> order;
   cfg()->ComputeStructuredOrder(function, &*function->begin(), &order);
@@ -103,12 +129,8 @@
 
     ProcessStructuredBlock(block);
 
-    // Generate state for next block
-    if (Instruction* mergeInst = block->GetMergeInst()) {
-      Instruction* loopMergeInst = block->GetLoopMergeInst();
-      if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst();
-      state_.emplace_back(loopMergeInst, mergeInst);
-    }
+    // Generate state for next block if warranted
+    GenerateState(block);
   }
 
   state_.clear();
@@ -133,12 +155,8 @@
       }
     }
 
-    // Generate state for next block
-    if (Instruction* mergeInst = block->GetMergeInst()) {
-      Instruction* loopMergeInst = block->GetLoopMergeInst();
-      if (!loopMergeInst) loopMergeInst = state_.back().LoopMergeInst();
-      state_.emplace_back(loopMergeInst, mergeInst);
-    }
+    // Generate state for next block if warranted
+    GenerateState(block);
   }
 
   // We have not kept the dominator tree up-to-date.
@@ -202,8 +220,8 @@
 
   if (tail_opcode == SpvOpReturn || tail_opcode == SpvOpReturnValue ||
       tail_opcode == SpvOpUnreachable) {
-    assert(CurrentState().InLoop() && "Should be in the dummy loop.");
-    BranchToBlock(block, CurrentState().LoopMergeId());
+    assert(CurrentState().InBreakable() && "Should be in the dummy construct.");
+    BranchToBlock(block, CurrentState().BreakMergeId());
     return_blocks_.insert(block->id());
   }
 }
@@ -337,8 +355,8 @@
   std::unordered_set<BasicBlock*> seen;
   if (block->id() == state->CurrentMergeId()) {
     state++;
-  } else if (block->id() == state->LoopMergeId()) {
-    while (state->LoopMergeId() == block->id()) {
+  } else if (block->id() == state->BreakMergeId()) {
+    while (state->BreakMergeId() == block->id()) {
       state++;
     }
   }
@@ -346,15 +364,14 @@
   while (block != nullptr && block != final_return_block_) {
     if (!predicated->insert(block).second) break;
     // Skip structured subgraphs.
-    assert(state->InLoop() && "Should be in the dummy loop at the very least.");
-    Instruction* current_loop_merge_inst = state->LoopMergeInst();
-    uint32_t merge_block_id =
-        current_loop_merge_inst->GetSingleWordInOperand(0);
-    while (state->LoopMergeId() == merge_block_id) {
+    assert(state->InBreakable() &&
+           "Should be in the dummy construct at the very least.");
+    Instruction* break_merge_inst = state->BreakMergeInst();
+    uint32_t merge_block_id = break_merge_inst->GetSingleWordInOperand(0);
+    while (state->BreakMergeId() == merge_block_id) {
       state++;
     }
-    if (!BreakFromConstruct(block, predicated, order,
-                            current_loop_merge_inst)) {
+    if (!BreakFromConstruct(block, predicated, order, break_merge_inst)) {
       return false;
     }
     block = context()->get_instr_block(merge_block_id);
@@ -364,9 +381,7 @@
 
 bool MergeReturnPass::BreakFromConstruct(
     BasicBlock* block, std::unordered_set<BasicBlock*>* predicated,
-    std::list<BasicBlock*>* order, Instruction* loop_merge_inst) {
-  assert(loop_merge_inst->opcode() == SpvOpLoopMerge &&
-         "loop_merge_inst must be a loop merge instruction.");
+    std::list<BasicBlock*>* order, Instruction* break_merge_inst) {
   // Make sure the CFG is build here.  If we don't then it becomes very hard
   // to know which new blocks need to be updated.
   context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG);
@@ -388,7 +403,7 @@
     }
   }
 
-  uint32_t merge_block_id = loop_merge_inst->GetSingleWordInOperand(0);
+  uint32_t merge_block_id = break_merge_inst->GetSingleWordInOperand(0);
   BasicBlock* merge_block = context()->get_instr_block(merge_block_id);
   if (merge_block->GetLoopMergeInst()) {
     cfg()->SplitLoopHeader(merge_block);
@@ -416,9 +431,10 @@
 
   // If |block| was a continue target for a loop |old_body| is now the correct
   // continue target.
-  if (loop_merge_inst->GetSingleWordInOperand(1) == block->id()) {
-    loop_merge_inst->SetInOperand(1, {old_body->id()});
-    context()->UpdateDefUse(loop_merge_inst);
+  if (break_merge_inst->opcode() == SpvOpLoopMerge &&
+      break_merge_inst->GetSingleWordInOperand(1) == block->id()) {
+    break_merge_inst->SetInOperand(1, {old_body->id()});
+    context()->UpdateDefUse(break_merge_inst);
   }
 
   // Update |order| so old_block will be traversed.
@@ -430,8 +446,8 @@
   // 3. Update OpPhi instructions in |merge_block|.
   // 4. Update the CFG.
   //
-  // Sine we are branching to the merge block of the current construct, there is
-  // no need for an OpSelectionMerge.
+  // Since we are branching to the merge block of the current construct, there
+  // is no need for an OpSelectionMerge.
 
   InstructionBuilder builder(
       context(), block,
@@ -710,7 +726,7 @@
   list->insert(pos, new_element);
 }
 
-void MergeReturnPass::AddDummyLoopAroundFunction() {
+void MergeReturnPass::AddDummySwitchAroundFunction() {
   CreateReturnBlock();
   CreateReturn(final_return_block_);
 
@@ -718,7 +734,7 @@
     cfg()->RegisterBlock(final_return_block_);
   }
 
-  CreateDummyLoop(final_return_block_);
+  CreateDummySwitch(final_return_block_);
 }
 
 BasicBlock* MergeReturnPass::CreateContinueTarget(uint32_t header_label_id) {
@@ -753,14 +769,8 @@
   return new_block;
 }
 
-void MergeReturnPass::CreateDummyLoop(BasicBlock* merge_target) {
-  std::unique_ptr<Instruction> label(
-      new Instruction(context(), SpvOpLabel, 0u, TakeNextId(), {}));
-
-  // Create the new basic block
-  std::unique_ptr<BasicBlock> block(new BasicBlock(std::move(label)));
-
-  // Insert the new block before any code is run.  We have to split the entry
+void MergeReturnPass::CreateDummySwitch(BasicBlock* merge_target) {
+  // Insert the switch before any code is run.  We have to split the entry
   // block to make sure the OpVariable instructions remain in the entry block.
   BasicBlock* start_block = &*function_->begin();
   auto split_pos = start_block->begin();
@@ -771,38 +781,16 @@
   BasicBlock* old_block =
       start_block->SplitBasicBlock(context(), TakeNextId(), split_pos);
 
-  // The new block must be inserted after the entry block.  We cannot make the
-  // entry block the header for the dummy loop because it is not valid to have a
-  // branch to the entry block, and the continue target must branch back to the
-  // loop header.
-  auto pos = function_->begin();
-  pos++;
-  BasicBlock* header_block = &*pos.InsertBefore(std::move(block));
-  context()->AnalyzeDefUse(header_block->GetLabelInst());
-  header_block->SetParent(function_);
-
-  // We have to create the continue block before OpLoopMerge instruction.
-  // Otherwise the def-use manager will compalain that there is a use without a
-  // definition.
-  uint32_t continue_target = CreateContinueTarget(header_block->id())->id();
-
-  // Add the code the the header block.
+  // Add the switch to the end of the entry block.
   InstructionBuilder builder(
-      context(), header_block,
-      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-
-  builder.AddLoopMerge(merge_target->id(), continue_target);
-  builder.AddBranch(old_block->id());
-
-  // Fix up the entry block by adding a branch to the loop header.
-  InstructionBuilder builder2(
       context(), start_block,
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-  builder2.AddBranch(header_block->id());
+
+  builder.AddSwitch(builder.GetUintConstantId(0u), old_block->id(), {},
+                    merge_target->id());
 
   if (context()->AreAnalysesValid(IRContext::kAnalysisCFG)) {
     cfg()->RegisterBlock(old_block);
-    cfg()->RegisterBlock(header_block);
     cfg()->AddEdges(start_block);
   }
 }
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index f8edd27..fe85557 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -33,7 +33,8 @@
  * Structured control flow guarantees that the CFG will converge at a given
  * point (the merge block). Within structured control flow, all blocks must be
  * post-dominated by the merge block, except return blocks and break blocks.
- * A break block is a block that branches to the innermost loop's merge block.
+ * A break block is a block that branches to a containing construct's merge
+ * block.
  *
  * Beyond this, we further assume that all unreachable blocks have been
  * cleaned up.  This means that the only unreachable blocks are those necessary
@@ -46,13 +47,14 @@
  * with a branch. If current block is not within structured control flow, this
  * is the final return. This block should branch to the new return block (its
  * direct successor). If the current block is within structured control flow,
- * the branch destination should be the innermost loop's merge.  This loop will
- * always exist because a dummy loop is added around the entire function.
- * If the merge block produces any live values it will need to be predicated.
- * While the merge is nested in structured control flow, the predication path
- *should branch to the merge block of the inner-most loop it is contained in.
- *Once structured control flow has been exited, it will be at the merge of the
- *dummy loop, with will simply return.
+ * the branch destination should be the innermost construct's merge.  This
+ * merge will always exist because a dummy switch is added around the
+ * entire function. If the merge block produces any live values it will need to
+ * be predicated. While the merge is nested in structured control flow, the
+ * predication path should branch to the merge block of the inner-most loop
+ * (or switch if no loop) it is contained in. Once structured control flow has
+ * been exited, it will be at the merge of the dummy switch, which will simply
+ * return.
  *
  * In the final return block, the return value should be loaded and returned.
  * Memory promotion passes should be able to promote the newly introduced
@@ -71,7 +73,7 @@
  *         ||
  *         \/
  *
- *          0 (dummy loop header)
+ *          0 (dummy switch header)
  *          |
  *          1 (loop header)
  *         / \
@@ -81,11 +83,11 @@
  *        / \
  *        |  3 (original code in 3)
  *        \ /
- *   (ret) 4 (dummy loop merge)
+ *   (ret) 4 (dummy switch merge)
  *
  * In the above (simple) example, the return originally in |2| is passed through
- * the merge. That merge is predicated such that the old body of the block is
- * the else branch. The branch condition is based on the value of the "has
+ * the loop merge. That merge is predicated such that the old body of the block
+ * is the else branch. The branch condition is based on the value of the "has
  * returned" variable.
  *
  ******************************************************************************/
@@ -108,17 +110,17 @@
   }
 
  private:
-  // This class is used to store the a loop merge instruction and a selection
-  // merge instruction.  The intended use is that is represent the inner most
-  // contain selection construct and the inner most loop construct.
+  // This class is used to store the a break merge instruction and a current
+  // merge instruction.  The intended use is to keep track of the block to
+  // break to and the current innermost control flow construct merge block.
   class StructuredControlState {
    public:
-    StructuredControlState(Instruction* loop, Instruction* merge)
-        : loop_merge_(loop), current_merge_(merge) {}
+    StructuredControlState(Instruction* break_merge, Instruction* merge)
+        : break_merge_(break_merge), current_merge_(merge) {}
 
     StructuredControlState(const StructuredControlState&) = default;
 
-    bool InLoop() const { return loop_merge_; }
+    bool InBreakable() const { return break_merge_; }
     bool InStructuredFlow() const { return CurrentMergeId() != 0; }
 
     uint32_t CurrentMergeId() const {
@@ -132,20 +134,14 @@
                             : 0;
     }
 
-    uint32_t LoopMergeId() const {
-      return loop_merge_ ? loop_merge_->GetSingleWordInOperand(0u) : 0u;
+    uint32_t BreakMergeId() const {
+      return break_merge_ ? break_merge_->GetSingleWordInOperand(0u) : 0u;
     }
 
-    uint32_t CurrentLoopHeader() const {
-      return loop_merge_
-                 ? loop_merge_->context()->get_instr_block(loop_merge_)->id()
-                 : 0;
-    }
-
-    Instruction* LoopMergeInst() const { return loop_merge_; }
+    Instruction* BreakMergeInst() const { return break_merge_; }
 
    private:
-    Instruction* loop_merge_;
+    Instruction* break_merge_;
     Instruction* current_merge_;
   };
 
@@ -159,6 +155,9 @@
   void MergeReturnBlocks(Function* function,
                          const std::vector<BasicBlock*>& returnBlocks);
 
+  // Generate and push new control flow state if |block| contains a merge.
+  void GenerateState(BasicBlock* block);
+
   // Merges the return instruction in |function| so that it has a single return
   // statement.  It is assumed that |function| has structured control flow, and
   // that |return_blocks| is a list of all of the basic blocks in |function|
@@ -219,9 +218,9 @@
                        std::list<BasicBlock*>* order);
 
   // Add a conditional branch at the start of |block| that either jumps to
-  // the merge block of |loop_merge_inst| or the original code in |block|
+  // the merge block of |break_merge_inst| or the original code in |block|
   // depending on the value in |return_flag_|.  The continue target in
-  // |loop_merge_inst| will be updated if needed.
+  // |break_merge_inst| will be updated if needed.
   //
   // If new blocks that are created will be added to |order|.  This way a call
   // can traverse these new block in structured order.
@@ -230,7 +229,7 @@
   bool BreakFromConstruct(BasicBlock* block,
                           std::unordered_set<BasicBlock*>* predicated,
                           std::list<BasicBlock*>* order,
-                          Instruction* loop_merge_inst);
+                          Instruction* break_merge_inst);
 
   // Add an |OpReturn| or |OpReturnValue| to the end of |block|.  If an
   // |OpReturnValue| is needed, the return value is loaded from |return_value_|.
@@ -274,27 +273,28 @@
   void InsertAfterElement(BasicBlock* element, BasicBlock* new_element,
                           std::list<BasicBlock*>* list);
 
-  // Creates a single iteration loop around all of the exectuable code of the
-  // current function and returns after the loop is done. Sets
+  // Creates a single case switch around all of the exectuable code of the
+  // current function where the switch and case value are both zero and the
+  // default is the merge block. Returns after the switch is executed. Sets
   // |final_return_block_|.
-  void AddDummyLoopAroundFunction();
+  void AddDummySwitchAroundFunction();
 
   // Creates a new basic block that branches to |header_label_id|.  Returns the
   // new basic block.  The block will be the second last basic block in the
   // function.
   BasicBlock* CreateContinueTarget(uint32_t header_label_id);
 
-  // Creates a loop around the executable code of the function with
+  // Creates a one case switch around the executable code of the function with
   // |merge_target| as the merge node.
-  void CreateDummyLoop(BasicBlock* merge_target);
+  void CreateDummySwitch(BasicBlock* merge_target);
 
   // Returns true if |function| has an unreachable block that is not a continue
   // target that simply branches back to the header, or a merge block containing
   // 1 instruction which is OpUnreachable.
   bool HasNontrivialUnreachableBlocks(Function* function);
 
-  // A stack used to keep track of the innermost contain loop and selection
-  // constructs.
+  // A stack used to keep track of the break and current control flow construct
+  // merge blocks.
   std::vector<StructuredControlState> state_;
 
   // The current function being transformed.
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index 4118169..d16b65c 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -268,7 +268,7 @@
 ; CHECK: [[true:%\w+]] = OpConstantTrue
 ; CHECK: OpFunction
 ; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]]
-; CHECK: OpLoopMerge [[return_block:%\w+]]
+; CHECK: OpSelectionMerge [[return_block:%\w+]]
 ; CHECK: OpSelectionMerge [[merge_lab:%\w+]]
 ; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]]
 ; CHECK: [[if_lab]] = OpLabel
@@ -314,7 +314,7 @@
 ; CHECK: [[true:%\w+]] = OpConstantTrue
 ; CHECK: OpFunction
 ; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]]
-; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]]
+; CHECK: OpSelectionMerge [[dummy_loop_merge:%\w+]]
 ; CHECK: OpSelectionMerge [[merge_lab:%\w+]]
 ; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]]
 ; CHECK: [[if_lab]] = OpLabel
@@ -364,7 +364,7 @@
 ; CHECK: [[true:%\w+]] = OpConstantTrue
 ; CHECK: OpFunction
 ; CHECK: [[var:%\w+]] = OpVariable [[:%\w+]] Function [[false]]
-; CHECK: OpLoopMerge [[return_block:%\w+]]
+; CHECK: OpSelectionMerge [[return_block:%\w+]]
 ; CHECK: OpSelectionMerge [[merge_lab:%\w+]]
 ; CHECK: OpBranchConditional [[cond:%\w+]] [[if_lab:%\w+]] [[then_lab:%\w+]]
 ; CHECK: [[if_lab]] = OpLabel
@@ -411,7 +411,7 @@
   const std::string before =
       R"(
 ; CHECK: OpFunction
-; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]]
+; CHECK: OpSelectionMerge [[dummy_loop_merge:%\w+]]
 ; CHECK: OpLoopMerge [[loop_merge:%\w+]]
 ; CHECK: [[loop_merge]] = OpLabel
 ; CHECK: OpBranchConditional {{%\w+}} [[dummy_loop_merge]] [[old_code_path:%\w+]]
@@ -525,7 +525,7 @@
       R"(
 ; CHECK: OpFunction
 ; CHECK: [[ret_flag:%\w+]] = OpVariable %_ptr_Function_bool Function %false
-; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]]
+; CHECK: OpSelectionMerge [[dummy_loop_merge:%\w+]]
 ; CHECK: OpLoopMerge [[loop1_merge:%\w+]] {{%\w+}}
 ; CHECK-NEXT: OpBranchConditional {{%\w+}} [[if_lab:%\w+]] {{%\w+}}
 ; CHECK: [[if_lab]] = OpLabel
@@ -914,7 +914,7 @@
   const std::string test =
       R"(
 ; CHECK: OpFunction
-; CHECK: OpLoopMerge [[dummy_loop_merge:%\w+]]
+; CHECK: OpSelectionMerge [[dummy_loop_merge:%\w+]]
 ; CHECK: OpLoopMerge [[outer_loop_merge:%\w+]]
 ; CHECK: OpLoopMerge [[inner_loop_merge:%\w+]]
 ; CHECK: OpSelectionMerge
@@ -1150,7 +1150,6 @@
       R"(
 ; CHECK: OpFunction %void
 ; CHECK: OpLabel
-; CHECK: OpLabel
 ; CHECK: [[pre_header:%\w+]] = OpLabel
 ; CHECK: [[header:%\w+]] = OpLabel
 ; CHECK-NEXT: OpPhi %bool {{%\w+}} [[pre_header]] [[iv:%\w+]] [[continue:%\w+]]
@@ -1196,7 +1195,6 @@
       R"(
 ; CHECK: OpFunction %void
 ; CHECK: OpLabel
-; CHECK: OpLabel
 ; CHECK: [[pre_header:%\w+]] = OpLabel
 ; CHECK: [[header:%\w+]] = OpLabel
 ; CHECK-NEXT: OpPhi
@@ -1258,7 +1256,9 @@
 TEST_F(MergeReturnPassTest, GeneratePhiInOuterLoop) {
   const std::string before =
       R"(
-      ; CHECK: OpLoopMerge
+      ; CHECK: OpSelectionMerge
+      ; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+      ; CHECK-NEXT: [[def_bb1]] = OpLabel
       ; CHECK: OpLoopMerge [[merge:%\w+]] [[continue:%\w+]]
       ; CHECK: [[continue]] = OpLabel
       ; CHECK-NEXT: [[undef:%\w+]] = OpUndef
@@ -1322,7 +1322,9 @@
   // #2455: This test case triggers phi insertions that use previously inserted
   // phis. Without the fix, it fails to validate.
   const std::string spirv = R"(
-; CHECK: OpLoopMerge
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
 ; CHECK: OpLoopMerge
 ; CHECK: OpLoopMerge
 ; CHECK: OpLoopMerge [[merge:%\w+]]
@@ -1430,9 +1432,9 @@
 TEST_F(MergeReturnPassTest, InnerLoopMergeIsOuterLoopContinue) {
   const std::string before =
       R"(
-      ; CHECK: OpLoopMerge
-      ; CHECK-NEXT: OpBranch [[bb1:%\w+]]
-      ; CHECK: [[bb1]] = OpLabel
+      ; CHECK: OpSelectionMerge
+      ; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+      ; CHECK-NEXT: [[def_bb1]] = OpLabel
       ; CHECK-NEXT: OpBranch [[outer_loop_header:%\w+]]
       ; CHECK: [[outer_loop_header]] = OpLabel
       ; CHECK-NEXT: OpLoopMerge [[outer_loop_merge:%\w+]] [[outer_loop_continue:%\w+]] None
@@ -1481,7 +1483,9 @@
 TEST_F(MergeReturnPassTest, BreakFromLoopUseNoLongerDominated) {
   const std::string spirv = R"(
 ; CHECK: [[undef:%\w+]] = OpUndef
-; CHECK: OpLoopMerge
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
 ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
 ; CHECK-NEXT: OpBranch [[body:%\w+]]
 ; CHECK: [[body]] = OpLabel
@@ -1541,7 +1545,9 @@
 TEST_F(MergeReturnPassTest, TwoBreaksFromLoopUsesNoLongerDominated) {
   const std::string spirv = R"(
 ; CHECK: [[undef:%\w+]] = OpUndef
-; CHECK: OpLoopMerge
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
 ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]]
 ; CHECK-NEXT: OpBranch [[body:%\w+]]
 ; CHECK: [[body]] = OpLabel
@@ -1725,7 +1731,9 @@
   const std::string text =
       R"(
 ; CHECK: [[new_undef:%\w+]] = OpUndef %uint
-; CHECK: OpLoopMerge
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
 ; CHECK: OpLoopMerge [[merge1:%\w+]]
 ; CHECK: OpLoopMerge [[merge2:%\w+]]
 ; CHECK: [[merge1]] = OpLabel
@@ -1781,7 +1789,9 @@
   //  Add and use a phi in the second merge block from the return.
   const std::string text =
       R"(
-; CHECK: OpLoopMerge
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
 ; CHECK: OpLoopMerge [[merge_bb:%\w+]] [[continue_bb:%\w+]]
 ; CHECK: [[continue_bb]] = OpLabel
 ; CHECK-NEXT: [[val:%\w+]] = OpUndef %float
@@ -1831,6 +1841,91 @@
   SinglePassRunAndMatch<MergeReturnPass>(text, true);
 }
 
+TEST_F(MergeReturnPassTest, ReturnsInSwitch) {
+  //  Cannot branch directly to dummy switch merge block from original switch.
+  //  Must branch to merge block of original switch and then do predicated
+  //  branch to merge block of dummy switch.
+  const std::string text =
+      R"(
+; CHECK: OpSelectionMerge [[dummy_merge_bb:%\w+]]
+; CHECK-NEXT: OpSwitch {{%\w+}} [[def_bb1:%\w+]]
+; CHECK-NEXT: [[def_bb1]] = OpLabel
+; CHECK: OpSelectionMerge
+; CHECK-NEXT: OpSwitch {{%\w+}} [[inner_merge_bb:%\w+]] 0 {{%\w+}} 1 {{%\w+}}
+; CHECK: OpBranch [[inner_merge_bb]]
+; CHECK: OpBranch [[inner_merge_bb]]
+; CHECK-NEXT: [[inner_merge_bb]] = OpLabel
+; CHECK: OpBranchConditional {{%\w+}} [[dummy_merge_bb]] {{%\w+}}
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %PSMain "PSMain" %_entryPointOutput_color
+               OpExecutionMode %PSMain OriginUpperLeft
+               OpSource HLSL 500
+               OpMemberDecorate %cb 0 Offset 0
+               OpMemberDecorate %cb 1 Offset 16
+               OpMemberDecorate %cb 2 Offset 32
+               OpMemberDecorate %cb 3 Offset 48
+               OpDecorate %cb Block
+               OpDecorate %_ DescriptorSet 0
+               OpDecorate %_ Binding 0
+               OpDecorate %_entryPointOutput_color Location 0
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+          %8 = OpTypeFunction %v4float
+        %int = OpTypeInt 32 1
+         %cb = OpTypeStruct %v4float %v4float %v4float %int
+%_ptr_Uniform_cb = OpTypePointer Uniform %cb
+          %_ = OpVariable %_ptr_Uniform_cb Uniform
+      %int_3 = OpConstant %int 3
+%_ptr_Uniform_int = OpTypePointer Uniform %int
+      %int_0 = OpConstant %int 0
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+      %int_1 = OpConstant %int 1
+      %int_2 = OpConstant %int 2
+    %float_0 = OpConstant %float 0
+    %float_1 = OpConstant %float 1
+         %45 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_1
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%_entryPointOutput_color = OpVariable %_ptr_Output_v4float Output
+     %PSMain = OpFunction %void None %3
+          %5 = OpLabel
+         %50 = OpFunctionCall %v4float %BlendValue_
+               OpStore %_entryPointOutput_color %50
+               OpReturn
+               OpFunctionEnd
+%BlendValue_ = OpFunction %v4float None %8
+         %10 = OpLabel
+         %21 = OpAccessChain %_ptr_Uniform_int %_ %int_3
+         %22 = OpLoad %int %21
+               OpSelectionMerge %25 None
+               OpSwitch %22 %25 0 %23 1 %24
+         %23 = OpLabel
+         %28 = OpAccessChain %_ptr_Uniform_v4float %_ %int_0
+         %29 = OpLoad %v4float %28
+               OpReturnValue %29
+         %24 = OpLabel
+         %31 = OpAccessChain %_ptr_Uniform_v4float %_ %int_0
+         %32 = OpLoad %v4float %31
+         %34 = OpAccessChain %_ptr_Uniform_v4float %_ %int_1
+         %35 = OpLoad %v4float %34
+         %37 = OpAccessChain %_ptr_Uniform_v4float %_ %int_2
+         %38 = OpLoad %v4float %37
+         %39 = OpFMul %v4float %35 %38
+         %40 = OpFAdd %v4float %32 %39
+               OpReturnValue %40
+         %25 = OpLabel
+               OpReturnValue %45
+               OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(text, true);
+}
+
 TEST_F(MergeReturnPassTest, UnreachableMergeAndContinue) {
   // Make sure that the pass can handle a single block that is both a merge and
   // a continue.