spirv-val: Add SPV_KHR_ray_query (#4848)
diff --git a/Android.mk b/Android.mk
index 6dd1834..cd1d7f8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -68,6 +68,7 @@
source/val/validate_logicals.cpp \
source/val/validate_non_uniform.cpp \
source/val/validate_primitives.cpp \
+ source/val/validate_ray_query.cpp \
source/val/validate_scopes.cpp \
source/val/validate_small_type_uses.cpp \
source/val/validate_type.cpp
diff --git a/BUILD.gn b/BUILD.gn
index 71a584f..9e9f6e5 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -530,6 +530,7 @@
"source/val/validate_mode_setting.cpp",
"source/val/validate_non_uniform.cpp",
"source/val/validate_primitives.cpp",
+ "source/val/validate_ray_query.cpp",
"source/val/validate_scopes.cpp",
"source/val/validate_scopes.h",
"source/val/validate_small_type_uses.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index c0974e1..1ceb78f 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -322,6 +322,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_query.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index dc6401a..55e9fd2 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -350,6 +350,7 @@
if (auto error = NonUniformPass(*vstate, &instruction)) return error;
if (auto error = LiteralsPass(*vstate, &instruction)) return error;
+ if (auto error = RayQueryPass(*vstate, &instruction)) return error;
}
// Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
diff --git a/source/val/validate.h b/source/val/validate.h
index cb1d05a..97d4683 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -197,6 +197,9 @@
/// Validates correctness of miscellaneous instructions.
spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
+/// Validates correctness of ray query instructions.
+spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
+
/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);
diff --git a/source/val/validate_ray_query.cpp b/source/val/validate_ray_query.cpp
new file mode 100644
index 0000000..f92bf01
--- /dev/null
+++ b/source/val/validate_ray_query.cpp
@@ -0,0 +1,271 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Validates ray query instructions from SPV_KHR_ray_query
+
+#include "source/opcode.h"
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
+ const Instruction* inst,
+ uint32_t ray_query_index) {
+ const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
+ auto variable = _.FindDef(ray_query_id);
+ if (!variable || (variable->opcode() != SpvOpVariable &&
+ variable->opcode() != SpvOpFunctionParameter)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Query must be a memory object declaration";
+ }
+ auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
+ if (!pointer || pointer->opcode() != SpvOpTypePointer) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Query must be a pointer";
+ }
+ auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
+ if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Query must be a pointer to OpTypeRayQueryKHR";
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t ValidateIntersectionId(ValidationState_t& _,
+ const Instruction* inst,
+ uint32_t intersection_index) {
+ const uint32_t intersection_id =
+ inst->GetOperandAs<uint32_t>(intersection_index);
+ const uint32_t intersection_type = _.GetTypeId(intersection_id);
+ const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
+ if (!_.IsIntScalarType(intersection_type) ||
+ _.GetBitWidth(intersection_type) != 32 ||
+ !spvOpcodeIsConstant(intersection_opcode)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Intersection ID to be a constant 32-bit int scalar";
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace
+
+spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
+ const SpvOp opcode = inst->opcode();
+ const uint32_t result_type = inst->type_id();
+
+ switch (opcode) {
+ case SpvOpRayQueryInitializeKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+
+ if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
+ SpvOpTypeAccelerationStructureKHR) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Acceleration Structure to be of type "
+ "OpTypeAccelerationStructureKHR";
+ }
+
+ const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
+ if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Flags must be a 32-bit int scalar";
+ }
+
+ const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
+ if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cull Mask must be a 32-bit int scalar";
+ }
+
+ const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
+ if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
+ _.GetBitWidth(ray_origin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Origin must be a 32-bit float 3-component vector";
+ }
+
+ const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
+ if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMin must be a 32-bit float scalar";
+ }
+
+ const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
+ if (!_.IsFloatVectorType(ray_direction) ||
+ _.GetDimension(ray_direction) != 3 ||
+ _.GetBitWidth(ray_direction) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Direction must be a 32-bit float 3-component vector";
+ }
+
+ const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
+ if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMax must be a 32-bit float scalar";
+ }
+ break;
+ }
+
+ case SpvOpRayQueryTerminateKHR:
+ case SpvOpRayQueryConfirmIntersectionKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+ break;
+ }
+
+ case SpvOpRayQueryGenerateIntersectionKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+
+ const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hit T must be a 32-bit float scalar";
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionFrontFaceKHR:
+ case SpvOpRayQueryProceedKHR:
+ case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+ if (!_.IsBoolScalarType(result_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be bool scalar type";
+ }
+
+ if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionTKHR:
+ case SpvOpRayQueryGetRayTMinKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+ if (!_.IsFloatScalarType(result_type) ||
+ _.GetBitWidth(result_type) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be 32-bit float scalar type";
+ }
+
+ if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionTypeKHR:
+ case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
+ case SpvOpRayQueryGetIntersectionInstanceIdKHR:
+ case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
+ case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
+ case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
+ case SpvOpRayQueryGetRayFlagsKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+ if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be 32-bit int scalar type";
+ }
+
+ if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
+ case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
+ case SpvOpRayQueryGetWorldRayDirectionKHR:
+ case SpvOpRayQueryGetWorldRayOriginKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+ if (!_.IsFloatVectorType(result_type) ||
+ _.GetDimension(result_type) != 3 ||
+ _.GetBitWidth(result_type) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be 32-bit float 3-component "
+ "vector type";
+ }
+
+ if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
+ opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+
+ if (!_.IsFloatVectorType(result_type) ||
+ _.GetDimension(result_type) != 2 ||
+ _.GetBitWidth(result_type) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be 32-bit float 2-component "
+ "vector type";
+ }
+
+ break;
+ }
+
+ case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
+ case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
+ if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+ if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+
+ uint32_t num_rows = 0;
+ uint32_t num_cols = 0;
+ uint32_t col_type = 0;
+ uint32_t component_type = 0;
+ if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
+ &component_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected matrix type as Result Type";
+ }
+
+ if (num_cols != 4) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type matrix to have a Column Count of 4";
+ }
+
+ if (!_.IsFloatScalarType(component_type) ||
+ _.GetBitWidth(result_type) != 32 || num_rows != 3) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type matrix to have a Column Type of "
+ "3-component 32-bit float vectors";
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace val
+} // namespace spvtools
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 65f2791..d02807a 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -88,8 +88,9 @@
PCH_FILE pch_test_val
)
-add_spvtools_unittest(TARGET val_stuvw
+add_spvtools_unittest(TARGET val_rstuvw
SRCS
+ val_ray_query.cpp
val_small_type_uses_test.cpp
val_ssa_test.cpp
val_state_test.cpp
diff --git a/test/val/val_ray_query.cpp b/test/val/val_ray_query.cpp
new file mode 100644
index 0000000..e9b9696
--- /dev/null
+++ b/test/val/val_ray_query.cpp
@@ -0,0 +1,578 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests ray query instructions from SPV_KHR_ray_query.
+
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateRayQuery = spvtest::ValidateBase<bool>;
+
+std::string GenerateShaderCode(
+ const std::string& body,
+ const std::string& capabilities_and_extensions = "",
+ const std::string& declarations = "") {
+ std::ostringstream ss;
+ ss << R"(
+OpCapability Shader
+OpCapability Int64
+OpCapability Float64
+OpCapability RayQueryKHR
+OpExtension "SPV_KHR_ray_query"
+)";
+
+ ss << capabilities_and_extensions;
+
+ ss << R"(
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+
+OpDecorate %top_level_as DescriptorSet 0
+OpDecorate %top_level_as Binding 0
+
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f32 = OpTypeFloat 32
+%f64 = OpTypeFloat 64
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+%u64 = OpTypeInt 64 0
+%s64 = OpTypeInt 64 1
+%type_rq = OpTypeRayQueryKHR
+%type_as = OpTypeAccelerationStructureKHR
+
+%s32vec2 = OpTypeVector %s32 2
+%u32vec2 = OpTypeVector %u32 2
+%f32vec2 = OpTypeVector %f32 2
+%u32vec3 = OpTypeVector %u32 3
+%s32vec3 = OpTypeVector %s32 3
+%f32vec3 = OpTypeVector %f32 3
+%u32vec4 = OpTypeVector %u32 4
+%s32vec4 = OpTypeVector %s32 4
+%f32vec4 = OpTypeVector %f32 4
+
+%mat4x3 = OpTypeMatrix %f32vec3 4
+
+%f32_0 = OpConstant %f32 0
+%f64_0 = OpConstant %f64 0
+%s32_0 = OpConstant %s32 0
+%u32_0 = OpConstant %u32 0
+%u64_0 = OpConstant %u64 0
+
+%u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
+%f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0
+%f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0
+
+%ptr_rq = OpTypePointer Private %type_rq
+%ray_query = OpVariable %ptr_rq Private
+
+%ptr_as = OpTypePointer UniformConstant %type_as
+%top_level_as = OpVariable %ptr_as UniformConstant
+
+%ptr_function_u32 = OpTypePointer Function %u32
+%ptr_function_f32 = OpTypePointer Function %f32
+%ptr_function_f32vec3 = OpTypePointer Function %f32vec3
+)";
+
+ ss << declarations;
+
+ ss << R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+)";
+
+ ss << body;
+
+ ss << R"(
+OpReturn
+OpFunctionEnd)";
+ return ss.str();
+}
+
+std::string RayQueryResult(std::string opcode) {
+ if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+ opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
+ opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+ "OffsetKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+ opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
+ opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+ return "%result =";
+ }
+ return "";
+}
+
+std::string RayQueryResultType(std::string opcode, bool valid) {
+ if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+ opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+ "OffsetKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0) {
+ return valid ? "%u32" : "%f64";
+ }
+
+ if (opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionTKHR") == 0) {
+ return valid ? "%f32" : "%f64";
+ }
+
+ if (opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0) {
+ return valid ? "%f32vec2" : "%f64";
+ }
+
+ if (opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+ opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
+ opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0) {
+ return valid ? "%f32vec3" : "%f64";
+ }
+
+ if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0) {
+ return valid ? "%bool" : "%f64";
+ }
+
+ if (opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+ return valid ? "%mat4x3" : "%f64";
+ }
+ return "";
+}
+
+std::string RayQueryIntersection(std::string opcode, bool valid) {
+ if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+ "OffsetKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+ opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+ return valid ? "%s32_0" : "%f32_0";
+ }
+ return "";
+}
+
+using RayQueryCommon = spvtest::ValidateBase<std::string>;
+
+TEST_P(RayQueryCommon, Success) {
+ std::string opcode = GetParam();
+ std::ostringstream ss;
+ ss << RayQueryResult(opcode);
+ ss << " " << opcode << " ";
+ ss << RayQueryResultType(opcode, true);
+ ss << " %ray_query ";
+ ss << RayQueryIntersection(opcode, true);
+ CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_P(RayQueryCommon, BadQuery) {
+ std::string opcode = GetParam();
+ std::ostringstream ss;
+ ss << RayQueryResult(opcode);
+ ss << " " << opcode << " ";
+ ss << RayQueryResultType(opcode, true);
+ ss << " %top_level_as ";
+ ss << RayQueryIntersection(opcode, true);
+ CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_P(RayQueryCommon, BadResult) {
+ std::string opcode = GetParam();
+ std::string result_type = RayQueryResultType(opcode, false);
+ if (!result_type.empty()) {
+ std::ostringstream ss;
+ ss << RayQueryResult(opcode);
+ ss << " " << opcode << " ";
+ ss << result_type;
+ ss << " %ray_query ";
+ ss << RayQueryIntersection(opcode, true);
+ CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+
+ std::string correct_result_type = RayQueryResultType(opcode, true);
+ if (correct_result_type.compare("%u32") == 0) {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("expected Result Type to be 32-bit int scalar type"));
+ } else if (correct_result_type.compare("%f32") == 0) {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("expected Result Type to be 32-bit float scalar type"));
+ } else if (correct_result_type.compare("%f32vec2") == 0) {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("expected Result Type to be 32-bit float "
+ "2-component vector type"));
+ } else if (correct_result_type.compare("%f32vec3") == 0) {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("expected Result Type to be 32-bit float "
+ "3-component vector type"));
+ } else if (correct_result_type.compare("%bool") == 0) {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("expected Result Type to be bool scalar type"));
+ } else if (correct_result_type.compare("%mat4x3") == 0) {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("expected matrix type as Result Type"));
+ }
+ }
+}
+
+TEST_P(RayQueryCommon, BadIntersection) {
+ std::string opcode = GetParam();
+ std::string intersection = RayQueryIntersection(opcode, false);
+ if (!intersection.empty()) {
+ std::ostringstream ss;
+ ss << RayQueryResult(opcode);
+ ss << " " << opcode << " ";
+ ss << RayQueryResultType(opcode, true);
+ ss << " %ray_query ";
+ ss << intersection;
+ CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "expected Intersection ID to be a constant 32-bit int scalar"));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ ValidateRayQueryCommon, RayQueryCommon,
+ Values("OpRayQueryTerminateKHR", "OpRayQueryConfirmIntersectionKHR",
+ "OpRayQueryProceedKHR", "OpRayQueryGetIntersectionTypeKHR",
+ "OpRayQueryGetRayTMinKHR", "OpRayQueryGetRayFlagsKHR",
+ "OpRayQueryGetWorldRayDirectionKHR",
+ "OpRayQueryGetWorldRayOriginKHR", "OpRayQueryGetIntersectionTKHR",
+ "OpRayQueryGetIntersectionInstanceCustomIndexKHR",
+ "OpRayQueryGetIntersectionInstanceIdKHR",
+ "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR",
+ "OpRayQueryGetIntersectionGeometryIndexKHR",
+ "OpRayQueryGetIntersectionPrimitiveIndexKHR",
+ "OpRayQueryGetIntersectionBarycentricsKHR",
+ "OpRayQueryGetIntersectionFrontFaceKHR",
+ "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR",
+ "OpRayQueryGetIntersectionObjectRayDirectionKHR",
+ "OpRayQueryGetIntersectionObjectRayOriginKHR",
+ "OpRayQueryGetIntersectionObjectToWorldKHR",
+ "OpRayQueryGetIntersectionWorldToObjectKHR"));
+
+// tests various Intersection operand types
+TEST_F(ValidateRayQuery, IntersectionSuccess) {
+ const std::string body = R"(
+%result_1 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %s32_0
+%result_2 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, IntersectionVector) {
+ const std::string body = R"(
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32vec3_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, IntersectionNonConstantVariable) {
+ const std::string body = R"(
+%var = OpVariable %ptr_function_u32 Function
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %var
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, IntersectionNonConstantLoad) {
+ const std::string body = R"(
+%var = OpVariable %ptr_function_u32 Function
+%load = OpLoad %u32 %var
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %load
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeSuccess) {
+ const std::string body = R"(
+%var_u32 = OpVariable %ptr_function_u32 Function
+%var_f32 = OpVariable %ptr_function_f32 Function
+%var_f32vec3 = OpVariable %ptr_function_f32vec3 Function
+
+%as = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+
+%_u32 = OpLoad %u32 %var_u32
+%_f32 = OpLoad %f32 %var_f32
+%_f32vec3 = OpLoad %f32vec3 %var_f32vec3
+OpRayQueryInitializeKHR %ray_query %as %_u32 %_u32 %_f32vec3 %_f32 %_f32vec3 %_f32
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, InitializeFunctionSuccess) {
+ const std::string declaration = R"(
+%rq_ptr = OpTypePointer Private %type_rq
+%rq_func_type = OpTypeFunction %void %rq_ptr
+%rq_var_1 = OpVariable %rq_ptr Private
+%rq_var_2 = OpVariable %rq_ptr Private
+)";
+
+ const std::string body = R"(
+%fcall_1 = OpFunctionCall %void %rq_func %rq_var_1
+%as_1 = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %rq_var_1 %as_1 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+%fcall_2 = OpFunctionCall %void %rq_func %rq_var_2
+OpReturn
+OpFunctionEnd
+%rq_func = OpFunction %void None %rq_func_type
+%rq_param = OpFunctionParameter %rq_ptr
+%label = OpLabel
+%as_2 = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body, "", declaration).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayQuery) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %top_level_as %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadAS) {
+ const std::string body = R"(
+OpRayQueryInitializeKHR %ray_query %ray_query %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Expected Acceleration Structure to be of type "
+ "OpTypeAccelerationStructureKHR"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayFlags64) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u64_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray Flags must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayFlagsVector) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32vec2 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Operand 15[%v2uint] cannot be a type"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadCullMask) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %f32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Cull Mask must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginVec4) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec4_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginFloat) {
+ const std::string body = R"(
+%var_f32 = OpVariable %ptr_function_f32 Function
+%_f32 = OpLoad %f32 %var_f32
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %_f32 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginInt) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %u32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayTMin) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %u32_0 %f32vec3_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray TMin must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayDirection) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec4_0 %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayTMax) {
+ const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f64_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray TMax must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionSuccess) {
+ const std::string body = R"(
+%var = OpVariable %ptr_function_f32 Function
+%load = OpLoad %f32 %var
+OpRayQueryGenerateIntersectionKHR %ray_query %f32_0
+OpRayQueryGenerateIntersectionKHR %ray_query %load
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionBadRayQuery) {
+ const std::string body = R"(
+OpRayQueryGenerateIntersectionKHR %top_level_as %f32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionBadHitT) {
+ const std::string body = R"(
+OpRayQueryGenerateIntersectionKHR %ray_query %u32_0
+)";
+
+ CompileSuccessfully(GenerateShaderCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Hit T must be a 32-bit float scalar"));
+}
+
+} // namespace
+} // namespace val
+} // namespace spvtools