Fix SSA re-writing in the presence of variable pointers. (#4010)

This fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/3873.

In the presence of variable pointers, the reaching definition may be
another pointer.  For example, the following fragment:

 %2 = OpVariable %_ptr_Input_float Input
%11 = OpVariable %_ptr_Function__ptr_Input_float Function
      OpStore %11 %2
%12 = OpLoad %_ptr_Input_float %11
%13 = OpLoad %float %12

corresponds to the pseudo-code:

layout(location = 0) in flat float *%2
float %13;
float *%12;
float **%11;
*%11 = %2;
%12 = *%11;
%13 = *%12;

which ultimately, should correspond to:

%13 = *%2;

During rewriting, the pointer %12 is found to be replaceable by %2.
However, when processing the load %13 = *%12, the type of %12's reaching
definition is another float pointer (%2), instead of a float value.

When this happens, we need to continue looking up the reaching definition
chain until we get to a float value or a non-target var (i.e. a variable
that cannot be SSA replaced, like %2 in this case since it is a function
argument).
diff --git a/source/opt/ssa_rewrite_pass.cpp b/source/opt/ssa_rewrite_pass.cpp
index 5a56887..3ff0361 100644
--- a/source/opt/ssa_rewrite_pass.cpp
+++ b/source/opt/ssa_rewrite_pass.cpp
@@ -47,6 +47,7 @@
 #include "source/opcode.h"
 #include "source/opt/cfg.h"
 #include "source/opt/mem_pass.h"
+#include "source/opt/types.h"
 #include "source/util/make_unique.h"
 
 // Debug logging (0: Off, 1-N: Verbosity level).  Replace this with the
@@ -326,32 +327,94 @@
 }
 
 bool SSARewriter::ProcessLoad(Instruction* inst, BasicBlock* bb) {
+  // Get the pointer that we are using to load from.
   uint32_t var_id = 0;
   (void)pass_->GetPtr(inst, &var_id);
-  if (pass_->IsTargetVar(var_id)) {
-    // Get the immediate reaching definition for |var_id|.
-    uint32_t val_id = GetReachingDef(var_id, bb);
+
+  // Get the immediate reaching definition for |var_id|.
+  //
+  // In the presence of variable pointers, the reaching definition may be
+  // another pointer.  For example, the following fragment:
+  //
+  //  %2 = OpVariable %_ptr_Input_float Input
+  // %11 = OpVariable %_ptr_Function__ptr_Input_float Function
+  //       OpStore %11 %2
+  // %12 = OpLoad %_ptr_Input_float %11
+  // %13 = OpLoad %float %12
+  //
+  // corresponds to the pseudo-code:
+  //
+  // layout(location = 0) in flat float *%2
+  // float %13;
+  // float *%12;
+  // float **%11;
+  // *%11 = %2;
+  // %12 = *%11;
+  // %13 = *%12;
+  //
+  // which ultimately, should correspond to:
+  //
+  // %13 = *%2;
+  //
+  // During rewriting, the pointer %12 is found to be replaceable by %2 (i.e.,
+  // load_replacement_[12] is 2). However, when processing the load
+  // %13 = *%12, the type of %12's reaching definition is another float
+  // pointer (%2), instead of a float value.
+  //
+  // When this happens, we need to continue looking up the reaching definition
+  // chain until we get to a float value or a non-target var (i.e. a variable
+  // that cannot be SSA replaced, like %2 in this case since it is a function
+  // argument).
+  analysis::DefUseManager* def_use_mgr = pass_->context()->get_def_use_mgr();
+  analysis::TypeManager* type_mgr = pass_->context()->get_type_mgr();
+  analysis::Type* load_type = type_mgr->GetType(inst->type_id());
+  uint32_t val_id = 0;
+  bool found_reaching_def = false;
+  while (!found_reaching_def) {
+    if (!pass_->IsTargetVar(var_id)) {
+      // If the variable we are loading from is not an SSA target (globals,
+      // function parameters), do nothing.
+      return true;
+    }
+
+    val_id = GetReachingDef(var_id, bb);
     if (val_id == 0) {
       return false;
     }
 
-    // Schedule a replacement for the result of this load instruction with
-    // |val_id|. After all the rewriting decisions are made, every use of
-    // this load will be replaced with |val_id|.
-    const uint32_t load_id = inst->result_id();
-    assert(load_replacement_.count(load_id) == 0);
-    load_replacement_[load_id] = val_id;
-    PhiCandidate* defining_phi = GetPhiCandidate(val_id);
-    if (defining_phi) {
-      defining_phi->AddUser(load_id);
+    // If the reaching definition is a pointer type different than the type of
+    // the instruction we are analyzing, then it must be a reference to another
+    // pointer (otherwise, this would be invalid SPIRV).  We continue
+    // de-referencing it by making |val_id| be |var_id|.
+    //
+    // NOTE: if there is no reaching definition instruction, it means |val_id|
+    // is an undef.
+    Instruction* reaching_def_inst = def_use_mgr->GetDef(val_id);
+    if (reaching_def_inst &&
+        !type_mgr->GetType(reaching_def_inst->type_id())->IsSame(load_type)) {
+      var_id = val_id;
+    } else {
+      found_reaching_def = true;
     }
+  }
+
+  // Schedule a replacement for the result of this load instruction with
+  // |val_id|. After all the rewriting decisions are made, every use of
+  // this load will be replaced with |val_id|.
+  uint32_t load_id = inst->result_id();
+  assert(load_replacement_.count(load_id) == 0);
+  load_replacement_[load_id] = val_id;
+  PhiCandidate* defining_phi = GetPhiCandidate(val_id);
+  if (defining_phi) {
+    defining_phi->AddUser(load_id);
+  }
 
 #if SSA_REWRITE_DEBUGGING_LEVEL > 1
-    std::cerr << "\tFound load: "
-              << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
-              << " (replacement for %" << load_id << " is %" << val_id << ")\n";
+  std::cerr << "\tFound load: "
+            << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
+            << " (replacement for %" << load_id << " is %" << val_id << ")\n";
 #endif
-  }
+
   return true;
 }
 
@@ -390,8 +453,8 @@
     }
   }
 
-  // Seal |bb|. This means that all the stores in it have been scanned and it's
-  // ready to feed them into its successors.
+  // Seal |bb|. This means that all the stores in it have been scanned and
+  // it's ready to feed them into its successors.
   SealBlock(bb);
 
 #if SSA_REWRITE_DEBUGGING_LEVEL > 1
@@ -504,8 +567,8 @@
   }
 
   // Scan uses for all inserted Phi instructions. Do this separately from the
-  // registration of the Phi instruction itself to avoid trying to analyze uses
-  // of Phi instructions that have not been registered yet.
+  // registration of the Phi instruction itself to avoid trying to analyze
+  // uses of Phi instructions that have not been registered yet.
   for (Instruction* phi_inst : generated_phis) {
     pass_->get_def_use_mgr()->AnalyzeInstUse(&*phi_inst);
   }
@@ -562,7 +625,8 @@
   // This candidate is now completed.
   phi_candidate->MarkComplete();
 
-  // If |phi_candidate| is not trivial, add it to the list of Phis to generate.
+  // If |phi_candidate| is not trivial, add it to the list of Phis to
+  // generate.
   if (TryRemoveTrivialPhi(phi_candidate) == phi_candidate->result_id()) {
     // If we could not remove |phi_candidate|, it means that it is complete
     // and not trivial. Add it to the list of Phis to generate.
diff --git a/test/opt/local_ssa_elim_test.cpp b/test/opt/local_ssa_elim_test.cpp
index 2ecd238..ff42193 100644
--- a/test/opt/local_ssa_elim_test.cpp
+++ b/test/opt/local_ssa_elim_test.cpp
@@ -3891,6 +3891,44 @@
   SinglePassRunAndMatch<SSARewritePass>(text, true);
 }
 
+// Check support for pointer variables. When pointer variables are used, the
+// computation of reaching definitions may need to follow pointer chains.
+// See https://github.com/KhronosGroup/SPIRV-Tools/issues/3873 for details.
+TEST_F(LocalSSAElimTest, PointerVariables) {
+  const std::string text = R"(
+               OpCapability Shader
+               OpCapability VariablePointers
+               OpExtension "SPV_KHR_variable_pointers"
+               OpMemoryModel Logical Simple
+               OpEntryPoint Fragment %1 "main" %2 %3
+               OpExecutionMode %1 OriginUpperLeft
+      %float = OpTypeFloat 32
+       %void = OpTypeVoid
+          %6 = OpTypeFunction %void
+%_ptr_Input_float = OpTypePointer Input %float
+%_ptr_Output_float = OpTypePointer Output %float
+%_ptr_Function__ptr_Input_float = OpTypePointer Function %_ptr_Input_float
+          %2 = OpVariable %_ptr_Input_float Input
+          %3 = OpVariable %_ptr_Output_float Output
+          %1 = OpFunction %void None %6
+         %10 = OpLabel
+         %11 = OpVariable %_ptr_Function__ptr_Input_float Function
+               OpStore %11 %2
+
+; CHECK-NOT: %12 = OpLoad %_ptr_Input_float %11
+         %12 = OpLoad %_ptr_Input_float %11
+
+; CHECK: %13 = OpLoad %float %2
+         %13 = OpLoad %float %12
+
+               OpStore %3 %13
+               OpReturn
+               OpFunctionEnd
+)";
+
+  SinglePassRunAndMatch<SSARewritePass>(text, true);
+}
+
 // TODO(greg-lunarg): Add tests to verify handling of these cases:
 //
 //    No optimization in the presence of