spirv-fuzz: Fix flatten conditional branch transformation (#3859)

Fixes #3850.
diff --git a/source/fuzz/transformation_flatten_conditional_branch.cpp b/source/fuzz/transformation_flatten_conditional_branch.cpp
index bd814b7..9dead67 100644
--- a/source/fuzz/transformation_flatten_conditional_branch.cpp
+++ b/source/fuzz/transformation_flatten_conditional_branch.cpp
@@ -280,6 +280,12 @@
   uint32_t first_block_last_branch_id =
       branch_instruction->GetSingleWordInOperand(branches[0]);
 
+  // Record the block that will be reached if the branch condition is true.
+  // This information is needed later to determine how to rewrite OpPhi
+  // instructions as OpSelect instructions at the branch's convergence point.
+  uint32_t branch_instruction_true_block_id =
+      branch_instruction->GetSingleWordInOperand(1);
+
   // The current header should unconditionally branch to the starting block in
   // the first branch to be laid out, if such a branch exists (i.e. the header
   // does not branch directly to the convergence block), and to the starting
@@ -333,16 +339,38 @@
     // with OpSelect.
 
     ir_context->get_instr_block(convergence_block_id)
-        ->ForEachPhiInst([&condition_operand](opt::Instruction* phi_inst) {
-          phi_inst->SetOpcode(SpvOpSelect);
+        ->ForEachPhiInst([branch_instruction_true_block_id, &condition_operand,
+                          header_block,
+                          ir_context](opt::Instruction* phi_inst) {
+          assert(phi_inst->NumInOperands() == 4 &&
+                 "We are going to replace an OpPhi with an OpSelect.  This "
+                 "only makes sense if the block has two distinct "
+                 "predecessors.");
+          // The OpPhi takes values from two distinct predecessors.  One
+          // predecessor is associated with the "true" path of the conditional
+          // we are flattening, the other with the "false" path, but these
+          // predecessors can appear in either order as operands to the OpPhi
+          // instruction.
+
           std::vector<opt::Operand> operands;
           operands.emplace_back(condition_operand);
-          // Only consider the operands referring to the instructions ids, as
-          // the block labels are not necessary anymore.
-          for (uint32_t i = 0; i < phi_inst->NumInOperands(); i += 2) {
-            operands.emplace_back(phi_inst->GetInOperand(i));
-          }
 
+          if (ir_context->GetDominatorAnalysis(header_block->GetParent())
+                  ->Dominates(branch_instruction_true_block_id,
+                              phi_inst->GetSingleWordInOperand(1))) {
+            // The "true" branch is handled first in the OpPhi's operands; we
+            // thus provide operands to OpSelect in the same order that they
+            // appear in the OpPhi.
+            operands.emplace_back(phi_inst->GetInOperand(0));
+            operands.emplace_back(phi_inst->GetInOperand(2));
+          } else {
+            // The "false" branch is handled first in the OpPhi's operands; we
+            // thus provide operands to OpSelect in reverse of the order that
+            // they appear in the OpPhi.
+            operands.emplace_back(phi_inst->GetInOperand(2));
+            operands.emplace_back(phi_inst->GetInOperand(0));
+          }
+          phi_inst->SetOpcode(SpvOpSelect);
           phi_inst->SetInOperands(std::move(operands));
         });
   }
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index f36ecbe..d77173d 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -791,6 +791,274 @@
   ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
 }
 
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect1) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpSelectionMerge %8 None
+               OpBranchConditional %5 %9 %8
+          %9 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %4 %5 %9 %10 %7
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpSelect %4 %5 %5 %10
+               OpReturn
+               OpFunctionEnd
+)";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect2) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpSelectionMerge %8 None
+               OpBranchConditional %5 %9 %8
+          %9 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %4 %10 %7 %5 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpSelect %4 %5 %5 %10
+               OpReturn
+               OpFunctionEnd
+)";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect3) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpSelectionMerge %8 None
+               OpBranchConditional %5 %9 %12
+          %9 = OpLabel
+               OpBranch %8
+         %12 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %4 %10 %12 %5 %9
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpBranch %12
+         %12 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpSelect %4 %5 %5 %10
+               OpReturn
+               OpFunctionEnd
+)";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, PhiToSelect4) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpSelectionMerge %8 None
+               OpBranchConditional %5 %9 %12
+          %9 = OpLabel
+               OpBranch %8
+         %12 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpPhi %4 %5 %9 %10 %12
+               OpReturn
+               OpFunctionEnd
+)";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_5;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  spvtools::ValidatorOptions validator_options;
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  auto transformation = TransformationFlattenConditionalBranch(7, true, {});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+  ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
+  ASSERT_TRUE(IsValid(env, context.get()));
+
+  std::string after_transformation = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+               OpSource ESSL 310
+          %3 = OpTypeVoid
+          %4 = OpTypeBool
+          %5 = OpConstantTrue %4
+         %10 = OpConstantFalse %4
+          %6 = OpTypeFunction %3
+          %2 = OpFunction %3 None %6
+          %7 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpBranch %12
+         %12 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+         %11 = OpSelect %4 %5 %5 %10
+               OpReturn
+               OpFunctionEnd
+)";
+  ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools