Eliminated unnecessary arrays in DSLParser variable declarations

Change-Id: Ia14371ef3bb83928f6ee93120bcc29de9db6c020
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/448269
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLDSLParser.cpp b/src/sksl/SkSLDSLParser.cpp
index 716589d..0c4183b 100644
--- a/src/sksl/SkSLDSLParser.cpp
+++ b/src/sksl/SkSLDSLParser.cpp
@@ -9,6 +9,7 @@
 
 #include "include/private/SkSLString.h"
 #include "src/sksl/SkSLCompiler.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
 
 #include <memory>
 
@@ -298,8 +299,7 @@
         return true;
     }
     if (lookahead.fKind == Token::Kind::TK_STRUCT) {
-        SkTArray<DSLGlobalVar> result = this->structVarDeclaration(modifiers);
-        Declare(result);
+        this->structVarDeclaration(modifiers);
         return true;
     }
     skstd::optional<DSLType> type = this->type(modifiers);
@@ -313,10 +313,7 @@
     if (this->checkNext(Token::Kind::TK_LPAREN)) {
         return this->functionDeclarationEnd(modifiers, *type, name);
     } else {
-        SkTArray<DSLGlobalVar> result = this->varDeclarationEnd<DSLGlobalVar>(this->position(name),
-                                                                              modifiers, *type,
-                                                                              this->text(name));
-        Declare(result);
+        this->globalVarDeclarationEnd(this->position(name), modifiers, *type, this->text(name));
         return true;
     }
 }
@@ -366,14 +363,6 @@
     return true;
 }
 
-static skstd::optional<DSLStatement> declaration_statements(SkTArray<DSLVar> vars,
-                                                            SymbolTable& symbols) {
-    if (vars.empty()) {
-        return skstd::nullopt;
-    }
-    return Declare(vars);
-}
-
 static bool is_valid(const skstd::optional<DSLWrapper<DSLExpression>>& expr) {
     return expr && expr->get().isValid();
 }
@@ -407,63 +396,102 @@
     }
 }
 
