Fix SSA re-writing in the presence of variable pointers. (#4010)
This fixes https://github.com/KhronosGroup/SPIRV-Tools/issues/3873.
In the presence of variable pointers, the reaching definition may be
another pointer. For example, the following fragment:
%2 = OpVariable %_ptr_Input_float Input
%11 = OpVariable %_ptr_Function__ptr_Input_float Function
OpStore %11 %2
%12 = OpLoad %_ptr_Input_float %11
%13 = OpLoad %float %12
corresponds to the pseudo-code:
layout(location = 0) in flat float *%2
float %13;
float *%12;
float **%11;
*%11 = %2;
%12 = *%11;
%13 = *%12;
which ultimately, should correspond to:
%13 = *%2;
During rewriting, the pointer %12 is found to be replaceable by %2.
However, when processing the load %13 = *%12, the type of %12's reaching
definition is another float pointer (%2), instead of a float value.
When this happens, we need to continue looking up the reaching definition
chain until we get to a float value or a non-target var (i.e. a variable
that cannot be SSA replaced, like %2 in this case since it is a function
argument).
diff --git a/source/opt/ssa_rewrite_pass.cpp b/source/opt/ssa_rewrite_pass.cpp
index 5a56887..3ff0361 100644
--- a/source/opt/ssa_rewrite_pass.cpp
+++ b/source/opt/ssa_rewrite_pass.cpp
@@ -47,6 +47,7 @@
#include "source/opcode.h"
#include "source/opt/cfg.h"
#include "source/opt/mem_pass.h"
+#include "source/opt/types.h"
#include "source/util/make_unique.h"
// Debug logging (0: Off, 1-N: Verbosity level). Replace this with the
@@ -326,32 +327,94 @@
}
bool SSARewriter::ProcessLoad(Instruction* inst, BasicBlock* bb) {
+ // Get the pointer that we are using to load from.
uint32_t var_id = 0;
(void)pass_->GetPtr(inst, &var_id);
- if (pass_->IsTargetVar(var_id)) {
- // Get the immediate reaching definition for |var_id|.
- uint32_t val_id = GetReachingDef(var_id, bb);
+
+ // Get the immediate reaching definition for |var_id|.
+ //
+ // In the presence of variable pointers, the reaching definition may be
+ // another pointer. For example, the following fragment:
+ //
+ // %2 = OpVariable %_ptr_Input_float Input
+ // %11 = OpVariable %_ptr_Function__ptr_Input_float Function
+ // OpStore %11 %2
+ // %12 = OpLoad %_ptr_Input_float %11
+ // %13 = OpLoad %float %12
+ //
+ // corresponds to the pseudo-code:
+ //
+ // layout(location = 0) in flat float *%2
+ // float %13;
+ // float *%12;
+ // float **%11;
+ // *%11 = %2;
+ // %12 = *%11;
+ // %13 = *%12;
+ //
+ // which ultimately, should correspond to:
+ //
+ // %13 = *%2;
+ //
+ // During rewriting, the pointer %12 is found to be replaceable by %2 (i.e.,
+ // load_replacement_[12] is 2). However, when processing the load
+ // %13 = *%12, the type of %12's reaching definition is another float
+ // pointer (%2), instead of a float value.
+ //
+ // When this happens, we need to continue looking up the reaching definition
+ // chain until we get to a float value or a non-target var (i.e. a variable
+ // that cannot be SSA replaced, like %2 in this case since it is a function
+ // argument).
+ analysis::DefUseManager* def_use_mgr = pass_->context()->get_def_use_mgr();
+ analysis::TypeManager* type_mgr = pass_->context()->get_type_mgr();
+ analysis::Type* load_type = type_mgr->GetType(inst->type_id());
+ uint32_t val_id = 0;
+ bool found_reaching_def = false;
+ while (!found_reaching_def) {
+ if (!pass_->IsTargetVar(var_id)) {
+ // If the variable we are loading from is not an SSA target (globals,
+ // function parameters), do nothing.
+ return true;
+ }
+
+ val_id = GetReachingDef(var_id, bb);
if (val_id == 0) {
return false;
}
- // Schedule a replacement for the result of this load instruction with
- // |val_id|. After all the rewriting decisions are made, every use of
- // this load will be replaced with |val_id|.
- const uint32_t load_id = inst->result_id();
- assert(load_replacement_.count(load_id) == 0);
- load_replacement_[load_id] = val_id;
- PhiCandidate* defining_phi = GetPhiCandidate(val_id);
- if (defining_phi) {
- defining_phi->AddUser(load_id);
+ // If the reaching definition is a pointer type different than the type of
+ // the instruction we are analyzing, then it must be a reference to another
+ // pointer (otherwise, this would be invalid SPIRV). We continue
+ // de-referencing it by making |val_id| be |var_id|.
+ //
+ // NOTE: if there is no reaching definition instruction, it means |val_id|
+ // is an undef.
+ Instruction* reaching_def_inst = def_use_mgr->GetDef(val_id);
+ if (reaching_def_inst &&
+ !type_mgr->GetType(reaching_def_inst->type_id())->IsSame(load_type)) {
+ var_id = val_id;
+ } else {
+ found_reaching_def = true;
}
+ }
+
+ // Schedule a replacement for the result of this load instruction with
+ // |val_id|. After all the rewriting decisions are made, every use of
+ // this load will be replaced with |val_id|.
+ uint32_t load_id = inst->result_id();
+ assert(load_replacement_.count(load_id) == 0);
+ load_replacement_[load_id] = val_id;
+ PhiCandidate* defining_phi = GetPhiCandidate(val_id);
+ if (defining_phi) {
+ defining_phi->AddUser(load_id);
+ }
#if SSA_REWRITE_DEBUGGING_LEVEL > 1
- std::cerr << "\tFound load: "
- << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
- << " (replacement for %" << load_id << " is %" << val_id << ")\n";
+ std::cerr << "\tFound load: "
+ << inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
+ << " (replacement for %" << load_id << " is %" << val_id << ")\n";
#endif
- }
+
return true;
}
@@ -390,8 +453,8 @@
}
}
- // Seal |bb|. This means that all the stores in it have been scanned and it's
- // ready to feed them into its successors.
+ // Seal |bb|. This means that all the stores in it have been scanned and
+ // it's ready to feed them into its successors.
SealBlock(bb);
#if SSA_REWRITE_DEBUGGING_LEVEL > 1
@@ -504,8 +567,8 @@
}
// Scan uses for all inserted Phi instructions. Do this separately from the
- // registration of the Phi instruction itself to avoid trying to analyze uses
- // of Phi instructions that have not been registered yet.
+ // registration of the Phi instruction itself to avoid trying to analyze
+ // uses of Phi instructions that have not been registered yet.
for (Instruction* phi_inst : generated_phis) {
pass_->get_def_use_mgr()->AnalyzeInstUse(&*phi_inst);
}
@@ -562,7 +625,8 @@
// This candidate is now completed.
phi_candidate->MarkComplete();
- // If |phi_candidate| is not trivial, add it to the list of Phis to generate.
+ // If |phi_candidate| is not trivial, add it to the list of Phis to
+ // generate.
if (TryRemoveTrivialPhi(phi_candidate) == phi_candidate->result_id()) {
// If we could not remove |phi_candidate|, it means that it is complete
// and not trivial. Add it to the list of Phis to generate.
diff --git a/test/opt/local_ssa_elim_test.cpp b/test/opt/local_ssa_elim_test.cpp
index 2ecd238..ff42193 100644
--- a/test/opt/local_ssa_elim_test.cpp
+++ b/test/opt/local_ssa_elim_test.cpp
@@ -3891,6 +3891,44 @@
SinglePassRunAndMatch<SSARewritePass>(text, true);
}
+// Check support for pointer variables. When pointer variables are used, the
+// computation of reaching definitions may need to follow pointer chains.
+// See https://github.com/KhronosGroup/SPIRV-Tools/issues/3873 for details.
+TEST_F(LocalSSAElimTest, PointerVariables) {
+ const std::string text = R"(
+ OpCapability Shader
+ OpCapability VariablePointers
+ OpExtension "SPV_KHR_variable_pointers"
+ OpMemoryModel Logical Simple
+ OpEntryPoint Fragment %1 "main" %2 %3
+ OpExecutionMode %1 OriginUpperLeft
+ %float = OpTypeFloat 32
+ %void = OpTypeVoid
+ %6 = OpTypeFunction %void
+%_ptr_Input_float = OpTypePointer Input %float
+%_ptr_Output_float = OpTypePointer Output %float
+%_ptr_Function__ptr_Input_float = OpTypePointer Function %_ptr_Input_float
+ %2 = OpVariable %_ptr_Input_float Input
+ %3 = OpVariable %_ptr_Output_float Output
+ %1 = OpFunction %void None %6
+ %10 = OpLabel
+ %11 = OpVariable %_ptr_Function__ptr_Input_float Function
+ OpStore %11 %2
+
+; CHECK-NOT: %12 = OpLoad %_ptr_Input_float %11
+ %12 = OpLoad %_ptr_Input_float %11
+
+; CHECK: %13 = OpLoad %float %2
+ %13 = OpLoad %float %12
+
+ OpStore %3 %13
+ OpReturn
+ OpFunctionEnd
+)";
+
+ SinglePassRunAndMatch<SSARewritePass>(text, true);
+}
+
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//
// No optimization in the presence of