Even more id overflow in sroa (#2806)

Now we need to handle id overflow when we overflow while replacing uses of the variable.  While looking at this code, I noticed an error in the way we handle access chains that cannot be replaced because of overflow.  Name it will make some change, and then give up by returning SuccessWithoutChange.  But it was changed.

This is fixed up by returning Failure if we notice the error at the time of rewriting the users.  This is for both id overflow or out-of-bounds accesses.

Code is added to "CheckUses" to remove variables that have out-of-bounds accesses from the candidate list, so we don't even try to rewrite its uses.

Fixes https://crbug.com/995032
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 07d8c01..d748e7f 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -78,36 +78,47 @@
   }
 
   std::vector<Instruction*> dead;
-  if (get_def_use_mgr()->WhileEachUser(
-          inst, [this, &replacements, &dead](Instruction* user) {
-            if (!IsAnnotationInst(user->opcode())) {
-              switch (user->opcode()) {
-                case SpvOpLoad:
-                  ReplaceWholeLoad(user, replacements);
-                  dead.push_back(user);
-                  break;
-                case SpvOpStore:
-                  ReplaceWholeStore(user, replacements);
-                  dead.push_back(user);
-                  break;
-                case SpvOpAccessChain:
-                case SpvOpInBoundsAccessChain:
-                  if (ReplaceAccessChain(user, replacements))
-                    dead.push_back(user);
-                  else
-                    return false;
-                  break;
-                case SpvOpName:
-                case SpvOpMemberName:
-                  break;
-                default:
-                  assert(false && "Unexpected opcode");
-                  break;
+  bool replaced_all_uses = get_def_use_mgr()->WhileEachUser(
+      inst, [this, &replacements, &dead](Instruction* user) {
+        if (!IsAnnotationInst(user->opcode())) {
+          switch (user->opcode()) {
+            case SpvOpLoad:
+              if (ReplaceWholeLoad(user, replacements)) {
+                dead.push_back(user);
+              } else {
+                return false;
               }
-            }
-            return true;
-          }))
+              break;
+            case SpvOpStore:
+              if (ReplaceWholeStore(user, replacements)) {
+                dead.push_back(user);
+              } else {
+                return false;
+              }
+              break;
+            case SpvOpAccessChain:
+            case SpvOpInBoundsAccessChain:
+              if (ReplaceAccessChain(user, replacements))
+                dead.push_back(user);
+              else
+                return false;
+              break;
+            case SpvOpName:
+            case SpvOpMemberName:
+              break;
+            default:
+              assert(false && "Unexpected opcode");
+              break;
+          }
+        }
+        return true;
+      });
+
+  if (replaced_all_uses) {
     dead.push_back(inst);
+  } else {
+    return Status::Failure;
+  }
 
   // If there are no dead instructions to clean up, return with no changes.
   if (dead.empty()) return Status::SuccessWithoutChange;
@@ -133,7 +144,7 @@
   return Status::SuccessWithChange;
 }
 
-void ScalarReplacementPass::ReplaceWholeLoad(
+bool ScalarReplacementPass::ReplaceWholeLoad(
     Instruction* load, const std::vector<Instruction*>& replacements) {
   // Replaces the load of the entire composite with a load from each replacement
   // variable followed by a composite construction.
@@ -150,6 +161,9 @@
 
     Instruction* type = GetStorageType(var);
     uint32_t loadId = TakeNextId();
+    if (loadId == 0) {
+      return false;
+    }
     std::unique_ptr<Instruction> newLoad(
         new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
                         std::initializer_list<Operand>{
@@ -168,6 +182,9 @@
 
   // Construct a new composite.
   uint32_t compositeId = TakeNextId();
+  if (compositeId == 0) {
+    return false;
+  }
   where = load;
   std::unique_ptr<Instruction> compositeConstruct(new Instruction(
       context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
@@ -180,9 +197,10 @@
   get_def_use_mgr()->AnalyzeInstDefUse(&*where);
   context()->set_instr_block(&*where, block);
   context()->ReplaceAllUsesWith(load->result_id(), compositeId);
+  return true;
 }
 
-void ScalarReplacementPass::ReplaceWholeStore(
+bool ScalarReplacementPass::ReplaceWholeStore(
     Instruction* store, const std::vector<Instruction*>& replacements) {
   // Replaces a store to the whole composite with a series of extract and stores
   // to each element.
@@ -199,6 +217,9 @@
 
     Instruction* type = GetStorageType(var);
     uint32_t extractId = TakeNextId();
+    if (extractId == 0) {
+      return false;
+    }
     std::unique_ptr<Instruction> extract(new Instruction(
         context(), SpvOpCompositeExtract, type->result_id(), extractId,
         std::initializer_list<Operand>{
@@ -224,6 +245,7 @@
     get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
     context()->set_instr_block(&*iter, block);
   }
+  return true;
 }
 
 bool ScalarReplacementPass::ReplaceAccessChain(
@@ -247,6 +269,9 @@
       // Replace input access chain with another access chain.
       BasicBlock::iterator chainIter(chain);
       uint32_t replacementId = TakeNextId();
+      if (replacementId == 0) {
+        return false;
+      }
       std::unique_ptr<Instruction> replacementChain(new Instruction(
           context(), chain->opcode(), chain->type_id(), replacementId,
           std::initializer_list<Operand>{
@@ -680,44 +705,51 @@
 
 bool ScalarReplacementPass::CheckUses(const Instruction* inst,
                                       VariableStats* stats) const {
+  uint64_t max_legal_index = GetMaxLegalIndex(inst);
+
   bool ok = true;
-  get_def_use_mgr()->ForEachUse(
-      inst, [this, stats, &ok](const Instruction* user, uint32_t index) {
-        // Annotations are check as a group separately.
-        if (!IsAnnotationInst(user->opcode())) {
-          switch (user->opcode()) {
-            case SpvOpAccessChain:
-            case SpvOpInBoundsAccessChain:
-              if (index == 2u && user->NumInOperands() > 1) {
-                uint32_t id = user->GetSingleWordInOperand(1u);
-                const Instruction* opInst = get_def_use_mgr()->GetDef(id);
-                if (!IsCompileTimeConstantInst(opInst->opcode())) {
-                  ok = false;
-                } else {
-                  if (!CheckUsesRelaxed(user)) ok = false;
-                }
-                stats->num_partial_accesses++;
-              } else {
-                ok = false;
-              }
-              break;
-            case SpvOpLoad:
-              if (!CheckLoad(user, index)) ok = false;
-              stats->num_full_accesses++;
-              break;
-            case SpvOpStore:
-              if (!CheckStore(user, index)) ok = false;
-              stats->num_full_accesses++;
-              break;
-            case SpvOpName:
-            case SpvOpMemberName:
-              break;
-            default:
+  get_def_use_mgr()->ForEachUse(inst, [this, max_legal_index, stats, &ok](
+                                          const Instruction* user,
+                                          uint32_t index) {
+    // Annotations are check as a group separately.
+    if (!IsAnnotationInst(user->opcode())) {
+      switch (user->opcode()) {
+        case SpvOpAccessChain:
+        case SpvOpInBoundsAccessChain:
+          if (index == 2u && user->NumInOperands() > 1) {
+            uint32_t id = user->GetSingleWordInOperand(1u);
+            const Instruction* opInst = get_def_use_mgr()->GetDef(id);
+            const auto* constant =
+                context()->get_constant_mgr()->GetConstantFromInst(opInst);
+            if (!constant) {
               ok = false;
-              break;
+            } else if (constant->GetZeroExtendedValue() >= max_legal_index) {
+              ok = false;
+            } else {
+              if (!CheckUsesRelaxed(user)) ok = false;
+            }
+            stats->num_partial_accesses++;
+          } else {
+            ok = false;
           }
-        }
-      });
+          break;
+        case SpvOpLoad:
+          if (!CheckLoad(user, index)) ok = false;
+          stats->num_full_accesses++;
+          break;
+        case SpvOpStore:
+          if (!CheckStore(user, index)) ok = false;
+          stats->num_full_accesses++;
+          break;
+        case SpvOpName:
+        case SpvOpMemberName:
+          break;
+        default:
+          ok = false;
+          break;
+      }
+    }
+  });
 
   return ok;
 }
@@ -847,5 +879,24 @@
   return null_inst;
 }
 
+uint64_t ScalarReplacementPass::GetMaxLegalIndex(
+    const Instruction* var_inst) const {
+  assert(var_inst->opcode() == SpvOpVariable &&
+         "|var_inst| must be a variable instruction.");
+  Instruction* type = GetStorageType(var_inst);
+  switch (type->opcode()) {
+    case SpvOpTypeStruct:
+      return type->NumInOperands();
+    case SpvOpTypeArray:
+      return GetArrayLength(type);
+    case SpvOpTypeMatrix:
+    case SpvOpTypeVector:
+      return GetNumElements(type);
+    default:
+      return 0;
+  }
+  return 0;
+}
+
 }  // namespace opt
 }  // namespace spvtools
diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h
index d2eb897..e20f1f1 100644
--- a/source/opt/scalar_replacement_pass.h
+++ b/source/opt/scalar_replacement_pass.h
@@ -188,21 +188,21 @@
   // Generates a load for each replacement variable and then creates a new
   // composite by combining all of the loads.
   //
-  // |load| must be a load.
-  void ReplaceWholeLoad(Instruction* load,
+  // |load| must be a load.  Returns true if successful.
+  bool ReplaceWholeLoad(Instruction* load,
                         const std::vector<Instruction*>& replacements);
 
   // Replaces the store to the entire composite.
   //
   // Generates a composite extract and store for each element in the scalarized
-  // variable from the original store data input.
-  void ReplaceWholeStore(Instruction* store,
+  // variable from the original store data input.  Returns true if successful.
+  bool ReplaceWholeStore(Instruction* store,
                          const std::vector<Instruction*>& replacements);
 
   // Replaces an access chain to the composite variable with either a direct use
   // of the appropriate replacement variable or another access chain with the
-  // replacement variable as the base and one fewer indexes. Returns false if
-  // the chain has an out of bounds access.
+  // replacement variable as the base and one fewer indexes. Returns true if
+  // successful.
   bool ReplaceAccessChain(Instruction* chain,
                           const std::vector<Instruction*>& replacements);
 
@@ -223,10 +223,16 @@
   // Maps type id to OpConstantNull for that type.
   std::unordered_map<uint32_t, uint32_t> type_to_null_;
 
+  // Returns the number of elements in the variable |var_inst|.
+  uint64_t GetMaxLegalIndex(const Instruction* var_inst) const;
+
+  // Returns true if |length| is larger than limit on the size of the variable
+  // that we will be willing to split.
+  bool IsLargerThanSizeLimit(uint64_t length) const;
+
   // Limit on the number of members in an object that will be replaced.
   // 0 means there is no limit.
   uint32_t max_num_elements_;
-  bool IsLargerThanSizeLimit(uint64_t length) const;
   char name_[55];
 };
 
diff --git a/test/opt/scalar_replacement_test.cpp b/test/opt/scalar_replacement_test.cpp
index e043326..3cf46ca 100644
--- a/test/opt/scalar_replacement_test.cpp
+++ b/test/opt/scalar_replacement_test.cpp
@@ -1696,6 +1696,62 @@
   EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
 }
 
+// Test that id overflow is handled gracefully.
+TEST_F(ScalarReplacementTest, IdBoundOverflow3) {
+  const std::string text = R"(
+OpCapability InterpolationFunction
+OpExtension "z"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %4 "main"
+OpExecutionMode %4 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeFloat 32
+%7 = OpTypeStruct %6 %6
+%9 = OpTypePointer Function %7
+%18 = OpTypeFunction %7 %9
+%21 = OpTypeInt 32 0
+%22 = OpConstant %21 4293000676
+%4194302 = OpConstantNull %6
+%4 = OpFunction %2 Inline|Pure %3
+%786464 = OpLabel
+%4194298 = OpVariable %9 Function
+%10 = OpVariable %9 Function
+%4194299 = OpUDiv %21 %22 %22
+%4194300 = OpLoad %7 %10
+%50959 = OpLoad %7 %4194298
+OpKill
+OpFunctionEnd
+%1 = OpFunction %7 None %18
+%19 = OpFunctionParameter %9
+%147667 = OpLabel
+%2044391 = OpUDiv %21 %22 %22
+%25 = OpLoad %7 %19
+OpReturnValue %25
+OpFunctionEnd
+%4194295 = OpFunction %2 None %3
+%4194296 = OpLabel
+OpKill
+OpFunctionEnd
+  )";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+  std::vector<Message> messages = {
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."},
+      {SPV_MSG_ERROR, "", 0, 0, "ID overflow. Try running compact-ids."}};
+  SetMessageConsumer(GetTestMessageConsumer(messages));
+  auto result = SinglePassRunToBinary<ScalarReplacementPass>(text, true, false);
+  EXPECT_EQ(Pass::Status::Failure, std::get<1>(result));
+}
+
 // Test that replacements for OpAccessChain do not go out of bounds.
 // https://github.com/KhronosGroup/SPIRV-Tools/issues/2609.
 TEST_F(ScalarReplacementTest, OutOfBoundOpAccessChain) {