Fix SplitLoopHeader to handle single block loop (#4829)
The code in `CFG::SplitLoopHeader` assumes the loop header is not the
latch. This leads to it not being able to find the latch block. This
has been fixed, and a test added.
Fixes #4527
diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp
index ac0fcc3..5358be6 100644
--- a/source/opt/cfg.cpp
+++ b/source/opt/cfg.cpp
@@ -205,7 +205,7 @@
// Find the back edge
BasicBlock* latch_block = nullptr;
Function::iterator latch_block_iter = header_it;
- while (++latch_block_iter != fn->end()) {
+ for (; latch_block_iter != fn->end(); ++latch_block_iter) {
// If blocks are in the proper order, then the only branch that appears
// after the header is the latch.
if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
@@ -237,6 +237,15 @@
context->set_instr_block(inst, new_header);
});
+ // If |bb| was the latch block, the branch back to the header is not in
+ // |new_header|.
+ if (latch_block == bb) {
+ if (new_header->ContinueBlockId() == bb->id()) {
+ new_header->GetLoopMergeInst()->SetInOperand(1, {new_header_id});
+ }
+ latch_block = new_header;
+ }
+
// Adjust the OpPhi instructions as needed.
bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
std::vector<uint32_t> preheader_phi_ops;
diff --git a/test/opt/cfg_test.cpp b/test/opt/cfg_test.cpp
index 2cfc9f3..a4c6271 100644
--- a/test/opt/cfg_test.cpp
+++ b/test/opt/cfg_test.cpp
@@ -200,6 +200,77 @@
ContainerEq(expected_result2)));
}
+TEST_F(CFGTest, SplitLoopHeaderForSingleBlockLoop) {
+ const std::string test = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main"
+ OpExecutionMode %2 OriginUpperLeft
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %6 = OpTypeFunction %void
+ %2 = OpFunction %void None %6
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ %9 = OpPhi %uint %uint_0 %7 %9 %8
+ OpLoopMerge %10 %8 None
+ OpBranch %8
+ %10 = OpLabel
+ OpUnreachable
+ OpFunctionEnd
+)";
+
+ const std::string expected_result = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%6 = OpTypeFunction %void
+%2 = OpFunction %void None %6
+%7 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpBranch %11
+%11 = OpLabel
+%9 = OpPhi %uint %9 %11 %uint_0 %8
+OpLoopMerge %10 %11 None
+OpBranch %11
+%10 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_NE(nullptr, context);
+
+ BasicBlock* loop_header = context->get_instr_block(8);
+ ASSERT_TRUE(loop_header->GetLoopMergeInst() != nullptr);
+
+ CFG* cfg = context->cfg();
+ cfg->SplitLoopHeader(loop_header);
+
+ std::vector<uint32_t> binary;
+ bool skip_nop = false;
+ context->module()->ToBinary(&binary, skip_nop);
+
+ std::string optimized_asm;
+ SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
+ EXPECT_TRUE(tools.Disassemble(binary, &optimized_asm,
+ SpirvTools::kDefaultDisassembleOption))
+ << "Disassembling failed for shader\n"
+ << std::endl;
+
+ EXPECT_EQ(optimized_asm, expected_result);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools