Add validation for SPV_NV_cooperative_matrix (#2404)
diff --git a/DEPS b/DEPS
index 5668c66..3c5361f 100644
--- a/DEPS
+++ b/DEPS
@@ -11,7 +11,7 @@
'googletest_revision': '98a0d007d7092b72eea0e501bb9ad17908a1a036',
'testing_revision': '340252637e2e7c72c0901dcbeeacfff419e19b59',
're2_revision': '6cf8ccd82dbaab2668e9b13596c68183c9ecd13f',
- 'spirv_headers_revision': '79b6681aadcb53c27d1052e5f8a0e82a981dbf2f',
+ 'spirv_headers_revision': 'e74c389f81915d0a48d6df1af83c3862c5ad85ab',
}
deps = {
diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp
index 4d98e3d..79f18ee 100644
--- a/source/assembly_grammar.cpp
+++ b/source/assembly_grammar.cpp
@@ -154,10 +154,11 @@
CASE(InBoundsAccessChain),
CASE(PtrAccessChain),
CASE(InBoundsPtrAccessChain),
+ CASE(CooperativeMatrixLengthNV)
};
-// The 59 is determined by counting the opcodes listed in the spec.
-static_assert(59 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
+// The 60 is determined by counting the opcodes listed in the spec.
+static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
"OpSpecConstantOp opcode table is incomplete");
#undef CASE
// clang-format on
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 78c2386..da096a4 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -260,6 +260,7 @@
case SpvOpTypeMatrix:
case SpvOpTypeArray:
case SpvOpTypeStruct:
+ case SpvOpTypeCooperativeMatrixNV:
return true;
default:
return false;
@@ -325,6 +326,7 @@
case SpvOpTypePipeStorage:
case SpvOpTypeNamedBarrier:
case SpvOpTypeAccelerationStructureNV:
+ case SpvOpTypeCooperativeMatrixNV:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index 79d90bd..8106442 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -45,7 +45,8 @@
inline bool IsTypeInst(SpvOp opcode) {
return (opcode >= SpvOpTypeVoid && opcode <= SpvOpTypeForwardPointer) ||
opcode == SpvOpTypePipeStorage || opcode == SpvOpTypeNamedBarrier ||
- opcode == SpvOpTypeAccelerationStructureNV;
+ opcode == SpvOpTypeAccelerationStructureNV ||
+ opcode == SpvOpTypeCooperativeMatrixNV;
}
inline bool IsConstantInst(SpvOp opcode) {
return opcode >= SpvOpConstantTrue && opcode <= SpvOpSpecConstantOp;
diff --git a/source/val/validate_arithmetics.cpp b/source/val/validate_arithmetics.cpp
index 2314e7d..433330d 100644
--- a/source/val/validate_arithmetics.cpp
+++ b/source/val/validate_arithmetics.cpp
@@ -39,8 +39,11 @@
case SpvOpFRem:
case SpvOpFMod:
case SpvOpFNegate: {
+ bool supportsCoopMat =
+ (opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
if (!_.IsFloatScalarType(result_type) &&
- !_.IsFloatVectorType(result_type))
+ !_.IsFloatVectorType(result_type) &&
+ !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@@ -58,8 +61,11 @@
case SpvOpUDiv:
case SpvOpUMod: {
+ bool supportsCoopMat = (opcode == SpvOpUDiv);
if (!_.IsUnsignedIntScalarType(result_type) &&
- !_.IsUnsignedIntVectorType(result_type))
+ !_.IsUnsignedIntVectorType(result_type) &&
+ !(supportsCoopMat &&
+ _.IsUnsignedIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@@ -82,7 +88,10 @@
case SpvOpSMod:
case SpvOpSRem:
case SpvOpSNegate: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+ bool supportsCoopMat =
+ (opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
+ if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
@@ -94,7 +103,8 @@
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!type_id ||
- (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
+ (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
+ !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
@@ -176,7 +186,8 @@
}
case SpvOpMatrixTimesScalar: {
- if (!_.IsFloatMatrixType(result_type))
+ if (!_.IsFloatMatrixType(result_type) &&
+ !_.IsCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float matrix type as Result Type: "
<< spvOpcodeString(opcode);
@@ -442,6 +453,92 @@
break;
}
+ case SpvOpCooperativeMatrixMulAddNV: {
+ const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
+ const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
+ const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
+ const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
+
+ if (!_.IsCooperativeMatrixType(A_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as A Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixType(B_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as B Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixType(C_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as C Type: "
+ << spvOpcodeString(opcode);
+ }
+ if (!_.IsCooperativeMatrixType(D_type_id)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix type as Result Type: "
+ << spvOpcodeString(opcode);
+ }
+
+ const auto A = _.FindDef(A_type_id);
+ const auto B = _.FindDef(B_type_id);
+ const auto C = _.FindDef(C_type_id);
+ const auto D = _.FindDef(D_type_id);
+
+ std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
+ A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
+
+ A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
+ B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
+ C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
+ D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
+
+ A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
+ B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
+ C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
+ D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
+
+ A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
+ B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
+ C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
+ D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
+
+ const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
+ std::tuple<bool, bool, uint32_t> Y) {
+ return (std::get<1>(X) && std::get<1>(Y) &&
+ std::get<2>(X) != std::get<2>(Y));
+ };
+
+ if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
+ notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
+ notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix scopes must match: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
+ notEqual(C_rows, D_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'M' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
+ notEqual(C_cols, D_cols)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'N' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+
+ if (notEqual(A_cols, B_rows)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Cooperative matrix 'K' mismatch: "
+ << spvOpcodeString(opcode);
+ }
+ break;
+ }
+
default:
break;
}
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index ccc5587..de3210e 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -118,6 +118,10 @@
*member_type = type_inst->word(component_index + 2);
break;
}
+ case SpvOpTypeCooperativeMatrixNV: {
+ *member_type = type_inst->word(2);
+ break;
+ }
default:
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Reached non-composite type while indexes still remain to "
@@ -315,6 +319,26 @@
break;
}
+ case SpvOpTypeCooperativeMatrixNV: {
+ const auto result_type_inst = _.FindDef(result_type);
+ assert(result_type_inst);
+ const auto component_type_id =
+ result_type_inst->GetOperandAs<uint32_t>(1);
+
+ if (3 != num_operands) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected single constituent";
+ }
+
+ const uint32_t operand_type_id = _.GetOperandTypeId(inst, 2);
+
+ if (operand_type_id != component_type_id) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected Constituent type to be equal to the component type";
+ }
+
+ break;
+ }
default: {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be a composite type";
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index e2f20f6..c413b4f 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -247,6 +247,36 @@
}
}
} break;
+ case SpvOpTypeCooperativeMatrixNV: {
+ if (1 != constituent_count) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opcode_name << " Constituent <id> '"
+ << _.getIdName(inst->type_id()) << "' count must be one.";
+ }
+ const auto constituent_id = inst->GetOperandAs<uint32_t>(2);
+ const auto constituent = _.FindDef(constituent_id);
+ if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opcode_name << " Constituent <id> '"
+ << _.getIdName(constituent_id)
+ << "' is not a constant or undef.";
+ }
+ const auto constituent_type = _.FindDef(constituent->type_id());
+ if (!constituent_type) {
+ return _.diag(SPV_ERROR_INVALID_ID, constituent)
+ << "Result type is not defined.";
+ }
+
+ const auto component_type_id = result_type->GetOperandAs<uint32_t>(1);
+ const auto component_type = _.FindDef(component_type_id);
+ if (!component_type || component_type->id() != constituent_type->id()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opcode_name << " Constituent <id> '"
+ << _.getIdName(constituent_id)
+ << "' type does not match the Result Type <id> '"
+ << _.getIdName(result_type->id()) << "'s component type.";
+ }
+ } break;
default:
break;
}
@@ -285,6 +315,7 @@
return true;
case SpvOpTypeArray:
case SpvOpTypeMatrix:
+ case SpvOpTypeCooperativeMatrixNV:
case SpvOpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
@@ -320,7 +351,7 @@
// The binary parser already ensures that the op is valid for *some*
// environment. Here we check restrictions.
- switch(op) {
+ switch (op) {
case SpvOpQuantizeToF16:
if (!_.HasCapability(SpvCapabilityShader)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -365,7 +396,7 @@
}
break;
- default:
+ default:
break;
}
diff --git a/source/val/validate_conversion.cpp b/source/val/validate_conversion.cpp
index 73da582..17af9f4 100644
--- a/source/val/validate_conversion.cpp
+++ b/source/val/validate_conversion.cpp
@@ -32,22 +32,31 @@
switch (opcode) {
case SpvOpConvertFToU: {
if (!_.IsUnsignedIntScalarType(result_type) &&
- !_.IsUnsignedIntVectorType(result_type))
+ !_.IsUnsignedIntVectorType(result_type) &&
+ !_.IsUnsignedIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
- !_.IsFloatVectorType(input_type)))
+ !_.IsFloatVectorType(input_type) &&
+ !_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -58,22 +67,31 @@
}
case SpvOpConvertFToS: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+ if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+ !_.IsIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
- !_.IsFloatVectorType(input_type)))
+ !_.IsFloatVectorType(input_type) &&
+ !_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (!_.features().use_int8_type && (8 == _.GetBitWidth(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -86,22 +104,31 @@
case SpvOpConvertSToF:
case SpvOpConvertUToF: {
if (!_.IsFloatScalarType(result_type) &&
- !_.IsFloatVectorType(result_type))
+ !_.IsFloatVectorType(result_type) &&
+ !_.IsFloatCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
- (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+ (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+ !_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (!_.features().use_int8_type && (8 == _.GetBitWidth(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -113,22 +140,31 @@
case SpvOpUConvert: {
if (!_.IsUnsignedIntScalarType(result_type) &&
- !_.IsUnsignedIntVectorType(result_type))
+ !_.IsUnsignedIntVectorType(result_type) &&
+ !_.IsUnsignedIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
- (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+ (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+ !_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -139,22 +175,31 @@
}
case SpvOpSConvert: {
- if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
+ if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
+ !_.IsIntCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type ||
- (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type)))
+ (!_.IsIntScalarType(input_type) && !_.IsIntVectorType(input_type) &&
+ !_.IsIntCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be int scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
@@ -166,22 +211,31 @@
case SpvOpFConvert: {
if (!_.IsFloatScalarType(result_type) &&
- !_.IsFloatVectorType(result_type))
+ !_.IsFloatVectorType(result_type) &&
+ !_.IsFloatCooperativeMatrixType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || (!_.IsFloatScalarType(input_type) &&
- !_.IsFloatVectorType(input_type)))
+ !_.IsFloatVectorType(input_type) &&
+ !_.IsFloatCooperativeMatrixType(input_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected input to be float scalar or vector: "
<< spvOpcodeString(opcode);
- if (_.GetDimension(result_type) != _.GetDimension(input_type))
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << "Expected input to have the same dimension as Result Type: "
- << spvOpcodeString(opcode);
+ if (_.IsCooperativeMatrixType(result_type) ||
+ _.IsCooperativeMatrixType(input_type)) {
+ spv_result_t ret =
+ _.CooperativeMatrixShapesMatch(inst, result_type, input_type);
+ if (ret != SPV_SUCCESS) return ret;
+ } else {
+ if (_.GetDimension(result_type) != _.GetDimension(input_type))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected input to have the same dimension as Result Type: "
+ << spvOpcodeString(opcode);
+ }
if (_.GetBitWidth(result_type) == _.GetBitWidth(input_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index 21a0411..cb18e13 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -167,7 +167,10 @@
const auto opcode = inst->opcode();
if (spvOpcodeGeneratesType(def->opcode()) &&
!spvOpcodeGeneratesType(opcode) && !spvOpcodeIsDebug(opcode) &&
- !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction) {
+ !spvOpcodeIsDecoration(opcode) && opcode != SpvOpFunction &&
+ opcode != SpvOpCooperativeMatrixLengthNV &&
+ !(opcode == SpvOpSpecConstantOp &&
+ inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " cannot be a type";
@@ -177,7 +180,10 @@
!spvOpcodeIsBranch(opcode) && opcode != SpvOpPhi &&
opcode != SpvOpExtInst && opcode != SpvOpExtInstImport &&
opcode != SpvOpSelectionMerge &&
- opcode != SpvOpLoopMerge && opcode != SpvOpFunction) {
+ opcode != SpvOpLoopMerge && opcode != SpvOpFunction &&
+ opcode != SpvOpCooperativeMatrixLengthNV &&
+ !(opcode == SpvOpSpecConstantOp &&
+ inst->word(3) == SpvOpCooperativeMatrixLengthNV)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Operand " << _.getIdName(operand_word)
<< " requires a type";
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 9e93cf1..f6127a1 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -197,17 +197,49 @@
return false;
}
+bool ContainsCooperativeMatrix(ValidationState_t& _,
+ const Instruction* storage) {
+ const size_t elem_type_index = 1;
+ uint32_t elem_type_id;
+ Instruction* elem_type;
+
+ switch (storage->opcode()) {
+ case SpvOpTypeCooperativeMatrixNV:
+ return true;
+ case SpvOpTypeArray:
+ case SpvOpTypeRuntimeArray:
+ elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
+ elem_type = _.FindDef(elem_type_id);
+ return ContainsCooperativeMatrix(_, elem_type);
+ case SpvOpTypeStruct:
+ for (size_t member_type_index = 1;
+ member_type_index < storage->operands().size();
+ ++member_type_index) {
+ auto member_type_id =
+ storage->GetOperandAs<uint32_t>(member_type_index);
+ auto member_type = _.FindDef(member_type_id);
+ if (ContainsCooperativeMatrix(_, member_type)) return true;
+ }
+ break;
+ default:
+ break;
+ }
+ return false;
+}
+
std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
ValidationState_t& _, const Instruction* inst) {
SpvStorageClass dst_sc = SpvStorageClassMax;
SpvStorageClass src_sc = SpvStorageClassMax;
switch (inst->opcode()) {
+ case SpvOpCooperativeMatrixLoadNV:
case SpvOpLoad: {
auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
auto load_pointer_type = _.FindDef(load_pointer->type_id());
dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
break;
}
+ case SpvOpCooperativeMatrixStoreNV:
case SpvOpStore: {
auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
auto store_pointer_type = _.FindDef(store_pointer->type_id());
@@ -232,7 +264,8 @@
}
// This function is only called for OpLoad, OpStore, OpCopyMemory and
-// OpCopyMemorySized.
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
+// OpCooperativeMatrixStoreNV.
uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
uint32_t offset = 1;
if (mask & SpvMemoryAccessAlignedMask) ++offset;
@@ -245,6 +278,10 @@
case SpvOpStore:
case SpvOpCopyMemory:
return inst->GetOperandAs<uint32_t>(2 + offset);
+ case SpvOpCooperativeMatrixLoadNV:
+ return inst->GetOperandAs<uint32_t>(5 + offset);
+ case SpvOpCooperativeMatrixStoreNV:
+ return inst->GetOperandAs<uint32_t>(4 + offset);
default:
assert(false && "unexpected opcode");
break;
@@ -253,8 +290,9 @@
return scope_id;
}
-// This function is only called for OpLoad, OpStore, OpCopyMemory and
-// OpCopyMemorySized.
+// This function is only called for OpLoad, OpStore, OpCopyMemory,
+// OpCopyMemorySized, OpCooperativeMatrixLoadNV, and
+// OpCooperativeMatrixStoreNV.
uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
uint32_t offset = 1;
if (mask & SpvMemoryAccessAlignedMask) ++offset;
@@ -268,6 +306,10 @@
case SpvOpStore:
case SpvOpCopyMemory:
return inst->GetOperandAs<uint32_t>(2 + offset);
+ case SpvOpCooperativeMatrixLoadNV:
+ return inst->GetOperandAs<uint32_t>(5 + offset);
+ case SpvOpCooperativeMatrixStoreNV:
+ return inst->GetOperandAs<uint32_t>(4 + offset);
default:
assert(false && "unexpected opcode");
break;
@@ -302,7 +344,8 @@
uint32_t mask = inst->GetOperandAs<uint32_t>(index);
if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
- if (inst->opcode() == SpvOpLoad) {
+ if (inst->opcode() == SpvOpLoad ||
+ inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerAvailableKHR cannot be used with OpLoad.";
}
@@ -320,7 +363,8 @@
}
if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
- if (inst->opcode() == SpvOpStore) {
+ if (inst->opcode() == SpvOpStore ||
+ inst->opcode() == SpvOpCooperativeMatrixStoreNV) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MakePointerVisibleKHR cannot be used with OpStore.";
}
@@ -672,6 +716,17 @@
}
}
+ // Cooperative matrix types can only be allocated in Function or Private
+ if ((storage_class != SpvStorageClassFunction &&
+ storage_class != SpvStorageClassPrivate) &&
+ ContainsCooperativeMatrix(_, pointee)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Cooperative matrix types (or types containing them) can only be "
+ "allocated "
+ << "in Function or Private storage classes or as function "
+ "parameters";
+ }
+
return SPV_SUCCESS;
}
@@ -1003,10 +1058,11 @@
switch (type_pointee->opcode()) {
case SpvOpTypeMatrix:
case SpvOpTypeVector:
+ case SpvOpTypeCooperativeMatrixNV:
case SpvOpTypeArray:
case SpvOpTypeRuntimeArray: {
- // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
- // word 2 is the Element Type.
+ // In OpTypeMatrix, OpTypeVector, SpvOpTypeCooperativeMatrixNV,
+ // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
type_pointee = _.FindDef(type_pointee->word(2));
break;
}
@@ -1136,6 +1192,140 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateCooperativeMatrixLengthNV(ValidationState_t& state,
+ const Instruction* inst) {
+ std::string instr_name =
+ "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
+
+ // Result type must be a 32-bit unsigned int.
+ auto result_type = state.FindDef(inst->type_id());
+ if (result_type->opcode() != SpvOpTypeInt ||
+ result_type->GetOperandAs<uint32_t>(1) != 32 ||
+ result_type->GetOperandAs<uint32_t>(2) != 0) {
+ return state.diag(SPV_ERROR_INVALID_ID, inst)
+ << "The Result Type of " << instr_name << " <id> '"
+ << state.getIdName(inst->id())
+ << "' must be OpTypeInt with width 32 and signedness 0.";
+ }
+
+ auto type_id = inst->GetOperandAs<uint32_t>(2);
+ auto type = state.FindDef(type_id);
+ if (type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+ return state.diag(SPV_ERROR_INVALID_ID, inst)
+ << "The type in " << instr_name << " <id> '"
+ << state.getIdName(type_id)
+ << "' must be OpTypeCooperativeMatrixNV.";
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t ValidateCooperativeMatrixLoadStoreNV(ValidationState_t& _,
+ const Instruction* inst) {
+ uint32_t type_id;
+ const char* opname;
+ if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
+ type_id = inst->type_id();
+ opname = "SpvOpCooperativeMatrixLoadNV";
+ } else {
+ // get Object operand's type
+ type_id = _.FindDef(inst->GetOperandAs<uint32_t>(1))->type_id();
+ opname = "SpvOpCooperativeMatrixStoreNV";
+ }
+
+ auto matrix_type = _.FindDef(type_id);
+
+ if (matrix_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+ if (inst->opcode() == SpvOpCooperativeMatrixLoadNV) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "SpvOpCooperativeMatrixLoadNV Result Type <id> '"
+ << _.getIdName(type_id) << "' is not a cooperative matrix type.";
+ } else {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "SpvOpCooperativeMatrixStoreNV Object type <id> '"
+ << _.getIdName(type_id) << "' is not a cooperative matrix type.";
+ }
+ }
+
+ const bool uses_variable_pointers =
+ _.features().variable_pointers ||
+ _.features().variable_pointers_storage_buffer;
+ const auto pointer_index =
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 2u : 0u;
+ const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
+ const auto pointer = _.FindDef(pointer_id);
+ if (!pointer ||
+ ((_.addressing_model() == SpvAddressingModelLogical) &&
+ ((!uses_variable_pointers &&
+ !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
+ (uses_variable_pointers &&
+ !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " Pointer <id> '" << _.getIdName(pointer_id)
+ << "' is not a logical pointer.";
+ }
+
+ const auto pointer_type_id = pointer->type_id();
+ const auto pointer_type = _.FindDef(pointer_type_id);
+ if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " type for pointer <id> '" << _.getIdName(pointer_id)
+ << "' is not a pointer type.";
+ }
+
+ const auto storage_class_index = 1u;
+ const auto storage_class =
+ pointer_type->GetOperandAs<uint32_t>(storage_class_index);
+
+ if (storage_class != SpvStorageClassWorkgroup &&
+ storage_class != SpvStorageClassStorageBuffer &&
+ storage_class != SpvStorageClassPhysicalStorageBufferEXT) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " storage class for pointer type <id> '"
+ << _.getIdName(pointer_type_id)
+ << "' is not Workgroup or StorageBuffer.";
+ }
+
+ const auto pointee_id = pointer_type->GetOperandAs<uint32_t>(2);
+ const auto pointee_type = _.FindDef(pointee_id);
+ if (!pointee_type || !(_.IsIntScalarOrVectorType(pointee_id) ||
+ _.IsFloatScalarOrVectorType(pointee_id))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << opname << " Pointer <id> '" << _.getIdName(pointer->id())
+ << "'s Type must be a scalar or vector type.";
+ }
+
+ const auto stride_index =
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 3u : 2u;
+ const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
+ const auto stride = _.FindDef(stride_id);
+ if (!stride || !_.IsIntScalarType(stride->type_id())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Stride operand <id> '" << _.getIdName(stride_id)
+ << "' must be a scalar integer type.";
+ }
+
+ const auto colmajor_index =
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 4u : 3u;
+ const auto colmajor_id = inst->GetOperandAs<uint32_t>(colmajor_index);
+ const auto colmajor = _.FindDef(colmajor_id);
+ if (!colmajor || !_.IsBoolScalarType(colmajor->type_id()) ||
+ !(spvOpcodeIsConstant(colmajor->opcode()) ||
+ spvOpcodeIsSpecConstant(colmajor->opcode()))) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Column Major operand <id> '" << _.getIdName(colmajor_id)
+ << "' must be a boolean constant instruction.";
+ }
+
+ const auto memory_access_index =
+ (inst->opcode() == SpvOpCooperativeMatrixLoadNV) ? 5u : 4u;
+ if (inst->operands().size() > memory_access_index) {
+ if (auto error = CheckMemoryAccess(_, inst, memory_access_index))
+ return error;
+ }
+
+ return SPV_SUCCESS;
+}
+
} // namespace
spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
@@ -1164,6 +1354,14 @@
case SpvOpArrayLength:
if (auto error = ValidateArrayLength(_, inst)) return error;
break;
+ case SpvOpCooperativeMatrixLoadNV:
+ case SpvOpCooperativeMatrixStoreNV:
+ if (auto error = ValidateCooperativeMatrixLoadStoreNV(_, inst))
+ return error;
+ break;
+ case SpvOpCooperativeMatrixLengthNV:
+ if (auto error = ValidateCooperativeMatrixLengthNV(_, inst)) return error;
+ break;
case SpvOpImageTexelPointer:
case SpvOpGenericPtrMemSemantics:
default:
diff --git a/source/val/validate_scopes.cpp b/source/val/validate_scopes.cpp
index b640131..2223a77 100644
--- a/source/val/validate_scopes.cpp
+++ b/source/val/validate_scopes.cpp
@@ -36,11 +36,19 @@
}
if (!is_const_int32) {
- if (_.HasCapability(SpvCapabilityShader)) {
+ if (_.HasCapability(SpvCapabilityShader) &&
+ !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be OpConstant when Shader capability is "
<< "present";
}
+ if (_.HasCapability(SpvCapabilityShader) &&
+ _.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
+ !spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Scope ids must be constant or specialization constant when "
+ << "CooperativeMatrixNV capability is present";
+ }
return SPV_SUCCESS;
}
@@ -130,11 +138,19 @@
}
if (!is_const_int32) {
- if (_.HasCapability(SpvCapabilityShader)) {
+ if (_.HasCapability(SpvCapabilityShader) &&
+ !_.HasCapability(SpvCapabilityCooperativeMatrixNV)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Scope ids must be OpConstant when Shader capability is "
<< "present";
}
+ if (_.HasCapability(SpvCapabilityShader) &&
+ _.HasCapability(SpvCapabilityCooperativeMatrixNV) &&
+ !spvOpcodeIsConstant(_.GetIdOpcode(scope))) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Scope ids must be constant or specialization constant when "
+ << "CooperativeMatrixNV capability is present";
+ }
return SPV_SUCCESS;
}
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index a5428d7..ad72a37 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -381,6 +381,53 @@
return SPV_SUCCESS;
}
+spv_result_t ValidateTypeCooperativeMatrixNV(ValidationState_t& _,
+ const Instruction* inst) {
+ const auto component_type_index = 1;
+ const auto component_type_id =
+ inst->GetOperandAs<uint32_t>(component_type_index);
+ const auto component_type = _.FindDef(component_type_id);
+ if (!component_type || (SpvOpTypeFloat != component_type->opcode() &&
+ SpvOpTypeInt != component_type->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrixNV Component Type <id> '"
+ << _.getIdName(component_type_id)
+ << "' is not a scalar numerical type.";
+ }
+
+ const auto scope_index = 2;
+ const auto scope_id = inst->GetOperandAs<uint32_t>(scope_index);
+ const auto scope = _.FindDef(scope_id);
+ if (!scope || !_.IsIntScalarType(scope->type_id()) ||
+ !spvOpcodeIsConstant(scope->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrixNV Scope <id> '" << _.getIdName(scope_id)
+ << "' is not a constant instruction with scalar integer type.";
+ }
+
+ const auto rows_index = 3;
+ const auto rows_id = inst->GetOperandAs<uint32_t>(rows_index);
+ const auto rows = _.FindDef(rows_id);
+ if (!rows || !_.IsIntScalarType(rows->type_id()) ||
+ !spvOpcodeIsConstant(rows->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrixNV Rows <id> '" << _.getIdName(rows_id)
+ << "' is not a constant instruction with scalar integer type.";
+ }
+
+ const auto cols_index = 4;
+ const auto cols_id = inst->GetOperandAs<uint32_t>(cols_index);
+ const auto cols = _.FindDef(cols_id);
+ if (!cols || !_.IsIntScalarType(cols->type_id()) ||
+ !spvOpcodeIsConstant(cols->opcode())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpTypeCooperativeMatrixNV Cols <id> '" << _.getIdName(rows_id)
+ << "' is not a constant instruction with scalar integer type.";
+ }
+
+ return SPV_SUCCESS;
+}
+
} // namespace
spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
@@ -416,6 +463,9 @@
case SpvOpTypeForwardPointer:
if (auto error = ValidateTypeForwardPointer(_, inst)) return error;
break;
+ case SpvOpTypeCooperativeMatrixNV:
+ if (auto error = ValidateTypeCooperativeMatrixNV(_, inst)) return error;
+ break;
default:
break;
}
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 2633963..e6e5e26 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -610,6 +610,9 @@
case SpvOpTypeMatrix:
return GetComponentType(inst->word(2));
+ case SpvOpTypeCooperativeMatrixNV:
+ return inst->word(2);
+
default:
break;
}
@@ -634,6 +637,10 @@
case SpvOpTypeMatrix:
return inst->word(3);
+ case SpvOpTypeCooperativeMatrixNV:
+ // Actual dimension isn't known, return 0
+ return 0;
+
default:
break;
}
@@ -862,6 +869,86 @@
return true;
}
+bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ assert(inst);
+ return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
+}
+
+bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
+ if (!IsCooperativeMatrixType(id)) return false;
+ return IsFloatScalarType(FindDef(id)->word(2));
+}
+
+bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
+ if (!IsCooperativeMatrixType(id)) return false;
+ return IsIntScalarType(FindDef(id)->word(2));
+}
+
+bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
+ if (!IsCooperativeMatrixType(id)) return false;
+ return IsUnsignedIntScalarType(FindDef(id)->word(2));
+}
+
+spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
+ const Instruction* inst, uint32_t m1, uint32_t m2) {
+ const auto m1_type = FindDef(m1);
+ const auto m2_type = FindDef(m2);
+
+ if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
+ m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
+ return diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected cooperative matrix types";
+ }
+
+ uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
+ uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
+ uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
+
+ uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
+ uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
+ uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
+
+ bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
+ m2_is_const_int32 = false;
+ uint32_t m1_value = 0, m2_value = 0;
+
+ std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+ EvalInt32IfConst(m1_scope_id);
+ std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+ EvalInt32IfConst(m2_scope_id);
+
+ if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+ return diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected scopes of Matrix and Result Type to be "
+ << "identical";
+ }
+
+ std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+ EvalInt32IfConst(m1_rows_id);
+ std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+ EvalInt32IfConst(m2_rows_id);
+
+ if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+ return diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected rows of Matrix type and Result Type to be "
+ << "identical";
+ }
+
+ std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
+ EvalInt32IfConst(m1_cols_id);
+ std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
+ EvalInt32IfConst(m2_cols_id);
+
+ if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
+ return diag(SPV_ERROR_INVALID_DATA, inst)
+ << "Expected columns of Matrix type and Result Type to be "
+ << "identical";
+ }
+
+ return SPV_SUCCESS;
+}
+
uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
size_t operand_index) const {
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
@@ -890,7 +977,7 @@
}
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
- uint32_t id) {
+ uint32_t id) const {
const Instruction* const inst = FindDef(id);
assert(inst);
const uint32_t type = inst->type_id();
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 55005a6..94fa945 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -552,6 +552,10 @@
bool IsBoolVectorType(uint32_t id) const;
bool IsBoolScalarOrVectorType(uint32_t id) const;
bool IsPointerType(uint32_t id) const;
+ bool IsCooperativeMatrixType(uint32_t id) const;
+ bool IsFloatCooperativeMatrixType(uint32_t id) const;
+ bool IsIntCooperativeMatrixType(uint32_t id) const;
+ bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
// Gets value from OpConstant and OpSpecConstant as uint64.
// Returns false on failure (no instruction, wrong instruction, not int).
@@ -635,7 +639,7 @@
// Returns tuple <is_int32, is_const_int32, value>.
// OpSpecConstant* return |is_const_int32| as false since their values cannot
// be relied upon during validation.
- std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id);
+ std::tuple<bool, bool, uint32_t> EvalInt32IfConst(uint32_t id) const;
// Returns the disassembly string for the given instruction.
std::string Disassemble(const Instruction& inst) const;
@@ -643,6 +647,12 @@
// Returns the disassembly string for the given instruction.
std::string Disassemble(const uint32_t* words, uint16_t num_words) const;
+ // Returns whether type m1 and type m2 are cooperative matrices with
+ // the same "shape" (matching scope, rows, cols). If any are specialization
+ // constants, we assume they can match because we can't prove they don't.
+ spv_result_t CooperativeMatrixShapesMatch(const Instruction* inst,
+ uint32_t m1, uint32_t m2);
+
private:
ValidationState_t(const ValidationState_t&);
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index 87e006c..b82fc97 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -1165,6 +1165,150 @@
"vector size of the right operand: OuterProduct"));
}
+std::string GenerateCoopMatCode(const std::string& extra_types,
+ const std::string& main_body) {
+ const std::string prefix =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_16 = OpConstant %u32 16
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
+%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u32_1 = OpConstant %u32 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%u32_c1 = OpSpecConstant %u32 1
+%u32_c2 = OpSpecConstant %u32 2
+
+%f16matc = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_c1 %u32_c2
+%f16matc_1 = OpConstantComposite %f16matc %f16_1
+
+%mat16x4 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_4
+%mat4x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_4 %u32_16
+%mat16x16 = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_16 %u32_16
+%f16mat_16x4_1 = OpConstantComposite %mat16x4 %f16_1
+%f16mat_4x16_1 = OpConstantComposite %mat4x16 %f16_1
+%f16mat_16x16_1 = OpConstantComposite %mat16x16 %f16_1)";
+
+ const std::string func_begin =
+ R"(
+%main = OpFunction %void None %func
+%main_entry = OpLabel)";
+
+ const std::string suffix =
+ R"(
+OpReturn
+OpFunctionEnd)";
+
+ return prefix + extra_types + func_begin + main_body + suffix;
+}
+
+TEST_F(ValidateArithmetics, CoopMatSuccess) {
+ const std::string body = R"(
+%val1 = OpFAdd %f16mat %f16mat_1 %f16mat_1
+%val2 = OpFSub %f16mat %f16mat_1 %f16mat_1
+%val3 = OpFDiv %f16mat %f16mat_1 %f16mat_1
+%val4 = OpFNegate %f16mat %f16mat_1
+%val5 = OpIAdd %u32mat %u32mat_1 %u32mat_1
+%val6 = OpISub %u32mat %u32mat_1 %u32mat_1
+%val7 = OpUDiv %u32mat %u32mat_1 %u32mat_1
+%val8 = OpIAdd %s32mat %s32mat_1 %s32mat_1
+%val9 = OpISub %s32mat %s32mat_1 %s32mat_1
+%val10 = OpSDiv %s32mat %s32mat_1 %s32mat_1
+%val11 = OpSNegate %s32mat %s32mat_1
+%val12 = OpMatrixTimesScalar %f16mat %f16mat_1 %f16_1
+%val13 = OpMatrixTimesScalar %u32mat %u32mat_1 %u32_1
+%val14 = OpMatrixTimesScalar %s32mat %s32mat_1 %s32_1
+%val15 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16mat_16x16_1
+%val16 = OpCooperativeMatrixMulAddNV %f16matc %f16matc_1 %f16matc_1 %f16matc_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, CoopMatFMulFail) {
+ const std::string body = R"(
+%val1 = OpFMul %f16mat %f16mat_1 %f16mat_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Expected floating scalar or vector type as Result Type: FMul"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatMatrixTimesScalarMismatchFail) {
+ const std::string body = R"(
+%val1 = OpMatrixTimesScalar %f16mat %f16mat_1 %f32_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected scalar operand type to be equal to the component "
+ "type of the matrix operand: MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatScopeFail) {
+ const std::string types = R"(
+%workgroup = OpConstant %u32 2
+
+%mat16x16_wg = OpTypeCooperativeMatrixNV %f16 %workgroup %u32_16 %u32_16
+%f16matwg_16x16_1 = OpConstantComposite %mat16x16_wg %f16_1
+)";
+
+ const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_16x4_1 %f16mat_4x16_1 %f16matwg_16x16_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatCode(types, body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Cooperative matrix scopes must match: CooperativeMatrixMulAddNV"));
+}
+
+TEST_F(ValidateArithmetics, CoopMatDimFail) {
+ const std::string body = R"(
+%val1 = OpCooperativeMatrixMulAddNV %mat16x16 %f16mat_4x16_1 %f16mat_16x4_1 %f16mat_16x16_1
+)";
+
+ CompileSuccessfully(GenerateCoopMatCode("", body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Cooperative matrix 'M' mismatch: CooperativeMatrixMulAddNV"));
+}
+
TEST_F(ValidateArithmetics, IAddCarrySuccess) {
const std::string body = R"(
%val1 = OpIAddCarry %struct_u32_u32 %u32_0 %u32_1
diff --git a/test/val/val_composites_test.cpp b/test/val/val_composites_test.cpp
index bf7f15d..db6ff5b 100644
--- a/test/val/val_composites_test.cpp
+++ b/test/val/val_composites_test.cpp
@@ -1467,6 +1467,83 @@
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}
+TEST_F(ValidateComposites, CoopMatConstantCompositeMismatchFail) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%f32_1 = OpConstant %f32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpConstantComposite Constituent <id> '11[%float_1]' type does "
+ "not match the Result Type <id> '10[%10]'s component type."));
+}
+
+TEST_F(ValidateComposites, CoopMatCompositeConstructMismatchFail) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%f32_1 = OpConstant %f32 1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%f16mat_1 = OpCompositeConstruct %f16mat %f32_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Expected Constituent type to be equal to the component type"));
+}
+
TEST_F(ValidateComposites, ExtractDynamicLabelIndex) {
const std::string spirv = R"(
OpCapability Shader
diff --git a/test/val/val_conversion_test.cpp b/test/val/val_conversion_test.cpp
index 5e4ad49..f905657 100644
--- a/test/val/val_conversion_test.cpp
+++ b/test/val/val_conversion_test.cpp
@@ -1184,6 +1184,172 @@
"GenericCastToPtrExplicit"));
}
+TEST_F(ValidateConversion, CoopMatConversionSuccess) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_8 %u32_8
+%u16mat = OpTypeCooperativeMatrixNV %u16 %subgroup %u32_8 %u32_8
+%u32mat = OpTypeCooperativeMatrixNV %u32 %subgroup %u32_8 %u32_8
+%s16mat = OpTypeCooperativeMatrixNV %s16 %subgroup %u32_8 %u32_8
+%s32mat = OpTypeCooperativeMatrixNV %s32 %subgroup %u32_8 %u32_8
+
+%f16_1 = OpConstant %f16 1
+%f32_1 = OpConstant %f32 1
+%u16_1 = OpConstant %u16 1
+%u32_1 = OpConstant %u32 1
+%s16_1 = OpConstant %s16 1
+%s32_1 = OpConstant %s32 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+%f32mat_1 = OpConstantComposite %f32mat %f32_1
+%u16mat_1 = OpConstantComposite %u16mat %u16_1
+%u32mat_1 = OpConstantComposite %u32mat %u32_1
+%s16mat_1 = OpConstantComposite %s16mat %s16_1
+%s32mat_1 = OpConstantComposite %s32mat %s32_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val11 = OpConvertFToU %u16mat %f16mat_1
+%val12 = OpConvertFToU %u32mat %f16mat_1
+%val13 = OpConvertFToS %s16mat %f16mat_1
+%val14 = OpConvertFToS %s32mat %f16mat_1
+%val15 = OpFConvert %f32mat %f16mat_1
+
+%val21 = OpConvertFToU %u16mat %f32mat_1
+%val22 = OpConvertFToU %u32mat %f32mat_1
+%val23 = OpConvertFToS %s16mat %f32mat_1
+%val24 = OpConvertFToS %s32mat %f32mat_1
+%val25 = OpFConvert %f16mat %f32mat_1
+
+%val31 = OpConvertUToF %f16mat %u16mat_1
+%val32 = OpConvertUToF %f32mat %u16mat_1
+%val33 = OpUConvert %u32mat %u16mat_1
+%val34 = OpSConvert %s32mat %u16mat_1
+
+%val41 = OpConvertSToF %f16mat %s16mat_1
+%val42 = OpConvertSToF %f32mat %s16mat_1
+%val43 = OpUConvert %u32mat %s16mat_1
+%val44 = OpSConvert %s32mat %s16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateConversion, CoopMatConversionShapesMismatchFail) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val15 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Expected rows of Matrix type and Result Type to be identical"));
+}
+
+TEST_F(ValidateConversion, CoopMatConversionShapesMismatchPass) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability Int16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%f16 = OpTypeFloat 16
+%f32 = OpTypeFloat 32
+%u16 = OpTypeInt 16 0
+%u32 = OpTypeInt 32 0
+%s16 = OpTypeInt 16 1
+%s32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%u32_4 = OpSpecConstant %u32 4
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+%f32mat = OpTypeCooperativeMatrixNV %f32 %subgroup %u32_4 %u32_4
+
+%f16_1 = OpConstant %f16 1
+
+%f16mat_1 = OpConstantComposite %f16mat %f16_1
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%val15 = OpFConvert %f32mat %f16mat_1
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateConversion, BitcastSuccess) {
const std::string body = R"(
%ptr = OpVariable %f32ptr_func Function
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index b567a7b..246b85e 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -1774,6 +1774,372 @@
HasSubstr("PhysicalStorageBufferEXT must not be used with OpVariable"));
}
+std::string GenCoopMatLoadStoreShader(const std::string& storeMemoryAccess,
+ const std::string& loadMemoryAccess) {
+ std::string s = R"(
+OpCapability Shader
+OpCapability GroupNonUniform
+OpCapability VulkanMemoryModelKHR
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpExtension "SPV_NV_cooperative_matrix"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint GLCompute %4 "main" %11 %21
+OpExecutionMode %4 LocalSize 1 1 1
+OpDecorate %11 BuiltIn SubgroupId
+OpDecorate %21 BuiltIn WorkgroupId
+OpDecorate %74 ArrayStride 4
+OpMemberDecorate %75 0 Offset 0
+OpDecorate %75 Block
+OpDecorate %77 DescriptorSet 0
+OpDecorate %77 Binding 0
+OpDecorate %92 ArrayStride 4
+OpMemberDecorate %93 0 Offset 0
+OpDecorate %93 Block
+OpDecorate %95 DescriptorSet 0
+OpDecorate %95 Binding 1
+OpDecorate %102 ArrayStride 4
+OpMemberDecorate %103 0 Offset 0
+OpDecorate %103 Block
+OpDecorate %105 DescriptorSet 0
+OpDecorate %105 Binding 2
+OpDecorate %117 ArrayStride 4
+OpMemberDecorate %118 0 Offset 0
+OpDecorate %118 Block
+OpDecorate %120 DescriptorSet 0
+OpDecorate %120 Binding 3
+OpDecorate %123 SpecId 2
+OpDecorate %124 SpecId 3
+OpDecorate %125 SpecId 4
+OpDecorate %126 SpecId 5
+OpDecorate %127 SpecId 0
+OpDecorate %128 SpecId 1
+OpDecorate %129 BuiltIn WorkgroupSize
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%6 = OpTypeInt 32 0
+%7 = OpTypeVector %6 2
+%8 = OpTypePointer Function %7
+%10 = OpTypePointer Input %6
+%11 = OpVariable %10 Input
+%13 = OpConstant %6 2
+%19 = OpTypeVector %6 3
+%20 = OpTypePointer Input %19
+%21 = OpVariable %20 Input
+%27 = OpConstantComposite %7 %13 %13
+%31 = OpTypePointer Function %6
+%33 = OpConstant %6 1024
+%34 = OpConstant %6 1
+%38 = OpConstant %6 8
+%39 = OpConstant %6 0
+%68 = OpTypeFloat 32
+%69 = OpConstant %6 16
+%70 = OpConstant %6 3
+%71 = OpTypeCooperativeMatrixNV %68 %70 %69 %38
+%72 = OpTypePointer Function %71
+%74 = OpTypeRuntimeArray %68
+%75 = OpTypeStruct %74
+%76 = OpTypePointer StorageBuffer %75
+%77 = OpVariable %76 StorageBuffer
+%78 = OpTypeInt 32 1
+%79 = OpConstant %78 0
+%81 = OpConstant %6 5
+%82 = OpTypePointer StorageBuffer %68
+%84 = OpConstant %6 64
+%85 = OpTypeBool
+%86 = OpConstantFalse %85
+%88 = OpTypePointer Private %71
+%89 = OpVariable %88 Private
+%92 = OpTypeRuntimeArray %68
+%93 = OpTypeStruct %92
+%94 = OpTypePointer StorageBuffer %93
+%95 = OpVariable %94 StorageBuffer
+%99 = OpVariable %88 Private
+%102 = OpTypeRuntimeArray %68
+%103 = OpTypeStruct %102
+%104 = OpTypePointer StorageBuffer %103
+%105 = OpVariable %104 StorageBuffer
+%109 = OpVariable %88 Private
+%111 = OpVariable %88 Private
+%112 = OpSpecConstantOp %6 CooperativeMatrixLengthNV %71
+%113 = OpSpecConstantOp %78 IAdd %112 %79
+%117 = OpTypeRuntimeArray %68
+%118 = OpTypeStruct %117
+%119 = OpTypePointer StorageBuffer %118
+%120 = OpVariable %119 StorageBuffer
+%123 = OpSpecConstant %78 1
+%124 = OpSpecConstant %78 1
+%125 = OpSpecConstant %78 1
+%126 = OpSpecConstant %78 1
+%127 = OpSpecConstant %6 1
+%128 = OpSpecConstant %6 1
+%129 = OpSpecConstantComposite %19 %127 %128 %34
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%9 = OpVariable %8 Function
+%18 = OpVariable %8 Function
+%32 = OpVariable %31 Function
+%44 = OpVariable %31 Function
+%52 = OpVariable %31 Function
+%60 = OpVariable %31 Function
+%73 = OpVariable %72 Function
+%91 = OpVariable %72 Function
+%101 = OpVariable %72 Function
+%12 = OpLoad %6 %11
+%14 = OpUMod %6 %12 %13
+%15 = OpLoad %6 %11
+%16 = OpUDiv %6 %15 %13
+%17 = OpCompositeConstruct %7 %14 %16
+OpStore %9 %17
+%22 = OpLoad %19 %21
+%23 = OpVectorShuffle %7 %22 %22 0 1
+%24 = OpCompositeExtract %6 %23 0
+%25 = OpCompositeExtract %6 %23 1
+%26 = OpCompositeConstruct %7 %24 %25
+%28 = OpIMul %7 %26 %27
+%29 = OpLoad %7 %9
+%30 = OpIAdd %7 %28 %29
+OpStore %18 %30
+%35 = OpAccessChain %31 %18 %34
+%36 = OpLoad %6 %35
+%37 = OpIMul %6 %33 %36
+%40 = OpAccessChain %31 %18 %39
+%41 = OpLoad %6 %40
+%42 = OpIMul %6 %38 %41
+%43 = OpIAdd %6 %37 %42
+OpStore %32 %43
+%45 = OpAccessChain %31 %18 %34
+%46 = OpLoad %6 %45
+%47 = OpIMul %6 %33 %46
+%48 = OpAccessChain %31 %18 %39
+%49 = OpLoad %6 %48
+%50 = OpIMul %6 %38 %49
+%51 = OpIAdd %6 %47 %50
+OpStore %44 %51
+%53 = OpAccessChain %31 %18 %34
+%54 = OpLoad %6 %53
+%55 = OpIMul %6 %33 %54
+%56 = OpAccessChain %31 %18 %39
+%57 = OpLoad %6 %56
+%58 = OpIMul %6 %38 %57
+%59 = OpIAdd %6 %55 %58
+OpStore %52 %59
+%61 = OpAccessChain %31 %18 %34
+%62 = OpLoad %6 %61
+%63 = OpIMul %6 %33 %62
+%64 = OpAccessChain %31 %18 %39
+%65 = OpLoad %6 %64
+%66 = OpIMul %6 %38 %65
+%67 = OpIAdd %6 %63 %66
+OpStore %60 %67
+%80 = OpLoad %6 %32
+%83 = OpAccessChain %82 %77 %79 %80
+%87 = OpCooperativeMatrixLoadNV %71 %83 %84 %86 )" +
+ loadMemoryAccess + R"( %81
+OpStore %73 %87
+%90 = OpLoad %71 %73
+OpStore %89 %90
+%96 = OpLoad %6 %44
+%97 = OpAccessChain %82 %95 %79 %96
+%98 = OpCooperativeMatrixLoadNV %71 %97 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
+OpStore %91 %98
+%100 = OpLoad %71 %91
+OpStore %99 %100
+%106 = OpLoad %6 %52
+%107 = OpAccessChain %82 %105 %79 %106
+%108 = OpCooperativeMatrixLoadNV %71 %107 %84 %86 MakePointerVisibleKHR|NonPrivatePointerKHR %81
+OpStore %101 %108
+%110 = OpLoad %71 %101
+OpStore %109 %110
+%114 = OpConvertSToF %68 %113
+%115 = OpCompositeConstruct %71 %114
+OpStore %111 %115
+%116 = OpLoad %71 %111
+%121 = OpLoad %6 %60
+%122 = OpAccessChain %82 %120 %79 %121
+OpCooperativeMatrixStoreNV %122 %116 %84 %86 )" + storeMemoryAccess + R"( %81
+OpReturn
+OpFunctionEnd
+)";
+
+ return s;
+}
+
+TEST_F(ValidateMemory, CoopMatLoadStoreSuccess) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+ "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidateMemory, CoopMatStoreMemoryAccessFail) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerVisibleKHR|NonPrivatePointerKHR",
+ "MakePointerVisibleKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("MakePointerVisibleKHR cannot be used with OpStore"));
+}
+
+TEST_F(ValidateMemory, CoopMatLoadMemoryAccessFail) {
+ std::string spirv =
+ GenCoopMatLoadStoreShader("MakePointerAvailableKHR|NonPrivatePointerKHR",
+ "MakePointerAvailableKHR|NonPrivatePointerKHR");
+
+ CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("MakePointerAvailableKHR cannot be used with OpLoad"));
+}
+
+TEST_F(ValidateMemory, CoopMatInvalidStorageClassFail) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%str = OpTypeStruct %f16mat
+%str_ptr = OpTypePointer Workgroup %str
+%sh = OpVariable %str_ptr Workgroup
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Cooperative matrix types (or types containing them) can only be "
+ "allocated in Function or Private storage classes or as function "
+ "parameters"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthResultTypeBad) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %i32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The Result Type of OpCooperativeMatrixLengthNV <id> "
+ "'11[%11]' must be OpTypeInt with width 32 and signedness 0"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthOperandTypeBad) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %u32 %u32
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("The type in OpCooperativeMatrixLengthNV <id> '5[%uint]' "
+ "must be OpTypeCooperativeMatrixNV"));
+}
+
+TEST_F(ValidateMemory, CoopMatMatrixLengthGood) {
+ const std::string body =
+ R"(
+OpCapability Shader
+OpCapability Float16
+OpCapability CooperativeMatrixNV
+OpExtension "SPV_NV_cooperative_matrix"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%f16 = OpTypeFloat 16
+%u32 = OpTypeInt 32 0
+%i32 = OpTypeInt 32 1
+
+%u32_8 = OpConstant %u32 8
+%subgroup = OpConstant %u32 3
+
+%f16mat = OpTypeCooperativeMatrixNV %f16 %subgroup %u32_8 %u32_8
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+
+%1 = OpCooperativeMatrixLengthNV %u32 %f16mat
+
+OpReturn
+OpFunctionEnd)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
TEST_F(ValidateMemory, VulkanRTAOutsideOfStructBad) {
std::string spirv = R"(
OpCapability Shader