Fixes #1480. Validate group non-uniform scopes.
* Adds new pass for validating non-uniform group instructions
* Currently on checks execution scope for Vulkan 1.1 and SPIR-V 1.3
* Added test framework
diff --git a/Android.mk b/Android.mk
index 30be9b7..f6ac829 100644
--- a/Android.mk
+++ b/Android.mk
@@ -54,6 +54,7 @@
source/validate_layout.cpp \
source/validate_literals.cpp \
source/validate_logicals.cpp \
+ source/validate_non_uniform.cpp \
source/validate_primitives.cpp \
source/validate_type_unique.cpp
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index d9f5955..df376ca 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -301,6 +301,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/validate_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validate_literals.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validate_logicals.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/validate_non_uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validate_primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/validate_type_unique.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
diff --git a/source/opcode.cpp b/source/opcode.cpp
index c73f14d..98c4bb9 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -454,3 +454,45 @@
return false;
}
}
+
+bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode) {
+ switch (opcode) {
+ case SpvOpGroupNonUniformElect:
+ case SpvOpGroupNonUniformAll:
+ case SpvOpGroupNonUniformAny:
+ case SpvOpGroupNonUniformAllEqual:
+ case SpvOpGroupNonUniformBroadcast:
+ case SpvOpGroupNonUniformBroadcastFirst:
+ case SpvOpGroupNonUniformBallot:
+ case SpvOpGroupNonUniformInverseBallot:
+ case SpvOpGroupNonUniformBallotBitExtract:
+ case SpvOpGroupNonUniformBallotBitCount:
+ case SpvOpGroupNonUniformBallotFindLSB:
+ case SpvOpGroupNonUniformBallotFindMSB:
+ case SpvOpGroupNonUniformShuffle:
+ case SpvOpGroupNonUniformShuffleXor:
+ case SpvOpGroupNonUniformShuffleUp:
+ case SpvOpGroupNonUniformShuffleDown:
+ case SpvOpGroupNonUniformIAdd:
+ case SpvOpGroupNonUniformFAdd:
+ case SpvOpGroupNonUniformIMul:
+ case SpvOpGroupNonUniformFMul:
+ case SpvOpGroupNonUniformSMin:
+ case SpvOpGroupNonUniformUMin:
+ case SpvOpGroupNonUniformFMin:
+ case SpvOpGroupNonUniformSMax:
+ case SpvOpGroupNonUniformUMax:
+ case SpvOpGroupNonUniformFMax:
+ case SpvOpGroupNonUniformBitwiseAnd:
+ case SpvOpGroupNonUniformBitwiseOr:
+ case SpvOpGroupNonUniformBitwiseXor:
+ case SpvOpGroupNonUniformLogicalAnd:
+ case SpvOpGroupNonUniformLogicalOr:
+ case SpvOpGroupNonUniformLogicalXor:
+ case SpvOpGroupNonUniformQuadBroadcast:
+ case SpvOpGroupNonUniformQuadSwap:
+ return true;
+ default:
+ return false;
+ }
+}
diff --git a/source/opcode.h b/source/opcode.h
index 9b58513..7aadf30 100644
--- a/source/opcode.h
+++ b/source/opcode.h
@@ -118,4 +118,7 @@
// Returns true if the given opcode always defines an opaque type.
bool spvOpcodeIsBaseOpaqueType(SpvOp opcode);
+
+// Returns true if the given opcode is a non-uniform group operation.
+bool spvOpcodeIsNonUniformGroupOperation(SpvOp opcode);
#endif // LIBSPIRV_OPCODE_H_
diff --git a/source/validate.cpp b/source/validate.cpp
index ea73004..953aad1 100644
--- a/source/validate.cpp
+++ b/source/validate.cpp
@@ -189,6 +189,7 @@
if (auto error = BarriersPass(_, inst)) return error;
if (auto error = PrimitivesPass(_, inst)) return error;
if (auto error = LiteralsPass(_, inst)) return error;
+ if (auto error = NonUniformPass(_, inst)) return error;
return SPV_SUCCESS;
}
diff --git a/source/validate.h b/source/validate.h
index a4f6dde..983b30d 100644
--- a/source/validate.h
+++ b/source/validate.h
@@ -170,6 +170,10 @@
spv_result_t ExtInstPass(ValidationState_t& _,
const spv_parsed_instruction_t* inst);
+/// Validates correctness of non-uniform group instructions.
+spv_result_t NonUniformPass(ValidationState_t& _,
+ const spv_parsed_instruction_t* inst);
+
// Validates that capability declarations use operands allowed in the current
// context.
spv_result_t CapabilityPass(ValidationState_t& _,
diff --git a/source/validate_non_uniform.cpp b/source/validate_non_uniform.cpp
new file mode 100644
index 0000000..66c2b42
--- /dev/null
+++ b/source/validate_non_uniform.cpp
@@ -0,0 +1,84 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Validates correctness of barrier SPIR-V instructions.
+
+#include "validate.h"
+
+#include "diagnostic.h"
+#include "opcode.h"
+#include "spirv_constant.h"
+#include "spirv_target_env.h"
+#include "util/bitutils.h"
+#include "val/instruction.h"
+#include "val/validation_state.h"
+
+namespace libspirv {
+
+namespace {
+
+spv_result_t ValidateExecutionScope(ValidationState_t& _,
+ const spv_parsed_instruction_t* inst,
+ uint32_t scope) {
+ SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+ bool is_int32 = false, is_const_int32 = false;
+ uint32_t value = 0;
+ std::tie(is_int32, is_const_int32, value) = _.EvalInt32IfConst(scope);
+
+ if (!is_int32) {
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << spvOpcodeString(opcode)
+ << ": expected Execution Scope to be a 32-bit int";
+ }
+
+ if (!is_const_int32) {
+ return SPV_SUCCESS;
+ }
+
+ if (spvIsVulkanEnv(_.context()->target_env) &&
+ _.context()->target_env != SPV_ENV_VULKAN_1_0 &&
+ value != SpvScopeSubgroup) {
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << spvOpcodeString(opcode)
+ << ": in Vulkan environment Execution scope is limited to "
+ "Subgroup";
+ }
+
+ if (value != SpvScopeSubgroup && value != SpvScopeWorkgroup) {
+ return _.diag(SPV_ERROR_INVALID_DATA) << spvOpcodeString(opcode)
+ << ": Execution scope is limited to "
+ "Subgroup or Workgroup";
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace
+
+// Validates correctness of non-uniform group instructions.
+spv_result_t NonUniformPass(ValidationState_t& _,
+ const spv_parsed_instruction_t* inst) {
+ const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
+ if (spvOpcodeIsNonUniformGroupOperation(opcode)) {
+ const uint32_t execution_scope = inst->words[3];
+ if (auto error = ValidateExecutionScope(_, inst, execution_scope)) {
+ return error;
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace libspirv
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 093a04a..86d5470 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -186,3 +186,9 @@
${VAL_TEST_COMMON_SRCS}
LIBS ${SPIRV_TOOLS}
)
+
+add_spvtools_unittest(TARGET val_non_uniform
+ SRCS val_non_uniform_test.cpp
+ ${VAL_TEST_COMMON_SRCS}
+ LIBS ${SPIRV_TOOLS}
+)
diff --git a/test/val/val_non_uniform_test.cpp b/test/val/val_non_uniform_test.cpp
new file mode 100644
index 0000000..1548f8f
--- /dev/null
+++ b/test/val/val_non_uniform_test.cpp
@@ -0,0 +1,247 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "unit_spirv.h"
+#include "val_fixtures.h"
+
+namespace {
+
+using ::testing::Combine;
+using ::testing::HasSubstr;
+using ::testing::Values;
+using ::testing::ValuesIn;
+
+std::string GenerateShaderCode(
+ const std::string& body,
+ const std::string& capabilities_and_extensions = "",
+ const std::string& execution_model = "GLCompute") {
+ std::ostringstream ss;
+ ss << R"(
+OpCapability Shader
+OpCapability GroupNonUniform
+OpCapability GroupNonUniformVote
+OpCapability GroupNonUniformBallot
+OpCapability GroupNonUniformShuffle
+OpCapability GroupNonUniformShuffleRelative
+OpCapability GroupNonUniformArithmetic
+OpCapability GroupNonUniformClustered
+OpCapability GroupNonUniformQuad
+)";
+
+ ss << capabilities_and_extensions;
+ ss << "OpMemoryModel Logical GLSL450\n";
+ ss << "OpEntryPoint " << execution_model << " %main \"main\"\n";
+
+ ss << R"(
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%bool = OpTypeBool
+%u32 = OpTypeInt 32 0
+%float = OpTypeFloat 32
+%u32vec4 = OpTypeVector %u32 4
+
+%true = OpConstantTrue %bool
+%false = OpConstantFalse %bool
+
+%u32_0 = OpConstant %u32 0
+
+%float_0 = OpConstant %float 0
+
+%u32vec4_null = OpConstantComposite %u32vec4 %u32_0 %u32_0 %u32_0 %u32_0
+
+%cross_device = OpConstant %u32 0
+%device = OpConstant %u32 1
+%workgroup = OpConstant %u32 2
+%subgroup = OpConstant %u32 3
+%invocation = OpConstant %u32 4
+
+%reduce = OpConstant %u32 0
+%inclusive_scan = OpConstant %u32 1
+%exclusive_scan = OpConstant %u32 2
+%clustered_reduce = OpConstant %u32 3
+
+%main = OpFunction %void None %func
+%main_entry = OpLabel
+)";
+
+ ss << body;
+
+ ss << R"(
+OpReturn
+OpFunctionEnd)";
+
+ return ss.str();
+}
+
+SpvScope scopes[] = {SpvScopeCrossDevice, SpvScopeDevice, SpvScopeWorkgroup,
+ SpvScopeSubgroup, SpvScopeInvocation};
+
+using GroupNonUniformScope = spvtest::ValidateBase<
+ std::tuple<std::string, std::string, SpvScope, std::string>>;
+
+std::string ConvertScope(SpvScope scope) {
+ switch (scope) {
+ case SpvScopeCrossDevice:
+ return "%cross_device";
+ case SpvScopeDevice:
+ return "%device";
+ case SpvScopeWorkgroup:
+ return "%workgroup";
+ case SpvScopeSubgroup:
+ return "%subgroup";
+ case SpvScopeInvocation:
+ return "%invocation";
+ default:
+ return "";
+ }
+}
+
+TEST_P(GroupNonUniformScope, Vulkan1p1) {
+ std::string opcode = std::get<0>(GetParam());
+ std::string type = std::get<1>(GetParam());
+ SpvScope execution_scope = std::get<2>(GetParam());
+ std::string args = std::get<3>(GetParam());
+
+ std::ostringstream sstr;
+ sstr << "%result = " << opcode << " ";
+ sstr << type << " ";
+ sstr << ConvertScope(execution_scope) << " ";
+ sstr << args << "\n";
+
+ CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_VULKAN_1_1);
+ spv_result_t result = ValidateInstructions(SPV_ENV_VULKAN_1_1);
+ if (execution_scope == SpvScopeSubgroup) {
+ EXPECT_EQ(SPV_SUCCESS, result);
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "in Vulkan environment Execution scope is limited to Subgroup"));
+ }
+}
+
+TEST_P(GroupNonUniformScope, Spirv1p3) {
+ std::string opcode = std::get<0>(GetParam());
+ std::string type = std::get<1>(GetParam());
+ SpvScope execution_scope = std::get<2>(GetParam());
+ std::string args = std::get<3>(GetParam());
+
+ std::ostringstream sstr;
+ sstr << "%result = " << opcode << " ";
+ sstr << type << " ";
+ sstr << ConvertScope(execution_scope) << " ";
+ sstr << args << "\n";
+
+ CompileSuccessfully(GenerateShaderCode(sstr.str()), SPV_ENV_UNIVERSAL_1_3);
+ spv_result_t result = ValidateInstructions(SPV_ENV_UNIVERSAL_1_3);
+ if (execution_scope == SpvScopeSubgroup ||
+ execution_scope == SpvScopeWorkgroup) {
+ EXPECT_EQ(SPV_SUCCESS, result);
+ } else {
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, result);
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Execution scope is limited to Subgroup or Workgroup"));
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformElect, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformElect"),
+ Values("%bool"), ValuesIn(scopes), Values("")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformVote, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformAll",
+ "OpGroupNonUniformAny",
+ "OpGroupNonUniformAllEqual"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("%true")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcast, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBroadcast"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("%true %u32_0")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBroadcastFirst, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBroadcastFirst"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("%true")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallot, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBallot"),
+ Values("%u32vec4"), ValuesIn(scopes),
+ Values("%true")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformInverseBallot, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformInverseBallot"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("%u32vec4_null")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitExtract, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBallotBitExtract"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("%u32vec4_null %u32_0")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotBitCount, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBallotBitCount"),
+ Values("%u32"), ValuesIn(scopes),
+ Values("Reduce %u32vec4_null")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformBallotFind, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformBallotFindLSB",
+ "OpGroupNonUniformBallotFindMSB"),
+ Values("%u32"), ValuesIn(scopes),
+ Values("%u32vec4_null")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformShuffle, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformShuffle",
+ "OpGroupNonUniformShuffleXor",
+ "OpGroupNonUniformShuffleUp",
+ "OpGroupNonUniformShuffleDown"),
+ Values("%u32"), ValuesIn(scopes),
+ Values("%u32_0 %u32_0")));
+
+INSTANTIATE_TEST_CASE_P(
+ GroupNonUniformIntegerArithmetic, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformIAdd", "OpGroupNonUniformIMul",
+ "OpGroupNonUniformSMin", "OpGroupNonUniformUMin",
+ "OpGroupNonUniformSMax", "OpGroupNonUniformUMax",
+ "OpGroupNonUniformBitwiseAnd", "OpGroupNonUniformBitwiseOr",
+ "OpGroupNonUniformBitwiseXor"),
+ Values("%u32"), ValuesIn(scopes), Values("Reduce %u32_0")));
+
+INSTANTIATE_TEST_CASE_P(
+ GroupNonUniformFloatArithmetic, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformFAdd", "OpGroupNonUniformFMul",
+ "OpGroupNonUniformFMin", "OpGroupNonUniformFMax"),
+ Values("%float"), ValuesIn(scopes), Values("Reduce %float_0")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformLogicalArithmetic, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformLogicalAnd",
+ "OpGroupNonUniformLogicalOr",
+ "OpGroupNonUniformLogicalXor"),
+ Values("%bool"), ValuesIn(scopes),
+ Values("Reduce %true")));
+
+INSTANTIATE_TEST_CASE_P(GroupNonUniformQuad, GroupNonUniformScope,
+ Combine(Values("OpGroupNonUniformQuadBroadcast",
+ "OpGroupNonUniformQuadSwap"),
+ Values("%u32"), ValuesIn(scopes),
+ Values("%u32_0 %u32_0")));
+
+} // anonymous namespace