Refactor the InstructionPass (#2924)

* move checks to more appropriate locations
  * remove some duplicated checks
* New function to check valid storage classes
* updated tests
diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp
index fecc351..6478b3c 100644
--- a/source/val/validate_instruction.cpp
+++ b/source/val/validate_instruction.cpp
@@ -55,18 +55,6 @@
   return ss.str();
 }
 
-bool IsValidWebGPUStorageClass(SpvStorageClass storage_class) {
-  return storage_class == SpvStorageClassUniformConstant ||
-         storage_class == SpvStorageClassUniform ||
-         storage_class == SpvStorageClassStorageBuffer ||
-         storage_class == SpvStorageClassInput ||
-         storage_class == SpvStorageClassOutput ||
-         storage_class == SpvStorageClassImage ||
-         storage_class == SpvStorageClassWorkgroup ||
-         storage_class == SpvStorageClassPrivate ||
-         storage_class == SpvStorageClassFunction;
-}
-
 // Returns capabilities that enable an opcode.  An empty result is interpreted
 // as no prohibition of use of the opcode.  If the result is non-empty, then
 // the opcode may only be used if at least one of the capabilities is specified
@@ -249,23 +237,6 @@
   return SPV_SUCCESS;
 }
 
-// Returns SPV_ERROR_INVALID_BINARY and emits a diagnostic if the instruction
-// is invalid because of an execution environment constraint.
-spv_result_t EnvironmentCheck(ValidationState_t& _, const Instruction* inst) {
-  const SpvOp opcode = inst->opcode();
-  switch (opcode) {
-    case SpvOpUndef:
-      if (_.features().bans_op_undef) {
-        return _.diag(SPV_ERROR_INVALID_BINARY, inst)
-               << "OpUndef is disallowed";
-      }
-      break;
-    default:
-      break;
-  }
-  return SPV_SUCCESS;
-}
-
 // Returns SPV_ERROR_INVALID_CAPABILITY and emits a diagnostic if the
 // instruction is invalid because the required capability isn't declared
 // in the module.
@@ -499,38 +470,6 @@
     }
     _.set_addressing_model(inst->GetOperandAs<SpvAddressingModel>(0));
     _.set_memory_model(inst->GetOperandAs<SpvMemoryModel>(1));
-
-    if (_.memory_model() != SpvMemoryModelVulkanKHR &&
-        _.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) {
-      return _.diag(SPV_ERROR_INVALID_DATA, inst)
-             << "VulkanMemoryModelKHR capability must only be specified if "
-                "the "
-                "VulkanKHR memory model is used.";
-    }
-
-    if (spvIsWebGPUEnv(_.context()->target_env)) {
-      if (_.addressing_model() != SpvAddressingModelLogical) {
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Addressing model must be Logical for WebGPU environment.";
-      }
-      if (_.memory_model() != SpvMemoryModelVulkanKHR) {
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Memory model must be VulkanKHR for WebGPU environment.";
-      }
-    }
-
-    if (spvIsOpenCLEnv(_.context()->target_env)) {
-      if ((_.addressing_model() != SpvAddressingModelPhysical32) &&
-          (_.addressing_model() != SpvAddressingModelPhysical64)) {
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Addressing model must be Physical32 or Physical64 "
-               << "in the OpenCL environment.";
-      }
-      if (_.memory_model() != SpvMemoryModelOpenCL) {
-        return _.diag(SPV_ERROR_INVALID_DATA, inst)
-               << "Memory model must be OpenCL in the OpenCL environment.";
-      }
-    }
   } else if (opcode == SpvOpExecutionMode) {
     const uint32_t entry_point = inst->word(1);
     _.RegisterExecutionModeForEntryPoint(entry_point,
@@ -540,61 +479,9 @@
     if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {
       return error;
     }
-
-    if (spvIsWebGPUEnv(_.context()->target_env) &&
-        !IsValidWebGPUStorageClass(storage_class)) {
-      return _.diag(SPV_ERROR_INVALID_BINARY, inst)
-             << "For WebGPU, OpVariable storage class must be one of "
-                "UniformConstant, Uniform, StorageBuffer, Input, Output, "
-                "Image, Workgroup, Private, Function for WebGPU";
-    }
-
-    if (storage_class == SpvStorageClassGeneric)
-      return _.diag(SPV_ERROR_INVALID_BINARY, inst)
-             << "OpVariable storage class cannot be Generic";
-    if (_.current_layout_section() == kLayoutFunctionDefinitions) {
-      if (storage_class != SpvStorageClassFunction) {
-        return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
-               << "Variables must have a function[7] storage class inside"
-                  " of a function";
-      }
-      if (_.current_function().IsFirstBlock(
-              _.current_function().current_block()->id()) == false) {
-        return _.diag(SPV_ERROR_INVALID_CFG, inst)
-               << "Variables can only be defined "
-                  "in the first block of a "
-                  "function";
-      }
-    } else {
-      if (storage_class == SpvStorageClassFunction) {
-        return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
-               << "Variables can not have a function[7] storage class "
-                  "outside of a function";
-      }
-    }
-  } else if (opcode == SpvOpTypePointer) {
-    const auto storage_class = inst->GetOperandAs<SpvStorageClass>(1);
-    if (spvIsWebGPUEnv(_.context()->target_env) &&
-        !IsValidWebGPUStorageClass(storage_class)) {
-      return _.diag(SPV_ERROR_INVALID_BINARY, inst)
-             << "For WebGPU, OpTypePointer storage class must be one of "
-                "UniformConstant, Uniform, StorageBuffer, Input, Output, "
-                "Image, Workgroup, Private, Function";
-    }
-  }
-
-  // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The
-  // Signedness in OpTypeInt must always be 0.
-  if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) &&
-      inst->GetOperandAs<uint32_t>(2) != 0u) {
-    return _.diag(SPV_ERROR_INVALID_BINARY, inst)
-           << "The Signedness in OpTypeInt "
-              "must always be 0 when Kernel "
-              "capability is used.";
   }
 
   if (auto error = ReservedCheck(_, inst)) return error;