-template<class T>
-SkTArray<T> DSLParser::varDeclarationEnd(PositionInfo pos, const dsl::DSLModifiers& mods,
-                                         dsl::DSLType baseType, skstd::string_view name) {
-    using namespace dsl;
-    SkTArray<T> result;
-    int offset = this->peek().fOffset;
-    auto parseArrayDimensions = [&](DSLType* type) -> bool {
-        while (this->checkNext(Token::Kind::TK_LBRACKET)) {
-            if (this->checkNext(Token::Kind::TK_RBRACKET)) {
-                this->error(offset, "expected array dimension");
-            } else {
-                *type = Array(*type, this->arraySize(), pos);
-                if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) {
-                    return {};
-                }
-            }
-        }
-        return true;
-    };
-    auto parseInitializer = [this](DSLExpression* initializer) -> bool {
-        if (this->checkNext(Token::Kind::TK_EQ)) {
-            skstd::optional<DSLWrapper<DSLExpression>> value = this->assignmentExpression();
-            if (!value) {
+bool DSLParser::parseArrayDimensions(int offset, DSLType* type) {
+    while (this->checkNext(Token::Kind::TK_LBRACKET)) {
+        if (this->checkNext(Token::Kind::TK_RBRACKET)) {
+            this->error(offset, "expected array dimension");
+        } else {
+            *type = Array(*type, this->arraySize(), this->position(offset));
+            if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) {
                 return false;
             }
-            initializer->swap(**value);
         }
-        return true;
-    };
+    }
+    return true;
+}
 
+bool DSLParser::parseInitializer(int offset, DSLExpression* initializer) {
+    if (this->checkNext(Token::Kind::TK_EQ)) {
+        skstd::optional<DSLWrapper<DSLExpression>> value = this->assignmentExpression();
+        if (!value) {
+            return false;
+        }
+        initializer->swap(**value);
+    }
+    return true;
+}
+
+/* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
+   (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
+void DSLParser::globalVarDeclarationEnd(PositionInfo pos, const dsl::DSLModifiers& mods,
+        dsl::DSLType baseType, skstd::string_view name) {
+    using namespace dsl;
+    int offset = this->peek().fOffset;
     DSLType type = baseType;
     DSLExpression initializer;
-    if (!parseArrayDimensions(&type)) {
-        return {};
+    if (!this->parseArrayDimensions(offset, &type)) {
+        return;
     }
-    parseInitializer(&initializer);
-    result.push_back(T(mods, type, name, std::move(initializer), pos));
-    AddToSymbolTable(result.back());
+    this->parseInitializer(offset, &initializer);
+    DSLGlobalVar first(mods, type, name, std::move(initializer), pos);
+    Declare(first);
+    AddToSymbolTable(first);
 
     while (this->checkNext(Token::Kind::TK_COMMA)) {
         type = baseType;
         Token identifierName;
         if (!this->expectIdentifier(&identifierName)) {
-            return result;
+            return;
         }
-        if (!parseArrayDimensions(&type)) {
-            return result;
+        if (!this->parseArrayDimensions(offset, &type)) {
+            return;
         }
         DSLExpression anotherInitializer;
-        if (!parseInitializer(&anotherInitializer)) {
-            return result;
+        if (!this->parseInitializer(offset, &anotherInitializer)) {
+            return;
         }
-        result.push_back(T(mods, type, this->text(identifierName), std::move(anotherInitializer)));
-        AddToSymbolTable(result.back());
+        DSLGlobalVar next(mods, type, this->text(identifierName), std::move(anotherInitializer));
+        Declare(next);
+        AddToSymbolTable(next);
     }
     this->expect(Token::Kind::TK_SEMICOLON, "';'");
-    return result;
+}
+
+/* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
+   (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
+skstd::optional<DSLStatement> DSLParser::localVarDeclarationEnd(PositionInfo pos,
+        const dsl::DSLModifiers& mods, dsl::DSLType baseType, skstd::string_view name) {
+    using namespace dsl;
+    int offset = this->peek().fOffset;
+    DSLType type = baseType;
+    DSLExpression initializer;
+    if (!this->parseArrayDimensions(offset, &type)) {
+        return skstd::nullopt;
+    }
+    this->parseInitializer(offset, &initializer);
+    DSLVar first(mods, type, name, std::move(initializer), pos);
+    DSLStatement result = Declare(first);
+    AddToSymbolTable(first);
+
+    while (this->checkNext(Token::Kind::TK_COMMA)) {
+        type = baseType;
+        Token identifierName;
+        if (!this->expectIdentifier(&identifierName)) {
+            return {std::move(result)};
+        }
+        if (!this->parseArrayDimensions(offset, &type)) {
+            return {std::move(result)};
+        }
+        DSLExpression anotherInitializer;
+        if (!this->parseInitializer(offset, &anotherInitializer)) {
+            return {std::move(result)};
+        }
+        DSLVar next(mods, type, this->text(identifierName), std::move(anotherInitializer));
+        DSLWriter::AddVarDeclaration(result, next);
+        AddToSymbolTable(next);
+    }
+    this->expect(Token::Kind::TK_SEMICOLON, "';'");
+    return {std::move(result)};
 }
 
 /* (varDeclarations | expressionStatement) */
@@ -486,11 +514,8 @@
         VarDeclarationsPrefix prefix;
         if (this->varDeclarationsPrefix(&prefix)) {
             checkpoint.accept();
-            return declaration_statements(this->varDeclarationEnd<DSLVar>(prefix.fPosition,
-                                                                          prefix.fModifiers,
-                                                                          prefix.fType,
-                                                                          this->text(prefix.fName)),
-                                          this->symbols());
+            return this->localVarDeclarationEnd(prefix.fPosition, prefix.fModifiers, prefix.fType,
+                    this->text(prefix.fName));
         }
 
         // If this statement wasn't actually a vardecl after all, rewind and try parsing it as an
@@ -519,11 +544,8 @@
     if (!this->varDeclarationsPrefix(&prefix)) {
         return skstd::nullopt;
     }
-    return declaration_statements(this->varDeclarationEnd<DSLVar>(prefix.fPosition,
-                                                                  prefix.fModifiers,
-                                                                  prefix.fType,
-                                                                  this->text(prefix.fName)),
-                                  this->symbols());
+    return this->localVarDeclarationEnd(prefix.fPosition, prefix.fModifiers, prefix.fType,
+            this->text(prefix.fName));
 }
 
 /* STRUCT IDENTIFIER LBRACE varDeclaration* RBRACE */
