Take new (raytracing) termination instructions into account. (#4050)
* Take new (raytracing) termination instructions into account.
* Remove duplicate function and add unit test.
* Use KHR for symbols in the test.
diff --git a/source/opcode.cpp b/source/opcode.cpp
index c80e3a0..d87e828 100644
--- a/source/opcode.cpp
+++ b/source/opcode.cpp
@@ -444,15 +444,32 @@
}
}
+bool spvOpcodeIsAbort(SpvOp opcode) {
+ switch (opcode) {
+ case SpvOpKill:
+ case SpvOpUnreachable:
+ case SpvOpTerminateInvocation:
+ case SpvOpTerminateRayKHR:
+ case SpvOpIgnoreIntersectionKHR:
+ return true;
+ default:
+ return false;
+ }
+}
+
bool spvOpcodeIsReturnOrAbort(SpvOp opcode) {
- return spvOpcodeIsReturn(opcode) || opcode == SpvOpKill ||
- opcode == SpvOpUnreachable || opcode == SpvOpTerminateInvocation;
+ return spvOpcodeIsReturn(opcode) || spvOpcodeIsAbort(opcode);
}
bool spvOpcodeIsBlockTerminator(SpvOp opcode) {
return spvOpcodeIsBranch(opcode) || spvOpcodeIsReturnOrAbort(opcode);
}
+bool spvOpcodeTerminatesExecution(SpvOp opcode) {
+ return opcode == SpvOpKill || opcode == SpvOpTerminateInvocation ||
+ opcode == SpvOpTerminateRayKHR || opcode == SpvOpIgnoreIntersectionKHR;
+}
+
bool spvOpcodeIsBaseOpaqueType(SpvOp opcode) {
switch (opcode) {
case SpvOpTypeImage:
diff --git a/source/opcode.h b/source/opcode.h
index 3702cb3..c8525a2 100644
--- a/source/opcode.h
+++ b/source/opcode.h
@@ -110,10 +110,18 @@
// Returns true if the given opcode is a return instruction.
bool spvOpcodeIsReturn(SpvOp opcode);
+// Returns true if the given opcode aborts execution.
+bool spvOpcodeIsAbort(SpvOp opcode);
+
// Returns true if the given opcode is a return instruction or it aborts
// execution.
bool spvOpcodeIsReturnOrAbort(SpvOp opcode);
+// Returns true if the given opcode is a kill instruction or it terminates
+// execution. Note that branches, returns, and unreachables do not terminate
+// execution.
+bool spvOpcodeTerminatesExecution(SpvOp opcode);
+
// Returns true if the given opcode is a basic block terminator.
bool spvOpcodeIsBlockTerminator(SpvOp opcode);
diff --git a/source/opt/basic_block.cpp b/source/opt/basic_block.cpp
index b7e122c..e82a744 100644
--- a/source/opt/basic_block.cpp
+++ b/source/opt/basic_block.cpp
@@ -230,7 +230,7 @@
std::ostringstream str;
ForEachInst([&str, options](const Instruction* inst) {
str << inst->PrettyPrint(options);
- if (!IsTerminatorInst(inst->opcode())) {
+ if (!spvOpcodeIsBlockTerminator(inst->opcode())) {
str << std::endl;
}
});
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index 88f395f..8159ebf 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -383,9 +383,7 @@
uint32_t returnLabelId = 0;
for (auto callee_block_itr = calleeFn->begin();
callee_block_itr != calleeFn->end(); ++callee_block_itr) {
- if (callee_block_itr->tail()->opcode() == SpvOpUnreachable ||
- callee_block_itr->tail()->opcode() == SpvOpKill ||
- callee_block_itr->tail()->opcode() == SpvOpTerminateInvocation) {
+ if (spvOpcodeIsAbort(callee_block_itr->tail()->opcode())) {
returnLabelId = context()->TakeNextId();
break;
}
@@ -759,8 +757,7 @@
bool InlinePass::ContainsKillOrTerminateInvocation(Function* func) const {
return !func->WhileEachInst([](Instruction* inst) {
- const auto opcode = inst->opcode();
- return (opcode != SpvOpKill) && (opcode != SpvOpTerminateInvocation);
+ return !spvOpcodeTerminatesExecution(inst->opcode());
});
}
diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp
index 06099ce..70e5144 100644
--- a/source/opt/ir_loader.cpp
+++ b/source/opt/ir_loader.cpp
@@ -137,7 +137,7 @@
return false;
}
block_ = MakeUnique<BasicBlock>(std::move(spv_inst));
- } else if (IsTerminatorInst(opcode)) {
+ } else if (spvOpcodeIsBlockTerminator(opcode)) {
if (function_ == nullptr) {
Error(consumer_, src, loc, "terminator instruction outside function");
return false;
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index 9d3b0ed..0c88601 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -188,7 +188,7 @@
i->ToBinaryWithoutAttachedDebugInsts(binary);
}
// Update the last line instruction.
- if (IsTerminatorInst(opcode) || opcode == SpvOpNoLine) {
+ if (spvOpcodeIsBlockTerminator(opcode) || opcode == SpvOpNoLine) {
last_line_inst = nullptr;
} else if (opcode == SpvOpLoopMerge || opcode == SpvOpSelectionMerge) {
between_merge_and_branch = true;
diff --git a/source/opt/reflect.h b/source/opt/reflect.h
index d374e68..c7d46df 100644
--- a/source/opt/reflect.h
+++ b/source/opt/reflect.h
@@ -59,10 +59,6 @@
inline bool IsSpecConstantInst(SpvOp opcode) {
return opcode >= SpvOpSpecConstantTrue && opcode <= SpvOpSpecConstantOp;
}
-inline bool IsTerminatorInst(SpvOp opcode) {
- return (opcode >= SpvOpBranch && opcode <= SpvOpUnreachable) ||
- (opcode == SpvOpTerminateInvocation);
-}
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/inline_test.cpp b/test/opt/inline_test.cpp
index d3ef09c..2939901 100644
--- a/test/opt/inline_test.cpp
+++ b/test/opt/inline_test.cpp
@@ -2581,6 +2581,63 @@
SinglePassRunAndCheck<InlineExhaustivePass>(before, after, false, true);
}
+TEST_F(InlineTest, InlineFuncWithOpTerminateRayNotInContinue) {
+ const std::string text =
+ R"(
+ OpCapability RayTracingKHR
+ OpExtension "SPV_KHR_ray_tracing"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint AnyHitKHR %MyAHitMain2 "MyAHitMain2" %a
+ OpSource HLSL 630
+ OpName %a "a"
+ OpName %MyAHitMain2 "MyAHitMain2"
+ OpName %param_var_a "param.var.a"
+ OpName %src_MyAHitMain2 "src.MyAHitMain2"
+ OpName %a_0 "a"
+ OpName %bb_entry "bb.entry"
+ %int = OpTypeInt 32 1
+%_ptr_IncomingRayPayloadKHR_int = OpTypePointer IncomingRayPayloadKHR %int
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+%_ptr_Function_int = OpTypePointer Function %int
+ %14 = OpTypeFunction %void %_ptr_Function_int
+ %a = OpVariable %_ptr_IncomingRayPayloadKHR_int IncomingRayPayloadKHR
+%MyAHitMain2 = OpFunction %void None %6
+ %7 = OpLabel
+%param_var_a = OpVariable %_ptr_Function_int Function
+ %10 = OpLoad %int %a
+ OpStore %param_var_a %10
+ %11 = OpFunctionCall %void %src_MyAHitMain2 %param_var_a
+ %13 = OpLoad %int %param_var_a
+ OpStore %a %13
+ OpReturn
+ OpFunctionEnd
+%src_MyAHitMain2 = OpFunction %void None %14
+ %a_0 = OpFunctionParameter %_ptr_Function_int
+ %bb_entry = OpLabel
+ %17 = OpLoad %int %a_0
+ OpStore %a %17
+ OpTerminateRayKHR
+ OpFunctionEnd
+
+; CHECK: %MyAHitMain2 = OpFunction %void None
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: %param_var_a = OpVariable %_ptr_Function_int Function
+; CHECK-NEXT: OpLoad %int %a
+; CHECK-NEXT: OpStore %param_var_a {{%\d+}}
+; CHECK-NEXT: OpLoad %int %param_var_a
+; CHECK-NEXT: OpStore %a {{%\d+}}
+; CHECK-NEXT: OpTerminateRayKHR
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: OpLoad %int %param_var_a
+; CHECK-NEXT: OpStore %a %16
+; CHECK-NEXT: OpReturn
+; CHECK-NEXT: OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<InlineExhaustivePass>(text, false);
+}
+
TEST_F(InlineTest, EarlyReturnFunctionInlined) {
// #version 140
//