Eliminate DSLFunction entirely.

This CL also makes a minor intentional change which lets the error
reporter point more accurately at function return types. Previously,
when complaining about a return type, the error would highlight the
whole function declaration; now, it highlights the start of the
function.

Change-Id: I3974db1e08b0f0b77ed3e804e7c7416c8f0559d9
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/700228
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/gn/sksl.gni b/gn/sksl.gni
index b93ef0f..c400752 100644
--- a/gn/sksl.gni
+++ b/gn/sksl.gni
@@ -117,8 +117,6 @@
   "$_src/sksl/codegen/SkSLVMCodeGenerator.h",
   "$_src/sksl/dsl/DSLExpression.cpp",
   "$_src/sksl/dsl/DSLExpression.h",
-  "$_src/sksl/dsl/DSLFunction.cpp",
-  "$_src/sksl/dsl/DSLFunction.h",
   "$_src/sksl/dsl/DSLModifiers.h",
   "$_src/sksl/dsl/DSLStatement.cpp",
   "$_src/sksl/dsl/DSLStatement.h",
diff --git a/public.bzl b/public.bzl
index 0b956f1..390c2c3 100644
--- a/public.bzl
+++ b/public.bzl
@@ -1508,8 +1508,6 @@
     "src/sksl/codegen/SkSLWGSLCodeGenerator.h",
     "src/sksl/dsl/DSLExpression.cpp",
     "src/sksl/dsl/DSLExpression.h",
-    "src/sksl/dsl/DSLFunction.cpp",
-    "src/sksl/dsl/DSLFunction.h",
     "src/sksl/dsl/DSLModifiers.h",
     "src/sksl/dsl/DSLStatement.cpp",
     "src/sksl/dsl/DSLStatement.h",
diff --git a/src/sksl/SkSLParser.cpp b/src/sksl/SkSLParser.cpp
index f393cf1..9be8964 100644
--- a/src/sksl/SkSLParser.cpp
+++ b/src/sksl/SkSLParser.cpp
@@ -19,7 +19,6 @@
 #include "src/sksl/SkSLOperator.h"
 #include "src/sksl/SkSLString.h"
 #include "src/sksl/SkSLThreadContext.h"
-#include "src/sksl/dsl/DSLFunction.h"
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBlock.h"
 #include "src/sksl/ir/SkSLBreakStatement.h"
@@ -31,6 +30,9 @@
 #include "src/sksl/ir/SkSLFieldAccess.h"
 #include "src/sksl/ir/SkSLForStatement.h"
 #include "src/sksl/ir/SkSLFunctionCall.h"
+#include "src/sksl/ir/SkSLFunctionDeclaration.h"
+#include "src/sksl/ir/SkSLFunctionDefinition.h"
+#include "src/sksl/ir/SkSLFunctionPrototype.h"
 #include "src/sksl/ir/SkSLIfStatement.h"
 #include "src/sksl/ir/SkSLIndexExpression.h"
 #include "src/sksl/ir/SkSLInterfaceBlock.h"
@@ -568,10 +570,11 @@
 /* (RPAREN | VOID RPAREN | parameter (COMMA parameter)* RPAREN) (block | SEMICOLON) */
 bool Parser::functionDeclarationEnd(Position start,
                                     DSLModifiers& modifiers,
-                                    DSLType type,
+                                    DSLType returnType,
                                     const Token& name) {
-    STArray<8, std::unique_ptr<Variable>> parameters;
     Token lookahead = this->peek();
+    bool validParams = true;
+    STArray<8, std::unique_ptr<Variable>> parameters;
     if (lookahead.fKind == Token::Kind::TK_RPAREN) {
         // `()` means no parameters at all.
     } else if (lookahead.fKind == Token::Kind::TK_IDENTIFIER && this->text(lookahead) == "void") {
@@ -583,6 +586,7 @@
             if (!this->parameter(&param)) {
                 return false;
             }
+            validParams = validParams && param;
             parameters.push_back(std::move(param));
             if (!this->checkNext(Token::Kind::TK_COMMA)) {
                 break;
@@ -593,22 +597,66 @@
         return false;
     }
 
-    DSLFunction result(this->text(name), modifiers, type, std::move(parameters),
-                       this->rangeFrom(start));
-
-    const bool hasFunctionBody = !this->checkNext(Token::Kind::TK_SEMICOLON);
-    if (hasFunctionBody) {
-        AutoSymbolTable symbols(this);
-        result.addParametersToSymbolTable(fCompiler.context());
-        Token bodyStart = this->peek();
-        std::optional<DSLStatement> body = this->block();
-        if (!body) {
-            return false;
-        }
-        result.define(std::move(*body), this->rangeFrom(bodyStart));
-    } else {
-        result.prototype();
+    SkSL::FunctionDeclaration* decl = nullptr;
+    if (validParams) {
+        decl = SkSL::FunctionDeclaration::Convert(ThreadContext::Context(),
+                                                  this->rangeFrom(start),
+                                                  modifiers.fPosition,
+                                                  &modifiers.fModifiers,
+                                                  this->text(name),
+                                                  std::move(parameters),
+                                                  start,
+                                                  &returnType.skslType());
     }
+
+    if (this->checkNext(Token::Kind::TK_SEMICOLON)) {
+        return this->prototypeFunction(decl);
+    } else {
+        return this->defineFunction(decl);
+    }
+}
+
+bool Parser::prototypeFunction(SkSL::FunctionDeclaration* decl) {
+    if (!decl) {
+        return false;
+    }
+    ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::FunctionPrototype>(
+            decl->fPosition, decl, fCompiler.context().fConfig->fIsBuiltinCode));
+    return true;
+}
+
+bool Parser::defineFunction(SkSL::FunctionDeclaration* decl) {
+    // Create a symbol table for the function parameters.
+    const Context& context = fCompiler.context();
+    AutoSymbolTable symbols(this);
+    if (decl) {
+        decl->addParametersToSymbolTable(context);
+    }
+
+    // Parse the function body.
+    Token bodyStart = this->peek();
+    std::optional<DSLStatement> body = this->block();
+
+    // If there was a problem with the declarations or body, don't actually create a definition.
+    if (!decl || !body) {
+        return false;
+    }
+
+    std::unique_ptr<SkSL::Statement> block = body->release();
+    SkASSERT(block->is<Block>());
+    Position pos = this->rangeFrom(bodyStart);
+    block->fPosition = pos;
+
+    std::unique_ptr<FunctionDefinition> function = FunctionDefinition::Convert(context,
+                                                                               pos,
+                                                                               *decl,
+                                                                               std::move(block),
+                                                                               /*builtin=*/false);
+    if (!function) {
+        return false;
+    }
+    decl->setDefinition(function.get());
+    ThreadContext::ProgramElements().push_back(std::move(function));
     return true;
 }
 
diff --git a/src/sksl/SkSLParser.h b/src/sksl/SkSLParser.h
index bc4913f..ee5c8ed 100644
--- a/src/sksl/SkSLParser.h
+++ b/src/sksl/SkSLParser.h
@@ -31,6 +31,7 @@
 class Compiler;
 class ErrorReporter;
 class Expression;
+class FunctionDeclaration;
 class SymbolTable;
 enum class ProgramKind : int8_t;
 struct Module;
@@ -148,9 +149,13 @@
 
     bool functionDeclarationEnd(Position start,
                                 dsl::DSLModifiers& modifiers,
-                                dsl::DSLType type,
+                                dsl::DSLType returnType,
                                 const Token& name);
 
+    bool prototypeFunction(SkSL::FunctionDeclaration* decl);
+
+    bool defineFunction(SkSL::FunctionDeclaration* decl);
+
     struct VarDeclarationsPrefix {
         Position fPosition;
         dsl::DSLModifiers fModifiers;
diff --git a/src/sksl/dsl/BUILD.bazel b/src/sksl/dsl/BUILD.bazel
index 669c87b..d2e32a9 100644
--- a/src/sksl/dsl/BUILD.bazel
+++ b/src/sksl/dsl/BUILD.bazel
@@ -9,8 +9,6 @@
     srcs = [
         "DSLExpression.cpp",
         "DSLExpression.h",
-        "DSLFunction.cpp",
-        "DSLFunction.h",
         "DSLModifiers.h",
         "DSLStatement.cpp",
         "DSLStatement.h",
diff --git a/src/sksl/dsl/DSLFunction.cpp b/src/sksl/dsl/DSLFunction.cpp
deleted file mode 100644
index 49e5a49..0000000
--- a/src/sksl/dsl/DSLFunction.cpp
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * Copyright 2021 Google LLC.
- *
- * Use of this source code is governed by a BSD-style license that can be
- * found in the LICENSE file.
- */
-
-#include "src/sksl/dsl/DSLFunction.h"
-
-#include "include/core/SkTypes.h"
-#include "include/private/SkSLDefines.h"
-#include "include/private/base/SkTArray.h"
-#include "src/sksl/SkSLContext.h"
-#include "src/sksl/SkSLIntrinsicList.h"
-#include "src/sksl/SkSLProgramSettings.h"
-#include "src/sksl/SkSLString.h"
-#include "src/sksl/SkSLThreadContext.h"
-#include "src/sksl/dsl/DSLModifiers.h"
-#include "src/sksl/dsl/DSLType.h"
-#include "src/sksl/ir/SkSLBlock.h"
-#include "src/sksl/ir/SkSLExpression.h"
-#include "src/sksl/ir/SkSLFunctionCall.h"
-#include "src/sksl/ir/SkSLFunctionDeclaration.h"
-#include "src/sksl/ir/SkSLFunctionDefinition.h"
-#include "src/sksl/ir/SkSLFunctionPrototype.h"
-#include "src/sksl/ir/SkSLProgramElement.h"
-#include "src/sksl/ir/SkSLStatement.h"
-#include "src/sksl/ir/SkSLVariable.h"
-
-#include <memory>
-#include <string>
-#include <utility>
-#include <vector>
-
-using namespace skia_private;
-
-namespace SkSL::dsl {
-
-DSLFunction::DSLFunction(std::string_view name,
-                         const DSLModifiers& modifiers,
-                         const DSLType& returnType,
-                         TArray<std::unique_ptr<SkSL::Variable>> parameters,
-                         Position pos) {
-    this->init(modifiers, returnType, name, std::move(parameters), pos);
-}
-
-static bool is_intrinsic_in_module(const Context& context, std::string_view name) {
-    return context.fConfig->fIsBuiltinCode && SkSL::FindIntrinsicKind(name) != kNotIntrinsic;
-}
-
-void DSLFunction::init(DSLModifiers modifiers,
-                       const DSLType& returnType,
-                       std::string_view name,
-                       TArray<std::unique_ptr<SkSL::Variable>> params,
-                       Position pos) {
-    for (const std::unique_ptr<SkSL::Variable>& param : params) {
-        if (!param) {
-            // We failed to create one of the params; an error should already have been reported.
-            return;
-        }
-    }
-    fPosition = pos;
-    fDecl = SkSL::FunctionDeclaration::Convert(ThreadContext::Context(),
-                                               pos,
-                                               modifiers.fPosition,
-                                               &modifiers.fModifiers,
-                                               name,
-                                               std::move(params),
-                                               pos,
-                                               &returnType.skslType());
-}
-
-void DSLFunction::prototype() {
-    if (!fDecl) {
-        // We failed to create the declaration; error should already have been reported.
-        return;
-    }
-    ThreadContext::ProgramElements().push_back(std::make_unique<SkSL::FunctionPrototype>(
-            fDecl->fPosition, fDecl, ThreadContext::IsModule()));
-}
-
-void DSLFunction::define(DSLStatement block, Position pos) {
-    std::unique_ptr<SkSL::Statement> body = block.release();
-    SkASSERT(body->is<Block>());
-    body->fPosition = pos;
-    if (!fDecl) {
-        // We failed to create the declaration; error should already have been reported.
-        return;
-    }
-    // We don't allow modules to define actual functions with intrinsic names. (Those should be
-    // reserved for actual intrinsics.)
-    const Context& context = ThreadContext::Context();
-    if (is_intrinsic_in_module(context, fDecl->name())) {
-        ThreadContext::ReportError(
-                SkSL::String::printf("Intrinsic function '%.*s' should not have a definition",
-                                     (int)fDecl->name().size(),
-                                     fDecl->name().data()),
-                fDecl->fPosition);
-        return;
-    }
-
-    if (fDecl->definition()) {
-        ThreadContext::ReportError(SkSL::String::printf("function '%s' was already defined",
-                                                        fDecl->description().c_str()),
-                                   fDecl->fPosition);
-        return;
-    }
-    std::unique_ptr<FunctionDefinition> function = FunctionDefinition::Convert(
-            ThreadContext::Context(),
-            pos,
-            *fDecl,
-            std::move(body),
-            /*builtin=*/false);
-    fDecl->setDefinition(function.get());
-    ThreadContext::ProgramElements().push_back(std::move(function));
-}
-
-DSLExpression DSLFunction::call(ExpressionArray args, Position pos) {
-    std::unique_ptr<SkSL::Expression> result =
-            SkSL::FunctionCall::Convert(ThreadContext::Context(), pos, *fDecl, std::move(args));
-    return DSLExpression(std::move(result), pos);
-}
-
-void DSLFunction::addParametersToSymbolTable(const Context& context) {
-    if (fDecl) {
-        fDecl->addParametersToSymbolTable(context);
-    }
-}
-
-}  // namespace SkSL::dsl
diff --git a/src/sksl/dsl/DSLFunction.h b/src/sksl/dsl/DSLFunction.h
deleted file mode 100644
index bd35c5c..0000000
--- a/src/sksl/dsl/DSLFunction.h
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Copyright 2021 Google LLC.
- *
- * Use of this source code is governed by a BSD-style license that can be
- * found in the LICENSE file.
- */
-
-#ifndef SKSL_DSL_FUNCTION
-#define SKSL_DSL_FUNCTION
-
-#include "include/private/base/SkTArray.h"
-#include "src/sksl/SkSLPosition.h"
-#include "src/sksl/dsl/DSLExpression.h"
-#include "src/sksl/dsl/DSLStatement.h"
-
-#include <string_view>
-#include <memory>
-
-namespace SkSL {
-
-class Context;
-class ExpressionArray;
-class FunctionDeclaration;
-class Variable;
-
-namespace dsl {
-
-class DSLType;
-struct DSLModifiers;
-
-class DSLFunction {
-public:
-    DSLFunction(std::string_view name,
-                const DSLModifiers& modifiers,
-                const DSLType& returnType,
-                skia_private::TArray<std::unique_ptr<SkSL::Variable>> parameters,
-                Position pos = {});
-
-    DSLFunction(SkSL::FunctionDeclaration* decl)
-            : fDecl(decl) {}
-
-    virtual ~DSLFunction() = default;
-
-    void define(DSLStatement block, Position pos = {});
-
-    void prototype();
-
-    void addParametersToSymbolTable(const Context& context);
-
-    /**
-     * Invokes the function with the given arguments.
-     */
-    DSLExpression call(ExpressionArray args, Position pos = {});
-
-private:
-    void init(DSLModifiers modifiers,
-              const DSLType& returnType,
-              std::string_view name,
-              skia_private::TArray<std::unique_ptr<SkSL::Variable>> params,
-              Position pos);
-
-    SkSL::FunctionDeclaration* fDecl = nullptr;
-    SkSL::Position fPosition;
-};
-
-} // namespace dsl
-
-} // namespace SkSL
-
-#endif
diff --git a/src/sksl/ir/SkSLFunctionDefinition.cpp b/src/sksl/ir/SkSLFunctionDefinition.cpp
index 67aeb05..963dfc6 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.cpp
+++ b/src/sksl/ir/SkSLFunctionDefinition.cpp
@@ -18,6 +18,7 @@
 #include "src/sksl/SkSLErrorReporter.h"
 #include "src/sksl/SkSLOperator.h"
 #include "src/sksl/SkSLProgramSettings.h"
+#include "src/sksl/SkSLString.h"
 #include "src/sksl/SkSLThreadContext.h"
 #include "src/sksl/ir/SkSLBinaryExpression.h"
 #include "src/sksl/ir/SkSLBlock.h"
@@ -341,6 +342,27 @@
         using INHERITED = ProgramWriter;
     };
 
+    // We don't allow modules to define actual functions with intrinsic names. (Those should be
+    // reserved for actual intrinsics.)
+    if (function.isIntrinsic()) {
+        context.fErrors->error(function.fPosition,
+                               SkSL::String::printf("Intrinsic function '%.*s' should not have "
+                                                    "a definition",
+                                                    (int)function.name().size(),
+                                                    function.name().data()));
+        return nullptr;
+    }
+
+    // A function can't have more than one definition.
+    if (function.definition()) {
+        context.fErrors->error(function.fPosition,
+                               SkSL::String::printf("function '%s' was already defined",
+                                                    function.description().c_str()));
+        return nullptr;
+    }
+
+    // Run the function finalizer. This checks for illegal constructs and missing return statements,
+    // and also performs some simple code cleanup.
     Finalizer(context, function, pos).visitStatementPtr(body);
     if (function.isMain() && ProgramConfig::IsVertex(context.fConfig->fKind)) {
         append_rtadjust_fixup_to_vertex_main(context, function, body->as<Block>());
@@ -351,8 +373,6 @@
                                                 "' can exit without returning a value");
     }
 
-    SkASSERTF(!function.isIntrinsic(), "Intrinsic function '%.*s' should not have a definition",
-              (int)function.name().size(), function.name().data());
     return std::make_unique<FunctionDefinition>(pos, &function, builtin, std::move(body));
 }
 
diff --git a/tests/sksl/errors/ArrayReturnTypes.glsl b/tests/sksl/errors/ArrayReturnTypes.glsl
index cdb1721..5ee517d 100644
--- a/tests/sksl/errors/ArrayReturnTypes.glsl
+++ b/tests/sksl/errors/ArrayReturnTypes.glsl
@@ -2,8 +2,8 @@
 
 error: 1: functions may not return type 'float4x4[2]'
 float4x4[2] return_float4x4_2() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^^^^^^^^
 error: 2: functions may not return type 'int[1]'
 int[1]      return_int_1()      {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^
+^^^
 2 errors
diff --git a/tests/sksl/errors/ArrayUnspecifiedDimensions.asm.frag b/tests/sksl/errors/ArrayUnspecifiedDimensions.asm.frag
index 81c7e43..99fa1f7 100644
--- a/tests/sksl/errors/ArrayUnspecifiedDimensions.asm.frag
+++ b/tests/sksl/errors/ArrayUnspecifiedDimensions.asm.frag
@@ -62,8 +62,8 @@
                                  ^^
 error: 24: functions may not return type 'int[]'
 int[] unsized_in_return_type_a() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^^^
 error: 25: functions may not return type 'S[]'
 S[]   unsized_in_return_type_b() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^
 22 errors
diff --git a/tests/sksl/errors/ArrayUnspecifiedDimensions.glsl b/tests/sksl/errors/ArrayUnspecifiedDimensions.glsl
index 81c7e43..99fa1f7 100644
--- a/tests/sksl/errors/ArrayUnspecifiedDimensions.glsl
+++ b/tests/sksl/errors/ArrayUnspecifiedDimensions.glsl
@@ -62,8 +62,8 @@
                                  ^^
 error: 24: functions may not return type 'int[]'
 int[] unsized_in_return_type_a() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^^^
 error: 25: functions may not return type 'S[]'
 S[]   unsized_in_return_type_b() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^
 22 errors
diff --git a/tests/sksl/errors/ArrayUnspecifiedDimensions.metal b/tests/sksl/errors/ArrayUnspecifiedDimensions.metal
index 81c7e43..99fa1f7 100644
--- a/tests/sksl/errors/ArrayUnspecifiedDimensions.metal
+++ b/tests/sksl/errors/ArrayUnspecifiedDimensions.metal
@@ -62,8 +62,8 @@
                                  ^^
 error: 24: functions may not return type 'int[]'
 int[] unsized_in_return_type_a() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^^^
 error: 25: functions may not return type 'S[]'
 S[]   unsized_in_return_type_b() {}
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+^
 22 errors
diff --git a/tests/sksl/errors/OverloadedBuiltin.glsl b/tests/sksl/errors/OverloadedBuiltin.glsl
index 6495f3b..32b1fa0 100644
--- a/tests/sksl/errors/OverloadedBuiltin.glsl
+++ b/tests/sksl/errors/OverloadedBuiltin.glsl
@@ -17,10 +17,10 @@
 ^^^^^^^^^^^^^^^^^
 error: 14: functions 'float2 cos(half2 a)' and '$pure $genHType cos($genHType angle)' differ only in return type
 float2 cos(half2 a) { return 0; /* error: overloads a builtin (despite return type mismatch) */ }
-^^^^^^^^^^^^^^^^^^^
+^^^^^^
 error: 15: functions 'int cos(out half3 a)' and '$pure $genHType cos($genHType angle)' differ only in return type
 int cos(out half3 a) { return 0; /* error: overloads a builtin (despite return type mismatch) */ }
-^^^^^^^^^^^^^^^^^^^^
+^^^
 error: 17: duplicate definition of 'float pow(float x, float y)'
 float pow(float x, float y) { return 0; /* error: overloads a builtin */ }
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/sksl/errors/ReturnDifferentType.glsl b/tests/sksl/errors/ReturnDifferentType.glsl
index 789f0e3..7dc309a 100644
--- a/tests/sksl/errors/ReturnDifferentType.glsl
+++ b/tests/sksl/errors/ReturnDifferentType.glsl
@@ -2,5 +2,5 @@
 
 error: 2: functions 'void func()' and 'int func()' differ only in return type
 void func() {}
-^^^^^^^^^^^
+^^^^
 1 error