Handle dontinline function in spread-volatile-semantics (#4776)
Handle function calls in spread-volatile-semantics
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index a80d4f2..c9c3f1b 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -926,6 +926,19 @@
return modified;
}
+void IRContext::CollectCallTreeFromRoots(unsigned entryId,
+ std::unordered_set<uint32_t>* funcs) {
+ std::queue<uint32_t> roots;
+ roots.push(entryId);
+ while (!roots.empty()) {
+ const uint32_t fi = roots.front();
+ roots.pop();
+ funcs->insert(fi);
+ Function* fn = GetFunction(fi);
+ AddCalls(fn, &roots);
+ }
+}
+
void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
if (!consumer()) {
return;
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 946f9e9..f9f5153 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -411,6 +411,10 @@
void CollectNonSemanticTree(Instruction* inst,
std::unordered_set<Instruction*>* to_kill);
+ // Collect function reachable from |entryId|, returns |funcs|
+ void CollectCallTreeFromRoots(unsigned entryId,
+ std::unordered_set<uint32_t>* funcs);
+
// Returns true if all of the given analyses are valid.
bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }
diff --git a/source/opt/spread_volatile_semantics.cpp b/source/opt/spread_volatile_semantics.cpp
index a1d3432..b61fd0f 100644
--- a/source/opt/spread_volatile_semantics.cpp
+++ b/source/opt/spread_volatile_semantics.cpp
@@ -68,38 +68,12 @@
return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
}
-bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
- std::unordered_set<uint32_t> entry_function_ids;
- for (Instruction& entry_point : module->entry_points()) {
- entry_function_ids.insert(
- entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
- }
- for (auto& function : *module) {
- if (entry_function_ids.find(function.result_id()) ==
- entry_function_ids.end()) {
- std::string message(
- "Functions of SPIR-V for spread-volatile-semantics pass input must "
- "be inlined except entry points");
- message += "\n " + function.DefInst().PrettyPrint(
- SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
- context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
- return false;
- }
- }
- return true;
-}
-
} // namespace
Pass::Status SpreadVolatileSemantics::Process() {
if (HasNoExecutionModel()) {
return Status::SuccessWithoutChange;
}
-
- if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
- return Status::Failure;
- }
-
const bool is_vk_memory_model_enabled =
context()->get_feature_mgr()->HasCapability(
SpvCapabilityVulkanMemoryModel);
@@ -142,6 +116,8 @@
uint32_t var_id, Instruction* entry_point) {
uint32_t entry_function_id =
entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
+ std::unordered_set<uint32_t> funcs;
+ context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
return !VisitLoadsOfPointersToVariableInEntries(
var_id,
[](Instruction* load) {
@@ -154,7 +130,7 @@
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
},
- {entry_function_id});
+ funcs);
}
bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
@@ -225,7 +201,7 @@
bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
- const std::unordered_set<uint32_t>& entry_function_ids) {
+ const std::unordered_set<uint32_t>& function_ids) {
std::vector<uint32_t> worklist({var_id});
auto* def_use_mgr = context()->get_def_use_mgr();
while (!worklist.empty()) {
@@ -233,11 +209,11 @@
worklist.pop_back();
bool finish_traversal = !def_use_mgr->WhileEachUser(
ptr_id, [this, &worklist, &ptr_id, handle_load,
- &entry_function_ids](Instruction* user) {
+ &function_ids](Instruction* user) {
BasicBlock* block = context()->get_instr_block(user);
if (block == nullptr ||
- entry_function_ids.find(block->GetParent()->result_id()) ==
- entry_function_ids.end()) {
+ function_ids.find(block->GetParent()->result_id()) ==
+ function_ids.end()) {
return true;
}
@@ -266,21 +242,25 @@
Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
// Set Volatile memory operand for all load instructions if they do not have
// it.
- VisitLoadsOfPointersToVariableInEntries(
- var->result_id(),
- [](Instruction* load) {
- if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
- load->AddOperand(
- {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
+ for (auto entry_id : entry_function_ids) {
+ std::unordered_set<uint32_t> funcs;
+ context()->CollectCallTreeFromRoots(entry_id, &funcs);
+ VisitLoadsOfPointersToVariableInEntries(
+ var->result_id(),
+ [](Instruction* load) {
+ if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
+ load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
+ {SpvMemoryAccessVolatileMask}});
+ return true;
+ }
+ uint32_t memory_operands =
+ load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
+ memory_operands |= SpvMemoryAccessVolatileMask;
+ load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
return true;
- }
- uint32_t memory_operands =
- load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
- memory_operands |= SpvMemoryAccessVolatileMask;
- load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
- return true;
- },
- entry_function_ids);
+ },
+ funcs);
+ }
}
bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
diff --git a/source/opt/spread_volatile_semantics.h b/source/opt/spread_volatile_semantics.h
index 531a21d..014858d 100644
--- a/source/opt/spread_volatile_semantics.h
+++ b/source/opt/spread_volatile_semantics.h
@@ -72,15 +72,14 @@
Instruction* entry_point);
// Visits load instructions of pointers to variable whose result id is
- // |var_id| if the load instructions are in entry points whose
- // function id is one of |entry_function_ids|. |handle_load| is a function to
- // do some actions for the load instructions. Finishes the traversal and
- // returns false if |handle_load| returns false for a load instruction.
- // Otherwise, returns true after running |handle_load| for all the load
- // instructions.
+ // |var_id| if the load instructions are in reachable functions from entry
+ // points. |handle_load| is a function to do some actions for the load
+ // instructions. Finishes the traversal and returns false if |handle_load|
+ // returns false for a load instruction. Otherwise, returns true after running
+ // |handle_load| for all the load instructions.
bool VisitLoadsOfPointersToVariableInEntries(
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
- const std::unordered_set<uint32_t>& entry_function_ids);
+ const std::unordered_set<uint32_t>& function_ids);
// Sets Memory Operands of OpLoad instructions that load |var| or pointers
// of |var| as Volatile if the function id of the OpLoad instruction is
diff --git a/test/opt/spread_volatile_semantics_test.cpp b/test/opt/spread_volatile_semantics_test.cpp
index fdabd92..dbb889c 100644
--- a/test/opt/spread_volatile_semantics_test.cpp
+++ b/test/opt/spread_volatile_semantics_test.cpp
@@ -54,6 +54,7 @@
OpSourceExtension "GL_EXT_nonuniform_qualifier"
OpSourceExtension "GL_KHR_ray_tracing"
OpName %main "main"
+OpName %fn "fn"
OpName %StorageBuffer "StorageBuffer"
OpMemberName %StorageBuffer 0 "index"
OpMemberName %StorageBuffer 1 "red"
@@ -109,6 +110,11 @@
%29 = OpCompositeExtract %float %27 0
%31 = OpAccessChain %_ptr_Uniform_float %sbo %int_1
OpStore %31 %29
+%32 = OpFunctionCall %void %fn
+OpReturn
+OpFunctionEnd
+%fn = OpFunction %void None %3
+%33 = OpLabel
OpReturn
OpFunctionEnd
)");
@@ -782,12 +788,7 @@
OpFunctionEnd
)";
- EXPECT_EQ(RunPass(text), Pass::Status::Failure);
- const char expected_error[] =
- "ERROR: 0: Functions of SPIR-V for spread-volatile-semantics pass "
- "input must be inlined except entry points";
- EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
- expected_error);
+ EXPECT_EQ(RunPass(text), Pass::Status::SuccessWithoutChange);
}
TEST_F(VolatileSpreadErrorTest, VarNotUsedInEntryPointForVolatile) {
@@ -1133,6 +1134,134 @@
EXPECT_EQ(status, Pass::Status::SuccessWithoutChange);
}
+TEST_F(VolatileSpreadTest, NoInlinedfuncCalls) {
+ const std::string text = R"(
+OpCapability RayTracingNV
+OpCapability VulkanMemoryModel
+OpCapability GroupNonUniform
+OpExtension "SPV_NV_ray_tracing"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical Vulkan
+OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
+OpSource HLSL 630
+OpName %main "main"
+OpName %src_main "src.main"
+OpName %bb_entry "bb.entry"
+OpName %func0 "func0"
+OpName %bb_entry_0 "bb.entry"
+OpName %func2 "func2"
+OpName %bb_entry_1 "bb.entry"
+OpName %param_var_count "param.var.count"
+OpName %func1 "func1"
+OpName %bb_entry_2 "bb.entry"
+OpName %func3 "func3"
+OpName %count "count"
+OpName %bb_entry_3 "bb.entry"
+OpDecorate %SubgroupSize BuiltIn SubgroupSize
+%uint = OpTypeInt 32 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+%_ptr_Function_uint = OpTypePointer Function %uint
+%25 = OpTypeFunction %void %_ptr_Function_uint
+%SubgroupSize = OpVariable %_ptr_Input_uint Input
+%main = OpFunction %void None %6
+%7 = OpLabel
+%8 = OpFunctionCall %void %src_main
+OpReturn
+OpFunctionEnd
+%src_main = OpFunction %void None %6
+%bb_entry = OpLabel
+%11 = OpFunctionCall %void %func0
+OpReturn
+OpFunctionEnd
+%func0 = OpFunction %void DontInline %6
+%bb_entry_0 = OpLabel
+%14 = OpFunctionCall %void %func2
+%16 = OpFunctionCall %void %func1
+OpReturn
+OpFunctionEnd
+%func2 = OpFunction %void DontInline %6
+%bb_entry_1 = OpLabel
+%param_var_count = OpVariable %_ptr_Function_uint Function
+; CHECK: {{%\w+}} = OpLoad %uint %SubgroupSize Volatile
+%21 = OpLoad %uint %SubgroupSize
+OpStore %param_var_count %21
+%22 = OpFunctionCall %void %func3 %param_var_count
+OpReturn
+OpFunctionEnd
+%func1 = OpFunction %void DontInline %6
+%bb_entry_2 = OpLabel
+OpReturn
+OpFunctionEnd
+%func3 = OpFunction %void DontInline %25
+%count = OpFunctionParameter %_ptr_Function_uint
+%bb_entry_3 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+ SinglePassRunAndMatch<SpreadVolatileSemantics>(text, true);
+}
+
+TEST_F(VolatileSpreadErrorTest, NoInlinedMultiEntryfuncCalls) {
+ const std::string text = R"(
+OpCapability RayTracingNV
+OpCapability SubgroupBallotKHR
+OpExtension "SPV_NV_ray_tracing"
+OpExtension "SPV_KHR_shader_ballot"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
+OpEntryPoint GLCompute %main2 "main2" %gl_LocalInvocationIndex %SubgroupSize
+OpSource HLSL 630
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %main2 "main2"
+OpName %bb_entry_0 "bb.entry"
+OpName %func "func"
+OpName %count "count"
+OpName %bb_entry_1 "bb.entry"
+OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
+OpDecorate %SubgroupSize BuiltIn SubgroupSize
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%void = OpTypeVoid
+%12 = OpTypeFunction %void
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%29 = OpTypeFunction %void %_ptr_Function_v4float
+%34 = OpTypeFunction %void %_ptr_Function_uint
+%SubgroupSize = OpVariable %_ptr_Input_uint Input
+%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
+%main = OpFunction %void None %12
+%bb_entry = OpLabel
+%20 = OpFunctionCall %void %func
+OpReturn
+OpFunctionEnd
+%main2 = OpFunction %void None %12
+%bb_entry_0 = OpLabel
+%33 = OpFunctionCall %void %func
+OpReturn
+OpFunctionEnd
+%func = OpFunction %void DontInline %12
+%bb_entry_1 = OpLabel
+%count = OpVariable %_ptr_Function_uint Function
+%35 = OpLoad %uint %SubgroupSize
+OpStore %count %35
+OpReturn
+OpFunctionEnd
+)";
+ EXPECT_EQ(RunPass(text), Pass::Status::Failure);
+ const char expected_error[] =
+ "ERROR: 0: Variable is a target for Volatile semantics for an entry "
+ "point, but it is not for another entry point";
+ EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
+ expected_error);
+}
+
} // namespace
} // namespace opt
} // namespace spvtools