spirv-val: Add SPV_KHR_ray_query (#4848)

diff --git a/Android.mk b/Android.mk
index 6dd1834..cd1d7f8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -68,6 +68,7 @@
 		source/val/validate_logicals.cpp \
 		source/val/validate_non_uniform.cpp \
 		source/val/validate_primitives.cpp \
+		source/val/validate_ray_query.cpp \
 		source/val/validate_scopes.cpp \
 		source/val/validate_small_type_uses.cpp \
 		source/val/validate_type.cpp
diff --git a/BUILD.gn b/BUILD.gn
index 71a584f..9e9f6e5 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -530,6 +530,7 @@
     "source/val/validate_mode_setting.cpp",
     "source/val/validate_non_uniform.cpp",
     "source/val/validate_primitives.cpp",
+    "source/val/validate_ray_query.cpp",
     "source/val/validate_scopes.cpp",
     "source/val/validate_scopes.h",
     "source/val/validate_small_type_uses.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index c0974e1..1ceb78f 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -322,6 +322,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_mode_setting.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_query.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index dc6401a..55e9fd2 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -350,6 +350,7 @@
     if (auto error = NonUniformPass(*vstate, &instruction)) return error;
 
     if (auto error = LiteralsPass(*vstate, &instruction)) return error;
+    if (auto error = RayQueryPass(*vstate, &instruction)) return error;
   }
 
   // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
diff --git a/source/val/validate.h b/source/val/validate.h
index cb1d05a..97d4683 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -197,6 +197,9 @@
 /// Validates correctness of miscellaneous instructions.
 spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
 
+/// Validates correctness of ray query instructions.
+spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
+
 /// Calculates the reachability of basic blocks.
 void ReachabilityPass(ValidationState_t& _);
 
