Do merge return if the return is not at the end of the function. (#3337)
* Do merge return if the return is not at the end of the function.
We will remove the code in inlining to handle a return in the middle of
a function. To inline those functions, we need to run merge return to
move the return to the end of the function.
diff --git a/source/opt/function.h b/source/opt/function.h
index d7e4176..e68a1d0 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -88,6 +88,10 @@
// Returns the entry basic block for this function.
const std::unique_ptr<BasicBlock>& entry() const { return blocks_.front(); }
+ // Returns the last basic block in this function.
+ BasicBlock* tail() { return blocks_.back().get(); }
+ const BasicBlock* tail() const { return blocks_.back().get(); }
+
iterator begin() { return iterator(&blocks_, blocks_.begin()); }
iterator end() { return iterator(&blocks_, blocks_.end()); }
const_iterator begin() const { return cbegin(); }
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index bbac4bb..7d238da 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -39,8 +39,11 @@
if (!is_shader || return_blocks.size() == 0) {
return false;
}
- if (context()->GetStructuredCFGAnalysis()->ContainingConstruct(
- return_blocks[0]->id()) == 0) {
+ bool isInConstruct =
+ context()->GetStructuredCFGAnalysis()->ContainingConstruct(
+ return_blocks[0]->id()) != 0;
+ bool endsWithReturn = return_blocks[0] == function->tail();
+ if (!isInConstruct && endsWithReturn) {
return false;
}
}
diff --git a/test/opt/pass_merge_return_test.cpp b/test/opt/pass_merge_return_test.cpp
index d16b65c..e3ec312 100644
--- a/test/opt/pass_merge_return_test.cpp
+++ b/test/opt/pass_merge_return_test.cpp
@@ -1970,6 +1970,41 @@
EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
}
+TEST_F(MergeReturnPassTest, SingleReturnInMiddle) {
+ const std::string before =
+ R"(
+; CHECK: OpFunction
+; CHECK: OpReturn
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %main "main"
+ OpSource GLSL 450
+ OpName %main "main"
+ OpName %foo_ "foo("
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %foo_ = OpFunction %void None %4
+ %7 = OpLabel
+ OpSelectionMerge %8 None
+ OpBranchConditional %true %9 %8
+ %8 = OpLabel
+ OpReturn
+ %9 = OpLabel
+ OpBranch %8
+ OpFunctionEnd
+ %main = OpFunction %void None %4
+ %10 = OpLabel
+ %11 = OpFunctionCall %void %foo_
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<MergeReturnPass>(before, false);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools