Treat access chain indexes as signed in SROA (#2776)

Fixes #2768

* In scalar replacement, interpret access chain indexes as signed counts
* Use Constant::GetSignExtendedValue and Constant::GetZeroExtendedValue
where appropriate
* new tests

diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index b5875c9..5c1468b 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -291,7 +291,7 @@
   }
 }
 
-const Constant* ConstantManager::GetConstantFromInst(Instruction* inst) {
+const Constant* ConstantManager::GetConstantFromInst(const Instruction* inst) {
   std::vector<uint32_t> literal_words_or_ids;
 
   // Collect the constant defining literals or component ids.
diff --git a/source/opt/constants.h b/source/opt/constants.h
index 7b9f248..93d0847 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -522,7 +522,7 @@
   // Gets or creates a Constant instance to hold the constant value of the given
   // instruction. It returns a pointer to a Constant instance or nullptr if it
   // could not create the constant.
-  const Constant* GetConstantFromInst(Instruction* inst);
+  const Constant* GetConstantFromInst(const Instruction* inst);
 
   // Gets or creates a constant defining instruction for the given Constant |c|.
   // If |c| had already been defined, it returns a pointer to the existing
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 9ae1ae8..7f352df 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -232,8 +232,12 @@
   // indexes) or a direct use of the replacement variable.
   uint32_t indexId = chain->GetSingleWordInOperand(1u);
   const Instruction* index = get_def_use_mgr()->GetDef(indexId);
-  uint64_t indexValue = GetConstantInteger(index);
-  if (indexValue >= replacements.size()) {
+  int64_t indexValue = context()
+                           ->get_constant_mgr()
+                           ->GetConstantFromInst(index)
+                           ->GetSignExtendedValue();
+  if (indexValue < 0 ||
+      indexValue >= static_cast<int64_t>(replacements.size())) {
     // Out of bounds access, this is illegal IR.  Notice that OpAccessChain
     // indexing is 0-based, so we should also reject index == size-of-array.
     return false;
@@ -269,7 +273,7 @@
     Instruction* inst, std::vector<Instruction*>* replacements) {
   Instruction* type = GetStorageType(inst);
 
-  std::unique_ptr<std::unordered_set<uint64_t>> components_used =
+  std::unique_ptr<std::unordered_set<int64_t>> components_used =
       GetUsedComponents(inst);
 
   uint32_t elem = 0;
@@ -467,35 +471,15 @@
   }
 }
 
-uint64_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
-  assert(op.words.size() <= 2);
-  uint64_t len = 0;
-  for (uint32_t i = 0; i != op.words.size(); ++i) {
-    len |= (op.words[i] << (32 * i));
-  }
-  return len;
-}
-
-uint64_t ScalarReplacementPass::GetConstantInteger(
-    const Instruction* constant) const {
-  assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
-         SpvOpTypeInt);
-  assert(constant->opcode() == SpvOpConstant ||
-         constant->opcode() == SpvOpConstantNull);
-  if (constant->opcode() == SpvOpConstantNull) {
-    return 0;
-  }
-
-  const Operand& op = constant->GetInOperand(0u);
-  return GetIntegerLiteral(op);
-}
-
 uint64_t ScalarReplacementPass::GetArrayLength(
     const Instruction* arrayType) const {
   assert(arrayType->opcode() == SpvOpTypeArray);
   const Instruction* length =
       get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
-  return GetConstantInteger(length);
+  return context()
+      ->get_constant_mgr()
+      ->GetConstantFromInst(length)
+      ->GetZeroExtendedValue();
 }
 
 uint64_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
@@ -734,10 +718,10 @@
   return length > max_num_elements_;
 }
 
-std::unique_ptr<std::unordered_set<uint64_t>>
+std::unique_ptr<std::unordered_set<int64_t>>
 ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
-  std::unique_ptr<std::unordered_set<uint64_t>> result(
-      new std::unordered_set<uint64_t>());
+  std::unique_ptr<std::unordered_set<int64_t>> result(
+      new std::unordered_set<int64_t>());
 
   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
 
@@ -775,18 +759,8 @@
         const analysis::Constant* index_const =
             const_mgr->FindDeclaredConstant(index_id);
         if (index_const) {
-          const analysis::Integer* index_type =
-              index_const->type()->AsInteger();
-          assert(index_type);
-          if (index_type->width() == 32) {
-            result->insert(index_const->GetU32());
-            return true;
-          } else if (index_type->width() == 64) {
-            result->insert(index_const->GetU64());
-            return true;
-          }
-          result.reset(nullptr);
-          return false;
+          result->insert(index_const->GetSignExtendedValue());
+          return true;
         } else {
           // Could be any element.  Assuming all are used.
           result.reset(nullptr);
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index 3a17045..5b51981 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -158,14 +158,6 @@
   bool CreateReplacementVariables(Instruction* inst,
                                   std::vector<Instruction*>* replacements);
 
-  // Returns the value of an OpConstant of integer type.
-  //
-  // |constant| must use two or fewer words to generate the value.
-  uint64_t GetConstantInteger(const Instruction* constant) const;
-
-  // Returns the integer literal for |op|.
-  uint64_t GetIntegerLiteral(const Operand& op) const;
-
   // Returns the array length for |arrayInst|.
   uint64_t GetArrayLength(const Instruction* arrayInst) const;
 
@@ -216,7 +208,7 @@
   // Returns a set containing the which components of the result of |inst| are
   // potentially used.  If the return value is |nullptr|, then every components
   // is possibly used.
-  std::unique_ptr<std::unordered_set<uint64_t>> GetUsedComponents(
+  std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents(
       Instruction* inst);
 
   // Returns an instruction defining a null constant with type |type_id|.  If
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index 2ed7b5a..04721c2 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -1701,6 +1701,67 @@
   EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
 }
 
+TEST_F(ScalarReplacementTest, CharIndex) {
+  const std::string text = R"(
+; CHECK: [[int:%\w+]] = OpTypeInt 32 0
+; CHECK: [[ptr:%\w+]] = OpTypePointer Function [[int]]
+; CHECK: OpVariable [[ptr]] Function
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 0
+%char_1 = OpConstant %char 1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<ScalarReplacementPass>(text, true, 0);
+}
+
+TEST_F(ScalarReplacementTest, OutOfBoundsOpAccessChainNegative) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Int8
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int_1024 = OpConstant %int 1024
+%char = OpTypeInt 8 1
+%char_n1 = OpConstant %char -1
+%array = OpTypeArray %int %int_1024
+%ptr_func_array = OpTypePointer Function %array
+%ptr_func_int = OpTypePointer Function %int
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+%var = OpVariable %ptr_func_array Function
+%gep = OpAccessChain %ptr_func_int %var %char_n1
+OpStore %gep %int_1024
+OpReturn
+OpFunctionEnd
+)";
+
+  auto result =
+      SinglePassRunAndDisassemble<ScalarReplacementPass>(text, true, true, 0);
+  EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools