Improve non-semantic instruction handling in the optimizer (#3693)
* No longer blindly add global non-semantic info instructions to global
types and values
* functions now have a list of non-semantic instructions that succeed
them in the global scope
* global non-semantic instructions go in global types and values if
they appear before any function, otherwise they are attached to the
immediate function predecessor in the module
* changed ADCE to use the function removal utility
* Modified EliminateFunction to have special handling for non-semantic
instructions in the global scope
* non-semantic instructions are moved to an earlier function (or full
global set) if the function they are attached to is eliminated
* Added IRContext::KillNonSemanticInfo to remove the tree of
non-semantic instructions that use an instruction
* this is used in function elimination
* There is still significant work in the optimizer to handle
non-semantic instructions fully in the optimizer
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 71bbed1..c8688a3 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -22,6 +22,7 @@
#include "source/cfa.h"
#include "source/latest_version_glsl_std_450_header.h"
+#include "source/opt/eliminate_dead_functions_util.h"
#include "source/opt/iterator.h"
#include "source/opt/reflect.h"
#include "source/spirv_constant.h"
@@ -727,8 +728,8 @@
funcIter != get_module()->end();) {
if (live_function_set.count(&*funcIter) == 0) {
modified = true;
- EliminateFunction(&*funcIter);
- funcIter = funcIter.Erase();
+ funcIter =
+ eliminatedeadfunctionsutil::EliminateFunction(context(), &funcIter);
} else {
++funcIter;
}
@@ -737,12 +738,6 @@
return modified;
}
-void AggressiveDCEPass::EliminateFunction(Function* func) {
- // Remove all of the instruction in the function body
- func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); },
- true);
-}
-
bool AggressiveDCEPass::ProcessGlobalValues() {
// Remove debug and annotation statements referencing dead instructions.
// This must be done before killing the instructions, otherwise there are
diff --git a/source/opt/aggressive_dead_code_elim_pass.h b/source/opt/aggressive_dead_code_elim_pass.h
index 2ce5b57..f02e729 100644
--- a/source/opt/aggressive_dead_code_elim_pass.h
+++ b/source/opt/aggressive_dead_code_elim_pass.h
@@ -127,9 +127,6 @@
// Erases functions that are unreachable from the entry points of the module.
bool EliminateDeadFunctions();
- // Removes |func| from the module and deletes all its instructions.
- void EliminateFunction(Function* func);
-
// For function |func|, mark all Stores to non-function-scope variables
// and block terminating instructions as live. Recursively mark the values
// they use. When complete, mark any non-live instructions to be deleted.
diff --git a/source/opt/eliminate_dead_functions_util.cpp b/source/opt/eliminate_dead_functions_util.cpp
index 8a38959..6b5234b 100644
--- a/source/opt/eliminate_dead_functions_util.cpp
+++ b/source/opt/eliminate_dead_functions_util.cpp
@@ -21,9 +21,35 @@
Module::iterator EliminateFunction(IRContext* context,
Module::iterator* func_iter) {
+ bool first_func = *func_iter == context->module()->begin();
+ bool seen_func_end = false;
(*func_iter)
- ->ForEachInst([context](Instruction* inst) { context->KillInst(inst); },
- true);
+ ->ForEachInst(
+ [context, first_func, func_iter, &seen_func_end](Instruction* inst) {
+ if (inst->opcode() == SpvOpFunctionEnd) {
+ seen_func_end = true;
+ }
+ // Move non-semantic instructions to the previous function or
+ // global values if this is the first function.
+ if (seen_func_end && inst->opcode() == SpvOpExtInst) {
+ assert(inst->IsNonSemanticInstruction());
+ std::unique_ptr<Instruction> clone(inst->Clone(context));
+ context->ForgetUses(inst);
+ context->AnalyzeDefUse(clone.get());
+ if (first_func) {
+ context->AddGlobalValue(std::move(clone));
+ } else {
+ auto prev_func_iter = *func_iter;
+ --prev_func_iter;
+ prev_func_iter->AddNonSemanticInstruction(std::move(clone));
+ }
+ inst->ToNop();
+ } else {
+ context->KillNonSemanticInfo(inst);
+ context->KillInst(inst);
+ }
+ },
+ true, true);
return func_iter->Erase();
}
diff --git a/source/opt/function.cpp b/source/opt/function.cpp
index 320f8ca..21ce0c6 100644
--- a/source/opt/function.cpp
+++ b/source/opt/function.cpp
@@ -47,31 +47,40 @@
}
clone->SetFunctionEnd(std::unique_ptr<Instruction>(EndInst()->Clone(ctx)));
+
+ clone->non_semantic_.reserve(non_semantic_.size());
+ for (auto& non_semantic : non_semantic_) {
+ clone->AddNonSemanticInstruction(
+ std::unique_ptr<Instruction>(non_semantic->Clone(ctx)));
+ }
return clone;
}
void Function::ForEachInst(const std::function<void(Instruction*)>& f,
- bool run_on_debug_line_insts) {
+ bool run_on_debug_line_insts,
+ bool run_on_non_semantic_insts) {
WhileEachInst(
[&f](Instruction* inst) {
f(inst);
return true;
},
- run_on_debug_line_insts);
+ run_on_debug_line_insts, run_on_non_semantic_insts);
}
void Function::ForEachInst(const std::function<void(const Instruction*)>& f,
- bool run_on_debug_line_insts) const {
+ bool run_on_debug_line_insts,
+ bool run_on_non_semantic_insts) const {
WhileEachInst(
[&f](const Instruction* inst) {
f(inst);
return true;
},
- run_on_debug_line_insts);
+ run_on_debug_line_insts, run_on_non_semantic_insts);
}
bool Function::WhileEachInst(const std::function<bool(Instruction*)>& f,
- bool run_on_debug_line_insts) {
+ bool run_on_debug_line_insts,
+ bool run_on_non_semantic_insts) {
if (def_inst_) {
if (!def_inst_->WhileEachInst(f, run_on_debug_line_insts)) {
return false;
@@ -99,13 +108,26 @@
}
}
- if (end_inst_) return end_inst_->WhileEachInst(f, run_on_debug_line_insts);
+ if (end_inst_) {
+ if (!end_inst_->WhileEachInst(f, run_on_debug_line_insts)) {
+ return false;
+ }
+ }
+
+ if (run_on_non_semantic_insts) {
+ for (auto& non_semantic : non_semantic_) {
+ if (!non_semantic->WhileEachInst(f, run_on_debug_line_insts)) {
+ return false;
+ }
+ }
+ }
return true;
}
bool Function::WhileEachInst(const std::function<bool(const Instruction*)>& f,
- bool run_on_debug_line_insts) const {
+ bool run_on_debug_line_insts,
+ bool run_on_non_semantic_insts) const {
if (def_inst_) {
if (!static_cast<const Instruction*>(def_inst_.get())
->WhileEachInst(f, run_on_debug_line_insts)) {
@@ -133,9 +155,21 @@
}
}
- if (end_inst_)
- return static_cast<const Instruction*>(end_inst_.get())
- ->WhileEachInst(f, run_on_debug_line_insts);
+ if (end_inst_) {
+ if (!static_cast<const Instruction*>(end_inst_.get())
+ ->WhileEachInst(f, run_on_debug_line_insts)) {
+ return false;
+ }
+ }
+
+ if (run_on_non_semantic_insts) {
+ for (auto& non_semantic : non_semantic_) {
+ if (!static_cast<const Instruction*>(non_semantic.get())
+ ->WhileEachInst(f, run_on_debug_line_insts)) {
+ return false;
+ }
+ }
+ }
return true;
}
diff --git a/source/opt/function.h b/source/opt/function.h
index f5035f0..1d11a09 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -79,6 +79,11 @@
// Saves the given function end instruction.
inline void SetFunctionEnd(std::unique_ptr<Instruction> end_inst);
+ // Add a non-semantic instruction that succeeds this function in the module.
+ // These instructions are maintained in the order they are added.
+ inline void AddNonSemanticInstruction(
+ std::unique_ptr<Instruction> non_semantic);
+
// Returns the given function end instruction.
inline Instruction* EndInst() { return end_inst_.get(); }
inline const Instruction* EndInst() const { return end_inst_.get(); }
@@ -115,19 +120,24 @@
}
// Runs the given function |f| on instructions in this function, in order,
- // and optionally on debug line instructions that might precede them.
+ // and optionally on debug line instructions that might precede them and
+ // non-semantic instructions that succceed the function.
void ForEachInst(const std::function<void(Instruction*)>& f,
- bool run_on_debug_line_insts = false);
+ bool run_on_debug_line_insts = false,
+ bool run_on_non_semantic_insts = false);
void ForEachInst(const std::function<void(const Instruction*)>& f,
- bool run_on_debug_line_insts = false) const;
+ bool run_on_debug_line_insts = false,
+ bool run_on_non_semantic_insts = false) const;
// Runs the given function |f| on instructions in this function, in order,
- // and optionally on debug line instructions that might precede them.
- // If |f| returns false, iteration is terminated and this function returns
- // false.
+ // and optionally on debug line instructions that might precede them and
+ // non-semantic instructions that succeed the function. If |f| returns
+ // false, iteration is terminated and this function returns false.
bool WhileEachInst(const std::function<bool(Instruction*)>& f,
- bool run_on_debug_line_insts = false);
+ bool run_on_debug_line_insts = false,
+ bool run_on_non_semantic_insts = false);
bool WhileEachInst(const std::function<bool(const Instruction*)>& f,
- bool run_on_debug_line_insts = false) const;
+ bool run_on_debug_line_insts = false,
+ bool run_on_non_semantic_insts = false) const;
// Runs the given function |f| on each parameter instruction in this function,
// in order, and optionally on debug line instructions that might precede
@@ -172,6 +182,8 @@
std::vector<std::unique_ptr<BasicBlock>> blocks_;
// The OpFunctionEnd instruction.
std::unique_ptr<Instruction> end_inst_;
+ // Non-semantic instructions succeeded by this function.
+ std::vector<std::unique_ptr<Instruction>> non_semantic_;
};
// Pretty-prints |func| to |str|. Returns |str|.
@@ -235,6 +247,11 @@
end_inst_ = std::move(end_inst);
}
+inline void Function::AddNonSemanticInstruction(
+ std::unique_ptr<Instruction> non_semantic) {
+ non_semantic_.emplace_back(std::move(non_semantic));
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp
index 9335231..8cc3d79 100644
--- a/source/opt/instruction.cpp
+++ b/source/opt/instruction.cpp
@@ -892,6 +892,16 @@
}
}
+bool Instruction::IsNonSemanticInstruction() const {
+ if (!HasResultId()) return false;
+ if (opcode() != SpvOpExtInst) return false;
+
+ auto import_inst =
+ context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(0));
+ std::string import_name = import_inst->GetInOperand(0).AsString();
+ return import_name.find("NonSemantic.") == 0;
+}
+
void DebugScope::ToBinary(uint32_t type_id, uint32_t result_id,
uint32_t ext_set,
std::vector<uint32_t>* binary) const {
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index 924f044..067f69c 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -549,6 +549,9 @@
return GetOpenCL100DebugOpcode() != OpenCLDebugInfo100InstructionsMax;
}
+ // Returns true if this instructions a non-semantic instruction.
+ bool IsNonSemanticInstruction() const;
+
// Dump this instruction on stderr. Useful when running interactive
// debuggers.
void Dump() const;
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index a56ff06..f147b0b 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -213,6 +213,30 @@
return next_instruction;
}
+void IRContext::KillNonSemanticInfo(Instruction* inst) {
+ if (!inst->HasResultId()) return;
+ std::vector<Instruction*> work_list;
+ std::vector<Instruction*> to_kill;
+ std::unordered_set<Instruction*> seen;
+ work_list.push_back(inst);
+
+ while (!work_list.empty()) {
+ auto* i = work_list.back();
+ work_list.pop_back();
+ get_def_use_mgr()->ForEachUser(
+ i, [&work_list, &to_kill, &seen](Instruction* user) {
+ if (user->IsNonSemanticInstruction() && seen.insert(user).second) {
+ work_list.push_back(user);
+ to_kill.push_back(user);
+ }
+ });
+ }
+
+ for (auto* dead : to_kill) {
+ KillInst(dead);
+ }
+}
+
bool IRContext::KillDef(uint32_t id) {
Instruction* def = get_def_use_mgr()->GetDef(id);
if (def != nullptr) {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index b193657..8c1b5d4 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -403,6 +403,9 @@
// instruction exists.
Instruction* KillInst(Instruction* inst);
+ // Removes the non-semantic instruction tree that uses |inst|'s result id.
+ void KillNonSemanticInfo(Instruction* inst);
+
// Returns true if all of the given analyses are valid.
bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }
diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp
index acd41cd..a10812e 100644
--- a/source/opt/ir_loader.cpp
+++ b/source/opt/ir_loader.cpp
@@ -167,13 +167,22 @@
} else if (IsTypeInst(opcode)) {
module_->AddType(std::move(spv_inst));
} else if (IsConstantInst(opcode) || opcode == SpvOpVariable ||
- opcode == SpvOpUndef ||
- (opcode == SpvOpExtInst &&
- spvExtInstIsNonSemantic(inst->ext_inst_type))) {
+ opcode == SpvOpUndef) {
module_->AddGlobalValue(std::move(spv_inst));
} else if (opcode == SpvOpExtInst &&
spvExtInstIsDebugInfo(inst->ext_inst_type)) {
module_->AddExtInstDebugInfo(std::move(spv_inst));
+ } else if (opcode == SpvOpExtInst &&
+ spvExtInstIsNonSemantic(inst->ext_inst_type)) {
+ // If there are no functions, add the non-semantic instructions to the
+ // global values. Otherwise append it to the list of the last function.
+ auto func_begin = module_->begin();
+ auto func_end = module_->end();
+ if (func_begin == func_end) {
+ module_->AddGlobalValue(std::move(spv_inst));
+ } else {
+ (--func_end)->AddNonSemanticInstruction(std::move(spv_inst));
+ }
} else {
Errorf(consumer_, src, loc,
"Unhandled inst type (opcode: %d) found outside function "
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index 2959d3d..6707631 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -98,7 +98,10 @@
DELEGATE(ext_inst_debuginfo_);
DELEGATE(annotations_);
DELEGATE(types_values_);
- for (auto& i : functions_) i->ForEachInst(f, run_on_debug_line_insts);
+ for (auto& i : functions_) {
+ i->ForEachInst(f, run_on_debug_line_insts,
+ /* run_on_non_semantic_insts = */ true);
+ }
#undef DELEGATE
}
@@ -120,8 +123,9 @@
for (auto& i : types_values_) DELEGATE(i);
for (auto& i : ext_inst_debuginfo_) DELEGATE(i);
for (auto& i : functions_) {
- static_cast<const Function*>(i.get())->ForEachInst(f,
- run_on_debug_line_insts);
+ static_cast<const Function*>(i.get())->ForEachInst(
+ f, run_on_debug_line_insts,
+ /* run_on_non_semantic_insts = */ true);
}
if (run_on_debug_line_insts) {
for (auto& i : trailing_dbg_line_info_) DELEGATE(i);
diff --git a/test/opt/eliminate_dead_functions_test.cpp b/test/opt/eliminate_dead_functions_test.cpp
index 2f8fa9a..96ecdc6 100644
--- a/test/opt/eliminate_dead_functions_test.cpp
+++ b/test/opt/eliminate_dead_functions_test.cpp
@@ -344,6 +344,101 @@
SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, false);
}
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoPersists) {
+ const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK: OpExtInst [[void]] [[import]] 2
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2
+)";
+
+ SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoRemoveDependent) {
+ const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK-NOT: OpExtInst [[void]] [[import]] 2
+; CHECK: OpExtInst [[void]] [[import]] 3
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2 %foo
+%non_semantic3 = OpExtInst %void %ext 3
+)";
+
+ SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoRemoveDependentTree) {
+ const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK-NOT: OpExtInst [[void]] [[import]] 2
+; CHECK: OpExtInst [[void]] [[import]] 3
+; CHECK-NOT: OpExtInst [[void]] [[import]] 4
+; CHECK-NOT: OpExtInst [[void]] [[import]] 5
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2 %foo
+%non_semantic3 = OpExtInst %void %ext 3
+%non_semantic4 = OpExtInst %void %ext 4 %non_semantic2
+%non_semantic5 = OpExtInst %void %ext 5 %non_semantic4
+)";
+
+ SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp
index 38ab298..b67ca49 100644
--- a/test/opt/function_test.cpp
+++ b/test/opt/function_test.cpp
@@ -168,6 +168,80 @@
EXPECT_FALSE(func->IsRecursive());
}
+TEST(FunctionTest, NonSemanticInfoSkipIteration) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%6 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%7 = OpExtInst %2 %1 2
+%8 = OpExtInst %2 %1 3
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 4);
+ ASSERT_TRUE(func != nullptr);
+ std::unordered_set<uint32_t> non_semantic_ids;
+ func->ForEachInst(
+ [&non_semantic_ids](const Instruction* inst) {
+ if (inst->opcode() == SpvOpExtInst) {
+ non_semantic_ids.insert(inst->result_id());
+ }
+ },
+ true, false);
+
+ EXPECT_EQ(1, non_semantic_ids.count(6));
+ EXPECT_EQ(0, non_semantic_ids.count(7));
+ EXPECT_EQ(0, non_semantic_ids.count(8));
+}
+
+TEST(FunctionTest, NonSemanticInfoIncludeIteration) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%6 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%7 = OpExtInst %2 %1 2
+%8 = OpExtInst %2 %1 3
+)";
+
+ std::unique_ptr<IRContext> ctx =
+ spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ auto* func = spvtest::GetFunction(ctx->module(), 4);
+ ASSERT_TRUE(func != nullptr);
+ std::unordered_set<uint32_t> non_semantic_ids;
+ func->ForEachInst(
+ [&non_semantic_ids](const Instruction* inst) {
+ if (inst->opcode() == SpvOpExtInst) {
+ non_semantic_ids.insert(inst->result_id());
+ }
+ },
+ true, true);
+
+ EXPECT_EQ(1, non_semantic_ids.count(6));
+ EXPECT_EQ(1, non_semantic_ids.count(7));
+ EXPECT_EQ(1, non_semantic_ids.count(8));
+}
+
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/module_test.cpp b/test/opt/module_test.cpp
index 406da09..a3c2eed 100644
--- a/test/opt/module_test.cpp
+++ b/test/opt/module_test.cpp
@@ -295,6 +295,47 @@
AssembleAndDisassemble(text);
}
+
+TEST(ModuleTest, NonSemanticInfoIteration) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpExtInst %2 %1 1
+%5 = OpFunction %2 None %3
+%6 = OpLabel
+%7 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%8 = OpExtInst %2 %1 1
+%9 = OpFunction %2 None %3
+%10 = OpLabel
+%11 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%12 = OpExtInst %2 %1 1
+)";
+
+ std::unique_ptr<IRContext> context = BuildModule(text);
+ std::unordered_set<uint32_t> non_semantic_ids;
+ context->module()->ForEachInst(
+ [&non_semantic_ids](const Instruction* inst) {
+ if (inst->opcode() == SpvOpExtInst) {
+ non_semantic_ids.insert(inst->result_id());
+ }
+ },
+ false);
+
+ EXPECT_EQ(1, non_semantic_ids.count(4));
+ EXPECT_EQ(1, non_semantic_ids.count(7));
+ EXPECT_EQ(1, non_semantic_ids.count(8));
+ EXPECT_EQ(1, non_semantic_ids.count(11));
+ EXPECT_EQ(1, non_semantic_ids.count(12));
+}
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/strip_reflect_info_test.cpp b/test/opt/strip_reflect_info_test.cpp
index 5db34b7..f3fc115 100644
--- a/test/opt/strip_reflect_info_test.cpp
+++ b/test/opt/strip_reflect_info_test.cpp
@@ -25,6 +25,7 @@
namespace {
using StripLineReflectInfoTest = PassTest<::testing::Test>;
+using StripNonSemanticInfoTest = PassTest<::testing::Test>;
// This test acts as an end-to-end code example on how to strip
// reflection info from a SPIR-V module. Use this code pattern
@@ -132,6 +133,99 @@
SinglePassRunAndCheck<StripReflectInfoPass>(before, after, false);
}
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticImport) {
+ std::string text = R"(
+; CHECK-NOT: OpExtension "SPV_KHR_non_semantic_info"
+; CHECK-NOT: OpExtInstImport
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+)";
+
+ SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticGlobal) {
+ std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%1 = OpExtInst %void %ext 1
+)";
+
+ SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticInFunction) {
+ std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+%1 = OpExtInst %void %ext 1 %foo
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticAfterFunction) {
+ std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%1 = OpExtInst %void %ext 1 %foo
+)";
+
+ SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticBetweenFunctions) {
+ std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%1 = OpExtInst %void %ext 1 %foo
+%bar = OpFunction %void None %void_fn
+%bar_entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools