Implement WebGPU specific CFG validation (#2386)

In WebGPU all blocks are required to be reachable, unless they are one of two
specific degenerate cases for merge-block or continue-target. This PR adds in
checking for these conditions.

Fixes #2068
diff --git a/source/val/function.cpp b/source/val/function.cpp
index f638fb5..90893b3 100644
--- a/source/val/function.cpp
+++ b/source/val/function.cpp
@@ -86,6 +86,12 @@
   continue_construct.set_corresponding_constructs({&loop_construct});
   loop_construct.set_corresponding_constructs({&continue_construct});
   merge_block_header_[&merge_block] = current_block_;
+  if (continue_target_headers_.find(&continue_target_block) ==
+      continue_target_headers_.end()) {
+    continue_target_headers_[&continue_target_block] = {current_block_};
+  } else {
+    continue_target_headers_[&continue_target_block].push_back(current_block_);
+  }
 
   return SPV_SUCCESS;
 }
diff --git a/source/val/function.h b/source/val/function.h
index a052bbd..9cda2ff 100644
--- a/source/val/function.h
+++ b/source/val/function.h
@@ -232,6 +232,28 @@
     return function_call_targets_;
   }
 
+  // Returns the block containing the OpSelectionMerge or OpLoopMerge that
+  // references |merge_block|.
+  // Values of |merge_block_header_| inserted by CFGPass, so do not call before
+  // the first iteration of ordered instructions in
+  // ValidateBinaryUsingContextAndValidationState has completed.
+  BasicBlock* GetMergeHeader(BasicBlock* merge_block) {
+    return merge_block_header_[merge_block];
+  }
+
+  // Returns vector of the blocks containing a OpLoopMerge that references
+  // |continue_target|.
+  // Values of |continue_target_headers_| inserted by CFGPass, so do not call
+  // before the first iteration of ordered instructions in
+  // ValidateBinaryUsingContextAndValidationState has completed.
+  std::vector<BasicBlock*> GetContinueHeaders(BasicBlock* continue_target) {
+    if (continue_target_headers_.find(continue_target) ==
+        continue_target_headers_.end()) {
+      return {};
+    }
+    return continue_target_headers_[continue_target];
+  }
+
  private:
   // Computes the representation of the augmented CFG.
   // Populates augmented_successors_map_ and augmented_predecessors_map_.
@@ -340,6 +362,10 @@
   /// This map provides the header block for a given merge block.
   std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_;
 
+  /// This map provides the header blocks for a given continue target.
+  std::unordered_map<BasicBlock*, std::vector<BasicBlock*>>
+      continue_target_headers_;
+
   /// Stores the control flow nesting depth of a given basic block
   std::unordered_map<BasicBlock*, int> block_depth_;
 
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index fe79dde..6e7acbd 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -29,6 +29,7 @@
 
 #include "source/cfa.h"
 #include "source/opcode.h"
+#include "source/spirv_target_env.h"
 #include "source/spirv_validator_options.h"
 #include "source/val/basic_block.h"
 #include "source/val/construct.h"
