SPV_KHR_cooperative_matrix (#5286)

* SPV_KHR_cooperative_matrix

* Update DEPS with headers

* Update according to review recommendations

* Bugfix and formatting

* Formatting missed or damaged by VS2022
diff --git a/DEPS b/DEPS
index 284797f..5727bad 100644
--- a/DEPS
+++ b/DEPS
@@ -13,7 +13,7 @@
   'protobuf_revision': 'v21.12',
 
   're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae',
-  'spirv_headers_revision': '10db9d4e194246a020a4148e220837ac7c68cfd9',
+  'spirv_headers_revision': '3469b164e25cee24435029a569933cb42578db5d',
 }
 
 deps = {
diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h
index 542b745..8ecaf0a 100644
--- a/include/spirv-tools/libspirv.h
+++ b/include/spirv-tools/libspirv.h
@@ -285,6 +285,13 @@
   // An optional packed vector format
   SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
 
+  // Concrete operand types for cooperative matrix.
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
+  // An optional cooperative matrix operands
+  SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
+  SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
+
   // This is a sentinel value, and does not represent an operand type.
   // It should come last.
   SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp
index 6df823e..56c7964 100644
--- a/source/assembly_grammar.cpp
+++ b/source/assembly_grammar.cpp
@@ -154,11 +154,12 @@
     CASE(InBoundsAccessChain),
     CASE(PtrAccessChain),
     CASE(InBoundsPtrAccessChain),
-    CASE(CooperativeMatrixLengthNV)
+    CASE(CooperativeMatrixLengthNV),
+    CASE(CooperativeMatrixLengthKHR)
 };
 
 // The 60 is determined by counting the opcodes listed in the spec.
-static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
               "OpSpecConstantOp opcode table is incomplete");
 #undef CASE
 // clang-format on
diff --git a/source/binary.cpp b/source/binary.cpp
index beb56be..207d4a9 100644
--- a/source/binary.cpp
+++ b/source/binary.cpp
@@ -691,7 +691,9 @@
     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
     case SPV_OPERAND_TYPE_SELECTION_CONTROL:
     case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
-    case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: {
+    case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
       // This operand is a mask.
 
       // Map an optional operand type to its corresponding concrete type.
@@ -699,6 +701,8 @@
         parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
       else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
         parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+      if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
+        parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
 
       // Check validity of set mask bits. Also prepare for operands for those
       // masks if they have any.  To get operand order correct, scan from
diff --git a/source/opcode.cpp b/source/opcode.cpp
index d26024a..ffbb2e8 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -274,6 +274,7 @@
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeStruct:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return true;
     default:
       return false;
@@ -340,6 +341,7 @@
     case spv::Op::OpTypeNamedBarrier:
     case spv::Op::OpTypeAccelerationStructureNV:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     // case spv::Op::OpTypeAccelerationStructureKHR: covered by
     // spv::Op::OpTypeAccelerationStructureNV
     case spv::Op::OpTypeRayQueryKHR:
diff --git a/source/operand.cpp b/source/operand.cpp
index 31a6c59..a78191b 100644
--- a/source/operand.cpp
+++ b/source/operand.cpp
@@ -236,6 +236,13 @@
     case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
     case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
       return "packed vector format";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
+      return "cooperative matrix operands";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+      return "cooperative matrix layout";
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
+      return "cooperative matrix use";
     case SPV_OPERAND_TYPE_IMAGE:
     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
       return "image";
@@ -369,6 +376,8 @@
     case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
     case SPV_OPERAND_TYPE_OVERFLOW_MODES:
     case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
       return true;
     default:
       break;
@@ -387,6 +396,7 @@
     case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
     case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
     case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
       return true;
     default:
       break;
@@ -405,6 +415,7 @@
     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
     case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
     case SPV_OPERAND_TYPE_OPTIONAL_CIV:
       return true;
     default:
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 1b1aead..2dcc259 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -423,6 +423,23 @@
               {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
       break;
     }
+    case Type::kCooperativeMatrixKHR: {
+      auto coop_mat = type->AsCooperativeMatrixKHR();
+      uint32_t const component_type =
+          GetTypeInstruction(coop_mat->component_type());
+      if (component_type == 0) {
+        return 0;
+      }
+      typeInst = MakeUnique<Instruction>(
+          context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
+          std::initializer_list<Operand>{
+              {SPV_OPERAND_TYPE_ID, {component_type}},
+              {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
+              {SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
+      break;
+    }
     default:
       assert(false && "Unexpected type");
       break;
@@ -628,6 +645,14 @@
           cm_type->columns_id());
       break;
     }
+    case Type::kCooperativeMatrixKHR: {
+      const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
+      const Type* component_type = cm_type->component_type();
+      rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
+          RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
+          cm_type->columns_id(), cm_type->use_id());
+      break;
+    }
     default:
       assert(false && "Unhandled type");
       return nullptr;
@@ -863,6 +888,12 @@
                                      inst.GetSingleWordInOperand(2),
                                      inst.GetSingleWordInOperand(3));
       break;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
+      type = new CooperativeMatrixKHR(
+          GetType(inst.GetSingleWordInOperand(0)),
+          inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
+          inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
+      break;
     case spv::Op::OpTypeRayQueryKHR:
       type = new RayQueryKHR();
       break;
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 49eec9b..b18b8cb 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -128,6 +128,7 @@
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -175,6 +176,7 @@
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -230,6 +232,7 @@
     DeclareKindCase(NamedBarrier);
     DeclareKindCase(AccelerationStructureNV);
     DeclareKindCase(CooperativeMatrixNV);
+    DeclareKindCase(CooperativeMatrixKHR);
     DeclareKindCase(RayQueryKHR);
     DeclareKindCase(HitObjectNV);
 #undef DeclareKindCase
@@ -708,6 +711,45 @@
          columns_id_ == mt->columns_id_ && HasSameDecorations(that);
 }
 
+CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
+                                           const uint32_t scope,
+                                           const uint32_t rows,
+                                           const uint32_t columns,
+                                           const uint32_t use)
+    : Type(kCooperativeMatrixKHR),
+      component_type_(type),
+      scope_id_(scope),
+      rows_id_(rows),
+      columns_id_(columns),
+      use_id_(use) {
+  assert(type != nullptr);
+  assert(scope != 0);
+  assert(rows != 0);
+  assert(columns != 0);
+}
+
+std::string CooperativeMatrixKHR::str() const {
+  std::ostringstream oss;
+  oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
+      << ", " << columns_id_ << ", " << use_id_ << ">";
+  return oss.str();
+}
+
+size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
+                                                   SeenTypes* seen) const {
+  hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
+  return component_type_->ComputeHashValue(hash, seen);
+}
+
+bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
+                                      IsSameCache* seen) const {
+  const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
+  if (!mt) return false;
+  return component_type_->IsSameImpl(mt->component_type_, seen) &&
+         scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
+         columns_id_ == mt->columns_id_ && HasSameDecorations(that);
+}
+
 }  // namespace analysis
 }  // namespace opt
 }  // namespace spvtools
diff --git a/source/opt/types.h b/source/opt/types.h
index 26c058c..16a948c 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -60,6 +60,7 @@
 class NamedBarrier;
 class AccelerationStructureNV;
 class CooperativeMatrixNV;
+class CooperativeMatrixKHR;
 class RayQueryKHR;
 class HitObjectNV;
 
@@ -100,6 +101,7 @@
     kNamedBarrier,
     kAccelerationStructureNV,
     kCooperativeMatrixNV,
+    kCooperativeMatrixKHR,
     kRayQueryKHR,
     kHitObjectNV,
     kLast
@@ -201,6 +203,7 @@
   DeclareCastMethod(NamedBarrier)
   DeclareCastMethod(AccelerationStructureNV)
   DeclareCastMethod(CooperativeMatrixNV)
+  DeclareCastMethod(CooperativeMatrixKHR)
   DeclareCastMethod(RayQueryKHR)
   DeclareCastMethod(HitObjectNV)
 #undef DeclareCastMethod
@@ -624,6 +627,38 @@
   const uint32_t columns_id_;
 };
 
+class CooperativeMatrixKHR : public Type {
+ public:
+  CooperativeMatrixKHR(const Type* type, const uint32_t scope,
+                       const uint32_t rows, const uint32_t columns,
+                       const uint32_t use);
+  CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
+
+  std::string str() const override;
+
+  CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
+  const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
+    return this;
+  }
+
+  size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+  const Type* component_type() const { return component_type_; }
+  uint32_t scope_id() const { return scope_id_; }
+  uint32_t rows_id() const { return rows_id_; }
+  uint32_t columns_id() const { return columns_id_; }
+  uint32_t use_id() const { return use_id_; }
+
+ private:
+  bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+  const Type* component_type_;
+  const uint32_t scope_id_;
+  const uint32_t rows_id_;
+  const uint32_t columns_id_;
+  const uint32_t use_id_;
+};
+
 #define DefineParameterlessType(type, name)                                \
   class type : public Type {                                               \
    public:                                                                 \
diff --git a/source/text.cpp b/source/text.cpp
index 9c77422..eb7f96b 100644
--- a/source/text.cpp
+++ b/source/text.cpp
@@ -402,7 +402,8 @@
     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
     case SPV_OPERAND_TYPE_SELECTION_CONTROL:
     case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
-    case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: {
+    case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+    case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
       uint32_t value;
       if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
         return context->diagnostic(error)
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index 4e7dd5e..b608a85 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -42,14 +42,29 @@
            opcode != spv::Op::OpFMod);
       if (!_.IsFloatScalarType(result_type) &&
           !_.IsFloatVectorType(result_type) &&
-          !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
+          !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
+          !(opcode == spv::Op::OpFMul &&
+            _.IsCooperativeMatrixKHRType(result_type) &&
+            _.IsFloatCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected floating scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
 
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
-        if (_.GetOperandTypeId(inst, operand_index) != result_type)
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsFloatCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected arithmetic operands to be of Result Type: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -71,7 +86,19 @@
 
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
-        if (_.GetOperandTypeId(inst, operand_index) != result_type)
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected arithmetic operands to be of Result Type: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -91,7 +118,10 @@
           (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
            opcode != spv::Op::OpSMod);
       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
-          !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
+          !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+          !(opcode == spv::Op::OpIMul &&
+            _.IsCooperativeMatrixKHRType(result_type) &&
+            _.IsIntCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected int scalar or vector type as Result Type: "
                << spvOpcodeString(opcode);
@@ -102,9 +132,26 @@
       for (size_t operand_index = 2; operand_index < inst->operands().size();
            ++operand_index) {
         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+
+        if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+          if (!_.IsCooperativeMatrixKHRType(type_id) ||
+              !_.IsIntCooperativeMatrixType(type_id)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                   << "Expected arithmetic operands to be of Result Type: "
+                   << spvOpcodeString(opcode) << " operand index "
+                   << operand_index;
+          }
+          spv_result_t ret =
+              _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+          if (ret != SPV_SUCCESS) return ret;
+        }
+
         if (!type_id ||
             (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
-             !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
+             !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+             !(opcode == spv::Op::OpIMul &&
+               _.IsCooperativeMatrixKHRType(result_type) &&
+               _.IsIntCooperativeMatrixType(result_type))))
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << "Expected int scalar or vector type as operand: "
                  << spvOpcodeString(opcode) << " operand index "
@@ -187,7 +234,7 @@
 
     case spv::Op::OpMatrixTimesScalar: {
       if (!_.IsFloatMatrixType(result_type) &&
-          !_.IsCooperativeMatrixType(result_type))
+          !(_.IsCooperativeMatrixType(result_type)))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected float matrix type as Result Type: "
                << spvOpcodeString(opcode);
@@ -459,22 +506,108 @@
       const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
       const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
 
-      if (!_.IsCooperativeMatrixType(A_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(A_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as A Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(B_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(B_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as B Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(C_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(C_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as C Type: "
                << spvOpcodeString(opcode);
       }
-      if (!_.IsCooperativeMatrixType(D_type_id)) {
+      if (!_.IsCooperativeMatrixNVType(D_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected cooperative matrix type as Result Type: "
+               << spvOpcodeString(opcode);
+      }
+
+      const auto A = _.FindDef(A_type_id);
+      const auto B = _.FindDef(B_type_id);
+      const auto C = _.FindDef(C_type_id);
+      const auto D = _.FindDef(D_type_id);
+
+      std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+          A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+      A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+      B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+      C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+      D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+      A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+      B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+      C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+      D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+      A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+      B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+      C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+      D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+      const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+                               std::tuple<bool, bool, uint32_t> Y) {
+        return (std::get<1>(X) && std::get<1>(Y) &&
+                std::get<2>(X) != std::get<2>(Y));
+      };
+
+      if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+          notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+          notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix scopes must match: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+          notEqual(C_rows, D_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'M' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+          notEqual(C_cols, D_cols)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'N' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+
+      if (notEqual(A_cols, B_rows)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix 'K' mismatch: "
+               << spvOpcodeString(opcode);
+      }
+      break;
+    }
+
+    case spv::Op::OpCooperativeMatrixMulAddKHR: {
+      const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+      const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+      const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+      const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+      if (!_.IsCooperativeMatrixAType(A_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be A Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixBType(B_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be B Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixAccType(C_type_id)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix type must be Accumulator Type: "
+               << spvOpcodeString(opcode);
+      }
+      if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected cooperative matrix type as Result Type: "
                << spvOpcodeString(opcode);
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index 2b83c63..ed043b6 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -122,6 +122,7 @@
         *member_type = type_inst->word(component_index + 2);
         break;
       }
+      case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeCooperativeMatrixNV: {
         *member_type = type_inst->word(2);
         break;
@@ -335,6 +336,25 @@
 
       break;
     }
+    case spv::Op::OpTypeCooperativeMatrixKHR: {
+      const auto result_type_inst = _.FindDef(result_type);
+      assert(result_type_inst);
+      const auto component_type_id =
+          result_type_inst->GetOperandAs<uint32_t>(1);
+
+      if (3 != num_operands) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Must be only one constituent";
+      }
+
+      const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+      if (operand_type_id != component_type_id) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected Constituent type to be equal to the component type";
+      }
+      break;
+    }
     case spv::Op::OpTypeCooperativeMatrixNV: {
       const auto result_type_inst = _.FindDef(result_type);
       assert(result_type_inst);
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index 006e504..4deaa49 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -243,6 +243,7 @@
         }
       }
     } break;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     case spv::Op::OpTypeCooperativeMatrixNV: {
       if (1 != constituent_count) {
         return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -310,6 +311,7 @@
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeMatrix:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
     case spv::Op::OpTypeVector: {
       auto base_type = _.FindDef(instruction[2]);
       return base_type && IsTypeNullable(base_type->words(), _);
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index 476c1fe..b2892a8 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -473,7 +473,10 @@
       const bool input_is_pointer = _.IsPointerType(input_type);
       const bool input_is_int_scalar = _.IsIntScalarType(input_type);
 
-      if (!result_is_pointer && !result_is_int_scalar &&
+      const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
+      const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
+
+      if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
           !_.IsIntVectorType(result_type) &&
           !_.IsFloatScalarType(result_type) &&
           !_.IsFloatVectorType(result_type))
@@ -481,13 +484,24 @@
                << "Expected Result Type to be a pointer or int or float vector "
                << "or scalar type: " << spvOpcodeString(opcode);
 
-      if (!input_is_pointer && !input_is_int_scalar &&
+      if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
           !_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
           !_.IsFloatVectorType(input_type))
         return _.diag(SPV_ERROR_INVALID_DATA, inst)
                << "Expected input to be a pointer or int or float vector "
                << "or scalar: " << spvOpcodeString(opcode);
 
+      if (result_is_coopmat != input_is_coopmat)
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cooperative matrix can only be cast to another cooperative "
+               << "matrix: " << spvOpcodeString(opcode);
+
+      if (result_is_coopmat) {
+        spv_result_t ret =
+            _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+        if (ret != SPV_SUCCESS) return ret;
+      }
+
       if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
           _.HasExtension(kSPV_KHR_physical_storage_buffer)) {
         const bool result_is_int_vector = _.IsIntVectorType(result_type);
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index 92a4e8e..bcfeb59 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -163,9 +163,12 @@
               !inst->IsDebugInfo() && !inst->IsNonSemantic() &&
               !spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
               opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+              opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
               !(opcode == spv::Op::OpSpecConstantOp &&
-                spv::Op(inst->word(3)) ==
-                    spv::Op::OpCooperativeMatrixLengthNV)) {
+                (spv::Op(inst->word(3)) ==
+                     spv::Op::OpCooperativeMatrixLengthNV ||
+                 spv::Op(inst->word(3)) ==
+                     spv::Op::OpCooperativeMatrixLengthKHR))) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " cannot be a type";
@@ -179,9 +182,12 @@
                      opcode != spv::Op::OpLoopMerge &&
                      opcode != spv::Op::OpFunction &&
                      opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+                     opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
                      !(opcode == spv::Op::OpSpecConstantOp &&
-                       spv::Op(inst->word(3)) ==
-                           spv::Op::OpCooperativeMatrixLengthNV)) {
+                       (spv::Op(inst->word(3)) ==
+                            spv::Op::OpCooperativeMatrixLengthNV ||
+                        spv::Op(inst->word(3)) ==
+                            spv::Op::OpCooperativeMatrixLengthKHR))) {
             return _.diag(SPV_ERROR_INVALID_ID, inst)
                    << "Operand " << _.getIdName(operand_word)
                    << " requires a type";
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 975a55c..f039496 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -204,6 +204,7 @@
 
   switch (storage->opcode()) {
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return true;
     case spv::Op::OpTypeArray:
     case spv::Op::OpTypeRuntimeArray:
@@ -232,6 +233,7 @@
   spv::StorageClass src_sc = spv::StorageClass::Max;
   switch (inst->opcode()) {
     case spv::Op::OpCooperativeMatrixLoadNV:
+    case spv::Op::OpCooperativeMatrixLoadKHR:
     case spv::Op::OpLoad: {
       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
       auto load_pointer_type = _.FindDef(load_pointer->type_id());
@@ -239,6 +241,7 @@
       break;
     }
     case spv::Op::OpCooperativeMatrixStoreNV:
+    case spv::Op::OpCooperativeMatrixStoreKHR:
     case spv::Op::OpStore: {
       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
       auto store_pointer_type = _.FindDef(store_pointer->type_id());
@@ -326,7 +329,8 @@
   const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
   if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
     if (inst->opcode() == spv::Op::OpLoad ||
-        inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
+        inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
+        inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "MakePointerAvailableKHR cannot be used with OpLoad.";
     }
@@ -1357,6 +1361,7 @@
       case spv::Op::OpTypeMatrix:
       case spv::Op::OpTypeVector:
       case spv::Op::OpTypeCooperativeMatrixNV:
+      case spv::Op::OpTypeCooperativeMatrixKHR:
       case spv::Op::OpTypeArray:
       case spv::Op::OpTypeRuntimeArray: {
         // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
@@ -1554,9 +1559,15 @@
            << " must be OpTypeInt with width 32 and signedness 0.";
   }
 
+  bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
   auto type_id = inst->GetOperandAs<uint32_t>(2);
   auto type = state.FindDef(type_id);
-  if (type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+  if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    return state.diag(SPV_ERROR_INVALID_ID, inst)
+           << "The type in " << instr_name << " <id> "
+           << state.getIdName(type_id)
+           << " must be OpTypeCooperativeMatrixKHR.";
+  } else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
     return state.diag(SPV_ERROR_INVALID_ID, inst)
            << "The type in " << instr_name << " <id> "
            << state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
@@ -1668,6 +1679,112 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
+                                                   const Instruction* inst) {
+  uint32_t type_id;
+  const char* opname;
+  if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+    type_id = inst->type_id();
+    opname = "spv::Op::OpCooperativeMatrixLoadKHR";
+  } else {
+    // get Object operand's type
+    type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
+    opname = "spv::Op::OpCooperativeMatrixStoreKHR";
+  }
+
+  auto matrix_type = _.FindDef(type_id);
+
+  if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+    if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
+             << _.getIdName(type_id) << " is not a cooperative matrix type.";
+    } else {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
+             << _.getIdName(type_id) << " is not a cooperative matrix type.";
+    }
+  }
+
+  const auto pointer_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
+  const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
+  const auto pointer = _.FindDef(pointer_id);
+  if (!pointer ||
+      ((_.addressing_model() == spv::AddressingModel::Logical) &&
+       ((!_.features().variable_pointers &&
+         !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+        (_.features().variable_pointers &&
+         !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> " << _.getIdName(pointer_id)
+           << " is not a logical pointer.";
+  }
+
+  const auto pointer_type_id = pointer->type_id();
+  const auto pointer_type = _.FindDef(pointer_type_id);
+  if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " type for pointer <id> " << _.getIdName(pointer_id)
+           << " is not a pointer type.";
+  }
+
+  const auto storage_class_index = 1u;
+  const auto storage_class =
+      pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
+
+  if (storage_class != spv::StorageClass::Workgroup &&
+      storage_class != spv::StorageClass::StorageBuffer &&
+      storage_class != spv::StorageClass::PhysicalStorageBuffer) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " storage class for pointer type <id> "
+           << _.getIdName(pointer_type_id)
+           << " is not Workgroup or StorageBuffer.";
+  }
+
+  const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
+  const auto pointee_type = _.FindDef(pointee_id);
+  if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
+                         _.IsFloatScalarOrVectorType(pointee_id))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << opname << " Pointer <id> " << _.getIdName(pointer->id())
+           << "s Type must be a scalar or vector type.";
+  }
+
+  const auto layout_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
+  const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
+  const auto colmajor = _.FindDef(colmajor_id);
+  if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
+      !(spvOpcodeIsConstant(colmajor->opcode()) ||
+        spvOpcodeIsSpecConstant(colmajor->opcode()))) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
+           << " must be a 32-bit integer constant instruction.";
+  }
+
+  const auto stride_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
+  if (inst->operands().size() > stride_index) {
+    const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
+    const auto stride = _.FindDef(stride_id);
+    if (!stride || !_.IsIntScalarType(stride->type_id())) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "Stride operand <id> " << _.getIdName(stride_id)
+             << " must be a scalar integer type.";
+    }
+  }
+
+  const auto memory_access_index =
+      (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
+  if (inst->operands().size() > memory_access_index) {
+    if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
+      return error;
+  }
+
+  return SPV_SUCCESS;
+}
+
 spv_result_t ValidatePtrComparison(ValidationState_t& _,
                                    const Instruction* inst) {
   if (_.addressing_model() == spv::AddressingModel::Logical &&
@@ -1757,9 +1874,15 @@
       if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
         return error;
       break;
+    case spv::Op::OpCooperativeMatrixLengthKHR:
     case spv::Op::OpCooperativeMatrixLengthNV:
       if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
       break;
+    case spv::Op::OpCooperativeMatrixLoadKHR:
+    case spv::Op::OpCooperativeMatrixStoreKHR:
+      if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
+        return error;
+      break;
     case spv::Op::OpPtrEqual:
     case spv::Op::OpPtrNotEqual:
     case spv::Op::OpPtrDiff:
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 430d819..7edd12f 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -552,8 +552,8 @@
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
-                                             const Instruction* inst) {
+spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
+                                           const Instruction* inst) {
   const auto component_type_index = 1;
   const auto component_type_id =
       inst->GetOperandAs<uint32_t>(component_type_index);
@@ -561,7 +561,7 @@
   if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
                           spv::Op::OpTypeInt != component_type->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Component Type <id> "
+           << "OpTypeCooperativeMatrix Component Type <id> "
            << _.getIdName(component_type_id)
            << " is not a scalar numerical type.";
   }
@@ -572,7 +572,7 @@
   if (!scope || !_.IsIntScalarType(scope->type_id()) ||
       !spvOpcodeIsConstant(scope->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Scope <id> " << _.getIdName(scope_id)
+           << "OpTypeCooperativeMatrix Scope <id> " << _.getIdName(scope_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
@@ -582,7 +582,7 @@
   if (!rows || !_.IsIntScalarType(rows->type_id()) ||
       !spvOpcodeIsConstant(rows->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Rows <id> " << _.getIdName(rows_id)
+           << "OpTypeCooperativeMatrix Rows <id> " << _.getIdName(rows_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
@@ -592,10 +592,22 @@
   if (!cols || !_.IsIntScalarType(cols->type_id()) ||
       !spvOpcodeIsConstant(cols->opcode())) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeCooperativeMatrixNV Cols <id> " << _.getIdName(cols_id)
+           << "OpTypeCooperativeMatrix Cols <id> " << _.getIdName(cols_id)
            << " is not a constant instruction with scalar integer type.";
   }
 
+  if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    const auto use_index = 5;
+    const auto use_id = inst->GetOperandAs<uint32_t>(use_index);
+    const auto use = _.FindDef(use_id);
+    if (!use || !_.IsIntScalarType(use->type_id()) ||
+        !spvOpcodeIsConstant(use->opcode())) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "OpTypeCooperativeMatrixKHR Use <id> " << _.getIdName(use_id)
+             << " is not a constant instruction with scalar integer type.";
+    }
+  }
+
   return SPV_SUCCESS;
 }
 }  // namespace
@@ -640,7 +652,8 @@
       if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
       break;
     case spv::Op::OpTypeCooperativeMatrixNV:
-      if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
+    case spv::Op::OpTypeCooperativeMatrixKHR:
+      if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
       break;
     default:
       break;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 5a138d9..cde8aaa 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -859,6 +859,7 @@
       return GetComponentType(inst->word(2));
 
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return inst->word(2);
 
     default:
@@ -886,6 +887,7 @@
       return inst->word(3);
 
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       // Actual dimension isn't known, return 0
       return 0;
 
@@ -1143,21 +1145,67 @@
 
 bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
+  return inst && (inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV ||
+                  inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR);
+}
+
+bool ValidationState_t::IsCooperativeMatrixNVType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
   return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
 }
 
+bool ValidationState_t::IsCooperativeMatrixKHRType(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR;
+}
+
+bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse ==
+           static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
+  }
+  return false;
+}
+
+bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse ==
+           static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
+  }
+  return false;
+}
+bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
+  if (!IsCooperativeMatrixKHRType(id)) return false;
+  const Instruction* inst = FindDef(id);
+  uint64_t matrixUse = 0;
+  if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+    return matrixUse == static_cast<uint64_t>(
+                            spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
+  }
+  return false;
+}
+
 bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsFloatScalarType(FindDef(id)->word(2));
 }
 
 bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsIntScalarType(FindDef(id)->word(2));
 }
 
 bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
-  if (!IsCooperativeMatrixType(id)) return false;
+  if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+    return false;
   return IsUnsignedIntScalarType(FindDef(id)->word(2));
 }
 
@@ -1173,8 +1221,7 @@
   const auto m1_type = FindDef(m1);
   const auto m2_type = FindDef(m2);
 
-  if (m1_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV ||
-      m2_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+  if (m1_type->opcode() != m2_type->opcode()) {
     return diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected cooperative matrix types";
   }
@@ -1224,6 +1271,21 @@
            << "identical";
   }
 
+  if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+    uint32_t m1_use_id = m1_type->GetOperandAs<uint32_t>(5);
+    uint32_t m2_use_id = m2_type->GetOperandAs<uint32_t>(5);
+    std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+        EvalInt32IfConst(m1_use_id);
+    std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+        EvalInt32IfConst(m2_use_id);
+
+    if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+      return diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Expected Use of Matrix type and Result Type to be "
+             << "identical";
+    }
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -1489,6 +1551,7 @@
     case spv::Op::OpTypeImage:
     case spv::Op::OpTypeSampledImage:
     case spv::Op::OpTypeCooperativeMatrixNV:
+    case spv::Op::OpTypeCooperativeMatrixKHR:
       return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
                           traverse_all_types);
     case spv::Op::OpTypePointer:
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 4d5ac00..bfae821 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -610,6 +610,11 @@
   bool IsPointerType(uint32_t id) const;
   bool IsAccelerationStructureType(uint32_t id) const;
   bool IsCooperativeMatrixType(uint32_t id) const;
