spirv-val: Vulkan Storage Class for Execution Model (#4212)

diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index db86fd2..14ee3b2 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -520,17 +520,39 @@
 void ValidationState_t::RegisterInstruction(Instruction* inst) {
   if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
 
-  // If the instruction is using an OpTypeSampledImage as an operand, it should
-  // be recorded. The validator will ensure that all usages of an
-  // OpTypeSampledImage and its definition are in the same basic block.
+  // Some validation checks are easier by getting all the consumers
   for (uint16_t i = 0; i < inst->operands().size(); ++i) {
     const spv_parsed_operand_t& operand = inst->operand(i);
-    if (SPV_OPERAND_TYPE_ID == operand.type) {
+    if ((SPV_OPERAND_TYPE_ID == operand.type) ||
+        (SPV_OPERAND_TYPE_TYPE_ID == operand.type)) {
       const uint32_t operand_word = inst->word(operand.offset);
       Instruction* operand_inst = FindDef(operand_word);
-      if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
+      if (!operand_inst) {
+        continue;
+      }
+
+      // If the instruction is using an OpTypeSampledImage as an operand, it
+      // should be recorded. The validator will ensure that all usages of an
+      // OpTypeSampledImage and its definition are in the same basic block.
+      if ((SPV_OPERAND_TYPE_ID == operand.type) &&
+          (SpvOpSampledImage == operand_inst->opcode())) {
         RegisterSampledImageConsumer(operand_word, inst);
       }
+
+      // In order to track storage classes (not Function) used per execution
+      // model we can't use RegisterExecutionModelLimitation on instructions
+      // like OpTypePointer which are going to be in the pre-function section.
+      // Instead just need to register storage class usage for consumers in a
+      // function block.
+      if (inst->function()) {
+        if (operand_inst->opcode() == SpvOpTypePointer) {
+          RegisterStorageClassConsumer(
+              operand_inst->GetOperandAs<SpvStorageClass>(1), inst);
+        } else if (operand_inst->opcode() == SpvOpVariable) {
+          RegisterStorageClassConsumer(
+              operand_inst->GetOperandAs<SpvStorageClass>(2), inst);
+        }
+      }
     }
   }
 }
@@ -550,6 +572,56 @@
   sampled_image_consumers_[sampled_image_id].push_back(consumer);
 }
 
+void ValidationState_t::RegisterStorageClassConsumer(
+    SpvStorageClass storage_class, Instruction* consumer) {
+  if (spvIsVulkanEnv(context()->target_env)) {
+    if (storage_class == SpvStorageClassOutput) {
+      std::string errorVUID = VkErrorID(4644);
+      function(consumer->function()->id())
+          ->RegisterExecutionModelLimitation([errorVUID](
+              SpvExecutionModel model, std::string* message) {
+            if (model == SpvExecutionModelGLCompute ||
+                model == SpvExecutionModelRayGenerationKHR ||
+                model == SpvExecutionModelIntersectionKHR ||
+                model == SpvExecutionModelAnyHitKHR ||
+                model == SpvExecutionModelClosestHitKHR ||
+                model == SpvExecutionModelMissKHR ||
+                model == SpvExecutionModelCallableKHR) {
+              if (message) {
+                *message =
+                    errorVUID +
+                    "in Vulkan evironment, Output Storage Class must not be "
+                    "used in RayGenerationKHR, IntersectionKHR, AnyHitKHR, "
+                    "ClosestHitKHR, MissKHR, or CallableKHR execution models";
+              }
+              return false;
+            }
+            return true;
+          });
+    }
+
+    if (storage_class == SpvStorageClassWorkgroup) {
+      std::string errorVUID = VkErrorID(4645);
+      function(consumer->function()->id())
+          ->RegisterExecutionModelLimitation([errorVUID](
+              SpvExecutionModel model, std::string* message) {
+            if (model != SpvExecutionModelGLCompute &&
+                model != SpvExecutionModelTaskNV &&
+                model != SpvExecutionModelMeshNV) {
+              if (message) {
+                *message =
+                    errorVUID +
+                    "in Vulkan evironment, Workgroup Storage Class is limited "
+                    "to MeshNV, TaskNV, and GLCompute execution model";
+              }
+              return false;
+            }
+            return true;
+          });
+    }
+  }
+}
+
 uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
 
 void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
@@ -1696,6 +1768,10 @@
       return VUID_WRAP(VUID-StandaloneSpirv-None-04642);
     case 4643:
       return VUID_WRAP(VUID-StandaloneSpirv-None-04643);
+    case 4644:
+      return VUID_WRAP(VUID-StandaloneSpirv-None-04644);
+    case 4645:
+      return VUID_WRAP(VUID-StandaloneSpirv-None-04645);
     case 4651:
       return VUID_WRAP(VUID-StandaloneSpirv-OpVariable-04651);
     case 4652:
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 8511139..57634bf 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -465,6 +465,10 @@
   void RegisterSampledImageConsumer(uint32_t sampled_image_id,
                                     Instruction* consumer);
 
+  // Record a function's storage class consumer instruction
+  void RegisterStorageClassConsumer(SpvStorageClass storage_class,
+                                    Instruction* consumer);
+
   /// Returns the set of Global Variables.
   std::unordered_set<uint32_t>& global_vars() { return global_vars_; }
 
diff --git a/test/val/val_atomics_test.cpp b/test/val/val_atomics_test.cpp
index e2ca71f..fc3aedb 100644
--- a/test/val/val_atomics_test.cpp
+++ b/test/val/val_atomics_test.cpp
@@ -101,7 +101,7 @@
 OpEntryPoint Fragment %main "main"
 OpExecutionMode %main OriginUpperLeft
 )";
