Ensure that only whitelisted extensions are used in WebGPU (#2127)
Fixes #2058
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index fe38f1f..f264c8e 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -21,6 +21,8 @@
#include <vector>
#include "source/diagnostic.h"
+#include "source/enum_string_mapping.h"
+#include "source/extensions.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/opcode.h"
@@ -42,6 +44,21 @@
} // anonymous namespace
+spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
+ if (spvIsWebGPUEnv(_.context()->target_env)) {
+ std::string extension = GetExtensionString(&(inst->c_inst()));
+
+ if (extension != ExtensionToString(kSPV_KHR_vulkan_memory_model)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU, the only valid parameter to OpExtension is "
+ << "\"" << ExtensionToString(kSPV_KHR_vulkan_memory_model)
+ << "\".";
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateExtInstImport(ValidationState_t& _,
const Instruction* inst) {
if (spvIsWebGPUEnv(_.context()->target_env)) {
@@ -2001,6 +2018,7 @@
spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst) {
const SpvOp opcode = inst->opcode();
+ if (opcode == SpvOpExtension) return ValidateExtension(_, inst);
if (opcode == SpvOpExtInstImport) return ValidateExtInstImport(_, inst);
if (opcode == SpvOpExtInst) return ValidateExtInst(_, inst);
diff --git a/test/val/val_webgpu_test.cpp b/test/val/val_webgpu_test.cpp
index ba59198..48ea21d 100644
--- a/test/val/val_webgpu_test.cpp
+++ b/test/val/val_webgpu_test.cpp
@@ -258,6 +258,24 @@
"OpExtInstImport \"OpenCL.std\"\n"));
}
+TEST_F(ValidateWebGPU, NonVulkanKHRMemoryModelExtensionBad) {
+ std::string spirv = R"(
+ OpCapability Shader
+ OpCapability VulkanMemoryModelKHR
+ OpExtension "SPV_KHR_8bit_storage"
+ OpExtension "SPV_KHR_vulkan_memory_model"
+ OpMemoryModel Logical VulkanKHR
+)";
+
+ CompileSuccessfully(spirv);
+
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("For WebGPU, the only valid parameter to OpExtension "
+ "is \"SPV_KHR_vulkan_memory_model\".\n OpExtension "
+ "\"SPV_KHR_8bit_storage\"\n"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools