Add support for Private & Output to initializer decompose flag (#2537)

Fixes #2388
diff --git a/source/opt/decompose_initialized_variables_pass.cpp b/source/opt/decompose_initialized_variables_pass.cpp
index b2471ca..875bf7e 100644
--- a/source/opt/decompose_initialized_variables_pass.cpp
+++ b/source/opt/decompose_initialized_variables_pass.cpp
@@ -34,43 +34,78 @@
 
 Pass::Status DecomposeInitializedVariablesPass::Process() {
   auto* module = context()->module();
-  bool changed = false;
+  std::unordered_set<Instruction*> changed;
 
-  // TODO(zoddicus): Handle 'Output' variables
-  // TODO(zoddicus): Handle 'Private' variables
+  std::vector<std::tuple<uint32_t, uint32_t>> global_stores;
+  for (auto iter = module->types_values_begin();
+       iter != module->types_values_end(); ++iter) {
+    Instruction* inst = &(*iter);
+    if (!HasInitializer(inst)) continue;
 
-  // Handle 'Function' variables
+    auto var_id = inst->result_id();
+    auto val_id = inst->GetOperand(3).words[0];
+    global_stores.push_back(std::make_tuple(var_id, val_id));
+    iter->RemoveOperand(3);
+    changed.insert(&*iter);
+  }
+
+  std::unordered_set<uint32_t> entry_ids;
+  for (auto entry = module->entry_points().begin();
+       entry != module->entry_points().end(); ++entry) {
+    entry_ids.insert(entry->GetSingleWordInOperand(1));
+  }
+
   for (auto func = module->begin(); func != module->end(); ++func) {
-    auto block = func->entry().get();
-    std::vector<Instruction*> new_stores;
-
-    auto last_var = block->begin();
-    for (auto iter = block->begin();
-         iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
-      last_var = iter;
+    std::vector<Instruction*> function_stores;
+    auto first_block = func->entry().get();
+    inst_iterator insert_point = first_block->begin();
+    for (auto iter = first_block->begin();
+         iter != first_block->end() && iter->opcode() == SpvOpVariable;
+         ++iter) {
+      // For valid SPIRV-V, there is guaranteed to be at least one instruction
+      // after the OpVariable instructions.
+      insert_point = (*iter).NextNode();
       Instruction* inst = &(*iter);
       if (!HasInitializer(inst)) continue;
 
-      changed = true;
       auto var_id = inst->result_id();
       auto val_id = inst->GetOperand(3).words[0];
       Instruction* store_inst = new Instruction(
           context(), SpvOpStore, 0, 0,
           {{SPV_OPERAND_TYPE_ID, {var_id}}, {SPV_OPERAND_TYPE_ID, {val_id}}});
-      new_stores.push_back(store_inst);
+      function_stores.push_back(store_inst);
       iter->RemoveOperand(3);
-      get_def_use_mgr()->UpdateDefUse(&*iter);
+      changed.insert(&*iter);
     }
 
-    for (auto store = new_stores.begin(); store != new_stores.end(); ++store) {
-      context()->AnalyzeDefUse(*store);
-      context()->set_instr_block(*store, block);
-      (*store)->InsertAfter(&*last_var);
-      last_var = *store;
+    if (entry_ids.find(func->result_id()) != entry_ids.end()) {
+      for (auto store_ids : global_stores) {
+        uint32_t var_id;
+        uint32_t val_id;
+        std::tie(var_id, val_id) = store_ids;
+        auto* store_inst = new Instruction(
+            context(), SpvOpStore, 0, 0,
+            {{SPV_OPERAND_TYPE_ID, {var_id}}, {SPV_OPERAND_TYPE_ID, {val_id}}});
+        context()->set_instr_block(store_inst, &*first_block);
+        first_block->AddInstruction(std::unique_ptr<Instruction>(store_inst));
+        store_inst->InsertBefore(&*insert_point);
+        changed.insert(store_inst);
+      }
+    }
+
+    for (auto store = function_stores.begin(); store != function_stores.end();
+         ++store) {
+      context()->set_instr_block(*store, first_block);
+      (*store)->InsertBefore(&*insert_point);
+      changed.insert(*store);
     }
   }
 
-  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+  auto* def_use_mgr = get_def_use_mgr();
+  for (auto* inst : changed) def_use_mgr->UpdateDefUse(inst);
+
+  return !changed.empty() ? Pass::Status::SuccessWithChange
+                          : Pass::Status::SuccessWithoutChange;
 }
 
 }  // namespace opt
diff --git a/test/opt/decompose_initialized_variables_test.cpp b/test/opt/decompose_initialized_variables_test.cpp
index 188e8b2..cdebb3f 100644
--- a/test/opt/decompose_initialized_variables_test.cpp
+++ b/test/opt/decompose_initialized_variables_test.cpp
@@ -23,63 +23,230 @@
 
 using DecomposeInitializedVariablesTest = PassTest<::testing::Test>;
 
-void operator+=(std::vector<const char*>& lhs,
-                const std::vector<const char*> rhs) {
-  for (auto elem : rhs) lhs.push_back(elem);
-}
+std::string single_entry_header = R"(OpCapability Shader
+OpCapability VulkanMemoryModelKHR
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint Vertex %1 "shader"
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%4 = OpConstantNull %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+)";
 
-std::vector<const char*> header = {
-    "OpCapability Shader",
-    "OpCapability VulkanMemoryModelKHR",
-    "OpExtension \"SPV_KHR_vulkan_memory_model\"",
-    "OpMemoryModel Logical VulkanKHR",
-    "OpEntryPoint Vertex %1 \"shader\"",
-    "%uint = OpTypeInt 32 0",
-    "%uint_1 = OpConstant %uint 1",
-    "%4 = OpConstantNull %uint",
-    "%void = OpTypeVoid",
-    "%6 = OpTypeFunction %void"};
-
-std::string GetFunctionTest(std::vector<const char*> body) {
-  auto result = header;
-  result += {"%_ptr_Function_uint = OpTypePointer Function %uint",
-             "%1 = OpFunction %void None %6", "%8 = OpLabel"};
-  result += body;
-  result += {"OpReturn", "OpFunctionEnd"};
-  return JoinAllInsts(result);
+std::string GetFunctionTest(std::string body) {
+  auto result = single_entry_header;
+  result += "%_ptr_Function_uint = OpTypePointer Function %uint\n";
+  result += "%1 = OpFunction %void None %6\n";
+  result += "%8 = OpLabel\n";
+  result += body + "\n";
+  result += "OpReturn\n";
+  result += "OpFunctionEnd\n";
+  return result;
 }
 
 TEST_F(DecomposeInitializedVariablesTest, FunctionChanged) {
-  std::string input =
-      GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function %uint_1"});
-  std::string expected = GetFunctionTest(
-      {"%9 = OpVariable %_ptr_Function_uint Function", "OpStore %9 %uint_1"});
+  std::string input = "%9 = OpVariable %_ptr_Function_uint Function %uint_1";
+  std::string expected = R"(%9 = OpVariable %_ptr_Function_uint Function
+OpStore %9 %uint_1)";
 
   SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      GetFunctionTest(input), GetFunctionTest(expected),
+      /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, FunctionUnchanged) {
+  std::string input = "%9 = OpVariable %_ptr_Function_uint Function";
+
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      GetFunctionTest(input), GetFunctionTest(input), /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, FunctionMultipleVariables) {
+  std::string input = R"(%9 = OpVariable %_ptr_Function_uint Function %uint_1
+%10 = OpVariable %_ptr_Function_uint Function %4)";
+  std::string expected = R"(%9 = OpVariable %_ptr_Function_uint Function
+%10 = OpVariable %_ptr_Function_uint Function
+OpStore %9 %uint_1
+OpStore %10 %4)";
+
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      GetFunctionTest(input), GetFunctionTest(expected),
+      /* skip_nop = */ false);
+}
+
+std::string GetGlobalTest(std::string storage_class, bool initialized,
+                          bool decomposed) {
+  auto result = single_entry_header;
+
+  result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+            storage_class + " %uint\n";
+  if (initialized) {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + " %4\n";
+  } else {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + "\n";
+  }
+  result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+  if (decomposed) result += "OpStore %8 %4\n";
+  result += R"(OpReturn
+OpFunctionEnd
+)";
+  return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateChanged) {
+  std::string input = GetGlobalTest("Private", true, false);
+  std::string expected = GetGlobalTest("Private", false, true);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
       input, expected, /* skip_nop = */ false);
 }
 
