Add validation for SPV_NV_cooperative_matrix (#2404)

diff --git a/DEPS b/DEPS
index 5668c66..3c5361f 100644
--- a/DEPS
+++ b/DEPS
@@ -11,7 +11,7 @@
   'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036',
   'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59',
   're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f',
-  'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f',
+  'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab',
 }
 
 deps = {
diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp
index 4d98e3d..79f18ee 100644
--- a/source/assembly_grammar.cpp
+++ b/source/assembly_grammar.cpp
@@ -154,10 +154,11 @@
     CASE(InBoundsAccessChain),
     CASE(PtrAccessChain),
     CASE(InBoundsPtrAccessChain),
+    CASE(CooperativeMatrixLengthNV)
 };
 
-// The 59 is determined by counting the opcodes listed in the spec.
-static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+// The 60 is determined by counting the opcodes listed in the spec.
+static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
               "OpSpecConstantOp opcode table is incomplete");
 #undef CASE
 // clang-format on
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 78c2386..da096a4 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -260,6 +260,7 @@
     case SpvOpTypeMatrix:
     case SpvOpTypeArray:
     case SpvOpTypeStruct:
+    case SpvOpTypeCooperativeMatrixNV:
       return true;
     default:
       return false;
@@ -325,6 +326,7 @@
     case SpvOpTypePipeStorage:
     case SpvOpTypeNamedBarrier:
     case SpvOpTypeAccelerationStructureNV:
+    case SpvOpTypeCooperativeMatrixNV:
       return true;
     default:
       // In particular, OpTypeForwardPointer does not generate a type,
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index 79d90bd..8106442 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -45,7 +45,8 @@
 inline bool IsTypeInst(SpvOp opcode) {
   return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) ||
          opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier ||
-         opcode == SpvOpTypeAccelerationStructureNV;
+         opcode == SpvOpTypeAccelerationStructureNV ||
+         opcode == SpvOpTypeCooperativeMatrixNV;
 }
 inline bool IsConstantInst(SpvOp opcode) {
   return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp;
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index 2314e7d..433330d 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -39,8 +39,11 @@
     case SpvOpFRem:
     case SpvOpFMod:
     case SpvOpFNegate: {
+      bool supportsCoopMat =
+          (opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
       if (!_.IsFloatScalarType(result_type) &&
-          !_.IsFloatVectorType(result_type))
+          !_.IsFloatVectorType(result_type) &&
+          !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected floating scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
@@ -58,8 +61,11 @@
 
     case SpvOpUDiv:
     case SpvOpUMod: {
+      bool supportsCoopMat = (opcode == SpvOpUDiv);
       if (!_.IsUnsignedIntScalarType(result_type) &&
-          !_.IsUnsignedIntVectorType(result_type))
+          !_.IsUnsignedIntVectorType(result_type) &&
+          !(supportsCoopMat &&
+            _.IsUnsignedIntCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected unsigned int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
@@ -82,7 +88,10 @@
     case SpvOpSMod:
     case SpvOpSRem:
     case SpvOpSNegate: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+      bool supportsCoopMat =
+          (opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
+      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+          !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
@@ -94,7 +103,8 @@
            ++operand_index) {
         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
         if (!type_id ||
-            (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
+            (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
+             !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected int scalar or vector type as operand: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -176,7 +186,8 @@
     }
 
     case SpvOpMatrixTimesScalar: {
-      if (!_.IsFloatMatrixType(result_type))
+      if (!_.IsFloatMatrixType(result_type) &&
+          !_.IsCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected float matrix type as Result Type: "
                << spvOpcodeString(opcode);
@@ -442,6 +453,92 @@
       break;
     }
 
+    case SpvOpCooperativeMatrixMulAddNV: {
+      const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+      const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+      const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+      const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+      if (!_.IsCooperativeMatrixType(A_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as A Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixType(B_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as B Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixType(C_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as C Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixType(D_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as Result Type: "
+               << spvOpcodeString(opcode);
+      }
+
+      const auto A = _.FindDef(A_type_id);
+      const auto B = _.FindDef(B_type_id);
+      const auto C = _.FindDef(C_type_id);
+      const auto D = _.FindDef(D_type_id);
+
+      std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+          A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+      A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+      B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+      C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+      D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+      A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+      B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+      C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+      D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+      A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+      B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+      C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+      D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+      const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+                               std::tuple<bool, bool, uint32_t> Y) {
+        return (std::get<1>(X) && std::get<1>(Y) &&
+                std::get<2>(X) != std::get<2>(Y));
+      };
+
+      if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+          notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+          notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix scopes must match: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+          notEqual(C_rows, D_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'M' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+          notEqual(C_cols, D_cols)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'N' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_cols, B_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'K' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+      break;
+    }
+
     default:
       break;
   }
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index ccc5587..de3210e 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -118,6 +118,10 @@
         *member_type = type_inst->word(component_index + 2);
         break;
       }
+      case SpvOpTypeCooperativeMatrixNV: {
+        *member_type = type_inst->word(2);
+        break;
+      }
       default:
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Reached non-composite type while indexes still remain to "
@@ -315,6 +319,26 @@
 
       break;
     }
+    case SpvOpTypeCooperativeMatrixNV: {
+      const auto result_type_inst = _.FindDef(result_type);
+      assert(result_type_inst);
+      const auto component_type_id =
+          result_type_inst->GetOperandAs<uint32_t>(1);
+
+      if (3 != num_operands) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected single constituent";
+      }
+
+      const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+      if (operand_type_id != component_type_id) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected Constituent type to be equal to the component type";
+      }
+
+      break;
+    }
     default: {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Expected Result Type to be a composite type";
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index e2f20f6..c413b4f 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -247,6 +247,36 @@
         }
       }
     } break;
+    case SpvOpTypeCooperativeMatrixNV: {
+      if (1 != constituent_count) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << opcode_name << " Constituent <id> '"
+               << _.getIdName(inst->type_id()) << "' count must be one.";
+      }
+      const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
+      const auto constituent = _.FindDef(constituent_id);
+      if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << opcode_name << " Constituent <id> '"
+               << _.getIdName(constituent_id)
+               << "' is not a constant or undef.";
+      }
+      const auto constituent_type = _.FindDef(constituent->type_id());
+      if (!constituent_type) {
+        return _.diag(SPV_ERROR_INVALID_ID, constituent)
+               << "Result type is not defined.";
+      }
+
+      const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
+      const auto component_type = _.FindDef(component_type_id);
+      if (!component_type || component_type->id() != constituent_type->id()) {
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
+               << opcode_name << " Constituent <id> '"
+               << _.getIdName(constituent_id)
+               << "' type does not match the Result Type <id> '"
+               << _.getIdName(result_type->id()) << "'s component type.";
+      }
+    } break;
     default:
       break;
   }
@@ -285,6 +315,7 @@
       return true;
     case SpvOpTypeArray:
     case SpvOpTypeMatrix:
+    case SpvOpTypeCooperativeMatrixNV:
     case SpvOpTypeVector: {
       auto base_type = _.FindDef(instruction[2]);
       return base_type && IsTypeNullable(base_type->words(), _);
@@ -320,7 +351,7 @@
 
   // The binary parser already ensures that the op is valid for *some*
   // environment.  Here we check restrictions.
-  switch(op) {
+  switch (op) {
     case SpvOpQuantizeToF16:
       if (!_.HasCapability(SpvCapabilityShader)) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -365,7 +396,7 @@
       }
       break;
 
-  default:
+    default:
       break;
   }
 
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index 73da582..17af9f4 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -32,22 +32,31 @@
   switch (opcode) {
     case SpvOpConvertFToU: {
       if (!_.IsUnsignedIntScalarType(result_type) &&
-          !_.IsUnsignedIntVectorType(result_type))
+          !_.IsUnsignedIntVectorType(result_type) &&
+          !_.IsUnsignedIntCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected unsigned int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type || (!_.IsFloatScalarType(input_type) &&
-                          !_.IsFloatVectorType(input_type)))
+                          !_.IsFloatVectorType(input_type) &&
+                          !_.IsFloatCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be float scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -58,22 +67,31 @@
     }
 
     case SpvOpConvertFToS: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+          !_.IsIntCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type || (!_.IsFloatScalarType(input_type) &&
-                          !_.IsFloatVectorType(input_type)))
+                          !_.IsFloatVectorType(input_type) &&
+                          !_.IsFloatCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be float scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -86,22 +104,31 @@
     case SpvOpConvertSToF:
     case SpvOpConvertUToF: {
       if (!_.IsFloatScalarType(result_type) &&
-          !_.IsFloatVectorType(result_type))
+          !_.IsFloatVectorType(result_type) &&
+          !_.IsFloatCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected float scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type ||
-          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+           !_.IsIntCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be int scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -113,22 +140,31 @@
 
     case SpvOpUConvert: {
       if (!_.IsUnsignedIntScalarType(result_type) &&
-          !_.IsUnsignedIntVectorType(result_type))
+          !_.IsUnsignedIntVectorType(result_type) &&
+          !_.IsUnsignedIntCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected unsigned int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type ||
-          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+           !_.IsIntCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be int scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -139,22 +175,31 @@
     }
 
     case SpvOpSConvert: {
-      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+      if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+          !_.IsIntCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type ||
-          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+          (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+           !_.IsIntCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be int scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -166,22 +211,31 @@
 
     case SpvOpFConvert: {
       if (!_.IsFloatScalarType(result_type) &&
-          !_.IsFloatVectorType(result_type))
+          !_.IsFloatVectorType(result_type) &&
+          !_.IsFloatCooperativeMatrixType(result_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected float scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       const uint32_t input_type = _.GetOperandTypeId(inst, 2);
       if (!input_type || (!_.IsFloatScalarType(input_type) &&
-                          !_.IsFloatVectorType(input_type)))
+                          !_.IsFloatVectorType(input_type) &&
+                          !_.IsFloatCooperativeMatrixType(input_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be float scalar or vector: "
                << spvOpcodeString(opcode);
 
-      if (_.GetDimension(result_type) != _.GetDimension(input_type))
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Expected input to have the same dimension as Result Type: "
-               << spvOpcodeString(opcode);
+      if (_.IsCooperativeMatrixType(result_type) ||
+          _.IsCooperativeMatrixType(input_type)) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      } else {
+        if (_.GetDimension(result_type) != _.GetDimension(input_type))
+          return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                 << "Expected input to have the same dimension as Result Type: "
+                 << spvOpcodeString(opcode);
+      }
 
       if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index 21a0411..cb18e13 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -167,7 +167,10 @@
           const auto opcode = inst->opcode();
           if (spvOpcodeGeneratesType(def->opcode()) &&
               !spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
-              !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) {
+              !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction &&
+              opcode != SpvOpCooperativeMatrixLengthNV &&
+              !(opcode == SpvOpSpecConstantOp &&
+                inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " cannot be a type";
@@ -177,7 +180,10 @@
                      !spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi &&
                      opcode != SpvOpExtInst && opcode != SpvOpExtInstImport &&
                      opcode != SpvOpSelectionMerge &&
-                     opcode != SpvOpLoopMerge && opcode != SpvOpFunction) {
+                     opcode != SpvOpLoopMerge && opcode != SpvOpFunction &&
+                     opcode != SpvOpCooperativeMatrixLengthNV &&
+                     !(opcode == SpvOpSpecConstantOp &&
+                       inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " requires a type";
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 9e93cf1..f6127a1 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -197,17 +197,49 @@
   return false;
 }
 
+bool ContainsCooperativeMatrix(ValidationState_t& _,
+                               const Instruction* storage) {
+  const size_t elem_type_index = 1;
+  uint32_t elem_type_id;
+  Instruction* elem_type;
+
+  switch (storage->opcode()) {
+    case SpvOpTypeCooperativeMatrixNV:
+      return true;
+    case SpvOpTypeArray:
+    case SpvOpTypeRuntimeArray:
+      elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
+      elem_type = _.FindDef(elem_type_id);
+      return ContainsCooperativeMatrix(_, elem_type);
+    case SpvOpTypeStruct:
+      for (size_t member_type_index = 1;
+           member_type_index < storage->operands().size();
+           ++member_type_index) {
+        auto member_type_id =
+            storage->GetOperandAs<uint32_t>(member_type_index);
+        auto member_type = _.FindDef(member_type_id);
+        if (ContainsCooperativeMatrix(_, member_type)) return true;
+      }
+      break;
+    default:
+      break;
+  }
+  return false;
+}
+
 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
     ValidationState_t& _, const Instruction* inst) {
   SpvStorageClass dst_sc = SpvStorageClassMax;
   SpvStorageClass src_sc = SpvStorageClassMax;
   switch (inst->opcode()) {
+    case SpvOpCooperativeMatrixLoadNV:
     case SpvOpLoad: {
       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
       auto load_pointer_type = _.FindDef(load_pointer->type_id());
       dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
       break;
     }
+    case SpvOpCooperativeMatrixStoreNV:
     case SpvOpStore: {
       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
       auto store_pointer_type = _.FindDef(store_pointer->type_id());
@@ -232,7 +264,8 @@
 }
 
 // This function is only called for OpLoad, OpStore, OpCopyMemory and
-// OpCopyMemorySized.
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
+// OpCooperativeMatrixStoreNV.
 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
   uint32_t offset = 1;
   if (mask & SpvMemoryAccessAlignedMask) ++offset;
@@ -245,6 +278,10 @@
     case SpvOpStore:
     case SpvOpCopyMemory:
       return inst->GetOperandAs<uint32_t>(2 + offset);
+    case SpvOpCooperativeMatrixLoadNV:
+      return inst->GetOperandAs<uint32_t>(5 + offset);
+    case SpvOpCooperativeMatrixStoreNV:
+      return inst->GetOperandAs<uint32_t>(4 + offset);
     default:
       assert(false && "unexpected opcode");
       break;
@@ -253,8 +290,9 @@
   return scope_id;
 }
 
-// This function is only called for OpLoad, OpStore, OpCopyMemory and
-// OpCopyMemorySized.
+// This function is only called for OpLoad, OpStore, OpCopyMemory,
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
+// OpCooperativeMatrixStoreNV.
 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
   uint32_t offset = 1;
   if (mask & SpvMemoryAccessAlignedMask) ++offset;
@@ -268,6 +306,10 @@
     case SpvOpStore:
     case SpvOpCopyMemory:
       return inst->GetOperandAs<uint32_t>(2 + offset);
+    case SpvOpCooperativeMatrixLoadNV:
+      return inst->GetOperandAs<uint32_t>(5 + offset);
+    case SpvOpCooperativeMatrixStoreNV:
+      return inst->GetOperandAs<uint32_t>(4 + offset);
     default:
       assert(false && "unexpected opcode");
       break;
@@ -302,7 +344,8 @@
 
   uint32_t mask = inst->GetOperandAs<uint32_t>(index);
   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
-    if (inst->opcode() == SpvOpLoad) {
+    if (inst->opcode() == SpvOpLoad ||
+        inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "MakePointerAvailableKHR cannot be used with OpLoad.";
     }
@@ -320,7 +363,8 @@
   }
 
   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
-    if (inst->opcode() == SpvOpStore) {
+    if (inst->opcode() == SpvOpStore ||
+        inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "MakePointerVisibleKHR cannot be used with OpStore.";
     }
@@ -672,6 +716,17 @@
     }
   }
 
+  // Cooperative matrix types can only be allocated in Function or Private
+  if ((storage_class != SpvStorageClassFunction &&
+       storage_class != SpvStorageClassPrivate) &&
+      ContainsCooperativeMatrix(_, pointee)) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Cooperative matrix types (or types containing them) can only be "
+              "allocated "
+           << "in Function or Private storage classes or as function "
+              "parameters";
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -1003,10 +1058,11 @@
     switch (type_pointee->opcode()) {
       case SpvOpTypeMatrix:
       case SpvOpTypeVector:
+      case SpvOpTypeCooperativeMatrixNV:
       case SpvOpTypeArray:
       case SpvOpTypeRuntimeArray: {
-        // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
-        // word 2 is the Element Type.
+        // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
+        // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
         type_pointee = _.FindDef(type_pointee->word(2));
         break;
       }
@@ -1136,6 +1192,140 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
+                                               const Instruction* inst) {
+  std::string instr_name =
+      "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
+
+  // Result type must be a 32-bit unsigned int.
+  auto result_type = state.FindDef(inst->type_id());
+  if (result_type->opcode() != SpvOpTypeInt ||
+      result_type->GetOperandAs<uint32_t>(1) != 32 ||
+      result_type->GetOperandAs<uint32_t>(2) != 0) {
+    return state.diag(SPV_ERROR_INVALID_ID, inst)
+           << "The Result Type of " << instr_name << " <id> '"
+           << state.getIdName(inst->id())
+           << "' must be OpTypeInt with width 32 and signedness 0.";
+  }
+
+  auto type_id = inst->GetOperandAs<uint32_t>(2);
+  auto type = state.FindDef(type_id);
+  if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+    return state.diag(SPV_ERROR_INVALID_ID, inst)
+           << "The type in " << instr_name << " <id> '"
+           << state.getIdName(type_id)
+           << "' must be OpTypeCooperativeMatrixNV.";
+  }
+  return SPV_SUCCESS;
+}
+
+spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
+                                                  const Instruction* inst) {
+  uint32_t type_id;
+  const char* opname;
+  if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
+    type_id = inst->type_id();
+    opname = "SpvOpCooperativeMatrixLoadNV";
+  } else {
+    // get Object operand's type
+    type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
+    opname = "SpvOpCooperativeMatrixStoreNV";
+  }
+
+  auto matrix_type = _.FindDef(type_id);
+
+  if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+    if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
+             << _.getIdName(type_id) << "' is not a cooperative matrix type.";
+    } else {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
+             << _.getIdName(type_id) << "' is not a cooperative matrix type.";
+    }
+  }
+
+  const bool uses_variable_pointers =
+      _.features().variable_pointers ||
+      _.features().variable_pointers_storage_buffer;
+  const auto pointer_index =
+      (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
+  const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
+  const auto pointer = _.FindDef(pointer_id);
+  if (!pointer ||
+      ((_.addressing_model() == SpvAddressingModelLogical) &&
+       ((!uses_variable_pointers &&
+         !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+        (uses_variable_pointers &&
+         !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> '" << _.getIdName(pointer_id)
+           << "' is not a logical pointer.";
+  }
+
+  const auto pointer_type_id = pointer->type_id();
+  const auto pointer_type = _.FindDef(pointer_type_id);
+  if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
+           << "' is not a pointer type.";
+  }
+
+  const auto storage_class_index = 1u;
+  const auto storage_class =
+      pointer_type->GetOperandAs<uint32_t>(storage_class_index);
+
+  if (storage_class != SpvStorageClassWorkgroup &&
+      storage_class != SpvStorageClassStorageBuffer &&
+      storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " storage class for pointer type <id> '"
+           << _.getIdName(pointer_type_id)
+           << "' is not Workgroup or StorageBuffer.";
+  }
+
+  const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
+  const auto pointee_type = _.FindDef(pointee_id);
+  if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
+                         _.IsFloatScalarOrVectorType(pointee_id))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> '" << _.getIdName(pointer->id())
+           << "'s Type must be a scalar or vector type.";
+  }
+
+  const auto stride_index =
+      (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
+  const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
+  const auto stride = _.FindDef(stride_id);
+  if (!stride || !_.IsIntScalarType(stride->type_id())) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Stride operand <id> '" << _.getIdName(stride_id)
+           << "' must be a scalar integer type.";
+  }
+
+  const auto colmajor_index =
+      (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
+  const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
+  const auto colmajor = _.FindDef(colmajor_id);
+  if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
+      !(spvOpcodeIsConstant(colmajor->opcode()) ||
+        spvOpcodeIsSpecConstant(colmajor->opcode()))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Column Major operand <id> '" << _.getIdName(colmajor_id)
+           << "' must be a boolean constant instruction.";
+  }
+
+  const auto memory_access_index =
+      (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
+  if (inst->operands().size() > memory_access_index) {
+    if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
+      return error;
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
@@ -1164,6 +1354,14 @@
     case SpvOpArrayLength:
       if (auto error = ValidateArrayLength(_, inst)) return error;
       break;
+    case SpvOpCooperativeMatrixLoadNV:
+    case SpvOpCooperativeMatrixStoreNV:
+      if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
+        return error;
+      break;
+    case SpvOpCooperativeMatrixLengthNV:
+      if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
+      break;
     case SpvOpImageTexelPointer:
     case SpvOpGenericPtrMemSemantics:
     default:
diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp
index b640131..2223a77 100644
--- a/source/val/validate_scopes.cpp
+++ b/source/val/validate_scopes.cpp
@@ -36,11 +36,19 @@
   }
 
   if (!is_const_int32) {
-    if (_.HasCapability(SpvCapabilityShader)) {
+    if (_.HasCapability(SpvCapabilityShader) &&
+        !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Scope ids must be OpConstant when Shader capability is "
              << "present";
     }
+    if (_.HasCapability(SpvCapabilityShader) &&
+        _.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
+        !spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Scope ids must be constant or specialization constant when "
+             << "CooperativeMatrixNV capability is present";
+    }
     return SPV_SUCCESS;
   }
 
@@ -130,11 +138,19 @@
   }
 
   if (!is_const_int32) {
-    if (_.HasCapability(SpvCapabilityShader)) {
+    if (_.HasCapability(SpvCapabilityShader) &&
+        !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << "Scope ids must be OpConstant when Shader capability is "
              << "present";
     }
+    if (_.HasCapability(SpvCapabilityShader) &&
+        _.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
+        !spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Scope ids must be constant or specialization constant when "
+             << "CooperativeMatrixNV capability is present";
+    }
     return SPV_SUCCESS;
   }
 
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index a5428d7..ad72a37 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -381,6 +381,53 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
+                                             const Instruction* inst) {
+  const auto component_type_index = 1;
+  const auto component_type_id =
+      inst->GetOperandAs<uint32_t>(component_type_index);
+  const auto component_type = _.FindDef(component_type_id);
+  if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
+                          SpvOpTypeInt != component_type->opcode())) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpTypeCooperativeMatrixNV Component Type <id> '"
+           << _.getIdName(component_type_id)
+           << "' is not a scalar numerical type.";
+  }
+
+  const auto scope_index = 2;
+  const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
+  const auto scope = _.FindDef(scope_id);
+  if (!scope || !_.IsIntScalarType(scope->type_id()) ||
+      !spvOpcodeIsConstant(scope->opcode())) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
+           << "' is not a constant instruction with scalar integer type.";
+  }
+
+  const auto rows_index = 3;
+  const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
+  const auto rows = _.FindDef(rows_id);
+  if (!rows || !_.IsIntScalarType(rows->type_id()) ||
+      !spvOpcodeIsConstant(rows->opcode())) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
+           << "' is not a constant instruction with scalar integer type.";
+  }
+
+  const auto cols_index = 4;
+  const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
+  const auto cols = _.FindDef(cols_id);
+  if (!cols || !_.IsIntScalarType(cols->type_id()) ||
+      !spvOpcodeIsConstant(cols->opcode())) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
+           << "' is not a constant instruction with scalar integer type.";
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
@@ -416,6 +463,9 @@
     case SpvOpTypeForwardPointer:
       if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
       break;
+    case SpvOpTypeCooperativeMatrixNV:
+      if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
+      break;
     default:
       break;
   }
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 2633963..e6e5e26 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -610,6 +610,9 @@
     case SpvOpTypeMatrix:
       return GetComponentType(inst->word(2));
 
+    case SpvOpTypeCooperativeMatrixNV:
+      return inst->word(2);
+
     default:
       break;
   }
@@ -634,6 +637,10 @@
     case SpvOpTypeMatrix:
       return inst->word(3);
 
+    case SpvOpTypeCooperativeMatrixNV:
+      // Actual dimension isn't known, return 0
+      return 0;
+
     default:
       break;
   }
@@ -862,6 +869,86 @@
   return true;
 }
 
+bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  assert(inst);
+  return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
+}
+
+bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
+  if (!IsCooperativeMatrixType(id)) return false;
+  return IsFloatScalarType(FindDef(id)->word(2));
+}
+
+bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
+  if (!IsCooperativeMatrixType(id)) return false;
+  return IsIntScalarType(FindDef(id)->word(2));
+}
+
+bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
+  if (!IsCooperativeMatrixType(id)) return false;
+  return IsUnsignedIntScalarType(FindDef(id)->word(2));
+}
+
+spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
+    const Instruction* inst, uint32_t m1, uint32_t m2) {
+  const auto m1_type = FindDef(m1);
+  const auto m2_type = FindDef(m2);
+
+  if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
+      m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+    return diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected cooperative matrix types";
+  }
+
+  uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
+  uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
+  uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
+
+  uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
+  uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
+  uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
+
+  bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
+       m2_is_const_int32 = false;
+  uint32_t m1_value = 0, m2_value = 0;
+
+  std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+      EvalInt32IfConst(m1_scope_id);
+  std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+      EvalInt32IfConst(m2_scope_id);
+
+  if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+    return diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected scopes of Matrix and Result Type to be "
+           << "identical";
+  }
+
+  std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+      EvalInt32IfConst(m1_rows_id);
+  std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+      EvalInt32IfConst(m2_rows_id);
+
+  if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+    return diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected rows of Matrix type and Result Type to be "
+           << "identical";
+  }
+
+  std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+      EvalInt32IfConst(m1_cols_id);
+  std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+      EvalInt32IfConst(m2_cols_id);
+
+  if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+    return diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Expected columns of Matrix type and Result Type to be "
+           << "identical";
+  }
+
+  return SPV_SUCCESS;
+}
+
 uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
                                              size_t operand_index) const {
   return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
@@ -890,7 +977,7 @@
 }
 
 std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
