| // Copyright (c) 2018 Google LLC. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "loop_unswitch_pass.h" |
| |
| #include <functional> |
| #include <list> |
| #include <memory> |
| #include <type_traits> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| |
| #include "basic_block.h" |
| #include "dominator_tree.h" |
| #include "fold.h" |
| #include "function.h" |
| #include "instruction.h" |
| #include "ir_builder.h" |
| #include "ir_context.h" |
| #include "loop_descriptor.h" |
| |
| #include "loop_utils.h" |
| |
| namespace spvtools { |
| namespace opt { |
| namespace { |
| |
| static const uint32_t kTypePointerStorageClassInIdx = 0; |
| static const uint32_t kBranchCondTrueLabIdInIdx = 1; |
| static const uint32_t kBranchCondFalseLabIdInIdx = 2; |
| |
| } // anonymous namespace |
| |
| namespace { |
| |
| // This class handle the unswitch procedure for a given loop. |
| // The unswitch will not happen if: |
| // - The loop has any instruction that will prevent it; |
| // - The loop invariant condition is not uniform. |
| class LoopUnswitch { |
| public: |
| LoopUnswitch(ir::IRContext* context, ir::Function* function, ir::Loop* loop, |
| ir::LoopDescriptor* loop_desc) |
| : function_(function), |
| loop_(loop), |
| loop_desc_(*loop_desc), |
| context_(context), |
| switch_block_(nullptr) {} |
| |
| // Returns true if the loop can be unswitched. |
| // Can be unswitch if: |
| // - The loop has no instructions that prevents it (such as barrier); |
| // - The loop has one conditional branch or switch that do not depends on the |
| // loop; |
| // - The loop invariant condition is uniform; |
| bool CanUnswitchLoop() { |
| if (switch_block_) return true; |
| if (loop_->IsSafeToClone()) return false; |
| |
| ir::CFG& cfg = *context_->cfg(); |
| |
| for (uint32_t bb_id : loop_->GetBlocks()) { |
| ir::BasicBlock* bb = cfg.block(bb_id); |
| if (bb->terminator()->IsBranch() && |
| bb->terminator()->opcode() != SpvOpBranch) { |
| if (IsConditionLoopInvariant(bb->terminator())) { |
| switch_block_ = bb; |
| break; |
| } |
| } |
| } |
| |
| return switch_block_; |
| } |
| |
| // Return the iterator to the basic block |bb|. |
| ir::Function::iterator FindBasicBlockPosition(ir::BasicBlock* bb_to_find) { |
| ir::Function::iterator it = function_->FindBlock(bb_to_find->id()); |
| assert(it != function_->end() && "Basic Block not found"); |
| return it; |
| } |
| |
| // Creates a new basic block and insert it into the function |fn| at the |
| // position |ip|. This function preserves the def/use and instr to block |
| // managers. |
| ir::BasicBlock* CreateBasicBlock(ir::Function::iterator ip) { |
| analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); |
| |
| ir::BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<ir::BasicBlock>( |
| new ir::BasicBlock(std::unique_ptr<ir::Instruction>(new ir::Instruction( |
| context_, SpvOpLabel, 0, context_->TakeNextId(), {}))))); |
| bb->SetParent(function_); |
| def_use_mgr->AnalyzeInstDef(bb->GetLabelInst()); |
| context_->set_instr_block(bb->GetLabelInst(), bb); |
| |
| return bb; |
| } |
| |
| // Unswitches |loop_|. |
| void PerformUnswitch() { |
| assert(CanUnswitchLoop() && |
| "Cannot unswitch if there is not constant condition"); |
| assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block"); |
| assert(loop_->IsLCSSA() && "This loop is not in LCSSA form"); |
| |
| ir::CFG& cfg = *context_->cfg(); |
| DominatorTree* dom_tree = |
| &context_->GetDominatorAnalysis(function_)->GetDomTree(); |
| analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); |
| LoopUtils loop_utils(context_, loop_); |
| |
| ////////////////////////////////////////////////////////////////////////////// |
| // Step 1: Create the if merge block for structured modules. |
| // To do so, the |loop_| merge block will become the if's one and we |
| // create a merge for the loop. This will limit the amount of duplicated |
| // code the structured control flow imposes. |
| // For non structured program, the new loop will be connected to |
| // the old loop's exit blocks. |
| ////////////////////////////////////////////////////////////////////////////// |
| |
| // Get the merge block if it exists. |
| ir::BasicBlock* if_merge_block = loop_->GetMergeBlock(); |
| // The merge block is only created if the loop has a unique exit block. We |
| // have this guarantee for structured loops, for compute loop it will |
| // trivially help maintain both a structured-like form and LCSAA. |
| ir::BasicBlock* loop_merge_block = |
| if_merge_block |
| ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block)) |
| : nullptr; |
| if (loop_merge_block) { |
| // Add the instruction and update managers. |
| opt::InstructionBuilder builder( |
| context_, loop_merge_block, |
| ir::IRContext::kAnalysisDefUse | |
| ir::IRContext::kAnalysisInstrToBlockMapping); |
| builder.AddBranch(if_merge_block->id()); |
| builder.SetInsertPoint(&*loop_merge_block->begin()); |
| cfg.RegisterBlock(loop_merge_block); |
| def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst()); |
| // Update CFG. |
| if_merge_block->ForEachPhiInst( |
| [loop_merge_block, &builder, this](ir::Instruction* phi) { |
| ir::Instruction* cloned = phi->Clone(context_); |
| builder.AddInstruction(std::unique_ptr<ir::Instruction>(cloned)); |
| phi->SetInOperand(0, {cloned->result_id()}); |
| phi->SetInOperand(1, {loop_merge_block->id()}); |
| for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--) |
| phi->RemoveInOperand(j); |
| }); |
| // Copy the predecessor list (will get invalidated otherwise). |
| std::vector<uint32_t> preds = cfg.preds(if_merge_block->id()); |
| for (uint32_t pid : preds) { |
| if (pid == loop_merge_block->id()) continue; |
| ir::BasicBlock* p_bb = cfg.block(pid); |
| p_bb->ForEachSuccessorLabel( |
| [if_merge_block, loop_merge_block](uint32_t* id) { |
| if (*id == if_merge_block->id()) *id = loop_merge_block->id(); |
| }); |
| cfg.AddEdge(pid, loop_merge_block->id()); |
| } |
| cfg.RemoveNonExistingEdges(if_merge_block->id()); |
| // Update loop descriptor. |
| if (ir::Loop* ploop = loop_->GetParent()) { |
| ploop->AddBasicBlock(loop_merge_block); |
| loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop); |
| } |
| |
| // Update the dominator tree. |
| DominatorTreeNode* loop_merge_dtn = |
| dom_tree->GetOrInsertNode(loop_merge_block); |
| DominatorTreeNode* if_merge_block_dtn = |
| dom_tree->GetOrInsertNode(if_merge_block); |
| loop_merge_dtn->parent_ = if_merge_block_dtn->parent_; |
| loop_merge_dtn->children_.push_back(if_merge_block_dtn); |
| loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn); |
| if_merge_block_dtn->parent_->children_.erase(std::find( |
| if_merge_block_dtn->parent_->children_.begin(), |
| if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn)); |
| |
| loop_->SetMergeBlock(loop_merge_block); |
| } |
| |
| //////////////////////////////////////////////////////////////////////////// |
| // Step 2: Build a new preheader for |loop_|, use the old one |
| // for the constant branch. |
| //////////////////////////////////////////////////////////////////////////// |
| |
| ir::BasicBlock* if_block = loop_->GetPreHeaderBlock(); |
| // If this preheader is the parent loop header, |
| // we need to create a dedicated block for the if. |
| ir::BasicBlock* loop_pre_header = |
| CreateBasicBlock(++FindBasicBlockPosition(if_block)); |
| opt::InstructionBuilder(context_, loop_pre_header, |
| ir::IRContext::kAnalysisDefUse | |
| ir::IRContext::kAnalysisInstrToBlockMapping) |
| .AddBranch(loop_->GetHeaderBlock()->id()); |
| |
| if_block->tail()->SetInOperand(0, {loop_pre_header->id()}); |
| |
| // Update loop descriptor. |
| if (ir::Loop* ploop = loop_desc_[if_block]) { |
| ploop->AddBasicBlock(loop_pre_header); |
| loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop); |
| } |
| |
| // Update the CFG. |
| cfg.RegisterBlock(loop_pre_header); |
| def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst()); |
| cfg.AddEdge(if_block->id(), loop_pre_header->id()); |
| cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id()); |
| |
| loop_->GetHeaderBlock()->ForEachPhiInst( |
| [loop_pre_header, if_block](ir::Instruction* phi) { |
| phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) { |
| if (*id == if_block->id()) { |
| *id = loop_pre_header->id(); |
| } |
| }); |
| }); |
| loop_->SetPreHeaderBlock(loop_pre_header); |
| |
| // Update the dominator tree. |
| DominatorTreeNode* loop_pre_header_dtn = |
| dom_tree->GetOrInsertNode(loop_pre_header); |
| DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block); |
| loop_pre_header_dtn->parent_ = if_block_dtn; |
| assert( |
| if_block_dtn->children_.size() == 1 && |
| "A loop preheader should only have the header block as a child in the " |
| "dominator tree"); |
| loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]); |
| if_block_dtn->children_.clear(); |
| if_block_dtn->children_.push_back(loop_pre_header_dtn); |
| |
| // Make domination queries valid. |
| dom_tree->ResetDFNumbering(); |
| |
| // Compute an ordered list of basic block to clone: loop blocks + pre-header |
| // + merge block. |
| loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true); |
| |
| ///////////////////////////// |
| // Do the actual unswitch: // |
| // - Clone the loop // |
| // - Connect exits // |
| // - Specialize the loop // |
| ///////////////////////////// |
| |
| ir::Instruction* iv_condition = &*switch_block_->tail(); |
| SpvOp iv_opcode = iv_condition->opcode(); |
| ir::Instruction* condition = |
| def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]); |
| |
| analysis::ConstantManager* cst_mgr = context_->get_constant_mgr(); |
| const analysis::Type* cond_type = |
| context_->get_type_mgr()->GetType(condition->type_id()); |
| |
| // Build the list of value for which we need to clone and specialize the |
| // loop. |
| std::vector<std::pair<ir::Instruction*, ir::BasicBlock*>> constant_branch; |
| // Special case for the original loop |
| ir::Instruction* original_loop_constant_value; |
| ir::BasicBlock* original_loop_target; |
| if (iv_opcode == SpvOpBranchConditional) { |
| constant_branch.emplace_back( |
| cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})), |
| nullptr); |
| original_loop_constant_value = |
| cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1})); |
| } else { |
| // We are looking to take the default branch, so we can't provide a |
| // specific value. |
| original_loop_constant_value = nullptr; |
| for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) { |
| constant_branch.emplace_back( |
| cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant( |
| cond_type, iv_condition->GetInOperand(i).words)), |
| nullptr); |
| } |
| } |
| |
| // Get the loop landing pads. |
| std::unordered_set<uint32_t> if_merging_blocks; |
| std::function<bool(uint32_t)> is_from_original_loop; |
| if (loop_->GetHeaderBlock()->GetLoopMergeInst()) { |
| if_merging_blocks.insert(if_merge_block->id()); |
| is_from_original_loop = [this](uint32_t id) { |
| return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id; |
| }; |
| } else { |
| loop_->GetExitBlocks(&if_merging_blocks); |
| is_from_original_loop = [this](uint32_t id) { |
| return loop_->IsInsideLoop(id); |
| }; |
| } |
| |
| for (auto& specialisation_pair : constant_branch) { |
| ir::Instruction* specialisation_value = specialisation_pair.first; |
| ////////////////////////////////////////////////////////// |
| // Step 3: Duplicate |loop_|. |
| ////////////////////////////////////////////////////////// |
| LoopUtils::LoopCloningResult clone_result; |
| |
| ir::Loop* cloned_loop = |
| loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_); |
| specialisation_pair.second = cloned_loop->GetPreHeaderBlock(); |
| |
| //////////////////////////////////// |
| // Step 4: Specialize the loop. // |
| //////////////////////////////////// |
| |
| { |
| std::unordered_set<uint32_t> dead_blocks; |
| std::unordered_set<uint32_t> unreachable_merges; |
| SimplifyLoop( |
| ir::make_range( |
| ir::UptrVectorIterator<ir::BasicBlock>( |
| &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), |
| ir::UptrVectorIterator<ir::BasicBlock>( |
| &clone_result.cloned_bb_, clone_result.cloned_bb_.end())), |
| cloned_loop, condition, specialisation_value, &dead_blocks); |
| |
| // We tagged dead blocks, create the loop before we invalidate any basic |
| // block. |
| cloned_loop = |
| CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges); |
| CleanUpCFG( |
| ir::UptrVectorIterator<ir::BasicBlock>( |
| &clone_result.cloned_bb_, clone_result.cloned_bb_.begin()), |
| dead_blocks, unreachable_merges); |
| |
| /////////////////////////////////////////////////////////// |
| // Step 5: Connect convergent edges to the landing pads. // |
| /////////////////////////////////////////////////////////// |
| |
| for (uint32_t merge_bb_id : if_merging_blocks) { |
| ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); |
| // We are in LCSSA so we only care about phi instructions. |
| merge->ForEachPhiInst([is_from_original_loop, &dead_blocks, |
| &clone_result](ir::Instruction* phi) { |
| uint32_t num_in_operands = phi->NumInOperands(); |
| for (uint32_t i = 0; i < num_in_operands; i += 2) { |
| uint32_t pred = phi->GetSingleWordInOperand(i + 1); |
| if (is_from_original_loop(pred)) { |
| pred = clone_result.value_map_.at(pred); |
| if (!dead_blocks.count(pred)) { |
| uint32_t incoming_value_id = phi->GetSingleWordInOperand(i); |
| // Not all the incoming value are coming from the loop. |
| ValueMapTy::iterator new_value = |
| clone_result.value_map_.find(incoming_value_id); |
| if (new_value != clone_result.value_map_.end()) { |
| incoming_value_id = new_value->second; |
| } |
| phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}}); |
| phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}}); |
| } |
| } |
| } |
| }); |
| } |
| } |
| function_->AddBasicBlocks(clone_result.cloned_bb_.begin(), |
| clone_result.cloned_bb_.end(), |
| ++FindBasicBlockPosition(if_block)); |
| } |
| |
| // Same as above but specialize the existing loop |
| { |
| std::unordered_set<uint32_t> dead_blocks; |
| std::unordered_set<uint32_t> unreachable_merges; |
| SimplifyLoop(ir::make_range(function_->begin(), function_->end()), loop_, |
| condition, original_loop_constant_value, &dead_blocks); |
| |
| for (uint32_t merge_bb_id : if_merging_blocks) { |
| ir::BasicBlock* merge = context_->cfg()->block(merge_bb_id); |
| // LCSSA, so we only care about phi instructions. |
| // If we the phi is reduced to a single incoming branch, do not |
| // propagate it to preserve LCSSA. |
| PatchPhis(merge, dead_blocks, true); |
| } |
| if (if_merge_block) { |
| bool has_live_pred = false; |
| for (uint32_t pid : cfg.preds(if_merge_block->id())) { |
| if (!dead_blocks.count(pid)) { |
| has_live_pred = true; |
| break; |
| } |
| } |
| if (!has_live_pred) unreachable_merges.insert(if_merge_block->id()); |
| } |
| original_loop_target = loop_->GetPreHeaderBlock(); |
| // We tagged dead blocks, prune the loop descriptor from any dead loops. |
| // After this call, |loop_| can be nullptr (i.e. the unswitch killed this |
| // loop). |
| loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges); |
| |
| CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges); |
| } |
| |
| ///////////////////////////////////// |
| // Finally: connect the new loops. // |
| ///////////////////////////////////// |
| |
| // Delete the old jump |
| context_->KillInst(&*if_block->tail()); |
| opt::InstructionBuilder builder(context_, if_block); |
| if (iv_opcode == SpvOpBranchConditional) { |
| assert(constant_branch.size() == 1); |
| builder.AddConditionalBranch( |
| condition->result_id(), original_loop_target->id(), |
| constant_branch[0].second->id(), |
| if_merge_block ? if_merge_block->id() : kInvalidId); |
| } else { |
| std::vector<std::pair<std::vector<uint32_t>, uint32_t>> targets; |
| for (auto& t : constant_branch) { |
| targets.emplace_back(t.first->GetInOperand(0).words, t.second->id()); |
| } |
| |
| builder.AddSwitch(condition->result_id(), original_loop_target->id(), |
| targets, |
| if_merge_block ? if_merge_block->id() : kInvalidId); |
| } |
| |
| switch_block_ = nullptr; |
| ordered_loop_blocks_.clear(); |
| |
| context_->InvalidateAnalysesExceptFor( |
| ir::IRContext::Analysis::kAnalysisLoopAnalysis); |
| } |
| |
| // Returns true if the unswitch killed the original |loop_|. |
| bool WasLoopKilled() const { return loop_ == nullptr; } |
| |
| private: |
| using ValueMapTy = std::unordered_map<uint32_t, uint32_t>; |
| using BlockMapTy = std::unordered_map<uint32_t, ir::BasicBlock*>; |
| |
| ir::Function* function_; |
| ir::Loop* loop_; |
| ir::LoopDescriptor& loop_desc_; |
| ir::IRContext* context_; |
| |
| ir::BasicBlock* switch_block_; |
| // Map between instructions and if they are dynamically uniform. |
| std::unordered_map<uint32_t, bool> dynamically_uniform_; |
| // The loop basic blocks in structured order. |
| std::vector<ir::BasicBlock*> ordered_loop_blocks_; |
| |
| // Returns the next usable id for the context. |
| uint32_t TakeNextId() { return context_->TakeNextId(); } |
| |
| // Patches |bb|'s phi instruction by removing incoming value from unexisting |
| // or tagged as dead branches. |
| void PatchPhis(ir::BasicBlock* bb, |
| const std::unordered_set<uint32_t>& dead_blocks, |
| bool preserve_phi) { |
| ir::CFG& cfg = *context_->cfg(); |
| |
| std::vector<ir::Instruction*> phi_to_kill; |
| const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id()); |
| auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) { |
| return dead_blocks.count(id) || |
| std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end(); |
| }; |
| bb->ForEachPhiInst([&phi_to_kill, &is_branch_dead, preserve_phi, |
| this](ir::Instruction* insn) { |
| uint32_t i = 0; |
| while (i < insn->NumInOperands()) { |
| uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1); |
| if (is_branch_dead(incoming_id)) { |
| // Remove the incoming block id operand. |
| insn->RemoveInOperand(i + 1); |
| // Remove the definition id operand. |
| insn->RemoveInOperand(i); |
| continue; |
| } |
| i += 2; |
| } |
| // If there is only 1 remaining edge, propagate the value and |
| // kill the instruction. |
| if (insn->NumInOperands() == 2 && !preserve_phi) { |
| phi_to_kill.push_back(insn); |
| context_->ReplaceAllUsesWith(insn->result_id(), |
| insn->GetSingleWordInOperand(0)); |
| } |
| }); |
| for (ir::Instruction* insn : phi_to_kill) { |
| context_->KillInst(insn); |
| } |
| } |
| |
| // Removes any block that is tagged as dead, if the block is in |
| // |unreachable_merges| then all block's instructions are replaced by a |
| // OpUnreachable. |
| void CleanUpCFG(ir::UptrVectorIterator<ir::BasicBlock> bb_it, |
| const std::unordered_set<uint32_t>& dead_blocks, |
| const std::unordered_set<uint32_t>& unreachable_merges) { |
| ir::CFG& cfg = *context_->cfg(); |
| |
| while (bb_it != bb_it.End()) { |
| ir::BasicBlock& bb = *bb_it; |
| |
| if (unreachable_merges.count(bb.id())) { |
| if (bb.begin() != bb.tail() || |
| bb.terminator()->opcode() != SpvOpUnreachable) { |
| // Make unreachable, but leave the label. |
| bb.KillAllInsts(false); |
| opt::InstructionBuilder(context_, &bb).AddUnreachable(); |
| cfg.RemoveNonExistingEdges(bb.id()); |
| } |
| ++bb_it; |
| } else if (dead_blocks.count(bb.id())) { |
| cfg.ForgetBlock(&bb); |
| // Kill this block. |
| bb.KillAllInsts(true); |
| bb_it = bb_it.Erase(); |
| } else { |
| cfg.RemoveNonExistingEdges(bb.id()); |
| ++bb_it; |
| } |
| } |
| } |
| |
| // Return true if |c_inst| is a Boolean constant and set |cond_val| with the |
| // value that |c_inst| |
| bool GetConstCondition(const ir::Instruction* c_inst, bool* cond_val) { |
| bool cond_is_const; |
| switch (c_inst->opcode()) { |
| case SpvOpConstantFalse: { |
| *cond_val = false; |
| cond_is_const = true; |
| } break; |
| case SpvOpConstantTrue: { |
| *cond_val = true; |
| cond_is_const = true; |
| } break; |
| default: { cond_is_const = false; } break; |
| } |
| return cond_is_const; |
| } |
| |
| // Simplifies |loop| assuming the instruction |to_version_insn| takes the |
| // value |cst_value|. |block_range| is an iterator range returning the loop |
| // basic blocks in a structured order (dominator first). |
| // The function will ignore basic blocks returned by |block_range| if they |
| // does not belong to the loop. |
| // The set |dead_blocks| will contain all the dead basic blocks. |
| // |
| // Requirements: |
| // - |loop| must be in the LCSSA form; |
| // - |cst_value| must be constant or null (to represent the default target |
| // of an OpSwitch). |
| void SimplifyLoop( |
| ir::IteratorRange<ir::UptrVectorIterator<ir::BasicBlock>> block_range, |
| ir::Loop* loop, ir::Instruction* to_version_insn, |
| ir::Instruction* cst_value, std::unordered_set<uint32_t>* dead_blocks) { |
| ir::CFG& cfg = *context_->cfg(); |
| analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); |
| |
| std::function<bool(uint32_t)> ignore_node; |
| ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); }; |
| |
| std::vector<std::pair<ir::Instruction*, uint32_t>> use_list; |
| def_use_mgr->ForEachUse( |
| to_version_insn, [&use_list, &ignore_node, this]( |
| ir::Instruction* inst, uint32_t operand_index) { |
| ir::BasicBlock* bb = context_->get_instr_block(inst); |
| |
| if (!bb || ignore_node(bb->id())) { |
| // Out of the loop, the specialization does not apply any more. |
| return; |
| } |
| use_list.emplace_back(inst, operand_index); |
| }); |
| |
| // First pass: inject the specialized value into the loop (and only the |
| // loop). |
| for (auto use : use_list) { |
| ir::Instruction* inst = use.first; |
| uint32_t operand_index = use.second; |
| ir::BasicBlock* bb = context_->get_instr_block(inst); |
| |
| // If it is not a branch, simply inject the value. |
| if (!inst->IsBranch()) { |
| // To also handle switch, cst_value can be nullptr: this case |
| // means that we are looking to branch to the default target of |
| // the switch. We don't actually know its value so we don't touch |
| // it if it not a switch. |
| if (cst_value) { |
| inst->SetOperand(operand_index, {cst_value->result_id()}); |
| def_use_mgr->AnalyzeInstUse(inst); |
| } |
| } |
| |
| // The user is a branch, kill dead branches. |
| uint32_t live_target = 0; |
| std::unordered_set<uint32_t> dead_branches; |
| switch (inst->opcode()) { |
| case SpvOpBranchConditional: { |
| assert(cst_value && "No constant value to specialize !"); |
| bool branch_cond = false; |
| if (GetConstCondition(cst_value, &branch_cond)) { |
| uint32_t true_label = |
| inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx); |
| uint32_t false_label = |
| inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx); |
| live_target = branch_cond ? true_label : false_label; |
| uint32_t dead_target = !branch_cond ? true_label : false_label; |
| cfg.RemoveEdge(bb->id(), dead_target); |
| } |
| break; |
| } |
| case SpvOpSwitch: { |
| live_target = inst->GetSingleWordInOperand(1); |
| if (cst_value) { |
| if (!cst_value->IsConstant()) break; |
| const ir::Operand& cst = cst_value->GetInOperand(0); |
| for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) { |
| const ir::Operand& literal = inst->GetInOperand(i); |
| if (literal == cst) { |
| live_target = inst->GetSingleWordInOperand(i + 1); |
| break; |
| } |
| } |
| } |
| for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) { |
| uint32_t id = inst->GetSingleWordInOperand(i); |
| if (id != live_target) { |
| cfg.RemoveEdge(bb->id(), id); |
| } |
| } |
| } |
| default: |
| break; |
| } |
| if (live_target != 0) { |
| // Check for the presence of the merge block. |
| if (ir::Instruction* merge = bb->GetMergeInst()) |
| context_->KillInst(merge); |
| context_->KillInst(&*bb->tail()); |
| opt::InstructionBuilder builder( |
| context_, bb, |
| ir::IRContext::kAnalysisDefUse | |
| ir::IRContext::kAnalysisInstrToBlockMapping); |
| builder.AddBranch(live_target); |
| } |
| } |
| |
| // Go through the loop basic block and tag all blocks that are obviously |
| // dead. |
| std::unordered_set<uint32_t> visited; |
| for (ir::BasicBlock& bb : block_range) { |
| if (ignore_node(bb.id())) continue; |
| visited.insert(bb.id()); |
| |
| // Check if this block is dead, if so tag it as dead otherwise patch phi |
| // instructions. |
| bool has_live_pred = false; |
| for (uint32_t pid : cfg.preds(bb.id())) { |
| if (!dead_blocks->count(pid)) { |
| has_live_pred = true; |
| break; |
| } |
| } |
| if (!has_live_pred) { |
| dead_blocks->insert(bb.id()); |
| const ir::BasicBlock& cbb = bb; |
| // Patch the phis for any back-edge. |
| cbb.ForEachSuccessorLabel( |
| [dead_blocks, &visited, &cfg, this](uint32_t id) { |
| if (!visited.count(id) || dead_blocks->count(id)) return; |
| ir::BasicBlock* succ = cfg.block(id); |
| PatchPhis(succ, *dead_blocks, false); |
| }); |
| continue; |
| } |
| // Update the phi instructions, some incoming branch have/will disappear. |
| PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false); |
| } |
| } |
| |
| // Returns true if the header is not reachable or tagged as dead or if we |
| // never loop back. |
| bool IsLoopDead(ir::BasicBlock* header, ir::BasicBlock* latch, |
| const std::unordered_set<uint32_t>& dead_blocks) { |
| if (!header || dead_blocks.count(header->id())) return true; |
| if (!latch || dead_blocks.count(latch->id())) return true; |
| for (uint32_t pid : context_->cfg()->preds(header->id())) { |
| if (!dead_blocks.count(pid)) { |
| // Seems reachable. |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| // Cleans the loop nest under |loop| and reflect changes to the loop |
| // descriptor. This will kill all descriptors that represent dead loops. |
| // If |loop_| is killed, it will be set to nullptr. |
| // Any merge blocks that become unreachable will be added to |
| // |unreachable_merges|. |
| // The function returns the pointer to |loop| or nullptr if the loop was |
| // killed. |
| ir::Loop* CleanLoopNest(ir::Loop* loop, |
| const std::unordered_set<uint32_t>& dead_blocks, |
| std::unordered_set<uint32_t>* unreachable_merges) { |
| // This represent the pair of dead loop and nearest alive parent (nullptr if |
| // no parent). |
| std::unordered_map<ir::Loop*, ir::Loop*> dead_loops; |
| auto get_parent = [&dead_loops](ir::Loop* l) -> ir::Loop* { |
| std::unordered_map<ir::Loop*, ir::Loop*>::iterator it = |
| dead_loops.find(l); |
| if (it != dead_loops.end()) return it->second; |
| return nullptr; |
| }; |
| |
| bool is_main_loop_dead = |
| IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks); |
| if (is_main_loop_dead) { |
| if (ir::Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) { |
| context_->KillInst(merge); |
| } |
| dead_loops[loop] = loop->GetParent(); |
| } else |
| dead_loops[loop] = loop; |
| // For each loop, check if we killed it. If we did, find a suitable parent |
| // for its children. |
| for (ir::Loop& sub_loop : |
| ir::make_range(++opt::TreeDFIterator<ir::Loop>(loop), |
| opt::TreeDFIterator<ir::Loop>())) { |
| if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(), |
| dead_blocks)) { |
| if (ir::Instruction* merge = |
| sub_loop.GetHeaderBlock()->GetLoopMergeInst()) { |
| context_->KillInst(merge); |
| } |
| dead_loops[&sub_loop] = get_parent(&sub_loop); |
| } else { |
| // The loop is alive, check if its merge block is dead, if it is, tag it |
| // as required. |
| if (sub_loop.GetMergeBlock()) { |
| uint32_t merge_id = sub_loop.GetMergeBlock()->id(); |
| if (dead_blocks.count(merge_id)) { |
| unreachable_merges->insert(sub_loop.GetMergeBlock()->id()); |
| } |
| } |
| } |
| } |
| if (!is_main_loop_dead) dead_loops.erase(loop); |
| |
| // Remove dead blocks from live loops. |
| for (uint32_t bb_id : dead_blocks) { |
| ir::Loop* l = loop_desc_[bb_id]; |
| if (l) { |
| l->RemoveBasicBlock(bb_id); |
| loop_desc_.ForgetBasicBlock(bb_id); |
| } |
| } |
| |
| std::for_each( |
| dead_loops.begin(), dead_loops.end(), |
| [&loop, this]( |
| std::unordered_map<ir::Loop*, ir::Loop*>::iterator::reference it) { |
| if (it.first == loop) loop = nullptr; |
| loop_desc_.RemoveLoop(it.first); |
| }); |
| |
| return loop; |
| } |
| |
| // Returns true if |var| is dynamically uniform. |
| // Note: this is currently approximated as uniform. |
| bool IsDynamicallyUniform(ir::Instruction* var, const ir::BasicBlock* entry, |
| const DominatorTree& post_dom_tree) { |
| assert(post_dom_tree.IsPostDominator()); |
| analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); |
| |
| auto it = dynamically_uniform_.find(var->result_id()); |
| |
| if (it != dynamically_uniform_.end()) return it->second; |
| |
| analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr(); |
| |
| bool& is_uniform = dynamically_uniform_[var->result_id()]; |
| is_uniform = false; |
| |
| dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform, |
| [&is_uniform](const ir::Instruction&) { |
| is_uniform = true; |
| return false; |
| }); |
| if (is_uniform) { |
| return is_uniform; |
| } |
| |
| ir::BasicBlock* parent = context_->get_instr_block(var); |
| if (!parent) { |
| return is_uniform = true; |
| } |
| |
| if (!post_dom_tree.Dominates(parent->id(), entry->id())) { |
| return is_uniform = false; |
| } |
| if (var->opcode() == SpvOpLoad) { |
| const uint32_t PtrTypeId = |
| def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id(); |
| const ir::Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId); |
| uint32_t storage_class = |
| PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx); |
| if (storage_class != SpvStorageClassUniform && |
| storage_class != SpvStorageClassUniformConstant) { |
| return is_uniform = false; |
| } |
| } else { |
| if (!context_->IsCombinatorInstruction(var)) { |
| return is_uniform = false; |
| } |
| } |
| |
| return is_uniform = var->WhileEachInId([entry, &post_dom_tree, |
| this](const uint32_t* id) { |
| return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id), |
| entry, post_dom_tree); |
| }); |
| } |
| |
| // Returns true if |insn| is constant and dynamically uniform within the loop. |
| bool IsConditionLoopInvariant(ir::Instruction* insn) { |
| assert(insn->IsBranch()); |
| assert(insn->opcode() != SpvOpBranch); |
| analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr(); |
| |
| ir::Instruction* condition = |
| def_use_mgr->GetDef(insn->GetOperand(0).words[0]); |
| return !loop_->IsInsideLoop(condition) && |
| IsDynamicallyUniform( |
| condition, function_->entry().get(), |
| context_->GetPostDominatorAnalysis(function_)->GetDomTree()); |
| } |
| }; |
| |
| } // namespace |
| |
| Pass::Status LoopUnswitchPass::Process(ir::IRContext* c) { |
| InitializeProcessing(c); |
| |
| bool modified = false; |
| ir::Module* module = c->module(); |
| |
| // Process each function in the module |
| for (ir::Function& f : *module) { |
| modified |= ProcessFunction(&f); |
| } |
| |
| return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; |
| } |
| |
| bool LoopUnswitchPass::ProcessFunction(ir::Function* f) { |
| bool modified = false; |
| std::unordered_set<ir::Loop*> processed_loop; |
| |
| ir::LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f); |
| |
| bool loop_changed = true; |
| while (loop_changed) { |
| loop_changed = false; |
| for (ir::Loop& loop : |
| ir::make_range(++opt::TreeDFIterator<ir::Loop>( |
| loop_descriptor.GetDummyRootLoop()), |
| opt::TreeDFIterator<ir::Loop>())) { |
| if (processed_loop.count(&loop)) continue; |
| processed_loop.insert(&loop); |
| |
| LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor); |
| while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) { |
| if (!loop.IsLCSSA()) { |
| LoopUtils(context(), &loop).MakeLoopClosedSSA(); |
| } |
| modified = true; |
| loop_changed = true; |
| unswitcher.PerformUnswitch(); |
| } |
| if (loop_changed) break; |
| } |
| } |
| |
| return modified; |
| } |
| |
| } // namespace opt |
| } // namespace spvtools |