reworked SPIR-V binary operations and added support for VectorTimesScalar
Bug: skia:
Change-Id: I03b8a1ed3cf78060c5b9a5ede8d0371998116744
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/208677
Reviewed-by: Greg Daniel <egdaniel@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 6edf530..b8cf3fe 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2012,48 +2012,43 @@
return result;
}
-SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
- // handle cases where we don't necessarily evaluate both LHS and RHS
- switch (b.fOperator) {
- case Token::EQ: {
- SpvId rhs = this->writeExpression(*b.fRight, out);
- this->getLValue(*b.fLeft, out)->store(rhs, out);
- return rhs;
- }
- case Token::LOGICALAND:
- return this->writeLogicalAnd(b, out);
- case Token::LOGICALOR:
- return this->writeLogicalOr(b, out);
- default:
- break;
+std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
+ if (type.isInteger()) {
+ return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
}
-
- // "normal" operators
- const Type& resultType = b.fType;
- std::unique_ptr<LValue> lvalue;
- SpvId lhs;
- if (is_assignment(b.fOperator)) {
- lvalue = this->getLValue(*b.fLeft, out);
- lhs = lvalue->load(out);
+ else if (type.isFloat()) {
+ return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
} else {
- lvalue = nullptr;
- lhs = this->writeExpression(*b.fLeft, out);
+ ABORT("math is unsupported on type '%s'", type.name().c_str());
}
- SpvId rhs = this->writeExpression(*b.fRight, out);
- if (b.fOperator == Token::COMMA) {
- return rhs;
- }
+}
+
+SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
+ const Type& rightType, SpvId rhs,
+ const Type& resultType, OutputStream& out) {
Type tmp("<invalid>");
// overall type we are operating on: float2, int, uint4...
const Type* operandType;
// IR allows mismatched types in expressions (e.g. float2 * float), but they need special
// handling in SPIR-V
- if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
- if (b.fLeft->fType.kind() == Type::kVector_Kind &&
- b.fRight->fType.isNumber()) {
+ if (this->getActualType(leftType) != this->getActualType(rightType)) {
+ if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
+ if (op == Token::SLASH) {
+ SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
+ SpvId inverse = this->nextId();
+ this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
+ rhs = inverse;
+ op = Token::STAR;
+ }
+ if (op == Token::STAR) {
+ SpvId result = this->nextId();
+ this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
+ result, lhs, rhs, out);
+ return result;
+ }
// promote number to vector
SpvId vec = this->nextId();
- const Type& vecType = b.fLeft->fType;
+ const Type& vecType = leftType;
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
this->writeWord(this->getType(vecType), out);
this->writeWord(vec, out);
@@ -2061,12 +2056,17 @@
this->writeWord(rhs, out);
}
rhs = vec;
- operandType = &b.fLeft->fType;
- } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
- b.fLeft->fType.isNumber()) {
+ operandType = &leftType;
+ } else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
+ if (op == Token::STAR) {
+ SpvId result = this->nextId();
+ this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
+ result, rhs, lhs, out);
+ return result;
+ }
// promote number to vector
SpvId vec = this->nextId();
- const Type& vecType = b.fRight->fType;
+ const Type& vecType = rightType;
this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
this->writeWord(this->getType(vecType), out);
this->writeWord(vec, out);
@@ -2074,52 +2074,41 @@
this->writeWord(lhs, out);
}
lhs = vec;
- SkASSERT(!lvalue);
- operandType = &b.fRight->fType;
- } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
- SpvOp_ op;
- if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
- op = SpvOpMatrixTimesMatrix;
- } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
- op = SpvOpMatrixTimesVector;
+ operandType = &rightType;
+ } else if (leftType.kind() == Type::kMatrix_Kind) {
+ SpvOp_ spvop;
+ if (rightType.kind() == Type::kMatrix_Kind) {
+ spvop = SpvOpMatrixTimesMatrix;
+ } else if (rightType.kind() == Type::kVector_Kind) {
+ spvop = SpvOpMatrixTimesVector;
} else {
- SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
- op = SpvOpMatrixTimesScalar;
+ SkASSERT(rightType.kind() == Type::kScalar_Kind);
+ spvop = SpvOpMatrixTimesScalar;
}
SpvId result = this->nextId();
- this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
- if (b.fOperator == Token::STAREQ) {
- lvalue->store(result, out);
- } else {
- SkASSERT(b.fOperator == Token::STAR);
- }
+ this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
return result;
- } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ } else if (rightType.kind() == Type::kMatrix_Kind) {
SpvId result = this->nextId();
- if (b.fLeft->fType.kind() == Type::kVector_Kind) {
- this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
+ if (leftType.kind() == Type::kVector_Kind) {
+ this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
lhs, rhs, out);
} else {
- SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
- this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
- lhs, out);
- }
- if (b.fOperator == Token::STAREQ) {
- lvalue->store(result, out);
- } else {
- SkASSERT(b.fOperator == Token::STAR);
+ SkASSERT(leftType.kind() == Type::kScalar_Kind);
+ this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
+ rhs, lhs, out);
}
return result;
} else {
- ABORT("unsupported binary expression: %s (%s, %s)", b.description().c_str(),
- b.fLeft->fType.description().c_str(), b.fRight->fType.description().c_str());
+ SkASSERT(false);
+ return -1;
}
} else {
- tmp = this->getActualType(b.fLeft->fType);
+ tmp = this->getActualType(leftType);
operandType = &tmp;
- SkASSERT(*operandType == this->getActualType(b.fRight->fType));
+ SkASSERT(*operandType == this->getActualType(rightType));
}
- switch (b.fOperator) {
+ switch (op) {
case Token::EQEQ: {
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
@@ -2178,26 +2167,26 @@
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
SpvOpULessThanEqual, SpvOpUndef, out);
case Token::PLUS:
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
- SkASSERT(b.fLeft->fType == b.fRight->fType);
- return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ if (leftType.kind() == Type::kMatrix_Kind &&
+ rightType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(leftType == rightType);
+ return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
SpvOpFAdd, SpvOpIAdd, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
case Token::MINUS:
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
- SkASSERT(b.fLeft->fType == b.fRight->fType);
- return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ if (leftType.kind() == Type::kMatrix_Kind &&
+ rightType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(leftType == rightType);
+ return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
SpvOpFSub, SpvOpISub, out);
}
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
case Token::STAR:
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ if (leftType.kind() == Type::kMatrix_Kind &&
+ rightType.kind() == Type::kMatrix_Kind) {
// matrix multiply
SpvId result = this->nextId();
this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2229,114 +2218,48 @@
case Token::BITWISEXOR:
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
- case Token::PLUSEQ: {
- SpvId result;
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
- SkASSERT(b.fLeft->fType == b.fRight->fType);
- result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
- SpvOpFAdd, SpvOpIAdd, out);
- }
- else {
- result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
- SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
- }
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::MINUSEQ: {
- SpvId result;
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
- SkASSERT(b.fLeft->fType == b.fRight->fType);
- result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
- SpvOpFSub, SpvOpISub, out);
- }
- else {
- result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
- SpvOpISub, SpvOpISub, SpvOpUndef, out);
- }
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::STAREQ: {
- if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
- b.fRight->fType.kind() == Type::kMatrix_Kind) {
- // matrix multiply
- SpvId result = this->nextId();
- this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
- lhs, rhs, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
- SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::SLASHEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
- SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::PERCENTEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
- SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::SHLEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
- SpvOpUndef, SpvOpShiftLeftLogical,
- SpvOpShiftLeftLogical, SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::SHREQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
- SpvOpUndef, SpvOpShiftRightArithmetic,
- SpvOpShiftRightLogical, SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::BITWISEANDEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
- SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
- SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::BITWISEOREQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
- SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
- SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
- case Token::BITWISEXOREQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
- SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
- SpvOpUndef, out);
- SkASSERT(lvalue);
- lvalue->store(result, out);
- return result;
- }
+ case Token::COMMA:
+ return rhs;
default:
- ABORT("unsupported binary expression: %s", b.description().c_str());
+ SkASSERT(false);
+ return -1;
}
}
+SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
+ // handle cases where we don't necessarily evaluate both LHS and RHS
+ switch (b.fOperator) {
+ case Token::EQ: {
+ SpvId rhs = this->writeExpression(*b.fRight, out);
+ this->getLValue(*b.fLeft, out)->store(rhs, out);
+ return rhs;
+ }
+ case Token::LOGICALAND:
+ return this->writeLogicalAnd(b, out);
+ case Token::LOGICALOR:
+ return this->writeLogicalOr(b, out);
+ default:
+ break;
+ }
+
+ std::unique_ptr<LValue> lvalue;
+ SpvId lhs;
+ if (is_assignment(b.fOperator)) {
+ lvalue = this->getLValue(*b.fLeft, out);
+ lhs = lvalue->load(out);
+ } else {
+ lvalue = nullptr;
+ lhs = this->writeExpression(*b.fLeft, out);
+ }
+ SpvId rhs = this->writeExpression(*b.fRight, out);
+ SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
+ b.fRight->fType, rhs, b.fType, out);
+ if (lvalue) {
+ lvalue->store(result, out);
+ }
+ return result;
+}
+
SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
SkASSERT(a.fOperator == Token::LOGICALAND);
BoolLiteral falseLiteral(fContext, -1, false);
@@ -2413,17 +2336,6 @@
return result;
}
-std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
- if (type.isInteger()) {
- return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
- }
- else if (type.isFloat()) {
- return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
- } else {
- ABORT("math is unsupported on type '%s'", type.name().c_str());
- }
-}
-
SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
if (p.fOperator == Token::MINUS) {
SpvId result = this->nextId();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 26560d5..fd4482f 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -274,6 +274,10 @@
SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
SpvOp_ ifUInt, OutputStream& out);
+ SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
+ const Type& rightType, SpvId rhs, const Type& resultType,
+ OutputStream& out);
+
SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
diff --git a/src/sksl/SkSLUtil.cpp b/src/sksl/SkSLUtil.cpp
index 4684df4..70b4918 100644
--- a/src/sksl/SkSLUtil.cpp
+++ b/src/sksl/SkSLUtil.cpp
@@ -67,7 +67,7 @@
case Token::LOGICALOREQ: return Token::LOGICALOR;
case Token::LOGICALXOREQ: return Token::LOGICALXOR;
case Token::LOGICALANDEQ: return Token::LOGICALAND;
- default: return Token::INVALID;
+ default: return op;
}
}