Validator: TaskNV can use LocalSize or LocalSizeId (#1970)
Correponds to the update to Rev2 of SPV_NV_mesh_shader
Fixes #1968
diff --git a/source/val/validate_mode_setting.cpp b/source/val/validate_mode_setting.cpp
index ec13b70..c1bfc27 100644
--- a/source/val/validate_mode_setting.cpp
+++ b/source/val/validate_mode_setting.cpp
@@ -376,6 +376,7 @@
case SpvExecutionModelKernel:
case SpvExecutionModelGLCompute:
return true;
+ case SpvExecutionModelTaskNV:
case SpvExecutionModelMeshNV:
return _.HasCapability(SpvCapabilityMeshShadingNV);
default:
@@ -384,8 +385,8 @@
})) {
if (_.HasCapability(SpvCapabilityMeshShadingNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Execution mode can only be used with a Kernel, GLCompute "
- "or MeshNV execution model.";
+ << "Execution mode can only be used with a Kernel, GLCompute, "
+ "MeshNV, or TaskNV execution model.";
} else {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Execution mode can only be used with a Kernel or "
diff --git a/test/val/val_modes_test.cpp b/test/val/val_modes_test.cpp
index 5d072f0..7f1ef09 100644
--- a/test/val/val_modes_test.cpp
+++ b/test/val/val_modes_test.cpp
@@ -720,6 +720,20 @@
EXPECT_THAT(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateModeExecution, TaskNVLocalSize) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpCapability MeshShadingNV
+OpExtension "SPV_NV_mesh_shader"
+OpMemoryModel Logical GLSL450
+OpEntryPoint TaskNV %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+)" + kVoidFunction;
+
+ CompileSuccessfully(spirv);
+ EXPECT_THAT(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateModeExecution, MeshNVOutputPoints) {
const std::string spirv = R"(
OpCapability Shader
@@ -765,6 +779,23 @@
EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env));
}
+TEST_F(ValidateModeExecution, TaskNVLocalSizeId) {
+ const std::string spirv = R"(
+OpCapability Shader
+OpCapability MeshShadingNV
+OpExtension "SPV_NV_mesh_shader"
+OpMemoryModel Logical GLSL450
+OpEntryPoint TaskNV %main "main"
+OpExecutionModeId %main LocalSizeId %int_1 %int_1 %int_1
+%int = OpTypeInt 32 0
+%int_1 = OpConstant %int 1
+)" + kVoidFunction;
+
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_3;
+ CompileSuccessfully(spirv, env);
+ EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(env));
+}
+
} // namespace
} // namespace val
} // namespace spvtools