Fix SkSL constant propagation within nested casts.
Previously, the inner type was ignored.
Change-Id: I51d251fc38358ef889b5a3f85d5f2d23bd8cf4c5
Bug: skia:10615
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/310657
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 342b720..6352dfe 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -136,46 +136,86 @@
return true;
}
- template<typename type>
+ template <typename type>
type getVecComponent(int index) const {
SkASSERT(fType.kind() == Type::kVector_Kind);
if (fArguments.size() == 1 && fArguments[0]->fType.kind() == Type::kScalar_Kind) {
+ // This constructor just wraps a scalar. Propagate out the value.
if (std::is_floating_point<type>::value) {
return fArguments[0]->getConstantFloat();
} else {
return fArguments[0]->getConstantInt();
}
}
+
+ // Walk through all the constructor arguments until we reach the index we're searching for.
int current = 0;
- for (const auto& arg : fArguments) {
- SkASSERT(current <= index);
+ for (const std::unique_ptr<Expression>& arg : fArguments) {
+ if (current > index) {
+ // Somehow, we went past the argument we're looking for. Bail.
+ break;
+ }
+
if (arg->fType.kind() == Type::kScalar_Kind) {
if (index == current) {
+ // We're on the proper argument, and it's a scalar; fetch it.
if (std::is_floating_point<type>::value) {
- return arg.get()->getConstantFloat();
+ return arg->getConstantFloat();
} else {
- return arg.get()->getConstantInt();
+ return arg->getConstantInt();
}
}
current++;
- } else if (arg->fKind == kConstructor_Kind) {
- if (current + arg->fType.columns() > index) {
- return ((const Constructor&) *arg).getVecComponent<type>(index - current);
- }
- current += arg->fType.columns();
- } else {
- if (current + arg->fType.columns() > index) {
- SkASSERT(arg->fKind == kPrefix_Kind);
- const PrefixExpression& p = (PrefixExpression&) *arg;
- const Constructor& c = (const Constructor&) *p.fOperand;
- return -c.getVecComponent<type>(index - current);
- }
- current += arg->fType.columns();
+ continue;
}
+
+ switch (arg->fKind) {
+ case kConstructor_Kind: {
+ const Constructor& constructor = static_cast<const Constructor&>(*arg);
+ if (current + constructor.fType.columns() > index) {
+ // We've found a constructor that overlaps the proper argument. Descend into
+ // it, honoring the type.
+ if (constructor.fType.componentType().isFloat()) {
+ return type(constructor.getVecComponent<SKSL_FLOAT>(index - current));
+ } else {
+ return type(constructor.getVecComponent<SKSL_INT>(index - current));
+ }
+ }
+ break;
+ }
+ case kPrefix_Kind: {
+ const PrefixExpression& prefix = static_cast<const PrefixExpression&>(*arg);
+ if (current + prefix.fType.columns() > index) {
+ // We found a prefix operator that contains the proper argument. Descend
+ // into it. We only support for constant propagation of the unary minus, so
+ // we shouldn't see any other tokens here.
+ SkASSERT(prefix.fOperator == Token::Kind::TK_MINUS);
+
+ // We expect the - prefix to always be attached to a constructor.
+ SkASSERT(prefix.fOperand->fKind == kConstructor_Kind);
+ const Constructor& constructor =
+ static_cast<const Constructor&>(*prefix.fOperand);
+
+ // Descend into this constructor, honoring the type.
+ if (constructor.fType.componentType().isFloat()) {
+ return -type(constructor.getVecComponent<SKSL_FLOAT>(index - current));
+ } else {
+ return -type(constructor.getVecComponent<SKSL_INT>(index - current));
+ }
+ }
+ break;
+ }
+ default: {
+ SkDEBUGFAILF("unexpected component %d { %s } in %s\n",
+ index, arg->description().c_str(), description().c_str());
+ break;
+ }
+ }
+
+ current += arg->fType.columns();
}
-#ifdef SK_DEBUG
- ABORT("failed to find vector component %d in %s\n", index, description().c_str());
-#endif
+
+ SkDEBUGFAILF("failed to find vector component %d in %s\n", index, description().c_str());
return -1;
}
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index afd30c7..5d7a5d9 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -106,21 +106,21 @@
std::unique_ptr<Expression> constantPropagate(const IRGenerator& irGenerator,
const DefinitionMap& definitions) override {
- if (fBase->fKind == Expression::kConstructor_Kind && fBase->isCompileTimeConstant()) {
- // we're swizzling a constant vector, e.g. float4(1).x. Simplify it.
- SkASSERT(fBase->fKind == Expression::kConstructor_Kind);
- if (fType.isInteger()) {
- SkASSERT(fComponents.size() == 1);
- int64_t value = ((Constructor&) *fBase).getIVecComponent(fComponents[0]);
- return std::unique_ptr<Expression>(new IntLiteral(irGenerator.fContext,
- -1,
- value));
- } else if (fType.isFloat()) {
- SkASSERT(fComponents.size() == 1);
- double value = ((Constructor&) *fBase).getFVecComponent(fComponents[0]);
- return std::unique_ptr<Expression>(new FloatLiteral(irGenerator.fContext,
- -1,
- value));
+ if (fBase->fKind == Expression::kConstructor_Kind) {
+ Constructor& constructor = static_cast<Constructor&>(*fBase);
+ if (constructor.isCompileTimeConstant()) {
+ // we're swizzling a constant vector, e.g. float4(1).x. Simplify it.
+ if (fType.isInteger()) {
+ SkASSERT(fComponents.size() == 1);
+ int64_t value = constructor.getIVecComponent(fComponents[0]);
+ return std::make_unique<IntLiteral>(irGenerator.fContext, constructor.fOffset,
+ value);
+ } else if (fType.isFloat()) {
+ SkASSERT(fComponents.size() == 1);
+ double value = constructor.getFVecComponent(fComponents[0]);
+ return std::make_unique<FloatLiteral>(irGenerator.fContext, constructor.fOffset,
+ value);
+ }
}
}
return nullptr;
diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp
index a6b3ccb..7c315c9 100644
--- a/tests/SkSLGLSLTest.cpp
+++ b/tests/SkSLGLSLTest.cpp
@@ -2578,6 +2578,64 @@
"}\n");
}
+DEF_TEST(SkSLStackingVectorCasts, r) {
+ test(r,
+ "void main() {"
+ " if (half4(0, 0, 1, 1) == half4(int4(0, 0, 1, 1)))"
+ " sk_FragColor = half4(0, 1, 0, 1);"
+ " else"
+ " sk_FragColor = half4(1, 0, 0, 1);"
+ "}",
+ *SkSL::ShaderCapsFactory::Default(),
+ "#version 400\n"
+ "out vec4 sk_FragColor;\n"
+ "void main() {\n"
+ " sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+ "}\n");
+ test(r,
+ "void main() {"
+ " if (half4(int4(0, 0, 1, 1)) == half4(int4(half4(0, 0, 1, 1))))"
+ " sk_FragColor = half4(0, 1, 0, 1);"
+ " else"
+ " sk_FragColor = half4(1, 0, 0, 1);"
+ "}",
+ *SkSL::ShaderCapsFactory::Default(),
+ "#version 400\n"
+ "out vec4 sk_FragColor;\n"
+ "void main() {\n"
+ " sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+ "}\n");
+}
+
+DEF_TEST(SkSLCastsRoundTowardZero, r) {
+ test(r,
+ "void main() {"
+ " if (half4(int4(0, 0, 1, 2)) == half4(int4(half4(0.01, 0.99, 1.49, 2.75))))"
+ " sk_FragColor = half4(0, 1, 0, 1);"
+ " else"
+ " sk_FragColor = half4(1, 0, 0, 1);"
+ "}",
+ *SkSL::ShaderCapsFactory::Default(),
+ "#version 400\n"
+ "out vec4 sk_FragColor;\n"
+ "void main() {\n"
+ " sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+ "}\n");
+ test(r,
+ "void main() {"
+ " if (half4(int4(0, 0, -1, -2)) == half4(int4(half4(-0.01, -0.99, -1.49, -2.75))))"
+ " sk_FragColor = half4(0, 1, 0, 1);"
+ " else"
+ " sk_FragColor = half4(1, 0, 0, 1);"
+ "}",
+ *SkSL::ShaderCapsFactory::Default(),
+ "#version 400\n"
+ "out vec4 sk_FragColor;\n"
+ "void main() {\n"
+ " sk_FragColor = vec4(0.0, 1.0, 0.0, 1.0);\n"
+ "}\n");
+}
+
DEF_TEST(SkSLNegatedVectorLiteral, r) {
test(r,
"void main() {"