Do merge return if the return is not at the end of the function. (#3337)

* Do merge return if the return is not at the end of the function.

We will remove the code in inlining to handle a return in the middle of
a function.  To inline those functions, we need to run merge return to
move the return to the end of the function.
diff --git a/source/opt/function.h b/source/opt/function.h
index d7e4176..e68a1d0 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -88,6 +88,10 @@
   // Returns the entry basic block for this function.
   const std::unique_ptr<BasicBlock>& entry() const { return blocks_.front(); }
 
+  // Returns the last basic block in this function.
+  BasicBlock* tail() { return blocks_.back().get(); }
+  const BasicBlock* tail() const { return blocks_.back().get(); }
+
   iterator begin() { return iterator(&blocks_, blocks_.begin()); }
   iterator end() { return iterator(&blocks_, blocks_.end()); }
   const_iterator begin() const { return cbegin(); }
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index bbac4bb..7d238da 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -39,8 +39,11 @@
       if (!is_shader || return_blocks.size() == 0) {
         return false;
       }
-      if (context()->GetStructuredCFGAnalysis()->ContainingConstruct(
-              return_blocks[0]->id()) == 0) {
+      bool isInConstruct =
+          context()->GetStructuredCFGAnalysis()->ContainingConstruct(
+              return_blocks[0]->id()) != 0;
+      bool endsWithReturn = return_blocks[0] == function->tail();
+      if (!isInConstruct && endsWithReturn) {
         return false;
       }
     }
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index d16b65c..e3ec312 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1970,6 +1970,41 @@
   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
 }
 
+TEST_F(MergeReturnPassTest, SingleReturnInMiddle) {
+  const std::string before =
+      R"(
+; CHECK: OpFunction
+; CHECK: OpReturn
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Vertex %main "main"
+               OpSource GLSL 450
+               OpName %main "main"
+               OpName %foo_ "foo("
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %foo_ = OpFunction %void None %4
+          %7 = OpLabel
+               OpSelectionMerge %8 None
+               OpBranchConditional %true %9 %8
+          %8 = OpLabel
+               OpReturn
+          %9 = OpLabel
+               OpBranch %8
+               OpFunctionEnd
+       %main = OpFunction %void None %4
+         %10 = OpLabel
+         %11 = OpFunctionCall %void %foo_
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<MergeReturnPass>(before, false);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools