Support ES2-compliant for loops in SkSL-to-SkVM

Such loops must be unrollable, so that's what we do.

Bug: skia:11094
Change-Id: I1b34917b6f2d015ae7867415d0120a5df0ffd618
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/353619
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Mike Klein <mtklein@google.com>
diff --git a/src/sksl/SkSLVMGenerator.cpp b/src/sksl/SkSLVMGenerator.cpp
index 763404d..5c8c95c 100644
--- a/src/sksl/SkSLVMGenerator.cpp
+++ b/src/sksl/SkSLVMGenerator.cpp
@@ -215,7 +215,7 @@
         // As we encounter (possibly conditional) return statements, fReturned is updated to store
         // the lanes that have already returned. For the remainder of the current function, those
         // lanes should be disabled.
-        return fMask & ~currentFunction().fReturned;
+        return fConditionMask & fLoopMask & ~currentFunction().fReturned;
     }
 
     Value writeExpression(const Expression& expr);
@@ -231,6 +231,9 @@
 
     void writeStatement(const Statement& s);
     void writeBlock(const Block& b);
+    void writeBreakStatement();
+    void writeContinueStatement();
+    void writeForStatement(const ForStatement& f);
     void writeIfStatement(const IfStatement& stmt);
     void writeReturnStatement(const ReturnStatement& r);
     void writeVarDeclaration(const VarDeclaration& decl);
@@ -255,8 +258,15 @@
     std::unordered_map<const Variable*, Slot> fVariableMap;
     std::vector<skvm::Val> fSlots;
 
-    // Conditional execution mask (changes are managed by AutoMask, and tied to control-flow scopes)
-    skvm::I32 fMask;
+    // Conditional execution mask (managed by ScopedCondition, and tied to control-flow scopes)
+    skvm::I32 fConditionMask;
+
+    // Similar: loop execution masks. Each loop starts with all lanes active (fLoopMask).
+    // 'break' disables a lane in fLoopMask until the loop finishes
+    // 'continue' disables a lane in fLoopMask, and sets fContinueMask to be re-enabled on the next
+    //   iteration
+    skvm::I32 fLoopMask;
+    skvm::I32 fContinueMask;
 
     //
     // State that's local to the generation of a single function:
@@ -268,18 +278,18 @@
     std::vector<Function> fFunctionStack;
     Function& currentFunction() { return fFunctionStack.back(); }
 
-    class AutoMask {
+    class ScopedCondition {
     public:
-        AutoMask(SkVMGenerator* generator, skvm::I32 mask)
-                : fGenerator(generator), fOldMask(fGenerator->fMask) {
-            fGenerator->fMask &= mask;
+        ScopedCondition(SkVMGenerator* generator, skvm::I32 mask)
+                : fGenerator(generator), fOldConditionMask(fGenerator->fConditionMask) {
+            fGenerator->fConditionMask &= mask;
         }
 
-        ~AutoMask() { fGenerator->fMask = fOldMask; }
+        ~ScopedCondition() { fGenerator->fConditionMask = fOldConditionMask; }
 
     private:
         SkVMGenerator* fGenerator;
-        skvm::I32 fOldMask;
+        skvm::I32 fOldConditionMask;
     };
 };
 
@@ -375,7 +385,7 @@
 
             { "sample", Intrinsic::kSample },
         } {
-    fMask = fBuilder->splat(0xffff'ffff);
+    fConditionMask = fLoopMask = fBuilder->splat(0xffff'ffff);
 
     // Now, add storage for each global variable (including uniforms) to fSlots, and entries in
     // fVariableMap to remember where every variable is stored.
@@ -540,7 +550,7 @@
             SkASSERT(!isAssignment);
             SkASSERT(nk == Type::NumberKind::kBoolean);
             skvm::I32 lVal = i32(this->writeExpression(left));
-            AutoMask shortCircuit(this, lVal);
+            ScopedCondition shortCircuit(this, lVal);
             skvm::I32 rVal = i32(this->writeExpression(right));
             return lVal & rVal;
         }
@@ -548,7 +558,7 @@
             SkASSERT(!isAssignment);
             SkASSERT(nk == Type::NumberKind::kBoolean);
             skvm::I32 lVal = i32(this->writeExpression(left));
-            AutoMask shortCircuit(this, ~lVal);
+            ScopedCondition shortCircuit(this, ~lVal);
             skvm::I32 rVal = i32(this->writeExpression(right));
             return lVal | rVal;
         }
@@ -1158,9 +1168,9 @@
     }
 
     {
-        // This AutoMask merges currentFunction().fReturned into fMask. Lanes that conditionally
+        // This merges currentFunction().fReturned into fConditionMask. Lanes that conditionally
         // returned in the current function would otherwise resume execution within the child.
-        AutoMask m(this, ~currentFunction().fReturned);
+        ScopedCondition m(this, ~currentFunction().fReturned);
         this->writeFunction(*f.function().definition(), argVals, result.asSpan());
     }
 
@@ -1268,11 +1278,11 @@
     Value ifTrue, ifFalse;
 
     {
-        AutoMask m(this, test);
+        ScopedCondition m(this, test);
         ifTrue = this->writeExpression(*t.ifTrue());
     }
     {
-        AutoMask m(this, ~test);
+        ScopedCondition m(this, ~test);
         ifFalse = this->writeExpression(*t.ifFalse());
     }
 
@@ -1347,14 +1357,55 @@
     }
 }
 