-TEST_F(DecomposeInitializedVariablesTest, FunctionUnchanged) {
-  std::string input =
-      GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function"});
-
+TEST_F(DecomposeInitializedVariablesTest, PrivateUnchanged) {
+  std::string input = GetGlobalTest("Private", false, false);
   SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
       input, input, /* skip_nop = */ false);
 }
 
-TEST_F(DecomposeInitializedVariablesTest, FunctionMultipleVariables) {
-  std::string input =
-      GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function %uint_1",
-                       "%10 = OpVariable %_ptr_Function_uint Function %4"});
-  std::string expected =
-      GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function",
-                       "%10 = OpVariable %_ptr_Function_uint Function",
-                       "OpStore %9 %uint_1", "OpStore %10 %4"});
-
+TEST_F(DecomposeInitializedVariablesTest, OutputChanged) {
+  std::string input = GetGlobalTest("Output", true, false);
+  std::string expected = GetGlobalTest("Output", false, true);
   SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
       input, expected, /* skip_nop = */ false);
 }
 
+TEST_F(DecomposeInitializedVariablesTest, OutputUnchanged) {
+  std::string input = GetGlobalTest("Output", false, false);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, input, /* skip_nop = */ false);
+}
+
+std::string multiple_entry_header = R"(OpCapability Shader
+OpCapability VulkanMemoryModelKHR
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint Vertex %1 "vertex"
+OpEntryPoint Fragment %2 "fragment"
+%uint = OpTypeInt 32 0
+%4 = OpConstantNull %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+)";
+
+std::string GetGlobalMultipleEntryTest(std::string storage_class,
+                                       bool initialized, bool decomposed) {
+  auto result = multiple_entry_header;
+  result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+            storage_class + " %uint\n";
+  if (initialized) {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + " %4\n";
+  } else {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + "\n";
+  }
+  result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+  if (decomposed) result += "OpStore %8 %4\n";
+  result += R"(OpReturn
+OpFunctionEnd
+%2 = OpFunction %void None %10
+%10 = OpLabel
+)";
+  if (decomposed) result += "OpStore %8 %4\n";
+  result += R"(OpReturn
+OpFunctionEnd
+)";
+
+  return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateMultipleEntryChanged) {
+  std::string input = GetGlobalMultipleEntryTest("Private", true, false);
+  std::string expected = GetGlobalMultipleEntryTest("Private", false, true);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateMultipleEntryUnchanged) {
+  std::string input = GetGlobalMultipleEntryTest("Private", false, false);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, input, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputMultipleEntryChanged) {
+  std::string input = GetGlobalMultipleEntryTest("Output", true, false);
+  std::string expected = GetGlobalMultipleEntryTest("Output", false, true);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputMultipleEntryUnchanged) {
+  std::string input = GetGlobalMultipleEntryTest("Output", false, false);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, input, /* skip_nop = */ false);
+}
+
+std::string GetGlobalWithNonEntryPointTest(std::string storage_class,
+                                           bool initialized, bool decomposed) {
+  auto result = single_entry_header;
+  result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+            storage_class + " %uint\n";
+  if (initialized) {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + " %4\n";
+  } else {
+    result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+              storage_class + "\n";
+  }
+  result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+  if (decomposed) result += "OpStore %8 %4\n";
+  result += R"(OpReturn
+OpFunctionEnd
+%10 = OpFunction %void None %11
+%11 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateWithNonEntryPointChanged) {
+  std::string input = GetGlobalWithNonEntryPointTest("Private", true, false);
+  std::string expected = GetGlobalWithNonEntryPointTest("Private", false, true);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateWithNonEntryPointUnchanged) {
+  std::string input = GetGlobalWithNonEntryPointTest("Private", false, false);
+  //  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, input, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputWithNonEntryPointChanged) {
+  std::string input = GetGlobalWithNonEntryPointTest("Output", true, false);
+  std::string expected = GetGlobalWithNonEntryPointTest("Output", false, true);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputWithNonEntryPointUnchanged) {
+  std::string input = GetGlobalWithNonEntryPointTest("Output", false, false);
+  //  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+      input, input, /* skip_nop = */ false);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools