Improve non-semantic instruction handling in the optimizer (#3693)

* No longer blindly add global non-semantic info instructions to global
  types and values
  * functions now have a list of non-semantic instructions that succeed
    them in the global scope
  * global non-semantic instructions go in global types and values if
    they appear before any function, otherwise they are attached to the
    immediate function predecessor in the module
* changed ADCE to use the function removal utility
* Modified EliminateFunction to have special handling for non-semantic
  instructions in the global scope
  * non-semantic instructions are moved to an earlier function (or full
    global set) if the function they are attached to is eliminated
  * Added IRContext::KillNonSemanticInfo to remove the tree of
    non-semantic instructions that use an instruction
  * this is used in function elimination
* There is still significant work in the optimizer to handle
  non-semantic instructions fully in the optimizer
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 71bbed1..c8688a3 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -22,6 +22,7 @@
 
 #include "source/cfa.h"
 #include "source/latest_version_glsl_std_450_header.h"
+#include "source/opt/eliminate_dead_functions_util.h"
 #include "source/opt/iterator.h"
 #include "source/opt/reflect.h"
 #include "source/spirv_constant.h"
@@ -727,8 +728,8 @@
        funcIter != get_module()->end();) {
     if (live_function_set.count(&*funcIter) == 0) {
       modified = true;
-      EliminateFunction(&*funcIter);
-      funcIter = funcIter.Erase();
+      funcIter =
+          eliminatedeadfunctionsutil::EliminateFunction(context(), &funcIter);
     } else {
       ++funcIter;
     }
@@ -737,12 +738,6 @@
   return modified;
 }
 
-void AggressiveDCEPass::EliminateFunction(Function* func) {
-  // Remove all of the instruction in the function body
-  func->ForEachInst([this](Instruction* inst) { context()->KillInst(inst); },
-                    true);
-}
-
 bool AggressiveDCEPass::ProcessGlobalValues() {
   // Remove debug and annotation statements referencing dead instructions.
   // This must be done before killing the instructions, otherwise there are
diff --git a/source/opt/aggressive_dead_code_elim_pass.h b/source/opt/aggressive_dead_code_elim_pass.h
index 2ce5b57..f02e729 100644
--- a/source/opt/aggressive_dead_code_elim_pass.h
+++ b/source/opt/aggressive_dead_code_elim_pass.h
@@ -127,9 +127,6 @@
   // Erases functions that are unreachable from the entry points of the module.
   bool EliminateDeadFunctions();
 
-  // Removes |func| from the module and deletes all its instructions.
-  void EliminateFunction(Function* func);
-
   // For function |func|, mark all Stores to non-function-scope variables
   // and block terminating instructions as live. Recursively mark the values
   // they use. When complete, mark any non-live instructions to be deleted.
diff --git a/source/opt/eliminate_dead_functions_util.cpp b/source/opt/eliminate_dead_functions_util.cpp
index 8a38959..6b5234b 100644
--- a/source/opt/eliminate_dead_functions_util.cpp
+++ b/source/opt/eliminate_dead_functions_util.cpp
@@ -21,9 +21,35 @@
 
 Module::iterator EliminateFunction(IRContext* context,
                                    Module::iterator* func_iter) {
+  bool first_func = *func_iter == context->module()->begin();
+  bool seen_func_end = false;
   (*func_iter)
-      ->ForEachInst([context](Instruction* inst) { context->KillInst(inst); },
-                    true);
+      ->ForEachInst(
+          [context, first_func, func_iter, &seen_func_end](Instruction* inst) {
+            if (inst->opcode() == SpvOpFunctionEnd) {
+              seen_func_end = true;
+            }
+            // Move non-semantic instructions to the previous function or
+            // global values if this is the first function.
+            if (seen_func_end && inst->opcode() == SpvOpExtInst) {
+              assert(inst->IsNonSemanticInstruction());
+              std::unique_ptr<Instruction> clone(inst->Clone(context));
+              context->ForgetUses(inst);
+              context->AnalyzeDefUse(clone.get());
+              if (first_func) {
+                context->AddGlobalValue(std::move(clone));
+              } else {
+                auto prev_func_iter = *func_iter;
+                --prev_func_iter;
+                prev_func_iter->AddNonSemanticInstruction(std::move(clone));
+              }
+              inst->ToNop();
+            } else {
+              context->KillNonSemanticInfo(inst);
+              context->KillInst(inst);
+            }
+          },
+          true, true);
   return func_iter->Erase();
 }
 
diff --git a/source/opt/function.cpp b/source/opt/function.cpp
index 320f8ca..21ce0c6 100644
--- a/source/opt/function.cpp
+++ b/source/opt/function.cpp
@@ -47,31 +47,40 @@
   }
 
   clone->SetFunctionEnd(std::unique_ptr<Instruction>(EndInst()->Clone(ctx)));
+
+  clone->non_semantic_.reserve(non_semantic_.size());
+  for (auto& non_semantic : non_semantic_) {
+    clone->AddNonSemanticInstruction(
+        std::unique_ptr<Instruction>(non_semantic->Clone(ctx)));
+  }
   return clone;
 }
 
 void Function::ForEachInst(const std::function<void(Instruction*)>& f,
-                           bool run_on_debug_line_insts) {
+                           bool run_on_debug_line_insts,
+                           bool run_on_non_semantic_insts) {
   WhileEachInst(
       [&f](Instruction* inst) {
         f(inst);
         return true;
       },
-      run_on_debug_line_insts);
+      run_on_debug_line_insts, run_on_non_semantic_insts);
 }
 
 void Function::ForEachInst(const std::function<void(const Instruction*)>& f,
-                           bool run_on_debug_line_insts) const {
+                           bool run_on_debug_line_insts,
+                           bool run_on_non_semantic_insts) const {
   WhileEachInst(
       [&f](const Instruction* inst) {
         f(inst);
         return true;
       },
-      run_on_debug_line_insts);
+      run_on_debug_line_insts, run_on_non_semantic_insts);
 }
 
 bool Function::WhileEachInst(const std::function<bool(Instruction*)>& f,
-                             bool run_on_debug_line_insts) {
+                             bool run_on_debug_line_insts,
+                             bool run_on_non_semantic_insts) {
   if (def_inst_) {
     if (!def_inst_->WhileEachInst(f, run_on_debug_line_insts)) {
       return false;
@@ -99,13 +108,26 @@
     }
   }
 
-  if (end_inst_) return end_inst_->WhileEachInst(f, run_on_debug_line_insts);
+  if (end_inst_) {
+    if (!end_inst_->WhileEachInst(f, run_on_debug_line_insts)) {
+      return false;
+    }
+  }
+
+  if (run_on_non_semantic_insts) {
+    for (auto& non_semantic : non_semantic_) {
+      if (!non_semantic->WhileEachInst(f, run_on_debug_line_insts)) {
+        return false;
+      }
+    }
+  }
 
   return true;
 }
 
 bool Function::WhileEachInst(const std::function<bool(const Instruction*)>& f,
-                             bool run_on_debug_line_insts) const {
+                             bool run_on_debug_line_insts,
+                             bool run_on_non_semantic_insts) const {
   if (def_inst_) {
     if (!static_cast<const Instruction*>(def_inst_.get())
              ->WhileEachInst(f, run_on_debug_line_insts)) {
@@ -133,9 +155,21 @@
     }
   }
 
-  if (end_inst_)
-    return static_cast<const Instruction*>(end_inst_.get())
-        ->WhileEachInst(f, run_on_debug_line_insts);
+  if (end_inst_) {
+    if (!static_cast<const Instruction*>(end_inst_.get())
+             ->WhileEachInst(f, run_on_debug_line_insts)) {
+      return false;
+    }
+  }
+
+  if (run_on_non_semantic_insts) {
+    for (auto& non_semantic : non_semantic_) {
+      if (!static_cast<const Instruction*>(non_semantic.get())
+               ->WhileEachInst(f, run_on_debug_line_insts)) {
+        return false;
+      }
+    }
+  }
 
   return true;
 }
