Check for recursion in Vulkan and WebGPU entry points (#2161)
Fixes #2061
Fixes #2160
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index abc2f31..5d0c624 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -175,14 +175,19 @@
// capability is being used.
// * No function can be targeted by both an OpEntryPoint instruction and an
// OpFunctionCall instruction.
+//
+// Additionally enforces that entry points for Vulkan and WebGPU should not have
+// recursion.
spv_result_t ValidateEntryPoints(ValidationState_t& _) {
_.ComputeFunctionToEntryPointMapping();
+ _.ComputeRecursiveEntryPoints();
if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) {
return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
<< "No OpEntryPoint instruction was found. This is only allowed if "
"the Linkage capability is being used.";
}
+
for (const auto& entry_point : _.entry_points()) {
if (_.IsFunctionCallTarget(entry_point)) {
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
@@ -190,6 +195,17 @@
<< ") may not be targeted by both an OpEntryPoint instruction and "
"an OpFunctionCall instruction.";
}
+
+ // For Vulkan and WebGPU, the static function-call graph for an entry point
+ // must not contain cycles.
+ if (spvIsWebGPUEnv(_.context()->target_env) ||
+ spvIsVulkanEnv(_.context()->target_env)) {
+ if (_.recursive_entry_points().find(entry_point) !=
+ _.recursive_entry_points().end()) {
+ return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
+ << "Entry points may not have a call graph with cycles.";
+ }
+ }
}
return SPV_SUCCESS;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 2bab46b..a10186f 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -919,6 +919,39 @@
}
}
+void ValidationState_t::ComputeRecursiveEntryPoints() {
+ for (const Function func : functions()) {
+ std::stack<uint32_t> call_stack;
+ std::set<uint32_t> visited;
+
+ for (const uint32_t new_call : func.function_call_targets()) {
+ call_stack.push(new_call);
+ }
+
+ while (!call_stack.empty()) {
+ const uint32_t called_func_id = call_stack.top();
+ call_stack.pop();
+
+ if (!visited.insert(called_func_id).second) continue;
+
+ if (called_func_id == func.id()) {
+ for (const uint32_t entry_point :
+ function_to_entry_points_[called_func_id])
+ recursive_entry_points_.insert(entry_point);
+ break;
+ }
+
+ const Function* called_func = function(called_func_id);
+ if (called_func) {
+ // Other checks should error out on this invalid SPIR-V.
+ for (const uint32_t new_call : called_func->function_call_targets()) {
+ call_stack.push(new_call);
+ }
+ }
+ }
+ }
+}
+
const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
uint32_t func) const {
auto iter = function_to_entry_points_.find(func);
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index cbd9a34..85229f2 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -222,6 +222,12 @@
/// Returns a list of entry point function ids
const std::vector<uint32_t>& entry_points() const { return entry_points_; }
+ /// Returns the set of entry points that root call graphs that contain
+ /// recursion.
+ const std::set<uint32_t>& recursive_entry_points() const {
+ return recursive_entry_points_;
+ }
+
/// Registers execution mode for the given entry point.
void RegisterExecutionModeForEntryPoint(uint32_t entry_point,
SpvExecutionMode execution_mode) {
@@ -261,6 +267,11 @@
/// Note: called after fully parsing the binary.
void ComputeFunctionToEntryPointMapping();
+ /// Traverse call tree and computes recursive_entry_points_.
+ /// Note: called after fully parsing the binary and calling
+ /// ComputeFunctionToEntryPointMapping.
+ void ComputeRecursiveEntryPoints();
+
/// Returns all the entry points that can call |func|.
const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
@@ -610,6 +621,10 @@
std::unordered_map<uint32_t, std::vector<EntryPointDescription>>
entry_point_descriptions_;
+ /// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call
+ /// graph that recurses.
+ std::set<uint32_t> recursive_entry_points_;
+
/// Functions IDs that are target of OpFunctionCall.
std::unordered_set<uint32_t> function_call_targets_;
diff --git a/test/val/val_validation_state_test.cpp b/test/val/val_validation_state_test.cpp
index 68504c5..beaeeb0 100644
--- a/test/val/val_validation_state_test.cpp
+++ b/test/val/val_validation_state_test.cpp
@@ -29,11 +29,17 @@
using ValidationStateTest = spvtest::ValidateBase<bool>;
-const char header[] =
+const char kHeader[] =
" OpCapability Shader"
" OpCapability Linkage"
" OpMemoryModel Logical GLSL450 ";
+const char kVulkanMemoryHeader[] =
+ " OpCapability Shader"
+ " OpCapability VulkanMemoryModelKHR"
+ " OpExtension \"SPV_KHR_vulkan_memory_model\""
+ " OpMemoryModel Logical VulkanKHR ";
+
const char kVoidFVoid[] =
" %void = OpTypeVoid"
" %void_f = OpTypeFunction %void"
@@ -42,9 +48,79 @@
" OpReturn"
" OpFunctionEnd ";
+// k*RecursiveBody examples originally from test/opt/function_test.cpp
+const char* kNonRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+const char* kDirectlyRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+const char* kIndirectlyRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+%14 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
// Tests that the instruction count in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumInstructions) {
- std::string spirv = std::string(header) + "%int = OpTypeInt 32 0";
+ std::string spirv = std::string(kHeader) + "%int = OpTypeInt 32 0";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
@@ -52,7 +128,7 @@
// Tests that the number of global variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumGlobalVars) {
- std::string spirv = std::string(header) + R"(
+ std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Input %int
%var_1 = OpVariable %_ptr_int Input
@@ -65,7 +141,7 @@
// Tests that the number of local variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumLocalVars) {
- std::string spirv = std::string(header) + R"(
+ std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Function %int
%voidt = OpTypeVoid
@@ -85,7 +161,7 @@
// Tests that the "id bound" in ValidationState is correct.
TEST_F(ValidationStateTest, CheckIdBound) {
- std::string spirv = std::string(header) + R"(
+ std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%voidt = OpTypeVoid
)";
@@ -96,7 +172,7 @@
// Tests that the entry_points in ValidationState is correct.
TEST_F(ValidationStateTest, CheckEntryPoints) {
- std::string spirv = std::string(header) +
+ std::string spirv = std::string(kHeader) +
" OpEntryPoint Vertex %func \"shader\"" +
std::string(kVoidFVoid);
CompileSuccessfully(spirv);
@@ -154,6 +230,79 @@
EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes);
}
+TEST_F(ValidationStateTest, CheckNonRecursiveBodyGood) {
+ std::string spirv = std::string(kHeader) + kNonRecursiveBody;
+ CompileSuccessfully(spirv);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanNonRecursiveBodyGood) {
+ std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_SUCCESS,
+ ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUNonRecursiveBodyGood) {
+ std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+}
+
+TEST_F(ValidationStateTest, CheckDirectlyRecursiveBodyGood) {
+ std::string spirv = std::string(kHeader) + kDirectlyRecursiveBody;
+ CompileSuccessfully(spirv);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanDirectlyRecursiveBodyBad) {
+ std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+ ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Entry points may not have a call graph with cycles.\n "
+ " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUDirectlyRecursiveBodyBad) {
+ std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+ ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Entry points may not have a call graph with cycles.\n "
+ " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckIndirectlyRecursiveBodyGood) {
+ std::string spirv = std::string(kHeader) + kIndirectlyRecursiveBody;
+ CompileSuccessfully(spirv);
+ EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanIndirectlyRecursiveBodyBad) {
+ std::string spirv =
+ std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+ EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+ ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Entry points may not have a call graph with cycles.\n "
+ " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUIndirectlyRecursiveBodyBad) {
+ std::string spirv =
+ std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
+ CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+ EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+ ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Entry points may not have a call graph with cycles.\n "
+ " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
} // namespace
} // namespace val
} // namespace spvtools