Split IRGenerator's VarDeclaration processing into two parts

This allows things to align better between the IR and DSL sides
and gives us equivalent error handling on both sides.

Change-Id: I6d5569e29df51a4d1a6cb0ad1e6611d419dfe30c
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/373737
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index febe77d..7177046 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -428,14 +428,11 @@
     this->checkModifiers(offset, modifiers, permitted);
 }
 
-std::unique_ptr<Statement> IRGenerator::convertVarDeclaration(int offset,
-                                                              const Modifiers& modifiers,
-                                                              const Type* baseType,
-                                                              StringFragment name,
-                                                              bool isArray,
-                                                              std::unique_ptr<Expression> arraySize,
-                                                              std::unique_ptr<Expression> value,
-                                                              Variable::Storage storage) {
+std::unique_ptr<Variable> IRGenerator::convertVar(int offset, const Modifiers& modifiers,
+                                                  const Type* baseType, StringFragment name,
+                                                  bool isArray,
+                                                  std::unique_ptr<Expression> arraySize,
+                                                  Variable::Storage storage) {
     if (modifiers.fLayout.fLocation == 0 && modifiers.fLayout.fIndex == 0 &&
         (modifiers.fFlags & Modifiers::kOut_Flag) &&
         this->programKind() == ProgramKind::kFragment && name != "sk_FragColor") {
@@ -452,43 +449,70 @@
         }
         type = fSymbolTable->addArrayDimension(type, arraySizeValue);
     }
-    auto var = std::make_unique<Variable>(offset, fModifiers->addToPool(modifiers),
-                                          name, type, fIsBuiltinCode, storage);
+    return std::make_unique<Variable>(offset, fModifiers->addToPool(modifiers), name, type,
+                                      fIsBuiltinCode, storage);
+}
+
+std::unique_ptr<Statement> IRGenerator::convertVarDeclaration(std::unique_ptr<Variable> var,
+                                                              std::unique_ptr<Expression> value) {
+    if (value) {
+        if (var->type().isOpaque()) {
+            this->errorReporter().error(
+                    value->fOffset,
+                    "opaque type '" + var->type().name() +
+                    "' cannot use initializer expressions");
+        }
+        if (var->modifiers().fFlags & Modifiers::kIn_Flag) {
+            this->errorReporter().error(value->fOffset,
+                                        "'in' variables cannot use initializer expressions");
+        }
+        if (var->modifiers().fFlags & Modifiers::kUniform_Flag) {
+            this->errorReporter().error(value->fOffset,
+                                        "'uniform' variables cannot use initializer expressions");
+        }
+        value = this->coerce(std::move(value), var->type());
+        if (!value) {
+            return nullptr;
+        }
+    }
+    const Type* baseType = &var->type();
+    int arraySize = 0;
+    if (baseType->isArray()) {
+        arraySize = baseType->columns();
+        baseType = &baseType->componentType();
+    }
+    auto result = std::make_unique<VarDeclaration>(var.get(), baseType, arraySize,
+                                                   std::move(value));
+    var->setDeclaration(result.get());
     if (var->name() == Compiler::RTADJUST_NAME) {
         SkASSERT(!fRTAdjust);
         SkASSERT(var->type() == *fContext.fTypes.fFloat4);
         fRTAdjust = var.get();
     }
-    if (value) {
-        if (type->isOpaque()) {
-            this->errorReporter().error(
-                    value->fOffset,
-                    "opaque type '" + type->name() + "' cannot use initializer expressions");
-        }
-        if (modifiers.fFlags & Modifiers::kIn_Flag) {
-            this->errorReporter().error(value->fOffset,
-                                        "'in' variables cannot use initializer expressions");
-        }
-        if (modifiers.fFlags & Modifiers::kUniform_Flag) {
-            this->errorReporter().error(value->fOffset,
-                                        "'uniform' variables cannot use initializer expressions");
-        }
-        value = this->coerce(std::move(value), *type);
-        if (!value) {
-            return {};
-        }
-    }
     const Symbol* symbol = (*fSymbolTable)[var->name()];
-    if (symbol && storage == Variable::Storage::kGlobal && var->name() == "sk_FragColor") {
+    if (symbol && var->storage() == Variable::Storage::kGlobal && var->name() == "sk_FragColor") {
         // Already defined, ignore.
         return nullptr;
     } else {
-        auto result = std::make_unique<VarDeclaration>(var.get(), baseType, arraySizeValue,
-                                                       std::move(value));
-        var->setDeclaration(result.get());
         fSymbolTable->add(std::move(var));
-        return std::move(result);
     }
+    return std::move(result);
+}
+
+std::unique_ptr<Statement> IRGenerator::convertVarDeclaration(int offset,
+                                                              const Modifiers& modifiers,
+                                                              const Type* baseType,
+                                                              StringFragment name,
+                                                              bool isArray,
+                                                              std::unique_ptr<Expression> arraySize,
+                                                              std::unique_ptr<Expression> value,
+                                                              Variable::Storage storage) {
+    std::unique_ptr<Variable> var = this->convertVar(offset, modifiers, baseType, name, isArray,
+                                                     std::move(arraySize), storage);
+    if (!var) {
+        return nullptr;
+    }
+    return this->convertVarDeclaration(std::move(var), std::move(value));
 }
 
 StatementArray IRGenerator::convertVarDeclarations(const ASTNode& decls,
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 0421e69..bf5f2fc 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -158,6 +158,12 @@
     void checkModifiers(int offset, const Modifiers& modifiers, int permitted);
     void checkVarDeclaration(int offset, const Modifiers& modifiers, const Type* baseType,
                              Variable::Storage storage);
+    std::unique_ptr<Variable> convertVar(int offset, const Modifiers& modifiers,
+                                         const Type* baseType, StringFragment name, bool isArray,
+                                         std::unique_ptr<Expression> arraySize,
+                                         Variable::Storage storage);
+    std::unique_ptr<Statement> convertVarDeclaration(std::unique_ptr<Variable> var,
+                                                     std::unique_ptr<Expression> value);
     std::unique_ptr<Statement> convertVarDeclaration(int offset, const Modifiers& modifiers,
                                                      const Type* baseType, StringFragment name,
                                                      bool isArray,
diff --git a/src/sksl/dsl/DSLCore.cpp b/src/sksl/dsl/DSLCore.cpp
index 131b8a4..ebe5719 100644
--- a/src/sksl/dsl/DSLCore.cpp
+++ b/src/sksl/dsl/DSLCore.cpp
@@ -73,16 +73,14 @@
     }
 
     static DSLStatement Declare(DSLVar& var, DSLExpression initialValue) {
-        if (!var.fDeclaration) {
-            DSLWriter::ReportError("Declare failed (was the variable already declared?)");
+        if (var.fConstVar) {
+            DSLWriter::ReportError("Variable already declared");
             return DSLStatement();
         }
-        VarDeclaration& decl = var.fDeclaration->as<SkSL::VarDeclaration>();
-        std::unique_ptr<Expression> expr = initialValue.coerceAndRelease(decl.var().type());
-        if (expr) {
-            decl.fValue = std::move(expr);
-        }
-        return DSLStatement(std::move(var.fDeclaration));
+        SkASSERT(var.fVar);
+        var.fConstVar = var.fVar.get();
+        return DSLWriter::IRGenerator().convertVarDeclaration(std::move(var.fVar),
+                                                              initialValue.release());
     }
 
     static DSLStatement Discard() {
diff --git a/src/sksl/dsl/DSLVar.cpp b/src/sksl/dsl/DSLVar.cpp
index 3471475..52466e5 100644
--- a/src/sksl/dsl/DSLVar.cpp
+++ b/src/sksl/dsl/DSLVar.cpp
@@ -30,7 +30,8 @@
         // converting all DSL code into strings rather than nodes, all we really need is a
         // correctly-named variable with the right type, so we just create a placeholder for it.
         // TODO(skia/11330): we'll need to fix this when switching over to nodes.
-        fVar = DSLWriter::SymbolTable()->takeOwnershipOfIRNode(std::make_unique<SkSL::Variable>(
+        fConstVar = DSLWriter::SymbolTable()->takeOwnershipOfIRNode(
+                            std::make_unique<SkSL::Variable>(
                                   /*offset=*/-1,
                                   DSLWriter::IRGenerator().fModifiers->addToPool(SkSL::Modifiers()),
                                   fName,
@@ -42,7 +43,7 @@
 #endif
     const SkSL::Symbol* result = (*DSLWriter::SymbolTable())[fName];
     SkASSERTF(result, "could not find '%s' in symbol table", fName);
-    fVar = &result->as<SkSL::Variable>();
+    fConstVar = &result->as<SkSL::Variable>();
 }
 
 DSLVar::DSLVar(DSLType type, const char* name)
@@ -84,15 +85,13 @@
 #endif // SK_SUPPORT_GPU && !defined(SKSL_STANDALONE)
     DSLWriter::IRGenerator().checkVarDeclaration(/*offset=*/-1, modifiers.fModifiers,
                                                  &type.skslType(), storage);
-    fDeclaration = DSLWriter::IRGenerator().convertVarDeclaration(/*offset=*/-1,
-                                                                  modifiers.fModifiers,
-                                                                  &type.skslType(),
-                                                                  fName,
-                                                                  /*isArray=*/false,
-                                                                  /*arraySize=*/nullptr,
-                                                                  /*value=*/nullptr,
-                                                                  storage);
-    fVar = &fDeclaration->as<SkSL::VarDeclaration>().var();
+    fVar = DSLWriter::IRGenerator().convertVar(/*offset=*/-1,
+                                               modifiers.fModifiers,
+                                               &type.skslType(),
+                                               fName,
+                                               /*isArray=*/false,
+                                               /*arraySize=*/nullptr,
+                                               storage);
 }
 
 #if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
diff --git a/src/sksl/dsl/DSLVar.h b/src/sksl/dsl/DSLVar.h
index ed7bbf5..7587fb2 100644
--- a/src/sksl/dsl/DSLVar.h
+++ b/src/sksl/dsl/DSLVar.h
@@ -102,7 +102,7 @@
     DSLVar(const char* name);
 
     const SkSL::Variable* var() const {
-        return fVar;
+        return fVar ? fVar.get() : fConstVar;
     }
 
     const char* name() const {
@@ -116,7 +116,8 @@
 #endif
 
     std::unique_ptr<SkSL::Statement> fDeclaration;
-    const SkSL::Variable* fVar = nullptr;
+    std::unique_ptr<SkSL::Variable> fVar = nullptr;
+    const SkSL::Variable* fConstVar = nullptr;
     const char* fName;
 
     friend DSLVar sk_SampleCoord();
diff --git a/src/sksl/ir/SkSLVarDeclarations.h b/src/sksl/ir/SkSLVarDeclarations.h
index 609edb8..8735ec9 100644
--- a/src/sksl/ir/SkSLVarDeclarations.h
+++ b/src/sksl/ir/SkSLVarDeclarations.h
@@ -91,7 +91,7 @@
     int fArraySize;  // zero means "not an array", Type::kUnsizedArray means var[]
     std::unique_ptr<Expression> fValue;
 
-    friend class dsl::DSLCore;
+    friend class IRGenerator;
 
     using INHERITED = Statement;
 };
diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h
index b6402f0..b660035 100644
--- a/src/sksl/ir/SkSLVariable.h
+++ b/src/sksl/ir/SkSLVariable.h
@@ -19,6 +19,10 @@
 class Expression;
 class VarDeclaration;
 
+namespace dsl {
+class DSLCore;
+} // namespace dsl
+
 enum class VariableStorage : int8_t {
     kGlobal,
     kInterfaceBlock,
@@ -75,6 +79,7 @@
 
     using INHERITED = Symbol;
 
+    friend class dsl::DSLCore;
     friend class VariableReference;
 };
 
diff --git a/tests/SkSLDSLTest.cpp b/tests/SkSLDSLTest.cpp
index eb29f86..80a1fca 100644
--- a/tests/SkSLDSLTest.cpp
+++ b/tests/SkSLDSLTest.cpp
@@ -1297,8 +1297,8 @@
     AutoDSLContext context(ctxInfo.directContext()->priv().getGpu());
 
     Var v1(kConst_Modifier, kInt, "v1");
-    Statement d1 = Declare(v1);
-    EXPECT_EQUAL(d1, "const int v1;");
+    Statement d1 = Declare(v1, 0);
+    EXPECT_EQUAL(d1, "const int v1 = 0;");
 
     // Most modifiers require an appropriate context to be legal. We can't yet give them that
     // context, so we can't as yet Declare() variables with these modifiers.