+void SkVMGenerator::writeBreakStatement() {
+    // Any active lanes stop executing for the duration of the current loop
+    fLoopMask &= ~this->mask();
+}
+
+void SkVMGenerator::writeContinueStatement() {
+    // Any active lanes stop executing for the current iteration.
+    // Remember them in fContinueMask, to be re-enabled later.
+    skvm::I32 mask = this->mask();
+    fLoopMask &= ~mask;
+    fContinueMask |= mask;
+}
+
+void SkVMGenerator::writeForStatement(const ForStatement& f) {
+    // We require that all loops be ES2-compliant (unrollable), and actually unroll them here
+    Analysis::UnrollableLoopInfo loop;
+    SkAssertResult(Analysis::ForLoopIsValidForES2(f, &loop, /*errors=*/nullptr));
+    SkASSERT(slot_count(loop.fIndex->type()) == 1);
+
+    Slot index = this->getSlot(*loop.fIndex);
+    double val = loop.fStart;
+
+    skvm::I32 oldLoopMask     = fLoopMask,
+              oldContinueMask = fContinueMask;
+
+    for (int i = 0; i < loop.fCount; ++i) {
+        fSlots[index] = loop.fIndex->type().isInteger()
+                                ? fBuilder->splat(static_cast<int>(val)).id
+                                : fBuilder->splat(static_cast<float>(val)).id;
+
+        fContinueMask = fBuilder->splat(0);
+        this->writeStatement(*f.statement());
+        fLoopMask |= fContinueMask;
+
+        val += loop.fDelta;
+    }
+
+    fLoopMask     = oldLoopMask;
+    fContinueMask = oldContinueMask;
+}
+
 void SkVMGenerator::writeIfStatement(const IfStatement& i) {
     Value test = this->writeExpression(*i.test());
     {
-        AutoMask ifTrue(this, i32(test));
+        ScopedCondition ifTrue(this, i32(test));
         this->writeStatement(*i.ifTrue());
     }
     if (i.ifFalse()) {
-        AutoMask ifFalse(this, ~i32(test));
+        ScopedCondition ifFalse(this, ~i32(test));
         this->writeStatement(*i.ifFalse());
     }
 }
@@ -1390,9 +1441,18 @@
         case Statement::Kind::kBlock:
             this->writeBlock(s.as<Block>());
             break;
+        case Statement::Kind::kBreak:
+            this->writeBreakStatement();
+            break;
+        case Statement::Kind::kContinue:
+            this->writeContinueStatement();
+            break;
         case Statement::Kind::kExpression:
             this->writeExpression(*s.as<ExpressionStatement>().expression());
             break;
+        case Statement::Kind::kFor:
+            this->writeForStatement(s.as<ForStatement>());
+            break;
         case Statement::Kind::kIf:
             this->writeIfStatement(s.as<IfStatement>());
             break;
@@ -1402,11 +1462,8 @@
         case Statement::Kind::kVarDeclaration:
             this->writeVarDeclaration(s.as<VarDeclaration>());
             break;
-        case Statement::Kind::kBreak:
-        case Statement::Kind::kContinue:
         case Statement::Kind::kDiscard:
         case Statement::Kind::kDo:
-        case Statement::Kind::kFor:
         case Statement::Kind::kSwitch:
             SkDEBUGFAIL("Unsupported control flow");
             break;
diff --git a/tests/SkSLInterpreterTest.cpp b/tests/SkSLInterpreterTest.cpp
index 3292ee3..260cf25 100644
--- a/tests/SkSLInterpreterTest.cpp
+++ b/tests/SkSLInterpreterTest.cpp
@@ -222,11 +222,8 @@
 
 void test(skiatest::Reporter* r, const char* src,
           float inR, float inG, float inB, float inA,
-          float exR, float exG, float exB, float exA,
-          bool testWithSkVM = true) {
-    if (testWithSkVM) {
-        test_skvm(r, src, inR, inG, inB, inA, exR, exG, exB, exA);
-    }
+          float exR, float exG, float exB, float exA) {
+    test_skvm(r, src, inR, inG, inB, inA, exR, exG, exB, exA);
 
     ByteCodeBuilder byteCode(r, src);
     if (!byteCode) { return; }
@@ -500,11 +497,9 @@
 }
 
 DEF_TEST(SkSLInterpreterFor, r) {
-    // TODO: SkVM for-loop support
     test(r, "void main(inout half4 color) { for (int i = 1; i <= 10; ++i) color.r += i; }",
          0, 0, 0, 0,
-         55, 0, 0, 0,
-         /*testWithSkVM=*/false);
+         55, 0, 0, 0);
     test(r,
          "void main(inout half4 color) {"
          "    for (int i = 1; i <= 10; ++i)"
@@ -512,8 +507,7 @@
          "            if (j >= i) { color.r += j; }"
          "}",
          0, 0, 0, 0,
-         385, 0, 0, 0,
-         /*testWithSkVM=*/false);
+         385, 0, 0, 0);
     test(r,
          "void main(inout half4 color) {"
          "    for (int i = 1; i <= 10; ++i)"
@@ -524,8 +518,7 @@
          "        }"
          "}",
          0, 0, 0, 0,
-         495, 0, 0, 0,
-         /*testWithSkVM=*/false);
+         495, 0, 0, 0);
 }
 
 DEF_TEST(SkSLInterpreterPrefixPostfix, r) {
@@ -678,7 +671,7 @@
         REPORTER_ASSERT(r, out == 8);
     }
 
-    // TODO: Doesn't work until SkVM generator supports loops
+    // TODO: Doesn't work until SkVM generator supports indexing-by-loop variable
     if (false) {
         float in[8] = { 1, 2, 3, 4, 5, 6, 7, 8 };
         float out = 0;
@@ -699,7 +692,7 @@
         REPORTER_ASSERT(r, out == gRects[2]);
     }
 
-    // TODO: Doesn't work until SkVM generator supports loops
+    // TODO: Doesn't work until SkVM generator supports indexing-by-loop variable
     if (false) {
         ManyRects in;
         memset(&in, 0, sizeof(in));