Add pass to generate needed initializers for WebGPU (#2481)

Fixes #2387
diff --git a/Android.mk b/Android.mk
index fe7a93d..4479ca8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -107,6 +107,7 @@
 		source/opt/fold_spec_constant_op_and_composite_pass.cpp \
 		source/opt/freeze_spec_constant_value_pass.cpp \
 		source/opt/function.cpp \
+		source/opt/generate_webgpu_initializers_pass.cpp \
 		source/opt/if_conversion.cpp \
 		source/opt/inline_pass.cpp \
 		source/opt/inline_exhaustive_pass.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index 24f5006..96c5037 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -514,6 +514,8 @@
     "source/opt/freeze_spec_constant_value_pass.h",
     "source/opt/function.cpp",
     "source/opt/function.h",
+    "source/opt/generate_webgpu_initializers_pass.cpp",
+    "source/opt/generate_webgpu_initializers_pass.hpp",
     "source/opt/if_conversion.cpp",
     "source/opt/if_conversion.h",
     "source/opt/inline_exhaustive_pass.cpp",
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index adfd492..4e92bb0 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -747,6 +747,12 @@
 // where an instruction is moved into a more deeply nested construct.
 Optimizer::PassToken CreateCodeSinkingPass();
 
+// Create a pass to adds initializers for OpVariable calls that require them
+// in WebGPU. Currently this pass naively initializes variables that are
+// missing an initializer with a null value. In the future it may initialize
+// variables to the first value stored in them, if that is a constant.
+Optimizer::PassToken CreateGenerateWebGPUInitializersPass();
+
 }  // namespace spvtools
 
 #endif  // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 53901a4..9eff861 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -46,6 +46,7 @@
   fold_spec_constant_op_and_composite_pass.h
   freeze_spec_constant_value_pass.h
   function.h
+  generate_webgpu_initializers_pass.h
   if_conversion.h
   inline_exhaustive_pass.h
   inline_opaque_pass.h
@@ -143,6 +144,7 @@
   fold_spec_constant_op_and_composite_pass.cpp
   freeze_spec_constant_value_pass.cpp
   function.cpp
+  generate_webgpu_initializers_pass.cpp
   if_conversion.cpp
   inline_exhaustive_pass.cpp
   inline_opaque_pass.cpp
diff --git a/source/opt/generate_webgpu_initializers_pass.cpp b/source/opt/generate_webgpu_initializers_pass.cpp
new file mode 100644
index 0000000..a8e00b6
--- /dev/null
+++ b/source/opt/generate_webgpu_initializers_pass.cpp
@@ -0,0 +1,112 @@
+// Copyright (c) 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "source/opt/generate_webgpu_initializers_pass.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace opt {
+
+using inst_iterator = InstructionList::iterator;
+
+namespace {
+
+bool NeedsWebGPUInitializer(Instruction* inst) {
+  if (inst->opcode() != SpvOpVariable) return false;
+
+  auto storage_class = inst->GetSingleWordOperand(2);
+  if (storage_class != SpvStorageClassOutput &&
+      storage_class != SpvStorageClassPrivate &&
+      storage_class != SpvStorageClassFunction) {
+    return false;
+  }
+
+  if (inst->NumOperands() > 3) return false;
+
+  return true;
+}
+
+}  // namespace
+
+Pass::Status GenerateWebGPUInitializersPass::Process() {
+  auto* module = context()->module();
+  bool changed = false;
+
+  // Handle global/module scoped variables
+  for (auto iter = module->types_values_begin();
+       iter != module->types_values_end(); ++iter) {
+    Instruction* inst = &(*iter);
+
+    if (inst->opcode() == SpvOpConstantNull) {
+      null_constant_type_map_[inst->type_id()] = inst;
+      seen_null_constants_.insert(inst);
+      continue;
+    }
+
+    if (!NeedsWebGPUInitializer(inst)) continue;
+
+    changed = true;
+
+    auto* constant_inst = GetNullConstantForVariable(inst);
+    if (seen_null_constants_.find(constant_inst) ==
+        seen_null_constants_.end()) {
+      constant_inst->InsertBefore(inst);
+      null_constant_type_map_[inst->type_id()] = inst;
+      seen_null_constants_.insert(inst);
+    }
+    AddNullInitializerToVariable(constant_inst, inst);
+  }
+
+  // Handle local/function scoped variables
+  for (auto func = module->begin(); func != module->end(); ++func) {
+    auto block = func->entry().get();
+    for (auto iter = block->begin();
+         iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
+      Instruction* inst = &(*iter);
+      if (!NeedsWebGPUInitializer(inst)) continue;
+
+      changed = true;
+      auto* constant_inst = GetNullConstantForVariable(inst);
+      AddNullInitializerToVariable(constant_inst, inst);
+    }
+  }
+
+  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+Instruction* GenerateWebGPUInitializersPass::GetNullConstantForVariable(
+    Instruction* variable_inst) {
+  auto constant_mgr = context()->get_constant_mgr();
+  auto* def_use_mgr = get_def_use_mgr();
+
+  auto* ptr_inst = def_use_mgr->GetDef(variable_inst->type_id());
+  auto type_id = ptr_inst->GetInOperand(1).words[0];
+  if (null_constant_type_map_.find(type_id) == null_constant_type_map_.end()) {
+    auto* constant_type = context()->get_type_mgr()->GetType(type_id);
+    auto* constant = constant_mgr->GetConstant(constant_type, {});
+    return constant_mgr->GetDefiningInstruction(constant, type_id);
+  } else {
+    return null_constant_type_map_[type_id];
+  }
+}
+
+void GenerateWebGPUInitializersPass::AddNullInitializerToVariable(
+    Instruction* constant_inst, Instruction* variable_inst) {
+  auto constant_id = constant_inst->result_id();
+  variable_inst->AddOperand(Operand(SPV_OPERAND_TYPE_ID, {constant_id}));
+  get_def_use_mgr()->AnalyzeInstUse(variable_inst);
+}
+
+}  // namespace opt
+}  // namespace spvtools
diff --git a/source/opt/generate_webgpu_initializers_pass.h b/source/opt/generate_webgpu_initializers_pass.h
new file mode 100644
index 0000000..9aa970d
--- /dev/null
+++ b/source/opt/generate_webgpu_initializers_pass.h
@@ -0,0 +1,62 @@
+// Copyright (c) 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_GENERATE_WEBGPU_INITIALIZERS_PASS_H_
+#define SOURCE_OPT_GENERATE_WEBGPU_INITIALIZERS_PASS_H_
+
+#include "source/opt/ir_context.h"
+#include "source/opt/module.h"
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// Adds initializers to variables with storage classes Output, Private, and
+// Function if they are missing. In the WebGPU environment these storage classes
+// require that the variables are initialized. Currently they are initialized to
+// NULL, though in the future some of them may be initialized to the first value
+// that is stored in them, if that was a constant.
+class GenerateWebGPUInitializersPass : public Pass {
+ public:
+  const char* name() const override { return "generate-webgpu-initializers"; }
+  Status Process() override;
+
+  IRContext::Analysis GetPreservedAnalyses() override {
+    return IRContext::kAnalysisInstrToBlockMapping |
+           IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
+           IRContext::kAnalysisCFG | IRContext::kAnalysisDominatorAnalysis |
+           IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap |
+           IRContext::kAnalysisScalarEvolution |
+           IRContext::kAnalysisRegisterPressure |
+           IRContext::kAnalysisValueNumberTable |
+           IRContext::kAnalysisStructuredCFG |
+           IRContext::kAnalysisBuiltinVarId |
+           IRContext::kAnalysisIdToFuncMapping | IRContext::kAnalysisTypes |
+           IRContext::kAnalysisDefUse | IRContext::kAnalysisConstants;
+  }
+
+ private:
+  using NullConstantTypeMap = std::unordered_map<uint32_t, Instruction*>;
+  NullConstantTypeMap null_constant_type_map_;
+  std::unordered_set<Instruction*> seen_null_constants_;
+
+  Instruction* GetNullConstantForVariable(Instruction* variable_inst);
+  void AddNullInitializerToVariable(Instruction* constant_inst,
+                                    Instruction* variable_inst);
+};
+
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // SOURCE_OPT_GENERATE_WEBGPU_INITIALIZERS_PASS_H_
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index 887a9c2..c6e48e6 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -220,6 +220,7 @@
 Optimizer& Optimizer::RegisterWebGPUPasses() {
   return RegisterPass(CreateStripDebugInfoPass())
       .RegisterPass(CreateStripAtomicCounterMemoryPass())
+      .RegisterPass(CreateGenerateWebGPUInitializersPass())
       .RegisterPass(CreateEliminateDeadConstantPass())
       .RegisterPass(CreateFlattenDecorationPass())
       .RegisterPass(CreateAggressiveDCEPass())
@@ -456,6 +457,8 @@
     RegisterSizePasses();
   } else if (pass_name == "legalize-hlsl") {
     RegisterLegalizationPasses();
+  } else if (pass_name == "generate-webgpu-initializers") {
+    RegisterPass(CreateGenerateWebGPUInitializersPass());
   } else {
     Errorf(consumer(), nullptr, {},
            "Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -826,4 +829,9 @@
       MakeUnique<opt::CodeSinkingPass>());
 }
 