diff --git a/source/opt/function.h b/source/opt/function.h
index f5035f0..1d11a09 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -79,6 +79,11 @@
   // Saves the given function end instruction.
   inline void SetFunctionEnd(std::unique_ptr<Instruction> end_inst);
 
+  // Add a non-semantic instruction that succeeds this function in the module.
+  // These instructions are maintained in the order they are added.
+  inline void AddNonSemanticInstruction(
+      std::unique_ptr<Instruction> non_semantic);
+
   // Returns the given function end instruction.
   inline Instruction* EndInst() { return end_inst_.get(); }
   inline const Instruction* EndInst() const { return end_inst_.get(); }
@@ -115,19 +120,24 @@
   }
 
   // Runs the given function |f| on instructions in this function, in order,
-  // and optionally on debug line instructions that might precede them.
+  // and optionally on debug line instructions that might precede them and
+  // non-semantic instructions that succceed the function.
   void ForEachInst(const std::function<void(Instruction*)>& f,
-                   bool run_on_debug_line_insts = false);
+                   bool run_on_debug_line_insts = false,
+                   bool run_on_non_semantic_insts = false);
   void ForEachInst(const std::function<void(const Instruction*)>& f,
-                   bool run_on_debug_line_insts = false) const;
+                   bool run_on_debug_line_insts = false,
+                   bool run_on_non_semantic_insts = false) const;
   // Runs the given function |f| on instructions in this function, in order,
-  // and optionally on debug line instructions that might precede them.
-  // If |f| returns false, iteration is terminated and this function returns
-  // false.
+  // and optionally on debug line instructions that might precede them and
+  // non-semantic instructions that succeed the function.  If |f| returns
+  // false, iteration is terminated and this function returns false.
   bool WhileEachInst(const std::function<bool(Instruction*)>& f,
-                     bool run_on_debug_line_insts = false);
+                     bool run_on_debug_line_insts = false,
+                     bool run_on_non_semantic_insts = false);
   bool WhileEachInst(const std::function<bool(const Instruction*)>& f,
-                     bool run_on_debug_line_insts = false) const;
+                     bool run_on_debug_line_insts = false,
+                     bool run_on_non_semantic_insts = false) const;
 
   // Runs the given function |f| on each parameter instruction in this function,
   // in order, and optionally on debug line instructions that might precede
@@ -172,6 +182,8 @@
   std::vector<std::unique_ptr<BasicBlock>> blocks_;
   // The OpFunctionEnd instruction.
   std::unique_ptr<Instruction> end_inst_;
+  // Non-semantic instructions succeeded by this function.
+  std::vector<std::unique_ptr<Instruction>> non_semantic_;
 };
 
 // Pretty-prints |func| to |str|. Returns |str|.
@@ -235,6 +247,11 @@
   end_inst_ = std::move(end_inst);
 }
 
+inline void Function::AddNonSemanticInstruction(
+    std::unique_ptr<Instruction> non_semantic) {
+  non_semantic_.emplace_back(std::move(non_semantic));
+}
+
 }  // namespace opt
 }  // namespace spvtools
 
diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp
index 9335231..8cc3d79 100644
--- a/source/opt/instruction.cpp
+++ b/source/opt/instruction.cpp
@@ -892,6 +892,16 @@
   }
 }
 
+bool Instruction::IsNonSemanticInstruction() const {
+  if (!HasResultId()) return false;
+  if (opcode() != SpvOpExtInst) return false;
+
+  auto import_inst =
+      context()->get_def_use_mgr()->GetDef(GetSingleWordInOperand(0));
+  std::string import_name = import_inst->GetInOperand(0).AsString();
+  return import_name.find("NonSemantic.") == 0;
+}
+
 void DebugScope::ToBinary(uint32_t type_id, uint32_t result_id,
                           uint32_t ext_set,
                           std::vector<uint32_t>* binary) const {
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index 924f044..067f69c 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -549,6 +549,9 @@
     return GetOpenCL100DebugOpcode() != OpenCLDebugInfo100InstructionsMax;
   }
 
