xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/buffer_packing.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <list>
17 
18 #include "mlir-hlo/Analysis/userange_analysis.h"
19 #include "mlir-hlo/Transforms/PassDetail.h"
20 #include "mlir-hlo/Transforms/passes.h"
21 #include "mlir-hlo/utils/hlo_utils.h"
22 #include "mlir/Analysis/BufferViewFlowAnalysis.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/Pass/Pass.h"
29 
30 namespace mlir {
31 
32 namespace {
33 
34 /// Returns the length of an userange interval.
computeUserangeSize(const UseInterval & interval)35 size_t computeUserangeSize(const UseInterval &interval) {
36   return interval.end - interval.start + 1;
37 }
38 
39 /// Compute the byte size of a given Value.
computeByteSize(const Value & v)40 size_t computeByteSize(const Value &v) {
41   auto type = v.getType().cast<ShapedType>();
42   return type.getSizeInBits() / 8;
43 }
44 
45 /// Compute the 64 byte alinged segments of a given Value.
computeAlignedSegments(const Value & v)46 size_t computeAlignedSegments(const Value &v) {
47   size_t padding = 64;
48   size_t bytes = computeByteSize(v);
49   return std::ceil(bytes / (double)padding);
50 }
51 
52 /// The buffer offset information.
53 struct AllocBufferOffset {
54  public:
AllocBufferOffsetmlir::__anonf6a1f69c0111::AllocBufferOffset55   AllocBufferOffset(Value source, size_t offset)
56       : source(source), offset(offset) {}
57 
58   Value source;
59   size_t offset;
60 };
61 
62 /// Contains the information to create a new buffer, that is used to pack
63 /// other buffers.
64 struct PackedBuffer {
65  public:
PackedBuffermlir::__anonf6a1f69c0111::PackedBuffer66   PackedBuffer(size_t numSegments,
67                std::vector<AllocBufferOffset> &packedBuffers)
68       : numSegments(numSegments), allocBufferOffsets(packedBuffers) {}
69 
70   size_t numSegments;
71   std::vector<AllocBufferOffset> allocBufferOffsets;
72 };
73 
74 /// Contains the information about a buffers allocation for sorting and checking
75 /// if it fits into other buffers and vise versa.
76 /// This structure contains the allocation value, the first and last userangeid
77 /// of a buffer, the window id, the number of alligned 64 byte segments and all
78 /// userange intervals.
79 struct AllocationInfo {
80  public:
AllocationInfomlir::__anonf6a1f69c0111::AllocationInfo81   AllocationInfo(Value alloc, size_t allocUserangeId, size_t firstUse,
82                  size_t lastUse, size_t numSegments, size_t windowId,
83                  const UseInterval::Vector *userangeIntervals)
84       : alloc(alloc),
85         allocUserangeId(allocUserangeId),
86         firstUse(firstUse),
87         lastUse(lastUse),
88         numSegments(numSegments),
89         windowId(windowId),
90         userangeIntervals(userangeIntervals) {}
91 
92   /// The allocation value.
93   Value alloc;
94 
95   /// The id of allocation based on the Userange Analysis.
96   size_t allocUserangeId;
97 
98   /// The first use of the buffer.
99   size_t firstUse;
100 
101   /// The last use of the buffer based on the Userange Analysis.
102   size_t lastUse;
103 
104   /// The number of 64 byte aligned segments of contigous memory.
105   size_t numSegments;
106 
107   /// The window id of the allocation position.
108   size_t windowId;
109 
110   /// The userange intervals of the buffer.
111   const UseInterval::Vector *userangeIntervals;
112 
113   /// Compute the gaps of the alloc userange with the number of segments. The
114   /// maxUserangeId is used to add a dummy gap from the last used id to the
115   /// maxUserangeId. By default the maxUserangeId is zero and no gap is added.
computeGapsmlir::__anonf6a1f69c0111::AllocationInfo116   std::list<std::pair<UseInterval, size_t>> computeGaps(
117       size_t maxUserangeId = 0) {
118     std::list<std::pair<UseInterval, size_t>> gaps;
119 
120     // The previous gap ending, initially set to 0.
121     size_t gapEnd = 0;
122 
123     for (const auto *useRangeIter = userangeIntervals->begin();
124          useRangeIter < userangeIntervals->end(); ++useRangeIter) {
125       // Add a gap if the end is not equal to the start.
126       if (gapEnd < useRangeIter->start)
127         gaps.emplace_back(UseInterval(gapEnd, useRangeIter->start - 1),
128                           numSegments);
129       gapEnd = useRangeIter->end + 1;
130     }
131 
132     // Add a dummy gap behind the last use of the buffer.
133     if (gapEnd < maxUserangeId) {
134       gaps.emplace_back(UseInterval(gapEnd, maxUserangeId), numSegments);
135     }
136 
137     return gaps;
138   }
139 
140   /// Compute the userange size.
getUserangeSizemlir::__anonf6a1f69c0111::AllocationInfo141   size_t getUserangeSize() const { return lastUse - firstUse + 1; }
142 };
143 
144 // Comparator to sort allocation informations by window id, userange and by
145 // number of memory segments.
146 class AllocInfoWinIdComparator {
147  public:
operator ()(const AllocationInfo & a,const AllocationInfo & b)148   bool operator()(const AllocationInfo &a, const AllocationInfo &b) {
149     if (a.windowId == b.windowId) {
150       if (a.allocUserangeId == b.allocUserangeId)
151         return a.numSegments > b.numSegments;
152       return a.allocUserangeId > b.allocUserangeId;
153     }
154     return a.windowId < b.windowId;
155   }
156 };
157 
158 // Comparator to sort the allocation informations by number of segments.
159 class AllocInfoMemSizeCompare {
160  public:
operator ()(const AllocationInfo & a,const AllocationInfo & b)161   bool operator()(const AllocationInfo &a, const AllocationInfo &b) {
162     return a.numSegments > b.numSegments;
163   }
164 };
165 
166 /// This approach computes an allocation information list and sorts it by
167 /// a given comparator. From top to bottom the algortihm tries to fill userange
168 /// gaps with appropriate buffers behind it, to optimze the memory. It is a bin
169 /// packing approach.
170 template <typename CompareT>
171 class SortedPackingStrategy {
172  public:
173   using AllocInfoList = std::vector<AllocationInfo>;
174 
175  public:
176   /// Constructs the Sorted Packing Strategy. The window size is used as sliding
177   /// window size. Allocation userangepositions that are in the same range are
178   /// mapped to the same window id. So the information of the allocation
179   /// starting position is blured.
SortedPackingStrategy(size_t windowSize,CompareT compare)180   SortedPackingStrategy(size_t windowSize, CompareT compare)
181       : windowSize(windowSize), compare(compare) {}
182 
183   /// Optimize the buffer allocations.
optimze(const mlir::bufferization::BufferPlacementAllocs & allocs,const UserangeAnalysis & userangeAnalysis,std::vector<PackedBuffer> & packedBuffers)184   void optimze(const mlir::bufferization::BufferPlacementAllocs &allocs,
185                const UserangeAnalysis &userangeAnalysis,
186                std::vector<PackedBuffer> &packedBuffers) {
187     AllocInfoList allocInfos;
188     allocInfos.reserve(std::distance(allocs.begin(), allocs.end()));
189 
190     // Create allocInformations and store them in allocInfos.
191     size_t maxUserangeId =
192         computeAllocationInfos(allocInfos, userangeAnalysis, allocs);
193 
194     // Sort the allocation infos.
195     std::sort(allocInfos.begin(), allocInfos.end(), compare);
196 
197     for (auto currentIter = allocInfos.begin(); currentIter != allocInfos.end();
198          ++currentIter) {
199       std::vector<AllocBufferOffset> allocBufferOffsets{
200           AllocBufferOffset(currentIter->alloc, 0)};
201 
202       // Compute userange gaps.
203       std::list<std::pair<UseInterval, size_t>> gaps =
204           currentIter->computeGaps(maxUserangeId);
205 
206       if (gaps.empty()) continue;
207 
208       for (auto checkedAllocInfoIter = std::next(currentIter);
209            checkedAllocInfoIter != allocInfos.end();) {
210         // Check if a gap exists to pack the memory into.
211         // If not continue.
212         if (!findGapAndUpdate(gaps, allocBufferOffsets, *checkedAllocInfoIter,
213                               *currentIter)) {
214           ++checkedAllocInfoIter;
215           continue;
216         }
217         checkedAllocInfoIter = allocInfos.erase(checkedAllocInfoIter);
218       }
219       // Add the current buffer offets to the packed infos.
220       packedBuffers.emplace_back(currentIter->numSegments * 64,
221                                  allocBufferOffsets);
222     }
223   }
224 
225  private:
226   const size_t windowSize;
227   const CompareT compare;
228 
229   /// We try to find an appropriate userange gap to pack the buffer into it.
230   /// If we find one we update only the gaps and the buffer offset map.
findGapAndUpdate(std::list<std::pair<UseInterval,size_t>> & gaps,std::vector<AllocBufferOffset> & allocBufferOffsets,const AllocationInfo & allocToPack,const AllocationInfo & allocToPackInto)231   bool findGapAndUpdate(std::list<std::pair<UseInterval, size_t>> &gaps,
232                         std::vector<AllocBufferOffset> &allocBufferOffsets,
233                         const AllocationInfo &allocToPack,
234                         const AllocationInfo &allocToPackInto) {
235     // Check if the buffer to pack into has enough memory.
236     if (allocToPackInto.numSegments < allocToPack.numSegments) return false;
237     for (auto gapIter = gaps.begin(); gapIter != gaps.end();) {
238       // The list is sorted, so we can break here.
239       if (gapIter->first.start > allocToPack.firstUse) break;
240 
241       // Checks if enough contiguous memory segments are free or if the current
242       // gap is out of bounds.
243       if (gapIter->second < allocToPack.numSegments ||
244           allocToPack.firstUse < gapIter->first.start ||
245           allocToPack.lastUse > gapIter->first.end) {
246         ++gapIter;
247         continue;
248       }
249 
250       // Stores the packed buffer with the offset.
251       allocBufferOffsets.emplace_back(
252           allocToPack.alloc,
253           (allocToPackInto.numSegments - gapIter->second) * 64);
254 
255       // Update gap segments, will removed later if no free contigous memory
256       // exists. It is needed to split the interval, if not the full gap is
257       // used.
258       size_t freeContiguousMemory = gapIter->second;
259       gapIter->second = freeContiguousMemory - allocToPack.numSegments;
260 
261       // Check if the gap must be splitted. If so, then the current gap must be
262       // trimmed accordingly. Therefore, new gaps are created in front and after
263       // the current gap.
264       if (computeUserangeSize(gapIter->first) > allocToPack.getUserangeSize()) {
265         size_t oldStart = gapIter->first.start;
266         size_t oldEnd = gapIter->first.end;
267         gapIter->first.end = allocToPack.lastUse;
268         gapIter->first.start = allocToPack.firstUse;
269 
270         // Insert a new gap behind.
271         if (allocToPack.lastUse < oldEnd)
272           gaps.insert(
273               std::next(gapIter),
274               std::make_pair(UseInterval(allocToPack.lastUse + 1, oldEnd),
275                              freeContiguousMemory));
276         // Insert a new gap before.
277         if (allocToPack.firstUse > oldStart)
278           gaps.insert(
279               gapIter,
280               std::make_pair(UseInterval(oldStart, allocToPack.firstUse - 1),
281                              freeContiguousMemory));
282       }
283 
284       // If a gap interval has no free contiguous memory anymore, erease it from
285       // list.
286       if (gapIter->second <= 0) gapIter = gaps.erase(gapIter);
287 
288       return true;
289     }
290     return false;
291   }
292 
293   /// Aggreagtes the allocation informations of the allocs and returns the
294   /// maximal userange.
computeAllocationInfos(AllocInfoList & allocInfos,const UserangeAnalysis & userangeAnalysis,const mlir::bufferization::BufferPlacementAllocs & allocs)295   size_t computeAllocationInfos(
296       AllocInfoList &allocInfos, const UserangeAnalysis &userangeAnalysis,
297       const mlir::bufferization::BufferPlacementAllocs &allocs) {
298     // Create allocInformations and store them in allocInfos.
299     size_t maxUserangeId = 0;
300 
301     for (auto &allocEntry : allocs) {
302       Value v = std::get<0>(allocEntry);
303       auto userangeIntervals = userangeAnalysis.getUserangeInterval(v);
304 
305       if (!userangeIntervals) continue;
306 
307       // Computes the userange id of the allocation.
308       size_t allocUserangeId = userangeAnalysis.computeId(v, v.getDefiningOp());
309 
310       // Computes the last use of the allocated buffer.
311       size_t lastUse = std::prev((*userangeIntervals.value()).end())->end;
312 
313       // Computes the first use of the allocated buffer.
314       size_t firstUse = (*userangeIntervals.value()).begin()->start;
315 
316       // Computes the number of aligend segments of the buffer.
317       size_t numSegments = computeAlignedSegments(v);
318       maxUserangeId = std::max(maxUserangeId, lastUse);
319       allocInfos.emplace_back(v, allocUserangeId, firstUse, lastUse,
320                               numSegments, 0, userangeIntervals.value());
321     }
322 
323     // If the window size is zero we need no sorting anymore.
324     if (windowSize == 0) return maxUserangeId;
325     // Sorts the allocation informations to compute the window id. The window id
326     // is used to blur the userange starting position of an allocation.
327     std::sort(allocInfos.begin(), allocInfos.end(),
328               [](const AllocationInfo &a, const AllocationInfo &b) {
329                 return a.allocUserangeId < b.allocUserangeId;
330               });
331 
332     // resize window id
333     size_t windowId = 0;
334     size_t lastAllocUserangeId = 0;
335     for (auto &allocationInfo : allocInfos) {
336       if (allocationInfo.allocUserangeId > lastAllocUserangeId + windowSize)
337         ++windowId;
338 
339       lastAllocUserangeId = allocationInfo.allocUserangeId;
340       allocationInfo.windowId = windowId;
341     }
342     return maxUserangeId;
343   }
344 };
345 
346 /// Pass to pack buffer together to optimize the memeory consumption and to
347 /// save allocation operations. A strategy must be passed as a template
348 /// argument.
349 class BufferPacking : bufferization::BufferPlacementTransformationBase {
350  public:
351   template <typename StrategyT>
BufferPacking(Operation * op,StrategyT strategy)352   BufferPacking(Operation *op, StrategyT strategy)
353       : BufferPlacementTransformationBase(op),
354         userangeAnalysis(op, allocs, aliases),
355         dominators(op) {
356     std::vector<PackedBuffer> packedBuffers;
357     strategy.optimze(allocs, userangeAnalysis, packedBuffers);
358 
359     for (auto &packedBuffer : packedBuffers) {
360       // Find common dominators.
361       Block *block = findAllocationsDominator(packedBuffer.allocBufferOffsets);
362       // Find alloc position operation.
363       mlir::OpBuilder packBuilder(&(block->front()));
364       auto location = block->front().getLoc();
365       auto memrefType =
366           MemRefType::get({static_cast<int64_t>(packedBuffer.numSegments)},
367                           packBuilder.getIntegerType(8));
368       Value targetBuffer =
369           packBuilder.create<memref::AllocOp>(location, memrefType);
370 
371       for (auto &packInfo : packedBuffer.allocBufferOffsets) {
372         Value currentAlloc = packInfo.source;
373         size_t offset = packInfo.offset;
374         Operation *viewDefOp = currentAlloc.getDefiningOp();
375         Location loc = viewDefOp->getLoc();
376         mlir::OpBuilder viewBuilder(viewDefOp);
377 
378         // Create a arithmetic ConstantOp with the aligned offset.
379         Value constantOp = viewBuilder.create<mlir::arith::ConstantOp>(
380             loc, viewBuilder.getIndexType(),
381             viewBuilder.getIntegerAttr(viewBuilder.getIndexType(), offset));
382 
383         // Store the operands for the ViewOp.
384         SmallVector<Value, 4> newOperands{targetBuffer};
385         newOperands.push_back(constantOp);
386 
387         auto shape = currentAlloc.getType().cast<MemRefType>();
388 
389         // Create a ViewOp with the shape of the old alloc and use the created
390         // packed alloc and the constant for the operands.
391         Value viewOp =
392             viewBuilder.create<memref::ViewOp>(loc, shape, newOperands);
393 
394         // Replace all old allocs references with the created ViewOp and
395         // afterwards remove the old allocs.
396         currentAlloc.replaceAllUsesWith(viewOp);
397         viewDefOp->erase();
398       }
399     }
400   }
401 
402  private:
403   UserangeAnalysis userangeAnalysis;
404   /// The current dominance info.
405   DominanceInfo dominators;
406 
407   /// Find the block that dominates all buffer allocations.
findAllocationsDominator(const std::vector<AllocBufferOffset> & packingInfos)408   Block *findAllocationsDominator(
409       const std::vector<AllocBufferOffset> &packingInfos) {
410     SmallPtrSet<Value, 16> allocValues;
411     for (auto &packInfo : packingInfos) {
412       allocValues.insert(packInfo.source);
413     }
414 
415     // Find common dominators.
416     return findCommonDominator(packingInfos.begin()->source, allocValues,
417                                dominators);
418   }
419 };
420 
421 /// Tries to pack allocated buffer together to save allocation operations and
422 /// memory. The window size is used as sliding window size. Allocation
423 /// userangepoitions that are in the same range are mapped to the same window
424 /// id. The information of the allocation starting position is blured.
425 struct BufferPackingPass : public BufferPackingBase<BufferPackingPass> {
BufferPackingPassmlir::__anonf6a1f69c0111::BufferPackingPass426   explicit BufferPackingPass(unsigned windowSize) {
427     this->window_size_ = windowSize;
428   }
429 
runOnOperationmlir::__anonf6a1f69c0111::BufferPackingPass430   void runOnOperation() override {
431     if (window_size_ == 0) {
432       SortedPackingStrategy<AllocInfoMemSizeCompare> strategy(
433           window_size_, AllocInfoMemSizeCompare());
434       BufferPacking packing(getOperation(), strategy);
435     } else {
436       SortedPackingStrategy<AllocInfoWinIdComparator> strategy(
437           window_size_, AllocInfoWinIdComparator());
438       BufferPacking packing(getOperation(), strategy);
439     }
440   }
441 };
442 
443 /// Pass to find all allocations and to compute memory usage.
444 struct MemoryCountPass : MemoryCountBase<MemoryCountPass> {
runOnOperationmlir::__anonf6a1f69c0111::MemoryCountPass445   void runOnOperation() override {
446     Operation *op = getOperation();
447     std::vector<Value> allocs;
448     op->walk([&](MemoryEffectOpInterface opInterface) {
449       // Try to find a single allocation result.
450       SmallVector<MemoryEffects::EffectInstance, 2> effects;
451       opInterface.getEffects(effects);
452 
453       SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
454       llvm::copy_if(
455           effects, std::back_inserter(allocateResultEffects),
456           [=](MemoryEffects::EffectInstance &it) {
457             Value value = it.getValue();
458             return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
459                    value.isa<OpResult>() &&
460                    it.getResource() !=
461                        SideEffects::AutomaticAllocationScopeResource::get();
462           });
463 
464       if (allocateResultEffects.size() != 1) return;
465       // Insert allocation.
466       allocs.push_back(allocateResultEffects[0].getValue());
467     });
468     auto output = mlir::hlo::computeMemory(allocs);
469     llvm::outs() << "Memory Count Pass:\n"
470                  << output.first << ";" << output.second << "\n";
471   }
472 };
473 
474 }  // namespace
475 
createBufferPackingPass(unsigned windowSize)476 std::unique_ptr<OperationPass<func::FuncOp>> createBufferPackingPass(
477     unsigned windowSize) {
478   return std::make_unique<BufferPackingPass>(windowSize);
479 }
480 
createMemoryCountPass()481 std::unique_ptr<OperationPass<func::FuncOp>> createMemoryCountPass() {
482   return std::make_unique<MemoryCountPass>();
483 }
484 
485 }  // namespace mlir
486