spirv-opt : Add FixFuncCallArgumentsPass (#4775)

spirv validation require OpFunctionCall with memory object, usually this
is non issue as all the functions are inlined.
This pass deal with some case for
DontInline function. accesschain input operand would be replaced new
created variable
diff --git a/Android.mk b/Android.mk
index b9fbcc8..d5b83b8 100644
--- a/Android.mk
+++ b/Android.mk
@@ -109,6 +109,7 @@
 		source/opt/eliminate_dead_input_components_pass.cpp \
 		source/opt/eliminate_dead_members_pass.cpp \
 		source/opt/feature_manager.cpp \
+		source/opt/fix_func_call_arguments.cpp \
 		source/opt/fix_storage_class.cpp \
 		source/opt/flatten_decoration_pass.cpp \
 		source/opt/fold.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index ba05497..5428d88 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -629,6 +629,8 @@
     "source/opt/empty_pass.h",
     "source/opt/feature_manager.cpp",
     "source/opt/feature_manager.h",
+    "source/opt/fix_func_call_arguments.cpp",
+    "source/opt/fix_func_call_arguments.h",
     "source/opt/fix_storage_class.cpp",
     "source/opt/fix_storage_class.h",
     "source/opt/flatten_decoration_pass.cpp",
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index fbbd9bc..df830d7 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -907,6 +907,11 @@
 // from every function in the module.  This is useful if you want the inliner to
 // inline these functions some reason.
 Optimizer::PassToken CreateRemoveDontInlinePass();
+// Create a fix-func-call-param pass to fix non memory argument for the function
+// call, as spirv-validation requires function parameters to be an memory
+// object, currently the pass would remove accesschain pointer argument passed
+// to the function
+Optimizer::PassToken CreateFixFuncCallArgumentsPass();
 }  // namespace spvtools
 
 #endif  // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 61e7a98..c42ae22 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 set(SPIRV_TOOLS_OPT_SOURCES
+  fix_func_call_arguments.h
   aggressive_dead_code_elim_pass.h
   amd_ext_to_khr.h
   basic_block.h
@@ -126,6 +127,7 @@
   workaround1209.h
   wrap_opkill.h
 
+  fix_func_call_arguments.cpp
   aggressive_dead_code_elim_pass.cpp
   amd_ext_to_khr.cpp
   basic_block.cpp
