spirv-fuzz: Fix handling of OpPhi in FlattenConditionalBranch (#3916)

Fixes #3915.
diff --git a/source/fuzz/transformation_flatten_conditional_branch.cpp b/source/fuzz/transformation_flatten_conditional_branch.cpp
index 9dead67..7fc636a 100644
--- a/source/fuzz/transformation_flatten_conditional_branch.cpp
+++ b/source/fuzz/transformation_flatten_conditional_branch.cpp
@@ -145,16 +145,8 @@
         current_block->terminator()->GetSingleWordInOperand(0);
   }
 
-  // Get the mapping from instructions to fresh ids.
-  auto insts_to_info = GetInstructionsToWrapperInfo(ir_context);
-
   auto branch_instruction = header_block->terminator();
 
-  // Get a reference to the last block in the first branch that will be laid out
-  // (this depends on |message_.true_branch_first|). The last block is the block
-  // in the branch just before flow converges (it might not exist).
-  opt::BasicBlock* last_block_first_branch = nullptr;
-
   // branch = 1 corresponds to the true branch, branch = 2 corresponds to the
   // false branch. If the true branch is to be laid out first, we need to visit
   // the false branch first, because each branch is moved to right after the
@@ -166,6 +158,68 @@
     branches = {1, 2};
   }
 
+  // Get the ids of the starting blocks of the first and last branches to be
+  // laid out. The first branch is the true branch iff
+  // |message_.true_branch_first| is true.
+  uint32_t first_block_first_branch_id =
+      branch_instruction->GetSingleWordInOperand(branches[1]);
+  uint32_t first_block_last_branch_id =
+      branch_instruction->GetSingleWordInOperand(branches[0]);
+
+  // If the OpBranchConditional instruction in the header branches to the same
+  // block for both values of the condition, this is the convergence block (the
+  // flow does not actually diverge) and the OpPhi instructions in it are still
+  // valid, so we do not need to make any changes.
+  if (first_block_first_branch_id != first_block_last_branch_id) {
+    // Replace all of the current OpPhi instructions in the convergence block
+    // with OpSelect.
+    ir_context->get_instr_block(convergence_block_id)
+        ->ForEachPhiInst([branch_instruction, header_block,
+                          ir_context](opt::Instruction* phi_inst) {
+          assert(phi_inst->NumInOperands() == 4 &&
+                 "We are going to replace an OpPhi with an OpSelect.  This "
+                 "only makes sense if the block has two distinct "
+                 "predecessors.");
+          // The OpPhi takes values from two distinct predecessors.  One
+          // predecessor is associated with the "true" path of the conditional
+          // we are flattening, the other with the "false" path, but these
+          // predecessors can appear in either order as operands to the OpPhi
+          // instruction.
+
+          std::vector<opt::Operand> operands;
+          operands.emplace_back(branch_instruction->GetInOperand(0));
+
+          uint32_t branch_instruction_true_block_id =
+              branch_instruction->GetSingleWordInOperand(1);
+
+          if (ir_context->GetDominatorAnalysis(header_block->GetParent())
+                  ->Dominates(branch_instruction_true_block_id,
+                              phi_inst->GetSingleWordInOperand(1))) {
+            // The "true" branch is handled first in the OpPhi's operands; we
+            // thus provide operands to OpSelect in the same order that they
+            // appear in the OpPhi.
+            operands.emplace_back(phi_inst->GetInOperand(0));
+            operands.emplace_back(phi_inst->GetInOperand(2));
+          } else {
+            // The "false" branch is handled first in the OpPhi's operands; we
+            // thus provide operands to OpSelect in reverse of the order that
+            // they appear in the OpPhi.
+            operands.emplace_back(phi_inst->GetInOperand(2));
+            operands.emplace_back(phi_inst->GetInOperand(0));
+          }
+          phi_inst->SetOpcode(SpvOpSelect);
+          phi_inst->SetInOperands(std::move(operands));
+        });
+  }
+
+  // Get the mapping from instructions to fresh ids.
+  auto insts_to_info = GetInstructionsToWrapperInfo(ir_context);
+
+  // Get a reference to the last block in the first branch that will be laid out
+  // (this depends on |message_.true_branch_first|). The last block is the block
+  // in the branch just before flow converges (it might not exist).
+  opt::BasicBlock* last_block_first_branch = nullptr;
+
   // Keep track of blocks and ids for which we should later add dead block and
   // irrelevant id facts.  We wait until we have finished applying the
   // transformation before adding these facts, so that the fact manager has
@@ -271,21 +325,6 @@
     }
   }
 
-  // Get the condition operand and the ids of the starting blocks of the first
-  // and last branches to be laid out. The first branch is the true branch iff
-  // |message_.true_branch_first| is true.
-  auto condition_operand = branch_instruction->GetInOperand(0);
-  uint32_t first_block_first_branch_id =
-      branch_instruction->GetSingleWordInOperand(branches[1]);
-  uint32_t first_block_last_branch_id =
-      branch_instruction->GetSingleWordInOperand(branches[0]);
-
-  // Record the block that will be reached if the branch condition is true.
-  // This information is needed later to determine how to rewrite OpPhi
-  // instructions as OpSelect instructions at the branch's convergence point.
-  uint32_t branch_instruction_true_block_id =
-      branch_instruction->GetSingleWordInOperand(1);
-
   // The current header should unconditionally branch to the starting block in
   // the first branch to be laid out, if such a branch exists (i.e. the header
   // does not branch directly to the convergence block), and to the starting
@@ -330,51 +369,6 @@
     }
   }
 
-  // If the OpBranchConditional instruction in the header branches to the same
-  // block for both values of the condition, this is the convergence block (the
-  // flow does not actually diverge) and the OpPhi instructions in it are still
-  // valid, so we do not need to make any changes.
-  if (first_block_first_branch_id != first_block_last_branch_id) {
-    // Replace all of the current OpPhi instructions in the convergence block
-    // with OpSelect.
-
-    ir_context->get_instr_block(convergence_block_id)
-        ->ForEachPhiInst([branch_instruction_true_block_id, &condition_operand,
-                          header_block,
-                          ir_context](opt::Instruction* phi_inst) {
-          assert(phi_inst->NumInOperands() == 4 &&
-                 "We are going to replace an OpPhi with an OpSelect.  This "
-                 "only makes sense if the block has two distinct "
-                 "predecessors.");
-          // The OpPhi takes values from two distinct predecessors.  One
-          // predecessor is associated with the "true" path of the conditional
-          // we are flattening, the other with the "false" path, but these
-          // predecessors can appear in either order as operands to the OpPhi
-          // instruction.
-
-          std::vector<opt::Operand> operands;
-          operands.emplace_back(condition_operand);
-
-          if (ir_context->GetDominatorAnalysis(header_block->GetParent())
-                  ->Dominates(branch_instruction_true_block_id,
-                              phi_inst->GetSingleWordInOperand(1))) {
-            // The "true" branch is handled first in the OpPhi's operands; we
-            // thus provide operands to OpSelect in the same order that they
-            // appear in the OpPhi.
-            operands.emplace_back(phi_inst->GetInOperand(0));
-            operands.emplace_back(phi_inst->GetInOperand(2));
-          } else {
-            // The "false" branch is handled first in the OpPhi's operands; we
-            // thus provide operands to OpSelect in reverse of the order that
-            // they appear in the OpPhi.
-            operands.emplace_back(phi_inst->GetInOperand(2));
-            operands.emplace_back(phi_inst->GetInOperand(0));
-          }
-          phi_inst->SetOpcode(SpvOpSelect);
-          phi_inst->SetInOperands(std::move(operands));
-        });
-  }
-
   // Invalidate all analyses
   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
 
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index 579b696..1bb8ee8 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -1066,6 +1066,108 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect5) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+        %100 = OpTypePointer Function %4
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+        %101 = OpVariable %100 Function
+        %102 = OpVariable %100 Function
+               OpSelectionMerge %470 None
+               OpBranchConditional %5 %454 %462
+        %454 = OpLabel
+        %522 = OpLoad %4 %101
+               OpBranch %470
+        %462 = OpLabel
+        %466 = OpLoad %4 %102
+               OpBranch %470
+        %470 = OpLabel
+        %534 = OpPhi %4 %522 %454 %466 %462
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(
+      7, true,
+      {MakeSideEffectWrapperInfo(MakeInstructionDescriptor(522, SpvOpLoad, 0),
+                                 200, 201, 202, 203, 204, 5),
+       MakeSideEffectWrapperInfo(MakeInstructionDescriptor(466, SpvOpLoad, 0),
+                                 300, 301, 302, 303, 304, 5)});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+        %100 = OpTypePointer Function %4
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+        %101 = OpVariable %100 Function
+        %102 = OpVariable %100 Function
+               OpBranch %454
+        %454 = OpLabel
+               OpSelectionMerge %200 None
+               OpBranchConditional %5 %201 %203
+        %201 = OpLabel
+        %202 = OpLoad %4 %101
+               OpBranch %200
+        %203 = OpLabel
+        %204 = OpCopyObject %4 %5
+               OpBranch %200
+        %200 = OpLabel
+        %522 = OpPhi %4 %202 %201 %204 %203
+               OpBranch %462
+        %462 = OpLabel
+               OpSelectionMerge %300 None
+               OpBranchConditional %5 %303 %301
+        %301 = OpLabel
+        %302 = OpLoad %4 %102
+               OpBranch %300
+        %303 = OpLabel
+        %304 = OpCopyObject %4 %5
+               OpBranch %300
+        %300 = OpLabel
+        %466 = OpPhi %4 %302 %301 %304 %303
+               OpBranch %470
+        %470 = OpLabel
+        %534 = OpSelect %4 %5 %522 %466
+               OpReturn
+               OpFunctionEnd
+)";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 TEST(TransformationFlattenConditionalBranchTest,
      LoadFromBufferBlockDecoratedStruct) {
   std::string shader = R"(