spirv-opt: fix copy-propagate-arrays index opti on structs. (#4891)

* spirv-opt: fix copy-propagate-arrays index opti on structs.

As per SPIR-V spec:
OpAccessChain indices must be OpConstant when indexing into a structure.

This optimization tried to remove load cascade. But in some scenario
failed:

```c
cbuffer MyStruct {
    uint my_field;
};

uint main(uint index) {
    const uint my_array[1] = { my_field };
    return my_array[index]
}
```

This is valid as the struct is indexed with a constant index, and then
the array is indexed using a dynamic index.
The optimization would consider the local array to be useless and
generated a load directly into the struct.

* spirv-opt: prevent creation of unused instructions

Copy-propagate-arrays optimization pass would create unused constants,
even if the optimization not completed.
This was caused by the way we handled OpAccessChain squashing: we
only referenced constants, and had to create them upfront.

Fixes #4887
Signed-off-by: Nathan Gauër <brioche@google.com>
diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp
index 1c30138..0b23562 100644
--- a/source/opt/copy_prop_arrays.cpp
+++ b/source/opt/copy_prop_arrays.cpp
@@ -151,9 +151,17 @@
     return source->GetVariable();
   }
 
+  source->BuildConstants();
+  std::vector<uint32_t> access_ids(source->AccessChain().size());
+  std::transform(
+      source->AccessChain().cbegin(), source->AccessChain().cend(),
+      access_ids.begin(), [](const AccessChainEntry& entry) {
+        assert(entry.is_result_id && "Constants needs to be built first.");
+        return entry.result_id;
+      });
+
   return builder.AddAccessChain(source->GetPointerTypeId(this),
-                                source->GetVariable()->result_id(),
-                                source->AccessChain());
+                                source->GetVariable()->result_id(), access_ids);
 }
 
 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
@@ -270,30 +278,20 @@
 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
   assert(extract_inst->opcode() == SpvOpCompositeExtract &&
          "Expecting an OpCompositeExtract instruction.");
-  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
-
   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
 
-  if (result) {
-    analysis::Integer int_type(32, false);
-    const analysis::Type* uint32_type =
-        context()->get_type_mgr()->GetRegisteredType(&int_type);
-
-    std::vector<uint32_t> components;
-    // Convert the indices in the extract instruction to a series of ids that
-    // can be used by the |OpAccessChain| instruction.
-    for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
-      uint32_t index = extract_inst->GetSingleWordInOperand(i);
-      const analysis::Constant* index_const =
-          const_mgr->GetConstant(uint32_type, {index});
-      components.push_back(
-          const_mgr->GetDefiningInstruction(index_const)->result_id());
-    }
-    result->GetMember(components);
-    return result;
+  if (!result) {
+    return nullptr;
   }
-  return nullptr;
+
+  // Copy the indices of the extract instruction to |OpAccessChain| indices.
+  std::vector<AccessChainEntry> components;
+  for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
+    components.push_back({false, {extract_inst->GetSingleWordInOperand(i)}});
+  }
+  result->PushIndirection(components);
+  return result;
 }
 
 std::unique_ptr<CopyPropagateArrays::MemoryObject>
@@ -317,19 +315,12 @@
     return nullptr;
   }
 
-  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
-  const analysis::Constant* last_access =
-      const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
-  if (!last_access || !last_access->type()->AsInteger()) {
+  AccessChainEntry last_access = memory_object->AccessChain().back();
+  if (!IsAccessChainIndexValidAndEqualTo(last_access, 0)) {
     return nullptr;
   }
 
-  if (last_access->GetU32() != 0) {
-    return nullptr;
-  }
-
-  memory_object->GetParent();
-
+  memory_object->PopIndirection();
   if (memory_object->GetNumberOfMembers() !=
       conststruct_inst->NumInOperands()) {
     return nullptr;
@@ -351,13 +342,8 @@
       return nullptr;
     }
 
-    last_access =
-        const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
-    if (!last_access || !last_access->type()->AsInteger()) {
-      return nullptr;
-    }
-
-    if (last_access->GetU32() != i) {
+    last_access = member_object->AccessChain().back();
+    if (!IsAccessChainIndexValidAndEqualTo(last_access, i)) {
       return nullptr;
     }
   }
@@ -411,17 +397,12 @@
     return nullptr;
   }
 
-  const analysis::Constant* last_access =
-      const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
-  if (!last_access || !last_access->type()->AsInteger()) {
+  AccessChainEntry last_access = memory_object->AccessChain().back();
+  if (!IsAccessChainIndexValidAndEqualTo(last_access, number_of_elements - 1)) {
     return nullptr;
   }
 
-  if (last_access->GetU32() != number_of_elements - 1) {
-    return nullptr;
-  }
-
-  memory_object->GetParent();
+  memory_object->PopIndirection();
 
   Instruction* current_insert =
       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
@@ -458,14 +439,9 @@
       return nullptr;
     }
 
-    const analysis::Constant* current_last_access =
-        const_mgr->FindDeclaredConstant(
-            current_memory_object->AccessChain().back());
-    if (!current_last_access || !current_last_access->type()->AsInteger()) {
-      return nullptr;
-    }
-
-    if (current_last_access->GetU32() != i - 1) {
+    AccessChainEntry current_last_access =
+        current_memory_object->AccessChain().back();
+    if (!IsAccessChainIndexValidAndEqualTo(current_last_access, i - 1)) {
       return nullptr;
     }
     current_insert =
@@ -475,6 +451,21 @@
   return memory_object;
 }
 
+bool CopyPropagateArrays::IsAccessChainIndexValidAndEqualTo(
+    const AccessChainEntry& entry, uint32_t value) const {
+  if (!entry.is_result_id) {
+    return entry.immediate == value;
+  }
+
+  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+  const analysis::Constant* constant =
+      const_mgr->FindDeclaredConstant(entry.result_id);
+  if (!constant || !constant->type()->AsInteger()) {
+    return false;
+  }
+  return constant->GetU32() == value;
+}
+
 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
   analysis::TypeManager* type_mgr = context()->get_type_mgr();
   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
@@ -532,6 +523,12 @@
             // Variable index means the type is a type where every element
             // is the same type.  Use element 0 to get the type.
             access_chain.push_back(0);
+
+            // We are trying to access a struct with variable indices.
+            // This cannot happen.
+            if (pointee_type->kind() == analysis::Type::kStruct) {
+              return false;
+            }
           }
         }
 
@@ -787,8 +784,8 @@
   return id;
 }
 
-void CopyPropagateArrays::MemoryObject::GetMember(
-    const std::vector<uint32_t>& access_chain) {
+void CopyPropagateArrays::MemoryObject::PushIndirection(
+    const std::vector<AccessChainEntry>& access_chain) {
   access_chain_.insert(access_chain_.end(), access_chain.begin(),
                        access_chain.end());
 }
@@ -823,23 +820,29 @@
 template <class iterator>
 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
                                                 iterator begin, iterator end)
-    : variable_inst_(var_inst), access_chain_(begin, end) {}
+    : variable_inst_(var_inst) {
+  std::transform(begin, end, std::back_inserter(access_chain_),
+                 [](uint32_t id) {
+                   return AccessChainEntry{true, {id}};
+                 });
+}
 
 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
   analysis::ConstantManager* const_mgr =
       variable_inst_->context()->get_constant_mgr();
 
-  std::vector<uint32_t> access_indices;
-  for (uint32_t id : AccessChain()) {
-    const analysis::Constant* element_index_const =
-        const_mgr->FindDeclaredConstant(id);
-    if (!element_index_const) {
-      access_indices.push_back(0);
-    } else {
-      access_indices.push_back(element_index_const->GetU32());
-    }
-  }
-  return access_indices;
+  std::vector<uint32_t> indices(AccessChain().size());
+  std::transform(AccessChain().cbegin(), AccessChain().cend(), indices.begin(),
+                 [&const_mgr](const AccessChainEntry& entry) {
+                   if (entry.is_result_id) {
+                     const analysis::Constant* constant =
+                         const_mgr->FindDeclaredConstant(entry.result_id);
+                     return constant == nullptr ? 0 : constant->GetU32();
+                   }
+
+                   return entry.immediate;
+                 });
+  return indices;
 }
 
 bool CopyPropagateArrays::MemoryObject::Contains(
@@ -860,5 +863,24 @@
   return true;
 }
 
+void CopyPropagateArrays::MemoryObject::BuildConstants() {
+  for (auto& entry : access_chain_) {
+    if (entry.is_result_id) {
+      continue;
+    }
+
+    auto context = variable_inst_->context();
+    analysis::Integer int_type(32, false);
+    const analysis::Type* uint32_type =
+        context->get_type_mgr()->GetRegisteredType(&int_type);
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Constant* index_const =
+        const_mgr->GetConstant(uint32_type, {entry.immediate});
+    entry.result_id =
+        const_mgr->GetDefiningInstruction(index_const)->result_id();
+    entry.is_result_id = true;
+  }
+}
+
 }  // namespace opt
 }  // namespace spvtools
diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h
index 07747c1..9e7641f 100644
--- a/source/opt/copy_prop_arrays.h
+++ b/source/opt/copy_prop_arrays.h
@@ -52,6 +52,22 @@
   }
 
  private:
+  // Represents one index in the OpAccessChain instruction. It can be either
+  // an instruction's result_id (OpConstant by ex), or a immediate value.
+  // Immediate values are used to prepare the final access chain without
+  // creating OpConstant instructions until done.
+  struct AccessChainEntry {
+    bool is_result_id;
+    union {
+      uint32_t result_id;
+      uint32_t immediate;
+    };
+
+    bool operator!=(const AccessChainEntry& other) const {
+      return other.is_result_id != is_result_id || other.result_id != result_id;
+    }
+  };
+
   // The class used to identify a particular memory object.  This memory object
   // will be owned by a particular variable, meaning that the memory is part of
   // that variable.  It could be the entire variable or a member of the
