Add validation support for SPV_EXT_shader_invocation_reorder. (#6401)
Co-authored-by: Steven Perron <stevenperron@google.com>
diff --git a/source/opcode.cpp b/source/opcode.cpp
index a8b4102..08e5b10 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -267,6 +267,7 @@
// spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR:
case spv::Op::OpTypeHitObjectNV:
+ case spv::Op::OpTypeHitObjectEXT:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeNodePayloadArrayAMDX:
case spv::Op::OpTypeTensorLayoutNV:
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index c30545f..2ce9e85 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -557,6 +557,7 @@
(uint32_t)spv::Op::OpTypeAccelerationStructureKHR,
(uint32_t)spv::Op::OpTypeRayQueryKHR,
(uint32_t)spv::Op::OpTypeHitObjectNV,
+ (uint32_t)spv::Op::OpTypeHitObjectEXT,
(uint32_t)spv::Op::OpTypeArray,
(uint32_t)spv::Op::OpTypeRuntimeArray,
(uint32_t)spv::Op::OpTypeNodePayloadArrayAMDX,
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 43e82cf..21c5df6 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -237,6 +237,7 @@
DefineParameterlessCase(AccelerationStructureNV);
DefineParameterlessCase(RayQueryKHR);
DefineParameterlessCase(HitObjectNV);
+ DefineParameterlessCase(HitObjectEXT);
#undef DefineParameterlessCase
case Type::kInteger:
typeInst = MakeUnique<Instruction>(
@@ -654,6 +655,7 @@
DefineNoSubtypeCase(AccelerationStructureNV);
DefineNoSubtypeCase(RayQueryKHR);
DefineNoSubtypeCase(HitObjectNV);
+ DefineNoSubtypeCase(HitObjectEXT);
#undef DefineNoSubtypeCase
case Type::kVector: {
const Vector* vec_ty = type.AsVector();
@@ -1082,6 +1084,9 @@
case spv::Op::OpTypeHitObjectNV:
type = new HitObjectNV();
break;
+ case spv::Op::OpTypeHitObjectEXT:
+ type = new HitObjectEXT();
+ break;
case spv::Op::OpTypeTensorLayoutNV:
type = new TensorLayoutNV(inst.GetSingleWordInOperand(0),
inst.GetSingleWordInOperand(1));
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index f0dc1c6..b6fa015 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -135,6 +135,7 @@
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
+ DeclareKindCase(HitObjectEXT);
DeclareKindCase(TensorARM);
DeclareKindCase(GraphARM);
#undef DeclareKindCase
@@ -187,6 +188,7 @@
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
+ DeclareKindCase(HitObjectEXT);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
DeclareKindCase(TensorARM);
@@ -249,6 +251,7 @@
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
+ DeclareKindCase(HitObjectEXT);
DeclareKindCase(TensorLayoutNV);
DeclareKindCase(TensorViewNV);
DeclareKindCase(TensorARM);
diff --git a/source/opt/types.h b/source/opt/types.h
index 2dd6c75..4fa3e66 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -67,6 +67,7 @@
class CooperativeVectorNV;
class RayQueryKHR;
class HitObjectNV;
+class HitObjectEXT;
class TensorLayoutNV;
class TensorViewNV;
class TensorARM;
@@ -114,6 +115,7 @@
kCooperativeVectorNV,
kRayQueryKHR,
kHitObjectNV,
+ kHitObjectEXT,
kTensorLayoutNV,
kTensorViewNV,
kTensorARM,
@@ -222,6 +224,7 @@
DeclareCastMethod(CooperativeVectorNV)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
+ DeclareCastMethod(HitObjectEXT)
DeclareCastMethod(TensorLayoutNV)
DeclareCastMethod(TensorViewNV)
DeclareCastMethod(TensorARM)
@@ -862,6 +865,7 @@
DefineParameterlessType(AccelerationStructureNV, accelerationStructureNV);
DefineParameterlessType(RayQueryKHR, rayQueryKHR);
DefineParameterlessType(HitObjectNV, hitObjectNV);
+DefineParameterlessType(HitObjectEXT, hitObjectEXT);
#undef DefineParameterlessType
} // namespace analysis
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index dd1cb40..27bb5cb 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -399,6 +399,7 @@
if (auto error = RayQueryPass(*vstate, &instruction)) return error;
if (auto error = RayTracingPass(*vstate, &instruction)) return error;
if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
+ if (auto error = RayReorderEXTPass(*vstate, &instruction)) return error;
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
if (auto error = TensorPass(*vstate, &instruction)) return error;
diff --git a/source/val/validate.h b/source/val/validate.h
index eb54d28..b21972e 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -220,6 +220,9 @@
/// Validates correctness of shader execution reorder instructions.
spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst);
+/// Validates correctness of shader execution reorder EXT instructions.
+spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst);
+
/// Validates correctness of mesh shading instructions.
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
diff --git a/source/val/validate_annotation.cpp b/source/val/validate_annotation.cpp
index 2545f2f..8ef9eef 100644
--- a/source/val/validate_annotation.cpp
+++ b/source/val/validate_annotation.cpp
@@ -217,6 +217,7 @@
sc != spv::StorageClass::IncomingCallableDataKHR &&
sc != spv::StorageClass::ShaderRecordBufferKHR &&
sc != spv::StorageClass::HitObjectAttributeNV &&
+ sc != spv::StorageClass::HitObjectAttributeEXT &&
sc != spv::StorageClass::TileImageEXT) {
return _.diag(SPV_ERROR_INVALID_ID, target)
<< _.VkErrorID(6672) << _.SpvDecorationString(dec)
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index 45a721c..4cd8f82 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -1088,6 +1088,7 @@
ExtensionToString(kSPV_KHR_workgroup_memory_explicit_layout) ||
extension == ExtensionToString(kSPV_EXT_mesh_shader) ||
extension == ExtensionToString(kSPV_NV_shader_invocation_reorder) ||
+ extension == ExtensionToString(kSPV_EXT_shader_invocation_reorder) ||
extension ==
ExtensionToString(kSPV_NV_cluster_acceleration_structure) ||
extension == ExtensionToString(kSPV_NV_linear_swept_spheres) ||
diff --git a/source/val/validate_logical_pointers.cpp b/source/val/validate_logical_pointers.cpp
index fcc9db3..1528701 100644
--- a/source/val/validate_logical_pointers.cpp
+++ b/source/val/validate_logical_pointers.cpp
@@ -247,6 +247,42 @@
case spv::Op::OpHitObjectIsEmptyNV:
case spv::Op::OpHitObjectIsHitNV:
case spv::Op::OpHitObjectIsMissNV:
+ // SPV_EXT_shader_invocation_reorder
+ case spv::Op::OpHitObjectRecordFromQueryEXT:
+ case spv::Op::OpHitObjectRecordMissEXT:
+ case spv::Op::OpHitObjectRecordMissMotionEXT:
+ case spv::Op::OpHitObjectGetIntersectionTriangleVertexPositionsEXT:
+ case spv::Op::OpHitObjectGetRayFlagsEXT:
+ case spv::Op::OpHitObjectSetShaderBindingTableRecordIndexEXT:
+ case spv::Op::OpHitObjectReorderExecuteShaderEXT:
+ case spv::Op::OpHitObjectTraceReorderExecuteEXT:
+ case spv::Op::OpHitObjectTraceMotionReorderExecuteEXT:
+ case spv::Op::OpReorderThreadWithHintEXT:
+ case spv::Op::OpReorderThreadWithHitObjectEXT:
+ case spv::Op::OpHitObjectTraceRayEXT:
+ case spv::Op::OpHitObjectTraceRayMotionEXT:
+ case spv::Op::OpHitObjectRecordEmptyEXT:
+ case spv::Op::OpHitObjectExecuteShaderEXT:
+ case spv::Op::OpHitObjectGetCurrentTimeEXT:
+ case spv::Op::OpHitObjectGetAttributesEXT:
+ case spv::Op::OpHitObjectGetHitKindEXT:
+ case spv::Op::OpHitObjectGetPrimitiveIndexEXT:
+ case spv::Op::OpHitObjectGetGeometryIndexEXT:
+ case spv::Op::OpHitObjectGetInstanceIdEXT:
+ case spv::Op::OpHitObjectGetInstanceCustomIndexEXT:
+ case spv::Op::OpHitObjectGetObjectRayOriginEXT:
+ case spv::Op::OpHitObjectGetObjectRayDirectionEXT:
+ case spv::Op::OpHitObjectGetWorldRayDirectionEXT:
+ case spv::Op::OpHitObjectGetWorldRayOriginEXT:
+ case spv::Op::OpHitObjectGetObjectToWorldEXT:
+ case spv::Op::OpHitObjectGetWorldToObjectEXT:
+ case spv::Op::OpHitObjectGetRayTMaxEXT:
+ case spv::Op::OpHitObjectGetRayTMinEXT:
+ case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexEXT:
+ case spv::Op::OpHitObjectGetShaderRecordBufferHandleEXT:
+ case spv::Op::OpHitObjectIsEmptyEXT:
+ case spv::Op::OpHitObjectIsHitEXT:
+ case spv::Op::OpHitObjectIsMissEXT:
// SPV_NV_raw_access_chains
case spv::Op::OpRawAccessChainNV:
// SPV_NV_cooperative_matrix2
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 67f85e1..9372f5c 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -515,6 +515,7 @@
storage_class != spv::StorageClass::IncomingCallableDataKHR &&
storage_class != spv::StorageClass::TaskPayloadWorkgroupEXT &&
storage_class != spv::StorageClass::HitObjectAttributeNV &&
+ storage_class != spv::StorageClass::HitObjectAttributeEXT &&
storage_class != spv::StorageClass::NodePayloadAMDX) {
bool storage_input_or_output = storage_class == spv::StorageClass::Input ||
storage_class == spv::StorageClass::Output;
@@ -756,6 +757,11 @@
<< "OpVariable, <id> " << _.getIdName(inst->id())
<< ", initializer are not allowed for HitObjectAttributeNV";
}
+ if (storage_class == spv::StorageClass::HitObjectAttributeEXT) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpVariable, <id> " << _.getIdName(inst->id())
+ << ", initializer are not allowed for HitObjectAttributeEXT";
+ }
}
if (storage_class == spv::StorageClass::PhysicalStorageBuffer) {
diff --git a/source/val/validate_ray_tracing_reorder.cpp b/source/val/validate_ray_tracing_reorder.cpp
index 3685a76..f189b44 100644
--- a/source/val/validate_ray_tracing_reorder.cpp
+++ b/source/val/validate_ray_tracing_reorder.cpp
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Validates ray tracing instructions from SPV_NV_shader_execution_reorder
+// Validates ray tracing instructions from SPV_NV_shader_invocation_reorder and
+// SPV_EXT_shader_invocation_reorder
#include "source/opcode.h"
#include "source/val/instruction.h"
@@ -37,18 +38,29 @@
return array_length;
}
+spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
+ const Instruction* inst,
+ uint32_t ray_query_index) {
+ const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
+ auto variable = _.FindDef(ray_query_id);
+ auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
+ if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Query must be a pointer";
+ }
+ auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
+ if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Query must be a pointer to OpTypeRayQueryKHR";
+ }
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateHitObjectPointer(ValidationState_t& _,
const Instruction* inst,
uint32_t hit_object_index) {
const uint32_t hit_object_id = inst->GetOperandAs<uint32_t>(hit_object_index);
auto variable = _.FindDef(hit_object_id);
- const auto var_opcode = variable->opcode();
- if (!variable || (var_opcode != spv::Op::OpVariable &&
- var_opcode != spv::Op::OpFunctionParameter &&
- var_opcode != spv::Op::OpAccessChain)) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Hit Object must be a memory object declaration";
- }
auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -62,6 +74,24 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateHitObjectPointerEXT(ValidationState_t& _,
+ const Instruction* inst,
+ uint32_t hit_object_index) {
+ const uint32_t hit_object_id = inst->GetOperandAs<uint32_t>(hit_object_index);
+ auto variable = _.FindDef(hit_object_id);
+ auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
+ if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hit Object must be a pointer";
+ }
+ auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
+ if (!type || type->opcode() != spv::Op::OpTypeHitObjectEXT) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Type must be OpTypeHitObjectEXT";
+ }
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateHitObjectInstructionCommonParameters(
ValidationState_t& _, const Instruction* inst,
uint32_t acceleration_struct_index, uint32_t instance_id_index,
@@ -247,8 +277,10 @@
auto variable = _.FindDef(hit_object_attr_id);
const auto var_opcode = variable->opcode();
if (!variable || var_opcode != spv::Op::OpVariable ||
- (variable->GetOperandAs<spv::StorageClass>(2)) !=
- spv::StorageClass::HitObjectAttributeNV) {
+ !((variable->GetOperandAs<spv::StorageClass>(2) ==
+ spv::StorageClass::HitObjectAttributeNV) ||
+ (variable->GetOperandAs<spv::StorageClass>(2) ==
+ spv::StorageClass::HitObjectAttributeEXT))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Hit Object Attributes id must be a OpVariable of storage "
"class HitObjectAttributeNV";
@@ -728,5 +760,651 @@
}
return SPV_SUCCESS;
}
+
+spv_result_t RayReorderEXTPass(ValidationState_t& _, const Instruction* inst) {
+ const spv::Op opcode = inst->opcode();
+ const uint32_t result_type = inst->type_id();
+
+ auto RegisterOpcodeForValidModel = [](ValidationState_t& vs,
+ const Instruction* rtinst) {
+ std::string opcode_name = spvOpcodeString(rtinst->opcode());
+ vs.function(rtinst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR &&
+ model != spv::ExecutionModel::ClosestHitKHR &&
+ model != spv::ExecutionModel::MissKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR, ClosestHitKHR and "
+ "MissKHR execution models";
+ }
+ return false;
+ }
+ return true;
+ });
+ return;
+ };
+
+ switch (opcode) {
+ case spv::Op::OpHitObjectIsMissEXT:
+ case spv::Op::OpHitObjectIsHitEXT:
+ case spv::Op::OpHitObjectIsEmptyEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (!_.IsBoolScalarType(result_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type to be bool scalar type";
+ }
+
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetShaderRecordBufferHandleEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ if (!_.IsIntVectorType(result_type) ||
+ (_.GetDimension(result_type) != 2) ||
+ (_.GetBitWidth(result_type) != 32))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected 32-bit integer type 2-component vector as Result "
+ "Type: "
+ << spvOpcodeString(opcode);
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetHitKindEXT:
+ case spv::Op::OpHitObjectGetPrimitiveIndexEXT:
+ case spv::Op::OpHitObjectGetGeometryIndexEXT:
+ case spv::Op::OpHitObjectGetInstanceIdEXT:
+ case spv::Op::OpHitObjectGetInstanceCustomIndexEXT:
+ case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexEXT:
+ case spv::Op::OpHitObjectGetRayFlagsEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected 32-bit integer type scalar as Result Type: "
+ << spvOpcodeString(opcode);
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetCurrentTimeEXT:
+ case spv::Op::OpHitObjectGetRayTMaxEXT:
+ case spv::Op::OpHitObjectGetRayTMinEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected 32-bit floating-point type scalar as Result Type: "
+ << spvOpcodeString(opcode);
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetObjectToWorldEXT:
+ case spv::Op::OpHitObjectGetWorldToObjectEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ uint32_t num_rows = 0;
+ uint32_t num_cols = 0;
+ uint32_t col_type = 0;
+ uint32_t component_type = 0;
+
+ if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
+ &component_type)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected matrix type as Result Type: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (num_cols != 4) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type matrix to have a Column Count of 4"
+ << spvOpcodeString(opcode);
+ }
+
+ if (!_.IsFloatScalarType(component_type) ||
+ _.GetBitWidth(result_type) != 32 || num_rows != 3) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "expected Result Type matrix to have a Column Type of "
+ "3-component 32-bit float vectors: "
+ << spvOpcodeString(opcode);
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetObjectRayOriginEXT:
+ case spv::Op::OpHitObjectGetObjectRayDirectionEXT:
+ case spv::Op::OpHitObjectGetWorldRayDirectionEXT:
+ case spv::Op::OpHitObjectGetWorldRayOriginEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ if (!_.IsFloatVectorType(result_type) ||
+ (_.GetDimension(result_type) != 3) ||
+ (_.GetBitWidth(result_type) != 32))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected 32-bit floating-point type 3-component vector as "
+ "Result Type: "
+ << spvOpcodeString(opcode);
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetIntersectionTriangleVertexPositionsEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 2)) return error;
+
+ auto result_id = _.FindDef(result_type);
+ if ((result_id->opcode() != spv::Op::OpTypeArray) ||
+ (GetArrayLength(_, result_id) != 3) ||
+ !_.IsFloatVectorType(_.GetComponentType(result_type)) ||
+ _.GetDimension(_.GetComponentType(result_type)) != 3 ||
+ _.GetBitWidth(_.GetComponentType(result_type)) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected 3 element array of 32-bit 3 component float "
+ "vectors as Result Type: "
+ << spvOpcodeString(opcode);
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectGetAttributesEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
+ auto variable = _.FindDef(hit_object_attr_id);
+ const auto var_opcode = variable->opcode();
+ if (!variable || var_opcode != spv::Op::OpVariable ||
+ variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::HitObjectAttributeEXT) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hit Object Attributes id must be a OpVariable of storage "
+ "class HitObjectAttributeEXT";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectSetShaderBindingTableRecordIndexEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ const uint32_t sbt_index_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsIntScalarType(sbt_index_id) ||
+ _.GetBitWidth(sbt_index_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "SBT Index must be a 32-bit integer scalar";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectExecuteShaderEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ const uint32_t payload_id = inst->GetOperandAs<uint32_t>(1);
+ auto variable = _.FindDef(payload_id);
+ const auto var_opcode = variable->opcode();
+ if (!variable || var_opcode != spv::Op::OpVariable ||
+ (variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::RayPayloadKHR &&
+ variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::IncomingRayPayloadKHR)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Payload must be a OpVariable of storage "
+ "class RayPayloadKHR or IncomingRayPayloadKHR";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectRecordEmptyEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+ break;
+ }
+
+ case spv::Op::OpHitObjectRecordFromQueryEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+ if (auto error = ValidateRayQueryPointer(_, inst, 1)) return error;
+
+ if (!_.HasCapability(spv::Capability::RayQueryKHR))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": requires RayQueryKHR capability";
+
+ // Validate SBT Record Index (operand 2)
+ const uint32_t sbt_record_index_id = _.GetOperandTypeId(inst, 2);
+ if (!_.IsIntScalarType(sbt_record_index_id) ||
+ _.GetBitWidth(sbt_record_index_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "SBT Record Index must be a 32-bit integer scalar";
+ }
+
+ // Validate Hit Object Attributes (operand 3)
+ const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(3);
+ auto attr_variable = _.FindDef(hit_object_attr_id);
+ const auto attr_var_opcode = attr_variable->opcode();
+ if (!attr_variable || attr_var_opcode != spv::Op::OpVariable ||
+ attr_variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::HitObjectAttributeEXT) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hit Object Attributes id must be a OpVariable of storage "
+ "class HitObjectAttributeEXT";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectRecordMissEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ // Ray Flags (operand 1)
+ const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsIntScalarType(ray_flags_id) ||
+ _.GetBitWidth(ray_flags_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Flags must be a 32-bit int scalar";
+ }
+
+ // Miss Index (operand 2)
+ const uint32_t miss_index = _.GetOperandTypeId(inst, 2);
+ if (!_.IsUnsignedIntScalarType(miss_index) ||
+ _.GetBitWidth(miss_index) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Miss Index must be a 32-bit unsigned int scalar";
+ }
+
+ // Ray Origin (operand 3)
+ const uint32_t ray_origin = _.GetOperandTypeId(inst, 3);
+ if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
+ _.GetBitWidth(ray_origin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Origin must be a 32-bit float 3-component vector";
+ }
+
+ // Ray TMin (operand 4)
+ const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
+ if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMin must be a 32-bit float scalar";
+ }
+
+ // Ray Direction (operand 5)
+ const uint32_t ray_direction = _.GetOperandTypeId(inst, 5);
+ if (!_.IsFloatVectorType(ray_direction) ||
+ _.GetDimension(ray_direction) != 3 ||
+ _.GetBitWidth(ray_direction) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Direction must be a 32-bit float 3-component vector";
+ }
+
+ // Ray TMax (operand 6)
+ const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
+ if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMax must be a 32-bit float scalar";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectRecordMissMotionEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ // Ray Flags (operand 1)
+ const uint32_t ray_flags_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsIntScalarType(ray_flags_id) ||
+ _.GetBitWidth(ray_flags_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Flags must be a 32-bit int scalar";
+ }
+
+ // Miss Index (operand 2)
+ const uint32_t miss_index = _.GetOperandTypeId(inst, 2);
+ if (!_.IsUnsignedIntScalarType(miss_index) ||
+ _.GetBitWidth(miss_index) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Miss Index must be a 32-bit unsigned int scalar";
+ }
+
+ // Ray Origin (operand 3)
+ const uint32_t ray_origin = _.GetOperandTypeId(inst, 3);
+ if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
+ _.GetBitWidth(ray_origin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Origin must be a 32-bit float 3-component vector";
+ }
+
+ // Ray TMin (operand 4)
+ const uint32_t ray_tmin = _.GetOperandTypeId(inst, 4);
+ if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMin must be a 32-bit float scalar";
+ }
+
+ // Ray Direction (operand 5)
+ const uint32_t ray_direction = _.GetOperandTypeId(inst, 5);
+ if (!_.IsFloatVectorType(ray_direction) ||
+ _.GetDimension(ray_direction) != 3 ||
+ _.GetBitWidth(ray_direction) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray Direction must be a 32-bit float 3-component vector";
+ }
+
+ // Ray TMax (operand 6)
+ const uint32_t ray_tmax = _.GetOperandTypeId(inst, 6);
+ if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Ray TMax must be a 32-bit float scalar";
+ }
+
+ // Current Time (operand 7)
+ const uint32_t current_time_id = _.GetOperandTypeId(inst, 7);
+ if (!_.IsFloatScalarType(current_time_id) ||
+ _.GetBitWidth(current_time_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Current Time must be a 32-bit float scalar";
+ }
+ break;
+ }
+
+ case spv::Op::OpReorderThreadWithHintEXT: {
+ std::string opcode_name = spvOpcodeString(inst->opcode());
+ _.function(inst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
+
+ const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
+ if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint must be a 32-bit int scalar";
+ }
+
+ const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Bits must be a 32-bit int scalar";
+ }
+ break;
+ }
+
+ case spv::Op::OpReorderThreadWithHitObjectEXT: {
+ std::string opcode_name = spvOpcodeString(inst->opcode());
+ _.function(inst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
+
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ if (inst->operands().size() > 1) {
+ if (inst->operands().size() != 3) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint and Bits are optional together i.e "
+ << " Either both Hint and Bits should be provided or neither.";
+ }
+
+ // Validate the optional operands Hint and Bits
+ const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
+ if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint must be a 32-bit int scalar";
+ }
+ const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
+ if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Bits must be a 32-bit int scalar";
+ }
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectTraceRayEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ if (auto error = ValidateHitObjectInstructionCommonParameters(
+ _, inst, 1 /* Acceleration Struct */,
+ KRayParamInvalidId /* Instance Id */,
+ KRayParamInvalidId /* Primitive Id */,
+ KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
+ 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
+ KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
+ 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
+ KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
+ 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
+ 10 /* Ray TMax */, 11 /* Payload */,
+ KRayParamInvalidId /* Hit Object Attribute */))
+ return error;
+ break;
+ }
+
+ case spv::Op::OpHitObjectTraceRayMotionEXT: {
+ RegisterOpcodeForValidModel(_, inst);
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ if (auto error = ValidateHitObjectInstructionCommonParameters(
+ _, inst, 1 /* Acceleration Struct */,
+ KRayParamInvalidId /* Instance Id */,
+ KRayParamInvalidId /* Primitive Id */,
+ KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
+ 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
+ KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
+ 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
+ KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
+ 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
+ 10 /* Ray TMax */, 12 /* Payload */,
+ KRayParamInvalidId /* Hit Object Attribute */))
+ return error;
+
+ // Current Time (operand 11)
+ const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
+ if (!_.IsFloatScalarType(current_time_id) ||
+ _.GetBitWidth(current_time_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Current Time must be a 32-bit float scalar";
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectReorderExecuteShaderEXT: {
+ std::string opcode_name = spvOpcodeString(inst->opcode());
+ _.function(inst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
+
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ // Validate Payload (operand 1)
+ const uint32_t payload_id = inst->GetOperandAs<uint32_t>(1);
+ auto variable = _.FindDef(payload_id);
+ const auto var_opcode = variable->opcode();
+ if (!variable || var_opcode != spv::Op::OpVariable ||
+ (variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::RayPayloadKHR &&
+ variable->GetOperandAs<spv::StorageClass>(2) !=
+ spv::StorageClass::IncomingRayPayloadKHR)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Payload must be a OpVariable of storage "
+ "class RayPayloadKHR or IncomingRayPayloadKHR";
+ }
+
+ // Check for optional Hint and Bits (operands 2 and 3)
+ if (inst->operands().size() > 2) {
+ if (inst->operands().size() != 4) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint and Bits are optional together i.e "
+ << " Either both Hint and Bits should be provided or neither.";
+ }
+
+ // Validate optional Hint and Bits
+ const uint32_t hint_id = _.GetOperandTypeId(inst, 2);
+ if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint must be a 32-bit int scalar";
+ }
+ const uint32_t bits_id = _.GetOperandTypeId(inst, 3);
+ if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Bits must be a 32-bit int scalar";
+ }
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectTraceReorderExecuteEXT: {
+ std::string opcode_name = spvOpcodeString(inst->opcode());
+ _.function(inst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
+
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ // Validate base trace ray parameters (operands 1-11)
+ if (auto error = ValidateHitObjectInstructionCommonParameters(
+ _, inst, 1 /* Acceleration Struct */,
+ KRayParamInvalidId /* Instance Id */,
+ KRayParamInvalidId /* Primitive Id */,
+ KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
+ 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
+ KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
+ 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
+ KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
+ 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
+ 10 /* Ray TMax */, 11 /* Payload */,
+ KRayParamInvalidId /* Hit Object Attribute */))
+ return error;
+
+ // Check for optional Hint and Bits (operands 12 and 13)
+ if (inst->operands().size() > 12) {
+ if (inst->operands().size() != 14) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint and Bits are optional together i.e "
+ << " Either both Hint and Bits should be provided or neither.";
+ }
+
+ // Validate optional Hint and Bits
+ const uint32_t hint_id = _.GetOperandTypeId(inst, 12);
+ if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint must be a 32-bit int scalar";
+ }
+ const uint32_t bits_id = _.GetOperandTypeId(inst, 13);
+ if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Bits must be a 32-bit int scalar";
+ }
+ }
+ break;
+ }
+
+ case spv::Op::OpHitObjectTraceMotionReorderExecuteEXT: {
+ std::string opcode_name = spvOpcodeString(inst->opcode());
+ _.function(inst->function()->id())
+ ->RegisterExecutionModelLimitation(
+ [opcode_name](spv::ExecutionModel model, std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR) {
+ if (message) {
+ *message = opcode_name +
+ " requires RayGenerationKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
+
+ if (auto error = ValidateHitObjectPointerEXT(_, inst, 0)) return error;
+
+ // Validate base trace ray parameters (operands 1-12)
+ if (auto error = ValidateHitObjectInstructionCommonParameters(
+ _, inst, 1 /* Acceleration Struct */,
+ KRayParamInvalidId /* Instance Id */,
+ KRayParamInvalidId /* Primitive Id */,
+ KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
+ 3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
+ KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
+ 5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
+ KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
+ 7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
+ 10 /* Ray TMax */, 12 /* Payload */,
+ KRayParamInvalidId /* Hit Object Attribute */))
+ return error;
+
+ // Current Time (operand 11)
+ const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
+ if (!_.IsFloatScalarType(current_time_id) ||
+ _.GetBitWidth(current_time_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Current Time must be a 32-bit float scalar";
+ }
+
+ // Check for optional Hint and Bits (operands 13 and 14)
+ if (inst->operands().size() > 13) {
+ if (inst->operands().size() != 15) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint and Bits are optional together i.e "
+ << " Either both Hint and Bits should be provided or neither.";
+ }
+
+ // Validate optional Hint and Bits
+ const uint32_t hint_id = _.GetOperandTypeId(inst, 13);
+ if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Hint must be a 32-bit int scalar";
+ }
+ const uint32_t bits_id = _.GetOperandTypeId(inst, 14);
+ if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Bits must be a 32-bit int scalar";
+ }
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ return SPV_SUCCESS;
+}
} // namespace val
} // namespace spvtools
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 57d544b..bc9d835 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -859,6 +859,22 @@
}
return true;
});
+ } else if (storage_class == spv::StorageClass::HitObjectAttributeEXT) {
+ function(consumer->function()->id())
+ ->RegisterExecutionModelLimitation([](spv::ExecutionModel model,
+ std::string* message) {
+ if (model != spv::ExecutionModel::RayGenerationKHR &&
+ model != spv::ExecutionModel::ClosestHitKHR &&
+ model != spv::ExecutionModel::MissKHR) {
+ if (message) {
+ *message =
+ "HitObjectAttributeEXT Storage Class is limited to "
+ "RayGenerationKHR, ClosestHitKHR or MissKHR execution model";
+ }
+ return false;
+ }
+ return true;
+ });
}
}
@@ -2032,6 +2048,7 @@
case spv::StorageClass::ShaderRecordBufferKHR:
case spv::StorageClass::TaskPayloadWorkgroupEXT:
case spv::StorageClass::HitObjectAttributeNV:
+ case spv::StorageClass::HitObjectAttributeEXT:
case spv::StorageClass::TileImageEXT:
case spv::StorageClass::NodePayloadAMDX:
case spv::StorageClass::TileAttachmentQCOM:
diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp
index de97bfa..080a329 100644
--- a/test/opt/type_manager_test.cpp
+++ b/test/opt/type_manager_test.cpp
@@ -177,6 +177,7 @@
types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002));
types.emplace_back(new RayQueryKHR());
types.emplace_back(new HitObjectNV());
+ types.emplace_back(new HitObjectEXT());
types.emplace_back(new CooperativeVectorNV(f32, 16));
// SPV_AMDX_shader_enqueue
diff --git a/test/val/val_ray_tracing_reorder_test.cpp b/test/val/val_ray_tracing_reorder_test.cpp
index a41af80..4c6f4f0 100644
--- a/test/val/val_ray_tracing_reorder_test.cpp
+++ b/test/val/val_ray_tracing_reorder_test.cpp
@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Tests instructions from SPV_NV_shader_invocation_reorder.
+// Tests instructions from SPV_NV_shader_invocation_reorder and
+// SPV_EXT_shader_invocation_reorder.
#include <sstream>
#include <string>
@@ -791,6 +792,778 @@
SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}
+
+// EXT Extension Tests
+using ValidateRayTracingReorderEXT = spvtest::ValidateBase<bool>;
+
+std::string GenerateReorderThreadCodeEXT(const std::string& body = "",
+ const std::string& declarations = "",
+ const std::string& extensions = "",
+ const std::string& capabilities = "") {
+ std::ostringstream ss;
+ ss << R"(
+ OpCapability RayTracingKHR
+ OpCapability ShaderInvocationReorderEXT
+ )";
+ ss << capabilities;
+ ss << R"(
+ OpExtension "SPV_KHR_ray_tracing"
+ OpExtension "SPV_EXT_shader_invocation_reorder"
+ )";
+ ss << extensions;
+ ss << R"(
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint RayGenerationNV %main "main" %hObj
+ OpSourceExtension "GL_EXT_ray_tracing"
+ OpSourceExtension "GL_EXT_shader_invocation_reorder"
+ OpName %main "main"
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %6 = OpTypeHitObjectEXT
+%_ptr_Private_6 = OpTypePointer Private %6
+ %hObj = OpVariable %_ptr_Private_6 Private
+ )";
+ ss << declarations;
+
+ ss << R"(
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ )";
+
+ ss << body;
+
+ ss << R"(
+ OpReturn
+ OpFunctionEnd
+ )";
+ return ss.str();
+}
+
+std::string GenerateReorderShaderCodeEXT(const std::string& body = "",
+ const std::string& declarations = "",
+ const std::string& extensions = "",
+ const std::string& capabilties = "") {
+ std::ostringstream ss;
+ ss << R"(
+ OpCapability RayTracingKHR
+ OpCapability ShaderInvocationReorderEXT
+ )";
+ ss << capabilties;
+ ss << R"( OpExtension "SPV_KHR_ray_tracing"
+ OpExtension "SPV_EXT_shader_invocation_reorder"
+ )";
+ ss << extensions;
+ ss << R"(
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint RayGenerationKHR %main "main" %attr %_ %hObj %payload %__0 %as %__1
+ OpSource GLSL 460
+ OpSourceExtension "GL_EXT_ray_tracing"
+ OpSourceExtension "GL_EXT_shader_invocation_reorder"
+ OpName %main "main"
+ OpName %attr "attr"
+ OpName %hBlock "hBlock"
+ OpMemberName %hBlock 0 "attrval"
+ OpName %_ ""
+ OpName %hObj "hObj"
+ OpName %payload "payload"
+ OpName %pBlock "pBlock"
+ OpMemberName %pBlock 0 "val1"
+ OpMemberName %pBlock 1 "val2"
+ OpName %__0 ""
+ OpName %as "as"
+ OpName %block "block"
+ OpMemberName %block 0 "op"
+ OpName %__1 ""
+ OpDecorate %hBlock Block
+ OpDecorate %pBlock Block
+ OpDecorate %as DescriptorSet 0
+ OpDecorate %as Binding 0
+ OpMemberDecorate %block 0 Offset 0
+ OpDecorate %block Block
+ OpDecorate %__1 DescriptorSet 0
+ OpDecorate %__1 Binding 1
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v2float = OpTypeVector %float 2
+%_ptr_HitObjectAttributeEXT_v2float = OpTypePointer HitObjectAttributeEXT %v2float
+ %attr = OpVariable %_ptr_HitObjectAttributeEXT_v2float HitObjectAttributeEXT
+ %float_1 = OpConstant %float 1
+ %11 = OpConstantComposite %v2float %float_1 %float_1
+ %hBlock = OpTypeStruct %float
+%_ptr_HitObjectAttributeEXT_hBlock = OpTypePointer HitObjectAttributeEXT %hBlock
+ %_ = OpVariable %_ptr_HitObjectAttributeEXT_hBlock HitObjectAttributeEXT
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+ %float_2 = OpConstant %float 2
+%_ptr_HitObjectAttributeEXT_float = OpTypePointer HitObjectAttributeEXT %float
+ %20 = OpTypeHitObjectEXT
+ %_ptr_Private_20 = OpTypePointer Private %20
+ %hObj = OpVariable %_ptr_Private_20 Private
+ %23 = OpTypeAccelerationStructureKHR
+ %_ptr_UniformConstant_23 = OpTypePointer UniformConstant %23
+ %as = OpVariable %_ptr_UniformConstant_23 UniformConstant
+ %v4float = OpTypeVector %float 4
+%_ptr_RayPayloadKHR_v4float = OpTypePointer RayPayloadKHR %v4float
+ %payload = OpVariable %_ptr_RayPayloadKHR_v4float RayPayloadKHR
+ %pBlock = OpTypeStruct %v2float %v2float
+%_ptr_RayPayloadKHR_pBlock = OpTypePointer RayPayloadKHR %pBlock
+ %__0 = OpVariable %_ptr_RayPayloadKHR_pBlock RayPayloadKHR
+ %block = OpTypeStruct %float
+%_ptr_StorageBuffer_block = OpTypePointer StorageBuffer %block
+ %__1 = OpVariable %_ptr_StorageBuffer_block StorageBuffer
+ )";
+
+ ss << declarations;
+
+ ss << R"(
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ )";
+
+ ss << body;
+
+ ss << R"(
+ OpReturn
+ OpFunctionEnd)";
+ return ss.str();
+}
+
+TEST_F(ValidateRayTracingReorderEXT, ReorderThreadWithHintEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+ )";
+
+ const std::string body = R"(
+ OpReorderThreadWithHintEXT %uint_4 %uint_4
+ )";
+
+ CompileSuccessfully(GenerateReorderThreadCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, ReorderThreadWithHitObjectEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+ %uint_2 = OpConstant %uint 2
+ )";
+
+ const std::string body = R"(
+ OpReorderThreadWithHitObjectEXT %hObj
+ OpReorderThreadWithHitObjectEXT %hObj %uint_4 %uint_2
+ )";
+
+ CompileSuccessfully(GenerateReorderThreadCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectTraceRayEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %31 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %32 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+ %int_1 = OpConstant %int 1
+ )";
+
+ const std::string body = R"(
+ OpStore %attr %11
+ %26 = OpLoad %23 %as
+ OpHitObjectTraceRayEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %payload
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectTraceRayMotionEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %31 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %32 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+ %float_10 = OpConstant %float 10
+ %int_2 = OpConstant %int 2
+ )";
+
+ const std::string body = R"(
+ OpStore %attr %11
+ %26 = OpLoad %23 %as
+ OpHitObjectTraceRayMotionEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %float_10 %__0
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectRecordEmptyEXT) {
+ const std::string body = R"(
+ OpHitObjectRecordEmptyEXT %hObj
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectRecordMissEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %29 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %float_1_5 = OpConstant %float 1.5
+ %31 = OpConstantComposite %v3float %float_1_5 %float_1_5 %float_1_5
+ %float_5 = OpConstant %float 5
+ )";
+
+ const std::string body = R"(
+ OpHitObjectRecordMissEXT %hObj %uint_1 %uint_1 %29 %float_2 %31 %float_5
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectIsHitEXT) {
+ const std::string declarations = R"(
+ %bool = OpTypeBool
+ %_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
+ )";
+
+ const std::string body = R"(
+ %26 = OpHitObjectIsHitEXT %bool %hObj
+ OpSelectionMerge %28 None
+ OpBranchConditional %26 %27 %28
+ %27 = OpLabel
+ %33 = OpAccessChain %_ptr_StorageBuffer_float %__1 %int_0
+ OpStore %33 %float_1
+ OpBranch %28
+ %28 = OpLabel
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectIsMissEXT) {
+ const std::string declarations = R"(
+ %bool = OpTypeBool
+ %_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
+ )";
+
+ const std::string body = R"(
+ %26 = OpHitObjectIsMissEXT %bool %hObj
+ OpSelectionMerge %28 None
+ OpBranchConditional %26 %27 %28
+ %27 = OpLabel
+ %33 = OpAccessChain %_ptr_StorageBuffer_float %__1 %int_0
+ OpStore %33 %float_1
+ OpBranch %28
+ %28 = OpLabel
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectIsEmptyEXT) {
+ const std::string declarations = R"(
+ %bool = OpTypeBool
+ %_ptr_StorageBuffer_float = OpTypePointer StorageBuffer %float
+ )";
+
+ const std::string body = R"(
+ %26 = OpHitObjectIsEmptyEXT %bool %hObj
+ OpSelectionMerge %28 None
+ OpBranchConditional %26 %27 %28
+ %27 = OpLabel
+ %33 = OpAccessChain %_ptr_StorageBuffer_float %__1 %int_0
+ OpStore %33 %float_1
+ OpBranch %28
+ %28 = OpLabel
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetGeometryIndexEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_int = OpTypePointer Function %int
+ )";
+
+ const std::string body = R"(
+ %id = OpVariable %_ptr_Function_int Function
+ %12 = OpHitObjectGetGeometryIndexEXT %int %hObj
+ OpStore %id %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetPrimitiveIndexEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_int = OpTypePointer Function %int
+ )";
+
+ const std::string body = R"(
+ %id = OpVariable %_ptr_Function_int Function
+ %12 = OpHitObjectGetPrimitiveIndexEXT %int %hObj
+ OpStore %id %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetInstanceIdEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_int = OpTypePointer Function %int
+ )";
+
+ const std::string body = R"(
+ %id = OpVariable %_ptr_Function_int Function
+ %12 = OpHitObjectGetInstanceIdEXT %int %hObj
+ OpStore %id %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetInstanceCustomIndexEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_int = OpTypePointer Function %int
+ )";
+
+ const std::string body = R"(
+ %id = OpVariable %_ptr_Function_int Function
+ %12 = OpHitObjectGetInstanceCustomIndexEXT %int %hObj
+ OpStore %id %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetHitKindEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %_ptr_Function_uint = OpTypePointer Function %uint
+ )";
+
+ const std::string body = R"(
+ %uid = OpVariable %_ptr_Function_uint Function
+ %12 = OpHitObjectGetHitKindEXT %uint %hObj
+ OpStore %uid %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetCurrentTimeEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_float = OpTypePointer Function %float
+ )";
+
+ const std::string body = R"(
+ %time = OpVariable %_ptr_Function_float Function
+ %12 = OpHitObjectGetCurrentTimeEXT %float %hObj
+ OpStore %time %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetObjectRayOriginEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %_ptr_Function_v3float = OpTypePointer Function %v3float
+ )";
+
+ const std::string body = R"(
+ %oorig = OpVariable %_ptr_Function_v3float Function
+ %13 = OpHitObjectGetObjectRayOriginEXT %v3float %hObj
+ OpStore %oorig %13
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetObjectRayDirectionEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %_ptr_Function_v3float = OpTypePointer Function %v3float
+ )";
+
+ const std::string body = R"(
+ %odir = OpVariable %_ptr_Function_v3float Function
+ %13 = OpHitObjectGetObjectRayDirectionEXT %v3float %hObj
+ OpStore %odir %13
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetRayTMaxEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_float = OpTypePointer Function %float
+ )";
+
+ const std::string body = R"(
+ %tmax = OpVariable %_ptr_Function_float Function
+ %12 = OpHitObjectGetRayTMaxEXT %float %hObj
+ OpStore %tmax %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetRayTMinEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_float = OpTypePointer Function %float
+ )";
+
+ const std::string body = R"(
+ %tmin = OpVariable %_ptr_Function_float Function
+ %12 = OpHitObjectGetRayTMinEXT %float %hObj
+ OpStore %tmin %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetRayFlagsEXT) {
+ const std::string declarations = R"(
+ %_ptr_Function_int = OpTypePointer Function %int
+ )";
+
+ const std::string body = R"(
+ %flags = OpVariable %_ptr_Function_int Function
+ %12 = OpHitObjectGetRayFlagsEXT %int %hObj
+ OpStore %flags %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetWorldRayOriginEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %_ptr_Function_v3float = OpTypePointer Function %v3float
+ )";
+
+ const std::string body = R"(
+ %orig = OpVariable %_ptr_Function_v3float Function
+ %13 = OpHitObjectGetWorldRayOriginEXT %v3float %hObj
+ OpStore %orig %13
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetWorldRayDirectionEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %_ptr_Function_v3float = OpTypePointer Function %v3float
+ )";
+
+ const std::string body = R"(
+ %dir = OpVariable %_ptr_Function_v3float Function
+ %13 = OpHitObjectGetWorldRayDirectionEXT %v3float %hObj
+ OpStore %dir %13
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetObjectToWorldEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %mat4v3float = OpTypeMatrix %v3float 4
+ %_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
+ )";
+
+ const std::string body = R"(
+ %otw = OpVariable %_ptr_Function_mat4v3float Function
+ %14 = OpHitObjectGetObjectToWorldEXT %mat4v3float %hObj
+ OpStore %otw %14
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetWorldToObjectEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %mat4v3float = OpTypeMatrix %v3float 4
+ %_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
+ )";
+
+ const std::string body = R"(
+ %wto = OpVariable %_ptr_Function_mat4v3float Function
+ %14 = OpHitObjectGetWorldToObjectEXT %mat4v3float %hObj
+ OpStore %wto %14
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetShaderRecordBufferHandleEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %v2uint = OpTypeVector %uint 2
+ %_ptr_Function_v2uint = OpTypePointer Function %v2uint
+ )";
+
+ const std::string body = R"(
+ %handle = OpVariable %_ptr_Function_v2uint Function
+ %13 = OpHitObjectGetShaderRecordBufferHandleEXT %v2uint %hObj
+ OpStore %handle %13
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT,
+ HitObjectGetShaderBindingTableRecordIndexEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %_ptr_Function_uint = OpTypePointer Function %uint
+ )";
+
+ const std::string body = R"(
+ %rid = OpVariable %_ptr_Function_uint Function
+ %12 = OpHitObjectGetShaderBindingTableRecordIndexEXT %uint %hObj
+ OpStore %rid %12
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT,
+ HitObjectSetShaderBindingTableRecordIndexEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_5 = OpConstant %uint 5
+ )";
+
+ const std::string body = R"(
+ OpHitObjectSetShaderBindingTableRecordIndexEXT %hObj %uint_5
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectGetAttributesEXT) {
+ const std::string body = R"(
+ OpHitObjectGetAttributesEXT %hObj %attr
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectExecuteShaderEXT) {
+ const std::string body = R"(
+ OpHitObjectExecuteShaderEXT %hObj %payload
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT,
+ HitObjectGetIntersectionTriangleVertexPositionsEXT) {
+ const std::string declarations = R"(
+ %v3float = OpTypeVector %float 3
+ %uint = OpTypeInt 32 0
+ %uint_3 = OpConstant %uint 3
+ %arr_3_v3float = OpTypeArray %v3float %uint_3
+ %_ptr_Function_arr_3_v3float = OpTypePointer Function %arr_3_v3float
+ )";
+
+ const std::string body = R"(
+ %vertices = OpVariable %_ptr_Function_arr_3_v3float Function
+ %result = OpHitObjectGetIntersectionTriangleVertexPositionsEXT %arr_3_v3float %hObj
+ OpStore %vertices %result
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectRecordFromQueryEXT) {
+ const std::string cap = R"(
+ OpCapability RayQueryKHR
+ )";
+ const std::string extensions = R"(
+ OpExtension "SPV_KHR_ray_query"
+ )";
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_5 = OpConstant %uint 5
+ %rayquery_type = OpTypeRayQueryKHR
+ %_ptr_Function_rayquery = OpTypePointer Function %rayquery_type
+ )";
+
+ const std::string body = R"(
+ %ray_query = OpVariable %_ptr_Function_rayquery Function
+ OpHitObjectRecordFromQueryEXT %hObj %ray_query %uint_5 %attr
+ )";
+
+ CompileSuccessfully(
+ GenerateReorderShaderCodeEXT(body, declarations, extensions, cap).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectRecordMissMotionEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %29 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %float_1_5 = OpConstant %float 1.5
+ %31 = OpConstantComposite %v3float %float_1_5 %float_1_5 %float_1_5
+ %float_5 = OpConstant %float 5
+ %float_10 = OpConstant %float 10
+ )";
+
+ const std::string body = R"(
+ OpHitObjectRecordMissMotionEXT %hObj %uint_1 %uint_1 %29 %float_2 %31 %float_5 %float_10
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+// Fused Hit Object Instructions Tests
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectReorderExecuteShaderEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+ %uint_2 = OpConstant %uint 2
+ )";
+
+ const std::string body = R"(
+ OpHitObjectReorderExecuteShaderEXT %hObj %payload
+ OpHitObjectReorderExecuteShaderEXT %hObj %payload %uint_4 %uint_2
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectTraceReorderExecuteEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_4 = OpConstant %uint 4
+ %uint_2 = OpConstant %uint 2
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %31 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %32 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+ %int_1 = OpConstant %int 1
+ )";
+
+ const std::string body = R"(
+ OpStore %attr %11
+ %26 = OpLoad %23 %as
+ OpHitObjectTraceReorderExecuteEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %payload
+ OpHitObjectTraceReorderExecuteEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %payload %uint_4 %uint_2
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateRayTracingReorderEXT, HitObjectTraceMotionReorderExecuteEXT) {
+ const std::string declarations = R"(
+ %uint = OpTypeInt 32 0
+ %uint_1 = OpConstant %uint 1
+ %uint_4 = OpConstant %uint 4
+ %uint_2 = OpConstant %uint 2
+ %v3float = OpTypeVector %float 3
+ %float_0_5 = OpConstant %float 0.5
+ %31 = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0_5
+ %32 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+ %float_10 = OpConstant %float 10
+ %int_2 = OpConstant %int 2
+ )";
+
+ const std::string body = R"(
+ OpStore %attr %11
+ %26 = OpLoad %23 %as
+ OpHitObjectTraceMotionReorderExecuteEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %float_10 %__0
+ OpHitObjectTraceMotionReorderExecuteEXT %hObj %26 %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %31 %float_0_5 %32 %float_1 %float_10 %__0 %uint_4 %uint_2
+ )";
+
+ CompileSuccessfully(GenerateReorderShaderCodeEXT(body, declarations).c_str(),
+ SPV_ENV_VULKAN_1_2);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
} // namespace
} // namespace val
} // namespace spvtools