-  if (auto error = EnvironmentCheck(_, inst)) return error;
   if (auto error = CapabilityCheck(_, inst)) return error;
   if (auto error = LimitCheckIdBound(_, inst)) return error;
   if (auto error = LimitCheckStruct(_, inst)) return error;
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 8e22097..59cdbb3 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -461,6 +461,28 @@
     }
   }
 
+  if (!_.IsValidStorageClass(storage_class)) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, inst)
+           << "Invalid storage class for target environment";
+  }
+
+  if (storage_class == SpvStorageClassGeneric) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, inst)
+           << "OpVariable storage class cannot be Generic";
+  }
+
+  if (inst->function() && storage_class != SpvStorageClassFunction) {
+    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+           << "Variables must have a function[7] storage class inside"
+              " of a function";
+  }
+
+  if (!inst->function() && storage_class == SpvStorageClassFunction) {
+    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
+           << "Variables can not have a function[7] storage class "
+              "outside of a function";
+  }
+
   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
   // storage class.
   const auto result_storage_class_index = 1;
diff --git a/source/val/validate_misc.cpp b/source/val/validate_misc.cpp
index 28e3fc6..0239593 100644
--- a/source/val/validate_misc.cpp
+++ b/source/val/validate_misc.cpp
@@ -32,6 +32,10 @@
            << "Cannot create undefined values with 8- or 16-bit types";
   }
 
+  if (spvIsWebGPUEnv(_.context()->target_env)) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, inst) << "OpUndef is disallowed";
+  }
+
   return SPV_SUCCESS;
 }
 
diff --git a/source/val/validate_mode_setting.cpp b/source/val/validate_mode_setting.cpp
index cbcf11a..e020f5a 100644
--- a/source/val/validate_mode_setting.cpp
+++ b/source/val/validate_mode_setting.cpp
@@ -485,6 +485,44 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateMemoryModel(ValidationState_t& _,
+                                 const Instruction* inst) {
+  // Already produced an error if multiple memory model instructions are
+  // present.
+  if (_.memory_model() != SpvMemoryModelVulkanKHR &&
+      _.HasCapability(SpvCapabilityVulkanMemoryModelKHR)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "VulkanMemoryModelKHR capability must only be specified if "
+              "the VulkanKHR memory model is used.";
+  }
+
+  if (spvIsWebGPUEnv(_.context()->target_env)) {
+    if (_.addressing_model() != SpvAddressingModelLogical) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Addressing model must be Logical for WebGPU environment.";
+    }
+    if (_.memory_model() != SpvMemoryModelVulkanKHR) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Memory model must be VulkanKHR for WebGPU environment.";
+    }
+  }
+
+  if (spvIsOpenCLEnv(_.context()->target_env)) {
+    if ((_.addressing_model() != SpvAddressingModelPhysical32) &&
+        (_.addressing_model() != SpvAddressingModelPhysical64)) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Addressing model must be Physical32 or Physical64 "
+             << "in the OpenCL environment.";
+    }
+    if (_.memory_model() != SpvMemoryModelOpenCL) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Memory model must be OpenCL in the OpenCL environment.";
+    }
+  }
+
+  return SPV_SUCCESS;
+}
+
 }  // namespace
 
 spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst) {
@@ -496,6 +534,9 @@
     case SpvOpExecutionModeId:
       if (auto error = ValidateExecutionMode(_, inst)) return error;
       break;
+    case SpvOpMemoryModel:
+      if (auto error = ValidateMemoryModel(_, inst)) return error;
+      break;
     default:
       break;
   }
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index d3872da..4d673b4 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -106,6 +106,17 @@
     return _.diag(SPV_ERROR_INVALID_VALUE, inst)
            << "OpTypeInt has invalid signedness:";
   }
