Handle dontinline function in spread-volatile-semantics (#4776)

Handle function calls in spread-volatile-semantics
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index a80d4f2..c9c3f1b 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -926,6 +926,19 @@
   return modified;
 }
 
+void IRContext::CollectCallTreeFromRoots(unsigned entryId,
+                                         std::unordered_set<uint32_t>* funcs) {
+  std::queue<uint32_t> roots;
+  roots.push(entryId);
+  while (!roots.empty()) {
+    const uint32_t fi = roots.front();
+    roots.pop();
+    funcs->insert(fi);
+    Function* fn = GetFunction(fi);
+    AddCalls(fn, &roots);
+  }
+}
+
 void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
   if (!consumer()) {
     return;
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 946f9e9..f9f5153 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -411,6 +411,10 @@
   void CollectNonSemanticTree(Instruction* inst,
                               std::unordered_set<Instruction*>* to_kill);
 
+  // Collect function reachable from |entryId|, returns |funcs|
+  void CollectCallTreeFromRoots(unsigned entryId,
+                                std::unordered_set<uint32_t>* funcs);
+
   // Returns true if all of the given analyses are valid.
   bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }
 
diff --git a/source/opt/spread_volatile_semantics.cpp b/source/opt/spread_volatile_semantics.cpp
index a1d3432..b61fd0f 100644
--- a/source/opt/spread_volatile_semantics.cpp
+++ b/source/opt/spread_volatile_semantics.cpp
@@ -68,38 +68,12 @@
   return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
 }
 
-bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
-  std::unordered_set<uint32_t> entry_function_ids;
-  for (Instruction& entry_point : module->entry_points()) {
-    entry_function_ids.insert(
-        entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
-  }
-  for (auto& function : *module) {
-    if (entry_function_ids.find(function.result_id()) ==
-        entry_function_ids.end()) {
-      std::string message(
-          "Functions of SPIR-V for spread-volatile-semantics pass input must "
-          "be inlined except entry points");
-      message += "\n  " + function.DefInst().PrettyPrint(
-                              SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
-      context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
-      return false;
-    }
-  }
-  return true;
-}
-
 }  // namespace
 
 Pass::Status SpreadVolatileSemantics::Process() {
   if (HasNoExecutionModel()) {
     return Status::SuccessWithoutChange;
   }
-
-  if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
-    return Status::Failure;
-  }
-
   const bool is_vk_memory_model_enabled =
       context()->get_feature_mgr()->HasCapability(
           SpvCapabilityVulkanMemoryModel);
@@ -142,6 +116,8 @@
     uint32_t var_id, Instruction* entry_point) {
   uint32_t entry_function_id =
       entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
+  std::unordered_set<uint32_t> funcs;
+  context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
   return !VisitLoadsOfPointersToVariableInEntries(
       var_id,
       [](Instruction* load) {
@@ -154,7 +130,7 @@
             load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
         return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
       },
-      {entry_function_id});
+      funcs);
 }
 
 bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
@@ -225,7 +201,7 @@
 
 bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
     uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
-    const std::unordered_set<uint32_t>& entry_function_ids) {
+    const std::unordered_set<uint32_t>& function_ids) {
   std::vector<uint32_t> worklist({var_id});
   auto* def_use_mgr = context()->get_def_use_mgr();
   while (!worklist.empty()) {
@@ -233,11 +209,11 @@
     worklist.pop_back();
     bool finish_traversal = !def_use_mgr->WhileEachUser(
         ptr_id, [this, &worklist, &ptr_id, handle_load,
-                 &entry_function_ids](Instruction* user) {
+                 &function_ids](Instruction* user) {
           BasicBlock* block = context()->get_instr_block(user);
           if (block == nullptr ||
-              entry_function_ids.find(block->GetParent()->result_id()) ==
-                  entry_function_ids.end()) {
+              function_ids.find(block->GetParent()->result_id()) ==
+                  function_ids.end()) {
             return true;
           }
 
@@ -266,21 +242,25 @@
     Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
   // Set Volatile memory operand for all load instructions if they do not have
   // it.
-  VisitLoadsOfPointersToVariableInEntries(
-      var->result_id(),
-      [](Instruction* load) {
-        if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
-          load->AddOperand(
-              {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
+  for (auto entry_id : entry_function_ids) {
+    std::unordered_set<uint32_t> funcs;
+    context()->CollectCallTreeFromRoots(entry_id, &funcs);
+    VisitLoadsOfPointersToVariableInEntries(
+        var->result_id(),
+        [](Instruction* load) {
+          if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
+            load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
+                              {SpvMemoryAccessVolatileMask}});
+            return true;
+          }
+          uint32_t memory_operands =
+              load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
+          memory_operands |= SpvMemoryAccessVolatileMask;
+          load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
           return true;
-        }
-        uint32_t memory_operands =
-            load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
-        memory_operands |= SpvMemoryAccessVolatileMask;
-        load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
-        return true;
-      },
-      entry_function_ids);
+        },
+        funcs);
+  }
 }
 
 bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
diff --git a/source/opt/spread_volatile_semantics.h b/source/opt/spread_volatile_semantics.h
index 531a21d..014858d 100644
--- a/source/opt/spread_volatile_semantics.h
+++ b/source/opt/spread_volatile_semantics.h
@@ -72,15 +72,14 @@
                                                  Instruction* entry_point);
 
   // Visits load instructions of pointers to variable whose result id is