diff --git a/source/opt/fix_func_call_arguments.cpp b/source/opt/fix_func_call_arguments.cpp
new file mode 100644
index 0000000..d140fb4
--- /dev/null
+++ b/source/opt/fix_func_call_arguments.cpp
@@ -0,0 +1,90 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fix_func_call_arguments.h"
+
+#include "ir_builder.h"
+
+using namespace spvtools;
+using namespace opt;
+
+bool FixFuncCallArgumentsPass::ModuleHasASingleFunction() {
+  auto funcsNum = get_module()->end() - get_module()->begin();
+  return funcsNum == 1;
+}
+
+Pass::Status FixFuncCallArgumentsPass::Process() {
+  bool modified = false;
+  if (ModuleHasASingleFunction()) return Status::SuccessWithoutChange;
+  for (auto& func : *get_module()) {
+    func.ForEachInst([this, &modified](Instruction* inst) {
+      if (inst->opcode() == SpvOpFunctionCall) {
+        modified |= FixFuncCallArguments(inst);
+      }
+    });
+  }
+  return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+bool FixFuncCallArgumentsPass::FixFuncCallArguments(
+    Instruction* func_call_inst) {
+  bool modified = false;
+  for (uint32_t i = 0; i < func_call_inst->NumInOperands(); ++i) {
+    Operand& op = func_call_inst->GetInOperand(i);
+    if (op.type != SPV_OPERAND_TYPE_ID) continue;
+    Instruction* operand_inst = get_def_use_mgr()->GetDef(op.AsId());
+    if (operand_inst->opcode() == SpvOpAccessChain) {
+      uint32_t var_id =
+          ReplaceAccessChainFuncCallArguments(func_call_inst, operand_inst);
+      func_call_inst->SetInOperand(i, {var_id});
+      modified = true;
+    }
+  }
+  if (modified) {
+    context()->UpdateDefUse(func_call_inst);
+  }
+  return modified;
+}
+
+uint32_t FixFuncCallArgumentsPass::ReplaceAccessChainFuncCallArguments(
+    Instruction* func_call_inst, Instruction* operand_inst) {
+  InstructionBuilder builder(
+      context(), func_call_inst,
+      IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
+
+  Instruction* next_insert_point = func_call_inst->NextNode();
+  // Get Variable insertion point
+  Function* func = context()->get_instr_block(func_call_inst)->GetParent();
+  Instruction* variable_insertion_point = &*(func->begin()->begin());
+  Instruction* op_ptr_type = get_def_use_mgr()->GetDef(operand_inst->type_id());
+  Instruction* op_type =
+      get_def_use_mgr()->GetDef(op_ptr_type->GetSingleWordInOperand(1));
+  uint32_t varType = context()->get_type_mgr()->FindPointerToType(
+      op_type->result_id(), SpvStorageClassFunction);
+  // Create new variable
+  builder.SetInsertPoint(variable_insertion_point);
+  Instruction* var = builder.AddVariable(varType, SpvStorageClassFunction);
+  // Load access chain to the new variable before function call
+  builder.SetInsertPoint(func_call_inst);
+
+  uint32_t operand_id = operand_inst->result_id();
+  Instruction* load = builder.AddLoad(op_type->result_id(), operand_id);
+  builder.AddStore(var->result_id(), load->result_id());
+  // Load return value to the acesschain after function call
+  builder.SetInsertPoint(next_insert_point);
+  load = builder.AddLoad(op_type->result_id(), var->result_id());
+  builder.AddStore(operand_id, load->result_id());
+
+  return var->result_id();
+}
diff --git a/source/opt/fix_func_call_arguments.h b/source/opt/fix_func_call_arguments.h
new file mode 100644
index 0000000..15781b8
--- /dev/null
+++ b/source/opt/fix_func_call_arguments.h
@@ -0,0 +1,47 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _VAR_FUNC_CALL_PASS_H
+#define _VAR_FUNC_CALL_PASS_H
+
+#include "source/opt/pass.h"
+
+namespace spvtools {
+namespace opt {
+class FixFuncCallArgumentsPass : public Pass {
+ public:
+  FixFuncCallArgumentsPass() {}
+  const char* name() const override { return "fix-for-funcall-param"; }
+  Status Process() override;
+  // Returns true if the module has one one function.
+  bool ModuleHasASingleFunction();
+  // Copies from the memory pointed to by |operand_inst| to a new function scope
+  // variable created before |func_call_inst|, and
+  // copies the value of the new variable back to the memory pointed to by
+  // |operand_inst| after |funct_call_inst|  Returns the id of
+  // the new variable.
+  uint32_t ReplaceAccessChainFuncCallArguments(Instruction* func_call_inst,
+                                               Instruction* operand_inst);
+
+  // Fix function call |func_call_inst| non memory object arguments
+  bool FixFuncCallArguments(Instruction* func_call_inst);
+
+  IRContext::Analysis GetPreservedAnalyses() override {
+    return IRContext::kAnalysisTypes;
+  }
+};
+}  // namespace opt
+}  // namespace spvtools
+
+#endif  // _VAR_FUNC_CALL_PASS_H
\ No newline at end of file
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index 4433cf0..9d4fa8f 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -487,6 +487,15 @@
     return AddInstruction(std::move(new_inst));
   }
 
+  Instruction* AddVariable(uint32_t type_id, uint32_t storage_class) {
+    std::vector<Operand> operands;
+    operands.push_back({SPV_OPERAND_TYPE_ID, {storage_class}});
+    std::unique_ptr<Instruction> new_inst(
+        new Instruction(GetContext(), SpvOpVariable, type_id,
+                        GetContext()->TakeNextId(), operands));
+    return AddInstruction(std::move(new_inst));
+  }
+
   Instruction* AddStore(uint32_t ptr_id, uint32_t obj_id) {
     std::vector<Operand> operands;
     operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}});
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index f28b1ba..051d573 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -525,6 +525,8 @@
     RegisterPass(CreateRemoveDontInlinePass());
   } else if (pass_name == "eliminate-dead-input-components") {
     RegisterPass(CreateEliminateDeadInputComponentsPass());
+  } else if (pass_name == "fix-func-call-param") {
+    RegisterPass(CreateFixFuncCallArgumentsPass());
   } else if (pass_name == "convert-to-sampled-image") {
     if (pass_args.size() > 0) {
       auto descriptor_set_binding_pairs =
@@ -1022,4 +1024,9 @@
   return MakeUnique<Optimizer::PassToken::Impl>(
       MakeUnique<opt::RemoveDontInline>());
 }
+
+Optimizer::PassToken CreateFixFuncCallArgumentsPass() {
+  return MakeUnique<Optimizer::PassToken::Impl>(
+      MakeUnique<opt::FixFuncCallArgumentsPass>());
+}
 }  // namespace spvtools
