xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Analysis/userange_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Analysis/userange_analysis.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "llvm/ADT/SetOperations.h"
22 #include "mlir/IR/Block.h"
23 #include "mlir/IR/Region.h"
24 #include "mlir/Interfaces/LoopLikeInterface.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Builds a userange information from the given value and its liveness. The
30 /// information includes all operations that are within the userange.
31 struct UserangeInfoBuilder {
32   using OperationListT = Liveness::OperationListT;
33   using ValueSetT = BufferViewFlowAnalysis::ValueSetT;
34 
35  public:
36   /// Constructs an Userange builder.
UserangeInfoBuilder__anon77d1cd940111::UserangeInfoBuilder37   UserangeInfoBuilder(Liveness pLiveness, ValueSetT pValues,
38                       OperationListT pOpList)
39       : values(std::move(pValues)),
40         opList(std::move(pOpList)),
41         liveness(std::move(pLiveness)) {}
42 
43   /// Computes the userange of the current value by iterating over all of its
44   /// uses.
computeUserange__anon77d1cd940111::UserangeInfoBuilder45   Liveness::OperationListT computeUserange() {
46     Region *topRegion = findTopRegion();
47     // Iterate over all associated uses.
48     for (Operation *use : opList) {
49       // If one of the parents implements a LoopLikeOpInterface we need to add
50       // all operations inside of its regions to the userange.
51       Operation *loopParent = use->getParentOfType<LoopLikeOpInterface>();
52       if (loopParent && topRegion->isProperAncestor(use->getParentRegion()))
53         addAllOperationsInRegion(loopParent);
54 
55       // Check if the parent block has already been processed.
56       Block *useBlock = findTopLiveBlock(use);
57       if (!startBlocks.insert(useBlock).second || visited.contains(useBlock))
58         continue;
59 
60       // Add all operations inside the block that are within the userange.
61       findOperationsInUse(useBlock);
62     }
63     return currentUserange;
64   }
65 
66  private:
67   /// Find the top most Region of all values stored in the values set.
findTopRegion__anon77d1cd940111::UserangeInfoBuilder68   Region *findTopRegion() const {
69     Region *topRegion = nullptr;
70     llvm::for_each(values, [&](Value v) {
71       Region *other = v.getParentRegion();
72       if (!topRegion || topRegion->isAncestor(other)) topRegion = other;
73     });
74     return topRegion;
75   }
76 
77   /// Finds the highest level block that has the current value in its liveOut
78   /// set.
findTopLiveBlock__anon77d1cd940111::UserangeInfoBuilder79   Block *findTopLiveBlock(Operation *op) const {
80     Operation *topOp = op;
81     while (const LivenessBlockInfo *blockInfo =
82                liveness.getLiveness(op->getBlock())) {
83       if (llvm::any_of(values,
84                        [&](Value v) { return blockInfo->isLiveOut(v); }))
85         topOp = op;
86       op = op->getParentOp();
87     }
88     return topOp->getBlock();
89   }
90 
91   /// Adds all operations from start to end to the userange of the current
92   /// value. If an operation implements a nested region all operations inside of
93   /// it are included as well. If includeEnd is false the end operation is not
94   /// added.
addAllOperationsBetween__anon77d1cd940111::UserangeInfoBuilder95   void addAllOperationsBetween(Operation *start, Operation *end) {
96     currentUserange.push_back(start);
97     addAllOperationsInRegion(start);
98 
99     while (start != end) {
100       start = start->getNextNode();
101       addAllOperationsInRegion(start);
102       currentUserange.push_back(start);
103     }
104   }
105 
106   /// Adds all operations that are uses of the value in the given block to the
107   /// userange of the current value. Additionally iterate over all successors
108   /// where the value is live.
findOperationsInUse__anon77d1cd940111::UserangeInfoBuilder109   void findOperationsInUse(Block *block) {
110     SmallVector<Block *, 8> blocksToProcess;
111     addOperationsInBlockAndFindSuccessors(
112         block, block, getStartOperation(block), blocksToProcess);
113     while (!blocksToProcess.empty()) {
114       Block *toProcess = blocksToProcess.pop_back_val();
115       addOperationsInBlockAndFindSuccessors(
116           block, toProcess, &toProcess->front(), blocksToProcess);
117     }
118   }
119 
120   /// Adds the operations between the given start operation and the computed end
121   /// operation to the userange. If the current value is live out, add all
122   /// successor blocks that have the value live in to the process queue. If we
123   /// find a loop, add the operations before the first use in block to the
124   /// userange (if any). The startBlock is the block where the iteration over
125   /// all successors started and is propagated further to find potential loops.
addOperationsInBlockAndFindSuccessors__anon77d1cd940111::UserangeInfoBuilder126   void addOperationsInBlockAndFindSuccessors(
127       const Block *startBlock, Block *toProcess, Operation *start,
128       SmallVector<Block *, 8> &blocksToProcess) {
129     const LivenessBlockInfo *blockInfo = liveness.getLiveness(toProcess);
130     Operation *end = getEndOperation(toProcess);
131 
132     addAllOperationsBetween(start, end);
133 
134     // If the value is live out we need to process all successors at which the
135     // value is live in.
136     if (!llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); }))
137       return;
138     for (Block *successor : toProcess->getSuccessors()) {
139       // If the successor is the startBlock, we found a loop and only have to
140       // add the operations from the block front to the first use of the
141       // value.
142       if (!llvm::any_of(values, [&](Value v) {
143             return liveness.getLiveness(successor)->isLiveIn(v);
144           }))
145         continue;
146       if (successor == startBlock) {
147         start = &successor->front();
148         end = getStartOperation(successor);
149         if (start != end) addAllOperationsBetween(start, end->getPrevNode());
150         // Else we need to check if the value is live in and the successor
151         // has not been visited before. If so we also need to process it.
152       } else if (visited.insert(successor).second) {
153         blocksToProcess.emplace_back(successor);
154       }
155     }
156   }
157 
158   /// Iterates over all regions of a given operation and adds all operations
159   /// inside those regions to the userange of the current value.
addAllOperationsInRegion__anon77d1cd940111::UserangeInfoBuilder160   void addAllOperationsInRegion(Operation *parentOp) {
161     // Iterate over all regions of the parentOp.
162     for (Region &region : parentOp->getRegions()) {
163       // Iterate over blocks inside the region.
164       for (Block &block : region) {
165         // If the blocks have been used as a startBlock before, we need to add
166         // all operations between the block front and the startOp of the value.
167         if (startBlocks.contains(&block)) {
168           Operation *start = &block.front();
169           Operation *end = getStartOperation(&block);
170           if (start != end) addAllOperationsBetween(start, end->getPrevNode());
171 
172           // If the block has never been seen before, we need to add all
173           // operations inside.
174         } else if (visited.insert(&block).second) {
175           for (Operation &op : block) {
176             addAllOperationsInRegion(&op);
177             currentUserange.emplace_back(&op);
178           }
179           continue;
180         }
181         // If the block has either been visited before or was used as a
182         // startBlock, we need to add all operations between the endOp of the
183         // value and the end of the block.
184         Operation *end = getEndOperation(&block);
185         if (end == &block.back()) continue;
186         addAllOperationsBetween(end->getNextNode(), &block.back());
187       }
188     }
189   }
190 
191   /// Find the start operation of the current value inside the given block.
getStartOperation__anon77d1cd940111::UserangeInfoBuilder192   Operation *getStartOperation(Block *block) {
193     Operation *startOperation = &block->back();
194     for (Operation *useOp : opList) {
195       // Find the associated operation in the current block (if any).
196       useOp = block->findAncestorOpInBlock(*useOp);
197       // Check whether the use is in our block and after the current end
198       // operation.
199       if (useOp && useOp->isBeforeInBlock(startOperation))
200         startOperation = useOp;
201     }
202     return startOperation;
203   }
204 
205   /// Find the end operation of the current value inside the given block.
getEndOperation__anon77d1cd940111::UserangeInfoBuilder206   Operation *getEndOperation(Block *block) {
207     const LivenessBlockInfo *blockInfo = liveness.getLiveness(block);
208     if (llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); }))
209       return &block->back();
210 
211     Operation *endOperation = &block->front();
212     for (Operation *useOp : opList) {
213       // Find the associated operation in the current block (if any).
214       useOp = block->findAncestorOpInBlock(*useOp);
215       // Check whether the use is in our block and after the current end
216       // operation.
217       if (useOp && endOperation->isBeforeInBlock(useOp)) endOperation = useOp;
218     }
219     return endOperation;
220   }
221 
222   /// The current Value.
223   ValueSetT values;
224 
225   /// The list of all operations used by the values.
226   OperationListT opList;
227 
228   /// The result list of the userange computation.
229   OperationListT currentUserange;
230 
231   /// The set of visited blocks during the userange computation.
232   SmallPtrSet<Block *, 32> visited;
233 
234   /// The set of blocks that the userange computation started from.
235   SmallPtrSet<Block *, 8> startBlocks;
236 
237   /// The current liveness info.
238   Liveness liveness;
239 };
240 }  // namespace
241 
UserangeAnalysis(Operation * op,const BufferPlacementAllocs & allocs,const BufferViewFlowAnalysis & aliases)242 UserangeAnalysis::UserangeAnalysis(Operation *op,
243                                    const BufferPlacementAllocs &allocs,
244                                    const BufferViewFlowAnalysis &aliases)
245     : liveness(op) {
246   // Walk over all operations and map them to an ID.
247   op->walk([&](Operation *operation) {
248     gatherMemoryEffects(operation);
249     operationIds.insert({operation, operationIds.size()});
250   });
251 
252   // Compute the use range for every allocValue and its aliases. Merge them
253   // and compute an interval. Add all computed intervals to the useIntervalMap.
254   for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
255     Value allocValue = std::get<0>(entry);
256     OperationListT useList;
257     for (auto &use : allocValue.getUses()) useList.emplace_back(use.getOwner());
258     UserangeInfoBuilder builder(liveness, {allocValue}, useList);
259     OperationListT liveOperations = builder.computeUserange();
260 
261     // Sort the operation list by ids.
262     std::sort(liveOperations.begin(), liveOperations.end(),
263               [&](Operation *left, Operation *right) {
264                 return operationIds[left] < operationIds[right];
265               });
266 
267     IntervalVector allocInterval = computeInterval(allocValue, liveOperations);
268     // Iterate over all aliases and add their useranges to the userange of the
269     // current value. Also add the useInterval of each alias to the
270     // useIntervalMap.
271     ValueSetT aliasSet = aliases.resolve(allocValue);
272     for (Value alias : aliasSet) {
273       if (alias == allocValue) continue;
274       if (!aliasUseranges.count(alias)) {
275         OperationListT aliasOperations;
276         // If the alias is a BlockArgument then the value is live with the first
277         // operation inside that block. Otherwise the liveness analysis is
278         // sufficient for the use range.
279         if (alias.isa<BlockArgument>()) {
280           aliasOperations.emplace_back(&alias.getParentBlock()->front());
281           for (auto &use : alias.getUses())
282             aliasOperations.emplace_back(use.getOwner());
283           // Compute the use range for the alias and sort the operations
284           // afterwards.
285           UserangeInfoBuilder aliasBuilder(liveness, {alias}, aliasOperations);
286           aliasOperations = aliasBuilder.computeUserange();
287           std::sort(aliasOperations.begin(), aliasOperations.end(),
288                     [&](Operation *left, Operation *right) {
289                       return operationIds[left] < operationIds[right];
290                     });
291         } else {
292           aliasOperations = liveness.resolveLiveness(alias);
293         }
294 
295         aliasUseranges.insert({alias, aliasOperations});
296         useIntervalMap.insert(
297             {alias, computeInterval(alias, aliasUseranges[alias])});
298       }
299       allocInterval =
300           std::get<0>(intervalMerge(allocInterval, useIntervalMap[alias]));
301     }
302     aliasCache.insert(std::make_pair(allocValue, aliasSet));
303 
304     // Map the current allocValue to the computed useInterval.
305     useIntervalMap.insert(std::make_pair(allocValue, allocInterval));
306   }
307 }
308 
309 /// Checks if the use intervals of the given values interfere.
rangesInterfere(Value itemA,Value itemB) const310 bool UserangeAnalysis::rangesInterfere(Value itemA, Value itemB) const {
311   return intervalUnion(itemA, itemB);
312 }
313 
314 /// Merges the userange of itemB into the userange of itemA.
315 /// Note: This assumes that there is no interference between the two
316 /// ranges.
unionRanges(Value itemA,Value itemB)317 void UserangeAnalysis::unionRanges(Value itemA, Value itemB) {
318   IntervalVector unionInterval =
319       std::get<0>(intervalMerge(useIntervalMap[itemA], useIntervalMap[itemB]));
320 
321   llvm::set_union(aliasCache[itemA], aliasCache[itemB]);
322   for (Value alias : aliasCache[itemA])
323     unionInterval =
324         std::get<0>(intervalMerge(unionInterval, useIntervalMap[alias]));
325 
326   // Compute new interval.
327   useIntervalMap[itemA] = unionInterval;
328 }
329 
330 /// Builds an IntervalVector corresponding to the given OperationList.
computeInterval(Value value,const Liveness::OperationListT & operationList)331 UserangeAnalysis::IntervalVector UserangeAnalysis::computeInterval(
332     Value value, const Liveness::OperationListT &operationList) {
333   assert(!operationList.empty() && "Operation list must not be empty");
334   size_t start = computeID(value, *operationList.begin());
335   size_t last = start;
336   UserangeAnalysis::IntervalVector intervals;
337   // Iterate over all operations in the operationList. If the gap between the
338   // respective operationIds is greater 1 create a new interval.
339   for (auto opIter = ++operationList.begin(), e = operationList.end();
340        opIter != e; ++opIter) {
341     size_t current = computeID(value, *opIter);
342     if (current - last > 2) {
343       intervals.emplace_back(UserangeAnalysis::UseInterval(start, last));
344       start = current;
345     }
346     last = current;
347   }
348   intervals.emplace_back(UserangeAnalysis::UseInterval(start, last));
349   return intervals;
350 }
351 
352 /// Checks each operand inside the operation for its memory effects and
353 /// separates them into read and write. Operands with read effects are added to
354 /// the opToReadMap.
gatherMemoryEffects(Operation * op)355 void UserangeAnalysis::gatherMemoryEffects(Operation *op) {
356   if (OpTrait::hasElementwiseMappableTraits(op)) {
357     if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
358       SmallPtrSet<Value, 2> readEffectSet;
359       SmallPtrSet<Value, 2> writeEffectSet;
360       for (auto operand : op->getOperands()) {
361         SmallVector<MemoryEffects::EffectInstance, 2> effects;
362         effectInterface.getEffectsOnValue(operand, effects);
363         for (auto effect : effects) {
364           if (isa<MemoryEffects::Write>(effect.getEffect()))
365             writeEffectSet.insert(operand);
366           else if (isa<MemoryEffects::Read>(effect.getEffect()))
367             readEffectSet.insert(operand);
368         }
369       }
370       opReadWriteMap.insert(
371           {op, std::make_pair(readEffectSet, writeEffectSet)});
372     }
373   }
374 }
375 
376 /// Computes the ID for the operation. If the operation contains operands which
377 /// have read effects, the returning ID will be odd. This allows us to
378 /// perform a replace in place.
computeID(Value v,Operation * op) const379 size_t UserangeAnalysis::computeID(Value v, Operation *op) const {
380   size_t doubledID = operationIds.find(op)->second * 2;
381   auto mapIter = opReadWriteMap.find(op);
382   if (mapIter == opReadWriteMap.end()) return doubledID;
383   auto reads = mapIter->second.first;
384   auto writes = mapIter->second.second;
385   if (reads.contains(v) && !writes.contains(v)) return doubledID - 1;
386   return doubledID;
387 }
388 
389 /// Merge two IntervalVectors into a new IntervalVector. Return a pair with the
390 /// resulting IntervalVector and a boolean if there were interferences during
391 /// merging.
392 std::pair<UserangeAnalysis::IntervalVector, bool>
intervalMerge(const IntervalVector & intervalA,const IntervalVector & intervalB) const393 UserangeAnalysis::intervalMerge(const IntervalVector &intervalA,
394                                 const IntervalVector &intervalB) const {
395   IntervalVector mergeResult;
396 
397   bool interference = false;
398   auto iterA = intervalA.begin();
399   auto iterB = intervalB.begin();
400   auto endA = intervalA.end();
401   auto endB = intervalB.end();
402   UseInterval current;
403   while (iterA != endA || iterB != endB) {
404     if (iterA == endA) {
405       // Only intervals from intervalB are left.
406       current = *iterB;
407       ++iterB;
408     } else if (iterB == endB) {
409       // Only intervals from intervalA are left.
410       current = *iterA;
411       ++iterA;
412     } else if (iterA->second < iterB->first) {
413       // A is strict before B: A(0,2), B(4,6)
414       current = *iterA;
415       ++iterA;
416     } else if (iterB->second < iterA->first) {
417       // B is strict before A: A(6,8), B(2,4)
418       current = *iterB;
419       ++iterB;
420     } else {
421       // A and B interfere.
422       interference = true;
423       current = UseInterval(std::min(iterA->first, iterB->first),
424                             std::max(iterA->second, iterB->second));
425       ++iterA;
426       ++iterB;
427     }
428     // Merge current with last element in mergeResult, if the intervals are
429     // consecutive and there is no gap.
430     if (mergeResult.empty()) {
431       mergeResult.emplace_back(current);
432       continue;
433     }
434     UseInterval *mergeResultLast = (mergeResult.end() - 1);
435     int diff = current.first - mergeResultLast->second;
436     if (diff <= 2 && mergeResultLast->second < current.second)
437       mergeResultLast->second = current.second;
438     else if (diff > 2)
439       mergeResult.emplace_back(current);
440   }
441 
442   return std::make_pair(mergeResult, interference);
443 }
444 
445 /// Performs an interval union of the interval vectors from the given values.
446 /// Returns an empty Optional if there is an interval interference.
intervalUnion(Value itemA,Value itemB) const447 bool UserangeAnalysis::intervalUnion(Value itemA, Value itemB) const {
448   ValueSetT intersect = aliasCache.find(itemA)->second;
449   llvm::set_intersect(intersect, aliasCache.find(itemB)->second);
450   IntervalVector tmpIntervalA = useIntervalMap.find(itemA)->second;
451 
452   // If the two values share a common alias, then the alias does not count as
453   // interference and should be removed.
454   if (!intersect.empty()) {
455     for (Value alias : intersect) {
456       IntervalVector aliasInterval = useIntervalMap.find(alias)->second;
457       intervalSubtract(tmpIntervalA, aliasInterval);
458     }
459   }
460 
461   return std::get<1>(
462       intervalMerge(tmpIntervalA, useIntervalMap.find(itemB)->second));
463 }
464 
465 /// Performs an interval subtraction => A = A - B.
466 /// Note: This assumes that all intervals of b are included in some interval
467 ///       of a.
intervalSubtract(IntervalVector & a,const IntervalVector & b) const468 void UserangeAnalysis::intervalSubtract(IntervalVector &a,
469                                         const IntervalVector &b) const {
470   auto iterB = b.begin();
471   auto endB = b.end();
472   for (auto iterA = a.begin(), endA = a.end();
473        iterA != endA && iterB != endB;) {
474     // iterA is strictly before iterB => increment iterA.
475     if (iterA->second < iterB->first) {
476       ++iterA;
477     } else if (iterA->first == iterB->first && iterA->second > iterB->second) {
478       // Usually, we would expect the case of iterB beeing strictly before
479       // iterA. However, due to the initial assumption that all intervals of b
480       // are included in some interval of a, we do not need to check if iterB is
481       // strictly before iterA.
482       // iterB is at the start of iterA, but iterA has some values that go
483       // beyond those of iterB. We have to set the lower bound of iterA to the
484       // upper bound of iterB + 1 and increment iterB.
485       // A(3, 100) - B(3, 5) => A(6,100)
486       iterA->first = iterB->second + 1;
487       ++iterB;
488     } else if (iterA->second == iterB->second && iterA->first < iterB->first) {
489       // iterB is at the end of iterA, but iterA has some values that come
490       // before iterB. We have to set the end of iterA to the start of iterB - 1
491       // and increment both iterators.
492       // A(4, 50) - B(40, 50) => A(4, 39)
493       iterA->second = iterB->first - 1;
494       ++iterA;
495       ++iterB;
496     } else if (iterA->first < iterB->first && iterA->second > iterB->second) {
497       // iterB is in the middle of iterA. We have to split iterA and increment
498       // iterB.
499       // A(2, 10) - B(5, 7) => (2, 4), (8, 10)
500       size_t endA = iterA->second;
501       iterA->second = iterB->first - 1;
502       iterA = a.insert(iterA, UseInterval(iterB->second + 1, endA));
503       ++iterB;
504     } else {
505       // Both intervals are equal. We have to erase the whole interval.
506       // A(5, 5) - B(5, 5) => {}
507       iterA = a.erase(iterA);
508       ++iterB;
509     }
510   }
511 }
512 
dump(raw_ostream & os)513 void UserangeAnalysis::dump(raw_ostream &os) {
514   os << "// ---- UserangeAnalysis -----\n";
515   std::vector<Value> values;
516   for (auto const &item : useIntervalMap) {
517     values.emplace_back(item.first);
518   }
519   std::sort(values.begin(), values.end(), [&](Value left, Value right) {
520     if (left.getDefiningOp()) {
521       if (right.getDefiningOp())
522         return operationIds[left.getDefiningOp()] <
523                operationIds[right.getDefiningOp()];
524       else
525         return true;
526     }
527     if (right.getDefiningOp()) return false;
528     return operationIds[&left.getParentBlock()->front()] <
529            operationIds[&right.getParentBlock()->front()];
530   });
531   for (auto value : values) {
532     os << "Value: " << value << (value.getDefiningOp() ? "\n" : "");
533     auto rangeIt = useIntervalMap[value].begin();
534     os << "Userange: {(" << rangeIt->first << ", " << rangeIt->second << ")";
535     rangeIt++;
536     for (auto e = useIntervalMap[value].end(); rangeIt != e; ++rangeIt) {
537       os << ", (" << rangeIt->first << ", " << rangeIt->second << ")";
538     }
539     os << "}\n";
540   }
541   os << "// ---------------------------\n";
542 }
543