Add ProgramWriter, a non-const version of ProgramVisitor.
This allows us to traverse a program's hierarchy and make changes (as
long as the structure remains intact). It's the caller's responsibility
to make sure they don't invalidate any iterators of the ProgramWriter.
Change-Id: Icfc651134d916e19b92004c92fe09880bb96600b
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/320717
Commit-Queue: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index 46e9a90..6ad95a7 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -338,8 +338,9 @@
////////////////////////////////////////////////////////////////////////////////
// ProgramVisitor
-bool ProgramVisitor::visit(const Program& program) {
- for (const ProgramElement& pe : program) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visit(PROG program) {
+ for (ELEM pe : program) {
if (this->visitProgramElement(pe)) {
return true;
}
@@ -347,8 +348,9 @@
return false;
}
-bool ProgramVisitor::visitExpression(const Expression& e) {
- switch(e.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
+ switch (e.kind()) {
case Expression::Kind::kBoolLiteral:
case Expression::Kind::kDefined:
case Expression::Kind::kExternalValue:
@@ -361,50 +363,61 @@
case Expression::Kind::kVariableReference:
// Leaf expressions return false
return false;
+
case Expression::Kind::kBinary: {
- const BinaryExpression& b = e.as<BinaryExpression>();
- return this->visitExpression(b.left()) || this->visitExpression(b.right()); }
+ auto& b = e.template as<BinaryExpression>();
+ return this->visitExpression(b.left()) || this->visitExpression(b.right());
+ }
case Expression::Kind::kConstructor: {
- const Constructor& c = e.as<Constructor>();
- for (const auto& arg : c.arguments()) {
+ auto& c = e.template as<Constructor>();
+ for (auto& arg : c.arguments()) {
if (this->visitExpression(*arg)) { return true; }
}
- return false; }
+ return false;
+ }
case Expression::Kind::kExternalFunctionCall: {
- const ExternalFunctionCall& c = e.as<ExternalFunctionCall>();
- for (const auto& arg : c.fArguments) {
+ auto& c = e.template as<ExternalFunctionCall>();
+ for (auto& arg : c.fArguments) {
if (this->visitExpression(*arg)) { return true; }
}
- return false; }
+ return false;
+ }
case Expression::Kind::kFieldAccess:
- return this->visitExpression(*e.as<FieldAccess>().fBase);
+ return this->visitExpression(*e.template as<FieldAccess>().fBase);
+
case Expression::Kind::kFunctionCall: {
- const FunctionCall& c = e.as<FunctionCall>();
- for (const auto& arg : c.fArguments) {
+ auto& c = e.template as<FunctionCall>();
+ for (auto& arg : c.fArguments) {
if (this->visitExpression(*arg)) { return true; }
}
- return false; }
+ return false;
+ }
case Expression::Kind::kIndex: {
- const IndexExpression& i = e.as<IndexExpression>();
- return this->visitExpression(*i.fBase) || this->visitExpression(*i.fIndex); }
+ auto& i = e.template as<IndexExpression>();
+ return this->visitExpression(*i.fBase) || this->visitExpression(*i.fIndex);
+ }
case Expression::Kind::kPostfix:
- return this->visitExpression(*e.as<PostfixExpression>().fOperand);
+ return this->visitExpression(*e.template as<PostfixExpression>().fOperand);
+
case Expression::Kind::kPrefix:
- return this->visitExpression(*e.as<PrefixExpression>().fOperand);
+ return this->visitExpression(*e.template as<PrefixExpression>().fOperand);
+
case Expression::Kind::kSwizzle:
- return this->visitExpression(*e.as<Swizzle>().fBase);
+ return this->visitExpression(*e.template as<Swizzle>().fBase);
+
case Expression::Kind::kTernary: {
- const TernaryExpression& t = e.as<TernaryExpression>();
- return this->visitExpression(*t.fTest) ||
- this->visitExpression(*t.fIfTrue) ||
- this->visitExpression(*t.fIfFalse); }
+ auto& t = e.template as<TernaryExpression>();
+ return this->visitExpression(*t.fTest) || this->visitExpression(*t.fIfTrue) ||
+ this->visitExpression(*t.fIfFalse);
+ }
default:
SkUNREACHABLE;
}
}
-bool ProgramVisitor::visitStatement(const Statement& s) {
- switch(s.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
+ switch (s.kind()) {
case Statement::Kind::kBreak:
case Statement::Kind::kContinue:
case Statement::Kind::kDiscard:
@@ -412,81 +425,114 @@
case Statement::Kind::kNop:
// Leaf statements just return false
return false;
+
case Statement::Kind::kBlock:
- for (const std::unique_ptr<Statement>& stmt : s.as<Block>().children()) {
+ for (auto& stmt : s.template as<Block>().children()) {
if (this->visitStatement(*stmt)) {
return true;
}
}
return false;
+
case Statement::Kind::kDo: {
- const DoStatement& d = s.as<DoStatement>();
- return this->visitExpression(*d.test()) || this->visitStatement(*d.statement()); }
+ auto& d = s.template as<DoStatement>();
+ return this->visitExpression(*d.test()) || this->visitStatement(*d.statement());
+ }
case Statement::Kind::kExpression:
- return this->visitExpression(*s.as<ExpressionStatement>().expression());
+ return this->visitExpression(*s.template as<ExpressionStatement>().expression());
+
case Statement::Kind::kFor: {
- const ForStatement& f = s.as<ForStatement>();
+ auto& f = s.template as<ForStatement>();
return (f.fInitializer && this->visitStatement(*f.fInitializer)) ||
(f.fTest && this->visitExpression(*f.fTest)) ||
(f.fNext && this->visitExpression(*f.fNext)) ||
- this->visitStatement(*f.fStatement); }
+ this->visitStatement(*f.fStatement);
+ }
case Statement::Kind::kIf: {
- const IfStatement& i = s.as<IfStatement>();
+ auto& i = s.template as<IfStatement>();
return this->visitExpression(*i.fTest) ||
this->visitStatement(*i.fIfTrue) ||
- (i.fIfFalse && this->visitStatement(*i.fIfFalse)); }
+ (i.fIfFalse && this->visitStatement(*i.fIfFalse));
+ }
case Statement::Kind::kReturn: {
- const ReturnStatement& r = s.as<ReturnStatement>();
- return r.fExpression && this->visitExpression(*r.fExpression); }
+ auto& r = s.template as<ReturnStatement>();
+ return r.fExpression && this->visitExpression(*r.fExpression);
+ }
case Statement::Kind::kSwitch: {
- const SwitchStatement& sw = s.as<SwitchStatement>();
- if (this->visitExpression(*sw.fValue)) { return true; }
- for (const auto& c : sw.fCases) {
- if (c->fValue && this->visitExpression(*c->fValue)) { return true; }
- for (const std::unique_ptr<Statement>& st : c->fStatements) {
- if (this->visitStatement(*st)) { return true; }
+ auto& sw = s.template as<SwitchStatement>();
+ if (this->visitExpression(*sw.fValue)) {
+ return true;
+ }
+ for (auto& c : sw.fCases) {
+ if (c->fValue && this->visitExpression(*c->fValue)) {
+ return true;
+ }
+ for (auto& st : c->fStatements) {
+ if (this->visitStatement(*st)) {
+ return true;
+ }
}
}
- return false; }
+ return false;
+ }
case Statement::Kind::kVarDeclaration: {
- const VarDeclaration& v = s.as<VarDeclaration>();
- for (const std::unique_ptr<Expression>& sizeExpr : v.fSizes) {
- if (sizeExpr && this->visitExpression(*sizeExpr)) { return true; }
+ auto& v = s.template as<VarDeclaration>();
+ for (auto& sizeExpr : v.fSizes) {
+ if (sizeExpr && this->visitExpression(*sizeExpr)) {
+ return true;
+ }
}
- return v.fValue && this->visitExpression(*v.fValue); }
+ return v.fValue && this->visitExpression(*v.fValue);
+ }
case Statement::Kind::kVarDeclarations:
- return this->visitProgramElement(*s.as<VarDeclarationsStatement>().fDeclaration);
+ return this->visitProgramElement(
+ *s.template as<VarDeclarationsStatement>().fDeclaration);
+
case Statement::Kind::kWhile: {
- const WhileStatement& w = s.as<WhileStatement>();
- return this->visitExpression(*w.fTest) || this->visitStatement(*w.fStatement); }
+ auto& w = s.template as<WhileStatement>();
+ return this->visitExpression(*w.fTest) || this->visitStatement(*w.fStatement);
+ }
default:
SkUNREACHABLE;
}
}
-bool ProgramVisitor::visitProgramElement(const ProgramElement& pe) {
- switch(pe.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
+ switch (pe.kind()) {
case ProgramElement::Kind::kEnum:
case ProgramElement::Kind::kExtension:
case ProgramElement::Kind::kModifiers:
case ProgramElement::Kind::kSection:
// Leaf program elements just return false by default
return false;
+
case ProgramElement::Kind::kFunction:
- return this->visitStatement(*pe.as<FunctionDefinition>().fBody);
+ return this->visitStatement(*pe.template as<FunctionDefinition>().fBody);
+
case ProgramElement::Kind::kInterfaceBlock:
- for (const auto& e : pe.as<InterfaceBlock>().fSizes) {
- if (this->visitExpression(*e)) { return true; }
+ for (auto& e : pe.template as<InterfaceBlock>().fSizes) {
+ if (this->visitExpression(*e)) {
+ return true;
+ }
}
return false;
+
case ProgramElement::Kind::kVar:
- for (const auto& v : pe.as<VarDeclarations>().fVars) {
- if (this->visitStatement(*v)) { return true; }
+ for (auto& v : pe.template as<VarDeclarations>().fVars) {
+ if (this->visitStatement(*v)) {
+ return true;
+ }
}
return false;
+
default:
SkUNREACHABLE;
}
}
+template class TProgramVisitor<const Program&, const Expression&,
+ const Statement&, const ProgramElement&>;
+template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+
} // namespace SkSL
diff --git a/src/sksl/SkSLAnalysis.h b/src/sksl/SkSLAnalysis.h
index b9629ee..95c7ce1 100644
--- a/src/sksl/SkSLAnalysis.h
+++ b/src/sksl/SkSLAnalysis.h
@@ -57,18 +57,35 @@
* stack.
*/
-class ProgramVisitor {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+class TProgramVisitor {
public:
- virtual ~ProgramVisitor() = default;
+ virtual ~TProgramVisitor() = default;
- bool visit(const Program&);
+ bool visit(PROG program);
protected:
- virtual bool visitExpression(const Expression&);
- virtual bool visitStatement(const Statement&);
- virtual bool visitProgramElement(const ProgramElement&);
+ virtual bool visitExpression(EXPR expression);
+ virtual bool visitStatement(STMT statement);
+ virtual bool visitProgramElement(ELEM programElement);
};
+// Squelch bogus Clang warning about template vtables: https://bugs.llvm.org/show_bug.cgi?id=18733
+#if defined(__clang__)
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wweak-template-vtables"
+#endif
+extern template class TProgramVisitor<const Program&, const Expression&,
+ const Statement&, const ProgramElement&>;
+extern template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+#if defined(__clang__)
+#pragma clang diagnostic pop
+#endif
+
+using ProgramVisitor = TProgramVisitor<const Program&, const Expression&,
+ const Statement&, const ProgramElement&>;
+using ProgramWriter = TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+
} // namespace SkSL
#endif
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 9062d34..5f3f026 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -222,13 +222,12 @@
std::unique_ptr<Expression> clone_with_ref_kind(const Expression& expr,
VariableReference::RefKind refKind) {
std::unique_ptr<Expression> clone = expr.clone();
- class SetRefKindInExpression : public ProgramVisitor {
+ class SetRefKindInExpression : public ProgramWriter {
public:
SetRefKindInExpression(VariableReference::RefKind refKind) : fRefKind(refKind) {}
- bool visitExpression(const Expression& expr) override {
+ bool visitExpression(Expression& expr) override {
if (expr.is<VariableReference>()) {
- // TODO: create a const-savvy ProgramVisitor and remove const_cast
- const_cast<VariableReference&>(expr.as<VariableReference>()).setRefKind(fRefKind);
+ expr.as<VariableReference>().setRefKind(fRefKind);
}
return INHERITED::visitExpression(expr);
}
@@ -236,7 +235,7 @@
private:
VariableReference::RefKind fRefKind;
- using INHERITED = ProgramVisitor;
+ using INHERITED = ProgramWriter;
};
SetRefKindInExpression{refKind}.visitExpression(*clone);