diff --git a/source/val/validate_ray_query.cpp b/source/val/validate_ray_query.cpp
new file mode 100644
index 0000000..f92bf01
--- /dev/null
+++ b/source/val/validate_ray_query.cpp
@@ -0,0 +1,271 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Validates ray query instructions from SPV_KHR_ray_query
+
+#include "source/opcode.h"
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
+                                     const Instruction* inst,
+                                     uint32_t ray_query_index) {
+  const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
+  auto variable = _.FindDef(ray_query_id);
+  if (!variable || (variable->opcode() != SpvOpVariable &&
+                    variable->opcode() != SpvOpFunctionParameter)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Ray Query must be a memory object declaration";
+  }
+  auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
+  if (!pointer || pointer->opcode() != SpvOpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Ray Query must be a pointer";
+  }
+  auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
+  if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Ray Query must be a pointer to OpTypeRayQueryKHR";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateIntersectionId(ValidationState_t& _,
+                                    const Instruction* inst,
+                                    uint32_t intersection_index) {
+  const uint32_t intersection_id =
+      inst->GetOperandAs<uint32_t>(intersection_index);
+  const uint32_t intersection_type = _.GetTypeId(intersection_id);
+  const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
+  if (!_.IsIntScalarType(intersection_type) ||
+      _.GetBitWidth(intersection_type) != 32 ||
+      !spvOpcodeIsConstant(intersection_opcode)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "expected Intersection ID to be a constant 32-bit int scalar";
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace
+
+spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
+  const SpvOp opcode = inst->opcode();
+  const uint32_t result_type = inst->type_id();
+
+  switch (opcode) {
+    case SpvOpRayQueryInitializeKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+
+      if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
+          SpvOpTypeAccelerationStructureKHR) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected Acceleration Structure to be of type "
+                  "OpTypeAccelerationStructureKHR";
+      }
+
+      const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
+      if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Flags must be a 32-bit int scalar";
+      }
+
+      const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
+      if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cull Mask must be a 32-bit int scalar";
+      }
+
+      const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
+      if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
+          _.GetBitWidth(ray_origin) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Origin must be a 32-bit float 3-component vector";
+      }
+
+      const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
+      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray TMin must be a 32-bit float scalar";
+      }
+
+      const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
+      if (!_.IsFloatVectorType(ray_direction) ||
+          _.GetDimension(ray_direction) != 3 ||
+          _.GetBitWidth(ray_direction) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Direction must be a 32-bit float 3-component vector";
+      }
+
+      const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
+      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray TMax must be a 32-bit float scalar";
+      }
+      break;
+    }
+
+    case SpvOpRayQueryTerminateKHR:
+    case SpvOpRayQueryConfirmIntersectionKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+      break;
+    }
+
+    case SpvOpRayQueryGenerateIntersectionKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
+
+      const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
+      if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Hit T must be a 32-bit float scalar";
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionFrontFaceKHR:
+    case SpvOpRayQueryProceedKHR:
+    case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+      if (!_.IsBoolScalarType(result_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be bool scalar type";
+      }
+
+      if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
+        if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionTKHR:
+    case SpvOpRayQueryGetRayTMinKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+      if (!_.IsFloatScalarType(result_type) ||
+          _.GetBitWidth(result_type) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be 32-bit float scalar type";
+      }
+
+      if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
+        if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionTypeKHR:
+    case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
+    case SpvOpRayQueryGetIntersectionInstanceIdKHR:
+    case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
+    case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
+    case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
+    case SpvOpRayQueryGetRayFlagsKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+      if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be 32-bit int scalar type";
+      }
+
+      if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
+        if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
+    case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
+    case SpvOpRayQueryGetWorldRayDirectionKHR:
+    case SpvOpRayQueryGetWorldRayOriginKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+
+      if (!_.IsFloatVectorType(result_type) ||
+          _.GetDimension(result_type) != 3 ||
+          _.GetBitWidth(result_type) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be 32-bit float 3-component "
+                  "vector type";
+      }
+
+      if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
+          opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
+        if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+      if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+
+      if (!_.IsFloatVectorType(result_type) ||
+          _.GetDimension(result_type) != 2 ||
+          _.GetBitWidth(result_type) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be 32-bit float 2-component "
+                  "vector type";
+      }
+
+      break;
+    }
+
+    case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
+    case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
+      if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
+      if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
+
+      uint32_t num_rows = 0;
+      uint32_t num_cols = 0;
+      uint32_t col_type = 0;
+      uint32_t component_type = 0;
+      if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
+                               &component_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected matrix type as Result Type";
+      }
+
+      if (num_cols != 4) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type matrix to have a Column Count of 4";
+      }
+
+      if (!_.IsFloatScalarType(component_type) ||
+          _.GetBitWidth(result_type) != 32 || num_rows != 3) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type matrix to have a Column Type of "
+                  "3-component 32-bit float vectors";
+      }
+      break;
+    }
+
+    default:
+      break;
+  }
+
+  return SPV_SUCCESS;
+}
+
+}  // namespace val
+}  // namespace spvtools
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 65f2791..d02807a 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -88,8 +88,9 @@
   PCH_FILE pch_test_val
 )
 
-add_spvtools_unittest(TARGET val_stuvw
+add_spvtools_unittest(TARGET val_rstuvw
   SRCS
+       val_ray_query.cpp
        val_small_type_uses_test.cpp
        val_ssa_test.cpp
        val_state_test.cpp
diff --git a/test/val/val_ray_query.cpp b/test/val/val_ray_query.cpp
new file mode 100644
index 0000000..e9b9696
--- /dev/null
+++ b/test/val/val_ray_query.cpp
@@ -0,0 +1,578 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests ray query instructions from SPV_KHR_ray_query.
+
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateRayQuery = spvtest::ValidateBase<bool>;
+
+std::string GenerateShaderCode(
+    const std::string& body,
+    const std::string& capabilities_and_extensions = "",
+    const std::string& declarations = "") {
+  std::ostringstream ss;
+  ss << R"(
+OpCapability Shader
+OpCapability Int64
+OpCapability Float64
+OpCapability RayQueryKHR
+OpExtension "SPV_KHR_ray_query"
+)";
+
+  ss << capabilities_and_extensions;
+
+  ss << R"(
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+
+OpDecorate %top_level_as DescriptorSet 0
+OpDecorate %top_level_as Binding 0
+
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f32 = OpTypeFloat 32
+%f64 = OpTypeFloat 64
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+%u64 = OpTypeInt 64 0
+%s64 = OpTypeInt 64 1
+%type_rq = OpTypeRayQueryKHR
+%type_as = OpTypeAccelerationStructureKHR
+
+%s32vec2 = OpTypeVector %s32 2
+%u32vec2 = OpTypeVector %u32 2
+%f32vec2 = OpTypeVector %f32 2
+%u32vec3 = OpTypeVector %u32 3
+%s32vec3 = OpTypeVector %s32 3
+%f32vec3 = OpTypeVector %f32 3
+%u32vec4 = OpTypeVector %u32 4
+%s32vec4 = OpTypeVector %s32 4
+%f32vec4 = OpTypeVector %f32 4
+
+%mat4x3 = OpTypeMatrix %f32vec3 4
+
+%f32_0 = OpConstant %f32 0
+%f64_0 = OpConstant %f64 0
+%s32_0 = OpConstant %s32 0
+%u32_0 = OpConstant %u32 0
+%u64_0 = OpConstant %u64 0
+
+%u32vec3_0 = OpConstantComposite %u32vec3 %u32_0 %u32_0 %u32_0
+%f32vec3_0 = OpConstantComposite %f32vec3 %f32_0 %f32_0 %f32_0
+%f32vec4_0 = OpConstantComposite %f32vec4 %f32_0 %f32_0 %f32_0 %f32_0
+
+%ptr_rq = OpTypePointer Private %type_rq
+%ray_query = OpVariable %ptr_rq Private
+
+%ptr_as = OpTypePointer UniformConstant %type_as
+%top_level_as = OpVariable %ptr_as UniformConstant
+
+%ptr_function_u32 = OpTypePointer Function %u32
+%ptr_function_f32 = OpTypePointer Function %f32
+%ptr_function_f32vec3 = OpTypePointer Function %f32vec3
+)";
+
+  ss << declarations;
+
+  ss << R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+)";
+
+  ss << body;
+
+  ss << R"(
+OpReturn
+OpFunctionEnd)";
+  return ss.str();
+}
+
+std::string RayQueryResult(std::string opcode) {
+  if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+      opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
+      opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+                     "OffsetKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+      opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
+      opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+    return "%result =";
+  }
+  return "";
+}
+
+std::string RayQueryResultType(std::string opcode, bool valid) {
+  if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+      opcode.compare("OpRayQueryGetRayFlagsKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+                     "OffsetKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0) {
+    return valid ? "%u32" : "%f64";
+  }
+
+  if (opcode.compare("OpRayQueryGetRayTMinKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionTKHR") == 0) {
+    return valid ? "%f32" : "%f64";
+  }
+
+  if (opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0) {
+    return valid ? "%f32vec2" : "%f64";
+  }
+
+  if (opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+      opcode.compare("OpRayQueryGetWorldRayDirectionKHR") == 0 ||
+      opcode.compare("OpRayQueryGetWorldRayOriginKHR") == 0) {
+    return valid ? "%f32vec3" : "%f64";
+  }
+
+  if (opcode.compare("OpRayQueryProceedKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionCandidateAABBOpaqueKHR") == 0) {
+    return valid ? "%bool" : "%f64";
+  }
+
+  if (opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+    return valid ? "%mat4x3" : "%f64";
+  }
+  return "";
+}
+
+std::string RayQueryIntersection(std::string opcode, bool valid) {
+  if (opcode.compare("OpRayQueryGetIntersectionTypeKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionTKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceCustomIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceIdKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionInstanceShaderBindingTableRecord"
+                     "OffsetKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionGeometryIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionPrimitiveIndexKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionBarycentricsKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionFrontFaceKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectRayDirectionKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectRayOriginKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionObjectToWorldKHR") == 0 ||
+      opcode.compare("OpRayQueryGetIntersectionWorldToObjectKHR") == 0) {
+    return valid ? "%s32_0" : "%f32_0";
+  }
+  return "";
+}
+
+using RayQueryCommon = spvtest::ValidateBase<std::string>;
+
+TEST_P(RayQueryCommon, Success) {
+  std::string opcode = GetParam();
+  std::ostringstream ss;
+  ss << RayQueryResult(opcode);
+  ss << " " << opcode << " ";
+  ss << RayQueryResultType(opcode, true);
+  ss << " %ray_query ";
+  ss << RayQueryIntersection(opcode, true);
+  CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_P(RayQueryCommon, BadQuery) {
+  std::string opcode = GetParam();
+  std::ostringstream ss;
+  ss << RayQueryResult(opcode);
+  ss << " " << opcode << " ";
+  ss << RayQueryResultType(opcode, true);
+  ss << " %top_level_as ";
+  ss << RayQueryIntersection(opcode, true);
+  CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_P(RayQueryCommon, BadResult) {
+  std::string opcode = GetParam();
+  std::string result_type = RayQueryResultType(opcode, false);
+  if (!result_type.empty()) {
+    std::ostringstream ss;
+    ss << RayQueryResult(opcode);
+    ss << " " << opcode << " ";
+    ss << result_type;
+    ss << " %ray_query ";
+    ss << RayQueryIntersection(opcode, true);
+    CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+    EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+
+    std::string correct_result_type = RayQueryResultType(opcode, true);
+    if (correct_result_type.compare("%u32") == 0) {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("expected Result Type to be 32-bit int scalar type"));
+    } else if (correct_result_type.compare("%f32") == 0) {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("expected Result Type to be 32-bit float scalar type"));
+    } else if (correct_result_type.compare("%f32vec2") == 0) {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("expected Result Type to be 32-bit float "
+                            "2-component vector type"));
+    } else if (correct_result_type.compare("%f32vec3") == 0) {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("expected Result Type to be 32-bit float "
+                            "3-component vector type"));
+    } else if (correct_result_type.compare("%bool") == 0) {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("expected Result Type to be bool scalar type"));
+    } else if (correct_result_type.compare("%mat4x3") == 0) {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("expected matrix type as Result Type"));
+    }
+  }
+}
+
+TEST_P(RayQueryCommon, BadIntersection) {
+  std::string opcode = GetParam();
+  std::string intersection = RayQueryIntersection(opcode, false);
+  if (!intersection.empty()) {
+    std::ostringstream ss;
+    ss << RayQueryResult(opcode);
+    ss << " " << opcode << " ";
+    ss << RayQueryResultType(opcode, true);
+    ss << " %ray_query ";
+    ss << intersection;
+    CompileSuccessfully(GenerateShaderCode(ss.str()).c_str());
+    EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+    EXPECT_THAT(
+        getDiagnosticString(),
+        HasSubstr(
+            "expected Intersection ID to be a constant 32-bit int scalar"));
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    ValidateRayQueryCommon, RayQueryCommon,
+    Values("OpRayQueryTerminateKHR", "OpRayQueryConfirmIntersectionKHR",
+           "OpRayQueryProceedKHR", "OpRayQueryGetIntersectionTypeKHR",
+           "OpRayQueryGetRayTMinKHR", "OpRayQueryGetRayFlagsKHR",
+           "OpRayQueryGetWorldRayDirectionKHR",
+           "OpRayQueryGetWorldRayOriginKHR", "OpRayQueryGetIntersectionTKHR",
+           "OpRayQueryGetIntersectionInstanceCustomIndexKHR",
+           "OpRayQueryGetIntersectionInstanceIdKHR",
+           "OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR",
+           "OpRayQueryGetIntersectionGeometryIndexKHR",
+           "OpRayQueryGetIntersectionPrimitiveIndexKHR",
+           "OpRayQueryGetIntersectionBarycentricsKHR",
+           "OpRayQueryGetIntersectionFrontFaceKHR",
+           "OpRayQueryGetIntersectionCandidateAABBOpaqueKHR",
+           "OpRayQueryGetIntersectionObjectRayDirectionKHR",
+           "OpRayQueryGetIntersectionObjectRayOriginKHR",
+           "OpRayQueryGetIntersectionObjectToWorldKHR",
+           "OpRayQueryGetIntersectionWorldToObjectKHR"));
+
+// tests various Intersection operand types
+TEST_F(ValidateRayQuery, IntersectionSuccess) {
+  const std::string body = R"(
+%result_1 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %s32_0
+%result_2 = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, IntersectionVector) {
+  const std::string body = R"(
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %u32vec3_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, IntersectionNonConstantVariable) {
+  const std::string body = R"(
+%var = OpVariable %ptr_function_u32 Function
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %var
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, IntersectionNonConstantLoad) {
+  const std::string body = R"(
+%var = OpVariable %ptr_function_u32 Function
+%load = OpLoad %u32 %var
+%result = OpRayQueryGetIntersectionFrontFaceKHR %bool %ray_query %load
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("expected Intersection ID to be a constant 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeSuccess) {
+  const std::string body = R"(
+%var_u32 = OpVariable %ptr_function_u32 Function
+%var_f32 = OpVariable %ptr_function_f32 Function
+%var_f32vec3 = OpVariable %ptr_function_f32vec3 Function
+
+%as = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %as %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+
+%_u32 = OpLoad %u32 %var_u32
+%_f32 = OpLoad %f32 %var_f32
+%_f32vec3 = OpLoad %f32vec3 %var_f32vec3
+OpRayQueryInitializeKHR %ray_query %as %_u32 %_u32 %_f32vec3 %_f32 %_f32vec3 %_f32
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, InitializeFunctionSuccess) {
+  const std::string declaration = R"(
+%rq_ptr = OpTypePointer Private %type_rq
+%rq_func_type = OpTypeFunction %void %rq_ptr
+%rq_var_1 = OpVariable %rq_ptr Private
+%rq_var_2 = OpVariable %rq_ptr Private
+)";
+
+  const std::string body = R"(
+%fcall_1 = OpFunctionCall %void %rq_func %rq_var_1
+%as_1 = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %rq_var_1 %as_1 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+%fcall_2 = OpFunctionCall %void %rq_func %rq_var_2
+OpReturn
+OpFunctionEnd
+%rq_func = OpFunction %void None %rq_func_type
+%rq_param = OpFunctionParameter %rq_ptr
+%label = OpLabel
+%as_2 = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %rq_param %as_2 %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body, "", declaration).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayQuery) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %top_level_as %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadAS) {
+  const std::string body = R"(
+OpRayQueryInitializeKHR %ray_query %ray_query %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected Acceleration Structure to be of type "
+                        "OpTypeAccelerationStructureKHR"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayFlags64) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u64_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray Flags must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayFlagsVector) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32vec2 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Operand 15[%v2uint] cannot be a type"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadCullMask) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %f32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Cull Mask must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginVec4) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec4_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginFloat) {
+  const std::string body = R"(
+%var_f32 = OpVariable %ptr_function_f32 Function
+%_f32 = OpLoad %f32 %var_f32
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %_f32 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayOriginInt) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %u32vec3_0 %f32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayTMin) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %u32_0 %f32vec3_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray TMin must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayDirection) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec4_0 %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayQuery, InitializeBadRayTMax) {
+  const std::string body = R"(
+%load = OpLoad %type_as %top_level_as
+OpRayQueryInitializeKHR %ray_query %load %u32_0 %u32_0 %f32vec3_0 %f32_0 %f32vec3_0 %f64_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray TMax must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionSuccess) {
+  const std::string body = R"(
+%var = OpVariable %ptr_function_f32 Function
+%load = OpLoad %f32 %var
+OpRayQueryGenerateIntersectionKHR %ray_query %f32_0
+OpRayQueryGenerateIntersectionKHR %ray_query %load
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionBadRayQuery) {
+  const std::string body = R"(
+OpRayQueryGenerateIntersectionKHR %top_level_as %f32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray Query must be a pointer to OpTypeRayQueryKHR"));
+}
+
+TEST_F(ValidateRayQuery, GenerateIntersectionBadHitT) {
+  const std::string body = R"(
+OpRayQueryGenerateIntersectionKHR %ray_query %u32_0
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Hit T must be a 32-bit float scalar"));
+}
+
+}  // namespace
+}  // namespace val
+}  // namespace spvtools