Fix invalid OpPhi generated by merge-return. (#2172)

* Fix invalid OpPhi generated by merge-return.

When we create a new phi node for a value say %10, we have to replace
all of the uses of %10 that are no longer dominated by the def of %10
by the result id of the new phi.  However, if the use is in a phi node,
it is possible that the bb contains the use is not dominated by either.
In this case, needs to be handled differently.

* Split loop headers before add a new branch to them.

In merge return, Phi node in loop header that are also merges for loop
do not get updated correctly.  Those cases do not fit in with our
current analysis.  Doing this will simplify the code by reducing the
number of cases that have to be handled.
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index d0434f2..a13540c 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -206,7 +206,11 @@
     RecordReturned(block);
     RecordReturnValue(block);
   }
+
   BasicBlock* target_block = context()->get_instr_block(target);
+  if (target_block->GetLoopMergeInst()) {
+    cfg()->SplitLoopHeader(target_block);
+  }
   UpdatePhiNodes(block, target_block);
 
   Instruction* return_inst = block->terminator();
@@ -241,8 +245,22 @@
   if (inst.result_id() != 0) {
     std::vector<Instruction*> users_to_update;
     context()->get_def_use_mgr()->ForEachUser(
-        &inst, [&users_to_update, &dom_tree, inst_bb, this](Instruction* user) {
-          BasicBlock* user_bb = context()->get_instr_block(user);
+        &inst,
+        [&users_to_update, &dom_tree, &inst, inst_bb, this](Instruction* user) {
+          BasicBlock* user_bb = nullptr;
+          if (user->opcode() != SpvOpPhi) {
+            user_bb = context()->get_instr_block(user);
+          } else {
+            // For OpPhi, the use should be considered to be in the predecessor.
+            for (uint32_t i = 0; i < user->NumInOperands(); i += 2) {
+              if (user->GetSingleWordInOperand(i) == inst.result_id()) {
+                uint32_t user_bb_id = user->GetSingleWordInOperand(i + 1);
+                user_bb = context()->get_instr_block(user_bb_id);
+                break;
+              }
+            }
+          }
+
           // If |user_bb| is nullptr, then |user| is not in the function.  It is
           // something like an OpName or decoration, which should not be
           // replaced with the result of the OpPhi.
@@ -362,6 +380,9 @@
       return false;
     }
   }
+  if (merge_block->GetLoopMergeInst()) {
+    cfg()->SplitLoopHeader(merge_block);
+  }
 
   // Leave the phi instructions behind.
   auto iter = block->begin();
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index 7f2c058..e49fec9 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1145,6 +1145,115 @@
   EXPECT_TRUE(messages.empty());
 }
 
+TEST_F(MergeReturnPassTest, StructuredControlFlowDontChangeEntryPhi) {
+  const std::string before =
+      R"(
+; CHECK: OpFunction %void
+; CHECK: OpLabel
+; CHECK: OpLabel
+; CHECK: [[pre_header:%\w+]] = OpLabel
+; CHECK: [[header:%\w+]] = OpLabel
+; CHECK-NEXT: OpPhi %bool {{%\w+}} [[pre_header]] [[iv:%\w+]] [[continue:%\w+]]
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]] [[continue]]
+; CHECK: [[continue]] = OpLabel
+; CHECK-NEXT: [[iv]] = Op
+; CHECK: [[merge]] = OpLabel
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %1 "main"
+       %void = OpTypeVoid
+       %bool = OpTypeBool
+          %4 = OpTypeFunction %void
+          %1 = OpFunction %void None %4
+          %5 = OpLabel
+          %6 = OpUndef %bool
+               OpBranch %7
+          %7 = OpLabel
+          %8 = OpPhi %bool %6 %5 %9 %10
+               OpLoopMerge %11 %10 None
+               OpBranch %12
+         %12 = OpLabel
+         %13 = OpUndef %bool
+               OpSelectionMerge %10 DontFlatten
+               OpBranchConditional %13 %10 %14
+         %14 = OpLabel
+               OpReturn
+         %10 = OpLabel
+          %9 = OpUndef %bool
+               OpBranchConditional %13 %7 %11
+         %11 = OpLabel
+               OpReturn
+               OpFunctionEnd
+
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(before, false);
+}
+
+TEST_F(MergeReturnPassTest, StructuredControlFlowPartialReplacePhi) {
+  const std::string before =
+      R"(
+; CHECK: OpFunction %void
+; CHECK: OpLabel
+; CHECK: OpLabel
+; CHECK: [[pre_header:%\w+]] = OpLabel
+; CHECK: [[header:%\w+]] = OpLabel
+; CHECK-NEXT: OpPhi
+; CHECK-NEXT: OpLoopMerge [[merge:%\w+]]
+; CHECK: OpLabel
+; CHECK: [[old_ret_block:%\w+]] = OpLabel
+; CHECK: [[bb:%\w+]] = OpLabel
+; CHECK-NEXT: [[val:%\w+]] = OpUndef %bool
+; CHECK: [[merge]] = OpLabel
+; CHECK-NEXT: [[phi1:%\w+]] = OpPhi %bool [[val]] [[bb]] {{%\w+}} [[old_ret_block]]
+; CHECK: OpBranchConditional {{%\w+}} {{%\w+}} [[bb2:%\w+]]
+; CHECK: [[bb2]] = OpLabel
+; CHECK: OpBranch [[header2:%\w+]]
+; CHECK: [[header2]] = OpLabel
+; CHECK-NEXT: [[phi2:%\w+]] = OpPhi %bool [[phi1]] [[continue2:%\w+]] [[phi1]] [[bb2]]
+; CHECK-NEXT: OpLoopMerge {{%\w+}} [[continue2]]
+; CHECK: [[continue2]] = OpLabel
+; CHECK-NEXT: OpBranch [[header2]]
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %1 "main"
+       %void = OpTypeVoid
+       %bool = OpTypeBool
+          %4 = OpTypeFunction %void
+          %1 = OpFunction %void None %4
+          %5 = OpLabel
+          %6 = OpUndef %bool
+               OpBranch %7
+          %7 = OpLabel
+          %8 = OpPhi %bool %6 %5 %9 %10
+               OpLoopMerge %11 %10 None
+               OpBranch %12
+         %12 = OpLabel
+         %13 = OpUndef %bool
+               OpSelectionMerge %10 DontFlatten
+               OpBranchConditional %13 %10 %14
+         %14 = OpLabel
+               OpReturn
+         %10 = OpLabel
+          %9 = OpUndef %bool
+               OpBranchConditional %13 %7 %11
+         %11 = OpLabel
+          %phi = OpPhi %bool %9 %10 %9 %cont
+               OpLoopMerge %ret %cont None
+               OpBranch %bb
+         %bb = OpLabel
+               OpBranchConditional %13 %ret %cont
+         %cont = OpLabel
+               OpBranch %11
+         %ret = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndMatch<MergeReturnPass>(before, false);
+}
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools