spirv-opt: Fix stacked CompositeExtract constant folds (#4932)

This was spotted in the Validation Layers where OpSpecConstantOp %x CompositeExtract %y 0 was being folded to a constant, but anything that was using it wasn't recognizing it as a constant, the simple fix was to add a const_mgr->MapInst(new_const_inst); so the next instruction knew it was a const
diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp
index b903da6..315741a 100644
--- a/source/opt/fold.cpp
+++ b/source/opt/fold.cpp
@@ -627,8 +627,7 @@
     Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
   analysis::ConstantManager* const_mgr = context_->get_constant_mgr();
 
-  if (!inst->IsFoldableByFoldScalar() &&
-      !GetConstantFoldingRules().HasFoldingRule(inst)) {
+  if (!inst->IsFoldableByFoldScalar() && !HasConstFoldingRule(inst)) {
     return nullptr;
   }
   // Collect the values of the constant parameters.
diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
index 8d68850..7a51870 100644
--- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp
+++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp
@@ -28,6 +28,7 @@
 
 Pass::Status FoldSpecConstantOpAndCompositePass::Process() {
   bool modified = false;
+  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
   // Traverse through all the constant defining instructions. For Normal
   // Constants whose values are determined and do not depend on OpUndef
   // instructions, records their values in two internal maps: id_to_const_val_
@@ -62,8 +63,8 @@
     // used in OpSpecConstant{Composite|Op} instructions.
     // TODO(qining): If the constant or its type has decoration, we may need
     // to skip it.
-    if (context()->get_constant_mgr()->GetType(inst) &&
-        !context()->get_constant_mgr()->GetType(inst)->decoration_empty())
+    if (const_mgr->GetType(inst) &&
+        !const_mgr->GetType(inst)->decoration_empty())
       continue;
     switch (SpvOp opcode = inst->opcode()) {
       // Records the values of Normal Constants.
@@ -80,15 +81,14 @@
         // Constant will be turned in to a Normal Constant. In that case, a
         // Constant instance should also be created successfully and recorded
         // in the id_to_const_val_ and const_val_to_id_ mapps.
-        if (auto const_value =
-                context()->get_constant_mgr()->GetConstantFromInst(inst)) {
+        if (auto const_value = const_mgr->GetConstantFromInst(inst)) {
           // Need to replace the OpSpecConstantComposite instruction with a
           // corresponding OpConstantComposite instruction.
           if (opcode == SpvOp::SpvOpSpecConstantComposite) {
             inst->SetOpcode(SpvOp::SpvOpConstantComposite);
             modified = true;
           }
-          context()->get_constant_mgr()->MapConstantToInst(const_value, inst);
+          const_mgr->MapConstantToInst(const_value, inst);
         }
         break;
       }
@@ -146,6 +146,7 @@
 
 Instruction* FoldSpecConstantOpAndCompositePass::FoldWithInstructionFolder(
     Module::inst_iterator* inst_iter_ptr) {
+  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
   // If one of operands to the instruction is not a
   // constant, then we cannot fold this spec constant.
   for (uint32_t i = 1; i < (*inst_iter_ptr)->NumInOperands(); i++) {
@@ -155,7 +156,7 @@
       continue;
     }
     uint32_t id = operand.words[0];
-    if (context()->get_constant_mgr()->FindDeclaredConstant(id) == nullptr) {
+    if (const_mgr->FindDeclaredConstant(id) == nullptr) {
       return nullptr;
     }
   }
@@ -202,6 +203,7 @@
     new_const_inst->InsertAfter(insert_pos);
     get_def_use_mgr()->AnalyzeInstDefUse(new_const_inst);
   }
+  const_mgr->MapInst(new_const_inst);
   return new_const_inst;
 }
 
@@ -285,8 +287,8 @@
 Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
     Module::inst_iterator* pos) {
   const Instruction* inst = &**pos;
-  const analysis::Type* result_type =
-      context()->get_constant_mgr()->GetType(inst);
+  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
+  const analysis::Type* result_type = const_mgr->GetType(inst);
   SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
   // Check and collect operands.
   std::vector<const analysis::Constant*> operands;
@@ -311,10 +313,9 @@
     // Scalar operation
     const uint32_t result_val =
         context()->get_instruction_folder().FoldScalars(spec_opcode, operands);
-    auto result_const = context()->get_constant_mgr()->GetConstant(
+    auto result_const = const_mgr->GetConstant(
         result_type, EncodeIntegerAsWords(*result_type, result_val));
-    return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-        result_const, pos);
+    return const_mgr->BuildInstructionAndAddToModule(result_const, pos);
   } else if (result_type->AsVector()) {
     // Vector operation
     const analysis::Type* element_type =
@@ -325,11 +326,10 @@
                                                         operands);
     std::vector<const analysis::Constant*> result_vector_components;
     for (const uint32_t r : result_vec) {
-      if (auto rc = context()->get_constant_mgr()->GetConstant(
+      if (auto rc = const_mgr->GetConstant(
               element_type, EncodeIntegerAsWords(*element_type, r))) {
         result_vector_components.push_back(rc);
-        if (!context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-                rc, pos)) {
+        if (!const_mgr->BuildInstructionAndAddToModule(rc, pos)) {
           assert(false &&
                  "Failed to build and insert constant declaring instruction "
                  "for the given vector component constant");
@@ -340,10 +340,8 @@
     }
     auto new_vec_const = MakeUnique<analysis::VectorConstant>(
         result_type->AsVector(), result_vector_components);
-    auto reg_vec_const = context()->get_constant_mgr()->RegisterConstant(
-        std::move(new_vec_const));
-    return context()->get_constant_mgr()->BuildInstructionAndAddToModule(
-        reg_vec_const, pos);
+    auto reg_vec_const = const_mgr->RegisterConstant(std::move(new_vec_const));
+    return const_mgr->BuildInstructionAndAddToModule(reg_vec_const, pos);
   } else {
     // Cannot process invalid component wise operation. The result of component
     // wise operation must be of integer or bool scalar or vector of
diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp
index 7eddf7e..c98a44c 100644
--- a/test/opt/fold_spec_const_op_composite_test.cpp
+++ b/test/opt/fold_spec_const_op_composite_test.cpp
@@ -105,6 +105,209 @@
       builder.GetCode(), builder.GetCode(), /* skip_nop = */ true);
 }
 
+// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
+// CompositeExtract
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedCompositeExtract) {
+  AssemblyBuilder builder;
+  builder.AppendTypesConstantsGlobals({
+      // clang-format off
+    "%uint = OpTypeInt 32 0",
+    "%v3uint = OpTypeVector %uint 3",
+    "%uint_2 = OpConstant %uint 2",
+    "%uint_3 = OpConstant %uint 3",
+    // Folding target:
+    "%composite_0 = OpSpecConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
+    "%op_0 = OpSpecConstantOp %uint CompositeExtract %composite_0 0",
+    "%op_1 = OpSpecConstantOp %uint CompositeExtract %composite_0 1",
+    "%op_2 = OpSpecConstantOp %uint IMul %op_0 %op_1",
+    "%composite_1 = OpSpecConstantComposite %v3uint %op_0 %op_1 %op_2",
+    "%op_3 = OpSpecConstantOp %uint CompositeExtract %composite_1 0",
+    "%op_4 = OpSpecConstantOp %uint IMul %op_2 %op_3",
+      // clang-format on
+  });
+
+  std::vector<const char*> expected = {
+      // clang-format off
+        "OpCapability Shader",
+        "OpCapability Float64",
+    "%1 = OpExtInstImport \"GLSL.std.450\"",
+        "OpMemoryModel Logical GLSL450",
+        "OpEntryPoint Vertex %main \"main\"",
+        "OpName %void \"void\"",
+        "OpName %main_func_type \"main_func_type\"",
+        "OpName %main \"main\"",
+        "OpName %main_func_entry_block \"main_func_entry_block\"",
+        "OpName %uint \"uint\"",
+        "OpName %v3uint \"v3uint\"",
+        "OpName %uint_2 \"uint_2\"",
+        "OpName %uint_3 \"uint_3\"",
+        "OpName %composite_0 \"composite_0\"",
+        "OpName %op_0 \"op_0\"",
+        "OpName %op_1 \"op_1\"",
+        "OpName %op_2 \"op_2\"",
+        "OpName %composite_1 \"composite_1\"",
+        "OpName %op_3 \"op_3\"",
+        "OpName %op_4 \"op_4\"",
+    "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+    "%uint = OpTypeInt 32 0",
+  "%v3uint = OpTypeVector %uint 3",
+  "%uint_2 = OpConstant %uint 2",
+  "%uint_3 = OpConstant %uint 3",
+"%composite_0 = OpConstantComposite %v3uint %uint_2 %uint_3 %uint_2",
+    "%op_0 = OpConstant %uint 2",
+    "%op_1 = OpConstant %uint 3",
+    "%op_2 = OpConstant %uint 6",
+"%composite_1 = OpConstantComposite %v3uint %op_0 %op_1 %op_2",
+"%op_3 = OpConstant %uint 2",
+ "%op_4 = OpConstant %uint 12",
+    "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+            "OpReturn",
+            "OpFunctionEnd",
+      // clang-format on
+  };
+  SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+      builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
+// Test where OpSpecConstantOp depends on another OpSpecConstantOp with
+// VectorShuffle
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, StackedVectorShuffle) {
+  AssemblyBuilder builder;
+  builder.AppendTypesConstantsGlobals({
+      // clang-format off
+    "%uint = OpTypeInt 32 0",
+    "%v3uint = OpTypeVector %uint 3",
+    "%uint_1 = OpConstant %uint 1",
+    "%uint_2 = OpConstant %uint 2",
+    "%uint_3 = OpConstant %uint 3",
+    "%uint_4 = OpConstant %uint 4",
+    "%uint_5 = OpConstant %uint 5",
+    "%uint_6 = OpConstant %uint 6",
+    // Folding target:
+    "%composite_0 = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
+    "%composite_1 = OpSpecConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
+    "%vecshuffle = OpSpecConstantOp %v3uint VectorShuffle %composite_0 %composite_1 0 5 3",
+    "%op = OpSpecConstantOp %uint CompositeExtract %vecshuffle 1",
+      // clang-format on
+  });
+
+  std::vector<const char*> expected = {
+      // clang-format off
+        "OpCapability Shader",
+        "OpCapability Float64",
+        "%1 = OpExtInstImport \"GLSL.std.450\"",
+        "OpMemoryModel Logical GLSL450",
+        "OpEntryPoint Vertex %main \"main\"",
+        "OpName %void \"void\"",
+        "OpName %main_func_type \"main_func_type\"",
+        "OpName %main \"main\"",
+        "OpName %main_func_entry_block \"main_func_entry_block\"",
+        "OpName %uint \"uint\"",
+        "OpName %v3uint \"v3uint\"",
+        "OpName %uint_1 \"uint_1\"",
+        "OpName %uint_2 \"uint_2\"",
+        "OpName %uint_3 \"uint_3\"",
+        "OpName %uint_4 \"uint_4\"",
+        "OpName %uint_5 \"uint_5\"",
+        "OpName %uint_6 \"uint_6\"",
+        "OpName %composite_0 \"composite_0\"",
+        "OpName %composite_1 \"composite_1\"",
+        "OpName %vecshuffle \"vecshuffle\"",
+        "OpName %op \"op\"",
+    "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+    "%uint = OpTypeInt 32 0",
+  "%v3uint = OpTypeVector %uint 3",
+  "%uint_1 = OpConstant %uint 1",
+  "%uint_2 = OpConstant %uint 2",
+  "%uint_3 = OpConstant %uint 3",
+  "%uint_4 = OpConstant %uint 4",
+  "%uint_5 = OpConstant %uint 5",
+  "%uint_6 = OpConstant %uint 6",
+"%composite_0 = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_3",
+"%composite_1 = OpConstantComposite %v3uint %uint_4 %uint_5 %uint_6",
+"%vecshuffle = OpConstantComposite %v3uint %uint_1 %uint_6 %uint_4",
+      "%op = OpConstant %uint 6",
+    "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+        "OpReturn",
+        "OpFunctionEnd",
+      // clang-format on
+  };
+  SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+      builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
+// Test CompositeExtract with matrix
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeExtractMaxtrix) {
+  AssemblyBuilder builder;
+  builder.AppendTypesConstantsGlobals({
+      // clang-format off
+    "%uint = OpTypeInt 32 0",
+    "%v3uint = OpTypeVector %uint 3",
+    "%mat3x3 = OpTypeMatrix %v3uint 3",
+    "%uint_1 = OpConstant %uint 1",
+    "%uint_2 = OpConstant %uint 2",
+    "%uint_3 = OpConstant %uint 3",
+    // Folding target:
+    "%a = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
+    "%b = OpSpecConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
+    "%c = OpSpecConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
+    "%op = OpSpecConstantComposite %mat3x3 %a %b %c",
+    "%x = OpSpecConstantOp %uint CompositeExtract %op 2 1",
+    "%y = OpSpecConstantOp %uint CompositeExtract %op 1 2",
+      // clang-format on
+  });
+
+  std::vector<const char*> expected = {
+      // clang-format off
+        "OpCapability Shader",
+        "OpCapability Float64",
+   "%1 = OpExtInstImport \"GLSL.std.450\"",
+        "OpMemoryModel Logical GLSL450",
+        "OpEntryPoint Vertex %main \"main\"",
+        "OpName %void \"void\"",
+        "OpName %main_func_type \"main_func_type\"",
+        "OpName %main \"main\"",
+        "OpName %main_func_entry_block \"main_func_entry_block\"",
+        "OpName %uint \"uint\"",
+        "OpName %v3uint \"v3uint\"",
+        "OpName %mat3x3 \"mat3x3\"",
+        "OpName %uint_1 \"uint_1\"",
+        "OpName %uint_2 \"uint_2\"",
+        "OpName %uint_3 \"uint_3\"",
+        "OpName %a \"a\"",
+        "OpName %b \"b\"",
+        "OpName %c \"c\"",
+        "OpName %op \"op\"",
+        "OpName %x \"x\"",
+        "OpName %y \"y\"",
+    "%void = OpTypeVoid",
+"%main_func_type = OpTypeFunction %void",
+    "%uint = OpTypeInt 32 0",
+  "%v3uint = OpTypeVector %uint 3",
+  "%mat3x3 = OpTypeMatrix %v3uint 3",
+  "%uint_1 = OpConstant %uint 1",
+  "%uint_2 = OpConstant %uint 2",
+  "%uint_3 = OpConstant %uint 3",
+       "%a = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1",
+       "%b = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_3",
+       "%c = OpConstantComposite %v3uint %uint_1 %uint_2 %uint_1",
+      "%op = OpConstantComposite %mat3x3 %a %b %c",
+       "%x = OpConstant %uint 2",
+       "%y = OpConstant %uint 3",
+    "%main = OpFunction %void None %main_func_type",
+"%main_func_entry_block = OpLabel",
+        "OpReturn",
+        "OpFunctionEnd",
+      // clang-format on
+  };
+  SinglePassRunAndCheck<FoldSpecConstantOpAndCompositePass>(
+      builder.GetCode(), JoinAllInsts(expected), /* skip_nop = */ true);
+}
+
 // All types and some common constants that are potentially required in
 // FoldSpecConstantOpAndCompositeTest.
 std::vector<std::string> CommonTypesAndConstants() {
@@ -199,7 +402,7 @@
 struct FoldSpecConstantOpAndCompositePassTestCase {
   // Original constants with unfolded spec constants.
   std::vector<std::string> original;
-  // Expected cosntants after folding.
+  // Expected constant after folding.
   std::vector<std::string> expected;
 };