sksl can now fold constant vector or matrix equality expressions

Bug: skia:
Change-Id: Icaddae68e53ed3629bcdc04b5f0b541d9e4398e2
Reviewed-on: https://skia-review.googlesource.com/14260
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ben Wagner <benjaminwagner@google.com>
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 56858a9..523b7a0 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1043,6 +1043,12 @@
             return std::unique_ptr<Expression>(new Constructor(Position(), left.fType, \
                                                                std::move(args)));
         switch (op) {
+            case Token::EQEQ:
+                return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
+                                                            left.compareConstant(fContext, right)));
+            case Token::NEQ:
+                return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
+                                                           !left.compareConstant(fContext, right)));
             case Token::PLUS:  RETURN_VEC_COMPONENTWISE_RESULT(+);
             case Token::MINUS: RETURN_VEC_COMPONENTWISE_RESULT(-);
             case Token::STAR:  RETURN_VEC_COMPONENTWISE_RESULT(*);
@@ -1050,6 +1056,20 @@
             default:           return nullptr;
         }
     }
+    if (left.fType.kind() == Type::kMatrix_Kind &&
+        right.fType.kind() == Type::kMatrix_Kind &&
+        left.fKind == right.fKind) {
+        switch (op) {
+            case Token::EQEQ:
+                return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
+                                                            left.compareConstant(fContext, right)));
+            case Token::NEQ:
+                return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(),
+                                                           !left.compareConstant(fContext, right)));
+            default:
+                return nullptr;
+        }
+    }
     #undef RESULT
     return nullptr;
 }
diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h
index 13203a4..a4151b8 100644
--- a/src/sksl/ir/SkSLBoolLiteral.h
+++ b/src/sksl/ir/SkSLBoolLiteral.h
@@ -33,6 +33,11 @@
         return true;
     }
 
+    bool compareConstant(const Context& context, const Expression& other) const override {
+        BoolLiteral& b = (BoolLiteral&) other;
+        return fValue == b.fValue;
+    }
+
     const bool fValue;
 
     typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 208031a..05f4096 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -81,6 +81,44 @@
         return true;
     }
 
+    bool compareConstant(const Context& context, const Expression& other) const override {
+        ASSERT(other.fKind == Expression::kConstructor_Kind && other.fType == fType);
+        Constructor& c = (Constructor&) other;
+        if (c.fType.kind() == Type::kVector_Kind) {
+            for (int i = 0; i < fType.columns(); i++) {
+                if (!this->getVecComponent(i).compareConstant(context, c.getVecComponent(i))) {
+                    return false;
+                }
+            }
+            return true;
+        }
+        // shouldn't be possible to have a constant constructor that isn't a vector or matrix;
+        // a constant scalar constructor should have been collapsed down to the appropriate
+        // literal
+        ASSERT(fType.kind() == Type::kMatrix_Kind);
+        const FloatLiteral fzero(context, Position(), 0);
+        const IntLiteral izero(context, Position(), 0);
+        const Expression* zero;
+        if (fType.componentType() == *context.fFloat_Type) {
+            zero = &fzero;
+        } else {
+            ASSERT(fType.componentType() == *context.fInt_Type);
+            zero = &izero;
+        }
+        for (int col = 0; col < fType.columns(); col++) {
+            for (int row = 0; row < fType.rows(); row++) {
+                const Expression* component1 = getMatComponent(col, row);
+                const Expression* component2 = c.getMatComponent(col, row);
+                if (!(component1 ? component1 : zero)->compareConstant(
+                                                                context,
+                                                                component2 ? *component2 : *zero)) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
     const Expression& getVecComponent(int index) const {
         ASSERT(fType.kind() == Type::kVector_Kind);
         if (fArguments.size() == 1 && fArguments[0]->fType.kind() == Type::kScalar_Kind) {
@@ -118,6 +156,51 @@
         return ((IntLiteral&) c).fValue;
     }
 
+    // null return should be interpreted as zero
+    const Expression* getMatComponent(int col, int row) const {
+        ASSERT(this->isConstant());
+        ASSERT(fType.kind() == Type::kMatrix_Kind);
+        ASSERT(col < fType.columns() && row < fType.rows());
+        if (fArguments.size() == 1) {
+            if (fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+                // single scalar argument, so matrix is of the form:
+                // x 0 0
+                // 0 x 0
+                // 0 0 x
+                // return x if col == row
+                return col == row ? fArguments[0].get() : nullptr;
+            }
+            if (fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
+                ASSERT(fArguments[0]->fKind == Expression::kConstructor_Kind);
+                // single matrix argument. make sure we're within the argument's bounds.
+                const Type& argType = ((Constructor&) *fArguments[0]).fType;
+                if (col < argType.columns() && row < argType.rows()) {
+                    // within bounds, defer to argument
+                    return ((Constructor&) *fArguments[0]).getMatComponent(col, row);
+                }
+                // out of bounds, return 0
+                return nullptr;
+            }
+        }
+        int currentIndex = 0;
+        int targetIndex = col * fType.rows() + row;
+        for (const auto& arg : fArguments) {
+            ASSERT(targetIndex >= currentIndex);
+            ASSERT(arg->fType.rows() == 1);
+            if (currentIndex + arg->fType.columns() > targetIndex) {
+                if (arg->fType.columns() == 1) {
+                    return arg.get();
+                } else {
+                    ASSERT(arg->fType.kind() == Type::kVector_Kind);
+                    ASSERT(arg->fKind == Expression::kConstructor_Kind);
+                    return &((Constructor&) *arg).getVecComponent(targetIndex - currentIndex);
+                }
+            }
+            currentIndex += arg->fType.columns();
+        }
+        ABORT("can't happen, matrix component out of bounds");
+    }
+
     std::vector<std::unique_ptr<Expression>> fArguments;
 
     typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 5db9ddf..07dad1d 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -48,11 +48,24 @@
     , fKind(kind)
     , fType(std::move(type)) {}
 
+    /**
+     * Returns true if this expression is constant. compareConstant must be implemented for all
+     * constants!
+     */
     virtual bool isConstant() const {
         return false;
     }
 
     /**
+     * Compares this constant expression against another constant expression of the same type. It is
+     * an error to call this on non-constant expressions, or if the types of the expressions do not
+     * match.
+     */
+    virtual bool compareConstant(const Context& context, const Expression& other) const {
+        ABORT("cannot call compareConstant on this type");
+    }
+
+    /**
      * Returns true if evaluating the expression potentially has side effects. Expressions may never
      * return false if they actually have side effects, but it is legal (though suboptimal) to
      * return true if there are not actually any side effects.
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 8f83e28..21a485f 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -34,6 +34,11 @@
         return true;
     }
 
+    bool compareConstant(const Context& context, const Expression& other) const override {
+        FloatLiteral& f = (FloatLiteral&) other;
+        return fValue == f.fValue;
+    }
+
     const double fValue;
 
     typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 3a95ed6..d8eba55 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -35,6 +35,11 @@
         return true;
     }
 
+    bool compareConstant(const Context& context, const Expression& other) const override {
+        IntLiteral& i = (IntLiteral&) other;
+        return fValue == i.fValue;
+    }
+
     const int64_t fValue;
 
     typedef Expression INHERITED;
diff --git a/tests/SkSLErrorTest.cpp b/tests/SkSLErrorTest.cpp
index bd0c64a..47b9af8 100644
--- a/tests/SkSLErrorTest.cpp
+++ b/tests/SkSLErrorTest.cpp
@@ -125,6 +125,9 @@
     test_failure(r,
                  "struct foo { int x; } foo; void main() { vec2 x = vec2(foo); }",
                  "error: 1: 'foo' is not a valid parameter to 'vec2' constructor\n1 error\n");
+    test_failure(r,
+                 "void main() { mat2 x = mat2(true); }",
+                 "error: 1: expected 'float', but found 'bool'\n1 error\n");
 }
 
 DEF_TEST(SkSLConstructorArgumentCount, r) {
diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp
index 97d7acb..1dc522b 100644
--- a/tests/SkSLGLSLTest.cpp
+++ b/tests/SkSLGLSLTest.cpp
@@ -573,6 +573,34 @@
          "sk_FragColor = vec4(2) * vec4(1, 2, 3, 4);"
          "sk_FragColor = vec4(12) / vec4(1, 2, 3, 4);"
          "sk_FragColor.r = (vec4(12) / vec4(1, 2, 3, 4)).y;"
+         "sk_FragColor.x = vec4(1) == vec4(1) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec4(1) == vec4(2) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec2(1) == vec2(1, 1) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec2(1, 1) == vec2(1, 1) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec2(1) == vec2(1, 0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec4(1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), 1, 0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) == "
+                          "mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(1.0, 1.0)) == "
+                          "mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1) == mat2(1) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1) == mat2(0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(2) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat3x2(2) == mat3x2(vec2(2.0, 0.0), vec2(0.0, 2.0), vec2(0.0)) ? "
+                                                                                        "1.0 : 0.0;"
+         "sk_FragColor.x = vec2(1) != vec2(1, 0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = vec4(1) != vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1) != mat2(1) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1) != mat2(0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 0.0)) == "
+                          "mat3(mat2(1.0)) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(mat3(1.0)) == mat2(1.0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(vec4(1.0, 0.0, 0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(1.0, 0.0, vec2(0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;"
+         "sk_FragColor.x = mat2(vec2(1.0, 0.0), 0.0, 1.0) == mat2(1.0) ? 1.0 : 0.0;"
          "}",
          *SkSL::ShaderCapsFactory::Default(),
          "#version 400\n"
@@ -617,6 +645,30 @@
          "    sk_FragColor = vec4(2.0, 4.0, 6.0, 8.0);\n"
          "    sk_FragColor = vec4(12.0, 6.0, 4.0, 3.0);\n"
          "    sk_FragColor.x = 6.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 0.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
+         "    sk_FragColor.x = 1.0;\n"
          "}\n");
 }