Don't inline recursive functions. (#2130)
* Move ProcessFunction* function from pass to the context.
There are a few functions that are used to traverse the call tree.
They currently live in the Pass class, but they have nothing to do with
a pass, and may be needed outside of a pass. They would be better in
the ir context, or in a specific call tree class if we ever have a need
for it.
* Don't inline recursive functions.
Inlining does not check if a function is recursive or not. This has
been fine as long as the shader was a Vulkan shader, which forbid
recursive functions. However, not all shaders are vulkan, so either
we limit inlining to Vulkan shaders or we teach it to look for recursive
functions.
I prefer to keep the passes as general as is reasonable. The change
does not require much new code in inlining and gives a reason to refactor
some other code.
The changes are to add a member function to the Function class that
checks if that function is recursive or not.
Then this is used in inlining to not inlining a function call if it calls
a recursive function.
* Add id to function analysis
There are a few places that build a map from ids to Function whose
result is that id. I decided to add an analysis to the context for this
to reduce that code, and simplify some of the functions.
* Add missing file.
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index db9b7b4..2793312 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -562,7 +562,7 @@
// Process all entry point functions.
ProcessFunction pfn = [this](Function* fp) { return AggressiveDCE(fp); };
- modified |= ProcessEntryPointCallTree(pfn, get_module());
+ modified |= context()->ProcessEntryPointCallTree(pfn);
// Process module-level instructions. Now that all live instructions have
// been marked, it is safe to remove dead global values.
@@ -575,7 +575,7 @@
// Cleanup all CFG including all unreachable blocks.
ProcessFunction cleanup = [this](Function* f) { return CFGCleanup(f); };
- modified |= ProcessEntryPointCallTree(cleanup, get_module());
+ modified |= context()->ProcessEntryPointCallTree(cleanup);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
@@ -589,7 +589,7 @@
live_function_set.insert(fp);
return false;
};
- ProcessEntryPointCallTree(mark_live, get_module());
+ context()->ProcessEntryPointCallTree(mark_live);
bool modified = false;
for (auto funcIter = get_module()->begin();
diff --git a/source/opt/block_merge_pass.cpp b/source/opt/block_merge_pass.cpp
index 99d3db6..09deb21 100644
--- a/source/opt/block_merge_pass.cpp
+++ b/source/opt/block_merge_pass.cpp
@@ -136,7 +136,7 @@
Pass::Status BlockMergePass::Process() {
// Process all entry point functions.
ProcessFunction pfn = [this](Function* fp) { return MergeBlocks(fp); };
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/ccp_pass.cpp b/source/opt/ccp_pass.cpp
index a8411d9..8356195 100644
--- a/source/opt/ccp_pass.cpp
+++ b/source/opt/ccp_pass.cpp
@@ -319,7 +319,7 @@
// Process all entry point functions.
ProcessFunction pfn = [this](Function* fp) { return PropagateConstants(fp); };
- bool modified = ProcessReachableCallTree(pfn, context());
+ bool modified = context()->ProcessReachableCallTree(pfn);
return modified ? Pass::Status::SuccessWithChange
: Pass::Status::SuccessWithoutChange;
}
diff --git a/source/opt/cfg_cleanup_pass.cpp b/source/opt/cfg_cleanup_pass.cpp
index 2d54846..6d48637 100644
--- a/source/opt/cfg_cleanup_pass.cpp
+++ b/source/opt/cfg_cleanup_pass.cpp
@@ -30,7 +30,7 @@
Pass::Status CFGCleanupPass::Process() {
// Process all entry point functions.
ProcessFunction pfn = [this](Function* fp) { return CFGCleanup(fp); };
- bool modified = ProcessReachableCallTree(pfn, context());
+ bool modified = context()->ProcessReachableCallTree(pfn);
return modified ? Pass::Status::SuccessWithChange
: Pass::Status::SuccessWithoutChange;
}
diff --git a/source/opt/common_uniform_elim_pass.cpp b/source/opt/common_uniform_elim_pass.cpp
index d447d11..efa40aa 100644
--- a/source/opt/common_uniform_elim_pass.cpp
+++ b/source/opt/common_uniform_elim_pass.cpp
@@ -526,7 +526,7 @@
ProcessFunction pfn = [this](Function* fp) {
return EliminateCommonUniform(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
index c65e551..9893536 100644
--- a/source/opt/dead_branch_elim_pass.cpp
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -419,9 +419,9 @@
// Structured order is more intuitive so use it where possible.
if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
- ProcessReachableCallTree(reorder_structured, context());
+ context()->ProcessReachableCallTree(reorder_structured);
} else {
- ProcessReachableCallTree(reorder_dominators, context());
+ context()->ProcessReachableCallTree(reorder_dominators);
}
}
@@ -435,7 +435,7 @@
ProcessFunction pfn = [this](Function* fp) {
return EliminateDeadBranches(fp);
};
- bool modified = ProcessReachableCallTree(pfn, context());
+ bool modified = context()->ProcessReachableCallTree(pfn);
if (modified) FixBlockOrder();
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/dead_insert_elim_pass.cpp b/source/opt/dead_insert_elim_pass.cpp
index b42588f..7d56343 100644
--- a/source/opt/dead_insert_elim_pass.cpp
+++ b/source/opt/dead_insert_elim_pass.cpp
@@ -255,7 +255,7 @@
ProcessFunction pfn = [this](Function* fp) {
return EliminateDeadInserts(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/eliminate_dead_functions_pass.cpp b/source/opt/eliminate_dead_functions_pass.cpp
index 5be983a..f067be5 100644
--- a/source/opt/eliminate_dead_functions_pass.cpp
+++ b/source/opt/eliminate_dead_functions_pass.cpp
@@ -29,7 +29,7 @@
live_function_set.insert(fp);
return false;
};
- ProcessReachableCallTree(mark_live, context());
+ context()->ProcessReachableCallTree(mark_live);
bool modified = false;
for (auto funcIter = get_module()->begin();
diff --git a/source/opt/function.cpp b/source/opt/function.cpp
index 6092e69..d4457ad 100644
--- a/source/opt/function.cpp
+++ b/source/opt/function.cpp
@@ -13,7 +13,10 @@
// limitations under the License.
#include "source/opt/function.h"
+#include "function.h"
+#include "ir_context.h"
+#include <source/util/bit_vector.h>
#include <ostream>
#include <sstream>
@@ -96,6 +99,19 @@
return nullptr;
}
+bool Function::IsRecursive() const {
+ IRContext* ctx = blocks_.front()->GetLabel()->context();
+ IRContext::ProcessFunction mark_visited = [this](Function* fp) {
+ return fp == this;
+ };
+
+ // Process the call tree from all of the function called by |this|. If it get
+ // back to |this|, then we have a recursive function.
+ std::queue<uint32_t> roots;
+ ctx->AddCalls(this, &roots);
+ return ctx->ProcessCallTreeFromRoots(mark_visited, &roots);
+}
+
std::ostream& operator<<(std::ostream& str, const Function& func) {
str << func.PrettyPrint();
return str;
@@ -115,6 +131,5 @@
});
return str.str();
}
-
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/function.h b/source/opt/function.h
index 549a029..7c0166f 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -121,6 +121,9 @@
BasicBlock* InsertBasicBlockAfter(std::unique_ptr<BasicBlock>&& new_block,
BasicBlock* position);
+ // Return true if the function calls itself either directly or indirectly.
+ bool IsRecursive() const;
+
// Pretty-prints all the basic blocks in this function into a std::string.
//
// |options| are the disassembly options. SPV_BINARY_TO_TEXT_OPTION_NO_HEADER
diff --git a/source/opt/inline_exhaustive_pass.cpp b/source/opt/inline_exhaustive_pass.cpp
index 5714cd8..10b5e98 100644
--- a/source/opt/inline_exhaustive_pass.cpp
+++ b/source/opt/inline_exhaustive_pass.cpp
@@ -64,7 +64,7 @@
Pass::Status InlineExhaustivePass::ProcessImpl() {
// Attempt exhaustive inlining on each entry point function in module
ProcessFunction pfn = [this](Function* fp) { return InlineExhaustive(fp); };
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/inline_opaque_pass.cpp b/source/opt/inline_opaque_pass.cpp
index c2c3719..e94f26d 100644
--- a/source/opt/inline_opaque_pass.cpp
+++ b/source/opt/inline_opaque_pass.cpp
@@ -98,7 +98,7 @@
Pass::Status InlineOpaquePass::ProcessImpl() {
// Do opaque inlining on each function in entry point call tree
ProcessFunction pfn = [this](Function* fp) { return InlineOpaque(fp); };
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index 0543f44..cdd5659 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -617,8 +617,15 @@
// done validly if the return was not in a loop in the original function.
// Also remember functions with multiple (early) returns.
AnalyzeReturns(func);
- return no_return_in_loop_.find(func->result_id()) !=
- no_return_in_loop_.cend();
+ if (no_return_in_loop_.find(func->result_id()) == no_return_in_loop_.cend()) {
+ return false;
+ }
+
+ if (func->IsRecursive()) {
+ return false;
+ }
+
+ return true;
}
void InlinePass::InitializeInline() {
diff --git a/source/opt/instrument_pass.cpp b/source/opt/instrument_pass.cpp
index 3291bbb..7f56a1e 100644
--- a/source/opt/instrument_pass.cpp
+++ b/source/opt/instrument_pass.cpp
@@ -574,7 +574,7 @@
if (done.insert(fi).second) {
Function* fn = id2function_.at(fi);
// Add calls first so we don't add new output function
- AddCalls(fn, roots);
+ context()->AddCalls(fn, roots);
modified = InstrumentFunction(fn, stage_idx, pfn) || modified;
}
}
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index 038ad6d..c1158e7 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -27,6 +27,7 @@
static const int kSpvDecorateDecorationInIdx = 1;
static const int kSpvDecorateBuiltinInIdx = 2;
static const int kEntryPointInterfaceInIdx = 3;
+static const int kEntryPointFunctionIdInIdx = 1;
} // anonymous namespace
@@ -70,6 +71,9 @@
if (set & kAnalysisStructuredCFG) {
BuildStructuredCFGAnalysis();
}
+ if (set & kAnalysisIdToFuncMapping) {
+ BuildIdToFuncMapping();
+ }
}
void IRContext::InvalidateAnalysesExceptFor(
@@ -110,6 +114,9 @@
if (analyses_to_invalidate & kAnalysisStructuredCFG) {
struct_cfg_analysis_.reset(nullptr);
}
+ if (analyses_to_invalidate & kAnalysisIdToFuncMapping) {
+ id_to_func_.clear();
+ }
valid_analyses_ = Analysis(valid_analyses_ & ~analyses_to_invalidate);
}
@@ -673,6 +680,70 @@
return var_id;
}
+void IRContext::AddCalls(const Function* func, std::queue<uint32_t>* todo) {
+ for (auto bi = func->begin(); bi != func->end(); ++bi)
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii)
+ if (ii->opcode() == SpvOpFunctionCall)
+ todo->push(ii->GetSingleWordInOperand(0));
+}
+
+bool IRContext::ProcessEntryPointCallTree(ProcessFunction& pfn) {
+ // Collect all of the entry points as the roots.
+ std::queue<uint32_t> roots;
+ for (auto& e : module()->entry_points()) {
+ roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
+ }
+ return ProcessCallTreeFromRoots(pfn, &roots);
+}
+
+bool IRContext::ProcessReachableCallTree(ProcessFunction& pfn) {
+ std::queue<uint32_t> roots;
+
+ // Add all entry points since they can be reached from outside the module.
+ for (auto& e : module()->entry_points())
+ roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
+
+ // Add all exported functions since they can be reached from outside the
+ // module.
+ for (auto& a : annotations()) {
+ // TODO: Handle group decorations as well. Currently not generate by any
+ // front-end, but could be coming.
+ if (a.opcode() == SpvOp::SpvOpDecorate) {
+ if (a.GetSingleWordOperand(1) ==
+ SpvDecoration::SpvDecorationLinkageAttributes) {
+ uint32_t lastOperand = a.NumOperands() - 1;
+ if (a.GetSingleWordOperand(lastOperand) ==
+ SpvLinkageType::SpvLinkageTypeExport) {
+ uint32_t id = a.GetSingleWordOperand(0);
+ if (GetFunction(id)) {
+ roots.push(id);
+ }
+ }
+ }
+ }
+ }
+
+ return ProcessCallTreeFromRoots(pfn, &roots);
+}
+
+bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn,
+ std::queue<uint32_t>* roots) {
+ // Process call tree
+ bool modified = false;
+ std::unordered_set<uint32_t> done;
+
+ while (!roots->empty()) {
+ const uint32_t fi = roots->front();
+ roots->pop();
+ if (done.insert(fi).second) {
+ Function* fn = GetFunction(fi);
+ modified = pfn(fn) || modified;
+ AddCalls(fn, roots);
+ }
+ }
+ return modified;
+}
+
// Gets the dominator analysis for function |f|.
DominatorAnalysis* IRContext::GetDominatorAnalysis(const Function* f) {
if (!AreAnalysesValid(kAnalysisDominatorAnalysis)) {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 6e9eda8..83e06b8 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -20,6 +20,7 @@
#include <limits>
#include <map>
#include <memory>
+#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -74,9 +75,12 @@
kAnalysisValueNumberTable = 1 << 10,
kAnalysisStructuredCFG = 1 << 11,
kAnalysisBuiltinVarId = 1 << 12,
- kAnalysisEnd = 1 << 13
+ kAnalysisIdToFuncMapping = 1 << 13,
+ kAnalysisEnd = 1 << 14
};
+ using ProcessFunction = std::function<bool(Function*)>;
+
friend inline Analysis operator|(Analysis lhs, Analysis rhs);
friend inline Analysis& operator|=(Analysis& lhs, Analysis rhs);
friend inline Analysis operator<<(Analysis a, int shift);
@@ -478,6 +482,43 @@
// supported, return 0.
uint32_t GetBuiltinVarId(uint32_t builtin);
+ // Returns the function whose id is |id|, if one exists. Returns |nullptr|
+ // otherwise.
+ Function* GetFunction(uint32_t id) {
+ if (!AreAnalysesValid(kAnalysisIdToFuncMapping)) {
+ BuildIdToFuncMapping();
+ }
+ auto entry = id_to_func_.find(id);
+ return (entry != id_to_func_.end()) ? entry->second : nullptr;
+ }
+
+ Function* GetFunction(Instruction* inst) {
+ if (inst->opcode() != SpvOpFunction) {
+ return nullptr;
+ }
+ return GetFunction(inst->result_id());
+ }
+
+ // Add to |todo| all ids of functions called in |func|.
+ void AddCalls(const Function* func, std::queue<uint32_t>* todo);
+
+ // Applies |pfn| to every function in the call trees that are rooted at the
+ // entry points. Returns true if any call |pfn| returns true. By convention
+ // |pfn| should return true if it modified the module.
+ bool ProcessEntryPointCallTree(ProcessFunction& pfn);
+
+ // Applies |pfn| to every function in the call trees rooted at the entry
+ // points and exported functions. Returns true if any call |pfn| returns
+ // true. By convention |pfn| should return true if it modified the module.
+ bool ProcessReachableCallTree(ProcessFunction& pfn);
+
+ // Applies |pfn| to every function in the call trees rooted at the elements of
+ // |roots|. Returns true if any call to |pfn| returns true. By convention
+ // |pfn| should return true if it modified the module. After returning
+ // |roots| will be empty.
+ bool ProcessCallTreeFromRoots(ProcessFunction& pfn,
+ std::queue<uint32_t>* roots);
+
private:
// Builds the def-use manager from scratch, even if it was already valid.
void BuildDefUseManager() {
@@ -498,6 +539,15 @@
valid_analyses_ = valid_analyses_ | kAnalysisInstrToBlockMapping;
}
+ // Builds the instruction-function map for the whole module.
+ void BuildIdToFuncMapping() {
+ id_to_func_.clear();
+ for (auto& fn : *module_) {
+ id_to_func_[fn.result_id()] = &fn;
+ }
+ valid_analyses_ = valid_analyses_ | kAnalysisIdToFuncMapping;
+ }
+
void BuildDecorationManager() {
decoration_mgr_ = MakeUnique<analysis::DecorationManager>(module());
valid_analyses_ = valid_analyses_ | kAnalysisDecorations;
@@ -613,13 +663,20 @@
std::unique_ptr<analysis::DecorationManager> decoration_mgr_;
std::unique_ptr<FeatureManager> feature_mgr_;
- // A map from instructions the the basic block they belong to. This mapping is
+ // A map from instructions to the basic block they belong to. This mapping is
// built on-demand when get_instr_block() is called.
//
// NOTE: Do not traverse this map. Ever. Use the function and basic block
// iterators to traverse instructions.
std::unordered_map<Instruction*, BasicBlock*> instr_to_block_;
+ // A map from ids to the function they define. This mapping is
+ // built on-demand when GetFunction() is called.
+ //
+ // NOTE: Do not traverse this map. Ever. Use the function and basic block
+ // iterators to traverse instructions.
+ std::unordered_map<uint32_t, Function*> id_to_func_;
+
// A bitset indicating which analyes are currently valid.
Analysis valid_analyses_;
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index 18dc9a0..5b976a1 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -292,7 +292,7 @@
ProcessFunction pfn = [this](Function* fp) {
return ConvertLocalAccessChains(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index 047ff0b..9330ab7 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -200,7 +200,7 @@
return LocalSingleBlockLoadStoreElim(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index 71ee6b3..6c09dec 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -67,7 +67,7 @@
ProcessFunction pfn = [this](Function* fp) {
return LocalSingleStoreElim(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/local_ssa_elim_pass.cpp b/source/opt/local_ssa_elim_pass.cpp
index 902caeb..8209aa4 100644
--- a/source/opt/local_ssa_elim_pass.cpp
+++ b/source/opt/local_ssa_elim_pass.cpp
@@ -50,7 +50,7 @@
ProcessFunction pfn = [this](Function* fp) {
return SSARewriter(this).RewriteFunctionIntoSSA(fp);
};
- bool modified = ProcessEntryPointCallTree(pfn, get_module());
+ bool modified = context()->ProcessEntryPointCallTree(pfn);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index 10c7d2f..820760c 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -54,7 +54,7 @@
return true;
};
- bool modified = ProcessReachableCallTree(pfn, context());
+ bool modified = context()->ProcessReachableCallTree(pfn);
if (failed) {
return Status::Failure;
diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp
index 4bf719d..edcd245 100644
--- a/source/opt/pass.cpp
+++ b/source/opt/pass.cpp
@@ -23,86 +23,12 @@
namespace {
-const uint32_t kEntryPointFunctionIdInIdx = 1;
const uint32_t kTypePointerTypeIdInIdx = 1;
} // namespace
Pass::Pass() : consumer_(nullptr), context_(nullptr), already_run_(false) {}
-void Pass::AddCalls(Function* func, std::queue<uint32_t>* todo) {
- for (auto bi = func->begin(); bi != func->end(); ++bi)
- for (auto ii = bi->begin(); ii != bi->end(); ++ii)
- if (ii->opcode() == SpvOpFunctionCall)
- todo->push(ii->GetSingleWordInOperand(0));
-}
-
-bool Pass::ProcessEntryPointCallTree(ProcessFunction& pfn, Module* module) {
- // Map from function's result id to function
- std::unordered_map<uint32_t, Function*> id2function;
- for (auto& fn : *module) id2function[fn.result_id()] = &fn;
-
- // Collect all of the entry points as the roots.
- std::queue<uint32_t> roots;
- for (auto& e : module->entry_points()) {
- roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
- }
- return ProcessCallTreeFromRoots(pfn, id2function, &roots);
-}
-
-bool Pass::ProcessReachableCallTree(ProcessFunction& pfn,
- IRContext* irContext) {
- // Map from function's result id to function
- std::unordered_map<uint32_t, Function*> id2function;
- for (auto& fn : *irContext->module()) id2function[fn.result_id()] = &fn;
-
- std::queue<uint32_t> roots;
-
- // Add all entry points since they can be reached from outside the module.
- for (auto& e : irContext->module()->entry_points())
- roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
-
- // Add all exported functions since they can be reached from outside the
- // module.
- for (auto& a : irContext->annotations()) {
- // TODO: Handle group decorations as well. Currently not generate by any
- // front-end, but could be coming.
- if (a.opcode() == SpvOp::SpvOpDecorate) {
- if (a.GetSingleWordOperand(1) ==
- SpvDecoration::SpvDecorationLinkageAttributes) {
- uint32_t lastOperand = a.NumOperands() - 1;
- if (a.GetSingleWordOperand(lastOperand) ==
- SpvLinkageType::SpvLinkageTypeExport) {
- uint32_t id = a.GetSingleWordOperand(0);
- if (id2function.count(id) != 0) roots.push(id);
- }
- }
- }
- }
-
- return ProcessCallTreeFromRoots(pfn, id2function, &roots);
-}
-
-bool Pass::ProcessCallTreeFromRoots(
- ProcessFunction& pfn,
- const std::unordered_map<uint32_t, Function*>& id2function,
- std::queue<uint32_t>* roots) {
- // Process call tree
- bool modified = false;
- std::unordered_set<uint32_t> done;
-
- while (!roots->empty()) {
- const uint32_t fi = roots->front();
- roots->pop();
- if (done.insert(fi).second) {
- Function* fn = id2function.at(fi);
- modified = pfn(fn) || modified;
- AddCalls(fn, roots);
- }
- }
- return modified;
-}
-
Pass::Status Pass::Run(IRContext* ctx) {
if (already_run_) {
return Status::Failure;
diff --git a/source/opt/pass.h b/source/opt/pass.h
index df17450..aabc645 100644
--- a/source/opt/pass.h
+++ b/source/opt/pass.h
@@ -17,7 +17,6 @@
#include <algorithm>
#include <map>
-#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <utility>
@@ -91,28 +90,6 @@
// Returns a pointer to the CFG for current module.
CFG* cfg() const { return context()->cfg(); }
- // Add to |todo| all ids of functions called in |func|.
- void AddCalls(Function* func, std::queue<uint32_t>* todo);
-
- // Applies |pfn| to every function in the call trees that are rooted at the
- // entry points. Returns true if any call |pfn| returns true. By convention
- // |pfn| should return true if it modified the module.
- bool ProcessEntryPointCallTree(ProcessFunction& pfn, Module* module);
-
- // Applies |pfn| to every function in the call trees rooted at the entry
- // points and exported functions. Returns true if any call |pfn| returns
- // true. By convention |pfn| should return true if it modified the module.
- bool ProcessReachableCallTree(ProcessFunction& pfn, IRContext* irContext);
-
- // Applies |pfn| to every function in the call trees rooted at the elements of
- // |roots|. Returns true if any call to |pfn| returns true. By convention
- // |pfn| should return true if it modified the module. After returning
- // |roots| will be empty.
- bool ProcessCallTreeFromRoots(
- ProcessFunction& pfn,
- const std::unordered_map<uint32_t, Function*>& id2function,
- std::queue<uint32_t>* roots);
-
// Run the pass on the given |module|. Returns Status::Failure if errors occur
// when processing. Returns the corresponding Status::Success if processing is
// successful to indicate whether changes are made to the module. If there
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index cbaa656..cfbac19 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -38,6 +38,7 @@
fold_spec_const_op_composite_test.cpp
fold_test.cpp
freeze_spec_const_test.cpp
+ function_test.cpp
if_conversion_test.cpp
inline_opaque_test.cpp
inline_test.cpp
@@ -61,7 +62,6 @@
pass_manager_test.cpp
pass_merge_return_test.cpp
pass_remove_duplicates_test.cpp
- pass_test.cpp pass_utils.cpp
pass_utils.cpp
private_to_local_test.cpp
process_lines_test.cpp
@@ -88,3 +88,4 @@
LIBS SPIRV-Tools-opt
PCH_FILE pch_test_opt
)
+
diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp
new file mode 100644
index 0000000..38ab298
--- /dev/null
+++ b/test/opt/function_test.cpp
@@ -0,0 +1,173 @@
+// Copyright (c) 2018 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "function_utils.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "source/opt/build_module.h"
+#include "source/opt/ir_context.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using ::testing::Eq;
+
+TEST(FunctionTest, IsNotRecursive) {
+ const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 9);
+ EXPECT_FALSE(func->IsRecursive());
+
+ func = spvtest::GetFunction(ctx->module(), 12);
+ EXPECT_FALSE(func->IsRecursive());
+}
+
+TEST(FunctionTest, IsDirectlyRecursive) {
+ const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 9);
+ EXPECT_TRUE(func->IsRecursive());
+}
+
+TEST(FunctionTest, IsIndirectlyRecursive) {
+ const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+%14 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 9);
+ EXPECT_TRUE(func->IsRecursive());
+
+ func = spvtest::GetFunction(ctx->module(), 12);
+ EXPECT_TRUE(func->IsRecursive());
+}
+
+TEST(FunctionTest, IsNotRecuriseCallingRecursive) {
+ const std::string text = R"(
+OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 1);
+ EXPECT_FALSE(func->IsRecursive());
+}
+
+} // namespace
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp
index 2ece9f4..44a9698 100644
--- a/test/opt/inline_test.cpp
+++ b/test/opt/inline_test.cpp
@@ -3046,12 +3046,72 @@
SinglePassRunAndMatch<InlineExhaustivePass>(text, true);
}
+TEST_F(InlineTest, DontInlineDirectlyRecursiveFunc) {
+ // Test that the name of the result id of the call is deleted.
+ const std::string test =
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<InlineExhaustivePass>(test, test, false, true);
+}
+
+TEST_F(InlineTest, DontInlineInDirectlyRecursiveFunc) {
+ // Test that the name of the result id of the call is deleted.
+ const std::string test =
+ R"(OpCapability Shader
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+OpDecorate %2 DescriptorSet 439418829
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_struct_6 = OpTypeStruct %float %float
+%7 = OpTypeFunction %_struct_6
+%1 = OpFunction %void Pure|Const %4
+%8 = OpLabel
+%2 = OpFunctionCall %_struct_6 %9
+OpKill
+OpFunctionEnd
+%9 = OpFunction %_struct_6 None %7
+%10 = OpLabel
+%11 = OpFunctionCall %_struct_6 %12
+OpUnreachable
+OpFunctionEnd
+%12 = OpFunction %_struct_6 None %7
+%13 = OpLabel
+%14 = OpFunctionCall %_struct_6 %9
+OpUnreachable
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<InlineExhaustivePass>(test, test, false, true);
+}
+
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// Empty modules
// Modules without function definitions
// Modules in which all functions do not call other functions
-// Recursive functions (calling self & calling each other)
// Caller and callee both accessing the same global variable
// Functions with OpLine & OpNoLine
// Others?
diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp
index 10a92a7..f66b16e 100644
--- a/test/opt/ir_context_test.cpp
+++ b/test/opt/ir_context_test.cpp
@@ -30,6 +30,7 @@
using Analysis = IRContext::Analysis;
using ::testing::Each;
+using ::testing::UnorderedElementsAre;
class DummyPassPreservesNothing : public Pass {
public:
@@ -370,6 +371,202 @@
EXPECT_TRUE(context->annotations().empty());
}
+TEST_F(IRContextTest, BasicVisitFromEntryPoint) {
+ // Make sure we visit the entry point, and the function it calls.
+ // Do not visit Dead or Exported.
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %10 "main"
+ OpName %10 "main"
+ OpName %Dead "Dead"
+ OpName %11 "Constant"
+ OpName %ExportedFunc "ExportedFunc"
+ OpDecorate %ExportedFunc LinkageAttributes "ExportedFunc" Export
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+ %10 = OpFunction %void None %6
+ %14 = OpLabel
+ %15 = OpFunctionCall %void %11
+ %16 = OpFunctionCall %void %11
+ OpReturn
+ OpFunctionEnd
+ %11 = OpFunction %void None %6
+ %18 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %Dead = OpFunction %void None %6
+ %19 = OpLabel
+ OpReturn
+ OpFunctionEnd
+%ExportedFunc = OpFunction %void None %7
+ %20 = OpLabel
+ %21 = OpFunctionCall %void %11
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<IRContext> localContext =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ std::vector<uint32_t> processed;
+ Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
+ processed.push_back(fp->result_id());
+ return false;
+ };
+ localContext->ProcessEntryPointCallTree(mark_visited);
+ EXPECT_THAT(processed, UnorderedElementsAre(10, 11));
+}
+
+TEST_F(IRContextTest, BasicVisitReachable) {
+ // Make sure we visit the entry point, exported function, and the function
+ // they call. Do not visit Dead.
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %10 "main"
+ OpName %10 "main"
+ OpName %Dead "Dead"
+ OpName %11 "Constant"
+ OpName %12 "ExportedFunc"
+ OpName %13 "Constant2"
+ OpDecorate %12 LinkageAttributes "ExportedFunc" Export
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+ %10 = OpFunction %void None %6
+ %14 = OpLabel
+ %15 = OpFunctionCall %void %11
+ %16 = OpFunctionCall %void %11
+ OpReturn
+ OpFunctionEnd
+ %11 = OpFunction %void None %6
+ %18 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %Dead = OpFunction %void None %6
+ %19 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %12 = OpFunction %void None %6
+ %20 = OpLabel
+ %21 = OpFunctionCall %void %13
+ OpReturn
+ OpFunctionEnd
+ %13 = OpFunction %void None %6
+ %22 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<IRContext> localContext =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ std::vector<uint32_t> processed;
+ Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
+ processed.push_back(fp->result_id());
+ return false;
+ };
+ localContext->ProcessReachableCallTree(mark_visited);
+ EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13));
+}
+
+TEST_F(IRContextTest, BasicVisitOnlyOnce) {
+ // Make sure we visit %12 only once, even if it is called from two different
+ // functions.
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %10 "main"
+ OpName %10 "main"
+ OpName %Dead "Dead"
+ OpName %11 "Constant"
+ OpName %12 "ExportedFunc"
+ OpDecorate %12 LinkageAttributes "ExportedFunc" Export
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+ %10 = OpFunction %void None %6
+ %14 = OpLabel
+ %15 = OpFunctionCall %void %11
+ %16 = OpFunctionCall %void %12
+ OpReturn
+ OpFunctionEnd
+ %11 = OpFunction %void None %6
+ %18 = OpLabel
+ %19 = OpFunctionCall %void %12
+ OpReturn
+ OpFunctionEnd
+ %Dead = OpFunction %void None %6
+ %20 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %12 = OpFunction %void None %6
+ %21 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<IRContext> localContext =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ std::vector<uint32_t> processed;
+ Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
+ processed.push_back(fp->result_id());
+ return false;
+ };
+ localContext->ProcessReachableCallTree(mark_visited);
+ EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12));
+}
+
+TEST_F(IRContextTest, BasicDontVisitExportedVariable) {
+ // Make sure we only visit functions and not exported variables.
+ const std::string text = R"(
+ OpCapability Shader
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %10 "main"
+ OpExecutionMode %10 OriginUpperLeft
+ OpSource GLSL 150
+ OpName %10 "main"
+ OpName %12 "export_var"
+ OpDecorate %12 LinkageAttributes "export_var" Export
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %float_1 = OpConstant %float 1
+ %12 = OpVariable %float Output
+ %10 = OpFunction %void None %6
+ %14 = OpLabel
+ OpStore %12 %float_1
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+
+ std::unique_ptr<IRContext> localContext =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
+ << text << std::endl;
+
+ std::vector<uint32_t> processed;
+ Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
+ processed.push_back(fp->result_id());
+ return false;
+ };
+ localContext->ProcessReachableCallTree(mark_visited);
+ EXPECT_THAT(processed, UnorderedElementsAre(10));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/pass_test.cpp b/test/opt/pass_test.cpp
deleted file mode 100644
index bce05b6..0000000
--- a/test/opt/pass_test.cpp
+++ /dev/null
@@ -1,242 +0,0 @@
-// Copyright (c) 2017 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "gmock/gmock.h"
-#include "source/opt/pass.h"
-#include "test/opt/assembly_builder.h"
-#include "test/opt/pass_fixture.h"
-#include "test/opt/pass_utils.h"
-
-namespace spvtools {
-namespace opt {
-namespace {
-
-class DummyPass : public Pass {
- public:
- const char* name() const override { return "dummy-pass"; }
- Status Process() override { return Status::SuccessWithoutChange; }
-};
-
-using ::testing::UnorderedElementsAre;
-using PassClassTest = PassTest<::testing::Test>;
-
-TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
- // Make sure we visit the entry point, and the function it calls.
- // Do not visit Dead or Exported.
- const std::string text = R"(
- OpCapability Shader
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %10 "main"
- OpName %10 "main"
- OpName %Dead "Dead"
- OpName %11 "Constant"
- OpName %ExportedFunc "ExportedFunc"
- OpDecorate %ExportedFunc LinkageAttributes "ExportedFunc" Export
- %void = OpTypeVoid
- %6 = OpTypeFunction %void
- %10 = OpFunction %void None %6
- %14 = OpLabel
- %15 = OpFunctionCall %void %11
- %16 = OpFunctionCall %void %11
- OpReturn
- OpFunctionEnd
- %11 = OpFunction %void None %6
- %18 = OpLabel
- OpReturn
- OpFunctionEnd
- %Dead = OpFunction %void None %6
- %19 = OpLabel
- OpReturn
- OpFunctionEnd
-%ExportedFunc = OpFunction %void None %7
- %20 = OpLabel
- %21 = OpFunctionCall %void %11
- OpReturn
- OpFunctionEnd
-)";
- // clang-format on
-
- std::unique_ptr<IRContext> localContext =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
- << text << std::endl;
- DummyPass testPass;
- std::vector<uint32_t> processed;
- Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
- processed.push_back(fp->result_id());
- return false;
- };
- testPass.ProcessEntryPointCallTree(mark_visited, localContext->module());
- EXPECT_THAT(processed, UnorderedElementsAre(10, 11));
-}
-
-TEST_F(PassClassTest, BasicVisitReachable) {
- // Make sure we visit the entry point, exported function, and the function
- // they call. Do not visit Dead.
- const std::string text = R"(
- OpCapability Shader
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %10 "main"
- OpName %10 "main"
- OpName %Dead "Dead"
- OpName %11 "Constant"
- OpName %12 "ExportedFunc"
- OpName %13 "Constant2"
- OpDecorate %12 LinkageAttributes "ExportedFunc" Export
- %void = OpTypeVoid
- %6 = OpTypeFunction %void
- %10 = OpFunction %void None %6
- %14 = OpLabel
- %15 = OpFunctionCall %void %11
- %16 = OpFunctionCall %void %11
- OpReturn
- OpFunctionEnd
- %11 = OpFunction %void None %6
- %18 = OpLabel
- OpReturn
- OpFunctionEnd
- %Dead = OpFunction %void None %6
- %19 = OpLabel
- OpReturn
- OpFunctionEnd
- %12 = OpFunction %void None %9
- %20 = OpLabel
- %21 = OpFunctionCall %void %13
- OpReturn
- OpFunctionEnd
- %13 = OpFunction %void None %6
- %22 = OpLabel
- OpReturn
- OpFunctionEnd
-)";
- // clang-format on
-
- std::unique_ptr<IRContext> localContext =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
- << text << std::endl;
-
- DummyPass testPass;
- std::vector<uint32_t> processed;
- Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
- processed.push_back(fp->result_id());
- return false;
- };
- testPass.ProcessReachableCallTree(mark_visited, localContext.get());
- EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12, 13));
-}
-
-TEST_F(PassClassTest, BasicVisitOnlyOnce) {
- // Make sure we visit %11 only once, even if it is called from two different
- // functions.
- const std::string text = R"(
- OpCapability Shader
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %10 "main" %gl_FragColor
- OpName %10 "main"
- OpName %Dead "Dead"
- OpName %11 "Constant"
- OpName %12 "ExportedFunc"
- OpDecorate %12 LinkageAttributes "ExportedFunc" Export
- %void = OpTypeVoid
- %6 = OpTypeFunction %void
- %10 = OpFunction %void None %6
- %14 = OpLabel
- %15 = OpFunctionCall %void %11
- %16 = OpFunctionCall %void %12
- OpReturn
- OpFunctionEnd
- %11 = OpFunction %void None %6
- %18 = OpLabel
- %19 = OpFunctionCall %void %12
- OpReturn
- OpFunctionEnd
- %Dead = OpFunction %void None %6
- %20 = OpLabel
- OpReturn
- OpFunctionEnd
- %12 = OpFunction %void None %9
- %21 = OpLabel
- OpReturn
- OpFunctionEnd
-)";
- // clang-format on
-
- std::unique_ptr<IRContext> localContext =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
- << text << std::endl;
-
- DummyPass testPass;
- std::vector<uint32_t> processed;
- Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
- processed.push_back(fp->result_id());
- return false;
- };
- testPass.ProcessReachableCallTree(mark_visited, localContext.get());
- EXPECT_THAT(processed, UnorderedElementsAre(10, 11, 12));
-}
-
-TEST_F(PassClassTest, BasicDontVisitExportedVariable) {
- // Make sure we only visit functions and not exported variables.
- const std::string text = R"(
- OpCapability Shader
- OpMemoryModel Logical GLSL450
- OpEntryPoint Fragment %10 "main" %gl_FragColor
- OpExecutionMode %10 OriginUpperLeft
- OpSource GLSL 150
- OpName %10 "main"
- OpName %Dead "Dead"
- OpName %11 "Constant"
- OpName %12 "export_var"
- OpDecorate %12 LinkageAttributes "export_var" Export
- %void = OpTypeVoid
- %6 = OpTypeFunction %void
- %float = OpTypeFloat 32
- %float_1 = OpConstant %float 1
- %12 = OpVariable %float Output
- %10 = OpFunction %void None %6
- %14 = OpLabel
- OpStore %12 %float_1
- OpReturn
- OpFunctionEnd
-)";
- // clang-format on
-
- std::unique_ptr<IRContext> localContext =
- BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
- SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
- EXPECT_NE(nullptr, localContext) << "Assembling failed for shader:\n"
- << text << std::endl;
-
- DummyPass testPass;
- std::vector<uint32_t> processed;
- Pass::ProcessFunction mark_visited = [&processed](Function* fp) {
- processed.push_back(fp->result_id());
- return false;
- };
- testPass.ProcessReachableCallTree(mark_visited, localContext.get());
- EXPECT_THAT(processed, UnorderedElementsAre(10));
-}
-
-} // namespace
-} // namespace opt
-} // namespace spvtools