-  // |var_id| if the load instructions are in entry points whose
-  // function id is one of |entry_function_ids|. |handle_load| is a function to
-  // do some actions for the load instructions. Finishes the traversal and
-  // returns false if |handle_load| returns false for a load instruction.
-  // Otherwise, returns true after running |handle_load| for all the load
-  // instructions.
+  // |var_id| if the load instructions are in reachable functions from entry
+  // points. |handle_load| is a function to do some actions for the load
+  // instructions. Finishes the traversal and returns false if |handle_load|
+  // returns false for a load instruction. Otherwise, returns true after running
+  // |handle_load| for all the load instructions.
   bool VisitLoadsOfPointersToVariableInEntries(
       uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
-      const std::unordered_set<uint32_t>& entry_function_ids);
+      const std::unordered_set<uint32_t>& function_ids);
 
   // Sets Memory Operands of OpLoad instructions that load |var| or pointers
   // of |var| as Volatile if the function id of the OpLoad instruction is
diff --git a/test/opt/spread_volatile_semantics_test.cpp b/test/opt/spread_volatile_semantics_test.cpp
index fdabd92..dbb889c 100644
--- a/test/opt/spread_volatile_semantics_test.cpp
+++ b/test/opt/spread_volatile_semantics_test.cpp
@@ -54,6 +54,7 @@
 OpSourceExtension "GL_EXT_nonuniform_qualifier"
 OpSourceExtension "GL_KHR_ray_tracing"
 OpName %main "main"
+OpName %fn "fn"
 OpName %StorageBuffer "StorageBuffer"
 OpMemberName %StorageBuffer 0 "index"
 OpMemberName %StorageBuffer 1 "red"
@@ -109,6 +110,11 @@
 %29 = OpCompositeExtract %float %27 0
 %31 = OpAccessChain %_ptr_Uniform_float %sbo %int_1
 OpStore %31 %29
+%32 = OpFunctionCall %void %fn
+OpReturn
+OpFunctionEnd
+%fn = OpFunction %void None %3
+%33 = OpLabel
 OpReturn
 OpFunctionEnd
 )");
@@ -782,12 +788,7 @@
 OpFunctionEnd
 )";
 
-  EXPECT_EQ(RunPass(text), Pass::Status::Failure);
-  const char expected_error[] =
-      "ERROR: 0: Functions of SPIR-V for spread-volatile-semantics pass "
-      "input must be inlined except entry points";
-  EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
-               expected_error);
+  EXPECT_EQ(RunPass(text), Pass::Status::SuccessWithoutChange);
 }
 
 TEST_F(VolatileSpreadErrorTest, VarNotUsedInEntryPointForVolatile) {
@@ -1133,6 +1134,134 @@
   EXPECT_EQ(status, Pass::Status::SuccessWithoutChange);
 }
 
+TEST_F(VolatileSpreadTest, NoInlinedfuncCalls) {
+  const std::string text = R"(
+OpCapability RayTracingNV
+OpCapability VulkanMemoryModel
+OpCapability GroupNonUniform
+OpExtension "SPV_NV_ray_tracing"
+OpExtension "SPV_KHR_vulkan_memory_model"
+OpMemoryModel Logical Vulkan
+OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
+OpSource HLSL 630
+OpName %main "main"
+OpName %src_main "src.main"
+OpName %bb_entry "bb.entry"
+OpName %func0 "func0"
+OpName %bb_entry_0 "bb.entry"
+OpName %func2 "func2"
+OpName %bb_entry_1 "bb.entry"
+OpName %param_var_count "param.var.count"
+OpName %func1 "func1"
+OpName %bb_entry_2 "bb.entry"
+OpName %func3 "func3"
+OpName %count "count"
+OpName %bb_entry_3 "bb.entry"
+OpDecorate %SubgroupSize BuiltIn SubgroupSize
+%uint = OpTypeInt 32 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+%_ptr_Function_uint = OpTypePointer Function %uint
+%25 = OpTypeFunction %void %_ptr_Function_uint
+%SubgroupSize = OpVariable %_ptr_Input_uint Input
+%main = OpFunction %void None %6
+%7 = OpLabel
+%8 = OpFunctionCall %void %src_main
+OpReturn
+OpFunctionEnd
+%src_main = OpFunction %void None %6
+%bb_entry = OpLabel
+%11 = OpFunctionCall %void %func0
+OpReturn
+OpFunctionEnd
+%func0 = OpFunction %void DontInline %6
+%bb_entry_0 = OpLabel
+%14 = OpFunctionCall %void %func2
+%16 = OpFunctionCall %void %func1
+OpReturn
+OpFunctionEnd
+%func2 = OpFunction %void DontInline %6
+%bb_entry_1 = OpLabel
+%param_var_count = OpVariable %_ptr_Function_uint Function
+; CHECK: {{%\w+}} = OpLoad %uint %SubgroupSize Volatile
+%21 = OpLoad %uint %SubgroupSize
+OpStore %param_var_count %21
+%22 = OpFunctionCall %void %func3 %param_var_count
+OpReturn
+OpFunctionEnd
+%func1 = OpFunction %void DontInline %6
+%bb_entry_2 = OpLabel
+OpReturn
+OpFunctionEnd
+%func3 = OpFunction %void DontInline %25
+%count = OpFunctionParameter %_ptr_Function_uint
+%bb_entry_3 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+  SinglePassRunAndMatch<SpreadVolatileSemantics>(text, true);
+}
+
+TEST_F(VolatileSpreadErrorTest, NoInlinedMultiEntryfuncCalls) {
+  const std::string text = R"(
+OpCapability RayTracingNV
+OpCapability SubgroupBallotKHR
+OpExtension "SPV_NV_ray_tracing"
+OpExtension "SPV_KHR_shader_ballot"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
+OpEntryPoint GLCompute %main2 "main2" %gl_LocalInvocationIndex %SubgroupSize
+OpSource HLSL 630
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %main2 "main2"
+OpName %bb_entry_0 "bb.entry"
+OpName %func "func"
+OpName %count "count"
+OpName %bb_entry_1 "bb.entry"
+OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
+OpDecorate %SubgroupSize BuiltIn SubgroupSize
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%void = OpTypeVoid
+%12 = OpTypeFunction %void
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%29 = OpTypeFunction %void %_ptr_Function_v4float
+%34 = OpTypeFunction %void %_ptr_Function_uint
+%SubgroupSize = OpVariable %_ptr_Input_uint Input
+%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
+%main = OpFunction %void None %12
+%bb_entry = OpLabel
+%20 = OpFunctionCall %void %func
+OpReturn
+OpFunctionEnd
+%main2 = OpFunction %void None %12
+%bb_entry_0 = OpLabel
+%33 = OpFunctionCall %void %func
+OpReturn
+OpFunctionEnd
+%func = OpFunction %void DontInline %12
+%bb_entry_1 = OpLabel
+%count = OpVariable %_ptr_Function_uint Function
+%35 = OpLoad %uint %SubgroupSize
+OpStore %count %35
+OpReturn
+OpFunctionEnd
+)";
+  EXPECT_EQ(RunPass(text), Pass::Status::Failure);
+  const char expected_error[] =
+      "ERROR: 0: Variable is a target for Volatile semantics for an entry "
+      "point, but it is not for another entry point";
+  EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
+               expected_error);
+}
+
 }  // namespace
 }  // namespace opt
 }  // namespace spvtools