+Optimizer::PassToken CreateGenerateWebGPUInitializersPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::GenerateWebGPUInitializersPass>());
+}
+
 }  // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 2b97793..d80f4ac 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -35,6 +35,7 @@
 #include "source/opt/flatten_decoration_pass.h"
 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
 #include "source/opt/freeze_spec_constant_value_pass.h"
+#include "source/opt/generate_webgpu_initializers_pass.h"
 #include "source/opt/if_conversion.h"
 #include "source/opt/inline_exhaustive_pass.h"
 #include "source/opt/inline_opaque_pass.h"
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 398baa4..adc78be 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -41,6 +41,7 @@
        fold_test.cpp
        freeze_spec_const_test.cpp
        function_test.cpp
+       generate_webgpu_initializers_test.cpp
        if_conversion_test.cpp
        inline_opaque_test.cpp
        inline_test.cpp
diff --git a/test/opt/generate_webgpu_initializers_test.cpp b/test/opt/generate_webgpu_initializers_test.cpp
new file mode 100644
index 0000000..3e4be55
--- /dev/null
+++ b/test/opt/generate_webgpu_initializers_test.cpp
@@ -0,0 +1,347 @@
+// Copyright (c) 2019 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <vector>
+
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+typedef std::tuple<std::string, bool> GenerateWebGPUInitializersParam;
+
+using GlobalVariableTest =
+    PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
+using LocalVariableTest =
+    PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
+
+using GenerateWebGPUInitializersTest = PassTest<::testing::Test>;
+
+void operator+=(std::vector<const char*>& lhs, const char* rhs) {
+  lhs.push_back(rhs);
+}
+
+void operator+=(std::vector<const char*>& lhs,
+                const std::vector<const char*>& rhs) {
+  lhs.reserve(lhs.size() + rhs.size());
+  for (auto* c : rhs) lhs.push_back(c);
+}
+
+std::string GetGlobalVariableTestString(std::string ptr_str,
+                                        std::string var_str,
+                                        std::string const_str = "") {
+  std::vector<const char*> result = {
+      // clang-format off
+               "OpCapability Shader",
+               "OpCapability VulkanMemoryModelKHR",
+               "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+               "OpMemoryModel Logical VulkanKHR",
+               "OpEntryPoint Vertex %1 \"shader\"",
+       "%uint = OpTypeInt 32 0",
+                ptr_str.c_str()};
+  // clang-format on
+
+  if (!const_str.empty()) result += const_str.c_str();
+
+  result += {
+      // clang-format off
+                var_str.c_str(),
+     "%uint_0 = OpConstant %uint 0",
+       "%void = OpTypeVoid",
+          "%7 = OpTypeFunction %void",
+          "%1 = OpFunction %void None %7",
+          "%8 = OpLabel",
+               "OpStore %4 %uint_0",
+               "OpReturn",
+               "OpFunctionEnd"
+      // clang-format on
+  };
+  return JoinAllInsts(result);
+}
+
+std::string GetPointerString(std::string storage_type) {
+  std::string result = "%_ptr_";
+  result += storage_type + "_uint = OpTypePointer ";
+  result += storage_type + " %uint";
+  return result;
+}
+
+std::string GetGlobalVariableString(std::string storage_type,
+                                    bool initialized) {
+  std::string result = "%4 = OpVariable %_ptr_";
+  result += storage_type + "_uint ";
+  result += storage_type;
+  if (initialized) result += " %9";
+  return result;
+}
+
+std::string GetUninitializedGlobalVariableTestString(std::string storage_type) {
+  return GetGlobalVariableTestString(
+      GetPointerString(storage_type),
+      GetGlobalVariableString(storage_type, false));
+}
+
+std::string GetNullConstantString() { return "%9 = OpConstantNull %uint"; }
+
+std::string GetInitializedGlobalVariableTestString(std::string storage_type) {
+  return GetGlobalVariableTestString(
+      GetPointerString(storage_type),
+      GetGlobalVariableString(storage_type, true), GetNullConstantString());
+}
+
+TEST_P(GlobalVariableTest, Check) {
+  std::string storage_class = std::get<0>(GetParam());
+  bool changed = std::get<1>(GetParam());
+  std::string input = GetUninitializedGlobalVariableTestString(storage_class);
+  std::string expected =
+      changed ? GetInitializedGlobalVariableTestString(storage_class) : input;
+
+  SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
+                                                        /* skip_nop = */ false);
+}
+
+// clang-format off
+INSTANTIATE_TEST_SUITE_P(
+    GenerateWebGPUInitializers, GlobalVariableTest,
+    ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
+       std::make_tuple("Private", true),
+       std::make_tuple("Output", true),
+       std::make_tuple("Function", true),
+       std::make_tuple("UniformConstant", false),
+       std::make_tuple("Input", false),
+       std::make_tuple("Uniform", false),
+       std::make_tuple("Workgroup", false)
+    })));
+// clang-format on
+
+std::string GetLocalVariableTestString(std::string ptr_str, std::string var_str,
+                                       std::string const_str = "") {
+  std::vector<const char*> result = {
+      // clang-format off
+               "OpCapability Shader",
+               "OpCapability VulkanMemoryModelKHR",
+               "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+               "OpMemoryModel Logical VulkanKHR",
+               "OpEntryPoint Vertex %1 \"shader\"",
+       "%uint = OpTypeInt 32 0",
+                ptr_str.c_str(),
+     "%uint_0 = OpConstant %uint 0",
+       "%void = OpTypeVoid",
+          "%6 = OpTypeFunction %void"};
+  // clang-format on
+
+  if (!const_str.empty()) result += const_str.c_str();
+
+  result += {
+      // clang-format off
+          "%1 = OpFunction %void None %6",
+          "%7 = OpLabel",
+                var_str.c_str(),
+               "OpStore %8 %uint_0"
+      // clang-format on
+  };
+  return JoinAllInsts(result);
+}
+
+std::string GetLocalVariableString(std::string storage_type, bool initialized) {
+  std::string result = "%8 = OpVariable %_ptr_";
+  result += storage_type + "_uint ";
+  result += storage_type;
+  if (initialized) result += " %9";
+  return result;
+}
+
+std::string GetUninitializedLocalVariableTestString(std::string storage_type) {
+  return GetLocalVariableTestString(
+      GetPointerString(storage_type),
+      GetLocalVariableString(storage_type, false));
+}
+
+std::string GetInitializedLocalVariableTestString(std::string storage_type) {
+  return GetLocalVariableTestString(GetPointerString(storage_type),
+                                    GetLocalVariableString(storage_type, true),
+                                    GetNullConstantString());
+}
+
+TEST_P(LocalVariableTest, Check) {
+  std::string storage_class = std::get<0>(GetParam());
+  bool changed = std::get<1>(GetParam());
+
+  std::string input = GetUninitializedLocalVariableTestString(storage_class);
+  std::string expected =
+      changed ? GetInitializedLocalVariableTestString(storage_class) : input;
+
+  SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
+                                                        /* skip_nop = */ false);
+}
+
+// clang-format off
+INSTANTIATE_TEST_SUITE_P(
+    GenerateWebGPUInitializers, LocalVariableTest,
+    ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
+       std::make_tuple("Private", true),
+       std::make_tuple("Output", true),
+       std::make_tuple("Function", true),
+       std::make_tuple("UniformConstant", false),
+       std::make_tuple("Input", false),
+       std::make_tuple("Uniform", false),
+       std::make_tuple("Workgroup", false)
+    })));
+// clang-format on
+
+TEST_F(GenerateWebGPUInitializersTest, AlreadyInitializedUnchanged) {
+  std::vector<const char*> spirv = {
+      // clang-format off
+                       "OpCapability Shader",
+                       "OpCapability VulkanMemoryModelKHR",
+                       "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+                       "OpMemoryModel Logical VulkanKHR",
+                       "OpEntryPoint Vertex %1 \"shader\"",
+               "%uint = OpTypeInt 32 0",
+  "%_ptr_Private_uint = OpTypePointer Private %uint",
+             "%uint_0 = OpConstant %uint 0",
+                  "%5 = OpVariable %_ptr_Private_uint Private %uint_0",
+               "%void = OpTypeVoid",
+                  "%7 = OpTypeFunction %void",
+                  "%1 = OpFunction %void None %7",
+                  "%8 = OpLabel",
+                       "OpReturn",
+                       "OpFunctionEnd"
+      // clang-format on
+  };
+  std::string str = JoinAllInsts(spirv);
+
+  SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(str, str,
+                                                        /* skip_nop = */ false);
+}
+
+TEST_F(GenerateWebGPUInitializersTest, AmbigiousArrays) {
+  std::vector<const char*> input_spirv = {
+      // clang-format off
+                                   "OpCapability Shader",
+                                   "OpCapability VulkanMemoryModelKHR",
+                                   "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+                                   "OpMemoryModel Logical VulkanKHR",
+                                   "OpEntryPoint Vertex %1 \"shader\"",
+                           "%uint = OpTypeInt 32 0",
+                         "%uint_2 = OpConstant %uint 2",
+               "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
+             "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
+  "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
+"%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
+                              "%8 = OpConstantNull %_arr_uint_uint_2_0",
+                              "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private",
+                             "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
+                           "%void = OpTypeVoid",
+                             "%12 = OpTypeFunction %void",
+                              "%1 = OpFunction %void None %12",
+                             "%13 = OpLabel",
+                                   "OpReturn",
+                                   "OpFunctionEnd"
+      // clang-format on
+  };
+  std::string input_str = JoinAllInsts(input_spirv);
+
+  std::vector<const char*> expected_spirv = {
+      // clang-format off
+                                   "OpCapability Shader",
+                                   "OpCapability VulkanMemoryModelKHR",
+                                   "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+                                   "OpMemoryModel Logical VulkanKHR",
+                                   "OpEntryPoint Vertex %1 \"shader\"",
+                           "%uint = OpTypeInt 32 0",
+                         "%uint_2 = OpConstant %uint 2",
+               "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
+             "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
+  "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
+"%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
+                              "%8 = OpConstantNull %_arr_uint_uint_2_0",
+                             "%14 = OpConstantNull %_arr_uint_uint_2",
+                              "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private %14",
+                             "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
+                           "%void = OpTypeVoid",
+                             "%12 = OpTypeFunction %void",
+                              "%1 = OpFunction %void None %12",
+                             "%13 = OpLabel",
+                                   "OpReturn",
+                                   "OpFunctionEnd"
+      // clang-format on
+  };
+  std::string expected_str = JoinAllInsts(expected_spirv);
+
+  SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
+                                                        /* skip_nop = */ false);
+}
+
+TEST_F(GenerateWebGPUInitializersTest, AmbigiousStructs) {
+  std::vector<const char*> input_spirv = {
+      // clang-format off
+                          "OpCapability Shader",
+                          "OpCapability VulkanMemoryModelKHR",
+                          "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+                          "OpMemoryModel Logical VulkanKHR",
+                          "OpEntryPoint Vertex %1 \"shader\"",
+                  "%uint = OpTypeInt 32 0",
+             "%_struct_3 = OpTypeStruct %uint",
+             "%_struct_4 = OpTypeStruct %uint",
+"%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
+"%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
+                     "%7 = OpConstantNull %_struct_3",
+                     "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
+                     "%9 = OpVariable %_ptr_Private__struct_4 Private",
+                  "%void = OpTypeVoid",
+                    "%11 = OpTypeFunction %void",
+                     "%1 = OpFunction %void None %11",
+                    "%12 = OpLabel",
+                          "OpReturn",
+                          "OpFunctionEnd"
+      // clang-format on
+  };
+  std::string input_str = JoinAllInsts(input_spirv);
+
+  std::vector<const char*> expected_spirv = {
+      // clang-format off
+                          "OpCapability Shader",
+                          "OpCapability VulkanMemoryModelKHR",
+                          "OpExtension \"SPV_KHR_vulkan_memory_model\"",
+                          "OpMemoryModel Logical VulkanKHR",
+                          "OpEntryPoint Vertex %1 \"shader\"",
+                  "%uint = OpTypeInt 32 0",
+             "%_struct_3 = OpTypeStruct %uint",
+             "%_struct_4 = OpTypeStruct %uint",
+"%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
+"%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
+                     "%7 = OpConstantNull %_struct_3",
+                     "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
+                    "%13 = OpConstantNull %_struct_4",
+                     "%9 = OpVariable %_ptr_Private__struct_4 Private %13",
+                  "%void = OpTypeVoid",
+                    "%11 = OpTypeFunction %void",
+                     "%1 = OpFunction %void None %11",
+                    "%12 = OpLabel",
+                          "OpReturn",
+                          "OpFunctionEnd"
+      // clang-format on
+  };
+  std::string expected_str = JoinAllInsts(expected_spirv);
+
+  SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
+                                                        /* skip_nop = */ false);
+}
+
+}  // namespace
+}  // namespace opt
+}  // namespace spvtools
diff --git a/test/opt/optimizer_test.cpp b/test/opt/optimizer_test.cpp
index 77d2d1a..513aa16 100644
--- a/test/opt/optimizer_test.cpp
+++ b/test/opt/optimizer_test.cpp
@@ -236,7 +236,8 @@
                                               "eliminate-dead-const",
                                               "flatten-decorations",
                                               "strip-debug",