@@ -610,6 +611,120 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t PerformWebGPUCfgChecks(ValidationState_t& _, Function* function) {
+  for (auto& block : function->ordered_blocks()) {
+    if (block->reachable()) continue;
+    if (block->is_type(kBlockTypeMerge)) {
+      // 1. Find the referencing merge and confirm that it is reachable.
+      BasicBlock* merge_header = function->GetMergeHeader(block);
+      assert(merge_header != nullptr);
+      if (!merge_header->reachable()) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable merge-blocks must be referenced by "
+                  "a reachable merge instruction.";
+      }
+
+      // 2. Check that the only instructions are OpLabel and OpUnreachable.
+      auto* label_inst = block->label();
+      auto* terminator_inst = block->terminator();
+      assert(label_inst != nullptr);
+      assert(terminator_inst != nullptr);
+
+      if (terminator_inst->opcode() != SpvOpUnreachable) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable merge-blocks must terminate with "
+                  "OpUnreachable.";
+      }
+
+      auto label_idx = label_inst - &_.ordered_instructions()[0];
+      auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
+      if (label_idx + 1 != terminator_idx) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable merge-blocks must only contain an "
+                  "OpLabel and OpUnreachable instruction.";
+      }
+
+      // 3. Use label instruction to confirm there is no uses by branches.
+      for (auto use : label_inst->uses()) {
+        const auto* use_inst = use.first;
+        if (spvOpcodeIsBranch(use_inst->opcode())) {
+          return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+                 << "For WebGPU, unreachable merge-blocks cannot be the target "
+                    "of a branch.";
+        }
+      }
+    } else if (block->is_type(kBlockTypeContinue)) {
+      // 1. Find referencing loop and confirm that it is reachable.
+      std::vector<BasicBlock*> continue_headers =
+          function->GetContinueHeaders(block);
+      if (continue_headers.empty()) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable continue-target must be referenced "
+                  "by a loop instruction.";
+      }
+
+      std::vector<BasicBlock*> reachable_headers(continue_headers.size());
+      auto iter =
+          std::copy_if(continue_headers.begin(), continue_headers.end(),
+                       reachable_headers.begin(),
+                       [](BasicBlock* header) { return header->reachable(); });
+      reachable_headers.resize(std::distance(reachable_headers.begin(), iter));
+
+      if (reachable_headers.empty()) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable continue-target must be referenced "
+                  "by a reachable loop instruction.";
+      }
+
+      // 2. Check that the only instructions are OpLabel and OpBranch.
+      auto* label_inst = block->label();
+      auto* terminator_inst = block->terminator();
+      assert(label_inst != nullptr);
+      assert(terminator_inst != nullptr);
+
+      if (terminator_inst->opcode() != SpvOpBranch) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable continue-target must terminate with "
+                  "OpBranch.";
+      }
+
+      auto label_idx = label_inst - &_.ordered_instructions()[0];
+      auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
+      if (label_idx + 1 != terminator_idx) {
+        return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+               << "For WebGPU, unreachable continue-target must only contain "
+                  "an OpLabel and an OpBranch instruction.";
+      }
+
+      // 3. Use label instruction to confirm there is no uses by branches.
+      for (auto use : label_inst->uses()) {
+        const auto* use_inst = use.first;
+        if (spvOpcodeIsBranch(use_inst->opcode())) {
+          return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+                 << "For WebGPU, unreachable continue-target cannot be the "
+                    "target of a branch.";
+        }
+      }
+
+      // 4. Confirm that continue-target has a back edge to a reachable loop
+      //    header block.
+      auto branch_target = terminator_inst->GetOperandAs<uint32_t>(0);
+      for (auto* continue_header : reachable_headers) {
+        if (branch_target != continue_header->id()) {
+          return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+                 << "For WebGPU, unreachable continue-target must only have a "
+                    "back edge to a single reachable loop instruction.";
+        }
+      }
+    } else {
+      return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+             << "For WebGPU, all blocks must be reachable, unless they are "
+             << "degenerate cases of merge-block or continue-target.";
+    }
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t PerformCfgChecks(ValidationState_t& _) {
   for (auto& function : _.functions()) {
     // Check all referenced blocks are defined within a function
@@ -689,6 +804,13 @@
                    << _.getIdName(idom->id());
           }
         }
+
+        // For WebGPU check that all unreachable blocks are degenerate cases for
+        // merge-block or continue-target.
+        if (spvIsWebGPUEnv(_.context()->target_env)) {
+          spv_result_t result = PerformWebGPUCfgChecks(_, &function);
+          if (result != SPV_SUCCESS) return result;
+        }
       }
       // If we have structed control flow, check that no block has a control
       // flow nesting depth larger than the limit.
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 15913c8..a0ee79a 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -25,6 +25,7 @@
 #include "gmock/gmock.h"
 
 #include "source/diagnostic.h"
+#include "source/spirv_target_env.h"
 #include "source/val/validate.h"
 #include "test/test_fixture.h"
 #include "test/unit_spirv.h"
@@ -103,6 +104,12 @@
         }
         out << ss.str();
       } break;
+      case SpvOpLoopMerge: {
+        assert(successors_.size() == 2);
+        out << "OpLoopMerge %" + successors_[0].label_ + " %" +
+                   successors_[0].label_ + "None";
+      } break;
+
       case SpvOpReturn:
         assert(successors_.size() == 0);
         out << "OpReturn\n";
@@ -115,6 +122,10 @@
         assert(successors_.size() == 1);
         out << "OpBranch %" + successors_.front().label_;
         break;
+      case SpvOpKill:
+        assert(successors_.size() == 0);
+        out << "OpKill\n";
+        break;
       default:
         assert(1 == 0 && "Unhandled");
     }
@@ -144,13 +155,13 @@
   return lhs;
 }
 
-const char* header(SpvCapability cap) {
-  static const char* shader_header =
+const std::string& GetDefaultHeader(SpvCapability cap) {
+  static const std::string shader_header =
       "OpCapability Shader\n"
       "OpCapability Linkage\n"
       "OpMemoryModel Logical GLSL450\n";
 
-  static const char* kernel_header =
+  static const std::string kernel_header =
       "OpCapability Kernel\n"
       "OpCapability Linkage\n"
       "OpMemoryModel Logical OpenCL\n";
@@ -158,8 +169,17 @@
   return (cap == SpvCapabilityShader) ? shader_header : kernel_header;
 }
 
-const char* types_consts() {
-  static const char* types =
+const std::string& GetWebGPUHeader() {
+  static const std::string header =
+      "OpCapability Shader\n"
+      "OpCapability VulkanMemoryModelKHR\n"
+      "OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
+      "OpMemoryModel Logical VulkanKHR\n";
+  return header;
+}
+
+const std::string& types_consts() {
+  static const std::string types =
       "%voidt   = OpTypeVoid\n"
       "%boolt   = OpTypeBool\n"
       "%intt    = OpTypeInt 32 0\n"
@@ -167,7 +187,6 @@
       "%two     = OpConstant %intt 2\n"
       "%ptrt    = OpTypePointer Function %intt\n"
       "%funct   = OpTypeFunction %voidt\n";
-
   return types;
 }
 
@@ -270,7 +289,7 @@
     loop.SetBody("OpLoopMerge %merge %cont None\n");
   }
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("loop", "entry", "cont", "merge",
                             std::make_pair("func", "Main")) +
                     types_consts() +
@@ -293,7 +312,7 @@
 
   entry.SetBody("%var = OpVariable %ptrt Function\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps(std::make_pair("func", "Main")) + types_consts() +
                     " %func    = OpFunction %voidt None %funct\n";
   str += entry >> cont;
@@ -313,7 +332,7 @@
   // This operation should only be performed in the entry block
   cont.SetBody("%var = OpVariable %ptrt Function\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps(std::make_pair("func", "Main")) + types_consts() +
                     " %func    = OpFunction %voidt None %funct\n";
 
@@ -339,7 +358,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("loop", "merge", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -364,7 +383,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) branch.SetBody("OpSelectionMerge %merge None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("cont", "branch", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -396,9 +415,10 @@
   // cannot share the same merge
   if (is_shader) selection.SetBody("OpSelectionMerge %merge None\n");
 
-  std::string str =
-      header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("merge", std::make_pair("func", "Main")) +
+                    types_consts() +
+                    "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
   str += loop >> selection;
@@ -431,9 +451,10 @@
   // cannot share the same merge
   if (is_shader) loop.SetBody(" OpLoopMerge %merge %loop None\n");
 
-  std::string str =
-      header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("merge", std::make_pair("func", "Main")) +
+                    types_consts() +
+                    "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> selection;
   str += selection >> std::vector<Block>({merge, loop});
@@ -457,7 +478,7 @@
   Block entry("entry");
   Block bad("bad");
   Block end("end", SpvOpReturn);
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "bad", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -481,7 +502,7 @@
   Block bad("bad");
   Block end("end", SpvOpReturn);
   Block badvalue("undef");  // This referenes the OpUndef.
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "bad", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -507,7 +528,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   bad.SetBody(" OpLoopMerge %entry %exit None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "bad", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -535,7 +556,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   bad.SetBody("OpLoopMerge %merge %cont None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "bad", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -567,7 +588,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   bad.SetBody("OpSelectionMerge %merge None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "bad", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -602,9 +623,10 @@
   Block middle2("middle2");
   Block end2("end2", SpvOpReturn);
 
-  std::string str =
-      header(GetParam()) + nameOps("middle2", std::make_pair("func", "Main")) +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("middle2", std::make_pair("func", "Main")) +
+                    types_consts() +
+                    "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> middle;
   str += middle >> std::vector<Block>({end, middle2});
@@ -637,7 +659,7 @@
 
   if (is_shader) head.AppendBody("OpSelectionMerge %merge None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("head", "merge", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -672,7 +694,7 @@
 
   if (is_shader) head.AppendBody("OpSelectionMerge %head None\n");
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("head", "exit", std::make_pair("func", "Main")) +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -694,8 +716,10 @@
   }
 }
 
-TEST_P(ValidateCFG, UnreachableMerge) {
-  bool is_shader = GetParam() == SpvCapabilityShader;
+std::string GetUnreachableMergeNoMergeInst(SpvCapability cap,
+                                           spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
   Block entry("entry");
   Block branch("branch", SpvOpBranchConditional);
   Block t("t", SpvOpReturn);
@@ -703,13 +727,18 @@
   Block merge("merge", SpvOpReturn);
 
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
-  if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n");
+  if (!spvIsWebGPUEnv(env) && cap == SpvCapabilityShader)
+    branch.AppendBody("OpSelectionMerge %merge None\n");
 
-  std::string str = header(GetParam()) +
-                    nameOps("branch", "merge", std::make_pair("func", "Main")) +
-                    types_consts() +
-                    "%func    = OpFunction %voidt None %funct\n";
-
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+  str += types_consts() + "%func    = OpFunction %voidt None %funct\n";
   str += entry >> branch;
   str += branch >> std::vector<Block>({t, f});
   str += t;
@@ -717,12 +746,209 @@
   str += merge;
   str += "OpFunctionEnd\n";
 
-  CompileSuccessfully(str);
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeNoMergeInst) {
+  CompileSuccessfully(
+      GetUnreachableMergeNoMergeInst(GetParam(), SPV_ENV_UNIVERSAL_1_0));
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
-  bool is_shader = GetParam() == SpvCapabilityShader;
+TEST_F(ValidateCFG, WebGPUUnreachableMergeNoMergeInst) {
+  CompileSuccessfully(
+      GetUnreachableMergeNoMergeInst(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("For WebGPU, all blocks must be reachable"));
+}
+
+std::string GetUnreachableMergeTerminatedBy(SpvCapability cap,
+                                            spv_target_env env, SpvOp op) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+  Block merge("merge", op);
+
+  entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpSelectionMerge %merge None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeTerminatedByOpUnreachable) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpUnreachable));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, UnreachableMergeTerminatedByOpKill) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_UNIVERSAL_1_0, SpvOpKill));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_P(ValidateCFG, UnreachableMergeTerminatedByOpReturn) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpReturn));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpUnreachable) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpUnreachable));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpKill) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpKill));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("must terminate with OpUnreachable"));
+}
+
+TEST_P(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpReturn) {
+  CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpReturn));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("must terminate with OpUnreachable"));
+}
+
+std::string GetUnreachableContinueTerminatedBy(SpvCapability cap,
+                                               spv_target_env env, SpvOp op) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", op);
+
+  if (op == SpvOpBranch) target >> branch;
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({merge});
+  str += merge;
+  str += target;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpUnreachable) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpUnreachable));
+  if (GetParam() == SpvCapabilityShader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("targeted by 0 back-edge blocks"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_F(ValidateCFG, UnreachableContinueTerminatedBySpvOpKill) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_UNIVERSAL_1_0, SpvOpKill));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("targeted by 0 back-edge blocks"));
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpReturn) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpReturn));
+  if (GetParam() == SpvCapabilityShader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("targeted by 0 back-edge blocks"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpBranch) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpBranch));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpUnreachable) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpUnreachable));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("For WebGPU, unreachable continue-target must "
+                        "terminate with OpBranch.\n  %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpKill) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpKill));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("For WebGPU, unreachable continue-target must "
+                        "terminate with OpBranch.\n  %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpReturn) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpReturn));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("For WebGPU, unreachable continue-target must "
+                        "terminate with OpBranch.\n  %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpBranch) {
+  CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpBranch));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+std::string GetUnreachableMergeUnreachableMergeInst(SpvCapability cap,
+                                                    spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block body("body", SpvOpReturn);
   Block entry("entry");
   Block branch("branch", SpvOpBranchConditional);
   Block t("t", SpvOpReturn);
@@ -730,13 +956,134 @@
   Block merge("merge", SpvOpUnreachable);
 
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
-  if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n");
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpSelectionMerge %merge None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
 
-  std::string str = header(GetParam()) +
-                    nameOps("branch", "merge", std::make_pair("func", "Main")) +
-                    types_consts() +
-                    "%func    = OpFunction %voidt None %funct\n";
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += body;
+  str += merge;
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += "OpFunctionEnd\n";
 
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeUnreachableMergeInst) {
+  CompileSuccessfully(GetUnreachableMergeUnreachableMergeInst(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeUnreachableMergeInst) {
+  CompileSuccessfully(GetUnreachableMergeUnreachableMergeInst(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("must be referenced by a reachable merge instruction"));
+}
+
+std::string GetUnreachableContinueUnreachableLoopInst(SpvCapability cap,
+                                                      spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block body("body", SpvOpReturn);
+  Block entry("entry");
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", SpvOpBranch);
+
+  target >> branch;
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += body;
+  str += target;
+  str += merge;
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({merge});
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueUnreachableLoopInst) {
+  CompileSuccessfully(GetUnreachableContinueUnreachableLoopInst(
+      GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  if (GetParam() == SpvCapabilityShader) {
+    // Shader causes additional structured CFG checks that cause a failure.
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("Back-edges (1[%branch] -> 3[%target]) can only be "
+                          "formed between a block and a loop header."));
+
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueUnreachableLoopInst) {
+  CompileSuccessfully(GetUnreachableContinueUnreachableLoopInst(
+      SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("must be referenced by a reachable loop instruction"));
+}
+
+std::string GetUnreachableMergeWithComplexBody(SpvCapability cap,
+                                               spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+  Block merge("merge", SpvOpUnreachable);
+
+  entry.AppendBody(spvIsWebGPUEnv(env)
+                       ? "%dummy   = OpVariable %intptrt Function %two\n"
+                       : "%dummy   = OpVariable %intptrt Function\n");
+  entry.AppendBody("%cond    = OpSLessThan %boolt %one %two\n");
+  merge.AppendBody("OpStore %dummy %one\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpSelectionMerge %merge None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%intptrt = OpTypePointer Function %intt\n";
+  str += "%func    = OpFunction %voidt None %funct\n";
   str += entry >> branch;
   str += branch >> std::vector<Block>({t, f});
   str += t;
@@ -744,31 +1091,406 @@
   str += merge;
   str += "OpFunctionEnd\n";
 
-  CompileSuccessfully(str);
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithComplexBody) {
+  CompileSuccessfully(
+      GetUnreachableMergeWithComplexBody(GetParam(), SPV_ENV_UNIVERSAL_1_0));
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_P(ValidateCFG, UnreachableBlock) {
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithComplexBody) {
+  CompileSuccessfully(GetUnreachableMergeWithComplexBody(SpvCapabilityShader,
+                                                         SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("must only contain an OpLabel and OpUnreachable instruction"));
+}
+
+std::string GetUnreachableContinueWithComplexBody(SpvCapability cap,
+                                                  spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", SpvOpBranch);
+
+  target >> branch;
+
+  entry.AppendBody(spvIsWebGPUEnv(env)
+                       ? "%dummy   = OpVariable %intptrt Function %two\n"
+                       : "%dummy   = OpVariable %intptrt Function\n");
+  target.AppendBody("OpStore %dummy %one\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%intptrt = OpTypePointer Function %intt\n";
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({merge});
+  str += merge;
+  str += target;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueWithComplexBody) {
+  CompileSuccessfully(
+      GetUnreachableContinueWithComplexBody(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueWithComplexBody) {
+  CompileSuccessfully(GetUnreachableContinueWithComplexBody(SpvCapabilityShader,
+                                                            SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("must only contain an OpLabel and an OpBranch instruction"));
+}
+
+std::string GetUnreachableMergeWithBranchUse(SpvCapability cap,
+                                             spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpBranch);
+  Block f("f", SpvOpReturn);
+  Block merge("merge", SpvOpUnreachable);
+
+  entry.AppendBody("%cond    = OpSLessThan %boolt %one %two\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpSelectionMerge %merge None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({t, f});
+  str += t >> merge;
+  str += f;
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithBranchUse) {
+  CompileSuccessfully(
+      GetUnreachableMergeWithBranchUse(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithBranchUse) {
+  CompileSuccessfully(
+      GetUnreachableMergeWithBranchUse(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("cannot be the target of a branch."));
+}
+
+std::string GetUnreachableMergeWithMultipleUses(SpvCapability cap,
+                                                spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+  Block merge("merge", SpvOpUnreachable);
+  Block duplicate("duplicate", SpvOpBranchConditional);
+
+  entry.AppendBody("%cond    = OpSLessThan %boolt %one %two\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader) {
+    branch.AppendBody("OpSelectionMerge %merge None\n");
+    duplicate.AppendBody("OpSelectionMerge %merge None\n");
+  }
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({t, f});
+  str += duplicate >> std::vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithMultipleUses) {
+  CompileSuccessfully(
+      GetUnreachableMergeWithMultipleUses(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  if (GetParam() == SpvCapabilityShader) {
+    ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("is already a merge block for another header"));
+  } else {
+    ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+  }
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithMultipleUses) {
+  CompileSuccessfully(GetUnreachableMergeWithMultipleUses(SpvCapabilityShader,
+                                                          SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("is already a merge block for another header"));
+}
+
+std::string GetUnreachableContinueWithBranchUse(SpvCapability cap,
+                                                spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block foo("foo", SpvOpBranch);
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", SpvOpBranch);
+
+  foo >> target;
+  target >> branch;
+
+  entry.AppendBody(spvIsWebGPUEnv(env)
+                       ? "%dummy   = OpVariable %intptrt Function %two\n"
+                       : "%dummy   = OpVariable %intptrt Function\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader)
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%intptrt = OpTypePointer Function %intt\n";
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({merge});
+  str += merge;
+  str += target;
+  str += foo;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueWithBranchUse) {
+  CompileSuccessfully(
+      GetUnreachableContinueWithBranchUse(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueWithBranchUse) {
+  CompileSuccessfully(GetUnreachableContinueWithBranchUse(SpvCapabilityShader,
+                                                          SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("cannot be the target of a branch."));
+}
+
+std::string GetReachableMergeAndContinue(SpvCapability cap,
+                                         spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", SpvOpBranch);
+  Block body("body", SpvOpBranchConditional);
+  Block t("t", SpvOpBranch);
+  Block f("f", SpvOpBranch);
+
+  target >> branch;
+  body.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
+  t >> merge;
+  f >> target;
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader) {
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+    body.AppendBody("OpSelectionMerge %target None\n");
+  }
+
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", "body", "t", "f",
+                   std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({body});
+  str += body >> std::vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge;
+  str += target;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, ReachableMergeAndContinue) {
+  CompileSuccessfully(
+      GetReachableMergeAndContinue(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUReachableMergeAndContinue) {
+  CompileSuccessfully(
+      GetReachableMergeAndContinue(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+std::string GetUnreachableMergeAndContinue(SpvCapability cap,
+                                           spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+  Block entry("entry");
+  Block branch("branch", SpvOpBranch);
+  Block merge("merge", SpvOpReturn);
+  Block target("target", SpvOpBranch);
+  Block body("body", SpvOpBranchConditional);
+  Block t("t", SpvOpReturn);
+  Block f("f", SpvOpReturn);
+
+  target >> branch;
+  body.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (cap == SpvCapabilityShader) {
+    branch.AppendBody("OpLoopMerge %merge %target None\n");
+    body.AppendBody("OpSelectionMerge %target None\n");
+  }
+
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("branch", "merge", "target", "body", "t", "f",
+                   std::make_pair("func", "Main"));
+
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
+  str += entry >> branch;
+  str += branch >> std::vector<Block>({body});
+  str += body >> std::vector<Block>({t, f});
+  str += t;
+  str += f;
+  str += merge;
+  str += target;
+  str += "OpFunctionEnd\n";
+
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeAndContinue) {
+  CompileSuccessfully(
+      GetUnreachableMergeAndContinue(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeAndContinue) {
+  CompileSuccessfully(
+      GetUnreachableMergeAndContinue(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("unreachable merge-blocks must terminate with OpUnreachable"));
+}
+
+std::string GetUnreachableBlock(SpvCapability cap, spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
   Block entry("entry");
   Block unreachable("unreachable");
   Block exit("exit", SpvOpReturn);
 
-  std::string str =
-      header(GetParam()) +
-      nameOps("unreachable", "exit", std::make_pair("func", "Main")) +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
-
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("unreachable", "exit", std::make_pair("func", "Main"));
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
   str += entry >> exit;
   str += unreachable >> exit;
   str += exit;
   str += "OpFunctionEnd\n";
 
-  CompileSuccessfully(str);
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableBlock) {
+  CompileSuccessfully(GetUnreachableBlock(GetParam(), SPV_ENV_UNIVERSAL_1_0));
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_P(ValidateCFG, UnreachableBranch) {
-  bool is_shader = GetParam() == SpvCapabilityShader;
+TEST_F(ValidateCFG, WebGPUUnreachableBlock) {
+  CompileSuccessfully(
+      GetUnreachableBlock(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("all blocks must be reachable"));
+}
+
+std::string GetUnreachableBranch(SpvCapability cap, spv_target_env env) {
+  std::string header =
+      spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
   Block entry("entry");
   Block unreachable("unreachable", SpvOpBranchConditional);
   Block unreachablechildt("unreachablechildt");
@@ -777,11 +1499,19 @@
   Block exit("exit", SpvOpReturn);
 
   unreachable.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
-  if (is_shader) unreachable.AppendBody("OpSelectionMerge %merge None\n");
-  std::string str =
-      header(GetParam()) +
-      nameOps("unreachable", "exit", std::make_pair("func", "Main")) +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  if (cap == SpvCapabilityShader)
+    unreachable.AppendBody("OpSelectionMerge %merge None\n");
+
+  std::string str = header;
+  if (spvIsWebGPUEnv(env)) {
+    str +=
+        "OpEntryPoint Fragment %func \"func\"\n"
+        "OpExecutionMode %func OriginUpperLeft\n";
+  }
+  if (!spvIsWebGPUEnv(env))
+    str += nameOps("unreachable", "exit", std::make_pair("func", "Main"));
+  str += types_consts();
+  str += "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> exit;
   str +=
@@ -792,12 +1522,23 @@
   str += exit;
   str += "OpFunctionEnd\n";
 
-  CompileSuccessfully(str);
+  return str;
+}
+
+TEST_P(ValidateCFG, UnreachableBranch) {
+  CompileSuccessfully(GetUnreachableBranch(GetParam(), SPV_ENV_UNIVERSAL_1_0));
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateCFG, WebGPUUnreachableBranch) {
+  CompileSuccessfully(
+      GetUnreachableBranch(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(), HasSubstr("all blocks must be reachable"));
+}
+
 TEST_P(ValidateCFG, EmptyFunction) {
-  std::string str = header(GetParam()) + std::string(types_consts()) +
+  std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
                     R"(%func    = OpFunction %voidt None %funct
                   %l = OpLabel
                   OpReturn
@@ -816,7 +1557,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.AppendBody("OpLoopMerge %exit %loop None\n");
 
-  std::string str = header(GetParam()) + std::string(types_consts()) +
+  std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
@@ -845,8 +1586,8 @@
     loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n");
   }
 
-  std::string str = header(GetParam()) + nameOps("loop2", "loop2_merge") +
-                    types_consts() +
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("loop2", "loop2_merge") + types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop1;
@@ -885,7 +1626,7 @@
       if_blocks[i].SetBody("OpSelectionMerge %if_merge" + ss.str() + " None\n");
     merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch);
   }
-  std::string str = header(GetParam()) + std::string(types_consts()) +
+  std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> if_blocks[0];
@@ -920,7 +1661,7 @@
     loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n");
   }
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("loop1", "loop2", "be_block", "loop2_merge") +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
@@ -957,7 +1698,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) split.SetBody("OpSelectionMerge %exit None\n");
 
-  std::string str = header(GetParam()) + nameOps("split", "f") +
+  std::string str = GetDefaultHeader(GetParam()) + nameOps("split", "f") +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
@@ -990,7 +1731,8 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) split.SetBody("OpSelectionMerge %exit None\n");
 
-  std::string str = header(GetParam()) + nameOps("split") + types_consts() +
+  std::string str = GetDefaultHeader(GetParam()) + nameOps("split") +
+                    types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> split;
@@ -1022,8 +1764,8 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.SetBody("OpLoopMerge %merge %back0 None\n");
 
-  std::string str = header(GetParam()) + nameOps("loop", "back0", "back1") +
-                    types_consts() +
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("loop", "back0", "back1") + types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
@@ -1059,8 +1801,8 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.SetBody("OpLoopMerge %merge %cheader None\n");
 
-  std::string str = header(GetParam()) + nameOps("cheader", "be_block") +
-                    types_consts() +
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("cheader", "be_block") + types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
@@ -1094,7 +1836,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
 
-  std::string str = header(GetParam()) + nameOps("cont", "loop") +
+  std::string str = GetDefaultHeader(GetParam()) + nameOps("cont", "loop") +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
@@ -1129,7 +1871,7 @@
   entry.SetBody("%cond    = OpSLessThan %boolt %one %two\n");
   if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
 
-  std::string str = header(GetParam()) + nameOps("cont", "loop") +
+  std::string str = GetDefaultHeader(GetParam()) + nameOps("cont", "loop") +
                     types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
@@ -1280,7 +2022,7 @@
     inner_head.SetBody("OpSelectionMerge %inner_merge None\n");
   }
 
-  std::string str = header(GetParam()) +
+  std::string str = GetDefaultHeader(GetParam()) +
                     nameOps("entry", "inner_merge", "exit") + types_consts() +
                     "%func    = OpFunction %voidt None %funct\n";
 
@@ -1318,7 +2060,7 @@
   }
 
   std::string str =
-      header(GetParam()) +
+      GetDefaultHeader(GetParam()) +
       nameOps("entry", "loop", "if_head", "if_true", "if_merge", "merge") +
       types_consts() + "%func    = OpFunction %voidt None %funct\n";
 
@@ -1348,9 +2090,10 @@
     loop.SetBody("OpLoopMerge %merge %latch None\n");
   }
 
-  std::string str =
-      header(GetParam()) + nameOps("entry", "loop", "latch", "merge") +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("entry", "loop", "latch", "merge") +
+                    types_consts() +
+                    "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
   str += loop >> std::vector<Block>({latch, merge});
@@ -1381,9 +2124,10 @@
     loop.SetBody("OpLoopMerge %merge %loop None\n");
   }
 
-  std::string str =
-      header(GetParam()) + nameOps("entry", "loop", "latch", "merge") +
-      types_consts() + "%func    = OpFunction %voidt None %funct\n";
+  std::string str = GetDefaultHeader(GetParam()) +
+                    nameOps("entry", "loop", "latch", "merge") +
+                    types_consts() +
+                    "%func    = OpFunction %voidt None %funct\n";
 
   str += entry >> loop;
   str += loop >> latch;