diff --git a/source/opt/passes.h b/source/opt/passes.h
index a12c76b..facaa41 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -37,6 +37,7 @@
 #include "source/opt/eliminate_dead_input_components_pass.h"
 #include "source/opt/eliminate_dead_members_pass.h"
 #include "source/opt/empty_pass.h"
+#include "source/opt/fix_func_call_arguments.h"
 #include "source/opt/fix_storage_class.h"
 #include "source/opt/flatten_decoration_pass.h"
 #include "source/opt/fold_spec_constant_op_and_composite_pass.h"
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 6dfb1b7..aa47dee 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -45,6 +45,7 @@
        eliminate_dead_input_components_test.cpp
        eliminate_dead_member_test.cpp
        feature_manager_test.cpp
+       fix_func_call_arguments_test.cpp
        fix_storage_class_test.cpp
        flatten_decoration_test.cpp
        fold_spec_const_op_composite_test.cpp
@@ -84,7 +85,7 @@
        reduce_load_size_test.cpp
        redundancy_elimination_test.cpp
        remove_dontinline_test.cpp
-	   remove_unused_interface_variables_test.cpp
+       remove_unused_interface_variables_test.cpp
        register_liveness.cpp
        relax_float_ops_test.cpp
        replace_desc_array_access_using_var_index_test.cpp
@@ -96,7 +97,7 @@
        spread_volatile_semantics_test.cpp
        strength_reduction_test.cpp
        strip_debug_info_test.cpp
-        strip_nonsemantic_info_test.cpp
+       strip_nonsemantic_info_test.cpp
        struct_cfg_analysis_test.cpp
        type_manager_test.cpp
        types_test.cpp
diff --git a/test/opt/fix_func_call_arguments_test.cpp b/test/opt/fix_func_call_arguments_test.cpp
new file mode 100644
index 0000000..ecd13a8
--- /dev/null
+++ b/test/opt/fix_func_call_arguments_test.cpp
@@ -0,0 +1,152 @@
+// Copyright (c) 2022 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "test/opt/pass_fixture.h"
+#include "test/opt/pass_utils.h"
+
+namespace spvtools {
+namespace opt {
+namespace {
+
+using FixFuncCallArgumentsTest = PassTest<::testing::Test>;
+TEST_F(FixFuncCallArgumentsTest, Simple) {
+  const std::string text = R"(
+;
+; CHECK: [[v0:%\w+]] = OpVariable %_ptr_Function_float Function
+; CHECK: [[v1:%\w+]] = OpVariable %_ptr_Function_float Function
+; CHECK: [[v2:%\w+]] = OpVariable %_ptr_Function_T Function
+; CHECK: [[ac0:%\w+]] = OpAccessChain %_ptr_Function_float %t %int_0
+; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
+; CHECK: [[ld0:%\w+]] = OpLoad %float [[ac0]]
+; CHECK:                OpStore [[v1]] [[ld0]]
+; CHECK: [[ld1:%\w+]] = OpLoad %float [[ac1]]
+; CHECK:                OpStore [[v0]] [[ld1]]
+; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[v1]] [[v0]]
+; CHECK: [[ld2:%\w+]] = OpLoad %float [[v0]]
+; CHECK: OpStore [[ac1]] [[ld2]]
+; CHECK: [[ld3:%\w+]] = OpLoad %float [[v1]]
+; CHECK: OpStore [[ac0]] [[ld3]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpSource HLSL 630
+OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
+OpName %r1 "r1"
+OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
+OpMemberName %type_ACSBuffer_counter 0 "counter"
+OpName %counter_var_r1 "counter.var.r1"
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %T "T"
+OpMemberName %T 0 "t0"
+OpName %t "t"
+OpName %fn "fn"
+OpName %p0 "p0"
+OpName %p2 "p2"
+OpName %bb_entry_0 "bb.entry"
+OpDecorate %main LinkageAttributes "main" Export
+OpDecorate %r1 DescriptorSet 0
+OpDecorate %r1 Binding 0
+OpDecorate %counter_var_r1 DescriptorSet 0
+OpDecorate %counter_var_r1 Binding 1
+OpDecorate %_runtimearr_float ArrayStride 4
+OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
+OpDecorate %type_RWStructuredBuffer_float BufferBlock
+OpMemberDecorate %type_ACSBuffer_counter 0 Offset 0
+OpDecorate %type_ACSBuffer_counter BufferBlock
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%int_1 = OpConstant %int 1
+%float = OpTypeFloat 32
+%_runtimearr_float = OpTypeRuntimeArray %float
+%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
+%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
+%type_ACSBuffer_counter = OpTypeStruct %int
+%_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter
+%15 = OpTypeFunction %int
+%T = OpTypeStruct %float
+%_ptr_Function_T = OpTypePointer Function %T
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+%void = OpTypeVoid
+%27 = OpTypeFunction %void %_ptr_Function_float %_ptr_Function_float
+%r1 = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
+%counter_var_r1 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
+%main = OpFunction %int None %15
+%bb_entry = OpLabel
+%t = OpVariable %_ptr_Function_T Function
+%21 = OpAccessChain %_ptr_Function_float %t %int_0
+%23 = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
+%25 = OpFunctionCall %void %fn %21 %23
+OpReturnValue %int_1
+OpFunctionEnd
+%fn = OpFunction %void DontInline %27
+%p0 = OpFunctionParameter %_ptr_Function_float
+%p2 = OpFunctionParameter %_ptr_Function_float
+%bb_entry_0 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, true);
+}
+
+TEST_F(FixFuncCallArgumentsTest, NotAccessChainInput) {
+  const std::string text = R"(
+;
+; CHECK: [[o:%\w+]] = OpCopyObject %_ptr_Function_float %t
+; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[o]]
+;
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpSource HLSL 630
+OpName %main "main"
+OpName %bb_entry "bb.entry"
+OpName %t "t"
+OpName %fn "fn"
+OpName %p0 "p0"
+OpName %bb_entry_0 "bb.entry"
+OpDecorate %main LinkageAttributes "main" Export
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%4 = OpTypeFunction %int
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%void = OpTypeVoid
+%12 = OpTypeFunction %void %_ptr_Function_float
+%main = OpFunction %int None %4
+%bb_entry = OpLabel
+%t = OpVariable %_ptr_Function_float Function
+%t1 = OpCopyObject %_ptr_Function_float %t
+%10 = OpFunctionCall %void %fn %t1
+OpReturnValue %int_1
+OpFunctionEnd
+%fn = OpFunction %void DontInline %12
+%p0 = OpFunctionParameter %_ptr_Function_float
+%bb_entry_0 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, false);
+}
+
+}  // namespace
+}  // namespace opt
+}  // namespace spvtools
\ No newline at end of file
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index 0129478..349b114 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -237,6 +237,10 @@
                loads and stores. Performed only on entry point call tree
                functions.)");
   printf(R"(
+  --fix-func-call-param
+               fix non memory argument for the function call, replace 
+               accesschain pointer argument with a variable.)");
+  printf(R"(
   --flatten-decorations
                Replace decoration groups with repeated OpDecorate and
                OpMemberDecorate instructions.)");