reworked SPIR-V binary operations and added support for VectorTimesScalar

Bug: skia:
Change-Id: I03b8a1ed3cf78060c5b9a5ede8d0371998116744
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/208677
Reviewed-by: Greg Daniel <egdaniel@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 6edf530..b8cf3fe 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2012,48 +2012,43 @@
     return result;
 }
 
-SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
-    // handle cases where we don't necessarily evaluate both LHS and RHS
-    switch (b.fOperator) {
-        case Token::EQ: {
-            SpvId rhs = this->writeExpression(*b.fRight, out);
-            this->getLValue(*b.fLeft, out)->store(rhs, out);
-            return rhs;
-        }
-        case Token::LOGICALAND:
-            return this->writeLogicalAnd(b, out);
-        case Token::LOGICALOR:
-            return this->writeLogicalOr(b, out);
-        default:
-            break;
+std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
+    if (type.isInteger()) {
+        return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
     }
-
-    // "normal" operators
-    const Type& resultType = b.fType;
-    std::unique_ptr<LValue> lvalue;
-    SpvId lhs;
-    if (is_assignment(b.fOperator)) {
-        lvalue = this->getLValue(*b.fLeft, out);
-        lhs = lvalue->load(out);
+    else if (type.isFloat()) {
+        return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
     } else {
-        lvalue = nullptr;
-        lhs = this->writeExpression(*b.fLeft, out);
+        ABORT("math is unsupported on type '%s'", type.name().c_str());
     }
-    SpvId rhs = this->writeExpression(*b.fRight, out);
-    if (b.fOperator == Token::COMMA) {
-        return rhs;
-    }
+}
+
+SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
+                                                const Type& rightType, SpvId rhs,
+                                                const Type& resultType, OutputStream& out) {
     Type tmp("<invalid>");
     // overall type we are operating on: float2, int, uint4...
     const Type* operandType;
     // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
     // handling in SPIR-V
-    if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
-        if (b.fLeft->fType.kind() == Type::kVector_Kind &&
-            b.fRight->fType.isNumber()) {
+    if (this->getActualType(leftType) != this->getActualType(rightType)) {
+        if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
+            if (op == Token::SLASH) {
+                SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
+                SpvId inverse = this->nextId();
+                this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
+                rhs = inverse;
+                op = Token::STAR;
+            }
+            if (op == Token::STAR) {
+                SpvId result = this->nextId();
+                this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
+                                       result, lhs, rhs, out);
+                return result;
+            }
             // promote number to vector
             SpvId vec = this->nextId();
-            const Type& vecType = b.fLeft->fType;
+            const Type& vecType = leftType;
             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
             this->writeWord(this->getType(vecType), out);
             this->writeWord(vec, out);
@@ -2061,12 +2056,17 @@
                 this->writeWord(rhs, out);
             }
             rhs = vec;
-            operandType = &b.fLeft->fType;
-        } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
-                   b.fLeft->fType.isNumber()) {
+            operandType = &leftType;
+        } else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
+            if (op == Token::STAR) {
+                SpvId result = this->nextId();
+                this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
+                                       result, rhs, lhs, out);
+                return result;
+            }
             // promote number to vector
             SpvId vec = this->nextId();
-            const Type& vecType = b.fRight->fType;
+            const Type& vecType = rightType;
             this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
             this->writeWord(this->getType(vecType), out);
             this->writeWord(vec, out);
@@ -2074,52 +2074,41 @@
                 this->writeWord(lhs, out);
             }
             lhs = vec;
-            SkASSERT(!lvalue);
-            operandType = &b.fRight->fType;
-        } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
-            SpvOp_ op;
-            if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                op = SpvOpMatrixTimesMatrix;
-            } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
-                op = SpvOpMatrixTimesVector;
+            operandType = &rightType;
+        } else if (leftType.kind() == Type::kMatrix_Kind) {
+            SpvOp_ spvop;
+            if (rightType.kind() == Type::kMatrix_Kind) {
+                spvop = SpvOpMatrixTimesMatrix;
+            } else if (rightType.kind() == Type::kVector_Kind) {
+                spvop = SpvOpMatrixTimesVector;
             } else {
-                SkASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
-                op = SpvOpMatrixTimesScalar;
+                SkASSERT(rightType.kind() == Type::kScalar_Kind);
+                spvop = SpvOpMatrixTimesScalar;
             }
             SpvId result = this->nextId();
-            this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
-            if (b.fOperator == Token::STAREQ) {
-                lvalue->store(result, out);
-            } else {
-                SkASSERT(b.fOperator == Token::STAR);
-            }
+            this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
             return result;
-        } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
+        } else if (rightType.kind() == Type::kMatrix_Kind) {
             SpvId result = this->nextId();
-            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
-                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
+            if (leftType.kind() == Type::kVector_Kind) {
+                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
                                        lhs, rhs, out);
             } else {
-                SkASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
-                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
-                                       lhs, out);
-            }
-            if (b.fOperator == Token::STAREQ) {
-                lvalue->store(result, out);
-            } else {
-                SkASSERT(b.fOperator == Token::STAR);
+                SkASSERT(leftType.kind() == Type::kScalar_Kind);
+                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
+                                       rhs, lhs, out);
             }
             return result;
         } else {
-            ABORT("unsupported binary expression: %s (%s, %s)", b.description().c_str(),
-                  b.fLeft->fType.description().c_str(), b.fRight->fType.description().c_str());
+            SkASSERT(false);
+            return -1;
         }
     } else {
-        tmp = this->getActualType(b.fLeft->fType);
+        tmp = this->getActualType(leftType);
         operandType = &tmp;
-        SkASSERT(*operandType == this->getActualType(b.fRight->fType));
+        SkASSERT(*operandType == this->getActualType(rightType));
     }
-    switch (b.fOperator) {
+    switch (op) {
         case Token::EQEQ: {
             if (operandType->kind() == Type::kMatrix_Kind) {
                 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
@@ -2178,26 +2167,26 @@
                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
                                               SpvOpULessThanEqual, SpvOpUndef, out);
         case Token::PLUS:
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                SkASSERT(b.fLeft->fType == b.fRight->fType);
-                return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+            if (leftType.kind() == Type::kMatrix_Kind &&
+                rightType.kind() == Type::kMatrix_Kind) {
+                SkASSERT(leftType == rightType);
+                return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
                                                             SpvOpFAdd, SpvOpIAdd, out);
             }
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
         case Token::MINUS:
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                SkASSERT(b.fLeft->fType == b.fRight->fType);
-                return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+            if (leftType.kind() == Type::kMatrix_Kind &&
+                rightType.kind() == Type::kMatrix_Kind) {
+                SkASSERT(leftType == rightType);
+                return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
                                                             SpvOpFSub, SpvOpISub, out);
             }
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
         case Token::STAR:
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
+            if (leftType.kind() == Type::kMatrix_Kind &&
+                rightType.kind() == Type::kMatrix_Kind) {
                 // matrix multiply
                 SpvId result = this->nextId();
                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
@@ -2229,114 +2218,48 @@
         case Token::BITWISEXOR:
             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
                                               SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
-        case Token::PLUSEQ: {
-            SpvId result;
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                SkASSERT(b.fLeft->fType == b.fRight->fType);
-                result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
-                                                            SpvOpFAdd, SpvOpIAdd, out);
-            }
-            else {
-                result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
-                                                      SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
-            }
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::MINUSEQ: {
-            SpvId result;
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                SkASSERT(b.fLeft->fType == b.fRight->fType);
-                result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
-                                                            SpvOpFSub, SpvOpISub, out);
-            }
-            else {
-                result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
-                                                      SpvOpISub, SpvOpISub, SpvOpUndef, out);
-            }
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::STAREQ: {
-            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
-                b.fRight->fType.kind() == Type::kMatrix_Kind) {
-                // matrix multiply
-                SpvId result = this->nextId();
-                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
-                                       lhs, rhs, out);
-                SkASSERT(lvalue);
-                lvalue->store(result, out);
-                return result;
-            }
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
-                                                      SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::SLASHEQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
-                                                      SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::PERCENTEQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
-                                                      SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::SHLEQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
-                                                      SpvOpUndef, SpvOpShiftLeftLogical,
-                                                      SpvOpShiftLeftLogical, SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::SHREQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
-                                                      SpvOpUndef, SpvOpShiftRightArithmetic,
-                                                      SpvOpShiftRightLogical, SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::BITWISEANDEQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
-                                                      SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
-                                                      SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::BITWISEOREQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
-                                                      SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
-                                                      SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
-        case Token::BITWISEXOREQ: {
-            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
-                                                      SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
-                                                      SpvOpUndef, out);
-            SkASSERT(lvalue);
-            lvalue->store(result, out);
-            return result;
-        }
+        case Token::COMMA:
+            return rhs;
         default:
-            ABORT("unsupported binary expression: %s", b.description().c_str());
+            SkASSERT(false);
+            return -1;
     }
 }
 
+SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
+    // handle cases where we don't necessarily evaluate both LHS and RHS
+    switch (b.fOperator) {
+        case Token::EQ: {
+            SpvId rhs = this->writeExpression(*b.fRight, out);
+            this->getLValue(*b.fLeft, out)->store(rhs, out);
+            return rhs;
+        }
+        case Token::LOGICALAND:
+            return this->writeLogicalAnd(b, out);
+        case Token::LOGICALOR:
+            return this->writeLogicalOr(b, out);
+        default:
+            break;
+    }
+
+    std::unique_ptr<LValue> lvalue;
+    SpvId lhs;
+    if (is_assignment(b.fOperator)) {
+        lvalue = this->getLValue(*b.fLeft, out);
+        lhs = lvalue->load(out);
+    } else {
+        lvalue = nullptr;
+        lhs = this->writeExpression(*b.fLeft, out);
+    }
+    SpvId rhs = this->writeExpression(*b.fRight, out);
+    SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
+                                               b.fRight->fType, rhs, b.fType, out);
+    if (lvalue) {
+        lvalue->store(result, out);
+    }
+    return result;
+}
+
 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
     SkASSERT(a.fOperator == Token::LOGICALAND);
     BoolLiteral falseLiteral(fContext, -1, false);
@@ -2413,17 +2336,6 @@
     return result;
 }
 
-std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
-    if (type.isInteger()) {
-        return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
-    }
-    else if (type.isFloat()) {
-        return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
-    } else {
-        ABORT("math is unsupported on type '%s'", type.name().c_str());
-    }
-}
-
 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
     if (p.fOperator == Token::MINUS) {
         SpvId result = this->nextId();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 26560d5..fd4482f 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -274,6 +274,10 @@
     SpvId writeBinaryOperation(const BinaryExpression& expr, SpvOp_ ifFloat, SpvOp_ ifInt,
                                SpvOp_ ifUInt, OutputStream& out);
 
+    SpvId writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
+                                const Type& rightType, SpvId rhs, const Type& resultType,
+                                OutputStream& out);
+
     SpvId writeBinaryExpression(const BinaryExpression& b, OutputStream& out);
 
     SpvId writeTernaryExpression(const TernaryExpression& t, OutputStream& out);
diff --git a/src/sksl/SkSLUtil.cpp b/src/sksl/SkSLUtil.cpp
index 4684df4..70b4918 100644
--- a/src/sksl/SkSLUtil.cpp
+++ b/src/sksl/SkSLUtil.cpp
@@ -67,7 +67,7 @@
         case Token::LOGICALOREQ:  return Token::LOGICALOR;
         case Token::LOGICALXOREQ: return Token::LOGICALXOR;
         case Token::LOGICALANDEQ: return Token::LOGICALAND;
-        default: return Token::INVALID;
+        default: return op;
     }
 }