-  const std::string defintions = R"(
+  const std::string definitions = R"(
 %u64 = OpTypeInt 64 0
 %s64 = OpTypeInt 64 1
 
@@ -115,19 +115,19 @@
 )";
   return GenerateShaderCodeImpl(
       body, "OpCapability Int64\n" + capabilities_and_extensions,
-      defintions + extra_defs,
-      memory_model, execution);
+      definitions + extra_defs, memory_model, execution);
 }
 
 std::string GenerateShaderComputeCode(
     const std::string& body,
     const std::string& capabilities_and_extensions = "",
+    const std::string& extra_defs = "",
     const std::string& memory_model = "GLSL450") {
   const std::string execution = R"(
 OpEntryPoint GLCompute %main "main"
 OpExecutionMode %main LocalSize 32 1 1
 )";
-  const std::string defintions = R"(
+  const std::string definitions = R"(
 %u64 = OpTypeInt 64 0
 %s64 = OpTypeInt 64 1
 
@@ -140,8 +140,8 @@
 %s64_var = OpVariable %s64_ptr Workgroup
 )";
   return GenerateShaderCodeImpl(
-      body, "OpCapability Int64\n" + capabilities_and_extensions, defintions,
-      memory_model, execution);
+      body, "OpCapability Int64\n" + capabilities_and_extensions,
+      definitions + extra_defs, memory_model, execution);
 }
 
 std::string GenerateKernelCode(
@@ -269,6 +269,21 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
+TEST_F(ValidateAtomics, AtomicLoadVulkanWrongStorageClass) {
+  const std::string body = R"(
+%val1 = OpAtomicLoad %u32 %u32_var %device %relaxed
+)";
+
+  CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-None-04645"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("in Vulkan evironment, Workgroup Storage Class is limited to "
+                "MeshNV, TaskNV, and GLCompute execution model"));
+}
+
 TEST_F(ValidateAtomics, AtomicAddIntVulkanWrongType1) {
   const std::string body = R"(
 %val1 = OpAtomicIAdd %f32 %f32_var %device %relaxed %f32_1
@@ -534,7 +549,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_add"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -554,7 +570,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra, defs), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra, defs),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -574,7 +591,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra, defs), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra, defs),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -587,7 +605,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -600,7 +619,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -620,7 +640,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra, defs), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra, defs),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -640,7 +661,8 @@
 OpExtension "SPV_EXT_shader_atomic_float_min_max"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra, defs), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra, defs),
+                      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -654,12 +676,27 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
-TEST_F(ValidateAtomics, AtomicStoreFloatVulkan) {
+TEST_F(ValidateAtomics, AtomicStoreVulkanWrongStorageClass) {
   const std::string body = R"(
 OpAtomicStore %f32_var %device %relaxed %f32_1
 )";
 
   CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-None-04645"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("in Vulkan evironment, Workgroup Storage Class is limited to "
+                "MeshNV, TaskNV, and GLCompute execution model"));
+}
+
+TEST_F(ValidateAtomics, AtomicStoreFloatVulkan) {
+  const std::string body = R"(
+OpAtomicStore %f32_var %device %relaxed %f32_1
+)";
+
+  CompileSuccessfully(GenerateShaderComputeCode(body), SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -668,7 +705,7 @@
 %val2 = OpAtomicExchange %f32 %f32_var %device %relaxed %f32_0
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body), SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -903,8 +940,9 @@
 OpAtomicStore %s64_var %device %relaxed %s64_1
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, "OpCapability Int64Atomics\n"),
-                      SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(
+      GenerateShaderComputeCode(body, "OpCapability Int64Atomics\n"),
+      SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -1007,7 +1045,7 @@
 OpAtomicStore %u32_var %invocation %relaxed %u32_1
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body), SPV_ENV_VULKAN_1_0);
+  CompileSuccessfully(GenerateShaderComputeCode(body), SPV_ENV_VULKAN_1_0);
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_0));
 }
 
@@ -2297,7 +2335,7 @@
 OpExtension "SPV_KHR_vulkan_memory_model"
 )";
 
-  CompileSuccessfully(GenerateShaderCode(body, extra, "", "VulkanKHR"),
+  CompileSuccessfully(GenerateShaderComputeCode(body, extra, "", "VulkanKHR"),
                       SPV_ENV_VULKAN_1_1);
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
 }
diff --git a/test/val/val_storage_test.cpp b/test/val/val_storage_test.cpp
index e6f98bf..2ab265d 100644
--- a/test/val/val_storage_test.cpp
+++ b/test/val/val_storage_test.cpp
@@ -30,6 +30,7 @@
 using ValidateStorage = spvtest::ValidateBase<std::string>;
 using ValidateStorageClass =
     spvtest::ValidateBase<std::tuple<std::string, bool, bool, std::string>>;
+using ValidateStorageExecutionModel = spvtest::ValidateBase<std::string>;
 
 TEST_F(ValidateStorage, FunctionStorageInsideFunction) {
   char str[] = R"(
@@ -250,6 +251,46 @@
               HasSubstr("OpFunctionCall Argument <id> '"));
 }
 
+TEST_P(ValidateStorageExecutionModel, VulkanOutsideStoreFailure) {
+  std::stringstream ss;
+  ss << R"(
+              OpCapability Shader
+              OpCapability RayTracingKHR
+              OpExtension "SPV_KHR_ray_tracing"
+              OpMemoryModel Logical GLSL450
+              OpEntryPoint )"
+     << GetParam() << R"(  %func "func" %output
+              OpDecorate %output Location 0
+%intt       = OpTypeInt 32 0
+%int0       = OpConstant %intt 0
+%voidt      = OpTypeVoid
+%vfunct     = OpTypeFunction %voidt
+%outputptrt = OpTypePointer Output %intt
+%output     = OpVariable %outputptrt Output
+%func       = OpFunction %voidt None %vfunct
+%funcl      = OpLabel
+              OpStore %output %int0
+              OpReturn
+              OpFunctionEnd
+)";
+
+  CompileSuccessfully(ss.str(), SPV_ENV_VULKAN_1_0);
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
+  EXPECT_THAT(getDiagnosticString(),
+              AnyVUID("VUID-StandaloneSpirv-None-04644"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("in Vulkan evironment, Output Storage Class must not be used "
+                "in RayGenerationKHR, IntersectionKHR, AnyHitKHR, "
+                "ClosestHitKHR, MissKHR, or CallableKHR execution models"));
+}
+
+INSTANTIATE_TEST_SUITE_P(MatrixExecutionModel, ValidateStorageExecutionModel,
+                         ::testing::Values("RayGenerationKHR",
+                                           "IntersectionKHR", "AnyHitKHR",
+                                           "ClosestHitKHR", "MissKHR",
+                                           "CallableKHR"));
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools