moved SkSL Switch data into IRNode

Change-Id: I0373cccfd3acc56417f8d1545bbe7320dc2dfa05
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/327256
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index 87c74d3..98dc92c 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -482,14 +482,14 @@
         }
         case Statement::Kind::kSwitch: {
             auto& sw = s.template as<SwitchStatement>();
-            if (this->visitExpression(*sw.fValue)) {
+            if (this->visitExpression(*sw.value())) {
                 return true;
             }
-            for (auto& c : sw.fCases) {
-                if (c->fValue && this->visitExpression(*c->fValue)) {
+            for (const auto& c : sw.cases()) {
+                if (c.value() && this->visitExpression(*c.value())) {
                     return true;
                 }
-                for (auto& st : c->fStatements) {
+                for (auto& st : c.statements()) {
                     if (this->visitStatement(*st)) {
                         return true;
                     }
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 0c2cd6b..4b67514 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -612,26 +612,26 @@
         }
         case Statement::Kind::kSwitch: {
             SwitchStatement& ss = (*s)->as<SwitchStatement>();
-            this->addExpression(cfg, &ss.fValue, /*constantPropagate=*/true);
+            this->addExpression(cfg, &ss.value(), /*constantPropagate=*/true);
             cfg.currentBlock().fNodes.push_back(BasicBlock::MakeStatement(s));
             BlockId start = cfg.fCurrent;
             BlockId switchExit = cfg.newIsolatedBlock();
             fLoopExits.push(switchExit);
-            for (const auto& c : ss.fCases) {
+            for (auto& c : ss.cases()) {
                 cfg.newBlock();
                 cfg.addExit(start, cfg.fCurrent);
-                if (c->fValue) {
+                if (c.value()) {
                     // technically this should go in the start block, but it doesn't actually matter
                     // because it must be constant. Not worth running two loops for.
-                    this->addExpression(cfg, &c->fValue, /*constantPropagate=*/true);
+                    this->addExpression(cfg, &c.value(), /*constantPropagate=*/true);
                 }
-                for (auto& caseStatement : c->fStatements) {
+                for (auto& caseStatement : c.statements()) {
                     this->addStatement(cfg, &caseStatement);
                 }
             }
             cfg.addExit(cfg.fCurrent, switchExit);
             // note that unlike GLSL, our grammar requires the default case to be last
-            if (ss.fCases.empty() || ss.fCases.back()->fValue) {
+            if (ss.cases().empty() || ss.cases().back().value()) {
                 // switch does not have a default clause, mark that it can skip straight to the end
                 cfg.addExit(start, switchExit);
             }
diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp
index 8d26ede..f3cc39e 100644
--- a/src/sksl/SkSLCPPCodeGenerator.cpp
+++ b/src/sksl/SkSLCPPCodeGenerator.cpp
@@ -380,7 +380,7 @@
 }
 
 void CPPCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
-    if (s.fIsStatic) {
+    if (s.isStatic()) {
         this->write("@");
     }
     INHERITED::writeSwitchStatement(s);
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index de55721..13c4575 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -1205,9 +1205,9 @@
     // We have to be careful to not move any of the pointers until after we're sure we're going to
     // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
     // of action. First, find the switch-case we are interested in.
-    auto iter = switchStatement->fCases.begin();
-    for (; iter != switchStatement->fCases.end(); ++iter) {
-        if (iter->get() == caseToCapture) {
+    auto iter = switchStatement->cases().begin();
+    for (; iter != switchStatement->cases().end(); ++iter) {
+        if (&*iter == caseToCapture) {
             break;
         }
     }
@@ -1217,8 +1217,8 @@
     // statements that we can use for simplification.
     auto startIter = iter;
     Statement* unconditionalBreakStmt = nullptr;
-    for (; iter != switchStatement->fCases.end(); ++iter) {
-        for (std::unique_ptr<Statement>& stmt : (*iter)->fStatements) {
+    for (; iter != switchStatement->cases().end(); ++iter) {
+        for (std::unique_ptr<Statement>& stmt : iter->statements()) {
             if (contains_conditional_break(*stmt)) {
                 // We can't reduce switch-cases to a block when they have conditional breaks.
                 return nullptr;
@@ -1243,7 +1243,7 @@
 
     // We can move over most of the statements as-is.
     while (startIter != iter) {
-        for (std::unique_ptr<Statement>& stmt : (*startIter)->fStatements) {
+        for (std::unique_ptr<Statement>& stmt : startIter->statements()) {
             caseStmts.push_back(std::move(stmt));
         }
         ++startIter;
@@ -1252,7 +1252,7 @@
     // If we found an unconditional break at the end, we need to move what we can while avoiding
     // that break.
     if (unconditionalBreakStmt != nullptr) {
-        for (std::unique_ptr<Statement>& stmt : (*startIter)->fStatements) {
+        for (std::unique_ptr<Statement>& stmt : startIter->statements()) {
             if (stmt.get() == unconditionalBreakStmt) {
                 move_all_but_break(stmt, &caseStmts);
                 unconditionalBreakStmt = nullptr;
@@ -1266,7 +1266,7 @@
     SkASSERT(unconditionalBreakStmt == nullptr);  // Verify that we fixed the unconditional break.
 
     // Return our newly-synthesized block.
-    return std::make_unique<Block>(/*offset=*/-1, std::move(caseStmts), switchStatement->fSymbols);
+    return std::make_unique<Block>(/*offset=*/-1, std::move(caseStmts), switchStatement->symbols());
 }
 
 void Compiler::simplifyStatement(DefinitionMap& definitions,
@@ -1334,28 +1334,30 @@
         case Statement::Kind::kSwitch: {
             SwitchStatement& s = stmt->as<SwitchStatement>();
             int64_t switchValue;
-            if (fIRGenerator->getConstantInt(*s.fValue, &switchValue)) {
+            if (fIRGenerator->getConstantInt(*s.value(), &switchValue)) {
                 // switch is constant, replace it with the case that matches
                 bool found = false;
                 SwitchCase* defaultCase = nullptr;
-                for (const std::unique_ptr<SwitchCase>& c : s.fCases) {
-                    if (!c->fValue) {
-                        defaultCase = c.get();
+                for (SwitchCase& c : s.cases()) {
+                    if (!c.value()) {
+                        defaultCase = &c;
                         continue;
                     }
                     int64_t caseValue;
-                    SkAssertResult(fIRGenerator->getConstantInt(*c->fValue, &caseValue));
+                    SkAssertResult(fIRGenerator->getConstantInt(*c.value(), &caseValue));
                     if (caseValue == switchValue) {
-                        std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
+                        std::unique_ptr<Statement> newBlock = block_for_case(&s, &c);
                         if (newBlock) {
                             (*iter)->setStatement(std::move(newBlock));
                             found = true;
                             break;
                         } else {
-                            if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
+                            if (s.isStatic() && !(fFlags & kPermitInvalidStaticTests_Flag) &&
+                                optimizationContext->fSilences.find(&s) ==
+                                optimizationContext->fSilences.end()) {
                                 this->error(s.fOffset,
                                             "static switch contains non-static conditional break");
-                                s.fIsStatic = false;
+                                optimizationContext->fSilences.insert(&s);
                             }
                             return; // can't simplify
                         }
@@ -1368,10 +1370,12 @@
                         if (newBlock) {
                             (*iter)->setStatement(std::move(newBlock));
                         } else {
-                            if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
+                            if (s.isStatic() && !(fFlags & kPermitInvalidStaticTests_Flag) &&
+                                optimizationContext->fSilences.find(&s) ==
+                                optimizationContext->fSilences.end()) {
                                 this->error(s.fOffset,
                                             "static switch contains non-static conditional break");
-                                s.fIsStatic = false;
+                                optimizationContext->fSilences.insert(&s);
                             }
                             return; // can't simplify
                         }
@@ -1498,8 +1502,10 @@
                         ++iter;
                         break;
                     case Statement::Kind::kSwitch:
-                        if (s.as<SwitchStatement>().fIsStatic &&
-                            !(fFlags & kPermitInvalidStaticTests_Flag)) {
+                        if (s.as<SwitchStatement>().isStatic() &&
+                            !(fFlags & kPermitInvalidStaticTests_Flag) &&
+                            optimizationContext.fSilences.find(&s) ==
+                            optimizationContext.fSilences.end()) {
                             this->error(s.fOffset, "static switch has non-static test");
                         }
                         ++iter;
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 1a6c4a9..663b57b 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -466,14 +466,14 @@
             case Statement::Kind::kSwitch: {
                 const SwitchStatement& ss = s->as<SwitchStatement>();
                 this->writeCommand(Rehydrator::kSwitch_Command);
-                this->writeU8(ss.fIsStatic);
-                AutoDehydratorSymbolTable symbols(this, ss.fSymbols);
-                this->write(ss.fValue.get());
-                this->writeU8(ss.fCases.size());
-                for (const std::unique_ptr<SwitchCase>& sc : ss.fCases) {
-                    this->write(sc->fValue.get());
-                    this->writeU8(sc->fStatements.size());
-                    for (const std::unique_ptr<Statement>& stmt : sc->fStatements) {
+                this->writeU8(ss.isStatic());
+                AutoDehydratorSymbolTable symbols(this, ss.symbols());
+                this->write(ss.value().get());
+                this->writeU8(ss.cases().count());
+                for (const SwitchCase& sc : ss.cases()) {
+                    this->write(sc.value().get());
+                    this->writeU8(sc.statements().size());
+                    for (const std::unique_ptr<Statement>& stmt : sc.statements()) {
                         this->write(stmt.get());
                     }
                 }
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 11e1357..f6d4d61 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -1445,19 +1445,19 @@
 
 void GLSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
     this->write("switch (");
-    this->writeExpression(*s.fValue, kTopLevel_Precedence);
+    this->writeExpression(*s.value(), kTopLevel_Precedence);
     this->writeLine(") {");
     fIndentation++;
-    for (const auto& c : s.fCases) {
-        if (c->fValue) {
+    for (const SwitchCase& c : s.cases()) {
+        if (c.value()) {
             this->write("case ");
-            this->writeExpression(*c->fValue, kTopLevel_Precedence);
+            this->writeExpression(*c.value(), kTopLevel_Precedence);
             this->writeLine(":");
         } else {
             this->writeLine("default:");
         }
         fIndentation++;
-        for (const auto& stmt : c->fStatements) {
+        for (const auto& stmt : c.statements()) {
             this->writeStatement(*stmt);
             this->writeLine();
         }
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 1ff4109..f29ab93 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -535,14 +535,14 @@
         case Statement::Kind::kSwitch: {
             const SwitchStatement& ss = statement.as<SwitchStatement>();
             std::vector<std::unique_ptr<SwitchCase>> cases;
-            cases.reserve(ss.fCases.size());
-            for (const auto& sc : ss.fCases) {
-                cases.push_back(std::make_unique<SwitchCase>(offset, expr(sc->fValue),
-                                                             stmts(sc->fStatements)));
+            cases.reserve(ss.cases().count());
+            for (const SwitchCase& sc : ss.cases()) {
+                cases.push_back(std::make_unique<SwitchCase>(offset, expr(sc.value()),
+                                                             stmts(sc.statements())));
             }
-            return std::make_unique<SwitchStatement>(offset, ss.fIsStatic, expr(ss.fValue),
+            return std::make_unique<SwitchStatement>(offset, ss.isStatic(), expr(ss.value()),
                                                      std::move(cases),
-                                                     SymbolTable::WrapIfBuiltin(ss.fSymbols));
+                                                     SymbolTable::WrapIfBuiltin(ss.symbols()));
         }
         case Statement::Kind::kVarDeclaration: {
             const VarDeclaration& decl = statement.as<VarDeclaration>();
@@ -935,14 +935,14 @@
             }
             case Statement::Kind::kSwitch: {
                 SwitchStatement& switchStmt = (*stmt)->as<SwitchStatement>();
-                if (switchStmt.fSymbols) {
-                    fSymbolTableStack.push_back(switchStmt.fSymbols.get());
+                if (switchStmt.symbols()) {
+                    fSymbolTableStack.push_back(switchStmt.symbols().get());
                 }
 
-                this->visitExpression(&switchStmt.fValue);
-                for (std::unique_ptr<SwitchCase>& switchCase : switchStmt.fCases) {
+                this->visitExpression(&switchStmt.value());
+                for (SwitchCase& switchCase : switchStmt.cases()) {
                     // The switch-case's fValue cannot be a FunctionCall; skip it.
-                    for (std::unique_ptr<Statement>& caseBlock : switchCase->fStatements) {
+                    for (std::unique_ptr<Statement>& caseBlock : switchCase.statements()) {
                         this->visitStatement(&caseBlock);
                     }
                 }
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 22f111c..c8ef96a 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -1365,19 +1365,19 @@
 
 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
     this->write("switch (");
-    this->writeExpression(*s.fValue, kTopLevel_Precedence);
+    this->writeExpression(*s.value(), kTopLevel_Precedence);
     this->writeLine(") {");
     fIndentation++;
-    for (const auto& c : s.fCases) {
-        if (c->fValue) {
+    for (const SwitchCase& c : s.cases()) {
+        if (c.value()) {
             this->write("case ");
-            this->writeExpression(*c->fValue, kTopLevel_Precedence);
+            this->writeExpression(*c.value(), kTopLevel_Precedence);
             this->writeLine(":");
         } else {
             this->writeLine("default:");
         }
         fIndentation++;
-        for (const auto& stmt : c->fStatements) {
+        for (const auto& stmt : c.statements()) {
             this->writeStatement(*stmt);
             this->writeLine();
         }
@@ -1801,9 +1801,9 @@
         }
         case Statement::Kind::kSwitch: {
             const SwitchStatement& sw = s->as<SwitchStatement>();
-            Requirements result = this->requirements(sw.fValue.get());
-            for (const auto& c : sw.fCases) {
-                for (const auto& st : c->fStatements) {
+            Requirements result = this->requirements(sw.value().get());
+            for (const SwitchCase& sc : sw.cases()) {
+                for (const auto& st : sc.statements()) {
                     result |= this->requirements(st.get());
                 }
             }
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 3c5d6e0..171af16 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -164,7 +164,7 @@
 }
 
 void PipelineStageCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
-    if (s.fIsStatic) {
+    if (s.isStatic()) {
         this->write("@");
     }
     INHERITED::writeSwitchStatement(s);
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index ba43ff9..0062431 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -3048,16 +3048,17 @@
 }
 
 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
-    SpvId value = this->writeExpression(*s.fValue, out);
+    SpvId value = this->writeExpression(*s.value(), out);
     std::vector<SpvId> labels;
     SpvId end = this->nextId();
     SpvId defaultLabel = end;
     fBreakTarget.push(end);
     int size = 3;
-    for (const auto& c : s.fCases) {
+    auto cases = s.cases();
+    for (const SwitchCase& c : cases) {
         SpvId label = this->nextId();
         labels.push_back(label);
-        if (c->fValue) {
+        if (c.value()) {
             size += 2;
         } else {
             defaultLabel = label;
@@ -3068,16 +3069,16 @@
     this->writeOpCode(SpvOpSwitch, size, out);
     this->writeWord(value, out);
     this->writeWord(defaultLabel, out);
-    for (size_t i = 0; i < s.fCases.size(); ++i) {
-        if (!s.fCases[i]->fValue) {
+    for (int i = 0; i < cases.count(); ++i) {
+        if (!cases[i].value()) {
             continue;
         }
-        this->writeWord(s.fCases[i]->fValue->as<IntLiteral>().value(), out);
+        this->writeWord(cases[i].value()->as<IntLiteral>().value(), out);
         this->writeWord(labels[i], out);
     }
-    for (size_t i = 0; i < s.fCases.size(); ++i) {
+    for (int i = 0; i < cases.count(); ++i) {
         this->writeLabel(labels[i], out);
-        for (const auto& stmt : s.fCases[i]->fStatements) {
+        for (const auto& stmt : cases[i].statements()) {
             this->writeStatement(*stmt, out);
         }
         if (fCurrentBlock) {
diff --git a/src/sksl/ir/SkSLIRNode.cpp b/src/sksl/ir/SkSLIRNode.cpp
index 34d1194..f3d81e3 100644
--- a/src/sksl/ir/SkSLIRNode.cpp
+++ b/src/sksl/ir/SkSLIRNode.cpp
@@ -112,6 +112,11 @@
 , fKind(kind)
 , fData(data) {}
 
+IRNode::IRNode(int offset, int kind, const SwitchStatementData& data)
+: fOffset(offset)
+, fKind(kind)
+, fData(data) {}
+
 IRNode::IRNode(int offset, int kind, const SwizzleData& data)
 : fOffset(offset)
 , fKind(kind)
diff --git a/src/sksl/ir/SkSLIRNode.h b/src/sksl/ir/SkSLIRNode.h
index 58241af..007bdba 100644
--- a/src/sksl/ir/SkSLIRNode.h
+++ b/src/sksl/ir/SkSLIRNode.h
@@ -206,6 +206,11 @@
         const Type* fType;
     };
 
+    struct SwitchStatementData {
+        bool fIsStatic;
+        std::shared_ptr<SymbolTable> fSymbols;
+    };
+
     struct SwizzleData {
         const Type* fType;
         std::vector<int> fComponents;
@@ -284,6 +289,7 @@
             kSection,
             kSetting,
             kString,
+            kSwitchStatement,
             kSwizzle,
             kSymbol,
             kSymbolAlias,
@@ -318,6 +324,7 @@
             SectionData fSection;
             SettingData fSetting;
             String fString;
+            SwitchStatementData fSwitchStatement;
             SwizzleData fSwizzle;
             SymbolData fSymbol;
             SymbolAliasData fSymbolAlias;
@@ -434,6 +441,11 @@
             *(new(&fContents) String) = data;
         }
 
+        NodeData(const SwitchStatementData& data)
+            : fKind(Kind::kSwitchStatement) {
+            *(new(&fContents) SwitchStatementData) = data;
+        }
+
         NodeData(const SwizzleData& data)
             : fKind(Kind::kSwizzle) {
             *(new(&fContents) SwizzleData) = data;
@@ -554,6 +566,9 @@
                 case Kind::kString:
                     *(new(&fContents) String) = other.fContents.fString;
                     break;
+                case Kind::kSwitchStatement:
+                    *(new(&fContents) SwitchStatementData) = other.fContents.fSwitchStatement;
+                    break;
                 case Kind::kSwizzle:
                     *(new(&fContents) SwizzleData) = other.fContents.fSwizzle;
                     break;
@@ -655,6 +670,9 @@
                 case Kind::kString:
                     fContents.fString.~String();
                     break;
+                case Kind::kSwitchStatement:
+                    fContents.fSwitchStatement.~SwitchStatementData();
+                    break;
                 case Kind::kSwizzle:
                     fContents.fSwizzle.~SwizzleData();
                     break;
@@ -728,6 +746,8 @@
 
     IRNode(int offset, int kind, const String& data);
 
+    IRNode(int offset, int kind, const SwitchStatementData& data);
+
     IRNode(int offset, int kind, const SwizzleData& data);
 
     IRNode(int offset, int kind, const SymbolData& data);
@@ -907,6 +927,16 @@
         return fData.fContents.fString;
     }
 
+    SwitchStatementData& switchStatementData() {
+        SkASSERT(fData.fKind == NodeData::Kind::kSwitchStatement);
+        return fData.fContents.fSwitchStatement;
+    }
+
+    const SwitchStatementData& switchStatementData() const {
+        SkASSERT(fData.fKind == NodeData::Kind::kSwitchStatement);
+        return fData.fContents.fSwitchStatement;
+    }
+
     SwizzleData& swizzleData() {
         SkASSERT(fData.fKind == NodeData::Kind::kSwizzle);
         return fData.fContents.fSwizzle;
diff --git a/src/sksl/ir/SkSLNodeArrayWrapper.h b/src/sksl/ir/SkSLNodeArrayWrapper.h
index 527f789..d0e2d32 100644
--- a/src/sksl/ir/SkSLNodeArrayWrapper.h
+++ b/src/sksl/ir/SkSLNodeArrayWrapper.h
@@ -5,6 +5,9 @@
  * found in the LICENSE file.
  */
 
+#ifndef SKSL_NODEARRAYWRAPPER
+#define SKSL_NODEARRAYWRAPPER
+
 #include "include/private/SkTArray.h"
 
 namespace SkSL {
@@ -205,7 +208,7 @@
         friend class ConstNodeArrayWrapper;
     };
 
-    ConstNodeArrayWrapper(SkTArray<std::unique_ptr<Base>>* contents)
+    ConstNodeArrayWrapper(const SkTArray<std::unique_ptr<Base>>* contents)
         : fContents(contents) {}
 
     ConstNodeArrayWrapper(const ConstNodeArrayWrapper& other)
@@ -252,3 +255,5 @@
 };
 
 } // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLStatement.h b/src/sksl/ir/SkSLStatement.h
index 8ceabc2..b4d26a1 100644
--- a/src/sksl/ir/SkSLStatement.h
+++ b/src/sksl/ir/SkSLStatement.h
@@ -58,6 +58,9 @@
     Statement(int offset, const InlineMarkerData& data)
     : INHERITED(offset, (int) Kind::kInlineMarker, data) {}
 
+    Statement(int offset, const SwitchStatementData& data)
+    : INHERITED(offset, (int) Kind::kSwitch, data) {}
+
     Statement(int offset, const VarDeclarationData& data)
     : INHERITED(offset, (int) Kind::kVarDeclaration, data) {}
 
diff --git a/src/sksl/ir/SkSLSwitchCase.h b/src/sksl/ir/SkSLSwitchCase.h
index 7a1022e..4072070 100644
--- a/src/sksl/ir/SkSLSwitchCase.h
+++ b/src/sksl/ir/SkSLSwitchCase.h
@@ -16,42 +16,58 @@
 /**
  * A single case of a 'switch' statement.
  */
-struct SwitchCase : public Statement {
+class SwitchCase : public Statement {
+public:
     static constexpr Kind kStatementKind = Kind::kSwitchCase;
 
+    // null value implies "default" case
     SwitchCase(int offset, std::unique_ptr<Expression> value, StatementArray statements)
-            : INHERITED(offset, kStatementKind)
-            , fValue(std::move(value))
-            , fStatements(std::move(statements)) {}
+            : INHERITED(offset, kStatementKind) {
+        fExpressionChildren.push_back(std::move(value));
+        fStatementChildren = std::move(statements);
+    }
+
+    std::unique_ptr<Expression>& value() {
+        return fExpressionChildren[0];
+    }
+
+    const std::unique_ptr<Expression>& value() const {
+        return fExpressionChildren[0];
+    }
+
+    StatementArray& statements() {
+        return fStatementChildren;
+    }
+
+    const StatementArray& statements() const {
+        return fStatementChildren;
+    }
 
     std::unique_ptr<Statement> clone() const override {
         StatementArray cloned;
-        cloned.reserve_back(fStatements.size());
-        for (const auto& s : fStatements) {
+        cloned.reserve_back(this->statements().size());
+        for (const auto& s : this->statements()) {
             cloned.push_back(s->clone());
         }
         return std::make_unique<SwitchCase>(fOffset,
-                                            fValue ? fValue->clone() : nullptr,
+                                            this->value() ? this->value()->clone() : nullptr,
                                             std::move(cloned));
     }
 
     String description() const override {
         String result;
-        if (fValue) {
-            result.appendf("case %s:\n", fValue->description().c_str());
+        if (this->value()) {
+            result.appendf("case %s:\n", this->value()->description().c_str());
         } else {
             result += "default:\n";
         }
-        for (const auto& s : fStatements) {
+        for (const auto& s : this->statements()) {
             result += s->description() + "\n";
         }
         return result;
     }
 
-    // null value implies "default" case
-    std::unique_ptr<Expression> fValue;
-    StatementArray fStatements;
-
+private:
     using INHERITED = Statement;
 };
 
diff --git a/src/sksl/ir/SkSLSwitchStatement.h b/src/sksl/ir/SkSLSwitchStatement.h
index b2a1af7..bd3055a 100644
--- a/src/sksl/ir/SkSLSwitchStatement.h
+++ b/src/sksl/ir/SkSLSwitchStatement.h
@@ -8,6 +8,7 @@
 #ifndef SKSL_SWITCHSTATEMENT
 #define SKSL_SWITCHSTATEMENT
 
+#include "src/sksl/ir/SkSLNodeArrayWrapper.h"
 #include "src/sksl/ir/SkSLStatement.h"
 #include "src/sksl/ir/SkSLSwitchCase.h"
 
@@ -18,48 +19,76 @@
 /**
  * A 'switch' statement.
  */
-struct SwitchStatement : public Statement {
+class SwitchStatement : public Statement {
+public:
     static constexpr Kind kStatementKind = Kind::kSwitch;
 
+    using CaseArray = NodeArrayWrapper<SwitchCase, Statement>;
+
+    using ConstCaseArray = ConstNodeArrayWrapper<SwitchCase, Statement>;
+
     SwitchStatement(int offset, bool isStatic, std::unique_ptr<Expression> value,
                     std::vector<std::unique_ptr<SwitchCase>> cases,
                     const std::shared_ptr<SymbolTable> symbols)
-    : INHERITED(offset, kStatementKind)
-    , fIsStatic(isStatic)
-    , fValue(std::move(value))
-    , fSymbols(std::move(symbols))
-    , fCases(std::move(cases)) {}
+    : INHERITED(offset, SwitchStatementData{isStatic, std::move(symbols)}) {
+        fExpressionChildren.push_back(std::move(value));
+        fStatementChildren.reserve_back(cases.size());
+        for (std::unique_ptr<SwitchCase>& c : cases) {
+            fStatementChildren.push_back(std::move(c));
+        }
+    }
+
+    std::unique_ptr<Expression>& value() {
+        return fExpressionChildren[0];
+    }
+
+    const std::unique_ptr<Expression>& value() const {
+        return fExpressionChildren[0];
+    }
+
+    CaseArray cases() {
+        return CaseArray(&fStatementChildren);
+    }
+
+    ConstCaseArray cases() const {
+        return ConstCaseArray(&fStatementChildren);
+    }
+
+    bool isStatic() const {
+        return this->switchStatementData().fIsStatic;
+    }
+
+    const std::shared_ptr<SymbolTable>& symbols() const {
+        return this->switchStatementData().fSymbols;
+    }
 
     std::unique_ptr<Statement> clone() const override {
         std::vector<std::unique_ptr<SwitchCase>> cloned;
-        for (const auto& s : fCases) {
-            cloned.push_back(std::unique_ptr<SwitchCase>((SwitchCase*) s->clone().release()));
+        for (const std::unique_ptr<Statement>& s : fStatementChildren) {
+            cloned.emplace_back((SwitchCase*) s->as<SwitchCase>().clone().release());
         }
-        return std::make_unique<SwitchStatement>(fOffset, fIsStatic, fValue->clone(),
-                                                 std::move(cloned),
-                                                 SymbolTable::WrapIfBuiltin(fSymbols));
+        return std::unique_ptr<Statement>(new SwitchStatement(
+                                                      fOffset,
+                                                      this->isStatic(),
+                                                      this->value()->clone(),
+                                                      std::move(cloned),
+                                                      SymbolTable::WrapIfBuiltin(this->symbols())));
     }
 
     String description() const override {
         String result;
-        if (fIsStatic) {
+        if (this->isStatic()) {
             result += "@";
         }
-        result += String::printf("switch (%s) {\n", fValue->description().c_str());
-        for (const auto& c : fCases) {
-            result += c->description();
+        result += String::printf("switch (%s) {\n", this->value()->description().c_str());
+        for (const auto& c : this->cases()) {
+            result += c.description();
         }
         result += "}";
         return result;
     }
 
-    bool fIsStatic;
-    std::unique_ptr<Expression> fValue;
-    // it's important to keep fCases defined after (and thus destroyed before) fSymbols, because
-    // destroying statements can modify reference counts in symbols
-    const std::shared_ptr<SymbolTable> fSymbols;
-    std::vector<std::unique_ptr<SwitchCase>> fCases;
-
+private:
     using INHERITED = Statement;
 };