Treat access chain indexes as signed in SROA (#2776)
Fixes #2768
* In scalar replacement, interpret access chain indexes as signed counts
* Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue
where appropriate
* new tests
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index b5875c9..5c1468b 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -291,7 +291,7 @@
}
}
-const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
+const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
std::vector<uint32_t> literal_words_or_ids;
// Collect the constant defining literals or component ids.
diff --git a/source/opt/constants.h b/source/opt/constants.h
index 7b9f248..93d0847 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -522,7 +522,7 @@
// Gets or creates a Constant instance to hold the constant value of the given
// instruction. It returns a pointer to a Constant instance or nullptr if it
// could not create the constant.
- const Constant* GetConstantFromInst(Instruction* inst);
+ const Constant* GetConstantFromInst(const Instruction* inst);
// Gets or creates a constant defining instruction for the given Constant |c|.
// If |c| had already been defined, it returns a pointer to the existing
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 9ae1ae8..7f352df 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -232,8 +232,12 @@
// indexes) or a direct use of the replacement variable.
uint32_t indexId = chain->GetSingleWordInOperand(1u);
const Instruction* index = get_def_use_mgr()->GetDef(indexId);
- uint64_t indexValue = GetConstantInteger(index);
- if (indexValue >= replacements.size()) {
+ int64_t indexValue = context()
+ ->get_constant_mgr()
+ ->GetConstantFromInst(index)
+ ->GetSignExtendedValue();
+ if (indexValue < 0 ||
+ indexValue >= static_cast<int64_t>(replacements.size())) {
// Out of bounds access, this is illegal IR. Notice that OpAccessChain
// indexing is 0-based, so we should also reject index == size-of-array.
return false;
@@ -269,7 +273,7 @@
Instruction* inst, std::vector<Instruction*>* replacements) {
Instruction* type = GetStorageType(inst);
- std::unique_ptr<std::unordered_set<uint64_t>> components_used =
+ std::unique_ptr<std::unordered_set<int64_t>> components_used =
GetUsedComponents(inst);
uint32_t elem = 0;
@@ -467,35 +471,15 @@
}
}
-uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
- assert(op.words.size() <= 2);
- uint64_t len = 0;
- for (uint32_t i = 0; i != op.words.size(); ++i) {
- len |= (op.words[i] << (32 * i));
- }
- return len;
-}
-
-uint64_t ScalarReplacementPass::GetConstantInteger(
- const Instruction* constant) const {
- assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
- SpvOpTypeInt);
- assert(constant->opcode() == SpvOpConstant ||
- constant->opcode() == SpvOpConstantNull);
- if (constant->opcode() == SpvOpConstantNull) {
- return 0;
- }
-
- const Operand& op = constant->GetInOperand(0u);
- return GetIntegerLiteral(op);
-}
-
uint64_t ScalarReplacementPass::GetArrayLength(
const Instruction* arrayType) const {
assert(arrayType->opcode() == SpvOpTypeArray);
const Instruction* length =
get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
- return GetConstantInteger(length);
+ return context()
+ ->get_constant_mgr()
+ ->GetConstantFromInst(length)
+ ->GetZeroExtendedValue();
}
uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
@@ -734,10 +718,10 @@
return length > max_num_elements_;
}
-std::unique_ptr<std::unordered_set<uint64_t>>
+std::unique_ptr<std::unordered_set<int64_t>>
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
- std::unique_ptr<std::unordered_set<uint64_t>> result(
- new std::unordered_set<uint64_t>());
+ std::unique_ptr<std::unordered_set<int64_t>> result(
+ new std::unordered_set<int64_t>());
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
@@ -775,18 +759,8 @@
const analysis::Constant* index_const =
const_mgr->FindDeclaredConstant(index_id);
if (index_const) {
- const analysis::Integer* index_type =
- index_const->type()->AsInteger();
- assert(index_type);
- if (index_type->width() == 32) {
- result->insert(index_const->GetU32());
- return true;
- } else if (index_type->width() == 64) {
- result->insert(index_const->GetU64());
- return true;
- }
- result.reset(nullptr);
- return false;
+ result->insert(index_const->GetSignExtendedValue());
+ return true;
} else {
// Could be any element. Assuming all are used.
result.reset(nullptr);
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index 3a17045..5b51981 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -158,14 +158,6 @@
bool CreateReplacementVariables(Instruction* inst,
std::vector<Instruction*>* replacements);
- // Returns the value of an OpConstant of integer type.
- //
- // |constant| must use two or fewer words to generate the value.
- uint64_t GetConstantInteger(const Instruction* constant) const;
-
- // Returns the integer literal for |op|.
- uint64_t GetIntegerLiteral(const Operand& op) const;
-
// Returns the array length for |arrayInst|.
uint64_t GetArrayLength(const Instruction* arrayInst) const;
@@ -216,7 +208,7 @@
// Returns a set containing the which components of the result of |inst| are
// potentially used. If the return value is |nullptr|, then every components
// is possibly used.
- std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
+ std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
Instruction* inst);
// Returns an instruction defining a null constant with type |type_id|. If
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index 2ed7b5a..04721c2 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -1701,6 +1701,67 @@
EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
}
+TEST_F(ScalarReplacementTest, CharIndex) {
+ const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
+; CHECK: OpVariable [[ptr]] Function
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 0
+%char_1 = OpConstant %char 1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
+}
+
+TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 1
+%char_n1 = OpConstant %char -1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_n1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+ auto result =
+ SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
+ EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools