Support SPV_KHR_terminate_invocation (#3568)
Covers:
- assembler
- disassembler
- validator
- optimizer
Co-authored-by: David Neto <dneto@google.com>
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index d393495..741f947 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -876,8 +876,10 @@
// for the first index.
Optimizer::PassToken CreateDescriptorScalarReplacementPass();
-// Create a pass to replace all OpKill instruction with a function call to a
-// function that has a single OpKill. This allows more code to be inlined.
+// Create a pass to replace each OpKill instruction with a function call to a
+// function that has a single OpKill. Also replace each OpTerminateInvocation
+// instruction with a function call to a function that has a single
+// OpTerminateInvocation. This allows more code to be inlined.
Optimizer::PassToken CreateWrapOpKillPass();
// Replaces the extensions VK_AMD_shader_ballot,VK_AMD_gcn_shader, and
diff --git a/source/opcode.cpp b/source/opcode.cpp
index 3781a8d..f93cfd3 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -446,7 +446,7 @@
bool spvOpcodeIsReturnOrAbort(SpvOp opcode) {
return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill ||
- opcode == SpvOpUnreachable;
+ opcode == SpvOpUnreachable || opcode == SpvOpTerminateInvocation;
}
bool spvOpcodeIsBlockTerminator(SpvOp opcode) {
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
index 9fcfd3a..b755787 100644
--- a/source/opt/aggressive_dead_code_elim_pass.cpp
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -986,6 +986,7 @@
"SPV_KHR_ray_tracing",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_terminate_invocation",
});
}
diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp
index da5073a..7e61506 100644
--- a/source/opt/dominator_tree.cpp
+++ b/source/opt/dominator_tree.cpp
@@ -176,7 +176,8 @@
// The tree construction requires 1 entry point, so we add a dummy node
// that is connected to all function exiting basic blocks.
// An exiting basic block is a block with an OpKill, OpUnreachable,
- // OpReturn or OpReturnValue as terminator instruction.
+ // OpReturn, OpReturnValue, or OpTerminateInvocation as terminator
+ // instruction.
for (BasicBlock& bb : f) {
if (bb.hasSuccessor()) {
BasicBlockListTy& pred_list = predecessors_[&bb];
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index cb5a126..ef94d0d 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -384,7 +384,8 @@
for (auto callee_block_itr = calleeFn->begin();
callee_block_itr != calleeFn->end(); ++callee_block_itr) {
if (callee_block_itr->tail()->opcode() == SpvOpUnreachable ||
- callee_block_itr->tail()->opcode() == SpvOpKill) {
+ callee_block_itr->tail()->opcode() == SpvOpKill ||
+ callee_block_itr->tail()->opcode() == SpvOpTerminateInvocation) {
returnLabelId = context()->TakeNextId();
break;
}
@@ -738,16 +739,18 @@
bool func_is_called_from_continue =
funcs_called_from_continue_.count(func->result_id()) != 0;
- if (func_is_called_from_continue && ContainsKill(func)) {
+ if (func_is_called_from_continue && ContainsKillOrTerminateInvocation(func)) {
return false;
}
return true;
}
-bool InlinePass::ContainsKill(Function* func) const {
- return !func->WhileEachInst(
- [](Instruction* inst) { return inst->opcode() != SpvOpKill; });
+bool InlinePass::ContainsKillOrTerminateInvocation(Function* func) const {
+ return !func->WhileEachInst([](Instruction* inst) {
+ const auto opcode = inst->opcode();
+ return (opcode != SpvOpKill) && (opcode != SpvOpTerminateInvocation);
+ });
}
void InlinePass::InitializeInline() {
diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h
index 202bc97..abe773a 100644
--- a/source/opt/inline_pass.h
+++ b/source/opt/inline_pass.h
@@ -139,8 +139,9 @@
// Return true if |func| is a function that can be inlined.
bool IsInlinableFunction(Function* func);
- // Returns true if |func| contains an OpKill instruction.
- bool ContainsKill(Function* func) const;
+ // Returns true if |func| contains an OpKill or OpTerminateInvocation
+ // instruction.
+ bool ContainsKillOrTerminateInvocation(Function* func) const;
// Update phis in succeeding blocks to point to new last block
void UpdateSucceedingPhis(
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index 05704c1..9b8c112 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -382,6 +382,7 @@
"SPV_KHR_ray_tracing",
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
+ "SPV_KHR_terminate_invocation",
});
}
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index 5757282..bd5d751 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -267,6 +267,7 @@
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_terminate_invocation",
});
}
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
index 6626d87..2384107 100644
--- a/source/opt/local_single_store_elim_pass.cpp
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -121,6 +121,7 @@
"SPV_KHR_ray_query",
"SPV_EXT_fragment_invocation_density",
"SPV_EXT_physical_storage_buffer",
+ "SPV_KHR_terminate_invocation",
});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp
index 10fac04..40cf6bc 100644
--- a/source/opt/loop_unroller.cpp
+++ b/source/opt/loop_unroller.cpp
@@ -997,7 +997,8 @@
const BasicBlock* block = context_->cfg()->block(label_id);
if (block->ctail()->opcode() == SpvOp::SpvOpKill ||
block->ctail()->opcode() == SpvOp::SpvOpReturn ||
- block->ctail()->opcode() == SpvOp::SpvOpReturnValue) {
+ block->ctail()->opcode() == SpvOp::SpvOpReturnValue ||
+ block->ctail()->opcode() == SpvOp::SpvOpTerminateInvocation) {
return false;
}
}
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index 51d23a7..2e253ad 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -60,7 +60,8 @@
return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp;
}
inline bool IsTerminatorInst(SpvOp opcode) {
- return opcode >= SpvOpBranch && opcode <= SpvOpUnreachable;
+ return (opcode >= SpvOpBranch && opcode <= SpvOpUnreachable) ||
+ (opcode == SpvOpTerminateInvocation);
}
} // namespace opt
diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp
index 4e0f24f..38b7539 100644
--- a/source/opt/replace_invalid_opc.cpp
+++ b/source/opt/replace_invalid_opc.cpp
@@ -141,6 +141,7 @@
// TODO: Teach |ReplaceInstruction| to handle block terminators. Then
// uncomment the OpKill case.
// case SpvOpKill:
+ // case SpvOpTerminateInstruction:
return true;
default:
return false;
diff --git a/source/opt/wrap_opkill.cpp b/source/opt/wrap_opkill.cpp
index 3c8bae6..4d70840 100644
--- a/source/opt/wrap_opkill.cpp
+++ b/source/opt/wrap_opkill.cpp
@@ -27,7 +27,8 @@
for (uint32_t func_id : func_to_process) {
Function* func = context()->GetFunction(func_id);
bool successful = func->WhileEachInst([this, &modified](Instruction* inst) {
- if (inst->opcode() == SpvOpKill) {
+ const auto opcode = inst->opcode();
+ if ((opcode == SpvOpKill) || (opcode == SpvOpTerminateInvocation)) {
modified = true;
if (!ReplaceWithFunctionCall(inst)) {
return false;
@@ -46,16 +47,22 @@
"The function should only be generated if something was modified.");
context()->AddFunction(std::move(opkill_function_));
}
+ if (opterminateinvocation_function_ != nullptr) {
+ assert(modified &&
+ "The function should only be generated if something was modified.");
+ context()->AddFunction(std::move(opterminateinvocation_function_));
+ }
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
bool WrapOpKill::ReplaceWithFunctionCall(Instruction* inst) {
- assert(inst->opcode() == SpvOpKill &&
- "|inst| must be an OpKill instruction.");
+ assert((inst->opcode() == SpvOpKill ||
+ inst->opcode() == SpvOpTerminateInvocation) &&
+ "|inst| must be an OpKill or OpTerminateInvocation instruction.");
InstructionBuilder ir_builder(
context(), inst,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
- uint32_t func_id = GetOpKillFuncId();
+ uint32_t func_id = GetKillingFuncId(inst->opcode());
if (func_id == 0) {
return false;
}
@@ -108,13 +115,20 @@
return type_mgr->GetTypeInstruction(&func_type);
}
-uint32_t WrapOpKill::GetOpKillFuncId() {
- if (opkill_function_ != nullptr) {
- return opkill_function_->result_id();
+uint32_t WrapOpKill::GetKillingFuncId(SpvOp opcode) {
+ // Parameterize by opcode
+ assert(opcode == SpvOpKill || opcode == SpvOpTerminateInvocation);
+
+ std::unique_ptr<Function>* const killing_func =
+ (opcode == SpvOpKill) ? &opkill_function_
+ : &opterminateinvocation_function_;
+
+ if (*killing_func != nullptr) {
+ return (*killing_func)->result_id();
}
- uint32_t opkill_func_id = TakeNextId();
- if (opkill_func_id == 0) {
+ uint32_t killing_func_id = TakeNextId();
+ if (killing_func_id == 0) {
return 0;
}
@@ -125,15 +139,15 @@
// Generate the function start instruction
std::unique_ptr<Instruction> func_start(new Instruction(
- context(), SpvOpFunction, void_type_id, opkill_func_id, {}));
+ context(), SpvOpFunction, void_type_id, killing_func_id, {}));
func_start->AddOperand({SPV_OPERAND_TYPE_FUNCTION_CONTROL, {0}});
func_start->AddOperand({SPV_OPERAND_TYPE_ID, {GetVoidFunctionTypeId()}});
- opkill_function_.reset(new Function(std::move(func_start)));
+ (*killing_func).reset(new Function(std::move(func_start)));
// Generate the function end instruction
std::unique_ptr<Instruction> func_end(
new Instruction(context(), SpvOpFunctionEnd, 0, 0, {}));
- opkill_function_->SetFunctionEnd(std::move(func_end));
+ (*killing_func)->SetFunctionEnd(std::move(func_end));
// Create the one basic block for the function.
uint32_t lab_id = TakeNextId();
@@ -146,21 +160,22 @@
// Add the OpKill to the basic block
std::unique_ptr<Instruction> kill_inst(
- new Instruction(context(), SpvOpKill, 0, 0, {}));
+ new Instruction(context(), opcode, 0, 0, {}));
bb->AddInstruction(std::move(kill_inst));
// Add the bb to the function
- bb->SetParent(opkill_function_.get());
- opkill_function_->AddBasicBlock(std::move(bb));
+ bb->SetParent((*killing_func).get());
+ (*killing_func)->AddBasicBlock(std::move(bb));
// Add the function to the module.
if (context()->AreAnalysesValid(IRContext::kAnalysisDefUse)) {
- opkill_function_->ForEachInst(
- [this](Instruction* inst) { context()->AnalyzeDefUse(inst); });
+ (*killing_func)->ForEachInst([this](Instruction* inst) {
+ context()->AnalyzeDefUse(inst);
+ });
}
if (context()->AreAnalysesValid(IRContext::kAnalysisInstrToBlockMapping)) {
- for (BasicBlock& basic_block : *opkill_function_) {
+ for (BasicBlock& basic_block : *(*killing_func)) {
context()->set_instr_block(basic_block.GetLabelInst(), &basic_block);
for (Instruction& inst : basic_block) {
context()->set_instr_block(&inst, &basic_block);
@@ -168,7 +183,7 @@
}
}
- return opkill_function_->result_id();
+ return (*killing_func)->result_id();
}
uint32_t WrapOpKill::GetOwningFunctionsReturnType(Instruction* inst) {
diff --git a/source/opt/wrap_opkill.h b/source/opt/wrap_opkill.h
index 09f2dfa..7e43ca6 100644
--- a/source/opt/wrap_opkill.h
+++ b/source/opt/wrap_opkill.h
@@ -38,10 +38,10 @@
}
private:
- // Replaces the OpKill instruction |inst| with a function call to a function
- // that contains a single instruction, which is OpKill. An OpUnreachable
- // instruction will be placed after the function call. Return true if
- // successful.
+ // Replaces the OpKill or OpTerminateInvocation instruction |inst| with a
+ // function call to a function that contains a single instruction, a clone of
+ // |inst|. An OpUnreachable instruction will be placed after the function
+ // call. Return true if successful.
bool ReplaceWithFunctionCall(Instruction* inst);
// Returns the id of the void type.
@@ -51,9 +51,9 @@
uint32_t GetVoidFunctionTypeId();
// Return the id of a function that has return type void, has no parameters,
- // and contains a single instruction, which is an OpKill. Returns 0 if the
- // function could not be generated.
- uint32_t GetOpKillFuncId();
+ // and contains a single instruction, which is |opcode|, either OpKill or
+ // OpTerminateInvocation. Returns 0 if the function could not be generated.
+ uint32_t GetKillingFuncId(SpvOp opcode);
// Returns the id of the return type for the function that contains |inst|.
// Returns 0 if |inst| is not in a function.
@@ -67,6 +67,11 @@
// function has a void return type and takes no parameters. If the function is
// |nullptr|, then the function has not been generated.
std::unique_ptr<Function> opkill_function_;
+ // The function that is a single instruction, which is an
+ // OpTerminateInvocation. The function has a void return type and takes no
+ // parameters. If the function is |nullptr|, then the function has not been
+ // generated.
+ std::unique_ptr<Function> opterminateinvocation_function_;
};
} // namespace opt
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index a2fe882..8eb3a96 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -1096,12 +1096,18 @@
case SpvOpKill:
case SpvOpReturnValue:
case SpvOpUnreachable:
+ case SpvOpTerminateInvocation:
_.current_function().RegisterBlockEnd(std::vector<uint32_t>());
if (opcode == SpvOpKill) {
_.current_function().RegisterExecutionModelLimitation(
SpvExecutionModelFragment,
"OpKill requires Fragment execution model");
}
+ if (opcode == SpvOpTerminateInvocation) {
+ _.current_function().RegisterExecutionModelLimitation(
+ SpvExecutionModelFragment,
+ "OpTerminateInvocation requires Fragment execution model");
+ }
break;
default:
break;
diff --git a/source/val/validate_instruction.cpp b/source/val/validate_instruction.cpp
index 6478b3c..9d395fb 100644
--- a/source/val/validate_instruction.cpp
+++ b/source/val/validate_instruction.cpp
@@ -296,7 +296,12 @@
<< SPV_SPIRV_VERSION_MINOR_PART(last_version) << " or earlier";
}
- if (inst_desc->numCapabilities > 0u) {
+ // OpTerminateInvocation is special because it is enabled by Shader
+ // capability, but also requries a extension and/or version check.
+ const bool capability_check_is_sufficient =
+ inst->opcode() != SpvOpTerminateInvocation;
+
+ if (capability_check_is_sufficient && (inst_desc->numCapabilities > 0u)) {
// We already checked that the direct capability dependency has been
// satisfied. We don't need to check any further.
return SPV_SUCCESS;
diff --git a/test/opt/block_merge_test.cpp b/test/opt/block_merge_test.cpp
index f1460c5..7381908 100644
--- a/test/opt/block_merge_test.cpp
+++ b/test/opt/block_merge_test.cpp
@@ -639,6 +639,40 @@
SinglePassRunAndMatch<BlockMergePass>(text, true);
}
+TEST_F(BlockMergeTest, DontMergeTerminateInvocation) {
+ const std::string text = R"(
+; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None
+; CHECK-NEXT: OpBranch [[ret:%\w+]]
+; CHECK: [[ret:%\w+]] = OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-DAG: [[cont]] = OpLabel
+; CHECK-DAG: [[merge]] = OpLabel
+OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %func "func"
+OpExecutionMode %func OriginUpperLeft
+%void = OpTypeVoid
+%bool = OpTypeBool
+%functy = OpTypeFunction %void
+%func = OpFunction %void None %functy
+%1 = OpLabel
+OpBranch %2
+%2 = OpLabel
+OpLoopMerge %3 %4 None
+OpBranch %5
+%5 = OpLabel
+OpTerminateInvocation
+%4 = OpLabel
+OpBranch %2
+%3 = OpLabel
+OpUnreachable
+OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<BlockMergePass>(text, true);
+}
+
TEST_F(BlockMergeTest, DontMergeUnreachable) {
const std::string text = R"(
; CHECK: OpLoopMerge [[merge:%\w+]] [[cont:%\w+]] None
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp
index fc2197c..ffd3e38 100644
--- a/test/opt/inline_test.cpp
+++ b/test/opt/inline_test.cpp
@@ -2453,6 +2453,103 @@
SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true);
}
+TEST_F(InlineTest, DontInlineFuncWithOpTerminateInvocationInContinue) {
+ const std::string test =
+ R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpBranch %9
+%9 = OpLabel
+OpLoopMerge %11 %12 None
+OpBranch %13
+%13 = OpLabel
+OpBranchConditional %true %10 %11
+%10 = OpLabel
+OpBranch %12
+%12 = OpLabel
+%16 = OpFunctionCall %void %kill_
+OpBranch %9
+%11 = OpLabel
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<InlineExhaustivePass>(test, test, false, true);
+}
+
+TEST_F(InlineTest, InlineFuncWithOpTerminateInvocationNotInContinue) {
+ const std::string before =
+ R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+%16 = OpFunctionCall %void %kill_
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 330
+OpName %main "main"
+OpName %kill_ "kill("
+%void = OpTypeVoid
+%3 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%main = OpFunction %void None %3
+%5 = OpLabel
+OpTerminateInvocation
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+%kill_ = OpFunction %void None %3
+%7 = OpLabel
+OpTerminateInvocation
+OpFunctionEnd
+)";
+
+ SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true);
+}
+
TEST_F(InlineTest, EarlyReturnFunctionInlined) {
// #version 140
//
diff --git a/test/opt/loop_optimizations/unroll_assumptions.cpp b/test/opt/loop_optimizations/unroll_assumptions.cpp
index 62f77d7..0f93302 100644
--- a/test/opt/loop_optimizations/unroll_assumptions.cpp
+++ b/test/opt/loop_optimizations/unroll_assumptions.cpp
@@ -467,6 +467,73 @@
SinglePassRunAndCheck<LoopUnroller>(text, text, false);
}
+TEST_F(PassClassTest, KillInBody) {
+ const std::string text = R"(OpCapability Shader
+OpMemoryModel Logical Simple
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpTypeBool
+%5 = OpTypeInt 32 0
+%6 = OpConstant %5 0
+%7 = OpConstant %5 1
+%8 = OpConstant %5 5
+%1 = OpFunction %2 None %3
+%9 = OpLabel
+OpBranch %10
+%10 = OpLabel
+%11 = OpPhi %5 %6 %9 %12 %13
+%14 = OpULessThan %4 %11 %8
+OpLoopMerge %15 %13 Unroll
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpKill
+%13 = OpLabel
+%12 = OpIAdd %5 %11 %7
+OpBranch %10
+%15 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<LoopUnroller>(text, text, false);
+}
+
+TEST_F(PassClassTest, TerminateInvocationInBody) {
+ const std::string text = R"(OpCapability Shader
+OpExtension "SPV_KHR_terminate_invocation"
+OpMemoryModel Logical Simple
+OpEntryPoint Fragment %1 "main"
+OpExecutionMode %1 OriginUpperLeft
+%2 = OpTypeVoid
+%3 = OpTypeFunction %2
+%4 = OpTypeBool
+%5 = OpTypeInt 32 0
+%6 = OpConstant %5 0
+%7 = OpConstant %5 1
+%8 = OpConstant %5 5
+%1 = OpFunction %2 None %3
+%9 = OpLabel
+OpBranch %10
+%10 = OpLabel
+%11 = OpPhi %5 %6 %9 %12 %13
+%14 = OpULessThan %4 %11 %8
+OpLoopMerge %15 %13 Unroll
+OpBranchConditional %14 %16 %15
+%16 = OpLabel
+OpTerminateInvocation
+%13 = OpLabel
+%12 = OpIAdd %5 %11 %7
+OpBranch %10
+%15 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ SetDisassembleOptions(SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
+ SinglePassRunAndCheck<LoopUnroller>(text, text, false);
+}
+
/*
Generated from the following GLSL
#version 440 core
diff --git a/test/opt/wrap_opkill_test.cpp b/test/opt/wrap_opkill_test.cpp
index 33e52f0..e944109 100644
--- a/test/opt/wrap_opkill_test.cpp
+++ b/test/opt/wrap_opkill_test.cpp
@@ -193,6 +193,310 @@
SinglePassRunAndMatch<WrapOpKill>(text, true);
}
+TEST_F(WrapOpKillTest, SingleOpTerminateInvocation) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %5
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpLoopMerge %10 %11 None
+ OpBranch %12
+ %12 = OpLabel
+ OpBranchConditional %true %13 %10
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %14 = OpFunctionCall %void %kill_
+ OpBranch %9
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %kill_ = OpFunction %void None %5
+ %15 = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleTerminateInvocationInSameFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpSelectionMerge
+; CHECK-NEXT: OpBranchConditional
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %5
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpLoopMerge %10 %11 None
+ OpBranch %12
+ %12 = OpLabel
+ OpBranchConditional %true %13 %10
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %14 = OpFunctionCall %void %kill_
+ OpBranch %9
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %kill_ = OpFunction %void None %5
+ %15 = OpLabel
+ OpSelectionMerge %16 None
+ OpBranchConditional %true %17 %18
+ %17 = OpLabel
+ OpTerminateInvocation
+ %18 = OpLabel
+ OpTerminateInvocation
+ %16 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, MultipleOpTerminateInvocationDifferentFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]]
+; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]]
+; CHECK: [[orig_kill1]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[orig_kill2]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %4
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ OpLoopMerge %9 %10 None
+ OpBranch %11
+ %11 = OpLabel
+ OpBranchConditional %true %12 %9
+ %12 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ %13 = OpFunctionCall %void %14
+ %15 = OpFunctionCall %void %16
+ OpBranch %8
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %14 = OpFunction %void None %4
+ %17 = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+ %16 = OpFunction %void None %4
+ %18 = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, KillAndTerminateInvocationSameFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill:%\w+]]
+; CHECK: [[orig_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpSelectionMerge
+; CHECK-NEXT: OpBranchConditional
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+; CHECK-NEXT: [[new_terminate]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %5 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %5
+ %8 = OpLabel
+ OpBranch %9
+ %9 = OpLabel
+ OpLoopMerge %10 %11 None
+ OpBranch %12
+ %12 = OpLabel
+ OpBranchConditional %true %13 %10
+ %13 = OpLabel
+ OpBranch %11
+ %11 = OpLabel
+ %14 = OpFunctionCall %void %kill_
+ OpBranch %9
+ %10 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %kill_ = OpFunction %void None %5
+ %15 = OpLabel
+ OpSelectionMerge %16 None
+ OpBranchConditional %true %17 %18
+ %17 = OpLabel
+ OpKill
+ %18 = OpLabel
+ OpTerminateInvocation
+ %16 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
+TEST_F(WrapOpKillTest, KillAndTerminateInvocationDifferentFunc) {
+ const std::string text = R"(
+; CHECK: OpEntryPoint Fragment [[main:%\w+]]
+; CHECK: [[main]] = OpFunction
+; CHECK: OpFunctionCall %void [[orig_kill1:%\w+]]
+; CHECK-NEXT: OpFunctionCall %void [[orig_kill2:%\w+]]
+; CHECK: [[orig_kill1]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_terminate:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[orig_kill2]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpFunctionCall %void [[new_kill:%\w+]]
+; CHECK-NEXT: OpReturn
+; CHECK: [[new_kill]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpKill
+; CHECK-NEXT: OpFunctionEnd
+; CHECK-NEXT: [[new_terminate]] = OpFunction
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpTerminateInvocation
+; CHECK-NEXT: OpFunctionEnd
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+ OpSource GLSL 330
+ OpName %main "main"
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+ %main = OpFunction %void None %4
+ %7 = OpLabel
+ OpBranch %8
+ %8 = OpLabel
+ OpLoopMerge %9 %10 None
+ OpBranch %11
+ %11 = OpLabel
+ OpBranchConditional %true %12 %9
+ %12 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ %13 = OpFunctionCall %void %14
+ %15 = OpFunctionCall %void %16
+ OpBranch %8
+ %9 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ %14 = OpFunction %void None %4
+ %17 = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+ %16 = OpFunction %void None %4
+ %18 = OpLabel
+ OpKill
+ OpFunctionEnd
+ )";
+
+ SinglePassRunAndMatch<WrapOpKill>(text, true);
+}
+
TEST_F(WrapOpKillTest, FuncWithReturnValue) {
const std::string text = R"(
; CHECK: OpEntryPoint Fragment [[main:%\w+]]
diff --git a/test/text_to_binary.control_flow_test.cpp b/test/text_to_binary.control_flow_test.cpp
index ee8fed4..3e117b8 100644
--- a/test/text_to_binary.control_flow_test.cpp
+++ b/test/text_to_binary.control_flow_test.cpp
@@ -388,12 +388,35 @@
}));
// clang-format on
+using OpKillTest = spvtest::TextToBinaryTest;
+
+INSTANTIATE_TEST_SUITE_P(OpKillTest, ControlFlowRoundTripTest,
+ Values("OpKill\n"));
+
+TEST_F(OpKillTest, ExtraArgsAssemblyError) {
+ const std::string input = "OpKill 1";
+ EXPECT_THAT(CompileFailure(input),
+ Eq("Expected <opcode> or <result-id> at the beginning of an "
+ "instruction, found '1'."));
+}
+
+using OpTerminateInvocationTest = spvtest::TextToBinaryTest;
+
+INSTANTIATE_TEST_SUITE_P(OpTerminateInvocationTest, ControlFlowRoundTripTest,
+ Values("OpTerminateInvocation\n"));
+
+TEST_F(OpTerminateInvocationTest, ExtraArgsAssemblyError) {
+ const std::string input = "OpTerminateInvocation 1";
+ EXPECT_THAT(CompileFailure(input),
+ Eq("Expected <opcode> or <result-id> at the beginning of an "
+ "instruction, found '1'."));
+}
+
// TODO(dneto): OpPhi
// TODO(dneto): OpLoopMerge
// TODO(dneto): OpLabel
// TODO(dneto): OpBranch
// TODO(dneto): OpSwitch
-// TODO(dneto): OpKill
// TODO(dneto): OpReturn
// TODO(dneto): OpReturnValue
// TODO(dneto): OpUnreachable
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index 138e711..23d7a19 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -38,6 +38,7 @@
val_entry_point.cpp
val_explicit_reserved_test.cpp
val_extensions_test.cpp
+ val_extension_spv_khr_terminate_invocation.cpp
val_ext_inst_test.cpp
${VAL_TEST_COMMON_SRCS}
LIBS ${SPIRV_TOOLS}
diff --git a/test/val/val_extension_spv_khr_terminate_invocation.cpp b/test/val/val_extension_spv_khr_terminate_invocation.cpp
new file mode 100644
index 0000000..4cabf9e
--- /dev/null
+++ b/test/val/val_extension_spv_khr_terminate_invocation.cpp
@@ -0,0 +1,150 @@
+// Copyright (c) 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for OpExtension validator rules.
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "source/enum_string_mapping.h"
+#include "source/extensions.h"
+#include "source/spirv_target_env.h"
+#include "test/test_fixture.h"
+#include "test/unit_spirv.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+using ::testing::ValuesIn;
+
+using ValidateSpvKHRTerminateInvocation = spvtest::ValidateBase<bool>;
+
+TEST_F(ValidateSpvKHRTerminateInvocation, Valid) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ OpMemoryModel Logical Simple
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresExtension) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpMemoryModel Logical Simple
+ OpEntryPoint Fragment %main "main"
+ OpExecutionMode %main OriginUpperLeft
+
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("TerminateInvocation requires one of the following "
+ "extensions: SPV_KHR_terminate_invocation"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresShaderCapability) {
+ const std::string str = R"(
+ OpCapability Kernel
+ OpCapability Addresses
+ OpExtension "SPV_KHR_terminate_invocation"
+ OpMemoryModel Physical32 OpenCL
+ OpEntryPoint Kernel %main "main"
+
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "TerminateInvocation requires one of these capabilities: Shader \n"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, RequiresFragmentShader) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ OpMemoryModel Logical Simple
+ OpEntryPoint GLCompute %main "main"
+
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+ OpTerminateInvocation
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpTerminateInvocation requires Fragment execution model"));
+}
+
+TEST_F(ValidateSpvKHRTerminateInvocation, IsTerminatorInstruction) {
+ const std::string str = R"(
+ OpCapability Shader
+ OpExtension "SPV_KHR_terminate_invocation"
+ OpMemoryModel Logical Simple
+ OpEntryPoint GLCompute %main "main"
+
+ %void = OpTypeVoid
+ %void_fn = OpTypeFunction %void
+
+ %main = OpFunction %void None %void_fn
+ %entry = OpLabel
+ OpTerminateInvocation
+ OpReturn
+ OpFunctionEnd
+)";
+ CompileSuccessfully(str.c_str());
+ EXPECT_NE(SPV_SUCCESS, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Return must appear in a block"));
+}
+
+} // namespace
+} // namespace val
+} // namespace spvtools
diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp
index 682c321..491a808 100644
--- a/test/val/val_extensions_test.cpp
+++ b/test/val/val_extensions_test.cpp
@@ -62,7 +62,8 @@
"SPV_EXT_shader_viewport_index_layer",
"SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
"SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
- "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing"));
+ "SPV_NV_shader_subgroup_partitioned", "SPV_EXT_descriptor_indexing",
+ "SPV_KHR_terminate_invocation"));
INSTANTIATE_TEST_SUITE_P(FailSilently, ValidateUnknownExtensions,
Values("ERROR_unknown_extension", "SPV_KHR_",