moved SkSL Switch data into IRNode
Change-Id: I0373cccfd3acc56417f8d1545bbe7320dc2dfa05
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/327256
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index 87c74d3..98dc92c 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -482,14 +482,14 @@
}
case Statement::Kind::kSwitch: {
auto& sw = s.template as<SwitchStatement>();
- if (this->visitExpression(*sw.fValue)) {
+ if (this->visitExpression(*sw.value())) {
return true;
}
- for (auto& c : sw.fCases) {
- if (c->fValue && this->visitExpression(*c->fValue)) {
+ for (const auto& c : sw.cases()) {
+ if (c.value() && this->visitExpression(*c.value())) {
return true;
}
- for (auto& st : c->fStatements) {
+ for (auto& st : c.statements()) {
if (this->visitStatement(*st)) {
return true;
}
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 0c2cd6b..4b67514 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -612,26 +612,26 @@
}
case Statement::Kind::kSwitch: {
SwitchStatement& ss = (*s)->as<SwitchStatement>();
- this->addExpression(cfg, &ss.fValue, /*constantPropagate=*/true);
+ this->addExpression(cfg, &ss.value(), /*constantPropagate=*/true);
cfg.currentBlock().fNodes.push_back(BasicBlock::MakeStatement(s));
BlockId start = cfg.fCurrent;
BlockId switchExit = cfg.newIsolatedBlock();
fLoopExits.push(switchExit);
- for (const auto& c : ss.fCases) {
+ for (auto& c : ss.cases()) {
cfg.newBlock();
cfg.addExit(start, cfg.fCurrent);
- if (c->fValue) {
+ if (c.value()) {
// technically this should go in the start block, but it doesn't actually matter
// because it must be constant. Not worth running two loops for.
- this->addExpression(cfg, &c->fValue, /*constantPropagate=*/true);
+ this->addExpression(cfg, &c.value(), /*constantPropagate=*/true);
}
- for (auto& caseStatement : c->fStatements) {
+ for (auto& caseStatement : c.statements()) {
this->addStatement(cfg, &caseStatement);
}
}
cfg.addExit(cfg.fCurrent, switchExit);
// note that unlike GLSL, our grammar requires the default case to be last
- if (ss.fCases.empty() || ss.fCases.back()->fValue) {
+ if (ss.cases().empty() || ss.cases().back().value()) {
// switch does not have a default clause, mark that it can skip straight to the end
cfg.addExit(start, switchExit);
}
diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp
index 8d26ede..f3cc39e 100644
--- a/src/sksl/SkSLCPPCodeGenerator.cpp
+++ b/src/sksl/SkSLCPPCodeGenerator.cpp
@@ -380,7 +380,7 @@
}
void CPPCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
- if (s.fIsStatic) {
+ if (s.isStatic()) {
this->write("@");
}
INHERITED::writeSwitchStatement(s);
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index de55721..13c4575 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -1205,9 +1205,9 @@
// We have to be careful to not move any of the pointers until after we're sure we're going to
// succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
// of action. First, find the switch-case we are interested in.
- auto iter = switchStatement->fCases.begin();
- for (; iter != switchStatement->fCases.end(); ++iter) {
- if (iter->get() == caseToCapture) {
+ auto iter = switchStatement->cases().begin();
+ for (; iter != switchStatement->cases().end(); ++iter) {
+ if (&*iter == caseToCapture) {
break;
}
}
@@ -1217,8 +1217,8 @@
// statements that we can use for simplification.
auto startIter = iter;
Statement* unconditionalBreakStmt = nullptr;
- for (; iter != switchStatement->fCases.end(); ++iter) {
- for (std::unique_ptr<Statement>& stmt : (*iter)->fStatements) {
+ for (; iter != switchStatement->cases().end(); ++iter) {
+ for (std::unique_ptr<Statement>& stmt : iter->statements()) {
if (contains_conditional_break(*stmt)) {
// We can't reduce switch-cases to a block when they have conditional breaks.
return nullptr;
@@ -1243,7 +1243,7 @@
// We can move over most of the statements as-is.
while (startIter != iter) {
- for (std::unique_ptr<Statement>& stmt : (*startIter)->fStatements) {
+ for (std::unique_ptr<Statement>& stmt : startIter->statements()) {
caseStmts.push_back(std::move(stmt));
}
++startIter;
@@ -1252,7 +1252,7 @@
// If we found an unconditional break at the end, we need to move what we can while avoiding
// that break.
if (unconditionalBreakStmt != nullptr) {
- for (std::unique_ptr<Statement>& stmt : (*startIter)->fStatements) {
+ for (std::unique_ptr<Statement>& stmt : startIter->statements()) {
if (stmt.get() == unconditionalBreakStmt) {
move_all_but_break(stmt, &caseStmts);
unconditionalBreakStmt = nullptr;
@@ -1266,7 +1266,7 @@
SkASSERT(unconditionalBreakStmt == nullptr); // Verify that we fixed the unconditional break.
// Return our newly-synthesized block.
- return std::make_unique<Block>(/*offset=*/-1, std::move(caseStmts), switchStatement->fSymbols);
+ return std::make_unique<Block>(/*offset=*/-1, std::move(caseStmts), switchStatement->symbols());
}
void Compiler::simplifyStatement(DefinitionMap& definitions,
@@ -1334,28 +1334,30 @@
case Statement::Kind::kSwitch: {
SwitchStatement& s = stmt->as<SwitchStatement>();
int64_t switchValue;
- if (fIRGenerator->getConstantInt(*s.fValue, &switchValue)) {
+ if (fIRGenerator->getConstantInt(*s.value(), &switchValue)) {
// switch is constant, replace it with the case that matches
bool found = false;
SwitchCase* defaultCase = nullptr;
- for (const std::unique_ptr<SwitchCase>& c : s.fCases) {
- if (!c->fValue) {
- defaultCase = c.get();
+ for (SwitchCase& c : s.cases()) {
+ if (!c.value()) {
+ defaultCase = &c;
continue;
}
int64_t caseValue;
- SkAssertResult(fIRGenerator->getConstantInt(*c->fValue, &caseValue));
+ SkAssertResult(fIRGenerator->getConstantInt(*c.value(), &caseValue));
if (caseValue == switchValue) {
- std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
+ std::unique_ptr<Statement> newBlock = block_for_case(&s, &c);
if (newBlock) {
(*iter)->setStatement(std::move(newBlock));
found = true;
break;
} else {
- if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
+ if (s.isStatic() && !(fFlags & kPermitInvalidStaticTests_Flag) &&
+ optimizationContext->fSilences.find(&s) ==
+ optimizationContext->fSilences.end()) {
this->error(s.fOffset,
"static switch contains non-static conditional break");
- s.fIsStatic = false;
+ optimizationContext->fSilences.insert(&s);
}
return; // can't simplify
}
@@ -1368,10 +1370,12 @@
if (newBlock) {
(*iter)->setStatement(std::move(newBlock));
} else {
- if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
+ if (s.isStatic() && !(fFlags & kPermitInvalidStaticTests_Flag) &&
+ optimizationContext->fSilences.find(&s) ==
+ optimizationContext->fSilences.end()) {
this->error(s.fOffset,
"static switch contains non-static conditional break");
- s.fIsStatic = false;
+ optimizationContext->fSilences.insert(&s);
}
return; // can't simplify
}
@@ -1498,8 +1502,10 @@
++iter;
break;
case Statement::Kind::kSwitch:
- if (s.as<SwitchStatement>().fIsStatic &&
- !(fFlags & kPermitInvalidStaticTests_Flag)) {
+ if (s.as<SwitchStatement>().isStatic() &&
+ !(fFlags & kPermitInvalidStaticTests_Flag) &&
+ optimizationContext.fSilences.find(&s) ==
+ optimizationContext.fSilences.end()) {
this->error(s.fOffset, "static switch has non-static test");
}
++iter;
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 1a6c4a9..663b57b 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -466,14 +466,14 @@
case Statement::Kind::kSwitch: {
const SwitchStatement& ss = s->as<SwitchStatement>();
this->writeCommand(Rehydrator::kSwitch_Command);
- this->writeU8(ss.fIsStatic);
- AutoDehydratorSymbolTable symbols(this, ss.fSymbols);
- this->write(ss.fValue.get());
- this->writeU8(ss.fCases.size());
- for (const std::unique_ptr<SwitchCase>& sc : ss.fCases) {
- this->write(sc->fValue.get());
- this->writeU8(sc->fStatements.size());
- for (const std::unique_ptr<Statement>& stmt : sc->fStatements) {
+ this->writeU8(ss.isStatic());
+ AutoDehydratorSymbolTable symbols(this, ss.symbols());
+ this->write(ss.value().get());
+ this->writeU8(ss.cases().count());
+ for (const SwitchCase& sc : ss.cases()) {
+ this->write(sc.value().get());
+ this->writeU8(sc.statements().size());
+ for (const std::unique_ptr<Statement>& stmt : sc.statements()) {
this->write(stmt.get());
}
}
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 11e1357..f6d4d61 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -1445,19 +1445,19 @@
void GLSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
this->write("switch (");
- this->writeExpression(*s.fValue, kTopLevel_Precedence);
+ this->writeExpression(*s.value(), kTopLevel_Precedence);
this->writeLine(") {");
fIndentation++;
- for (const auto& c : s.fCases) {
- if (c->fValue) {
+ for (const SwitchCase& c : s.cases()) {
+ if (c.value()) {
this->write("case ");
- this->writeExpression(*c->fValue, kTopLevel_Precedence);
+ this->writeExpression(*c.value(), kTopLevel_Precedence);
this->writeLine(":");
} else {
this->writeLine("default:");
}
fIndentation++;
- for (const auto& stmt : c->fStatements) {
+ for (const auto& stmt : c.statements()) {
this->writeStatement(*stmt);
this->writeLine();
}
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 1ff4109..f29ab93 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -535,14 +535,14 @@
case Statement::Kind::kSwitch: {
const SwitchStatement& ss = statement.as<SwitchStatement>();
std::vector<std::unique_ptr<SwitchCase>> cases;
- cases.reserve(ss.fCases.size());
- for (const auto& sc : ss.fCases) {
- cases.push_back(std::make_unique<SwitchCase>(offset, expr(sc->fValue),
- stmts(sc->fStatements)));
+ cases.reserve(ss.cases().count());
+ for (const SwitchCase& sc : ss.cases()) {
+ cases.push_back(std::make_unique<SwitchCase>(offset, expr(sc.value()),
+ stmts(sc.statements())));
}
- return std::make_unique<SwitchStatement>(offset, ss.fIsStatic, expr(ss.fValue),
+ return std::make_unique<SwitchStatement>(offset, ss.isStatic(), expr(ss.value()),
std::move(cases),
- SymbolTable::WrapIfBuiltin(ss.fSymbols));
+ SymbolTable::WrapIfBuiltin(ss.symbols()));
}
case Statement::Kind::kVarDeclaration: {
const VarDeclaration& decl = statement.as<VarDeclaration>();
@@ -935,14 +935,14 @@
}
case Statement::Kind::kSwitch: {
SwitchStatement& switchStmt = (*stmt)->as<SwitchStatement>();
- if (switchStmt.fSymbols) {
- fSymbolTableStack.push_back(switchStmt.fSymbols.get());
+ if (switchStmt.symbols()) {
+ fSymbolTableStack.push_back(switchStmt.symbols().get());
}
- this->visitExpression(&switchStmt.fValue);
- for (std::unique_ptr<SwitchCase>& switchCase : switchStmt.fCases) {
+ this->visitExpression(&switchStmt.value());
+ for (SwitchCase& switchCase : switchStmt.cases()) {
// The switch-case's fValue cannot be a FunctionCall; skip it.
- for (std::unique_ptr<Statement>& caseBlock : switchCase->fStatements) {
+ for (std::unique_ptr<Statement>& caseBlock : switchCase.statements()) {
this->visitStatement(&caseBlock);
}
}
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 22f111c..c8ef96a 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -1365,19 +1365,19 @@
void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
this->write("switch (");
- this->writeExpression(*s.fValue, kTopLevel_Precedence);
+ this->writeExpression(*s.value(), kTopLevel_Precedence);
this->writeLine(") {");
fIndentation++;
- for (const auto& c : s.fCases) {
- if (c->fValue) {
+ for (const SwitchCase& c : s.cases()) {
+ if (c.value()) {
this->write("case ");
- this->writeExpression(*c->fValue, kTopLevel_Precedence);
+ this->writeExpression(*c.value(), kTopLevel_Precedence);
this->writeLine(":");
} else {
this->writeLine("default:");
}
fIndentation++;
- for (const auto& stmt : c->fStatements) {
+ for (const auto& stmt : c.statements()) {
this->writeStatement(*stmt);
this->writeLine();
}
@@ -1801,9 +1801,9 @@
}
case Statement::Kind::kSwitch: {
const SwitchStatement& sw = s->as<SwitchStatement>();
- Requirements result = this->requirements(sw.fValue.get());
- for (const auto& c : sw.fCases) {
- for (const auto& st : c->fStatements) {
+ Requirements result = this->requirements(sw.value().get());
+ for (const SwitchCase& sc : sw.cases()) {
+ for (const auto& st : sc.statements()) {
result |= this->requirements(st.get());
}
}
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 3c5d6e0..171af16 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -164,7 +164,7 @@
}
void PipelineStageCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
- if (s.fIsStatic) {
+ if (s.isStatic()) {
this->write("@");
}
INHERITED::writeSwitchStatement(s);
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index ba43ff9..0062431 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -3048,16 +3048,17 @@
}
void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
- SpvId value = this->writeExpression(*s.fValue, out);
+ SpvId value = this->writeExpression(*s.value(), out);
std::vector<SpvId> labels;
SpvId end = this->nextId();
SpvId defaultLabel = end;
fBreakTarget.push(end);
int size = 3;
- for (const auto& c : s.fCases) {
+ auto cases = s.cases();
+ for (const SwitchCase& c : cases) {
SpvId label = this->nextId();
labels.push_back(label);
- if (c->fValue) {
+ if (c.value()) {
size += 2;
} else {
defaultLabel = label;
@@ -3068,16 +3069,16 @@
this->writeOpCode(SpvOpSwitch, size, out);
this->writeWord(value, out);
this->writeWord(defaultLabel, out);
- for (size_t i = 0; i < s.fCases.size(); ++i) {
- if (!s.fCases[i]->fValue) {
+ for (int i = 0; i < cases.count(); ++i) {
+ if (!cases[i].value()) {
continue;
}
- this->writeWord(s.fCases[i]->fValue->as<IntLiteral>().value(), out);
+ this->writeWord(cases[i].value()->as<IntLiteral>().value(), out);
this->writeWord(labels[i], out);
}
- for (size_t i = 0; i < s.fCases.size(); ++i) {
+ for (int i = 0; i < cases.count(); ++i) {
this->writeLabel(labels[i], out);
- for (const auto& stmt : s.fCases[i]->fStatements) {
+ for (const auto& stmt : cases[i].statements()) {
this->writeStatement(*stmt, out);
}
if (fCurrentBlock) {
diff --git a/src/sksl/ir/SkSLIRNode.cpp b/src/sksl/ir/SkSLIRNode.cpp
index 34d1194..f3d81e3 100644
--- a/src/sksl/ir/SkSLIRNode.cpp
+++ b/src/sksl/ir/SkSLIRNode.cpp
@@ -112,6 +112,11 @@
, fKind(kind)
, fData(data) {}
+IRNode::IRNode(int offset, int kind, const SwitchStatementData& data)
+: fOffset(offset)
+, fKind(kind)
+, fData(data) {}
+
IRNode::IRNode(int offset, int kind, const SwizzleData& data)
: fOffset(offset)
, fKind(kind)
diff --git a/src/sksl/ir/SkSLIRNode.h b/src/sksl/ir/SkSLIRNode.h
index 58241af..007bdba 100644
--- a/src/sksl/ir/SkSLIRNode.h
+++ b/src/sksl/ir/SkSLIRNode.h
@@ -206,6 +206,11 @@
const Type* fType;
};
+ struct SwitchStatementData {
+ bool fIsStatic;
+ std::shared_ptr<SymbolTable> fSymbols;
+ };
+
struct SwizzleData {
const Type* fType;
std::vector<int> fComponents;
@@ -284,6 +289,7 @@
kSection,
kSetting,
kString,
+ kSwitchStatement,
kSwizzle,
kSymbol,
kSymbolAlias,
@@ -318,6 +324,7 @@
SectionData fSection;
SettingData fSetting;
String fString;
+ SwitchStatementData fSwitchStatement;
SwizzleData fSwizzle;
SymbolData fSymbol;
SymbolAliasData fSymbolAlias;
@@ -434,6 +441,11 @@
*(new(&fContents) String) = data;
}
+ NodeData(const SwitchStatementData& data)
+ : fKind(Kind::kSwitchStatement) {
+ *(new(&fContents) SwitchStatementData) = data;
+ }
+
NodeData(const SwizzleData& data)
: fKind(Kind::kSwizzle) {
*(new(&fContents) SwizzleData) = data;
@@ -554,6 +566,9 @@
case Kind::kString:
*(new(&fContents) String) = other.fContents.fString;
break;
+ case Kind::kSwitchStatement:
+ *(new(&fContents) SwitchStatementData) = other.fContents.fSwitchStatement;
+ break;
case Kind::kSwizzle:
*(new(&fContents) SwizzleData) = other.fContents.fSwizzle;
break;
@@ -655,6 +670,9 @@
case Kind::kString:
fContents.fString.~String();
break;
+ case Kind::kSwitchStatement:
+ fContents.fSwitchStatement.~SwitchStatementData();
+ break;
case Kind::kSwizzle:
fContents.fSwizzle.~SwizzleData();
break;
@@ -728,6 +746,8 @@
IRNode(int offset, int kind, const String& data);
+ IRNode(int offset, int kind, const SwitchStatementData& data);
+
IRNode(int offset, int kind, const SwizzleData& data);
IRNode(int offset, int kind, const SymbolData& data);
@@ -907,6 +927,16 @@
return fData.fContents.fString;
}
+ SwitchStatementData& switchStatementData() {
+ SkASSERT(fData.fKind == NodeData::Kind::kSwitchStatement);
+ return fData.fContents.fSwitchStatement;
+ }
+
+ const SwitchStatementData& switchStatementData() const {
+ SkASSERT(fData.fKind == NodeData::Kind::kSwitchStatement);
+ return fData.fContents.fSwitchStatement;
+ }
+
SwizzleData& swizzleData() {
SkASSERT(fData.fKind == NodeData::Kind::kSwizzle);
return fData.fContents.fSwizzle;
diff --git a/src/sksl/ir/SkSLNodeArrayWrapper.h b/src/sksl/ir/SkSLNodeArrayWrapper.h
index 527f789..d0e2d32 100644
--- a/src/sksl/ir/SkSLNodeArrayWrapper.h
+++ b/src/sksl/ir/SkSLNodeArrayWrapper.h
@@ -5,6 +5,9 @@
* found in the LICENSE file.
*/
+#ifndef SKSL_NODEARRAYWRAPPER
+#define SKSL_NODEARRAYWRAPPER
+
#include "include/private/SkTArray.h"
namespace SkSL {
@@ -205,7 +208,7 @@
friend class ConstNodeArrayWrapper;
};
- ConstNodeArrayWrapper(SkTArray<std::unique_ptr<Base>>* contents)
+ ConstNodeArrayWrapper(const SkTArray<std::unique_ptr<Base>>* contents)
: fContents(contents) {}
ConstNodeArrayWrapper(const ConstNodeArrayWrapper& other)
@@ -252,3 +255,5 @@
};
} // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLStatement.h b/src/sksl/ir/SkSLStatement.h
index 8ceabc2..b4d26a1 100644
--- a/src/sksl/ir/SkSLStatement.h
+++ b/src/sksl/ir/SkSLStatement.h
@@ -58,6 +58,9 @@
Statement(int offset, const InlineMarkerData& data)
: INHERITED(offset, (int) Kind::kInlineMarker, data) {}
+ Statement(int offset, const SwitchStatementData& data)
+ : INHERITED(offset, (int) Kind::kSwitch, data) {}
+
Statement(int offset, const VarDeclarationData& data)
: INHERITED(offset, (int) Kind::kVarDeclaration, data) {}
diff --git a/src/sksl/ir/SkSLSwitchCase.h b/src/sksl/ir/SkSLSwitchCase.h
index 7a1022e..4072070 100644
--- a/src/sksl/ir/SkSLSwitchCase.h
+++ b/src/sksl/ir/SkSLSwitchCase.h
@@ -16,42 +16,58 @@
/**
* A single case of a 'switch' statement.
*/
-struct SwitchCase : public Statement {
+class SwitchCase : public Statement {
+public:
static constexpr Kind kStatementKind = Kind::kSwitchCase;
+ // null value implies "default" case
SwitchCase(int offset, std::unique_ptr<Expression> value, StatementArray statements)
- : INHERITED(offset, kStatementKind)
- , fValue(std::move(value))
- , fStatements(std::move(statements)) {}
+ : INHERITED(offset, kStatementKind) {
+ fExpressionChildren.push_back(std::move(value));
+ fStatementChildren = std::move(statements);
+ }
+
+ std::unique_ptr<Expression>& value() {
+ return fExpressionChildren[0];
+ }
+
+ const std::unique_ptr<Expression>& value() const {
+ return fExpressionChildren[0];
+ }
+
+ StatementArray& statements() {
+ return fStatementChildren;
+ }
+
+ const StatementArray& statements() const {
+ return fStatementChildren;
+ }
std::unique_ptr<Statement> clone() const override {
StatementArray cloned;
- cloned.reserve_back(fStatements.size());
- for (const auto& s : fStatements) {
+ cloned.reserve_back(this->statements().size());
+ for (const auto& s : this->statements()) {
cloned.push_back(s->clone());
}
return std::make_unique<SwitchCase>(fOffset,
- fValue ? fValue->clone() : nullptr,
+ this->value() ? this->value()->clone() : nullptr,
std::move(cloned));
}
String description() const override {
String result;
- if (fValue) {
- result.appendf("case %s:\n", fValue->description().c_str());
+ if (this->value()) {
+ result.appendf("case %s:\n", this->value()->description().c_str());
} else {
result += "default:\n";
}
- for (const auto& s : fStatements) {
+ for (const auto& s : this->statements()) {
result += s->description() + "\n";
}
return result;
}
- // null value implies "default" case
- std::unique_ptr<Expression> fValue;
- StatementArray fStatements;
-
+private:
using INHERITED = Statement;
};
diff --git a/src/sksl/ir/SkSLSwitchStatement.h b/src/sksl/ir/SkSLSwitchStatement.h
index b2a1af7..bd3055a 100644
--- a/src/sksl/ir/SkSLSwitchStatement.h
+++ b/src/sksl/ir/SkSLSwitchStatement.h
@@ -8,6 +8,7 @@
#ifndef SKSL_SWITCHSTATEMENT
#define SKSL_SWITCHSTATEMENT
+#include "src/sksl/ir/SkSLNodeArrayWrapper.h"
#include "src/sksl/ir/SkSLStatement.h"
#include "src/sksl/ir/SkSLSwitchCase.h"
@@ -18,48 +19,76 @@
/**
* A 'switch' statement.
*/
-struct SwitchStatement : public Statement {
+class SwitchStatement : public Statement {
+public:
static constexpr Kind kStatementKind = Kind::kSwitch;
+ using CaseArray = NodeArrayWrapper<SwitchCase, Statement>;
+
+ using ConstCaseArray = ConstNodeArrayWrapper<SwitchCase, Statement>;
+
SwitchStatement(int offset, bool isStatic, std::unique_ptr<Expression> value,
std::vector<std::unique_ptr<SwitchCase>> cases,
const std::shared_ptr<SymbolTable> symbols)
- : INHERITED(offset, kStatementKind)
- , fIsStatic(isStatic)
- , fValue(std::move(value))
- , fSymbols(std::move(symbols))
- , fCases(std::move(cases)) {}
+ : INHERITED(offset, SwitchStatementData{isStatic, std::move(symbols)}) {
+ fExpressionChildren.push_back(std::move(value));
+ fStatementChildren.reserve_back(cases.size());
+ for (std::unique_ptr<SwitchCase>& c : cases) {
+ fStatementChildren.push_back(std::move(c));
+ }
+ }
+
+ std::unique_ptr<Expression>& value() {
+ return fExpressionChildren[0];
+ }
+
+ const std::unique_ptr<Expression>& value() const {
+ return fExpressionChildren[0];
+ }
+
+ CaseArray cases() {
+ return CaseArray(&fStatementChildren);
+ }
+
+ ConstCaseArray cases() const {
+ return ConstCaseArray(&fStatementChildren);
+ }
+
+ bool isStatic() const {
+ return this->switchStatementData().fIsStatic;
+ }
+
+ const std::shared_ptr<SymbolTable>& symbols() const {
+ return this->switchStatementData().fSymbols;
+ }
std::unique_ptr<Statement> clone() const override {
std::vector<std::unique_ptr<SwitchCase>> cloned;
- for (const auto& s : fCases) {
- cloned.push_back(std::unique_ptr<SwitchCase>((SwitchCase*) s->clone().release()));
+ for (const std::unique_ptr<Statement>& s : fStatementChildren) {
+ cloned.emplace_back((SwitchCase*) s->as<SwitchCase>().clone().release());
}
- return std::make_unique<SwitchStatement>(fOffset, fIsStatic, fValue->clone(),
- std::move(cloned),
- SymbolTable::WrapIfBuiltin(fSymbols));
+ return std::unique_ptr<Statement>(new SwitchStatement(
+ fOffset,
+ this->isStatic(),
+ this->value()->clone(),
+ std::move(cloned),
+ SymbolTable::WrapIfBuiltin(this->symbols())));
}
String description() const override {
String result;
- if (fIsStatic) {
+ if (this->isStatic()) {
result += "@";
}
- result += String::printf("switch (%s) {\n", fValue->description().c_str());
- for (const auto& c : fCases) {
- result += c->description();
+ result += String::printf("switch (%s) {\n", this->value()->description().c_str());
+ for (const auto& c : this->cases()) {
+ result += c.description();
}
result += "}";
return result;
}
- bool fIsStatic;
- std::unique_ptr<Expression> fValue;
- // it's important to keep fCases defined after (and thus destroyed before) fSymbols, because
- // destroying statements can modify reference counts in symbols
- const std::shared_ptr<SymbolTable> fSymbols;
- std::vector<std::unique_ptr<SwitchCase>> fCases;
-
+private:
using INHERITED = Statement;
};