Replace getMatExpression with getConstantSubexpression.

This approach gives us similar flexibility but requires fewer lines of
code to get the same result. In a followup CL we will be able to
eliminate get[BFI]VecExpression as well. This approach will also scale
to arrays and structs if we want to support constant-folding on these.

Change-Id: Ib0034935926c7004f84ba62ddbdb3168df8ce91d
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393076
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index bb086eb..dbdd129 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -73,6 +73,11 @@
         return std::make_unique<BoolLiteral>(fOffset, this->value(), &this->type());
     }
 
+    const Expression* getConstantSubexpression(int n) const override {
+        SkASSERT(n == 0);
+        return this;
+    }
+
 private:
     bool fValue;
 
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 9cfa6fe..f489f80 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -175,65 +175,6 @@
     return std::make_unique<Constructor>(offset, type, std::move(args));
 }
 
-Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
-    if (other.is<ConstructorDiagonalMatrix>()) {
-        return other.compareConstant(*this);
-    }
-    if (other.is<ConstructorMatrixResize>()) {
-        return other.compareConstant(*this);
-    }
-    if (other.is<ConstructorSplat>()) {
-        return other.compareConstant(*this);
-    }
-    if (!other.is<Constructor>()) {
-        return ComparisonResult::kUnknown;
-    }
-    const Constructor& c = other.as<Constructor>();
-    const Type& myType = this->type();
-    SkASSERT(myType == c.type());
-
-    if (myType.isVector()) {
-        if (myType.componentType().isFloat()) {
-            for (int i = 0; i < myType.columns(); i++) {
-                if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
-                    return ComparisonResult::kNotEqual;
-                }
-            }
-            return ComparisonResult::kEqual;
-        }
-        if (myType.componentType().isInteger()) {
-            for (int i = 0; i < myType.columns(); i++) {
-                if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
-                    return ComparisonResult::kNotEqual;
-                }
-            }
-            return ComparisonResult::kEqual;
-        }
-        if (myType.componentType().isBoolean()) {
-            for (int i = 0; i < myType.columns(); i++) {
-                if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
-                    return ComparisonResult::kNotEqual;
-                }
-            }
-            return ComparisonResult::kEqual;
-        }
-    }
-
-    if (myType.isMatrix()) {
-        for (int col = 0; col < myType.columns(); col++) {
-            for (int row = 0; row < myType.rows(); row++) {
-                if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
-                    return ComparisonResult::kNotEqual;
-                }
-            }
-        }
-        return ComparisonResult::kEqual;
-    }
-
-    SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
-    return ComparisonResult::kUnknown;
-}
-
 template <typename ResultType>
 ResultType Constructor::getConstantValue(const Expression& expr) const {
     const Type& type = expr.type();
@@ -313,50 +254,6 @@
 template SKSL_FLOAT Constructor::getVecComponent(int) const;
 template bool Constructor::getVecComponent(int) const;
 
-SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
-    SkDEBUGCODE(const Type& myType = this->type();)
-    SkASSERT(this->isCompileTimeConstant());
-    SkASSERT(myType.isMatrix());
-    SkASSERT(col < myType.columns() && row < myType.rows());
-    if (this->arguments().size() == 1) {
-        const Type& argType = this->arguments()[0]->type();
-        if (argType.isScalar()) {
-            // single scalar argument, so matrix is of the form:
-            // x 0 0
-            // 0 x 0
-            // 0 0 x
-            // return x if col == row
-            return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
-        }
-        if (argType.isMatrix()) {
-            SkASSERT(this->arguments()[0]->isAnyConstructor());
-            // single matrix argument. make sure we're within the argument's bounds.
-            if (col < argType.columns() && row < argType.rows()) {
-                // within bounds, defer to argument
-                return this->arguments()[0]->getMatComponent(col, row);
-            }
-            // out of bounds
-            return 0.0;
-        }
-    }
-    int currentIndex = 0;
-    int targetIndex = col * this->type().rows() + row;
-    for (const auto& arg : this->arguments()) {
-        const Type& argType = arg->type();
-        SkASSERT(targetIndex >= currentIndex);
-        SkASSERT(argType.rows() == 1);
-        if (currentIndex + argType.columns() > targetIndex) {
-            if (argType.columns() == 1) {
-                return arg->getConstantFloat();
-            } else {
-                return arg->getFVecComponent(targetIndex - currentIndex);
-            }
-        }
-        currentIndex += argType.columns();
-    }
-    SK_ABORT("can't happen, matrix component out of bounds");
-}
-
 SKSL_INT Constructor::getConstantInt() const {
     // We're looking for scalar integer constructors only, i.e. `int(1)`.
     SkASSERT(this->arguments().size() == 1);
@@ -399,6 +296,45 @@
                                      (bool)expr.getConstantFloat();
 }
 
