spirv-fuzz: Fix bugs in TransformationFlattenConditionalBranch (#4006)

Fixes #4005.
Fixes #3993.
diff --git a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
index c7c2933..1e21aa5 100644
--- a/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
+++ b/source/fuzz/fuzzer_pass_flatten_conditional_branches.cpp
@@ -71,7 +71,8 @@
       // Do not consider this header if the conditional cannot be flattened.
       if (!TransformationFlattenConditionalBranch::
               GetProblematicInstructionsIfConditionalCanBeFlattened(
-                  GetIRContext(), header, &instructions_that_need_ids)) {
+                  GetIRContext(), header, *GetTransformationContext(),
+                  &instructions_that_need_ids)) {
         continue;
       }
 
@@ -214,7 +215,7 @@
           }
         }
 
-        wrappers_info.emplace_back(wrapper_info);
+        wrappers_info.push_back(std::move(wrapper_info));
       }
 
       // Apply the transformation, evenly choosing whether to lay out the true
diff --git a/source/fuzz/transformation_flatten_conditional_branch.cpp b/source/fuzz/transformation_flatten_conditional_branch.cpp
index bad4972..dec933c 100644
--- a/source/fuzz/transformation_flatten_conditional_branch.cpp
+++ b/source/fuzz/transformation_flatten_conditional_branch.cpp
@@ -82,65 +82,6 @@
     }
   }
 
-  if (OpSelectArgumentsAreRestricted(ir_context)) {
-    // OpPhi instructions at the convergence block for the selection are handled
-    // by turning them into OpSelect instructions.  As the SPIR-V version in use
-    // has restrictions on the arguments that OpSelect can take, we must check
-    // that any OpPhi instructions are compatible with these restrictions.
-    uint32_t convergence_block_id =
-        FindConvergenceBlock(ir_context, *header_block);
-    // Consider every OpPhi instruction at the convergence block.
-    if (!ir_context->cfg()
-             ->block(convergence_block_id)
-             ->WhileEachPhiInst([this,
-                                 ir_context](opt::Instruction* inst) -> bool {
-               // Decide whether the OpPhi can be handled based on its result
-               // type.
-               opt::Instruction* phi_result_type =
-                   ir_context->get_def_use_mgr()->GetDef(inst->type_id());
-               switch (phi_result_type->opcode()) {
-                 case SpvOpTypeBool:
-                 case SpvOpTypeInt:
-                 case SpvOpTypeFloat:
-                 case SpvOpTypePointer:
-                   // Fine: OpSelect can work directly on scalar and pointer
-                   // types.
-                   return true;
-                 case SpvOpTypeVector: {
-                   // In its restricted form, OpSelect can only select between
-                   // vectors if the condition of the select is a boolean
-                   // boolean vector.  We thus require the appropriate boolean
-                   // vector type to be present.
-                   uint32_t bool_type_id =
-                       fuzzerutil::MaybeGetBoolType(ir_context);
-                   uint32_t dimension =
-                       phi_result_type->GetSingleWordInOperand(1);
-                   if (fuzzerutil::MaybeGetVectorType(ir_context, bool_type_id,
-                                                      dimension) == 0) {
-                     // The required boolean vector type is not present.
-                     return false;
-                   }
-                   // The transformation needs to be equipped with a fresh id
-                   // in which to store the vectorized version of the selection
-                   // construct's condition.
-                   switch (dimension) {
-                     case 2:
-                       return message_.fresh_id_for_bvec2_selector() != 0;
-                     case 3:
-                       return message_.fresh_id_for_bvec3_selector() != 0;
-                     default:
-                       assert(dimension == 4 && "Invalid vector dimension.");
-                       return message_.fresh_id_for_bvec4_selector() != 0;
-                   }
-                 }
-                 default:
-                   return false;
-               }
-             })) {
-      return false;
-    }
-  }
-
   // Use a set to keep track of the instructions that require fresh ids.
   std::set<opt::Instruction*> instructions_that_need_ids;
 
@@ -148,7 +89,8 @@
   // if so, add all the problematic instructions that need to be enclosed inside
   // conditionals to |instructions_that_need_ids|.
   if (!GetProblematicInstructionsIfConditionalCanBeFlattened(
-          ir_context, header_block, &instructions_that_need_ids)) {
+          ir_context, header_block, transformation_context,
+          &instructions_that_need_ids)) {
     return false;
   }
 
@@ -205,6 +147,69 @@
     }
   }
 
+  if (OpSelectArgumentsAreRestricted(ir_context)) {
+    // OpPhi instructions at the convergence block for the selection are handled
+    // by turning them into OpSelect instructions.  As the SPIR-V version in use
+    // has restrictions on the arguments that OpSelect can take, we must check
+    // that any OpPhi instructions are compatible with these restrictions.
+    uint32_t convergence_block_id =
+        FindConvergenceBlock(ir_context, *header_block);
+    // Consider every OpPhi instruction at the convergence block.
+    if (!ir_context->cfg()
+             ->block(convergence_block_id)
+             ->WhileEachPhiInst([this,
+                                 ir_context](opt::Instruction* inst) -> bool {
+               // Decide whether the OpPhi can be handled based on its result
+               // type.
+               opt::Instruction* phi_result_type =
+                   ir_context->get_def_use_mgr()->GetDef(inst->type_id());
+               switch (phi_result_type->opcode()) {
+                 case SpvOpTypeBool:
+                 case SpvOpTypeInt:
+                 case SpvOpTypeFloat:
+                 case SpvOpTypePointer:
+                   // Fine: OpSelect can work directly on scalar and pointer
+                   // types.
+                   return true;
+                 case SpvOpTypeVector: {
+                   // In its restricted form, OpSelect can only select between
+                   // vectors if the condition of the select is a boolean
+                   // boolean vector.  We thus require the appropriate boolean
+                   // vector type to be present.
+                   uint32_t bool_type_id =
+                       fuzzerutil::MaybeGetBoolType(ir_context);
+                   if (!bool_type_id) {
+                     return false;
+                   }
+
+                   uint32_t dimension =
+                       phi_result_type->GetSingleWordInOperand(1);
+                   if (fuzzerutil::MaybeGetVectorType(ir_context, bool_type_id,
+                                                      dimension) == 0) {
+                     // The required boolean vector type is not present.
+                     return false;
+                   }
+                   // The transformation needs to be equipped with a fresh id
+                   // in which to store the vectorized version of the selection
+                   // construct's condition.
+                   switch (dimension) {
+                     case 2:
+                       return message_.fresh_id_for_bvec2_selector() != 0;
+                     case 3:
+                       return message_.fresh_id_for_bvec3_selector() != 0;
+                     default:
+                       assert(dimension == 4 && "Invalid vector dimension.");
+                       return message_.fresh_id_for_bvec4_selector() != 0;
+                   }
+                 }
+                 default:
+                   return false;
+               }
+             })) {
+      return false;
+    }
+  }
+
   // All checks were passed.
   return true;
 }
@@ -428,6 +433,7 @@
 bool TransformationFlattenConditionalBranch::
     GetProblematicInstructionsIfConditionalCanBeFlattened(
         opt::IRContext* ir_context, opt::BasicBlock* header,
+        const TransformationContext& transformation_context,
         std::set<opt::Instruction*>* instructions_that_need_ids) {
   uint32_t merge_block_id = header->MergeBlockIdIfAny();
   assert(merge_block_id &&
@@ -441,6 +447,11 @@
   auto postdominator_analysis =
       ir_context->GetPostDominatorAnalysis(enclosing_function);
 
+  // |header| must be reachable.
+  if (!dominator_analysis->IsReachable(header)) {
+    return false;
+  }
+
   // Check that the header and the merge block describe a single-entry,
   // single-exit region.
   if (!dominator_analysis->Dominates(header->id(), merge_block_id) ||
@@ -454,13 +465,22 @@
   //  - they branch unconditionally to another block
   //  Add any side-effecting instruction, requiring fresh ids, to
   //  |instructions_that_need_ids|
-  std::list<uint32_t> to_check;
+  std::queue<uint32_t> to_check;
   header->ForEachSuccessorLabel(
-      [&to_check](uint32_t label) { to_check.push_back(label); });
+      [&to_check](uint32_t label) { to_check.push(label); });
 
+  auto* structured_cfg = ir_context->GetStructuredCFGAnalysis();
   while (!to_check.empty()) {
     uint32_t block_id = to_check.front();
-    to_check.pop_front();
+    to_check.pop();
+
+    if (structured_cfg->ContainingConstruct(block_id) != header->id() &&
+        block_id != merge_block_id) {
+      // This block can be reached from the |header| but doesn't belong to its
+      // selection construct. This might be a continue target of some loop -
+      // we can't flatten the |header|.
+      return false;
+    }
 
     // If the block post-dominates the header, this is where flow converges, and
     // we don't need to check this branch any further, because the
@@ -470,6 +490,15 @@
       continue;
     }
 
+    if (!transformation_context.GetFactManager()->BlockIsDead(header->id()) &&
+        transformation_context.GetFactManager()->BlockIsDead(block_id)) {
+      // The |header| is not dead but the |block_id| is. Since |block_id|
+      // doesn't postdominate the |header|, CFG hasn't converged yet. Thus, we
+      // don't flatten the construct to prevent |block_id| from becoming
+      // executable.
+      return false;
+    }
+
     auto block = ir_context->cfg()->block(block_id);
 
     // The block must not have a merge instruction, because inner constructs are
@@ -518,7 +547,7 @@
 
     // Add the successor of this block to the list of blocks that need to be
     // checked.
-    to_check.push_back(block->terminator()->GetSingleWordInOperand(0));
+    to_check.push(block->terminator()->GetSingleWordInOperand(0));
   }
 
   // All the blocks are compatible with the transformation and this is indeed a
@@ -564,7 +593,7 @@
     opt::Instruction* instruction,
     const protobufs::SideEffectWrapperInfo& wrapper_info, uint32_t condition_id,
     bool exec_if_cond_true, std::vector<uint32_t>* dead_blocks,
-    std::vector<uint32_t>* irrelevant_ids) const {
+    std::vector<uint32_t>* irrelevant_ids) {
   // Get the next instruction (it will be useful for splitting).
   auto next_instruction = instruction->NextNode();
 
@@ -810,7 +839,7 @@
 void TransformationFlattenConditionalBranch::AddBooleanVectorConstructorToBlock(
     uint32_t fresh_id, uint32_t dimension,
     const opt::Operand& branch_condition_operand, opt::IRContext* ir_context,
-    opt::BasicBlock* block) const {
+    opt::BasicBlock* block) {
   opt::Instruction::OperandList in_operands;
   for (uint32_t i = 0; i < dimension; i++) {
     in_operands.emplace_back(branch_condition_operand);
diff --git a/source/fuzz/transformation_flatten_conditional_branch.h b/source/fuzz/transformation_flatten_conditional_branch.h
index 2d5e8d7..e8cb414 100644
--- a/source/fuzz/transformation_flatten_conditional_branch.h
+++ b/source/fuzz/transformation_flatten_conditional_branch.h
@@ -72,6 +72,7 @@
   // instructions are OpSelectionMerge and OpBranchConditional.
   static bool GetProblematicInstructionsIfConditionalCanBeFlattened(
       opt::IRContext* ir_context, opt::BasicBlock* header,
+      const TransformationContext& transformation_context,
       std::set<opt::Instruction*>* instructions_that_need_ids);
 
   // Returns true iff the given instruction needs a placeholder to be enclosed
@@ -117,14 +118,14 @@
   // |dead_blocks| and |irrelevant_ids| are used to record the ids of blocks
   // and instructions for which dead block and irrelevant id facts should
   // ultimately be created.
-  opt::BasicBlock* EncloseInstructionInConditional(
+  static opt::BasicBlock* EncloseInstructionInConditional(
       opt::IRContext* ir_context,
       const TransformationContext& transformation_context,
       opt::BasicBlock* block, opt::Instruction* instruction,
       const protobufs::SideEffectWrapperInfo& wrapper_info,
       uint32_t condition_id, bool exec_if_cond_true,
       std::vector<uint32_t>* dead_blocks,
-      std::vector<uint32_t>* irrelevant_ids) const;
+      std::vector<uint32_t>* irrelevant_ids);
 
   // Turns every OpPhi instruction of |convergence_block| -- the convergence
   // block for |header_block| (both in |ir_context|) into an OpSelect
@@ -137,10 +138,10 @@
   // |ir_context|, with result id given by |fresh_id|.  The instruction will
   // make a |dimension|-dimensional boolean vector with
   // |branch_condition_operand| at every component.
-  void AddBooleanVectorConstructorToBlock(
+  static void AddBooleanVectorConstructorToBlock(
       uint32_t fresh_id, uint32_t dimension,
       const opt::Operand& branch_condition_operand, opt::IRContext* ir_context,
-      opt::BasicBlock* block) const;
+      opt::BasicBlock* block);
 
   // Returns true if the given instruction either has no side effects or it can
   // be handled by being enclosed in a conditional.
diff --git a/test/fuzz/transformation_flatten_conditional_branch_test.cpp b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
index 0aaf20e..e0697d4 100644
--- a/test/fuzz/transformation_flatten_conditional_branch_test.cpp
+++ b/test/fuzz/transformation_flatten_conditional_branch_test.cpp
@@ -2041,6 +2041,100 @@
   ASSERT_TRUE(IsEqual(env, expected, context.get()));
 }
 
+TEST(TransformationFlattenConditionalBranchTest, ContainsDeadBlocksTest) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpName %4 "main"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeBool
+          %7 = OpConstantFalse %6
+          %4 = OpFunction %2 None %3
+          %5 = OpLabel
+               OpSelectionMerge %9 None
+               OpBranchConditional %7 %8 %9
+          %8 = OpLabel
+         %10 = OpCopyObject %6 %7
+               OpBranch %9
+          %9 = OpLabel
+         %11 = OpPhi %6 %10 %8 %7 %5
+         %12 = OpPhi %6 %7 %5 %10 %8
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  spvtools::ValidatorOptions validator_options;
+  ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+                                               kConsoleMessageConsumer));
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  TransformationFlattenConditionalBranch transformation(5, true, 0, 0, 0, {});
+  ASSERT_TRUE(
+      transformation.IsApplicable(context.get(), transformation_context));
+
+  transformation_context.GetFactManager()->AddFactBlockIsDead(8);
+
+  ASSERT_FALSE(
+      transformation.IsApplicable(context.get(), transformation_context));
+}
+
+TEST(TransformationFlattenConditionalBranchTest, ContainsContinueBlockTest) {
+  std::string shader = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %4 "main"
+               OpExecutionMode %4 OriginUpperLeft
+               OpSource ESSL 320
+               OpName %4 "main"
+          %2 = OpTypeVoid
+          %3 = OpTypeFunction %2
+          %6 = OpTypeBool
+          %7 = OpConstantFalse %6
+          %4 = OpFunction %2 None %3
+         %12 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+               OpLoopMerge %15 %14 None
+               OpBranchConditional %7 %5 %15
+          %5 = OpLabel
+               OpSelectionMerge %11 None
+               OpBranchConditional %7 %9 %10
+          %9 = OpLabel
+               OpBranch %11
+         %10 = OpLabel
+               OpBranch %14
+         %11 = OpLabel
+               OpBranch %14
+         %14 = OpLabel
+               OpBranch %13
+         %15 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  const auto env = SPV_ENV_UNIVERSAL_1_3;
+  const auto consumer = nullptr;
+  const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+  spvtools::ValidatorOptions validator_options;
+  ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+                                               kConsoleMessageConsumer));
+  TransformationContext transformation_context(
+      MakeUnique<FactManager>(context.get()), validator_options);
+
+  ASSERT_FALSE(TransformationFlattenConditionalBranch(5, true, 0, 0, 0, {})
+                   .IsApplicable(context.get(), transformation_context));
+}
+
 }  // namespace
 }  // namespace fuzz
 }  // namespace spvtools