Document in the context what happens with id overflow. (#2159)
Added documentation to the ir context to indicates that TakeNextId()
returns 0 when the max id is reached. TODOs were added to each call
sight so that we know where we have to start to handle this case.
Handle id overflow in |SplitLoopHeader|.
Handle id overflow in |GetOrCreatePreHeaderBlock|.
Handle failure to create preheader in LICM.
Part of https://github.com/KhronosGroup/SPIRV-Tools/issues/1841.
diff --git a/source/opt/cfg.cpp b/source/opt/cfg.cpp
index 778c527..7e1097e 100644
--- a/source/opt/cfg.cpp
+++ b/source/opt/cfg.cpp
@@ -167,6 +167,13 @@
Function* fn = bb->GetParent();
IRContext* context = module_->context();
+ // Get the new header id up front. If we are out of ids, then we cannot split
+ // the loop.
+ uint32_t new_header_id = context->TakeNextId();
+ if (new_header_id == 0) {
+ return nullptr;
+ }
+
// Find the insertion point for the new bb.
Function::iterator header_it = std::find_if(
fn->begin(), fn->end(),
@@ -197,10 +204,7 @@
++iter;
}
- BasicBlock* new_header =
- bb->SplitBasicBlock(context, context->TakeNextId(), iter);
-
- uint32_t new_header_id = new_header->id();
+ BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
context->AnalyzeDefUse(new_header->GetLabelInst());
// Update cfg
diff --git a/source/opt/cfg.h b/source/opt/cfg.h
index 7bb8ecb..5ff3aa0 100644
--- a/source/opt/cfg.h
+++ b/source/opt/cfg.h
@@ -128,7 +128,8 @@
// id as |block| and will become a preheader for the loop. The other block
// is a new block that will be the new loop header.
//
- // Returns a pointer to the new loop header.
+ // Returns a pointer to the new loop header. Returns |nullptr| if the new
+ // loop pointer could not be created.
BasicBlock* SplitLoopHeader(BasicBlock* bb);
private:
diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp
index ecb5f97..768364b 100644
--- a/source/opt/constants.cpp
+++ b/source/opt/constants.cpp
@@ -165,6 +165,7 @@
Instruction* ConstantManager::BuildInstructionAndAddToModule(
const Constant* new_const, Module::inst_iterator* pos, uint32_t type_id) {
+ // TODO(1841): Handle id overflow.
uint32_t new_id = context()->TakeNextId();
auto new_inst = CreateInstruction(new_id, new_const, type_id);
if (!new_inst) {
diff --git a/source/opt/ir_builder.h b/source/opt/ir_builder.h
index 434c380..2f741d8 100644
--- a/source/opt/ir_builder.h
+++ b/source/opt/ir_builder.h
@@ -59,6 +59,7 @@
preserved_analyses) {}
Instruction* AddNullaryOp(uint32_t type_id, SpvOp opcode) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newUnOp(new Instruction(
GetContext(), opcode, type_id,
opcode == SpvOpReturn ? 0 : GetContext()->TakeNextId(), {}));
@@ -66,6 +67,7 @@
}
Instruction* AddUnaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newUnOp(new Instruction(
GetContext(), opcode, type_id, GetContext()->TakeNextId(),
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}}}));
@@ -74,6 +76,7 @@
Instruction* AddBinaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1,
uint32_t operand2) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newBinOp(new Instruction(
GetContext(), opcode, type_id,
opcode == SpvOpStore ? 0 : GetContext()->TakeNextId(),
@@ -84,6 +87,7 @@
Instruction* AddTernaryOp(uint32_t type_id, SpvOp opcode, uint32_t operand1,
uint32_t operand2, uint32_t operand3) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newTernOp(new Instruction(
GetContext(), opcode, type_id, GetContext()->TakeNextId(),
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}},
@@ -95,6 +99,7 @@
Instruction* AddQuadOp(uint32_t type_id, SpvOp opcode, uint32_t operand1,
uint32_t operand2, uint32_t operand3,
uint32_t operand4) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newQuadOp(new Instruction(
GetContext(), opcode, type_id, GetContext()->TakeNextId(),
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}},
@@ -106,6 +111,7 @@
Instruction* AddIdLiteralOp(uint32_t type_id, SpvOp opcode, uint32_t operand1,
uint32_t operand2) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> newBinOp(new Instruction(
GetContext(), opcode, type_id, GetContext()->TakeNextId(),
{{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {operand1}},
@@ -124,6 +130,7 @@
for (size_t i = 0; i < operands.size(); i++) {
ops.push_back({SPV_OPERAND_TYPE_ID, {operands[i]}});
}
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_inst(new Instruction(
GetContext(), opcode, type_id,
result != 0 ? result : GetContext()->TakeNextId(), ops));
@@ -251,6 +258,7 @@
// The id |op1| is the left hand side of the operation.
// The id |op2| is the right hand side of the operation.
Instruction* AddIAdd(uint32_t type, uint32_t op1, uint32_t op2) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> inst(new Instruction(
GetContext(), SpvOpIAdd, type, GetContext()->TakeNextId(),
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}));
@@ -264,6 +272,7 @@
Instruction* AddULessThan(uint32_t op1, uint32_t op2) {
analysis::Bool bool_type;
uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type);
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> inst(new Instruction(
GetContext(), SpvOpULessThan, type, GetContext()->TakeNextId(),
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}));
@@ -277,6 +286,7 @@
Instruction* AddSLessThan(uint32_t op1, uint32_t op2) {
analysis::Bool bool_type;
uint32_t type = GetContext()->get_type_mgr()->GetId(&bool_type);
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> inst(new Instruction(
GetContext(), SpvOpSLessThan, type, GetContext()->TakeNextId(),
{{SPV_OPERAND_TYPE_ID, {op1}}, {SPV_OPERAND_TYPE_ID, {op2}}}));
@@ -306,6 +316,7 @@
// bool) for |type|.
Instruction* AddSelect(uint32_t type, uint32_t cond, uint32_t true_value,
uint32_t false_value) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> select(new Instruction(
GetContext(), SpvOpSelect, type, GetContext()->TakeNextId(),
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cond}},
@@ -330,6 +341,7 @@
ops.emplace_back(SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
}
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> construct(
new Instruction(GetContext(), SpvOpCompositeConstruct, type,
GetContext()->TakeNextId(), ops));
@@ -401,6 +413,7 @@
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
}
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_inst(
new Instruction(GetContext(), SpvOpCompositeExtract, type,
GetContext()->TakeNextId(), operands));
@@ -424,6 +437,7 @@
operands.push_back({SPV_OPERAND_TYPE_ID, {index_id}});
}
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_inst(
new Instruction(GetContext(), SpvOpAccessChain, type_id,
GetContext()->TakeNextId(), operands));
@@ -434,6 +448,7 @@
std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {base_ptr_id}});
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_inst(
new Instruction(GetContext(), SpvOpLoad, type_id,
GetContext()->TakeNextId(), operands));
diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp
index c1158e7..31dbe5b 100644
--- a/source/opt/ir_context.cpp
+++ b/source/opt/ir_context.cpp
@@ -665,6 +665,7 @@
uint32_t type_id = type_mgr->GetTypeInstruction(reg_type);
uint32_t varTyPtrId =
type_mgr->FindPointerToType(type_id, SpvStorageClassInput);
+ // TODO(1841): Handle id overflow.
var_id = TakeNextId();
std::unique_ptr<Instruction> newVarOp(
new Instruction(this, SpvOpVariable, varTyPtrId, var_id,
diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h
index 83e06b8..94bfd01 100644
--- a/source/opt/ir_context.h
+++ b/source/opt/ir_context.h
@@ -450,7 +450,8 @@
post_dominator_trees_.erase(f);
}
- // Return the next available SSA id and increment it.
+ // Return the next available SSA id and increment it. Returns 0 if the
+ // maximum SSA id has been reached.
inline uint32_t TakeNextId() { return module()->TakeNextIdBound(); }
FeatureManager* get_feature_mgr() {
diff --git a/source/opt/licm_pass.cpp b/source/opt/licm_pass.cpp
index d825667..c553221 100644
--- a/source/opt/licm_pass.cpp
+++ b/source/opt/licm_pass.cpp
@@ -23,70 +23,81 @@
namespace spvtools {
namespace opt {
-Pass::Status LICMPass::Process() {
- return ProcessIRContext() ? Status::SuccessWithChange
- : Status::SuccessWithoutChange;
-}
+Pass::Status LICMPass::Process() { return ProcessIRContext(); }
-bool LICMPass::ProcessIRContext() {
- bool modified = false;
+Pass::Status LICMPass::ProcessIRContext() {
+ Status status = Status::SuccessWithoutChange;
Module* module = get_module();
// Process each function in the module
- for (Function& f : *module) {
- modified |= ProcessFunction(&f);
+ for (auto func = module->begin();
+ func != module->end() && status != Status::Failure; ++func) {
+ status = CombineStatus(status, ProcessFunction(&*func));
}
- return modified;
+ return status;
}
-bool LICMPass::ProcessFunction(Function* f) {
- bool modified = false;
+Pass::Status LICMPass::ProcessFunction(Function* f) {
+ Status status = Status::SuccessWithoutChange;
LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
// Process each loop in the function
- for (Loop& loop : *loop_descriptor) {
+ for (auto it = loop_descriptor->begin();
+ it != loop_descriptor->end() && status != Status::Failure; ++it) {
+ Loop& loop = *it;
// Ignore nested loops, as we will process them in order in ProcessLoop
if (loop.IsNested()) {
continue;
}
- modified |= ProcessLoop(&loop, f);
+ status = CombineStatus(status, ProcessLoop(&loop, f));
}
- return modified;
+ return status;
}
-bool LICMPass::ProcessLoop(Loop* loop, Function* f) {
- bool modified = false;
+Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) {
+ Status status = Status::SuccessWithoutChange;
// Process all nested loops first
- for (Loop* nested_loop : *loop) {
- modified |= ProcessLoop(nested_loop, f);
+ for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure;
+ ++nl) {
+ Loop* nested_loop = *nl;
+ status = CombineStatus(status, ProcessLoop(nested_loop, f));
}
std::vector<BasicBlock*> loop_bbs{};
- modified |= AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs);
+ status = CombineStatus(
+ status,
+ AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs));
- for (size_t i = 0; i < loop_bbs.size(); ++i) {
+ for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) {
BasicBlock* bb = loop_bbs[i];
// do not delete the element
- modified |= AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs);
+ status =
+ CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs));
}
- return modified;
+ return status;
}
-bool LICMPass::AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb,
- std::vector<BasicBlock*>* loop_bbs) {
+Pass::Status LICMPass::AnalyseAndHoistFromBB(
+ Loop* loop, Function* f, BasicBlock* bb,
+ std::vector<BasicBlock*>* loop_bbs) {
bool modified = false;
- std::function<void(Instruction*)> hoist_inst =
+ std::function<bool(Instruction*)> hoist_inst =
[this, &loop, &modified](Instruction* inst) {
if (loop->ShouldHoistInstruction(this->context(), inst)) {
- HoistInstruction(loop, inst);
+ if (!HoistInstruction(loop, inst)) {
+ return false;
+ }
modified = true;
}
+ return true;
};
if (IsImmediatelyContainedInLoop(loop, f, bb)) {
- bb->ForEachInst(hoist_inst, false);
+ if (!bb->WhileEachInst(hoist_inst, false)) {
+ return Status::Failure;
+ }
}
DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f);
@@ -98,7 +109,7 @@
}
}
- return modified;
+ return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}
bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f,
@@ -107,10 +118,15 @@
return loop == (*loop_descriptor)[bb->id()];
}
-void LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
+bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
+ // TODO(1841): Handle failure to create pre-header.
BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock();
+ if (!pre_header_bb) {
+ return false;
+ }
inst->InsertBefore(std::move(&(*pre_header_bb->tail())));
context()->set_instr_block(inst, pre_header_bb);
+ return true;
}
} // namespace opt
diff --git a/source/opt/licm_pass.h b/source/opt/licm_pass.h
index a174500..a94ae11 100644
--- a/source/opt/licm_pass.h
+++ b/source/opt/licm_pass.h
@@ -35,30 +35,35 @@
private:
// Searches the IRContext for functions and processes each, moving invariants
- // outside loops within the function where possible
- // Returns true if a change was made to a function within the IRContext
- bool ProcessIRContext();
+ // outside loops within the function where possible.
+ // Returns the status depending on whether or not there was a failure or
+ // change.
+ Pass::Status ProcessIRContext();
// Checks the function for loops, calling ProcessLoop on each one found.
- // Returns true if a change was made to the function, false otherwise.
- bool ProcessFunction(Function* f);
+ // Returns the status depending on whether or not there was a failure or
+ // change.
+ Pass::Status ProcessFunction(Function* f);
// Checks for invariants in the loop and attempts to move them to the loops
// preheader. Works from inner loop to outer when nested loops are found.
- // Returns true if a change was made to the loop, false otherwise.
- bool ProcessLoop(Loop* loop, Function* f);
+ // Returns the status depending on whether or not there was a failure or
+ // change.
+ Pass::Status ProcessLoop(Loop* loop, Function* f);
// Analyses each instruction in |bb|, hoisting invariants to |pre_header_bb|.
// Each child of |bb| wrt to |dom_tree| is pushed to |loop_bbs|
- bool AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb,
- std::vector<BasicBlock*>* loop_bbs);
+ // Returns the status depending on whether or not there was a failure or
+ // change.
+ Pass::Status AnalyseAndHoistFromBB(Loop* loop, Function* f, BasicBlock* bb,
+ std::vector<BasicBlock*>* loop_bbs);
// Returns true if |bb| is immediately contained in |loop|
bool IsImmediatelyContainedInLoop(Loop* loop, Function* f, BasicBlock* bb);
// Move the instruction to the given BasicBlock
// This method will update the instruction to block mapping for the context
- void HoistInstruction(Loop* loop, Instruction* inst);
+ bool HoistInstruction(Loop* loop, Instruction* inst);
};
} // namespace opt
diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp
index efc56bd..5aff34c 100644
--- a/source/opt/loop_descriptor.cpp
+++ b/source/opt/loop_descriptor.cpp
@@ -914,6 +914,7 @@
for (auto& loop : *this) {
if (!loop.GetPreHeaderBlock()) {
modified = true;
+ // TODO(1841): Handle failure to create pre-header.
loop.GetOrCreatePreHeaderBlock();
}
}
diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h
index 45a175a..38f017b 100644
--- a/source/opt/loop_descriptor.h
+++ b/source/opt/loop_descriptor.h
@@ -132,7 +132,7 @@
void SetPreHeaderBlock(BasicBlock* preheader);
// Returns the loop pre-header, if there is no suitable preheader it will be
- // created.
+ // created. Returns |nullptr| if it fails to create the preheader.
BasicBlock* GetOrCreatePreHeaderBlock();
// Returns true if this loop contains any nested loops.
diff --git a/source/opt/loop_fission.cpp b/source/opt/loop_fission.cpp
index 0052406..0678113 100644
--- a/source/opt/loop_fission.cpp
+++ b/source/opt/loop_fission.cpp
@@ -367,6 +367,7 @@
cloned_loop->UpdateLoopMergeInst();
// Add the loop_ to the module.
+ // TODO(1841): Handle failure to create pre-header.
Function::iterator it =
util.GetFunction()->FindBlock(loop_->GetOrCreatePreHeaderBlock()->id());
util.GetFunction()->AddBasicBlocks(clone_results.cloned_bb_.begin(),
diff --git a/source/opt/loop_peeling.cpp b/source/opt/loop_peeling.cpp
index 227ba4a..b640542 100644
--- a/source/opt/loop_peeling.cpp
+++ b/source/opt/loop_peeling.cpp
@@ -39,6 +39,7 @@
assert(CanPeelLoop() && "Cannot peel loop!");
std::vector<BasicBlock*> ordered_loop_blocks;
+ // TODO(1841): Handle failure to create pre-header.
BasicBlock* pre_header = loop_->GetOrCreatePreHeaderBlock();
loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks);
@@ -131,6 +132,7 @@
// Force the creation of a new preheader for the original loop and set it as
// the merge block for the cloned loop.
+ // TODO(1841): Handle failure to create pre-header.
cloned_loop_->SetMergeBlock(loop_->GetOrCreatePreHeaderBlock());
}
@@ -345,6 +347,7 @@
CFG& cfg = *context_->cfg();
assert(cfg.preds(bb->id()).size() == 1 && "More than one predecessor");
+ // TODO(1841): Handle id overflow.
std::unique_ptr<BasicBlock> new_bb =
MakeUnique<BasicBlock>(std::unique_ptr<Instruction>(new Instruction(
context_, SpvOpLabel, 0, context_->TakeNextId(), {})));
@@ -391,6 +394,7 @@
BasicBlock* LoopPeeling::ProtectLoop(Loop* loop, Instruction* condition,
BasicBlock* if_merge) {
+ // TODO(1841): Handle failure to create pre-header.
BasicBlock* if_block = loop->GetOrCreatePreHeaderBlock();
// Will no longer be a pre-header because of the if.
loop->SetPreHeaderBlock(nullptr);
diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp
index d3e733a..0d49d88 100644
--- a/source/opt/loop_unroller.cpp
+++ b/source/opt/loop_unroller.cpp
@@ -377,6 +377,7 @@
// number of bodies.
void LoopUnrollerUtilsImpl::PartiallyUnrollResidualFactor(Loop* loop,
size_t factor) {
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_label{new Instruction(
context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})};
std::unique_ptr<BasicBlock> new_exit_bb{new BasicBlock(std::move(new_label))};
@@ -834,6 +835,7 @@
// Label instructions aren't covered by normal traversal of the
// instructions.
+ // TODO(1841): Handle id overflow.
uint32_t new_label_id = context_->TakeNextId();
// Assign a new id to the label.
@@ -850,6 +852,7 @@
}
// Give the instruction a new id.
+ // TODO(1841): Handle id overflow.
inst.SetResultId(context_->TakeNextId());
def_use_mgr->AnalyzeInstDef(&inst);
diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp
index 59a0cbc..7e374d9 100644
--- a/source/opt/loop_unswitch_pass.cpp
+++ b/source/opt/loop_unswitch_pass.cpp
@@ -99,6 +99,7 @@
BasicBlock* CreateBasicBlock(Function::iterator ip) {
analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
+ // TODO(1841): Handle id overflow.
BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
@@ -459,7 +460,10 @@
std::vector<BasicBlock*> ordered_loop_blocks_;
// Returns the next usable id for the context.
- uint32_t TakeNextId() { return context_->TakeNextId(); }
+ uint32_t TakeNextId() {
+ // TODO(1841): Handle id overflow.
+ return context_->TakeNextId();
+ }
// Patches |bb|'s phi instruction by removing incoming value from unexisting
// or tagged as dead branches.
@@ -474,28 +478,28 @@
return dead_blocks.count(id) ||
std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
};
- bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi,
- this](Instruction* insn) {
- uint32_t i = 0;
- while (i < insn->NumInOperands()) {
- uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
- if (is_branch_dead(incoming_id)) {
- // Remove the incoming block id operand.
- insn->RemoveInOperand(i + 1);
- // Remove the definition id operand.
- insn->RemoveInOperand(i);
- continue;
- }
- i += 2;
- }
- // If there is only 1 remaining edge, propagate the value and
- // kill the instruction.
- if (insn->NumInOperands() == 2 && !preserve_phi) {
- phi_to_kill.push_back(insn);
- context_->ReplaceAllUsesWith(insn->result_id(),
- insn->GetSingleWordInOperand(0));
- }
- });
+ bb->ForEachPhiInst(
+ [&phi_to_kill, &is_branch_dead, preserve_phi, this](Instruction* insn) {
+ uint32_t i = 0;
+ while (i < insn->NumInOperands()) {
+ uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
+ if (is_branch_dead(incoming_id)) {
+ // Remove the incoming block id operand.
+ insn->RemoveInOperand(i + 1);
+ // Remove the definition id operand.
+ insn->RemoveInOperand(i);
+ continue;
+ }
+ i += 2;
+ }
+ // If there is only 1 remaining edge, propagate the value and
+ // kill the instruction.
+ if (insn->NumInOperands() == 2 && !preserve_phi) {
+ phi_to_kill.push_back(insn);
+ context_->ReplaceAllUsesWith(insn->result_id(),
+ insn->GetSingleWordInOperand(0));
+ }
+ });
for (Instruction* insn : phi_to_kill) {
context_->KillInst(insn);
}
diff --git a/source/opt/loop_utils.cpp b/source/opt/loop_utils.cpp
index 482335f..8c6d355 100644
--- a/source/opt/loop_utils.cpp
+++ b/source/opt/loop_utils.cpp
@@ -352,6 +352,7 @@
assert(insert_pt != function->end() && "Basic Block not found");
// Create the dedicate exit basic block.
+ // TODO(1841): Handle id overflow.
BasicBlock& exit = *insert_pt.InsertBefore(std::unique_ptr<BasicBlock>(
new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
@@ -491,6 +492,7 @@
Loop* new_loop = CloneLoop(cloning_result);
// Create a new exit block/label for the new loop.
+ // TODO(1841): Handle id overflow.
std::unique_ptr<Instruction> new_label{new Instruction(
context_, SpvOp::SpvOpLabel, 0, context_->TakeNextId(), {})};
std::unique_ptr<BasicBlock> new_exit_bb{new BasicBlock(std::move(new_label))};
@@ -528,6 +530,7 @@
inst->SetOperand(operand, {new_header});
});
+ // TODO(1841): Handle failure to create pre-header.
def_use->ForEachUse(
loop_->GetOrCreatePreHeaderBlock()->id(),
[new_merge_block, this](Instruction* inst, uint32_t operand) {
@@ -560,6 +563,7 @@
// between old and new ids.
BasicBlock* new_bb = old_bb->Clone(context_);
new_bb->SetParent(&function_);
+ // TODO(1841): Handle id overflow.
new_bb->GetLabelInst()->SetResultId(context_->TakeNextId());
def_use_mgr->AnalyzeInstDef(new_bb->GetLabelInst());
context_->set_instr_block(new_bb->GetLabelInst(), new_bb);
@@ -575,6 +579,7 @@
new_inst != new_bb->end(); ++new_inst, ++old_inst) {
cloning_result->ptr_map_[&*new_inst] = &*old_inst;
if (new_inst->HasResultId()) {
+ // TODO(1841): Handle id overflow.
new_inst->SetResultId(context_->TakeNextId());
cloning_result->value_map_[old_inst->result_id()] =
new_inst->result_id();
diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp
index 820760c..d0434f2 100644
--- a/source/opt/merge_return_pass.cpp
+++ b/source/opt/merge_return_pass.cpp
@@ -121,7 +121,9 @@
// Predicate successors of the original return blocks as necessary.
if (std::find(return_blocks.begin(), return_blocks.end(), block) !=
return_blocks.end()) {
- PredicateBlocks(block, &predicated, &order);
+ if (!PredicateBlocks(block, &predicated, &order)) {
+ return false;
+ }
}
// Generate state for next block
@@ -288,14 +290,14 @@
}
}
-void MergeReturnPass::PredicateBlocks(
+bool MergeReturnPass::PredicateBlocks(
BasicBlock* return_block, std::unordered_set<BasicBlock*>* predicated,
std::list<BasicBlock*>* order) {
// The CFG is being modified as the function proceeds so avoid caching
// successors.
if (predicated->count(return_block)) {
- return;
+ return true;
}
BasicBlock* block = nullptr;
@@ -328,12 +330,15 @@
while (state->LoopMergeId() == next->id()) {
state++;
}
- BreakFromConstruct(block, next, predicated, order);
+ if (!BreakFromConstruct(block, next, predicated, order)) {
+ return false;
+ }
block = next;
}
+ return true;
}
-void MergeReturnPass::BreakFromConstruct(
+bool MergeReturnPass::BreakFromConstruct(
BasicBlock* block, BasicBlock* merge_block,
std::unordered_set<BasicBlock*>* predicated,
std::list<BasicBlock*>* order) {
@@ -353,7 +358,9 @@
// If |block| is a loop header, then the back edge must jump to the original
// code, not the new header.
if (block->GetLoopMergeInst()) {
- cfg()->SplitLoopHeader(block);
+ if (cfg()->SplitLoopHeader(block) == nullptr) {
+ return false;
+ }
}
// Leave the phi instructions behind.
@@ -407,6 +414,7 @@
assert(old_body->begin() != old_body->end());
assert(block->begin() != block->end());
+ return true;
}
void MergeReturnPass::RecordReturned(BasicBlock* block) {
diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h
index f27332f..264e8b7 100644
--- a/source/opt/merge_return_pass.h
+++ b/source/opt/merge_return_pass.h
@@ -212,7 +212,9 @@
//
// If new blocks that are created will be added to |order|. This way a call
// can traverse these new block in structured order.
- void PredicateBlocks(BasicBlock* return_block,
+ //
+ // Returns true if successful.
+ bool PredicateBlocks(BasicBlock* return_block,
std::unordered_set<BasicBlock*>* pSet,
std::list<BasicBlock*>* order);
@@ -222,7 +224,9 @@
//
// If new blocks that are created will be added to |order|. This way a call
// can traverse these new block in structured order.
- void BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block,
+ //
+ // Returns true if successful.
+ bool BreakFromConstruct(BasicBlock* block, BasicBlock* merge_block,
std::unordered_set<BasicBlock*>* predicated,
std::list<BasicBlock*>* order);
diff --git a/source/opt/pass.h b/source/opt/pass.h
index aabc645..c95f502 100644
--- a/source/opt/pass.h
+++ b/source/opt/pass.h
@@ -122,6 +122,7 @@
virtual Status Process() = 0;
// Return the next available SSA id and increment it.
+ // TODO(1841): Handle id overflow.
uint32_t TakeNextId() { return context_->TakeNextId(); }
private:
@@ -136,6 +137,10 @@
bool already_run_;
};
+inline Pass::Status CombineStatus(Pass::Status a, Pass::Status b) {
+ return std::min(a, b);
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/ssa_rewrite_pass.cpp b/source/opt/ssa_rewrite_pass.cpp
index 83d2433..0a5d390 100644
--- a/source/opt/ssa_rewrite_pass.cpp
+++ b/source/opt/ssa_rewrite_pass.cpp
@@ -90,6 +90,7 @@
SSARewriter::PhiCandidate& SSARewriter::CreatePhiCandidate(uint32_t var_id,
BasicBlock* bb) {
+ // TODO(1841): Handle id overflow.
uint32_t phi_result_id = pass_->context()->TakeNextId();
auto result = phi_candidates_.emplace(
phi_result_id, PhiCandidate(var_id, phi_result_id, bb));
diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp
index cb19cca..19eb47e 100644
--- a/source/opt/type_manager.cpp
+++ b/source/opt/type_manager.cpp
@@ -205,6 +205,7 @@
if (id != 0) return id;
std::unique_ptr<Instruction> typeInst;
+ // TODO(1841): Handle id overflow.
id = context()->TakeNextId();
RegisterType(id, *type);
switch (type->kind()) {
@@ -397,6 +398,7 @@
}
// Must create the pointer type.
+ // TODO(1841): Handle id overflow.
uint32_t resultId = context()->TakeNextId();
std::unique_ptr<Instruction> type_inst(
new Instruction(context(), SpvOpTypePointer, 0, resultId,
diff --git a/test/opt/ir_builder.cpp b/test/opt/ir_builder.cpp
index 3fd792e..4c3b9b4 100644
--- a/test/opt/ir_builder.cpp
+++ b/test/opt/ir_builder.cpp
@@ -198,6 +198,7 @@
BasicBlock& bb_merge = *fn.begin();
+ // TODO(1841): Handle id overflow.
fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
@@ -207,6 +208,7 @@
builder.AddBranch(bb_merge.id());
}
+ // TODO(1841): Handle id overflow.
fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
diff --git a/test/opt/ir_context_test.cpp b/test/opt/ir_context_test.cpp
index f66b16e..4e2f5b2 100644
--- a/test/opt/ir_context_test.cpp
+++ b/test/opt/ir_context_test.cpp
@@ -567,6 +567,103 @@
EXPECT_THAT(processed, UnorderedElementsAre(10));
}
+TEST_F(IRContextTest, IdBoundTestAtLimit) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4 = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ uint32_t current_bound = context->module()->id_bound();
+ context->set_max_id_bound(current_bound);
+ uint32_t next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, 0);
+ EXPECT_EQ(current_bound, context->module()->id_bound());
+ next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, 0);
+}
+
+TEST_F(IRContextTest, IdBoundTestBelowLimit) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4 = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ uint32_t current_bound = context->module()->id_bound();
+ context->set_max_id_bound(current_bound + 100);
+ uint32_t next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, current_bound);
+ EXPECT_EQ(current_bound + 1, context->module()->id_bound());
+ next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, current_bound + 1);
+}
+
+TEST_F(IRContextTest, IdBoundTestNearLimit) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4 = OpLabel
+OpReturn
+OpFunctionEnd)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ uint32_t current_bound = context->module()->id_bound();
+ context->set_max_id_bound(current_bound + 1);
+ uint32_t next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, current_bound);
+ EXPECT_EQ(current_bound + 1, context->module()->id_bound());
+ next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, 0);
+}
+
+TEST_F(IRContextTest, IdBoundTestUIntMax) {
+ const std::string text = R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%1 = OpTypeVoid
+%2 = OpTypeFunction %1
+%3 = OpFunction %1 None %2
+%4294967294 = OpLabel ; ID is UINT_MAX-1
+OpReturn
+OpFunctionEnd)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ uint32_t current_bound = context->module()->id_bound();
+
+ // Expecting |BuildModule| to preserve the numeric ids.
+ EXPECT_EQ(current_bound, std::numeric_limits<uint32_t>::max());
+
+ context->set_max_id_bound(current_bound);
+ uint32_t next_id_bound = context->TakeNextId();
+ EXPECT_EQ(next_id_bound, 0);
+ EXPECT_EQ(current_bound, context->module()->id_bound());
+}
} // namespace
} // namespace opt
} // namespace spvtools
diff --git a/test/opt/loop_optimizations/hoist_without_preheader.cpp b/test/opt/loop_optimizations/hoist_without_preheader.cpp
index 9a14996..2e34b01 100644
--- a/test/opt/loop_optimizations/hoist_without_preheader.cpp
+++ b/test/opt/loop_optimizations/hoist_without_preheader.cpp
@@ -117,6 +117,81 @@
SinglePassRunAndMatch<LICMPass>(text, false);
}
+TEST_F(PassClassTest, HoistWithoutPreheaderAtIdBound) {
+ const std::string text = R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 440
+OpName %main "main"
+%void = OpTypeVoid
+%4 = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_1 = OpConstant %int 1
+%int_2 = OpConstant %int 2
+%int_0 = OpConstant %int 0
+%int_10 = OpConstant %int 10
+%bool = OpTypeBool
+%int_5 = OpConstant %int 5
+%main = OpFunction %void None %4
+%13 = OpLabel
+OpBranch %14
+%14 = OpLabel
+%15 = OpPhi %int %int_0 %13 %16 %17
+OpLoopMerge %25 %17 None
+OpBranch %19
+%19 = OpLabel
+%20 = OpSLessThan %bool %15 %int_10
+OpBranchConditional %20 %21 %25
+%21 = OpLabel
+%22 = OpIEqual %bool %15 %int_5
+OpSelectionMerge %23 None
+OpBranchConditional %22 %24 %23
+%24 = OpLabel
+OpBranch %25
+%23 = OpLabel
+OpBranch %17
+%17 = OpLabel
+%16 = OpIAdd %int %15 %int_1
+OpBranch %14
+%25 = OpLabel
+%26 = OpPhi %int %int_0 %24 %int_0 %19 %27 %28
+%29 = OpPhi %int %int_0 %24 %int_0 %19 %30 %28
+OpLoopMerge %31 %28 None
+OpBranch %32
+%32 = OpLabel
+%33 = OpSLessThan %bool %29 %int_10
+OpBranchConditional %33 %34 %31
+%34 = OpLabel
+%27 = OpIAdd %int %int_1 %int_2
+OpBranch %28
+%28 = OpLabel
+%30 = OpIAdd %int %29 %int_1
+OpBranch %25
+%31 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ std::unique_ptr<IRContext> context =
+ BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ uint32_t current_bound = context->module()->id_bound();
+ context->set_max_id_bound(current_bound);
+
+ auto pass = MakeUnique<LICMPass>();
+ auto result = pass->Run(context.get());
+ EXPECT_EQ(result, Pass::Status::Failure);
+
+ std::vector<uint32_t> binary;
+ context->module()->ToBinary(&binary, false);
+ std::string optimized_asm;
+ SpirvTools tools_(SPV_ENV_UNIVERSAL_1_1);
+ tools_.Disassemble(binary, &optimized_asm);
+ std::cout << optimized_asm << std::endl;
+}
} // namespace
} // namespace opt
} // namespace spvtools