1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Analysis/userange_analysis.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "llvm/ADT/SetOperations.h"
22 #include "mlir/IR/Block.h"
23 #include "mlir/IR/Region.h"
24 #include "mlir/Interfaces/LoopLikeInterface.h"
25
26 using namespace mlir;
27
28 namespace {
29 /// Builds a userange information from the given value and its liveness. The
30 /// information includes all operations that are within the userange.
31 struct UserangeInfoBuilder {
32 using OperationListT = Liveness::OperationListT;
33 using ValueSetT = BufferViewFlowAnalysis::ValueSetT;
34
35 public:
36 /// Constructs an Userange builder.
UserangeInfoBuilder__anon77d1cd940111::UserangeInfoBuilder37 UserangeInfoBuilder(Liveness pLiveness, ValueSetT pValues,
38 OperationListT pOpList)
39 : values(std::move(pValues)),
40 opList(std::move(pOpList)),
41 liveness(std::move(pLiveness)) {}
42
43 /// Computes the userange of the current value by iterating over all of its
44 /// uses.
computeUserange__anon77d1cd940111::UserangeInfoBuilder45 Liveness::OperationListT computeUserange() {
46 Region *topRegion = findTopRegion();
47 // Iterate over all associated uses.
48 for (Operation *use : opList) {
49 // If one of the parents implements a LoopLikeOpInterface we need to add
50 // all operations inside of its regions to the userange.
51 Operation *loopParent = use->getParentOfType<LoopLikeOpInterface>();
52 if (loopParent && topRegion->isProperAncestor(use->getParentRegion()))
53 addAllOperationsInRegion(loopParent);
54
55 // Check if the parent block has already been processed.
56 Block *useBlock = findTopLiveBlock(use);
57 if (!startBlocks.insert(useBlock).second || visited.contains(useBlock))
58 continue;
59
60 // Add all operations inside the block that are within the userange.
61 findOperationsInUse(useBlock);
62 }
63 return currentUserange;
64 }
65
66 private:
67 /// Find the top most Region of all values stored in the values set.
findTopRegion__anon77d1cd940111::UserangeInfoBuilder68 Region *findTopRegion() const {
69 Region *topRegion = nullptr;
70 llvm::for_each(values, [&](Value v) {
71 Region *other = v.getParentRegion();
72 if (!topRegion || topRegion->isAncestor(other)) topRegion = other;
73 });
74 return topRegion;
75 }
76
77 /// Finds the highest level block that has the current value in its liveOut
78 /// set.
findTopLiveBlock__anon77d1cd940111::UserangeInfoBuilder79 Block *findTopLiveBlock(Operation *op) const {
80 Operation *topOp = op;
81 while (const LivenessBlockInfo *blockInfo =
82 liveness.getLiveness(op->getBlock())) {
83 if (llvm::any_of(values,
84 [&](Value v) { return blockInfo->isLiveOut(v); }))
85 topOp = op;
86 op = op->getParentOp();
87 }
88 return topOp->getBlock();
89 }
90
91 /// Adds all operations from start to end to the userange of the current
92 /// value. If an operation implements a nested region all operations inside of
93 /// it are included as well. If includeEnd is false the end operation is not
94 /// added.
addAllOperationsBetween__anon77d1cd940111::UserangeInfoBuilder95 void addAllOperationsBetween(Operation *start, Operation *end) {
96 currentUserange.push_back(start);
97 addAllOperationsInRegion(start);
98
99 while (start != end) {
100 start = start->getNextNode();
101 addAllOperationsInRegion(start);
102 currentUserange.push_back(start);
103 }
104 }
105
106 /// Adds all operations that are uses of the value in the given block to the
107 /// userange of the current value. Additionally iterate over all successors
108 /// where the value is live.
findOperationsInUse__anon77d1cd940111::UserangeInfoBuilder109 void findOperationsInUse(Block *block) {
110 SmallVector<Block *, 8> blocksToProcess;
111 addOperationsInBlockAndFindSuccessors(
112 block, block, getStartOperation(block), blocksToProcess);
113 while (!blocksToProcess.empty()) {
114 Block *toProcess = blocksToProcess.pop_back_val();
115 addOperationsInBlockAndFindSuccessors(
116 block, toProcess, &toProcess->front(), blocksToProcess);
117 }
118 }
119
120 /// Adds the operations between the given start operation and the computed end
121 /// operation to the userange. If the current value is live out, add all
122 /// successor blocks that have the value live in to the process queue. If we
123 /// find a loop, add the operations before the first use in block to the
124 /// userange (if any). The startBlock is the block where the iteration over
125 /// all successors started and is propagated further to find potential loops.
addOperationsInBlockAndFindSuccessors__anon77d1cd940111::UserangeInfoBuilder126 void addOperationsInBlockAndFindSuccessors(
127 const Block *startBlock, Block *toProcess, Operation *start,
128 SmallVector<Block *, 8> &blocksToProcess) {
129 const LivenessBlockInfo *blockInfo = liveness.getLiveness(toProcess);
130 Operation *end = getEndOperation(toProcess);
131
132 addAllOperationsBetween(start, end);
133
134 // If the value is live out we need to process all successors at which the
135 // value is live in.
136 if (!llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); }))
137 return;
138 for (Block *successor : toProcess->getSuccessors()) {
139 // If the successor is the startBlock, we found a loop and only have to
140 // add the operations from the block front to the first use of the
141 // value.
142 if (!llvm::any_of(values, [&](Value v) {
143 return liveness.getLiveness(successor)->isLiveIn(v);
144 }))
145 continue;
146 if (successor == startBlock) {
147 start = &successor->front();
148 end = getStartOperation(successor);
149 if (start != end) addAllOperationsBetween(start, end->getPrevNode());
150 // Else we need to check if the value is live in and the successor
151 // has not been visited before. If so we also need to process it.
152 } else if (visited.insert(successor).second) {
153 blocksToProcess.emplace_back(successor);
154 }
155 }
156 }
157
158 /// Iterates over all regions of a given operation and adds all operations
159 /// inside those regions to the userange of the current value.
addAllOperationsInRegion__anon77d1cd940111::UserangeInfoBuilder160 void addAllOperationsInRegion(Operation *parentOp) {
161 // Iterate over all regions of the parentOp.
162 for (Region ®ion : parentOp->getRegions()) {
163 // Iterate over blocks inside the region.
164 for (Block &block : region) {
165 // If the blocks have been used as a startBlock before, we need to add
166 // all operations between the block front and the startOp of the value.
167 if (startBlocks.contains(&block)) {
168 Operation *start = &block.front();
169 Operation *end = getStartOperation(&block);
170 if (start != end) addAllOperationsBetween(start, end->getPrevNode());
171
172 // If the block has never been seen before, we need to add all
173 // operations inside.
174 } else if (visited.insert(&block).second) {
175 for (Operation &op : block) {
176 addAllOperationsInRegion(&op);
177 currentUserange.emplace_back(&op);
178 }
179 continue;
180 }
181 // If the block has either been visited before or was used as a
182 // startBlock, we need to add all operations between the endOp of the
183 // value and the end of the block.
184 Operation *end = getEndOperation(&block);
185 if (end == &block.back()) continue;
186 addAllOperationsBetween(end->getNextNode(), &block.back());
187 }
188 }
189 }
190
191 /// Find the start operation of the current value inside the given block.
getStartOperation__anon77d1cd940111::UserangeInfoBuilder192 Operation *getStartOperation(Block *block) {
193 Operation *startOperation = &block->back();
194 for (Operation *useOp : opList) {
195 // Find the associated operation in the current block (if any).
196 useOp = block->findAncestorOpInBlock(*useOp);
197 // Check whether the use is in our block and after the current end
198 // operation.
199 if (useOp && useOp->isBeforeInBlock(startOperation))
200 startOperation = useOp;
201 }
202 return startOperation;
203 }
204
205 /// Find the end operation of the current value inside the given block.
getEndOperation__anon77d1cd940111::UserangeInfoBuilder206 Operation *getEndOperation(Block *block) {
207 const LivenessBlockInfo *blockInfo = liveness.getLiveness(block);
208 if (llvm::any_of(values, [&](Value v) { return blockInfo->isLiveOut(v); }))
209 return &block->back();
210
211 Operation *endOperation = &block->front();
212 for (Operation *useOp : opList) {
213 // Find the associated operation in the current block (if any).
214 useOp = block->findAncestorOpInBlock(*useOp);
215 // Check whether the use is in our block and after the current end
216 // operation.
217 if (useOp && endOperation->isBeforeInBlock(useOp)) endOperation = useOp;
218 }
219 return endOperation;
220 }
221
222 /// The current Value.
223 ValueSetT values;
224
225 /// The list of all operations used by the values.
226 OperationListT opList;
227
228 /// The result list of the userange computation.
229 OperationListT currentUserange;
230
231 /// The set of visited blocks during the userange computation.
232 SmallPtrSet<Block *, 32> visited;
233
234 /// The set of blocks that the userange computation started from.
235 SmallPtrSet<Block *, 8> startBlocks;
236
237 /// The current liveness info.
238 Liveness liveness;
239 };
240 } // namespace
241
UserangeAnalysis(Operation * op,const BufferPlacementAllocs & allocs,const BufferViewFlowAnalysis & aliases)242 UserangeAnalysis::UserangeAnalysis(Operation *op,
243 const BufferPlacementAllocs &allocs,
244 const BufferViewFlowAnalysis &aliases)
245 : liveness(op) {
246 // Walk over all operations and map them to an ID.
247 op->walk([&](Operation *operation) {
248 gatherMemoryEffects(operation);
249 operationIds.insert({operation, operationIds.size()});
250 });
251
252 // Compute the use range for every allocValue and its aliases. Merge them
253 // and compute an interval. Add all computed intervals to the useIntervalMap.
254 for (const BufferPlacementAllocs::AllocEntry &entry : allocs) {
255 Value allocValue = std::get<0>(entry);
256 OperationListT useList;
257 for (auto &use : allocValue.getUses()) useList.emplace_back(use.getOwner());
258 UserangeInfoBuilder builder(liveness, {allocValue}, useList);
259 OperationListT liveOperations = builder.computeUserange();
260
261 // Sort the operation list by ids.
262 std::sort(liveOperations.begin(), liveOperations.end(),
263 [&](Operation *left, Operation *right) {
264 return operationIds[left] < operationIds[right];
265 });
266
267 IntervalVector allocInterval = computeInterval(allocValue, liveOperations);
268 // Iterate over all aliases and add their useranges to the userange of the
269 // current value. Also add the useInterval of each alias to the
270 // useIntervalMap.
271 ValueSetT aliasSet = aliases.resolve(allocValue);
272 for (Value alias : aliasSet) {
273 if (alias == allocValue) continue;
274 if (!aliasUseranges.count(alias)) {
275 OperationListT aliasOperations;
276 // If the alias is a BlockArgument then the value is live with the first
277 // operation inside that block. Otherwise the liveness analysis is
278 // sufficient for the use range.
279 if (alias.isa<BlockArgument>()) {
280 aliasOperations.emplace_back(&alias.getParentBlock()->front());
281 for (auto &use : alias.getUses())
282 aliasOperations.emplace_back(use.getOwner());
283 // Compute the use range for the alias and sort the operations
284 // afterwards.
285 UserangeInfoBuilder aliasBuilder(liveness, {alias}, aliasOperations);
286 aliasOperations = aliasBuilder.computeUserange();
287 std::sort(aliasOperations.begin(), aliasOperations.end(),
288 [&](Operation *left, Operation *right) {
289 return operationIds[left] < operationIds[right];
290 });
291 } else {
292 aliasOperations = liveness.resolveLiveness(alias);
293 }
294
295 aliasUseranges.insert({alias, aliasOperations});
296 useIntervalMap.insert(
297 {alias, computeInterval(alias, aliasUseranges[alias])});
298 }
299 allocInterval =
300 std::get<0>(intervalMerge(allocInterval, useIntervalMap[alias]));
301 }
302 aliasCache.insert(std::make_pair(allocValue, aliasSet));
303
304 // Map the current allocValue to the computed useInterval.
305 useIntervalMap.insert(std::make_pair(allocValue, allocInterval));
306 }
307 }
308
309 /// Checks if the use intervals of the given values interfere.
rangesInterfere(Value itemA,Value itemB) const310 bool UserangeAnalysis::rangesInterfere(Value itemA, Value itemB) const {
311 return intervalUnion(itemA, itemB);
312 }
313
314 /// Merges the userange of itemB into the userange of itemA.
315 /// Note: This assumes that there is no interference between the two
316 /// ranges.
unionRanges(Value itemA,Value itemB)317 void UserangeAnalysis::unionRanges(Value itemA, Value itemB) {
318 IntervalVector unionInterval =
319 std::get<0>(intervalMerge(useIntervalMap[itemA], useIntervalMap[itemB]));
320
321 llvm::set_union(aliasCache[itemA], aliasCache[itemB]);
322 for (Value alias : aliasCache[itemA])
323 unionInterval =
324 std::get<0>(intervalMerge(unionInterval, useIntervalMap[alias]));
325
326 // Compute new interval.
327 useIntervalMap[itemA] = unionInterval;
328 }
329
330 /// Builds an IntervalVector corresponding to the given OperationList.
computeInterval(Value value,const Liveness::OperationListT & operationList)331 UserangeAnalysis::IntervalVector UserangeAnalysis::computeInterval(
332 Value value, const Liveness::OperationListT &operationList) {
333 assert(!operationList.empty() && "Operation list must not be empty");
334 size_t start = computeID(value, *operationList.begin());
335 size_t last = start;
336 UserangeAnalysis::IntervalVector intervals;
337 // Iterate over all operations in the operationList. If the gap between the
338 // respective operationIds is greater 1 create a new interval.
339 for (auto opIter = ++operationList.begin(), e = operationList.end();
340 opIter != e; ++opIter) {
341 size_t current = computeID(value, *opIter);
342 if (current - last > 2) {
343 intervals.emplace_back(UserangeAnalysis::UseInterval(start, last));
344 start = current;
345 }
346 last = current;
347 }
348 intervals.emplace_back(UserangeAnalysis::UseInterval(start, last));
349 return intervals;
350 }
351
352 /// Checks each operand inside the operation for its memory effects and
353 /// separates them into read and write. Operands with read effects are added to
354 /// the opToReadMap.
gatherMemoryEffects(Operation * op)355 void UserangeAnalysis::gatherMemoryEffects(Operation *op) {
356 if (OpTrait::hasElementwiseMappableTraits(op)) {
357 if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
358 SmallPtrSet<Value, 2> readEffectSet;
359 SmallPtrSet<Value, 2> writeEffectSet;
360 for (auto operand : op->getOperands()) {
361 SmallVector<MemoryEffects::EffectInstance, 2> effects;
362 effectInterface.getEffectsOnValue(operand, effects);
363 for (auto effect : effects) {
364 if (isa<MemoryEffects::Write>(effect.getEffect()))
365 writeEffectSet.insert(operand);
366 else if (isa<MemoryEffects::Read>(effect.getEffect()))
367 readEffectSet.insert(operand);
368 }
369 }
370 opReadWriteMap.insert(
371 {op, std::make_pair(readEffectSet, writeEffectSet)});
372 }
373 }
374 }
375
376 /// Computes the ID for the operation. If the operation contains operands which
377 /// have read effects, the returning ID will be odd. This allows us to
378 /// perform a replace in place.
computeID(Value v,Operation * op) const379 size_t UserangeAnalysis::computeID(Value v, Operation *op) const {
380 size_t doubledID = operationIds.find(op)->second * 2;
381 auto mapIter = opReadWriteMap.find(op);
382 if (mapIter == opReadWriteMap.end()) return doubledID;
383 auto reads = mapIter->second.first;
384 auto writes = mapIter->second.second;
385 if (reads.contains(v) && !writes.contains(v)) return doubledID - 1;
386 return doubledID;
387 }
388
389 /// Merge two IntervalVectors into a new IntervalVector. Return a pair with the
390 /// resulting IntervalVector and a boolean if there were interferences during
391 /// merging.
392 std::pair<UserangeAnalysis::IntervalVector, bool>
intervalMerge(const IntervalVector & intervalA,const IntervalVector & intervalB) const393 UserangeAnalysis::intervalMerge(const IntervalVector &intervalA,
394 const IntervalVector &intervalB) const {
395 IntervalVector mergeResult;
396
397 bool interference = false;
398 auto iterA = intervalA.begin();
399 auto iterB = intervalB.begin();
400 auto endA = intervalA.end();
401 auto endB = intervalB.end();
402 UseInterval current;
403 while (iterA != endA || iterB != endB) {
404 if (iterA == endA) {
405 // Only intervals from intervalB are left.
406 current = *iterB;
407 ++iterB;
408 } else if (iterB == endB) {
409 // Only intervals from intervalA are left.
410 current = *iterA;
411 ++iterA;
412 } else if (iterA->second < iterB->first) {
413 // A is strict before B: A(0,2), B(4,6)
414 current = *iterA;
415 ++iterA;
416 } else if (iterB->second < iterA->first) {
417 // B is strict before A: A(6,8), B(2,4)
418 current = *iterB;
419 ++iterB;
420 } else {
421 // A and B interfere.
422 interference = true;
423 current = UseInterval(std::min(iterA->first, iterB->first),
424 std::max(iterA->second, iterB->second));
425 ++iterA;
426 ++iterB;
427 }
428 // Merge current with last element in mergeResult, if the intervals are
429 // consecutive and there is no gap.
430 if (mergeResult.empty()) {
431 mergeResult.emplace_back(current);
432 continue;
433 }
434 UseInterval *mergeResultLast = (mergeResult.end() - 1);
435 int diff = current.first - mergeResultLast->second;
436 if (diff <= 2 && mergeResultLast->second < current.second)
437 mergeResultLast->second = current.second;
438 else if (diff > 2)
439 mergeResult.emplace_back(current);
440 }
441
442 return std::make_pair(mergeResult, interference);
443 }
444
445 /// Performs an interval union of the interval vectors from the given values.
446 /// Returns an empty Optional if there is an interval interference.
intervalUnion(Value itemA,Value itemB) const447 bool UserangeAnalysis::intervalUnion(Value itemA, Value itemB) const {
448 ValueSetT intersect = aliasCache.find(itemA)->second;
449 llvm::set_intersect(intersect, aliasCache.find(itemB)->second);
450 IntervalVector tmpIntervalA = useIntervalMap.find(itemA)->second;
451
452 // If the two values share a common alias, then the alias does not count as
453 // interference and should be removed.
454 if (!intersect.empty()) {
455 for (Value alias : intersect) {
456 IntervalVector aliasInterval = useIntervalMap.find(alias)->second;
457 intervalSubtract(tmpIntervalA, aliasInterval);
458 }
459 }
460
461 return std::get<1>(
462 intervalMerge(tmpIntervalA, useIntervalMap.find(itemB)->second));
463 }
464
465 /// Performs an interval subtraction => A = A - B.
466 /// Note: This assumes that all intervals of b are included in some interval
467 /// of a.
intervalSubtract(IntervalVector & a,const IntervalVector & b) const468 void UserangeAnalysis::intervalSubtract(IntervalVector &a,
469 const IntervalVector &b) const {
470 auto iterB = b.begin();
471 auto endB = b.end();
472 for (auto iterA = a.begin(), endA = a.end();
473 iterA != endA && iterB != endB;) {
474 // iterA is strictly before iterB => increment iterA.
475 if (iterA->second < iterB->first) {
476 ++iterA;
477 } else if (iterA->first == iterB->first && iterA->second > iterB->second) {
478 // Usually, we would expect the case of iterB beeing strictly before
479 // iterA. However, due to the initial assumption that all intervals of b
480 // are included in some interval of a, we do not need to check if iterB is
481 // strictly before iterA.
482 // iterB is at the start of iterA, but iterA has some values that go
483 // beyond those of iterB. We have to set the lower bound of iterA to the
484 // upper bound of iterB + 1 and increment iterB.
485 // A(3, 100) - B(3, 5) => A(6,100)
486 iterA->first = iterB->second + 1;
487 ++iterB;
488 } else if (iterA->second == iterB->second && iterA->first < iterB->first) {
489 // iterB is at the end of iterA, but iterA has some values that come
490 // before iterB. We have to set the end of iterA to the start of iterB - 1
491 // and increment both iterators.
492 // A(4, 50) - B(40, 50) => A(4, 39)
493 iterA->second = iterB->first - 1;
494 ++iterA;
495 ++iterB;
496 } else if (iterA->first < iterB->first && iterA->second > iterB->second) {
497 // iterB is in the middle of iterA. We have to split iterA and increment
498 // iterB.
499 // A(2, 10) - B(5, 7) => (2, 4), (8, 10)
500 size_t endA = iterA->second;
501 iterA->second = iterB->first - 1;
502 iterA = a.insert(iterA, UseInterval(iterB->second + 1, endA));
503 ++iterB;
504 } else {
505 // Both intervals are equal. We have to erase the whole interval.
506 // A(5, 5) - B(5, 5) => {}
507 iterA = a.erase(iterA);
508 ++iterB;
509 }
510 }
511 }
512
dump(raw_ostream & os)513 void UserangeAnalysis::dump(raw_ostream &os) {
514 os << "// ---- UserangeAnalysis -----\n";
515 std::vector<Value> values;
516 for (auto const &item : useIntervalMap) {
517 values.emplace_back(item.first);
518 }
519 std::sort(values.begin(), values.end(), [&](Value left, Value right) {
520 if (left.getDefiningOp()) {
521 if (right.getDefiningOp())
522 return operationIds[left.getDefiningOp()] <
523 operationIds[right.getDefiningOp()];
524 else
525 return true;
526 }
527 if (right.getDefiningOp()) return false;
528 return operationIds[&left.getParentBlock()->front()] <
529 operationIds[&right.getParentBlock()->front()];
530 });
531 for (auto value : values) {
532 os << "Value: " << value << (value.getDefiningOp() ? "\n" : "");
533 auto rangeIt = useIntervalMap[value].begin();
534 os << "Userange: {(" << rangeIt->first << ", " << rangeIt->second << ")";
535 rangeIt++;
536 for (auto e = useIntervalMap[value].end(); rangeIt != e; ++rangeIt) {
537 os << ", (" << rangeIt->first << ", " << rangeIt->second << ")";
538 }
539 os << "}\n";
540 }
541 os << "// ---------------------------\n";
542 }
543