Extra restrictions for accesses of block arrays * If a PtrAccessChain is rooted on a block, element (if constant) must be zero * UntypedAccessChains check that block arrays must not be reinterpreted * Basic element operand checks for ptr access chains
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index ae7de40..4d98ac7 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp
@@ -1555,6 +1555,55 @@ return _.diag(SPV_ERROR_INVALID_ID, inst) << "Base type must be a non-pointer type"; } + + const auto ContainsBlock = [&_](const Instruction* type_inst) { + if (type_inst->opcode() == spv::Op::OpTypeStruct) { + if (_.HasDecoration(type_inst->id(), spv::Decoration::Block) || + _.HasDecoration(type_inst->id(), spv::Decoration::BufferBlock)) { + return true; + } + } + return false; + }; + + const bool base_type_block_array = + base_type->opcode() == spv::Op::OpTypeArray && + _.ContainsType(base_type->id(), ContainsBlock, + /* traverse_all_types = */ false); + + const auto base_index = untyped_pointer ? 3 : 2; + const auto base_id = inst->GetOperandAs<uint32_t>(base_index); + auto base = _.FindDef(base_id); + while (base->opcode() == spv::Op::OpCopyObject) { + base = _.FindDef(base->GetOperandAs<uint32_t>(2)); + } + const Instruction* base_data_type = nullptr; + if (base->opcode() == spv::Op::OpVariable) { + const auto ptr_type = _.FindDef(base->type_id()); + base_data_type = _.FindDef(ptr_type->GetOperandAs<uint32_t>(2)); + } else if (base->opcode() == spv::Op::OpUntypedVariableKHR) { + if (base->operands().size() > 3) { + base_data_type = _.FindDef(base->GetOperandAs<uint32_t>(3)); + } + } + + if (base_data_type) { + const bool base_block_array = + base_data_type->opcode() == spv::Op::OpTypeArray && + _.ContainsType(base_data_type->id(), ContainsBlock, + /* traverse_all_types = */ false); + + if (base_type_block_array != base_block_array) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "Both Base Type and Base must be Block or BufferBlock arrays " + "or neither can be"; + } else if (base_type_block_array && base_block_array && + base_type->id() != base_data_type->id()) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "If Base or Base Type is a Block or BufferBlock array, the " + "other must also be the same array"; + } + } } // Base must be a pointer, pointing to the base of a composite object. @@ -1845,14 +1894,34 @@ const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode()); - const auto base_id = inst->GetOperandAs<uint32_t>(2); - const auto base = _.FindDef(base_id); - const auto base_type = untyped_pointer - ? _.FindDef(inst->GetOperandAs<uint32_t>(2)) - : _.FindDef(base->type_id()); + const auto base_idx = untyped_pointer ? 3 : 2; + const auto base = _.FindDef(inst->GetOperandAs<uint32_t>(base_idx)); + const auto base_type = _.FindDef(base->type_id()); const auto base_type_storage_class = base_type->GetOperandAs<spv::StorageClass>(1); + const auto element_idx = untyped_pointer ? 4 : 3; + const auto element = _.FindDef(inst->GetOperandAs<uint32_t>(element_idx)); + const auto element_type = _.FindDef(element->type_id()); + if (!element_type || element_type->opcode() != spv::Op::OpTypeInt) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Element must be an integer"; + } + uint64_t element_val = 0; + if (_.EvalConstantValUint64(element->id(), &element_val)) { + if (element_val != 0) { + const auto interp_type = untyped_pointer + ? _.FindDef(inst->GetOperandAs<uint32_t>(2)) + : _.FindDef(base_type->GetOperandAs<uint32_t>(2)); + if (interp_type->opcode() == spv::Op::OpTypeStruct && + (_.HasDecoration(interp_type->id(), spv::Decoration::Block) || + _.HasDecoration(interp_type->id(), spv::Decoration::BufferBlock))) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "Element must be 0 if the interpretation type is a Block- or " + "BufferBlock-decorated structure"; + } + } + } + if (_.HasCapability(spv::Capability::Shader) && (base_type_storage_class == spv::StorageClass::Uniform || base_type_storage_class == spv::StorageClass::StorageBuffer ||
diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp index bb0ec03..8640865 100644 --- a/test/opt/eliminate_dead_member_test.cpp +++ b/test/opt/eliminate_dead_member_test.cpp
@@ -958,7 +958,7 @@ ; CHECK: OpMemberDecorate %type__Globals 1 Offset 16 ; CHECK: %type__Globals = OpTypeStruct %float %float ; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0 -; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_1 %uint_0 +; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_0 ; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_1 OpCapability Shader OpCapability VariablePointersStorageBuffer @@ -995,7 +995,7 @@ %main = OpFunction %void None %14 %16 = OpLabel %17 = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0 - %18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_1 %uint_0 + %18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_0 %19 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_2 OpReturn OpFunctionEnd
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp index 7a19df6..6dad5fb 100644 --- a/test/val/val_decoration_test.cpp +++ b/test/val/val_decoration_test.cpp
@@ -10304,6 +10304,7 @@ OpMemberDecorate %struct 1 Offset 4 )" + set + R"(OpMemberDecorate %test_type 0 Offset 0 OpMemberDecorate %test_type 1 Offset 1 +OpDecorate %ptr ArrayStride 16 %void = OpTypeVoid %int = OpTypeInt 32 0 %int_0 = OpConstant %int 0 @@ -10355,6 +10356,7 @@ OpMemberDecorate %struct 0 Offset 0 OpMemberDecorate %struct 1 Offset 4 )" + set + R"(OpDecorate %test_type ArrayStride 4 +OpDecorate %ptr ArrayStride 16 %void = OpTypeVoid %int = OpTypeInt 32 0 %int_0 = OpConstant %int 0
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index 15b1663..d8db7c8 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp
@@ -6148,6 +6148,7 @@ const bool ptr = opcode == "OpUntypedPtrAccessChainKHR" || opcode == "OpUntypedInBoundsPtrAccessChainKHR"; const std::string extra_param = ptr ? "%int_0" : ""; + const std::string deco = ptr ? "OpDecorate %ptr_ssbo ArrayStride 4" : ""; const std::string spirv = R"( OpCapability Shader @@ -6158,6 +6159,7 @@ OpExtension "SPV_KHR_untyped_pointers" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" +)" + deco + R"( %void = OpTypeVoid %int = OpTypeInt 32 0 %int_0 = OpConstant %int 0 @@ -6183,6 +6185,7 @@ const bool ptr = opcode == "OpUntypedPtrAccessChainKHR" || opcode == "OpUntypedInBoundsPtrAccessChainKHR"; const std::string extra_param = ptr ? "%int_0" : ""; + const std::string deco = ptr ? "OpDecorate %ptr ArrayStride 4" : ""; const std::string spirv = R"( OpCapability Shader @@ -6193,6 +6196,7 @@ OpExtension "SPV_KHR_untyped_pointers" OpMemoryModel Logical GLSL450 OpEntryPoint GLCompute %main "main" +)" + deco + R"( %void = OpTypeVoid %int = OpTypeInt 32 0 %int_0 = OpConstant %int 0