@@ -586,8 +608,8 @@
     }
     Token name;
     if (this->checkNext(Token::Kind::TK_IDENTIFIER, &name)) {
-        return this->varDeclarationEnd<DSLGlobalVar>(this->position(name), modifiers,
-                std::move(*type), this->text(name));
+        this->globalVarDeclarationEnd(this->position(name), modifiers, std::move(*type),
+                this->text(name));
     }
     this->expect(Token::Kind::TK_SEMICOLON, "';'");
     return {};
diff --git a/src/sksl/SkSLDSLParser.h b/src/sksl/SkSLDSLParser.h
index e283e69..485b0f9 100644
--- a/src/sksl/SkSLDSLParser.h
+++ b/src/sksl/SkSLDSLParser.h
@@ -163,14 +163,15 @@
 
     SkTArray<dsl::DSLGlobalVar> structVarDeclaration(const dsl::DSLModifiers& modifiers);
 
-    /* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
-       (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
-    template<class T>
-    SkTArray<T> varDeclarationEnd(PositionInfo position, const dsl::DSLModifiers& mods,
-                                  dsl::DSLType baseType, skstd::string_view name);
+    bool parseArrayDimensions(int offset, dsl::DSLType* type);
 
-    SkTArray<dsl::DSLGlobalVar> globalVarDeclarationEnd(const dsl::DSLModifiers& modifiers,
-                                                        dsl::DSLType type, skstd::string_view name);
+    bool parseInitializer(int offset, dsl::DSLExpression* initializer);
+
+    void globalVarDeclarationEnd(PositionInfo position, const dsl::DSLModifiers& mods,
+            dsl::DSLType baseType, skstd::string_view name);
+
+    skstd::optional<dsl::DSLStatement> localVarDeclarationEnd(PositionInfo position,
+            const dsl::DSLModifiers& mods, dsl::DSLType baseType, skstd::string_view name);
 
     skstd::optional<dsl::DSLWrapper<dsl::DSLParameter>> parameter();
 
diff --git a/src/sksl/dsl/priv/DSLWriter.cpp b/src/sksl/dsl/priv/DSLWriter.cpp
index 4372ecd..f753c68 100644
--- a/src/sksl/dsl/priv/DSLWriter.cpp
+++ b/src/sksl/dsl/priv/DSLWriter.cpp
@@ -135,6 +135,21 @@
 }
 #endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
 
+void DSLWriter::AddVarDeclaration(DSLStatement& existing, DSLVar& additional) {
+    if (existing.fStatement->is<Block>()) {
+        SkSL::Block& block = existing.fStatement->as<Block>();
+        SkASSERT(!block.isScope());
+        block.children().push_back(Declare(additional).release());
+    } else {
+        SkASSERT(existing.fStatement->is<VarDeclaration>());
+        StatementArray stmts;
+        stmts.reserve_back(2);
+        stmts.push_back(std::move(existing.fStatement));
+        stmts.push_back(Declare(additional).release());
+        existing.fStatement = SkSL::Block::MakeUnscoped(/*offset=*/-1, std::move(stmts));
+    }
+}
+
 std::unique_ptr<SkSL::Expression> DSLWriter::Call(const FunctionDeclaration& function,
                                                   ExpressionArray arguments,
                                                   PositionInfo pos) {
diff --git a/src/sksl/dsl/priv/DSLWriter.h b/src/sksl/dsl/priv/DSLWriter.h
index b57bda8..ee981d3 100644
--- a/src/sksl/dsl/priv/DSLWriter.h
+++ b/src/sksl/dsl/priv/DSLWriter.h
@@ -190,6 +190,13 @@
     }
 #endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
 
+    /**
+     * Adds a new declaration into an existing declaration statement. This either turns the original
+     * declaration into an unscoped block or, if it already was, appends a new statement to the end
+     * of it.
+     */
+    static void AddVarDeclaration(DSLStatement& existing, DSLVar& additional);
+
     static std::unique_ptr<SkSL::Expression> Call(const FunctionDeclaration& function,
                                                   ExpressionArray arguments,
                                                   PositionInfo pos = PositionInfo::Capture());