| // Copyright (c) 2025 Google LLC. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include <iostream> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| #include "source/opcode.h" |
| #include "source/val/validate.h" |
| #include "source/val/validation_state.h" |
| |
| namespace spvtools { |
| namespace val { |
| namespace { |
| |
| // Returns true if inst is a logical pointer. |
| bool IsLogicalPointer(const ValidationState_t& _, const Instruction* inst) { |
| if (!_.IsPointerType(inst->type_id())) { |
| return false; |
| } |
| |
| // Physical storage buffer pointers are not logical pointers. |
| auto type_inst = _.FindDef(inst->type_id()); |
| auto sc = type_inst->GetOperandAs<spv::StorageClass>(1); |
| if (sc == spv::StorageClass::PhysicalStorageBuffer) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| // Returns true if inst is a variable pointer. |
| // Caches the result in variable_pointers. |
| bool IsVariablePointer(const ValidationState_t& _, |
| std::unordered_map<uint32_t, bool>& variable_pointers, |
| const Instruction* inst) { |
| const auto iter = variable_pointers.find(inst->id()); |
| if (iter != variable_pointers.end()) { |
| return iter->second; |
| } |
| |
| // Temporarily mark the instruction as NOT a variable pointer. |
| variable_pointers[inst->id()] = false; |
| |
| bool is_var_ptr = false; |
| switch (inst->opcode()) { |
| case spv::Op::OpPtrAccessChain: |
| case spv::Op::OpUntypedPtrAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsPtrAccessChainKHR: |
| case spv::Op::OpLoad: |
| case spv::Op::OpSelect: |
| case spv::Op::OpPhi: |
| case spv::Op::OpFunctionCall: |
| case spv::Op::OpConstantNull: |
| is_var_ptr = true; |
| break; |
| case spv::Op::OpFunctionParameter: |
| // Special case: skip to function calls. |
| if (IsLogicalPointer(_, inst)) { |
| auto func = inst->function(); |
| auto func_inst = _.FindDef(func->id()); |
| |
| const auto param_inst_num = inst - &_.ordered_instructions()[0]; |
| uint32_t param_index = 0; |
| uint32_t inst_index = 1; |
| while (_.ordered_instructions()[param_inst_num - inst_index].opcode() != |
| spv::Op::OpFunction) { |
| if (_.ordered_instructions()[param_inst_num - inst_index].opcode() == |
| spv::Op::OpFunctionParameter) { |
| param_index++; |
| } |
| ++inst_index; |
| } |
| |
| for (const auto& use_pair : func_inst->uses()) { |
| const auto use_inst = use_pair.first; |
| if (use_inst->opcode() == spv::Op::OpFunctionCall) { |
| const auto arg_id = |
| use_inst->GetOperandAs<uint32_t>(3 + param_index); |
| const auto arg_inst = _.FindDef(arg_id); |
| is_var_ptr |= IsVariablePointer(_, variable_pointers, arg_inst); |
| } |
| } |
| } |
| break; |
| default: { |
| for (uint32_t i = 0; i < inst->operands().size(); ++i) { |
| if (inst->operands()[i].type != SPV_OPERAND_TYPE_ID) { |
| continue; |
| } |
| |
| auto op_inst = _.FindDef(inst->GetOperandAs<uint32_t>(i)); |
| if (IsLogicalPointer(_, op_inst)) { |
| is_var_ptr |= IsVariablePointer(_, variable_pointers, op_inst); |
| } |
| } |
| break; |
| } |
| } |
| variable_pointers[inst->id()] = is_var_ptr; |
| return is_var_ptr; |
| } |
| |
| spv_result_t ValidateLogicalPointerOperands(ValidationState_t& _, |
| const Instruction* inst) { |
| bool has_pointer_operand = false; |
| spv::StorageClass sc = spv::StorageClass::Function; |
| for (uint32_t i = 0; i < inst->operands().size(); ++i) { |
| if (inst->operands()[i].type != SPV_OPERAND_TYPE_ID) { |
| continue; |
| } |
| |
| auto op_inst = _.FindDef(inst->GetOperandAs<uint32_t>(i)); |
| if (IsLogicalPointer(_, op_inst)) { |
| has_pointer_operand = true; |
| |
| // Assume that there are not mixed storage classes in the instruction. |
| // This is not true for OpCopyMemory and OpCopyMemorySized, but they allow |
| // all storage classes. |
| auto type_inst = _.FindDef(op_inst->type_id()); |
| sc = type_inst->GetOperandAs<spv::StorageClass>(1); |
| break; |
| } |
| } |
| |
| if (!has_pointer_operand) { |
| return SPV_SUCCESS; |
| } |
| |
| switch (inst->opcode()) { |
| // The following instructions allow logical pointer operands in all cases |
| // without capabilities. |
| case spv::Op::OpLoad: |
| case spv::Op::OpStore: |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| case spv::Op::OpFunctionCall: |
| case spv::Op::OpImageTexelPointer: |
| case spv::Op::OpCopyMemory: |
| case spv::Op::OpCopyObject: |
| case spv::Op::OpArrayLength: |
| case spv::Op::OpExtInst: |
| // Core spec bugs |
| case spv::Op::OpDecorate: |
| case spv::Op::OpDecorateId: |
| case spv::Op::OpGroupDecorate: |
| case spv::Op::OpEntryPoint: |
| case spv::Op::OpName: |
| case spv::Op::OpDecorateString: |
| // SPV_KHR_untyped_pointers |
| case spv::Op::OpUntypedArrayLengthKHR: |
| case spv::Op::OpUntypedAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsAccessChainKHR: |
| case spv::Op::OpCopyMemorySized: |
| // Cooperative matrix KHR/NV |
| case spv::Op::OpCooperativeMatrixLoadKHR: |
| case spv::Op::OpCooperativeMatrixLoadNV: |
| case spv::Op::OpCooperativeMatrixStoreKHR: |
| case spv::Op::OpCooperativeMatrixStoreNV: |
| // SPV_KHR_ray_tracing |
| case spv::Op::OpTraceRayKHR: |
| case spv::Op::OpExecuteCallableKHR: |
| // SPV_KHR_ray_query |
| case spv::Op::OpRayQueryConfirmIntersectionKHR: |
| case spv::Op::OpRayQueryInitializeKHR: |
| case spv::Op::OpRayQueryTerminateKHR: |
| case spv::Op::OpRayQueryGenerateIntersectionKHR: |
| case spv::Op::OpRayQueryProceedKHR: |
| case spv::Op::OpRayQueryGetIntersectionTypeKHR: |
| case spv::Op::OpRayQueryGetRayTMinKHR: |
| case spv::Op::OpRayQueryGetRayFlagsKHR: |
| case spv::Op::OpRayQueryGetIntersectionTKHR: |
| case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR: |
| case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR: |
| case spv::Op:: |
| OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR: |
| case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR: |
| case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR: |
| case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: |
| case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR: |
| case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: |
| case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR: |
| case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR: |
| case spv::Op::OpRayQueryGetWorldRayDirectionKHR: |
| case spv::Op::OpRayQueryGetWorldRayOriginKHR: |
| case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR: |
| case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: |
| // SPV_KHR_ray_tracing_position_fetch |
| case spv::Op::OpRayQueryGetIntersectionTriangleVertexPositionsKHR: |
| // SPV_NV_cluster_acceleration_structure |
| case spv::Op::OpRayQueryGetClusterIdNV: |
| case spv::Op::OpHitObjectGetClusterIdNV: |
| // SPV_NV_ray_tracing_motion_blur |
| case spv::Op::OpTraceMotionNV: |
| case spv::Op::OpTraceRayMotionNV: |
| // SPV_NV_linear_swept_spheres |
| case spv::Op::OpRayQueryGetIntersectionSpherePositionNV: |
| case spv::Op::OpRayQueryGetIntersectionSphereRadiusNV: |
| case spv::Op::OpRayQueryGetIntersectionLSSPositionsNV: |
| case spv::Op::OpRayQueryGetIntersectionLSSRadiiNV: |
| case spv::Op::OpRayQueryGetIntersectionLSSHitValueNV: |
| case spv::Op::OpRayQueryIsSphereHitNV: |
| case spv::Op::OpRayQueryIsLSSHitNV: |
| case spv::Op::OpHitObjectGetSpherePositionNV: |
| case spv::Op::OpHitObjectGetSphereRadiusNV: |
| case spv::Op::OpHitObjectGetLSSPositionsNV: |
| case spv::Op::OpHitObjectGetLSSRadiiNV: |
| case spv::Op::OpHitObjectIsSphereHitNV: |
| case spv::Op::OpHitObjectIsLSSHitNV: |
| // SPV_NV_shader_invocation_reorder |
| case spv::Op::OpReorderThreadWithHitObjectNV: |
| case spv::Op::OpHitObjectTraceRayNV: |
| case spv::Op::OpHitObjectTraceRayMotionNV: |
| case spv::Op::OpHitObjectRecordHitNV: |
| case spv::Op::OpHitObjectRecordHitMotionNV: |
| case spv::Op::OpHitObjectRecordHitWithIndexNV: |
| case spv::Op::OpHitObjectRecordHitWithIndexMotionNV: |
| case spv::Op::OpHitObjectRecordMissNV: |
| case spv::Op::OpHitObjectRecordMissMotionNV: |
| case spv::Op::OpHitObjectRecordEmptyNV: |
| case spv::Op::OpHitObjectExecuteShaderNV: |
| case spv::Op::OpHitObjectGetCurrentTimeNV: |
| case spv::Op::OpHitObjectGetAttributesNV: |
| case spv::Op::OpHitObjectGetHitKindNV: |
| case spv::Op::OpHitObjectGetPrimitiveIndexNV: |
| case spv::Op::OpHitObjectGetGeometryIndexNV: |
| case spv::Op::OpHitObjectGetInstanceIdNV: |
| case spv::Op::OpHitObjectGetInstanceCustomIndexNV: |
| case spv::Op::OpHitObjectGetObjectRayOriginNV: |
| case spv::Op::OpHitObjectGetObjectRayDirectionNV: |
| case spv::Op::OpHitObjectGetWorldRayDirectionNV: |
| case spv::Op::OpHitObjectGetWorldRayOriginNV: |
| case spv::Op::OpHitObjectGetObjectToWorldNV: |
| case spv::Op::OpHitObjectGetWorldToObjectNV: |
| case spv::Op::OpHitObjectGetRayTMaxNV: |
| case spv::Op::OpHitObjectGetRayTMinNV: |
| case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexNV: |
| case spv::Op::OpHitObjectGetShaderRecordBufferHandleNV: |
| case spv::Op::OpHitObjectIsEmptyNV: |
| case spv::Op::OpHitObjectIsHitNV: |
| case spv::Op::OpHitObjectIsMissNV: |
| // SPV_EXT_shader_invocation_reorder |
| case spv::Op::OpHitObjectRecordFromQueryEXT: |
| case spv::Op::OpHitObjectRecordMissEXT: |
| case spv::Op::OpHitObjectRecordMissMotionEXT: |
| case spv::Op::OpHitObjectGetIntersectionTriangleVertexPositionsEXT: |
| case spv::Op::OpHitObjectGetRayFlagsEXT: |
| case spv::Op::OpHitObjectSetShaderBindingTableRecordIndexEXT: |
| case spv::Op::OpHitObjectReorderExecuteShaderEXT: |
| case spv::Op::OpHitObjectTraceReorderExecuteEXT: |
| case spv::Op::OpHitObjectTraceMotionReorderExecuteEXT: |
| case spv::Op::OpReorderThreadWithHintEXT: |
| case spv::Op::OpReorderThreadWithHitObjectEXT: |
| case spv::Op::OpHitObjectTraceRayEXT: |
| case spv::Op::OpHitObjectTraceRayMotionEXT: |
| case spv::Op::OpHitObjectRecordEmptyEXT: |
| case spv::Op::OpHitObjectExecuteShaderEXT: |
| case spv::Op::OpHitObjectGetCurrentTimeEXT: |
| case spv::Op::OpHitObjectGetAttributesEXT: |
| case spv::Op::OpHitObjectGetHitKindEXT: |
| case spv::Op::OpHitObjectGetPrimitiveIndexEXT: |
| case spv::Op::OpHitObjectGetGeometryIndexEXT: |
| case spv::Op::OpHitObjectGetInstanceIdEXT: |
| case spv::Op::OpHitObjectGetInstanceCustomIndexEXT: |
| case spv::Op::OpHitObjectGetObjectRayOriginEXT: |
| case spv::Op::OpHitObjectGetObjectRayDirectionEXT: |
| case spv::Op::OpHitObjectGetWorldRayDirectionEXT: |
| case spv::Op::OpHitObjectGetWorldRayOriginEXT: |
| case spv::Op::OpHitObjectGetObjectToWorldEXT: |
| case spv::Op::OpHitObjectGetWorldToObjectEXT: |
| case spv::Op::OpHitObjectGetRayTMaxEXT: |
| case spv::Op::OpHitObjectGetRayTMinEXT: |
| case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexEXT: |
| case spv::Op::OpHitObjectGetShaderRecordBufferHandleEXT: |
| case spv::Op::OpHitObjectIsEmptyEXT: |
| case spv::Op::OpHitObjectIsHitEXT: |
| case spv::Op::OpHitObjectIsMissEXT: |
| // SPV_NV_raw_access_chains |
| case spv::Op::OpRawAccessChainNV: |
| // SPV_NV_cooperative_matrix2 |
| case spv::Op::OpCooperativeMatrixLoadTensorNV: |
| case spv::Op::OpCooperativeMatrixStoreTensorNV: |
| // SPV_NV_cooperative_vector |
| case spv::Op::OpCooperativeVectorLoadNV: |
| case spv::Op::OpCooperativeVectorStoreNV: |
| case spv::Op::OpCooperativeVectorMatrixMulNV: |
| case spv::Op::OpCooperativeVectorMatrixMulAddNV: |
| case spv::Op::OpCooperativeVectorOuterProductAccumulateNV: |
| case spv::Op::OpCooperativeVectorReduceSumAccumulateNV: |
| // SPV_EXT_mesh_shader |
| case spv::Op::OpEmitMeshTasksEXT: |
| // SPV_AMD_shader_enqueue (spec bugs) |
| case spv::Op::OpEnqueueNodePayloadsAMDX: |
| case spv::Op::OpNodePayloadArrayLengthAMDX: |
| case spv::Op::OpIsNodePayloadValidAMDX: |
| case spv::Op::OpFinishWritingNodePayloadAMDX: |
| // SPV_ARM_graph |
| case spv::Op::OpGraphEntryPointARM: |
| return SPV_SUCCESS; |
| // The following cases require a variable pointer capability. Since all |
| // instructions are for variable pointers, the storage class and capability |
| // are also checked. |
| case spv::Op::OpReturnValue: |
| case spv::Op::OpPtrAccessChain: |
| case spv::Op::OpPtrEqual: |
| case spv::Op::OpPtrNotEqual: |
| case spv::Op::OpPtrDiff: |
| // Core spec bugs |
| case spv::Op::OpSelect: |
| case spv::Op::OpPhi: |
| case spv::Op::OpVariable: |
| // SPV_KHR_untyped_pointers |
| case spv::Op::OpUntypedPtrAccessChainKHR: |
| if ((_.HasCapability(spv::Capability::VariablePointersStorageBuffer) && |
| sc == spv::StorageClass ::StorageBuffer) || |
| (_.HasCapability(spv::Capability::VariablePointers) && |
| sc == spv::StorageClass::Workgroup)) { |
| return SPV_SUCCESS; |
| } |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Instruction may only have a logical pointer operand in the " |
| "StorageBuffer or Workgroup storage classes with appropriate " |
| "variable pointers capability"; |
| default: |
| if (spvOpcodeIsAtomicOp(inst->opcode())) { |
| return SPV_SUCCESS; |
| } |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Instruction may not have a logical pointer operand"; |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t ValidateLogicalPointerReturns(ValidationState_t& _, |
| const Instruction* inst) { |
| if (!IsLogicalPointer(_, inst)) { |
| return SPV_SUCCESS; |
| } |
| |
| const auto type_inst = _.FindDef(inst->type_id()); |
| const auto sc = type_inst->GetOperandAs<spv::StorageClass>(1u); |
| |
| switch (inst->opcode()) { |
| // Core spec without an variable pointer capability. |
| case spv::Op::OpVariable: |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| case spv::Op::OpFunctionParameter: |
| case spv::Op::OpImageTexelPointer: |
| case spv::Op::OpCopyObject: |
| // Core spec bugs |
| case spv::Op::OpUndef: |
| // SPV_KHR_untyped_pointers |
| case spv::Op::OpUntypedAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsAccessChainKHR: |
| case spv::Op::OpUntypedVariableKHR: |
| // SPV_NV_raw_access_chains |
| case spv::Op::OpRawAccessChainNV: |
| // SPV_AMD_shader_enqueue (spec bugs) |
| case spv::Op::OpAllocateNodePayloadsAMDX: |
| return SPV_SUCCESS; |
| // Core spec with variable pointer capability. Check storage classes since |
| // variable pointers can only be in certain storage classes. |
| case spv::Op::OpSelect: |
| case spv::Op::OpPhi: |
| case spv::Op::OpFunctionCall: |
| case spv::Op::OpPtrAccessChain: |
| case spv::Op::OpLoad: |
| case spv::Op::OpConstantNull: |
| case spv::Op::OpFunction: |
| // SPV_KHR_untyped_pointers |
| case spv::Op::OpUntypedPtrAccessChainKHR: |
| if ((_.HasCapability(spv::Capability::VariablePointersStorageBuffer) && |
| sc == spv::StorageClass ::StorageBuffer) || |
| (_.HasCapability(spv::Capability::VariablePointers) && |
| sc == spv::StorageClass::Workgroup)) { |
| return SPV_SUCCESS; |
| } |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Instruction may only return a logical pointer in the " |
| "StorageBuffer or Workgroup storage classes with appropriate " |
| "variable pointers capability"; |
| default: |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Instruction may not return a logical pointer"; |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t IsBlockArray(ValidationState_t& _, const Instruction* type) { |
| if (type->opcode() == spv::Op::OpTypeArray || |
| type->opcode() == spv::Op::OpTypeRuntimeArray) { |
| const auto element_type = _.FindDef(type->GetOperandAs<uint32_t>(1)); |
| if (element_type->opcode() == spv::Op::OpTypeStruct && |
| (_.HasDecoration(element_type->id(), spv::Decoration::Block) || |
| _.HasDecoration(element_type->id(), spv::Decoration::BufferBlock))) { |
| return SPV_ERROR_INVALID_DATA; |
| } |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t CheckMatrixElementTyped(ValidationState_t& _, |
| const Instruction* inst) { |
| switch (inst->opcode()) { |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| case spv::Op::OpPtrAccessChain: { |
| // Get the type of the base operand. |
| uint32_t start_index = |
| inst->opcode() == spv::Op::OpPtrAccessChain ? 4 : 3; |
| const auto access_type_id = _.GetOperandTypeId(inst, 2); |
| auto access_type = _.FindDef(access_type_id); |
| access_type = _.FindDef(access_type->GetOperandAs<uint32_t>(2)); |
| |
| // If the base operand is a matrix, then it was definitely pointing to a |
| // sub-component. |
| if (access_type->opcode() == spv::Op::OpTypeMatrix) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to a column or a " |
| "component of a column of a matrix"; |
| } |
| |
| // Otherwise, step through the indices to see if we pass a matrix. |
| for (uint32_t i = start_index; i < inst->operands().size(); ++i) { |
| const auto index = inst->GetOperandAs<uint32_t>(i); |
| if (access_type->opcode() == spv::Op::OpTypeStruct) { |
| uint64_t val = 0; |
| _.EvalConstantValUint64(index, &val); |
| access_type = _.FindDef(access_type->GetOperandAs<uint32_t>( |
| 1 + static_cast<uint32_t>(val))); |
| } else { |
| access_type = _.FindDef(_.GetComponentType(access_type->id())); |
| } |
| |
| if (access_type->opcode() == spv::Op::OpTypeMatrix) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to a column or a " |
| "component of a column of a matrix"; |
| } |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t CheckMatrixElementUntyped(ValidationState_t& _, |
| const Instruction* inst) { |
| switch (inst->opcode()) { |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| case spv::Op::OpPtrAccessChain: |
| case spv::Op::OpUntypedAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsAccessChainKHR: |
| case spv::Op::OpUntypedPtrAccessChainKHR: { |
| const bool untyped = spvOpcodeGeneratesUntypedPointer(inst->opcode()); |
| uint32_t start_index; |
| Instruction* access_type = nullptr; |
| if (untyped) { |
| // Get the type of the base operand. |
| start_index = |
| inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR ? 5 : 4; |
| const auto access_type_id = inst->GetOperandAs<uint32_t>(2); |
| access_type = _.FindDef(access_type_id); |
| } else { |
| start_index = inst->opcode() == spv::Op::OpPtrAccessChain ? 4 : 3; |
| const auto access_type_id = _.GetOperandTypeId(inst, 2); |
| access_type = _.FindDef(access_type_id); |
| access_type = _.FindDef(access_type->GetOperandAs<uint32_t>(2)); |
| } |
| |
| // If the base operand is a matrix, then it was definitely pointing to a |
| // sub-component. |
| if (access_type->opcode() == spv::Op::OpTypeMatrix) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to a column or a " |
| "component of a column of a matrix."; |
| } |
| |
| // Otherwise, step through the indices to see if we pass a matrix. |
| for (uint32_t i = start_index; i < inst->operands().size(); ++i) { |
| const auto index = inst->GetOperandAs<uint32_t>(i); |
| if (access_type->opcode() == spv::Op::OpTypeStruct) { |
| uint64_t val = 0; |
| _.EvalConstantValUint64(index, &val); |
| access_type = _.FindDef(access_type->GetOperandAs<uint32_t>( |
| 1 + static_cast<uint32_t>(val))); |
| } else { |
| access_type = _.FindDef(_.GetComponentType(access_type->id())); |
| } |
| |
| if (access_type->opcode() == spv::Op::OpTypeMatrix) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to a column or a " |
| "component of a column of a matrix."; |
| } |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| return SPV_SUCCESS; |
| } |
| |
| // Traces the variable pointer inst backwards. |
| // checker is called on each visited instruction. |
| spv_result_t TraceVariablePointers( |
| ValidationState_t& _, const Instruction* inst, |
| const std::function<spv_result_t(ValidationState_t&, const Instruction*)>& |
| checker) { |
| std::vector<const Instruction*> stack; |
| std::unordered_set<const Instruction*> seen; |
| stack.push_back(inst); |
| while (!stack.empty()) { |
| const Instruction* trace_inst = stack.back(); |
| stack.pop_back(); |
| |
| if (!seen.insert(trace_inst).second) { |
| continue; |
| } |
| |
| if (auto error = checker(_, trace_inst)) { |
| return error; |
| } |
| |
| const auto untyped = spvOpcodeGeneratesUntypedPointer(trace_inst->opcode()); |
| switch (trace_inst->opcode()) { |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| case spv::Op::OpPtrAccessChain: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| break; |
| case spv::Op::OpUntypedAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsAccessChainKHR: |
| case spv::Op::OpUntypedPtrAccessChainKHR: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(3))); |
| break; |
| case spv::Op::OpPhi: |
| for (uint32_t i = 2; i < trace_inst->operands().size(); i += 2) { |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(i))); |
| } |
| break; |
| case spv::Op::OpSelect: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(3))); |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(4))); |
| break; |
| case spv::Op::OpFunctionParameter: { |
| // Jump to function calls |
| auto func = trace_inst->function(); |
| auto func_inst = _.FindDef(func->id()); |
| |
| const auto param_inst_num = trace_inst - &_.ordered_instructions()[0]; |
| uint32_t param_index = 0; |
| uint32_t inst_index = 1; |
| while (_.ordered_instructions()[param_inst_num - inst_index].opcode() != |
| spv::Op::OpFunction) { |
| if (_.ordered_instructions()[param_inst_num - inst_index].opcode() == |
| spv::Op::OpFunctionParameter) { |
| param_index++; |
| } |
| ++inst_index; |
| } |
| |
| for (const auto& use_pair : func_inst->uses()) { |
| const auto use_inst = use_pair.first; |
| if (use_inst->opcode() == spv::Op::OpFunctionCall) { |
| const auto arg_id = |
| use_inst->GetOperandAs<uint32_t>(3 + param_index); |
| const auto arg_inst = _.FindDef(arg_id); |
| stack.push_back(arg_inst); |
| } |
| } |
| break; |
| } |
| case spv::Op::OpFunctionCall: { |
| // Jump to return values. |
| const auto* func = _.function(trace_inst->GetOperandAs<uint32_t>(2)); |
| for (auto* bb : func->ordered_blocks()) { |
| const auto* terminator = bb->terminator(); |
| if (terminator->opcode() == spv::Op::OpReturnValue) { |
| stack.push_back(terminator); |
| } |
| } |
| break; |
| } |
| case spv::Op::OpReturnValue: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(0))); |
| break; |
| case spv::Op::OpCopyObject: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| break; |
| case spv::Op::OpLoad: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| break; |
| case spv::Op::OpStore: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(0))); |
| break; |
| case spv::Op::OpVariable: |
| case spv::Op::OpUntypedVariableKHR: { |
| const auto sc = trace_inst->GetOperandAs<spv::StorageClass>(2); |
| if (sc == spv::StorageClass::Function || |
| sc == spv::StorageClass::Private) { |
| // Add the initializer |
| const uint32_t init_operand = untyped ? 4 : 3; |
| if (trace_inst->operands().size() > init_operand) { |
| stack.push_back( |
| _.FindDef(trace_inst->GetOperandAs<uint32_t>(init_operand))); |
| } |
| // Jump to stores |
| std::vector<std::pair<const Instruction*, uint32_t>> store_stack( |
| trace_inst->uses()); |
| std::unordered_set<const Instruction*> store_seen; |
| while (!store_stack.empty()) { |
| const auto use = store_stack.back(); |
| store_stack.pop_back(); |
| |
| if (!store_seen.insert(use.first).second) { |
| continue; |
| } |
| |
| // If the use is a store pointer, trace the store object. |
| // Note: use.second is a word index. |
| if (use.first->opcode() == spv::Op::OpStore && use.second == 1) { |
| stack.push_back(_.FindDef(use.first->GetOperandAs<uint32_t>(1))); |
| } else { |
| // Most likely a gep so keep tracing. |
| for (auto& next_use : use.first->uses()) { |
| store_stack.push_back(next_use); |
| } |
| } |
| } |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| // Traces the variable pointer inst backwards, but only unmodified pointers. |
| // checker is called on each visited instruction. |
| spv_result_t TraceUnmodifiedVariablePointers( |
| ValidationState_t& _, const Instruction* inst, |
| const std::function<spv_result_t(ValidationState_t&, const Instruction*)>& |
| checker) { |
| std::vector<const Instruction*> stack; |
| std::unordered_set<const Instruction*> seen; |
| stack.push_back(inst); |
| while (!stack.empty()) { |
| const Instruction* trace_inst = stack.back(); |
| stack.pop_back(); |
| |
| if (!seen.insert(trace_inst).second) { |
| continue; |
| } |
| |
| if (auto error = checker(_, trace_inst)) { |
| return error; |
| } |
| |
| const auto untyped = spvOpcodeGeneratesUntypedPointer(trace_inst->opcode()); |
| switch (trace_inst->opcode()) { |
| case spv::Op::OpAccessChain: |
| case spv::Op::OpInBoundsAccessChain: |
| if (trace_inst->operands().size() == 2) { |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| } |
| break; |
| case spv::Op::OpUntypedAccessChainKHR: |
| case spv::Op::OpUntypedInBoundsAccessChainKHR: |
| case spv::Op::OpUntypedPtrAccessChainKHR: |
| if (trace_inst->operands().size() == 3) { |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(3))); |
| } |
| break; |
| case spv::Op::OpPhi: |
| for (uint32_t i = 2; i < trace_inst->operands().size(); i += 2) { |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(i))); |
| } |
| break; |
| case spv::Op::OpSelect: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(3))); |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(4))); |
| break; |
| case spv::Op::OpFunctionParameter: { |
| // Jump to function calls |
| auto func = trace_inst->function(); |
| auto func_inst = _.FindDef(func->id()); |
| |
| const auto param_inst_num = trace_inst - &_.ordered_instructions()[0]; |
| uint32_t param_index = 0; |
| uint32_t inst_index = 1; |
| while (_.ordered_instructions()[param_inst_num - inst_index].opcode() != |
| spv::Op::OpFunction) { |
| if (_.ordered_instructions()[param_inst_num - inst_index].opcode() == |
| spv::Op::OpFunctionParameter) { |
| param_index++; |
| } |
| ++inst_index; |
| } |
| |
| for (const auto& use_pair : func_inst->uses()) { |
| const auto use_inst = use_pair.first; |
| if (use_inst->opcode() == spv::Op::OpFunctionCall) { |
| const auto arg_id = |
| use_inst->GetOperandAs<uint32_t>(3 + param_index); |
| const auto arg_inst = _.FindDef(arg_id); |
| stack.push_back(arg_inst); |
| } |
| } |
| break; |
| } |
| case spv::Op::OpFunctionCall: { |
| // Jump to return values. |
| const auto* func = _.function(trace_inst->GetOperandAs<uint32_t>(2)); |
| for (auto* bb : func->ordered_blocks()) { |
| const auto* terminator = bb->terminator(); |
| if (terminator->opcode() == spv::Op::OpReturnValue) { |
| stack.push_back(terminator); |
| } |
| } |
| break; |
| } |
| case spv::Op::OpReturnValue: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(0))); |
| break; |
| case spv::Op::OpCopyObject: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| break; |
| case spv::Op::OpLoad: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(2))); |
| break; |
| case spv::Op::OpStore: |
| stack.push_back(_.FindDef(trace_inst->GetOperandAs<uint32_t>(0))); |
| break; |
| case spv::Op::OpVariable: |
| case spv::Op::OpUntypedVariableKHR: { |
| const auto sc = trace_inst->GetOperandAs<spv::StorageClass>(2); |
| if (sc == spv::StorageClass::Function || |
| sc == spv::StorageClass::Private) { |
| // Add the initializer |
| const uint32_t init_operand = untyped ? 4 : 3; |
| if (trace_inst->operands().size() > init_operand) { |
| stack.push_back( |
| _.FindDef(trace_inst->GetOperandAs<uint32_t>(init_operand))); |
| } |
| // Jump to stores |
| std::vector<std::pair<const Instruction*, uint32_t>> store_stack( |
| trace_inst->uses()); |
| std::unordered_set<const Instruction*> store_seen; |
| while (!store_stack.empty()) { |
| const auto use = store_stack.back(); |
| store_stack.pop_back(); |
| |
| if (!store_seen.insert(use.first).second) { |
| continue; |
| } |
| |
| // If the use is a store pointer, trace the store object. |
| // Note: use.second is a word index. |
| if (use.first->opcode() == spv::Op::OpStore && use.second == 1) { |
| stack.push_back(_.FindDef(use.first->GetOperandAs<uint32_t>(1))); |
| } else { |
| // Most likely a gep so keep tracing. |
| for (auto& next_use : use.first->uses()) { |
| store_stack.push_back(next_use); |
| } |
| } |
| } |
| } |
| break; |
| } |
| default: |
| break; |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| spv_result_t ValidateVariablePointers( |
| ValidationState_t& _, std::unordered_map<uint32_t, bool>& variable_pointers, |
| const Instruction* inst) { |
| // Variable pointers cannot be operands to array length. |
| if (inst->opcode() == spv::Op::OpArrayLength || |
| inst->opcode() == spv::Op::OpUntypedArrayLengthKHR) { |
| const auto ptr_index = inst->opcode() == spv::Op::OpArrayLength ? 2 : 3; |
| const auto ptr_id = inst->GetOperandAs<uint32_t>(ptr_index); |
| const auto ptr_inst = _.FindDef(ptr_id); |
| if (IsVariablePointer(_, variable_pointers, ptr_inst)) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Pointer operand must not be a variable pointer"; |
| } |
| return SPV_SUCCESS; |
| } |
| |
| // Check untyped loads and stores of variable pointers for matrix types. |
| // Neither instruction would be a variable pointer in a such a case. |
| if (inst->opcode() == spv::Op::OpLoad) { |
| const auto pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2)); |
| const auto pointer_type = _.FindDef(pointer->type_id()); |
| if (pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR && |
| IsVariablePointer(_, variable_pointers, pointer)) { |
| const auto data_type = _.FindDef(inst->type_id()); |
| if (_.ContainsType( |
| data_type->id(), |
| [](const Instruction* type_inst) { |
| return type_inst->opcode() == spv::Op::OpTypeMatrix; |
| }, |
| /* traverse_all_types = */ false)) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to an object that is or " |
| "contains a matrix"; |
| } |
| } |
| } else if (inst->opcode() == spv::Op::OpStore) { |
| const auto pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0)); |
| const auto pointer_type = _.FindDef(pointer->type_id()); |
| if (pointer_type->opcode() == spv::Op::OpTypeUntypedPointerKHR && |
| IsVariablePointer(_, variable_pointers, pointer)) { |
| const auto data_type_id = _.GetOperandTypeId(inst, 1); |
| const auto data_type = _.FindDef(data_type_id); |
| if (_.ContainsType( |
| data_type->id(), |
| [](const Instruction* type_inst) { |
| return type_inst->opcode() == spv::Op::OpTypeMatrix; |
| }, |
| /* traverse_all_types = */ false)) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to an object that is or " |
| "contains a matrix"; |
| } |
| } |
| } |
| |
| if (!IsLogicalPointer(_, inst) || |
| !IsVariablePointer(_, variable_pointers, inst)) { |
| return SPV_SUCCESS; |
| } |
| |
| const auto result_type = _.FindDef(inst->type_id()); |
| const auto untyped = |
| result_type->opcode() == spv::Op::OpTypeUntypedPointerKHR; |
| |
| // Pointers must be selected from the same buffer unless the VariablePointers |
| // capability is declared. |
| if (!_.HasCapability(spv::Capability::VariablePointers) && |
| (inst->opcode() == spv::Op::OpSelect || |
| inst->opcode() == spv::Op::OpPhi)) { |
| std::unordered_set<const Instruction*> sources; |
| const auto checker = [&sources, &inst]( |
| ValidationState_t& vstate, |
| const Instruction* check_inst) -> spv_result_t { |
| switch (check_inst->opcode()) { |
| case spv::Op::OpVariable: |
| case spv::Op::OpUntypedVariableKHR: |
| if (check_inst->GetOperandAs<spv::StorageClass>(2) == |
| spv::StorageClass::StorageBuffer || |
| check_inst->GetOperandAs<spv::StorageClass>(2) == |
| spv::StorageClass::Workgroup) { |
| sources.insert(check_inst); |
| } |
| if (sources.size() > 1) { |
| return vstate.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointers must point into the same structure " |
| "(or OpConstantNull)"; |
| } |
| break; |
| default: |
| break; |
| } |
| return SPV_SUCCESS; |
| }; |
| if (auto error = TraceVariablePointers(_, inst, checker)) { |
| return error; |
| } |
| } |
| |
| // Variable pointers must not: |
| // * point to array of Block- or BufferBlock-decorated structs |
| // * point to an object that is or contains a matrix |
| // * point to a column, or component in a column, of a matrix |
| if (untyped) { |
| if (auto error = |
| TraceVariablePointers(_, inst, CheckMatrixElementUntyped)) { |
| return error; |
| } |
| |
| // Block arrays can only really appear as the top most type so only look at |
| // unmodified pointers to determine if one is used. |
| const auto num_operands = inst->operands().size(); |
| if (!(num_operands == 3 && |
| (inst->opcode() == spv::Op::OpUntypedAccessChainKHR || |
| inst->opcode() == spv::Op::OpUntypedInBoundsAccessChainKHR || |
| inst->opcode() == spv::Op::OpUntypedPtrAccessChainKHR))) { |
| const auto checker = [&inst]( |
| ValidationState_t& vstate, |
| const Instruction* check_inst) -> spv_result_t { |
| bool fail = false; |
| if (check_inst->opcode() == spv::Op::OpUntypedVariableKHR) { |
| if (check_inst->operands().size() > 3) { |
| const auto type = |
| vstate.FindDef(check_inst->GetOperandAs<uint32_t>(3)); |
| fail = IsBlockArray(vstate, type); |
| } |
| } else if (check_inst->opcode() == spv::Op::OpVariable) { |
| const auto res_type = vstate.FindDef(check_inst->type_id()); |
| const auto pointee_type = |
| vstate.FindDef(res_type->GetOperandAs<uint32_t>(2)); |
| fail = IsBlockArray(vstate, pointee_type); |
| } |
| |
| if (fail) { |
| return vstate.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to an array of Block- or " |
| "BufferBlock-decorated structs"; |
| } |
| return SPV_SUCCESS; |
| }; |
| |
| if (auto error = TraceUnmodifiedVariablePointers(_, inst, checker)) { |
| return error; |
| } |
| } |
| } else { |
| const auto pointee_type = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); |
| if (IsBlockArray(_, pointee_type)) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to an array of Block- or " |
| "BufferBlock-decorated structs"; |
| } else if (_.ContainsType( |
| pointee_type->id(), |
| [](const Instruction* type_inst) { |
| return type_inst->opcode() == spv::Op::OpTypeMatrix; |
| }, |
| /* traverse_all_types = */ false)) { |
| return _.diag(SPV_ERROR_INVALID_DATA, inst) |
| << "Variable pointer must not point to an object that is or " |
| "contains a matrix"; |
| } else if (_.IsFloatScalarOrVectorType(pointee_type->id())) { |
| // Pointing to a column or component in a column is trickier to detect. |
| // Trace backwards and check encountered access chains to determine if |
| // this pointer is pointing into a matrix. |
| if (auto error = |
| TraceVariablePointers(_, inst, CheckMatrixElementTyped)) { |
| return error; |
| } |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| } // namespace |
| |
| spv_result_t ValidateLogicalPointers(ValidationState_t& _) { |
| // Only the following addressing models have logical pointers. |
| if (_.addressing_model() != spv::AddressingModel::Logical && |
| _.addressing_model() != spv::AddressingModel::PhysicalStorageBuffer64) { |
| return SPV_SUCCESS; |
| } |
| |
| if (_.options()->relax_logical_pointer) { |
| return SPV_SUCCESS; |
| } |
| |
| // Cache all variable pointers |
| std::unordered_map<uint32_t, bool> variable_pointers; |
| for (auto& inst : _.ordered_instructions()) { |
| if (!IsLogicalPointer(_, &inst)) { |
| continue; |
| } |
| |
| IsVariablePointer(_, variable_pointers, &inst); |
| } |
| |
| for (auto& inst : _.ordered_instructions()) { |
| if (auto error = ValidateLogicalPointerOperands(_, &inst)) { |
| return error; |
| } |
| if (auto error = ValidateLogicalPointerReturns(_, &inst)) { |
| return error; |
| } |
| if (auto error = ValidateVariablePointers(_, variable_pointers, &inst)) { |
| return error; |
| } |
| } |
| |
| return SPV_SUCCESS; |
| } |
| |
| } // namespace val |
| } // namespace spvtools |