spirv-val: Add OpPtrAccessChain Base checks (#4965)

diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 074bdb8..8a66bee 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1410,7 +1410,51 @@
              << "VariablePointers or VariablePointersStorageBuffer";
     }
   }
-  return ValidateAccessChain(_, inst);
+
+  // Need to call first, will make sure Base is a valid ID
+  if (auto error = ValidateAccessChain(_, inst)) return error;
+
+  const auto base_id = inst->GetOperandAs<uint32_t>(2);
+  const auto base = _.FindDef(base_id);
+  const auto base_type = _.FindDef(base->type_id());
+  const auto base_type_storage_class = base_type->word(2);
+
+  if (_.HasCapability(SpvCapabilityShader) &&
+      (base_type_storage_class == SpvStorageClassUniform ||
+       base_type_storage_class == SpvStorageClassStorageBuffer ||
+       base_type_storage_class == SpvStorageClassPhysicalStorageBuffer ||
+       base_type_storage_class == SpvStorageClassPushConstant ||
+       (_.HasCapability(SpvCapabilityWorkgroupMemoryExplicitLayoutKHR) &&
+        base_type_storage_class == SpvStorageClassWorkgroup)) &&
+      !_.HasDecoration(base_type->id(), SpvDecorationArrayStride)) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "OpPtrAccessChain must have a Base whose type is decorated "
+              "with ArrayStride";
+  }
+
+  if (spvIsVulkanEnv(_.context()->target_env)) {
+    if (base_type_storage_class == SpvStorageClassWorkgroup) {
+      if (!_.HasCapability(SpvCapabilityVariablePointers)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "OpPtrAccessChain Base operand pointing to Workgroup "
+                  "storage class must use VariablePointers capability";
+      }
+    } else if (base_type_storage_class == SpvStorageClassStorageBuffer) {
+      if (!_.features().variable_pointers) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "OpPtrAccessChain Base operand pointing to StorageBuffer "
+                  "storage class must use VariablePointers or "
+                  "VariablePointersStorageBuffer capability";
+      }
+    } else if (base_type_storage_class !=
+               SpvStorageClassPhysicalStorageBuffer) {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "OpPtrAccessChain Base operand must point to Workgroup, "
+                "StorageBuffer, or PhysicalStorageBuffer storage class";
+    }
+  }
+
+  return SPV_SUCCESS;
 }
 
 spv_result_t ValidateArrayLength(ValidationState_t& state,
diff --git a/test/opt/eliminate_dead_member_test.cpp b/test/opt/eliminate_dead_member_test.cpp
index e277999..4438f3d 100644
--- a/test/opt/eliminate_dead_member_test.cpp
+++ b/test/opt/eliminate_dead_member_test.cpp
@@ -978,6 +978,7 @@
                OpMemberDecorate %type__Globals 1 Offset 4
                OpMemberDecorate %type__Globals 2 Offset 16
                OpDecorate %type__Globals Block
+               OpDecorate %_ptr_Uniform_type__Globals ArrayStride 8
        %uint = OpTypeInt 32 0
      %uint_0 = OpConstant %uint 0
      %uint_1 = OpConstant %uint 1
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index dd040b3..a457980 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -3981,8 +3981,11 @@
 TEST_P(AccessChainInstructionTest, AccessChainTooManyIndexesGood) {
   const std::string instr = GetParam();
   const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
+  const std::string arrayStride =
+      " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 ";
   int depth = 255;
-  std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup;
+  std::string header =
+      kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup;
   header.erase(header.find("%func"));
   std::ostringstream spirv;
   spirv << header << "\n";
@@ -4044,8 +4047,11 @@
 TEST_P(AccessChainInstructionTest, CustomizedAccessChainTooManyIndexesGood) {
   const std::string instr = GetParam();
   const std::string elem = AccessChainRequiresElemId(instr) ? " %int_0 " : "";
+  const std::string arrayStride =
+      " OpDecorate %_ptr_Uniform_deep_struct ArrayStride 8 ";
   int depth = 10;
-  std::string header = kGLSL450MemoryModel + kDeeplyNestedStructureSetup;
+  std::string header =
+      kGLSL450MemoryModel + arrayStride + kDeeplyNestedStructureSetup;
   header.erase(header.find("%func"));
   std::ostringstream spirv;
   spirv << header << "\n";
@@ -4217,8 +4223,11 @@
   // 0 will select the element at the index 0 of the vector. (which is a float).
   const std::string instr = GetParam();
   const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
+  const std::string arrayStride =
+      " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 ";
   std::ostringstream spirv;
-  spirv << kGLSL450MemoryModel << kDeeplyNestedStructureSetup << std::endl;
+  spirv << kGLSL450MemoryModel << arrayStride << kDeeplyNestedStructureSetup
+        << std::endl;
   spirv << "%ss = " << instr << " %_ptr_Uniform_struct_s %blockName_var "
         << elem << "%int_0" << std::endl;
   spirv << "%sa = " << instr << " %_ptr_Uniform_array5_mat4x3 %blockName_var "
@@ -4241,9 +4250,12 @@
 TEST_P(AccessChainInstructionTest, AccessChainIndexIntoRuntimeArrayGood) {
   const std::string instr = GetParam();
   const std::string elem = AccessChainRequiresElemId(instr) ? "%int_0 " : "";
-  std::string spirv = kGLSL450MemoryModel + kDeeplyNestedStructureSetup + R"(
-%runtime_arr_entry = )" +
-                      instr + R"( %_ptr_Uniform_float %blockName_var )" + elem +
+  const std::string arrayStride =
+      " OpDecorate %_ptr_Uniform_blockName ArrayStride 8 ";
+  std::string spirv = kGLSL450MemoryModel + arrayStride +
+                      kDeeplyNestedStructureSetup + R"(
+%runtime_arr_entry = )" + instr +
+                      R"( %_ptr_Uniform_float %blockName_var )" + elem +
                       R"(%int_2 %int_0
 OpReturn
 OpFunctionEnd
@@ -5631,6 +5643,7 @@
      OpExtension "SPV_KHR_variable_pointers"
      OpMemoryModel Logical GLSL450
      OpEntryPoint GLCompute %3 ""
+     OpDecorate %ptr ArrayStride 8
 %int = OpTypeInt 32 0
 %int_2 = OpConstant %int 2
 %int_4 = OpConstant %int 4
diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp
index 4299eda..780aeed 100644
--- a/test/val/val_memory_test.cpp
+++ b/test/val/val_memory_test.cpp
@@ -4712,6 +4712,220 @@
   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
+TEST_F(ValidateMemory, PtrAccessChainArrayStrideBad) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability VariablePointersStorageBuffer
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+               OpDecorate %var DescriptorSet 0
+               OpDecorate %var Binding 0
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+        %ptr = OpTypePointer StorageBuffer %uint
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+        %var = OpVariable %ptr StorageBuffer
+       %main = OpFunction %void None %func
+      %label = OpLabel
+     %access = OpAccessChain %ptr %var
+ %ptr_access = OpPtrAccessChain %ptr %access %uint_1
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA,
+            ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpPtrAccessChain must have a Base whose type is "
+                        "decorated with ArrayStride"));
+}
+
+TEST_F(ValidateMemory, PtrAccessChainArrayStrideSuccess) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability VariablePointersStorageBuffer
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+               OpDecorate %var DescriptorSet 0
+               OpDecorate %var Binding 00
+               OpDecorate %ptr ArrayStride 4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+        %ptr = OpTypePointer StorageBuffer %uint
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+        %var = OpVariable %ptr StorageBuffer
+       %main = OpFunction %void None %func
+      %label = OpLabel
+     %access = OpAccessChain %ptr %var
+ %ptr_access = OpPtrAccessChain %ptr %access %uint_1
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_5);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_UNIVERSAL_1_5));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferSuccess) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability VariablePointersStorageBuffer
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+               OpDecorate %_runtimearr_uint ArrayStride 4
+               OpMemberDecorate %_struct_10 0 Offset 0
+               OpDecorate %_struct_10 Block
+               OpDecorate %var DescriptorSet 0
+               OpDecorate %var Binding 0
+               OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+ %_struct_10 = OpTypeStruct %_runtimearr_uint
+%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+       %void = OpTypeVoid
+      %func2 = OpTypeFunction %void %_ptr_StorageBuffer_uint
+      %func1 = OpTypeFunction %void
+         %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
+     %called = OpFunction %void None %func2
+      %param = OpFunctionParameter %_ptr_StorageBuffer_uint
+     %label2 = OpLabel
+ %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %param %uint_1
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %func1
+     %label1 = OpLabel
+     %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0
+       %call = OpFunctionCall %void %called %access
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainStorageBufferCapability) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability PhysicalStorageBufferAddresses
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel PhysicalStorageBuffer64 GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+               OpDecorate %_runtimearr_uint ArrayStride 4
+               OpMemberDecorate %_struct_10 0 Offset 0
+               OpDecorate %_struct_10 Block
+               OpDecorate %var DescriptorSet 0
+               OpDecorate %var Binding 0
+               OpDecorate %_ptr_StorageBuffer_uint ArrayStride 4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+%_runtimearr_uint = OpTypeRuntimeArray %uint
+ %_struct_10 = OpTypeStruct %_runtimearr_uint
+%_ptr_StorageBuffer__struct_10 = OpTypePointer StorageBuffer %_struct_10
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+         %var = OpVariable %_ptr_StorageBuffer__struct_10 StorageBuffer
+       %main = OpFunction %void None %func
+      %label = OpLabel
+     %access = OpAccessChain %_ptr_StorageBuffer_uint %var %uint_0 %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_StorageBuffer_uint %access %uint_1
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpPtrAccessChain Base operand pointing to "
+                        "StorageBuffer storage class must use VariablePointers "
+                        "or VariablePointersStorageBuffer capability"));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupCapability) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability VariablePointersStorageBuffer
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+               OpDecorate %_ptr_Workgroup_uint ArrayStride 4
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+%_arr_uint = OpTypeArray %uint %uint_1
+%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+        %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup
+       %main = OpFunction %void None %func
+      %label = OpLabel
+     %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpPtrAccessChain Base operand pointing to Workgroup "
+                        "storage class must use VariablePointers capability"));
+}
+
+TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupNoArrayStrideSuccess) {
+  const std::string spirv = R"(
+               OpCapability Shader
+               OpCapability VariablePointers
+               OpExtension "SPV_KHR_storage_buffer_storage_class"
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "foo" %var
+               OpExecutionMode %main LocalSize 1 1 1
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+     %uint_1 = OpConstant %uint 1
+%_arr_uint = OpTypeArray %uint %uint_1
+%_ptr_Workgroup__arr_uint = OpTypePointer Workgroup %_arr_uint
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+       %void = OpTypeVoid
+       %func = OpTypeFunction %void
+        %var = OpVariable %_ptr_Workgroup__arr_uint Workgroup
+       %main = OpFunction %void None %func
+      %label = OpLabel
+     %access = OpAccessChain %_ptr_Workgroup_uint %var %uint_0
+ %ptr_access = OpPtrAccessChain %_ptr_Workgroup_uint %access %uint_1
+               OpReturn
+               OpFunctionEnd
+)";
+
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools