Add DSL PossibleExpression & PossibleStatement

These are currently unused, but in future CLs they will be used to
capture line number information in DSL error handling.

Change-Id: Ieee730e0ad8323043437972fedb5bec471c367e4
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/375069
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/dsl/DSLCore.h b/src/sksl/dsl/DSLCore.h
index d2fe16c..e97fb4f 100644
--- a/src/sksl/dsl/DSLCore.h
+++ b/src/sksl/dsl/DSLCore.h
@@ -32,16 +32,6 @@
 using namespace SkSL::SwizzleComponent;
 
 /**
- * Class which is notified in the event of an error.
- */
-class ErrorHandler {
-public:
-    virtual ~ErrorHandler() {}
-
-    virtual void handleError(const char* msg) = 0;
-};
-
-/**
  * Starts DSL output on the current thread using the specified compiler. This must be called
  * prior to any other DSL functions.
  */
diff --git a/src/sksl/dsl/DSLErrorHandling.h b/src/sksl/dsl/DSLErrorHandling.h
new file mode 100644
index 0000000..f0dfcfa
--- /dev/null
+++ b/src/sksl/dsl/DSLErrorHandling.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2021 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_ERROR_HANDLING
+#define SKSL_DSL_ERROR_HANDLING
+
+namespace SkSL {
+
+namespace dsl {
+
+class PositionInfo {
+public:
+#if defined(__GNUC__) || defined(__clang__)
+    PositionInfo(const char* file = __builtin_FILE(), int line = __builtin_LINE())
+#else
+    PositionInfo(const char* file = nullptr, int line = -1)
+#endif // defined(__GNUC__) || defined(__clang__)
+        : fFile(file)
+        , fLine(line) {}
+
+    const char* file_name() {
+        return fFile;
+    }
+
+    int line() {
+        return fLine;
+    }
+
+private:
+    const char* fFile;
+    int fLine;
+};
+
+/**
+ * Class which is notified in the event of an error.
+ */
+class ErrorHandler {
+public:
+    virtual ~ErrorHandler() {}
+
+    /**
+     * Reports a DSL error. Position may not be available, in which case it will be null.
+     */
+    virtual void handleError(const char* msg, PositionInfo* position) = 0;
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLExpression.cpp b/src/sksl/dsl/DSLExpression.cpp
index 780d22f..5f7cb74 100644
--- a/src/sksl/dsl/DSLExpression.cpp
+++ b/src/sksl/dsl/DSLExpression.cpp
@@ -66,6 +66,14 @@
                                                         var.var(),
                                                         SkSL::VariableReference::RefKind::kRead)) {}
 
+DSLExpression::DSLExpression(DSLPossibleExpression expr, PositionInfo pos) {
+    if (DSLWriter::Compiler().errorCount()) {
+        DSLWriter::ReportError(DSLWriter::Compiler().errorText(/*showCount=*/false).c_str(), &pos);
+        DSLWriter::Compiler().setErrorCount(0);
+    }
+    fExpression = std::move(expr.fExpression);
+}
+
 DSLExpression::~DSLExpression() {
 #if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
     if (fExpression && DSLWriter::InFragmentProcessor()) {
@@ -189,6 +197,92 @@
     return DSLWriter::Coerce(this->release(), type).release();
 }
 
+DSLPossibleExpression::DSLPossibleExpression(std::unique_ptr<SkSL::Expression> expr)
+    : fExpression(std::move(expr)) {}
+
+DSLPossibleExpression::~DSLPossibleExpression() {
+    if (fExpression) {
+        // this handles incorporating the expression into the output tree
+        DSLExpression(std::move(fExpression));
+    }
+}
+
+DSLExpression DSLPossibleExpression::x(PositionInfo pos) {
+    return DSLExpression(this->release()).x();
+}
+
+DSLExpression DSLPossibleExpression::y(PositionInfo pos) {
+    return DSLExpression(this->release()).y();
+}
+
+DSLExpression DSLPossibleExpression::z(PositionInfo pos) {
+    return DSLExpression(this->release()).z();
+}
+
+DSLExpression DSLPossibleExpression::w(PositionInfo pos) {
+    return DSLExpression(this->release()).w();
+}
+
+DSLExpression DSLPossibleExpression::r(PositionInfo pos) {
+    return DSLExpression(this->release()).r();
+}
+
+DSLExpression DSLPossibleExpression::g(PositionInfo pos) {
+    return DSLExpression(this->release()).g();
+}
+
+DSLExpression DSLPossibleExpression::b(PositionInfo pos) {
+    return DSLExpression(this->release()).b();
+}
+
+DSLExpression DSLPossibleExpression::a(PositionInfo pos) {
+    return DSLExpression(this->release()).a();
+}
+
+DSLExpression DSLPossibleExpression::field(const char* name, PositionInfo pos) {
+    return DSLExpression(this->release()).field(name);
+}
+
+DSLExpression DSLPossibleExpression::operator=(const DSLVar& var) {
+    return this->operator=(DSLExpression(var));
+}
+
+DSLExpression DSLPossibleExpression::operator=(DSLExpression expr) {
+    return DSLExpression(this->release()) = std::move(expr);
+}
+
+DSLExpression DSLPossibleExpression::operator=(int expr) {
+    return this->operator=(DSLExpression(expr));
+}
+
+DSLExpression DSLPossibleExpression::operator=(float expr) {
+    return this->operator=(DSLExpression(expr));
+}
+
+DSLExpression DSLPossibleExpression::operator[](DSLExpression index) {
+    return DSLExpression(this->release())[std::move(index)];
+}
+
+DSLExpression DSLPossibleExpression::operator++() {
+    return ++DSLExpression(this->release());
+}
+
+DSLExpression DSLPossibleExpression::operator++(int) {
+    return DSLExpression(this->release())++;
+}
+
+DSLExpression DSLPossibleExpression::operator--() {
+    return --DSLExpression(this->release());
+}
+
+DSLExpression DSLPossibleExpression::operator--(int) {
+    return DSLExpression(this->release())--;
+}
+
+std::unique_ptr<SkSL::Expression> DSLPossibleExpression::release() {
+    return std::move(fExpression);
+}
+
 } // namespace dsl
 
 } // namespace SkSL
diff --git a/src/sksl/dsl/DSLExpression.h b/src/sksl/dsl/DSLExpression.h
index a9dc428..4cc0b89 100644
--- a/src/sksl/dsl/DSLExpression.h
+++ b/src/sksl/dsl/DSLExpression.h
@@ -9,6 +9,7 @@
 #define SKSL_DSL_EXPRESSION
 
 #include "include/core/SkTypes.h"
+#include "src/sksl/dsl/DSLErrorHandling.h"
 #include "src/sksl/ir/SkSLIRNode.h"
 
 #include <cstdint>
@@ -20,6 +21,7 @@
 
 namespace dsl {
 
+class DSLPossibleExpression;
 class DSLStatement;
 class DSLVar;
 
@@ -60,6 +62,8 @@
      */
     DSLExpression(const DSLVar& var);
 
+    DSLExpression(DSLPossibleExpression expr, PositionInfo pos = PositionInfo());
+
     ~DSLExpression();
 
     /**
@@ -116,6 +120,7 @@
 
     friend class DSLCore;
     friend class DSLFunction;
+    friend class DSLPossibleExpression;
     friend class DSLVar;
     friend class DSLWriter;
 };
@@ -156,6 +161,67 @@
 DSLExpression operator--(DSLExpression expr);
 DSLExpression operator--(DSLExpression expr, int);
 
+/**
+ * Represents an Expression which may have failed and/or have pending errors to report. Converting a
+ * PossibleExpression into an Expression requires PositionInfo so that any pending errors can be
+ * reported at the correct position.
+ *
+ * PossibleExpression is used instead of Expression in situations where it is not possible to
+ * capture the PositionInfo at the time of Expression construction (notably in operator overloads,
+ * where we cannot add default parameters).
+ */
+class DSLPossibleExpression {
+public:
+    DSLPossibleExpression(std::unique_ptr<SkSL::Expression> expression);
+
+    DSLPossibleExpression(DSLPossibleExpression&& other) = default;
+
+    ~DSLPossibleExpression();
+
+    DSLExpression x(PositionInfo pos = PositionInfo());
+
+    DSLExpression y(PositionInfo pos = PositionInfo());
+
+    DSLExpression z(PositionInfo pos = PositionInfo());
+
+    DSLExpression w(PositionInfo pos = PositionInfo());
+
+    DSLExpression r(PositionInfo pos = PositionInfo());
+
+    DSLExpression g(PositionInfo pos = PositionInfo());
+
+    DSLExpression b(PositionInfo pos = PositionInfo());
+
+    DSLExpression a(PositionInfo pos = PositionInfo());
+
+    DSLExpression field(const char* name, PositionInfo pos = PositionInfo());
+
+    DSLExpression operator=(const DSLVar& var);
+
+    DSLExpression operator=(DSLExpression expr);
+
+    DSLExpression operator=(int expr);
+
+    DSLExpression operator=(float expr);
+
+    DSLExpression operator[](DSLExpression index);
+
+    DSLExpression operator++();
+
+    DSLExpression operator++(int);
+
+    DSLExpression operator--();
+
+    DSLExpression operator--(int);
+
+    std::unique_ptr<SkSL::Expression> release();
+
+private:
+    std::unique_ptr<SkSL::Expression> fExpression;
+
+    friend class DSLExpression;
+};
+
 } // namespace dsl
 
 } // namespace SkSL
diff --git a/src/sksl/dsl/DSLStatement.cpp b/src/sksl/dsl/DSLStatement.cpp
index 55c4b5e..9f9862c 100644
--- a/src/sksl/dsl/DSLStatement.cpp
+++ b/src/sksl/dsl/DSLStatement.cpp
@@ -42,6 +42,17 @@
     }
 }
 
+DSLStatement::DSLStatement(DSLPossibleExpression expr, PositionInfo pos)
+    : DSLStatement(DSLExpression(std::move(expr), pos)) {}
+
+DSLStatement::DSLStatement(DSLPossibleStatement stmt, PositionInfo pos) {
+    if (DSLWriter::Compiler().errorCount()) {
+        DSLWriter::ReportError(DSLWriter::Compiler().errorText(/*showCount=*/false).c_str(), &pos);
+        DSLWriter::Compiler().setErrorCount(0);
+    }
+    fStatement = std::move(stmt.fStatement);
+}
+
 DSLStatement::~DSLStatement() {
 #if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
     if (fStatement && DSLWriter::InFragmentProcessor()) {
@@ -52,6 +63,16 @@
     SkASSERTF(!fStatement, "Statement destroyed without being incorporated into program");
 }
 
+DSLPossibleStatement::DSLPossibleStatement(std::unique_ptr<SkSL::Statement> statement)
+    : fStatement(std::move(statement)) {}
+
+DSLPossibleStatement::~DSLPossibleStatement() {
+    if (fStatement) {
+        // this handles incorporating the expression into the output tree
+        DSLStatement(std::move(fStatement));
+    }
+}
+
 } // namespace dsl
 
 } // namespace SkSL
diff --git a/src/sksl/dsl/DSLStatement.h b/src/sksl/dsl/DSLStatement.h
index b03dc095..941674f 100644
--- a/src/sksl/dsl/DSLStatement.h
+++ b/src/sksl/dsl/DSLStatement.h
@@ -10,6 +10,7 @@
 
 #include "include/core/SkString.h"
 #include "include/core/SkTypes.h"
+#include "src/sksl/dsl/DSLErrorHandling.h"
 
 #include <memory>
 
@@ -24,6 +25,8 @@
 
 class DSLBlock;
 class DSLExpression;
+class DSLPossibleExpression;
+class DSLPossibleStatement;
 class DSLVar;
 
 class DSLStatement {
@@ -32,6 +35,10 @@
 
     DSLStatement(DSLExpression expr);
 
+    DSLStatement(DSLPossibleExpression expr, PositionInfo pos = PositionInfo());
+
+    DSLStatement(DSLPossibleStatement stmt, PositionInfo pos = PositionInfo());
+
     DSLStatement(DSLBlock block);
 
     DSLStatement(DSLStatement&&) = default;
@@ -52,9 +59,36 @@
     friend class DSLBlock;
     friend class DSLCore;
     friend class DSLExpression;
+    friend class DSLPossibleStatement;
     friend class DSLWriter;
 };
 
+/**
+ * Represents a Statement which may have failed and/or have pending errors to report. Converting a
+ * PossibleStatement into a Statement requires PositionInfo so that any pending errors can be
+ * reported at the correct position.
+ *
+ * PossibleStatement is used instead of Statement in situations where it is not possible to capture
+ * the PositionInfo at the time of Statement construction.
+ */
+class DSLPossibleStatement {
+public:
+    DSLPossibleStatement(std::unique_ptr<SkSL::Statement> stmt);
+
+    DSLPossibleStatement(DSLPossibleStatement&& other) = default;
+
+    ~DSLPossibleStatement();
+
+    std::unique_ptr<SkSL::Statement> release() {
+        return std::move(fStatement);
+    }
+
+private:
+    std::unique_ptr<SkSL::Statement> fStatement;
+
+    friend class DSLStatement;
+};
+
 } // namespace dsl
 
 } // namespace SkSL
diff --git a/src/sksl/dsl/priv/DSLWriter.cpp b/src/sksl/dsl/priv/DSLWriter.cpp
index 2495e19..66f7079 100644
--- a/src/sksl/dsl/priv/DSLWriter.cpp
+++ b/src/sksl/dsl/priv/DSLWriter.cpp
@@ -141,10 +141,15 @@
                                  IRGenerator().fSymbolTable);
 }
 
-
-void DSLWriter::ReportError(const char* msg) {
+void DSLWriter::ReportError(const char* msg, PositionInfo* info) {
+    if (info && !info->file_name()) {
+        info = nullptr;
+    }
     if (Instance().fErrorHandler) {
-        Instance().fErrorHandler->handleError(msg);
+        Instance().fErrorHandler->handleError(msg, info);
+    } else if (info) {
+        SK_ABORT("%s: %d: %sNo SkSL DSL error handler configured, treating this as a fatal error\n",
+                 info->file_name(), info->line(), msg);
     } else {
         SK_ABORT("%sNo SkSL DSL error handler configured, treating this as a fatal error\n", msg);
     }
diff --git a/src/sksl/dsl/priv/DSLWriter.h b/src/sksl/dsl/priv/DSLWriter.h
index ee0597c..8a46b85 100644
--- a/src/sksl/dsl/priv/DSLWriter.h
+++ b/src/sksl/dsl/priv/DSLWriter.h
@@ -167,7 +167,7 @@
      * Notifies the current ErrorHandler that a DSL error has occurred. With a null ErrorHandler
      * (the default), any errors will be dumped to stderr and a fatal exception will be generated.
      */
-    static void ReportError(const char* msg);
+    static void ReportError(const char* msg, PositionInfo* info = nullptr);
 
     /**
      * Returns whether name mangling is enabled. This should always be enabled outside of tests.
diff --git a/tests/SkSLDSLTest.cpp b/tests/SkSLDSLTest.cpp
index 0fde847..0055f78 100644
--- a/tests/SkSLDSLTest.cpp
+++ b/tests/SkSLDSLTest.cpp
@@ -44,7 +44,7 @@
         SetErrorHandler(nullptr);
     }
 
-    void handleError(const char* msg) override {
+    void handleError(const char* msg, PositionInfo* pos) override {
         REPORTER_ASSERT(fReporter, !strcmp(msg, fMsg),
                         "Error mismatch: expected:\n%sbut received:\n%s", fMsg, msg);
         fMsg = nullptr;