Extra restrictions for accesses of block arrays
* If a PtrAccessChain is rooted on a block, element (if constant) must
be zero
* UntypedAccessChains check that block arrays must not be reinterpreted
* Basic element operand checks for ptr access chains
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index ae7de40..4d98ac7 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1555,6 +1555,55 @@
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Base type must be a non-pointer type";
}
+
+ const auto ContainsBlock = [&_](const Instruction* type_inst) {
+ if (type_inst->opcode() == spv::Op::OpTypeStruct) {
+ if (_.HasDecoration(type_inst->id(), spv::Decoration::Block) ||
+ _.HasDecoration(type_inst->id(), spv::Decoration::BufferBlock)) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ const bool base_type_block_array =
+ base_type->opcode() == spv::Op::OpTypeArray &&
+ _.ContainsType(base_type->id(), ContainsBlock,
+ /* traverse_all_types = */ false);
+
+ const auto base_index = untyped_pointer ? 3 : 2;
+ const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
+ auto base = _.FindDef(base_id);
+ while (base->opcode() == spv::Op::OpCopyObject) {
+ base = _.FindDef(base->GetOperandAs<uint32_t>(2));
+ }
+ const Instruction* base_data_type = nullptr;
+ if (base->opcode() == spv::Op::OpVariable) {
+ const auto ptr_type = _.FindDef(base->type_id());
+ base_data_type = _.FindDef(ptr_type->GetOperandAs<uint32_t>(2));
+ } else if (base->opcode() == spv::Op::OpUntypedVariableKHR) {
+ if (base->operands().size() > 3) {
+ base_data_type = _.FindDef(base->GetOperandAs<uint32_t>(3));
+ }
+ }
+
+ if (base_data_type) {
+ const bool base_block_array =
+ base_data_type->opcode() == spv::Op::OpTypeArray &&
+ _.ContainsType(base_data_type->id(), ContainsBlock,
+ /* traverse_all_types = */ false);
+
+ if (base_type_block_array != base_block_array) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Both Base Type and Base must be Block or BufferBlock arrays "
+ "or neither can be";
+ } else if (base_type_block_array && base_block_array &&
+ base_type->id() != base_data_type->id()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "If Base or Base Type is a Block or BufferBlock array, the "
+ "other must also be the same array";
+ }
+ }
}
// Base must be a pointer, pointing to the base of a composite object.
@@ -1845,14 +1894,34 @@
const bool untyped_pointer = spvOpcodeGeneratesUntypedPointer(inst->opcode());
- const auto base_id = inst->GetOperandAs<uint32_t>(2);
- const auto base = _.FindDef(base_id);
- const auto base_type = untyped_pointer
- ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
- : _.FindDef(base->type_id());
+ const auto base_idx = untyped_pointer ? 3 : 2;
+ const auto base = _.FindDef(inst->GetOperandAs<uint32_t>(base_idx));
+ const auto base_type = _.FindDef(base->type_id());
const auto base_type_storage_class =
base_type->GetOperandAs<spv::StorageClass>(1);
+ const auto element_idx = untyped_pointer ? 4 : 3;
+ const auto element = _.FindDef(inst->GetOperandAs<uint32_t>(element_idx));
+ const auto element_type = _.FindDef(element->type_id());
+ if (!element_type || element_type->opcode() != spv::Op::OpTypeInt) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Element must be an integer";
+ }
+ uint64_t element_val = 0;
+ if (_.EvalConstantValUint64(element->id(), &element_val)) {
+ if (element_val != 0) {
+ const auto interp_type = untyped_pointer
+ ? _.FindDef(inst->GetOperandAs<uint32_t>(2))
+ : _.FindDef(base_type->GetOperandAs<uint32_t>(2));
+ if (interp_type->opcode() == spv::Op::OpTypeStruct &&
+ (_.HasDecoration(interp_type->id(), spv::Decoration::Block) ||
+ _.HasDecoration(interp_type->id(), spv::Decoration::BufferBlock))) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Element must be 0 if the interpretation type is a Block- or "
+ "BufferBlock-decorated structure";
+ }
+ }
+ }
+
if (_.HasCapability(spv::Capability::Shader) &&
(base_type_storage_class == spv::StorageClass::Uniform ||
base_type_storage_class == spv::StorageClass::StorageBuffer ||
diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp
index bb0ec03..8640865 100644
--- a/test/opt/eliminate_dead_member_test.cpp
+++ b/test/opt/eliminate_dead_member_test.cpp
@@ -958,7 +958,7 @@
; CHECK: OpMemberDecorate %type__Globals 1 Offset 16
; CHECK: %type__Globals = OpTypeStruct %float %float
; CHECK: [[ac:%\w+]] = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0
-; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_1 %uint_0
+; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_0
; CHECK: OpPtrAccessChain %_ptr_Uniform_float [[ac]] %uint_0 %uint_1
OpCapability Shader
OpCapability VariablePointersStorageBuffer
@@ -995,7 +995,7 @@
%main = OpFunction %void None %14
%16 = OpLabel
%17 = OpAccessChain %_ptr_Uniform_type__Globals %_Globals %uint_0
- %18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_1 %uint_0
+ %18 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_0
%19 = OpPtrAccessChain %_ptr_Uniform_float %17 %uint_0 %uint_2
OpReturn
OpFunctionEnd
diff --git a/test/val/val_decoration_test.cpp b/test/val/val_decoration_test.cpp
index 7a19df6..6dad5fb 100644
--- a/test/val/val_decoration_test.cpp
+++ b/test/val/val_decoration_test.cpp
@@ -10304,6 +10304,7 @@
OpMemberDecorate %struct 1 Offset 4
)" + set + R"(OpMemberDecorate %test_type 0 Offset 0
OpMemberDecorate %test_type 1 Offset 1
+OpDecorate %ptr ArrayStride 16
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_0 = OpConstant %int 0
@@ -10355,6 +10356,7 @@
OpMemberDecorate %struct 0 Offset 0
OpMemberDecorate %struct 1 Offset 4
)" + set + R"(OpDecorate %test_type ArrayStride 4
+OpDecorate %ptr ArrayStride 16
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_0 = OpConstant %int 0
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index 15b1663..d8db7c8 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -6148,6 +6148,7 @@
const bool ptr = opcode == "OpUntypedPtrAccessChainKHR" ||
opcode == "OpUntypedInBoundsPtrAccessChainKHR";
const std::string extra_param = ptr ? "%int_0" : "";
+ const std::string deco = ptr ? "OpDecorate %ptr_ssbo ArrayStride 4" : "";
const std::string spirv = R"(
OpCapability Shader
@@ -6158,6 +6159,7 @@
OpExtension "SPV_KHR_untyped_pointers"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
+)" + deco + R"(
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_0 = OpConstant %int 0
@@ -6183,6 +6185,7 @@
const bool ptr = opcode == "OpUntypedPtrAccessChainKHR" ||
opcode == "OpUntypedInBoundsPtrAccessChainKHR";
const std::string extra_param = ptr ? "%int_0" : "";
+ const std::string deco = ptr ? "OpDecorate %ptr ArrayStride 4" : "";
const std::string spirv = R"(
OpCapability Shader
@@ -6193,6 +6196,7 @@
OpExtension "SPV_KHR_untyped_pointers"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
+)" + deco + R"(
%void = OpTypeVoid
%int = OpTypeInt 32 0
%int_0 = OpConstant %int 0