Build struct order only for the section needed when unrolling. (#4830)
We currently build the structured order for all nodes reachable from the
loop header when unrolling a loop. However, unrolling only needs the
nodes in the loop and possibly the merge node.
To avoid needless computation, I have implemented a search that will
stop at the merge node.
Fixes #4827
diff --git a/source/cfa.h b/source/cfa.h
index 7cadf55..f55a7bd 100644
--- a/source/cfa.h
+++ b/source/cfa.h
@@ -68,6 +68,8 @@
/// CFG following postorder traversal semantics
/// @param[in] backedge A function that will be called when a backedge is
/// encountered during a traversal
+ /// @param[in] terminal A function that will be called to determine if the
+ /// search should stop at the given node.
/// NOTE: The @p successor_func and predecessor_func each return a pointer to
/// a
/// collection such that iterators to that collection remain valid for the
@@ -76,7 +78,8 @@
const BB* entry, get_blocks_func successor_func,
std::function<void(cbb_ptr)> preorder,
std::function<void(cbb_ptr)> postorder,
- std::function<void(cbb_ptr, cbb_ptr)> backedge);
+ std::function<void(cbb_ptr, cbb_ptr)> backedge,
+ std::function<bool(cbb_ptr)> terminal);
/// @brief Calculates dominator edges for a set of blocks
///
@@ -138,7 +141,8 @@
const BB* entry, get_blocks_func successor_func,
std::function<void(cbb_ptr)> preorder,
std::function<void(cbb_ptr)> postorder,
- std::function<void(cbb_ptr, cbb_ptr)> backedge) {
+ std::function<void(cbb_ptr, cbb_ptr)> backedge,
+ std::function<bool(cbb_ptr)> terminal) {
std::unordered_set<uint32_t> processed;
/// NOTE: work_list is the sequence of nodes from the root node to the node
@@ -152,7 +156,7 @@
while (!work_list.empty()) {
block_info& top = work_list.back();
- if (top.iter == end(*successor_func(top.block))) {
+ if (terminal(top.block) || top.iter == end(*successor_func(top.block))) {
postorder(top.block);
work_list.pop_back();
} else {
@@ -266,11 +270,13 @@
auto mark_visited = [&visited](const BB* b) { visited.insert(b); };
auto ignore_block = [](const BB*) {};
auto ignore_blocks = [](const BB*, const BB*) {};
+ auto no_terminal_blocks = [](const BB*) { return false; };
auto traverse_from_root = [&mark_visited, &succ_func, &ignore_block,
- &ignore_blocks](const BB* entry) {
+ &ignore_blocks,
+ &no_terminal_blocks](const BB* entry) {
DepthFirstTraversal(entry, succ_func, mark_visited, ignore_block,
- ignore_blocks);
+ ignore_blocks, no_terminal_blocks);
};
std::vector<BB*> result;
diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp
index 5358be6..66b1aed 100644
--- a/source/opt/cfg.cpp
+++ b/source/opt/cfg.cpp
@@ -74,6 +74,12 @@
void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
std::list<BasicBlock*>* order) {
+ ComputeStructuredOrder(func, root, nullptr, order);
+}
+
+void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
+ BasicBlock* end,
+ std::list<BasicBlock*>* order) {
assert(module_->context()->get_feature_mgr()->HasCapability(
SpvCapabilityShader) &&
"This only works on structured control flow");
@@ -82,6 +88,8 @@
ComputeStructuredSuccessors(func);
auto ignore_block = [](cbb_ptr) {};
auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ auto terminal = [end](cbb_ptr bb) { return bb == end; };
+
auto get_structured_successors = [this](const BasicBlock* b) {
return &(block2structured_succs_[b]);
};
@@ -92,7 +100,8 @@
order->push_front(const_cast<BasicBlock*>(b));
};
CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
- ignore_block, post_order, ignore_edge);
+ ignore_block, post_order, ignore_edge,
+ terminal);
}
void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
diff --git a/source/opt/cfg.h b/source/opt/cfg.h
index 33412f1..fa4fef2 100644
--- a/source/opt/cfg.h
+++ b/source/opt/cfg.h
@@ -66,6 +66,14 @@
void ComputeStructuredOrder(Function* func, BasicBlock* root,
std::list<BasicBlock*>* order);
+ // Compute structured block order into |order| for |func| starting at |root|
+ // and ending at |end|. This order has the property that dominators come
+ // before all blocks they dominate, merge blocks come after all blocks that
+ // are in the control constructs of their header, and continue blocks come
+ // after all the blocks in the body of their loop.
+ void ComputeStructuredOrder(Function* func, BasicBlock* root, BasicBlock* end,
+ std::list<BasicBlock*>* order);
+
// Applies |f| to all blocks that can be reach from |bb| in post order.
void ForEachBlockInPostOrder(BasicBlock* bb,
const std::function<void(BasicBlock*)>& f);
diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp
index d86de15..d6017bb 100644
--- a/source/opt/dominator_tree.cpp
+++ b/source/opt/dominator_tree.cpp
@@ -59,7 +59,9 @@
PreLambda pre, PostLambda post) {
// Ignore backedge operation.
auto nop_backedge = [](const BBType*, const BBType*) {};
- CFA<BBType>::DepthFirstTraversal(bb, successors, pre, post, nop_backedge);
+ auto no_terminal_blocks = [](const BBType*) { return false; };
+ CFA<BBType>::DepthFirstTraversal(bb, successors, pre, post, nop_backedge,
+ no_terminal_blocks);
}
// Wrapper around CFA::DepthFirstTraversal to provide an interface to perform
diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp
index 4feb64e..13982d1 100644
--- a/source/opt/loop_descriptor.cpp
+++ b/source/opt/loop_descriptor.cpp
@@ -497,7 +497,8 @@
// continue blocks that must be copied to retain the structured order.
// The structured order will include these.
std::list<BasicBlock*> order;
- cfg.ComputeStructuredOrder(loop_header_->GetParent(), loop_header_, &order);
+ cfg.ComputeStructuredOrder(loop_header_->GetParent(), loop_header_,
+ loop_merge_, &order);
for (BasicBlock* bb : order) {
if (bb == GetMergeBlock()) {
break;
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index dd605d2..be05c9c 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -861,12 +861,13 @@
std::vector<std::pair<uint32_t, uint32_t>> back_edges;
auto ignore_block = [](const BasicBlock*) {};
auto ignore_edge = [](const BasicBlock*, const BasicBlock*) {};
+ auto no_terminal_blocks = [](const BasicBlock*) { return false; };
if (!function.ordered_blocks().empty()) {
/// calculate dominators
CFA<BasicBlock>::DepthFirstTraversal(
function.first_block(), function.AugmentedCFGSuccessorsFunction(),
ignore_block, [&](const BasicBlock* b) { postorder.push_back(b); },
- ignore_edge);
+ ignore_edge, no_terminal_blocks);
auto edges = CFA<BasicBlock>::CalculateDominators(
postorder, function.AugmentedCFGPredecessorsFunction());
for (auto edge : edges) {
@@ -879,7 +880,7 @@
function.pseudo_exit_block(),
function.AugmentedCFGPredecessorsFunction(), ignore_block,
[&](const BasicBlock* b) { postdom_postorder.push_back(b); },
- ignore_edge);
+ ignore_edge, no_terminal_blocks);
auto postdom_edges = CFA<BasicBlock>::CalculateDominators(
postdom_postorder, function.AugmentedCFGSuccessorsFunction());
for (auto edge : postdom_edges) {
@@ -893,7 +894,8 @@
ignore_block, ignore_block,
[&](const BasicBlock* from, const BasicBlock* to) {
back_edges.emplace_back(from->id(), to->id());
- });
+ },
+ no_terminal_blocks);
}
UpdateContinueConstructExitBlocks(function, back_edges);
diff --git a/test/opt/cfg_test.cpp b/test/opt/cfg_test.cpp
index a4c6271..7dfd2bc 100644
--- a/test/opt/cfg_test.cpp
+++ b/test/opt/cfg_test.cpp
@@ -271,6 +271,54 @@
EXPECT_EQ(optimized_asm, expected_result);
}
+TEST_F(CFGTest, ComputeStructedOrderForLoop) {
+ const std::string test = R"(
+OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Vertex %main "main"
+OpName %main "main"
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%uint = OpTypeInt 32 0
+%5 = OpConstant %uint 5
+%main = OpFunction %void None %4
+%8 = OpLabel
+OpBranch %9
+%9 = OpLabel
+OpLoopMerge %11 %10 None
+OpBranchConditional %true %11 %10
+%10 = OpLabel
+OpBranch %9
+%11 = OpLabel
+OpBranch %12
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ CFG* cfg = context->cfg();
+ Module* module = context->module();
+ Function* function = &*module->begin();
+ std::list<BasicBlock*> order;
+ cfg->ComputeStructuredOrder(function, context->get_instr_block(9),
+ context->get_instr_block(11), &order);
+
+ // Order should contain the loop header, the continue target, and the merge
+ // node.
+ std::list<BasicBlock*> expected_result = {context->get_instr_block(9),
+ context->get_instr_block(10),
+ context->get_instr_block(11)};
+ EXPECT_THAT(order, ContainerEq(expected_result));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools