spirv-val: SBT Index for OpExecuteCallableKHR (#4900)
diff --git a/source/val/validate_ray_tracing.cpp b/source/val/validate_ray_tracing.cpp
index 81fa593..78bac19 100644
--- a/source/val/validate_ray_tracing.cpp
+++ b/source/val/validate_ray_tracing.cpp
@@ -174,6 +174,13 @@
return true;
});
+ const uint32_t sbt_index = _.GetOperandTypeId(inst, 0);
+ if (!_.IsUnsignedIntScalarType(sbt_index) ||
+ _.GetBitWidth(sbt_index) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "SBT Index must be a 32-bit unsigned int scalar";
+ }
+
const auto callable_data = _.FindDef(inst->GetOperandAs<uint32_t>(1));
if (callable_data->opcode() != SpvOpVariable) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
diff --git a/test/val/val_ray_tracing.cpp b/test/val/val_ray_tracing.cpp
index 9486777..58b9356 100644
--- a/test/val/val_ray_tracing.cpp
+++ b/test/val/val_ray_tracing.cpp
@@ -334,6 +334,34 @@
"or IncomingCallableDataKHR"));
}
+TEST_F(ValidateRayTracing, ExecuteCallableSbtIndex) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint CallableKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%int_1 = OpConstant %int 1
+%data_ptr = OpTypePointer CallableDataKHR %int
+%data = OpVariable %data_ptr CallableDataKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %int_1 %data
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("SBT Index must be a 32-bit unsigned int scalar"));
+}
+
std::string GenerateRayTraceCode(
const std::string& body,
const std::string execution_model = "RayGenerationKHR") {