Support SPV_KHR_terminate_invocation (#3568)

Covers:
- assembler
- disassembler
- validator
- optimizer

Co-authored-by: David Neto <dneto@google.com>
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index d393495..741f947 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -876,8 +876,10 @@
 // for the first index.
 Optimizer::PassToken CreateDescriptorScalarReplacementPass();
 
-// Create a pass to replace all OpKill instruction with a function call to a
-// function that has a single OpKill.  This allows more code to be inlined.
+// Create a pass to replace each OpKill instruction with a function call to a
+// function that has a single OpKill.  Also replace each OpTerminateInvocation
+// instruction  with a function call to a function that has a single
+// OpTerminateInvocation.  This allows more code to be inlined.
 Optimizer::PassToken CreateWrapOpKillPass();
 
 // Replaces the extensions VK_AMD_shader_ballot,VK_AMD_gcn_shader, and
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 3781a8d..f93cfd3 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -446,7 +446,7 @@
 
 bool spvOpcodeIsReturnOrAbort(SpvOp opcode) {
   return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill ||
-         opcode == SpvOpUnreachable;
+         opcode == SpvOpUnreachable || opcode == SpvOpTerminateInvocation;
 }
 
 bool spvOpcodeIsBlockTerminator(SpvOp opcode) {
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 9fcfd3a..b755787 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -986,6 +986,7 @@
       "SPV_KHR_ray_tracing",
       "SPV_EXT_fragment_invocation_density",
       "SPV_EXT_physical_storage_buffer",
+      "SPV_KHR_terminate_invocation",
   });
 }
 
diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp
index da5073a..7e61506 100644
--- a/source/opt/dominator_tree.cpp
+++ b/source/opt/dominator_tree.cpp
@@ -176,7 +176,8 @@
     // The tree construction requires 1 entry point, so we add a dummy node
     // that is connected to all function exiting basic blocks.
     // An exiting basic block is a block with an OpKill, OpUnreachable,
-    // OpReturn or OpReturnValue as terminator instruction.
+    // OpReturn, OpReturnValue, or OpTerminateInvocation  as terminator
+    // instruction.
     for (BasicBlock& bb : f) {
       if (bb.hasSuccessor()) {
         BasicBlockListTy& pred_list = predecessors_[&bb];
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index cb5a126..ef94d0d 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -384,7 +384,8 @@
   for (auto callee_block_itr = calleeFn->begin();
        callee_block_itr != calleeFn->end(); ++callee_block_itr) {
     if (callee_block_itr->tail()->opcode() == SpvOpUnreachable ||
-        callee_block_itr->tail()->opcode() == SpvOpKill) {
+        callee_block_itr->tail()->opcode() == SpvOpKill ||
+        callee_block_itr->tail()->opcode() == SpvOpTerminateInvocation) {
       returnLabelId = context()->TakeNextId();
       break;
     }
@@ -738,16 +739,18 @@
   bool func_is_called_from_continue =
       funcs_called_from_continue_.count(func->result_id()) != 0;
 
-  if (func_is_called_from_continue && ContainsKill(func)) {
+  if (func_is_called_from_continue && ContainsKillOrTerminateInvocation(func)) {
     return false;
   }
 
   return true;
 }
 
-bool InlinePass::ContainsKill(Function* func) const {
-  return !func->WhileEachInst(
-      [](Instruction* inst) { return inst->opcode() != SpvOpKill; });
+bool InlinePass::ContainsKillOrTerminateInvocation(Function* func) const {
+  return !func->WhileEachInst([](Instruction* inst) {
+    const auto opcode = inst->opcode();
+    return (opcode != SpvOpKill) && (opcode != SpvOpTerminateInvocation);
+  });
 }
 
 void InlinePass::InitializeInline() {
diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h
index 202bc97..abe773a 100644
--- a/source/opt/inline_pass.h
+++ b/source/opt/inline_pass.h
@@ -139,8 +139,9 @@
   // Return true if |func| is a function that can be inlined.
   bool IsInlinableFunction(Function* func);
 
-  // Returns true if |func| contains an OpKill instruction.
-  bool ContainsKill(Function* func) const;
+  // Returns true if |func| contains an OpKill or OpTerminateInvocation
+  // instruction.
+  bool ContainsKillOrTerminateInvocation(Function* func) const;
 
   // Update phis in succeeding blocks to point to new last block
   void UpdateSucceedingPhis(
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index 05704c1..9b8c112 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -382,6 +382,7 @@
       "SPV_KHR_ray_tracing",
       "SPV_KHR_ray_query",
       "SPV_EXT_fragment_invocation_density",
+      "SPV_KHR_terminate_invocation",
   });
 }
 
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index 5757282..bd5d751 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -267,6 +267,7 @@
       "SPV_KHR_ray_query",
       "SPV_EXT_fragment_invocation_density",
       "SPV_EXT_physical_storage_buffer",
+      "SPV_KHR_terminate_invocation",
   });
 }
 
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index 6626d87..2384107 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -121,6 +121,7 @@
       "SPV_KHR_ray_query",
       "SPV_EXT_fragment_invocation_density",
       "SPV_EXT_physical_storage_buffer",
+      "SPV_KHR_terminate_invocation",
   });
 }
 bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp
index 10fac04..40cf6bc 100644
--- a/source/opt/loop_unroller.cpp
+++ b/source/opt/loop_unroller.cpp
@@ -997,7 +997,8 @@
     const BasicBlock* block = context_->cfg()->block(label_id);
     if (block->ctail()->opcode() == SpvOp::SpvOpKill ||
         block->ctail()->opcode() == SpvOp::SpvOpReturn ||
-        block->ctail()->opcode() == SpvOp::SpvOpReturnValue) {
+        block->ctail()->opcode() == SpvOp::SpvOpReturnValue ||
+        block->ctail()->opcode() == SpvOp::SpvOpTerminateInvocation) {
       return false;
     }
   }
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index 51d23a7..2e253ad 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -60,7 +60,8 @@
   return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp;
 }
 inline bool IsTerminatorInst(SpvOp opcode) {
-  return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable;
+  return (opcode >= SpvOpBranch && opcode <= SpvOpUnreachable) ||
+         (opcode == SpvOpTerminateInvocation);
 }
 
 }  // namespace opt
diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp
index 4e0f24f..38b7539 100644
--- a/source/opt/replace_invalid_opc.cpp
+++ b/source/opt/replace_invalid_opc.cpp
@@ -141,6 +141,7 @@
       // TODO: Teach |ReplaceInstruction| to handle block terminators.  Then
       // uncomment the OpKill case.
       // case SpvOpKill:
+      // case SpvOpTerminateInstruction:
       return true;
     default:
       return false;
diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp
index 3c8bae6..4d70840 100644
--- a/source/opt/wrap_opkill.cpp
+++ b/source/opt/wrap_opkill.cpp
@@ -27,7 +27,8 @@
   for (uint32_t func_id : func_to_process) {
     Function* func = context()->GetFunction(func_id);
     bool successful = func->WhileEachInst([this, &modified](Instruction* inst) {
-      if (inst->opcode() == SpvOpKill) {
+      const auto opcode = inst->opcode();
+      if ((opcode == SpvOpKill) || (opcode == SpvOpTerminateInvocation)) {
         modified = true;
         if (!ReplaceWithFunctionCall(inst)) {
           return false;
@@ -46,16 +47,22 @@
            "The function should only be generated if something was modified.");
     context()->AddFunction(std::move(opkill_function_));
   }
+  if (opterminateinvocation_function_ != nullptr) {
+    assert(modified &&
+           "The function should only be generated if something was modified.");
+    context()->AddFunction(std::move(opterminateinvocation_function_));
+  }
   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
 }
 
 bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
-  assert(inst->opcode() == SpvOpKill &&
-         "|inst| must be an OpKill instruction.");
+  assert((inst->opcode() == SpvOpKill ||
+          inst->opcode() == SpvOpTerminateInvocation) &&
+         "|inst| must be an OpKill or OpTerminateInvocation instruction.");
   InstructionBuilder ir_builder(
       context(), inst,
       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
-  uint32_t func_id = GetOpKillFuncId();
+  uint32_t func_id = GetKillingFuncId(inst->opcode());
   if (func_id == 0) {
     return false;
   }
@@ -108,13 +115,20 @@
   return type_mgr->GetTypeInstruction(&func_type);
 }
 
-uint32_t WrapOpKill::GetOpKillFuncId() {
-  if (opkill_function_ != nullptr) {
-    return opkill_function_->result_id();
+uint32_t WrapOpKill::GetKillingFuncId(SpvOp opcode) {
+  //  Parameterize by opcode
+  assert(opcode == SpvOpKill || opcode == SpvOpTerminateInvocation);
+
+  std::unique_ptr<Function>* const killing_func =
+      (opcode == SpvOpKill) ? &opkill_function_
+                            : &opterminateinvocation_function_;
+
+  if (*killing_func != nullptr) {
+    return (*killing_func)->result_id();
   }
 
-  uint32_t opkill_func_id = TakeNextId();
-  if (opkill_func_id == 0) {
+  uint32_t killing_func_id = TakeNextId();
+  if (killing_func_id == 0) {
     return 0;
   }
 
@@ -125,15 +139,15 @@
 
   // Generate the function start instruction
   std::unique_ptr<Instruction> func_start(new Instruction(
-      context(), SpvOpFunction, void_type_id, opkill_func_id, {}));
+      context(), SpvOpFunction, void_type_id, killing_func_id, {}));
   func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}});
   func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}});
-  opkill_function_.reset(new Function(std::move(func_start)));
+  (*killing_func).reset(new Function(std::move(func_start)));
 
   // Generate the function end instruction
   std::unique_ptr<Instruction> func_end(
       new Instruction(context(), SpvOpFunctionEnd, 0, 0, {}));
-  opkill_function_->SetFunctionEnd(std::move(func_end));
+  (*killing_func)->SetFunctionEnd(std::move(func_end));
 
   // Create the one basic block for the function.
   uint32_t lab_id = TakeNextId();
@@ -146,21 +160,22 @@
 
   // Add the OpKill to the basic block
   std::unique_ptr<Instruction> kill_inst(
-      new Instruction(context(), SpvOpKill, 0, 0, {}));
+      new Instruction(context(), opcode, 0, 0, {}));
   bb->AddInstruction(std::move(kill_inst));
 
   // Add the bb to the function
-  bb->SetParent(opkill_function_.get());
-  opkill_function_->AddBasicBlock(std::move(bb));
+  bb->SetParent((*killing_func).get());
+  (*killing_func)->AddBasicBlock(std::move(bb));
 
   // Add the function to the module.
   if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) {
-    opkill_function_->ForEachInst(
-        [this](Instruction* inst) { context()->AnalyzeDefUse(inst); });
+    (*killing_func)->ForEachInst([this](Instruction* inst) {
+      context()->AnalyzeDefUse(inst);
+    });
   }
 
   if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) {
-    for (BasicBlock& basic_block : *opkill_function_) {
+    for (BasicBlock& basic_block : *(*killing_func)) {
       context()->set_instr_block(basic_block.GetLabelInst(), &basic_block);
       for (Instruction& inst : basic_block) {
         context()->set_instr_block(&inst, &basic_block);
@@ -168,7 +183,7 @@
     }
   }
 
-  return opkill_function_->result_id();
+  return (*killing_func)->result_id();
 }
 
 uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) {
diff --git a/source/opt/wrap_opkill.h b/source/opt/wrap_opkill.h
index 09f2dfa..7e43ca6 100644
--- a/source/opt/wrap_opkill.h
+++ b/source/opt/wrap_opkill.h
@@ -38,10 +38,10 @@
   }
 
  private:
-  // Replaces the OpKill instruction |inst| with a function call to a function
-  // that contains a single instruction, which is OpKill.  An OpUnreachable
-  // instruction will be placed after the function call.  Return true if
-  // successful.
+  // Replaces the OpKill or OpTerminateInvocation instruction |inst| with a
+  // function call to a function that contains a single instruction, a clone of
+  // |inst|.  An OpUnreachable instruction will be placed after the function
+  // call. Return true if successful.
   bool ReplaceWithFunctionCall(Instruction* inst);
 
   // Returns the id of the void type.
@@ -51,9 +51,9 @@
   uint32_t GetVoidFunctionTypeId();
 
   // Return the id of a function that has return type void, has no parameters,
-  // and contains a single instruction, which is an OpKill.  Returns 0 if the
-  // function could not be generated.
-  uint32_t GetOpKillFuncId();
+  // and contains a single instruction, which is |opcode|, either OpKill or
+  // OpTerminateInvocation.  Returns 0 if the function could not be generated.
+  uint32_t GetKillingFuncId(SpvOp opcode);
 
   // Returns the id of the return type for the function that contains |inst|.
   // Returns 0 if |inst| is not in a function.
@@ -67,6 +67,11 @@
   // function has a void return type and takes no parameters. If the function is
   // |nullptr|, then the function has not been generated.
   std::unique_ptr<Function> opkill_function_;