+
+  // SPIR-V Spec 2.16.3: Validation Rules for Kernel Capabilities: The
+  // Signedness in OpTypeInt must always be 0.
+  if (SpvOpTypeInt == inst->opcode() && _.HasCapability(SpvCapabilityKernel) &&
+      inst->GetOperandAs<uint32_t>(2) != 0u) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, inst)
+           << "The Signedness in OpTypeInt "
+              "must always be 0 when Kernel "
+              "capability is used.";
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -445,6 +456,12 @@
       if (sampled == 2) _.RegisterPointerToStorageImage(inst->id());
     }
   }
+
+  if (!_.IsValidStorageClass(storage_class)) {
+    return _.diag(SPV_ERROR_INVALID_BINARY, inst)
+           << "Invalid storage class for target environment";
+  }
+
   return SPV_SUCCESS;
 }
 
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 794d0f7..20eaf88 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -212,14 +212,6 @@
     }
   }
 
-  switch (env) {
-    case SPV_ENV_WEBGPU_0:
-      features_.bans_op_undef = true;
-      break;
-    default:
-      break;
-  }
-
   // Only attempt to count if we have words, otherwise let the other validation
   // fail and generate an error.
   if (num_words > 0) {
@@ -1277,5 +1269,52 @@
   return false;
 }
 
+bool ValidationState_t::IsValidStorageClass(
+    SpvStorageClass storage_class) const {
+  if (spvIsWebGPUEnv(context()->target_env)) {
+    switch (storage_class) {
+      case SpvStorageClassUniformConstant:
+      case SpvStorageClassUniform:
+      case SpvStorageClassStorageBuffer:
+      case SpvStorageClassInput:
+      case SpvStorageClassOutput:
+      case SpvStorageClassImage:
+      case SpvStorageClassWorkgroup:
+      case SpvStorageClassPrivate:
+      case SpvStorageClassFunction:
+        return true;
+      default:
+        return false;
+    }
+  }
+
+  if (spvIsVulkanEnv(context()->target_env)) {
+    switch (storage_class) {
+      case SpvStorageClassUniformConstant:
+      case SpvStorageClassUniform:
+      case SpvStorageClassStorageBuffer:
+      case SpvStorageClassInput:
+      case SpvStorageClassOutput:
+      case SpvStorageClassImage:
+      case SpvStorageClassWorkgroup:
+      case SpvStorageClassPrivate:
+      case SpvStorageClassFunction:
+      case SpvStorageClassPushConstant:
+      case SpvStorageClassPhysicalStorageBuffer:
+      case SpvStorageClassRayPayloadNV:
+      case SpvStorageClassIncomingRayPayloadNV:
+      case SpvStorageClassHitAttributeNV:
+      case SpvStorageClassCallableDataNV:
+      case SpvStorageClassIncomingCallableDataNV:
+      case SpvStorageClassShaderRecordBufferNV:
+        return true;
+      default:
+        return false;
+    }
+  }
+
+  return true;
+}
+
 }  // namespace val
 }  // namespace spvtools
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index e650d2e..e5d31ac 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -79,9 +79,6 @@
     // Permit group oerations Reduce, InclusiveScan, ExclusiveScan
     bool group_ops_reduce_and_scans = false;
 
-    // Disallows the use of OpUndef
-    bool bans_op_undef = false;
-
     // Allow OpTypeInt with 8 bit width?
     bool declare_int8_type = false;
 
@@ -707,6 +704,9 @@
   // * OpCopyObject
   const Instruction* TracePointer(const Instruction* inst) const;
 
+  // Validates the storage class for the target environment.
+  bool IsValidStorageClass(SpvStorageClass storage_class) const;
+
  private:
   ValidationState_t(const ValidationState_t&);
 
diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp
index b22db06..547eb57 100644
--- a/test/val/val_cfg_test.cpp
+++ b/test/val/val_cfg_test.cpp
@@ -342,11 +342,10 @@
   str += "OpFunctionEnd\n";
 
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions());
-  EXPECT_THAT(
-      getDiagnosticString(),
-      HasSubstr(
-          "Variables can only be defined in the first block of a function"));
+  ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("All OpVariable instructions in a function must be the "
+                        "first instructions in the first block"));
 }
 
 TEST_P(ValidateCFG, BlockSelfLoopIsOk) {
diff --git a/test/val/val_opencl_test.cpp b/test/val/val_opencl_test.cpp
index 18b2f71..1064158 100644
--- a/test/val/val_opencl_test.cpp
+++ b/test/val/val_opencl_test.cpp
@@ -45,15 +45,18 @@
 TEST_F(ValidateOpenCL, NonOpenCLMemoryModelBad) {
   std::string spirv = R"(
      OpCapability Kernel
-     OpMemoryModel Physical32 GLSL450
+     OpCapability Addresses
+     OpCapability VulkanMemoryModelKHR
+     OpExtension "SPV_KHR_vulkan_memory_model"
+     OpMemoryModel Physical32 VulkanKHR
 )";
 
   CompileSuccessfully(spirv);
 
   EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_OPENCL_1_2));
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Memory model must be OpenCL in the OpenCL environment."
-                        "\n  OpMemoryModel Physical32 GLSL450\n"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Memory model must be OpenCL in the OpenCL environment."));
 }
 
 TEST_F(ValidateOpenCL, NonVoidSampledTypeImageBad) {
diff --git a/test/val/val_storage_test.cpp b/test/val/val_storage_test.cpp
index f54b425..fe37a93 100644
--- a/test/val/val_storage_test.cpp
+++ b/test/val/val_storage_test.cpp
@@ -157,6 +157,7 @@
   const auto str = R"(
           OpCapability Kernel
           OpCapability Linkage
+          OpCapability GenericPointer
           OpMemoryModel Logical OpenCL
 %intt   = OpTypeInt 32 0
 %ptrt   = OpTypePointer Function %intt
@@ -172,6 +173,7 @@
   const auto str = R"(
           OpCapability Shader
           OpCapability Linkage
+          OpCapability GenericPointer
           OpMemoryModel Logical GLSL450
 %intt   = OpTypeInt 32 1
 %voidt  = OpTypeVoid
@@ -184,7 +186,7 @@
           OpFunctionEnd
 )";
   CompileSuccessfully(str);
-  ASSERT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
               HasSubstr("OpVariable storage class cannot be Generic"));
 }
@@ -307,12 +309,10 @@
            std::make_tuple("Workgroup", false, true, ""),
            std::make_tuple("Private", false, true, ""),
            std::make_tuple("Function", true, true, ""),
-           std::make_tuple(
-               "CrossWorkgroup", false, false,
-               "For WebGPU, OpTypePointer storage class must be one of"),
-           std::make_tuple(
-               "PushConstant", false, false,
-               "For WebGPU, OpTypePointer storage class must be one of")));
+           std::make_tuple("CrossWorkgroup", false, false,
+                           "Invalid storage class for target environment"),
+           std::make_tuple("PushConstant", false, false,
+                           "Invalid storage class for target environment")));
 
 }  // namespace
 }  // namespace val
diff --git a/test/val/val_webgpu_test.cpp b/test/val/val_webgpu_test.cpp
index 1eae0d3..8f62555 100644
--- a/test/val/val_webgpu_test.cpp
+++ b/test/val/val_webgpu_test.cpp
@@ -187,23 +187,6 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_WEBGPU_0));
 }
 
-TEST_F(ValidateWebGPU, NonLogicalAddressingModelBad) {
-  std::string spirv = R"(
-     OpCapability Shader
-     OpCapability VulkanMemoryModelKHR
-     OpExtension "SPV_KHR_vulkan_memory_model"
-     OpMemoryModel Physical32 VulkanKHR
-)";
-
-  CompileSuccessfully(spirv);
-
-  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_WEBGPU_0));
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Addressing model must be Logical for WebGPU "
-                        "environment.\n  OpMemoryModel Physical32 "
-                        "Vulkan\n"));
-}
-
 TEST_F(ValidateWebGPU, NonVulkanKHRMemoryModelBad) {
   std::string spirv = R"(
      OpCapability Shader