SPV_NV_shader_atomic_fp16_vector (#5581)
diff --git a/DEPS b/DEPS
index 2cd4493..136f5da 100644
--- a/DEPS
+++ b/DEPS
@@ -13,7 +13,7 @@
'protobuf_revision': 'v21.12',
're2_revision': 'b4c6fe091b74b65f706ff9c9ff369b396c2a3177',
- 'spirv_headers_revision': 'd3c2a6fa95ad463ca8044d7fc45557db381a6a64',
+ 'spirv_headers_revision': '05cc486580771e4fa7ddc89f5c9ee1e97382689a',
}
deps = {
diff --git a/source/val/validate_atomics.cpp b/source/val/validate_atomics.cpp
index b745a9e..8ddef17 100644
--- a/source/val/validate_atomics.cpp
+++ b/source/val/validate_atomics.cpp
@@ -144,12 +144,13 @@
case spv::Op::OpAtomicFlagClear: {
const uint32_t result_type = inst->type_id();
- // All current atomics only are scalar result
// Validate return type first so can just check if pointer type is same
// (if applicable)
if (HasReturnType(opcode)) {
if (HasOnlyFloatReturnType(opcode) &&
- !_.IsFloatScalarType(result_type)) {
+ (!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(result_type)) &&
+ !_.IsFloatScalarType(result_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": expected Result Type to be float scalar type";
@@ -160,6 +161,9 @@
<< ": expected Result Type to be integer scalar type";
} else if (HasIntOrFloatReturnType(opcode) &&
!_.IsFloatScalarType(result_type) &&
+ !(opcode == spv::Op::OpAtomicExchange &&
+ _.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(result_type)) &&
!_.IsIntScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
@@ -222,12 +226,21 @@
if (opcode == spv::Op::OpAtomicFAddEXT) {
// result type being float checked already
- if ((_.GetBitWidth(result_type) == 16) &&
- (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << spvOpcodeString(opcode)
- << ": float add atomics require the AtomicFloat32AddEXT "
- "capability";
+ if (_.GetBitWidth(result_type) == 16) {
+ if (_.IsFloat16Vector2Or4Type(result_type)) {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float vector atomics require the "
+ "AtomicFloat16VectorNV capability";
+ } else {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float add atomics require the AtomicFloat32AddEXT "
+ "capability";
+ }
+ }
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
@@ -245,12 +258,21 @@
}
} else if (opcode == spv::Op::OpAtomicFMinEXT ||
opcode == spv::Op::OpAtomicFMaxEXT) {
- if ((_.GetBitWidth(result_type) == 16) &&
- (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
- return _.diag(SPV_ERROR_INVALID_DATA, inst)
- << spvOpcodeString(opcode)
- << ": float min/max atomics require the "
- "AtomicFloat16MinMaxEXT capability";
+ if (_.GetBitWidth(result_type) == 16) {
+ if (_.IsFloat16Vector2Or4Type(result_type)) {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float vector atomics require the "
+ "AtomicFloat16VectorNV capability";
+ } else {
+ if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << spvOpcodeString(opcode)
+ << ": float min/max atomics require the "
+ "AtomicFloat16MinMaxEXT capability";
+ }
+ }
}
if ((_.GetBitWidth(result_type) == 32) &&
(!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp
index 39eeb4b..46a32f2 100644
--- a/source/val/validate_image.cpp
+++ b/source/val/validate_image.cpp
@@ -1118,7 +1118,10 @@
const auto ptr_type = result_type->GetOperandAs<uint32_t>(2);
const auto ptr_opcode = _.GetIdOpcode(ptr_type);
if (ptr_opcode != spv::Op::OpTypeInt && ptr_opcode != spv::Op::OpTypeFloat &&
- ptr_opcode != spv::Op::OpTypeVoid) {
+ ptr_opcode != spv::Op::OpTypeVoid &&
+ !(ptr_opcode == spv::Op::OpTypeVector &&
+ _.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(ptr_type))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be OpTypePointer whose Type operand "
"must be a scalar numerical type or OpTypeVoid";
@@ -1142,7 +1145,14 @@
<< "Corrupt image type definition";
}
- if (info.sampled_type != ptr_type) {
+ if (info.sampled_type != ptr_type &&
+ !(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+ _.IsFloat16Vector2Or4Type(ptr_type) &&
+ _.GetIdOpcode(info.sampled_type) == spv::Op::OpTypeFloat &&
+ ((_.GetDimension(ptr_type) == 2 &&
+ info.format == spv::ImageFormat::Rg16f) ||
+ (_.GetDimension(ptr_type) == 4 &&
+ info.format == spv::ImageFormat::Rgba16f)))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Image 'Sampled Type' to be the same as the Type "
"pointed to by Result Type";
@@ -1213,7 +1223,10 @@
(info.format != spv::ImageFormat::R64ui) &&
(info.format != spv::ImageFormat::R32f) &&
(info.format != spv::ImageFormat::R32i) &&
- (info.format != spv::ImageFormat::R32ui)) {
+ (info.format != spv::ImageFormat::R32ui) &&
+ !((info.format == spv::ImageFormat::Rg16f ||
+ info.format == spv::ImageFormat::Rgba16f) &&
+ _.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4658)
<< "Expected the Image Format in Image to be R64i, R64ui, R32f, "
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 971e031..25b374d 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -954,6 +954,20 @@
return false;
}
+bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ assert(inst);
+
+ if (inst->opcode() == spv::Op::OpTypeVector) {
+ uint32_t vectorDim = GetDimension(id);
+ return IsFloatScalarType(GetComponentType(id)) &&
+ (vectorDim == 2 || vectorDim == 4) &&
+ (GetBitWidth(GetComponentType(id)) == 16);
+ }
+
+ return false;
+}
+
bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
const Instruction* inst = FindDef(id);
if (!inst) {
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 0cd6c78..46a8cbf 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -602,6 +602,7 @@
bool IsVoidType(uint32_t id) const;
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
+ bool IsFloat16Vector2Or4Type(uint32_t id) const;
bool IsFloatScalarOrVectorType(uint32_t id) const;
bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
diff --git a/test/val/val_atomics_test.cpp b/test/val/val_atomics_test.cpp
index b266ad6..0f65634 100644
--- a/test/val/val_atomics_test.cpp
+++ b/test/val/val_atomics_test.cpp
@@ -318,7 +318,8 @@
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFAddEXT requires one of these capabilities: "
- "AtomicFloat32AddEXT AtomicFloat64AddEXT AtomicFloat16AddEXT"));
+ "AtomicFloat16VectorNV AtomicFloat32AddEXT AtomicFloat64AddEXT "
+ "AtomicFloat16AddEXT"));
}
TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
@@ -331,7 +332,8 @@
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Opcode AtomicFMinEXT requires one of these capabilities: "
- "AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
+ "AtomicFloat16VectorNV AtomicFloat32MinMaxEXT "
+ "AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
}
TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
@@ -343,8 +345,10 @@
ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
- HasSubstr("Opcode AtomicFMaxEXT requires one of these capabilities: "
- "AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
+ HasSubstr(
+ "Opcode AtomicFMaxEXT requires one of these capabilities: "
+ "AtomicFloat16VectorNV AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT "
+ "AtomicFloat16MinMaxEXT"));
}
TEST_F(ValidateAtomics, AtomicAddFloatVulkanWrongType1) {
@@ -2713,6 +2717,136 @@
"value of type Result Type"));
}
+TEST_F(ValidateAtomics, AtomicFloat16VectorSuccess) {
+ const std::string definitions = R"(
+%f16 = OpTypeFloat 16
+%f16vec2 = OpTypeVector %f16 2
+%f16vec4 = OpTypeVector %f16 4
+
+%f16_1 = OpConstant %f16 1
+%f16vec2_1 = OpConstantComposite %f16vec2 %f16_1 %f16_1
+%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1
+
+%f16vec2_ptr = OpTypePointer Workgroup %f16vec2
+%f16vec4_ptr = OpTypePointer Workgroup %f16vec4
+%f16vec2_var = OpVariable %f16vec2_ptr Workgroup
+%f16vec4_var = OpVariable %f16vec4_ptr Workgroup
+)";
+
+ const std::string body = R"(
+%val3 = OpAtomicFMinEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val4 = OpAtomicFMaxEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val8 = OpAtomicFAddEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val9 = OpAtomicExchange %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+
+%val11 = OpAtomicFMinEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val12 = OpAtomicFMaxEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val18 = OpAtomicFAddEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val19 = OpAtomicExchange %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+
+)";
+
+ CompileSuccessfully(GenerateShaderComputeCode(
+ body,
+ "OpCapability Float16\n"
+ "OpCapability AtomicFloat16VectorNV\n"
+ "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+ definitions),
+ SPV_ENV_VULKAN_1_0);
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
+static constexpr char Float16Vector3Defs[] = R"(
+%f16 = OpTypeFloat 16
+%f16vec3 = OpTypeVector %f16 3
+
+%f16_1 = OpConstant %f16 1
+%f16vec3_1 = OpConstantComposite %f16vec3 %f16_1 %f16_1 %f16_1
+
+%f16vec3_ptr = OpTypePointer Workgroup %f16vec3
+%f16vec3_var = OpVariable %f16vec3_ptr Workgroup
+)";
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3MinFail) {
+ const std::string definitions = Float16Vector3Defs;
+
+ const std::string body = R"(
+%val11 = OpAtomicFMinEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+ CompileSuccessfully(GenerateShaderComputeCode(
+ body,
+ "OpCapability Float16\n"
+ "OpCapability AtomicFloat16VectorNV\n"
+ "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+ definitions),
+ SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("AtomicFMinEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3MaxFail) {
+ const std::string definitions = Float16Vector3Defs;
+
+ const std::string body = R"(
+%val12 = OpAtomicFMaxEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+ CompileSuccessfully(GenerateShaderComputeCode(
+ body,
+ "OpCapability Float16\n"
+ "OpCapability AtomicFloat16VectorNV\n"
+ "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+ definitions),
+ SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("AtomicFMaxEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3AddFail) {
+ const std::string definitions = Float16Vector3Defs;
+
+ const std::string body = R"(
+%val18 = OpAtomicFAddEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+ CompileSuccessfully(GenerateShaderComputeCode(
+ body,
+ "OpCapability Float16\n"
+ "OpCapability AtomicFloat16VectorNV\n"
+ "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+ definitions),
+ SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("AtomicFAddEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3ExchangeFail) {
+ const std::string definitions = Float16Vector3Defs;
+
+ const std::string body = R"(
+%val19 = OpAtomicExchange %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+ CompileSuccessfully(GenerateShaderComputeCode(
+ body,
+ "OpCapability Float16\n"
+ "OpCapability AtomicFloat16VectorNV\n"
+ "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+ definitions),
+ SPV_ENV_VULKAN_1_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("AtomicExchange: expected Result Type to be integer or "
+ "float scalar type"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools