Fix SplitLoopHeader to handle single block loop (#4829)

The code in `CFG::SplitLoopHeader` assumes the loop header is not the
latch.  This leads to it not being able to find the latch block.  This
has been fixed, and a test added.

Fixes #4527
diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp
index ac0fcc3..5358be6 100644
--- a/source/opt/cfg.cpp
+++ b/source/opt/cfg.cpp
@@ -205,7 +205,7 @@
   // Find the back edge
   BasicBlock* latch_block = nullptr;
   Function::iterator latch_block_iter = header_it;
-  while (++latch_block_iter != fn->end()) {
+  for (; latch_block_iter != fn->end(); ++latch_block_iter) {
     // If blocks are in the proper order, then the only branch that appears
     // after the header is the latch.
     if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
@@ -237,6 +237,15 @@
     context->set_instr_block(inst, new_header);
   });
 
+  // If |bb| was the latch block, the branch back to the header is not in
+  // |new_header|.
+  if (latch_block == bb) {
+    if (new_header->ContinueBlockId() == bb->id()) {
+      new_header->GetLoopMergeInst()->SetInOperand(1, {new_header_id});
+    }
+    latch_block = new_header;
+  }
+
   // Adjust the OpPhi instructions as needed.
   bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
     std::vector<uint32_t> preheader_phi_ops;
diff --git a/test/opt/cfg_test.cpp b/test/opt/cfg_test.cpp
index 2cfc9f3..a4c6271 100644
--- a/test/opt/cfg_test.cpp
+++ b/test/opt/cfg_test.cpp
@@ -200,6 +200,77 @@
                            ContainerEq(expected_result2)));
 }
 
+TEST_F(CFGTest, SplitLoopHeaderForSingleBlockLoop) {
+  const std::string test = R"(
+               OpCapability Shader
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %2 "main"
+               OpExecutionMode %2 OriginUpperLeft
+       %void = OpTypeVoid
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+          %6 = OpTypeFunction %void
+          %2 = OpFunction %void None %6
+          %7 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+          %9 = OpPhi %uint %uint_0 %7 %9 %8
+               OpLoopMerge %10 %8 None
+               OpBranch %8
+         %10 = OpLabel
+               OpUnreachable
+               OpFunctionEnd
+)";
+
+  const std::string expected_result = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %2 "main"
+OpExecutionMode %2 OriginUpperLeft
+%void = OpTypeVoid
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%6 = OpTypeFunction %void
+%2 = OpFunction %void None %6
+%7 = OpLabel
+OpBranch %8
+%8 = OpLabel
+OpBranch %11
+%11 = OpLabel
+%9 = OpPhi %uint %9 %11 %uint_0 %8
+OpLoopMerge %10 %11 None
+OpBranch %11
+%10 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+  std::unique_ptr<IRContext> context =
+      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
+                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  ASSERT_NE(nullptr, context);
+
+  BasicBlock* loop_header = context->get_instr_block(8);
+  ASSERT_TRUE(loop_header->GetLoopMergeInst() != nullptr);
+
+  CFG* cfg = context->cfg();
+  cfg->SplitLoopHeader(loop_header);
+
+  std::vector<uint32_t> binary;
+  bool skip_nop = false;
+  context->module()->ToBinary(&binary, skip_nop);
+
+  std::string optimized_asm;
+  SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
+  EXPECT_TRUE(tools.Disassemble(binary, &optimized_asm,
+                                SpirvTools::kDefaultDisassembleOption))
+      << "Disassembling failed for shader\n"
+      << std::endl;
+
+  EXPECT_EQ(optimized_asm, expected_result);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools