spirv-val: Fix Vulkan image sampled check (#4085)

* Fix SampledType logic
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp
index e8f65cf..2d29314 100644
--- a/source/val/validate_image.cpp
+++ b/source/val/validate_image.cpp
@@ -746,16 +746,26 @@
            << "Corrupt image type definition";
   }
 
+  if (_.IsIntScalarType(info.sampled_type) &&
+      (64 == _.GetBitWidth(info.sampled_type)) &&
+      !_.HasCapability(SpvCapabilityInt64ImageEXT)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Capability Int64ImageEXT is required when using Sampled Type of "
+              "64-bit int";
+  }
+
   const auto target_env = _.context()->target_env;
   if (spvIsVulkanEnv(target_env)) {
     if ((!_.IsFloatScalarType(info.sampled_type) &&
          !_.IsIntScalarType(info.sampled_type)) ||
-        (32 != _.GetBitWidth(info.sampled_type) &&
-         (64 != _.GetBitWidth(info.sampled_type) ||
-          !_.HasCapability(SpvCapabilityInt64ImageEXT)))) {
+        ((32 != _.GetBitWidth(info.sampled_type)) &&
+         (64 != _.GetBitWidth(info.sampled_type))) ||
+        ((64 == _.GetBitWidth(info.sampled_type)) &&
+         _.IsFloatScalarType(info.sampled_type))) {
       return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "Expected Sampled Type to be a 32-bit int or float "
-                "scalar type for Vulkan environment";
+             << _.VkErrorID(4656)
+             << "Expected Sampled Type to be a 32-bit int, 64-bit int or "
+                "32-bit float scalar type for Vulkan environment";
     }
   } else if (spvIsOpenCLEnv(target_env)) {
     if (!_.IsVoidType(info.sampled_type)) {
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 0fe9082..ae34185 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1667,6 +1667,8 @@
       return VUID_WRAP(VUID-ShadingRateKHR-ShadingRateKHR-04492);
     case 4633:
       return VUID_WRAP(VUID-StandaloneSpirv-None-04633);
+    case 4656:
+      return VUID_WRAP(VUID-StandaloneSpirv-OpTypeImage-04656);
     case 4658:
       return VUID_WRAP(VUID-StandaloneSpirv-OpImageTexelPointer-04658);
     case 4685:
diff --git a/test/val/val_image_test.cpp b/test/val/val_image_test.cpp
index 7666348..c5c92c4 100644
--- a/test/val/val_image_test.cpp
+++ b/test/val/val_image_test.cpp
@@ -480,6 +480,7 @@
   ss << R"(
 OpCapability Shader
 OpCapability Int64
+OpCapability Float64
 )";
 
   ss << capabilities_and_extensions;
@@ -500,9 +501,11 @@
 %func = OpTypeFunction %void
 %bool = OpTypeBool
 %f32 = OpTypeFloat 32
+%f64 = OpTypeFloat 64
 %u32 = OpTypeInt 32 0
 %u64 = OpTypeInt 64 0
 %s32 = OpTypeInt 32 1
+%s64 = OpTypeInt 64 1
 )";
 
   return ss.str();
@@ -524,8 +527,7 @@
 TEST_F(ValidateImage, TypeImageVoidSampledTypeVulkan) {
   const std::string code = GetShaderHeader() + R"(
 %img_type = OpTypeImage %void 2D 0 0 0 1 Unknown
-%void_func = OpTypeFunction %void
-%main = OpFunction %void None %void_func
+%main = OpFunction %void None %func
 %main_lab = OpLabel
 OpReturn
 OpFunctionEnd
@@ -535,15 +537,131 @@
   CompileSuccessfully(code, env);
   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env));
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Expected Sampled Type to be a 32-bit int "
-                        "or float scalar type for Vulkan environment"));
+              AnyVUID("VUID-StandaloneSpirv-OpTypeImage-04656"));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected Sampled Type to be a 32-bit int, 64-bit int "
+                        "or 32-bit float scalar type for Vulkan environment"));
+}
+
+TEST_F(ValidateImage, TypeImageU32SampledTypeVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %u32 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateImage, TypeImageI32SampledTypeVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %s32 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateImage, TypeImageI64SampledTypeNoCapabilityVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %s64 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Capability Int64ImageEXT is required when using "
+                        "Sampled Type of 64-bit int"));
+}
+
+TEST_F(ValidateImage, TypeImageI64SampledTypeVulkan) {
+  const std::string code = GetShaderHeader(
+                               "OpCapability Int64ImageEXT\nOpExtension "
+                               "\"SPV_EXT_shader_image_int64\"\n") +
+                           R"(
+%img_type = OpTypeImage %s64 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateImage, TypeImageU64SampledTypeNoCapabilityVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %u64 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Capability Int64ImageEXT is required when using "
+                        "Sampled Type of 64-bit int"));
 }
 
 TEST_F(ValidateImage, TypeImageU64SampledTypeVulkan) {
-  const std::string code = GetShaderHeader() + R"(
+  const std::string code = GetShaderHeader(
+                               "OpCapability Int64ImageEXT\nOpExtension "
+                               "\"SPV_EXT_shader_image_int64\"\n") +
+                           R"(
 %img_type = OpTypeImage %u64 2D 0 0 0 1 Unknown
-%void_func = OpTypeFunction %void
-%main = OpFunction %void None %void_func
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateImage, TypeImageF32SampledTypeVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %f32 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(), Eq(""));
+}
+
+TEST_F(ValidateImage, TypeImageF64SampledTypeVulkan) {
+  const std::string code = GetShaderHeader() + R"(
+%img_type = OpTypeImage %f64 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
 %main_lab = OpLabel
 OpReturn
 OpFunctionEnd
@@ -553,8 +671,32 @@
   CompileSuccessfully(code, env);
   ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env));
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Expected Sampled Type to be a 32-bit int "
-                        "or float scalar type for Vulkan environment"));
+              AnyVUID("VUID-StandaloneSpirv-OpTypeImage-04656"));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected Sampled Type to be a 32-bit int, 64-bit int "
+                        "or 32-bit float scalar type for Vulkan environment"));
+}
+
+TEST_F(ValidateImage, TypeImageF64SampledTypeWithInt64Vulkan) {
+  const std::string code = GetShaderHeader(
+                               "OpCapability Int64ImageEXT\nOpExtension "
+                               "\"SPV_EXT_shader_image_int64\"\n") +
+                           R"(
+%img_type = OpTypeImage %f64 2D 0 0 0 1 Unknown
+%main = OpFunction %void None %func
+%main_lab = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  const spv_target_env env = SPV_ENV_VULKAN_1_0;
+  CompileSuccessfully(code, env);
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(env));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-OpTypeImage-04656"));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected Sampled Type to be a 32-bit int, 64-bit int "
+                        "or 32-bit float scalar type for Vulkan environment"));
 }
 
 TEST_F(ValidateImage, TypeImageWrongDepth) {