SPV_NV_shader_atomic_fp16_vector (#5581)

diff --git a/DEPS b/DEPS
index 2cd4493..136f5da 100644
--- a/DEPS
+++ b/DEPS
@@ -13,7 +13,7 @@
   'protobuf_revision': 'v21.12',
 
   're2_revision': 'b4c6fe091b74b65f706ff9c9ff369b396c2a3177',
-  'spirv_headers_revision': 'd3c2a6fa95ad463ca8044d7fc45557db381a6a64',
+  'spirv_headers_revision': '05cc486580771e4fa7ddc89f5c9ee1e97382689a',
 }
 
 deps = {
diff --git a/source/val/validate_atomics.cpp b/source/val/validate_atomics.cpp
index b745a9e..8ddef17 100644
--- a/source/val/validate_atomics.cpp
+++ b/source/val/validate_atomics.cpp
@@ -144,12 +144,13 @@
     case spv::Op::OpAtomicFlagClear: {
       const uint32_t result_type = inst->type_id();
 
-      // All current atomics only are scalar result
       // Validate return type first so can just check if pointer type is same
       // (if applicable)
       if (HasReturnType(opcode)) {
         if (HasOnlyFloatReturnType(opcode) &&
-            !_.IsFloatScalarType(result_type)) {
+            (!(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+               _.IsFloat16Vector2Or4Type(result_type)) &&
+             !_.IsFloatScalarType(result_type))) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << spvOpcodeString(opcode)
                  << ": expected Result Type to be float scalar type";
@@ -160,6 +161,9 @@
                  << ": expected Result Type to be integer scalar type";
         } else if (HasIntOrFloatReturnType(opcode) &&
                    !_.IsFloatScalarType(result_type) &&
+                   !(opcode == spv::Op::OpAtomicExchange &&
+                     _.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+                     _.IsFloat16Vector2Or4Type(result_type)) &&
                    !_.IsIntScalarType(result_type)) {
           return _.diag(SPV_ERROR_INVALID_DATA, inst)
                  << spvOpcodeString(opcode)
@@ -222,12 +226,21 @@
 
         if (opcode == spv::Op::OpAtomicFAddEXT) {
           // result type being float checked already
-          if ((_.GetBitWidth(result_type) == 16) &&
-              (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT))) {
-            return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                   << spvOpcodeString(opcode)
-                   << ": float add atomics require the AtomicFloat32AddEXT "
-                      "capability";
+          if (_.GetBitWidth(result_type) == 16) {
+            if (_.IsFloat16Vector2Or4Type(result_type)) {
+              if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+                return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                       << spvOpcodeString(opcode)
+                       << ": float vector atomics require the "
+                          "AtomicFloat16VectorNV capability";
+            } else {
+              if (!_.HasCapability(spv::Capability::AtomicFloat16AddEXT)) {
+                return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                       << spvOpcodeString(opcode)
+                       << ": float add atomics require the AtomicFloat32AddEXT "
+                          "capability";
+              }
+            }
           }
           if ((_.GetBitWidth(result_type) == 32) &&
               (!_.HasCapability(spv::Capability::AtomicFloat32AddEXT))) {
@@ -245,12 +258,21 @@
           }
         } else if (opcode == spv::Op::OpAtomicFMinEXT ||
                    opcode == spv::Op::OpAtomicFMaxEXT) {
-          if ((_.GetBitWidth(result_type) == 16) &&
-              (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT))) {
-            return _.diag(SPV_ERROR_INVALID_DATA, inst)
-                   << spvOpcodeString(opcode)
-                   << ": float min/max atomics require the "
-                      "AtomicFloat16MinMaxEXT capability";
+          if (_.GetBitWidth(result_type) == 16) {
+            if (_.IsFloat16Vector2Or4Type(result_type)) {
+              if (!_.HasCapability(spv::Capability::AtomicFloat16VectorNV))
+                return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                       << spvOpcodeString(opcode)
+                       << ": float vector atomics require the "
+                          "AtomicFloat16VectorNV capability";
+            } else {
+              if (!_.HasCapability(spv::Capability::AtomicFloat16MinMaxEXT)) {
+                return _.diag(SPV_ERROR_INVALID_DATA, inst)
+                       << spvOpcodeString(opcode)
+                       << ": float min/max atomics require the "
+                          "AtomicFloat16MinMaxEXT capability";
+              }
+            }
           }
           if ((_.GetBitWidth(result_type) == 32) &&
               (!_.HasCapability(spv::Capability::AtomicFloat32MinMaxEXT))) {
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp
index 39eeb4b..46a32f2 100644
--- a/source/val/validate_image.cpp
+++ b/source/val/validate_image.cpp
@@ -1118,7 +1118,10 @@
   const auto ptr_type = result_type->GetOperandAs<uint32_t>(2);
   const auto ptr_opcode = _.GetIdOpcode(ptr_type);
   if (ptr_opcode != spv::Op::OpTypeInt && ptr_opcode != spv::Op::OpTypeFloat &&
-      ptr_opcode != spv::Op::OpTypeVoid) {
+      ptr_opcode != spv::Op::OpTypeVoid &&
+      !(ptr_opcode == spv::Op::OpTypeVector &&
+        _.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+        _.IsFloat16Vector2Or4Type(ptr_type))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Result Type to be OpTypePointer whose Type operand "
               "must be a scalar numerical type or OpTypeVoid";
@@ -1142,7 +1145,14 @@
            << "Corrupt image type definition";
   }
 
-  if (info.sampled_type != ptr_type) {
+  if (info.sampled_type != ptr_type &&
+      !(_.HasCapability(spv::Capability::AtomicFloat16VectorNV) &&
+        _.IsFloat16Vector2Or4Type(ptr_type) &&
+        _.GetIdOpcode(info.sampled_type) == spv::Op::OpTypeFloat &&
+        ((_.GetDimension(ptr_type) == 2 &&
+          info.format == spv::ImageFormat::Rg16f) ||
+         (_.GetDimension(ptr_type) == 4 &&
+          info.format == spv::ImageFormat::Rgba16f)))) {
     return _.diag(SPV_ERROR_INVALID_DATA, inst)
            << "Expected Image 'Sampled Type' to be the same as the Type "
               "pointed to by Result Type";
@@ -1213,7 +1223,10 @@
         (info.format != spv::ImageFormat::R64ui) &&
         (info.format != spv::ImageFormat::R32f) &&
         (info.format != spv::ImageFormat::R32i) &&
-        (info.format != spv::ImageFormat::R32ui)) {
+        (info.format != spv::ImageFormat::R32ui) &&
+        !((info.format == spv::ImageFormat::Rg16f ||
+           info.format == spv::ImageFormat::Rgba16f) &&
+          _.HasCapability(spv::Capability::AtomicFloat16VectorNV))) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
              << _.VkErrorID(4658)
              << "Expected the Image Format in Image to be R64i, R64ui, R32f, "
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 971e031..25b374d 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -954,6 +954,20 @@
   return false;
 }
 
+bool ValidationState_t::IsFloat16Vector2Or4Type(uint32_t id) const {
+  const Instruction* inst = FindDef(id);
+  assert(inst);
+
+  if (inst->opcode() == spv::Op::OpTypeVector) {
+    uint32_t vectorDim = GetDimension(id);
+    return IsFloatScalarType(GetComponentType(id)) &&
+           (vectorDim == 2 || vectorDim == 4) &&
+           (GetBitWidth(GetComponentType(id)) == 16);
+  }
+
+  return false;
+}
+
 bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
   const Instruction* inst = FindDef(id);
   if (!inst) {
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 0cd6c78..46a8cbf 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -602,6 +602,7 @@
   bool IsVoidType(uint32_t id) const;
   bool IsFloatScalarType(uint32_t id) const;
   bool IsFloatVectorType(uint32_t id) const;
+  bool IsFloat16Vector2Or4Type(uint32_t id) const;
   bool IsFloatScalarOrVectorType(uint32_t id) const;
   bool IsFloatMatrixType(uint32_t id) const;
   bool IsIntScalarType(uint32_t id) const;
diff --git a/test/val/val_atomics_test.cpp b/test/val/val_atomics_test.cpp
index b266ad6..0f65634 100644
--- a/test/val/val_atomics_test.cpp
+++ b/test/val/val_atomics_test.cpp
@@ -318,7 +318,8 @@
   EXPECT_THAT(
       getDiagnosticString(),
       HasSubstr("Opcode AtomicFAddEXT requires one of these capabilities: "
-                "AtomicFloat32AddEXT AtomicFloat64AddEXT AtomicFloat16AddEXT"));
+                "AtomicFloat16VectorNV AtomicFloat32AddEXT AtomicFloat64AddEXT "
+                "AtomicFloat16AddEXT"));
 }
 
 TEST_F(ValidateAtomics, AtomicMinFloatVulkan) {
@@ -331,7 +332,8 @@
   EXPECT_THAT(
       getDiagnosticString(),
       HasSubstr("Opcode AtomicFMinEXT requires one of these capabilities: "
-                "AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
+                "AtomicFloat16VectorNV AtomicFloat32MinMaxEXT "
+                "AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
 }
 
 TEST_F(ValidateAtomics, AtomicMaxFloatVulkan) {
@@ -343,8 +345,10 @@
   ASSERT_EQ(SPV_ERROR_INVALID_CAPABILITY, ValidateInstructions());
   EXPECT_THAT(
       getDiagnosticString(),
-      HasSubstr("Opcode AtomicFMaxEXT requires one of these capabilities: "
-                "AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT AtomicFloat16MinMaxEXT"));
+      HasSubstr(
+          "Opcode AtomicFMaxEXT requires one of these capabilities: "
+          "AtomicFloat16VectorNV AtomicFloat32MinMaxEXT AtomicFloat64MinMaxEXT "
+          "AtomicFloat16MinMaxEXT"));
 }
 
 TEST_F(ValidateAtomics, AtomicAddFloatVulkanWrongType1) {
@@ -2713,6 +2717,136 @@
                         "value of type Result Type"));
 }
 
+TEST_F(ValidateAtomics, AtomicFloat16VectorSuccess) {
+  const std::string definitions = R"(
+%f16 = OpTypeFloat 16
+%f16vec2 = OpTypeVector %f16 2
+%f16vec4 = OpTypeVector %f16 4
+
+%f16_1 = OpConstant %f16 1
+%f16vec2_1 = OpConstantComposite %f16vec2 %f16_1 %f16_1
+%f16vec4_1 = OpConstantComposite %f16vec4 %f16_1 %f16_1 %f16_1 %f16_1
+
+%f16vec2_ptr = OpTypePointer Workgroup %f16vec2
+%f16vec4_ptr = OpTypePointer Workgroup %f16vec4
+%f16vec2_var = OpVariable %f16vec2_ptr Workgroup
+%f16vec4_var = OpVariable %f16vec4_ptr Workgroup
+)";
+
+  const std::string body = R"(
+%val3 = OpAtomicFMinEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val4 = OpAtomicFMaxEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val8 = OpAtomicFAddEXT %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+%val9 = OpAtomicExchange %f16vec2 %f16vec2_var %device %relaxed %f16vec2_1
+
+%val11 = OpAtomicFMinEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val12 = OpAtomicFMaxEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val18 = OpAtomicFAddEXT %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+%val19 = OpAtomicExchange %f16vec4 %f16vec4_var %device %relaxed %f16vec4_1
+
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(
+                          body,
+                          "OpCapability Float16\n"
+                          "OpCapability AtomicFloat16VectorNV\n"
+                          "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+                          definitions),
+                      SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+}
+
+static constexpr char Float16Vector3Defs[] = R"(
+%f16 = OpTypeFloat 16
+%f16vec3 = OpTypeVector %f16 3
+
+%f16_1 = OpConstant %f16 1
+%f16vec3_1 = OpConstantComposite %f16vec3 %f16_1 %f16_1 %f16_1
+
+%f16vec3_ptr = OpTypePointer Workgroup %f16vec3
+%f16vec3_var = OpVariable %f16vec3_ptr Workgroup
+)";
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3MinFail) {
+  const std::string definitions = Float16Vector3Defs;
+
+  const std::string body = R"(
+%val11 = OpAtomicFMinEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(
+                          body,
+                          "OpCapability Float16\n"
+                          "OpCapability AtomicFloat16VectorNV\n"
+                          "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+                          definitions),
+                      SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("AtomicFMinEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3MaxFail) {
+  const std::string definitions = Float16Vector3Defs;
+
+  const std::string body = R"(
+%val12 = OpAtomicFMaxEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(
+                          body,
+                          "OpCapability Float16\n"
+                          "OpCapability AtomicFloat16VectorNV\n"
+                          "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+                          definitions),
+                      SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("AtomicFMaxEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3AddFail) {
+  const std::string definitions = Float16Vector3Defs;
+
+  const std::string body = R"(
+%val18 = OpAtomicFAddEXT %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(
+                          body,
+                          "OpCapability Float16\n"
+                          "OpCapability AtomicFloat16VectorNV\n"
+                          "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+                          definitions),
+                      SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("AtomicFAddEXT: expected Result Type to be float scalar type"));
+}
+
+TEST_F(ValidateAtomics, AtomicFloat16Vector3ExchangeFail) {
+  const std::string definitions = Float16Vector3Defs;
+
+  const std::string body = R"(
+%val19 = OpAtomicExchange %f16vec3 %f16vec3_var %device %relaxed %f16vec3_1
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(
+                          body,
+                          "OpCapability Float16\n"
+                          "OpCapability AtomicFloat16VectorNV\n"
+                          "OpExtension \"SPV_NV_shader_atomic_fp16_vector\"\n",
+                          definitions),
+                      SPV_ENV_VULKAN_1_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("AtomicExchange: expected Result Type to be integer or "
+                        "float scalar type"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools