Add more folding for composite instructions (#4802)

* Add move folding for composite instructions

Fold chains of insert into construct

If a chain of OpCompositeInsert instruction write to every element of a
composite object, then we can replace it with an OpCompositeConstruct.

Fold a construct fed by extracts to a single extract

We already fold an OpCompositeConstruct when it is simlpy reconstructing
an object that was decomposed by a series of OpCompositeExtract
instructions.  However, we do not do that if that object is an element
of a larger object.

I have updated the rule, so that if the original object is a an element
of a larger object, then the OpCompositeConstruct is replaced with a
single OpCompositeExtract from the larger object.

Fixes #4371.
diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp
index d15ad04..ab7a20e 100644
--- a/source/opt/folding_rules.cpp
+++ b/source/opt/folding_rules.cpp
@@ -1631,6 +1631,57 @@
   return true;
 }
 
+// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
+// OpCompositeExtract instruction, and returns the type of the final element
+// being accessed.
+const analysis::Type* GetElementType(uint32_t type_id,
+                                     Instruction::iterator start,
+                                     Instruction::iterator end,
+                                     const analysis::TypeManager* type_mgr) {
+  const analysis::Type* type = type_mgr->GetType(type_id);
+  for (auto index : make_range(std::move(start), std::move(end))) {
+    assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
+           index.words.size() == 1);
+    if (auto* array_type = type->AsArray()) {
+      type = array_type->element_type();
+    } else if (auto* matrix_type = type->AsMatrix()) {
+      type = matrix_type->element_type();
+    } else if (auto* struct_type = type->AsStruct()) {
+      type = struct_type->element_types()[index.words[0]];
+    } else {
+      type = nullptr;
+    }
+  }
+  return type;
+}
+
+// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
+// to index into a composite object, excluding the last index.  The two
+// instructions must have the same opcode, and be either OpCompositeExtract or
+// OpCompositeInsert instructions.
+bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
+  assert(inst_1->opcode() == inst_2->opcode() &&
+         "Expecting the opcodes to be the same.");
+  assert((inst_1->opcode() == SpvOpCompositeInsert ||
+          inst_1->opcode() == SpvOpCompositeExtract) &&
+         "Instructions must be OpCompositeInsert or OpCompositeExtract.");
+
+  if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
+    return false;
+  }
+
+  uint32_t first_index_position =
+      (inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
+  for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
+       i++) {
+    if (inst_1->GetSingleWordInOperand(i) !=
+        inst_2->GetSingleWordInOperand(i)) {
+      return false;
+    }
+  }
+  return true;
+}
+
 // If the OpCompositeConstruct is simply putting back together elements that
 // where extracted from the same source, we can simply reuse the source.
 //
@@ -1653,19 +1704,24 @@
   // - extractions
   // - extracting the same position they are inserting
   // - all extract from the same id.
+  Instruction* first_element_inst = nullptr;
   for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
     const uint32_t element_id = inst->GetSingleWordInOperand(i);
     Instruction* element_inst = def_use_mgr->GetDef(element_id);
+    if (first_element_inst == nullptr) {
+      first_element_inst = element_inst;
+    }
 
     if (element_inst->opcode() != SpvOpCompositeExtract) {
       return false;
     }
 
-    if (element_inst->NumInOperands() != 2) {
+    if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
       return false;
     }
 
-    if (element_inst->GetSingleWordInOperand(1) != i) {
+    if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
+                                             1) != i) {
       return false;
     }
 
@@ -1681,13 +1737,31 @@
   // The last check it to see that the object being extracted from is the
   // correct type.
   Instruction* original_inst = def_use_mgr->GetDef(original_id);
-  if (original_inst->type_id() != inst->type_id()) {
+  analysis::TypeManager* type_mgr = context->get_type_mgr();
+  const analysis::Type* original_type =
+      GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
+                     first_element_inst->end() - 1, type_mgr);
+
+  if (original_type == nullptr) {
     return false;
   }
 
-  // Simplify by using the original object.
-  inst->SetOpcode(SpvOpCopyObject);
-  inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+  if (inst->type_id() != type_mgr->GetId(original_type)) {
+    return false;
+  }
+
+  if (first_element_inst->NumInOperands() == 2) {
+    // Simplify by using the original object.
+    inst->SetOpcode(SpvOpCopyObject);
+    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
+    return true;
+  }
+
+  // Copies the original id and all indexes except for the last to the new
+  // extract instruction.
+  inst->SetOpcode(SpvOpCompositeExtract);
+  inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
+                                           first_element_inst->end() - 1));
   return true;
 }
 
@@ -1891,6 +1965,139 @@
   };
 }
 
+// Returns the number of elements in the composite type |type|.  Returns 0 if
+// |type| is a scalar value.
+uint32_t GetNumberOfElements(const analysis::Type* type) {
+  if (auto* vector_type = type->AsVector()) {
+    return vector_type->element_count();
+  }
+  if (auto* matrix_type = type->AsMatrix()) {
+    return matrix_type->element_count();
+  }
+  if (auto* struct_type = type->AsStruct()) {
+    return static_cast<uint32_t>(struct_type->element_types().size());
+  }
+  if (auto* array_type = type->AsArray()) {
+    return array_type->length_info().words[0];
+  }
+  return 0;
+}
+
+// Returns a map with the set of values that were inserted into an object by
+// the chain of OpCompositeInsertInstruction starting with |inst|.
+// The map will map the index to the value inserted at that index.
+std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
+  analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
+  std::map<uint32_t, uint32_t> values_inserted;
+  Instruction* current_inst = inst;
+  while (current_inst->opcode() == SpvOpCompositeInsert) {
+    if (current_inst->NumInOperands() > inst->NumInOperands()) {
+      // This is the catch the case
+      //   %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
+      //   %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
+      //   %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
+      // In this case we cannot do a single construct to get the matrix.
+      uint32_t partially_inserted_element_index =
+          current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+      if (values_inserted.count(partially_inserted_element_index) == 0)
+        return {};
+    }
+    if (HaveSameIndexesExceptForLast(inst, current_inst)) {
+      values_inserted.insert(
+          {current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
+                                                1),
+           current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
+    }
+    current_inst = def_use_mgr->GetDef(
+        current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
+  }
+  return values_inserted;
+}
+
+// Returns true of there is an entry in |values_inserted| for every element of
+// |Type|.
+bool DoInsertedValuesCoverEntireObject(
+    const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
+  uint32_t container_size = GetNumberOfElements(type);
+  if (container_size != values_inserted.size()) {
+    return false;
+  }
+
+  if (values_inserted.rbegin()->first >= container_size) {
+    return false;
+  }
+  return true;
+}
+
+// Returns the type of the element that immediately contains the element being
+// inserted by the OpCompositeInsert instruction |inst|.
+const analysis::Type* GetContainerType(Instruction* inst) {
+  assert(inst->opcode() == SpvOpCompositeInsert);
+  analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
+  return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
+                        type_mgr);
+}
+
+// Returns an OpCompositeConstruct instruction that build an object with
+// |type_id| out of the values in |values_inserted|.  Each value will be
+// placed at the index corresponding to the value.  The new instruction will
+// be placed before |insert_before|.
+Instruction* BuildCompositeConstruct(
+    uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
+    Instruction* insert_before) {
+  InstructionBuilder ir_builder(
+      insert_before->context(), insert_before,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+  std::vector<uint32_t> ids_in_order;
+  for (auto it : values_inserted) {
+    ids_in_order.push_back(it.second);
+  }
+  Instruction* construct =
+      ir_builder.AddCompositeConstruct(type_id, ids_in_order);
+  return construct;
+}
+
+// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
+// object as |inst| with final index removed.  If the resulting
+// OpCompositeInsert instruction would have no remaining indexes, the
+// instruction is replaced with an OpCopyObject instead.
+void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
+  if (inst->NumInOperands() == 3) {
+    inst->SetOpcode(SpvOpCopyObject);
+    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
+  } else {
+    inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
+    inst->RemoveOperand(inst->NumOperands() - 1);
+  }
+}
+
+// Replaces a series of |OpCompositeInsert| instruction that cover the entire
+// object with an |OpCompositeConstruct|.
+bool CompositeInsertToCompositeConstruct(
+    IRContext* context, Instruction* inst,
+    const std::vector<const analysis::Constant*>&) {
+  assert(inst->opcode() == SpvOpCompositeInsert &&
+         "Wrong opcode.  Should be OpCompositeInsert.");
+  if (inst->NumInOperands() < 3) return false;
+
+  std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
+  const analysis::Type* container_type = GetContainerType(inst);
+  if (container_type == nullptr) {
+    return false;
+  }
+
+  if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
+    return false;
+  }
+
+  analysis::TypeManager* type_mgr = context->get_type_mgr();
+  Instruction* construct = BuildCompositeConstruct(
+      type_mgr->GetId(container_type), values_inserted, inst);
+  InsertConstructedObject(inst, construct);
+  return true;
+}
+
 FoldingRule RedundantPhi() {
   // An OpPhi instruction where all values are the same or the result of the phi
   // itself, can be replaced by the value itself.
@@ -2591,6 +2798,8 @@
   rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
   rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
 
+  rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
+
   rules_[SpvOpDot].push_back(DotProductDoingExtract());
 
   rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());
diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp
index 2ca3256..e2240b8 100644
--- a/test/opt/fold_test.cpp
+++ b/test/opt/fold_test.cpp
@@ -147,6 +147,7 @@
 %v2double = OpTypeVector %double 2
 %v2half = OpTypeVector %half 2
 %v2bool = OpTypeVector %bool 2
+%m2x2int = OpTypeMatrix %v2int 2
 %struct_v2int_int_int = OpTypeStruct %v2int %int %int
 %_ptr_int = OpTypePointer Function %int
 %_ptr_uint = OpTypePointer Function %uint
@@ -218,7 +219,9 @@
 %struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
 %v2int_null = OpConstantNull %v2int
 %102 = OpConstantComposite %v2int %103 %103
+%v4int_undef = OpUndef %v4int
 %v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
+%m2x2int_undef = OpUndef %m2x2int
 %struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
 %float_n1 = OpConstant %float -1
 %104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
@@ -6862,7 +6865,7 @@
       4, true)
 ));
 
-INSTANTIATE_TEST_SUITE_P(CompositeExtractMatchingTest, MatchingInstructionFoldingTest,
+INSTANTIATE_TEST_SUITE_P(CompositeExtractOrInsertMatchingTest, MatchingInstructionFoldingTest,
 ::testing::Values(
     // Test case 0: Extracting from result of consecutive shuffles of differing
     // size.
@@ -7002,7 +7005,145 @@
             "%4 = OpCompositeExtract %int %3 1\n" +
             "OpReturn\n" +
             "OpFunctionEnd",
-        4, true)
+        4, true),
+    // Test case 8: Inserting every element of a vector turns into a composite construct.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+            "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+            "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+            "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+            "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+            "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+            "%3 = OpCompositeInsert %v4int %int_1 %2 1\n" +
+            "%4 = OpCompositeInsert %v4int %int_2 %3 2\n" +
+            "%5 = OpCompositeInsert %v4int %int_3 %4 3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        5, true),
+    // Test case 9: Inserting every element of a vector turns into a composite construct in a different order.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+            "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+            "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+            "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+            "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+            "; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+            "%4 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+            "%3 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+            "%5 = OpCompositeInsert %v4int %int_3 %3 3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        5, true),
+    // Test case 10: Check multiple inserts to the same position are handled correctly.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
+            "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+            "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+            "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+            "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
+            "; CHECK: %6 = OpCopyObject [[v4]] [[construct]]\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
+            "%3 = OpCompositeInsert %v4int %int_2 %2 2\n" +
+            "%4 = OpCompositeInsert %v4int %int_4 %3 1\n" +
+            "%5 = OpCompositeInsert %v4int %int_1 %4 1\n" +
+            "%6 = OpCompositeInsert %v4int %int_3 %5 3\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        6, true),
+    // Test case 11: The last indexes are 0 and 1, but they have different first indexes.  This should not be folded.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+            "%3 = OpCompositeInsert %m2x2int %int_1 %2 1 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        3, false),
+    // Test case 12: Don't fold when there is a partial insertion.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0\n" +
+            "%3 = OpCompositeInsert %m2x2int %int_4 %2 0 0\n" +
+            "%4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        4, false),
+    // Test case 13: Insert into a column of a matrix
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+            "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+            "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+            "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+// We keep this insert in the chain.  DeadInsertElimPass should remove it.
+            "; CHECK: [[insert:%\\w+]] = OpCompositeInsert [[m2x2]] %100 [[m2x2_undef]] 0 0\n" +
+            "; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+            "; CHECK: %3 = OpCompositeInsert [[m2x2]] [[construct]] [[insert]] 0\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
+            "%3 = OpCompositeInsert %m2x2int %int_1 %2 0 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        3, true),
+    // Test case 14: Insert all elements of the matrix.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+            "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+            "; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+            "; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
+            "; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
+            "; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
+            "; CHECK: [[c0:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
+            "; CHECK: [[c1:%\\w+]] = OpCompositeConstruct [[v2]] [[int2]] [[int3]]\n" +
+            "; CHECK: [[matrix:%\\w+]] = OpCompositeConstruct [[m2x2]] [[c0]] [[c1]]\n" +
+            "; CHECK: %5 = OpCopyObject [[m2x2]] [[matrix]]\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%2 = OpCompositeConstruct %v2int %100 %int_1\n" +
+            "%3 = OpCompositeInsert %m2x2int %2 %m2x2int_undef 0\n" +
+            "%4 = OpCompositeInsert %m2x2int %int_2 %3 1 0\n" +
+            "%5 = OpCompositeInsert %m2x2int %int_3 %4 1 1\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        5, true),
+    // Test case 15: Replace construct with extract when reconstructing a member
+    // of another object.
+    InstructionFoldingCase<bool>(
+        Header() +
+            "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
+            "; CHECK: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
+            "; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
+            "; CHECK: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
+            "; CHECK: %5 = OpCompositeExtract [[v2]] [[m2x2_undef]]\n" +
+            "%main = OpFunction %void None %void_func\n" +
+            "%main_lab = OpLabel\n" +
+            "%3 = OpCompositeExtract %int %m2x2int_undef 1 0\n" +
+            "%4 = OpCompositeExtract %int %m2x2int_undef 1 1\n" +
+            "%5 = OpCompositeConstruct %v2int %3 %4\n" +
+            "OpReturn\n" +
+            "OpFunctionEnd",
+        5, true)
 ));
 
 INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,