+  // Returns true if this instructions a non-semantic instruction.
+  bool IsNonSemanticInstruction() const;
+
   // Dump this instruction on stderr.  Useful when running interactive
   // debuggers.
   void Dump() const;
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index a56ff06..f147b0b 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -213,6 +213,30 @@
   return next_instruction;
 }
 
+void IRContext::KillNonSemanticInfo(Instruction* inst) {
+  if (!inst->HasResultId()) return;
+  std::vector<Instruction*> work_list;
+  std::vector<Instruction*> to_kill;
+  std::unordered_set<Instruction*> seen;
+  work_list.push_back(inst);
+
+  while (!work_list.empty()) {
+    auto* i = work_list.back();
+    work_list.pop_back();
+    get_def_use_mgr()->ForEachUser(
+        i, [&work_list, &to_kill, &seen](Instruction* user) {
+          if (user->IsNonSemanticInstruction() && seen.insert(user).second) {
+            work_list.push_back(user);
+            to_kill.push_back(user);
+          }
+        });
+  }
+
+  for (auto* dead : to_kill) {
+    KillInst(dead);
+  }
+}
+
 bool IRContext::KillDef(uint32_t id) {
   Instruction* def = get_def_use_mgr()->GetDef(id);
   if (def != nullptr) {
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index b193657..8c1b5d4 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -403,6 +403,9 @@
   // instruction exists.
   Instruction* KillInst(Instruction* inst);
 
+  // Removes the non-semantic instruction tree that uses |inst|'s result id.
+  void KillNonSemanticInfo(Instruction* inst);
+
   // Returns true if all of the given analyses are valid.
   bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }
 
diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp
index acd41cd..a10812e 100644
--- a/source/opt/ir_loader.cpp
+++ b/source/opt/ir_loader.cpp
@@ -167,13 +167,22 @@
       } else if (IsTypeInst(opcode)) {
         module_->AddType(std::move(spv_inst));
       } else if (IsConstantInst(opcode) || opcode == SpvOpVariable ||
-                 opcode == SpvOpUndef ||
-                 (opcode == SpvOpExtInst &&
-                  spvExtInstIsNonSemantic(inst->ext_inst_type))) {
+                 opcode == SpvOpUndef) {
         module_->AddGlobalValue(std::move(spv_inst));
       } else if (opcode == SpvOpExtInst &&
                  spvExtInstIsDebugInfo(inst->ext_inst_type)) {
         module_->AddExtInstDebugInfo(std::move(spv_inst));
+      } else if (opcode == SpvOpExtInst &&
+                 spvExtInstIsNonSemantic(inst->ext_inst_type)) {
+        // If there are no functions, add the non-semantic instructions to the
+        // global values. Otherwise append it to the list of the last function.
+        auto func_begin = module_->begin();
+        auto func_end = module_->end();
+        if (func_begin == func_end) {
+          module_->AddGlobalValue(std::move(spv_inst));
+        } else {
+          (--func_end)->AddNonSemanticInstruction(std::move(spv_inst));
+        }
       } else {
         Errorf(consumer_, src, loc,
                "Unhandled inst type (opcode: %d) found outside function "
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index 2959d3d..6707631 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -98,7 +98,10 @@
   DELEGATE(ext_inst_debuginfo_);
   DELEGATE(annotations_);
   DELEGATE(types_values_);
-  for (auto& i : functions_) i->ForEachInst(f, run_on_debug_line_insts);
+  for (auto& i : functions_) {
+    i->ForEachInst(f, run_on_debug_line_insts,
+                   /* run_on_non_semantic_insts = */ true);
+  }
 #undef DELEGATE
 }
 
@@ -120,8 +123,9 @@
   for (auto& i : types_values_) DELEGATE(i);
   for (auto& i : ext_inst_debuginfo_) DELEGATE(i);
   for (auto& i : functions_) {
-    static_cast<const Function*>(i.get())->ForEachInst(f,
-                                                       run_on_debug_line_insts);
+    static_cast<const Function*>(i.get())->ForEachInst(
+        f, run_on_debug_line_insts,
+        /* run_on_non_semantic_insts = */ true);
   }
   if (run_on_debug_line_insts) {
     for (auto& i : trailing_dbg_line_info_) DELEGATE(i);
diff --git a/test/opt/eliminate_dead_functions_test.cpp b/test/opt/eliminate_dead_functions_test.cpp
index 2f8fa9a..96ecdc6 100644
--- a/test/opt/eliminate_dead_functions_test.cpp
+++ b/test/opt/eliminate_dead_functions_test.cpp
@@ -344,6 +344,101 @@
   SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, false);
 }
 
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoPersists) {
+  const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK: OpExtInst [[void]] [[import]] 2
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2
+)";
+
+  SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoRemoveDependent) {
+  const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK-NOT: OpExtInst [[void]] [[import]] 2
+; CHECK: OpExtInst [[void]] [[import]] 3
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2 %foo
+%non_semantic3 = OpExtInst %void %ext 3 
+)";
+
+  SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
+TEST_F(EliminateDeadFunctionsBasicTest, NonSemanticInfoRemoveDependentTree) {
+  const std::string text = R"(
+; CHECK: [[import:%\w+]] = OpExtInstImport
+; CHECK: [[void:%\w+]] = OpTypeVoid
+; CHECK-NOT: OpExtInst [[void]] [[import]] 1
+; CHECK-NOT: OpExtInst [[void]] [[import]] 2
+; CHECK: OpExtInst [[void]] [[import]] 3
+; CHECK-NOT: OpExtInst [[void]] [[import]] 4
+; CHECK-NOT: OpExtInst [[void]] [[import]] 5
+OpCapability Shader
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+OpEntryPoint GLCompute %main "main"
+OpExecutionMode %main LocalSize 1 1 1
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%main = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%foo = OpFunction %void None %void_fn
+%foo_entry = OpLabel
+%non_semantic1 = OpExtInst %void %ext 1
+OpReturn
+OpFunctionEnd
+%non_semantic2 = OpExtInst %void %ext 2 %foo
+%non_semantic3 = OpExtInst %void %ext 3 
+%non_semantic4 = OpExtInst %void %ext 4 %non_semantic2
+%non_semantic5 = OpExtInst %void %ext 5 %non_semantic4
+)";
+
+  SinglePassRunAndMatch<EliminateDeadFunctionsPass>(text, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools
diff --git a/test/opt/function_test.cpp b/test/opt/function_test.cpp
index 38ab298..b67ca49 100644
--- a/test/opt/function_test.cpp
+++ b/test/opt/function_test.cpp
@@ -168,6 +168,80 @@
   EXPECT_FALSE(func->IsRecursive());
 }
 
+TEST(FunctionTest, NonSemanticInfoSkipIteration) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%6 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%7 = OpExtInst %2 %1 2
+%8 = OpExtInst %2 %1 3
+)";
+
+  std::unique_ptr<IRContext> ctx =
+      spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                            SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  auto* func = spvtest::GetFunction(ctx->module(), 4);
+  ASSERT_TRUE(func != nullptr);
+  std::unordered_set<uint32_t> non_semantic_ids;
+  func->ForEachInst(
+      [&non_semantic_ids](const Instruction* inst) {
+        if (inst->opcode() == SpvOpExtInst) {
+          non_semantic_ids.insert(inst->result_id());
+        }
+      },
+      true, false);
+
+  EXPECT_EQ(1, non_semantic_ids.count(6));
+  EXPECT_EQ(0, non_semantic_ids.count(7));
+  EXPECT_EQ(0, non_semantic_ids.count(8));
+}
+
+TEST(FunctionTest, NonSemanticInfoIncludeIteration) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpFunction %2 None %3
+%5 = OpLabel
+%6 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%7 = OpExtInst %2 %1 2
+%8 = OpExtInst %2 %1 3
+)";
+
+  std::unique_ptr<IRContext> ctx =
+      spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+                            SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  auto* func = spvtest::GetFunction(ctx->module(), 4);
+  ASSERT_TRUE(func != nullptr);
+  std::unordered_set<uint32_t> non_semantic_ids;
+  func->ForEachInst(
+      [&non_semantic_ids](const Instruction* inst) {
+        if (inst->opcode() == SpvOpExtInst) {
+          non_semantic_ids.insert(inst->result_id());
+        }
+      },
+      true, true);
+
+  EXPECT_EQ(1, non_semantic_ids.count(6));
+  EXPECT_EQ(1, non_semantic_ids.count(7));
+  EXPECT_EQ(1, non_semantic_ids.count(8));
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools
diff --git a/test/opt/module_test.cpp b/test/opt/module_test.cpp
index 406da09..a3c2eed 100644
--- a/test/opt/module_test.cpp
+++ b/test/opt/module_test.cpp
@@ -295,6 +295,47 @@
 
   AssembleAndDisassemble(text);
 }
+
+TEST(ModuleTest, NonSemanticInfoIteration) {
+  const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%1 = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpExtInst %2 %1 1
+%5 = OpFunction %2 None %3
+%6 = OpLabel
+%7 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%8 = OpExtInst %2 %1 1
+%9 = OpFunction %2 None %3
+%10 = OpLabel
+%11 = OpExtInst %2 %1 1
+OpReturn
+OpFunctionEnd
+%12 = OpExtInst %2 %1 1
+)";
+
+  std::unique_ptr<IRContext> context = BuildModule(text);
+  std::unordered_set<uint32_t> non_semantic_ids;
+  context->module()->ForEachInst(
+      [&non_semantic_ids](const Instruction* inst) {
+        if (inst->opcode() == SpvOpExtInst) {
+          non_semantic_ids.insert(inst->result_id());
+        }
+      },
+      false);
+
+  EXPECT_EQ(1, non_semantic_ids.count(4));
+  EXPECT_EQ(1, non_semantic_ids.count(7));
+  EXPECT_EQ(1, non_semantic_ids.count(8));
+  EXPECT_EQ(1, non_semantic_ids.count(11));
+  EXPECT_EQ(1, non_semantic_ids.count(12));
+}
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools
diff --git a/test/opt/strip_reflect_info_test.cpp b/test/opt/strip_reflect_info_test.cpp
index 5db34b7..f3fc115 100644
--- a/test/opt/strip_reflect_info_test.cpp
+++ b/test/opt/strip_reflect_info_test.cpp
@@ -25,6 +25,7 @@
 namespace {
 
 using StripLineReflectInfoTest = PassTest<::testing::Test>;
+using StripNonSemanticInfoTest = PassTest<::testing::Test>;
 
 // This test acts as an end-to-end code example on how to strip
 // reflection info from a SPIR-V module.  Use this code pattern
@@ -132,6 +133,99 @@
   SinglePassRunAndCheck<StripReflectInfoPass>(before, after, false);
 }
 
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticImport) {
+  std::string text = R"(
+; CHECK-NOT: OpExtension "SPV_KHR_non_semantic_info"
+; CHECK-NOT: OpExtInstImport
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+)";
+
+  SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticGlobal) {
+  std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%1 = OpExtInst %void %ext 1
+)";
+
+  SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticInFunction) {
+  std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+%1 = OpExtInst %void %ext 1 %foo
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticAfterFunction) {
+  std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%1 = OpExtInst %void %ext 1 %foo
+)";
+
+  SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
+TEST_F(StripNonSemanticInfoTest, StripNonSemanticBetweenFunctions) {
+  std::string text = R"(
+; CHECK-NOT: OpExtInst
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_non_semantic_info"
+%ext = OpExtInstImport "NonSemantic.Test"
+OpMemoryModel Logical GLSL450
+%void = OpTypeVoid
+%void_fn = OpTypeFunction %void
+%foo = OpFunction %void None %void_fn
+%entry = OpLabel
+OpReturn
+OpFunctionEnd
+%1 = OpExtInst %void %ext 1 %foo
+%bar = OpFunction %void None %void_fn
+%bar_entry = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<StripReflectInfoPass>(text, true);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools