spirv-opt: Add const folding for CompositeInsert (#4943)

* spirv-opt: Add const folding pass for CompositeInsert

* spirv-opt: Fix anas stack-use-after-scope
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 64475a6..6d80fbb 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -120,6 +120,83 @@
   };
 }
 
+// Folds an OpcompositeInsert where input is a composite constant.
+ConstantFoldingRule FoldInsertWithConstants() {
+  return [](IRContext* context, Instruction* inst,
+            const std::vector<const analysis::Constant*>& constants)
+             -> const analysis::Constant* {
+    analysis::ConstantManager* const_mgr = context->get_constant_mgr();
+    const analysis::Constant* object = constants[0];
+    const analysis::Constant* composite = constants[1];
+    if (object == nullptr || composite == nullptr) {
+      return nullptr;
+    }
+
+    // If there is more than 1 index, then each additional constant used by the
+    // index will need to be recreated to use the inserted object.
+    std::vector<const analysis::Constant*> chain;
+    std::vector<const analysis::Constant*> components;
+    const analysis::Type* type = nullptr;
+
+    // Work down hierarchy and add all the indexes, not including the final
+    // index.
+    for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
+      if (i != inst->NumInOperands() - 1) {
+        chain.push_back(composite);
+      }
+      const uint32_t index = inst->GetSingleWordInOperand(i);
+      components = composite->AsCompositeConstant()->GetComponents();
+      type = composite->AsCompositeConstant()->type();
+      composite = components[index];
+    }
+
+    // Final index in hierarchy is inserted with new object.
+    const uint32_t final_index =
+        inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+    std::vector<uint32_t> ids;
+    for (size_t i = 0; i < components.size(); i++) {
+      const analysis::Constant* constant =
+          (i == final_index) ? object : components[i];
+      Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
+      ids.push_back(member_inst->result_id());
+    }
+    const analysis::Constant* new_constant = const_mgr->GetConstant(type, ids);
+
+    // Work backwards up the chain and replace each index with new constant.
+    for (size_t i = chain.size(); i > 0; i--) {
+      // Need to insert any previous instruction into the module first.
+      // Can't just insert in types_values_begin() because it will move above
+      // where the types are declared
+      for (Module::inst_iterator inst_iter = context->types_values_begin();
+           inst_iter != context->types_values_end(); ++inst_iter) {
+        Instruction* x = &*inst_iter;
+        if (inst->result_id() == x->result_id()) {
+          const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
+          break;
+        }
+      }
+
+      composite = chain[i - 1];
+      components = composite->AsCompositeConstant()->GetComponents();
+      type = composite->AsCompositeConstant()->type();
+      ids.clear();
+      for (size_t k = 0; k < components.size(); k++) {
+        const uint32_t index =
+            inst->GetSingleWordInOperand(1 + static_cast<uint32_t>(i));
+        const analysis::Constant* constant =
+            (k == index) ? new_constant : components[k];
+        const uint32_t constant_id =
+            const_mgr->FindDeclaredConstant(constant, 0);
+        ids.push_back(constant_id);
+      }
+      new_constant = const_mgr->GetConstant(type, ids);
+    }
+
+    // If multiple constants were created, only need to return the top index.
+    return new_constant;
+  };
+}
+
 ConstantFoldingRule FoldVectorShuffleWithConstants() {
   return [](IRContext* context, Instruction* inst,
             const std::vector<const analysis::Constant*>& constants)
@@ -1410,6 +1487,7 @@
   rules_[spv::Op::OpCompositeConstruct].push_back(FoldCompositeWithConstants());
 
   rules_[spv::Op::OpCompositeExtract].push_back(FoldExtractWithConstants());
+  rules_[spv::Op::OpCompositeInsert].push_back(FoldInsertWithConstants());
 
   rules_[spv::Op::OpConvertFToS].push_back(FoldFToI());
   rules_[spv::Op::OpConvertFToU].push_back(FoldFToI());
diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp
index c98a44c..e2374c5 100644
--- a/test/opt/fold_spec_const_op_composite_test.cpp
+++ b/test/opt/fold_spec_const_op_composite_test.cpp
@@ -308,6 +308,72 @@
       builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
 }
 
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVector) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+       %uint = OpTypeInt 32 0
+     %v3uint = OpTypeVector %uint 3
+     %uint_2 = OpConstant %uint 2
+     %uint_3 = OpConstant %uint 3
+          %8 = OpConstantNull %uint
+          %9 = OpSpecConstantComposite %v3uint %uint_2 %uint_2 %uint_2
+ ; CHECK: %15 = OpConstantComposite %v3uint %uint_3 %uint_2 %uint_2
+ ; CHECK: %uint_3_0 = OpConstant %uint 3
+ ; CHECK: %17 = OpConstantComposite %v3uint %8 %uint_2 %uint_2
+ ; CHECK: %18 = OpConstantNull %uint
+         %10 = OpSpecConstantOp %v3uint CompositeInsert %uint_3 %9 0
+         %11 = OpSpecConstantOp %uint CompositeExtract %10 0
+         %12 = OpSpecConstantOp %v3uint CompositeInsert %8 %9 0
+         %13 = OpSpecConstantOp %uint CompositeExtract %12 0
+          %1 = OpFunction %void None %3
+         %14 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+%mat3v3float = OpTypeMatrix %v3float 3
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+          %9 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
+         %10 = OpSpecConstantComposite %v3float %float_1 %float_1 %float_1
+         %11 = OpSpecConstantComposite %v3float %float_1 %float_2 %float_1
+         %12 = OpSpecConstantComposite %mat3v3float %9 %10 %11
+ ; CHECK: %float_2_0 = OpConstant %float 2
+ ; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_2
+ ; CHECK: %19 = OpConstantComposite %mat3v3float %9 %18 %11
+ ; CHECK: %float_2_1 = OpConstant %float 2
+         %13 = OpSpecConstantOp %float CompositeExtract %12 2 1
+         %14 = OpSpecConstantOp %mat3v3float CompositeInsert %13 %12 1 2
+         %15 = OpSpecConstantOp %float CompositeExtract %14 1 2
+          %1 = OpFunction %void None %3
+         %16 = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
 // All types and some common constants that are potentially required in
 // FoldSpecConstantOpAndCompositeTest.
 std::vector<std::string> CommonTypesAndConstants() {