+  bool IsCooperativeMatrixNVType(uint32_t id) const;
+  bool IsCooperativeMatrixKHRType(uint32_t id) const;
+  bool IsCooperativeMatrixAType(uint32_t id) const;
+  bool IsCooperativeMatrixBType(uint32_t id) const;
+  bool IsCooperativeMatrixAccType(uint32_t id) const;
   bool IsFloatCooperativeMatrixType(uint32_t id) const;
   bool IsIntCooperativeMatrixType(uint32_t id) const;
   bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp
index 563eb74..cb30171 100644
--- a/test/opt/type_manager_test.cpp
+++ b/test/opt/type_manager_test.cpp
@@ -171,6 +171,7 @@
   types.emplace_back(new NamedBarrier());
   types.emplace_back(new AccelerationStructureNV());
   types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
+  types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002));
   types.emplace_back(new RayQueryKHR());
   types.emplace_back(new HitObjectNV());
 
@@ -237,6 +238,8 @@
     %arr_long_constant = OpTypeArray %s32 %long_constant
     %arr_spec_const_op = OpTypeArray %s32 %spec_const_op
     %cm   = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
+    %id2    = OpConstant %u32 2
+    %cmkhr  = OpTypeCooperativeMatrixKHR %f64 %id4 %id4 %id4 %id2
   )";
 
   std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@@ -275,6 +278,7 @@
       {37, "[sint32, id(33), words(0,705032704,1)]"},
       {38, "[sint32, id(34), words(2,34)]"},
       {39, "<float64, 6, 6, 6>"},
+      {41, "<float64, 6, 6, 6, 40>"},
   };
 
   std::unique_ptr<IRContext> context =
@@ -940,12 +944,15 @@
   std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
   uint32_t id = 1u;
   for (auto& t : types) {
+    std::cout << ". id " << id << std::endl;
     context->get_type_mgr()->RegisterType(id, *t);
     EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
   }
+  std::cout << "clear" << id << std::endl;
   types.clear();
 
   for (; id > 0; --id) {
+    std::cout << ". remove id " << id << std::endl;
     context->get_type_mgr()->RemoveId(id);
     EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
   }
@@ -1030,6 +1037,8 @@
 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
 ; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
 ; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
+; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
+; CHECK: [[uint8:%\w+]] = OpConstant [[uint]] 8
 ; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
 ; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
 ; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
@@ -1085,6 +1094,7 @@
 ; CHECK: OpTypeNamedBarrier
 ; CHECK: OpTypeAccelerationStructureKHR
 ; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
+; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
 ; CHECK: OpTypeRayQueryKHR
 ; CHECK: OpTypeHitObjectNV
 OpCapability Shader
@@ -1094,6 +1104,8 @@
 %uint = OpTypeInt 32 0
 %1 = OpTypePointer Input %uint
 %2 = OpTypePointer Uniform %uint
+%1002 = OpConstant %uint 2
+%8 = OpConstant %uint 8
 %24 = OpConstant %uint 24
 %42 = OpConstant %uint 42
 %100 = OpConstant %uint 100
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 631375e..06c4e39 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1318,7 +1318,7 @@
   CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
+              HasSubstr("OpTypeCooperativeMatrix Component Type <id> "
                         "'4[%bool]' is not a scalar numerical type."));
 }
 
@@ -1331,7 +1331,7 @@
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
+      HasSubstr("OpTypeCooperativeMatrix Scope <id> '17[%float_1]' is not a "
                 "constant instruction with scalar integer type."));
 }
 
@@ -1344,7 +1344,7 @@
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
+      HasSubstr("OpTypeCooperativeMatrix Rows <id> '17[%float_1]' is not a "
                 "constant instruction with scalar integer type."));
 }
 
@@ -1357,7 +1357,7 @@
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
+      HasSubstr("OpTypeCooperativeMatrix Cols <id> '17[%float_1]' is not a "
                 "constant instruction with scalar integer type."));
 }
 
@@ -1469,6 +1469,146 @@
                 "SMulExtended"));
 }
 
+std::string GenerateCoopMatKHRCode(const std::string& extra_types,
+                                   const std::string& main_body) {
+  const std::string prefix = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+
+%u32_16 = OpConstant %u32 16
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%useA = OpConstant %u32 0
+%useB = OpConstant %u32 1
+%useC = OpConstant %u32 2
+
+%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useA
+%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useA
+
+%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useB
+%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useB
+%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useB
+
+%f16matC = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useC
+%f32matC = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useC
+%u32matC = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useC
+%s32matC = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useC
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u32_1 = OpConstant %u32 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_A_1 = OpConstantComposite %f16matA %f16_1
+%u32mat_A_1 = OpConstantComposite %u32matA %u32_1
+%s32mat_A_1 = OpConstantComposite %s32matA %s32_1
+
+%f16mat_B_1 = OpConstantComposite %f16matB %f16_1
+%u32mat_B_1 = OpConstantComposite %u32matB %u32_1
+%s32mat_B_1 = OpConstantComposite %s32matB %s32_1
+
+%f16mat_C_1 = OpConstantComposite %f16matC %f16_1
+%u32mat_C_1 = OpConstantComposite %u32matC %u32_1
+%s32mat_C_1 = OpConstantComposite %s32matC %s32_1
+
+)";
+
+  const std::string func_begin = R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+  const std::string suffix = R"(
+OpReturn
+OpFunctionEnd)";
+
+  return prefix + extra_types + func_begin + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRSuccess) {
+  const std::string body = R"(
+%val1 = OpFAdd %f16matA %f16mat_A_1 %f16mat_A_1
+%val2 = OpFSub %f16matA %f16mat_A_1 %f16mat_A_1
+%val3 = OpFMul %f16matA %f16mat_A_1 %f16mat_A_1
+%val4 = OpFDiv %f16matA %f16mat_A_1 %f16mat_A_1
+%val5 = OpFNegate %f16matA %f16mat_A_1
+%val6 = OpIAdd %u32matA %u32mat_A_1 %u32mat_A_1
+%val7 = OpISub %u32matA %u32mat_A_1 %u32mat_A_1
+%val8 = OpUDiv %u32matA %u32mat_A_1 %u32mat_A_1
+%val9 = OpIAdd %s32matA %s32mat_A_1 %s32mat_A_1
+%val10 = OpISub %s32matA %s32mat_A_1 %s32mat_A_1
+%val11 = OpSDiv %s32matA %s32mat_A_1 %s32mat_A_1
+%val12 = OpSNegate %s32matA %s32mat_A_1
+%val13 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f16_1
+%val14 = OpMatrixTimesScalar %u32matA %u32mat_A_1 %u32_1
+%val15 = OpMatrixTimesScalar %s32matA %s32mat_A_1 %s32_1
+%val16 = OpCooperativeMatrixMulAddKHR %f32matC %f16mat_A_1 %f16mat_B_1 %f16mat_C_1
+%val17 = OpCooperativeMatrixMulAddKHR %s32matC %s32mat_A_1 %s32mat_B_1 %s32mat_C_1)";
+
+  CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
+  const std::string body = R"(
+%val1 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f32_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected scalar operand type to be equal to the component "
+                "type of the matrix operand: MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
+  const std::string types = R"(
+%workgroup = OpConstant %u32 2
+%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC
+%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
+)";
+
+  const std::string body = R"(
+%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
+  const std::string types = R"(
+%mat16x4 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_4 %useC
+%mat16x4_C_1 = OpConstantComposite %mat16x4 %f16_1
+)";
+
+  const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddKHR %mat16x4 %f16mat_A_1 %f16mat_B_1 %mat16x4_C_1
+)";
+
+  CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools
diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp
index 0fd1ed6..6e0d7c0 100644
--- a/test/val/val_composites_test.cpp
+++ b/test/val/val_composites_test.cpp
@@ -1486,8 +1486,7 @@
 }
 
 TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
