1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/loop_descriptor.h"
16 
17 #include <algorithm>
18 #include <limits>
19 #include <stack>
20 #include <utility>
21 #include <vector>
22 
23 #include "source/opt/cfg.h"
24 #include "source/opt/constants.h"
25 #include "source/opt/dominator_tree.h"
26 #include "source/opt/ir_context.h"
27 #include "source/opt/iterator.h"
28 #include "source/opt/tree_iterator.h"
29 #include "source/util/make_unique.h"
30 
31 namespace spvtools {
32 namespace opt {
33 
34 // Takes in a phi instruction |induction| and the loop |header| and returns the
35 // step operation of the loop.
GetInductionStepOperation(const Instruction * induction) const36 Instruction* Loop::GetInductionStepOperation(
37     const Instruction* induction) const {
38   // Induction must be a phi instruction.
39   assert(induction->opcode() == spv::Op::OpPhi);
40 
41   Instruction* step = nullptr;
42 
43   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
44 
45   // Traverse the incoming operands of the phi instruction.
46   for (uint32_t operand_id = 1; operand_id < induction->NumInOperands();
47        operand_id += 2) {
48     // Incoming edge.
49     BasicBlock* incoming_block =
50         context_->cfg()->block(induction->GetSingleWordInOperand(operand_id));
51 
52     // Check if the block is dominated by header, and thus coming from within
53     // the loop.
54     if (IsInsideLoop(incoming_block)) {
55       step = def_use_manager->GetDef(
56           induction->GetSingleWordInOperand(operand_id - 1));
57       break;
58     }
59   }
60 
61   if (!step || !IsSupportedStepOp(step->opcode())) {
62     return nullptr;
63   }
64 
65   // The induction variable which binds the loop must only be modified once.
66   uint32_t lhs = step->GetSingleWordInOperand(0);
67   uint32_t rhs = step->GetSingleWordInOperand(1);
68 
69   // One of the left hand side or right hand side of the step instruction must
70   // be the induction phi and the other must be an OpConstant.
71   if (lhs != induction->result_id() && rhs != induction->result_id()) {
72     return nullptr;
73   }
74 
75   if (def_use_manager->GetDef(lhs)->opcode() != spv::Op::OpConstant &&
76       def_use_manager->GetDef(rhs)->opcode() != spv::Op::OpConstant) {
77     return nullptr;
78   }
79 
80   return step;
81 }
82 
83 // Returns true if the |step| operation is an induction variable step operation
84 // which is currently handled.
IsSupportedStepOp(spv::Op step) const85 bool Loop::IsSupportedStepOp(spv::Op step) const {
86   switch (step) {
87     case spv::Op::OpISub:
88     case spv::Op::OpIAdd:
89       return true;
90     default:
91       return false;
92   }
93 }
94 
IsSupportedCondition(spv::Op condition) const95 bool Loop::IsSupportedCondition(spv::Op condition) const {
96   switch (condition) {
97     // <
98     case spv::Op::OpULessThan:
99     case spv::Op::OpSLessThan:
100     // >
101     case spv::Op::OpUGreaterThan:
102     case spv::Op::OpSGreaterThan:
103 
104     // >=
105     case spv::Op::OpSGreaterThanEqual:
106     case spv::Op::OpUGreaterThanEqual:
107     // <=
108     case spv::Op::OpSLessThanEqual:
109     case spv::Op::OpULessThanEqual:
110 
111       return true;
112     default:
113       return false;
114   }
115 }
116 
GetResidualConditionValue(spv::Op condition,int64_t initial_value,int64_t step_value,size_t number_of_iterations,size_t factor)117 int64_t Loop::GetResidualConditionValue(spv::Op condition,
118                                         int64_t initial_value,
119                                         int64_t step_value,
120                                         size_t number_of_iterations,
121                                         size_t factor) {
122   int64_t remainder =
123       initial_value + (number_of_iterations % factor) * step_value;
124 
125   // We subtract or add one as the above formula calculates the remainder if the
126   // loop where just less than or greater than. Adding or subtracting one should
127   // give a functionally equivalent value.
128   switch (condition) {
129     case spv::Op::OpSGreaterThanEqual:
130     case spv::Op::OpUGreaterThanEqual: {
131       remainder -= 1;
132       break;
133     }
134     case spv::Op::OpSLessThanEqual:
135     case spv::Op::OpULessThanEqual: {
136       remainder += 1;
137       break;
138     }
139 
140     default:
141       break;
142   }
143   return remainder;
144 }
145 
GetConditionInst() const146 Instruction* Loop::GetConditionInst() const {
147   BasicBlock* condition_block = FindConditionBlock();
148   if (!condition_block) {
149     return nullptr;
150   }
151   Instruction* branch_conditional = &*condition_block->tail();
152   if (!branch_conditional ||
153       branch_conditional->opcode() != spv::Op::OpBranchConditional) {
154     return nullptr;
155   }
156   Instruction* condition_inst = context_->get_def_use_mgr()->GetDef(
157       branch_conditional->GetSingleWordInOperand(0));
158   if (IsSupportedCondition(condition_inst->opcode())) {
159     return condition_inst;
160   }
161 
162   return nullptr;
163 }
164 
165 // Extract the initial value from the |induction| OpPhi instruction and store it
166 // in |value|. If the function couldn't find the initial value of |induction|
167 // return false.
GetInductionInitValue(const Instruction * induction,int64_t * value) const168 bool Loop::GetInductionInitValue(const Instruction* induction,
169                                  int64_t* value) const {
170   Instruction* constant_instruction = nullptr;
171   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
172 
173   for (uint32_t operand_id = 0; operand_id < induction->NumInOperands();
174        operand_id += 2) {
175     BasicBlock* bb = context_->cfg()->block(
176         induction->GetSingleWordInOperand(operand_id + 1));
177 
178     if (!IsInsideLoop(bb)) {
179       constant_instruction = def_use_manager->GetDef(
180           induction->GetSingleWordInOperand(operand_id));
181     }
182   }
183 
184   if (!constant_instruction) return false;
185 
186   const analysis::Constant* constant =
187       context_->get_constant_mgr()->FindDeclaredConstant(
188           constant_instruction->result_id());
189   if (!constant) return false;
190 
191   if (value) {
192     const analysis::Integer* type = constant->type()->AsInteger();
193     if (!type) {
194       return false;
195     }
196 
197     *value = type->IsSigned() ? constant->GetSignExtendedValue()
198                               : constant->GetZeroExtendedValue();
199   }
200 
201   return true;
202 }
203 
Loop(IRContext * context,DominatorAnalysis * dom_analysis,BasicBlock * header,BasicBlock * continue_target,BasicBlock * merge_target)204 Loop::Loop(IRContext* context, DominatorAnalysis* dom_analysis,
205            BasicBlock* header, BasicBlock* continue_target,
206            BasicBlock* merge_target)
207     : context_(context),
208       loop_header_(header),
209       loop_continue_(continue_target),
210       loop_merge_(merge_target),
211       loop_preheader_(nullptr),
212       parent_(nullptr),
213       loop_is_marked_for_removal_(false) {
214   assert(context);
215   assert(dom_analysis);
216   loop_preheader_ = FindLoopPreheader(dom_analysis);
217   loop_latch_ = FindLatchBlock();
218 }
219 
FindLoopPreheader(DominatorAnalysis * dom_analysis)220 BasicBlock* Loop::FindLoopPreheader(DominatorAnalysis* dom_analysis) {
221   CFG* cfg = context_->cfg();
222   DominatorTree& dom_tree = dom_analysis->GetDomTree();
223   DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_);
224 
225   // The loop predecessor.
226   BasicBlock* loop_pred = nullptr;
227 
228   auto header_pred = cfg->preds(loop_header_->id());
229   for (uint32_t p_id : header_pred) {
230     DominatorTreeNode* node = dom_tree.GetTreeNode(p_id);
231     if (node && !dom_tree.Dominates(header_node, node)) {
232       // The predecessor is not part of the loop, so potential loop preheader.
233       if (loop_pred && node->bb_ != loop_pred) {
234         // If we saw 2 distinct predecessors that are outside the loop, we don't
235         // have a loop preheader.
236         return nullptr;
237       }
238       loop_pred = node->bb_;
239     }
240   }
241   // Safe guard against invalid code, SPIR-V spec forbids loop with the entry
242   // node as header.
243   assert(loop_pred && "The header node is the entry block ?");
244 
245   // So we have a unique basic block that can enter this loop.
246   // If this loop is the unique successor of this block, then it is a loop
247   // preheader.
248   bool is_preheader = true;
249   uint32_t loop_header_id = loop_header_->id();
250   const auto* const_loop_pred = loop_pred;
251   const_loop_pred->ForEachSuccessorLabel(
252       [&is_preheader, loop_header_id](const uint32_t id) {
253         if (id != loop_header_id) is_preheader = false;
254       });
255   if (is_preheader) return loop_pred;
256   return nullptr;
257 }
258 
IsInsideLoop(Instruction * inst) const259 bool Loop::IsInsideLoop(Instruction* inst) const {
260   const BasicBlock* parent_block = context_->get_instr_block(inst);
261   if (!parent_block) return false;
262   return IsInsideLoop(parent_block);
263 }
264 
IsBasicBlockInLoopSlow(const BasicBlock * bb)265 bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
266   assert(bb->GetParent() && "The basic block does not belong to a function");
267   DominatorAnalysis* dom_analysis =
268       context_->GetDominatorAnalysis(bb->GetParent());
269   if (dom_analysis->IsReachable(bb) &&
270       !dom_analysis->Dominates(GetHeaderBlock(), bb))
271     return false;
272 
273   return true;
274 }
275 
GetOrCreatePreHeaderBlock()276 BasicBlock* Loop::GetOrCreatePreHeaderBlock() {
277   if (loop_preheader_) return loop_preheader_;
278 
279   CFG* cfg = context_->cfg();
280   loop_header_ = cfg->SplitLoopHeader(loop_header_);
281   return loop_preheader_;
282 }
283 
SetContinueBlock(BasicBlock * continue_block)284 void Loop::SetContinueBlock(BasicBlock* continue_block) {
285   assert(IsInsideLoop(continue_block));
286   loop_continue_ = continue_block;
287 }
288 
SetLatchBlock(BasicBlock * latch)289 void Loop::SetLatchBlock(BasicBlock* latch) {
290 #ifndef NDEBUG
291   assert(latch->GetParent() && "The basic block does not belong to a function");
292 
293   const auto* const_latch = latch;
294   const_latch->ForEachSuccessorLabel([this](uint32_t id) {
295     assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) &&
296            "A predecessor of the continue block does not belong to the loop");
297   });
298 #endif  // NDEBUG
299   assert(IsInsideLoop(latch) && "The continue block is not in the loop");
300 
301   SetLatchBlockImpl(latch);
302 }
303 
SetMergeBlock(BasicBlock * merge)304 void Loop::SetMergeBlock(BasicBlock* merge) {
305 #ifndef NDEBUG
306   assert(merge->GetParent() && "The basic block does not belong to a function");
307 #endif  // NDEBUG
308   assert(!IsInsideLoop(merge) && "The merge block is in the loop");
309 
310   SetMergeBlockImpl(merge);
311   if (GetHeaderBlock()->GetLoopMergeInst()) {
312     UpdateLoopMergeInst();
313   }
314 }
315 
SetPreHeaderBlock(BasicBlock * preheader)316 void Loop::SetPreHeaderBlock(BasicBlock* preheader) {
317   if (preheader) {
318     assert(!IsInsideLoop(preheader) && "The preheader block is in the loop");
319     assert(preheader->tail()->opcode() == spv::Op::OpBranch &&
320            "The preheader block does not unconditionally branch to the header "
321            "block");
322     assert(preheader->tail()->GetSingleWordOperand(0) ==
323                GetHeaderBlock()->id() &&
324            "The preheader block does not unconditionally branch to the header "
325            "block");
326   }
327   loop_preheader_ = preheader;
328 }
329 
FindLatchBlock()330 BasicBlock* Loop::FindLatchBlock() {
331   CFG* cfg = context_->cfg();
332 
333   DominatorAnalysis* dominator_analysis =
334       context_->GetDominatorAnalysis(loop_header_->GetParent());
335 
336   // Look at the predecessors of the loop header to find a predecessor block
337   // which is dominated by the loop continue target. There should only be one
338   // block which meets this criteria and this is the latch block, as per the
339   // SPIR-V spec.
340   for (uint32_t block_id : cfg->preds(loop_header_->id())) {
341     if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) {
342       return cfg->block(block_id);
343     }
344   }
345 
346   assert(
347       false &&
348       "Every loop should have a latch block dominated by the continue target");
349   return nullptr;
350 }
351 
GetExitBlocks(std::unordered_set<uint32_t> * exit_blocks) const352 void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
353   CFG* cfg = context_->cfg();
354   exit_blocks->clear();
355 
356   for (uint32_t bb_id : GetBlocks()) {
357     const BasicBlock* bb = cfg->block(bb_id);
358     bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) {
359       if (!IsInsideLoop(succ)) {
360         exit_blocks->insert(succ);
361       }
362     });
363   }
364 }
365 
GetMergingBlocks(std::unordered_set<uint32_t> * merging_blocks) const366 void Loop::GetMergingBlocks(
367     std::unordered_set<uint32_t>* merging_blocks) const {
368   assert(GetMergeBlock() && "This loop is not structured");
369   CFG* cfg = context_->cfg();
370   merging_blocks->clear();
371 
372   std::stack<const BasicBlock*> to_visit;
373   to_visit.push(GetMergeBlock());
374   while (!to_visit.empty()) {
375     const BasicBlock* bb = to_visit.top();
376     to_visit.pop();
377     merging_blocks->insert(bb->id());
378     for (uint32_t pred_id : cfg->preds(bb->id())) {
379       if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) {
380         to_visit.push(cfg->block(pred_id));
381       }
382     }
383   }
384 }
385 
386 namespace {
387 
IsBasicBlockSafeToClone(IRContext * context,BasicBlock * bb)388 inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) {
389   for (Instruction& inst : *bb) {
390     if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst))
391       return false;
392   }
393 
394   return true;
395 }
396 
397 }  // namespace
398 
IsSafeToClone() const399 bool Loop::IsSafeToClone() const {
400   CFG& cfg = *context_->cfg();
401 
402   for (uint32_t bb_id : GetBlocks()) {
403     BasicBlock* bb = cfg.block(bb_id);
404     assert(bb);
405     if (!IsBasicBlockSafeToClone(context_, bb)) return false;
406   }
407 
408   // Look at the merge construct.
409   if (GetHeaderBlock()->GetLoopMergeInst()) {
410     std::unordered_set<uint32_t> blocks;
411     GetMergingBlocks(&blocks);
412     blocks.erase(GetMergeBlock()->id());
413     for (uint32_t bb_id : blocks) {
414       BasicBlock* bb = cfg.block(bb_id);
415       assert(bb);
416       if (!IsBasicBlockSafeToClone(context_, bb)) return false;
417     }
418   }
419 
420   return true;
421 }
422 
IsLCSSA() const423 bool Loop::IsLCSSA() const {
424   CFG* cfg = context_->cfg();
425   analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
426 
427   std::unordered_set<uint32_t> exit_blocks;
428   GetExitBlocks(&exit_blocks);
429 
430   // Declare ir_context so we can capture context_ in the below lambda
431   IRContext* ir_context = context_;
432 
433   for (uint32_t bb_id : GetBlocks()) {
434     for (Instruction& insn : *cfg->block(bb_id)) {
435       // All uses must be either:
436       //  - In the loop;
437       //  - In an exit block and in a phi instruction.
438       if (!def_use_mgr->WhileEachUser(
439               &insn,
440               [&exit_blocks, ir_context, this](Instruction* use) -> bool {
441                 BasicBlock* parent = ir_context->get_instr_block(use);
442                 assert(parent && "Invalid analysis");
443                 if (IsInsideLoop(parent)) return true;
444                 if (use->opcode() != spv::Op::OpPhi) return false;
445                 return exit_blocks.count(parent->id());
446               }))
447         return false;
448     }
449   }
450   return true;
451 }
452 
ShouldHoistInstruction(const Instruction & inst) const453 bool Loop::ShouldHoistInstruction(const Instruction& inst) const {
454   return inst.IsOpcodeCodeMotionSafe() && AreAllOperandsOutsideLoop(inst) &&
455          (!inst.IsLoad() || inst.IsReadOnlyLoad());
456 }
457 
AreAllOperandsOutsideLoop(const Instruction & inst) const458 bool Loop::AreAllOperandsOutsideLoop(const Instruction& inst) const {
459   analysis::DefUseManager* def_use_mgr = GetContext()->get_def_use_mgr();
460 
461   const std::function<bool(const uint32_t*)> operand_outside_loop =
462       [this, &def_use_mgr](const uint32_t* id) {
463         return !this->IsInsideLoop(def_use_mgr->GetDef(*id));
464       };
465 
466   return inst.WhileEachInId(operand_outside_loop);
467 }
468 
ComputeLoopStructuredOrder(std::vector<BasicBlock * > * ordered_loop_blocks,bool include_pre_header,bool include_merge) const469 void Loop::ComputeLoopStructuredOrder(
470     std::vector<BasicBlock*>* ordered_loop_blocks, bool include_pre_header,
471     bool include_merge) const {
472   CFG& cfg = *context_->cfg();
473 
474   // Reserve the memory: all blocks in the loop + extra if needed.
475   ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header +
476                                include_merge);
477 
478   if (include_pre_header && GetPreHeaderBlock())
479     ordered_loop_blocks->push_back(loop_preheader_);
480 
481   bool is_shader =
482       context_->get_feature_mgr()->HasCapability(spv::Capability::Shader);
483   if (!is_shader) {
484     cfg.ForEachBlockInReversePostOrder(
485         loop_header_, [ordered_loop_blocks, this](BasicBlock* bb) {
486           if (IsInsideLoop(bb)) ordered_loop_blocks->push_back(bb);
487         });
488   } else {
489     // If this is a shader, it is possible that there are unreachable merge and
490     // continue blocks that must be copied to retain the structured order.
491     // The structured order will include these.
492     std::list<BasicBlock*> order;
493     cfg.ComputeStructuredOrder(loop_header_->GetParent(), loop_header_,
494                                loop_merge_, &order);
495     for (BasicBlock* bb : order) {
496       if (bb == GetMergeBlock()) {
497         break;
498       }
499       ordered_loop_blocks->push_back(bb);
500     }
501   }
502   if (include_merge && GetMergeBlock())
503     ordered_loop_blocks->push_back(loop_merge_);
504 }
505 
LoopDescriptor(IRContext * context,const Function * f)506 LoopDescriptor::LoopDescriptor(IRContext* context, const Function* f)
507     : loops_(), placeholder_top_loop_(nullptr) {
508   PopulateList(context, f);
509 }
510 
~LoopDescriptor()511 LoopDescriptor::~LoopDescriptor() { ClearLoops(); }
512 
PopulateList(IRContext * context,const Function * f)513 void LoopDescriptor::PopulateList(IRContext* context, const Function* f) {
514   DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f);
515 
516   ClearLoops();
517 
518   // Post-order traversal of the dominator tree to find all the OpLoopMerge
519   // instructions.
520   DominatorTree& dom_tree = dom_analysis->GetDomTree();
521   for (DominatorTreeNode& node :
522        make_range(dom_tree.post_begin(), dom_tree.post_end())) {
523     Instruction* merge_inst = node.bb_->GetLoopMergeInst();
524     if (merge_inst) {
525       bool all_backedge_unreachable = true;
526       for (uint32_t pid : context->cfg()->preds(node.bb_->id())) {
527         if (dom_analysis->IsReachable(pid) &&
528             dom_analysis->Dominates(node.bb_->id(), pid)) {
529           all_backedge_unreachable = false;
530           break;
531         }
532       }
533       if (all_backedge_unreachable)
534         continue;  // ignore this one, we actually never branch back.
535 
536       // The id of the merge basic block of this loop.
537       uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
538 
539       // The id of the continue basic block of this loop.
540       uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1);
541 
542       // The merge target of this loop.
543       BasicBlock* merge_bb = context->cfg()->block(merge_bb_id);
544 
545       // The continue target of this loop.
546       BasicBlock* continue_bb = context->cfg()->block(continue_bb_id);
547 
548       // The basic block containing the merge instruction.
549       BasicBlock* header_bb = context->get_instr_block(merge_inst);
550 
551       // Add the loop to the list of all the loops in the function.
552       Loop* current_loop =
553           new Loop(context, dom_analysis, header_bb, continue_bb, merge_bb);
554       loops_.push_back(current_loop);
555 
556       // We have a bottom-up construction, so if this loop has nested-loops,
557       // they are by construction at the tail of the loop list.
558       for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) {
559         Loop* previous_loop = *itr;
560 
561         // If the loop already has a parent, then it has been processed.
562         if (previous_loop->HasParent()) continue;
563 
564         // If the current loop does not dominates the previous loop then it is
565         // not nested loop.
566         if (!dom_analysis->Dominates(header_bb,
567                                      previous_loop->GetHeaderBlock()))
568           continue;
569         // If the current loop merge dominates the previous loop then it is
570         // not nested loop.
571         if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock()))
572           continue;
573 
574         current_loop->AddNestedLoop(previous_loop);
575       }
576       DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb);
577       for (DominatorTreeNode& loop_node :
578            make_range(node.df_begin(), node.df_end())) {
579         // Check if we are in the loop.
580         if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
581         current_loop->AddBasicBlock(loop_node.bb_);
582         basic_block_to_loop_.insert(
583             std::make_pair(loop_node.bb_->id(), current_loop));
584       }
585     }
586   }
587   for (Loop* loop : loops_) {
588     if (!loop->HasParent()) placeholder_top_loop_.nested_loops_.push_back(loop);
589   }
590 }
591 
GetLoopsInBinaryLayoutOrder()592 std::vector<Loop*> LoopDescriptor::GetLoopsInBinaryLayoutOrder() {
593   std::vector<uint32_t> ids{};
594 
595   for (size_t i = 0; i < NumLoops(); ++i) {
596     ids.push_back(GetLoopByIndex(i).GetHeaderBlock()->id());
597   }
598 
599   std::vector<Loop*> loops{};
600   if (!ids.empty()) {
601     auto function = GetLoopByIndex(0).GetHeaderBlock()->GetParent();
602     for (const auto& block : *function) {
603       auto block_id = block.id();
604 
605       auto element = std::find(std::begin(ids), std::end(ids), block_id);
606       if (element != std::end(ids)) {
607         loops.push_back(&GetLoopByIndex(element - std::begin(ids)));
608       }
609     }
610   }
611 
612   return loops;
613 }
614 
FindConditionBlock() const615 BasicBlock* Loop::FindConditionBlock() const {
616   if (!loop_merge_) {
617     return nullptr;
618   }
619   BasicBlock* condition_block = nullptr;
620 
621   uint32_t in_loop_pred = 0;
622   for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) {
623     if (IsInsideLoop(p)) {
624       if (in_loop_pred) {
625         // 2 in-loop predecessors.
626         return nullptr;
627       }
628       in_loop_pred = p;
629     }
630   }
631   if (!in_loop_pred) {
632     // Merge block is unreachable.
633     return nullptr;
634   }
635 
636   BasicBlock* bb = context_->cfg()->block(in_loop_pred);
637 
638   if (!bb) return nullptr;
639 
640   const Instruction& branch = *bb->ctail();
641 
642   // Make sure the branch is a conditional branch.
643   if (branch.opcode() != spv::Op::OpBranchConditional) return nullptr;
644 
645   // Make sure one of the two possible branches is to the merge block.
646   if (branch.GetSingleWordInOperand(1) == loop_merge_->id() ||
647       branch.GetSingleWordInOperand(2) == loop_merge_->id()) {
648     condition_block = bb;
649   }
650 
651   return condition_block;
652 }
653 
FindNumberOfIterations(const Instruction * induction,const Instruction * branch_inst,size_t * iterations_out,int64_t * step_value_out,int64_t * init_value_out) const654 bool Loop::FindNumberOfIterations(const Instruction* induction,
655                                   const Instruction* branch_inst,
656                                   size_t* iterations_out,
657                                   int64_t* step_value_out,
658                                   int64_t* init_value_out) const {
659   // From the branch instruction find the branch condition.
660   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
661 
662   // Condition instruction from the OpConditionalBranch.
663   Instruction* condition =
664       def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0));
665 
666   assert(IsSupportedCondition(condition->opcode()));
667 
668   // Get the constant manager from the ir context.
669   analysis::ConstantManager* const_manager = context_->get_constant_mgr();
670 
671   // Find the constant value used by the condition variable. Exit out if it
672   // isn't a constant int.
673   const analysis::Constant* upper_bound =
674       const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3));
675   if (!upper_bound) return false;
676 
677   // Must be integer because of the opcode on the condition.
678   const analysis::Integer* type = upper_bound->type()->AsInteger();
679 
680   if (!type || type->width() > 64) {
681     return false;
682   }
683 
684   int64_t condition_value = type->IsSigned()
685                                 ? upper_bound->GetSignExtendedValue()
686                                 : upper_bound->GetZeroExtendedValue();
687 
688   // Find the instruction which is stepping through the loop.
689   //
690   // GetInductionStepOperation returns nullptr if |step_inst| is OpConstantNull.
691   Instruction* step_inst = GetInductionStepOperation(induction);
692   if (!step_inst) return false;
693 
694   // Find the constant value used by the condition variable.
695   const analysis::Constant* step_constant =
696       const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3));
697   if (!step_constant) return false;
698 
699   // Must be integer because of the opcode on the condition.
700   int64_t step_value = 0;
701 
702   const analysis::Integer* step_type =
703       step_constant->AsIntConstant()->type()->AsInteger();
704 
705   if (step_type->IsSigned()) {
706     step_value = step_constant->AsIntConstant()->GetS32BitValue();
707   } else {
708     step_value = step_constant->AsIntConstant()->GetU32BitValue();
709   }
710 
711   // If this is a subtraction step we should negate the step value.
712   if (step_inst->opcode() == spv::Op::OpISub) {
713     step_value = -step_value;
714   }
715 
716   // Find the initial value of the loop and make sure it is a constant integer.
717   int64_t init_value = 0;
718   if (!GetInductionInitValue(induction, &init_value)) return false;
719 
720   // If iterations is non null then store the value in that.
721   int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
722                                    init_value, step_value);
723 
724   // If the loop body will not be reached return false.
725   if (num_itrs <= 0) {
726     return false;
727   }
728 
729   if (iterations_out) {
730     assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
731     *iterations_out = static_cast<size_t>(num_itrs);
732   }
733 
734   if (step_value_out) {
735     *step_value_out = step_value;
736   }
737 
738   if (init_value_out) {
739     *init_value_out = init_value;
740   }
741 
742   return true;
743 }
744 
745 // We retrieve the number of iterations using the following formula, diff /
746 // |step_value| where diff is calculated differently according to the
747 // |condition| and uses the |condition_value| and |init_value|. If diff /
748 // |step_value| is NOT cleanly divisible then we add one to the sum.
GetIterations(spv::Op condition,int64_t condition_value,int64_t init_value,int64_t step_value) const749 int64_t Loop::GetIterations(spv::Op condition, int64_t condition_value,
750                             int64_t init_value, int64_t step_value) const {
751   if (step_value == 0) {
752     return 0;
753   }
754 
755   int64_t diff = 0;
756 
757   switch (condition) {
758     case spv::Op::OpSLessThan:
759     case spv::Op::OpULessThan: {
760       // If the condition is not met to begin with the loop will never iterate.
761       if (!(init_value < condition_value)) return 0;
762 
763       diff = condition_value - init_value;
764 
765       // If the operation is a less then operation then the diff and step must
766       // have the same sign otherwise the induction will never cross the
767       // condition (either never true or always true).
768       if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
769         return 0;
770       }
771 
772       break;
773     }
774     case spv::Op::OpSGreaterThan:
775     case spv::Op::OpUGreaterThan: {
776       // If the condition is not met to begin with the loop will never iterate.
777       if (!(init_value > condition_value)) return 0;
778 
779       diff = init_value - condition_value;
780 
781       // If the operation is a greater than operation then the diff and step
782       // must have opposite signs. Otherwise the condition will always be true
783       // or will never be true.
784       if ((diff < 0 && step_value < 0) || (diff > 0 && step_value > 0)) {
785         return 0;
786       }
787 
788       break;
789     }
790 
791     case spv::Op::OpSGreaterThanEqual:
792     case spv::Op::OpUGreaterThanEqual: {
793       // If the condition is not met to begin with the loop will never iterate.
794       if (!(init_value >= condition_value)) return 0;
795 
796       // We subtract one to make it the same as spv::Op::OpGreaterThan as it is
797       // functionally equivalent.
798       diff = init_value - (condition_value - 1);
799 
800       // If the operation is a greater than operation then the diff and step
801       // must have opposite signs. Otherwise the condition will always be true
802       // or will never be true.
803       if ((diff > 0 && step_value > 0) || (diff < 0 && step_value < 0)) {
804         return 0;
805       }
806 
807       break;
808     }
809 
810     case spv::Op::OpSLessThanEqual:
811     case spv::Op::OpULessThanEqual: {
812       // If the condition is not met to begin with the loop will never iterate.
813       if (!(init_value <= condition_value)) return 0;
814 
815       // We add one to make it the same as spv::Op::OpLessThan as it is
816       // functionally equivalent.
817       diff = (condition_value + 1) - init_value;
818 
819       // If the operation is a less than operation then the diff and step must
820       // have the same sign otherwise the induction will never cross the
821       // condition (either never true or always true).
822       if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
823         return 0;
824       }
825 
826       break;
827     }
828 
829     default:
830       assert(false &&
831              "Could not retrieve number of iterations from the loop condition. "
832              "Condition is not supported.");
833   }
834 
835   // Take the abs of - step values.
836   step_value = llabs(step_value);
837   diff = llabs(diff);
838   int64_t result = diff / step_value;
839 
840   if (diff % step_value != 0) {
841     result += 1;
842   }
843   return result;
844 }
845 
846 // Returns the list of induction variables within the loop.
GetInductionVariables(std::vector<Instruction * > & induction_variables) const847 void Loop::GetInductionVariables(
848     std::vector<Instruction*>& induction_variables) const {
849   for (Instruction& inst : *loop_header_) {
850     if (inst.opcode() == spv::Op::OpPhi) {
851       induction_variables.push_back(&inst);
852     }
853   }
854 }
855 
FindConditionVariable(const BasicBlock * condition_block) const856 Instruction* Loop::FindConditionVariable(
857     const BasicBlock* condition_block) const {
858   // Find the branch instruction.
859   const Instruction& branch_inst = *condition_block->ctail();
860 
861   Instruction* induction = nullptr;
862   // Verify that the branch instruction is a conditional branch.
863   if (branch_inst.opcode() == spv::Op::OpBranchConditional) {
864     // From the branch instruction find the branch condition.
865     analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
866 
867     // Find the instruction representing the condition used in the conditional
868     // branch.
869     Instruction* condition =
870         def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0));
871 
872     // Ensure that the condition is a less than operation.
873     if (condition && IsSupportedCondition(condition->opcode())) {
874       // The left hand side operand of the operation.
875       Instruction* variable_inst =
876           def_use_manager->GetDef(condition->GetSingleWordOperand(2));
877 
878       // Make sure the variable instruction used is a phi.
879       if (!variable_inst || variable_inst->opcode() != spv::Op::OpPhi)
880         return nullptr;
881 
882       // Make sure the phi instruction only has two incoming blocks. Each
883       // incoming block will be represented by two in operands in the phi
884       // instruction, the value and the block which that value came from. We
885       // assume the cannocalised phi will have two incoming values, one from the
886       // preheader and one from the continue block.
887       size_t max_supported_operands = 4;
888       if (variable_inst->NumInOperands() == max_supported_operands) {
889         // The operand index of the first incoming block label.
890         uint32_t operand_label_1 = 1;
891 
892         // The operand index of the second incoming block label.
893         uint32_t operand_label_2 = 3;
894 
895         // Make sure one of them is the preheader.
896         if (!IsInsideLoop(
897                 variable_inst->GetSingleWordInOperand(operand_label_1)) &&
898             !IsInsideLoop(
899                 variable_inst->GetSingleWordInOperand(operand_label_2))) {
900           return nullptr;
901         }
902 
903         // And make sure that the other is the latch block.
904         if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
905                 loop_latch_->id() &&
906             variable_inst->GetSingleWordInOperand(operand_label_2) !=
907                 loop_latch_->id()) {
908           return nullptr;
909         }
910       } else {
911         return nullptr;
912       }
913 
914       if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr))
915         return nullptr;
916       induction = variable_inst;
917     }
918   }
919 
920   return induction;
921 }
922 
CreatePreHeaderBlocksIfMissing()923 bool LoopDescriptor::CreatePreHeaderBlocksIfMissing() {
924   auto modified = false;
925 
926   for (auto& loop : *this) {
927     if (!loop.GetPreHeaderBlock()) {
928       modified = true;
929       // TODO(1841): Handle failure to create pre-header.
930       loop.GetOrCreatePreHeaderBlock();
931     }
932   }
933 
934   return modified;
935 }
936 
937 // Add and remove loops which have been marked for addition and removal to
938 // maintain the state of the loop descriptor class.
PostModificationCleanup()939 void LoopDescriptor::PostModificationCleanup() {
940   LoopContainerType loops_to_remove_;
941   for (Loop* loop : loops_) {
942     if (loop->IsMarkedForRemoval()) {
943       loops_to_remove_.push_back(loop);
944       if (loop->HasParent()) {
945         loop->GetParent()->RemoveChildLoop(loop);
946       }
947     }
948   }
949 
950   for (Loop* loop : loops_to_remove_) {
951     loops_.erase(std::find(loops_.begin(), loops_.end(), loop));
952     delete loop;
953   }
954 
955   for (auto& pair : loops_to_add_) {
956     Loop* parent = pair.first;
957     std::unique_ptr<Loop> loop = std::move(pair.second);
958 
959     if (parent) {
960       loop->SetParent(nullptr);
961       parent->AddNestedLoop(loop.get());
962 
963       for (uint32_t block_id : loop->GetBlocks()) {
964         parent->AddBasicBlock(block_id);
965       }
966     }
967 
968     loops_.emplace_back(loop.release());
969   }
970 
971   loops_to_add_.clear();
972 }
973 
ClearLoops()974 void LoopDescriptor::ClearLoops() {
975   for (Loop* loop : loops_) {
976     delete loop;
977   }
978   loops_.clear();
979 }
980 
981 // Adds a new loop nest to the descriptor set.
AddLoopNest(std::unique_ptr<Loop> new_loop)982 Loop* LoopDescriptor::AddLoopNest(std::unique_ptr<Loop> new_loop) {
983   Loop* loop = new_loop.release();
984   if (!loop->HasParent()) placeholder_top_loop_.nested_loops_.push_back(loop);
985   // Iterate from inner to outer most loop, adding basic block to loop mapping
986   // as we go.
987   for (Loop& current_loop :
988        make_range(iterator::begin(loop), iterator::end(nullptr))) {
989     loops_.push_back(&current_loop);
990     for (uint32_t bb_id : current_loop.GetBlocks())
991       basic_block_to_loop_.insert(std::make_pair(bb_id, &current_loop));
992   }
993 
994   return loop;
995 }
996 
RemoveLoop(Loop * loop)997 void LoopDescriptor::RemoveLoop(Loop* loop) {
998   Loop* parent = loop->GetParent() ? loop->GetParent() : &placeholder_top_loop_;
999   parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(),
1000                                         parent->nested_loops_.end(), loop));
1001   std::for_each(
1002       loop->nested_loops_.begin(), loop->nested_loops_.end(),
1003       [loop](Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); });
1004   parent->nested_loops_.insert(parent->nested_loops_.end(),
1005                                loop->nested_loops_.begin(),
1006                                loop->nested_loops_.end());
1007   for (uint32_t bb_id : loop->GetBlocks()) {
1008     Loop* l = FindLoopForBasicBlock(bb_id);
1009     if (l == loop) {
1010       SetBasicBlockToLoop(bb_id, l->GetParent());
1011     } else {
1012       ForgetBasicBlock(bb_id);
1013     }
1014   }
1015 
1016   LoopContainerType::iterator it =
1017       std::find(loops_.begin(), loops_.end(), loop);
1018   assert(it != loops_.end());
1019   delete loop;
1020   loops_.erase(it);
1021 }
1022 
1023 }  // namespace opt
1024 }  // namespace spvtools
1025