Initial patch for scalar evolution analysis
This patch adds support for the analysis of scalars in loops. It works
by traversing the defuse chain to build a DAG of scalar operations and
then simplifies the DAG by folding constants and grouping like terms.
It represents induction variables as recurrent expressions with respect
to a given loop and can simplify DAGs containing recurrent expression by
rewritting the entire DAG to be a recurrent expression with respect to
the same loop.
diff --git a/Android.mk b/Android.mk
index b0f3adf..b775541 100644
--- a/Android.mk
+++ b/Android.mk
@@ -118,6 +118,8 @@
source/opt/redundancy_elimination.cpp \
source/opt/remove_duplicates_pass.cpp \
source/opt/replace_invalid_opc.cpp \
+ source/opt/scalar_analysis.cpp \
+ source/opt/scalar_analysis_simplification.cpp \
source/opt/scalar_replacement_pass.cpp \
source/opt/set_spec_constant_default_value_pass.cpp \
source/opt/simplification_pass.cpp \
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index e589bd1..6002809 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -77,6 +77,8 @@
reflect.h
remove_duplicates_pass.h
replace_invalid_opc.h
+ scalar_analysis.h
+ scalar_analysis_nodes.h
scalar_replacement_pass.h
set_spec_constant_default_value_pass.h
simplification_pass.h
@@ -151,6 +153,8 @@
redundancy_elimination.cpp
remove_duplicates_pass.cpp
replace_invalid_opc.cpp
+ scalar_analysis.cpp
+ scalar_analysis_simplification.cpp
scalar_replacement_pass.cpp
set_spec_constant_default_value_pass.cpp
simplification_pass.cpp
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index efa3006..856e403 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -45,6 +45,9 @@
if (set & kAnalysisNameMap) {
BuildIdToNameMap();
}
+ if (set & kAnalysisScalarEvolution) {
+ BuildScalarEvolutionAnalysis();
+ }
}
void IRContext::InvalidateAnalysesExceptFor(
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 209d6d7..bb44b49 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -24,6 +24,7 @@
#include "feature_manager.h"
#include "loop_descriptor.h"
#include "module.h"
+#include "scalar_analysis.h"
#include "type_manager.h"
#include <algorithm>
@@ -58,7 +59,8 @@
kAnalysisDominatorAnalysis = 1 << 5,
kAnalysisLoopAnalysis = 1 << 6,
kAnalysisNameMap = 1 << 7,
- kAnalysisEnd = 1 << 8
+ kAnalysisScalarEvolution = 1 << 8,
+ kAnalysisEnd = 1 << 9
};
friend inline Analysis operator|(Analysis lhs, Analysis rhs);
@@ -258,6 +260,15 @@
return type_mgr_.get();
}
+ // Returns a pointer to the scalar evolution analysis. If it is invalid it
+ // will be rebuilt first.
+ opt::ScalarEvolutionAnalysis* GetScalarEvolutionAnalysis() {
+ if (!AreAnalysesValid(kAnalysisScalarEvolution)) {
+ BuildScalarEvolutionAnalysis();
+ }
+ return scalar_evolution_analysis_.get();
+ }
+
// Build the map from the ids to the OpName and OpMemberName instruction
// associated with it.
inline void BuildIdToNameMap();
@@ -444,6 +455,11 @@
valid_analyses_ = valid_analyses_ | kAnalysisCFG;
}
+ void BuildScalarEvolutionAnalysis() {
+ scalar_evolution_analysis_.reset(new opt::ScalarEvolutionAnalysis(this));
+ valid_analyses_ = valid_analyses_ | kAnalysisScalarEvolution;
+ }
+
// Removes all computed dominator and post-dominator trees. This will force
// the context to rebuild the trees on demand.
void ResetDominatorAnalysis() {
@@ -544,6 +560,9 @@
// A map from an id to its corresponding OpName and OpMemberName instructions.
std::unique_ptr<std::multimap<uint32_t, Instruction*>> id_to_name_;
+
+ // The cache scalar evolution analysis node.
+ std::unique_ptr<opt::ScalarEvolutionAnalysis> scalar_evolution_analysis_;
};
inline ir::IRContext::Analysis operator|(ir::IRContext::Analysis lhs,
diff --git a/source/opt/scalar_analysis.cpp b/source/opt/scalar_analysis.cpp
new file mode 100644
index 0000000..ccdb66c
--- /dev/null
+++ b/source/opt/scalar_analysis.cpp
@@ -0,0 +1,638 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "opt/scalar_analysis.h"
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "opt/ir_context.h"
+
+// Transforms a given scalar operation instruction into a DAG representation.
+//
+// 1. Take an instruction and traverse its operands until we reach a
+// constant node or an instruction which we do not know how to compute the
+// value, such as a load.
+//
+// 2. Create a new node for each instruction traversed and build the nodes for
+// the in operands of that instruction as well.
+//
+// 3. Add the operand nodes as children of the first and hash the node. Use the
+// hash to see if the node is already in the cache. We ensure the children are
+// always in sorted order so that two nodes with the same children but inserted
+// in a different order have the same hash and so that the overloaded operator==
+// will return true. If the node is already in the cache return the cached
+// version instead.
+//
+// 4. The created DAG can then be simplified by
+// ScalarAnalysis::SimplifyExpression, implemented in
+// scalar_analysis_simplification.cpp. See that file for further information on
+// the simplification process.
+//
+
+namespace spvtools {
+namespace opt {
+
+uint32_t SENode::NumberOfNodes = 0;
+
+ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(ir::IRContext* context)
+ : context_(context) {
+ // Create and cached the CantComputeNode.
+ cached_cant_compute_ =
+ GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
+ // If operand is can't compute then the whole graph is can't compute.
+ if (operand->IsCantCompute()) return CreateCantComputeNode();
+
+ if (operand->GetType() == SENode::Constant) {
+ return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
+ }
+ std::unique_ptr<SENode> negation_node{new SENegative(this)};
+ negation_node->AddChild(operand);
+ return GetCachedOrAdd(std::move(negation_node));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
+ return GetCachedOrAdd(
+ std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
+ const ir::Loop* loop, SENode* offset, SENode* coefficient) {
+ assert(loop && "Recurrent add expressions must have a valid loop.");
+
+ // If operands are can't compute then the whole graph is can't compute.
+ if (offset->IsCantCompute() || coefficient->IsCantCompute())
+ return CreateCantComputeNode();
+
+ std::unique_ptr<SERecurrentNode> phi_node{new SERecurrentNode(this, loop)};
+ phi_node->AddOffset(offset);
+ phi_node->AddCoefficient(coefficient);
+
+ return GetCachedOrAdd(std::move(phi_node));
+}
+
+SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
+ const ir::Instruction* multiply) {
+ assert(multiply->opcode() == SpvOp::SpvOpIMul &&
+ "Multiply node did not come from a multiply instruction");
+ opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr();
+
+ SENode* op1 =
+ AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
+ SENode* op2 =
+ AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
+
+ return CreateMultiplyNode(op1, op2);
+}
+
+SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
+ SENode* operand_2) {
+ // If operands are can't compute then the whole graph is can't compute.
+ if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
+ return CreateCantComputeNode();
+
+ if (operand_1->GetType() == SENode::Constant &&
+ operand_2->GetType() == SENode::Constant) {
+ return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
+ operand_2->AsSEConstantNode()->FoldToSingleValue());
+ }
+
+ std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
+
+ multiply_node->AddChild(operand_1);
+ multiply_node->AddChild(operand_2);
+
+ return GetCachedOrAdd(std::move(multiply_node));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
+ SENode* operand_2) {
+ // Fold if both operands are constant.
+ if (operand_1->GetType() == SENode::Constant &&
+ operand_2->GetType() == SENode::Constant) {
+ return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
+ operand_2->AsSEConstantNode()->FoldToSingleValue());
+ }
+
+ return CreateAddNode(operand_1, CreateNegation(operand_2));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
+ SENode* operand_2) {
+ // Fold if both operands are constant and the |simplify| flag is true.
+ if (operand_1->GetType() == SENode::Constant &&
+ operand_2->GetType() == SENode::Constant) {
+ return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
+ operand_2->AsSEConstantNode()->FoldToSingleValue());
+ }
+
+ // If operands are can't compute then the whole graph is can't compute.
+ if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
+ return CreateCantComputeNode();
+
+ std::unique_ptr<SENode> add_node{new SEAddNode(this)};
+
+ add_node->AddChild(operand_1);
+ add_node->AddChild(operand_2);
+
+ return GetCachedOrAdd(std::move(add_node));
+}
+
+SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(
+ const ir::Instruction* inst) {
+ auto itr = recurrent_node_map_.find(inst);
+ if (itr != recurrent_node_map_.end()) return itr->second;
+
+ SENode* output = nullptr;
+ switch (inst->opcode()) {
+ case SpvOp::SpvOpPhi: {
+ output = AnalyzePhiInstruction(inst);
+ break;
+ }
+ case SpvOp::SpvOpConstant:
+ case SpvOp::SpvOpConstantNull: {
+ output = AnalyzeConstant(inst);
+ break;
+ }
+ case SpvOp::SpvOpISub:
+ case SpvOp::SpvOpIAdd: {
+ output = AnalyzeAddOp(inst);
+ break;
+ }
+ case SpvOp::SpvOpIMul: {
+ output = AnalyzeMultiplyOp(inst);
+ break;
+ }
+ default: {
+ output = CreateValueUnknownNode(inst);
+ break;
+ }
+ }
+
+ return output;
+}
+
+SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const ir::Instruction* inst) {
+ if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0);
+
+ assert(inst->opcode() == SpvOp::SpvOpConstant);
+ assert(inst->NumInOperands() == 1);
+ int64_t value = 0;
+
+ // Look up the instruction in the constant manager.
+ const opt::analysis::Constant* constant =
+ context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
+
+ if (!constant) return CreateCantComputeNode();
+
+ const opt::analysis::IntConstant* int_constant = constant->AsIntConstant();
+
+ // Exit out if it is a 64 bit integer.
+ if (!int_constant || int_constant->words().size() != 1)
+ return CreateCantComputeNode();
+
+ if (int_constant->type()->AsInteger()->IsSigned()) {
+ value = int_constant->GetS32BitValue();
+ } else {
+ value = int_constant->GetU32BitValue();
+ }
+
+ return CreateConstant(value);
+}
+
+// Handles both addition and subtraction. If the |sub| flag is set then the
+// addition will be op1+(-op2) otherwise op1+op2.
+SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const ir::Instruction* inst) {
+ assert((inst->opcode() == SpvOp::SpvOpIAdd ||
+ inst->opcode() == SpvOp::SpvOpISub) &&
+ "Add node must be created from a OpIAdd or OpISub instruction");
+
+ opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr();
+
+ SENode* op1 =
+ AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
+
+ SENode* op2 =
+ AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
+
+ // To handle subtraction we wrap the second operand in a unary negation node.
+ if (inst->opcode() == SpvOp::SpvOpISub) {
+ op2 = CreateNegation(op2);
+ }
+
+ return CreateAddNode(op1, op2);
+}
+
+SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(
+ const ir::Instruction* phi) {
+ // The phi should only have two incoming value pairs.
+ if (phi->NumInOperands() != 4) {
+ return CreateCantComputeNode();
+ }
+
+ opt::analysis::DefUseManager* def_use = context_->get_def_use_mgr();
+
+ // Get the basic block this instruction belongs to.
+ ir::BasicBlock* basic_block =
+ context_->get_instr_block(const_cast<ir::Instruction*>(phi));
+
+ // And then the function that the basic blocks belongs to.
+ ir::Function* function = basic_block->GetParent();
+
+ // Use the function to get the loop descriptor.
+ ir::LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
+
+ // We only handle phis in loops at the moment.
+ if (!loop_descriptor) return CreateCantComputeNode();
+
+ // Get the innermost loop which this block belongs to.
+ ir::Loop* loop = (*loop_descriptor)[basic_block->id()];
+
+ // If the loop doesn't exist or doesn't have a preheader or latch block, exit
+ // out.
+ if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
+ loop->GetHeaderBlock() != basic_block)
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ std::unique_ptr<SERecurrentNode> phi_node{new SERecurrentNode(this, loop)};
+
+ // We add the node to this map to allow it to be returned before the node is
+ // fully built. This is needed as the subsequent call to AnalyzeInstruction
+ // could lead back to this |phi| instruction so we return the pointer
+ // immediately in AnalyzeInstruction to break the recursion.
+ recurrent_node_map_[phi] = phi_node.get();
+
+ // Traverse the operands of the instruction an create new nodes for each one.
+ for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
+ uint32_t value_id = phi->GetSingleWordInOperand(i);
+ uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
+
+ ir::Instruction* value_inst = def_use->GetDef(value_id);
+ SENode* value_node = AnalyzeInstruction(value_inst);
+
+ // If any operand is CantCompute then the whole graph is CantCompute.
+ if (value_node->IsCantCompute())
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ // If the value is coming from the preheader block then the value is the
+ // initial value of the phi.
+ if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
+ phi_node->AddOffset(value_node);
+ } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
+ // Assumed to be in the form of step + phi.
+ if (value_node->GetType() != SENode::Add)
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ SENode* step_node = nullptr;
+ SENode* phi_operand = nullptr;
+ SENode* operand_1 = value_node->GetChild(0);
+ SENode* operand_2 = value_node->GetChild(1);
+
+ // Find which node is the step term.
+ if (!operand_1->AsSERecurrentNode())
+ step_node = operand_1;
+ else if (!operand_2->AsSERecurrentNode())
+ step_node = operand_2;
+
+ // Find which node is the recurrent expression.
+ if (operand_1->AsSERecurrentNode())
+ phi_operand = operand_1;
+ else if (operand_2->AsSERecurrentNode())
+ phi_operand = operand_2;
+
+ // If it is not in the form step + phi exit out.
+ if (!(step_node && phi_operand))
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ // If the phi operand is not the same phi node exit out.
+ if (phi_operand != phi_node.get())
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ if (!IsLoopInvariant(loop, step_node))
+ return recurrent_node_map_[phi] = CreateCantComputeNode();
+
+ phi_node->AddCoefficient(step_node);
+ }
+ }
+
+ // Once the node is fully built we update the map with the version from the
+ // cache (if it has already been added to the cache).
+ return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
+ const ir::Instruction* inst) {
+ std::unique_ptr<SEValueUnknown> load_node{
+ new SEValueUnknown(this, inst->result_id())};
+ return GetCachedOrAdd(std::move(load_node));
+}
+
+SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
+ return cached_cant_compute_;
+}
+
+// Add the created node into the cache of nodes. If it already exists return it.
+SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
+ std::unique_ptr<SENode> prospective_node) {
+ auto itr = node_cache_.find(prospective_node);
+ if (itr != node_cache_.end()) {
+ return (*itr).get();
+ }
+
+ SENode* raw_ptr_to_node = prospective_node.get();
+ node_cache_.insert(std::move(prospective_node));
+ return raw_ptr_to_node;
+}
+
+bool ScalarEvolutionAnalysis::IsLoopInvariant(const ir::Loop* loop,
+ const SENode* node) const {
+ for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
+ if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
+ const ir::BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
+
+ // If the loop which the recurrent expression belongs to is either |loop
+ // or a nested loop inside |loop| then we assume it is variant.
+ if (loop->IsInsideLoop(header)) {
+ return false;
+ }
+ } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
+ // If the instruction is inside the loop we conservatively assume it is
+ // loop variant.
+ if (loop->IsInsideLoop(unknown->ResultId())) return false;
+ }
+ }
+
+ return true;
+}
+
+SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
+ SENode* node, const ir::Loop* loop) {
+ // Traverse the DAG to find the recurrent expression belonging to |loop|.
+ for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
+ SERecurrentNode* rec = itr->AsSERecurrentNode();
+ if (rec && rec->GetLoop() == loop) {
+ return rec->GetCoefficient();
+ }
+ }
+ return CreateConstant(0);
+}
+
+SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
+ SENode* old_child,
+ SENode* new_child) {
+ // Only handles add.
+ if (parent->GetType() != SENode::Add) return parent;
+
+ std::vector<SENode*> new_children;
+ for (SENode* child : *parent) {
+ if (child == old_child) {
+ new_children.push_back(new_child);
+ } else {
+ new_children.push_back(child);
+ }
+ }
+
+ std::unique_ptr<SENode> add_node{new SEAddNode(this)};
+ for (SENode* child : new_children) {
+ add_node->AddChild(child);
+ }
+
+ return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
+}
+
+// Rebuild the |node| eliminating, if it exists, the recurrent term which
+// belongs to the |loop|.
+SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
+ SENode* node, const ir::Loop* loop) {
+ // If the node is already a recurrent expression belonging to loop then just
+ // return the offset.
+ SERecurrentNode* recurrent = node->AsSERecurrentNode();
+ if (recurrent) {
+ if (recurrent->GetLoop() == loop) {
+ return recurrent->GetOffset();
+ } else {
+ return node;
+ }
+ }
+
+ std::vector<SENode*> new_children;
+ // Otherwise find the recurrent node in the children of this node.
+ for (auto itr : *node) {
+ recurrent = itr->AsSERecurrentNode();
+ if (recurrent && recurrent->GetLoop() == loop) {
+ new_children.push_back(recurrent->GetOffset());
+ } else {
+ new_children.push_back(itr);
+ }
+ }
+
+ std::unique_ptr<SENode> add_node{new SEAddNode(this)};
+ for (SENode* child : new_children) {
+ add_node->AddChild(child);
+ }
+
+ return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
+}
+
+// Return the recurrent term belonging to |loop| if it appears in the graph
+// starting at |node| or null if it doesn't.
+SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(
+ SENode* node, const ir::Loop* loop) {
+ for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
+ SERecurrentNode* rec = itr->AsSERecurrentNode();
+ if (rec && rec->GetLoop() == loop) {
+ return rec;
+ }
+ }
+ return nullptr;
+}
+std::string SENode::AsString() const {
+ switch (GetType()) {
+ case Constant:
+ return "Constant";
+ case RecurrentAddExpr:
+ return "RecurrentAddExpr";
+ case Add:
+ return "Add";
+ case Negative:
+ return "Negative";
+ case Multiply:
+ return "Multiply";
+ case ValueUnknown:
+ return "Value Unknown";
+ case CanNotCompute:
+ return "Can not compute";
+ }
+ return "NULL";
+}
+
+bool SENode::operator==(const SENode& other) const {
+ if (GetType() != other.GetType()) return false;
+
+ if (other.GetChildren().size() != children_.size()) return false;
+
+ const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
+
+ // Check the children are the same, for SERecurrentNodes we need to check the
+ // offset and coefficient manually as the child vector is sorted by ids so the
+ // offset/coefficient information is lost.
+ if (!this_as_recurrent) {
+ for (size_t index = 0; index < children_.size(); ++index) {
+ if (other.GetChildren()[index] != children_[index]) return false;
+ }
+ } else {
+ const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
+
+ // We've already checked the types are the same, this should not fail if
+ // this->AsSERecurrentNode() succeeded.
+ assert(other_as_recurrent);
+
+ if (this_as_recurrent->GetCoefficient() !=
+ other_as_recurrent->GetCoefficient())
+ return false;
+
+ if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
+ return false;
+
+ if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
+ return false;
+ }
+
+ // If we're dealing with a value unknown node check both nodes were created by
+ // the same instruction.
+ if (GetType() == SENode::ValueUnknown) {
+ if (AsSEValueUnknown()->ResultId() !=
+ other.AsSEValueUnknown()->ResultId()) {
+ return false;
+ }
+ }
+
+ if (AsSEConstantNode()) {
+ if (AsSEConstantNode()->FoldToSingleValue() !=
+ other.AsSEConstantNode()->FoldToSingleValue())
+ return false;
+ }
+
+ return true;
+}
+
+bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
+
+namespace {
+// Helper functions to insert 32/64 bit values into the 32 bit hash string. This
+// allows us to add pointers to the string by reinterpreting the pointers as
+// uintptr_t. PushToString will deduce the type, call sizeof on it and use
+// that size to call into the correct PushToStringImpl functor depending on
+// whether it is 32 or 64 bit.
+
+template <typename T, size_t size_of_t>
+struct PushToStringImpl;
+
+template <typename T>
+struct PushToStringImpl<T, 8> {
+ void operator()(T id, std::u32string* str) {
+ str->push_back(static_cast<uint32_t>(id >> 32));
+ str->push_back(static_cast<uint32_t>(id));
+ }
+};
+
+template <typename T>
+struct PushToStringImpl<T, 4> {
+ void operator()(T id, std::u32string* str) {
+ str->push_back(static_cast<uint32_t>(id));
+ }
+};
+
+template <typename T>
+static void PushToString(T id, std::u32string* str) {
+ PushToStringImpl<T, sizeof(T)>{}(id, str);
+}
+
+} // namespace
+
+// Implements the hashing of SENodes.
+size_t SENodeHash::operator()(const SENode* node) const {
+ // Concatinate the terms into a string which we can hash.
+ std::u32string hash_string{};
+
+ // Hashing the type as a string is safer than hashing the enum as the enum is
+ // very likely to collide with constants.
+ for (char ch : node->AsString()) {
+ hash_string.push_back(static_cast<char32_t>(ch));
+ }
+
+ // We just ignore the literal value unless it is a constant.
+ if (node->GetType() == SENode::Constant)
+ PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
+
+ const SERecurrentNode* recurrent = node->AsSERecurrentNode();
+
+ // If we're dealing with a recurrent expression hash the loop as well so that
+ // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
+ if (recurrent) {
+ PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
+ &hash_string);
+
+ // Recurrent expressions can't be hashed using the normal method as the
+ // order of coefficient and offset matters to the hash.
+ PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
+ &hash_string);
+ PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
+ &hash_string);
+
+ return std::hash<std::u32string>{}(hash_string);
+ }
+
+ // Hash the result id of the original instruction which created this node if
+ // it is a value unknown node.
+ if (node->GetType() == SENode::ValueUnknown) {
+ PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
+ }
+
+ // Hash the pointers of the child nodes, each SENode has a unique pointer
+ // associated with it.
+ const std::vector<SENode*>& children = node->GetChildren();
+ for (const SENode* child : children) {
+ PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
+ }
+
+ return std::hash<std::u32string>{}(hash_string);
+}
+
+// This overload is the actual overload used by the node_cache_ set.
+size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
+ return this->operator()(node.get());
+}
+
+void SENode::DumpDot(std::ostream& out, bool recurse) const {
+ size_t unique_id = std::hash<const SENode*>{}(this);
+ out << unique_id << " [label=\"" << AsString() << " ";
+ if (GetType() == SENode::Constant) {
+ out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
+ }
+ out << "\"]\n";
+ for (const SENode* child : children_) {
+ size_t child_unique_id = std::hash<const SENode*>{}(child);
+ out << unique_id << " -> " << child_unique_id << " \n";
+ if (recurse) child->DumpDot(out, true);
+ }
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/scalar_analysis.h b/source/opt/scalar_analysis.h
new file mode 100644
index 0000000..71cc424
--- /dev/null
+++ b/source/opt/scalar_analysis.h
@@ -0,0 +1,156 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_SCALAR_ANALYSIS_H_
+#define SOURCE_OPT_SCALAR_ANALYSIS_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <map>
+#include <memory>
+#include <unordered_set>
+#include <vector>
+
+#include "opt/basic_block.h"
+#include "opt/instruction.h"
+#include "opt/scalar_analysis_nodes.h"
+
+namespace spvtools {
+namespace ir {
+class IRContext;
+class Loop;
+} // namespace ir
+
+namespace opt {
+
+// Manager for the Scalar Evolution analysis. Creates and maintains a DAG of
+// scalar operations generated from analysing the use def graph from incoming
+// instructions. Each node is hashed as it is added so like node (for instance,
+// two induction variables i=0,i++ and j=0,j++) become the same node. After
+// creating a DAG with AnalyzeInstruction it can the be simplified into a more
+// usable form with SimplifyExpression.
+class ScalarEvolutionAnalysis {
+ public:
+ explicit ScalarEvolutionAnalysis(ir::IRContext* context);
+
+ // Create a unary negative node on |operand|.
+ SENode* CreateNegation(SENode* operand);
+
+ // Creates a subtraction between the two operands by adding |operand_1| to the
+ // negation of |operand_2|.
+ SENode* CreateSubtraction(SENode* operand_1, SENode* operand_2);
+
+ // Create an addition node between two operands. The |simplify| when set will
+ // allow the function to return an SEConstant instead of an addition if the
+ // two input operands are also constant.
+ SENode* CreateAddNode(SENode* operand_1, SENode* operand_2);
+
+ // Create a multiply node between two operands.
+ SENode* CreateMultiplyNode(SENode* operand_1, SENode* operand_2);
+
+ // Create a node representing a constant integer.
+ SENode* CreateConstant(int64_t integer);
+
+ // Create a value unknown node, such as a load.
+ SENode* CreateValueUnknownNode(const ir::Instruction* inst);
+
+ // Create a CantComputeNode. Used to exit out of analysis.
+ SENode* CreateCantComputeNode();
+
+ // Create a new recurrent node with |offset| and |coefficient|, with respect
+ // to |loop|.
+ SENode* CreateRecurrentExpression(const ir::Loop* loop, SENode* offset,
+ SENode* coefficient);
+
+ // Construct the DAG by traversing use def chain of |inst|.
+ SENode* AnalyzeInstruction(const ir::Instruction* inst);
+
+ // Simplify the |node| by grouping like terms or if contains a recurrent
+ // expression, rewrite the graph so the whole DAG (from |node| down) is in
+ // terms of that recurrent expression.
+ //
+ // For example.
+ // Induction variable i=0, i++ would produce Rec(0,1) so i+1 could be
+ // transformed into Rec(1,1).
+ //
+ // X+X*2+Y-Y+34-17 would be transformed into 3*X + 17, where X and Y are
+ // ValueUnknown nodes (such as a load instruction).
+ SENode* SimplifyExpression(SENode* node);
+
+ // Add |prospective_node| into the cache and return a raw pointer to it. If
+ // |prospective_node| is already in the cache just return the raw pointer.
+ SENode* GetCachedOrAdd(std::unique_ptr<SENode> prospective_node);
+
+ // Checks that the graph starting from |node| is invariant to the |loop|.
+ bool IsLoopInvariant(const ir::Loop* loop, const SENode* node) const;
+
+ // Find the recurrent term belonging to |loop| in the graph starting from
+ // |node| and return the coefficient of that recurrent term. Constant zero
+ // will be returned if no recurrent could be found. |node| should be in
+ // simplest form.
+ SENode* GetCoefficientFromRecurrentTerm(SENode* node, const ir::Loop* loop);
+
+ // Return a rebuilt graph starting from |node| with the recurrent expression
+ // belonging to |loop| being zeroed out. Returned node will be simplified.
+ SENode* BuildGraphWithoutRecurrentTerm(SENode* node, const ir::Loop* loop);
+
+ // Return the recurrent term belonging to |loop| if it appears in the graph
+ // starting at |node| or null if it doesn't.
+ SERecurrentNode* GetRecurrentTerm(SENode* node, const ir::Loop* loop);
+
+ SENode* UpdateChildNode(SENode* parent, SENode* child, SENode* new_child);
+
+ private:
+ SENode* AnalyzeConstant(const ir::Instruction* inst);
+
+ // Handles both addition and subtraction. If the |instruction| is OpISub
+ // then the resulting node will be op1+(-op2) otherwise if it is OpIAdd then
+ // the result will be op1+op2. |instruction| must be OpIAdd or OpISub.
+ SENode* AnalyzeAddOp(const ir::Instruction* instruction);
+
+ SENode* AnalyzeMultiplyOp(const ir::Instruction* multiply);
+
+ SENode* AnalyzePhiInstruction(const ir::Instruction* phi);
+
+ ir::IRContext* context_;
+
+ // A map of instructions to SENodes. This is used to track recurrent
+ // expressions as they are added when analyzing instructions. Recurrent
+ // expressions come from phi nodes which by nature can include recursion so we
+ // check if nodes have already been built when analyzing instructions.
+ std::map<const ir::Instruction*, SENode*> recurrent_node_map_;
+
+ // On creation we create and cache the CantCompute node so we not need to
+ // perform a needless create step.
+ SENode* cached_cant_compute_;
+
+ // Helper functor to allow two unique_ptr to nodes to be compare. Only
+ // needed
+ // for the unordered_set implementation.
+ struct NodePointersEquality {
+ bool operator()(const std::unique_ptr<SENode>& lhs,
+ const std::unique_ptr<SENode>& rhs) const {
+ return *lhs == *rhs;
+ }
+ };
+
+ // Cache of nodes. All pointers to the nodes are references to the memory
+ // managed by they set.
+ std::unordered_set<std::unique_ptr<SENode>, SENodeHash, NodePointersEquality>
+ node_cache_;
+};
+
+} // namespace opt
+} // namespace spvtools
+#endif // SOURCE_OPT_SCALAR_ANALYSIS_H__
diff --git a/source/opt/scalar_analysis_nodes.h b/source/opt/scalar_analysis_nodes.h
new file mode 100644
index 0000000..094ee8e
--- /dev/null
+++ b/source/opt/scalar_analysis_nodes.h
@@ -0,0 +1,313 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASI,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
+#define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+#include "opt/tree_iterator.h"
+
+namespace spvtools {
+namespace ir {
+class Loop;
+} // namespace ir
+
+namespace opt {
+
+class ScalarEvolutionAnalysis;
+class SEConstantNode;
+class SERecurrentNode;
+class SEAddNode;
+class SEMultiplyNode;
+class SENegative;
+class SEValueUnknown;
+class SECantCompute;
+
+// Abstract class representing a node in the scalar evolution DAG. Each node
+// contains a vector of pointers to its children and each subclass of SENode
+// implements GetType and an As method to allow casting. SENodes can be hashed
+// using the SENodeHash functor. The vector of children is sorted when a node is
+// added. This is important as it allows the hash of X+Y to be the same as Y+X.
+class SENode {
+ public:
+ enum SENodeType {
+ Constant,
+ RecurrentAddExpr,
+ Add,
+ Multiply,
+ Negative,
+ ValueUnknown,
+ CanNotCompute
+ };
+
+ using ChildContainerType = std::vector<SENode*>;
+
+ explicit SENode(opt::ScalarEvolutionAnalysis* parent_analysis)
+ : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {}
+
+ virtual SENodeType GetType() const = 0;
+
+ virtual ~SENode() {}
+
+ virtual inline void AddChild(SENode* child) {
+ // If this is a constant node, assert.
+ if (AsSEConstantNode()) {
+ assert(false && "Trying to add a child node to a constant!");
+ }
+
+ // Find the first point in the vector where |child| is greater than the node
+ // currently in the vector.
+ auto find_first_less_than = [child](const SENode* node) {
+ return child->unique_id_ <= node->unique_id_;
+ };
+
+ auto position = std::find_if_not(children_.begin(), children_.end(),
+ find_first_less_than);
+ // Children are sorted so the hashing and equality operator will be the same
+ // for a node with the same children. X+Y should be the same as Y+X.
+ children_.insert(position, child);
+ }
+
+ // Get the type as an std::string. This is used to represent the node in the
+ // dot output and is used to hash the type as well.
+ std::string AsString() const;
+
+ // Dump the SENode and its immediate children, if |recurse| is true then it
+ // will recurse through all children to print the DAG starting from this node
+ // as a root.
+ void DumpDot(std::ostream& out, bool recurse = false) const;
+
+ // Checks if two nodes are the same by hashing them.
+ bool operator==(const SENode& other) const;
+
+ // Checks if two nodes are not the same by comparing the hashes.
+ bool operator!=(const SENode& other) const;
+
+ // Return the child node at |index|.
+ inline SENode* GetChild(size_t index) { return children_[index]; }
+ inline const SENode* GetChild(size_t index) const { return children_[index]; }
+
+ // Iterator to iterate over the child nodes.
+ using iterator = ChildContainerType::iterator;
+ using const_iterator = ChildContainerType::const_iterator;
+
+ // Iterate over immediate child nodes.
+ iterator begin() { return children_.begin(); }
+ iterator end() { return children_.end(); }
+
+ // Constant overloads for iterating over immediate child nodes.
+ const_iterator begin() const { return children_.cbegin(); }
+ const_iterator end() const { return children_.cend(); }
+ const_iterator cbegin() { return children_.cbegin(); }
+ const_iterator cend() { return children_.cend(); }
+
+ // Iterator to iterate over the entire DAG. Even though we are using the tree
+ // iterator it should still be safe to iterate over. However, nodes with
+ // multiple parents will be visited multiple times, unlike in a tree.
+ using dag_iterator = TreeDFIterator<SENode>;
+ using const_dag_iterator = TreeDFIterator<const SENode>;
+
+ // Iterate over all child nodes in the graph.
+ dag_iterator graph_begin() { return dag_iterator(this); }
+ dag_iterator graph_end() { return dag_iterator(); }
+ const_dag_iterator graph_begin() const { return graph_cbegin(); }
+ const_dag_iterator graph_end() const { return graph_cend(); }
+ const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); }
+ const_dag_iterator graph_cend() const { return const_dag_iterator(); }
+
+ // Return the vector of immediate children.
+ const ChildContainerType& GetChildren() const { return children_; }
+ ChildContainerType& GetChildren() { return children_; }
+
+ // Return true if this node is a cant compute node.
+ bool IsCantCompute() const { return GetType() == CanNotCompute; }
+
+// Implements a casting method for each type.
+#define DeclareCastMethod(target) \
+ virtual target* As##target() { return nullptr; } \
+ virtual const target* As##target() const { return nullptr; }
+ DeclareCastMethod(SEConstantNode);
+ DeclareCastMethod(SERecurrentNode);
+ DeclareCastMethod(SEAddNode);
+ DeclareCastMethod(SEMultiplyNode);
+ DeclareCastMethod(SENegative);
+ DeclareCastMethod(SEValueUnknown);
+ DeclareCastMethod(SECantCompute);
+#undef DeclareCastMethod
+
+ // Get the analysis which has this node in its cache.
+ inline opt::ScalarEvolutionAnalysis* GetParentAnalysis() const {
+ return parent_analysis_;
+ }
+
+ protected:
+ ChildContainerType children_;
+
+ opt::ScalarEvolutionAnalysis* parent_analysis_;
+
+ // The unique id of this node, assigned on creation by incrementing the static
+ // node count.
+ uint32_t unique_id_;
+
+ // The number of nodes created.
+ static uint32_t NumberOfNodes;
+};
+
+// Function object to handle the hashing of SENodes. Hashing algorithm hashes
+// the type (as a string), the literal value of any constants, and the child
+// pointers which are assumed to be unique.
+struct SENodeHash {
+ size_t operator()(const std::unique_ptr<SENode>& node) const;
+ size_t operator()(const SENode* node) const;
+};
+
+// A node representing a constant integer.
+class SEConstantNode : public SENode {
+ public:
+ SEConstantNode(opt::ScalarEvolutionAnalysis* parent_analysis, int64_t value)
+ : SENode(parent_analysis), literal_value_(value) {}
+
+ SENodeType GetType() const final { return Constant; }
+
+ int64_t FoldToSingleValue() const { return literal_value_; }
+
+ SEConstantNode* AsSEConstantNode() override { return this; }
+ const SEConstantNode* AsSEConstantNode() const override { return this; }
+
+ inline void AddChild(SENode*) final {
+ assert(false && "Attempting to add a child to a constant node!");
+ }
+
+ protected:
+ int64_t literal_value_;
+};
+
+// A node representing a recurrent expression in the code. A recurrent
+// expression is an expression whose value can be expressed as a linear
+// expression of the loop iterations. Such as an induction variable. The actual
+// value of a recurrent expression is coefficent_ * iteration + offset_, hence
+// an induction variable i=0, i++ becomes a recurrent expression with an offset
+// of zero and a coefficient of one.
+class SERecurrentNode : public SENode {
+ public:
+ SERecurrentNode(opt::ScalarEvolutionAnalysis* parent_analysis,
+ const ir::Loop* loop)
+ : SENode(parent_analysis), loop_(loop) {}
+
+ SENodeType GetType() const final { return RecurrentAddExpr; }
+
+ inline void AddCoefficient(SENode* child) {
+ coefficient_ = child;
+ SENode::AddChild(child);
+ }
+
+ inline void AddOffset(SENode* child) {
+ offset_ = child;
+ SENode::AddChild(child);
+ }
+
+ inline const SENode* GetCoefficient() const { return coefficient_; }
+ inline SENode* GetCoefficient() { return coefficient_; }
+
+ inline const SENode* GetOffset() const { return offset_; }
+ inline SENode* GetOffset() { return offset_; }
+
+ // Return the loop which this recurrent expression is recurring within.
+ const ir::Loop* GetLoop() const { return loop_; }
+
+ SERecurrentNode* AsSERecurrentNode() override { return this; }
+ const SERecurrentNode* AsSERecurrentNode() const override { return this; }
+
+ private:
+ SENode* coefficient_;
+ SENode* offset_;
+ const ir::Loop* loop_;
+};
+
+// A node representing an addition operation between child nodes.
+class SEAddNode : public SENode {
+ public:
+ explicit SEAddNode(opt::ScalarEvolutionAnalysis* parent_analysis)
+ : SENode(parent_analysis) {}
+
+ SENodeType GetType() const final { return Add; }
+
+ SEAddNode* AsSEAddNode() override { return this; }
+ const SEAddNode* AsSEAddNode() const override { return this; }
+};
+
+// A node representing a multiply operation between child nodes.
+class SEMultiplyNode : public SENode {
+ public:
+ explicit SEMultiplyNode(opt::ScalarEvolutionAnalysis* parent_analysis)
+ : SENode(parent_analysis) {}
+
+ SENodeType GetType() const final { return Multiply; }
+
+ SEMultiplyNode* AsSEMultiplyNode() override { return this; }
+ const SEMultiplyNode* AsSEMultiplyNode() const override { return this; }
+};
+
+// A node representing a unary negative operation.
+class SENegative : public SENode {
+ public:
+ explicit SENegative(opt::ScalarEvolutionAnalysis* parent_analysis)
+ : SENode(parent_analysis) {}
+
+ SENodeType GetType() const final { return Negative; }
+
+ SENegative* AsSENegative() override { return this; }
+ const SENegative* AsSENegative() const override { return this; }
+};
+
+// A node representing a value which we do not know the value of, such as a load
+// instruction.
+class SEValueUnknown : public SENode {
+ public:
+ // SEValueUnknowns must come from an instruction |unique_id| is the unique id
+ // of that instruction. This is so we cancompare value unknowns and have a
+ // unique value unknown for each instruction.
+ SEValueUnknown(opt::ScalarEvolutionAnalysis* parent_analysis,
+ uint32_t result_id)
+ : SENode(parent_analysis), result_id_(result_id) {}
+
+ SENodeType GetType() const final { return ValueUnknown; }
+
+ SEValueUnknown* AsSEValueUnknown() override { return this; }
+ const SEValueUnknown* AsSEValueUnknown() const override { return this; }
+
+ inline uint32_t ResultId() const { return result_id_; }
+
+ private:
+ uint32_t result_id_;
+};
+
+// A node which we cannot reason about at all.
+class SECantCompute : public SENode {
+ public:
+ explicit SECantCompute(opt::ScalarEvolutionAnalysis* parent_analysis)
+ : SENode(parent_analysis) {}
+
+ SENodeType GetType() const final { return CanNotCompute; }
+
+ SECantCompute* AsSECantCompute() override { return this; }
+ const SECantCompute* AsSECantCompute() const override { return this; }
+};
+
+} // namespace opt
+} // namespace spvtools
+#endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
diff --git a/source/opt/scalar_analysis_simplification.cpp b/source/opt/scalar_analysis_simplification.cpp
new file mode 100644
index 0000000..018896a
--- /dev/null
+++ b/source/opt/scalar_analysis_simplification.cpp
@@ -0,0 +1,539 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "opt/scalar_analysis.h"
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <set>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+// Simplifies scalar analysis DAGs.
+//
+// 1. Given a node passed to SimplifyExpression we first simplify the graph by
+// calling SimplifyPolynomial. This groups like nodes following basic arithmetic
+// rules, so multiple adds of the same load instruction could be grouped into a
+// single multiply of that instruction. SimplifyPolynomial will traverse the DAG
+// and build up an accumulator buffer for each class of instruction it finds.
+// For example take the loop:
+// for (i=0, i<N; i++) { i+B+23+4+B+C; }
+// In this example the expression "i+B+23+4+B+C" has four classes of
+// instruction, induction variable i, the two value unknowns B and C, and the
+// constants. The accumulator buffer is then used to rebuild the graph using
+// the accumulation of each type. This example would then be folded into
+// i+2*B+C+27.
+//
+// This new graph contains a single add node (or if only one type found then
+// just that node) with each of the like terms (or multiplication node) as a
+// child.
+//
+// 2. FoldRecurrentAddExpressions is then called on this new DAG. This will take
+// RecurrentAddExpressions which are with respect to the same loop and fold them
+// into a single new RecurrentAddExpression with respect to that same loop. An
+// expression can have multiple RecurrentAddExpression's with respect to
+// different loops in the case of nested loops. These expressions cannot be
+// folded further. For example:
+//
+// for (i=0; i<N;i++) for(j=0,k=1; j<N;++j,++k)
+//
+// The 'j' and 'k' are RecurrentAddExpression with respect to the second loop
+// and 'i' to the first. If 'j' and 'k' are used in an expression together then
+// they will be folded into a new RecurrentAddExpression with respect to the
+// second loop in that expression.
+//
+//
+// 3. If the DAG now only contains a single RecurrentAddExpression we can now
+// perform a final optimization SimplifyRecurrentAddExpression. This will
+// transform the entire DAG into a RecurrentAddExpression. Additions to the
+// RecurrentAddExpression are added to the offset field and multiplications to
+// the coefficient.
+//
+
+namespace spvtools {
+namespace opt {
+
+// Implementation of the functions which are used to simplify the graph. Graphs
+// of unknowns, multiplies, additions, and constants can be turned into a linear
+// add node with each term as a child. For instance a large graph built from, X
+// + X*2 + Y - Y*3 + 4 - 1, would become a single add expression with the
+// children X*3, -Y*2, and the constant 3. Graphs containing a recurrent
+// expression will be simplified to represent the entire graph around a single
+// recurrent expression. So for an induction variable (i=0, i++) if you add 1 to
+// i in an expression we can rewrite the graph of that expression to be a single
+// recurrent expression of (i=1,i++).
+class SENodeSimplifyImpl {
+ public:
+ SENodeSimplifyImpl(ScalarEvolutionAnalysis* analysis,
+ SENode* node_to_simplify)
+ : analysis_(*analysis),
+ node_(node_to_simplify),
+ constant_accumulator_(0) {}
+
+ // Return the result of the simplification.
+ SENode* Simplify();
+
+ private:
+ // Recursively descend through the graph to build up the accumulator objects
+ // which are used to flatten the graph. |child| is the node currenty being
+ // traversed and the |negation| flag is used to signify that this operation
+ // was preceded by a unary negative operation and as such the result should be
+ // negated.
+ void GatherAccumulatorsFromChildNodes(SENode* new_node, SENode* child,
+ bool negation);
+
+ // Given a |multiply| node add to the accumulators for the term type within
+ // the |multiply| expression. Will return true if the accumulators could be
+ // calculated successfully. If the |multiply| is in any form other than
+ // unknown*constant then we return false. |negation| signifies that the
+ // operation was preceded by a unary negative.
+ bool AccumulatorsFromMultiply(SENode* multiply, bool negation);
+
+ SERecurrentNode* UpdateCoefficient(SERecurrentNode* recurrent,
+ int64_t coefficient_update) const;
+
+ // If the graph contains a recurrent expression, ie, an expression with the
+ // loop iterations as a term in the expression, then the whole expression
+ // can be rewritten to be a recurrent expression.
+ SENode* SimplifyRecurrentAddExpression(SERecurrentNode* node);
+
+ // Simplify the whole graph by linking like terms together in a single flat
+ // add node. So X*2 + Y -Y + 3 +6 would become X*2 + 9. Where X and Y are a
+ // ValueUnknown node (i.e, a load) or a recurrent expression.
+ SENode* SimplifyPolynomial();
+
+ // Each recurrent expression is an expression with respect to a specific loop.
+ // If we have two different recurrent terms with respect to the same loop in a
+ // single expression then we can fold those terms into a single new term.
+ // For instance:
+ //
+ // induction i = 0, i++
+ // temp = i*10
+ // array[i+temp]
+ //
+ // We can fold the i + temp into a single expression. Rec(0,1) + Rec(0,10) can
+ // become Rec(0,11).
+ SENode* FoldRecurrentAddExpressions(SENode*);
+
+ // We can eliminate recurrent expressions which have a coefficient of zero by
+ // replacing them with their offset value. We are able to do this because a
+ // recurrent expression represents the equation coefficient*iterations +
+ // offset.
+ SENode* EliminateZeroCoefficientRecurrents(SENode* node);
+
+ // A reference the the analysis which requested the simplification.
+ ScalarEvolutionAnalysis& analysis_;
+
+ // The node being simplified.
+ SENode* node_;
+
+ // An accumulator of the net result of all the constant operations performed
+ // in a graph.
+ int64_t constant_accumulator_;
+
+ // An accumulator for each of the non constant terms in the graph.
+ std::map<SENode*, int64_t> accumulators_;
+};
+
+// From a |multiply| build up the accumulator objects.
+bool SENodeSimplifyImpl::AccumulatorsFromMultiply(SENode* multiply,
+ bool negation) {
+ if (multiply->GetChildren().size() != 2 ||
+ multiply->GetType() != SENode::Multiply)
+ return false;
+
+ SENode* operand_1 = multiply->GetChild(0);
+ SENode* operand_2 = multiply->GetChild(1);
+
+ SENode* value_unknown = nullptr;
+ SENode* constant = nullptr;
+
+ // Work out which operand is the unknown value.
+ if (operand_1->GetType() == SENode::ValueUnknown ||
+ operand_1->GetType() == SENode::RecurrentAddExpr)
+ value_unknown = operand_1;
+ else if (operand_2->GetType() == SENode::ValueUnknown ||
+ operand_2->GetType() == SENode::RecurrentAddExpr)
+ value_unknown = operand_2;
+
+ // Work out which operand is the constant coefficient.
+ if (operand_1->GetType() == SENode::Constant)
+ constant = operand_1;
+ else if (operand_2->GetType() == SENode::Constant)
+ constant = operand_2;
+
+ // If the expression is not a variable multiplied by a constant coefficient,
+ // exit out.
+ if (!(value_unknown && constant)) {
+ return false;
+ }
+
+ int64_t sign = negation ? -1 : 1;
+
+ auto iterator = accumulators_.find(value_unknown);
+ int64_t new_value = constant->AsSEConstantNode()->FoldToSingleValue() * sign;
+ // Add the result of the multiplication to the accumulators.
+ if (iterator != accumulators_.end()) {
+ (*iterator).second += new_value;
+ } else {
+ accumulators_.insert({value_unknown, new_value});
+ }
+
+ return true;
+}
+
+SENode* SENodeSimplifyImpl::Simplify() {
+ // We only handle graphs with an addition, multiplication, or negation, at the
+ // root.
+ if (node_->GetType() != SENode::Add && node_->GetType() != SENode::Multiply &&
+ node_->GetType() != SENode::Negative)
+ return node_;
+
+ SENode* simplified_polynomial = SimplifyPolynomial();
+
+ SERecurrentNode* recurrent_expr = nullptr;
+ node_ = simplified_polynomial;
+
+ // Fold recurrent expressions which are with respect to the same loop into a
+ // single recurrent expression.
+ simplified_polynomial = FoldRecurrentAddExpressions(simplified_polynomial);
+
+ simplified_polynomial =
+ EliminateZeroCoefficientRecurrents(simplified_polynomial);
+
+ // Traverse the immediate children of the new node to find the recurrent
+ // expression. If there is more than one there is nothing further we can do.
+ for (SENode* child : simplified_polynomial->GetChildren()) {
+ if (child->GetType() == SENode::RecurrentAddExpr) {
+ recurrent_expr = child->AsSERecurrentNode();
+ }
+ }
+
+ // We need to count the number of unique recurrent expressions in the DAG to
+ // ensure there is only one.
+ for (auto child_iterator = simplified_polynomial->graph_begin();
+ child_iterator != simplified_polynomial->graph_end(); ++child_iterator) {
+ if (child_iterator->GetType() == SENode::RecurrentAddExpr &&
+ recurrent_expr != child_iterator->AsSERecurrentNode()) {
+ return simplified_polynomial;
+ }
+ }
+
+ if (recurrent_expr) {
+ return SimplifyRecurrentAddExpression(recurrent_expr);
+ }
+
+ return simplified_polynomial;
+}
+
+// Traverse the graph to build up the accumulator objects.
+void SENodeSimplifyImpl::GatherAccumulatorsFromChildNodes(SENode* new_node,
+ SENode* child,
+ bool negation) {
+ int32_t sign = negation ? -1 : 1;
+
+ if (child->GetType() == SENode::Constant) {
+ // Collect all the constants and add them together.
+ constant_accumulator_ +=
+ child->AsSEConstantNode()->FoldToSingleValue() * sign;
+
+ } else if (child->GetType() == SENode::ValueUnknown ||
+ child->GetType() == SENode::RecurrentAddExpr) {
+ // To rebuild the graph of X+X+X*2 into 4*X we count the occurrences of X
+ // and create a new node of count*X after. X can either be a ValueUnknown or
+ // a RecurrentAddExpr. The count for each X is stored in the accumulators_
+ // map.
+
+ auto iterator = accumulators_.find(child);
+ // If we've encountered this term before add to the accumulator for it.
+ if (iterator == accumulators_.end())
+ accumulators_.insert({child, sign});
+ else
+ iterator->second += sign;
+
+ } else if (child->GetType() == SENode::Multiply) {
+ if (!AccumulatorsFromMultiply(child, negation)) {
+ new_node->AddChild(child);
+ }
+
+ } else if (child->GetType() == SENode::Add) {
+ for (SENode* next_child : *child) {
+ GatherAccumulatorsFromChildNodes(new_node, next_child, negation);
+ }
+
+ } else if (child->GetType() == SENode::Negative) {
+ SENode* negated_node = child->GetChild(0);
+ GatherAccumulatorsFromChildNodes(new_node, negated_node, !negation);
+ } else {
+ // If we can't work out how to fold the expression just add it back into
+ // the graph.
+ new_node->AddChild(child);
+ }
+}
+
+SERecurrentNode* SENodeSimplifyImpl::UpdateCoefficient(
+ SERecurrentNode* recurrent, int64_t coefficient_update) const {
+ std::unique_ptr<SERecurrentNode> new_recurrent_node{new SERecurrentNode(
+ recurrent->GetParentAnalysis(), recurrent->GetLoop())};
+
+ SENode* new_coefficient = analysis_.CreateMultiplyNode(
+ recurrent->GetCoefficient(),
+ analysis_.CreateConstant(coefficient_update));
+
+ // See if the node can be simplified.
+ SENode* simplified = analysis_.SimplifyExpression(new_coefficient);
+ if (simplified->GetType() != SENode::CanNotCompute)
+ new_coefficient = simplified;
+
+ if (coefficient_update < 0) {
+ new_recurrent_node->AddOffset(
+ analysis_.CreateNegation(recurrent->GetOffset()));
+ } else {
+ new_recurrent_node->AddOffset(recurrent->GetOffset());
+ }
+
+ new_recurrent_node->AddCoefficient(new_coefficient);
+
+ return analysis_.GetCachedOrAdd(std::move(new_recurrent_node))
+ ->AsSERecurrentNode();
+}
+
+// Simplify all the terms in the polynomial function.
+SENode* SENodeSimplifyImpl::SimplifyPolynomial() {
+ std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
+
+ // Traverse the graph and gather the accumulators from it.
+ GatherAccumulatorsFromChildNodes(new_add.get(), node_, false);
+
+ // Fold all the constants into a single constant node.
+ if (constant_accumulator_ != 0) {
+ new_add->AddChild(analysis_.CreateConstant(constant_accumulator_));
+ }
+
+ for (auto& pair : accumulators_) {
+ SENode* term = pair.first;
+ int64_t count = pair.second;
+
+ // We can eliminate the term completely.
+ if (count == 0) continue;
+
+ if (count == 1) {
+ new_add->AddChild(term);
+ } else if (count == -1 && term->GetType() != SENode::RecurrentAddExpr) {
+ // If the count is -1 we can just add a negative version of that node,
+ // unless it is a recurrent expression as we would rather the negative
+ // goes on the recurrent expressions children. This makes it easier to
+ // work with in other places.
+ new_add->AddChild(analysis_.CreateNegation(term));
+ } else {
+ // Output value unknown terms as count*term and output recurrent
+ // expression terms as rec(offset, coefficient + count) offset and
+ // coefficient are the same as in the original expression.
+ if (term->GetType() == SENode::ValueUnknown) {
+ SENode* count_as_constant = analysis_.CreateConstant(count);
+ new_add->AddChild(
+ analysis_.CreateMultiplyNode(count_as_constant, term));
+ } else {
+ assert(term->GetType() == SENode::RecurrentAddExpr &&
+ "We only handle value unknowns or recurrent expressions");
+
+ // Create a new recurrent expression by adding the count to the
+ // coefficient of the old one.
+ new_add->AddChild(UpdateCoefficient(term->AsSERecurrentNode(), count));
+ }
+ }
+ }
+
+ // If there is only one term in the addition left just return that term.
+ if (new_add->GetChildren().size() == 1) {
+ return new_add->GetChild(0);
+ }
+
+ // If there are no terms left in the addition just return 0.
+ if (new_add->GetChildren().size() == 0) {
+ return analysis_.CreateConstant(0);
+ }
+
+ return analysis_.GetCachedOrAdd(std::move(new_add));
+}
+
+SENode* SENodeSimplifyImpl::FoldRecurrentAddExpressions(SENode* root) {
+ std::unique_ptr<SEAddNode> new_node{new SEAddNode(&analysis_)};
+
+ // A mapping of loops to the list of recurrent expressions which are with
+ // respect to those loops.
+ std::map<const ir::Loop*, std::vector<std::pair<SERecurrentNode*, bool>>>
+ loops_to_recurrent{};
+
+ bool has_multiple_same_loop_recurrent_terms = false;
+
+ for (SENode* child : *root) {
+ bool negation = false;
+
+ if (child->GetType() == SENode::Negative) {
+ child = child->GetChild(0);
+ negation = true;
+ }
+
+ if (child->GetType() == SENode::RecurrentAddExpr) {
+ const ir::Loop* loop = child->AsSERecurrentNode()->GetLoop();
+
+ SERecurrentNode* rec = child->AsSERecurrentNode();
+ if (loops_to_recurrent.find(loop) == loops_to_recurrent.end()) {
+ loops_to_recurrent[loop] = {std::make_pair(rec, negation)};
+ } else {
+ loops_to_recurrent[loop].push_back(std::make_pair(rec, negation));
+ has_multiple_same_loop_recurrent_terms = true;
+ }
+ } else {
+ new_node->AddChild(child);
+ }
+ }
+
+ if (!has_multiple_same_loop_recurrent_terms) return root;
+
+ for (auto pair : loops_to_recurrent) {
+ std::vector<std::pair<SERecurrentNode*, bool>>& recurrent_expressions =
+ pair.second;
+ const ir::Loop* loop = pair.first;
+
+ std::unique_ptr<SENode> new_coefficient{new SEAddNode(&analysis_)};
+ std::unique_ptr<SENode> new_offset{new SEAddNode(&analysis_)};
+
+ for (auto node_pair : recurrent_expressions) {
+ SERecurrentNode* node = node_pair.first;
+ bool negative = node_pair.second;
+
+ if (!negative) {
+ new_coefficient->AddChild(node->GetCoefficient());
+ new_offset->AddChild(node->GetOffset());
+ } else {
+ new_coefficient->AddChild(
+ analysis_.CreateNegation(node->GetCoefficient()));
+ new_offset->AddChild(analysis_.CreateNegation(node->GetOffset()));
+ }
+ }
+
+ std::unique_ptr<SERecurrentNode> new_recurrent{
+ new SERecurrentNode(&analysis_, loop)};
+
+ SENode* new_coefficient_simplified =
+ analysis_.SimplifyExpression(new_coefficient.get());
+
+ SENode* new_offset_simplified =
+ analysis_.SimplifyExpression(new_offset.get());
+
+ if (new_coefficient_simplified->GetType() == SENode::Constant &&
+ new_coefficient_simplified->AsSEConstantNode()->FoldToSingleValue() ==
+ 0) {
+ return new_offset_simplified;
+ }
+
+ new_recurrent->AddCoefficient(new_coefficient_simplified);
+ new_recurrent->AddOffset(new_offset_simplified);
+
+ new_node->AddChild(analysis_.GetCachedOrAdd(std::move(new_recurrent)));
+ }
+
+ // If we only have one child in the add just return that.
+ if (new_node->GetChildren().size() == 1) {
+ return new_node->GetChild(0);
+ }
+
+ return analysis_.GetCachedOrAdd(std::move(new_node));
+}
+
+SENode* SENodeSimplifyImpl::EliminateZeroCoefficientRecurrents(SENode* node) {
+ if (node->GetType() != SENode::Add) return node;
+
+ bool has_change = false;
+
+ std::vector<SENode*> new_children{};
+ for (SENode* child : *node) {
+ if (child->GetType() == SENode::RecurrentAddExpr) {
+ SENode* coefficient = child->AsSERecurrentNode()->GetCoefficient();
+ // If coefficient is zero then we can eliminate the recurrent expression
+ // entirely and just return the offset as the recurrent expression is
+ // representing the equation coefficient*iterations + offset.
+ if (coefficient->GetType() == SENode::Constant &&
+ coefficient->AsSEConstantNode()->FoldToSingleValue() == 0) {
+ new_children.push_back(child->AsSERecurrentNode()->GetOffset());
+ has_change = true;
+ } else {
+ new_children.push_back(child);
+ }
+ } else {
+ new_children.push_back(child);
+ }
+ }
+
+ if (!has_change) return node;
+
+ std::unique_ptr<SENode> new_add{new SEAddNode(node_->GetParentAnalysis())};
+
+ for (SENode* child : new_children) {
+ new_add->AddChild(child);
+ }
+
+ return analysis_.GetCachedOrAdd(std::move(new_add));
+}
+
+SENode* SENodeSimplifyImpl::SimplifyRecurrentAddExpression(
+ SERecurrentNode* recurrent_expr) {
+ const std::vector<SENode*>& children = node_->GetChildren();
+
+ std::unique_ptr<SERecurrentNode> recurrent_node{new SERecurrentNode(
+ recurrent_expr->GetParentAnalysis(), recurrent_expr->GetLoop())};
+
+ // Create and simplify the new offset node.
+ std::unique_ptr<SENode> new_offset{
+ new SEAddNode(recurrent_expr->GetParentAnalysis())};
+ new_offset->AddChild(recurrent_expr->GetOffset());
+
+ for (SENode* child : children) {
+ if (child->GetType() != SENode::RecurrentAddExpr) {
+ new_offset->AddChild(child);
+ }
+ }
+
+ // Simplify the new offset.
+ SENode* simplified_child = analysis_.SimplifyExpression(new_offset.get());
+
+ // If the child can be simplified, add the simplified form otherwise, add it
+ // via the usual caching mechanism.
+ if (simplified_child->GetType() != SENode::CanNotCompute) {
+ recurrent_node->AddOffset(simplified_child);
+ } else {
+ recurrent_expr->AddOffset(analysis_.GetCachedOrAdd(std::move(new_offset)));
+ }
+
+ recurrent_node->AddCoefficient(recurrent_expr->GetCoefficient());
+
+ return analysis_.GetCachedOrAdd(std::move(recurrent_node));
+}
+
+/*
+ * Scalar Analysis simplification public methods.
+ */
+
+SENode* ScalarEvolutionAnalysis::SimplifyExpression(SENode* node) {
+ SENodeSimplifyImpl impl{this, node};
+
+ return impl.Simplify();
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index 5b8ace1..de1cd16 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -301,8 +301,13 @@
SRCS simplification_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)
-
add_spvtools_unittest(TARGET copy_prop_array
SRCS copy_prop_array_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)
+
+add_spvtools_unittest(TARGET scalar_analysis
+ SRCS scalar_analysis.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
diff --git a/test/opt/scalar_analysis.cpp b/test/opt/scalar_analysis.cpp
new file mode 100644
index 0000000..a73953e
--- /dev/null
+++ b/test/opt/scalar_analysis.cpp
@@ -0,0 +1,1228 @@
+// Copyright (c) 2018 Google LLC.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gmock/gmock.h>
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "assembly_builder.h"
+#include "function_utils.h"
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+#include "opt/iterator.h"
+#include "opt/loop_descriptor.h"
+#include "opt/pass.h"
+#include "opt/scalar_analysis.h"
+#include "opt/tree_iterator.h"
+
+namespace {
+
+using namespace spvtools;
+using ::testing::UnorderedElementsAre;
+
+using ScalarAnalysisTest = PassTest<::testing::Test>;
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 410 core
+layout (location = 1) out float array[10];
+void main() {
+ for (int i = 0; i < 10; ++i) {
+ array[i] = array[i+1];
+ }
+}
+*/
+TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %24
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %24 "array"
+ OpDecorate %24 Location 1
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpConstant %6 0
+ %16 = OpConstant %6 10
+ %17 = OpTypeBool
+ %19 = OpTypeFloat 32
+ %20 = OpTypeInt 32 0
+ %21 = OpConstant %20 10
+ %22 = OpTypeArray %19 %21
+ %23 = OpTypePointer Output %22
+ %24 = OpVariable %23 Output
+ %27 = OpConstant %6 1
+ %29 = OpTypePointer Output %19
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ OpBranch %10
+ %10 = OpLabel
+ %35 = OpPhi %6 %9 %5 %34 %13
+ OpLoopMerge %12 %13 None
+ OpBranch %14
+ %14 = OpLabel
+ %18 = OpSLessThan %17 %35 %16
+ OpBranchConditional %18 %11 %12
+ %11 = OpLabel
+ %28 = OpIAdd %6 %35 %27
+ %30 = OpAccessChain %29 %24 %28
+ %31 = OpLoad %19 %30
+ %32 = OpAccessChain %29 %24 %35
+ OpStore %32 %31
+ OpBranch %13
+ %13 = OpLabel
+ %34 = OpIAdd %6 %35 %27
+ OpBranch %10
+ %12 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 4);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* store = nullptr;
+ const ir::Instruction* load = nullptr;
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
+ if (inst.opcode() == SpvOp::SpvOpStore) {
+ store = &inst;
+ }
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ load = &inst;
+ }
+ }
+
+ EXPECT_NE(load, nullptr);
+ EXPECT_NE(store, nullptr);
+
+ ir::Instruction* access_chain =
+ context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
+
+ ir::Instruction* child = context->get_def_use_mgr()->GetDef(
+ access_chain->GetSingleWordInOperand(1));
+ const opt::SENode* node = analysis.AnalyzeInstruction(child);
+
+ EXPECT_NE(node, nullptr);
+
+ // Unsimplified node should have the form of ADD(REC(0,1), 1)
+ EXPECT_EQ(node->GetType(), opt::SENode::Add);
+
+ const opt::SENode* child_1 = node->GetChild(0);
+ EXPECT_TRUE(child_1->GetType() == opt::SENode::Constant ||
+ child_1->GetType() == opt::SENode::RecurrentAddExpr);
+
+ const opt::SENode* child_2 = node->GetChild(1);
+ EXPECT_TRUE(child_2->GetType() == opt::SENode::Constant ||
+ child_2->GetType() == opt::SENode::RecurrentAddExpr);
+
+ opt::SENode* simplified =
+ analysis.SimplifyExpression(const_cast<opt::SENode*>(node));
+ // Simplified should be in the form of REC(1,1)
+ EXPECT_EQ(simplified->GetType(), opt::SENode::RecurrentAddExpr);
+
+ EXPECT_EQ(simplified->GetChild(0)->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
+ 1);
+
+ EXPECT_EQ(simplified->GetChild(1)->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
+ 1);
+
+ EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 410 core
+layout (location = 1) out float array[10];
+layout (location = 2) flat in int loop_invariant;
+void main() {
+ for (int i = 0; i < 10; ++i) {
+ array[i] = array[i+loop_invariant];
+ }
+}
+
+*/
+TEST_F(ScalarAnalysisTest, LoadTest) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %3 "array"
+ OpName %4 "loop_invariant"
+ OpDecorate %3 Location 1
+ OpDecorate %4 Flat
+ OpDecorate %4 Location 2
+ %5 = OpTypeVoid
+ %6 = OpTypeFunction %5
+ %7 = OpTypeInt 32 1
+ %8 = OpTypePointer Function %7
+ %9 = OpConstant %7 0
+ %10 = OpConstant %7 10
+ %11 = OpTypeBool
+ %12 = OpTypeFloat 32
+ %13 = OpTypeInt 32 0
+ %14 = OpConstant %13 10
+ %15 = OpTypeArray %12 %14
+ %16 = OpTypePointer Output %15
+ %3 = OpVariable %16 Output
+ %17 = OpTypePointer Input %7
+ %4 = OpVariable %17 Input
+ %18 = OpTypePointer Output %12
+ %19 = OpConstant %7 1
+ %2 = OpFunction %5 None %6
+ %20 = OpLabel
+ OpBranch %21
+ %21 = OpLabel
+ %22 = OpPhi %7 %9 %20 %23 %24
+ OpLoopMerge %25 %24 None
+ OpBranch %26
+ %26 = OpLabel
+ %27 = OpSLessThan %11 %22 %10
+ OpBranchConditional %27 %28 %25
+ %28 = OpLabel
+ %29 = OpLoad %7 %4
+ %30 = OpIAdd %7 %22 %29
+ %31 = OpAccessChain %18 %3 %30
+ %32 = OpLoad %12 %31
+ %33 = OpAccessChain %18 %3 %22
+ OpStore %33 %32
+ OpBranch %24
+ %24 = OpLabel
+ %23 = OpIAdd %7 %22 %19
+ OpBranch %21
+ %25 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* load = nullptr;
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ load = &inst;
+ }
+ }
+
+ EXPECT_NE(load, nullptr);
+
+ ir::Instruction* access_chain =
+ context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
+
+ ir::Instruction* child = context->get_def_use_mgr()->GetDef(
+ access_chain->GetSingleWordInOperand(1));
+ // const opt::SENode* node =
+ // analysis.GetNodeFromInstruction(child->unique_id());
+
+ const opt::SENode* node = analysis.AnalyzeInstruction(child);
+
+ EXPECT_NE(node, nullptr);
+
+ // Unsimplified node should have the form of ADD(REC(0,1), X)
+ EXPECT_EQ(node->GetType(), opt::SENode::Add);
+
+ const opt::SENode* child_1 = node->GetChild(0);
+ EXPECT_TRUE(child_1->GetType() == opt::SENode::ValueUnknown ||
+ child_1->GetType() == opt::SENode::RecurrentAddExpr);
+
+ const opt::SENode* child_2 = node->GetChild(1);
+ EXPECT_TRUE(child_2->GetType() == opt::SENode::ValueUnknown ||
+ child_2->GetType() == opt::SENode::RecurrentAddExpr);
+
+ opt::SENode* simplified =
+ analysis.SimplifyExpression(const_cast<opt::SENode*>(node));
+ EXPECT_EQ(simplified->GetType(), opt::SENode::RecurrentAddExpr);
+
+ const opt::SERecurrentNode* rec = simplified->AsSERecurrentNode();
+
+ EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
+
+ EXPECT_EQ(rec->GetOffset()->GetType(), opt::SENode::ValueUnknown);
+
+ EXPECT_EQ(rec->GetCoefficient()->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 410 core
+layout (location = 1) out float array[10];
+layout (location = 2) flat in int loop_invariant;
+void main() {
+ array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
+loop_invariant+ 16 * 3];
+}
+
+*/
+TEST_F(ScalarAnalysisTest, SimplifySimple) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %3 "array"
+ OpName %4 "loop_invariant"
+ OpDecorate %3 Location 1
+ OpDecorate %4 Flat
+ OpDecorate %4 Location 2
+ %5 = OpTypeVoid
+ %6 = OpTypeFunction %5
+ %7 = OpTypeFloat 32
+ %8 = OpTypeInt 32 0
+ %9 = OpConstant %8 10
+ %10 = OpTypeArray %7 %9
+ %11 = OpTypePointer Output %10
+ %3 = OpVariable %11 Output
+ %12 = OpTypeInt 32 1
+ %13 = OpConstant %12 0
+ %14 = OpTypePointer Input %12
+ %4 = OpVariable %14 Input
+ %15 = OpConstant %12 2
+ %16 = OpConstant %12 4
+ %17 = OpConstant %12 5
+ %18 = OpConstant %12 24
+ %19 = OpConstant %12 48
+ %20 = OpTypePointer Output %7
+ %2 = OpFunction %5 None %6
+ %21 = OpLabel
+ %22 = OpLoad %12 %4
+ %23 = OpIMul %12 %22 %15
+ %24 = OpIAdd %12 %23 %16
+ %25 = OpIAdd %12 %24 %17
+ %26 = OpISub %12 %25 %18
+ %28 = OpISub %12 %26 %22
+ %30 = OpISub %12 %28 %22
+ %31 = OpIAdd %12 %30 %19
+ %32 = OpAccessChain %20 %3 %31
+ %33 = OpLoad %7 %32
+ %34 = OpAccessChain %20 %3 %13
+ OpStore %34 %33
+ OpReturn
+ OpFunctionEnd
+ )";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* load = nullptr;
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
+ load = &inst;
+ }
+ }
+
+ EXPECT_NE(load, nullptr);
+
+ ir::Instruction* access_chain =
+ context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
+
+ ir::Instruction* child = context->get_def_use_mgr()->GetDef(
+ access_chain->GetSingleWordInOperand(1));
+
+ const opt::SENode* node = analysis.AnalyzeInstruction(child);
+
+ // Unsimplified is a very large graph with an add at the top.
+ EXPECT_NE(node, nullptr);
+ EXPECT_EQ(node->GetType(), opt::SENode::Add);
+
+ // Simplified node should resolve down to a constant expression as the loads
+ // will eliminate themselves.
+ opt::SENode* simplified =
+ analysis.SimplifyExpression(const_cast<opt::SENode*>(node));
+
+ EXPECT_EQ(simplified->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 410 core
+layout(location = 0) in vec4 c;
+layout (location = 1) out float array[10];
+void main() {
+ int N = int(c.x);
+ for (int i = 0; i < 10; ++i) {
+ array[i] = array[i];
+ array[i] = array[i-1];
+ array[i] = array[i+1];
+ array[i+1] = array[i+1];
+ array[i+N] = array[i+N];
+ array[i] = array[i+N];
+ }
+}
+
+*/
+TEST_F(ScalarAnalysisTest, Simplify) {
+ const std::string text = R"( OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main" %12 %33
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource GLSL 410
+ OpName %4 "main"
+ OpName %8 "N"
+ OpName %12 "c"
+ OpName %19 "i"
+ OpName %33 "array"
+ OpDecorate %12 Location 0
+ OpDecorate %33 Location 1
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %7 = OpTypePointer Function %6
+ %9 = OpTypeFloat 32
+ %10 = OpTypeVector %9 4
+ %11 = OpTypePointer Input %10
+ %12 = OpVariable %11 Input
+ %13 = OpTypeInt 32 0
+ %14 = OpConstant %13 0
+ %15 = OpTypePointer Input %9
+ %20 = OpConstant %6 0
+ %27 = OpConstant %6 10
+ %28 = OpTypeBool
+ %30 = OpConstant %13 10
+ %31 = OpTypeArray %9 %30
+ %32 = OpTypePointer Output %31
+ %33 = OpVariable %32 Output
+ %36 = OpTypePointer Output %9
+ %42 = OpConstant %6 1
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %8 = OpVariable %7 Function
+ %19 = OpVariable %7 Function
+ %16 = OpAccessChain %15 %12 %14
+ %17 = OpLoad %9 %16
+ %18 = OpConvertFToS %6 %17
+ OpStore %8 %18
+ OpStore %19 %20
+ OpBranch %21
+ %21 = OpLabel
+ %78 = OpPhi %6 %20 %5 %77 %24
+ OpLoopMerge %23 %24 None
+ OpBranch %25
+ %25 = OpLabel
+ %29 = OpSLessThan %28 %78 %27
+ OpBranchConditional %29 %22 %23
+ %22 = OpLabel
+ %37 = OpAccessChain %36 %33 %78
+ %38 = OpLoad %9 %37
+ %39 = OpAccessChain %36 %33 %78
+ OpStore %39 %38
+ %43 = OpISub %6 %78 %42
+ %44 = OpAccessChain %36 %33 %43
+ %45 = OpLoad %9 %44
+ %46 = OpAccessChain %36 %33 %78
+ OpStore %46 %45
+ %49 = OpIAdd %6 %78 %42
+ %50 = OpAccessChain %36 %33 %49
+ %51 = OpLoad %9 %50
+ %52 = OpAccessChain %36 %33 %78
+ OpStore %52 %51
+ %54 = OpIAdd %6 %78 %42
+ %56 = OpIAdd %6 %78 %42
+ %57 = OpAccessChain %36 %33 %56
+ %58 = OpLoad %9 %57
+ %59 = OpAccessChain %36 %33 %54
+ OpStore %59 %58
+ %62 = OpIAdd %6 %78 %18
+ %65 = OpIAdd %6 %78 %18
+ %66 = OpAccessChain %36 %33 %65
+ %67 = OpLoad %9 %66
+ %68 = OpAccessChain %36 %33 %62
+ OpStore %68 %67
+ %72 = OpIAdd %6 %78 %18
+ %73 = OpAccessChain %36 %33 %72
+ %74 = OpLoad %9 %73
+ %75 = OpAccessChain %36 %33 %78
+ OpStore %75 %74
+ OpBranch %24
+ %24 = OpLabel
+ %77 = OpIAdd %6 %78 %42
+ OpStore %19 %77
+ OpBranch %21
+ %23 = OpLabel
+ OpReturn
+ OpFunctionEnd
+)";
+ // clang-format on
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 4);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* loads[6];
+ const ir::Instruction* stores[6];
+ int load_count = 0;
+ int store_count = 0;
+
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ loads[load_count] = &inst;
+ ++load_count;
+ }
+ if (inst.opcode() == SpvOp::SpvOpStore) {
+ stores[store_count] = &inst;
+ ++store_count;
+ }
+ }
+
+ EXPECT_EQ(load_count, 6);
+ EXPECT_EQ(store_count, 6);
+
+ ir::Instruction* load_access_chain;
+ ir::Instruction* store_access_chain;
+ ir::Instruction* load_child;
+ ir::Instruction* store_child;
+ opt::SENode* load_node;
+ opt::SENode* store_node;
+ opt::SENode* subtract_node;
+ opt::SENode* simplified_node;
+
+ // Testing [i] - [i] == 0
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
+
+ // Testing [i] - [i-1] == 1
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
+
+ // Testing [i] - [i+1] == -1
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
+
+ // Testing [i+1] - [i+1] == 0
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
+
+ // Testing [i+N] - [i+N] == 0
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
+
+ // Testing [i] - [i+N] == -N
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
+
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ load_node = analysis.AnalyzeInstruction(load_child);
+ store_node = analysis.AnalyzeInstruction(store_child);
+
+ subtract_node = analysis.CreateSubtraction(store_node, load_node);
+ simplified_node = analysis.SimplifyExpression(subtract_node);
+ EXPECT_EQ(simplified_node->GetType(), opt::SENode::Negative);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 430
+layout(location = 1) out float array[10];
+layout(location = 2) flat in int loop_invariant;
+void main(void) {
+ for (int i = 0; i < 10; ++i) {
+ array[i * 2 + i * 5] = array[i * i * 2];
+ array[i * 2] = array[i * 5];
+ }
+}
+
+*/
+
+TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %5 "i"
+ OpName %3 "array"
+ OpName %4 "loop_invariant"
+ OpDecorate %3 Location 1
+ OpDecorate %4 Flat
+ OpDecorate %4 Location 2
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 10
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 10
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Output %16
+ %3 = OpVariable %17 Output
+ %18 = OpConstant %8 2
+ %19 = OpConstant %8 5
+ %20 = OpTypePointer Output %13
+ %21 = OpConstant %8 1
+ %22 = OpTypePointer Input %8
+ %4 = OpVariable %22 Input
+ %2 = OpFunction %6 None %7
+ %23 = OpLabel
+ %5 = OpVariable %9 Function
+ OpStore %5 %10
+ OpBranch %24
+ %24 = OpLabel
+ %25 = OpPhi %8 %10 %23 %26 %27
+ OpLoopMerge %28 %27 None
+ OpBranch %29
+ %29 = OpLabel
+ %30 = OpSLessThan %12 %25 %11
+ OpBranchConditional %30 %31 %28
+ %31 = OpLabel
+ %32 = OpIMul %8 %25 %18
+ %33 = OpIMul %8 %25 %19
+ %34 = OpIAdd %8 %32 %33
+ %35 = OpIMul %8 %25 %25
+ %36 = OpIMul %8 %35 %18
+ %37 = OpAccessChain %20 %3 %36
+ %38 = OpLoad %13 %37
+ %39 = OpAccessChain %20 %3 %34
+ OpStore %39 %38
+ %40 = OpIMul %8 %25 %18
+ %41 = OpIMul %8 %25 %19
+ %42 = OpAccessChain %20 %3 %41
+ %43 = OpLoad %13 %42
+ %44 = OpAccessChain %20 %3 %40
+ OpStore %44 %43
+ OpBranch %27
+ %27 = OpLabel
+ %26 = OpIAdd %8 %25 %21
+ OpStore %5 %26
+ OpBranch %24
+ %28 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* loads[2] = {nullptr, nullptr};
+ const ir::Instruction* stores[2] = {nullptr, nullptr};
+ int load_count = 0;
+ int store_count = 0;
+
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ loads[load_count] = &inst;
+ ++load_count;
+ }
+ if (inst.opcode() == SpvOp::SpvOpStore) {
+ stores[store_count] = &inst;
+ ++store_count;
+ }
+ }
+
+ EXPECT_EQ(load_count, 2);
+ EXPECT_EQ(store_count, 2);
+
+ ir::Instruction* load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
+ ir::Instruction* store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
+
+ ir::Instruction* load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ ir::Instruction* store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* store_node = analysis.AnalyzeInstruction(store_child);
+
+ opt::SENode* store_simplified = analysis.SimplifyExpression(store_node);
+
+ load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
+ store_access_chain =
+ context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
+ load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+ store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* second_store =
+ analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
+ opt::SENode* second_load =
+ analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
+ opt::SENode* combined_add = analysis.SimplifyExpression(
+ analysis.CreateAddNode(second_load, second_store));
+
+ // We're checking that the two recurrent expression have been correctly
+ // folded. In store_simplified they will have been folded as the entire
+ // expression was simplified as one. In combined_add the two expressions have
+ // been simplified one after the other which means the recurrent expressions
+ // aren't exactly the same but should still be folded as they are with respect
+ // to the same loop.
+ EXPECT_EQ(combined_add, store_simplified);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 430
+void main(void) {
+ for (int i = 0; i < 10; --i) {
+ array[i] = array[i];
+ }
+}
+
+*/
+
+TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %5 "i"
+ OpName %3 "array"
+ OpName %4 "loop_invariant"
+ OpDecorate %3 Location 1
+ OpDecorate %4 Flat
+ OpDecorate %4 Location 2
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 10
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 10
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Output %16
+ %3 = OpVariable %17 Output
+ %18 = OpTypePointer Output %13
+ %19 = OpConstant %8 1
+ %20 = OpTypePointer Input %8
+ %4 = OpVariable %20 Input
+ %2 = OpFunction %6 None %7
+ %21 = OpLabel
+ %5 = OpVariable %9 Function
+ OpStore %5 %10
+ OpBranch %22
+ %22 = OpLabel
+ %23 = OpPhi %8 %10 %21 %24 %25
+ OpLoopMerge %26 %25 None
+ OpBranch %27
+ %27 = OpLabel
+ %28 = OpSLessThan %12 %23 %11
+ OpBranchConditional %28 %29 %26
+ %29 = OpLabel
+ %30 = OpAccessChain %18 %3 %23
+ %31 = OpLoad %13 %30
+ %32 = OpAccessChain %18 %3 %23
+ OpStore %32 %31
+ OpBranch %25
+ %25 = OpLabel
+ %24 = OpISub %8 %23 %19
+ OpStore %5 %24
+ OpBranch %22
+ %26 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ const ir::Instruction* loads[1] = {nullptr};
+ int load_count = 0;
+
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ loads[load_count] = &inst;
+ ++load_count;
+ }
+ }
+
+ EXPECT_EQ(load_count, 1);
+
+ ir::Instruction* load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
+ ir::Instruction* load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* load_node = analysis.AnalyzeInstruction(load_child);
+
+ EXPECT_TRUE(load_node);
+ EXPECT_EQ(load_node->GetType(), opt::SENode::RecurrentAddExpr);
+ EXPECT_TRUE(load_node->AsSERecurrentNode());
+
+ opt::SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
+ opt::SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
+
+ EXPECT_EQ(child_1->GetType(), opt::SENode::Constant);
+ EXPECT_EQ(child_2->GetType(), opt::SENode::Constant);
+
+ EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
+ EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
+
+ opt::SERecurrentNode* load_simplified =
+ analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
+
+ EXPECT_TRUE(load_simplified);
+ EXPECT_EQ(load_node, load_simplified);
+
+ EXPECT_EQ(load_simplified->GetType(), opt::SENode::RecurrentAddExpr);
+ EXPECT_TRUE(load_simplified->AsSERecurrentNode());
+
+ opt::SENode* simplified_child_1 =
+ load_simplified->AsSERecurrentNode()->GetCoefficient();
+ opt::SENode* simplified_child_2 =
+ load_simplified->AsSERecurrentNode()->GetOffset();
+
+ EXPECT_EQ(child_1, simplified_child_1);
+ EXPECT_EQ(child_2, simplified_child_2);
+}
+
+/*
+Generated from the following GLSL + --eliminate-local-multi-store
+
+#version 430
+void main(void) {
+ for (int i = 0; i < 10; --i) {
+ array[i] = array[i];
+ }
+}
+
+*/
+
+TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %5 "i"
+ OpName %3 "array"
+ OpName %4 "N"
+ OpDecorate %3 Location 1
+ OpDecorate %4 Flat
+ OpDecorate %4 Location 2
+ %6 = OpTypeVoid
+ %7 = OpTypeFunction %6
+ %8 = OpTypeInt 32 1
+ %9 = OpTypePointer Function %8
+ %10 = OpConstant %8 0
+ %11 = OpConstant %8 10
+ %12 = OpTypeBool
+ %13 = OpTypeFloat 32
+ %14 = OpTypeInt 32 0
+ %15 = OpConstant %14 10
+ %16 = OpTypeArray %13 %15
+ %17 = OpTypePointer Output %16
+ %3 = OpVariable %17 Output
+ %18 = OpConstant %8 2
+ %19 = OpTypePointer Input %8
+ %4 = OpVariable %19 Input
+ %20 = OpTypePointer Output %13
+ %21 = OpConstant %8 1
+ %2 = OpFunction %6 None %7
+ %22 = OpLabel
+ %5 = OpVariable %9 Function
+ OpStore %5 %10
+ OpBranch %23
+ %23 = OpLabel
+ %24 = OpPhi %8 %10 %22 %25 %26
+ OpLoopMerge %27 %26 None
+ OpBranch %28
+ %28 = OpLabel
+ %29 = OpSLessThan %12 %24 %11
+ OpBranchConditional %29 %30 %27
+ %30 = OpLabel
+ %31 = OpLoad %8 %4
+ %32 = OpIMul %8 %18 %31
+ %33 = OpIAdd %8 %24 %32
+ %35 = OpIAdd %8 %24 %31
+ %36 = OpAccessChain %20 %3 %35
+ %37 = OpLoad %13 %36
+ %38 = OpAccessChain %20 %3 %33
+ OpStore %38 %37
+ %39 = OpIMul %8 %18 %24
+ %41 = OpIMul %8 %18 %31
+ %42 = OpIAdd %8 %39 %41
+ %43 = OpIAdd %8 %42 %21
+ %44 = OpIMul %8 %18 %24
+ %46 = OpIAdd %8 %44 %31
+ %47 = OpIAdd %8 %46 %21
+ %48 = OpAccessChain %20 %3 %47
+ %49 = OpLoad %13 %48
+ %50 = OpAccessChain %20 %3 %43
+ OpStore %50 %49
+ OpBranch %26
+ %26 = OpLabel
+ %25 = OpISub %8 %24 %21
+ OpStore %5 %25
+ OpBranch %23
+ %27 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ std::vector<const ir::Instruction*> loads{};
+ std::vector<const ir::Instruction*> stores{};
+
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
+ if (inst.opcode() == SpvOp::SpvOpLoad) {
+ loads.push_back(&inst);
+ }
+ if (inst.opcode() == SpvOp::SpvOpStore) {
+ stores.push_back(&inst);
+ }
+ }
+
+ EXPECT_EQ(loads.size(), 3u);
+ EXPECT_EQ(stores.size(), 2u);
+ {
+ ir::Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
+ stores[0]->GetSingleWordInOperand(0));
+
+ ir::Instruction* store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* store_node = analysis.AnalyzeInstruction(store_child);
+
+ opt::SENode* store_simplified = analysis.SimplifyExpression(store_node);
+
+ ir::Instruction* load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
+
+ ir::Instruction* load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* load_node = analysis.AnalyzeInstruction(load_child);
+
+ opt::SENode* load_simplified = analysis.SimplifyExpression(load_node);
+
+ opt::SENode* difference =
+ analysis.CreateSubtraction(store_simplified, load_simplified);
+
+ opt::SENode* difference_simplified =
+ analysis.SimplifyExpression(difference);
+
+ // Check that i+2*N - i*N, turns into just N when both sides have already
+ // been simplified into a single recurrent expression.
+ EXPECT_EQ(difference_simplified->GetType(), opt::SENode::ValueUnknown);
+
+ // Check that the inverse, i*N - i+2*N turns into -N.
+ opt::SENode* difference_inverse = analysis.SimplifyExpression(
+ analysis.CreateSubtraction(load_simplified, store_simplified));
+
+ EXPECT_EQ(difference_inverse->GetType(), opt::SENode::Negative);
+ EXPECT_EQ(difference_inverse->GetChild(0)->GetType(),
+ opt::SENode::ValueUnknown);
+ EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
+ }
+
+ {
+ ir::Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
+ stores[1]->GetSingleWordInOperand(0));
+
+ ir::Instruction* store_child = context->get_def_use_mgr()->GetDef(
+ store_access_chain->GetSingleWordInOperand(1));
+ opt::SENode* store_node = analysis.AnalyzeInstruction(store_child);
+ opt::SENode* store_simplified = analysis.SimplifyExpression(store_node);
+
+ ir::Instruction* load_access_chain =
+ context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
+
+ ir::Instruction* load_child = context->get_def_use_mgr()->GetDef(
+ load_access_chain->GetSingleWordInOperand(1));
+
+ opt::SENode* load_node = analysis.AnalyzeInstruction(load_child);
+
+ opt::SENode* load_simplified = analysis.SimplifyExpression(load_node);
+
+ opt::SENode* difference =
+ analysis.CreateSubtraction(store_simplified, load_simplified);
+ opt::SENode* difference_simplified =
+ analysis.SimplifyExpression(difference);
+
+ // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both
+ // sides have already been simplified into a single recurrent expression.
+ EXPECT_EQ(difference_simplified->GetType(), opt::SENode::ValueUnknown);
+
+ // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N.
+ opt::SENode* difference_inverse = analysis.SimplifyExpression(
+ analysis.CreateSubtraction(load_simplified, store_simplified));
+
+ EXPECT_EQ(difference_inverse->GetType(), opt::SENode::Negative);
+ EXPECT_EQ(difference_inverse->GetChild(0)->GetType(),
+ opt::SENode::ValueUnknown);
+ EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
+ }
+}
+
+/* Generated from the following GLSL + --eliminate-local-multi-store
+
+ #version 430
+ layout(location = 1) out float array[10];
+ layout(location = 2) flat in int N;
+ void main(void) {
+ int step = 0;
+ for (int i = 0; i < N; i += step) {
+ step++;
+ }
+ }
+*/
+TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
+ const std::string text = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %2 "main" %3 %4
+ OpExecutionMode %2 OriginUpperLeft
+ OpSource GLSL 430
+ OpName %2 "main"
+ OpName %5 "step"
+ OpName %6 "i"
+ OpName %3 "N"
+ OpName %4 "array"
+ OpDecorate %3 Flat
+ OpDecorate %3 Location 2
+ OpDecorate %4 Location 1
+ %7 = OpTypeVoid
+ %8 = OpTypeFunction %7
+ %9 = OpTypeInt 32 1
+ %10 = OpTypePointer Function %9
+ %11 = OpConstant %9 0
+ %12 = OpTypePointer Input %9
+ %3 = OpVariable %12 Input
+ %13 = OpTypeBool
+ %14 = OpConstant %9 1
+ %15 = OpTypeFloat 32
+ %16 = OpTypeInt 32 0
+ %17 = OpConstant %16 10
+ %18 = OpTypeArray %15 %17
+ %19 = OpTypePointer Output %18
+ %4 = OpVariable %19 Output
+ %2 = OpFunction %7 None %8
+ %20 = OpLabel
+ %5 = OpVariable %10 Function
+ %6 = OpVariable %10 Function
+ OpStore %5 %11
+ OpStore %6 %11
+ OpBranch %21
+ %21 = OpLabel
+ %22 = OpPhi %9 %11 %20 %23 %24
+ %25 = OpPhi %9 %11 %20 %26 %24
+ OpLoopMerge %27 %24 None
+ OpBranch %28
+ %28 = OpLabel
+ %29 = OpLoad %9 %3
+ %30 = OpSLessThan %13 %25 %29
+ OpBranchConditional %30 %31 %27
+ %31 = OpLabel
+ %23 = OpIAdd %9 %22 %14
+ OpStore %5 %23
+ OpBranch %24
+ %24 = OpLabel
+ %26 = OpIAdd %9 %25 %23
+ OpStore %6 %26
+ OpBranch %21
+ %27 = OpLabel
+ OpReturn
+ OpFunctionEnd
+ )";
+ std::unique_ptr<ir::IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ir::Module* module = context->module();
+ EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
+ << text << std::endl;
+ const ir::Function* f = spvtest::GetFunction(module, 2);
+ opt::ScalarEvolutionAnalysis analysis{context.get()};
+
+ std::vector<const ir::Instruction*> phis{};
+
+ for (const ir::Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
+ if (inst.opcode() == SpvOp::SpvOpPhi) {
+ phis.push_back(&inst);
+ }
+ }
+
+ EXPECT_EQ(phis.size(), 2u);
+ opt::SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
+ opt::SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
+ phi_node_1->DumpDot(std::cout, true);
+ EXPECT_NE(phi_node_1, nullptr);
+ EXPECT_NE(phi_node_2, nullptr);
+
+ EXPECT_EQ(phi_node_1->GetType(), opt::SENode::RecurrentAddExpr);
+ EXPECT_EQ(phi_node_2->GetType(), opt::SENode::CanNotCompute);
+
+ opt::SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
+ opt::SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
+
+ EXPECT_EQ(simplified_1->GetType(), opt::SENode::RecurrentAddExpr);
+ EXPECT_EQ(simplified_2->GetType(), opt::SENode::CanNotCompute);
+}
+
+} // namespace