Fix SkSL constant propagation within nested casts.

Previously, the inner type was ignored.

Change-Id: I51d251fc38358ef889b5a3f85d5f2d23bd8cf4c5
Bug: skia:10615
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/310657
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 342b720..6352dfe 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -136,46 +136,86 @@
         return true;
     }
 
-    template<typename type>
+    template <typename type>
     type getVecComponent(int index) const {
         SkASSERT(fType.kind() == Type::kVector_Kind);
         if (fArguments.size() == 1 && fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+            // This constructor just wraps a scalar. Propagate out the value.
             if (std::is_floating_point<type>::value) {
                 return fArguments[0]->getConstantFloat();
             } else {
                 return fArguments[0]->getConstantInt();
             }
         }
+
+        // Walk through all the constructor arguments until we reach the index we're searching for.
         int current = 0;
-        for (const auto& arg : fArguments) {
-            SkASSERT(current <= index);
+        for (const std::unique_ptr<Expression>& arg : fArguments) {
+            if (current > index) {
+                // Somehow, we went past the argument we're looking for. Bail.
+                break;
+            }
+
             if (arg->fType.kind() == Type::kScalar_Kind) {
                 if (index == current) {
+                    // We're on the proper argument, and it's a scalar; fetch it.
                     if (std::is_floating_point<type>::value) {
-                        return arg.get()->getConstantFloat();
+                        return arg->getConstantFloat();
                     } else {
-                        return arg.get()->getConstantInt();
+                        return arg->getConstantInt();
                     }
                 }
                 current++;
-            } else if (arg->fKind == kConstructor_Kind) {
-                if (current + arg->fType.columns() > index) {
-                    return ((const Constructor&) *arg).getVecComponent<type>(index - current);
-                }
-                current += arg->fType.columns();
-            } else {
-                if (current + arg->fType.columns() > index) {
-                    SkASSERT(arg->fKind == kPrefix_Kind);
-                    const PrefixExpression& p = (PrefixExpression&) *arg;
-                    const Constructor& c = (const Constructor&) *p.fOperand;
-                    return -c.getVecComponent<type>(index - current);
-                }
-                current += arg->fType.columns();
+                continue;
             }
+
+            switch (arg->fKind) {
+                case kConstructor_Kind: {
+                    const Constructor& constructor = static_cast<const Constructor&>(*arg);
+                    if (current + constructor.fType.columns() > index) {
+                        // We've found a constructor that overlaps the proper argument. Descend into
+                        // it, honoring the type.
+                        if (constructor.fType.componentType().isFloat()) {
+                            return type(constructor.getVecComponent<SKSL_FLOAT>(index - current));
+                        } else {
+                            return type(constructor.getVecComponent<SKSL_INT>(index - current));
+                        }
+                    }
+                    break;
+                }
+                case kPrefix_Kind: {
+                    const PrefixExpression& prefix = static_cast<const PrefixExpression&>(*arg);
+                    if (current + prefix.fType.columns() > index) {
+                        // We found a prefix operator that contains the proper argument. Descend
+                        // into it. We only support for constant propagation of the unary minus, so
+                        // we shouldn't see any other tokens here.
+                        SkASSERT(prefix.fOperator == Token::Kind::TK_MINUS);
+
+                        // We expect the - prefix to always be attached to a constructor.
+                        SkASSERT(prefix.fOperand->fKind == kConstructor_Kind);
+                        const Constructor& constructor =
+                                static_cast<const Constructor&>(*prefix.fOperand);
+
+                        // Descend into this constructor, honoring the type.
+                        if (constructor.fType.componentType().isFloat()) {
+                            return -type(constructor.getVecComponent<SKSL_FLOAT>(index - current));
+                        } else {
+                            return -type(constructor.getVecComponent<SKSL_INT>(index - current));
+                        }
+                    }
+                    break;
+                }
+                default: {
+                    SkDEBUGFAILF("unexpected component %d { %s } in %s\n",
+                                 index, arg->description().c_str(), description().c_str());
+                    break;
+                }
+            }
+
+            current += arg->fType.columns();
         }
-#ifdef SK_DEBUG
-        ABORT("failed to find vector component %d in %s\n", index, description().c_str());
-#endif
+
+        SkDEBUGFAILF("failed to find vector component %d in %s\n", index, description().c_str());
         return -1;
     }
 
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index afd30c7..5d7a5d9 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -106,21 +106,21 @@
 
     std::unique_ptr<Expression> constantPropagate(const IRGenerator& irGenerator,
                                                   const DefinitionMap& definitions) override {
-        if (fBase->fKind == Expression::kConstructor_Kind && fBase->isCompileTimeConstant()) {
-            // we're swizzling a constant vector, e.g. float4(1).x. Simplify it.
-            SkASSERT(fBase->fKind == Expression::kConstructor_Kind);
-            if (fType.isInteger()) {
-                SkASSERT(fComponents.size() == 1);
-                int64_t value = ((Constructor&) *fBase).getIVecComponent(fComponents[0]);
-                return std::unique_ptr<Expression>(new IntLiteral(irGenerator.fContext,
-                                                                  -1,
-                                                                  value));
-            } else if (fType.isFloat()) {
-                SkASSERT(fComponents.size() == 1);
-                double value = ((Constructor&) *fBase).getFVecComponent(fComponents[0]);
-                return std::unique_ptr<Expression>(new FloatLiteral(irGenerator.fContext,
-                                                                    -1,
-                                                                    value));
+        if (fBase->fKind == Expression::kConstructor_Kind) {
+            Constructor& constructor = static_cast<Constructor&>(*fBase);
+            if (constructor.isCompileTimeConstant()) {
+                // we're swizzling a constant vector, e.g. float4(1).x. Simplify it.
+                if (fType.isInteger()) {
+                    SkASSERT(fComponents.size() == 1);
+                    int64_t value = constructor.getIVecComponent(fComponents[0]);
+                    return std::make_unique<IntLiteral>(irGenerator.fContext, constructor.fOffset,
+                                                        value);
+                } else if (fType.isFloat()) {
+                    SkASSERT(fComponents.size() == 1);
+                    double value = constructor.getFVecComponent(fComponents[0]);
+                    return std::make_unique<FloatLiteral>(irGenerator.fContext, constructor.fOffset,
+                                                          value);
+                }
             }
         }
         return nullptr;
diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp
index a6b3ccb..7c315c9 100644
--- a/tests/SkSLGLSLTest.cpp
+++ b/tests/SkSLGLSLTest.cpp
@@ -2578,6 +2578,64 @@
          "}\n");
 }
 
+DEF_TEST(SkSLStackingVectorCasts, r) {
+    test(r,
+         "void main() {"
+         "    if (half4(0, 0, 1, 1) == half4(int4(0, 0, 1, 1)))"
+         "        sk_FragColor = half4(0, 1, 0, 1);"
+         "    else"
+         "        sk_FragColor = half4(1, 0, 0, 1);"
+         "}",
+         *SkSL::ShaderCapsFactory::Default(),
+         "#version 400\n"
+         "out vec4 sk_FragColor;\n"
+         "void main() {\n"
+         "    sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+         "}\n");
+    test(r,
+         "void main() {"
+         "    if (half4(int4(0, 0, 1, 1)) == half4(int4(half4(0, 0, 1, 1))))"
+         "        sk_FragColor = half4(0, 1, 0, 1);"
+         "    else"
+         "        sk_FragColor = half4(1, 0, 0, 1);"
+         "}",
+         *SkSL::ShaderCapsFactory::Default(),
+         "#version 400\n"
+         "out vec4 sk_FragColor;\n"
+         "void main() {\n"
+         "    sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+         "}\n");
+}
+
+DEF_TEST(SkSLCastsRoundTowardZero, r) {
+    test(r,
+         "void main() {"
+         "    if (half4(int4(0, 0, 1, 2)) == half4(int4(half4(0.01, 0.99, 1.49, 2.75))))"
+         "        sk_FragColor = half4(0, 1, 0, 1);"
+         "    else"
+         "        sk_FragColor = half4(1, 0, 0, 1);"
+         "}",
+         *SkSL::ShaderCapsFactory::Default(),
+         "#version 400\n"
+         "out vec4 sk_FragColor;\n"
+         "void main() {\n"
+         "    sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+         "}\n");
+    test(r,
+         "void main() {"
+         "    if (half4(int4(0, 0, -1, -2)) == half4(int4(half4(-0.01, -0.99, -1.49, -2.75))))"
+         "        sk_FragColor = half4(0, 1, 0, 1);"
+         "    else"
+         "        sk_FragColor = half4(1, 0, 0, 1);"
+         "}",
+         *SkSL::ShaderCapsFactory::Default(),
+         "#version 400\n"
+         "out vec4 sk_FragColor;\n"
+         "void main() {\n"
+         "    sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+         "}\n");
+}
+
 DEF_TEST(SkSLNegatedVectorLiteral, r) {
     test(r,
          "void main() {"