Replace getVecComponent with getConstantSubexpression.

Change-Id: I792f23d3aba45bdaea174ee51d4aca5bd9cb4ea4
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393079
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index ddc0697..3f9fb87 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -795,17 +795,20 @@
 }
 
 static bool get_constant_value(const Expression& expr, double* val) {
-    if (!expr.isCompileTimeConstant()) {
+    const Expression* valExpr = expr.getConstantSubexpression(0);
+    if (!valExpr) {
         return false;
     }
-    if (!expr.type().isNumber()) {
-        SkDEBUGFAILF("unexpected constant type (%s)", expr.type().description().c_str());
-        return false;
+    if (valExpr->is<IntLiteral>()) {
+        *val = static_cast<double>(valExpr->as<IntLiteral>().value());
+        return true;
     }
-
-    *val = expr.type().isInteger() ? static_cast<double>(expr.getConstantInt())
-                                   : static_cast<double>(expr.getConstantFloat());
-    return true;
+    if (valExpr->is<FloatLiteral>()) {
+        *val = static_cast<double>(valExpr->as<FloatLiteral>().value());
+        return true;
+    }
+    SkDEBUGFAILF("unexpected constant type (%s)", expr.type().description().c_str());
+    return false;
 }
 
 static const char* invalid_for_ES2(int offset,
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index 440b4cc..d78edfb 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -98,7 +98,8 @@
         ExpressionArray args;
         args.reserve_back(type.columns());
         for (int i = 0; i < type.columns(); i++) {
-            U value = foldFn(left.getVecComponent<T>(i), right.getVecComponent<T>(i));
+            U value = foldFn(left.getConstantSubexpression(i)->as<Literal<T>>().value(),
+                             right.getConstantSubexpression(i)->as<Literal<T>>().value());
             args.push_back(Literal<T>::Make(left.fOffset, value, &componentType));
         }
         auto foldedCtor = Constructor::Convert(context, left.fOffset, type, std::move(args));
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index de2be9e..6927e34 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -1188,25 +1188,13 @@
     SPIRVVectorConstant key{this->getType(type),
                             /*fValueId=*/{SpvId(-1), SpvId(-1), SpvId(-1), SpvId(-1)}};
 
-    if (c.componentType().isFloat()) {
-        for (int i = 0; i < type.columns(); i++) {
-            FloatLiteral literal(c.fOffset, c.getFVecComponent(i), &c.componentType());
-            key.fValueId[i] = this->writeFloatLiteral(literal);
+    for (int n = 0; n < type.columns(); n++) {
+        const Expression* expr = c.getConstantSubexpression(n);
+        if (!expr) {
+            SkDEBUGFAILF("writeConstantVector: %s not actually constant", c.description().c_str());
+            return (SpvId)-1;
         }
-    } else if (c.componentType().isInteger()) {
-        for (int i = 0; i < type.columns(); i++) {
-            IntLiteral literal(c.fOffset, c.getIVecComponent(i), &c.componentType());
-            key.fValueId[i] = this->writeIntLiteral(literal);
-        }
-    } else if (c.componentType().isBoolean()) {
-        for (int i = 0; i < type.columns(); i++) {
-            BoolLiteral literal(c.fOffset, c.getBVecComponent(i), &c.componentType());
-            key.fValueId[i] = this->writeBoolLiteral(literal);
-        }
-    } else {
-        SkDEBUGFAILF("unexpected vector component type: %s",
-                     c.componentType().displayName().c_str());
-        return SpvId(-1);
+        key.fValueId[n] = this->writeExpression(*expr, fConstantBuffer);
     }
 
     // Check to see if we've already synthesized this vector constant.
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index dbdd129..1f8c350 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -65,10 +65,6 @@
                                                                 : ComparisonResult::kNotEqual;
     }
 
-    bool getConstantBool() const override {
-        return this->value();
-    }
-
     std::unique_ptr<Expression> clone() const override {
         return std::make_unique<BoolLiteral>(fOffset, this->value(), &this->type());
     }
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index f489f80..e524c7e 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -175,127 +175,6 @@
     return std::make_unique<Constructor>(offset, type, std::move(args));
 }
 
-template <typename ResultType>
-ResultType Constructor::getConstantValue(const Expression& expr) const {
-    const Type& type = expr.type();
-    SkASSERT(type.isScalar());
-    if (type.isFloat()) {
-        return ResultType(expr.getConstantFloat());
-    } else if (type.isInteger()) {
-        return ResultType(expr.getConstantInt());
-    } else if (type.isBoolean()) {
-        return ResultType(expr.getConstantBool());
-    }
-    SkDEBUGFAILF("unrecognized kind of constant value: %s", expr.description().c_str());
-    return ResultType(0);
-}
-
-template <typename ResultType>
-ResultType Constructor::getInnerVecComponent(const Expression& expr, int position) const {
-    const Type& type = expr.type().componentType();
-    if (type.isFloat()) {
-        return ResultType(expr.getVecComponent<SKSL_FLOAT>(position));
-    } else if (type.isInteger()) {
-        return ResultType(expr.getVecComponent<SKSL_INT>(position));
-    } else if (type.isBoolean()) {
-        return ResultType(expr.getVecComponent<bool>(position));
-    }
-    SkDEBUGFAILF("unrecognized type of constant: %s", expr.description().c_str());
-    return ResultType(0);
-};
-
-template <typename ResultType>
-ResultType Constructor::getVecComponent(int index) const {
-    static_assert(std::is_same<ResultType, SKSL_FLOAT>::value ||
-                  std::is_same<ResultType, SKSL_INT>::value ||
-                  std::is_same<ResultType, bool>::value);
-
-    SkASSERT(this->type().isVector());
-    SkASSERT(this->isCompileTimeConstant());
-
-    if (this->arguments().size() == 1 &&
-        this->arguments()[0]->type().isScalar()) {
-        // This constructor just wraps a scalar. Propagate out the value.
-        return this->getConstantValue<ResultType>(*this->arguments()[0]);
-    }
-
-    // Walk through all the constructor arguments until we reach the index we're searching for.
-    int current = 0;
-    for (const std::unique_ptr<Expression>& arg : this->arguments()) {
-        if (current > index) {
-            // Somehow, we went past the argument we're looking for. Bail.
-            break;
-        }
-
-        if (arg->type().isScalar()) {
-            if (index == current) {
-                // We're on the proper argument, and it's a scalar; fetch it.
-                return this->getConstantValue<ResultType>(*arg);
-            }
-            current++;
-            continue;
-        }
-
-        if (arg->type().isVector()) {
-            if (current + arg->type().columns() > index) {
-                // We've found an expression that encompasses the proper argument. Descend into it.
-                return this->getInnerVecComponent<ResultType>(*arg, index - current);
-            }
-        }
-
-        current += arg->type().columns();
-    }
-
-    SkDEBUGFAILF("failed to find vector component %d in %s\n", index, description().c_str());
-    return ResultType(0);
-}
-
-template SKSL_INT Constructor::getVecComponent(int) const;
-template SKSL_FLOAT Constructor::getVecComponent(int) const;
-template bool Constructor::getVecComponent(int) const;
-
-SKSL_INT Constructor::getConstantInt() const {
-    // We're looking for scalar integer constructors only, i.e. `int(1)`.
-    SkASSERT(this->arguments().size() == 1);
-    SkASSERT(this->type().columns() == 1);
-    SkASSERT(this->type().isInteger());
-
-    // This might be a cast, meaning the inner argument would actually be a different scalar type.
-    const Expression& expr = *this->arguments().front();
-    SkASSERT(expr.type().isInteger() || expr.type().isFloat() || expr.type().isBoolean());
-    return expr.type().isInteger() ? expr.getConstantInt() :
-             expr.type().isFloat() ? (SKSL_INT)expr.getConstantFloat() :
-                                     (SKSL_INT)expr.getConstantBool();
-}
-
-SKSL_FLOAT Constructor::getConstantFloat() const {
-    // We're looking for scalar integer constructors only, i.e. `float(1.0)`.
-    SkASSERT(this->arguments().size() == 1);
-    SkASSERT(this->type().columns() == 1);
-    SkASSERT(this->type().isFloat());
-
-    // This might be a cast, meaning the inner argument would actually be a different scalar type.
-    const Expression& expr = *this->arguments().front();
-    SkASSERT(expr.type().isInteger() || expr.type().isFloat() || expr.type().isBoolean());
-    return   expr.type().isFloat() ? expr.getConstantFloat() :
-           expr.type().isInteger() ? (SKSL_FLOAT)expr.getConstantInt() :
-                                     (SKSL_FLOAT)expr.getConstantBool();
-}
-
-bool Constructor::getConstantBool() const {
-    // We're looking for scalar Boolean constructors only, i.e. `bool(true)`.
-    SkASSERT(this->arguments().size() == 1);
-    SkASSERT(this->type().columns() == 1);
-    SkASSERT(this->type().isBoolean());
-
-    // This might be a cast, meaning the inner argument would actually be a different scalar type.
-    const Expression& expr = *this->arguments().front();
-    SkASSERT(expr.type().isInteger() || expr.type().isFloat() || expr.type().isBoolean());
-    return expr.type().isBoolean() ? expr.getConstantBool() :
-           expr.type().isInteger() ? (bool)expr.getConstantInt() :
-                                     (bool)expr.getConstantFloat();
-}
-
 const Expression* AnyConstructor::getConstantSubexpression(int n) const {
     SkASSERT(n >= 0 && n < (int)this->type().slotCount());
     for (const std::unique_ptr<Expression>& arg : this->argumentSpan()) {
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 7ff91c9..37e61db 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -187,42 +187,6 @@
         return std::make_unique<Constructor>(fOffset, this->type(), this->cloneArguments());
     }
 
-    template <typename ResultType>
-    ResultType getVecComponent(int index) const;
-
-    /**
-     * For a literal vector expression, return the float value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a compile-time constant vector of
-     * floating-point type.
-     */
-    SKSL_FLOAT getFVecComponent(int n) const override {
-        return this->getVecComponent<SKSL_FLOAT>(n);
-    }
-
-    /**
-     * For a literal vector expression, return the integer value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a compile-time constant vector of
-     * integer type.
-     */
-    SKSL_INT getIVecComponent(int n) const override {
-        return this->getVecComponent<SKSL_INT>(n);
-    }
-
-    /**
-     * For a literal vector expression, return the boolean value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a compile-time constant vector of
-     * Boolean type.
-     */
-    bool getBVecComponent(int n) const override {
-        return this->getVecComponent<bool>(n);
-    }
-
-    SKSL_INT getConstantInt() const override;
-
-    SKSL_FLOAT getConstantFloat() const override;
-
-    bool getConstantBool() const override;
-
 private:
     static std::unique_ptr<Expression> MakeScalarConstructor(const Context& context,
                                                              int offset,
@@ -234,11 +198,6 @@
                                                                const Type& type,
                                                                ExpressionArray args);
 
-    template <typename ResultType> ResultType getConstantValue(const Expression& expr) const;
-
-    template <typename ResultType>
-    ResultType getInnerVecComponent(const Expression& expr, int position) const;
-
     using INHERITED = MultiArgumentConstructor;
 };
 
diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h
index 5fb2a45..4a27e00 100644
--- a/src/sksl/ir/SkSLConstructorSplat.h
+++ b/src/sksl/ir/SkSLConstructorSplat.h
@@ -38,18 +38,6 @@
         return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
     }
 
-    SKSL_FLOAT getFVecComponent(int) const override {
-        return this->argument()->getConstantFloat();
-    }
-
-    SKSL_INT getIVecComponent(int) const override {
-        return this->argument()->getConstantInt();
-    }
-
-    bool getBVecComponent(int) const override {
-        return this->argument()->getConstantBool();
-    }
-
     const Expression* getConstantSubexpression(int n) const override {
         SkASSERT(n >= 0 && n < this->type().columns());
         return this->argument()->getConstantSubexpression(0);
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 7b53868..982d56b 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -132,30 +132,6 @@
     }
 
     /**
-     * For an expression which evaluates to a constant int, returns the value. Otherwise calls
-     * SK_ABORT.
-     */
-    virtual SKSL_INT getConstantInt() const {
-        SK_ABORT("not a constant int");
-    }
-
-    /**
-     * For an expression which evaluates to a constant float, returns the value. Otherwise calls
-     * SK_ABORT.
-     */
-    virtual SKSL_FLOAT getConstantFloat() const {
-        SK_ABORT("not a constant float");
-    }
-
-    /**
-     * For an expression which evaluates to a constant Boolean, returns the value. Otherwise calls
-     * SK_ABORT.
-     */
-    virtual bool getConstantBool() const {
-        SK_ABORT("not a constant Boolean");
-    }
-
-    /**
      * Returns true if, given fixed values for uniforms, this expression always evaluates to the
      * same result with no side effects.
      */
@@ -190,43 +166,6 @@
         return nullptr;
     }
 
-    /**
-     * For a vector of floating point values, return the value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a vector of floating-point
-     * constant expressions.
-     */
-    virtual SKSL_FLOAT getFVecComponent(int n) const {
-        SkDEBUGFAILF("expression does not support getVecComponent: %s",
-                     this->description().c_str());
-        return 0;
-    }
-
-    /**
-     * For a vector of integer values, return the value of the n'th vector component. It is an error
-     * to call this method on an expression which is not a vector of integer constant expressions.
-     */
-    virtual SKSL_INT getIVecComponent(int n) const {
-        SkDEBUGFAILF("expression does not support getVecComponent: %s",
-                     this->description().c_str());
-        return 0;
-    }
-
-    /**
-     * For a vector of Boolean values, return the value of the n'th vector component. It is an error
-     * to call this method on an expression which is not a vector of Boolean constant expressions.
-     */
-    virtual bool getBVecComponent(int n) const {
-        SkDEBUGFAILF("expression does not support getVecComponent: %s",
-                     this->description().c_str());
-        return false;
-    }
-
-    /**
-     * For a vector of literals, return the value of the n'th vector component. It is an error to
-     * call this method on an expression which is not a vector of Literal<T>.
-     */
-    template <typename T> T getVecComponent(int index) const;
-
     virtual std::unique_ptr<Expression> clone() const = 0;
 
 private:
@@ -235,18 +174,6 @@
     using INHERITED = IRNode;
 };
 
-template <> inline SKSL_FLOAT Expression::getVecComponent<SKSL_FLOAT>(int index) const {
-    return this->getFVecComponent(index);
-}
-
-template <> inline SKSL_INT Expression::getVecComponent<SKSL_INT>(int index) const {
-    return this->getIVecComponent(index);
-}
-
-template <> inline bool Expression::getVecComponent<bool>(int index) const {
-    return this->getBVecComponent(index);
-}
-
 }  // namespace SkSL
 
 #endif
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 40eebc2..184af12 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -71,10 +71,6 @@
                                                                  : ComparisonResult::kNotEqual;
     }
 
-    SKSL_FLOAT getConstantFloat() const override {
-        return this->value();
-    }
-
     std::unique_ptr<Expression> clone() const override {
         return std::make_unique<FloatLiteral>(fOffset, this->value(), &this->type());
     }
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index b9763a7..d5a2cc5 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -73,10 +73,6 @@
         return INHERITED::coercionCost(target);
     }
 
-    SKSL_INT getConstantInt() const override {
-        return this->value();
-    }
-
     std::unique_ptr<Expression> clone() const override {
         return std::make_unique<IntLiteral>(fOffset, this->value(), &this->type());
     }