Add precise check for allowing use of gl_InstanceID for specific vulkan raytracing stages . (#2096)
* Checks that gl_InstanceID is only used in specific execution models
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 98eb98c..aaba324 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -207,6 +207,11 @@
const Instruction& referenced_inst,
const Instruction& referenced_from_inst);
+ spv_result_t ValidateInstanceIdAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst);
+
spv_result_t ValidateInstanceIndexAtReference(
const Decoration& decoration, const Instruction& built_in_inst,
const Instruction& referenced_inst,
@@ -2098,6 +2103,43 @@
"to be used.";
}
+ if (label == SpvBuiltInInstanceId) {
+ return ValidateInstanceIdAtReference(decoration, inst, inst, inst);
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t BuiltInsValidator::ValidateInstanceIdAtReference(
+ const Decoration& decoration, const Instruction& built_in_inst,
+ const Instruction& referenced_inst,
+ const Instruction& referenced_from_inst) {
+ if (spvIsVulkanEnv(_.context()->target_env)) {
+ for (const SpvExecutionModel execution_model : execution_models_) {
+ switch (execution_model) {
+ case SpvExecutionModelIntersectionNV:
+ case SpvExecutionModelClosestHitNV:
+ case SpvExecutionModelAnyHitNV:
+ // Do nothing, valid stages
+ break;
+ default:
+ return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+ << "Vulkan spec allows BuiltIn InstanceId to be used "
+ "only with IntersectionNV, ClosestHitNV and AnyHitNV "
+ "execution models. "
+ << GetReferenceDesc(decoration, built_in_inst, referenced_inst,
+ referenced_from_inst);
+ break;
+ }
+ }
+ }
+
+ if (function_id_ == 0) {
+ // Propagate this rule to all dependant ids in the global scope.
+ id_to_at_reference_checks_[referenced_from_inst.id()].push_back(std::bind(
+ &BuiltInsValidator::ValidateInstanceIdAtReference, this, decoration,
+ built_in_inst, referenced_from_inst, std::placeholders::_1));
+ }
+
return SPV_SUCCESS;
}
diff --git a/test/val/val_builtins_test.cpp b/test/val/val_builtins_test.cpp
index 9e3798b..b1458c9 100644
--- a/test/val/val_builtins_test.cpp
+++ b/test/val/val_builtins_test.cpp
@@ -2209,6 +2209,44 @@
EXPECT_THAT(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
}
+TEST_F(ValidateBuiltIns, DisallowInstanceIdWithRayGenShader) {
+ CodeGenerator generator = GetDefaultShaderCodeGenerator();
+ generator.capabilities_ += R"(
+OpCapability RayTracingNV
+)";
+
+ generator.extensions_ = R"(
+OpExtension "SPV_NV_ray_tracing"
+)";
+
+ generator.before_types_ = R"(
+OpMemberDecorate %input_type 0 BuiltIn InstanceId
+)";
+
+ generator.after_types_ = R"(
+%input_type = OpTypeStruct %u32
+%input_ptr = OpTypePointer Input %input_type
+%input_ptr_u32 = OpTypePointer Input %u32
+%input = OpVariable %input_ptr Input
+)";
+
+ EntryPoint entry_point;
+ entry_point.name = "main_d_r";
+ entry_point.execution_model = "RayGenerationNV";
+ entry_point.interfaces = "%input";
+ entry_point.body = R"(
+%input_member = OpAccessChain %input_ptr_u32 %input %u32_0
+)";
+ generator.entry_points_.push_back(std::move(entry_point));
+
+ CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Vulkan spec allows BuiltIn InstanceId to be used "
+ "only with IntersectionNV, ClosestHitNV and "
+ "AnyHitNV execution models"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools