SPV_KHR_cooperative_matrix (#5286)
* SPV_KHR_cooperative_matrix
* Update DEPS with headers
* Update according to review recommendations
* Bugfix and formatting
* Formatting missed or damaged by VS2022
diff --git a/DEPS b/DEPS
index 284797f..5727bad 100644
--- a/DEPS
+++ b/DEPS
@@ -13,7 +13,7 @@
'protobuf_revision': 'v21.12',
're2_revision': '7c5e396af825562ec8321fdbf2f1cf276b26e3ae',
- 'spirv_headers_revision': '10db9d4e194246a020a4148e220837ac7c68cfd9',
+ 'spirv_headers_revision': '3469b164e25cee24435029a569933cb42578db5d',
}
deps = {
diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h
index 542b745..8ecaf0a 100644
--- a/include/spirv-tools/libspirv.h
+++ b/include/spirv-tools/libspirv.h
@@ -285,6 +285,13 @@
// An optional packed vector format
SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT,
+ // Concrete operand types for cooperative matrix.
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS,
+ // An optional cooperative matrix operands
+ SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS,
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT,
+ SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE,
+
// This is a sentinel value, and does not represent an operand type.
// It should come last.
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp
index 6df823e..56c7964 100644
--- a/source/assembly_grammar.cpp
+++ b/source/assembly_grammar.cpp
@@ -154,11 +154,12 @@
CASE(InBoundsAccessChain),
CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain),
- CASE(CooperativeMatrixLengthNV)
+ CASE(CooperativeMatrixLengthNV),
+ CASE(CooperativeMatrixLengthKHR)
};
// The 60 is determined by counting the opcodes listed in the spec.
-static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+static_assert(61 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete");
#undef CASE
// clang-format on
diff --git a/source/binary.cpp b/source/binary.cpp
index beb56be..207d4a9 100644
--- a/source/binary.cpp
+++ b/source/binary.cpp
@@ -691,7 +691,9 @@
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
- case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS: {
+ case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
// This operand is a mask.
// Map an optional operand type to its corresponding concrete type.
@@ -699,6 +701,8 @@
parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS)
+ parsed_operand.type = SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS;
// Check validity of set mask bits. Also prepare for operands for those
// masks if they have any. To get operand order correct, scan from
diff --git a/source/opcode.cpp b/source/opcode.cpp
index d26024a..ffbb2e8 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -274,6 +274,7 @@
case spv::Op::OpTypeArray:
case spv::Op::OpTypeStruct:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
return false;
@@ -340,6 +341,7 @@
case spv::Op::OpTypeNamedBarrier:
case spv::Op::OpTypeAccelerationStructureNV:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
// spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR:
diff --git a/source/operand.cpp b/source/operand.cpp
index 31a6c59..a78191b 100644
--- a/source/operand.cpp
+++ b/source/operand.cpp
@@ -236,6 +236,13 @@
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
return "packed vector format";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
+ return "cooperative matrix operands";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+ return "cooperative matrix layout";
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
+ return "cooperative matrix use";
case SPV_OPERAND_TYPE_IMAGE:
case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
return "image";
@@ -369,6 +376,8 @@
case SPV_OPERAND_TYPE_QUANTIZATION_MODES:
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
case SPV_OPERAND_TYPE_PACKED_VECTOR_FORMAT:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_LAYOUT:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_USE:
return true;
default:
break;
@@ -387,6 +396,7 @@
case SPV_OPERAND_TYPE_FRAGMENT_SHADING_RATE:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_COOPERATIVE_MATRIX_OPERANDS:
return true;
default:
break;
@@ -405,6 +415,7 @@
case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING:
case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
case SPV_OPERAND_TYPE_OPTIONAL_PACKED_VECTOR_FORMAT:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS:
case SPV_OPERAND_TYPE_OPTIONAL_CIV:
return true;
default:
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index 1b1aead..2dcc259 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -423,6 +423,23 @@
{SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}});
break;
}
+ case Type::kCooperativeMatrixKHR: {
+ auto coop_mat = type->AsCooperativeMatrixKHR();
+ uint32_t const component_type =
+ GetTypeInstruction(coop_mat->component_type());
+ if (component_type == 0) {
+ return 0;
+ }
+ typeInst = MakeUnique<Instruction>(
+ context(), spv::Op::OpTypeCooperativeMatrixKHR, 0, id,
+ std::initializer_list<Operand>{
+ {SPV_OPERAND_TYPE_ID, {component_type}},
+ {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}},
+ {SPV_OPERAND_TYPE_ID, {coop_mat->use_id()}}});
+ break;
+ }
default:
assert(false && "Unexpected type");
break;
@@ -628,6 +645,14 @@
cm_type->columns_id());
break;
}
+ case Type::kCooperativeMatrixKHR: {
+ const CooperativeMatrixKHR* cm_type = type.AsCooperativeMatrixKHR();
+ const Type* component_type = cm_type->component_type();
+ rebuilt_ty = MakeUnique<CooperativeMatrixKHR>(
+ RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(),
+ cm_type->columns_id(), cm_type->use_id());
+ break;
+ }
default:
assert(false && "Unhandled type");
return nullptr;
@@ -863,6 +888,12 @@
inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3));
break;
+ case spv::Op::OpTypeCooperativeMatrixKHR:
+ type = new CooperativeMatrixKHR(
+ GetType(inst.GetSingleWordInOperand(0)),
+ inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
+ inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
+ break;
case spv::Op::OpTypeRayQueryKHR:
type = new RayQueryKHR();
break;
diff --git a/source/opt/types.cpp b/source/opt/types.cpp
index 49eec9b..b18b8cb 100644
--- a/source/opt/types.cpp
+++ b/source/opt/types.cpp
@@ -128,6 +128,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -175,6 +176,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -230,6 +232,7 @@
DeclareKindCase(NamedBarrier);
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
+ DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
@@ -708,6 +711,45 @@
columns_id_ == mt->columns_id_ && HasSameDecorations(that);
}
+CooperativeMatrixKHR::CooperativeMatrixKHR(const Type* type,
+ const uint32_t scope,
+ const uint32_t rows,
+ const uint32_t columns,
+ const uint32_t use)
+ : Type(kCooperativeMatrixKHR),
+ component_type_(type),
+ scope_id_(scope),
+ rows_id_(rows),
+ columns_id_(columns),
+ use_id_(use) {
+ assert(type != nullptr);
+ assert(scope != 0);
+ assert(rows != 0);
+ assert(columns != 0);
+}
+
+std::string CooperativeMatrixKHR::str() const {
+ std::ostringstream oss;
+ oss << "<" << component_type_->str() << ", " << scope_id_ << ", " << rows_id_
+ << ", " << columns_id_ << ", " << use_id_ << ">";
+ return oss.str();
+}
+
+size_t CooperativeMatrixKHR::ComputeExtraStateHash(size_t hash,
+ SeenTypes* seen) const {
+ hash = hash_combine(hash, scope_id_, rows_id_, columns_id_, use_id_);
+ return component_type_->ComputeHashValue(hash, seen);
+}
+
+bool CooperativeMatrixKHR::IsSameImpl(const Type* that,
+ IsSameCache* seen) const {
+ const CooperativeMatrixKHR* mt = that->AsCooperativeMatrixKHR();
+ if (!mt) return false;
+ return component_type_->IsSameImpl(mt->component_type_, seen) &&
+ scope_id_ == mt->scope_id_ && rows_id_ == mt->rows_id_ &&
+ columns_id_ == mt->columns_id_ && HasSameDecorations(that);
+}
+
} // namespace analysis
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/types.h b/source/opt/types.h
index 26c058c..16a948c 100644
--- a/source/opt/types.h
+++ b/source/opt/types.h
@@ -60,6 +60,7 @@
class NamedBarrier;
class AccelerationStructureNV;
class CooperativeMatrixNV;
+class CooperativeMatrixKHR;
class RayQueryKHR;
class HitObjectNV;
@@ -100,6 +101,7 @@
kNamedBarrier,
kAccelerationStructureNV,
kCooperativeMatrixNV,
+ kCooperativeMatrixKHR,
kRayQueryKHR,
kHitObjectNV,
kLast
@@ -201,6 +203,7 @@
DeclareCastMethod(NamedBarrier)
DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV)
+ DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
#undef DeclareCastMethod
@@ -624,6 +627,38 @@
const uint32_t columns_id_;
};
+class CooperativeMatrixKHR : public Type {
+ public:
+ CooperativeMatrixKHR(const Type* type, const uint32_t scope,
+ const uint32_t rows, const uint32_t columns,
+ const uint32_t use);
+ CooperativeMatrixKHR(const CooperativeMatrixKHR&) = default;
+
+ std::string str() const override;
+
+ CooperativeMatrixKHR* AsCooperativeMatrixKHR() override { return this; }
+ const CooperativeMatrixKHR* AsCooperativeMatrixKHR() const override {
+ return this;
+ }
+
+ size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
+
+ const Type* component_type() const { return component_type_; }
+ uint32_t scope_id() const { return scope_id_; }
+ uint32_t rows_id() const { return rows_id_; }
+ uint32_t columns_id() const { return columns_id_; }
+ uint32_t use_id() const { return use_id_; }
+
+ private:
+ bool IsSameImpl(const Type* that, IsSameCache*) const override;
+
+ const Type* component_type_;
+ const uint32_t scope_id_;
+ const uint32_t rows_id_;
+ const uint32_t columns_id_;
+ const uint32_t use_id_;
+};
+
#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
diff --git a/source/text.cpp b/source/text.cpp
index 9c77422..eb7f96b 100644
--- a/source/text.cpp
+++ b/source/text.cpp
@@ -402,7 +402,8 @@
case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
case SPV_OPERAND_TYPE_SELECTION_CONTROL:
case SPV_OPERAND_TYPE_DEBUG_INFO_FLAGS:
- case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS: {
+ case SPV_OPERAND_TYPE_CLDEBUG100_DEBUG_INFO_FLAGS:
+ case SPV_OPERAND_TYPE_OPTIONAL_COOPERATIVE_MATRIX_OPERANDS: {
uint32_t value;
if (auto error = grammar.parseMaskOperand(type, textValue, &value)) {
return context->diagnostic(error)
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index 4e7dd5e..b608a85 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -42,14 +42,29 @@
opcode != spv::Op::OpFMod);
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type) &&
- !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
+ !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpFMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsFloatCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
- if (_.GetOperandTypeId(inst, operand_index) != result_type)
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsFloatCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
@@ -71,7 +86,19 @@
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
- if (_.GetOperandTypeId(inst, operand_index) != result_type)
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
@@ -91,7 +118,10 @@
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
opcode != spv::Op::OpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
- !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpIMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@@ -102,9 +132,26 @@
for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
+
+ if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
+ if (!_.IsCooperativeMatrixKHRType(type_id) ||
+ !_.IsIntCooperativeMatrixType(type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected arithmetic operands to be of Result Type: "
+ << spvOpcodeString(opcode) << " operand index "
+ << operand_index;
+ }
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, type_id, result_type);
+ if (ret != SPV_SUCCESS) return ret;
+ }
+
if (!type_id ||
(!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
- !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
+ !(opcode == spv::Op::OpIMul &&
+ _.IsCooperativeMatrixKHRType(result_type) &&
+ _.IsIntCooperativeMatrixType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
@@ -187,7 +234,7 @@
case spv::Op::OpMatrixTimesScalar: {
if (!_.IsFloatMatrixType(result_type) &&
- !_.IsCooperativeMatrixType(result_type))
+ !(_.IsCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float matrix type as Result Type: "
<< spvOpcodeString(opcode);
@@ -459,22 +506,108 @@
const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
- if (!_.IsCooperativeMatrixType(A_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(A_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as A Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(B_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(B_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as B Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(C_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(C_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as C Type: "
<< spvOpcodeString(opcode);
}
- if (!_.IsCooperativeMatrixType(D_type_id)) {
+ if (!_.IsCooperativeMatrixNVType(D_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as Result Type: "
+ << spvOpcodeString(opcode);
+ }
+
+ const auto A = _.FindDef(A_type_id);
+ const auto B = _.FindDef(B_type_id);
+ const auto C = _.FindDef(C_type_id);
+ const auto D = _.FindDef(D_type_id);
+
+ std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+ A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+ A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+ B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+ C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+ D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+ A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+ B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+ C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+ D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+ A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+ B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+ C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+ D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+ const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+ std::tuple<bool, bool, uint32_t> Y) {
+ return (std::get<1>(X) && std::get<1>(Y) &&
+ std::get<2>(X) != std::get<2>(Y));
+ };
+
+ if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+ notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+ notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix scopes must match: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+ notEqual(C_rows, D_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'M' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+ notEqual(C_cols, D_cols)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'N' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_cols, B_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'K' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+ break;
+ }
+
+ case spv::Op::OpCooperativeMatrixMulAddKHR: {
+ const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+ const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+ const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+ const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+ if (!_.IsCooperativeMatrixAType(A_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be A Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixBType(B_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be B Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixAccType(C_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix type must be Accumulator Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix type as Result Type: "
<< spvOpcodeString(opcode);
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index 2b83c63..ed043b6 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -122,6 +122,7 @@
*member_type = type_inst->word(component_index + 2);
break;
}
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: {
*member_type = type_inst->word(2);
break;
@@ -335,6 +336,25 @@
break;
}
+ case spv::Op::OpTypeCooperativeMatrixKHR: {
+ const auto result_type_inst = _.FindDef(result_type);
+ assert(result_type_inst);
+ const auto component_type_id =
+ result_type_inst->GetOperandAs<uint32_t>(1);
+
+ if (3 != num_operands) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Must be only one constituent";
+ }
+
+ const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+ if (operand_type_id != component_type_id) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Constituent type to be equal to the component type";
+ }
+ break;
+ }
case spv::Op::OpTypeCooperativeMatrixNV: {
const auto result_type_inst = _.FindDef(result_type);
assert(result_type_inst);
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index 006e504..4deaa49 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -243,6 +243,7 @@
}
}
} break;
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeMatrixNV: {
if (1 != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -310,6 +311,7 @@
case spv::Op::OpTypeArray:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index 476c1fe..b2892a8 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -473,7 +473,10 @@
const bool input_is_pointer = _.IsPointerType(input_type);
const bool input_is_int_scalar = _.IsIntScalarType(input_type);
- if (!result_is_pointer && !result_is_int_scalar &&
+ const bool result_is_coopmat = _.IsCooperativeMatrixType(result_type);
+ const bool input_is_coopmat = _.IsCooperativeMatrixType(input_type);
+
+ if (!result_is_pointer && !result_is_int_scalar && !result_is_coopmat &&
!_.IsIntVectorType(result_type) &&
!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type))
@@ -481,13 +484,24 @@
<< "Expected Result Type to be a pointer or int or float vector "
<< "or scalar type: " << spvOpcodeString(opcode);
- if (!input_is_pointer && !input_is_int_scalar &&
+ if (!input_is_pointer && !input_is_int_scalar && !input_is_coopmat &&
!_.IsIntVectorType(input_type) && !_.IsFloatScalarType(input_type) &&
!_.IsFloatVectorType(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be a pointer or int or float vector "
<< "or scalar: " << spvOpcodeString(opcode);
+ if (result_is_coopmat != input_is_coopmat)
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix can only be cast to another cooperative "
+ << "matrix: " << spvOpcodeString(opcode);
+
+ if (result_is_coopmat) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ }
+
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 5) ||
_.HasExtension(kSPV_KHR_physical_storage_buffer)) {
const bool result_is_int_vector = _.IsIntVectorType(result_type);
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index 92a4e8e..bcfeb59 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -163,9 +163,12 @@
!inst->IsDebugInfo() && !inst->IsNonSemantic() &&
!spvOpcodeIsDecoration(opcode) && opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+ opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp &&
- spv::Op(inst->word(3)) ==
- spv::Op::OpCooperativeMatrixLengthNV)) {
+ (spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthNV ||
+ spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " cannot be a type";
@@ -179,9 +182,12 @@
opcode != spv::Op::OpLoopMerge &&
opcode != spv::Op::OpFunction &&
opcode != spv::Op::OpCooperativeMatrixLengthNV &&
+ opcode != spv::Op::OpCooperativeMatrixLengthKHR &&
!(opcode == spv::Op::OpSpecConstantOp &&
- spv::Op(inst->word(3)) ==
- spv::Op::OpCooperativeMatrixLengthNV)) {
+ (spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthNV ||
+ spv::Op(inst->word(3)) ==
+ spv::Op::OpCooperativeMatrixLengthKHR))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " requires a type";
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 975a55c..f039496 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -204,6 +204,7 @@
switch (storage->opcode()) {
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray:
@@ -232,6 +233,7 @@
spv::StorageClass src_sc = spv::StorageClass::Max;
switch (inst->opcode()) {
case spv::Op::OpCooperativeMatrixLoadNV:
+ case spv::Op::OpCooperativeMatrixLoadKHR:
case spv::Op::OpLoad: {
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
auto load_pointer_type = _.FindDef(load_pointer->type_id());
@@ -239,6 +241,7 @@
break;
}
case spv::Op::OpCooperativeMatrixStoreNV:
+ case spv::Op::OpCooperativeMatrixStoreKHR:
case spv::Op::OpStore: {
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
auto store_pointer_type = _.FindDef(store_pointer->type_id());
@@ -326,7 +329,8 @@
const uint32_t mask = inst->GetOperandAs<uint32_t>(index);
if (mask & uint32_t(spv::MemoryAccessMask::MakePointerAvailableKHR)) {
if (inst->opcode() == spv::Op::OpLoad ||
- inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV) {
+ inst->opcode() == spv::Op::OpCooperativeMatrixLoadNV ||
+ inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
}
@@ -1357,6 +1361,7 @@
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray: {
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
@@ -1554,9 +1559,15 @@
<< " must be OpTypeInt with width 32 and signedness 0.";
}
+ bool isKhr = inst->opcode() == spv::Op::OpCooperativeMatrixLengthKHR;
auto type_id = inst->GetOperandAs<uint32_t>(2);
auto type = state.FindDef(type_id);
- if (type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+ if (isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+ return state.diag(SPV_ERROR_INVALID_ID, inst)
+ << "The type in " << instr_name << " <id> "
+ << state.getIdName(type_id)
+ << " must be OpTypeCooperativeMatrixKHR.";
+ } else if (!isKhr && type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
return state.diag(SPV_ERROR_INVALID_ID, inst)
<< "The type in " << instr_name << " <id> "
<< state.getIdName(type_id) << " must be OpTypeCooperativeMatrixNV.";
@@ -1668,6 +1679,112 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
+ const Instruction* inst) {
+ uint32_t type_id;
+ const char* opname;
+ if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+ type_id = inst->type_id();
+ opname = "spv::Op::OpCooperativeMatrixLoadKHR";
+ } else {
+ // get Object operand's type
+ type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
+ opname = "spv::Op::OpCooperativeMatrixStoreKHR";
+ }
+
+ auto matrix_type = _.FindDef(type_id);
+
+ if (matrix_type->opcode() != spv::Op::OpTypeCooperativeMatrixKHR) {
+ if (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "spv::Op::OpCooperativeMatrixLoadKHR Result Type <id> "
+ << _.getIdName(type_id) << " is not a cooperative matrix type.";
+ } else {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "spv::Op::OpCooperativeMatrixStoreKHR Object type <id> "
+ << _.getIdName(type_id) << " is not a cooperative matrix type.";
+ }
+ }
+
+ const auto pointer_index =
+ (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 2u : 0u;
+ const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
+ const auto pointer = _.FindDef(pointer_id);
+ if (!pointer ||
+ ((_.addressing_model() == spv::AddressingModel::Logical) &&
+ ((!_.features().variable_pointers &&
+ !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+ (_.features().variable_pointers &&
+ !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " Pointer <id> " << _.getIdName(pointer_id)
+ << " is not a logical pointer.";
+ }
+
+ const auto pointer_type_id = pointer->type_id();
+ const auto pointer_type = _.FindDef(pointer_type_id);
+ if (!pointer_type || pointer_type->opcode() != spv::Op::OpTypePointer) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " type for pointer <id> " << _.getIdName(pointer_id)
+ << " is not a pointer type.";
+ }
+
+ const auto storage_class_index = 1u;
+ const auto storage_class =
+ pointer_type->GetOperandAs<spv::StorageClass>(storage_class_index);
+
+ if (storage_class != spv::StorageClass::Workgroup &&
+ storage_class != spv::StorageClass::StorageBuffer &&
+ storage_class != spv::StorageClass::PhysicalStorageBuffer) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " storage class for pointer type <id> "
+ << _.getIdName(pointer_type_id)
+ << " is not Workgroup or StorageBuffer.";
+ }
+
+ const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
+ const auto pointee_type = _.FindDef(pointee_id);
+ if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
+ _.IsFloatScalarOrVectorType(pointee_id))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " Pointer <id> " << _.getIdName(pointer->id())
+ << "s Type must be a scalar or vector type.";
+ }
+
+ const auto layout_index =
+ (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
+ const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
+ const auto colmajor = _.FindDef(colmajor_id);
+ if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
+ !(spvOpcodeIsConstant(colmajor->opcode()) ||
+ spvOpcodeIsSpecConstant(colmajor->opcode()))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
+ << " must be a 32-bit integer constant instruction.";
+ }
+
+ const auto stride_index =
+ (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
+ if (inst->operands().size() > stride_index) {
+ const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
+ const auto stride = _.FindDef(stride_id);
+ if (!stride || !_.IsIntScalarType(stride->type_id())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Stride operand <id> " << _.getIdName(stride_id)
+ << " must be a scalar integer type.";
+ }
+ }
+
+ const auto memory_access_index =
+ (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 5u : 4u;
+ if (inst->operands().size() > memory_access_index) {
+ if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
+ return error;
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidatePtrComparison(ValidationState_t& _,
const Instruction* inst) {
if (_.addressing_model() == spv::AddressingModel::Logical &&
@@ -1757,9 +1874,15 @@
if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
return error;
break;
+ case spv::Op::OpCooperativeMatrixLengthKHR:
case spv::Op::OpCooperativeMatrixLengthNV:
if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
break;
+ case spv::Op::OpCooperativeMatrixLoadKHR:
+ case spv::Op::OpCooperativeMatrixStoreKHR:
+ if (auto error = ValidateCooperativeMatrixLoadStoreKHR(_, inst))
+ return error;
+ break;
case spv::Op::OpPtrEqual:
case spv::Op::OpPtrNotEqual:
case spv::Op::OpPtrDiff:
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 430d819..7edd12f 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -552,8 +552,8 @@
return SPV_SUCCESS;
}
-spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
- const Instruction* inst) {
+spv_result_t ValidateTypeCooperativeMatrix(ValidationState_t& _,
+ const Instruction* inst) {
const auto component_type_index = 1;
const auto component_type_id =
inst->GetOperandAs<uint32_t>(component_type_index);
@@ -561,7 +561,7 @@
if (!component_type || (spv::Op::OpTypeFloat != component_type->opcode() &&
spv::Op::OpTypeInt != component_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeCooperativeMatrixNV Component Type <id> "
+ << "OpTypeCooperativeMatrix Component Type <id> "
<< _.getIdName(component_type_id)
<< " is not a scalar numerical type.";
}
@@ -572,7 +572,7 @@
if (!scope || !_.IsIntScalarType(scope->type_id()) ||
!spvOpcodeIsConstant(scope->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeCooperativeMatrixNV Scope <id> " << _.getIdName(scope_id)
+ << "OpTypeCooperativeMatrix Scope <id> " << _.getIdName(scope_id)
<< " is not a constant instruction with scalar integer type.";
}
@@ -582,7 +582,7 @@
if (!rows || !_.IsIntScalarType(rows->type_id()) ||
!spvOpcodeIsConstant(rows->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeCooperativeMatrixNV Rows <id> " << _.getIdName(rows_id)
+ << "OpTypeCooperativeMatrix Rows <id> " << _.getIdName(rows_id)
<< " is not a constant instruction with scalar integer type.";
}
@@ -592,10 +592,22 @@
if (!cols || !_.IsIntScalarType(cols->type_id()) ||
!spvOpcodeIsConstant(cols->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeCooperativeMatrixNV Cols <id> " << _.getIdName(cols_id)
+ << "OpTypeCooperativeMatrix Cols <id> " << _.getIdName(cols_id)
<< " is not a constant instruction with scalar integer type.";
}
+ if (inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+ const auto use_index = 5;
+ const auto use_id = inst->GetOperandAs<uint32_t>(use_index);
+ const auto use = _.FindDef(use_id);
+ if (!use || !_.IsIntScalarType(use->type_id()) ||
+ !spvOpcodeIsConstant(use->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrixKHR Use <id> " << _.getIdName(use_id)
+ << " is not a constant instruction with scalar integer type.";
+ }
+ }
+
return SPV_SUCCESS;
}
} // namespace
@@ -640,7 +652,8 @@
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
break;
case spv::Op::OpTypeCooperativeMatrixNV:
- if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
+ case spv::Op::OpTypeCooperativeMatrixKHR:
+ if (auto error = ValidateTypeCooperativeMatrix(_, inst)) return error;
break;
default:
break;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 5a138d9..cde8aaa 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -859,6 +859,7 @@
return GetComponentType(inst->word(2));
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
return inst->word(2);
default:
@@ -886,6 +887,7 @@
return inst->word(3);
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
// Actual dimension isn't known, return 0
return 0;
@@ -1143,21 +1145,67 @@
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
const Instruction* inst = FindDef(id);
+ return inst && (inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV ||
+ inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR);
+}
+
+bool ValidationState_t::IsCooperativeMatrixNVType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixNV;
}
+bool ValidationState_t::IsCooperativeMatrixKHRType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ return inst && inst->opcode() == spv::Op::OpTypeCooperativeMatrixKHR;
+}
+
+bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
+ if (!IsCooperativeMatrixKHRType(id)) return false;
+ const Instruction* inst = FindDef(id);
+ uint64_t matrixUse = 0;
+ if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ return matrixUse ==
+ static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
+ }
+ return false;
+}
+
+bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
+ if (!IsCooperativeMatrixKHRType(id)) return false;
+ const Instruction* inst = FindDef(id);
+ uint64_t matrixUse = 0;
+ if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ return matrixUse ==
+ static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
+ }
+ return false;
+}
+bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
+ if (!IsCooperativeMatrixKHRType(id)) return false;
+ const Instruction* inst = FindDef(id);
+ uint64_t matrixUse = 0;
+ if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ return matrixUse == static_cast<uint64_t>(
+ spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
+ }
+ return false;
+}
+
bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
- if (!IsCooperativeMatrixType(id)) return false;
+ if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+ return false;
return IsFloatScalarType(FindDef(id)->word(2));
}
bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
- if (!IsCooperativeMatrixType(id)) return false;
+ if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+ return false;
return IsIntScalarType(FindDef(id)->word(2));
}
bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
- if (!IsCooperativeMatrixType(id)) return false;
+ if (!IsCooperativeMatrixNVType(id) && !IsCooperativeMatrixKHRType(id))
+ return false;
return IsUnsignedIntScalarType(FindDef(id)->word(2));
}
@@ -1173,8 +1221,7 @@
const auto m1_type = FindDef(m1);
const auto m2_type = FindDef(m2);
- if (m1_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV ||
- m2_type->opcode() != spv::Op::OpTypeCooperativeMatrixNV) {
+ if (m1_type->opcode() != m2_type->opcode()) {
return diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected cooperative matrix types";
}
@@ -1224,6 +1271,21 @@
<< "identical";
}
+ if (m1_type->opcode() == spv::Op::OpTypeCooperativeMatrixKHR) {
+ uint32_t m1_use_id = m1_type->GetOperandAs<uint32_t>(5);
+ uint32_t m2_use_id = m2_type->GetOperandAs<uint32_t>(5);
+ std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+ EvalInt32IfConst(m1_use_id);
+ std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+ EvalInt32IfConst(m2_use_id);
+
+ if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+ return diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Use of Matrix type and Result Type to be "
+ << "identical";
+ }
+ }
+
return SPV_SUCCESS;
}
@@ -1489,6 +1551,7 @@
case spv::Op::OpTypeImage:
case spv::Op::OpTypeSampledImage:
case spv::Op::OpTypeCooperativeMatrixNV:
+ case spv::Op::OpTypeCooperativeMatrixKHR:
return ContainsType(inst->GetOperandAs<uint32_t>(1u), f,
traverse_all_types);
case spv::Op::OpTypePointer:
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 4d5ac00..bfae821 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -610,6 +610,11 @@
bool IsPointerType(uint32_t id) const;
bool IsAccelerationStructureType(uint32_t id) const;
bool IsCooperativeMatrixType(uint32_t id) const;
+ bool IsCooperativeMatrixNVType(uint32_t id) const;
+ bool IsCooperativeMatrixKHRType(uint32_t id) const;
+ bool IsCooperativeMatrixAType(uint32_t id) const;
+ bool IsCooperativeMatrixBType(uint32_t id) const;
+ bool IsCooperativeMatrixAccType(uint32_t id) const;
bool IsFloatCooperativeMatrixType(uint32_t id) const;
bool IsIntCooperativeMatrixType(uint32_t id) const;
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
diff --git a/test/opt/type_manager_test.cpp b/test/opt/type_manager_test.cpp
index 563eb74..cb30171 100644
--- a/test/opt/type_manager_test.cpp
+++ b/test/opt/type_manager_test.cpp
@@ -171,6 +171,7 @@
types.emplace_back(new NamedBarrier());
types.emplace_back(new AccelerationStructureNV());
types.emplace_back(new CooperativeMatrixNV(f32, 24, 24, 24));
+ types.emplace_back(new CooperativeMatrixKHR(f32, 8, 8, 8, 1002));
types.emplace_back(new RayQueryKHR());
types.emplace_back(new HitObjectNV());
@@ -237,6 +238,8 @@
%arr_long_constant = OpTypeArray %s32 %long_constant
%arr_spec_const_op = OpTypeArray %s32 %spec_const_op
%cm = OpTypeCooperativeMatrixNV %f64 %id4 %id4 %id4
+ %id2 = OpConstant %u32 2
+ %cmkhr = OpTypeCooperativeMatrixKHR %f64 %id4 %id4 %id4 %id2
)";
std::vector<std::pair<uint32_t, std::string>> type_id_strs = {
@@ -275,6 +278,7 @@
{37, "[sint32, id(33), words(0,705032704,1)]"},
{38, "[sint32, id(34), words(2,34)]"},
{39, "<float64, 6, 6, 6>"},
+ {41, "<float64, 6, 6, 6, 40>"},
};
std::unique_ptr<IRContext> context =
@@ -940,12 +944,15 @@
std::vector<std::unique_ptr<Type>> types = GenerateAllTypes();
uint32_t id = 1u;
for (auto& t : types) {
+ std::cout << ". id " << id << std::endl;
context->get_type_mgr()->RegisterType(id, *t);
EXPECT_EQ(*t, *context->get_type_mgr()->GetType(id));
}
+ std::cout << "clear" << id << std::endl;
types.clear();
for (; id > 0; --id) {
+ std::cout << ". remove id " << id << std::endl;
context->get_type_mgr()->RemoveId(id);
EXPECT_EQ(nullptr, context->get_type_mgr()->GetType(id));
}
@@ -1030,6 +1037,8 @@
; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
; CHECK: [[input_ptr:%\w+]] = OpTypePointer Input [[uint]]
; CHECK: [[uniform_ptr:%\w+]] = OpTypePointer Uniform [[uint]]
+; CHECK: [[uint2:%\w+]] = OpConstant [[uint]] 2
+; CHECK: [[uint8:%\w+]] = OpConstant [[uint]] 8
; CHECK: [[uint24:%\w+]] = OpConstant [[uint]] 24
; CHECK: [[uint42:%\w+]] = OpConstant [[uint]] 42
; CHECK: [[uint100:%\w+]] = OpConstant [[uint]] 100
@@ -1085,6 +1094,7 @@
; CHECK: OpTypeNamedBarrier
; CHECK: OpTypeAccelerationStructureKHR
; CHECK: OpTypeCooperativeMatrixNV [[f32]] [[uint24]] [[uint24]] [[uint24]]
+; CHECK: OpTypeCooperativeMatrixKHR [[f32]] [[uint8]] [[uint8]] [[uint8]] [[uint2]]
; CHECK: OpTypeRayQueryKHR
; CHECK: OpTypeHitObjectNV
OpCapability Shader
@@ -1094,6 +1104,8 @@
%uint = OpTypeInt 32 0
%1 = OpTypePointer Input %uint
%2 = OpTypePointer Uniform %uint
+%1002 = OpConstant %uint 2
+%8 = OpConstant %uint 8
%24 = OpConstant %uint 24
%42 = OpConstant %uint 42
%100 = OpConstant %uint 100
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 631375e..06c4e39 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1318,7 +1318,7 @@
CompileSuccessfully(GenerateCoopMatCode(types, "").c_str());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
- HasSubstr("OpTypeCooperativeMatrixNV Component Type <id> "
+ HasSubstr("OpTypeCooperativeMatrix Component Type <id> "
"'4[%bool]' is not a scalar numerical type."));
}
@@ -1331,7 +1331,7 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("OpTypeCooperativeMatrixNV Scope <id> '17[%float_1]' is not a "
+ HasSubstr("OpTypeCooperativeMatrix Scope <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
@@ -1344,7 +1344,7 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("OpTypeCooperativeMatrixNV Rows <id> '17[%float_1]' is not a "
+ HasSubstr("OpTypeCooperativeMatrix Rows <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
@@ -1357,7 +1357,7 @@
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("OpTypeCooperativeMatrixNV Cols <id> '17[%float_1]' is not a "
+ HasSubstr("OpTypeCooperativeMatrix Cols <id> '17[%float_1]' is not a "
"constant instruction with scalar integer type."));
}
@@ -1469,6 +1469,146 @@
"SMulExtended"));
}
+std::string GenerateCoopMatKHRCode(const std::string& extra_types,
+ const std::string& main_body) {
+ const std::string prefix = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+
+%u32_16 = OpConstant %u32 16
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%useA = OpConstant %u32 0
+%useB = OpConstant %u32 1
+%useC = OpConstant %u32 2
+
+%f16matA = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+%u32matA = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useA
+%s32matA = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useA
+
+%f16matB = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useB
+%u32matB = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useB
+%s32matB = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useB
+
+%f16matC = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useC
+%f32matC = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_16 %u32_16 %useC
+%u32matC = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_16 %u32_16 %useC
+%s32matC = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_16 %u32_16 %useC
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u32_1 = OpConstant %u32 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_A_1 = OpConstantComposite %f16matA %f16_1
+%u32mat_A_1 = OpConstantComposite %u32matA %u32_1
+%s32mat_A_1 = OpConstantComposite %s32matA %s32_1
+
+%f16mat_B_1 = OpConstantComposite %f16matB %f16_1
+%u32mat_B_1 = OpConstantComposite %u32matB %u32_1
+%s32mat_B_1 = OpConstantComposite %s32matB %s32_1
+
+%f16mat_C_1 = OpConstantComposite %f16matC %f16_1
+%u32mat_C_1 = OpConstantComposite %u32matC %u32_1
+%s32mat_C_1 = OpConstantComposite %s32matC %s32_1
+
+)";
+
+ const std::string func_begin = R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+ const std::string suffix = R"(
+OpReturn
+OpFunctionEnd)";
+
+ return prefix + extra_types + func_begin + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRSuccess) {
+ const std::string body = R"(
+%val1 = OpFAdd %f16matA %f16mat_A_1 %f16mat_A_1
+%val2 = OpFSub %f16matA %f16mat_A_1 %f16mat_A_1
+%val3 = OpFMul %f16matA %f16mat_A_1 %f16mat_A_1
+%val4 = OpFDiv %f16matA %f16mat_A_1 %f16mat_A_1
+%val5 = OpFNegate %f16matA %f16mat_A_1
+%val6 = OpIAdd %u32matA %u32mat_A_1 %u32mat_A_1
+%val7 = OpISub %u32matA %u32mat_A_1 %u32mat_A_1
+%val8 = OpUDiv %u32matA %u32mat_A_1 %u32mat_A_1
+%val9 = OpIAdd %s32matA %s32mat_A_1 %s32mat_A_1
+%val10 = OpISub %s32matA %s32mat_A_1 %s32mat_A_1
+%val11 = OpSDiv %s32matA %s32mat_A_1 %s32mat_A_1
+%val12 = OpSNegate %s32matA %s32mat_A_1
+%val13 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f16_1
+%val14 = OpMatrixTimesScalar %u32matA %u32mat_A_1 %u32_1
+%val15 = OpMatrixTimesScalar %s32matA %s32mat_A_1 %s32_1
+%val16 = OpCooperativeMatrixMulAddKHR %f32matC %f16mat_A_1 %f16mat_B_1 %f16mat_C_1
+%val17 = OpCooperativeMatrixMulAddKHR %s32matC %s32mat_A_1 %s32mat_B_1 %s32mat_C_1)";
+
+ CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, CoopMatMatrixKHRTimesScalarMismatchFail) {
+ const std::string body = R"(
+%val1 = OpMatrixTimesScalar %f16matA %f16mat_A_1 %f32_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatKHRCode("", body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected scalar operand type to be equal to the component "
+ "type of the matrix operand: MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRScopeFail) {
+ const std::string types = R"(
+%workgroup = OpConstant %u32 2
+%mat16x16_wg = OpTypeCooperativeMatrixKHR %f16 %workgroup %u32_16 %u32_16 %useC
+%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
+)";
+
+ const std::string body = R"(
+%val1 = OpFAdd %f16matA %f16matwg_16x16_1 %f16mat_A_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatKHRDimFail) {
+ const std::string types = R"(
+%mat16x4 = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_4 %useC
+%mat16x4_C_1 = OpConstantComposite %mat16x4 %f16_1
+)";
+
+ const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddKHR %mat16x4 %f16mat_A_1 %f16mat_B_1 %mat16x4_C_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatKHRCode(types, body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Cooperative matrix 'N' mismatch: CooperativeMatrixMulAddKHR"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools
diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp
index 0fd1ed6..6e0d7c0 100644
--- a/test/val/val_composites_test.cpp
+++ b/test/val/val_composites_test.cpp
@@ -1486,8 +1486,7 @@
}
TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
- const std::string body =
- R"(
+ const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
@@ -1525,8 +1524,7 @@
}
TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
- const std::string body =
- R"(
+ const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixNV
@@ -1562,6 +1560,86 @@
HasSubstr("Expected Constituent type to be equal to the component type"));
}
+TEST_F(ValidateComposites, CoopMatKHRConstantCompositeMismatchFail) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+
+%f32_1 = OpConstant %f32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "OpConstantComposite Constituent <id> '12[%float_1]' type "
+ "does not match the Result Type <id> '11[%11]'s component type."));
+}
+
+TEST_F(ValidateComposites, CoopMatKHRCompositeConstructMismatchFail) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_16 = OpConstant %u32 16
+%useA = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_16 %u32_16 %useA
+
+%f32_1 = OpConstant %f32 1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected Constituent type to be equal to the component type"));
+}
+
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
const std::string spirv = R"(
OpCapability Shader
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 1f8c426..0128aa1 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -1149,8 +1149,7 @@
}
TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
- const std::string body =
- R"(
+ const std::string body = R"(
OpCapability Shader
OpCapability Float16
OpCapability Int16
@@ -1191,6 +1190,179 @@
ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateConversion, CoopMatKHRConversionSuccess) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_A
+%u16mat = OpTypeCooperativeMatrixKHR %u16 %subgroup %u32_8 %u32_8 %use_A
+%u32mat = OpTypeCooperativeMatrixKHR %u32 %subgroup %u32_8 %u32_8 %use_A
+%s16mat = OpTypeCooperativeMatrixKHR %s16 %subgroup %u32_8 %u32_8 %use_A
+%s32mat = OpTypeCooperativeMatrixKHR %s32 %subgroup %u32_8 %u32_8 %use_A
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u16_1 = OpConstant %u16 1
+%u32_1 = OpConstant %u32 1
+%s16_1 = OpConstant %s16 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%f32mat_1 = OpConstantComposite %f32mat %f32_1
+%u16mat_1 = OpConstantComposite %u16mat %u16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s16mat_1 = OpConstantComposite %s16mat %s16_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val11 = OpConvertFToU %u16mat %f16mat_1
+%val12 = OpConvertFToU %u32mat %f16mat_1
+%val13 = OpConvertFToS %s16mat %f16mat_1
+%val14 = OpConvertFToS %s32mat %f16mat_1
+%val15 = OpFConvert %f32mat %f16mat_1
+
+%val21 = OpConvertFToU %u16mat %f32mat_1
+%val22 = OpConvertFToU %u32mat %f32mat_1
+%val23 = OpConvertFToS %s16mat %f32mat_1
+%val24 = OpConvertFToS %s32mat %f32mat_1
+%val25 = OpFConvert %f16mat %f32mat_1
+
+%val31 = OpConvertUToF %f16mat %u16mat_1
+%val32 = OpConvertUToF %f32mat %u16mat_1
+%val33 = OpUConvert %u32mat %u16mat_1
+%val34 = OpSConvert %s32mat %u16mat_1
+
+%val41 = OpConvertSToF %f16mat %s16mat_1
+%val42 = OpConvertSToF %f32mat %s16mat_1
+%val43 = OpUConvert %u32mat %s16mat_1
+%val44 = OpSConvert %s32mat %s16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateConversion, CoopMatKHRConversionUseMismatchFail) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%use_A = OpConstant %u32 0
+%use_B = OpConstant %u32 1
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %subgroup %u32_8 %u32_8 %use_B
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val1 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected Use of Matrix type and Result Type to be identical"));
+}
+
+TEST_F(ValidateConversion, CoopMatKHRConversionScopeMismatchFail) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+%workgroup = OpConstant %u32 2
+%use_A = OpConstant %u32 0
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+%f32mat = OpTypeCooperativeMatrixKHR %f32 %workgroup %u32_8 %u32_8 %use_A
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val1 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected scopes of Matrix and Result Type to be identical"));
+}
+
TEST_F(ValidateConversion, BitcastSuccess) {
const std::string body = R"(
%ptr = OpVariable %f32ptr_func Function
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index d575318..8d0a94d 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -23,12 +23,14 @@
#include "test/val/val_fixtures.h"
// For pretty-printing tuples with spv_target_env.
-std::ostream& operator<<(std::ostream& stream, spv_target_env target)
-{
+std::ostream& operator<<(std::ostream& stream, spv_target_env target) {
switch (target) {
- case SPV_ENV_UNIVERSAL_1_3: return stream << "SPV_ENV_UNIVERSAL_1_3";
- case SPV_ENV_UNIVERSAL_1_4: return stream << "SPV_ENV_UNIVERSAL_1_4";
- default: return stream << (unsigned)target;
+ case SPV_ENV_UNIVERSAL_1_3:
+ return stream << "SPV_ENV_UNIVERSAL_1_3";
+ case SPV_ENV_UNIVERSAL_1_4:
+ return stream << "SPV_ENV_UNIVERSAL_1_4";
+ default:
+ return stream << (unsigned)target;
}
}
@@ -2346,6 +2348,186 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+ "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
+ "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRLoadMemoryAccessFail) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+ "MakePointerAvailableKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
+}
+
+TEST_F(ValidateMemory, CoopMatKHRInvalidStorageClassFail) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%str = OpTypeStruct %f16mat
+%str_ptr = OpTypePointer Workgroup %str
+%sh = OpVariable %str_ptr Workgroup
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Cooperative matrix types (or types containing them) can only be "
+ "allocated in Function or Private storage classes or as function "
+ "parameters"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthResultTypeBad) {
+ const std::string body = R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %i32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The Result Type of OpCooperativeMatrixLengthKHR <id> "
+ "'12[%12]' must be OpTypeInt with width 32 and signedness 0"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthOperandTypeBad) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %u32 %u32
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The type in OpCooperativeMatrixLengthKHR <id> '5[%uint]' "
+ "must be OpTypeCooperativeMatrixKHR"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixKHRLengthGood) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixKHR
+OpExtension "SPV_KHR_cooperative_matrix"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%use_A = OpConstant %u32 0
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixKHR %f16 %subgroup %u32_8 %u32_8 %use_A
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthKHR %u32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
std::string spirv = R"(
OpCapability Shader
@@ -3765,9 +3947,8 @@
HasSubstr("In the Vulkan environment, cannot store to Uniform Blocks"));
}
-using ValidateSizedVariable =
- spvtest::ValidateBase<std::tuple<std::string, std::string,
- std::string, spv_target_env>>;
+using ValidateSizedVariable = spvtest::ValidateBase<
+ std::tuple<std::string, std::string, std::string, spv_target_env>>;
CodeGenerator GetSizedVariableCodeGenerator(bool is_8bit, bool buffer_block) {
CodeGenerator generator;
@@ -3777,7 +3958,8 @@
"\"SPV_KHR_8bit_storage\"\n";
generator.memory_model_ = "OpMemoryModel Logical GLSL450\n";
if (is_8bit) {
- generator.before_types_ = "OpMemberDecorate %char_buffer_block 0 Offset 0\n";
+ generator.before_types_ =
+ "OpMemberDecorate %char_buffer_block 0 Offset 0\n";
if (buffer_block)
generator.before_types_ += "OpDecorate %char_buffer_block BufferBlock\n";
diff --git a/utils/generate_grammar_tables.py b/utils/generate_grammar_tables.py
index 6b7167b..e6a1455 100755
--- a/utils/generate_grammar_tables.py
+++ b/utils/generate_grammar_tables.py
@@ -540,7 +540,7 @@
# We have a few operand kinds that require their optional counterpart to
# exist in the operand info table.
- optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat']
+ optional_enums = ['ImageOperands', 'AccessQualifier', 'MemoryAccess', 'PackedVectorFormat', 'CooperativeMatrixOperands']
optional_enums = [e for e in enums if e[0] in optional_enums]
enums.extend(optional_enums)