Check for recursion in Vulkan and WebGPU entry points (#2161)

Fixes #2061
Fixes #2160
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index abc2f31..5d0c624 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -175,14 +175,19 @@
 //   capability is being used.
 // * No function can be targeted by both an OpEntryPoint instruction and an
 //   OpFunctionCall instruction.
+//
+// Additionally enforces that entry points for Vulkan and WebGPU should not have
+// recursion.
 spv_result_t ValidateEntryPoints(ValidationState_t& _) {
   _.ComputeFunctionToEntryPointMapping();
+  _.ComputeRecursiveEntryPoints();
 
   if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) {
     return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
            << "No OpEntryPoint instruction was found. This is only allowed if "
               "the Linkage capability is being used.";
   }
+
   for (const auto& entry_point : _.entry_points()) {
     if (_.IsFunctionCallTarget(entry_point)) {
       return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
@@ -190,6 +195,17 @@
              << ") may not be targeted by both an OpEntryPoint instruction and "
                 "an OpFunctionCall instruction.";
     }
+
+    // For Vulkan and WebGPU, the static function-call graph for an entry point
+    // must not contain cycles.
+    if (spvIsWebGPUEnv(_.context()->target_env) ||
+        spvIsVulkanEnv(_.context()->target_env)) {
+      if (_.recursive_entry_points().find(entry_point) !=
+          _.recursive_entry_points().end()) {
+        return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
+               << "Entry points may not have a call graph with cycles.";
+      }
+    }
   }
 
   return SPV_SUCCESS;
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 2bab46b..a10186f 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -919,6 +919,39 @@
   }
 }
 
+void ValidationState_t::ComputeRecursiveEntryPoints() {
+  for (const Function func : functions()) {
+    std::stack<uint32_t> call_stack;
+    std::set<uint32_t> visited;
+
+    for (const uint32_t new_call : func.function_call_targets()) {
+      call_stack.push(new_call);
+    }
+
+    while (!call_stack.empty()) {
+      const uint32_t called_func_id = call_stack.top();
+      call_stack.pop();
+
+      if (!visited.insert(called_func_id).second) continue;
+
+      if (called_func_id == func.id()) {
+        for (const uint32_t entry_point :
+             function_to_entry_points_[called_func_id])
+          recursive_entry_points_.insert(entry_point);
+        break;
+      }
+
+      const Function* called_func = function(called_func_id);
+      if (called_func) {
+        // Other checks should error out on this invalid SPIR-V.
+        for (const uint32_t new_call : called_func->function_call_targets()) {
+          call_stack.push(new_call);
+        }
+      }
+    }
+  }
+}
+
 const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
     uint32_t func) const {
   auto iter = function_to_entry_points_.find(func);
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index cbd9a34..85229f2 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -222,6 +222,12 @@
   /// Returns a list of entry point function ids
   const std::vector<uint32_t>& entry_points() const { return entry_points_; }
 
+  /// Returns the set of entry points that root call graphs that contain
+  /// recursion.
+  const std::set<uint32_t>& recursive_entry_points() const {
+    return recursive_entry_points_;
+  }
+
   /// Registers execution mode for the given entry point.
   void RegisterExecutionModeForEntryPoint(uint32_t entry_point,
                                           SpvExecutionMode execution_mode) {
@@ -261,6 +267,11 @@
   /// Note: called after fully parsing the binary.
   void ComputeFunctionToEntryPointMapping();
 
+  /// Traverse call tree and computes recursive_entry_points_.
+  /// Note: called after fully parsing the binary and calling
+  /// ComputeFunctionToEntryPointMapping.
+  void ComputeRecursiveEntryPoints();
+
   /// Returns all the entry points that can call |func|.
   const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
 
@@ -610,6 +621,10 @@
   std::unordered_map<uint32_t, std::vector<EntryPointDescription>>
       entry_point_descriptions_;
 
+  /// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call
+  /// graph that recurses.
+  std::set<uint32_t> recursive_entry_points_;
+
   /// Functions IDs that are target of OpFunctionCall.
   std::unordered_set<uint32_t> function_call_targets_;
 
diff --git a/test/val/val_validation_state_test.cpp b/test/val/val_validation_state_test.cpp
index 68504c5..beaeeb0 100644
--- a/test/val/val_validation_state_test.cpp
+++ b/test/val/val_validation_state_test.cpp
@@ -29,11 +29,17 @@
 
 using ValidationStateTest = spvtest::ValidateBase<bool>;
 
-const char header[] =
+const char kHeader[] =
     " OpCapability Shader"
     " OpCapability Linkage"
     " OpMemoryModel Logical GLSL450 ";
 
+const char kVulkanMemoryHeader[] =
+    " OpCapability Shader"
+    " OpCapability VulkanMemoryModelKHR"
+    " OpExtension \"SPV_KHR_vulkan_memory_model\""
+    " OpMemoryModel Logical VulkanKHR ";
+
 const char kVoidFVoid[] =
     " %void   = OpTypeVoid"
     " %void_f = OpTypeFunction %void"
@@ -42,9 +48,79 @@
     "           OpReturn"
     "           OpFunctionEnd ";
 
+// k*RecursiveBody examples originally from test/opt/function_test.cpp
+const char* kNonRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+const char* kDirectlyRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+const char* kIndirectlyRecursiveBody = R"(
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+%14 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
 // Tests that the instruction count in ValidationState is correct.
 TEST_F(ValidationStateTest, CheckNumInstructions) {
-  std::string spirv = std::string(header) + "%int = OpTypeInt 32 0";
+  std::string spirv = std::string(kHeader) + "%int = OpTypeInt 32 0";
   CompileSuccessfully(spirv);
   EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
   EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
@@ -52,7 +128,7 @@
 
 // Tests that the number of global variables in ValidationState is correct.
 TEST_F(ValidationStateTest, CheckNumGlobalVars) {
-  std::string spirv = std::string(header) + R"(
+  std::string spirv = std::string(kHeader) + R"(
      %int = OpTypeInt 32 0
 %_ptr_int = OpTypePointer Input %int
    %var_1 = OpVariable %_ptr_int Input
@@ -65,7 +141,7 @@
 
 // Tests that the number of local variables in ValidationState is correct.
 TEST_F(ValidationStateTest, CheckNumLocalVars) {
-  std::string spirv = std::string(header) + R"(
+  std::string spirv = std::string(kHeader) + R"(
  %int      = OpTypeInt 32 0
  %_ptr_int = OpTypePointer Function %int
  %voidt    = OpTypeVoid
@@ -85,7 +161,7 @@
 
 // Tests that the "id bound" in ValidationState is correct.
 TEST_F(ValidationStateTest, CheckIdBound) {
-  std::string spirv = std::string(header) + R"(
+  std::string spirv = std::string(kHeader) + R"(
  %int      = OpTypeInt 32 0
  %voidt    = OpTypeVoid
   )";
@@ -96,7 +172,7 @@
 
 // Tests that the entry_points in ValidationState is correct.
 TEST_F(ValidationStateTest, CheckEntryPoints) {
-  std::string spirv = std::string(header) +
+  std::string spirv = std::string(kHeader) +
                       " OpEntryPoint Vertex %func \"shader\"" +
                       std::string(kVoidFVoid);
   CompileSuccessfully(spirv);
@@ -154,6 +230,79 @@
   EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes);
 }
 
+TEST_F(ValidationStateTest, CheckNonRecursiveBodyGood) {
+  std::string spirv = std::string(kHeader) + kNonRecursiveBody;
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanNonRecursiveBodyGood) {
+  std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_SUCCESS,
+            ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUNonRecursiveBodyGood) {
+  std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+}
+
+TEST_F(ValidationStateTest, CheckDirectlyRecursiveBodyGood) {
+  std::string spirv = std::string(kHeader) + kDirectlyRecursiveBody;
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanDirectlyRecursiveBodyBad) {
+  std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Entry points may not have a call graph with cycles.\n "
+                        " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUDirectlyRecursiveBodyBad) {
+  std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Entry points may not have a call graph with cycles.\n "
+                        " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckIndirectlyRecursiveBodyGood) {
+  std::string spirv = std::string(kHeader) + kIndirectlyRecursiveBody;
+  CompileSuccessfully(spirv);
+  EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
+}
+
+TEST_F(ValidationStateTest, CheckVulkanIndirectlyRecursiveBodyBad) {
+  std::string spirv =
+      std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Entry points may not have a call graph with cycles.\n "
+                        " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
+TEST_F(ValidationStateTest, CheckWebGPUIndirectlyRecursiveBodyBad) {
+  std::string spirv =
+      std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
+  CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
+  EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
+            ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Entry points may not have a call graph with cycles.\n "
+                        " %1 = OpFunction %void Pure|Const %3\n"));
+}
+
 }  // namespace
 }  // namespace val
 }  // namespace spvtools