Handle back edges better in dead branch elim. (#2417)

* Handle back edges better in dead branch elim.

Loop header must have exactly one back edge.  Sometimes the branch
with the back edge can be folded.  However, it should not be folded
if it removes the back edge.

The code to check this simply avoids folding the branch in the
continue block.  That needs to be changed to not fold the back edge,
wherever it is.

At the same time, the branch can be folded if it folds to a branch to
the header, because the back edge will still exist.

Fixes #2391.
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
index 9893536..df36541 100644
--- a/source/opt/dead_branch_elim_pass.cpp
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -95,7 +95,7 @@
     Function* func, std::unordered_set<BasicBlock*>* live_blocks) {
   StructuredCFGAnalysis* cfgAnalysis = context()->GetStructuredCFGAnalysis();
 
-  std::unordered_set<BasicBlock*> continues;
+  std::unordered_set<BasicBlock*> blocks_with_backedge;
   std::vector<BasicBlock*> stack;
   stack.push_back(&*func->begin());
   bool modified = false;
@@ -107,7 +107,10 @@
     if (!live_blocks->insert(block).second) continue;
 
     uint32_t cont_id = block->ContinueBlockIdIfAny();
-    if (cont_id != 0) continues.insert(GetParentBlock(cont_id));
+    if (cont_id != 0) {
+      AddBlocksWithBackEdge(cont_id, block->id(), block->MergeBlockIdIfAny(),
+                            &blocks_with_backedge);
+    }
 
     Instruction* terminator = block->terminator();
     uint32_t live_lab_id = 0;
@@ -146,13 +149,23 @@
       }
     }
 
-    // Don't simplify branches of continue blocks. A path from the continue to
-    // the header is required.
-    // TODO(alan-baker): They can be simplified iff there remains a path to the
-    // backedge. Structured control flow should guarantee one path hits the
-    // backedge, but I've removed the requirement for structured control flow
-    // from this pass.
-    bool simplify = live_lab_id != 0 && !continues.count(block);
+    // Don't simplify back edges unless it becomes a branch to the header. Every
+    // loop must have exactly one back edge to the loop header, so we cannot
+    // remove it.
+    bool simplify = false;
+    if (live_lab_id != 0) {
+      if (!blocks_with_backedge.count(block)) {
+        // This is not a back edge.
+        simplify = true;
+      } else {
+        const auto& struct_cfg_analysis = context()->GetStructuredCFGAnalysis();
+        uint32_t header_id = struct_cfg_analysis->ContainingLoop(block->id());
+        if (live_lab_id == header_id) {
+          // The new branch will be a branch to the header.
+          simplify = true;
+        }
+      }
+    }
 
     if (simplify) {
       modified = true;
@@ -538,5 +551,39 @@
   return nullptr;
 }
 
+void DeadBranchElimPass::AddBlocksWithBackEdge(
+    uint32_t cont_id, uint32_t header_id, uint32_t merge_id,
+    std::unordered_set<BasicBlock*>* blocks_with_back_edges) {
+  std::unordered_set<uint32_t> visited;
+  visited.insert(cont_id);
+  visited.insert(header_id);
+  visited.insert(merge_id);
+
+  std::vector<uint32_t> work_list;
+  work_list.push_back(cont_id);
+
+  while (!work_list.empty()) {
+    uint32_t bb_id = work_list.back();
+    work_list.pop_back();
+
+    BasicBlock* bb = context()->get_instr_block(bb_id);
+
+    bool has_back_edge = false;
+    bb->ForEachSuccessorLabel([header_id, &visited, &work_list,
+                               &has_back_edge](uint32_t* succ_label_id) {
+      if (visited.insert(*succ_label_id).second) {
+        work_list.push_back(*succ_label_id);
+      }
+      if (*succ_label_id == header_id) {
+        has_back_edge = true;
+      }
+    });
+
+    if (has_back_edge) {
+      blocks_with_back_edges->insert(bb);
+    }
+  }
+}
+
 }  // namespace opt
 }  // namespace spvtools
diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h
index 4a1b6b4..2ced589 100644
--- a/source/opt/dead_branch_elim_pass.h
+++ b/source/opt/dead_branch_elim_pass.h
@@ -148,6 +148,14 @@
                                                uint32_t merge_block_id,
                                                uint32_t loop_merge_id,
                                                uint32_t loop_continue_id);
+
+  // Adds to |blocks_with_back_edges| all of the blocks on the path from the
+  // basic block |cont_id| to |header_id| and |merge_id|.  The intention is that
+  // |cond_id| is a the continue target of a loop, |header_id| is the header of
+  // the loop, and |merge_id| is the merge block of the loop.
+  void AddBlocksWithBackEdge(
+      uint32_t cont_id, uint32_t header_id, uint32_t merge_id,
+      std::unordered_set<BasicBlock*>* blocks_with_back_edges);
 };
 
 }  // namespace opt
diff --git a/test/opt/dead_branch_elim_test.cpp b/test/opt/dead_branch_elim_test.cpp
index 66e3bdd..3b9aab8 100644
--- a/test/opt/dead_branch_elim_test.cpp
+++ b/test/opt/dead_branch_elim_test.cpp
@@ -2868,6 +2868,83 @@
   SinglePassRunAndMatch<DeadBranchElimPass>(body, true);
 }
 
+TEST_F(DeadBranchElimTest, DontFoldBackedge) {
+  const std::string body =
+      R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%2 = OpFunction %void None %4
+%7 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpLoopMerge %9 %10 None
+OpBranch %11
+%11 = OpLabel
+%12 = OpUndef %bool
+OpSelectionMerge %10 None
+OpBranchConditional %12 %13 %10
+%13 = OpLabel
+OpBranch %9
+%10 = OpLabel
+OpBranch %14
+%14 = OpLabel
+OpBranchConditional %false %8 %9
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndCheck<DeadBranchElimPass>(body, body, true);
+}
+
+TEST_F(DeadBranchElimTest, FoldBackedgeToHeader) {
+  const std::string body =
+      R"(
+; CHECK: OpLabel
+; CHECK: [[header:%\w+]] = OpLabel
+; CHECK-NEXT: OpLoopMerge {{%\w+}} [[cont:%\w+]]
+; CHECK: [[cont]] = OpLabel
+; This branch may not be in the continue block, but must come after it.
+; CHECK: OpBranch [[header]]
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%2 = OpFunction %void None %4
+%7 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpLoopMerge %9 %10 None
+OpBranch %11
+%11 = OpLabel
+%12 = OpUndef %bool
+OpSelectionMerge %10 None
+OpBranchConditional %12 %13 %10
+%13 = OpLabel
+OpBranch %9
+%10 = OpLabel
+OpBranch %14
+%14 = OpLabel
+OpBranchConditional %true %8 %9
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<DeadBranchElimPass>(body, true);
+}
+
 // TODO(greg-lunarg): Add tests to verify handling of these cases:
 //
 //    More complex control flow