Fix switch optimization pass.

The optimizer now properly recognizes all types of exits from a switch
statement. Break, continue and return are all potential exits and need
to be considered when determining the exit path from the switch.

Previously, dead code elimination was hiding the effects of this bug
from us, but it meant that an optimized switch had the potential to
generate lots of worthless IR nodes which then needed to be detected and
eliminated by the CFG. In particular, this affected the enum form of
blend, causing a catastrophic amount of extra work to be done.

Change-Id: If857e38cadfc016884624ea4db25a273ad3dce5b
Bug: skia:11352
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/372958
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/gn/sksl_tests.gni b/gn/sksl_tests.gni
index 4306f3c..f195bcc 100644
--- a/gn/sksl_tests.gni
+++ b/gn/sksl_tests.gni
@@ -162,6 +162,9 @@
   "/sksl/errors/StaticIfTest.sksl",
   "/sksl/errors/StaticSwitchConditionalBreak.sksl",
   "/sksl/errors/StaticSwitchTest.sksl",
+  "/sksl/errors/StaticSwitchWithConditionalBreak.sksl",
+  "/sksl/errors/StaticSwitchWithConditionalContinue.sksl",
+  "/sksl/errors/StaticSwitchWithConditionalReturn.sksl",
   "/sksl/errors/StructNameWithoutIdentifier.sksl",
   "/sksl/errors/StructTooDeeplyNested.sksl",
   "/sksl/errors/SwitchDuplicateCase.sksl",
@@ -401,6 +404,7 @@
   "/sksl/shared/StaticSwitchWithBreakInsideBlock.sksl",
   "/sksl/shared/StaticSwitchWithConditionalBreak.sksl",
   "/sksl/shared/StaticSwitchWithConditionalBreakInsideBlock.sksl",
+  "/sksl/shared/StaticSwitchWithContinue.sksl",
   "/sksl/shared/StaticSwitchWithFallthroughA.sksl",
   "/sksl/shared/StaticSwitchWithFallthroughB.sksl",
   "/sksl/shared/StaticSwitchWithStaticConditionalBreak.sksl",
diff --git a/resources/sksl/errors/StaticSwitchWithConditionalBreak.sksl b/resources/sksl/errors/StaticSwitchWithConditionalBreak.sksl
new file mode 100644
index 0000000..b3e12135
--- /dev/null
+++ b/resources/sksl/errors/StaticSwitchWithConditionalBreak.sksl
@@ -0,0 +1,3 @@
+uniform half4 testInputs;
+
+void test_break()    { @switch (1) { case 1: if (testInputs.x > 2) break; } }
diff --git a/resources/sksl/errors/StaticSwitchWithConditionalContinue.sksl b/resources/sksl/errors/StaticSwitchWithConditionalContinue.sksl
new file mode 100644
index 0000000..aabefa2
--- /dev/null
+++ b/resources/sksl/errors/StaticSwitchWithConditionalContinue.sksl
@@ -0,0 +1,3 @@
+uniform half4 testInputs;
+
+void test_continue() { for (;;) { @switch (1) { case 1: if (testInputs.x > 3) continue; } } }
diff --git a/resources/sksl/errors/StaticSwitchWithConditionalReturn.sksl b/resources/sksl/errors/StaticSwitchWithConditionalReturn.sksl
new file mode 100644
index 0000000..77c1d07
--- /dev/null
+++ b/resources/sksl/errors/StaticSwitchWithConditionalReturn.sksl
@@ -0,0 +1,3 @@
+uniform half4 testInputs;
+
+void test_return()   { @switch (1) { case 1: if (testInputs.x > 1) return; } }
diff --git a/resources/sksl/shared/StaticSwitchWithContinue.sksl b/resources/sksl/shared/StaticSwitchWithContinue.sksl
new file mode 100644
index 0000000..77aeee4
--- /dev/null
+++ b/resources/sksl/shared/StaticSwitchWithContinue.sksl
@@ -0,0 +1,24 @@
+/*#pragma settings NoDeadCodeElimination*/
+
+// A continue inside a switch (where permitted) prevents fallthrough to the next case block, just
+// like a break statement would.
+
+// Make sure that we properly dead-strip code following `continue` in a switch.
+// This is particularly relevant because our inliner replaces return statements with continue.
+
+uniform half4 colorGreen, colorRed;
+
+half4 main() {
+    // A looping construct is required for continue.
+    float result = 0;
+    for (int x=0; x<=1; x++) {
+        @switch (2) {
+            case 1: result = abs(1); continue;
+            case 2: result = abs(2); continue;
+            case 3: result = abs(3); continue;
+            case 4: result = abs(4); continue;
+        }
+    }
+
+    return result == 2 ? colorGreen : colorRed;
+}
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 1b2a469..68226e2 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -1172,18 +1172,29 @@
     }
 }
 
-// Returns true if this statement could potentially execute a break at the current level. We ignore
-// nested loops and switches, since any breaks inside of them will merely break the loop / switch.
-static bool contains_conditional_break(Statement& stmt) {
-    class ContainsConditionalBreak : public ProgramVisitor {
+static bool contains_exit(Statement& stmt, bool conditionalExits) {
+    class ContainsExit : public ProgramVisitor {
     public:
+        ContainsExit(bool e) : fConditionalExits(e) {}
+
         bool visitStatement(const Statement& stmt) override {
             switch (stmt.kind()) {
                 case Statement::Kind::kBlock:
                     return INHERITED::visitStatement(stmt);
 
+                case Statement::Kind::kReturn:
+                    // Returns are an early exit regardless of the surrounding control structures.
+                    return fConditionalExits ? fInConditional : !fInConditional;
+
+                case Statement::Kind::kContinue:
+                    // Continues are an early exit from switches, but not loops.
+                    return !fInLoop &&
+                           (fConditionalExits ? fInConditional : !fInConditional);
+
                 case Statement::Kind::kBreak:
-                    return fInConditional > 0;
+                    // Breaks cannot escape from switches or loops.
+                    return !fInLoop && !fInSwitch &&
+                           (fConditionalExits ? fInConditional : !fInConditional);
 
                 case Statement::Kind::kIf: {
                     ++fInConditional;
@@ -1192,40 +1203,51 @@
                     return result;
                 }
 
+                case Statement::Kind::kFor:
+                case Statement::Kind::kDo: {
+                    // Loops are treated as conditionals because a loop could potentially execute
+                    // zero times. We don't have a straightforward way to determine that a loop
+                    // definitely executes at least once.
+                    ++fInConditional;
+                    ++fInLoop;
+                    bool result = INHERITED::visitStatement(stmt);
+                    --fInLoop;
+                    --fInConditional;
+                    return result;
+                }
+
+                case Statement::Kind::kSwitch: {
+                    ++fInSwitch;
+                    bool result = INHERITED::visitStatement(stmt);
+                    --fInSwitch;
+                    return result;
+                }
+
                 default:
                     return false;
             }
         }
 
+        bool fConditionalExits = false;
         int fInConditional = 0;
+        int fInLoop = 0;
+        int fInSwitch = 0;
         using INHERITED = ProgramVisitor;
     };
 
-    return ContainsConditionalBreak{}.visitStatement(stmt);
+    return ContainsExit{conditionalExits}.visitStatement(stmt);
 }
 
-// returns true if this statement definitely executes a break at the current level (we ignore
-// nested loops and switches, since any breaks inside of them will merely break the loop / switch)
-static bool contains_unconditional_break(Statement& stmt) {
-    class ContainsUnconditionalBreak : public ProgramVisitor {
-    public:
-        bool visitStatement(const Statement& stmt) override {
-            switch (stmt.kind()) {
-                case Statement::Kind::kBlock:
-                    return INHERITED::visitStatement(stmt);
+// Finds unconditional exits from a switch-case. Returns true if this statement unconditionally
+// causes an exit from this switch (via continue, break or return).
+static bool contains_unconditional_exit(Statement& stmt) {
+    return contains_exit(stmt, /*conditionalExits=*/false);
+}
 
-                case Statement::Kind::kBreak:
-                    return true;
-
-                default:
-                    return false;
-            }
-        }
-
-        using INHERITED = ProgramVisitor;
-    };
-
-    return ContainsUnconditionalBreak{}.visitStatement(stmt);
+// Finds conditional exits from a switch-case. Returns true if this statement contains a conditional
+// that wraps a potential exit from the switch (via continue, break or return).
+static bool contains_conditional_exit(Statement& stmt) {
+    return contains_exit(stmt, /*conditionalExits=*/true);
 }
 
 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
@@ -1277,23 +1299,22 @@
     // stuck and can't simplify at all. If we find an unconditional break, we have a range of
     // statements that we can use for simplification.
     auto startIter = iter;
-    Statement* unconditionalBreakStmt = nullptr;
+    Statement* stripBreakStmt = nullptr;
     for (; iter != switchStatement->cases().end(); ++iter) {
         for (std::unique_ptr<Statement>& stmt : (*iter)->statements()) {
-            if (contains_conditional_break(*stmt)) {
-                // We can't reduce switch-cases to a block when they have conditional breaks.
+            if (contains_conditional_exit(*stmt)) {
+                // We can't reduce switch-cases to a block when they have conditional exits.
                 return nullptr;
             }
-
-            if (contains_unconditional_break(*stmt)) {
-                // We found an unconditional break. We can use this block, but we need to strip
-                // out the break statement.
-                unconditionalBreakStmt = stmt.get();
+            if (contains_unconditional_exit(*stmt)) {
+                // We found an unconditional exit. We can use this block, but we need to strip
+                // out a break statement if it has one.
+                stripBreakStmt = stmt.get();
                 break;
             }
         }
 
-        if (unconditionalBreakStmt != nullptr) {
+        if (stripBreakStmt) {
             break;
         }
     }
@@ -1310,13 +1331,12 @@
         ++startIter;
     }
 
-    // If we found an unconditional break at the end, we need to move what we can while avoiding
-    // that break.
-    if (unconditionalBreakStmt != nullptr) {
+    // For the last statement, we need to move what we can, stopping at a break if there is one.
+    if (stripBreakStmt != nullptr) {
         for (std::unique_ptr<Statement>& stmt : (*startIter)->statements()) {
-            if (stmt.get() == unconditionalBreakStmt) {
+            if (stmt.get() == stripBreakStmt) {
                 move_all_but_break(stmt, &caseStmts);
-                unconditionalBreakStmt = nullptr;
+                stripBreakStmt = nullptr;
                 break;
             }
 
@@ -1324,7 +1344,7 @@
         }
     }
 
-    SkASSERT(unconditionalBreakStmt == nullptr);  // Verify that we fixed the unconditional break.
+    SkASSERT(stripBreakStmt == nullptr);  // Verify that we fixed the unconditional break.
 
     // Return our newly-synthesized block.
     return std::make_unique<Block>(/*offset=*/-1, std::move(caseStmts), switchStatement->symbols());
@@ -1424,7 +1444,7 @@
                                 auto [iter, didInsert] = optimizationContext->fSilences.insert(&s);
                                 if (didInsert) {
                                     this->error(s.fOffset, "static switch contains non-static "
-                                                           "conditional break");
+                                                           "conditional exit");
                                 }
                             }
                             return; // can't simplify
@@ -1443,7 +1463,7 @@
                                 auto [iter, didInsert] = optimizationContext->fSilences.insert(&s);
                                 if (didInsert) {
                                     this->error(s.fOffset, "static switch contains non-static "
-                                                           "conditional break");
+                                                           "conditional exit");
                                 }
                             }
                             return; // can't simplify
diff --git a/tests/sksl/errors/StaticSwitchConditionalBreak.glsl b/tests/sksl/errors/StaticSwitchConditionalBreak.glsl
index cb7d670..aca3d8a 100644
--- a/tests/sksl/errors/StaticSwitchConditionalBreak.glsl
+++ b/tests/sksl/errors/StaticSwitchConditionalBreak.glsl
@@ -1,4 +1,4 @@
 ### Compilation failed:
 
-error: 3: static switch contains non-static conditional break
+error: 3: static switch contains non-static conditional exit
 1 error
diff --git a/tests/sksl/errors/StaticSwitchWithConditionalBreak.glsl b/tests/sksl/errors/StaticSwitchWithConditionalBreak.glsl
new file mode 100644
index 0000000..aca3d8a
--- /dev/null
+++ b/tests/sksl/errors/StaticSwitchWithConditionalBreak.glsl
@@ -0,0 +1,4 @@
+### Compilation failed:
+
+error: 3: static switch contains non-static conditional exit
+1 error
diff --git a/tests/sksl/errors/StaticSwitchWithConditionalContinue.glsl b/tests/sksl/errors/StaticSwitchWithConditionalContinue.glsl
new file mode 100644
index 0000000..aca3d8a
--- /dev/null
+++ b/tests/sksl/errors/StaticSwitchWithConditionalContinue.glsl
@@ -0,0 +1,4 @@
+### Compilation failed:
+
+error: 3: static switch contains non-static conditional exit
+1 error
diff --git a/tests/sksl/errors/StaticSwitchWithConditionalReturn.glsl b/tests/sksl/errors/StaticSwitchWithConditionalReturn.glsl
new file mode 100644
index 0000000..aca3d8a
--- /dev/null
+++ b/tests/sksl/errors/StaticSwitchWithConditionalReturn.glsl
@@ -0,0 +1,4 @@
+### Compilation failed:
+
+error: 3: static switch contains non-static conditional exit
+1 error
diff --git a/tests/sksl/inliner/EnumsCanBeInlinedSafely.glsl b/tests/sksl/inliner/EnumsCanBeInlinedSafely.glsl
index cc423dc..ff6091a 100644
--- a/tests/sksl/inliner/EnumsCanBeInlinedSafely.glsl
+++ b/tests/sksl/inliner/EnumsCanBeInlinedSafely.glsl
@@ -2,13 +2,6 @@
 out vec4 sk_FragColor;
 vec4 helper();
 void main() {
-    for (int _1_loop = 0;_1_loop < 1; _1_loop++) {
-        {
-            {
-                continue;
-            }
-        }
-    }
     sk_FragColor = vec4(0.5, 0.5, 0.5, 1.0);
 
 }
diff --git a/tests/sksl/inliner/StaticSwitch.glsl b/tests/sksl/inliner/StaticSwitch.glsl
index 2775fc0..e0000e7 100644
--- a/tests/sksl/inliner/StaticSwitch.glsl
+++ b/tests/sksl/inliner/StaticSwitch.glsl
@@ -10,14 +10,6 @@
                 _0_get = abs(2.0);
                 continue;
             }
-            {
-                _0_get = abs(3.0);
-                continue;
-            }
-            {
-                _0_get = abs(4.0);
-                continue;
-            }
         }
         {
             _0_get = abs(5.0);
diff --git a/tests/sksl/shared/StaticSwitchWithContinue.asm.frag b/tests/sksl/shared/StaticSwitchWithContinue.asm.frag
new file mode 100644
index 0000000..5d6aa07
--- /dev/null
+++ b/tests/sksl/shared/StaticSwitchWithContinue.asm.frag
@@ -0,0 +1,100 @@
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %_entrypoint "_entrypoint" %sk_FragColor %sk_Clockwise
+OpExecutionMode %_entrypoint OriginUpperLeft
+OpName %sk_FragColor "sk_FragColor"
+OpName %sk_Clockwise "sk_Clockwise"
+OpName %_UniformBuffer "_UniformBuffer"
+OpMemberName %_UniformBuffer 0 "colorGreen"
+OpMemberName %_UniformBuffer 1 "colorRed"
+OpName %_entrypoint "_entrypoint"
+OpName %main "main"
+OpName %result "result"
+OpName %x "x"
+OpDecorate %sk_FragColor RelaxedPrecision
+OpDecorate %sk_FragColor Location 0
+OpDecorate %sk_FragColor Index 0
+OpDecorate %sk_Clockwise RelaxedPrecision
+OpDecorate %sk_Clockwise BuiltIn FrontFacing
+OpMemberDecorate %_UniformBuffer 0 Offset 0
+OpMemberDecorate %_UniformBuffer 0 RelaxedPrecision
+OpMemberDecorate %_UniformBuffer 1 Offset 16
+OpMemberDecorate %_UniformBuffer 1 RelaxedPrecision
+OpDecorate %_UniformBuffer Block
+OpDecorate %10 Binding 0
+OpDecorate %10 DescriptorSet 0
+OpDecorate %48 RelaxedPrecision
+OpDecorate %50 RelaxedPrecision
+OpDecorate %51 RelaxedPrecision
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%sk_FragColor = OpVariable %_ptr_Output_v4float Output
+%bool = OpTypeBool
+%_ptr_Input_bool = OpTypePointer Input %bool
+%sk_Clockwise = OpVariable %_ptr_Input_bool Input
+%_UniformBuffer = OpTypeStruct %v4float %v4float
+%_ptr_Uniform__UniformBuffer = OpTypePointer Uniform %_UniformBuffer
+%10 = OpVariable %_ptr_Uniform__UniformBuffer Uniform
+%void = OpTypeVoid
+%15 = OpTypeFunction %void
+%18 = OpTypeFunction %v4float
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%int_1 = OpConstant %int 1
+%float_2 = OpConstant %float 2
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Uniform_v4float = OpTypePointer Uniform %v4float
+%_entrypoint = OpFunction %void None %15
+%16 = OpLabel
+%17 = OpFunctionCall %v4float %main
+OpStore %sk_FragColor %17
+OpReturn
+OpFunctionEnd
+%main = OpFunction %v4float None %18
+%19 = OpLabel
+%result = OpVariable %_ptr_Function_float Function
+%x = OpVariable %_ptr_Function_int Function
+%41 = OpVariable %_ptr_Function_v4float Function
+OpStore %result %float_0
+OpStore %x %int_0
+OpBranch %27
+%27 = OpLabel
+OpLoopMerge %31 %30 None
+OpBranch %28
+%28 = OpLabel
+%32 = OpLoad %int %x
+%34 = OpSLessThanEqual %bool %32 %int_1
+OpBranchConditional %34 %29 %31
+%29 = OpLabel
+%35 = OpExtInst %float %1 FAbs %float_2
+OpStore %result %35
+OpBranch %30
+%30 = OpLabel
+%37 = OpLoad %int %x
+%38 = OpIAdd %int %37 %int_1
+OpStore %x %38
+OpBranch %27
+%31 = OpLabel
+%39 = OpLoad %float %result
+%40 = OpFOrdEqual %bool %39 %float_2
+OpSelectionMerge %45 None
+OpBranchConditional %40 %43 %44
+%43 = OpLabel
+%46 = OpAccessChain %_ptr_Uniform_v4float %10 %int_0
+%48 = OpLoad %v4float %46
+OpStore %41 %48
+OpBranch %45
+%44 = OpLabel
+%49 = OpAccessChain %_ptr_Uniform_v4float %10 %int_1
+%50 = OpLoad %v4float %49
+OpStore %41 %50
+OpBranch %45
+%45 = OpLabel
+%51 = OpLoad %v4float %41
+OpReturnValue %51
+OpFunctionEnd
diff --git a/tests/sksl/shared/StaticSwitchWithContinue.glsl b/tests/sksl/shared/StaticSwitchWithContinue.glsl
new file mode 100644
index 0000000..be465db
--- /dev/null
+++ b/tests/sksl/shared/StaticSwitchWithContinue.glsl
@@ -0,0 +1,14 @@
+
+out vec4 sk_FragColor;
+uniform vec4 colorGreen;
+uniform vec4 colorRed;
+vec4 main() {
+    float result = 0.0;
+    for (int x = 0;x <= 1; x++) {
+        {
+            result = abs(2.0);
+            continue;
+        }
+    }
+    return result == 2.0 ? colorGreen : colorRed;
+}
diff --git a/tests/sksl/shared/StaticSwitchWithContinue.metal b/tests/sksl/shared/StaticSwitchWithContinue.metal
new file mode 100644
index 0000000..fe8388d
--- /dev/null
+++ b/tests/sksl/shared/StaticSwitchWithContinue.metal
@@ -0,0 +1,27 @@
+#include <metal_stdlib>
+#include <simd/simd.h>
+using namespace metal;
+struct Uniforms {
+    float4 colorGreen;
+    float4 colorRed;
+};
+struct Inputs {
+};
+struct Outputs {
+    float4 sk_FragColor [[color(0)]];
+};
+
+
+fragment Outputs fragmentMain(Inputs _in [[stage_in]], constant Uniforms& _uniforms [[buffer(0)]], bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]) {
+    Outputs _out;
+    (void)_out;
+    float result = 0.0;
+    for (int x = 0;x <= 1; x++) {
+        {
+            result = abs(2.0);
+            continue;
+        }
+    }
+    _out.sk_FragColor = result == 2.0 ? _uniforms.colorGreen : _uniforms.colorRed;
+    return _out;
+}