Avoid replacing access chain with OOB access (#4819)
An access chain could have a constant index that is an out of bounds
access. This is valid spir-v, even if it can cause problems at runtime.
However, it is not valid to have an OpCompositeExtract with an out of
bounds access. This means we have to stop local-access-chain-convert
from making that change.
Fixes #4605
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index da4cac3..9491798 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -237,7 +237,8 @@
}
// Rule out variables with nested access chains
// TODO(): Convert nested access chains
- if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
+ bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
+ if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
kAccessChainPtrIdInIdx) != varId) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
@@ -249,6 +250,12 @@
seen_target_vars_.erase(varId);
break;
}
+
+ if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
+ seen_non_target_vars_.insert(varId);
+ seen_target_vars_.erase(varId);
+ break;
+ }
} break;
default:
break;
@@ -446,5 +453,42 @@
});
}
+bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
+ const Instruction* access_chain_inst) {
+ assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
+
+ analysis::TypeManager* type_mgr = context()->get_type_mgr();
+ analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+ auto constants = const_mgr->GetOperandConstants(access_chain_inst);
+ uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
+ Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
+ const analysis::Pointer* base_pointer_type =
+ type_mgr->GetType(base_pointer->type_id())->AsPointer();
+ assert(base_pointer_type != nullptr &&
+ "The base of the access chain is not a pointer.");
+ const analysis::Type* current_type = base_pointer_type->pointee_type();
+ for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
+ if (IsIndexOutOfBounds(constants[i], current_type)) {
+ return true;
+ }
+
+ uint32_t index =
+ (constants[i]
+ ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
+ : 0);
+ current_type = type_mgr->GetMemberType(current_type, {index});
+ }
+
+ return false;
+}
+
+bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
+ const analysis::Constant* index, const analysis::Type* type) const {
+ if (index == nullptr) {
+ return false;
+ }
+ return index->GetZeroExtendedValue() >= type->NumberOfComponents();
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h
index 8548e16..eabf864 100644
--- a/source/opt/local_access_chain_convert_pass.h
+++ b/source/opt/local_access_chain_convert_pass.h
@@ -111,6 +111,17 @@
// Returns a status to indicate success or failure, and change or no change.
Status ConvertLocalAccessChains(Function* func);
+ // Returns true one of the indexes in the |access_chain_inst| is definitly out
+ // of bounds. If the size of the type or the value of the index is unknown,
+ // then it will be considered in-bounds.
+ bool AnyIndexIsOutOfBounds(const Instruction* access_chain_inst);
+
+ // Returns true if getting element |index| from |type| would be out-of-bounds.
+ // If |index| is nullptr or the size of the type are unknown, then it will be
+ // considered in-bounds.
+ bool IsIndexOutOfBounds(const analysis::Constant* index,
+ const analysis::Type* type) const;
+
// Initialize extensions allowlist
void InitExtensions();
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index ebbdc36..056aceb 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -16,6 +16,7 @@
#include <algorithm>
#include <cassert>
+#include <climits>
#include <cstdint>
#include <sstream>
#include <string>
@@ -246,6 +247,35 @@
return ComputeHashValue(0, &seen);
}
+uint64_t Type::NumberOfComponents() const {
+ switch (kind()) {
+ case kVector:
+ return AsVector()->element_count();
+ case kMatrix:
+ return AsMatrix()->element_count();
+ case kArray: {
+ Array::LengthInfo length_info = AsArray()->length_info();
+ if (length_info.words[0] != Array::LengthInfo::kConstant) {
+ return UINT64_MAX;
+ }
+ assert(length_info.words.size() <= 3 &&
+ "The size of the array could not fit size_t.");
+ uint64_t length = 0;
+ length |= length_info.words[1];
+ if (length_info.words.size() > 2) {
+ length |= static_cast<uint64_t>(length_info.words[2]) << 32;
+ }
+ return length;
+ }
+ case kRuntimeArray:
+ return UINT64_MAX;
+ case kStruct:
+ return AsStruct()->element_types().size();
+ default:
+ return 0;
+ }
+}
+
bool Integer::IsSameImpl(const Type* that, IsSameCache*) const {
const Integer* it = that->AsInteger();
return it && width_ == it->width_ && signed_ == it->signed_ &&
diff --git a/source/opt/types.h b/source/opt/types.h
index f5a4a6b..a92669e 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -160,6 +160,10 @@
size_t ComputeHashValue(size_t hash, SeenTypes* seen) const;
+ // Returns the number of components in a composite type. Returns 0 for a
+ // non-composite type.
+ uint64_t NumberOfComponents() const;
+
// A bunch of methods for casting this type to a given type. Returns this if the
// cast can be done, nullptr otherwise.
// clang-format off
diff --git a/test/opt/local_access_chain_convert_test.cpp b/test/opt/local_access_chain_convert_test.cpp
index 2b3231c..6f5021c 100644
--- a/test/opt/local_access_chain_convert_test.cpp
+++ b/test/opt/local_access_chain_convert_test.cpp
@@ -1252,6 +1252,70 @@
true);
}
+TEST_F(LocalAccessChainConvertTest, OutOfBoundsAccess) {
+ // The access chain indexes element 12 in an array of size 10. Nothing should
+ // be done.
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%int_10 = OpConstant %int 10
+%_arr_int_int_10 = OpTypeArray %int %int_10
+%_ptr_Function_int = OpTypePointer Function %int
+%int_12 = OpConstant %int 12
+%_ptr_Output_int = OpTypePointer Output %int
+%3 = OpVariable %_ptr_Output_int Output
+%_ptr_Function__arr_int_int_10 = OpTypePointer Function %_arr_int_int_10
+%2 = OpFunction %void None %5
+%13 = OpLabel
+%14 = OpVariable %_ptr_Function__arr_int_int_10 Function
+%15 = OpAccessChain %_ptr_Function_int %14 %int_12
+%16 = OpLoad %int %15
+OpStore %3 %16
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<LocalAccessChainConvertPass>(assembly, assembly, false,
+ true);
+}
+
+TEST_F(LocalAccessChainConvertTest, OutOfBoundsAccessAtBoundary) {
+ // The access chain indexes element 10 in an array of size 10. Nothing should
+ // be done.
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main" %3
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%int_10 = OpConstant %int 10
+%_arr_int_int_10 = OpTypeArray %int %int_10
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Output_int = OpTypePointer Output %int
+%3 = OpVariable %_ptr_Output_int Output
+%_ptr_Function__arr_int_int_10 = OpTypePointer Function %_arr_int_int_10
+%2 = OpFunction %void None %5
+%12 = OpLabel
+%13 = OpVariable %_ptr_Function__arr_int_int_10 Function
+%14 = OpAccessChain %_ptr_Function_int %13 %int_10
+%15 = OpLoad %int %14
+OpStore %3 %15
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<LocalAccessChainConvertPass>(assembly, assembly, false,
+ true);
+}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// Assorted vector and matrix types
diff --git a/test/opt/types_test.cpp b/test/opt/types_test.cpp
index 82e4040..552ad97 100644
--- a/test/opt/types_test.cpp
+++ b/test/opt/types_test.cpp
@@ -266,6 +266,67 @@
}
}
+TEST(Types, TestNumberOfComponentsOnArrays) {
+ Float f32(32);
+ EXPECT_EQ(f32.NumberOfComponents(), 0);
+
+ Array array_size_42(
+ &f32, Array::LengthInfo{99u, {Array::LengthInfo::kConstant, 42u}});
+ EXPECT_EQ(array_size_42.NumberOfComponents(), 42);
+
+ Array array_size_0xDEADBEEF00C0FFEE(
+ &f32, Array::LengthInfo{
+ 99u, {Array::LengthInfo::kConstant, 0xC0FFEE, 0xDEADBEEF}});
+ EXPECT_EQ(array_size_0xDEADBEEF00C0FFEE.NumberOfComponents(),
+ 0xDEADBEEF00C0FFEEull);
+
+ Array array_size_unknown(
+ &f32,
+ Array::LengthInfo{99u, {Array::LengthInfo::kConstantWithSpecId, 10}});
+ EXPECT_EQ(array_size_unknown.NumberOfComponents(), UINT64_MAX);
+
+ RuntimeArray runtime_array(&f32);
+ EXPECT_EQ(runtime_array.NumberOfComponents(), UINT64_MAX);
+}
+
+TEST(Types, TestNumberOfComponentsOnVectors) {
+ Float f32(32);
+ EXPECT_EQ(f32.NumberOfComponents(), 0);
+
+ for (uint32_t vector_size = 1; vector_size < 4; ++vector_size) {
+ Vector vector(&f32, vector_size);
+ EXPECT_EQ(vector.NumberOfComponents(), vector_size);
+ }
+}
+
+TEST(Types, TestNumberOfComponentsOnMatrices) {
+ Float f32(32);
+ Vector vector(&f32, 2);
+
+ for (uint32_t number_of_columns = 1; number_of_columns < 4;
+ ++number_of_columns) {
+ Matrix matrix(&vector, number_of_columns);
+ EXPECT_EQ(matrix.NumberOfComponents(), number_of_columns);
+ }
+}
+
+TEST(Types, TestNumberOfComponentsOnStructs) {
+ Float f32(32);
+ Vector vector(&f32, 2);
+
+ Struct empty_struct({});
+ EXPECT_EQ(empty_struct.NumberOfComponents(), 0);
+
+ Struct struct_f32({&f32});
+ EXPECT_EQ(struct_f32.NumberOfComponents(), 1);
+
+ Struct struct_f32_vec({&f32, &vector});
+ EXPECT_EQ(struct_f32_vec.NumberOfComponents(), 2);
+
+ Struct struct_100xf32(std::vector<const Type*>(100, &f32));
+ EXPECT_EQ(struct_100xf32.NumberOfComponents(), 100);
+}
+
TEST(Types, IntSignedness) {
std::vector<bool> signednesses = {true, false, false, true};
std::vector<std::unique_ptr<Integer>> types;