+  // The function that is a single instruction, which is an
+  // OpTerminateInvocation. The function has a void return type and takes no
+  // parameters. If the function is |nullptr|, then the function has not been
+  // generated.
+  std::unique_ptr<Function> opterminateinvocation_function_;
 };
 
 }  // namespace opt
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index a2fe882..8eb3a96 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -1096,12 +1096,18 @@
     case SpvOpKill:
     case SpvOpReturnValue:
     case SpvOpUnreachable:
+    case SpvOpTerminateInvocation:
       _.current_function().RegisterBlockEnd(std::vector<uint32_t>());
       if (opcode == SpvOpKill) {
         _.current_function().RegisterExecutionModelLimitation(
             SpvExecutionModelFragment,
             "OpKill requires Fragment execution model");
       }
+      if (opcode == SpvOpTerminateInvocation) {
+        _.current_function().RegisterExecutionModelLimitation(
+            SpvExecutionModelFragment,
+            "OpTerminateInvocation requires Fragment execution model");
+      }
       break;
     default:
       break;
diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp
index 6478b3c..9d395fb 100644
--- a/source/val/validate_instruction.cpp
+++ b/source/val/validate_instruction.cpp
@@ -296,7 +296,12 @@
            << SPV_SPIRV_VERSION_MINOR_PART(last_version) << " or earlier";
   }
 
-  if (inst_desc->numCapabilities > 0u) {
+  // OpTerminateInvocation is special because it is enabled by Shader
+  // capability, but also requries a extension and/or version check.
+  const bool capability_check_is_sufficient =
+      inst->opcode() != SpvOpTerminateInvocation;
+
+  if (capability_check_is_sufficient && (inst_desc->numCapabilities > 0u)) {
     // We already checked that the direct capability dependency has been
     // satisfied. We don't need to check any further.
     return SPV_SUCCESS;
diff --git a/test/opt/block_merge_test.cpp b/test/opt/block_merge_test.cpp
index f1460c5..7381908 100644
--- a/test/opt/block_merge_test.cpp
+++ b/test/opt/block_merge_test.cpp
@@ -639,6 +639,40 @@
   SinglePassRunAndMatch<BlockMergePass>(text, true);
 }
 
+TEST_F(BlockMergeTest, DontMergeTerminateInvocation) {
+  const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None
+; CHECK-NEXT: OpBranch [[ret:%\w+]]
+; CHECK: [[ret:%\w+]] = OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-DAG: [[cont]] = OpLabel
+; CHECK-DAG: [[merge]] = OpLabel
+OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+OpExecutionMode %func OriginUpperLeft
+%void = OpTypeVoid
+%bool = OpTypeBool
+%functy = OpTypeFunction %void
+%func = OpFunction %void None %functy
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %3 %4 None
+OpBranch %5
+%5 = OpLabel
+OpTerminateInvocation
+%4 = OpLabel
+OpBranch %2
+%3 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<BlockMergePass>(text, true);
+}
+
 TEST_F(BlockMergeTest, DontMergeUnreachable) {
   const std::string text = R"(
 ; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp
index fc2197c..ffd3e38 100644
--- a/test/opt/inline_test.cpp
+++ b/test/opt/inline_test.cpp
@@ -2453,6 +2453,103 @@
   SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true);
 }
 
+TEST_F(InlineTest, DontInlineFuncWithOpTerminateInvocationInContinue) {
+  const std::string test =
+      R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpBranch %9
+%9 = OpLabel
+OpLoopMerge %11 %12 None
+OpBranch %13
+%13 = OpLabel
+OpBranchConditional %true %10 %11
+%10 = OpLabel
+OpBranch %12
+%12 = OpLabel
+%16 = OpFunctionCall %void %kill_
+OpBranch %9
+%11 = OpLabel
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<InlineExhaustivePass>(test, test, false, true);
+}
+
+TEST_F(InlineTest, InlineFuncWithOpTerminateInvocationNotInContinue) {
+  const std::string before =
+      R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+%16 = OpFunctionCall %void %kill_
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+  const std::string after =
+      R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpTerminateInvocation
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+  SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true);
+}
+
 TEST_F(InlineTest, EarlyReturnFunctionInlined) {
   // #version 140
   //
diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp
index 62f77d7..0f93302 100644
--- a/test/opt/loop_optimizations/unroll_assumptions.cpp
+++ b/test/opt/loop_optimizations/unroll_assumptions.cpp
@@ -467,6 +467,73 @@
   SinglePassRunAndCheck<LoopUnroller>(text, text, false);
 }
 
+TEST_F(PassClassTest, KillInBody) {
+  const std::string text = R"(OpCapability Shader
+OpMemoryModel Logical Simple
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpTypeBool
+%5 = OpTypeInt 32 0
+%6 = OpConstant %5 0
+%7 = OpConstant %5 1
+%8 = OpConstant %5 5
+%1 = OpFunction %2 None %3
+%9 = OpLabel
+OpBranch %10
+%10 = OpLabel
+%11 = OpPhi %5 %6 %9 %12 %13
+%14 = OpULessThan %4 %11 %8
+OpLoopMerge %15 %13 Unroll
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpKill
+%13 = OpLabel
+%12 = OpIAdd %5 %11 %7
+OpBranch %10
+%15 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<LoopUnroller>(text, text, false);
+}
+
+TEST_F(PassClassTest, TerminateInvocationInBody) {
+  const std::string text = R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+OpMemoryModel Logical Simple
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpTypeBool
+%5 = OpTypeInt 32 0
+%6 = OpConstant %5 0
+%7 = OpConstant %5 1
+%8 = OpConstant %5 5
+%1 = OpFunction %2 None %3
+%9 = OpLabel
+OpBranch %10
+%10 = OpLabel
+%11 = OpPhi %5 %6 %9 %12 %13
+%14 = OpULessThan %4 %11 %8
+OpLoopMerge %15 %13 Unroll
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpTerminateInvocation
+%13 = OpLabel
+%12 = OpIAdd %5 %11 %7
+OpBranch %10
+%15 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+  SinglePassRunAndCheck<LoopUnroller>(text, text, false);
+}
+
 /*
 Generated from the following GLSL
 #version 440 core
diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp
index 33e52f0..e944109 100644
--- a/test/opt/wrap_opkill_test.cpp
+++ b/test/opt/wrap_opkill_test.cpp
@@ -193,6 +193,310 @@
   SinglePassRunAndMatch<WrapOpKill>(text, true);
 }
 
+TEST_F(WrapOpKillTest, SingleOpTerminateInvocation) {
+  const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+       %void = OpTypeVoid
+          %5 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %5
+          %8 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpLoopMerge %10 %11 None
+               OpBranch %12
+         %12 = OpLabel
+               OpBranchConditional %true %13 %10
+         %13 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+         %14 = OpFunctionCall %void %kill_
+               OpBranch %9
+         %10 = OpLabel
+               OpReturn
+               OpFunctionEnd
+      %kill_ = OpFunction %void None %5
+         %15 = OpLabel
+               OpTerminateInvocation
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleTerminateInvocationInSameFunc) {
+  const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpSelectionMerge
+; CHECK-NEXT: OpBranchConditional
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+       %void = OpTypeVoid
+          %5 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %5
+          %8 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpLoopMerge %10 %11 None
+               OpBranch %12
+         %12 = OpLabel
+               OpBranchConditional %true %13 %10
+         %13 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+         %14 = OpFunctionCall %void %kill_
+               OpBranch %9
+         %10 = OpLabel
+               OpReturn
+               OpFunctionEnd
+      %kill_ = OpFunction %void None %5
+         %15 = OpLabel
+               OpSelectionMerge %16 None
+               OpBranchConditional %true %17 %18
+         %17 = OpLabel
+               OpTerminateInvocation
+         %18 = OpLabel
+               OpTerminateInvocation
+         %16 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleOpTerminateInvocationDifferentFunc) {
+  const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]]
+; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]]
+; CHECK: [[orig_kill1]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[orig_kill2]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %4
+          %7 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+               OpLoopMerge %9 %10 None
+               OpBranch %11
+         %11 = OpLabel
+               OpBranchConditional %true %12 %9
+         %12 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+         %13 = OpFunctionCall %void %14
+         %15 = OpFunctionCall %void %16
+               OpBranch %8
+          %9 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %14 = OpFunction %void None %4
+         %17 = OpLabel
+               OpTerminateInvocation
+               OpFunctionEnd
+         %16 = OpFunction %void None %4
+         %18 = OpLabel
+               OpTerminateInvocation
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, KillAndTerminateInvocationSameFunc) {
+  const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpSelectionMerge
+; CHECK-NEXT: OpBranchConditional
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+; CHECK-NEXT: [[new_terminate]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+       %void = OpTypeVoid
+          %5 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %5
+          %8 = OpLabel
+               OpBranch %9
+          %9 = OpLabel
+               OpLoopMerge %10 %11 None
+               OpBranch %12
+         %12 = OpLabel
+               OpBranchConditional %true %13 %10
+         %13 = OpLabel
+               OpBranch %11
+         %11 = OpLabel
+         %14 = OpFunctionCall %void %kill_
+               OpBranch %9
+         %10 = OpLabel
+               OpReturn
+               OpFunctionEnd
+      %kill_ = OpFunction %void None %5
+         %15 = OpLabel
+               OpSelectionMerge %16 None
+               OpBranchConditional %true %17 %18
+         %17 = OpLabel
+               OpKill
+         %18 = OpLabel
+               OpTerminateInvocation
+         %16 = OpLabel
+               OpReturn
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, KillAndTerminateInvocationDifferentFunc) {
+  const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]]
+; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]]
+; CHECK: [[orig_kill1]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[orig_kill2]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+; CHECK-NEXT: [[new_terminate]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+               OpCapability Shader
+               OpExtension "SPV_KHR_terminate_invocation"
+          %1 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main"
+               OpExecutionMode %main OriginUpperLeft
+               OpSource GLSL 330
+               OpName %main "main"
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+       %bool = OpTypeBool
+       %true = OpConstantTrue %bool
+       %main = OpFunction %void None %4
+          %7 = OpLabel
+               OpBranch %8
+          %8 = OpLabel
+               OpLoopMerge %9 %10 None
+               OpBranch %11
+         %11 = OpLabel
+               OpBranchConditional %true %12 %9
+         %12 = OpLabel
+               OpBranch %10
+         %10 = OpLabel
+         %13 = OpFunctionCall %void %14
+         %15 = OpFunctionCall %void %16
+               OpBranch %8
+          %9 = OpLabel
+               OpReturn
+               OpFunctionEnd
+         %14 = OpFunction %void None %4
+         %17 = OpLabel
+               OpTerminateInvocation
+               OpFunctionEnd
+         %16 = OpFunction %void None %4
+         %18 = OpLabel
+               OpKill
+               OpFunctionEnd
+  )";
+
+  SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
 TEST_F(WrapOpKillTest, FuncWithReturnValue) {
   const std::string text = R"(
 ; CHECK: OpEntryPoint Fragment [[main:%\w+]]
diff --git a/test/text_to_binary.control_flow_test.cpp b/test/text_to_binary.control_flow_test.cpp
index ee8fed4..3e117b8 100644
--- a/test/text_to_binary.control_flow_test.cpp
+++ b/test/text_to_binary.control_flow_test.cpp
@@ -388,12 +388,35 @@
     }));
 // clang-format on
 
+using OpKillTest = spvtest::TextToBinaryTest;
+
+INSTANTIATE_TEST_SUITE_P(OpKillTest, ControlFlowRoundTripTest,
+                         Values("OpKill\n"));
+
+TEST_F(OpKillTest, ExtraArgsAssemblyError) {
+  const std::string input = "OpKill 1";
+  EXPECT_THAT(CompileFailure(input),
+              Eq("Expected <opcode> or <result-id> at the beginning of an "
+                 "instruction, found '1'."));
+}
+
+using OpTerminateInvocationTest = spvtest::TextToBinaryTest;
+
+INSTANTIATE_TEST_SUITE_P(OpTerminateInvocationTest, ControlFlowRoundTripTest,
+                         Values("OpTerminateInvocation\n"));
+
+TEST_F(OpTerminateInvocationTest, ExtraArgsAssemblyError) {
+  const std::string input = "OpTerminateInvocation 1";
+  EXPECT_THAT(CompileFailure(input),
+              Eq("Expected <opcode> or <result-id> at the beginning of an "
+                 "instruction, found '1'."));
+}
+
 // TODO(dneto): OpPhi
 // TODO(dneto): OpLoopMerge
 // TODO(dneto): OpLabel
 // TODO(dneto): OpBranch
 // TODO(dneto): OpSwitch
-// TODO(dneto): OpKill
 // TODO(dneto): OpReturn
 // TODO(dneto): OpReturnValue
 // TODO(dneto): OpUnreachable
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 138e711..23d7a19 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -38,6 +38,7 @@
        val_entry_point.cpp
        val_explicit_reserved_test.cpp
        val_extensions_test.cpp
+       val_extension_spv_khr_terminate_invocation.cpp
        val_ext_inst_test.cpp
        ${VAL_TEST_COMMON_SRCS}
   LIBS ${SPIRV_TOOLS}
diff --git a/test/val/val_extension_spv_khr_terminate_invocation.cpp b/test/val/val_extension_spv_khr_terminate_invocation.cpp
new file mode 100644
index 0000000..4cabf9e
--- /dev/null
+++ b/test/val/val_extension_spv_khr_terminate_invocation.cpp
@@ -0,0 +1,150 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for OpExtension validator rules.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "source/enum_string_mapping.h"
+#include "source/extensions.h"
+#include "source/spirv_target_env.h"
+#include "test/test_fixture.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+using ::testing::ValuesIn;
+
+using ValidateSpvKHRTerminateInvocation = spvtest::ValidateBase<bool>;
+
+TEST_F(ValidateSpvKHRTerminateInvocation, Valid) {
+  const std::string str = R"(
+    OpCapability Shader
+    OpExtension "SPV_KHR_terminate_invocation"
+    OpMemoryModel Logical Simple
+    OpEntryPoint Fragment %main "main"
+    OpExecutionMode %main OriginUpperLeft
+    
+    %void    = OpTypeVoid
+    %void_fn = OpTypeFunction %void
+
+    %main = OpFunction %void None %void_fn
+    %entry = OpLabel
+    OpTerminateInvocation
+    OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresExtension) {
+  const std::string str = R"(
+    OpCapability Shader
+    OpMemoryModel Logical Simple
+    OpEntryPoint Fragment %main "main"
+    OpExecutionMode %main OriginUpperLeft
+    
+    %void    = OpTypeVoid
+    %void_fn = OpTypeFunction %void
+
+    %main = OpFunction %void None %void_fn
+    %entry = OpLabel
+    OpTerminateInvocation
+    OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("TerminateInvocation requires one of the following "
+                        "extensions: SPV_KHR_terminate_invocation"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresShaderCapability) {
+  const std::string str = R"(
+    OpCapability Kernel
+    OpCapability Addresses
+    OpExtension "SPV_KHR_terminate_invocation"
+    OpMemoryModel Physical32 OpenCL
+    OpEntryPoint Kernel %main "main"
+    
+    %void    = OpTypeVoid
+    %void_fn = OpTypeFunction %void
+
+    %main = OpFunction %void None %void_fn
+    %entry = OpLabel
+    OpTerminateInvocation
+    OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "TerminateInvocation requires one of these capabilities: Shader \n"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresFragmentShader) {
+  const std::string str = R"(
+    OpCapability Shader
+    OpExtension "SPV_KHR_terminate_invocation"
+    OpMemoryModel Logical Simple
+    OpEntryPoint GLCompute %main "main"
+    
+    %void    = OpTypeVoid
+    %void_fn = OpTypeFunction %void
+
+    %main = OpFunction %void None %void_fn
+    %entry = OpLabel
+    OpTerminateInvocation
+    OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpTerminateInvocation requires Fragment execution model"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, IsTerminatorInstruction) {
+  const std::string str = R"(
+    OpCapability Shader
+    OpExtension "SPV_KHR_terminate_invocation"
+    OpMemoryModel Logical Simple
+    OpEntryPoint GLCompute %main "main"
+    
+    %void    = OpTypeVoid
+    %void_fn = OpTypeFunction %void
+
+    %main = OpFunction %void None %void_fn
+    %entry = OpLabel
+    OpTerminateInvocation
+    OpReturn
+    OpFunctionEnd
+)";
+  CompileSuccessfully(str.c_str());
+  EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Return must appear in a block"));
+}
+
+}  // namespace
+}  // namespace val
+}  // namespace spvtools
diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp
index 682c321..491a808 100644
--- a/test/val/val_extensions_test.cpp
+++ b/test/val/val_extensions_test.cpp
@@ -62,7 +62,8 @@
         "SPV_EXT_shader_viewport_index_layer",
         "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
         "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
-        "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing"));
+        "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing",
+        "SPV_KHR_terminate_invocation"));
 
 INSTANTIATE_TEST_SUITE_P(FailSilently, ValidateUnknownExtensions,
                          Values("ERROR_unknown_extension", "SPV_KHR_",