Validate variable pointer related function call rules (#2270)

Fixes #2105

* Check storage class validity
* Check memory object declaration validity

diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp
index de41b27..96c4776 100644
--- a/source/val/validate_function.cpp
+++ b/source/val/validate_function.cpp
@@ -247,6 +247,49 @@
              << "'s type does not match Function <id> '"
              << _.getIdName(parameter_type_id) << "'s parameter type.";
     }
+
+    if (_.addressing_model() == SpvAddressingModelLogical) {
+      if (parameter_type->opcode() == SpvOpTypePointer) {
+        SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
+        // Validate which storage classes can be pointer operands.
+        switch (sc) {
+          case SpvStorageClassUniformConstant:
+          case SpvStorageClassFunction:
+          case SpvStorageClassPrivate:
+          case SpvStorageClassWorkgroup:
+          case SpvStorageClassAtomicCounter:
+            // These are always allowed.
+            break;
+          case SpvStorageClassStorageBuffer:
+            if (!_.features().variable_pointers_storage_buffer) {
+              return _.diag(SPV_ERROR_INVALID_ID, inst)
+                     << "StorageBuffer pointer operand "
+                     << _.getIdName(argument_id)
+                     << " requires a variable pointers capability";
+            }
+            break;
+          default:
+            return _.diag(SPV_ERROR_INVALID_ID, inst)
+                   << "Invalid storage class for pointer operand "
+                   << _.getIdName(argument_id);
+        }
+
+        // Validate memory object declaration requirements.
+        if (argument->opcode() != SpvOpVariable &&
+            argument->opcode() != SpvOpFunctionParameter) {
+          const bool ssbo_vptr =
+              _.features().variable_pointers_storage_buffer &&
+              sc == SpvStorageClassStorageBuffer;
+          const bool wg_vptr =
+              _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
+          if (!ssbo_vptr && !wg_vptr) {
+            return _.diag(SPV_ERROR_INVALID_ID, inst)
+                   << "Pointer operand " << _.getIdName(argument_id)
+                   << " must be a memory object declaration";
+          }
+        }
+      }
+    }
   }
   return SPV_SUCCESS;
 }
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index b4fad28..9f538c2 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -51,8 +51,9 @@
   PCH_FILE pch_test_val
 )
 
-add_spvtools_unittest(TARGET val_ijklmnop
+add_spvtools_unittest(TARGET val_fghijklmnop
   SRCS
+       val_function_test.cpp
        val_id_test.cpp
        val_image_test.cpp
        val_interfaces_test.cpp
diff --git a/test/val/val_function_test.cpp b/test/val/val_function_test.cpp
new file mode 100644
index 0000000..f3dd15e
--- /dev/null
+++ b/test/val/val_function_test.cpp
@@ -0,0 +1,424 @@
+// Copyright (c) 2019 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sstream>
+#include <string>
+#include <tuple>
+
+#include "gmock/gmock.h"
+#include "test/test_fixture.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::Combine;
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateFunctionCall = spvtest::ValidateBase<std::string>;
+
+std::string GenerateShader(const std::string& storage_class,
+                           const std::string& capabilities,
+                           const std::string& extensions) {
+  std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+                      extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %var "var"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%ptr = OpTypePointer )" + storage_class + R"( %int
+%caller_ty = OpTypeFunction %void
+%callee_ty = OpTypeFunction %void %ptr
+)";
+
+  if (storage_class != "Function") {
+    spirv += "%var = OpVariable %ptr " + storage_class;
+  }
+
+  spirv += R"(
+%caller = OpFunction %void None %caller_ty
+%1 = OpLabel
+)";
+
+  if (storage_class == "Function") {
+    spirv += "%var = OpVariable %ptr Function";
+  }
+
+  spirv += R"(
+%call = OpFunctionCall %void %callee %var
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %callee_ty
+%param = OpFunctionParameter %ptr
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  return spirv;
+}
+
+std::string GenerateShaderParameter(const std::string& storage_class,
+                                    const std::string& capabilities,
+                                    const std::string& extensions) {
+  std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+                      extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %p "p"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%ptr = OpTypePointer )" + storage_class + R"( %int
+%func_ty = OpTypeFunction %void %ptr
+%caller = OpFunction %void None %func_ty
+%p = OpFunctionParameter %ptr
+%1 = OpLabel
+%call = OpFunctionCall %void %callee %p
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %func_ty
+%param = OpFunctionParameter %ptr
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  return spirv;
+}
+
+std::string GenerateShaderAccessChain(const std::string& storage_class,
+                                      const std::string& capabilities,
+                                      const std::string& extensions) {
+  std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+                      extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %var "var"
+OpName %gep "gep"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int2 = OpTypeVector %int 2
+%int_0 = OpConstant %int 0
+%ptr = OpTypePointer )" + storage_class + R"( %int2
+%ptr2 = OpTypePointer )" +
+                      storage_class + R"( %int
+%caller_ty = OpTypeFunction %void
+%callee_ty = OpTypeFunction %void %ptr2
+)";
+
+  if (storage_class != "Function") {
+    spirv += "%var = OpVariable %ptr " + storage_class;
+  }
+
+  spirv += R"(
+%caller = OpFunction %void None %caller_ty
+%1 = OpLabel
+)";
+
+  if (storage_class == "Function") {
+    spirv += "%var = OpVariable %ptr Function";
+  }
+
+  spirv += R"(
+%gep = OpAccessChain %ptr2 %var %int_0
+%call = OpFunctionCall %void %callee %gep
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %callee_ty
+%param = OpFunctionParameter %ptr2
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  return spirv;
+}
+
+TEST_P(ValidateFunctionCall, VariableNoVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShader(storage_class, "", "");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    if (storage_class == "StorageBuffer") {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("StorageBuffer pointer operand 1[%var] requires a "
+                            "variable pointers capability"));
+    } else {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+    }
+  }
+}
+
+TEST_P(ValidateFunctionCall, VariableVariablePointersStorageClass) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShader(
+      storage_class, "OpCapability VariablePointersStorageBuffer",
+      "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+  }
+}
+
+TEST_P(ValidateFunctionCall, VariableVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv =
+      GenerateShader(storage_class, "OpCapability VariablePointers",
+                     "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+  }
+}
+
+TEST_P(ValidateFunctionCall, ParameterNoVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShaderParameter(storage_class, "", "");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    if (storage_class == "StorageBuffer") {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("StorageBuffer pointer operand 1[%p] requires a "
+                            "variable pointers capability"));
+    } else {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+    }
+  }
+}
+
+TEST_P(ValidateFunctionCall, ParameterVariablePointersStorageBuffer) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShaderParameter(
+      storage_class, "OpCapability VariablePointersStorageBuffer",
+      "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+  }
+}
+
+TEST_P(ValidateFunctionCall, ParameterVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv =
+      GenerateShaderParameter(storage_class, "OpCapability VariablePointers",
+                              "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  if (valid) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    EXPECT_THAT(getDiagnosticString(),
+                HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+  }
+}
+
+TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationNoVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShaderAccessChain(storage_class, "", "");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+  bool valid_sc =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  if (valid_sc) {
+    EXPECT_THAT(
+        getDiagnosticString(),
+        HasSubstr(
+            "Pointer operand 2[%gep] must be a memory object declaration"));
+  } else {
+    if (storage_class == "StorageBuffer") {
+      EXPECT_THAT(getDiagnosticString(),
+                  HasSubstr("StorageBuffer pointer operand 2[%gep] requires a "
+                            "variable pointers capability"));
+    } else {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+    }
+  }
+}
+
+TEST_P(ValidateFunctionCall,
+       NonMemoryObjectDeclarationVariablePointersStorageBuffer) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv = GenerateShaderAccessChain(
+      storage_class, "OpCapability VariablePointersStorageBuffer",
+      "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid_sc =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+  bool validate = storage_class == "StorageBuffer";
+
+  CompileSuccessfully(spirv);
+  if (validate) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    if (valid_sc) {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr(
+              "Pointer operand 2[%gep] must be a memory object declaration"));
+    } else {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+    }
+  }
+}
+
+TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationVariablePointers) {
+  const std::string storage_class = GetParam();
+
+  std::string spirv =
+      GenerateShaderAccessChain(storage_class, "OpCapability VariablePointers",
+                                "OpExtension \"SPV_KHR_variable_pointers\"");
+
+  const std::vector<std::string> valid_storage_classes = {
+      "UniformConstant", "Function",      "Private",
+      "Workgroup",       "StorageBuffer", "AtomicCounter"};
+  bool valid_sc =
+      std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+                storage_class) != valid_storage_classes.end();
+  bool validate =
+      storage_class == "StorageBuffer" || storage_class == "Workgroup";
+
+  CompileSuccessfully(spirv);
+  if (validate) {
+    EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+  } else {
+    EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+    if (valid_sc) {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr(
+              "Pointer operand 2[%gep] must be a memory object declaration"));
+    } else {
+      EXPECT_THAT(
+          getDiagnosticString(),
+          HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+    }
+  }
+}
+
+INSTANTIATE_TEST_SUITE_P(StorageClass, ValidateFunctionCall,
+                         Values("UniformConstant", "Input", "Uniform", "Output",
+                                "Workgroup", "Private", "Function",
+                                "PushConstant", "Image", "StorageBuffer",
+                                "AtomicCounter"));
+}  // namespace
+}  // namespace val
+}  // namespace spvtools
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 9910a09..df177f2 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -2274,6 +2274,7 @@
     *spirv << "OpCapability VariablePointers ";
     *spirv << "OpExtension \"SPV_KHR_variable_pointers\" ";
   }
+  *spirv << "OpExtension \"SPV_KHR_storage_buffer_storage_class\" ";
   *spirv << R"(
     OpMemoryModel Logical GLSL450
     OpEntryPoint GLCompute %main "main"
@@ -2282,12 +2283,12 @@
     %bool      = OpTypeBool
     %i32       = OpTypeInt 32 1
     %f32       = OpTypeFloat 32
-    %f32ptr    = OpTypePointer Uniform %f32
+    %f32ptr    = OpTypePointer StorageBuffer %f32
     %i         = OpConstant %i32 1
     %zero      = OpConstant %i32 0
     %float_1   = OpConstant %f32 1.0
-    %ptr1      = OpVariable %f32ptr Uniform
-    %ptr2      = OpVariable %f32ptr Uniform
+    %ptr1      = OpVariable %f32ptr StorageBuffer
+    %ptr2      = OpVariable %f32ptr StorageBuffer
   )";
   if (add_helper_function) {
     *spirv << R"(