Eliminated unnecessary arrays in DSLParser variable declarations
Change-Id: Ia14371ef3bb83928f6ee93120bcc29de9db6c020
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/448269
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLDSLParser.cpp b/src/sksl/SkSLDSLParser.cpp
index 716589d..0c4183b 100644
--- a/src/sksl/SkSLDSLParser.cpp
+++ b/src/sksl/SkSLDSLParser.cpp
@@ -9,6 +9,7 @@
#include "include/private/SkSLString.h"
#include "src/sksl/SkSLCompiler.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
#include <memory>
@@ -298,8 +299,7 @@
return true;
}
if (lookahead.fKind == Token::Kind::TK_STRUCT) {
- SkTArray<DSLGlobalVar> result = this->structVarDeclaration(modifiers);
- Declare(result);
+ this->structVarDeclaration(modifiers);
return true;
}
skstd::optional<DSLType> type = this->type(modifiers);
@@ -313,10 +313,7 @@
if (this->checkNext(Token::Kind::TK_LPAREN)) {
return this->functionDeclarationEnd(modifiers, *type, name);
} else {
- SkTArray<DSLGlobalVar> result = this->varDeclarationEnd<DSLGlobalVar>(this->position(name),
- modifiers, *type,
- this->text(name));
- Declare(result);
+ this->globalVarDeclarationEnd(this->position(name), modifiers, *type, this->text(name));
return true;
}
}
@@ -366,14 +363,6 @@
return true;
}
-static skstd::optional<DSLStatement> declaration_statements(SkTArray<DSLVar> vars,
- SymbolTable& symbols) {
- if (vars.empty()) {
- return skstd::nullopt;
- }
- return Declare(vars);
-}
-
static bool is_valid(const skstd::optional<DSLWrapper<DSLExpression>>& expr) {
return expr && expr->get().isValid();
}
@@ -407,63 +396,102 @@
}
}
-template<class T>
-SkTArray<T> DSLParser::varDeclarationEnd(PositionInfo pos, const dsl::DSLModifiers& mods,
- dsl::DSLType baseType, skstd::string_view name) {
- using namespace dsl;
- SkTArray<T> result;
- int offset = this->peek().fOffset;
- auto parseArrayDimensions = [&](DSLType* type) -> bool {
- while (this->checkNext(Token::Kind::TK_LBRACKET)) {
- if (this->checkNext(Token::Kind::TK_RBRACKET)) {
- this->error(offset, "expected array dimension");
- } else {
- *type = Array(*type, this->arraySize(), pos);
- if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) {
- return {};
- }
- }
- }
- return true;
- };
- auto parseInitializer = [this](DSLExpression* initializer) -> bool {
- if (this->checkNext(Token::Kind::TK_EQ)) {
- skstd::optional<DSLWrapper<DSLExpression>> value = this->assignmentExpression();
- if (!value) {
+bool DSLParser::parseArrayDimensions(int offset, DSLType* type) {
+ while (this->checkNext(Token::Kind::TK_LBRACKET)) {
+ if (this->checkNext(Token::Kind::TK_RBRACKET)) {
+ this->error(offset, "expected array dimension");
+ } else {
+ *type = Array(*type, this->arraySize(), this->position(offset));
+ if (!this->expect(Token::Kind::TK_RBRACKET, "']'")) {
return false;
}
- initializer->swap(**value);
}
- return true;
- };
+ }
+ return true;
+}
+bool DSLParser::parseInitializer(int offset, DSLExpression* initializer) {
+ if (this->checkNext(Token::Kind::TK_EQ)) {
+ skstd::optional<DSLWrapper<DSLExpression>> value = this->assignmentExpression();
+ if (!value) {
+ return false;
+ }
+ initializer->swap(**value);
+ }
+ return true;
+}
+
+/* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
+ (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
+void DSLParser::globalVarDeclarationEnd(PositionInfo pos, const dsl::DSLModifiers& mods,
+ dsl::DSLType baseType, skstd::string_view name) {
+ using namespace dsl;
+ int offset = this->peek().fOffset;
DSLType type = baseType;
DSLExpression initializer;
- if (!parseArrayDimensions(&type)) {
- return {};
+ if (!this->parseArrayDimensions(offset, &type)) {
+ return;
}
- parseInitializer(&initializer);
- result.push_back(T(mods, type, name, std::move(initializer), pos));
- AddToSymbolTable(result.back());
+ this->parseInitializer(offset, &initializer);
+ DSLGlobalVar first(mods, type, name, std::move(initializer), pos);
+ Declare(first);
+ AddToSymbolTable(first);
while (this->checkNext(Token::Kind::TK_COMMA)) {
type = baseType;
Token identifierName;
if (!this->expectIdentifier(&identifierName)) {
- return result;
+ return;
}
- if (!parseArrayDimensions(&type)) {
- return result;
+ if (!this->parseArrayDimensions(offset, &type)) {
+ return;
}
DSLExpression anotherInitializer;
- if (!parseInitializer(&anotherInitializer)) {
- return result;
+ if (!this->parseInitializer(offset, &anotherInitializer)) {
+ return;
}
- result.push_back(T(mods, type, this->text(identifierName), std::move(anotherInitializer)));
- AddToSymbolTable(result.back());
+ DSLGlobalVar next(mods, type, this->text(identifierName), std::move(anotherInitializer));
+ Declare(next);
+ AddToSymbolTable(next);
}
this->expect(Token::Kind::TK_SEMICOLON, "';'");
- return result;
+}
+
+/* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
+ (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
+skstd::optional<DSLStatement> DSLParser::localVarDeclarationEnd(PositionInfo pos,
+ const dsl::DSLModifiers& mods, dsl::DSLType baseType, skstd::string_view name) {
+ using namespace dsl;
+ int offset = this->peek().fOffset;
+ DSLType type = baseType;
+ DSLExpression initializer;
+ if (!this->parseArrayDimensions(offset, &type)) {
+ return skstd::nullopt;
+ }
+ this->parseInitializer(offset, &initializer);
+ DSLVar first(mods, type, name, std::move(initializer), pos);
+ DSLStatement result = Declare(first);
+ AddToSymbolTable(first);
+
+ while (this->checkNext(Token::Kind::TK_COMMA)) {
+ type = baseType;
+ Token identifierName;
+ if (!this->expectIdentifier(&identifierName)) {
+ return {std::move(result)};
+ }
+ if (!this->parseArrayDimensions(offset, &type)) {
+ return {std::move(result)};
+ }
+ DSLExpression anotherInitializer;
+ if (!this->parseInitializer(offset, &anotherInitializer)) {
+ return {std::move(result)};
+ }
+ DSLVar next(mods, type, this->text(identifierName), std::move(anotherInitializer));
+ DSLWriter::AddVarDeclaration(result, next);
+ AddToSymbolTable(next);
+ }
+ this->expect(Token::Kind::TK_SEMICOLON, "';'");
+ return {std::move(result)};
}
/* (varDeclarations | expressionStatement) */
@@ -486,11 +514,8 @@
VarDeclarationsPrefix prefix;
if (this->varDeclarationsPrefix(&prefix)) {
checkpoint.accept();
- return declaration_statements(this->varDeclarationEnd<DSLVar>(prefix.fPosition,
- prefix.fModifiers,
- prefix.fType,
- this->text(prefix.fName)),
- this->symbols());
+ return this->localVarDeclarationEnd(prefix.fPosition, prefix.fModifiers, prefix.fType,
+ this->text(prefix.fName));
}
// If this statement wasn't actually a vardecl after all, rewind and try parsing it as an
@@ -519,11 +544,8 @@
if (!this->varDeclarationsPrefix(&prefix)) {
return skstd::nullopt;
}
- return declaration_statements(this->varDeclarationEnd<DSLVar>(prefix.fPosition,
- prefix.fModifiers,
- prefix.fType,
- this->text(prefix.fName)),
- this->symbols());
+ return this->localVarDeclarationEnd(prefix.fPosition, prefix.fModifiers, prefix.fType,
+ this->text(prefix.fName));
}
/* STRUCT IDENTIFIER LBRACE varDeclaration* RBRACE */
@@ -586,8 +608,8 @@
}
Token name;
if (this->checkNext(Token::Kind::TK_IDENTIFIER, &name)) {
- return this->varDeclarationEnd<DSLGlobalVar>(this->position(name), modifiers,
- std::move(*type), this->text(name));
+ this->globalVarDeclarationEnd(this->position(name), modifiers, std::move(*type),
+ this->text(name));
}
this->expect(Token::Kind::TK_SEMICOLON, "';'");
return {};
diff --git a/src/sksl/SkSLDSLParser.h b/src/sksl/SkSLDSLParser.h
index e283e69..485b0f9 100644
--- a/src/sksl/SkSLDSLParser.h
+++ b/src/sksl/SkSLDSLParser.h
@@ -163,14 +163,15 @@
SkTArray<dsl::DSLGlobalVar> structVarDeclaration(const dsl::DSLModifiers& modifiers);
- /* (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)? (COMMA IDENTIFER
- (LBRACKET expression? RBRACKET)* (EQ assignmentExpression)?)* SEMICOLON */
- template<class T>
- SkTArray<T> varDeclarationEnd(PositionInfo position, const dsl::DSLModifiers& mods,
- dsl::DSLType baseType, skstd::string_view name);
+ bool parseArrayDimensions(int offset, dsl::DSLType* type);
- SkTArray<dsl::DSLGlobalVar> globalVarDeclarationEnd(const dsl::DSLModifiers& modifiers,
- dsl::DSLType type, skstd::string_view name);
+ bool parseInitializer(int offset, dsl::DSLExpression* initializer);
+
+ void globalVarDeclarationEnd(PositionInfo position, const dsl::DSLModifiers& mods,
+ dsl::DSLType baseType, skstd::string_view name);
+
+ skstd::optional<dsl::DSLStatement> localVarDeclarationEnd(PositionInfo position,
+ const dsl::DSLModifiers& mods, dsl::DSLType baseType, skstd::string_view name);
skstd::optional<dsl::DSLWrapper<dsl::DSLParameter>> parameter();
diff --git a/src/sksl/dsl/priv/DSLWriter.cpp b/src/sksl/dsl/priv/DSLWriter.cpp
index 4372ecd..f753c68 100644
--- a/src/sksl/dsl/priv/DSLWriter.cpp
+++ b/src/sksl/dsl/priv/DSLWriter.cpp
@@ -135,6 +135,21 @@
}
#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+void DSLWriter::AddVarDeclaration(DSLStatement& existing, DSLVar& additional) {
+ if (existing.fStatement->is<Block>()) {
+ SkSL::Block& block = existing.fStatement->as<Block>();
+ SkASSERT(!block.isScope());
+ block.children().push_back(Declare(additional).release());
+ } else {
+ SkASSERT(existing.fStatement->is<VarDeclaration>());
+ StatementArray stmts;
+ stmts.reserve_back(2);
+ stmts.push_back(std::move(existing.fStatement));
+ stmts.push_back(Declare(additional).release());
+ existing.fStatement = SkSL::Block::MakeUnscoped(/*offset=*/-1, std::move(stmts));
+ }
+}
+
std::unique_ptr<SkSL::Expression> DSLWriter::Call(const FunctionDeclaration& function,
ExpressionArray arguments,
PositionInfo pos) {
diff --git a/src/sksl/dsl/priv/DSLWriter.h b/src/sksl/dsl/priv/DSLWriter.h
index b57bda8..ee981d3 100644
--- a/src/sksl/dsl/priv/DSLWriter.h
+++ b/src/sksl/dsl/priv/DSLWriter.h
@@ -190,6 +190,13 @@
}
#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+ /**
+ * Adds a new declaration into an existing declaration statement. This either turns the original
+ * declaration into an unscoped block or, if it already was, appends a new statement to the end
+ * of it.
+ */
+ static void AddVarDeclaration(DSLStatement& existing, DSLVar& additional);
+
static std::unique_ptr<SkSL::Expression> Call(const FunctionDeclaration& function,
ExpressionArray arguments,
PositionInfo pos = PositionInfo::Capture());