-                                              "strip-atomic-counter-memory"};
+                                              "strip-atomic-counter-memory",
+                                              "generate-webgpu-initializers"};
   std::sort(registered_passes.begin(), registered_passes.end());
   std::sort(expected_passes.begin(), expected_passes.end());
 
@@ -436,7 +437,45 @@
          "OpReturn\n"
          "OpFunctionEnd\n",
          // pass
-         "strip-atomic-counter-memory"}}));
+         "strip-atomic-counter-memory"},
+        // Generate WebGPU Initializers
+        {// input
+         "OpCapability Shader\n"
+         "OpCapability VulkanMemoryModelKHR\n"
+         "OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
+         "OpMemoryModel Logical VulkanKHR\n"
+         "OpEntryPoint Vertex %func \"shader\"\n"
+         "%u32 = OpTypeInt 32 0\n"
+         "%u32_ptr = OpTypePointer Private %u32\n"
+         "%u32_var = OpVariable %u32_ptr Private\n"
+         "%u32_0 = OpConstant %u32 0\n"
+         "%void = OpTypeVoid\n"
+         "%void_f = OpTypeFunction %void\n"
+         "%func = OpFunction %void None %void_f\n"
+         "%label = OpLabel\n"
+         "OpStore %u32_var %u32_0\n"
+         "OpReturn\n"
+         "OpFunctionEnd\n",
+         // expected
+         "OpCapability Shader\n"
+         "OpCapability VulkanMemoryModelKHR\n"
+         "OpExtension \"SPV_KHR_vulkan_memory_model\"\n"
+         "OpMemoryModel Logical VulkanKHR\n"
+         "OpEntryPoint Vertex %1 \"shader\"\n"
+         "%uint = OpTypeInt 32 0\n"
+         "%_ptr_Private_uint = OpTypePointer Private %uint\n"
+         "%9 = OpConstantNull %uint\n"
+         "%4 = OpVariable %_ptr_Private_uint Private %9\n"
+         "%uint_0 = OpConstant %uint 0\n"
+         "%void = OpTypeVoid\n"
+         "%7 = OpTypeFunction %void\n"
+         "%1 = OpFunction %void None %7\n"
+         "%8 = OpLabel\n"
+         "OpStore %4 %uint_0\n"
+         "OpReturn\n"
+         "OpFunctionEnd\n",
+         // pass
+         "generate-webgpu-initializers"}}));
 
 }  // namespace
 }  // namespace opt
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index fb48110..473dd75 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -178,6 +178,9 @@
   --freeze-spec-const
                Freeze the values of specialization constants to their default
                values.
+  --generate-webgpu-initializers
+               Adds initial values to OpVariable instructions that are missing
+               them, due to their storage type requiring them for WebGPU.
   --if-conversion
                Convert if-then-else like assignments into OpSelect.
   --inline-entry-points-exhaustive