spirv-opt: Fix OpCompositeInsert with Null Constant (#5008)

* spirv-opt: Unify GetConstId function names

* spirv-opt: Fix OpCompositeInsert with Null Constant

* spirv-opt: Improve GetNullCompositeConstant description
diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp
index 19b39d6..14f2208 100644
--- a/source/opt/const_folding_rules.cpp
+++ b/source/opt/const_folding_rules.cpp
@@ -136,32 +136,38 @@
     std::vector<const analysis::Constant*> chain;
     std::vector<const analysis::Constant*> components;
     const analysis::Type* type = nullptr;
+    const uint32_t final_index = (inst->NumInOperands() - 1);
 
-    // Work down hierarchy and add all the indexes, not including the final
-    // index.
+    // Work down hierarchy of all indexes
     for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
-      if (composite->AsNullConstant()) {
-        // Return Null for the return type.
-        analysis::TypeManager* type_mgr = context->get_type_mgr();
-        return const_mgr->GetConstant(type_mgr->GetType(inst->type_id()), {});
-      }
+      type = composite->type();
 
-      if (i != inst->NumInOperands() - 1) {
-        chain.push_back(composite);
+      if (composite->AsNullConstant()) {
+        // Make new composite so it can be inserted in the index with the
+        // non-null value
+        const auto new_composite = const_mgr->GetNullCompositeConstant(type);
+        // Keep track of any indexes along the way to last index
+        if (i != final_index) {
+          chain.push_back(new_composite);
+        }
+        components = new_composite->AsCompositeConstant()->GetComponents();
+      } else {
+        // Keep track of any indexes along the way to last index
+        if (i != final_index) {
+          chain.push_back(composite);
+        }
+        components = composite->AsCompositeConstant()->GetComponents();
       }
       const uint32_t index = inst->GetSingleWordInOperand(i);
-      components = composite->AsCompositeConstant()->GetComponents();
-      type = composite->AsCompositeConstant()->type();
       composite = components[index];
     }
 
     // Final index in hierarchy is inserted with new object.
-    const uint32_t final_index =
-        inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
+    const uint32_t final_operand = inst->GetSingleWordInOperand(final_index);
     std::vector<uint32_t> ids;
     for (size_t i = 0; i < components.size(); i++) {
       const analysis::Constant* constant =
-          (i == final_index) ? object : components[i];
+          (i == final_operand) ? object : components[i];
       Instruction* member_inst = const_mgr->GetDefiningInstruction(constant);
       ids.push_back(member_inst->result_id());
     }
@@ -171,19 +177,16 @@
     for (size_t i = chain.size(); i > 0; i--) {
       // Need to insert any previous instruction into the module first.
       // Can't just insert in types_values_begin() because it will move above
-      // where the types are declared
-      for (Module::inst_iterator inst_iter = context->types_values_begin();
-           inst_iter != context->types_values_end(); ++inst_iter) {
-        Instruction* x = &*inst_iter;
-        if (inst->result_id() == x->result_id()) {
-          const_mgr->BuildInstructionAndAddToModule(new_constant, &inst_iter);
-          break;
-        }
-      }
+      // where the types are declared.
+      // Can't compare with location of inst because not all new added
+      // instructions are added to types_values_
+      auto iter = context->types_values_end();
+      Module::inst_iterator* pos = &iter;
+      const_mgr->BuildInstructionAndAddToModule(new_constant, pos);
 
       composite = chain[i - 1];
       components = composite->AsCompositeConstant()->GetComponents();
-      type = composite->AsCompositeConstant()->type();
+      type = composite->type();
       ids.clear();
       for (size_t k = 0; k < components.size(); k++) {
         const uint32_t index =
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index 9930b44..d70e27b 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -391,6 +391,43 @@
   return cst ? RegisterConstant(std::move(cst)) : nullptr;
 }
 
+const Constant* ConstantManager::GetNullCompositeConstant(const Type* type) {
+  std::vector<uint32_t> literal_words_or_id;
+
+  if (type->AsVector()) {
+    const Type* element_type = type->AsVector()->element_type();
+    const uint32_t null_id = GetNullConstId(element_type);
+    const uint32_t element_count = type->AsVector()->element_count();
+    for (uint32_t i = 0; i < element_count; i++) {
+      literal_words_or_id.push_back(null_id);
+    }
+  } else if (type->AsMatrix()) {
+    const Type* element_type = type->AsMatrix()->element_type();
+    const uint32_t null_id = GetNullConstId(element_type);
+    const uint32_t element_count = type->AsMatrix()->element_count();
+    for (uint32_t i = 0; i < element_count; i++) {
+      literal_words_or_id.push_back(null_id);
+    }
+  } else if (type->AsStruct()) {
+    // TODO (sfricke-lunarg) add proper struct support
+    return nullptr;
+  } else if (type->AsArray()) {
+    const Type* element_type = type->AsArray()->element_type();
+    const uint32_t null_id = GetNullConstId(element_type);
+    assert(type->AsArray()->length_info().words[0] ==
+               analysis::Array::LengthInfo::kConstant &&
+           "unexpected array length");
+    const uint32_t element_count = type->AsArray()->length_info().words[0];
+    for (uint32_t i = 0; i < element_count; i++) {
+      literal_words_or_id.push_back(null_id);
+    }
+  } else {
+    return nullptr;
+  }
+
+  return GetConstant(type, literal_words_or_id);
+}
+
 const Constant* ConstantManager::GetNumericVectorConstantWithWords(
     const Vector* type, const std::vector<uint32_t>& literal_words) {
   const auto* element_type = type->element_type();
@@ -445,18 +482,23 @@
   return c;
 }
 
-uint32_t ConstantManager::GetSIntConst(int32_t val) {
+uint32_t ConstantManager::GetSIntConstId(int32_t val) {
   Type* sint_type = context()->get_type_mgr()->GetSIntType();
   const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)});
   return GetDefiningInstruction(c)->result_id();
 }
 
-uint32_t ConstantManager::GetUIntConst(uint32_t val) {
+uint32_t ConstantManager::GetUIntConstId(uint32_t val) {
   Type* uint_type = context()->get_type_mgr()->GetUIntType();
   const Constant* c = GetConstant(uint_type, {val});
   return GetDefiningInstruction(c)->result_id();
 }
 
+uint32_t ConstantManager::GetNullConstId(const Type* type) {
+  const Constant* c = GetConstant(type, {});
+  return GetDefiningInstruction(c)->result_id();
+}
+
 std::vector<const analysis::Constant*> Constant::GetVectorComponents(
     analysis::ConstantManager* const_mgr) const {
   std::vector<const analysis::Constant*> components;
diff --git a/source/opt/constants.h b/source/opt/constants.h
index 588ca3e..410304e 100644
--- a/source/opt/constants.h
+++ b/source/opt/constants.h
@@ -520,6 +520,14 @@
                                                    literal_words_or_ids.end()));
   }
 
+  // Takes a type and creates a OpConstantComposite
+  // This allows a
+  // OpConstantNull %composite_type
+  // to become a
+  // OpConstantComposite %composite_type %null %null ... etc
+  // Assumes type is a Composite already, otherwise returns null
+  const Constant* GetNullCompositeConstant(const Type* type);
+
   // Gets or creates a unique Constant instance of Vector type |type| with
   // numeric elements and a vector of constant defining words |literal_words|.
   // If a Constant instance existed already in the constant pool, it returns a
@@ -649,10 +657,13 @@
   const Constant* GetDoubleConst(double val);
 
   // Returns the id of a 32-bit signed integer constant with value |val|.
-  uint32_t GetSIntConst(int32_t val);
+  uint32_t GetSIntConstId(int32_t val);
 
   // Returns the id of a 32-bit unsigned integer constant with value |val|.
-  uint32_t GetUIntConst(uint32_t val);
+  uint32_t GetUIntConstId(uint32_t val);
+
+  // Returns the id of a OpConstantNull with type of |type|.
+  uint32_t GetNullConstId(const Type* type);
 
  private:
   // Creates a Constant instance with the given type and a vector of constant
diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp
index 0ec392f..1e614c6 100644
--- a/source/opt/debug_info_manager.cpp
+++ b/source/opt/debug_info_manager.cpp
@@ -235,7 +235,8 @@
           !context()->AreAnalysesValid(IRContext::Analysis::kAnalysisConstants))
         line_number = AddNewConstInGlobals(context(), line_number);
       else
-        line_number = context()->get_constant_mgr()->GetUIntConst(line_number);
+        line_number =
+            context()->get_constant_mgr()->GetUIntConstId(line_number);
     }
   }
 
@@ -344,7 +345,7 @@
              {static_cast<uint32_t>(OpenCLDebugInfo100Deref)}},
         }));
   } else {
-    uint32_t deref_id = context()->get_constant_mgr()->GetUIntConst(
+    uint32_t deref_id = context()->get_constant_mgr()->GetUIntConstId(
         NonSemanticShaderDebugInfo100Deref);
 
     deref_operation = std::unique_ptr<Instruction>(
diff --git a/source/opt/eliminate_dead_io_components_pass.cpp b/source/opt/eliminate_dead_io_components_pass.cpp
index df59645..e430c6d 100644
--- a/source/opt/eliminate_dead_io_components_pass.cpp
+++ b/source/opt/eliminate_dead_io_components_pass.cpp
@@ -197,7 +197,7 @@
       type_mgr->GetType(arr_var.type_id())->AsPointer();
   const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray();
   assert(arr_ty && "expecting array type");
-  uint32_t length_id = const_mgr->GetUIntConst(length);
+  uint32_t length_id = const_mgr->GetUIntConstId(length);
   analysis::Array new_arr_ty(arr_ty->element_type(),
                              arr_ty->GetConstantLengthInfo(length_id, length));
   analysis::Type* reg_new_arr_ty = type_mgr->GetRegisteredType(&new_arr_ty);
diff --git a/source/opt/interface_var_sroa.cpp b/source/opt/interface_var_sroa.cpp
index 8205c75..08477cb 100644
--- a/source/opt/interface_var_sroa.cpp
+++ b/source/opt/interface_var_sroa.cpp
@@ -489,7 +489,7 @@
     Instruction* insert_before) {
   uint32_t ptr_type_id =
       GetPointerType(component_type_id, GetStorageClass(var));
-  uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
+  uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(index);
   std::unique_ptr<Instruction> new_access_chain(new Instruction(
       context(), spv::Op::OpAccessChain, ptr_type_id, TakeNextId(),
       std::initializer_list<Operand>{
@@ -781,7 +781,7 @@
     uint32_t elem_type_id, uint32_t array_length) {
   analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
   uint32_t array_length_id =
-      context()->get_constant_mgr()->GetUIntConst(array_length);
+      context()->get_constant_mgr()->GetUIntConstId(array_length);
   analysis::Array array_type(
       elem_type,
       analysis::Array::LengthInfo{array_length_id, {0, array_length}});
diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp
index 93c77d3..59745e1 100644
--- a/source/opt/replace_desc_array_access_using_var_index.cpp
+++ b/source/opt/replace_desc_array_access_using_var_index.cpp
@@ -331,7 +331,7 @@
 void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain(
     Instruction* access_chain, uint32_t const_element_idx) const {
   uint32_t const_element_idx_id =
-      context()->get_constant_mgr()->GetUIntConst(const_element_idx);
+      context()->get_constant_mgr()->GetUIntConstId(const_element_idx);
   access_chain->SetInOperand(kOpAccessChainInOperandIndexes,
                              {const_element_idx_id});
 }
diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp
index 6045158..bfebb01 100644
--- a/source/opt/scalar_replacement_pass.cpp
+++ b/source/opt/scalar_replacement_pass.cpp
@@ -191,7 +191,7 @@
     if (added_dbg_value == nullptr) return false;
     added_dbg_value->AddOperand(
         {SPV_OPERAND_TYPE_ID,
-         {context()->get_constant_mgr()->GetSIntConst(idx)}});
+         {context()->get_constant_mgr()->GetSIntConstId(idx)}});
     added_dbg_value->SetOperand(kDebugValueOperandExpressionIndex,
                                 {deref_expr->result_id()});
     if (context()->AreAnalysesValid(IRContext::Analysis::kAnalysisDefUse)) {
@@ -217,7 +217,7 @@
     // Append 'Indexes' operand.
     new_dbg_value->AddOperand(
         {SPV_OPERAND_TYPE_ID,
-         {context()->get_constant_mgr()->GetSIntConst(idx)}});
+         {context()->get_constant_mgr()->GetSIntConstId(idx)}});
     // Insert the new DebugValue to the basic block.
     auto* added_instr = dbg_value->InsertBefore(std::move(new_dbg_value));
     get_def_use_mgr()->AnalyzeInstDefUse(added_instr);
diff --git a/test/opt/fold_spec_const_op_composite_test.cpp b/test/opt/fold_spec_const_op_composite_test.cpp
index aae9eb2..f83e86e 100644
--- a/test/opt/fold_spec_const_op_composite_test.cpp
+++ b/test/opt/fold_spec_const_op_composite_test.cpp
@@ -340,6 +340,41 @@
   SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
 }
 
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertVectorIntoMatrix) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v2float = OpTypeVector %float 2
+ %mat2v2float = OpTypeMatrix %v2float 2
+    %float_0 = OpConstant %float 0
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+ %v2float_01 = OpConstantComposite %v2float %float_0 %float_1
+ %v2float_12 = OpConstantComposite %v2float %float_1 %float_2
+
+; CHECK: %10 = OpConstantComposite %v2float %float_0 %float_1
+; CHECK: %11 = OpConstantComposite %v2float %float_1 %float_2
+; CHECK: %12 = OpConstantComposite %mat2v2float %11 %11
+%mat2v2float_1212 = OpConstantComposite %mat2v2float %v2float_12 %v2float_12
+
+; CHECK: %15 = OpConstantComposite %mat2v2float %10 %11
+     %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_01 %mat2v2float_1212 0
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
 TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrix) {
   const std::string test =
       R"(
@@ -374,7 +409,7 @@
   SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
 }
 
-TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertNull) {
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertFloatNull) {
   const std::string test =
       R"(
                OpCapability Shader
@@ -384,16 +419,254 @@
        %void = OpTypeVoid
           %3 = OpTypeFunction %void
       %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_1 = OpConstant %float 1
+
+; CHECK: %7 = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %7
+; CHECK: %12 = OpConstantComposite %v3float %7 %7 %float_1
+       %null = OpConstantNull %float
+     %spec_0 = OpConstantComposite %v3float %null %null %null
+     %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 2
+
+; CHECK: %float_1_0 = OpConstant %float 1
+     %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertFloatSetNull) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_1 = OpConstant %float 1
+
+; CHECK: %7 = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1
+; CHECK: %12 = OpConstantComposite %v3float %7 %7 %7
+       %null = OpConstantNull %float
+     %spec_0 = OpConstantComposite %v3float %null %null %float_1
+     %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2
+
+; CHECK: %13 = OpConstantNull %float
+     %spec_2 = OpSpecConstantOp %float CompositeExtract %spec_1 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertVectorNull) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_1 = OpConstant %float 1
+       %null = OpConstantNull %v3float
+
+; CHECK: %11 = OpConstantNull %float
+; CHECK: %12 = OpConstantComposite %v3float %11 %11 %float_1
+     %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2
+
+
+; CHECK: %float_1_0 = OpConstant %float 1
+     %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertNullVectorIntoMatrix) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v2float = OpTypeVector %float 2
+ %mat2v2float = OpTypeMatrix %v2float 2
+       %null = OpConstantNull %mat2v2float
+    %float_1 = OpConstant %float 1
+    %float_2 = OpConstant %float 2
+ %v2float_12 = OpConstantComposite %v2float %float_1 %float_2
+
+; CHECK: %13 = OpConstantNull %v2float
+; CHECK: %14 = OpConstantComposite %mat2v2float %10 %13
+     %spec_0 = OpSpecConstantOp %mat2v2float CompositeInsert %v2float_12 %null 0
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertVectorKeepNull) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_0 = OpConstant %float 0
+ %null_float = OpConstantNull %float
+   %null_vec = OpConstantNull %v3float
+
+; CHECK: %15 = OpConstantComposite %v3float %7 %7 %float_0
+     %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_0 %null_vec 2
+
+; CHECK: %float_0_0 = OpConstant %float 0
+     %spec_1 = OpSpecConstantOp %float CompositeExtract %spec_0 2
+
+; CHECK: %17 = OpConstantComposite %v3float %7 %7 %7
+     %spec_2 = OpSpecConstantOp %v3float CompositeInsert %null_float %null_vec 2
+
+; CHECK: %18 = OpConstantNull %float
+     %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+        %add = OpFAdd %float %spec_3 %spec_3
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertVectorChainNull) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_1 = OpConstant %float 1
+       %null = OpConstantNull %v3float
+
+; CHECK: %15 = OpConstantNull %float
+; CHECK: %16 = OpConstantComposite %v3float %15 %15 %float_1
+; CHECK: %17 = OpConstantComposite %v3float %15 %float_1 %float_1
+; CHECK: %18 = OpConstantComposite %v3float %float_1 %float_1 %float_1
+     %spec_0 = OpSpecConstantOp %v3float CompositeInsert %float_1 %null 2
+     %spec_1 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_0 1
+     %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 0
+
+; CHECK: %float_1_0 = OpConstant %float 1
+; CHECK: %float_1_1 = OpConstant %float 1
+; CHECK: %float_1_2 = OpConstant %float 1
+     %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 0
+     %spec_4 = OpSpecConstantOp %float CompositeExtract %spec_2 1
+     %spec_5 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest,
+       CompositeInsertVectorChainReset) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %1 "main"
+               OpExecutionMode %1 LocalSize 1 1 1
+       %void = OpTypeVoid
+          %3 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+    %float_1 = OpConstant %float 1
+       %null = OpConstantNull %float
+; CHECK: %8 = OpConstantComposite %v3float %7 %7 %float_1
+     %spec_0 = OpConstantComposite %v3float %null %null %float_1
+
+            ; set to null
+; CHECK: %13 = OpConstantComposite %v3float %7 %7 %7
+     %spec_1 = OpSpecConstantOp %v3float CompositeInsert %null %spec_0 2
+
+            ; set to back to original value
+; CHECK: %14 = OpConstantComposite %v3float %7 %7 %float_1
+     %spec_2 = OpSpecConstantOp %v3float CompositeInsert %float_1 %spec_1 2
+
+; CHECK: %float_1_0 = OpConstant %float 1
+     %spec_3 = OpSpecConstantOp %float CompositeExtract %spec_2 2
+          %1 = OpFunction %void None %3
+      %label = OpLabel
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FoldSpecConstantOpAndCompositePass>(test, false);
+}
+
+TEST_F(FoldSpecConstantOpAndCompositePassBasicTest, CompositeInsertMatrixNull) {
+  const std::string test =
+      R"(
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+      %float = OpTypeFloat 32
+        %int = OpTypeInt 32 0
 %v2float = OpTypeVector %float 2
 %mat2v2float = OpTypeMatrix %v2float 2
 %null = OpConstantNull %mat2v2float
     %float_1 = OpConstant %float 1
-  %v2float_1 = OpConstantComposite %v2float %float_1 %float_1
-   %mat2v2_1 = OpConstantComposite %mat2v2float %v2float_1 %v2float_1
- ; CHECK: %13 = OpConstantNull %mat2v2float
-         %14 = OpSpecConstantOp %mat2v2float CompositeInsert %mat2v2_1 %null 0 0
-          %1 = OpFunction %void None %3
-         %16 = OpLabel
+ ; CHECK: %13 = OpConstantNull %v2float
+ ; CHECK: %14 = OpConstantNull %float
+ ; CHECK: %15 = OpConstantComposite %v2float %float_1 %14
+ ; CHECK: %16 = OpConstantComposite %mat2v2float %13 %15
+       %spec = OpSpecConstantOp %mat2v2float CompositeInsert %float_1 %null 1 0
+; extra type def to make sure new type def are not just thrown at end
+      %v2int = OpTypeVector %int 2
+       %main = OpFunction %void None %func
+      %label = OpLabel
                OpReturn
                OpFunctionEnd
 )";