spirv-val: Add SPV_KHR_ray_query (#4848)
diff --git a/Android.mk b/Android.mk index 6dd1834..cd1d7f8 100644 --- a/Android.mk +++ b/Android.mk
@@ -68,6 +68,7 @@ source/val/validate_logicals.cpp \ source/val/validate_non_uniform.cpp \ source/val/validate_primitives.cpp \ + source/val/validate_ray_query.cpp \ source/val/validate_scopes.cpp \ source/val/validate_small_type_uses.cpp \ source/val/validate_type.cpp
diff --git a/BUILD.gn b/BUILD.gn index 71a584f..9e9f6e5 100644 --- a/BUILD.gn +++ b/BUILD.gn
@@ -530,6 +530,7 @@ "source/val/validate_mode_setting.cpp", "source/val/validate_non_uniform.cpp", "source/val/validate_primitives.cpp", + "source/val/validate_ray_query.cpp", "source/val/validate_scopes.cpp", "source/val/validate_scopes.h", "source/val/validate_small_type_uses.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index c0974e1..1ceb78f 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt
@@ -322,6 +322,7 @@ ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_query.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp index dc6401a..55e9fd2 100644 --- a/source/val/validate.cpp +++ b/source/val/validate.cpp
@@ -350,6 +350,7 @@ if (auto error = NonUniformPass(*vstate, &instruction)) return error; if (auto error = LiteralsPass(*vstate, &instruction)) return error; + if (auto error = RayQueryPass(*vstate, &instruction)) return error; } // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
diff --git a/source/val/validate.h b/source/val/validate.h index cb1d05a..97d4683 100644 --- a/source/val/validate.h +++ b/source/val/validate.h
@@ -197,6 +197,9 @@ /// Validates correctness of miscellaneous instructions. spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst); +/// Validates correctness of ray query instructions. +spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst); + /// Calculates the reachability of basic blocks. void ReachabilityPass(ValidationState_t& _);
diff --git a/source/val/validate_ray_query.cpp b/source/val/validate_ray_query.cpp new file mode 100644 index 0000000..f92bf01 --- /dev/null +++ b/source/val/validate_ray_query.cpp
@@ -0,0 +1,271 @@ +// Copyright (c) 2022 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Validates ray query instructions from SPV_KHR_ray_query + +#include "source/opcode.h" +#include "source/val/instruction.h" +#include "source/val/validate.h" +#include "source/val/validation_state.h" + +namespace spvtools { +namespace val { +namespace { + +spv_result_t ValidateRayQueryPointer(ValidationState_t& _, + const Instruction* inst, + uint32_t ray_query_index) { + const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index); + auto variable = _.FindDef(ray_query_id); + if (!variable || (variable->opcode() != SpvOpVariable && + variable->opcode() != SpvOpFunctionParameter)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Query must be a memory object declaration"; + } + auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0)); + if (!pointer || pointer->opcode() != SpvOpTypePointer) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Query must be a pointer"; + } + auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2)); + if (!type || type->opcode() != SpvOpTypeRayQueryKHR) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Query must be a pointer to OpTypeRayQueryKHR"; + } + return SPV_SUCCESS; +} + +spv_result_t ValidateIntersectionId(ValidationState_t& _, + const Instruction* inst, + uint32_t intersection_index) { + const uint32_t intersection_id = + inst->GetOperandAs<uint32_t>(intersection_index); + const uint32_t intersection_type = _.GetTypeId(intersection_id); + const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id); + if (!_.IsIntScalarType(intersection_type) || + _.GetBitWidth(intersection_type) != 32 || + !spvOpcodeIsConstant(intersection_opcode)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Intersection ID to be a constant 32-bit int scalar"; + } + + return SPV_SUCCESS; +} + +} // namespace + +spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) { + const SpvOp opcode = inst->opcode(); + const uint32_t result_type = inst->type_id(); + + switch (opcode) { + case SpvOpRayQueryInitializeKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error; + + if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) != + SpvOpTypeAccelerationStructureKHR) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Expected Acceleration Structure to be of type " + "OpTypeAccelerationStructureKHR"; + } + + const uint32_t ray_flags = _.GetOperandTypeId(inst, 2); + if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Flags must be a 32-bit int scalar"; + } + + const uint32_t cull_mask = _.GetOperandTypeId(inst, 3); + if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Cull Mask must be a 32-bit int scalar"; + } + + const uint32_t ray_origin = _.GetOperandTypeId(inst, 4); + if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 || + _.GetBitWidth(ray_origin) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Origin must be a 32-bit float 3-component vector"; + } + + const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5); + if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray TMin must be a 32-bit float scalar"; + } + + const uint32_t ray_direction = _.GetOperandTypeId(inst, 6); + if (!_.IsFloatVectorType(ray_direction) || + _.GetDimension(ray_direction) != 3 || + _.GetBitWidth(ray_direction) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray Direction must be a 32-bit float 3-component vector"; + } + + const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7); + if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Ray TMax must be a 32-bit float scalar"; + } + break; + } + + case SpvOpRayQueryTerminateKHR: + case SpvOpRayQueryConfirmIntersectionKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error; + break; + } + + case SpvOpRayQueryGenerateIntersectionKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error; + + const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1); + if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Hit T must be a 32-bit float scalar"; + } + + break; + } + + case SpvOpRayQueryGetIntersectionFrontFaceKHR: + case SpvOpRayQueryProceedKHR: + case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + + if (!_.IsBoolScalarType(result_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be bool scalar type"; + } + + if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) { + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + } + + break; + } + + case SpvOpRayQueryGetIntersectionTKHR: + case SpvOpRayQueryGetRayTMinKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + + if (!_.IsFloatScalarType(result_type) || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit float scalar type"; + } + + if (opcode == SpvOpRayQueryGetIntersectionTKHR) { + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + } + + break; + } + + case SpvOpRayQueryGetIntersectionTypeKHR: + case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR: + case SpvOpRayQueryGetIntersectionInstanceIdKHR: + case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: + case SpvOpRayQueryGetIntersectionGeometryIndexKHR: + case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR: + case SpvOpRayQueryGetRayFlagsKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + + if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit int scalar type"; + } + + if (opcode != SpvOpRayQueryGetRayFlagsKHR) { + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + } + + break; + } + + case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR: + case SpvOpRayQueryGetIntersectionObjectRayOriginKHR: + case SpvOpRayQueryGetWorldRayDirectionKHR: + case SpvOpRayQueryGetWorldRayOriginKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 3 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit float 3-component " + "vector type"; + } + + if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR || + opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) { + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + } + + break; + } + + case SpvOpRayQueryGetIntersectionBarycentricsKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + if (!_.IsFloatVectorType(result_type) || + _.GetDimension(result_type) != 2 || + _.GetBitWidth(result_type) != 32) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type to be 32-bit float 2-component " + "vector type"; + } + + break; + } + + case SpvOpRayQueryGetIntersectionObjectToWorldKHR: + case SpvOpRayQueryGetIntersectionWorldToObjectKHR: { + if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error; + if (auto error = ValidateIntersectionId(_, inst, 3)) return error; + + uint32_t num_rows = 0; + uint32_t num_cols = 0; + uint32_t col_type = 0; + uint32_t component_type = 0; + if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type, + &component_type)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected matrix type as Result Type"; + } + + if (num_cols != 4) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type matrix to have a Column Count of 4"; + } + + if (!_.IsFloatScalarType(component_type) || + _.GetBitWidth(result_type) != 32 || num_rows != 3) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "expected Result Type matrix to have a Column Type of " + "3-component 32-bit float vectors"; + } + break; + } + + default: + break; + } + + return SPV_SUCCESS; +} + +} // namespace val +} // namespace spvtools
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt index 65f2791..d02807a 100644 --- a/test/val/CMakeLists.txt +++ b/test/val/CMakeLists.txt
@@ -88,8 +88,9 @@ PCH_FILE pch_test_val ) -add_spvtools_unittest(TARGET val_stuvw +add_spvtools_unittest(TARGET val_rstuvw SRCS + val_ray_query.cpp val_small_type_uses_test.cpp val_ssa_test.cpp val_state_test.cpp
diff --git a/test/val/val_ray_query.cpp b/test/val/val_ray_query.cpp new file mode 100644 index 0000000..e9b9696 --- /dev/null +++ b/test/val/val_ray_query.cpp
@@ -0,0 +1,578 @@ +// Copyright (c) 2022 The Khronos Group Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests ray query instructions from SPV_KHR_ray_query. + +#include <sstream> +#include <string> + +#include "gmock/gmock.h" +#include "test/val/val_fixtures.h" + +namespace spvtools { +namespace val { +namespace { + +using ::testing::HasSubstr; +using ::testing::Values; + +using ValidateRayQuery = spvtest::ValidateBase<bool>; + +std::string GenerateShaderCode( + const std::string& body, + const std::string& capabilities_and_extensions = "", + const std::string& declarations = "") { + std::ostringstream ss; + ss << R"( +OpCapability Shader +OpCapability Int64 +OpCapability Float64 +OpCapability RayQueryKHR +OpExtension "SPV_KHR_ray_query" +)"; + + ss << capabilities_and_extensions; + + ss << R"( +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" +OpExecutionMode %main LocalSize 1 1 1 + +OpDecorate %top_level_as DescriptorSet 0 +OpDecorate %top_level_as Binding 0 + +%void = OpTypeVoid +%func = OpTypeFunction %void +%bool = OpTypeBool +%f32 = OpTypeFloat 32 +%f64 = OpTypeFloat 64 +%u32 = OpTypeInt 32 0 +%s32 = OpTypeInt 32 1 +%u64 = OpTypeInt 64 0 +%s64 = OpTypeInt 64 1 +%type_rq = OpTypeRayQueryKHR +%type_as = OpTypeAccelerationStructureKHR + +%s32vec2 = OpTypeVector %s32 2 +%u32vec2 = OpTypeVector %u32 2 +%f32vec2 = OpTypeVector %f32 2 +%u32vec3 = OpTypeVector %u32 3 +%s32vec3 = OpTypeVector %s32 3 +%f32vec3 = OpTypeVector %f32 3 +%u32vec4 = OpTypeVector %u32 4 +%s32vec4 = OpTypeVector %s32 4 +%f32vec4 = OpTypeVector %f32 4 + +%mat4x3 = OpTypeMatrix %f32vec3 4 + +%f32_0 = OpConstant %f32 0 +%f64_0 = OpConstant %f64 0 +%s32_0 = OpConstant %s32 0 +%u32_0 = OpConstant %u32 0 +%u64_0 = OpConstant %u64 0 + +%u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0 +%f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0 +%f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0 + +%ptr_rq = OpTypePointer Private %type_rq +%ray_query = OpVariable %ptr_rq Private + +%ptr_as = OpTypePointer UniformConstant %type_as +%top_level_as = OpVariable %ptr_as UniformConstant + +%ptr_function_u32 = OpTypePointer Function %u32 +%ptr_function_f32 = OpTypePointer Function %f32 +%ptr_function_f32vec3 = OpTypePointer Function %f32vec3 +)"; + + ss << declarations; + + ss << R"( +%main = OpFunction %void None %func +%main_entry = OpLabel +)"; + + ss << body; + + ss << R"( +OpReturn +OpFunctionEnd)"; + return ss.str(); +} + +std::string RayQueryResult(std::string opcode) { + if (opcode.compare("OpRayQueryProceedKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 || + opcode.compare("OpRayQueryGetRayTMinKHR") == 0 || + opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord" + "OffsetKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 || + opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 || + opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) { + return "%result ="; + } + return ""; +} + +std::string RayQueryResultType(std::string opcode, bool valid) { + if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 || + opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord" + "OffsetKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0) { + return valid ? "%u32" : "%f64"; + } + + if (opcode.compare("OpRayQueryGetRayTMinKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionTKHR") == 0) { + return valid ? "%f32" : "%f64"; + } + + if (opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0) { + return valid ? "%f32vec2" : "%f64"; + } + + if (opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 || + opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 || + opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0) { + return valid ? "%f32vec3" : "%f64"; + } + + if (opcode.compare("OpRayQueryProceedKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0) { + return valid ? "%bool" : "%f64"; + } + + if (opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) { + return valid ? "%mat4x3" : "%f64"; + } + return ""; +} + +std::string RayQueryIntersection(std::string opcode, bool valid) { + if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord" + "OffsetKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 || + opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) { + return valid ? "%s32_0" : "%f32_0"; + } + return ""; +} + +using RayQueryCommon = spvtest::ValidateBase<std::string>; + +TEST_P(RayQueryCommon, Success) { + std::string opcode = GetParam(); + std::ostringstream ss; + ss << RayQueryResult(opcode); + ss << " " << opcode << " "; + ss << RayQueryResultType(opcode, true); + ss << " %ray_query "; + ss << RayQueryIntersection(opcode, true); + CompileSuccessfully(GenerateShaderCode(ss.str()).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_P(RayQueryCommon, BadQuery) { + std::string opcode = GetParam(); + std::ostringstream ss; + ss << RayQueryResult(opcode); + ss << " " << opcode << " "; + ss << RayQueryResultType(opcode, true); + ss << " %top_level_as "; + ss << RayQueryIntersection(opcode, true); + CompileSuccessfully(GenerateShaderCode(ss.str()).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR")); +} + +TEST_P(RayQueryCommon, BadResult) { + std::string opcode = GetParam(); + std::string result_type = RayQueryResultType(opcode, false); + if (!result_type.empty()) { + std::ostringstream ss; + ss << RayQueryResult(opcode); + ss << " " << opcode << " "; + ss << result_type; + ss << " %ray_query "; + ss << RayQueryIntersection(opcode, true); + CompileSuccessfully(GenerateShaderCode(ss.str()).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + + std::string correct_result_type = RayQueryResultType(opcode, true); + if (correct_result_type.compare("%u32") == 0) { + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("expected Result Type to be 32-bit int scalar type")); + } else if (correct_result_type.compare("%f32") == 0) { + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("expected Result Type to be 32-bit float scalar type")); + } else if (correct_result_type.compare("%f32vec2") == 0) { + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected Result Type to be 32-bit float " + "2-component vector type")); + } else if (correct_result_type.compare("%f32vec3") == 0) { + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected Result Type to be 32-bit float " + "3-component vector type")); + } else if (correct_result_type.compare("%bool") == 0) { + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected Result Type to be bool scalar type")); + } else if (correct_result_type.compare("%mat4x3") == 0) { + EXPECT_THAT(getDiagnosticString(), + HasSubstr("expected matrix type as Result Type")); + } + } +} + +TEST_P(RayQueryCommon, BadIntersection) { + std::string opcode = GetParam(); + std::string intersection = RayQueryIntersection(opcode, false); + if (!intersection.empty()) { + std::ostringstream ss; + ss << RayQueryResult(opcode); + ss << " " << opcode << " "; + ss << RayQueryResultType(opcode, true); + ss << " %ray_query "; + ss << intersection; + CompileSuccessfully(GenerateShaderCode(ss.str()).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr( + "expected Intersection ID to be a constant 32-bit int scalar")); + } +} + +INSTANTIATE_TEST_SUITE_P( + ValidateRayQueryCommon, RayQueryCommon, + Values("OpRayQueryTerminateKHR", "OpRayQueryConfirmIntersectionKHR", + "OpRayQueryProceedKHR", "OpRayQueryGetIntersectionTypeKHR", + "OpRayQueryGetRayTMinKHR", "OpRayQueryGetRayFlagsKHR", + "OpRayQueryGetWorldRayDirectionKHR", + "OpRayQueryGetWorldRayOriginKHR", "OpRayQueryGetIntersectionTKHR", + "OpRayQueryGetIntersectionInstanceCustomIndexKHR", + "OpRayQueryGetIntersectionInstanceIdKHR", + "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR", + "OpRayQueryGetIntersectionGeometryIndexKHR", + "OpRayQueryGetIntersectionPrimitiveIndexKHR", + "OpRayQueryGetIntersectionBarycentricsKHR", + "OpRayQueryGetIntersectionFrontFaceKHR", + "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR", + "OpRayQueryGetIntersectionObjectRayDirectionKHR", + "OpRayQueryGetIntersectionObjectRayOriginKHR", + "OpRayQueryGetIntersectionObjectToWorldKHR", + "OpRayQueryGetIntersectionWorldToObjectKHR")); + +// tests various Intersection operand types +TEST_F(ValidateRayQuery, IntersectionSuccess) { + const std::string body = R"( +%result_1 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %s32_0 +%result_2 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, IntersectionVector) { + const std::string body = R"( +%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32vec3_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("expected Intersection ID to be a constant 32-bit int scalar")); +} + +TEST_F(ValidateRayQuery, IntersectionNonConstantVariable) { + const std::string body = R"( +%var = OpVariable %ptr_function_u32 Function +%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %var +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("expected Intersection ID to be a constant 32-bit int scalar")); +} + +TEST_F(ValidateRayQuery, IntersectionNonConstantLoad) { + const std::string body = R"( +%var = OpVariable %ptr_function_u32 Function +%load = OpLoad %u32 %var +%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %load +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("expected Intersection ID to be a constant 32-bit int scalar")); +} + +TEST_F(ValidateRayQuery, InitializeSuccess) { + const std::string body = R"( +%var_u32 = OpVariable %ptr_function_u32 Function +%var_f32 = OpVariable %ptr_function_f32 Function +%var_f32vec3 = OpVariable %ptr_function_f32vec3 Function + +%as = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 + +%_u32 = OpLoad %u32 %var_u32 +%_f32 = OpLoad %f32 %var_f32 +%_f32vec3 = OpLoad %f32vec3 %var_f32vec3 +OpRayQueryInitializeKHR %ray_query %as %_u32 %_u32 %_f32vec3 %_f32 %_f32vec3 %_f32 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, InitializeFunctionSuccess) { + const std::string declaration = R"( +%rq_ptr = OpTypePointer Private %type_rq +%rq_func_type = OpTypeFunction %void %rq_ptr +%rq_var_1 = OpVariable %rq_ptr Private +%rq_var_2 = OpVariable %rq_ptr Private +)"; + + const std::string body = R"( +%fcall_1 = OpFunctionCall %void %rq_func %rq_var_1 +%as_1 = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %rq_var_1 %as_1 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +%fcall_2 = OpFunctionCall %void %rq_func %rq_var_2 +OpReturn +OpFunctionEnd +%rq_func = OpFunction %void None %rq_func_type +%rq_param = OpFunctionParameter %rq_ptr +%label = OpLabel +%as_2 = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body, "", declaration).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, InitializeBadRayQuery) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %top_level_as %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR")); +} + +TEST_F(ValidateRayQuery, InitializeBadAS) { + const std::string body = R"( +OpRayQueryInitializeKHR %ray_query %ray_query %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Expected Acceleration Structure to be of type " + "OpTypeAccelerationStructureKHR")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayFlags64) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u64_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Flags must be a 32-bit int scalar")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayFlagsVector) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32vec2 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Operand 15[%v2uint] cannot be a type")); +} + +TEST_F(ValidateRayQuery, InitializeBadCullMask) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %f32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Cull Mask must be a 32-bit int scalar")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayOriginVec4) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec4_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Origin must be a 32-bit float 3-component vector")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayOriginFloat) { + const std::string body = R"( +%var_f32 = OpVariable %ptr_function_f32 Function +%_f32 = OpLoad %f32 %var_f32 +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %_f32 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Origin must be a 32-bit float 3-component vector")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayOriginInt) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %u32vec3_0 %f32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Origin must be a 32-bit float 3-component vector")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayTMin) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %u32_0 %f32vec3_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray TMin must be a 32-bit float scalar")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayDirection) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec4_0 %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Ray Direction must be a 32-bit float 3-component vector")); +} + +TEST_F(ValidateRayQuery, InitializeBadRayTMax) { + const std::string body = R"( +%load = OpLoad %type_as %top_level_as +OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f64_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray TMax must be a 32-bit float scalar")); +} + +TEST_F(ValidateRayQuery, GenerateIntersectionSuccess) { + const std::string body = R"( +%var = OpVariable %ptr_function_f32 Function +%load = OpLoad %f32 %var +OpRayQueryGenerateIntersectionKHR %ray_query %f32_0 +OpRayQueryGenerateIntersectionKHR %ray_query %load +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateRayQuery, GenerateIntersectionBadRayQuery) { + const std::string body = R"( +OpRayQueryGenerateIntersectionKHR %top_level_as %f32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR")); +} + +TEST_F(ValidateRayQuery, GenerateIntersectionBadHitT) { + const std::string body = R"( +OpRayQueryGenerateIntersectionKHR %ray_query %u32_0 +)"; + + CompileSuccessfully(GenerateShaderCode(body).c_str()); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("Hit T must be a 32-bit float scalar")); +} + +} // namespace +} // namespace val +} // namespace spvtools