Add support for Private & Output to initializer decompose flag (#2537)
Fixes #2388
diff --git a/source/opt/decompose_initialized_variables_pass.cpp b/source/opt/decompose_initialized_variables_pass.cpp
index b2471ca..875bf7e 100644
--- a/source/opt/decompose_initialized_variables_pass.cpp
+++ b/source/opt/decompose_initialized_variables_pass.cpp
@@ -34,43 +34,78 @@
Pass::Status DecomposeInitializedVariablesPass::Process() {
auto* module = context()->module();
- bool changed = false;
+ std::unordered_set<Instruction*> changed;
- // TODO(zoddicus): Handle 'Output' variables
- // TODO(zoddicus): Handle 'Private' variables
+ std::vector<std::tuple<uint32_t, uint32_t>> global_stores;
+ for (auto iter = module->types_values_begin();
+ iter != module->types_values_end(); ++iter) {
+ Instruction* inst = &(*iter);
+ if (!HasInitializer(inst)) continue;
- // Handle 'Function' variables
+ auto var_id = inst->result_id();
+ auto val_id = inst->GetOperand(3).words[0];
+ global_stores.push_back(std::make_tuple(var_id, val_id));
+ iter->RemoveOperand(3);
+ changed.insert(&*iter);
+ }
+
+ std::unordered_set<uint32_t> entry_ids;
+ for (auto entry = module->entry_points().begin();
+ entry != module->entry_points().end(); ++entry) {
+ entry_ids.insert(entry->GetSingleWordInOperand(1));
+ }
+
for (auto func = module->begin(); func != module->end(); ++func) {
- auto block = func->entry().get();
- std::vector<Instruction*> new_stores;
-
- auto last_var = block->begin();
- for (auto iter = block->begin();
- iter != block->end() && iter->opcode() == SpvOpVariable; ++iter) {
- last_var = iter;
+ std::vector<Instruction*> function_stores;
+ auto first_block = func->entry().get();
+ inst_iterator insert_point = first_block->begin();
+ for (auto iter = first_block->begin();
+ iter != first_block->end() && iter->opcode() == SpvOpVariable;
+ ++iter) {
+ // For valid SPIRV-V, there is guaranteed to be at least one instruction
+ // after the OpVariable instructions.
+ insert_point = (*iter).NextNode();
Instruction* inst = &(*iter);
if (!HasInitializer(inst)) continue;
- changed = true;
auto var_id = inst->result_id();
auto val_id = inst->GetOperand(3).words[0];
Instruction* store_inst = new Instruction(
context(), SpvOpStore, 0, 0,
{{SPV_OPERAND_TYPE_ID, {var_id}}, {SPV_OPERAND_TYPE_ID, {val_id}}});
- new_stores.push_back(store_inst);
+ function_stores.push_back(store_inst);
iter->RemoveOperand(3);
- get_def_use_mgr()->UpdateDefUse(&*iter);
+ changed.insert(&*iter);
}
- for (auto store = new_stores.begin(); store != new_stores.end(); ++store) {
- context()->AnalyzeDefUse(*store);
- context()->set_instr_block(*store, block);
- (*store)->InsertAfter(&*last_var);
- last_var = *store;
+ if (entry_ids.find(func->result_id()) != entry_ids.end()) {
+ for (auto store_ids : global_stores) {
+ uint32_t var_id;
+ uint32_t val_id;
+ std::tie(var_id, val_id) = store_ids;
+ auto* store_inst = new Instruction(
+ context(), SpvOpStore, 0, 0,
+ {{SPV_OPERAND_TYPE_ID, {var_id}}, {SPV_OPERAND_TYPE_ID, {val_id}}});
+ context()->set_instr_block(store_inst, &*first_block);
+ first_block->AddInstruction(std::unique_ptr<Instruction>(store_inst));
+ store_inst->InsertBefore(&*insert_point);
+ changed.insert(store_inst);
+ }
+ }
+
+ for (auto store = function_stores.begin(); store != function_stores.end();
+ ++store) {
+ context()->set_instr_block(*store, first_block);
+ (*store)->InsertBefore(&*insert_point);
+ changed.insert(*store);
}
}
- return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+ auto* def_use_mgr = get_def_use_mgr();
+ for (auto* inst : changed) def_use_mgr->UpdateDefUse(inst);
+
+ return !changed.empty() ? Pass::Status::SuccessWithChange
+ : Pass::Status::SuccessWithoutChange;
}
} // namespace opt
diff --git a/test/opt/decompose_initialized_variables_test.cpp b/test/opt/decompose_initialized_variables_test.cpp
index 188e8b2..cdebb3f 100644
--- a/test/opt/decompose_initialized_variables_test.cpp
+++ b/test/opt/decompose_initialized_variables_test.cpp
@@ -23,63 +23,230 @@
using DecomposeInitializedVariablesTest = PassTest<::testing::Test>;
-void operator+=(std::vector<const char*>& lhs,
- const std::vector<const char*> rhs) {
- for (auto elem : rhs) lhs.push_back(elem);
-}
+std::string single_entry_header = R"(OpCapability Shader
+OpCapability VulkanMemoryModelKHR
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint Vertex %1 "shader"
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%4 = OpConstantNull %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+)";
-std::vector<const char*> header = {
- "OpCapability Shader",
- "OpCapability VulkanMemoryModelKHR",
- "OpExtension \"SPV_KHR_vulkan_memory_model\"",
- "OpMemoryModel Logical VulkanKHR",
- "OpEntryPoint Vertex %1 \"shader\"",
- "%uint = OpTypeInt 32 0",
- "%uint_1 = OpConstant %uint 1",
- "%4 = OpConstantNull %uint",
- "%void = OpTypeVoid",
- "%6 = OpTypeFunction %void"};
-
-std::string GetFunctionTest(std::vector<const char*> body) {
- auto result = header;
- result += {"%_ptr_Function_uint = OpTypePointer Function %uint",
- "%1 = OpFunction %void None %6", "%8 = OpLabel"};
- result += body;
- result += {"OpReturn", "OpFunctionEnd"};
- return JoinAllInsts(result);
+std::string GetFunctionTest(std::string body) {
+ auto result = single_entry_header;
+ result += "%_ptr_Function_uint = OpTypePointer Function %uint\n";
+ result += "%1 = OpFunction %void None %6\n";
+ result += "%8 = OpLabel\n";
+ result += body + "\n";
+ result += "OpReturn\n";
+ result += "OpFunctionEnd\n";
+ return result;
}
TEST_F(DecomposeInitializedVariablesTest, FunctionChanged) {
- std::string input =
- GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function %uint_1"});
- std::string expected = GetFunctionTest(
- {"%9 = OpVariable %_ptr_Function_uint Function", "OpStore %9 %uint_1"});
+ std::string input = "%9 = OpVariable %_ptr_Function_uint Function %uint_1";
+ std::string expected = R"(%9 = OpVariable %_ptr_Function_uint Function
+OpStore %9 %uint_1)";
SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ GetFunctionTest(input), GetFunctionTest(expected),
+ /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, FunctionUnchanged) {
+ std::string input = "%9 = OpVariable %_ptr_Function_uint Function";
+
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ GetFunctionTest(input), GetFunctionTest(input), /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, FunctionMultipleVariables) {
+ std::string input = R"(%9 = OpVariable %_ptr_Function_uint Function %uint_1
+%10 = OpVariable %_ptr_Function_uint Function %4)";
+ std::string expected = R"(%9 = OpVariable %_ptr_Function_uint Function
+%10 = OpVariable %_ptr_Function_uint Function
+OpStore %9 %uint_1
+OpStore %10 %4)";
+
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ GetFunctionTest(input), GetFunctionTest(expected),
+ /* skip_nop = */ false);
+}
+
+std::string GetGlobalTest(std::string storage_class, bool initialized,
+ bool decomposed) {
+ auto result = single_entry_header;
+
+ result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+ storage_class + " %uint\n";
+ if (initialized) {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + " %4\n";
+ } else {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + "\n";
+ }
+ result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+ if (decomposed) result += "OpStore %8 %4\n";
+ result += R"(OpReturn
+OpFunctionEnd
+)";
+ return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateChanged) {
+ std::string input = GetGlobalTest("Private", true, false);
+ std::string expected = GetGlobalTest("Private", false, true);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
input, expected, /* skip_nop = */ false);
}
-TEST_F(DecomposeInitializedVariablesTest, FunctionUnchanged) {
- std::string input =
- GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function"});
-
+TEST_F(DecomposeInitializedVariablesTest, PrivateUnchanged) {
+ std::string input = GetGlobalTest("Private", false, false);
SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
input, input, /* skip_nop = */ false);
}
-TEST_F(DecomposeInitializedVariablesTest, FunctionMultipleVariables) {
- std::string input =
- GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function %uint_1",
- "%10 = OpVariable %_ptr_Function_uint Function %4"});
- std::string expected =
- GetFunctionTest({"%9 = OpVariable %_ptr_Function_uint Function",
- "%10 = OpVariable %_ptr_Function_uint Function",
- "OpStore %9 %uint_1", "OpStore %10 %4"});
-
+TEST_F(DecomposeInitializedVariablesTest, OutputChanged) {
+ std::string input = GetGlobalTest("Output", true, false);
+ std::string expected = GetGlobalTest("Output", false, true);
SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
input, expected, /* skip_nop = */ false);
}
+TEST_F(DecomposeInitializedVariablesTest, OutputUnchanged) {
+ std::string input = GetGlobalTest("Output", false, false);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, input, /* skip_nop = */ false);
+}
+
+std::string multiple_entry_header = R"(OpCapability Shader
+OpCapability VulkanMemoryModelKHR
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical VulkanKHR
+OpEntryPoint Vertex %1 "vertex"
+OpEntryPoint Fragment %2 "fragment"
+%uint = OpTypeInt 32 0
+%4 = OpConstantNull %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+)";
+
+std::string GetGlobalMultipleEntryTest(std::string storage_class,
+ bool initialized, bool decomposed) {
+ auto result = multiple_entry_header;
+ result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+ storage_class + " %uint\n";
+ if (initialized) {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + " %4\n";
+ } else {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + "\n";
+ }
+ result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+ if (decomposed) result += "OpStore %8 %4\n";
+ result += R"(OpReturn
+OpFunctionEnd
+%2 = OpFunction %void None %10
+%10 = OpLabel
+)";
+ if (decomposed) result += "OpStore %8 %4\n";
+ result += R"(OpReturn
+OpFunctionEnd
+)";
+
+ return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateMultipleEntryChanged) {
+ std::string input = GetGlobalMultipleEntryTest("Private", true, false);
+ std::string expected = GetGlobalMultipleEntryTest("Private", false, true);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateMultipleEntryUnchanged) {
+ std::string input = GetGlobalMultipleEntryTest("Private", false, false);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, input, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputMultipleEntryChanged) {
+ std::string input = GetGlobalMultipleEntryTest("Output", true, false);
+ std::string expected = GetGlobalMultipleEntryTest("Output", false, true);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputMultipleEntryUnchanged) {
+ std::string input = GetGlobalMultipleEntryTest("Output", false, false);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, input, /* skip_nop = */ false);
+}
+
+std::string GetGlobalWithNonEntryPointTest(std::string storage_class,
+ bool initialized, bool decomposed) {
+ auto result = single_entry_header;
+ result += "%_ptr_" + storage_class + "_uint = OpTypePointer " +
+ storage_class + " %uint\n";
+ if (initialized) {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + " %4\n";
+ } else {
+ result += "%8 = OpVariable %_ptr_" + storage_class + "_uint " +
+ storage_class + "\n";
+ }
+ result += R"(%1 = OpFunction %void None %9
+%9 = OpLabel
+)";
+ if (decomposed) result += "OpStore %8 %4\n";
+ result += R"(OpReturn
+OpFunctionEnd
+%10 = OpFunction %void None %11
+%11 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ return result;
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateWithNonEntryPointChanged) {
+ std::string input = GetGlobalWithNonEntryPointTest("Private", true, false);
+ std::string expected = GetGlobalWithNonEntryPointTest("Private", false, true);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, PrivateWithNonEntryPointUnchanged) {
+ std::string input = GetGlobalWithNonEntryPointTest("Private", false, false);
+ // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, input, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputWithNonEntryPointChanged) {
+ std::string input = GetGlobalWithNonEntryPointTest("Output", true, false);
+ std::string expected = GetGlobalWithNonEntryPointTest("Output", false, true);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, expected, /* skip_nop = */ false);
+}
+
+TEST_F(DecomposeInitializedVariablesTest, OutputWithNonEntryPointUnchanged) {
+ std::string input = GetGlobalWithNonEntryPointTest("Output", false, false);
+ // SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<DecomposeInitializedVariablesPass>(
+ input, input, /* skip_nop = */ false);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools