Explicitly detect static recursion in SkSL

This relaxes our rules to allow calls to declared (but not yet defined)
functions. With that rule change, we have to specifically detect static
recursion and produce an error.

Bug: skia:12137
Change-Id: I39cc281fcd73fb30014bc7b43043552623727e03
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/431537
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/gn/sksl_tests.gni b/gn/sksl_tests.gni
index cd107b5..e118aa9 100644
--- a/gn/sksl_tests.gni
+++ b/gn/sksl_tests.gni
@@ -529,7 +529,9 @@
   "/sksl/runtime_errors/IllegalArrayOps.rts",
   "/sksl/runtime_errors/IllegalIndexing.rts",
   "/sksl/runtime_errors/IllegalOperators.rts",
-  "/sksl/runtime_errors/IllegalRecursion.rts",
+  "/sksl/runtime_errors/IllegalRecursionComplex.rts",
+  "/sksl/runtime_errors/IllegalRecursionMutual.rts",
+  "/sksl/runtime_errors/IllegalRecursionSimple.rts",
   "/sksl/runtime_errors/IllegalShaderUse.rts",
   "/sksl/runtime_errors/IllegalStatements.rts",
   "/sksl/runtime_errors/InvalidBlendMain.rtb",
diff --git a/resources/sksl/runtime_errors/IllegalRecursion.rts b/resources/sksl/runtime_errors/IllegalRecursion.rts
deleted file mode 100644
index ba14cf8..0000000
--- a/resources/sksl/runtime_errors/IllegalRecursion.rts
+++ /dev/null
@@ -1,14 +0,0 @@
-// Expect 3 errors
-
-// TODO(skia:12137) Today, we detect these as errors because we do not allow calls to undefined
-// functions. That produces three errors (one for each function calling an undefined function).
-// After we support calling declared (but not defined) functions, we should instead emit one
-// error per cycle.
-
-// Simple recursion is not allowed, even with branching:
-int fibonacci(int n) { return n <= 1 ? n : fibonacci(n - 1) + fibonacci(n - 2); }
-
-// We also detect more complex cycles in the call-graph of functions:
-bool is_even(int n);
-bool is_odd (int n) { return n == 0 ? false : is_even(n - 1); }
-bool is_even(int n) { return n == 0 ? true  : is_odd (n - 1); }
diff --git a/resources/sksl/runtime_errors/IllegalRecursionComplex.rts b/resources/sksl/runtime_errors/IllegalRecursionComplex.rts
new file mode 100644
index 0000000..f98a6d6
--- /dev/null
+++ b/resources/sksl/runtime_errors/IllegalRecursionComplex.rts
@@ -0,0 +1,34 @@
+// Expect 1 errors (with f_one(int), f_two, f_three in cycle)
+
+// Complex recursion spanning several functions with overloads, etc.
+
+void f_one(bool b);
+void f_one(int n);
+void f_two(int n);
+void f_three(int n);
+void f_four(int n);
+
+void f_one(bool b) {
+    int n = b ? 1 : 0;
+    f_one(n);
+}
+
+void f_one(int n) {
+    if (n > 0) {
+        f_four(n);
+    } else {
+        f_two(n);
+    }
+}
+
+void f_two(int n) {
+    for (int i = 0; i < 4; ++i) {
+        f_three(n);
+    }
+}
+
+void f_three(int n) {
+    f_one(n);
+}
+
+void f_four(int n) {}
diff --git a/resources/sksl/runtime_errors/IllegalRecursionMutual.rts b/resources/sksl/runtime_errors/IllegalRecursionMutual.rts
new file mode 100644
index 0000000..de43582
--- /dev/null
+++ b/resources/sksl/runtime_errors/IllegalRecursionMutual.rts
@@ -0,0 +1,6 @@
+// Expect 1 error
+
+// Straightforward mutual recursion (not allowed)
+bool is_even(int n);
+bool is_odd (int n) { return n == 0 ? false : is_even(n - 1); }
+bool is_even(int n) { return n == 0 ? true  : is_odd (n - 1); }
diff --git a/resources/sksl/runtime_errors/IllegalRecursionSimple.rts b/resources/sksl/runtime_errors/IllegalRecursionSimple.rts
new file mode 100644
index 0000000..cdffbc3
--- /dev/null
+++ b/resources/sksl/runtime_errors/IllegalRecursionSimple.rts
@@ -0,0 +1,4 @@
+// Expect 1 error
+
+// Simple recursion is not allowed, even with branching:
+int fibonacci(int n) { return n <= 1 ? n : fibonacci(n - 1) + fibonacci(n - 2); }
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index dcfb88a..06fed37 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -619,6 +619,122 @@
     return visitor.visit(program);
 }
 
+bool Analysis::DetectStaticRecursion(SkSpan<std::unique_ptr<ProgramElement>> programElements,
+                                     ErrorReporter& errors) {
+    using Function = const FunctionDeclaration;
+    using CallSet = std::unordered_set<Function*>;
+    using CallGraph = std::unordered_map<Function*, CallSet>;
+
+    class CallGraphVisitor : public ProgramVisitor {
+    public:
+        CallGraphVisitor(CallGraph* calls) : fCallGraph(calls), fCurrentFunctionCalls(nullptr) {}
+
+        bool visitExpression(const Expression& e) override {
+            if (e.is<FunctionCall>()) {
+                fCurrentFunctionCalls->insert(&e.as<FunctionCall>().function());
+            }
+            return INHERITED::visitExpression(e);
+        }
+
+        bool visitProgramElement(const ProgramElement& p) override {
+            if (p.is<FunctionDefinition>()) {
+                Function* fn = &p.as<FunctionDefinition>().declaration();
+                SkASSERT(fCallGraph->count(fn) == 0);
+
+                SkASSERT(fCurrentFunctionCalls == nullptr);
+                CallSet currentFunctionCalls;
+                fCurrentFunctionCalls = &currentFunctionCalls;
+
+                INHERITED::visitProgramElement(p);
+
+                fCurrentFunctionCalls = nullptr;
+                fCallGraph->insert({fn, std::move(currentFunctionCalls)});
+            }
+            return false;
+        }
+
+        CallGraph* fCallGraph;
+        CallSet*   fCurrentFunctionCalls;
+
+        using INHERITED = ProgramVisitor;
+    };
+
+    CallGraph callGraph;
+    CallGraphVisitor visitor{&callGraph};
+    for (const auto& pe : programElements) {
+        visitor.visitProgramElement(*pe);
+    }
+
+    class CycleFinder {
+    public:
+        CycleFinder(CallGraph* calls) : fCallGraph(calls) {}
+
+        bool containsCycle() {
+            for (const auto& [caller, callees] : *fCallGraph) {
+                SkASSERT(fStack.empty());
+                if (this->dfsHelper(caller)) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        const std::vector<Function*>& cycle() const { return fStack; }
+
+    private:
+        bool dfsHelper(Function* fn) {
+            SkASSERT(std::find(fStack.begin(), fStack.end(), fn) == fStack.end());
+            fStack.push_back(fn);
+
+            const CallSet& calls = (*fCallGraph)[fn];
+            for (Function* calledFn : calls) {
+                auto it = std::find(fStack.begin(), fStack.end(), calledFn);
+                if (it != fStack.end()) {
+                    // Cycle detected. It includes the functions from 'it' to the end of fStack
+                    fStack.erase(fStack.begin(), it);
+                    return true;
+                }
+                if (this->dfsHelper(calledFn)) {
+                    return true;
+                }
+            }
+
+            fStack.pop_back();
+            return false;
+        }
+
+        CallGraph*             fCallGraph;
+        std::vector<Function*> fStack;
+    };
+
+    CycleFinder cycleFinder{&callGraph};
+    if (cycleFinder.containsCycle()) {
+        // Get the description of each function participating in the cycle
+        std::vector<String> fnNames;
+        for (Function* fn : cycleFinder.cycle()) {
+            fnNames.push_back(fn->description());
+        }
+
+        // Find the lexicographically first function description, so we generate stable errors
+        std::vector<String>::iterator cycleStart = std::min_element(fnNames.begin(), fnNames.end());
+        ptrdiff_t startIndex = std::distance(fnNames.begin(), cycleStart);
+
+        // Construct a list of the functions participating in the cycle (including the "start"
+        // at both the beginning and end):
+        String cycleDescription;
+        for (size_t i = 0; i <= fnNames.size(); ++i) {
+            cycleDescription += "\n\t" + fnNames[(i + startIndex) % fnNames.size()];
+        }
+
+        // Go back to the original data to find the offset of the cycle start's declaration
+        Function* cycleStartFn = cycleFinder.cycle()[startIndex];
+        errors.error(cycleStartFn->fOffset,
+                     "potential recursion (function call cycle) not allowed:" + cycleDescription);
+        return true;
+    }
+    return false;
+}
+
 int Analysis::NodeCountUpToLimit(const FunctionDefinition& function, int limit) {
     return NodeCountVisitor{limit}.visit(*function.body());
 }
diff --git a/src/sksl/SkSLAnalysis.h b/src/sksl/SkSLAnalysis.h
index 596bfcb..c2126db 100644
--- a/src/sksl/SkSLAnalysis.h
+++ b/src/sksl/SkSLAnalysis.h
@@ -8,6 +8,7 @@
 #ifndef SkSLAnalysis_DEFINED
 #define SkSLAnalysis_DEFINED
 
+#include "include/core/SkSpan.h"
 #include "include/private/SkSLDefines.h"
 #include "include/private/SkSLSampleUsage.h"
 
@@ -52,6 +53,12 @@
 
     static bool CallsSampleOutsideMain(const Program& program);
 
+    /*
+     * Does the function call graph of the program include any cycles? If so, emits an error.
+     */
+    static bool DetectStaticRecursion(SkSpan<std::unique_ptr<ProgramElement>> programElements,
+                                      ErrorReporter& errors);
+
     static int NodeCountUpToLimit(const FunctionDefinition& function, int limit);
 
     /**
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index e8392b9..0f4bd7e 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1687,6 +1687,10 @@
         }
     }
 
+    if (this->strictES2Mode()) {
+        Analysis::DetectStaticRecursion(SkMakeSpan(*fProgramElements), this->errorReporter());
+    }
+
     return IRBundle{std::move(*fProgramElements),
                     std::move(*fSharedElements),
                     std::move(fSymbolTable),
diff --git a/src/sksl/codegen/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/codegen/SkSLPipelineStageCodeGenerator.cpp
index bcd8035..c6e2c90 100644
--- a/src/sksl/codegen/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLPipelineStageCodeGenerator.cpp
@@ -21,6 +21,7 @@
 #include "src/sksl/ir/SkSLFunctionCall.h"
 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
 #include "src/sksl/ir/SkSLFunctionDefinition.h"
+#include "src/sksl/ir/SkSLFunctionPrototype.h"
 #include "src/sksl/ir/SkSLIfStatement.h"
 #include "src/sksl/ir/SkSLIndexExpression.h"
 #include "src/sksl/ir/SkSLPostfixExpression.h"
@@ -63,7 +64,9 @@
     String typeName(const Type& type);
     void writeType(const Type& type);
 
+    String functionName(const FunctionDeclaration& decl);
     void writeFunction(const FunctionDefinition& f);
+    void writeFunctionPrototype(const FunctionPrototype& f);
 
     String modifierString(const Modifiers& modifiers);
 
@@ -251,6 +254,21 @@
     this->write(";");
 }
 
+String PipelineStageCodeGenerator::functionName(const FunctionDeclaration& decl) {
+    if (decl.isMain()) {
+        return String(decl.name());
+    }
+
+    auto it = fFunctionNames.find(&decl);
+    if (it != fFunctionNames.end()) {
+        return it->second;
+    }
+
+    String mangledName = fCallbacks->getMangledName(String(decl.name()).c_str());
+    fFunctionNames.insert({&decl, mangledName});
+    return mangledName;
+}
+
 void PipelineStageCodeGenerator::writeFunction(const FunctionDefinition& f) {
     AutoOutputBuffer body(this);
 
@@ -273,8 +291,7 @@
         fCastReturnsToHalf = false;
     }
 
-    String fnName = decl.isMain() ? String(decl.name())
-                                  : fCallbacks->getMangledName(String(decl.name()).c_str());
+    String fnName = this->functionName(decl);
 
     // This is similar to decl.description(), but substitutes a mangled name, and handles modifiers
     // on the function (e.g. `inline`) and its parameters (e.g. `inout`).
@@ -296,10 +313,14 @@
     }
     declString.append(")");
 
-    fFunctionNames.insert({&decl, std::move(fnName)});
     fCallbacks->defineFunction(declString.c_str(), body.fBuffer.str().c_str(), decl.isMain());
 }
 
+void PipelineStageCodeGenerator::writeFunctionPrototype(const FunctionPrototype& f) {
+    const FunctionDeclaration& decl = f.declaration();
+    (void)this->functionName(decl);
+}
+
 void PipelineStageCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& g) {
     const VarDeclaration& decl = g.declaration()->as<VarDeclaration>();
     const Variable& var = decl.var();
@@ -347,8 +368,7 @@
             this->writeFunction(e.as<FunctionDefinition>());
             break;
         case ProgramElement::Kind::kFunctionPrototype:
-            // Runtime effects don't allow calls to undefined functions, so prototypes are never
-            // necessary. If we do support them, they should emit calls to emitFunctionPrototype.
+            this->writeFunctionPrototype(e.as<FunctionPrototype>());
             break;
         case ProgramElement::Kind::kStructDefinition:
             this->writeStructDefinition(e.as<StructDefinition>());
diff --git a/src/sksl/ir/SkSLFunctionCall.cpp b/src/sksl/ir/SkSLFunctionCall.cpp
index a2c1e82..22e45fb 100644
--- a/src/sksl/ir/SkSLFunctionCall.cpp
+++ b/src/sksl/ir/SkSLFunctionCall.cpp
@@ -678,14 +678,6 @@
         return nullptr;
     }
 
-    // GLSL ES 1.0 requires static recursion be rejected by the compiler. Also, our CPU back-end
-    // cannot handle recursion (and is tied to strictES2Mode front-ends). The safest way to reject
-    // all (potentially) recursive code is to disallow calls to functions before they're defined.
-    if (context.fConfig->strictES2Mode() && !function.definition() && !function.isBuiltin()) {
-        context.fErrors.error(offset, "call to undefined function '" + function.name() + "'");
-        return nullptr;
-    }
-
     // Resolve generic types.
     FunctionDeclaration::ParamTypes types;
     const Type* returnType;
@@ -729,7 +721,6 @@
                                                const FunctionDeclaration& function,
                                                ExpressionArray arguments) {
     SkASSERT(function.parameters().size() == arguments.size());
-    SkASSERT(function.definition() || function.isBuiltin() || !context.fConfig->strictES2Mode());
 
     if (context.fConfig->fSettings.fOptimize) {
         // We might be able to optimize built-in intrinsics.
diff --git a/tests/SkRuntimeEffectTest.cpp b/tests/SkRuntimeEffectTest.cpp
index ec784e5..9e2ebda 100644
--- a/tests/SkRuntimeEffectTest.cpp
+++ b/tests/SkRuntimeEffectTest.cpp
@@ -52,7 +52,7 @@
 
 DEF_TEST(SkRuntimeEffectInvalid_UndefinedFunction, r) {
     test_invalid_effect(r, "half4 missing(); half4 main(float2 p) { return missing(); }",
-                           "undefined function");
+                           "function 'half4 missing()' is not defined");
 }
 
 DEF_TEST(SkRuntimeEffectInvalid_UndefinedMain, r) {
diff --git a/tests/SkSLTest.cpp b/tests/SkSLTest.cpp
index 5c32142..9d1565b 100644
--- a/tests/SkSLTest.cpp
+++ b/tests/SkSLTest.cpp
@@ -245,6 +245,7 @@
 SKSL_TEST(SkSLFunctionArgTypeMatch,            "shared/FunctionArgTypeMatch.sksl")
 SKSL_TEST(SkSLFunctionReturnTypeMatch,         "shared/FunctionReturnTypeMatch.sksl")
 SKSL_TEST(SkSLFunctions,                       "shared/Functions.sksl")
+SKSL_TEST(SkSLFunctionPrototype,               "shared/FunctionPrototype.sksl")
 SKSL_TEST(SkSLGeometricIntrinsics,             "shared/GeometricIntrinsics.sksl")
 SKSL_TEST(SkSLHelloWorld,                      "shared/HelloWorld.sksl")
 SKSL_TEST(SkSLHex,                             "shared/Hex.sksl")
@@ -292,14 +293,6 @@
 SKSL_TEST_ES3(SkSLWhileLoopControlFlow,        "shared/WhileLoopControlFlow.sksl")
 
 /*
-// Incompatible with Runtime Effects because calling a function before its definition is disallowed.
-// (This was done to prevent recursion, as required by ES2.)
-// TODO(skia:12137) Enable this test once we specifically detect recursion, rather than just
-// calling functions before definition.
-SKSL_TEST(SkSLFunctionPrototype,               "shared/FunctionPrototype.sksl")
-*/
-
-/*
 TODO(skia:11209): enable these tests when Runtime Effects have support for ES3
 
 SKSL_TEST(SkSLMatrixFoldingES3,                "folding/MatrixFoldingES3.sksl")
diff --git a/tests/sksl/runtime_errors/IllegalRecursion.skvm b/tests/sksl/runtime_errors/IllegalRecursion.skvm
deleted file mode 100644
index 69dae42..0000000
--- a/tests/sksl/runtime_errors/IllegalRecursion.skvm
+++ /dev/null
@@ -1,6 +0,0 @@
-### Compilation failed:
-
-error: 9: call to undefined function 'fibonacci'
-error: 13: call to undefined function 'is_even'
-error: 14: call to undefined function 'is_odd'
-3 errors
diff --git a/tests/sksl/runtime_errors/IllegalRecursionComplex.skvm b/tests/sksl/runtime_errors/IllegalRecursionComplex.skvm
new file mode 100644
index 0000000..ece80d3
--- /dev/null
+++ b/tests/sksl/runtime_errors/IllegalRecursionComplex.skvm
@@ -0,0 +1,8 @@
+### Compilation failed:
+
+error: 6: potential recursion (function call cycle) not allowed:
+	void f_one(int n)
+	void f_two(int n)
+	void f_three(int n)
+	void f_one(int n)
+1 error
diff --git a/tests/sksl/runtime_errors/IllegalRecursionMutual.skvm b/tests/sksl/runtime_errors/IllegalRecursionMutual.skvm
new file mode 100644
index 0000000..84a613b
--- /dev/null
+++ b/tests/sksl/runtime_errors/IllegalRecursionMutual.skvm
@@ -0,0 +1,7 @@
+### Compilation failed:
+
+error: 4: potential recursion (function call cycle) not allowed:
+	bool is_even(int n)
+	bool is_odd(int n)
+	bool is_even(int n)
+1 error
diff --git a/tests/sksl/runtime_errors/IllegalRecursionSimple.skvm b/tests/sksl/runtime_errors/IllegalRecursionSimple.skvm
new file mode 100644
index 0000000..6659d20
--- /dev/null
+++ b/tests/sksl/runtime_errors/IllegalRecursionSimple.skvm
@@ -0,0 +1,6 @@
+### Compilation failed:
+
+error: 4: potential recursion (function call cycle) not allowed:
+	int fibonacci(int n)
+	int fibonacci(int n)
+1 error