Implement WebGPU specific CFG validation (#2386)
In WebGPU all blocks are required to be reachable, unless they are one of two
specific degenerate cases for merge-block or continue-target. This PR adds in
checking for these conditions.
Fixes #2068
diff --git a/source/val/function.cpp b/source/val/function.cpp
index f638fb5..90893b3 100644
--- a/source/val/function.cpp
+++ b/source/val/function.cpp
@@ -86,6 +86,12 @@
continue_construct.set_corresponding_constructs({&loop_construct});
loop_construct.set_corresponding_constructs({&continue_construct});
merge_block_header_[&merge_block] = current_block_;
+ if (continue_target_headers_.find(&continue_target_block) ==
+ continue_target_headers_.end()) {
+ continue_target_headers_[&continue_target_block] = {current_block_};
+ } else {
+ continue_target_headers_[&continue_target_block].push_back(current_block_);
+ }
return SPV_SUCCESS;
}
diff --git a/source/val/function.h b/source/val/function.h
index a052bbd..9cda2ff 100644
--- a/source/val/function.h
+++ b/source/val/function.h
@@ -232,6 +232,28 @@
return function_call_targets_;
}
+ // Returns the block containing the OpSelectionMerge or OpLoopMerge that
+ // references |merge_block|.
+ // Values of |merge_block_header_| inserted by CFGPass, so do not call before
+ // the first iteration of ordered instructions in
+ // ValidateBinaryUsingContextAndValidationState has completed.
+ BasicBlock* GetMergeHeader(BasicBlock* merge_block) {
+ return merge_block_header_[merge_block];
+ }
+
+ // Returns vector of the blocks containing a OpLoopMerge that references
+ // |continue_target|.
+ // Values of |continue_target_headers_| inserted by CFGPass, so do not call
+ // before the first iteration of ordered instructions in
+ // ValidateBinaryUsingContextAndValidationState has completed.
+ std::vector<BasicBlock*> GetContinueHeaders(BasicBlock* continue_target) {
+ if (continue_target_headers_.find(continue_target) ==
+ continue_target_headers_.end()) {
+ return {};
+ }
+ return continue_target_headers_[continue_target];
+ }
+
private:
// Computes the representation of the augmented CFG.
// Populates augmented_successors_map_ and augmented_predecessors_map_.
@@ -340,6 +362,10 @@
/// This map provides the header block for a given merge block.
std::unordered_map<BasicBlock*, BasicBlock*> merge_block_header_;
+ /// This map provides the header blocks for a given continue target.
+ std::unordered_map<BasicBlock*, std::vector<BasicBlock*>>
+ continue_target_headers_;
+
/// Stores the control flow nesting depth of a given basic block
std::unordered_map<BasicBlock*, int> block_depth_;
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index fe79dde..6e7acbd 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -29,6 +29,7 @@
#include "source/cfa.h"
#include "source/opcode.h"
+#include "source/spirv_target_env.h"
#include "source/spirv_validator_options.h"
#include "source/val/basic_block.h"
#include "source/val/construct.h"
@@ -610,6 +611,120 @@
return SPV_SUCCESS;
}
+spv_result_t PerformWebGPUCfgChecks(ValidationState_t& _, Function* function) {
+ for (auto& block : function->ordered_blocks()) {
+ if (block->reachable()) continue;
+ if (block->is_type(kBlockTypeMerge)) {
+ // 1. Find the referencing merge and confirm that it is reachable.
+ BasicBlock* merge_header = function->GetMergeHeader(block);
+ assert(merge_header != nullptr);
+ if (!merge_header->reachable()) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable merge-blocks must be referenced by "
+ "a reachable merge instruction.";
+ }
+
+ // 2. Check that the only instructions are OpLabel and OpUnreachable.
+ auto* label_inst = block->label();
+ auto* terminator_inst = block->terminator();
+ assert(label_inst != nullptr);
+ assert(terminator_inst != nullptr);
+
+ if (terminator_inst->opcode() != SpvOpUnreachable) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable merge-blocks must terminate with "
+ "OpUnreachable.";
+ }
+
+ auto label_idx = label_inst - &_.ordered_instructions()[0];
+ auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
+ if (label_idx + 1 != terminator_idx) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable merge-blocks must only contain an "
+ "OpLabel and OpUnreachable instruction.";
+ }
+
+ // 3. Use label instruction to confirm there is no uses by branches.
+ for (auto use : label_inst->uses()) {
+ const auto* use_inst = use.first;
+ if (spvOpcodeIsBranch(use_inst->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable merge-blocks cannot be the target "
+ "of a branch.";
+ }
+ }
+ } else if (block->is_type(kBlockTypeContinue)) {
+ // 1. Find referencing loop and confirm that it is reachable.
+ std::vector<BasicBlock*> continue_headers =
+ function->GetContinueHeaders(block);
+ if (continue_headers.empty()) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target must be referenced "
+ "by a loop instruction.";
+ }
+
+ std::vector<BasicBlock*> reachable_headers(continue_headers.size());
+ auto iter =
+ std::copy_if(continue_headers.begin(), continue_headers.end(),
+ reachable_headers.begin(),
+ [](BasicBlock* header) { return header->reachable(); });
+ reachable_headers.resize(std::distance(reachable_headers.begin(), iter));
+
+ if (reachable_headers.empty()) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target must be referenced "
+ "by a reachable loop instruction.";
+ }
+
+ // 2. Check that the only instructions are OpLabel and OpBranch.
+ auto* label_inst = block->label();
+ auto* terminator_inst = block->terminator();
+ assert(label_inst != nullptr);
+ assert(terminator_inst != nullptr);
+
+ if (terminator_inst->opcode() != SpvOpBranch) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target must terminate with "
+ "OpBranch.";
+ }
+
+ auto label_idx = label_inst - &_.ordered_instructions()[0];
+ auto terminator_idx = terminator_inst - &_.ordered_instructions()[0];
+ if (label_idx + 1 != terminator_idx) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target must only contain "
+ "an OpLabel and an OpBranch instruction.";
+ }
+
+ // 3. Use label instruction to confirm there is no uses by branches.
+ for (auto use : label_inst->uses()) {
+ const auto* use_inst = use.first;
+ if (spvOpcodeIsBranch(use_inst->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target cannot be the "
+ "target of a branch.";
+ }
+ }
+
+ // 4. Confirm that continue-target has a back edge to a reachable loop
+ // header block.
+ auto branch_target = terminator_inst->GetOperandAs<uint32_t>(0);
+ for (auto* continue_header : reachable_headers) {
+ if (branch_target != continue_header->id()) {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, unreachable continue-target must only have a "
+ "back edge to a single reachable loop instruction.";
+ }
+ }
+ } else {
+ return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
+ << "For WebGPU, all blocks must be reachable, unless they are "
+ << "degenerate cases of merge-block or continue-target.";
+ }
+ }
+ return SPV_SUCCESS;
+}
+
spv_result_t PerformCfgChecks(ValidationState_t& _) {
for (auto& function : _.functions()) {
// Check all referenced blocks are defined within a function
@@ -689,6 +804,13 @@
<< _.getIdName(idom->id());
}
}
+
+ // For WebGPU check that all unreachable blocks are degenerate cases for
+ // merge-block or continue-target.
+ if (spvIsWebGPUEnv(_.context()->target_env)) {
+ spv_result_t result = PerformWebGPUCfgChecks(_, &function);
+ if (result != SPV_SUCCESS) return result;
+ }
}
// If we have structed control flow, check that no block has a control
// flow nesting depth larger than the limit.
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index 15913c8..a0ee79a 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -25,6 +25,7 @@
#include "gmock/gmock.h"
#include "source/diagnostic.h"
+#include "source/spirv_target_env.h"
#include "source/val/validate.h"
#include "test/test_fixture.h"
#include "test/unit_spirv.h"
@@ -103,6 +104,12 @@
}
out << ss.str();
} break;
+ case SpvOpLoopMerge: {
+ assert(successors_.size() == 2);
+ out << "OpLoopMerge %" + successors_[0].label_ + " %" +
+ successors_[0].label_ + "None";
+ } break;
+
case SpvOpReturn:
assert(successors_.size() == 0);
out << "OpReturn\n";
@@ -115,6 +122,10 @@
assert(successors_.size() == 1);
out << "OpBranch %" + successors_.front().label_;
break;
+ case SpvOpKill:
+ assert(successors_.size() == 0);
+ out << "OpKill\n";
+ break;
default:
assert(1 == 0 && "Unhandled");
}
@@ -144,13 +155,13 @@
return lhs;
}
-const char* header(SpvCapability cap) {
- static const char* shader_header =
+const std::string& GetDefaultHeader(SpvCapability cap) {
+ static const std::string shader_header =
"OpCapability Shader\n"
"OpCapability Linkage\n"
"OpMemoryModel Logical GLSL450\n";
- static const char* kernel_header =
+ static const std::string kernel_header =
"OpCapability Kernel\n"
"OpCapability Linkage\n"
"OpMemoryModel Logical OpenCL\n";
@@ -158,8 +169,17 @@
return (cap == SpvCapabilityShader) ? shader_header : kernel_header;
}
-const char* types_consts() {
- static const char* types =
+const std::string& GetWebGPUHeader() {
+ static const std::string header =
+ "OpCapability Shader\n"
+ "OpCapability VulkanMemoryModelKHR\n"
+ "OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
+ "OpMemoryModel Logical VulkanKHR\n";
+ return header;
+}
+
+const std::string& types_consts() {
+ static const std::string types =
"%voidt = OpTypeVoid\n"
"%boolt = OpTypeBool\n"
"%intt = OpTypeInt 32 0\n"
@@ -167,7 +187,6 @@
"%two = OpConstant %intt 2\n"
"%ptrt = OpTypePointer Function %intt\n"
"%funct = OpTypeFunction %voidt\n";
-
return types;
}
@@ -270,7 +289,7 @@
loop.SetBody("OpLoopMerge %merge %cont None\n");
}
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("loop", "entry", "cont", "merge",
std::make_pair("func", "Main")) +
types_consts() +
@@ -293,7 +312,7 @@
entry.SetBody("%var = OpVariable %ptrt Function\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps(std::make_pair("func", "Main")) + types_consts() +
" %func = OpFunction %voidt None %funct\n";
str += entry >> cont;
@@ -313,7 +332,7 @@
// This operation should only be performed in the entry block
cont.SetBody("%var = OpVariable %ptrt Function\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps(std::make_pair("func", "Main")) + types_consts() +
" %func = OpFunction %voidt None %funct\n";
@@ -339,7 +358,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("loop", "merge", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -364,7 +383,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) branch.SetBody("OpSelectionMerge %merge None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("cont", "branch", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -396,9 +415,10 @@
// cannot share the same merge
if (is_shader) selection.SetBody("OpSelectionMerge %merge None\n");
- std::string str =
- header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("merge", std::make_pair("func", "Main")) +
+ types_consts() +
+ "%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> selection;
@@ -431,9 +451,10 @@
// cannot share the same merge
if (is_shader) loop.SetBody(" OpLoopMerge %merge %loop None\n");
- std::string str =
- header(GetParam()) + nameOps("merge", std::make_pair("func", "Main")) +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("merge", std::make_pair("func", "Main")) +
+ types_consts() +
+ "%func = OpFunction %voidt None %funct\n";
str += entry >> selection;
str += selection >> std::vector<Block>({merge, loop});
@@ -457,7 +478,7 @@
Block entry("entry");
Block bad("bad");
Block end("end", SpvOpReturn);
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "bad", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -481,7 +502,7 @@
Block bad("bad");
Block end("end", SpvOpReturn);
Block badvalue("undef"); // This referenes the OpUndef.
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "bad", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -507,7 +528,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
bad.SetBody(" OpLoopMerge %entry %exit None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "bad", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -535,7 +556,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
bad.SetBody("OpLoopMerge %merge %cont None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "bad", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -567,7 +588,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
bad.SetBody("OpSelectionMerge %merge None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "bad", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -602,9 +623,10 @@
Block middle2("middle2");
Block end2("end2", SpvOpReturn);
- std::string str =
- header(GetParam()) + nameOps("middle2", std::make_pair("func", "Main")) +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("middle2", std::make_pair("func", "Main")) +
+ types_consts() +
+ "%func = OpFunction %voidt None %funct\n";
str += entry >> middle;
str += middle >> std::vector<Block>({end, middle2});
@@ -637,7 +659,7 @@
if (is_shader) head.AppendBody("OpSelectionMerge %merge None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("head", "merge", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -672,7 +694,7 @@
if (is_shader) head.AppendBody("OpSelectionMerge %head None\n");
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("head", "exit", std::make_pair("func", "Main")) +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -694,8 +716,10 @@
}
}
-TEST_P(ValidateCFG, UnreachableMerge) {
- bool is_shader = GetParam() == SpvCapabilityShader;
+std::string GetUnreachableMergeNoMergeInst(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
Block entry("entry");
Block branch("branch", SpvOpBranchConditional);
Block t("t", SpvOpReturn);
@@ -703,13 +727,18 @@
Block merge("merge", SpvOpReturn);
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
- if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n");
+ if (!spvIsWebGPUEnv(env) && cap == SpvCapabilityShader)
+ branch.AppendBody("OpSelectionMerge %merge None\n");
- std::string str = header(GetParam()) +
- nameOps("branch", "merge", std::make_pair("func", "Main")) +
- types_consts() +
- "%func = OpFunction %voidt None %funct\n";
-
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+ str += types_consts() + "%func = OpFunction %voidt None %funct\n";
str += entry >> branch;
str += branch >> std::vector<Block>({t, f});
str += t;
@@ -717,12 +746,209 @@
str += merge;
str += "OpFunctionEnd\n";
- CompileSuccessfully(str);
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeNoMergeInst) {
+ CompileSuccessfully(
+ GetUnreachableMergeNoMergeInst(GetParam(), SPV_ENV_UNIVERSAL_1_0));
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
-TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) {
- bool is_shader = GetParam() == SpvCapabilityShader;
+TEST_F(ValidateCFG, WebGPUUnreachableMergeNoMergeInst) {
+ CompileSuccessfully(
+ GetUnreachableMergeNoMergeInst(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, all blocks must be reachable"));
+}
+
+std::string GetUnreachableMergeTerminatedBy(SpvCapability cap,
+ spv_target_env env, SpvOp op) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranchConditional);
+ Block t("t", SpvOpReturn);
+ Block f("f", SpvOpReturn);
+ Block merge("merge", op);
+
+ entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpSelectionMerge %merge None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({t, f});
+ str += t;
+ str += f;
+ str += merge;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeTerminatedByOpUnreachable) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpUnreachable));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, UnreachableMergeTerminatedByOpKill) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_UNIVERSAL_1_0, SpvOpKill));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_P(ValidateCFG, UnreachableMergeTerminatedByOpReturn) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpReturn));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpUnreachable) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpUnreachable));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpKill) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpKill));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must terminate with OpUnreachable"));
+}
+
+TEST_P(ValidateCFG, WebGPUUnreachableMergeTerminatedByOpReturn) {
+ CompileSuccessfully(GetUnreachableMergeTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpReturn));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must terminate with OpUnreachable"));
+}
+
+std::string GetUnreachableContinueTerminatedBy(SpvCapability cap,
+ spv_target_env env, SpvOp op) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", op);
+
+ if (op == SpvOpBranch) target >> branch;
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({merge});
+ str += merge;
+ str += target;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpUnreachable) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpUnreachable));
+ if (GetParam() == SpvCapabilityShader) {
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("targeted by 0 back-edge blocks"));
+ } else {
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+ }
+}
+
+TEST_F(ValidateCFG, UnreachableContinueTerminatedBySpvOpKill) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_UNIVERSAL_1_0, SpvOpKill));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("targeted by 0 back-edge blocks"));
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpReturn) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpReturn));
+ if (GetParam() == SpvCapabilityShader) {
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("targeted by 0 back-edge blocks"));
+ } else {
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+ }
+}
+
+TEST_P(ValidateCFG, UnreachableContinueTerminatedBySpvOpBranch) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0, SpvOpBranch));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpUnreachable) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpUnreachable));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, unreachable continue-target must "
+ "terminate with OpBranch.\n %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpKill) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpKill));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, unreachable continue-target must "
+ "terminate with OpBranch.\n %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpReturn) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpReturn));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, unreachable continue-target must "
+ "terminate with OpBranch.\n %12 = OpLabel\n"));
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueTerminatedBySpvOpBranch) {
+ CompileSuccessfully(GetUnreachableContinueTerminatedBy(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0, SpvOpBranch));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+std::string GetUnreachableMergeUnreachableMergeInst(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block body("body", SpvOpReturn);
Block entry("entry");
Block branch("branch", SpvOpBranchConditional);
Block t("t", SpvOpReturn);
@@ -730,13 +956,134 @@
Block merge("merge", SpvOpUnreachable);
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
- if (is_shader) branch.AppendBody("OpSelectionMerge %merge None\n");
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpSelectionMerge %merge None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
- std::string str = header(GetParam()) +
- nameOps("branch", "merge", std::make_pair("func", "Main")) +
- types_consts() +
- "%func = OpFunction %voidt None %funct\n";
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += body;
+ str += merge;
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({t, f});
+ str += t;
+ str += f;
+ str += "OpFunctionEnd\n";
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeUnreachableMergeInst) {
+ CompileSuccessfully(GetUnreachableMergeUnreachableMergeInst(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeUnreachableMergeInst) {
+ CompileSuccessfully(GetUnreachableMergeUnreachableMergeInst(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must be referenced by a reachable merge instruction"));
+}
+
+std::string GetUnreachableContinueUnreachableLoopInst(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block body("body", SpvOpReturn);
+ Block entry("entry");
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", SpvOpBranch);
+
+ target >> branch;
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += body;
+ str += target;
+ str += merge;
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({merge});
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueUnreachableLoopInst) {
+ CompileSuccessfully(GetUnreachableContinueUnreachableLoopInst(
+ GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ if (GetParam() == SpvCapabilityShader) {
+ // Shader causes additional structured CFG checks that cause a failure.
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Back-edges (1[%branch] -> 3[%target]) can only be "
+ "formed between a block and a loop header."));
+
+ } else {
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+ }
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueUnreachableLoopInst) {
+ CompileSuccessfully(GetUnreachableContinueUnreachableLoopInst(
+ SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("must be referenced by a reachable loop instruction"));
+}
+
+std::string GetUnreachableMergeWithComplexBody(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranchConditional);
+ Block t("t", SpvOpReturn);
+ Block f("f", SpvOpReturn);
+ Block merge("merge", SpvOpUnreachable);
+
+ entry.AppendBody(spvIsWebGPUEnv(env)
+ ? "%dummy = OpVariable %intptrt Function %two\n"
+ : "%dummy = OpVariable %intptrt Function\n");
+ entry.AppendBody("%cond = OpSLessThan %boolt %one %two\n");
+ merge.AppendBody("OpStore %dummy %one\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpSelectionMerge %merge None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%intptrt = OpTypePointer Function %intt\n";
+ str += "%func = OpFunction %voidt None %funct\n";
str += entry >> branch;
str += branch >> std::vector<Block>({t, f});
str += t;
@@ -744,31 +1091,406 @@
str += merge;
str += "OpFunctionEnd\n";
- CompileSuccessfully(str);
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithComplexBody) {
+ CompileSuccessfully(
+ GetUnreachableMergeWithComplexBody(GetParam(), SPV_ENV_UNIVERSAL_1_0));
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
-TEST_P(ValidateCFG, UnreachableBlock) {
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithComplexBody) {
+ CompileSuccessfully(GetUnreachableMergeWithComplexBody(SpvCapabilityShader,
+ SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("must only contain an OpLabel and OpUnreachable instruction"));
+}
+
+std::string GetUnreachableContinueWithComplexBody(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", SpvOpBranch);
+
+ target >> branch;
+
+ entry.AppendBody(spvIsWebGPUEnv(env)
+ ? "%dummy = OpVariable %intptrt Function %two\n"
+ : "%dummy = OpVariable %intptrt Function\n");
+ target.AppendBody("OpStore %dummy %one\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%intptrt = OpTypePointer Function %intt\n";
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({merge});
+ str += merge;
+ str += target;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueWithComplexBody) {
+ CompileSuccessfully(
+ GetUnreachableContinueWithComplexBody(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueWithComplexBody) {
+ CompileSuccessfully(GetUnreachableContinueWithComplexBody(SpvCapabilityShader,
+ SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("must only contain an OpLabel and an OpBranch instruction"));
+}
+
+std::string GetUnreachableMergeWithBranchUse(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranchConditional);
+ Block t("t", SpvOpBranch);
+ Block f("f", SpvOpReturn);
+ Block merge("merge", SpvOpUnreachable);
+
+ entry.AppendBody("%cond = OpSLessThan %boolt %one %two\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpSelectionMerge %merge None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({t, f});
+ str += t >> merge;
+ str += f;
+ str += merge;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithBranchUse) {
+ CompileSuccessfully(
+ GetUnreachableMergeWithBranchUse(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithBranchUse) {
+ CompileSuccessfully(
+ GetUnreachableMergeWithBranchUse(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("cannot be the target of a branch."));
+}
+
+std::string GetUnreachableMergeWithMultipleUses(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranchConditional);
+ Block t("t", SpvOpReturn);
+ Block f("f", SpvOpReturn);
+ Block merge("merge", SpvOpUnreachable);
+ Block duplicate("duplicate", SpvOpBranchConditional);
+
+ entry.AppendBody("%cond = OpSLessThan %boolt %one %two\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader) {
+ branch.AppendBody("OpSelectionMerge %merge None\n");
+ duplicate.AppendBody("OpSelectionMerge %merge None\n");
+ }
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({t, f});
+ str += duplicate >> std::vector<Block>({t, f});
+ str += t;
+ str += f;
+ str += merge;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeWithMultipleUses) {
+ CompileSuccessfully(
+ GetUnreachableMergeWithMultipleUses(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ if (GetParam() == SpvCapabilityShader) {
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("is already a merge block for another header"));
+ } else {
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+ }
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeWithMultipleUses) {
+ CompileSuccessfully(GetUnreachableMergeWithMultipleUses(SpvCapabilityShader,
+ SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("is already a merge block for another header"));
+}
+
+std::string GetUnreachableContinueWithBranchUse(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block foo("foo", SpvOpBranch);
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", SpvOpBranch);
+
+ foo >> target;
+ target >> branch;
+
+ entry.AppendBody(spvIsWebGPUEnv(env)
+ ? "%dummy = OpVariable %intptrt Function %two\n"
+ : "%dummy = OpVariable %intptrt Function\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader)
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%intptrt = OpTypePointer Function %intt\n";
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({merge});
+ str += merge;
+ str += target;
+ str += foo;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableContinueWithBranchUse) {
+ CompileSuccessfully(
+ GetUnreachableContinueWithBranchUse(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableContinueWithBranchUse) {
+ CompileSuccessfully(GetUnreachableContinueWithBranchUse(SpvCapabilityShader,
+ SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("cannot be the target of a branch."));
+}
+
+std::string GetReachableMergeAndContinue(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", SpvOpBranch);
+ Block body("body", SpvOpBranchConditional);
+ Block t("t", SpvOpBranch);
+ Block f("f", SpvOpBranch);
+
+ target >> branch;
+ body.SetBody("%cond = OpSLessThan %boolt %one %two\n");
+ t >> merge;
+ f >> target;
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader) {
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ body.AppendBody("OpSelectionMerge %target None\n");
+ }
+
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", "body", "t", "f",
+ std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({body});
+ str += body >> std::vector<Block>({t, f});
+ str += t;
+ str += f;
+ str += merge;
+ str += target;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, ReachableMergeAndContinue) {
+ CompileSuccessfully(
+ GetReachableMergeAndContinue(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUReachableMergeAndContinue) {
+ CompileSuccessfully(
+ GetReachableMergeAndContinue(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
+}
+
+std::string GetUnreachableMergeAndContinue(SpvCapability cap,
+ spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
+ Block entry("entry");
+ Block branch("branch", SpvOpBranch);
+ Block merge("merge", SpvOpReturn);
+ Block target("target", SpvOpBranch);
+ Block body("body", SpvOpBranchConditional);
+ Block t("t", SpvOpReturn);
+ Block f("f", SpvOpReturn);
+
+ target >> branch;
+ body.SetBody("%cond = OpSLessThan %boolt %one %two\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (cap == SpvCapabilityShader) {
+ branch.AppendBody("OpLoopMerge %merge %target None\n");
+ body.AppendBody("OpSelectionMerge %target None\n");
+ }
+
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("branch", "merge", "target", "body", "t", "f",
+ std::make_pair("func", "Main"));
+
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
+ str += entry >> branch;
+ str += branch >> std::vector<Block>({body});
+ str += body >> std::vector<Block>({t, f});
+ str += t;
+ str += f;
+ str += merge;
+ str += target;
+ str += "OpFunctionEnd\n";
+
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableMergeAndContinue) {
+ CompileSuccessfully(
+ GetUnreachableMergeAndContinue(GetParam(), SPV_ENV_UNIVERSAL_1_0));
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateCFG, WebGPUUnreachableMergeAndContinue) {
+ CompileSuccessfully(
+ GetUnreachableMergeAndContinue(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("unreachable merge-blocks must terminate with OpUnreachable"));
+}
+
+std::string GetUnreachableBlock(SpvCapability cap, spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
Block entry("entry");
Block unreachable("unreachable");
Block exit("exit", SpvOpReturn);
- std::string str =
- header(GetParam()) +
- nameOps("unreachable", "exit", std::make_pair("func", "Main")) +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
-
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("unreachable", "exit", std::make_pair("func", "Main"));
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
str += entry >> exit;
str += unreachable >> exit;
str += exit;
str += "OpFunctionEnd\n";
- CompileSuccessfully(str);
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableBlock) {
+ CompileSuccessfully(GetUnreachableBlock(GetParam(), SPV_ENV_UNIVERSAL_1_0));
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
-TEST_P(ValidateCFG, UnreachableBranch) {
- bool is_shader = GetParam() == SpvCapabilityShader;
+TEST_F(ValidateCFG, WebGPUUnreachableBlock) {
+ CompileSuccessfully(
+ GetUnreachableBlock(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("all blocks must be reachable"));
+}
+
+std::string GetUnreachableBranch(SpvCapability cap, spv_target_env env) {
+ std::string header =
+ spvIsWebGPUEnv(env) ? GetWebGPUHeader() : GetDefaultHeader(cap);
+
Block entry("entry");
Block unreachable("unreachable", SpvOpBranchConditional);
Block unreachablechildt("unreachablechildt");
@@ -777,11 +1499,19 @@
Block exit("exit", SpvOpReturn);
unreachable.SetBody("%cond = OpSLessThan %boolt %one %two\n");
- if (is_shader) unreachable.AppendBody("OpSelectionMerge %merge None\n");
- std::string str =
- header(GetParam()) +
- nameOps("unreachable", "exit", std::make_pair("func", "Main")) +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ if (cap == SpvCapabilityShader)
+ unreachable.AppendBody("OpSelectionMerge %merge None\n");
+
+ std::string str = header;
+ if (spvIsWebGPUEnv(env)) {
+ str +=
+ "OpEntryPoint Fragment %func \"func\"\n"
+ "OpExecutionMode %func OriginUpperLeft\n";
+ }
+ if (!spvIsWebGPUEnv(env))
+ str += nameOps("unreachable", "exit", std::make_pair("func", "Main"));
+ str += types_consts();
+ str += "%func = OpFunction %voidt None %funct\n";
str += entry >> exit;
str +=
@@ -792,12 +1522,23 @@
str += exit;
str += "OpFunctionEnd\n";
- CompileSuccessfully(str);
+ return str;
+}
+
+TEST_P(ValidateCFG, UnreachableBranch) {
+ CompileSuccessfully(GetUnreachableBranch(GetParam(), SPV_ENV_UNIVERSAL_1_0));
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateCFG, WebGPUUnreachableBranch) {
+ CompileSuccessfully(
+ GetUnreachableBranch(SpvCapabilityShader, SPV_ENV_WEBGPU_0));
+ ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(), HasSubstr("all blocks must be reachable"));
+}
+
TEST_P(ValidateCFG, EmptyFunction) {
- std::string str = header(GetParam()) + std::string(types_consts()) +
+ std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
R"(%func = OpFunction %voidt None %funct
%l = OpLabel
OpReturn
@@ -816,7 +1557,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.AppendBody("OpLoopMerge %exit %loop None\n");
- std::string str = header(GetParam()) + std::string(types_consts()) +
+ std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
@@ -845,8 +1586,8 @@
loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n");
}
- std::string str = header(GetParam()) + nameOps("loop2", "loop2_merge") +
- types_consts() +
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("loop2", "loop2_merge") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop1;
@@ -885,7 +1626,7 @@
if_blocks[i].SetBody("OpSelectionMerge %if_merge" + ss.str() + " None\n");
merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch);
}
- std::string str = header(GetParam()) + std::string(types_consts()) +
+ std::string str = GetDefaultHeader(GetParam()) + std::string(types_consts()) +
"%func = OpFunction %voidt None %funct\n";
str += entry >> if_blocks[0];
@@ -920,7 +1661,7 @@
loop2.SetBody("OpLoopMerge %loop2_merge %loop2 None\n");
}
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("loop1", "loop2", "be_block", "loop2_merge") +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -957,7 +1698,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) split.SetBody("OpSelectionMerge %exit None\n");
- std::string str = header(GetParam()) + nameOps("split", "f") +
+ std::string str = GetDefaultHeader(GetParam()) + nameOps("split", "f") +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -990,7 +1731,8 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) split.SetBody("OpSelectionMerge %exit None\n");
- std::string str = header(GetParam()) + nameOps("split") + types_consts() +
+ std::string str = GetDefaultHeader(GetParam()) + nameOps("split") +
+ types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> split;
@@ -1022,8 +1764,8 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.SetBody("OpLoopMerge %merge %back0 None\n");
- std::string str = header(GetParam()) + nameOps("loop", "back0", "back1") +
- types_consts() +
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("loop", "back0", "back1") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
@@ -1059,8 +1801,8 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.SetBody("OpLoopMerge %merge %cheader None\n");
- std::string str = header(GetParam()) + nameOps("cheader", "be_block") +
- types_consts() +
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("cheader", "be_block") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
@@ -1094,7 +1836,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
- std::string str = header(GetParam()) + nameOps("cont", "loop") +
+ std::string str = GetDefaultHeader(GetParam()) + nameOps("cont", "loop") +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -1129,7 +1871,7 @@
entry.SetBody("%cond = OpSLessThan %boolt %one %two\n");
if (is_shader) loop.SetBody("OpLoopMerge %merge %loop None\n");
- std::string str = header(GetParam()) + nameOps("cont", "loop") +
+ std::string str = GetDefaultHeader(GetParam()) + nameOps("cont", "loop") +
types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -1280,7 +2022,7 @@
inner_head.SetBody("OpSelectionMerge %inner_merge None\n");
}
- std::string str = header(GetParam()) +
+ std::string str = GetDefaultHeader(GetParam()) +
nameOps("entry", "inner_merge", "exit") + types_consts() +
"%func = OpFunction %voidt None %funct\n";
@@ -1318,7 +2060,7 @@
}
std::string str =
- header(GetParam()) +
+ GetDefaultHeader(GetParam()) +
nameOps("entry", "loop", "if_head", "if_true", "if_merge", "merge") +
types_consts() + "%func = OpFunction %voidt None %funct\n";
@@ -1348,9 +2090,10 @@
loop.SetBody("OpLoopMerge %merge %latch None\n");
}
- std::string str =
- header(GetParam()) + nameOps("entry", "loop", "latch", "merge") +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("entry", "loop", "latch", "merge") +
+ types_consts() +
+ "%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> std::vector<Block>({latch, merge});
@@ -1381,9 +2124,10 @@
loop.SetBody("OpLoopMerge %merge %loop None\n");
}
- std::string str =
- header(GetParam()) + nameOps("entry", "loop", "latch", "merge") +
- types_consts() + "%func = OpFunction %voidt None %funct\n";
+ std::string str = GetDefaultHeader(GetParam()) +
+ nameOps("entry", "loop", "latch", "merge") +
+ types_consts() +
+ "%func = OpFunction %voidt None %funct\n";
str += entry >> loop;
str += loop >> latch;