Add ProgramWriter, a non-const version of ProgramVisitor.

This allows us to traverse a program's hierarchy and make changes (as
long as the structure remains intact). It's the caller's responsibility
to make sure they don't invalidate any iterators of the ProgramWriter.

Change-Id: Icfc651134d916e19b92004c92fe09880bb96600b
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/320717
Commit-Queue: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index 46e9a90..6ad95a7 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -338,8 +338,9 @@
 ////////////////////////////////////////////////////////////////////////////////
 // ProgramVisitor
 
-bool ProgramVisitor::visit(const Program& program) {
-    for (const ProgramElement& pe : program) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visit(PROG program) {
+    for (ELEM pe : program) {
         if (this->visitProgramElement(pe)) {
             return true;
         }
@@ -347,8 +348,9 @@
     return false;
 }
 
-bool ProgramVisitor::visitExpression(const Expression& e) {
-    switch(e.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitExpression(EXPR e) {
+    switch (e.kind()) {
         case Expression::Kind::kBoolLiteral:
         case Expression::Kind::kDefined:
         case Expression::Kind::kExternalValue:
@@ -361,50 +363,61 @@
         case Expression::Kind::kVariableReference:
             // Leaf expressions return false
             return false;
+
         case Expression::Kind::kBinary: {
-            const BinaryExpression& b = e.as<BinaryExpression>();
-            return this->visitExpression(b.left()) || this->visitExpression(b.right()); }
+            auto& b = e.template as<BinaryExpression>();
+            return this->visitExpression(b.left()) || this->visitExpression(b.right());
+        }
         case Expression::Kind::kConstructor: {
-            const Constructor& c = e.as<Constructor>();
-            for (const auto& arg : c.arguments()) {
+            auto& c = e.template as<Constructor>();
+            for (auto& arg : c.arguments()) {
                 if (this->visitExpression(*arg)) { return true; }
             }
-            return false; }
+            return false;
+        }
         case Expression::Kind::kExternalFunctionCall: {
-            const ExternalFunctionCall& c = e.as<ExternalFunctionCall>();
-            for (const auto& arg : c.fArguments) {
+            auto& c = e.template as<ExternalFunctionCall>();
+            for (auto& arg : c.fArguments) {
                 if (this->visitExpression(*arg)) { return true; }
             }
-            return false; }
+            return false;
+        }
         case Expression::Kind::kFieldAccess:
-            return this->visitExpression(*e.as<FieldAccess>().fBase);
+            return this->visitExpression(*e.template as<FieldAccess>().fBase);
+
         case Expression::Kind::kFunctionCall: {
-            const FunctionCall& c = e.as<FunctionCall>();
-            for (const auto& arg : c.fArguments) {
+            auto& c = e.template as<FunctionCall>();
+            for (auto& arg : c.fArguments) {
                 if (this->visitExpression(*arg)) { return true; }
             }
-            return false; }
+            return false;
+        }
         case Expression::Kind::kIndex: {
-            const IndexExpression& i = e.as<IndexExpression>();
-            return this->visitExpression(*i.fBase) || this->visitExpression(*i.fIndex); }
+            auto& i = e.template as<IndexExpression>();
+            return this->visitExpression(*i.fBase) || this->visitExpression(*i.fIndex);
+        }
         case Expression::Kind::kPostfix:
-            return this->visitExpression(*e.as<PostfixExpression>().fOperand);
+            return this->visitExpression(*e.template as<PostfixExpression>().fOperand);
+
         case Expression::Kind::kPrefix:
-            return this->visitExpression(*e.as<PrefixExpression>().fOperand);
+            return this->visitExpression(*e.template as<PrefixExpression>().fOperand);
+
         case Expression::Kind::kSwizzle:
-            return this->visitExpression(*e.as<Swizzle>().fBase);
+            return this->visitExpression(*e.template as<Swizzle>().fBase);
+
         case Expression::Kind::kTernary: {
-            const TernaryExpression& t = e.as<TernaryExpression>();
-            return this->visitExpression(*t.fTest) ||
-                   this->visitExpression(*t.fIfTrue) ||
-                   this->visitExpression(*t.fIfFalse); }
+            auto& t = e.template as<TernaryExpression>();
+            return this->visitExpression(*t.fTest) || this->visitExpression(*t.fIfTrue) ||
+                   this->visitExpression(*t.fIfFalse);
+        }
         default:
             SkUNREACHABLE;
     }
 }
 
-bool ProgramVisitor::visitStatement(const Statement& s) {
-    switch(s.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitStatement(STMT s) {
+    switch (s.kind()) {
         case Statement::Kind::kBreak:
         case Statement::Kind::kContinue:
         case Statement::Kind::kDiscard:
@@ -412,81 +425,114 @@
         case Statement::Kind::kNop:
             // Leaf statements just return false
             return false;
+
         case Statement::Kind::kBlock:
-            for (const std::unique_ptr<Statement>& stmt : s.as<Block>().children()) {
+            for (auto& stmt : s.template as<Block>().children()) {
                 if (this->visitStatement(*stmt)) {
                     return true;
                 }
             }
             return false;
+
         case Statement::Kind::kDo: {
-            const DoStatement& d = s.as<DoStatement>();
-            return this->visitExpression(*d.test()) || this->visitStatement(*d.statement()); }
+            auto& d = s.template as<DoStatement>();
+            return this->visitExpression(*d.test()) || this->visitStatement(*d.statement());
+        }
         case Statement::Kind::kExpression:
-            return this->visitExpression(*s.as<ExpressionStatement>().expression());
+            return this->visitExpression(*s.template as<ExpressionStatement>().expression());
+
         case Statement::Kind::kFor: {
-            const ForStatement& f = s.as<ForStatement>();
+            auto& f = s.template as<ForStatement>();
             return (f.fInitializer && this->visitStatement(*f.fInitializer)) ||
                    (f.fTest && this->visitExpression(*f.fTest)) ||
                    (f.fNext && this->visitExpression(*f.fNext)) ||
-                   this->visitStatement(*f.fStatement); }
+                   this->visitStatement(*f.fStatement);
+        }
         case Statement::Kind::kIf: {
-            const IfStatement& i = s.as<IfStatement>();
+            auto& i = s.template as<IfStatement>();
             return this->visitExpression(*i.fTest) ||
                    this->visitStatement(*i.fIfTrue) ||
-                   (i.fIfFalse && this->visitStatement(*i.fIfFalse)); }
+                   (i.fIfFalse && this->visitStatement(*i.fIfFalse));
+        }
         case Statement::Kind::kReturn: {
-            const ReturnStatement& r = s.as<ReturnStatement>();
-            return r.fExpression && this->visitExpression(*r.fExpression); }
+            auto& r = s.template as<ReturnStatement>();
+            return r.fExpression && this->visitExpression(*r.fExpression);
+        }
         case Statement::Kind::kSwitch: {
-            const SwitchStatement& sw = s.as<SwitchStatement>();
-            if (this->visitExpression(*sw.fValue)) { return true; }
-            for (const auto& c : sw.fCases) {
-                if (c->fValue && this->visitExpression(*c->fValue)) { return true; }
-                for (const std::unique_ptr<Statement>& st : c->fStatements) {
-                    if (this->visitStatement(*st)) { return true; }
+            auto& sw = s.template as<SwitchStatement>();
+            if (this->visitExpression(*sw.fValue)) {
+                return true;
+            }
+            for (auto& c : sw.fCases) {
+                if (c->fValue && this->visitExpression(*c->fValue)) {
+                    return true;
+                }
+                for (auto& st : c->fStatements) {
+                    if (this->visitStatement(*st)) {
+                        return true;
+                    }
                 }
             }
-            return false; }
+            return false;
+        }
         case Statement::Kind::kVarDeclaration: {
-            const VarDeclaration& v = s.as<VarDeclaration>();
-            for (const std::unique_ptr<Expression>& sizeExpr : v.fSizes) {
-                if (sizeExpr && this->visitExpression(*sizeExpr)) { return true; }
+            auto& v = s.template as<VarDeclaration>();
+            for (auto& sizeExpr : v.fSizes) {
+                if (sizeExpr && this->visitExpression(*sizeExpr)) {
+                    return true;
+                }
             }
-            return v.fValue && this->visitExpression(*v.fValue); }
+            return v.fValue && this->visitExpression(*v.fValue);
+        }
         case Statement::Kind::kVarDeclarations:
-            return this->visitProgramElement(*s.as<VarDeclarationsStatement>().fDeclaration);
+            return this->visitProgramElement(
+                    *s.template as<VarDeclarationsStatement>().fDeclaration);
+
         case Statement::Kind::kWhile: {
-            const WhileStatement& w = s.as<WhileStatement>();
-            return this->visitExpression(*w.fTest) || this->visitStatement(*w.fStatement); }
+            auto& w = s.template as<WhileStatement>();
+            return this->visitExpression(*w.fTest) || this->visitStatement(*w.fStatement);
+        }
         default:
             SkUNREACHABLE;
     }
 }
 
-bool ProgramVisitor::visitProgramElement(const ProgramElement& pe) {
-    switch(pe.kind()) {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+bool TProgramVisitor<PROG, EXPR, STMT, ELEM>::visitProgramElement(ELEM pe) {
+    switch (pe.kind()) {
         case ProgramElement::Kind::kEnum:
         case ProgramElement::Kind::kExtension:
         case ProgramElement::Kind::kModifiers:
         case ProgramElement::Kind::kSection:
             // Leaf program elements just return false by default
             return false;
+
         case ProgramElement::Kind::kFunction:
-            return this->visitStatement(*pe.as<FunctionDefinition>().fBody);
+            return this->visitStatement(*pe.template as<FunctionDefinition>().fBody);
+
         case ProgramElement::Kind::kInterfaceBlock:
-            for (const auto& e : pe.as<InterfaceBlock>().fSizes) {
-                if (this->visitExpression(*e)) { return true; }
+            for (auto& e : pe.template as<InterfaceBlock>().fSizes) {
+                if (this->visitExpression(*e)) {
+                    return true;
+                }
             }
             return false;
+
         case ProgramElement::Kind::kVar:
-            for (const auto& v : pe.as<VarDeclarations>().fVars) {
-                if (this->visitStatement(*v)) { return true; }
+            for (auto& v : pe.template as<VarDeclarations>().fVars) {
+                if (this->visitStatement(*v)) {
+                    return true;
+                }
             }
             return false;
+
         default:
             SkUNREACHABLE;
     }
 }
 
+template class TProgramVisitor<const Program&, const Expression&,
+                               const Statement&, const ProgramElement&>;
+template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+
 }  // namespace SkSL
diff --git a/src/sksl/SkSLAnalysis.h b/src/sksl/SkSLAnalysis.h
index b9629ee..95c7ce1 100644
--- a/src/sksl/SkSLAnalysis.h
+++ b/src/sksl/SkSLAnalysis.h
@@ -57,18 +57,35 @@
  * stack.
  */
 
-class ProgramVisitor {
+template <typename PROG, typename EXPR, typename STMT, typename ELEM>
+class TProgramVisitor {
 public:
-    virtual ~ProgramVisitor() = default;
+    virtual ~TProgramVisitor() = default;
 
-    bool visit(const Program&);
+    bool visit(PROG program);
 
 protected:
-    virtual bool visitExpression(const Expression&);
-    virtual bool visitStatement(const Statement&);
-    virtual bool visitProgramElement(const ProgramElement&);
+    virtual bool visitExpression(EXPR expression);
+    virtual bool visitStatement(STMT statement);
+    virtual bool visitProgramElement(ELEM programElement);
 };
 
+// Squelch bogus Clang warning about template vtables: https://bugs.llvm.org/show_bug.cgi?id=18733
+#if defined(__clang__)
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wweak-template-vtables"
+#endif
+extern template class TProgramVisitor<const Program&, const Expression&,
+                                      const Statement&, const ProgramElement&>;
+extern template class TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+#if defined(__clang__)
+#pragma clang diagnostic pop
+#endif
+
+using ProgramVisitor = TProgramVisitor<const Program&, const Expression&,
+                                       const Statement&, const ProgramElement&>;
+using ProgramWriter = TProgramVisitor<Program&, Expression&, Statement&, ProgramElement&>;
+
 }  // namespace SkSL
 
 #endif
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 9062d34..5f3f026 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -222,13 +222,12 @@
 std::unique_ptr<Expression> clone_with_ref_kind(const Expression& expr,
                                                 VariableReference::RefKind refKind) {
     std::unique_ptr<Expression> clone = expr.clone();
-    class SetRefKindInExpression : public ProgramVisitor {
+    class SetRefKindInExpression : public ProgramWriter {
     public:
         SetRefKindInExpression(VariableReference::RefKind refKind) : fRefKind(refKind) {}
-        bool visitExpression(const Expression& expr) override {
+        bool visitExpression(Expression& expr) override {
             if (expr.is<VariableReference>()) {
-                // TODO: create a const-savvy ProgramVisitor and remove const_cast
-                const_cast<VariableReference&>(expr.as<VariableReference>()).setRefKind(fRefKind);
+                expr.as<VariableReference>().setRefKind(fRefKind);
             }
             return INHERITED::visitExpression(expr);
         }
@@ -236,7 +235,7 @@
     private:
         VariableReference::RefKind fRefKind;
 
-        using INHERITED = ProgramVisitor;
+        using INHERITED = ProgramWriter;
     };
 
     SetRefKindInExpression{refKind}.visitExpression(*clone);