@@ -70,12 +86,12 @@
     // (starting from the current member).  The elements in |access_chain| are
     // interpreted the same as the indices in the |OpAccessChain|
     // instruction.
-    void GetMember(const std::vector<uint32_t>& access_chain);
+    void PushIndirection(const std::vector<AccessChainEntry>& access_chain);
 
     // Change |this| to now represent the first enclosing object to which it
     // belongs.  (Remove the last element off the access_chain). It is invalid
     // to call this function if |this| does not represent a member of its owner.
-    void GetParent() {
+    void PopIndirection() {
       assert(IsMember());
       access_chain_.pop_back();
     }
@@ -95,7 +111,13 @@
     // member that |this| represents starting from the owning variable.  These
     // values are to be interpreted the same way the indices are in an
     // |OpAccessChain| instruction.
-    const std::vector<uint32_t>& AccessChain() const { return access_chain_; }
+    const std::vector<AccessChainEntry>& AccessChain() const {
+      return access_chain_;
+    }
+
+    // Converts all immediate values in the AccessChain their OpConstant
+    // equivalent.
+    void BuildConstants();
 
     // Returns the type id of the pointer type that can be used to point to this
     // memory object.
@@ -137,7 +159,7 @@
     // The access chain to reach the particular member the memory object
     // represents.  It should be interpreted the same way the indices in an
     // |OpAccessChain| are interpreted.
-    std::vector<uint32_t> access_chain_;
+    std::vector<AccessChainEntry> access_chain_;
     std::vector<uint32_t> GetAccessIds() const;
   };
 
@@ -192,6 +214,10 @@
   std::unique_ptr<MemoryObject> BuildMemoryObjectFromInsert(
       Instruction* insert_inst);
 
+  // Return true if the given entry can represent the given value.
+  bool IsAccessChainIndexValidAndEqualTo(const AccessChainEntry& entry,
+                                         uint32_t value) const;
+
   // Return true if |type_id| is a pointer type whose pointee type is an array.
   bool IsPointerToArrayType(uint32_t type_id);
 
diff --git a/test/opt/copy_prop_array_test.cpp b/test/opt/copy_prop_array_test.cpp
index f322f4a..d6e376e 100644
--- a/test/opt/copy_prop_array_test.cpp
+++ b/test/opt/copy_prop_array_test.cpp
@@ -1884,6 +1884,65 @@
   SetTargetEnv(SPV_ENV_UNIVERSAL_1_4);
   SinglePassRunAndMatch<CopyPropagateArrays>(before, false);
 }
+
+// As per SPIRV spec, struct cannot be indexed with non-constant indices
+// through OpAccessChain, only arrays.
+// The copy-propagate-array pass tries to remove superfluous copies when the
+// original array could be indexed instead of the copy.
+//
+// This test verifies we handle this case:
+//  struct SRC { int field1; ...; int fieldN }
+//  int tmp_arr[N] = { SRC.field1, ..., SRC.fieldN }
+//  return tmp_arr[index];
+//
+// In such case, we cannot optimize the access: this array was added to allow
+// dynamic indexing in the struct.
+TEST_F(CopyPropArrayPassTest, StructIndexCannotBecomeDynamic) {
+  const std::string text = R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %1 "main"
+OpDecorate %2 DescriptorSet 0
+OpDecorate %2 Binding 0
+OpMemberDecorate %_struct_3 0 Offset 0
+OpDecorate %_struct_3 Block
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_struct_3 = OpTypeStruct %v4float
+%_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3
+%uint = OpTypeInt 32 0
+%void = OpTypeVoid
+%11 = OpTypeFunction %void
+%_ptr_Function_uint = OpTypePointer Function %uint
+%13 = OpTypeFunction %v4float %_ptr_Function_uint
+%uint_1 = OpConstant %uint 1
+%_arr_v4float_uint_1 = OpTypeArray %v4float %uint_1
+%_ptr_Function__arr_v4float_uint_1 = OpTypePointer Function %_arr_v4float_uint_1
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+%2 = OpVariable %_ptr_Uniform__struct_3 Uniform
+%19 = OpUndef %v4float
+%1 = OpFunction %void None %11
+%20 = OpLabel
+OpReturn
+OpFunctionEnd
+%21 = OpFunction %v4float None %13
+%22 = OpFunctionParameter %_ptr_Function_uint
+%23 = OpLabel
+%24 = OpVariable %_ptr_Function__arr_v4float_uint_1 Function
+%25 = OpAccessChain %_ptr_Uniform_v4float %2 %int_0
+%26 = OpLoad %v4float %25
+%27 = OpCompositeConstruct %_arr_v4float_uint_1 %26
+OpStore %24 %27
+%28 = OpLoad %uint %22
+%29 = OpAccessChain %_ptr_Function_v4float %24 %28
+OpReturnValue %19
+OpFunctionEnd
+)";
+
+  SinglePassRunAndCheck<CopyPropagateArrays>(text, text, false);
+}
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools