Handle breaks from structured-ifs in DCE. (#1848)

* Handle breaks from structured-ifs in DCE.

dead code elimination assumes that are conditional branches except for
breaks and continues in loops will have an OpSelectionMerge before them.
That is not true when breaking out of a selection construct.

The fix is to look for breaks in selection constructs in the same place
we look for breaks and continues for loops.
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 39970a5..faf278a 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -245,11 +245,13 @@
 }
 
 void AggressiveDCEPass::AddBreaksAndContinuesToWorklist(
-    Instruction* loopMerge) {
-  BasicBlock* header = context()->get_instr_block(loopMerge);
+    Instruction* mergeInst) {
+  assert(mergeInst->opcode() == SpvOpSelectionMerge ||
+         mergeInst->opcode() == SpvOpLoopMerge);
+
+  BasicBlock* header = context()->get_instr_block(mergeInst);
   uint32_t headerIndex = structured_order_index_[header];
-  const uint32_t mergeId =
-      loopMerge->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx);
+  const uint32_t mergeId = mergeInst->GetSingleWordInOperand(0);
   BasicBlock* merge = context()->get_instr_block(mergeId);
   uint32_t mergeIndex = structured_order_index_[merge];
   get_def_use_mgr()->ForEachUser(
@@ -265,8 +267,14 @@
           if (userMerge != nullptr) AddToWorklist(userMerge);
         }
       });
+
+  if (mergeInst->opcode() != SpvOpLoopMerge) {
+    return;
+  }
+
+  // For loops we need to find the continues as well.
   const uint32_t contId =
-      loopMerge->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx);
+      mergeInst->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx);
   get_def_use_mgr()->ForEachUser(contId, [&contId, this](Instruction* user) {
     SpvOp op = user->opcode();
     if (op == SpvOpBranchConditional || op == SpvOpSwitch) {
@@ -373,7 +381,9 @@
         case SpvOpSwitch:
         case SpvOpBranch:
         case SpvOpBranchConditional: {
-          if (assume_branches_live.top()) AddToWorklist(&*ii);
+          if (assume_branches_live.top()) {
+            AddToWorklist(&*ii);
+          }
         } break;
         default: {
           // Function calls, atomics, function params, function returns, etc.
@@ -426,9 +436,7 @@
       AddToWorklist(branchInst);
       Instruction* mergeInst = branch2merge_[branchInst];
       AddToWorklist(mergeInst);
-      // If in a loop, mark all its break and continue instructions live
-      if (mergeInst->opcode() == SpvOpLoopMerge)
-        AddBreaksAndContinuesToWorklist(mergeInst);
+      AddBreaksAndContinuesToWorklist(mergeInst);
     }
     // If local load, add all variable's stores if variable not already live
     if (liveInst->opcode() == SpvOpLoad) {
@@ -445,7 +453,7 @@
       if (varId != 0) {
         ProcessLoad(varId);
       }
-    // If function call, treat as if it loads from all pointer arguments
+      // If function call, treat as if it loads from all pointer arguments
     } else if (liveInst->opcode() == SpvOpFunctionCall) {
       liveInst->ForEachInId([this](const uint32_t* iid) {
         // Skip non-ptr args
@@ -454,11 +462,11 @@
         (void)GetPtr(*iid, &varId);
         ProcessLoad(varId);
       });
-    // If function parameter, treat as if it's result id is loaded from
+      // If function parameter, treat as if it's result id is loaded from
     } else if (liveInst->opcode() == SpvOpFunctionParameter) {
       ProcessLoad(liveInst->result_id());
-    // We treat an OpImageTexelPointer as a load of the pointer, and
-    // that value is manipulated to get the result.
+      // We treat an OpImageTexelPointer as a load of the pointer, and
+      // that value is manipulated to get the result.
     } else if (liveInst->opcode() == SpvOpImageTexelPointer) {
       uint32_t varId;
       (void)GetPtr(liveInst, &varId);
diff --git a/source/opt/aggressive_dead_code_elim_pass.h b/source/opt/aggressive_dead_code_elim_pass.h
index 73457ce..3c03cc6 100644
--- a/source/opt/aggressive_dead_code_elim_pass.h
+++ b/source/opt/aggressive_dead_code_elim_pass.h
@@ -114,7 +114,7 @@
   // Add branch to |labelId| to end of block |bp|.
   void AddBranch(uint32_t labelId, BasicBlock* bp);
 
-  // Add all break and continue branches in the loop associated with
+  // Add all break and continue branches in the construct associated with
   // |mergeInst| to worklist if not already live
   void AddBreaksAndContinuesToWorklist(Instruction* mergeInst);
 
diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp
index 873fc34..287fcef 100644
--- a/test/opt/aggressive_dead_code_elim_test.cpp
+++ b/test/opt/aggressive_dead_code_elim_test.cpp
@@ -5767,6 +5767,51 @@
   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
   SinglePassRunAndCheck<AggressiveDCEPass>(test, result, true, true);
 }
+
+TEST_F(AggressiveDCETest, StructuredIfWithConditionalExit) {
+  // We are able to remove "local2" because it is not loaded, but have to keep
+  // the stores to "local1".
+  const std::string test =
+      R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpSourceExtension "GL_GOOGLE_cpp_style_line_directive"
+OpSourceExtension "GL_GOOGLE_include_directive"
+OpName %main "main"
+OpName %a "a"
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Uniform_int = OpTypePointer Uniform %int
+%int_0 = OpConstant %int 0
+%bool = OpTypeBool
+%int_100 = OpConstant %int 100
+%int_1 = OpConstant %int 1
+%a = OpVariable %_ptr_Uniform_int Uniform
+%main = OpFunction %void None %5
+%12 = OpLabel
+%13 = OpLoad %int %a
+%14 = OpSGreaterThan %bool %13 %int_0
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+%17 = OpLoad %int %a
+%18 = OpSLessThan %bool %17 %int_100
+OpBranchConditional %18 %19 %15
+%19 = OpLabel
+OpStore %a %int_1
+OpBranch %15
+%15 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<AggressiveDCEPass>(test, test, true, true);
+}
 // TODO(greg-lunarg): Add tests to verify handling of these cases:
 //
 //    Check that logical addressing required