xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/buffer_assignment.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <iosfwd>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/service/heap_simulator.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/logical_buffer.h"
37 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/logging.h"
42 
43 namespace xla {
44 
45 // Walk the call graph of the HLO module and place each computation into either
46 // thread_local_computations or global_computations depending upon whether the
47 // computation requires thread-local allocations or global allocations. The
48 // elements in thread_local_computations and global_computations are in post
49 // order (if computation A has an instruction which calls computation B, then A
50 // will appear after B in the vector).
51 Status GatherComputationsByAllocationType(
52     const HloModule* module,
53     std::vector<const HloComputation*>* thread_local_computations,
54     std::vector<const HloComputation*>* global_computations);
55 
56 // This class abstracts an allocation of contiguous memory which can hold the
57 // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range
58 // of the allocation, represented by a Slice. A single BufferAllocation may hold
59 // LogicalBuffers with disjoint liveness, which may have overlapping Slices. A
60 // single BufferAllocation may also hold LogicalBuffers with overlapping
61 // liveness, which must have disjoint Slices.
62 //
63 // The abstraction includes information required by the backends for allocation,
64 // use, and deallocation of the buffer. This includes the LogicalBuffers which
65 // are held in this allocation through the execution of the computation.
66 class BufferAllocation {
67  public:
68   // Holds a unique identifier for each allocation. Values are assigned
69   // contiguously and can be used as array indexes.
70   using Index = int64_t;
71 
BufferAllocation(Index index,int64_t size,LogicalBuffer::Color color)72   BufferAllocation(Index index, int64_t size, LogicalBuffer::Color color)
73       : index_(index), size_(size), color_(color) {}
~BufferAllocation()74   ~BufferAllocation() {}
75 
76   // Returns the index of this allocation.
index()77   Index index() const { return index_; }
78 
79   // Whether this allocation is used in a parallel calling context such as
80   // inside of a map or reduce computation. Such allocations need to be thread
81   // local.
is_thread_local()82   bool is_thread_local() const { return is_thread_local_; }
set_is_thread_local(bool is_thread_local)83   void set_is_thread_local(bool is_thread_local) {
84     is_thread_local_ = is_thread_local;
85   }
86 
87   // Whether this allocation can be used by more than one logical buffer.
is_reusable()88   bool is_reusable() const {
89     // We do not reuse thread-local buffers for now, because they are
90     // dynamically allocated and their lifetimes are hard to compute.
91     //
92     // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
93     // assumes longer buffer liveness than indicated by the analysis.
94     return !is_thread_local() && !is_tuple();
95   }
96 
97   // Whether this allocation is readonly i.e. backed by memory we cannot write
98   // to.
is_readonly()99   bool is_readonly() const {
100     // Entry parameters are generally readonly, except when they are aliased
101     // with any output.
102     return (is_entry_computation_parameter() &&
103             !is_parameter_aliased_with_output_) ||
104            is_constant();
105   }
106 
is_tuple()107   bool is_tuple() const { return is_tuple_; }
set_is_tuple(bool is_tuple)108   void set_is_tuple(bool is_tuple) { is_tuple_ = is_tuple; }
109 
110   // Whether this allocation holds a LogicalBuffer from a parameter of the entry
111   // computation. These buffers have lifetimes which may be longer than the
112   // XLA computation.
is_entry_computation_parameter()113   bool is_entry_computation_parameter() const {
114     return is_entry_computation_parameter_;
115   }
116 
117   // Whether this allocation holds a constant.  On the CPU and GPU backends
118   // constant allocations are not allocated dynamically, instead we resolve
119   // references to these buffer allocations to a global in the readonly section
120   // of the binary.
is_constant()121   bool is_constant() const { return is_constant_; }
122 
123   // If this allocation holds a Buffer from a parameter of the entry
124   // computation, this methods returns the parameter number. CHECKs otherwise.
parameter_number()125   int64_t parameter_number() const {
126     CHECK(is_entry_computation_parameter_);
127     return parameter_number_;
128   }
129 
130   // If this allocation is for a parameter of the entry computation, this
131   // function returns which subshape of the parameter the allocation is for.
param_shape_index()132   const ShapeIndex& param_shape_index() const {
133     CHECK(is_entry_computation_parameter_);
134     return param_shape_index_;
135   }
136 
137   // Returns whether this allocation is assigned a LogicalBuffer which may
138   // be live out of the entry computation.
maybe_live_out()139   bool maybe_live_out() const { return maybe_live_out_; }
140 
set_maybe_live_out(bool value)141   void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
142 
143   // Returns the size of the allocation. Necessarily this must be at least as
144   // large as any LogicalBuffer assigned to this allocation.
size()145   int64_t size() const { return size_; }
146 
147   // Returns the color of the allocation. Only logical buffers with a matching
148   // color can reside in this allocation.
color()149   LogicalBuffer::Color color() const { return color_; }
150 
151   struct OffsetSize {
152     int64_t offset = 0;
153     int64_t size = 0;
154   };
155 
156   // Access to the logical buffers assigned to this allocation, and their
157   // associated logical offsets and sizes.
assigned_buffers()158   const absl::flat_hash_map<const HloValue*, OffsetSize>& assigned_buffers()
159       const {
160     return assigned_buffers_;
161   }
162 
163   // A Slice represents a contiguous portion of a memory allocation. It is used
164   // to identify the memory range that a LogicalBuffer corresponds to.
165   class Slice {
166    public:
Slice()167     Slice() {}
Slice(const BufferAllocation * allocation,int64_t offset,int64_t size)168     Slice(const BufferAllocation* allocation, int64_t offset, int64_t size)
169         : allocation_(allocation), offset_(offset), size_(size) {}
170 
allocation()171     const BufferAllocation* allocation() const { return allocation_; }
index()172     Index index() const { return allocation_->index(); }
offset()173     int64_t offset() const { return offset_; }
size()174     int64_t size() const { return size_; }
175 
176     bool operator==(const Slice& other) const {
177       return index() == other.index() && offset_ == other.offset_ &&
178              size_ == other.size_;
179     }
180     bool operator!=(const Slice& other) const { return !(*this == other); }
181     bool operator<(const Slice& other) const {
182       if (index() != other.index()) return index() < other.index();
183       if (offset_ != other.offset_) return offset_ < other.offset_;
184       return size_ < other.size_;
185     }
186 
187     // Returns true iff this slice's memory range has a non-empty intersection
188     // with the other slice's memory range.
OverlapsWith(const Slice & other)189     bool OverlapsWith(const Slice& other) const {
190       const int64_t end = offset_ + size_;
191       const int64_t other_end = other.offset_ + other.size_;
192       return index() == other.index() && offset_ < other_end &&
193              end > other.offset_;
194     }
195 
196     template <typename H>
AbslHashValue(H h,const Slice & s)197     friend H AbslHashValue(H h, const Slice& s) {
198       return H::combine(std::move(h), s.index(), s.offset(), s.size());
199     }
200 
201     std::string ToString() const;
202 
203    private:
204     const BufferAllocation* allocation_ = nullptr;
205     int64_t offset_ = 0;
206     int64_t size_ = 0;
207   };
208 
209   // GetSlice returns the Slice of contiguous memory that holds the value
210   // described by the given 'buffer'.
211   // REQUIRES: 'buffer' must be assigned to this allocation.
212   Slice GetSlice(const HloValue& buffer) const;
213 
214   std::string ToString() const;
215   BufferAllocationProto ToProto() const;
216 
217   // Whether the buffer is a parameter to or live out of the entry computation.
IsInputOrOutput()218   bool IsInputOrOutput() const {
219     return is_entry_computation_parameter() || maybe_live_out();
220   }
221 
222   // Whether the buffer is a temporary buffer allocated before
223   // Executable::ExecuteOnStream.
IsPreallocatedTempBuffer()224   bool IsPreallocatedTempBuffer() const {
225     // Parameters do not need temporary buffers.
226     return !is_entry_computation_parameter() &&
227            // LogicalBuffers that maybe pointed to by the output should live out
228            // of the computation.
229            !maybe_live_out() &&
230            // Thread-local buffers are allocated using `alloca`s.
231            !is_thread_local() &&
232            // Constant buffers are allocated as global values.
233            !is_constant();
234   }
235 
236   // Add a heap trace which was used to assign slices to logical buffers in this
237   // allocation. A single BufferAllocation may include multiple heap traces
238   // in the case of the temporary block where there is a heap trace per
239   // computation.
AddHeapTrace(const HeapSimulatorTrace & heap_trace)240   void AddHeapTrace(const HeapSimulatorTrace& heap_trace) {
241     heap_traces_.push_back(heap_trace);
242     heap_traces_.back().set_buffer_allocation_index(index());
243   }
244 
245   // Return the set of heap traces used to assign slices to logical buffers in
246   // this allocation.
HeapTraces()247   const std::vector<HeapSimulatorTrace> HeapTraces() const {
248     return heap_traces_;
249   }
250 
251   // Returns the LogicalBuffers which are live at the point of peak memory usage
252   // for this allocation. The point of peak memory usage is the point at which
253   // the total size of all live logical buffers is maximal. If peak memory is
254   // reached at multiple points, the set of logical buffers live at the earliest
255   // maximal point is returned. The vector is stably sorted by
256   // BufferValue::Index.
PeakMemoryLogicalBuffers()257   const std::vector<const HloValue*>& PeakMemoryLogicalBuffers() const {
258     return peak_buffers_;
259   }
260 
261   // Get the number of bytes lost to fragmentation. This is equal to the
262   // difference between the size of the allocation and the size of the maximal
263   // live set.
fragmentation_bytes()264   int64_t fragmentation_bytes() const { return fragmentation_bytes_; }
265 
266   bool operator==(const BufferAllocation& other) const {
267     return index_ == other.index_;
268   }
269   bool operator!=(const BufferAllocation& other) const {
270     return !(*this == other);
271   }
272   bool operator<(const BufferAllocation& other) const {
273     return index() < other.index();
274   }
275 
set_entry_computation_parameter(int64_t parameter_number,ShapeIndex param_shape_index,bool parameter_aliased_with_output)276   void set_entry_computation_parameter(int64_t parameter_number,
277                                        ShapeIndex param_shape_index,
278                                        bool parameter_aliased_with_output) {
279     is_entry_computation_parameter_ = true;
280     is_parameter_aliased_with_output_ = parameter_aliased_with_output;
281     parameter_number_ = parameter_number;
282     param_shape_index_ = std::move(param_shape_index);
283   }
284 
set_constant(bool is_constant)285   void set_constant(bool is_constant) { is_constant_ = is_constant; }
286 
287  private:
288   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
289   friend class BufferAssigner;
290   friend class BufferAssignment;
291 
292   // Adds a LogicalBuffer to the set assigned to this buffer.
293   void AddAssignment(const HloValue& buffer, int64_t offset, int64_t size);
294 
set_index(Index index)295   void set_index(Index index) { index_ = index; }
set_size(int64_t size)296   void set_size(int64_t size) { size_ = size; }
297 
298   // The index of the allocation in the BufferAssignment.
299   Index index_;
300 
301   // Size of the allocation in bytes.
302   int64_t size_;
303 
304   // Whether this buffer needs to be thread-local.
305   bool is_thread_local_ = false;
306 
307   // Whether this buffer holds a tuple.
308   bool is_tuple_ = false;
309 
310   // Color of the allocation.
311   LogicalBuffer::Color color_;
312 
313   // Whether this allocation holds an entry computation parameter. Entry
314   // computation parameters are special because they have lifetimes which may
315   // outlast the computation.
316   bool is_entry_computation_parameter_ = false;
317 
318   // Whether this entry computation parameter is aliased with output.
319   bool is_parameter_aliased_with_output_ = false;
320 
321   // If this allocation holds an entry computation parameter, this field
322   // indicates the index (starting from 0) of the parameter.
323   int64_t parameter_number_ = 0;
324 
325   // If this buffer is for an entry computation parameter, which subshape of the
326   // parameter is it for?
327   ShapeIndex param_shape_index_;
328 
329   // Whether the allocation contains a LogicalBuffer which may be live-out of
330   // the entry computation. Note that this flag is conservatively computed by
331   // TuplePointsToAnalysis.  That is, an allocation marked `maybe_live_out_`
332   // might not actually escape.
333   bool maybe_live_out_ = false;
334 
335   // See comment on the is_constant() accessor.
336   bool is_constant_ = false;
337 
338   // Mapping from the set of buffers assigned to this allocation to their
339   // logical offsets and sizes.
340   absl::flat_hash_map<const HloValue*, OffsetSize> assigned_buffers_;
341 
342   int64_t fragmentation_bytes_ = 0;
343   std::vector<HeapSimulatorTrace> heap_traces_;
344 
345   // Set of buffers live at the point of peak memory usage for this allocation.
346   std::vector<const HloValue*> peak_buffers_;
347 };
348 
349 // Add stream operators for nicer output of CHECK/RET_CHECK failures.
350 std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
351 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s);
352 
353 // This class encapsulates an assignment of the LogicalBuffers in an XLA
354 // module to a set of BufferAllocations.
355 class BufferAssignment {
356  public:
357   // Returns the vector containing all buffer allocations in this assignment.
Allocations()358   const std::vector<BufferAllocation>& Allocations() const {
359     return allocations_;
360   }
361 
362   // This is similar to copying Allocations(), but since it's moved out, it
363   // preserves the addresses. Since BufferAllocation::Slice keeps a
364   // BufferAllocation*, and some backends keep BufferAllocation::Slice in
365   // xla::Executables, migrating off the use of addresses can be hard.
ReleaseAllocations()366   std::vector<BufferAllocation> ReleaseAllocations() {
367     return std::move(allocations_);
368   }
369 
370   // Returns the total size allocation holding all temporary buffers.
temp_allocation_total_size()371   int64_t temp_allocation_total_size() const {
372     return temp_allocation_total_size_;
373   }
374 
multiheap_size_constraint_per_heap()375   uint64_t multiheap_size_constraint_per_heap() const {
376     return multiheap_size_constraint_per_heap_;
377   }
378 
379   // Returns whether the given buffer has been assigned an allocation.
380   bool HasAllocation(const HloValue& value) const;
381 
382   bool HasAllocation(const HloBuffer& buffer) const;
383 
384   // Returns the allocation that a particular LogicalBuffer has been assigned
385   // to. CHECKs if buffer has not been assigned an allocation.
386   const BufferAllocation& GetAssignedAllocation(const HloValue& value) const;
387 
388   const BufferAllocation& GetAssignedAllocation(
389       const HloBuffer& hlo_buffer) const;
390 
391   // Returns the allocation with the given index. CHECKs if no allocation exists
392   // with the given index.
393   const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
394 
395   // Returns the allocation with the given instruction and shape index. nullptr
396   // if no allocation exists.
397   const BufferAllocation* GetInstructionAllocation(
398       const HloInstruction* hlo, const ShapeIndex& shape_index) const;
399 
400   // Builds and returns a vector containing the slices which might contain the
401   // subvalue at the given index of given instruction.
402   std::set<BufferAllocation::Slice> GetAllSlices(
403       const HloInstruction* instruction, const ShapeIndex& index) const;
404 
405   // Convenience function which returns whether the buffer of the
406   // instruction at the given index is assigned an allocation.
407   bool HasAllocationAt(const HloInstruction* instruction,
408                        const ShapeIndex& index) const;
409 
410   // Convenience function which returns whether the top-level buffer of the
411   // instruction (index == {}) is assigned an allocation.
412   bool HasTopLevelAllocation(const HloInstruction* instruction) const;
413 
414   // Convenience function which returns the unique slice containing the buffer
415   // at the given index of the given instruction. If a slice is not assigned or
416   // the slice cannot be determined at compile time then an error is returned.
417   StatusOr<BufferAllocation::Slice> GetUniqueSlice(
418       const HloInstruction* instruction, const ShapeIndex& index) const;
419   // Like GetUniqueSlice but fixes the index to the top-level of the shape
420   // (index = {}).
421   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelSlice(
422       const HloInstruction* instruction) const;
423   // Like GetUniqueTopLevelSlice but returns the slice for the output of the
424   // entry computation of the HLO module (ie, the result of the XLA
425   // computation).
426   StatusOr<BufferAllocation::Slice> GetUniqueTopLevelOutputSlice() const;
427 
428   // Returns the set BufferValues which may be the source of the value at the
429   // given index and instruction.
GetSourceBuffers(const HloInstruction * instruction,const ShapeIndex & index)430   const std::vector<const HloValue*>& GetSourceBuffers(
431       const HloInstruction* instruction, const ShapeIndex& index) const {
432     return dataflow_analysis().GetValueSet(instruction, index).values();
433   }
434 
435   // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
436   // share the same BufferAllocation::Slice.
437   // Returns false otherwise.
438   // REQUIRES: BufferAssignment assigned allocations to both instructions.
439   bool SharesSliceAtIndex(const HloInstruction* hlo_a,
440                           const ShapeIndex& shape_index_a,
441                           const HloInstruction* hlo_b,
442                           const ShapeIndex& shape_index_b) const;
443 
444   // Returns true if the top-level buffers of hlo_a and hlo_b are the same.
445   // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b).
SharesTopLevelSlice(const HloInstruction * hlo_a,const HloInstruction * hlo_b)446   bool SharesTopLevelSlice(const HloInstruction* hlo_a,
447                            const HloInstruction* hlo_b) const {
448     return SharesSliceAtIndex(hlo_a, {}, hlo_b, {});
449   }
450 
451   // Returns true if hlo_a and hlo_b both have at least one buffer assigned for
452   // their top-level and each of their nested shape indices, and if hlo_a's
453   // buffers are all different from hlo_b's buffers.
454   bool HaveDisjointSlices(const HloInstruction* hlo_a,
455                           const HloInstruction* hlo_b) const;
456 
dataflow_analysis()457   const HloDataflowAnalysis& dataflow_analysis() const {
458     return alias_analysis_->dataflow_analysis();
459   }
460 
alias_analysis()461   HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; }
462 
hlo_ordering()463   const HloOrdering& hlo_ordering() const { return *hlo_ordering_; }
464 
465   // Returns the HloLiveRange object used to construct this assignment.
hlo_live_range()466   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
467 
468   std::string ToString() const;
469   // Verbose string tailored to debugging OOMs, includes the Hlo op metadata for
470   // every buffer associated with each allocation.
471   std::string ToVerboseString() const;
472   std::string BufferInfoString() const;
473   BufferAssignmentProto ToProto() const;
474 
475   // Statistics for the assignment.  Values initialized to -1 are not always
476   // collected; fragmentation is only collected for instructions that have a
477   // sequential total ordering.
478   struct Stats {
479     int64_t parameter_allocation_count = 0;
480     int64_t parameter_allocation_bytes = 0;
481     int64_t constant_allocation_count = 0;
482     int64_t constant_allocation_bytes = 0;
483     int64_t maybe_live_out_allocation_count = 0;
484     int64_t maybe_live_out_allocation_bytes = 0;
485     int64_t preallocated_temp_allocation_count = 0;
486     int64_t preallocated_temp_allocation_bytes = 0;
487     int64_t preallocated_temp_fragmentation_bytes = -1;
488     int64_t total_allocation_count = 0;
489     int64_t total_allocation_bytes = 0;
490     int64_t total_fragmentation_bytes = -1;
491 
492     std::string ToString() const;
493   };
GetStats()494   const Stats& GetStats() const { return stats_; }
495 
496  private:
497   // Only BufferAssigner can build or modify BufferAssignments.
498   friend class BufferAssigner;
499 
BufferAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range)500   BufferAssignment(const HloModule* module,
501                    std::unique_ptr<HloOrdering> hlo_ordering,
502                    BufferValue::SizeFunction buffer_size,
503                    LogicalBuffer::AlignmentFunction color_alignment,
504                    std::unique_ptr<HloAliasAnalysis> alias_analysis,
505                    std::unique_ptr<HloLiveRange> hlo_live_range)
506       : module_(module),
507         hlo_ordering_(std::move(hlo_ordering)),
508         buffer_size_(std::move(buffer_size)),
509         color_alignment_(std::move(color_alignment)),
510         alias_analysis_(std::move(alias_analysis)),
511         hlo_live_range_(std::move(hlo_live_range)) {
512     int32_t raw_value = module->config()
513                             .debug_options()
514                             .xla_multiheap_size_constraint_per_heap();
515     // -1 means no constraint.
516     multiheap_size_constraint_per_heap_ =
517         (raw_value == -1) ? UINT64_MAX : raw_value;
518   }
519 
520   // Creates and returns a new BufferAllocation, with no assigned
521   // LogicalBuffers. Ownership is maintained internally.
522   BufferAllocation* NewEmptyAllocation(int64_t size,
523                                        LogicalBuffer::Color color);
524 
525   // Helper that calls NewEmptyAllocation and AddAssignment in one call,
526   // creating an allocation containing a single LogicalBuffer.
527   BufferAllocation* NewAllocation(const HloBuffer& buffer, int64_t size);
528 
529   // Adds a LogicalBuffer to the set assigned to the given allocation.
530   void AddAssignment(BufferAllocation* allocation, const HloBuffer& buffer,
531                      int64_t offset, int64_t size);
532 
533   void AddAssignment(BufferAllocation* allocation, const HloValue& value,
534                      int64_t offset, int64_t size);
535 
536   // Returns the HloModule used to construct this assignment.
module()537   const HloModule& module() const { return *module_; }
538 
539   // Mutable accessors for allocations.
540   BufferAllocation* GetMutableAssignedAllocation(const HloBuffer& buffer);
541   BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
542 
HloBufferSize(const HloBuffer & buffer)543   int64_t HloBufferSize(const HloBuffer& buffer) {
544     int64_t result = buffer_size_(*buffer.values()[0]);
545     for (const HloValue* value : buffer.values()) {
546       DCHECK_EQ(result, buffer_size_(*value));
547     }
548     return result;
549   }
550 
551   // Combines allocations of temporary buffers into one big BufferAllocation.
552   void CombineTempAllocations();
553 
554   // Computes stats for the assignment, to be retrieved by GetStats.
555   Status ComputeSummaryStats();
556 
557   // The vector of buffer allocations. Indexed by BufferAllocation::Index.
558   std::vector<BufferAllocation> allocations_;
559 
560   // The total size of all temporary buffers.
561   int64_t temp_allocation_total_size_ = 0;
562 
563   uint64_t multiheap_size_constraint_per_heap_;
564 
565   // Maps Buffers to the index of the BufferAllocation which holds the buffer.
566   absl::flat_hash_map<const HloValue*, BufferAllocation::Index>
567       allocation_index_for_value_;
568 
569   const HloModule* module_;
570 
571   const std::unique_ptr<HloOrdering> hlo_ordering_;
572 
573   // Function which returns the buffer size for a given logical buffer (shape).
574   BufferValue::SizeFunction buffer_size_;
575 
576   // Function which returns the alignment for a given logical buffer color.
577   LogicalBuffer::AlignmentFunction color_alignment_;
578 
579   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
580 
581   std::unique_ptr<HloLiveRange> hlo_live_range_;
582 
583   Stats stats_;
584 
585   BufferAssignment(const BufferAssignment&) = delete;
586   BufferAssignment& operator=(const BufferAssignment&) = delete;
587 };
588 
589 // A class which constructs a buffer assignment.
590 class BufferAssigner {
591  public:
592   using Colorer = std::function<Status(HloAliasAnalysis*, const HloOrdering&)>;
593   using MustNotLiveOut =
594       std::function<bool(const HloInstruction*, const ShapeIndex&)>;
595 
DefaultColorer()596   static Colorer DefaultColorer() {
597     return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) {
598       for (HloValue* value : alias_analysis->dataflow_analysis().values()) {
599         const HloPosition& defining_position = value->defining_position();
600         if (defining_position.shape().has_layout()) {
601           value->set_color(BufferValue::Color(
602               defining_position.shape().layout().memory_space()));
603         } else {
604           value->set_color(BufferValue::Color(0));
605         }
606       }
607       return OkStatus();
608     };
609   }
610 
611   // Returns false if a buffer cannot be assigned to given allocation.
612 
613   // Build and return a BufferAssignment for the given module. The given
614   // HloOrdering is used to determine buffer liveness. buffer_size and
615   // color_alignment are functions which returns the size and alignment of a
616   // LogicalBuffer. If preset_assignments is provided, those pre-set assignment
617   // offsets will be used. The caller guarantees that those assignments are
618   // valid and they do not overwrite each other.
619   static StatusOr<std::unique_ptr<BufferAssignment>> Run(
620       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
621       BufferValue::SizeFunction buffer_size,
622       LogicalBuffer::AlignmentFunction color_alignment,
623       bool allocate_buffers_for_constants = false,
624       Colorer colorer = DefaultColorer(),
625       std::optional<MustNotLiveOut> must_not_live_out = std::nullopt,
626       HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr,
627       std::unique_ptr<memory_space_assignment::PresetAssignments>
628           preset_assignments = {});
629 
630  private:
BufferAssigner(bool allocate_buffers_for_constants,Colorer colorer,std::optional<MustNotLiveOut> must_not_live_out,std::unique_ptr<memory_space_assignment::PresetAssignments> preset_assignments)631   BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
632                  std::optional<MustNotLiveOut> must_not_live_out,
633                  std::unique_ptr<memory_space_assignment::PresetAssignments>
634                      preset_assignments)
635       : allocate_buffers_for_constants_(allocate_buffers_for_constants),
636         colorer_(colorer),
637         must_not_live_out_(must_not_live_out),
638         preset_assignments_(std::move(preset_assignments)) {}
639   virtual ~BufferAssigner() = default;
640 
641   // Create a buffer assignment.
642   StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
643       const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
644       BufferValue::SizeFunction buffer_size,
645       LogicalBuffer::AlignmentFunction color_alignment,
646       HloDataflowAnalysis::CanShareBuffer can_share_buffer);
647 
648   // Assigns buffers to the instructions in the given computations. "assignment"
649   // is modified to reflect the new buffer assignments. If is_thread_local is
650   // true, then all assigned buffers have the is_thread_local flag set to
651   // true.
652   Status AssignBuffersForComputations(
653       const std::vector<const HloComputation*>& computations,
654       bool is_thread_local,
655       absl::flat_hash_map<const HloComputation*,
656                           absl::flat_hash_set<const HloValue*>>*
657           buffers_to_assign_sequentially,
658       BufferAssignment* assignment);
659 
660   // Returns true if buffer's live range interferences with buffer2's.
661   bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2,
662                            BufferAssignment* assignment);
663 
664   // Assigns pre-set assignments, if provided. These assignments will be added
665   // to assigned_buffers and skip buffer allocation.
666   Status AssignPresetBuffers(
667       absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
668       BufferAssignment* assignment);
669 
670   // Assigns a single hlo buffer to an HLO allocation.
671   Status AssignSingleHloBuffer(
672       const HloBuffer* hlo_buffer, bool is_thread_local,
673       absl::flat_hash_map<const HloComputation*,
674                           absl::flat_hash_set<const HloValue*>>*
675           buffers_to_assign_sequentially,
676       std::vector<BufferAllocation::Index>* allocation_indices,
677       BufferAssignment* assignment);
678 
679   // Assigns 'buffers_to_assign_sequentially' using heap simulation, assuming
680   // the HLO instructions will be executed in the sequential order given by
681   // assignment->liveness().hlo_ordering().SequentialOrder. If
682   // 'run_whole_module_heap_simulation' is true, the heap simulation will be run
683   // assuming all global computations are sequentially ordered.
684   Status AssignBuffersWithSequentialOrdering(
685       const absl::flat_hash_map<const HloComputation*,
686                                 absl::flat_hash_set<const HloValue*>>&
687           buffers_to_assign_sequentially,
688       bool run_whole_module_heap_simulation, BufferAssignment* assignment);
689 
690   // Uses the results of the heap simulator to create a single allocation, with
691   // LogicalBuffers packed to specific offsets.
692   void AssignBuffersFromHeapSimulator(
693       const HeapSimulator::Result<HloValue>& result,
694       BufferAssignment* assignment, LogicalBuffer::Color color);
695 
696   // Tries to assign the given instruction to the given buffer. Returns if the
697   // assignment was successful.
698   bool MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& buffer,
699                          BufferAssignment* assignment);
700 
701   // Split a set of buffers into several sets, each of which contains buffers
702   // colored with the same color.
703   absl::flat_hash_map<LogicalBuffer::Color,
704                       absl::flat_hash_set<const HloValue*>>
705   SplitBuffersByColor(const absl::flat_hash_set<const HloValue*>& buffers);
706 
707   // If true, allocate buffers for constant instructions.
708   bool allocate_buffers_for_constants_;
709 
710   // Functor used to assign colors to newly allocated logical buffers.
711   Colorer colorer_;
712 
713   // An optional function that returns true if the given instruction can't live
714   // out of a computation.
715   std::optional<MustNotLiveOut> must_not_live_out_;
716 
717   // Description of any buffer offsets that are already set by an earlier pass.
718   std::unique_ptr<memory_space_assignment::PresetAssignments>
719       preset_assignments_;
720 
721   BufferAssigner(const BufferAssigner&) = delete;
722   BufferAssigner& operator=(const BufferAssigner&) = delete;
723 };
724 
725 }  // namespace xla
726 
727 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
728