Replace getMatExpression with getConstantSubexpression.
This approach gives us similar flexibility but requires fewer lines of
code to get the same result. In a followup CL we will be able to
eliminate get[BFI]VecExpression as well. This approach will also scale
to arrays and structs if we want to support constant-folding on these.
Change-Id: Ib0034935926c7004f84ba62ddbdb3168df8ce91d
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/393076
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index bb086eb..dbdd129 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -73,6 +73,11 @@
return std::make_unique<BoolLiteral>(fOffset, this->value(), &this->type());
}
+ const Expression* getConstantSubexpression(int n) const override {
+ SkASSERT(n == 0);
+ return this;
+ }
+
private:
bool fValue;
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 9cfa6fe..f489f80 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -175,65 +175,6 @@
return std::make_unique<Constructor>(offset, type, std::move(args));
}
-Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
- if (other.is<ConstructorDiagonalMatrix>()) {
- return other.compareConstant(*this);
- }
- if (other.is<ConstructorMatrixResize>()) {
- return other.compareConstant(*this);
- }
- if (other.is<ConstructorSplat>()) {
- return other.compareConstant(*this);
- }
- if (!other.is<Constructor>()) {
- return ComparisonResult::kUnknown;
- }
- const Constructor& c = other.as<Constructor>();
- const Type& myType = this->type();
- SkASSERT(myType == c.type());
-
- if (myType.isVector()) {
- if (myType.componentType().isFloat()) {
- for (int i = 0; i < myType.columns(); i++) {
- if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
- return ComparisonResult::kNotEqual;
- }
- }
- return ComparisonResult::kEqual;
- }
- if (myType.componentType().isInteger()) {
- for (int i = 0; i < myType.columns(); i++) {
- if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
- return ComparisonResult::kNotEqual;
- }
- }
- return ComparisonResult::kEqual;
- }
- if (myType.componentType().isBoolean()) {
- for (int i = 0; i < myType.columns(); i++) {
- if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
- return ComparisonResult::kNotEqual;
- }
- }
- return ComparisonResult::kEqual;
- }
- }
-
- if (myType.isMatrix()) {
- for (int col = 0; col < myType.columns(); col++) {
- for (int row = 0; row < myType.rows(); row++) {
- if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
- return ComparisonResult::kNotEqual;
- }
- }
- }
- return ComparisonResult::kEqual;
- }
-
- SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
- return ComparisonResult::kUnknown;
-}
-
template <typename ResultType>
ResultType Constructor::getConstantValue(const Expression& expr) const {
const Type& type = expr.type();
@@ -313,50 +254,6 @@
template SKSL_FLOAT Constructor::getVecComponent(int) const;
template bool Constructor::getVecComponent(int) const;
-SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
- SkDEBUGCODE(const Type& myType = this->type();)
- SkASSERT(this->isCompileTimeConstant());
- SkASSERT(myType.isMatrix());
- SkASSERT(col < myType.columns() && row < myType.rows());
- if (this->arguments().size() == 1) {
- const Type& argType = this->arguments()[0]->type();
- if (argType.isScalar()) {
- // single scalar argument, so matrix is of the form:
- // x 0 0
- // 0 x 0
- // 0 0 x
- // return x if col == row
- return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
- }
- if (argType.isMatrix()) {
- SkASSERT(this->arguments()[0]->isAnyConstructor());
- // single matrix argument. make sure we're within the argument's bounds.
- if (col < argType.columns() && row < argType.rows()) {
- // within bounds, defer to argument
- return this->arguments()[0]->getMatComponent(col, row);
- }
- // out of bounds
- return 0.0;
- }
- }
- int currentIndex = 0;
- int targetIndex = col * this->type().rows() + row;
- for (const auto& arg : this->arguments()) {
- const Type& argType = arg->type();
- SkASSERT(targetIndex >= currentIndex);
- SkASSERT(argType.rows() == 1);
- if (currentIndex + argType.columns() > targetIndex) {
- if (argType.columns() == 1) {
- return arg->getConstantFloat();
- } else {
- return arg->getFVecComponent(targetIndex - currentIndex);
- }
- }
- currentIndex += argType.columns();
- }
- SK_ABORT("can't happen, matrix component out of bounds");
-}
-
SKSL_INT Constructor::getConstantInt() const {
// We're looking for scalar integer constructors only, i.e. `int(1)`.
SkASSERT(this->arguments().size() == 1);
@@ -399,6 +296,45 @@
(bool)expr.getConstantFloat();
}
+const Expression* AnyConstructor::getConstantSubexpression(int n) const {
+ SkASSERT(n >= 0 && n < (int)this->type().slotCount());
+ for (const std::unique_ptr<Expression>& arg : this->argumentSpan()) {
+ int argSlots = arg->type().slotCount();
+ if (n < argSlots) {
+ return arg->getConstantSubexpression(n);
+ }
+ n -= argSlots;
+ }
+
+ SkDEBUGFAIL("argument-list slot count doesn't match constructor-type slot count");
+ return nullptr;
+}
+
+Expression::ComparisonResult AnyConstructor::compareConstant(const Expression& other) const {
+ ComparisonResult result = ComparisonResult::kEqual;
+ SkASSERT(this->type().slotCount() == other.type().slotCount());
+
+ int exprs = this->type().slotCount();
+ for (int n = 0; n < exprs; ++n) {
+ // Get the n'th subexpression from each side. If either one is null, return "unknown."
+ const Expression* left = this->getConstantSubexpression(n);
+ if (!left) {
+ return ComparisonResult::kUnknown;
+ }
+ const Expression* right = other.getConstantSubexpression(n);
+ if (!right) {
+ return ComparisonResult::kUnknown;
+ }
+ // Recurse into the subexpressions; the literal types will perform real comparisons, and
+ // most other expressions fall back on the base class Expression which returns unknown.
+ result = left->compareConstant(*right);
+ if (result != ComparisonResult::kEqual) {
+ break;
+ }
+ }
+ return result;
+}
+
AnyConstructor& Expression::asAnyConstructor() {
SkASSERT(this->isAnyConstructor());
return static_cast<AnyConstructor&>(*this);
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index d7c6c3a..7ff91c9 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -71,6 +71,10 @@
return true;
}
+ const Expression* getConstantSubexpression(int n) const override;
+
+ ComparisonResult compareConstant(const Expression& other) const override;
+
private:
std::unique_ptr<Expression> fArgument;
@@ -183,8 +187,6 @@
return std::make_unique<Constructor>(fOffset, this->type(), this->cloneArguments());
}
- ComparisonResult compareConstant(const Expression& other) const override;
-
template <typename ResultType>
ResultType getVecComponent(int index) const;
@@ -215,8 +217,6 @@
return this->getVecComponent<bool>(n);
}
- SKSL_FLOAT getMatComponent(int col, int row) const override;
-
SKSL_INT getConstantInt() const override;
SKSL_FLOAT getConstantFloat() const override;
diff --git a/src/sksl/ir/SkSLConstructorArray.cpp b/src/sksl/ir/SkSLConstructorArray.cpp
index 4517b75..c305fb4 100644
--- a/src/sksl/ir/SkSLConstructorArray.cpp
+++ b/src/sksl/ir/SkSLConstructorArray.cpp
@@ -56,22 +56,4 @@
return std::make_unique<ConstructorArray>(offset, type, std::move(args));
}
-Expression::ComparisonResult ConstructorArray::compareConstant(const Expression& other) const {
- // There is only one array-constructor type, so if this comparison had type-checked
- // successfully, `other` should be a ConstructorArray with the same array size.
- const ConstructorArray& otherArray = other.as<ConstructorArray>();
- int numColumns = this->type().columns();
- SkASSERT(numColumns == otherArray.type().columns());
-
- ComparisonResult check = ComparisonResult::kEqual;
- for (int index = 0; index < numColumns; index++) {
- check = this->arguments()[index]->compareConstant(*otherArray.arguments()[index]);
- if (check != ComparisonResult::kEqual) {
- break;
- }
- }
-
- return check;
-}
-
} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorArray.h b/src/sksl/ir/SkSLConstructorArray.h
index a2b154d..139de52 100644
--- a/src/sksl/ir/SkSLConstructorArray.h
+++ b/src/sksl/ir/SkSLConstructorArray.h
@@ -39,8 +39,6 @@
return std::make_unique<ConstructorArray>(fOffset, this->type(), this->cloneArguments());
}
- ComparisonResult compareConstant(const Expression& other) const override;
-
private:
using INHERITED = MultiArgumentConstructor;
};
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
index ccdcdef..97cf760 100644
--- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
@@ -21,35 +21,17 @@
return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
}
-Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
- const Expression& other) const {
- SkASSERT(other.type().isMatrix());
- SkASSERT(this->type() == other.type());
+const Expression* ConstructorDiagonalMatrix::getConstantSubexpression(int n) const {
+ int rows = this->type().rows();
+ int row = n % rows;
+ int col = n / rows;
- // The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
- for (int col = 0; col < this->type().columns(); col++) {
- for (int row = 0; row < this->type().rows(); row++) {
- if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
- return ComparisonResult::kNotEqual;
- }
- }
- }
-
- return ComparisonResult::kEqual;
-}
-
-SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
- SkASSERT(this->isCompileTimeConstant());
SkASSERT(col >= 0);
SkASSERT(row >= 0);
SkASSERT(col < this->type().columns());
SkASSERT(row < this->type().rows());
- // Our matrix is of the form:
- // |x 0 0|
- // |0 x 0|
- // |0 0 x|
- return (col == row) ? this->argument()->getConstantFloat() : 0.0;
+ return (col == row) ? this->argument()->getConstantSubexpression(0) : &fZeroLiteral;
}
} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
index 3e87051..b9fd229 100644
--- a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
@@ -27,7 +27,8 @@
static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
- : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+ : INHERITED(offset, kExpressionKind, &type, std::move(arg))
+ , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType()) {}
static std::unique_ptr<Expression> Make(const Context& context,
int offset,
@@ -39,11 +40,10 @@
argument()->clone());
}
- ComparisonResult compareConstant(const Expression& other) const override;
-
- SKSL_FLOAT getMatComponent(int col, int row) const override;
+ const Expression* getConstantSubexpression(int n) const override;
private:
+ const FloatLiteral fZeroLiteral;
using INHERITED = SingleArgumentConstructor;
};
diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.cpp b/src/sksl/ir/SkSLConstructorMatrixResize.cpp
index e7e41f3..4c8b374 100644
--- a/src/sksl/ir/SkSLConstructorMatrixResize.cpp
+++ b/src/sksl/ir/SkSLConstructorMatrixResize.cpp
@@ -27,27 +27,11 @@
return std::make_unique<ConstructorMatrixResize>(offset, type, std::move(arg));
}
-Expression::ComparisonResult ConstructorMatrixResize::compareConstant(
- const Expression& other) const {
- SkASSERT(other.type().isMatrix());
- SkASSERT(this->type() == other.type());
- SkASSERT(this->type().rows() == other.type().rows());
- SkASSERT(this->type().columns() == other.type().columns());
+const Expression* ConstructorMatrixResize::getConstantSubexpression(int n) const {
+ int rows = this->type().rows();
+ int row = n % rows;
+ int col = n / rows;
- // Check each cell individually.
- for (int col = 0; col < this->type().columns(); col++) {
- for (int row = 0; row < this->type().rows(); row++) {
- if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
- return ComparisonResult::kNotEqual;
- }
- }
- }
-
- return ComparisonResult::kEqual;
-}
-
-SKSL_FLOAT ConstructorMatrixResize::getMatComponent(int col, int row) const {
- SkASSERT(this->isCompileTimeConstant());
SkASSERT(col >= 0);
SkASSERT(row >= 0);
SkASSERT(col < this->type().columns());
@@ -59,13 +43,15 @@
// |0 0 1|
// Where `m` is the matrix being wrapped, and other cells contain the identity matrix.
- // Forward `getMatComponent` to the wrapped matrix if the position is in its bounds.
+ // Forward `getConstantSubexpression` to the wrapped matrix if the position is in its bounds.
if (col < this->argument()->type().columns() && row < this->argument()->type().rows()) {
- return this->argument()->getMatComponent(col, row);
+ // Recalculate `n` in terms of the inner matrix's dimensions.
+ n = row + (col * this->argument()->type().rows());
+ return this->argument()->getConstantSubexpression(n);
}
// Synthesize an identity matrix for out-of-bounds positions.
- return (col == row) ? 1.0f : 0.0f;
+ return (col == row) ? &fOneLiteral : &fZeroLiteral;
}
} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorMatrixResize.h b/src/sksl/ir/SkSLConstructorMatrixResize.h
index 2a18403..0036180 100644
--- a/src/sksl/ir/SkSLConstructorMatrixResize.h
+++ b/src/sksl/ir/SkSLConstructorMatrixResize.h
@@ -28,7 +28,9 @@
static constexpr Kind kExpressionKind = Kind::kConstructorMatrixResize;
ConstructorMatrixResize(int offset, const Type& type, std::unique_ptr<Expression> arg)
- : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+ : INHERITED(offset, kExpressionKind, &type, std::move(arg))
+ , fZeroLiteral(offset, /*value=*/0.0f, &type.componentType())
+ , fOneLiteral(offset, /*value=*/1.0f, &type.componentType()) {}
static std::unique_ptr<Expression> Make(const Context& context,
int offset,
@@ -40,12 +42,12 @@
argument()->clone());
}
- Expression::ComparisonResult compareConstant(const Expression& other) const override;
-
- SKSL_FLOAT getMatComponent(int col, int row) const override;
+ const Expression* getConstantSubexpression(int n) const override;
private:
using INHERITED = SingleArgumentConstructor;
+ const FloatLiteral fZeroLiteral;
+ const FloatLiteral fOneLiteral;
};
} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.cpp b/src/sksl/ir/SkSLConstructorSplat.cpp
index d8d0062..4d07288 100644
--- a/src/sksl/ir/SkSLConstructorSplat.cpp
+++ b/src/sksl/ir/SkSLConstructorSplat.cpp
@@ -25,29 +25,4 @@
return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
}
-Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
- SkASSERT(this->type() == other.type());
- if (!other.isAnyConstructor()) {
- return ComparisonResult::kUnknown;
- }
-
- return this->compareConstantConstructor(other.asAnyConstructor());
-}
-
-Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
- const AnyConstructor& other) const {
- ComparisonResult check = ComparisonResult::kEqual;
- for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
- // We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
- check = expr->isAnyConstructor()
- ? this->compareConstantConstructor(expr->asAnyConstructor())
- : argument()->compareConstant(*expr);
- if (check != ComparisonResult::kEqual) {
- break;
- }
- }
-
- return check;
-}
-
} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h
index 4ca3a07..5fb2a45 100644
--- a/src/sksl/ir/SkSLConstructorSplat.h
+++ b/src/sksl/ir/SkSLConstructorSplat.h
@@ -38,8 +38,6 @@
return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
}
- ComparisonResult compareConstant(const Expression& other) const override;
-
SKSL_FLOAT getFVecComponent(int) const override {
return this->argument()->getConstantFloat();
}
@@ -52,9 +50,12 @@
return this->argument()->getConstantBool();
}
-private:
- Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
+ const Expression* getConstantSubexpression(int n) const override {
+ SkASSERT(n >= 0 && n < this->type().columns());
+ return this->argument()->getConstantSubexpression(0);
+ }
+private:
using INHERITED = SingleArgumentConstructor;
};
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index c332992..7b53868 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -179,6 +179,18 @@
}
/**
+ * Returns the n'th compile-time constant expression within a literal or constructor.
+ * Use Type::slotCount to determine the number of subexpressions within an expression.
+ * Subexpressions which are not compile-time constants will return null.
+ * `vec4(1, vec2(2), 3)` contains four subexpressions: (1, 2, 2, 3)
+ * `mat2(f)` contains four subexpressions: (null, 0,
+ * 0, null)
+ */
+ virtual const Expression* getConstantSubexpression(int n) const {
+ return nullptr;
+ }
+
+ /**
* For a vector of floating point values, return the value of the n'th vector component. It is
* an error to call this method on an expression which is not a vector of floating-point
* constant expressions.
@@ -215,16 +227,6 @@
*/
template <typename T> T getVecComponent(int index) const;
- /**
- * For a literal matrix expression, return the floating point value of the component at
- * [col][row]. It is an error to call this method on an expression which is not a literal
- * matrix.
- */
- virtual SKSL_FLOAT getMatComponent(int col, int row) const {
- SkASSERT(false);
- return 0;
- }
-
virtual std::unique_ptr<Expression> clone() const = 0;
private:
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 4cc8eb9..40eebc2 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -79,6 +79,11 @@
return std::make_unique<FloatLiteral>(fOffset, this->value(), &this->type());
}
+ const Expression* getConstantSubexpression(int n) const override {
+ SkASSERT(n == 0);
+ return this;
+ }
+
private:
float fValue;
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 3505079..b9763a7 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -81,6 +81,11 @@
return std::make_unique<IntLiteral>(fOffset, this->value(), &this->type());
}
+ const Expression* getConstantSubexpression(int n) const override {
+ SkASSERT(n == 0);
+ return this;
+ }
+
private:
SKSL_INT fValue;