Fix structured exit validation (#3141)

Fixes #3139

* If the header of the construct is also a merge block, jump to the
associated header instead of the immediate dominator
  * prevents spurious failures from unrelated constructs
* new tests

diff --git a/source/val/construct.cpp b/source/val/construct.cpp
index 1564449..733856c 100644
--- a/source/val/construct.cpp
+++ b/source/val/construct.cpp
@@ -169,9 +169,22 @@
       return true;
     }
 
+    // The next block in the traversal is either:
+    //  i.  The header block that declares |block| as its merge block.
+    //  ii. The immediate dominator of |block|.
+    auto NextBlock = [](const BasicBlock* block) -> const BasicBlock* {
+      for (auto& use : block->label()->uses()) {
+        if ((use.first->opcode() == SpvOpLoopMerge ||
+             use.first->opcode() == SpvOpSelectionMerge) &&
+            use.second == 1)
+          return use.first->block();
+      }
+      return block->immediate_dominator();
+    };
+
     bool seen_switch = false;
     auto header = entry_block();
-    auto block = header->immediate_dominator();
+    auto block = NextBlock(header);
     while (block) {
       auto terminator = block->terminator();
       auto index = terminator - &_.ordered_instructions()[0];
@@ -183,7 +196,7 @@
         auto merge_target = merge_inst->GetOperandAs<uint32_t>(0u);
         auto merge_block = merge_inst->function()->GetBlock(merge_target).first;
         if (merge_block->dominates(*header)) {
-          block = block->immediate_dominator();
+          block = NextBlock(block);
           continue;
         }
 
@@ -197,13 +210,15 @@
           }
         }
 
-        if (terminator->opcode() == SpvOpSwitch) seen_switch = true;
+        if (terminator->opcode() == SpvOpSwitch) {
+          seen_switch = true;
+        }
 
         // Hit an enclosing loop and didn't break or continue.
         if (merge_inst->opcode() == SpvOpLoopMerge) return false;
       }
 
-      block = block->immediate_dominator();
+      block = NextBlock(block);
     }
   }
 
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 43a6af7..f3019d1 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -728,10 +728,10 @@
     }
 
     Construct::ConstructBlockSet construct_blocks = construct.blocks(function);
+    std::string construct_name, header_name, exit_name;
+    std::tie(construct_name, header_name, exit_name) =
+        ConstructNames(construct.type());
     for (auto block : construct_blocks) {
-      std::string construct_name, header_name, exit_name;
-      std::tie(construct_name, header_name, exit_name) =
-          ConstructNames(construct.type());
       // Check that all exits from the construct are via structured exits.
       for (auto succ : *block->successors()) {
         if (block->reachable() && !construct_blocks.count(succ) &&
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 8d8de37..4cf4029 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -4187,6 +4187,115 @@
           "1[%loop], but its merge block 2[%continue] is not"));
 }
 
+TEST_F(ValidateCFG, ExitFromConstructWhoseHeaderIsAMerge) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%2 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%4 = OpUndef %int
+%bool = OpTypeBool
+%6 = OpUndef %bool
+%7 = OpFunction %void None %2
+%8 = OpLabel
+OpSelectionMerge %9 None
+OpSwitch %4 %10 0 %11
+%10 = OpLabel
+OpBranch %9
+%11 = OpLabel
+OpBranch %12
+%12 = OpLabel
+OpLoopMerge %13 %14 None
+OpBranch %15
+%15 = OpLabel
+OpSelectionMerge %16 None
+OpSwitch %4 %17 1 %18 2 %19
+%17 = OpLabel
+OpBranch %16
+%18 = OpLabel
+OpBranch %14
+%19 = OpLabel
+OpBranch %16
+%16 = OpLabel
+OpBranch %14
+%14 = OpLabel
+OpBranchConditional %6 %12 %13
+%13 = OpLabel
+OpSelectionMerge %20 None
+OpBranchConditional %6 %21 %20
+%21 = OpLabel
+OpBranch %9
+%20 = OpLabel
+OpBranch %10
+%9 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, ExitFromConstructWhoseHeaderIsAMerge2) {
+  const std::string text = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+          %6 = OpUndef %int
+       %bool = OpTypeBool
+          %8 = OpUndef %bool
+          %2 = OpFunction %void None %4
+          %9 = OpLabel
+               OpSelectionMerge %10 None
+               OpSwitch %6 %11 0 %12
+         %11 = OpLabel
+               OpBranch %10
+         %12 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+               OpLoopMerge %14 %15 None
+               OpBranch %16
+         %16 = OpLabel
+               OpSelectionMerge %17 None
+               OpSwitch %6 %18 1 %19 2 %20
+         %18 = OpLabel
+               OpBranch %17
+         %19 = OpLabel
+               OpBranch %15
+         %20 = OpLabel
+               OpBranch %17
+         %17 = OpLabel
+               OpBranch %15
+         %15 = OpLabel
+               OpBranchConditional %8 %13 %14
+         %14 = OpLabel
+               OpSelectionMerge %21 None
+               OpBranchConditional %8 %22 %21
+         %22 = OpLabel
+               OpSelectionMerge %23 None
+               OpBranchConditional %8 %24 %23
+         %24 = OpLabel
+               OpBranch %10
+         %23 = OpLabel
+               OpBranch %21
+         %21 = OpLabel
+               OpBranch %11
+         %10 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(text);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools