moved SkSL FunctionDefinition data into IRNode

Change-Id: Ia828de0793ee66301ba315f4593b4d7d69222b4e
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/326717
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/core/SkRuntimeEffect.cpp b/src/core/SkRuntimeEffect.cpp
index 68910d6..951c386 100644
--- a/src/core/SkRuntimeEffect.cpp
+++ b/src/core/SkRuntimeEffect.cpp
@@ -214,7 +214,7 @@
         // Functions
         else if (elem->is<SkSL::FunctionDefinition>()) {
             const auto& func = elem->as<SkSL::FunctionDefinition>();
-            const SkSL::FunctionDeclaration& decl = func.fDeclaration;
+            const SkSL::FunctionDeclaration& decl = func.declaration();
             if (decl.name() == "main") {
                 hasMain = true;
             }
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index a462fb3..6e6e83d 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -324,7 +324,7 @@
 }
 
 bool Analysis::NodeCountExceeds(const FunctionDefinition& function, int limit) {
-    return NodeCountVisitor{limit}.visit(*function.fBody) > limit;
+    return NodeCountVisitor{limit}.visit(*function.body()) > limit;
 }
 
 bool Analysis::StatementWritesToVariable(const Statement& stmt, const Variable& var) {
@@ -506,7 +506,7 @@
             return false;
 
         case ProgramElement::Kind::kFunction:
-            return this->visitStatement(*pe.template as<FunctionDefinition>().fBody);
+            return this->visitStatement(*pe.template as<FunctionDefinition>().body());
 
         case ProgramElement::Kind::kInterfaceBlock:
             for (auto& e : pe.template as<InterfaceBlock>().fSizes) {
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index a3927e5..4928074 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -184,14 +184,14 @@
 
 std::unique_ptr<ByteCodeFunction> ByteCodeGenerator::writeFunction(const FunctionDefinition& f) {
     fFunction = &f;
-    std::unique_ptr<ByteCodeFunction> result(new ByteCodeFunction(&f.fDeclaration));
+    std::unique_ptr<ByteCodeFunction> result(new ByteCodeFunction(&f.declaration()));
     fParameterCount = result->fParameterCount;
     fLoopCount = fMaxLoopCount = 0;
     fConditionCount = fMaxConditionCount = 0;
     fStackCount = fMaxStackCount = 0;
     fCode = &result->fCode;
 
-    this->writeStatement(*f.fBody);
+    this->writeStatement(*f.body());
     if (0 == fErrors.errorCount()) {
         SkASSERT(fLoopCount == 0);
         SkASSERT(fConditionCount == 0);
@@ -204,7 +204,7 @@
     result->fLoopCount      = fMaxLoopCount;
     result->fStackCount     = fMaxStackCount;
 
-    const Type& returnType = f.fDeclaration.returnType();
+    const Type& returnType = f.declaration().returnType();
     if (returnType != *fContext.fVoid_Type) {
         result->fReturnCount = SlotCount(returnType);
     }
@@ -441,7 +441,7 @@
         }
         case Variable::Storage::kParameter: {
             int offset = 0;
-            for (const auto& p : fFunction->fDeclaration.parameters()) {
+            for (const auto& p : fFunction->declaration().parameters()) {
                 if (p == &var) {
                     SkASSERT(offset <= 255);
                     return { offset, Storage::kLocal };
@@ -1238,7 +1238,7 @@
     // before they're defined. This is an easy-to-understand rule that prevents recursion.
     int idx = -1;
     for (size_t i = 0; i < fFunctions.size(); ++i) {
-        if (f.function().matches(fFunctions[i]->fDeclaration)) {
+        if (f.function().matches(fFunctions[i]->declaration())) {
             idx = i;
             break;
         }
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 79c128a..0c2cd6b 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -654,7 +654,7 @@
     CFG result;
     result.fStart = result.newBlock();
     result.fCurrent = result.fStart;
-    this->addStatement(result, &f.fBody);
+    this->addStatement(result, &f.body());
     result.newBlock();
     result.fExit = result.fCurrent;
     return result;
diff --git a/src/sksl/SkSLCPPCodeGenerator.cpp b/src/sksl/SkSLCPPCodeGenerator.cpp
index 4537608..8d26ede 100644
--- a/src/sksl/SkSLCPPCodeGenerator.cpp
+++ b/src/sksl/SkSLCPPCodeGenerator.cpp
@@ -595,7 +595,7 @@
 }
 
 void CPPCodeGenerator::writeFunction(const FunctionDefinition& f) {
-    const FunctionDeclaration& decl = f.fDeclaration;
+    const FunctionDeclaration& decl = f.declaration();
     if (decl.isBuiltin()) {
         return;
     }
@@ -605,7 +605,7 @@
     fOut = &buffer;
     if (decl.name() == "main") {
         fInMain = true;
-        for (const std::unique_ptr<Statement>& s : f.fBody->as<Block>().children()) {
+        for (const std::unique_ptr<Statement>& s : f.body()->as<Block>().children()) {
             this->writeStatement(*s);
             this->writeLine();
         }
@@ -625,7 +625,7 @@
         }
         args += "};";
         this->addExtraEmitCodeLine(args.c_str());
-        for (const std::unique_ptr<Statement>& s : f.fBody->as<Block>().children()) {
+        for (const std::unique_ptr<Statement>& s : f.body()->as<Block>().children()) {
             this->writeStatement(*s);
             this->writeLine();
         }
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index ee29937..528ce07 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -360,8 +360,8 @@
         switch (element->kind()) {
             case ProgramElement::Kind::kFunction: {
                 const FunctionDefinition& f = element->as<FunctionDefinition>();
-                SkASSERT(f.fDeclaration.isBuiltin());
-                intrinsics->insertOrDie(f.fDeclaration.description(), std::move(element));
+                SkASSERT(f.declaration().isBuiltin());
+                intrinsics->insertOrDie(f.declaration().description(), std::move(element));
                 break;
             }
             case ProgramElement::Kind::kEnum: {
@@ -1525,9 +1525,9 @@
     }
 
     // check for missing return
-    if (f.fDeclaration.returnType() != *fContext->fVoid_Type) {
+    if (f.declaration().returnType() != *fContext->fVoid_Type) {
         if (cfg.fBlocks[cfg.fExit].fIsReachable) {
-            this->error(f.fOffset, String("function '" + String(f.fDeclaration.name()) +
+            this->error(f.fOffset, String("function '" + String(f.declaration().name()) +
                                           "' can exit without returning a value"));
         }
     }
@@ -1601,8 +1601,8 @@
                                            return false;
                                        }
                                        const auto& fn = element->as<FunctionDefinition>();
-                                       bool dead = fn.fDeclaration.callCount() == 0 &&
-                                                   fn.fDeclaration.name() != "main";
+                                       bool dead = fn.declaration().callCount() == 0 &&
+                                                   fn.declaration().name() != "main";
                                        madeChanges |= dead;
                                        return dead;
                                    }),
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 6bbcfd4..ea36e94 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -529,11 +529,11 @@
         case ProgramElement::Kind::kFunction: {
             const FunctionDefinition& f = e.as<FunctionDefinition>();
             this->writeCommand(Rehydrator::kFunctionDefinition_Command);
-            this->writeU16(this->symbolId(&f.fDeclaration));
-            this->write(f.fBody.get());
-            this->writeU8(f.fReferencedIntrinsics.size());
+            this->writeU16(this->symbolId(&f.declaration()));
+            this->write(f.body().get());
+            this->writeU8(f.referencedIntrinsics().size());
             std::set<uint16_t> ordered;
-            for (const FunctionDeclaration* ref : f.fReferencedIntrinsics) {
+            for (const FunctionDeclaration* ref : f.referencedIntrinsics()) {
                 ordered.insert(this->symbolId(ref));
             }
             for (uint16_t ref : ordered) {
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 1d53ad4a..1994bf3 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -1044,11 +1044,11 @@
     // accidentally end up here.
     SkASSERT(fProgramKind != Program::kPipelineStage_Kind);
 
-    this->writeTypePrecision(f.fDeclaration.returnType());
-    this->writeType(f.fDeclaration.returnType());
-    this->write(" " + f.fDeclaration.name() + "(");
+    this->writeTypePrecision(f.declaration().returnType());
+    this->writeType(f.declaration().returnType());
+    this->write(" " + f.declaration().name() + "(");
     const char* separator = "";
-    for (const auto& param : f.fDeclaration.parameters()) {
+    for (const auto& param : f.declaration().parameters()) {
         this->write(separator);
         separator = ", ";
         this->writeModifiers(param->modifiers(), false);
@@ -1076,7 +1076,7 @@
     OutputStream* oldOut = fOut;
     StringStream buffer;
     fOut = &buffer;
-    for (const std::unique_ptr<Statement>& stmt : f.fBody->as<Block>().children()) {
+    for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
         if (!stmt->isEmpty()) {
             this->writeStatement(*stmt);
             this->writeLine();
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index a1b784f..239f32b 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -727,7 +727,7 @@
                                                                 fContext.fVoid_Type.get(),
                                                                 /*builtin=*/false));
     fProgramElements->push_back(std::make_unique<FunctionDefinition>(/*offset=*/-1,
-                                                                     *invokeDecl,
+                                                                     invokeDecl,
                                                                      std::move(main)));
 
     std::vector<std::unique_ptr<VarDeclaration>> variables;
@@ -1061,10 +1061,10 @@
         if (Program::kVertex_Kind == fKind && funcData.fName == "main" && fRTAdjust) {
             body->children().push_back(this->getNormalizeSkPositionCode());
         }
-        auto result = std::make_unique<FunctionDefinition>(f.fOffset, *decl, std::move(body),
+        auto result = std::make_unique<FunctionDefinition>(f.fOffset, decl, std::move(body),
                                                            std::move(fReferencedIntrinsics));
         decl->setDefinition(result.get());
-        result->fSource = &f;
+        result->setSource(&f);
         fProgramElements->push_back(std::move(result));
     }
 }
@@ -2001,8 +2001,8 @@
 
         // Sort the referenced intrinsics into a consistent order; otherwise our output will become
         // non-deterministic.
-        std::vector<const FunctionDeclaration*> intrinsics(original.fReferencedIntrinsics.begin(),
-                                                           original.fReferencedIntrinsics.end());
+        std::vector<const FunctionDeclaration*> intrinsics(original.referencedIntrinsics().begin(),
+                                                           original.referencedIntrinsics().end());
         std::sort(intrinsics.begin(), intrinsics.end(),
                   [](const FunctionDeclaration* a, const FunctionDeclaration* b) {
                       if (a->isBuiltin() != b->isBuiltin()) {
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index c80374d..bbb6ee0 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -660,10 +660,10 @@
 
     // Create a variable to hold the result in the extra statements (excepting void).
     std::unique_ptr<Expression> resultExpr;
-    if (function.fDeclaration.returnType() != *fContext->fVoid_Type) {
+    if (function.declaration().returnType() != *fContext->fVoid_Type) {
         std::unique_ptr<Expression> noInitialValue;
-        resultExpr = makeInlineVar(String(function.fDeclaration.name()),
-                                   &function.fDeclaration.returnType(),
+        resultExpr = makeInlineVar(String(function.declaration().name()),
+                                   &function.declaration().returnType(),
                                    Modifiers{}, &noInitialValue);
    }
 
@@ -672,13 +672,13 @@
     VariableRewriteMap varMap;
     std::vector<int> argsToCopyBack;
     for (int i = 0; i < (int) arguments.size(); ++i) {
-        const Variable* param = function.fDeclaration.parameters()[i];
+        const Variable* param = function.declaration().parameters()[i];
         bool isOutParam = param->modifiers().fFlags & Modifiers::kOut_Flag;
 
         // If this argument can be inlined trivially (e.g. a swizzle, or a constant array index)...
         if (is_trivial_argument(*arguments[i])) {
             // ... and it's an `out` param, or it isn't written to within the inline function...
-            if (isOutParam || !Analysis::StatementWritesToVariable(*function.fBody, *param)) {
+            if (isOutParam || !Analysis::StatementWritesToVariable(*function.body(), *param)) {
                 // ... we don't need to copy it at all! We can just use the existing expression.
                 varMap[param] = arguments[i]->clone();
                 continue;
@@ -693,7 +693,7 @@
                                       param->modifiers(), &arguments[i]);
     }
 
-    const Block& body = function.fBody->as<Block>();
+    const Block& body = function.body()->as<Block>();
     auto inlineBlock = std::make_unique<Block>(offset, StatementArray{});
     inlineBlock->children().reserve(body.children().size());
     for (const std::unique_ptr<Statement>& stmt : body.children()) {
@@ -718,7 +718,7 @@
 
     // Copy back the values of `out` parameters into their real destinations.
     for (int i : argsToCopyBack) {
-        const Variable* p = function.fDeclaration.parameters()[i];
+        const Variable* p = function.declaration().parameters()[i];
         SkASSERT(varMap.find(p) != varMap.end());
         inlinedBody.children().push_back(
                 std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
@@ -823,7 +823,7 @@
             case ProgramElement::Kind::kFunction: {
                 FunctionDefinition& funcDef = pe->as<FunctionDefinition>();
                 fEnclosingFunction = &funcDef;
-                this->visitStatement(&funcDef.fBody);
+                this->visitStatement(&funcDef.body());
                 break;
             }
             default:
@@ -1164,7 +1164,7 @@
 
         // Convert the function call to its inlined equivalent.
         InlinedCall inlinedCall = this->inlineCall(&funcCall, candidate.fSymbols,
-                                                   &candidate.fEnclosingFunction->fDeclaration);
+                                                   &candidate.fEnclosingFunction->declaration());
         if (inlinedCall.fInlinedBody) {
             // Ensure that the inlined body has a scope if it needs one.
             this->ensureScopedBlocks(inlinedCall.fInlinedBody.get(), candidate.fParentStmt->get());
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 710bf60..985c9fe 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -936,7 +936,7 @@
 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
     fRTHeightName = fProgram.fInputs.fRTHeight ? "_globals->_anonInterface0->u_skRTHeight" : "";
     const char* separator = "";
-    if ("main" == f.fDeclaration.name()) {
+    if ("main" == f.declaration().name()) {
         switch (fProgram.fKind) {
             case Program::kFragment_Kind:
                 this->write("fragment Outputs fragmentMain");
@@ -1006,11 +1006,11 @@
         }
         separator = ", ";
     } else {
-        this->writeType(f.fDeclaration.returnType());
+        this->writeType(f.declaration().returnType());
         this->write(" ");
-        this->writeName(f.fDeclaration.name());
+        this->writeName(f.declaration().name());
         this->write("(");
-        Requirements requirements = this->requirements(f.fDeclaration);
+        Requirements requirements = this->requirements(f.declaration());
         if (requirements & kInputs_Requirement) {
             this->write("Inputs _in");
             separator = ", ";
@@ -1036,7 +1036,7 @@
             separator = ", ";
         }
     }
-    for (const auto& param : f.fDeclaration.parameters()) {
+    for (const auto& param : f.declaration().parameters()) {
         this->write(separator);
         separator = ", ";
         this->writeModifiers(param->modifiers(), false);
@@ -1064,7 +1064,7 @@
 
     SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
 
-    if (f.fDeclaration.name() == "main") {
+    if (f.declaration().name() == "main") {
         this->writeGlobalInit();
         this->writeLine("    Outputs _outputStruct;");
         this->writeLine("    thread Outputs* _out = &_outputStruct;");
@@ -1075,13 +1075,13 @@
     StringStream buffer;
     fOut = &buffer;
     fIndentation++;
-    for (const std::unique_ptr<Statement>& stmt : f.fBody->as<Block>().children()) {
+    for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
         if (!stmt->isEmpty()) {
             this->writeStatement(*stmt);
             this->writeLine();
         }
     }
-    if (f.fDeclaration.name() == "main") {
+    if (f.declaration().name() == "main") {
         switch (fProgram.fKind) {
             case Program::kFragment_Kind:
                 this->writeLine("return *_out;");
@@ -1813,8 +1813,8 @@
         for (const auto& e : fProgram.elements()) {
             if (e->is<FunctionDefinition>()) {
                 const FunctionDefinition& def = e->as<FunctionDefinition>();
-                if (&def.fDeclaration == &f) {
-                    Requirements reqs = this->requirements(def.fBody.get());
+                if (&def.declaration() == &f) {
+                    Requirements reqs = this->requirements(def.body().get());
                     fRequirements[&f] = reqs;
                     return reqs;
                 }
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index be2c69a..3c5d6e0 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -83,7 +83,7 @@
         int index = 0;
         for (const auto& e : fProgram.elements()) {
             if (e->is<FunctionDefinition>()) {
-                if (&e->as<FunctionDefinition>().fDeclaration == &function) {
+                if (&e->as<FunctionDefinition>().declaration() == &function) {
                     break;
                 }
                 ++index;
@@ -175,8 +175,8 @@
     OutputStream* oldOut = fOut;
     StringStream buffer;
     fOut = &buffer;
-    if (f.fDeclaration.name() == "main") {
-        for (const std::unique_ptr<Statement>& stmt : f.fBody->as<Block>().children()) {
+    if (f.declaration().name() == "main") {
+        for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
             this->writeStatement(*stmt);
             this->writeLine();
         }
@@ -184,7 +184,7 @@
         this->write(fFunctionHeader);
         this->write(buffer.str());
     } else {
-        const FunctionDeclaration& decl = f.fDeclaration;
+        const FunctionDeclaration& decl = f.declaration();
         Compiler::GLSLFunction result;
         if (!type_to_grsltype(fContext, decl.returnType(), &result.fReturnType)) {
             fErrors.error(f.fOffset, "unsupported return type");
@@ -199,7 +199,7 @@
             }
             result.fParameters.emplace_back(v->name(), paramSLType);
         }
-        for (const std::unique_ptr<Statement>& stmt : f.fBody->as<Block>().children()) {
+        for (const std::unique_ptr<Statement>& stmt : f.body()->as<Block>().children()) {
             this->writeStatement(*stmt);
             this->writeLine();
         }
diff --git a/src/sksl/SkSLRehydrator.cpp b/src/sksl/SkSLRehydrator.cpp
index 9ae5850..fd5453f 100644
--- a/src/sksl/SkSLRehydrator.cpp
+++ b/src/sksl/SkSLRehydrator.cpp
@@ -300,7 +300,7 @@
                 refs.insert(this->symbolRef<FunctionDeclaration>(
                                                                Symbol::Kind::kFunctionDeclaration));
             }
-            FunctionDefinition* result = new FunctionDefinition(-1, *decl, std::move(body),
+            FunctionDefinition* result = new FunctionDefinition(-1, decl, std::move(body),
                                                                 std::move(refs));
             decl->setDefinition(result);
             return std::unique_ptr<ProgramElement>(result);
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index f3f0580..27b2833 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2598,17 +2598,17 @@
 
 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
     fVariableBuffer.reset();
-    SpvId result = this->writeFunctionStart(f.fDeclaration, out);
+    SpvId result = this->writeFunctionStart(f.declaration(), out);
     this->writeLabel(this->nextId(), out);
     StringStream bodyBuffer;
-    this->writeBlock((Block&) *f.fBody, bodyBuffer);
+    this->writeBlock((Block&) *f.body(), bodyBuffer);
     write_stringstream(fVariableBuffer, out);
-    if (f.fDeclaration.name() == "main") {
+    if (f.declaration().name() == "main") {
         write_stringstream(fGlobalInitializersBuffer, out);
     }
     write_stringstream(bodyBuffer, out);
     if (fCurrentBlock) {
-        if (f.fDeclaration.returnType() == *fContext.fVoid_Type) {
+        if (f.declaration().returnType() == *fContext.fVoid_Type) {
             this->writeInstruction(SpvOpReturn, out);
         } else {
             this->writeInstruction(SpvOpUnreachable, out);
@@ -3167,7 +3167,7 @@
         switch (e->kind()) {
             case ProgramElement::Kind::kFunction: {
                 const FunctionDefinition& f = e->as<FunctionDefinition>();
-                fFunctionMap[&f.fDeclaration] = this->nextId();
+                fFunctionMap[&f.declaration()] = this->nextId();
                 break;
             }
             case ProgramElement::Kind::kModifiers: {
diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h
index c122173..304aee1 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.h
+++ b/src/sksl/ir/SkSLFunctionDefinition.h
@@ -12,8 +12,6 @@
 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
 #include "src/sksl/ir/SkSLProgramElement.h"
 
-#include <unordered_set>
-
 namespace SkSL {
 
 struct ASTNode;
@@ -25,34 +23,48 @@
     static constexpr Kind kProgramElementKind = Kind::kFunction;
 
     FunctionDefinition(int offset,
-                       const FunctionDeclaration& declaration,
+                       const FunctionDeclaration* declaration,
                        std::unique_ptr<Statement> body,
                        std::unordered_set<const FunctionDeclaration*> referencedIntrinsics = {})
-        : INHERITED(offset, kProgramElementKind)
-        , fDeclaration(declaration)
-        , fBody(std::move(body))
-        , fReferencedIntrinsics(std::move(referencedIntrinsics)) {}
+        : INHERITED(offset, FunctionDefinitionData{declaration, std::move(referencedIntrinsics),
+                                                   nullptr}) {
+        fStatementChildren.push_back(std::move(body));
+    }
+
+    const FunctionDeclaration& declaration() const {
+        return *this->functionDefinitionData().fDeclaration;
+    }
+
+    std::unique_ptr<Statement>& body() {
+        return this->fStatementChildren[0];
+    }
+
+    const std::unique_ptr<Statement>& body() const {
+        return this->fStatementChildren[0];
+    }
+
+    const std::unordered_set<const FunctionDeclaration*>& referencedIntrinsics() const {
+        return this->functionDefinitionData().fReferencedIntrinsics;
+    }
+
+    const ASTNode* source() const {
+        return this->functionDefinitionData().fSource;
+    }
+
+    void setSource(const ASTNode* source) {
+        this->functionDefinitionData().fSource = source;
+    }
 
     std::unique_ptr<ProgramElement> clone() const override {
-        return std::make_unique<FunctionDefinition>(fOffset, fDeclaration,
-                                                    fBody->clone(), fReferencedIntrinsics);
+        return std::make_unique<FunctionDefinition>(fOffset, &this->declaration(),
+                                                    this->body()->clone(),
+                                                    this->referencedIntrinsics());
     }
 
     String description() const override {
-        return fDeclaration.description() + " " + fBody->description();
+        return this->declaration().description() + " " + this->body()->description();
     }
 
-    const FunctionDeclaration& fDeclaration;
-    std::unique_ptr<Statement> fBody;
-    // We track intrinsic functions we reference so that we can ensure that all of them end up
-    // copied into the final output.
-    std::unordered_set<const FunctionDeclaration*> fReferencedIntrinsics;
-    // This pointer may be null, and even when non-null is not guaranteed to remain valid for the
-    // entire lifespan of this object. The parse tree's lifespan is normally controlled by
-    // IRGenerator, so the IRGenerator being destroyed or being used to compile another file will
-    // invalidate this pointer.
-    const ASTNode* fSource = nullptr;
-
     using INHERITED = ProgramElement;
 };
 
diff --git a/src/sksl/ir/SkSLIRNode.cpp b/src/sksl/ir/SkSLIRNode.cpp
index 8cfcda3..7dbd665 100644
--- a/src/sksl/ir/SkSLIRNode.cpp
+++ b/src/sksl/ir/SkSLIRNode.cpp
@@ -62,6 +62,11 @@
 , fKind(kind)
 , fData(data) {}
 
+IRNode::IRNode(int offset, int kind, const FunctionDefinitionData& data)
+: fOffset(offset)
+, fKind(kind)
+, fData(data) {}
+
 IRNode::IRNode(int offset, int kind, const FunctionReferenceData& data)
 : fOffset(offset)
 , fKind(kind)
diff --git a/src/sksl/ir/SkSLIRNode.h b/src/sksl/ir/SkSLIRNode.h
index de68ceb..af9ad33 100644
--- a/src/sksl/ir/SkSLIRNode.h
+++ b/src/sksl/ir/SkSLIRNode.h
@@ -16,6 +16,7 @@
 
 #include <algorithm>
 #include <atomic>
+#include <unordered_set>
 #include <vector>
 
 namespace SkSL {
@@ -138,6 +139,18 @@
         }
     };
 
+    struct FunctionDefinitionData {
+        const FunctionDeclaration* fDeclaration;
+        // We track intrinsic functions we reference so that we can ensure that all of them end up
+        // copied into the final output.
+        std::unordered_set<const FunctionDeclaration*> fReferencedIntrinsics;
+        // This pointer may be null, and even when non-null is not guaranteed to remain valid for
+        // the entire lifespan of this object. The parse tree's lifespan is normally controlled by
+        // IRGenerator, so the IRGenerator being destroyed or being used to compile another file
+        // will invalidate this pointer.
+        const ASTNode* fSource;
+    };
+
     struct FunctionReferenceData {
         const Type* fType;
         std::vector<const FunctionDeclaration*> fFunctions;
@@ -239,6 +252,7 @@
             kForStatement,
             kFunctionCall,
             kFunctionDeclaration,
+            kFunctionDefinition,
             kFunctionReference,
             kIfStatement,
             kInlineMarker,
@@ -271,6 +285,7 @@
             ForStatementData fForStatement;
             FunctionCallData fFunctionCall;
             FunctionDeclarationData fFunctionDeclaration;
+            FunctionDefinitionData fFunctionDefinition;
             FunctionReferenceData fFunctionReference;
             IfStatementData fIfStatement;
             InlineMarkerData fInlineMarker;
@@ -345,6 +360,11 @@
             *(new(&fContents) FunctionDeclarationData) = data;
         }
 
+        NodeData(const FunctionDefinitionData& data)
+            : fKind(Kind::kFunctionDefinition) {
+            *(new(&fContents) FunctionDefinitionData) = data;
+        }
+
         NodeData(const FunctionReferenceData& data)
             : fKind(Kind::kFunctionReference) {
             *(new(&fContents) FunctionReferenceData) = data;
@@ -474,6 +494,9 @@
                     *(new(&fContents) FunctionDeclarationData) =
                                                                other.fContents.fFunctionDeclaration;
                     break;
+                case Kind::kFunctionDefinition:
+                    *(new(&fContents) FunctionDefinitionData) = other.fContents.fFunctionDefinition;
+                    break;
                 case Kind::kFunctionReference:
                     *(new(&fContents) FunctionReferenceData) = other.fContents.fFunctionReference;
                     break;
@@ -570,6 +593,9 @@
                 case Kind::kFunctionDeclaration:
                     fContents.fFunctionDeclaration.~FunctionDeclarationData();
                     break;
+                case Kind::kFunctionDefinition:
+                    fContents.fFunctionDefinition.~FunctionDefinitionData();
+                    break;
                 case Kind::kFunctionReference:
                     fContents.fFunctionReference.~FunctionReferenceData();
                     break;
@@ -647,6 +673,8 @@
 
     IRNode(int offset, int kind, const FunctionDeclarationData& data);
 
+    IRNode(int offset, int kind, const FunctionDefinitionData& data);
+
     IRNode(int offset, int kind, const FunctionReferenceData& data);
 
     IRNode(int offset, int kind, const IfStatementData& data);
@@ -782,6 +810,16 @@
         return fData.fContents.fFunctionDeclaration;
     }
 
+    FunctionDefinitionData& functionDefinitionData() {
+        SkASSERT(fData.fKind == NodeData::Kind::kFunctionDefinition);
+        return fData.fContents.fFunctionDefinition;
+    }
+
+    const FunctionDefinitionData& functionDefinitionData() const {
+        SkASSERT(fData.fKind == NodeData::Kind::kFunctionDefinition);
+        return fData.fContents.fFunctionDefinition;
+    }
+
     const FunctionReferenceData& functionReferenceData() const {
         SkASSERT(fData.fKind == NodeData::Kind::kFunctionReference);
         return fData.fContents.fFunctionReference;
diff --git a/src/sksl/ir/SkSLProgramElement.h b/src/sksl/ir/SkSLProgramElement.h
index 2ea4fd3..7a2873b 100644
--- a/src/sksl/ir/SkSLProgramElement.h
+++ b/src/sksl/ir/SkSLProgramElement.h
@@ -37,19 +37,22 @@
         SkASSERT(kind >= Kind::kFirst && kind <= Kind::kLast);
     }
 
-    ProgramElement(int offset, const EnumData& enumData)
-    : INHERITED(offset, (int) Kind::kEnum, enumData) {}
+    ProgramElement(int offset, const EnumData& data)
+    : INHERITED(offset, (int) Kind::kEnum, data) {}
 
-    ProgramElement(int offset, const ModifiersDeclarationData& enumData)
-    : INHERITED(offset, (int) Kind::kModifiers, enumData) {}
+    ProgramElement(int offset, const FunctionDefinitionData& data)
+    : INHERITED(offset, (int) Kind::kFunction, data) {}
+
+    ProgramElement(int offset, const ModifiersDeclarationData& data)
+    : INHERITED(offset, (int) Kind::kModifiers, data) {}
 
     ProgramElement(int offset, Kind kind, const String& data)
     : INHERITED(offset, (int) kind, data) {
         SkASSERT(kind >= Kind::kFirst && kind <= Kind::kLast);
     }
 
-    ProgramElement(int offset, const SectionData& sectionData)
-    : INHERITED(offset, (int) Kind::kSection, sectionData) {}
+    ProgramElement(int offset, const SectionData& data)
+    : INHERITED(offset, (int) Kind::kSection, data) {}
 
     Kind kind() const {
         return (Kind) fKind;