Validate variable pointer related function call rules (#2270)
Fixes #2105
* Check storage class validity
* Check memory object declaration validity
diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp
index de41b27..96c4776 100644
--- a/source/val/validate_function.cpp
+++ b/source/val/validate_function.cpp
@@ -247,6 +247,49 @@
<< "'s type does not match Function <id> '"
<< _.getIdName(parameter_type_id) << "'s parameter type.";
}
+
+ if (_.addressing_model() == SpvAddressingModelLogical) {
+ if (parameter_type->opcode() == SpvOpTypePointer) {
+ SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
+ // Validate which storage classes can be pointer operands.
+ switch (sc) {
+ case SpvStorageClassUniformConstant:
+ case SpvStorageClassFunction:
+ case SpvStorageClassPrivate:
+ case SpvStorageClassWorkgroup:
+ case SpvStorageClassAtomicCounter:
+ // These are always allowed.
+ break;
+ case SpvStorageClassStorageBuffer:
+ if (!_.features().variable_pointers_storage_buffer) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "StorageBuffer pointer operand "
+ << _.getIdName(argument_id)
+ << " requires a variable pointers capability";
+ }
+ break;
+ default:
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Invalid storage class for pointer operand "
+ << _.getIdName(argument_id);
+ }
+
+ // Validate memory object declaration requirements.
+ if (argument->opcode() != SpvOpVariable &&
+ argument->opcode() != SpvOpFunctionParameter) {
+ const bool ssbo_vptr =
+ _.features().variable_pointers_storage_buffer &&
+ sc == SpvStorageClassStorageBuffer;
+ const bool wg_vptr =
+ _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
+ if (!ssbo_vptr && !wg_vptr) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "Pointer operand " << _.getIdName(argument_id)
+ << " must be a memory object declaration";
+ }
+ }
+ }
+ }
}
return SPV_SUCCESS;
}
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index b4fad28..9f538c2 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -51,8 +51,9 @@
PCH_FILE pch_test_val
)
-add_spvtools_unittest(TARGET val_ijklmnop
+add_spvtools_unittest(TARGET val_fghijklmnop
SRCS
+ val_function_test.cpp
val_id_test.cpp
val_image_test.cpp
val_interfaces_test.cpp
diff --git a/test/val/val_function_test.cpp b/test/val/val_function_test.cpp
new file mode 100644
index 0000000..f3dd15e
--- /dev/null
+++ b/test/val/val_function_test.cpp
@@ -0,0 +1,424 @@
+// Copyright (c) 2019 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sstream>
+#include <string>
+#include <tuple>
+
+#include "gmock/gmock.h"
+#include "test/test_fixture.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::Combine;
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateFunctionCall = spvtest::ValidateBase<std::string>;
+
+std::string GenerateShader(const std::string& storage_class,
+ const std::string& capabilities,
+ const std::string& extensions) {
+ std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+ extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %var "var"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%ptr = OpTypePointer )" + storage_class + R"( %int
+%caller_ty = OpTypeFunction %void
+%callee_ty = OpTypeFunction %void %ptr
+)";
+
+ if (storage_class != "Function") {
+ spirv += "%var = OpVariable %ptr " + storage_class;
+ }
+
+ spirv += R"(
+%caller = OpFunction %void None %caller_ty
+%1 = OpLabel
+)";
+
+ if (storage_class == "Function") {
+ spirv += "%var = OpVariable %ptr Function";
+ }
+
+ spirv += R"(
+%call = OpFunctionCall %void %callee %var
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %callee_ty
+%param = OpFunctionParameter %ptr
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ return spirv;
+}
+
+std::string GenerateShaderParameter(const std::string& storage_class,
+ const std::string& capabilities,
+ const std::string& extensions) {
+ std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+ extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %p "p"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%ptr = OpTypePointer )" + storage_class + R"( %int
+%func_ty = OpTypeFunction %void %ptr
+%caller = OpFunction %void None %func_ty
+%p = OpFunctionParameter %ptr
+%1 = OpLabel
+%call = OpFunctionCall %void %callee %p
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %func_ty
+%param = OpFunctionParameter %ptr
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ return spirv;
+}
+
+std::string GenerateShaderAccessChain(const std::string& storage_class,
+ const std::string& capabilities,
+ const std::string& extensions) {
+ std::string spirv = R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability AtomicStorage
+)" + capabilities + R"(
+OpExtension "SPV_KHR_storage_buffer_storage_class"
+)" +
+ extensions + R"(
+OpMemoryModel Logical GLSL450
+OpName %var "var"
+OpName %gep "gep"
+%void = OpTypeVoid
+%int = OpTypeInt 32 0
+%int2 = OpTypeVector %int 2
+%int_0 = OpConstant %int 0
+%ptr = OpTypePointer )" + storage_class + R"( %int2
+%ptr2 = OpTypePointer )" +
+ storage_class + R"( %int
+%caller_ty = OpTypeFunction %void
+%callee_ty = OpTypeFunction %void %ptr2
+)";
+
+ if (storage_class != "Function") {
+ spirv += "%var = OpVariable %ptr " + storage_class;
+ }
+
+ spirv += R"(
+%caller = OpFunction %void None %caller_ty
+%1 = OpLabel
+)";
+
+ if (storage_class == "Function") {
+ spirv += "%var = OpVariable %ptr Function";
+ }
+
+ spirv += R"(
+%gep = OpAccessChain %ptr2 %var %int_0
+%call = OpFunctionCall %void %callee %gep
+OpReturn
+OpFunctionEnd
+%callee = OpFunction %void None %callee_ty
+%param = OpFunctionParameter %ptr2
+%2 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ return spirv;
+}
+
+TEST_P(ValidateFunctionCall, VariableNoVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShader(storage_class, "", "");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ if (storage_class == "StorageBuffer") {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("StorageBuffer pointer operand 1[%var] requires a "
+ "variable pointers capability"));
+ } else {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+ }
+ }
+}
+
+TEST_P(ValidateFunctionCall, VariableVariablePointersStorageClass) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShader(
+ storage_class, "OpCapability VariablePointersStorageBuffer",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+ }
+}
+
+TEST_P(ValidateFunctionCall, VariableVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv =
+ GenerateShader(storage_class, "OpCapability VariablePointers",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%var]"));
+ }
+}
+
+TEST_P(ValidateFunctionCall, ParameterNoVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShaderParameter(storage_class, "", "");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ if (storage_class == "StorageBuffer") {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("StorageBuffer pointer operand 1[%p] requires a "
+ "variable pointers capability"));
+ } else {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+ }
+ }
+}
+
+TEST_P(ValidateFunctionCall, ParameterVariablePointersStorageBuffer) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShaderParameter(
+ storage_class, "OpCapability VariablePointersStorageBuffer",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+ }
+}
+
+TEST_P(ValidateFunctionCall, ParameterVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv =
+ GenerateShaderParameter(storage_class, "OpCapability VariablePointers",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ if (valid) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 1[%p]"));
+ }
+}
+
+TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationNoVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShaderAccessChain(storage_class, "", "");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
+ bool valid_sc =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+
+ CompileSuccessfully(spirv);
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ if (valid_sc) {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Pointer operand 2[%gep] must be a memory object declaration"));
+ } else {
+ if (storage_class == "StorageBuffer") {
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("StorageBuffer pointer operand 2[%gep] requires a "
+ "variable pointers capability"));
+ } else {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+ }
+ }
+}
+
+TEST_P(ValidateFunctionCall,
+ NonMemoryObjectDeclarationVariablePointersStorageBuffer) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv = GenerateShaderAccessChain(
+ storage_class, "OpCapability VariablePointersStorageBuffer",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid_sc =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+ bool validate = storage_class == "StorageBuffer";
+
+ CompileSuccessfully(spirv);
+ if (validate) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ if (valid_sc) {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Pointer operand 2[%gep] must be a memory object declaration"));
+ } else {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+ }
+ }
+}
+
+TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationVariablePointers) {
+ const std::string storage_class = GetParam();
+
+ std::string spirv =
+ GenerateShaderAccessChain(storage_class, "OpCapability VariablePointers",
+ "OpExtension \"SPV_KHR_variable_pointers\"");
+
+ const std::vector<std::string> valid_storage_classes = {
+ "UniformConstant", "Function", "Private",
+ "Workgroup", "StorageBuffer", "AtomicCounter"};
+ bool valid_sc =
+ std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
+ storage_class) != valid_storage_classes.end();
+ bool validate =
+ storage_class == "StorageBuffer" || storage_class == "Workgroup";
+
+ CompileSuccessfully(spirv);
+ if (validate) {
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ if (valid_sc) {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "Pointer operand 2[%gep] must be a memory object declaration"));
+ } else {
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Invalid storage class for pointer operand 2[%gep]"));
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(StorageClass, ValidateFunctionCall,
+ Values("UniformConstant", "Input", "Uniform", "Output",
+ "Workgroup", "Private", "Function",
+ "PushConstant", "Image", "StorageBuffer",
+ "AtomicCounter"));
+} // namespace
+} // namespace val
+} // namespace spvtools
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 9910a09..df177f2 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -2274,6 +2274,7 @@
*spirv << "OpCapability VariablePointers ";
*spirv << "OpExtension \"SPV_KHR_variable_pointers\" ";
}
+ *spirv << "OpExtension \"SPV_KHR_storage_buffer_storage_class\" ";
*spirv << R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
@@ -2282,12 +2283,12 @@
%bool = OpTypeBool
%i32 = OpTypeInt 32 1
%f32 = OpTypeFloat 32
- %f32ptr = OpTypePointer Uniform %f32
+ %f32ptr = OpTypePointer StorageBuffer %f32
%i = OpConstant %i32 1
%zero = OpConstant %i32 0
%float_1 = OpConstant %f32 1.0
- %ptr1 = OpVariable %f32ptr Uniform
- %ptr2 = OpVariable %f32ptr Uniform
+ %ptr1 = OpVariable %f32ptr StorageBuffer
+ %ptr2 = OpVariable %f32ptr StorageBuffer
)";
if (add_helper_function) {
*spirv << R"(