+const Expression* AnyConstructor::getConstantSubexpression(int n) const {
+    SkASSERT(n >= 0 && n < (int)this->type().slotCount());
+    for (const std::unique_ptr<Expression>& arg : this->argumentSpan()) {
+        int argSlots = arg->type().slotCount();
+        if (n < argSlots) {
+            return arg->getConstantSubexpression(n);
+        }
+        n -= argSlots;
+    }
+
+    SkDEBUGFAIL("argument-list slot count doesn't match constructor-type slot count");
+    return nullptr;
+}
+
+Expression::ComparisonResult AnyConstructor::compareConstant(const Expression& other) const {
+    ComparisonResult result = ComparisonResult::kEqual;
+    SkASSERT(this->type().slotCount() == other.type().slotCount());
+
+    int exprs = this->type().slotCount();
+    for (int n = 0; n < exprs; ++n) {
+        // Get the n'th subexpression from each side. If either one is null, return "unknown."
+        const Expression* left = this->getConstantSubexpression(n);
+        if (!left) {
+            return ComparisonResult::kUnknown;
+        }
+        const Expression* right = other.getConstantSubexpression(n);
+        if (!right) {
+            return ComparisonResult::kUnknown;
+        }
+        // Recurse into the subexpressions; the literal types will perform real comparisons, and
+        // most other expressions fall back on the base class Expression which returns unknown.
+        result = left->compareConstant(*right);
+        if (result != ComparisonResult::kEqual) {
+            break;
+        }
+    }
+    return result;
+}
+
 AnyConstructor& Expression::asAnyConstructor() {
     SkASSERT(this->isAnyConstructor());
     return static_cast<AnyConstructor&>(*this);
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index d7c6c3a..7ff91c9 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -71,6 +71,10 @@
         return true;
     }
 
+    const Expression* getConstantSubexpression(int n) const override;
+
+    ComparisonResult compareConstant(const Expression& other) const override;
+
 private:
     std::unique_ptr<Expression> fArgument;
 
@@ -183,8 +187,6 @@
         return std::make_unique<Constructor>(fOffset, this->type(), this->cloneArguments());
     }
 
-    ComparisonResult compareConstant(const Expression& other) const override;
-
     template <typename ResultType>
     ResultType getVecComponent(int index) const;
 
@@ -215,8 +217,6 @@
         return this->getVecComponent<bool>(n);
     }
 
-    SKSL_FLOAT getMatComponent(int col, int row) const override;
-
     SKSL_INT getConstantInt() const override;
 
     SKSL_FLOAT getConstantFloat() const override;
diff --git a/src/sksl/ir/SkSLConstructorArray.cpp b/src/sksl/ir/SkSLConstructorArray.cpp
index 4517b75..c305fb4 100644
--- a/src/sksl/ir/SkSLConstructorArray.cpp
+++ b/src/sksl/ir/SkSLConstructorArray.cpp
@@ -56,22 +56,4 @@
     return std::make_unique<ConstructorArray>(offset, type, std::move(args));
 }
 
-Expression::ComparisonResult ConstructorArray::compareConstant(const Expression& other) const {
-    // There is only one array-constructor type, so if this comparison had type-checked
-    // successfully, `other` should be a ConstructorArray with the same array size.
-    const ConstructorArray& otherArray = other.as<ConstructorArray>();
-    int numColumns = this->type().columns();
-    SkASSERT(numColumns == otherArray.type().columns());
-
-    ComparisonResult check = ComparisonResult::kEqual;
-    for (int index = 0; index < numColumns; index++) {
-        check = this->arguments()[index]->compareConstant(*otherArray.arguments()[index]);
-        if (check != ComparisonResult::kEqual) {
-            break;
-        }
-    }
-
-    return check;
-}
-
 }  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorArray.h b/src/sksl/ir/SkSLConstructorArray.h
index a2b154d..139de52 100644
--- a/src/sksl/ir/SkSLConstructorArray.h
+++ b/src/sksl/ir/SkSLConstructorArray.h
@@ -39,8 +39,6 @@
         return std::make_unique<ConstructorArray>(fOffset, this->type(), this->cloneArguments());
     }
 
-    ComparisonResult compareConstant(const Expression& other) const override;
-
 private:
     using INHERITED = MultiArgumentConstructor;
 };
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
index ccdcdef..97cf760 100644
--- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
@@ -21,35 +21,17 @@
     return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
 }
 
-Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
-        const Expression& other) const {
-    SkASSERT(other.type().isMatrix());
-    SkASSERT(this->type() == other.type());
+const Expression* ConstructorDiagonalMatrix::getConstantSubexpression(int n) const {
+    int rows = this->type().rows();
+    int row = n % rows;
+    int col = n / rows;
 
-    // The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
-    for (int col = 0; col < this->type().columns(); col++) {
-        for (int row = 0; row < this->type().rows(); row++) {
-            if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
-                return ComparisonResult::kNotEqual;
-            }
-        }
-    }
-
-    return ComparisonResult::kEqual;
-}
-
-SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
-    SkASSERT(this->isCompileTimeConstant());
     SkASSERT(col >= 0);
     SkASSERT(row >= 0);
     SkASSERT(col < this->type().columns());
     SkASSERT(row < this->type().rows());
 
-    // Our matrix is of the form:
-    //  |x 0 0|
-    //  |0 x 0|
-    //  |0 0 x|
-    return (col == row) ? this->argument()->getConstantFloat() : 0.0;
+    return (col == row) ? this->argument()->getConstantSubexpression(0) : &fZeroLiteral;
 }
 
 }  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
index 3e87051..b9fd229 100644
--- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
@@ -27,7 +27,8 @@
     static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
 
     ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
-        : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+        : INHERITED(offset, kExpressionKind, &type, std::move(arg))
+        , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) {}
 
     static std::unique_ptr<Expression> Make(const Context& context,
                                             int offset,
@@ -39,11 +40,10 @@
                                                            argument()->clone());
     }
 
-    ComparisonResult compareConstant(const Expression& other) const override;
-
-    SKSL_FLOAT getMatComponent(int col, int row) const override;
+    const Expression* getConstantSubexpression(int n) const override;
 
 private:
+    const FloatLiteral fZeroLiteral;
     using INHERITED = SingleArgumentConstructor;
 };
 
diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.cpp b/src/sksl/ir/SkSLConstructorMatrixResize.cpp
index e7e41f3..4c8b374 100644
--- a/src/sksl/ir/SkSLConstructorMatrixResize.cpp
+++ b/src/sksl/ir/SkSLConstructorMatrixResize.cpp
@@ -27,27 +27,11 @@
     return std::make_unique<ConstructorMatrixResize>(offset, type, std::move(arg));
 }
 
-Expression::ComparisonResult ConstructorMatrixResize::compareConstant(
-        const Expression& other) const {
-    SkASSERT(other.type().isMatrix());
-    SkASSERT(this->type() == other.type());
-    SkASSERT(this->type().rows() == other.type().rows());
-    SkASSERT(this->type().columns() == other.type().columns());
+const Expression* ConstructorMatrixResize::getConstantSubexpression(int n) const {
+    int rows = this->type().rows();
+    int row = n % rows;
+    int col = n / rows;
 
-    // Check each cell individually.
-    for (int col = 0; col < this->type().columns(); col++) {
-        for (int row = 0; row < this->type().rows(); row++) {
-            if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
-                return ComparisonResult::kNotEqual;
-            }
-        }
-    }
-
-    return ComparisonResult::kEqual;
-}
-
-SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
-    SkASSERT(this->isCompileTimeConstant());
     SkASSERT(col >= 0);
     SkASSERT(row >= 0);
     SkASSERT(col < this->type().columns());
@@ -59,13 +43,15 @@
     //  |0 0 1|
     // Where `m` is the matrix being wrapped, and other cells contain the identity matrix.
 
-    // Forward `getMatComponent` to the wrapped matrix if the position is in its bounds.
+    // Forward `getConstantSubexpression` to the wrapped matrix if the position is in its bounds.
     if (col < this->argument()->type().columns() && row < this->argument()->type().rows()) {
-        return this->argument()->getMatComponent(col, row);
+        // Recalculate `n` in terms of the inner matrix's dimensions.
+        n = row + (col * this->argument()->type().rows());
+        return this->argument()->getConstantSubexpression(n);
     }
 
     // Synthesize an identity matrix for out-of-bounds positions.
-    return (col == row) ? 1.0f : 0.0f;
+    return (col == row) ? &fOneLiteral : &fZeroLiteral;
 }
 
 }  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.h b/src/sksl/ir/SkSLConstructorMatrixResize.h
index 2a18403..0036180 100644
--- a/src/sksl/ir/SkSLConstructorMatrixResize.h
+++ b/src/sksl/ir/SkSLConstructorMatrixResize.h
@@ -28,7 +28,9 @@
     static constexpr Kind kExpressionKind = Kind::kConstructorMatrixResize;
 
     ConstructorMatrixResize(int offset, const Type& type, std::unique_ptr<Expression> arg)
-            : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+            : INHERITED(offset, kExpressionKind, &type, std::move(arg))
+            , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType())
+            , fOneLiteral(offset, /*value=*/1.0f, &type.componentType()) {}
 
     static std::unique_ptr<Expression> Make(const Context& context,
                                             int offset,
@@ -40,12 +42,12 @@
                                                          argument()->clone());
     }
 
-    Expression::ComparisonResult compareConstant(const Expression& other) const override;
-
-    SKSL_FLOAT getMatComponent(int col, int row) const override;
+    const Expression* getConstantSubexpression(int n) const override;
 
 private:
     using INHERITED = SingleArgumentConstructor;
+    const FloatLiteral fZeroLiteral;
+    const FloatLiteral fOneLiteral;
 };
 
 }  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.cpp b/src/sksl/ir/SkSLConstructorSplat.cpp
index d8d0062..4d07288 100644
--- a/src/sksl/ir/SkSLConstructorSplat.cpp
+++ b/src/sksl/ir/SkSLConstructorSplat.cpp
@@ -25,29 +25,4 @@
     return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
 }
 
-Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
-    SkASSERT(this->type() == other.type());
-    if (!other.isAnyConstructor()) {
-        return ComparisonResult::kUnknown;
-    }
-
-    return this->compareConstantConstructor(other.asAnyConstructor());
-}
-
-Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
-        const AnyConstructor& other) const {
-    ComparisonResult check = ComparisonResult::kEqual;
-    for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
-        // We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
-        check = expr->isAnyConstructor()
-                        ? this->compareConstantConstructor(expr->asAnyConstructor())
-                        : argument()->compareConstant(*expr);
-        if (check != ComparisonResult::kEqual) {
-            break;
-        }
-    }
-
-    return check;
-}
-
 }  // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h
index 4ca3a07..5fb2a45 100644
--- a/src/sksl/ir/SkSLConstructorSplat.h
+++ b/src/sksl/ir/SkSLConstructorSplat.h
@@ -38,8 +38,6 @@
         return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
     }
 
-    ComparisonResult compareConstant(const Expression& other) const override;
-
     SKSL_FLOAT getFVecComponent(int) const override {
         return this->argument()->getConstantFloat();
     }
@@ -52,9 +50,12 @@
         return this->argument()->getConstantBool();
     }
 
-private:
-    Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
+    const Expression* getConstantSubexpression(int n) const override {
+        SkASSERT(n >= 0 && n < this->type().columns());
+        return this->argument()->getConstantSubexpression(0);
+    }
 
+private:
     using INHERITED = SingleArgumentConstructor;
 };
 
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index c332992..7b53868 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -179,6 +179,18 @@
     }
 
     /**
+     * Returns the n'th compile-time constant expression within a literal or constructor.
+     * Use Type::slotCount to determine the number of subexpressions within an expression.
+     * Subexpressions which are not compile-time constants will return null.
+     * `vec4(1, vec2(2), 3)` contains four subexpressions: (1, 2, 2, 3)
+     * `mat2(f)` contains four subexpressions: (null, 0,
+     *                                          0, null)
+     */
+    virtual const Expression* getConstantSubexpression(int n) const {
+        return nullptr;
+    }
+
+    /**
      * For a vector of floating point values, return the value of the n'th vector component. It is
      * an error to call this method on an expression which is not a vector of floating-point
      * constant expressions.
@@ -215,16 +227,6 @@
      */
     template <typename T> T getVecComponent(int index) const;
 
-    /**
-     * For a literal matrix expression, return the floating point value of the component at
-     * [col][row]. It is an error to call this method on an expression which is not a literal
-     * matrix.
-     */
-    virtual SKSL_FLOAT getMatComponent(int col, int row) const {
-        SkASSERT(false);
-        return 0;
-    }
-
     virtual std::unique_ptr<Expression> clone() const = 0;
 
 private:
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 4cc8eb9..40eebc2 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -79,6 +79,11 @@
         return std::make_unique<FloatLiteral>(fOffset, this->value(), &this->type());
     }
 
+    const Expression* getConstantSubexpression(int n) const override {
+        SkASSERT(n == 0);
+        return this;
+    }
+
 private:
     float fValue;
 
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 3505079..b9763a7 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -81,6 +81,11 @@
         return std::make_unique<IntLiteral>(fOffset, this->value(), &this->type());
     }
 
+    const Expression* getConstantSubexpression(int n) const override {
+        SkASSERT(n == 0);
+        return this;
+    }
+
 private:
     SKSL_INT fValue;