Add helper for 'is Vulkan or WebGPU' (#2324)
Fixes #2323
diff --git a/source/spirv_target_env.cpp b/source/spirv_target_env.cpp
index 1dc2508..320b306 100644
--- a/source/spirv_target_env.cpp
+++ b/source/spirv_target_env.cpp
@@ -249,6 +249,10 @@
return false;
}
+bool spvIsVulkanOrWebGPUEnv(spv_target_env env) {
+ return spvIsVulkanEnv(env) || spvIsWebGPUEnv(env);
+}
+
std::string spvLogStringForEnv(spv_target_env env) {
switch (env) {
case SPV_ENV_OPENCL_1_2:
diff --git a/source/spirv_target_env.h b/source/spirv_target_env.h
index 9061a3a..d463570 100644
--- a/source/spirv_target_env.h
+++ b/source/spirv_target_env.h
@@ -32,6 +32,9 @@
// Returns true if |env| is an WEBGPU environment, false otherwise.
bool spvIsWebGPUEnv(spv_target_env env);
+// Returns true if |env| is a VULKAN or WEBGPU environment, false otherwise.
+bool spvIsVulkanOrWebGPUEnv(spv_target_env env);
+
// Returns the version number for the given SPIR-V target environment.
uint32_t spvVersionForTargetEnv(spv_target_env env);
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 9797d31..4024f61 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -249,8 +249,7 @@
// For Vulkan and WebGPU, the static function-call graph for an entry point
// must not contain cycles.
- if (spvIsWebGPUEnv(_.context()->target_env) ||
- spvIsVulkanEnv(_.context()->target_env)) {
+ if (spvIsVulkanOrWebGPUEnv(_.context()->target_env)) {
if (_.recursive_entry_points().find(entry_point) !=
_.recursive_entry_points().end()) {
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 3b104be..9e93cf1 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -515,23 +515,13 @@
if (inst->operands().size() > 3 && storage_class != SpvStorageClassOutput &&
storage_class != SpvStorageClassPrivate &&
storage_class != SpvStorageClassFunction) {
- if (spvIsVulkanEnv(_.context()->target_env)) {
+ if (spvIsVulkanOrWebGPUEnv(_.context()->target_env)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpVariable, <id> '" << _.getIdName(inst->id())
<< "', has a disallowed initializer & storage class "
<< "combination.\n"
- << "From Vulkan spec, Appendix A:\n"
- << "Variable declarations that include initializers must have "
- << "one of the following storage classes: Output, Private, or "
- << "Function";
- }
-
- if (spvIsWebGPUEnv(_.context()->target_env)) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpVariable, <id> '" << _.getIdName(inst->id())
- << "', has a disallowed initializer & storage class "
- << "combination.\n"
- << "From WebGPU execution environment spec:\n"
+ << "From " << spvLogStringForEnv(_.context()->target_env)
+ << " spec:\n"
<< "Variable declarations that include initializers must have "
<< "one of the following storage classes: Output, Private, or "
<< "Function";
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 365d4cc..7669f72 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -107,8 +107,7 @@
<< "' is a void type.";
}
- if ((spvIsVulkanEnv(_.context()->target_env) ||
- spvIsWebGPUEnv(_.context()->target_env)) &&
+ if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
element_type->opcode() == SpvOpTypeRuntimeArray) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeArray Element Type <id> '" << _.getIdName(element_type_id)
@@ -171,8 +170,7 @@
<< _.getIdName(element_id) << "' is a void type.";
}
- if ((spvIsVulkanEnv(_.context()->target_env) ||
- spvIsWebGPUEnv(_.context()->target_env)) &&
+ if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
element_type->opcode() == SpvOpTypeRuntimeArray) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeRuntimeArray Element Type <id> '"
@@ -226,8 +224,7 @@
}
}
- if ((spvIsVulkanEnv(_.context()->target_env) ||
- spvIsWebGPUEnv(_.context()->target_env)) &&
+ if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
member_type->opcode() == SpvOpTypeRuntimeArray) {
const bool is_last_member =
member_type_index == inst->operands().size() - 1;
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index 167b11f..9ebbac0 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -449,12 +449,11 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_WEBGPU_0));
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr(
- "OpVariable, <id> '5[%5]', has a disallowed initializer & storage "
- "class combination.\nFrom WebGPU execution environment spec:\n"
- "Variable declarations that include initializers must have one of "
- "the following storage classes: Output, Private, or Function\n"
- " %5 = OpVariable %_ptr_Uniform_float Uniform %float_1\n"));
+ HasSubstr("OpVariable, <id> '5[%5]', has a disallowed initializer & "
+ "storage class combination.\nFrom WebGPU spec:\nVariable "
+ "declarations that include initializers must have one of the "
+ "following storage classes: Output, Private, or Function\n %5 "
+ "= OpVariable %_ptr_Uniform_float Uniform %float_1\n"));
}
TEST_F(ValidateMemory, WebGPUOutputStorageClassWithoutInitializerBad) {
@@ -628,12 +627,11 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr(
- "OpVariable, <id> '5[%5]', has a disallowed initializer & storage "
- "class combination.\nFrom Vulkan spec, Appendix A:\n"
- "Variable declarations that include initializers must have one of "
- "the following storage classes: Output, Private, or Function\n "
- "%5 = OpVariable %_ptr_Input_float Input %float_1\n"));
+ HasSubstr("OpVariable, <id> '5[%5]', has a disallowed initializer & "
+ "storage class combination.\nFrom Vulkan spec:\nVariable "
+ "declarations that include initializers must have one of the "
+ "following storage classes: Output, Private, or Function\n %5 "
+ "= OpVariable %_ptr_Input_float Input %float_1\n"));
}
TEST_F(ValidateMemory, ArrayLenCorrectResultType) {