-  const std::string body =
-      R"(
+  const std::string body = R"(
 OpCapability Shader
 OpCapability Float16
 OpCapability CooperativeMatrixNV
@@ -1525,8 +1524,7 @@
 }
 
 TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
-  const std::string body =
-      R"(
+  const std::string body = R"(
 OpCapability Shader
 OpCapability Float16
 OpCapability CooperativeMatrixNV
@@ -1562,6 +1560,86 @@
       HasSubstr("Expected Constituent type to be equal to the component type"));
 }
 
+TEST_F(ValidateComposites, CoopMatKHRConstantCompositeMismatchFail) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+
+%f32_1 = OpConstant %f32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "OpConstantComposite Constituent <id> '12[%float_1]' type "
+          "does not match the Result Type <id> '11[%11]'s component type."));
+}
+
+TEST_F(ValidateComposites, CoopMatKHRCompositeConstructMismatchFail) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+
+%f32_1 = OpConstant %f32 1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected Constituent type to be equal to the component type"));
+}
+
 TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
   const std::string spirv = R"(
 OpCapability Shader
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 1f8c426..0128aa1 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -1149,8 +1149,7 @@
 }
 
 TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
-  const std::string body =
-      R"(
+  const std::string body = R"(
 OpCapability Shader
 OpCapability Float16
 OpCapability Int16
@@ -1191,6 +1190,179 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateConversion, CoopMatKHRConversionSuccess) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
+%u16mat = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A
+%u32mat = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A
+%s16mat = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A
+%s32mat = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u16_1 = OpConstant %u16 1
+%u32_1 = OpConstant %u32 1
+%s16_1 = OpConstant %s16 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%f32mat_1 = OpConstantComposite %f32mat %f32_1
+%u16mat_1 = OpConstantComposite %u16mat %u16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s16mat_1 = OpConstantComposite %s16mat %s16_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val11 = OpConvertFToU %u16mat %f16mat_1
+%val12 = OpConvertFToU %u32mat %f16mat_1
+%val13 = OpConvertFToS %s16mat %f16mat_1
+%val14 = OpConvertFToS %s32mat %f16mat_1
+%val15 = OpFConvert %f32mat %f16mat_1
+
+%val21 = OpConvertFToU %u16mat %f32mat_1
+%val22 = OpConvertFToU %u32mat %f32mat_1
+%val23 = OpConvertFToS %s16mat %f32mat_1
+%val24 = OpConvertFToS %s32mat %f32mat_1
+%val25 = OpFConvert %f16mat %f32mat_1
+
+%val31 = OpConvertUToF %f16mat %u16mat_1
+%val32 = OpConvertUToF %f32mat %u16mat_1
+%val33 = OpUConvert %u32mat %u16mat_1
+%val34 = OpSConvert %s32mat %u16mat_1
+
+%val41 = OpConvertSToF %f16mat %s16mat_1
+%val42 = OpConvertSToF %f32mat %s16mat_1
+%val43 = OpUConvert %u32mat %s16mat_1
+%val44 = OpSConvert %s32mat %s16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateConversion, CoopMatKHRConversionUseMismatchFail) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%use_A = OpConstant %u32 0
+%use_B = OpConstant %u32 1
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val1 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected Use of Matrix type and Result Type to be identical"));
+}
+
+TEST_F(ValidateConversion, CoopMatKHRConversionScopeMismatchFail) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%workgroup = OpConstant %u32 2
+%use_A = OpConstant %u32 0
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val1 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
+}
+
 TEST_F(ValidateConversion, BitcastSuccess) {
   const std::string body = R"(
 %ptr = OpVariable %f32ptr_func Function
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index d575318..8d0a94d 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -23,12 +23,14 @@
 #include "test/val/val_fixtures.h"
 
 // For pretty-printing tuples with spv_target_env.
-std::ostream& operator<<(std::ostream& stream, spv_target_env target)
-{
+std::ostream& operator<<(std::ostream& stream, spv_target_env target) {
   switch (target) {
-    case SPV_ENV_UNIVERSAL_1_3: return stream << "SPV_ENV_UNIVERSAL_1_3";
-    case SPV_ENV_UNIVERSAL_1_4: return stream << "SPV_ENV_UNIVERSAL_1_4";
-    default:                    return stream << (unsigned)target;
+    case SPV_ENV_UNIVERSAL_1_3:
+      return stream << "SPV_ENV_UNIVERSAL_1_3";
+    case SPV_ENV_UNIVERSAL_1_4:
+      return stream << "SPV_ENV_UNIVERSAL_1_4";
+    default:
+      return stream << (unsigned)target;
   }
 }
 
@@ -2346,6 +2348,186 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+                                "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
+                                "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) {
+  std::string spirv =
+      GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+                                "MakePointerAvailableKHR|NonPrivatePointerKHR");
+
+  CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRInvalidStorageClassFail) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%str = OpTypeStruct %f16mat
+%str_ptr = OpTypePointer Workgroup %str
+%sh = OpVariable %str_ptr Workgroup
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "Cooperative matrix types (or types containing them) can only be "
+          "allocated in Function or Private storage classes or as function "
+          "parameters"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthResultTypeBad) {
+  const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %i32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The Result Type of OpCooperativeMatrixLengthKHR <id> "
+                "'12[%12]' must be OpTypeInt with width 32 and signedness 0"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthOperandTypeBad) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %u32 %u32
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The type in OpCooperativeMatrixLengthKHR <id> '5[%uint]' "
+                "must be OpTypeCooperativeMatrixKHR"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthGood) {
+  const std::string body =
+      R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %u32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
 TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
   std::string spirv = R"(
 OpCapability Shader
@@ -3765,9 +3947,8 @@
       HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks"));
 }
 
-using ValidateSizedVariable =
-    spvtest::ValidateBase<std::tuple<std::string, std::string,
-                                     std::string, spv_target_env>>;
+using ValidateSizedVariable = spvtest::ValidateBase<
+    std::tuple<std::string, std::string, std::string, spv_target_env>>;
 
 CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
   CodeGenerator generator;
@@ -3777,7 +3958,8 @@
       "\"SPV_KHR_8bit_storage\"\n";
   generator.memory_model_ = "OpMemoryModel Logical GLSL450\n";
   if (is_8bit) {
-    generator.before_types_ = "OpMemberDecorate %char_buffer_block 0 Offset 0\n";
+    generator.before_types_ =
+        "OpMemberDecorate %char_buffer_block 0 Offset 0\n";
     if (buffer_block)
       generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n";
 
diff --git a/utils/generate_grammar_tables.py b/utils/generate_grammar_tables.py
index 6b7167b..e6a1455 100755
--- a/utils/generate_grammar_tables.py
+++ b/utils/generate_grammar_tables.py
@@ -540,7 +540,7 @@
 
     # We have a few operand kinds that require their optional counterpart to
     # exist in the operand info table.
-    optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat']
+    optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands']
     optional_enums = [e for e in enums if e[0] in optional_enums]
     enums.extend(optional_enums)