-    uint32_t id) {
+    uint32_t id) const {
   const Instruction* const inst = FindDef(id);
   assert(inst);
   const uint32_t type = inst->type_id();
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 55005a6..94fa945 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -552,6 +552,10 @@
   bool IsBoolVectorType(uint32_t id) const;
   bool IsBoolScalarOrVectorType(uint32_t id) const;
   bool IsPointerType(uint32_t id) const;
+  bool IsCooperativeMatrixType(uint32_t id) const;
+  bool IsFloatCooperativeMatrixType(uint32_t id) const;
+  bool IsIntCooperativeMatrixType(uint32_t id) const;
+  bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
 
   // Gets value from OpConstant and OpSpecConstant as uint64.
   // Returns false on failure (no instruction, wrong instruction, not int).
@@ -635,7 +639,7 @@
   // Returns tuple <is_int32, is_const_int32, value>.
   // OpSpecConstant* return |is_const_int32| as false since their values cannot
   // be relied upon during validation.
-  std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id);
+  std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id) const;
 
   // Returns the disassembly string for the given instruction.
   std::string Disassemble(const Instruction& inst) const;
@@ -643,6 +647,12 @@
   // Returns the disassembly string for the given instruction.
   std::string Disassemble(const uint32_t* words, uint16_t num_words) const;
 
+  // Returns whether type m1 and type m2 are cooperative matrices with
+  // the same "shape" (matching scope, rows, cols). If any are specialization
+  // constants, we assume they can match because we can't prove they don't.
+  spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst,
+                                            uint32_t m1, uint32_t m2);
+
  private:
   ValidationState_t(const ValidationState_t&);
 
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 87e006c..b82fc97 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1165,6 +1165,150 @@
                 "vector size of the right operand: OuterProduct"));
 }
 
+std::string GenerateCoopMatCode(const std::string& extra_types,
+                                const std::string& main_body) {
+  const std::string prefix =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_16 = OpConstant %u32 16
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
+%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u32_1 = OpConstant %u32 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%u32_c1 = OpSpecConstant %u32 1
+%u32_c2 = OpSpecConstant %u32 2
+
+%f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
+%f16matc_1 = OpConstantComposite %f16matc %f16_1
+
+%mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
+%mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
+%mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
+%f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
+%f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
+%f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
+
+  const std::string func_begin =
+      R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+  const std::string suffix =
+      R"(
+OpReturn
+OpFunctionEnd)";
+
+  return prefix + extra_types + func_begin + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, CoopMatSuccess) {
+  const std::string body = R"(
+%val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
+%val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
+%val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
+%val4 = OpFNegate %f16mat %f16mat_1
+%val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
+%val6 = OpISub %u32mat %u32mat_1 %u32mat_1
+%val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
+%val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
+%val9 = OpISub %s32mat %s32mat_1 %s32mat_1
+%val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
+%val11 = OpSNegate %s32mat %s32mat_1
+%val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
+%val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
+%val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
+%val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
+%val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, CoopMatFMulFail) {
+  const std::string body = R"(
+%val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Expected floating scalar or vector type as Result Type: FMul"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
+  const std::string body = R"(
+%val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected scalar operand type to be equal to the component "
+                "type of the matrix operand: MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatScopeFail) {
+  const std::string types = R"(
+%workgroup = OpConstant %u32 2
+
+%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
+%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
+)";
+
+  const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatDimFail) {
+  const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
+}
+
 TEST_F(ValidateArithmetics, IAddCarrySuccess) {
   const std::string body = R"(
 %val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp
index bf7f15d..db6ff5b 100644
--- a/test/val/val_composites_test.cpp
+++ b/test/val/val_composites_test.cpp
@@ -1467,6 +1467,83 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%f32_1 = OpConstant %f32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpConstantComposite Constituent <id> '11[%float_1]' type does "
+                "not match the Result Type <id> '10[%10]'s component type."));
+}
+
+TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%f32_1 = OpConstant %f32 1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected Constituent type to be equal to the component type"));
+}
+
 TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
   const std::string spirv = R"(
 OpCapability Shader
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 5e4ad49..f905657 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -1184,6 +1184,172 @@
                 "GenericCastToPtrExplicit"));
 }
 
+TEST_F(ValidateConversion, CoopMatConversionSuccess) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
+%u16mat = OpTypeCooperativeMatrixNV %u16 %subgroup %u32_8 %u32_8
+%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
+%s16mat = OpTypeCooperativeMatrixNV %s16 %subgroup %u32_8 %u32_8
+%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u16_1 = OpConstant %u16 1
+%u32_1 = OpConstant %u32 1
+%s16_1 = OpConstant %s16 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%f32mat_1 = OpConstantComposite %f32mat %f32_1
+%u16mat_1 = OpConstantComposite %u16mat %u16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s16mat_1 = OpConstantComposite %s16mat %s16_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val11 = OpConvertFToU %u16mat %f16mat_1
+%val12 = OpConvertFToU %u32mat %f16mat_1
+%val13 = OpConvertFToS %s16mat %f16mat_1
+%val14 = OpConvertFToS %s32mat %f16mat_1
+%val15 = OpFConvert %f32mat %f16mat_1
+
+%val21 = OpConvertFToU %u16mat %f32mat_1
+%val22 = OpConvertFToU %u32mat %f32mat_1
+%val23 = OpConvertFToS %s16mat %f32mat_1
+%val24 = OpConvertFToS %s32mat %f32mat_1
+%val25 = OpFConvert %f16mat %f32mat_1
+
+%val31 = OpConvertUToF %f16mat %u16mat_1
+%val32 = OpConvertUToF %f32mat %u16mat_1
+%val33 = OpUConvert %u32mat %u16mat_1
+%val34 = OpSConvert %s32mat %u16mat_1
+
+%val41 = OpConvertSToF %f16mat %s16mat_1
+%val42 = OpConvertSToF %f32mat %s16mat_1
+%val43 = OpUConvert %u32mat %s16mat_1
+%val44 = OpSConvert %s32mat %s16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateConversion, CoopMatConversionShapesMismatchFail) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val15 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Expected rows of Matrix type and Result Type to be identical"));
+}
+
+TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpSpecConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val15 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 TEST_F(ValidateConversion, BitcastSuccess) {
   const std::string body = R"(
 %ptr = OpVariable %f32ptr_func Function
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index b567a7b..246b85e 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -1774,6 +1774,372 @@
       HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable"));
 }
 
+std::string GenCoopMatLoadStoreShader(const std::string& storeMemoryAccess,
+                                      const std::string& loadMemoryAccess) {
+  std::string s = R"(
+OpCapability Shader
+OpCapability GroupNonUniform
+OpCapability VulkanMemoryModelKHR
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpExtension "SPV_NV_cooperative_matrix"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint GLCompute %4 "main" %11 %21
+OpExecutionMode %4 LocalSize 1 1 1
+OpDecorate %11 BuiltIn SubgroupId
+OpDecorate %21 BuiltIn WorkgroupId
+OpDecorate %74 ArrayStride 4
+OpMemberDecorate %75 0 Offset 0
+OpDecorate %75 Block
+OpDecorate %77 DescriptorSet 0
+OpDecorate %77 Binding 0
+OpDecorate %92 ArrayStride 4
+OpMemberDecorate %93 0 Offset 0
+OpDecorate %93 Block
+OpDecorate %95 DescriptorSet 0
+OpDecorate %95 Binding 1
+OpDecorate %102 ArrayStride 4
+OpMemberDecorate %103 0 Offset 0
+OpDecorate %103 Block
+OpDecorate %105 DescriptorSet 0
+OpDecorate %105 Binding 2
+OpDecorate %117 ArrayStride 4
+OpMemberDecorate %118 0 Offset 0
+OpDecorate %118 Block
+OpDecorate %120 DescriptorSet 0
+OpDecorate %120 Binding 3
+OpDecorate %123 SpecId 2
+OpDecorate %124 SpecId 3
+OpDecorate %125 SpecId 4
+OpDecorate %126 SpecId 5
+OpDecorate %127 SpecId 0
+OpDecorate %128 SpecId 1
+OpDecorate %129 BuiltIn WorkgroupSize
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 0
+%7 = OpTypeVector %6 2
+%8 = OpTypePointer Function %7
+%10 = OpTypePointer Input %6
+%11 = OpVariable %10 Input
+%13 = OpConstant %6 2
+%19 = OpTypeVector %6 3
+%20 = OpTypePointer Input %19
+%21 = OpVariable %20 Input
+%27 = OpConstantComposite %7 %13 %13
+%31 = OpTypePointer Function %6
+%33 = OpConstant %6 1024
+%34 = OpConstant %6 1
+%38 = OpConstant %6 8
+%39 = OpConstant %6 0
+%68 = OpTypeFloat 32
+%69 = OpConstant %6 16
+%70 = OpConstant %6 3
+%71 = OpTypeCooperativeMatrixNV %68 %70 %69 %38
+%72 = OpTypePointer Function %71
+%74 = OpTypeRuntimeArray %68
+%75 = OpTypeStruct %74
+%76 = OpTypePointer StorageBuffer %75
+%77 = OpVariable %76 StorageBuffer
+%78 = OpTypeInt 32 1
+%79 = OpConstant %78 0
+%81 = OpConstant %6 5
+%82 = OpTypePointer StorageBuffer %68
+%84 = OpConstant %6 64
+%85 = OpTypeBool
+%86 = OpConstantFalse %85
+%88 = OpTypePointer Private %71
+%89 = OpVariable %88 Private
+%92 = OpTypeRuntimeArray %68
+%93 = OpTypeStruct %92
+%94 = OpTypePointer StorageBuffer %93
+%95 = OpVariable %94 StorageBuffer
+%99 = OpVariable %88 Private
+%102 = OpTypeRuntimeArray %68
+%103 = OpTypeStruct %102
+%104 = OpTypePointer StorageBuffer %103
+%105 = OpVariable %104 StorageBuffer
+%109 = OpVariable %88 Private
+%111 = OpVariable %88 Private
+%112 = OpSpecConstantOp %6 CooperativeMatrixLengthNV %71
+%113 = OpSpecConstantOp %78 IAdd %112 %79
+%117 = OpTypeRuntimeArray %68
+%118 = OpTypeStruct %117
+%119 = OpTypePointer StorageBuffer %118
+%120 = OpVariable %119 StorageBuffer
+%123 = OpSpecConstant %78 1
+%124 = OpSpecConstant %78 1
+%125 = OpSpecConstant %78 1
+%126 = OpSpecConstant %78 1
+%127 = OpSpecConstant %6 1
+%128 = OpSpecConstant %6 1
+%129 = OpSpecConstantComposite %19 %127 %128 %34
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%9 = OpVariable %8 Function
+%18 = OpVariable %8 Function
+%32 = OpVariable %31 Function
+%44 = OpVariable %31 Function
+%52 = OpVariable %31 Function
+%60 = OpVariable %31 Function
+%73 = OpVariable %72 Function
+%91 = OpVariable %72 Function
+%101 = OpVariable %72 Function
+%12 = OpLoad %6 %11
+%14 = OpUMod %6 %12 %13
+%15 = OpLoad %6 %11
+%16 = OpUDiv %6 %15 %13
+%17 = OpCompositeConstruct %7 %14 %16
+OpStore %9 %17
+%22 = OpLoad %19 %21
+%23 = OpVectorShuffle %7 %22 %22 0 1
+%24 = OpCompositeExtract %6 %23 0
+%25 = OpCompositeExtract %6 %23 1
+%26 = OpCompositeConstruct %7 %24 %25
+%28 = OpIMul %7 %26 %27
+%29 = OpLoad %7 %9
+%30 = OpIAdd %7 %28 %29
+OpStore %18 %30
+%35 = OpAccessChain %31 %18 %34
+%36 = OpLoad %6 %35
+%37 = OpIMul %6 %33 %36
+%40 = OpAccessChain %31 %18 %39
+%41 = OpLoad %6 %40
+%42 = OpIMul %6 %38 %41
+%43 = OpIAdd %6 %37 %42
+OpStore %32 %43
+%45 = OpAccessChain %31 %18 %34
+%46 = OpLoad %6 %45
+%47 = OpIMul %6 %33 %46
+%48 = OpAccessChain %31 %18 %39
+%49 = OpLoad %6 %48
+%50 = OpIMul %6 %38 %49
+%51 = OpIAdd %6 %47 %50
+OpStore %44 %51
+%53 = OpAccessChain %31 %18 %34
+%54 = OpLoad %6 %53
+%55 = OpIMul %6 %33 %54
+%56 = OpAccessChain %31 %18 %39
+%57 = OpLoad %6 %56
+%58 = OpIMul %6 %38 %57
+%59 = OpIAdd %6 %55 %58
+OpStore %52 %59
+%61 = OpAccessChain %31 %18 %34
+%62 = OpLoad %6 %61
+%63 = OpIMul %6 %33 %62
+%64 = OpAccessChain %31 %18 %39
+%65 = OpLoad %6 %64
+%66 = OpIMul %6 %38 %65
+%67 = OpIAdd %6 %63 %66
+OpStore %60 %67
+%80 = OpLoad %6 %32
+%83 = OpAccessChain %82 %77 %79 %80
+%87 = OpCooperativeMatrixLoadNV %71 %83 %84 %86 )" +
+                  loadMemoryAccess + R"( %81
+OpStore %73 %87
+%90 = OpLoad %71 %73
+OpStore %89 %90
+%96 = OpLoad %6 %44
+%97 = OpAccessChain %82 %95 %79 %96
+%98 = OpCooperativeMatrixLoadNV %71 %97 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
+OpStore %91 %98
+%100 = OpLoad %71 %91
+OpStore %99 %100
+%106 = OpLoad %6 %52
+%107 = OpAccessChain %82 %105 %79 %106
+%108 = OpCooperativeMatrixLoadNV %71 %107 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
+OpStore %101 %108
+%110 = OpLoad %71 %101
+OpStore %109 %110
+%114 = OpConvertSToF %68 %113
+%115 = OpCompositeConstruct %71 %114
+OpStore %111 %115
+%116 = OpLoad %71 %111
+%121 = OpLoad %6 %60
+%122 = OpAccessChain %82 %120 %79 %121
+OpCooperativeMatrixStoreNV %122 %116 %84 %86 )" + storeMemoryAccess + R"( %81
+OpReturn
+OpFunctionEnd
+)";
+
+  return s;
+}
+
+TEST_F(ValidateMemory, CoopMatLoadStoreSuccess) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+                                "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidateMemory, CoopMatStoreMemoryAccessFail) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
+                                "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
+}
+
+TEST_F(ValidateMemory, CoopMatLoadMemoryAccessFail) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+                                "MakePointerAvailableKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
+}
+
+TEST_F(ValidateMemory, CoopMatInvalidStorageClassFail) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%str = OpTypeStruct %f16mat
+%str_ptr = OpTypePointer Workgroup %str
+%sh = OpVariable %str_ptr Workgroup
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Cooperative matrix types (or types containing them) can only be "
+          "allocated in Function or Private storage classes or as function "
+          "parameters"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthResultTypeBad) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %i32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The Result Type of OpCooperativeMatrixLengthNV <id> "
+                "'11[%11]' must be OpTypeInt with width 32 and signedness 0"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthOperandTypeBad) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %u32 %u32
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The type in OpCooperativeMatrixLengthNV <id> '5[%uint]' "
+                "must be OpTypeCooperativeMatrixNV"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthGood) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %u32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
   std::string spirv = R"(
 OpCapability Shader