spirv-opt: Implement opt::Function::HasEarlyReturn function (#3711)
diff --git a/source/opt/function.cpp b/source/opt/function.cpp
index 21ce0c6..52054ea 100644
--- a/source/opt/function.cpp
+++ b/source/opt/function.cpp
@@ -227,6 +227,18 @@
return nullptr;
}
+bool Function::HasEarlyReturn() const {
+ auto post_dominator_analysis =
+ blocks_.front()->GetLabel()->context()->GetPostDominatorAnalysis(this);
+ for (auto& block : blocks_) {
+ if (spvOpcodeIsReturn(block->tail()->opcode()) &&
+ !post_dominator_analysis->Dominates(block.get(), entry().get())) {
+ return true;
+ }
+ }
+ return false;
+}
+
bool Function::IsRecursive() const {
IRContext* ctx = blocks_.front()->GetLabel()->context();
IRContext::ProcessFunction mark_visited = [this](Function* fp) {
diff --git a/source/opt/function.h b/source/opt/function.h
index 1d11a09..b7c17a6 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -158,7 +158,10 @@
BasicBlock* InsertBasicBlockBefore(std::unique_ptr<BasicBlock>&& new_block,
BasicBlock* position);
- // Return true if the function calls itself either directly or indirectly.
+ // Returns true if the function has a return block other than the exit block.
+ bool HasEarlyReturn() const;
+
+ // Returns true if the function calls itself either directly or indirectly.
bool IsRecursive() const;
// Pretty-prints all the basic blocks in this function into a std::string.
diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp
index b67ca49..af25bac 100644
--- a/test/opt/function_test.cpp
+++ b/test/opt/function_test.cpp
@@ -29,6 +29,60 @@
using ::testing::Eq;
+TEST(FunctionTest, HasEarlyReturn) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Vertex %6 "main"
+
+; Types
+ %2 = OpTypeBool
+ %3 = OpTypeVoid
+ %4 = OpTypeFunction %3
+
+; Constants
+ %5 = OpConstantTrue %2
+
+; main function without early return
+ %6 = OpFunction %3 None %4
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+
+; function with early return
+ %11 = OpFunction %3 None %4
+ %12 = OpLabel
+ OpSelectionMerge %15 None
+ OpBranchConditional %5 %13 %14
+ %13 = OpLabel
+ OpReturn
+ %14 = OpLabel
+ OpBranch %15
+ %15 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, shader,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+
+ // Tests |function| without early return.
+ auto* function = spvtest::GetFunction(context->module(), 6);
+ ASSERT_FALSE(function->HasEarlyReturn());
+
+ // Tests |function| with early return.
+ function = spvtest::GetFunction(context->module(), 11);
+ ASSERT_TRUE(function->HasEarlyReturn());
+}
+
TEST(FunctionTest, IsNotRecursive) {
const std::string text = R"(
OpCapability Shader