xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/memory_space_assignment.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 // TODO(b/210891274): Use btree_map after build issue in Windows is resolved.
24 #if defined(__GNUC__) || defined(__clang__)
25 #include "absl/container/btree_map.h"
26 #else
27 #include <map>
28 #endif
29 #include "tensorflow/compiler/xla/service/heap_simulator.h"
30 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
31 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h"
32 
33 namespace xla {
34 
35 namespace memory_space_assignment {
36 // Forward Declaration of Options.
37 class Options;
38 
39 // This class contains pre-set assignments determined by memory space
40 // assignment. It contains two data structures: (1) a chunks vector that maps a
41 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info
42 // vector that maps the memory space to information like its allocated size and
43 // heap memory trace. If there is only one alternate memory space like there is
44 // currently, there will be one entry in assignment_info.
45 class PresetAssignments {
46  public:
47   // Contains per-memory-space information like the allocated size and heap
48   // simulator trace.
49   struct AssignmentInformation {
50     int64_t size;
51     HeapSimulatorTrace heap_simulator_trace;
52   };
53 
54   PresetAssignments() = default;
55 
add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)56   void add_chunk(const HloPosition& position,
57                  const HeapSimulator::Chunk& chunk) {
58     chunks_.emplace_back(position, chunk);
59   }
60 
add_scoped_allocation_chunk(HloInstruction * instruction,const HeapSimulator::Chunk & chunk)61   void add_scoped_allocation_chunk(HloInstruction* instruction,
62                                    const HeapSimulator::Chunk& chunk) {
63     scoped_allocation_chunks_.emplace_back(instruction, chunk);
64   }
65 
assignment_information_for_space(int64_t memory_space)66   AssignmentInformation* assignment_information_for_space(
67       int64_t memory_space) {
68     for (auto& space_and_info : assignment_info_) {
69       if (space_and_info.first == memory_space) {
70         return &space_and_info.second;
71       }
72     }
73     assignment_info_.emplace_back(memory_space, AssignmentInformation());
74     return &assignment_info_.back().second;
75   }
76 
chunks()77   absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks()
78       const {
79     return chunks_;
80   }
81 
82   absl::Span<const std::pair<HloInstruction*, HeapSimulator::Chunk>>
scoped_allocation_chunks()83   scoped_allocation_chunks() const {
84     return scoped_allocation_chunks_;
85   }
86 
87   absl::Span<const std::pair<int64_t, AssignmentInformation>>
assignment_informations()88   assignment_informations() const {
89     return assignment_info_;
90   }
91 
92   // Get debugging information.
buffer_info_str()93   std::string buffer_info_str() const { return buffer_info_str_; }
allocation_info_str()94   std::string allocation_info_str() const { return allocation_info_str_; }
95 
96  private:
97   std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_;
98   std::vector<std::pair<HloInstruction*, HeapSimulator::Chunk>>
99       scoped_allocation_chunks_;
100   std::vector<std::pair<int64_t, AssignmentInformation>> assignment_info_;
101   std::string buffer_info_str_;
102   std::string allocation_info_str_;
103 };
104 
105 // A wrapper class around HloCostAnalysis with additional knowledge about the
106 // bandwidths of different memory spaces.
107 class MemorySpaceAssignmentCostAnalysis {
108  public:
109   // An optional Cache object may be provided to some of the methods below to
110   // speed up the lookup.
111   struct Cache {
112     absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
113   };
114 
115   // Function type that can be used to indicate which input/output values are in
116   // the alternate memory.
117   using IsInAlternateMemoryFun =
118       std::function<bool(std::optional<int> /*operand_num*/,
119                          const ShapeIndex& /*index*/, const Shape& /*shape*/)>;
120 
121   virtual ~MemorySpaceAssignmentCostAnalysis() = default;
122 
123   static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
124       const HloCostAnalysis& cost_analysis, const Options& options,
125       const HloModule& module);
126 
cost_analysis()127   const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
128 
129   // Returns a heuristic value that captures how much putting this tensor to the
130   // alternate memory would help if the op is memory bound, or otherwise how far
131   // off is the op to memory boundedness. The larger this number, the higher
132   // priority it will be placed in the alternate memory.
133   float GetAlternateMemoryBenefit(const HloInstruction& instruction,
134                                   float elapsed_time_due_to_alternate_mem,
135                                   Cache* cache = nullptr) const;
136   // Like above, return the benefit of putting the output tensor in the
137   // alternate memory.
138   float GetAlternateMemoryBenefit(const HloPosition& position,
139                                   Cache* cache = nullptr) const;
140   // Like above, return the benefit of putting the input tensor in the alternate
141   // memory.
142   float GetAlternateMemoryBenefit(const HloUse& use,
143                                   Cache* cache = nullptr) const;
144 
145   // Returns a heuristic value of memory boundedness for the given
146   // BufferInterval.  The larger this number, the higher priority it will be
147   // placed in the alternate memory.
148   float GetMemoryBoundedness(
149       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
150       Cache* cache = nullptr) const;
151 
152   // Returns the elapsed time in seconds due to compute only.
153   float GetInstructionElapsedDueToCompute(
154       const HloInstruction& instruction) const;
155 
156   // Returns the elapsed time in seconds due to memory only. If
157   // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will
158   // assume that the corresponding operands or output will be in the alternate
159   // memory space. This is useful for calculating the benefit of placing the
160   // buffer in alternate memory.
161   float GetInstructionElapsedDueToMemory(
162       const HloInstruction& instruction,
163       absl::Span<const std::pair<int64_t, ShapeIndex>>
164           operands_in_alternate_mem = {},
165       absl::Span<const ShapeIndex> outputs_in_alternate_mem = {}) const;
166 
167   // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in
168   // the alternate memory.
169   float GetInstructionElapsedDueToMemory(
170       const HloInstruction& instruction,
171       IsInAlternateMemoryFun is_in_alternate_mem) const;
172 
173   // Returns the estimated elapsed duration of the instruction in seconds.  It
174   // assumes all operands and outputs of the instruction are in the default
175   // memory.
176   virtual float GetInstructionElapsed(const HloInstruction& instruction) const;
177 
178   // Returns the estimated elapsed duration of the instruction in seconds.  It
179   // assumes all operands and outputs of the instruction are in the default
180   // memory, except for the operands and outputs specified to be in the
181   // alternate memory.
182   virtual float GetInstructionElapsedInAlternateMemory(
183       const HloInstruction& instruction,
184       absl::Span<const std::pair<int64_t, ShapeIndex>>
185           operands_in_alternate_mem,
186       absl::Span<const ShapeIndex> outputs_in_alternate_mem) const;
187 
188   // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in
189   // the alternate memory.
190   float GetInstructionElapsedInAlternateMemory(
191       const HloInstruction& instruction,
192       IsInAlternateMemoryFun is_in_alternate_mem) const;
193 
194   // Returns the elapsed time it would take to asynchronously copy the shape
195   // from default to alternate memory space (or vice versa).
196   virtual float GetAsyncCopyElapsed(const Shape& shape) const;
197 
198   int64_t GetScheduleEndTime() const;
199 
200   // Returns the number of nested computation levels this instruction resides
201   // in. If while_only is true, it returns the while loop nest level and 0
202   // means the instruction is not in a while loop.
203   int CalculateComputationNestLevel(const HloInstruction* instruction,
204                                     bool while_only) const;
205 
hlo_live_range()206   const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
options()207   const Options& options() const { return options_; }
208 
209  protected:
MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)210   MemorySpaceAssignmentCostAnalysis(
211       const HloCostAnalysis& cost_analysis, const Options& options,
212       std::unique_ptr<HloAliasAnalysis> alias_analysis,
213       std::unique_ptr<HloLiveRange> hlo_live_range,
214       std::unique_ptr<CallGraph> call_graph)
215       : cost_analysis_(cost_analysis),
216         options_(options),
217         alias_analysis_(std::move(alias_analysis)),
218         hlo_live_range_(std::move(hlo_live_range)),
219         call_graph_(std::move(call_graph)) {}
220 
221  private:
222   const HloCostAnalysis& cost_analysis_;
223   const Options& options_;
224   std::unique_ptr<HloAliasAnalysis> alias_analysis_;
225   std::unique_ptr<HloLiveRange> hlo_live_range_;
226   std::unique_ptr<CallGraph> call_graph_;
227 };
228 
229 // Abstract base class that memory space assignment uses to pick prefetch
230 // intervals.
231 class PrefetchIntervalPicker {
232  public:
233   PrefetchIntervalPicker() = default;
234   virtual ~PrefetchIntervalPicker() = default;
235 
236   // Returns true if the buffer can be allocated in alternate memory space
237   // without any copies (prefetches).
238   virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
239                                                   int64_t start_time,
240                                                   int64_t end_time) const = 0;
241 
242   // Returns the preferred end time for an eviction that starts at a given time
243   // and must end by the given end time.
244   virtual int64_t PreferredEvictionEndTime(const Shape& shape,
245                                            int64_t start_time,
246                                            int64_t latest_end_time) const = 0;
247 
248   // Returns the latest time that a prefetch can start.
249   virtual int64_t LatestPrefetchStartTime(const Shape& shape,
250                                           int64_t start_time, int64_t end_time,
251                                           const HloUse* use) const = 0;
252 
253   // Returns the preferred time that a prefetch can start.
254   virtual int64_t PreferredPrefetchStartTime(
255       const Shape& shape, int64_t earliest_prefetch_start_time,
256       int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0;
257 
258   // Returns the latest time that a prefetch can end that is less than or equal
259   // to proposed_prefetch_end_time.
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time)260   virtual int64_t LatestPrefetchEndTime(
261       int64_t original_prefetch_end_time,
262       int64_t proposed_prefetch_end_time) const {
263     return proposed_prefetch_end_time;
264   }
265 
266   // Returns the estimated end time of a prefetch that starts at the given time.
267   virtual int64_t EstimatedPrefetchEndTime(const Shape& shape,
268                                            int64_t start_time,
269                                            int64_t end_time) const = 0;
270 
271   // Returns the elapsed time in seconds between the logical interval that
272   // corresponds to the instruction schedule.
273   virtual float GetLogicalIntervalElapsed(int64_t start_time,
274                                           int64_t end_time) const = 0;
275 
276   // Begins the iterator for the first start time of the prefetch.
277   virtual void Begin(const HloUse& use, int64_t start_time,
278                      int64_t end_time) = 0;
279 
280   // Advances the start time of the prefetch and returns that value.
281   virtual int64_t Next() = 0;
282 
283   // Returns true if the available prefetch intervals have been exhausted.
284   virtual bool Done() const = 0;
285 
286   // Returns the latest time the prefetch interval picker will have pick.
287   virtual int64_t latest_time() const = 0;
288 
289   // The retry number can be used to modify the interval picking policies. The
290   // first attempt will have a retry_number of 0, then 1, etc.
SetRetryNumber(int retry_number)291   virtual void SetRetryNumber(int retry_number) {
292     retry_number_ = retry_number;
293   }
retry_number()294   int retry_number() const { return retry_number_; }
295 
296   // Returns a debug string for the current state of the prefetch interval
297   // picker.
298   virtual std::string ToDebugString() const = 0;
299 
300   // Returns a debug string for no-copy allocation.
301   virtual std::string ToNoCopyDebugString(const Shape& shape,
302                                           int64_t start_time,
303                                           int64_t end_time) const = 0;
304 
305   // Prefetch interval pickers may return a value corresponding to the benefit
306   // of placing the BufferInterval in the alternate memory. The larger value,
307   // the more beneficial.
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)308   virtual std::optional<float> BufferIntervalAlternateMemoryBenefit(
309       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
310       const {
311     return std::nullopt;
312   }
313 
314  protected:
315   const absl::flat_hash_map<const HloInstruction*, int64_t>*
316       instruction_schedule_ = nullptr;
317   int retry_number_ = 0;
318 };
319 
320 // Prefetch interval picker that uses instruction count to overlap asynchronous
321 // copies with independent computation. The min and max overlap counts describe
322 // the number of independent HLOs overlapped while a value is being prefetched
323 // into the alternate memory (between CopyStart and CopyDone HLO instructions).
324 // max_overlap_count attempts to prevent bringing tensors into the alternate
325 // memory too eagerly and hence occupying the space for other tensors which
326 // might use it.  min_overlap_count attempts to prevent cases where tensors are
327 // prefetched into the alternate memory without sufficient time for the copy to
328 // take place.  In those cases, it's just better to keep the tensor in the
329 // default memory instead of hurting the critical path with this copy that
330 // likely won't finish in time.
331 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker {
332  public:
InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,int64_t max_overlap_count)333   InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,
334                                          int64_t max_overlap_count)
335       : min_overlap_count_(min_overlap_count),
336         max_overlap_count_(max_overlap_count) {}
337 
338   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
339                                           int64_t start_time,
340                                           int64_t end_time) const override;
341 
342   int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time,
343                                    int64_t latest_end_time) const override;
344 
345   int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time,
346                                   int64_t end_time,
347                                   const HloUse* use) const override;
348 
349   int64_t PreferredPrefetchStartTime(const Shape& shape,
350                                      int64_t earliest_prefetch_start_time,
351                                      int64_t latest_prefetch_start_time,
352                                      int64_t prefetch_end_time) const override;
353 
354   int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time,
355                                    int64_t end_time) const override;
356   float GetLogicalIntervalElapsed(int64_t start_time,
357                                   int64_t end_time) const override;
358 
359   void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override;
360 
361   int64_t Next() override;
362   bool Done() const override;
363 
364   int64_t latest_time() const override;
365 
366   std::string ToDebugString() const override;
367   std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time,
368                                   int64_t end_time) const override;
369 
370  private:
371   int64_t min_overlap_count_;
372   int64_t max_overlap_count_;
373   int64_t end_time_;
374   int64_t current_prefetch_time_;
375 };
376 
377 // Forward Declaration of MemorySpaceAssignmentCostAnalysis
378 class MemorySpaceAssignmentCostAnalysis;
379 // Prefetch interval picker that uses cost analysis to overlap asynchronous
380 // copies with independent computation. It uses min (independent computation
381 // duration) / (asynchronous copy duration) ratio to guide whether the prefetch
382 // is within the lower bound. For the upper bound, it restricts the maximum
383 // duration that a buffer may occupy the alternate memory space as a multiple of
384 // the time it would take to copy a buffer that is the size of the alternate
385 // memory. It starts with the preferred ratio in Begin() and works its way for
386 // alternately earlier and later prefetches until hitting min and max ratios.
387 // The value for buffer size for max async copy is a mechanism to prevent
388 // copying small buffers between the two memories unnecessarily. For calculating
389 // the max time that the buffer can reside in alternate memory, we use the
390 // larger of this value and the actual size of the buffer.
391 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
392  public:
393   CostAnalysisPrefetchIntervalPicker(
394       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
395       float min_overlap_to_async_copy_ratio,
396       float preferred_overlap_to_async_copy_ratio,
397       float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes);
398 
399   bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape,
400                                           int64_t start_time,
401                                           int64_t end_time) const override;
402 
403   int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time,
404                                    int64_t latest_end_time) const override;
405 
406   int64_t LatestPrefetchEndTime(
407       int64_t original_prefetch_end_time,
408       int64_t proposed_prefetch_end_time) const override;
409 
410   int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time,
411                                   int64_t end_time,
412                                   const HloUse* use) const override;
413 
414   int64_t PreferredPrefetchStartTime(const Shape& shape,
415                                      int64_t earliest_prefetch_start_time,
416                                      int64_t latest_prefetch_start_time,
417                                      int64_t prefetch_end_time) const override;
418 
419   int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time,
420                                    int64_t end_time) const override;
421   float GetLogicalIntervalElapsed(int64_t start_time,
422                                   int64_t end_time) const override;
423 
424   void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override;
425 
426   int64_t Next() override;
427   bool Done() const override;
428 
429   int64_t latest_time() const override;
430 
431   void SetRetryNumber(int retry_number) override;
432 
433   std::string ToDebugString() const override;
434   std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time,
435                                   int64_t end_time) const override;
436 
437   std::optional<float> BufferIntervalAlternateMemoryBenefit(
438       const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
439       const override;
440 
441  private:
442   // Finds the minimum nest level in the given interval.
443   int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const;
444 
445   // Given the elapsed time to copy this buffer to the alternate memory, returns
446   // the longest time that this buffer may reside in the alternate memory space.
447   float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const;
448 
449   // For each instruction in the flattened schedule, maintain their elapsed time
450   // (in cumulative sum) and while nesting level.
451   std::vector<float> elapsed_time_cumsum_;
452   std::vector<int> while_nest_level_;
453   std::vector<int> computation_nest_level_;
454   // Maintain the index of the most recent (before this instruction) nest level
455   // change in order to efficiently determine the minimum nest level in an
456   // interval.
457   std::vector<int> while_nest_level_change_;
458 
459   const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
460   float min_overlap_to_async_copy_ratio_;
461   float preferred_overlap_to_async_copy_ratio_;
462   float max_async_copy_elapsed_;
463   float max_overlap_multiplier_ = 1.0;
464 
465   float async_copy_elapsed_;
466   float inst_elapsed_reduction_;
467   int64_t end_logical_time_;
468   int64_t earliest_prefetch_time_;
469   int64_t latest_prefetch_time_;
470   bool using_increasing_prefetch_time_iterator_ = true;
471   int64_t increasing_prefetch_time_iterator_;
472   int64_t decreasing_prefetch_time_iterator_;
473 
474   std::vector<float> while_execution_counts_;
475 };
476 
477 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each
478 // instruction in the module. It will greedily try placing as as many values in
479 // the alternate memory space as possible. It uses the heap simulator to
480 // determine the actual allocation offsets of values in the alternate memory
481 // space to account for fragmentation. The default memory space is assumed to be
482 // large enough to hold the values that could not be placed in the alternate
483 // memory space.
484 class MemorySpaceAssignment {
485  public:
486   using Chunk = HeapSimulator::Chunk;
487   using BufferInterval =
488       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval;
489   using BufferIntervalCompare =
490       GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare;
491   using IsAllowedInAlternateMemoryFunction =
492       std::function<bool(const HloValue&)>;
493   using IsUseAllowedInAlternateMemoryFunction =
494       std::function<bool(const HloUse&)>;
495   using ReservedScopedMemoryFunction =
496       std::function<int64_t(const HloInstruction*)>;
497 
498   // MemorySpaceAssignment uses a notion of a slow and large default memory
499   // space and a fast and small alternate memory space.
500   enum class MemorySpace { kDefault, kAlternate };
501 
502   // Forward declaration for Allocation.
503   class Allocation;
504   class ParentAllocation;
505 
506   // This class represents an allocation that might either be in the default or
507   // alternate memory. An HloValue might live in multiple different allocations
508   // over its lifetime. The lifetimes of the allocations are defined using
509   // start_time and end_time, which corresponds to the instruction indexes in
510   // the flattened schedule. Each of these allocations might partially overlap
511   // with each other. CopyAllocation defined below represents asynchronous
512   // copies between Allocations.
513   //
514   // Consider an instruction Foo, and its users Bar and Baz, and the times given
515   // in terms of the flattened schedule of the entire module:
516   //
517   //      Foo:10
518   //       /   \
519   //    Bar:14  \
520   //           Baz:25
521   //
522   // A valid memory space assignment could be like the following:
523   //
524   //  Time:         10 ... 14        ...      25
525   //                Foo    Bar                Baz
526   //  Alternate     +-------+           +-----+
527   //  Default           +---------------------+
528   //                    ^   ^           ^     ^
529   //                    |   |           |     |
530   //                evict   evict  prefetch  prefetch
531   //                start    end    start      end
532   //
533   // This would be represented with:
534   //   - Allocation(memory_space=kAlternate, start_time=10, end_time=14)
535   //   - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25)
536   //   - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25)
537   class Allocation {
538     friend class ParentAllocation;
539 
540    public:
Allocation(HloPosition defining_position,MemorySpace memory_space,std::optional<Chunk> chunk,int64_t start_time,int64_t end_time,bool is_scoped_allocation)541     Allocation(HloPosition defining_position, MemorySpace memory_space,
542                std::optional<Chunk> chunk, int64_t start_time, int64_t end_time,
543                bool is_scoped_allocation)
544         : defining_position_(defining_position),
545           memory_space_(memory_space),
546           chunk_(chunk),
547           start_time_(start_time),
548           end_time_(end_time),
549           is_scoped_allocation_(is_scoped_allocation) {
550       CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({}));
551     }
552     virtual ~Allocation() = default;
553 
is_copy_allocation()554     virtual bool is_copy_allocation() const { return false; }
555 
556     // Adds a use to this allocation.
557     void AddUse(HloUse use);
558 
559     // Extends the end time of this allocation.
Extend(int64_t end_time)560     void Extend(int64_t end_time) { end_time_ = end_time; }
561 
562     // After all of the time ranges for the allocations have been assigned,
563     // Process morphs the instructions affected to assign the memory spaces and
564     // insert asynchronous copy instructions if necessary.
565     virtual Status Process();
566 
567     // An optional post-process step that will be called after all allocations
568     // have been processed.
PostProcess()569     virtual Status PostProcess() { return OkStatus(); }
570 
571     // Marks (adds this allocation to needed_allocations) if this allocation is
572     // needed. Allocation and CopyAllocations are always needed and
573     // ParentAllocations are needed if they have any uses or if other
574     // CopyAllocation or ParentAllocations depend on them.
575     virtual void MarkIfNeeded(
576         absl::flat_hash_set<const Allocation*>& needed_allocations) const;
577 
578     // Marks this allocation as needed.
579     virtual void MarkNeeded(
580         absl::flat_hash_set<const Allocation*>& needed_allocations) const;
581 
582     // Returns the defining position for this allocation.
defining_position()583     virtual HloPosition defining_position() const { return defining_position_; }
584 
585     // Returns the time the buffer is first available to be used. For
586     // Allocation, this is start_time.
earliest_available_time()587     virtual int64_t earliest_available_time() const { return start_time_; }
588 
uses()589     const std::vector<HloUse>& uses() const { return uses_; }
memory_space()590     MemorySpace memory_space() const { return memory_space_; }
591     // Returns the associated chunk that may be a nullopt if the allocation is
592     // in the default memory space.
maybe_chunk()593     std::optional<Chunk> maybe_chunk() const { return chunk_; }
594     // Returns the associated chunk. The caller should ensure that the chunk is
595     // defined (the allocation should be in the alternate memory space).
chunk()596     Chunk chunk() const {
597       CHECK(chunk_.has_value());
598       return *chunk_;
599     }
mutable_chunk()600     Chunk* mutable_chunk() { return &*chunk_; }
set_start_time(int64_t start_time)601     void set_start_time(int64_t start_time) { start_time_ = start_time; }
start_time()602     int64_t start_time() const { return start_time_; }
end_time()603     int64_t end_time() const { return end_time_; }
is_scoped_allocation()604     bool is_scoped_allocation() const { return is_scoped_allocation_; }
605 
606     bool operator==(const Allocation& other) const;
607     virtual std::string ToString() const;
608 
609    protected:
610     // Recursively create kGetTupleElement instructions if the defining position
611     // shape is not an array. Returns the new instruction that has array shape.
612     HloInstruction* AddGetTupleElements() const;
613 
614     HloPosition defining_position_;
615     std::vector<HloUse> uses_;
616     MemorySpace memory_space_;
617     std::optional<Chunk> chunk_;
618     int64_t start_time_;
619     int64_t end_time_;
620     const bool is_scoped_allocation_;
621   };
622 
623   // This class represents an allocation as a result of an asynchronous copy.
624   // Note: CopyStart instructions are inserted after `start_time` or later,
625   // while CopyDone instructions are inserted before
626   // `copy_done_schedule_before_time` or earlier.
627   class CopyAllocation : public Allocation {
628    public:
629     CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
630                    std::optional<Chunk> chunk, int64_t start_time,
631                    int64_t end_time, int64_t copy_done_schedule_before_time,
632                    bool is_cross_program_prefetch = false)
633         : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
634                      start_time, end_time, /*is_scoped_allocation=*/false),
635           prev_allocation_(prev_allocation),
636           copy_start_schedule_after_(start_time),
637           copy_done_schedule_before_(copy_done_schedule_before_time),
638           is_cross_program_prefetch_(is_cross_program_prefetch) {}
639 
is_copy_allocation()640     bool is_copy_allocation() const override { return true; }
641 
642     Status Process() override;
643 
644     void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
645         const override;
646 
defining_position()647     HloPosition defining_position() const override {
648       // Unless explicitly set, the defining position of a copy allocation in
649       // retrieved from the previous allocation. This is because we don't create
650       // new CopyStart/CopyDone instructions until later and the position should
651       // point to the previous (copy or otherwise) allocation's position for the
652       // original defining position.
653       if (defining_position_.instruction == nullptr) {
654         return prev_allocation_.defining_position();
655       }
656       return defining_position_;
657     }
658 
copy_start()659     HloInstruction* copy_start() const { return copy_start_; }
copy_done()660     HloInstruction* copy_done() const { return copy_done_; }
661 
662     // Returns the time the buffer is first available to be used. For
663     // CopyAllocation, this is when the copy ends, which is
664     // copy_done_schedule_before.
earliest_available_time()665     int64_t earliest_available_time() const override {
666       return copy_done_schedule_before_;
667     }
668 
copy_start_schedule_after()669     int64_t copy_start_schedule_after() const {
670       return copy_start_schedule_after_;
671     }
copy_done_schedule_before()672     int64_t copy_done_schedule_before() const {
673       return copy_done_schedule_before_;
674     }
675 
set_copy_start_schedule_after(int64_t copy_start_schedule_after)676     void set_copy_start_schedule_after(int64_t copy_start_schedule_after) {
677       copy_start_schedule_after_ = copy_start_schedule_after;
678     }
679 
set_copy_done_schedule_before(int64_t copy_done_schedule_before)680     void set_copy_done_schedule_before(int64_t copy_done_schedule_before) {
681       copy_done_schedule_before_ = copy_done_schedule_before;
682     }
683 
is_cross_program_prefetch()684     bool is_cross_program_prefetch() const {
685       return is_cross_program_prefetch_;
686     }
687 
688     bool operator==(const CopyAllocation& other) const;
689     std::string ToString() const override;
690 
691    private:
692     const Allocation& prev_allocation_;
693     // These variables define the scheduling boundaries where CopyStart and
694     // CopyDone can be scheduled. The earliest CopyStart can be scheduled is
695     // after copy_start_schedule_after_ and the latest CopyDone can be scheduled
696     // is before copy_done_schedule_before_.
697     int64_t copy_start_schedule_after_;
698     int64_t copy_done_schedule_before_;
699     bool is_cross_program_prefetch_;
700     HloInstruction* copy_start_;
701     HloInstruction* copy_done_;
702   };
703 
704   // An allocation in the default memory space that mirrors another Allocation
705   // object. This is useful to model an eviction that happens before a while op
706   // so that we don't need to redundantly evict the buffer after the while op as
707   // well.
708   class MirroredAllocation : public Allocation {
709    public:
MirroredAllocation(const Allocation & original_allocation,int64_t time)710     MirroredAllocation(const Allocation& original_allocation, int64_t time)
711         : Allocation(original_allocation.defining_position(),
712                      MemorySpace::kDefault, original_allocation.maybe_chunk(),
713                      /*start_time=*/time,
714                      /*end_time=*/time, /*is_scoped_allocation=*/false),
715           original_allocation_(original_allocation) {}
716 
717     Status Process() override;
718 
719     void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
720         const override;
721 
722     std::string ToString() const override;
723 
724    private:
725     const Allocation& original_allocation_;
726   };
727 
728   // An allocation in default memory space that is defined in the parent
729   // computation. If a value has a copy in the default memory space in the
730   // parent computation, we don't need to evict this buffer in a while loop.
731   class ParentAllocation : public Allocation {
732    public:
ParentAllocation(const Allocation & original_allocation,HloInstruction * calling_instruction,HloPosition position,int64_t time)733     ParentAllocation(const Allocation& original_allocation,
734                      HloInstruction* calling_instruction, HloPosition position,
735                      int64_t time)
736         : Allocation(position, MemorySpace::kDefault,
737                      original_allocation.maybe_chunk(), /*start_time=*/time,
738                      /*end_time=*/time, /*is_scoped_allocation=*/false),
739           original_allocation_(original_allocation),
740           calling_instruction_(calling_instruction) {}
741 
742     Status Process() override;
743     Status PostProcess() override;
744 
745     void MarkIfNeeded(absl::flat_hash_set<const Allocation*>&
746                           needed_allocations) const override;
747     void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations)
748         const override;
749 
750     std::string ToString() const override;
751 
752    private:
753     const Allocation& original_allocation_;
754     HloInstruction* calling_instruction_;
755   };
756 
757   using AllocationSequence = std::vector<std::unique_ptr<Allocation>>;
758   // AllocationValue is used to break up HloValues for each non-trivial position
759   // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An
760   // HloValue may include positions and uses that alias with each other across
761   // multiple computations. We use this class to break these HloValues such that
762   // every AllocationValue has one defining position (that may alias with other
763   // AllocationValues). The uses field of the AllocationValue contains only the
764   // direct uses of the AllocationValue's defining position.
765   //
766   // For example, consider the following HLO snippet:
767   //
768   // Body {
769   //   body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
770   //   get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param),
771   //   index=0
772   //   ...
773   //   ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...)
774   // }
775   //
776   // Cond {
777   //   cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
778   //   ...
779   // }
780   //
781   // add.4 = f32[4,3]{1,0} add(...)
782   // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...)
783   // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond
784   // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0
785   // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...)
786   //
787   // This contains an HloValue that looks like the following:
788   // positions:
789   //  add.4
790   //  body_param {0}
791   //  get-tuple-element.3
792   //  tuple {0}
793   //  cond_param {0}
794   //  tuple.1 {0}
795   //  while {0}
796   //  get-tuple-element.5
797   // uses:
798   //  add.1, operand 0
799   //  tuple, operand 0
800   //  while, operand 0 {0}
801   //  add.5, operand 0
802   //
803   // We break this HloValue up into the following AllocationValues for each
804   // non-trivial position:
805   // AllocationValue1: computation = Entry
806   //  position:
807   //   add.4
808   //  uses:
809   //   while, operand 0 {0}
810   // AllocationValue2: computation = Cond
811   //  position:
812   //   cond_param {0}
813   //  uses:
814   // AllocationValue3: computation = Body
815   //  position:
816   //   body_param {0}
817   //  uses:
818   //   add.1, operand 0
819   //   tuple, operand 0
820   // AllocationValue4: computation = Entry
821   //  position:
822   //   while {0}
823   //  uses:
824   //   add.5, operand 0
825   class AllocationValue {
826    public:
827     // This data structure wraps an HloUse and adds additional metadata that are
828     // useful for allocation.
829     struct Use {
830       // The wrapped HloUse object.
831       HloUse hlo_use;
832       // The logical time this use is scheduled.
833       int64_t time;
834       // All the positions where this use aliases with. The aliased positions
835       // must get the same allocation.
836       std::vector<HloPosition> aliases;
837 
838       bool operator==(const Use& other) const {
839         return hlo_use == other.hlo_use && time == other.time &&
840                aliases == other.aliases;
841       }
842 
843       template <typename H>
AbslHashValueUse844       friend H AbslHashValue(H h, const Use& s) {
845         return H::combine(std::move(h), s.hlo_use, s.time, s.aliases);
846       }
847     };
848 
AllocationValue(const HloValue * value,const HloPosition & position,int64_t size)849     AllocationValue(const HloValue* value, const HloPosition& position,
850                     int64_t size)
851         : value_(value),
852           defining_position_(position),
853           size_(size),
854           requires_contiguous_allocation_(false) {}
855 
defining_position()856     const HloPosition& defining_position() const { return defining_position_; }
defining_instruction()857     const HloInstruction* defining_instruction() const {
858       return defining_position().instruction;
859     }
size()860     int64_t size() const { return size_; }
uses()861     const std::vector<Use>& uses() const { return uses_; }
uses()862     std::vector<Use>& uses() { return uses_; }
value()863     const HloValue* value() const { return value_; }
computation()864     const HloComputation* computation() const {
865       return defining_instruction()->parent();
866     }
allocation_sequence()867     AllocationSequence* allocation_sequence() { return &allocation_sequence_; }
868 
869     // Sets/gets whether this AllocationValue requires allocating it
870     // contiguously throughout its live range (without any copies).
requires_contiguous_allocation()871     bool requires_contiguous_allocation() const {
872       return requires_contiguous_allocation_;
873     }
set_requires_contiguous_allocation(bool requires_contiguous_allocation)874     void set_requires_contiguous_allocation(
875         bool requires_contiguous_allocation) {
876       requires_contiguous_allocation_ = requires_contiguous_allocation;
877     }
878 
AddUse(const HloUse & use,int64_t use_time)879     void AddUse(const HloUse& use, int64_t use_time) {
880       uses_.push_back({use, use_time, {}});
881     }
882 
883     std::string ToString() const;
884     std::string ToShortString() const;
885 
886    private:
887     const HloValue* value_;
888     HloPosition defining_position_;
889     int64_t size_;
890     // If true, there must be a contiguous allocation for this buffer without
891     // any copies.
892     bool requires_contiguous_allocation_;
893     std::vector<Use> uses_;
894     AllocationSequence allocation_sequence_;
895   };
896 
897   // Statistics of asynchronous copies.
898   struct AsyncCopyStats {
899     int64_t max_outstanding_async_copies;
900     int64_t num_prefetches;
901     int64_t prefetch_bytes;
902     int64_t num_evictions;
903     int64_t eviction_bytes;
904   };
905 
906   virtual ~MemorySpaceAssignment() = default;
907 
908   // Runs the MemorySpaceAssignment pass.
909   static StatusOr<std::unique_ptr<PresetAssignments>> Run(
910       HloModule* module, const HloLiveRange& hlo_live_range,
911       const HloAliasAnalysis& alias_analysis, const Options& options);
912 
913   // Calculates asynchronous copy statistics.
914   StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
915 
916   static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare(
917       const MemorySpaceAssignmentCostAnalysis& cost_analysis,
918       MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr);
919 
920   // Verify that the memory space assignment is free of overlapping buffers and
921   // export heap simulator trace to be used by buffer_assignment.
922   Status VerifyAndExportHeapSimulatorTrace();
923 
924  protected:
925   // Main driver of the memory space assignment pass.
926   virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment(
927       const HloLiveRange& hlo_live_range,
928       const HloAliasAnalysis& alias_analysis);
929 
930   // Finds an AllocationSequence for placing buffers in alternate memory using
931   // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
932   // called.
933   virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
934                                         const HloAliasAnalysis& alias_analysis);
935 
options()936   const Options& options() const { return options_; }
937 
MemorySpaceAssignment(HloModule * module,const Options & options,const HloLiveRange & hlo_live_range)938   MemorySpaceAssignment(HloModule* module, const Options& options,
939                         const HloLiveRange& hlo_live_range)
940       : module_(module),
941         options_(options),
942         flattened_instructions_(hlo_live_range.flattened_instruction_sequence()
943                                     .instructions()
944                                     .begin(),
945                                 hlo_live_range.flattened_instruction_sequence()
946                                     .instructions()
947                                     .end()),
948         computations_in_schedule_(),
949         preset_assignments_(std::make_unique<PresetAssignments>()) {
950     for (const auto& computation_and_bound :
951          hlo_live_range.computation_span_times()) {
952       computations_in_schedule_.insert(computation_and_bound.first);
953     }
954   }
955 
956   AllocationSequence allocations_;
957 
module()958   HloModule* module() { return module_; }
959 
960  private:
961   // Process calls Process methods of the allocations after the allocations have
962   // been finalized.
963   Status Process();
964 
965   // Process() might have altered the computation graph by inserting kTuple and
966   // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and
967   // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b),
968   // 1), simply forwards b). Runs to fixed point.
969   Status SimplifyGraph();
970 
971   // FixSchedule inserts asynchronous copies in the schedule.
972   Status FixSchedule();
973 
974   // Export the alternate memory assignments to the PresetAssignments and color
975   // the HLO graph with the determined memory spaces.
976   Status ExportAndColorBuffers();
977 
978   // Schedules asynchronous copies and ensures that the CopyStarts and their
979   // corresponding CopyDones follow the same order.
980   void ScheduleAsynchronousCopies();
981 
982   // Remove the positions and chunks associated with the instruction from
983   // alternate_memory_assignments_.
984   void RemoveAssignmentForInstruction(const HloInstruction* instruction);
985 
986   // Returns the estimated elapsed duration of the hlo module in seconds. It
987   // uses the 'allocations' argument to determine the location (default memory
988   // or alternate memory) of each operand and output of an instruction.
989   float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range,
990                                     const AllocationSequence& allocations);
991 
992   HloModule* module_;
993   const Options& options_;
994   std::vector<HloInstruction*> flattened_instructions_;
995   absl::flat_hash_set<const HloComputation*> computations_in_schedule_;
996   std::unique_ptr<PresetAssignments> preset_assignments_;
997   std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_;
998   std::vector<std::pair<HloInstruction*, Chunk>> scoped_memory_assignments_;
999   int64_t alternate_memory_size_ = 0;
1000 
1001   // These maps hold vectors of new instructions that need to be scheduled after
1002   // (or before) the instruction index in the key. FixSchedule uses these maps
1003   // to modify and fix the schedule.
1004   absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> schedule_after_;
1005   absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> schedule_before_;
1006 };
1007 
1008 // The different options to be passed to the Run() API.
1009 struct Options {
1010   // Backend-specific integer value that describes the alternate memory.
1011   int64_t alternate_memory_space = 0;
1012 
1013   // Maximum size of the alternate memory space.
1014   int64_t max_size_in_bytes = 0;
1015 
1016   // Memory alignment of the alternate memory space.
1017   int64_t alignment_in_bytes = 1;
1018 
1019   // If provided, we sort the buffers using this comparison function
1020   // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial.
1021   std::optional<MemorySpaceAssignment::BufferIntervalCompare>
1022       buffer_interval_compare = std::nullopt;
1023 
1024   // This object determines how early and how late prefetches can occur.
1025   PrefetchIntervalPicker* prefetch_interval_picker = nullptr;
1026 
1027   // This object is used to determine the benefit of a particular allocation.
1028   MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr;
1029 
1030   // Size function for buffer values.
1031   BufferValue::SizeFunction size_fn;
1032 
1033   // This function can be used to prevent certain HloValues (e.g., based on
1034   // the opcode) to be placed on the alternate memory.
1035   MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction
1036       is_allowed_in_alternate_mem_fn;
1037 
1038   // This function can be used to prevent certain HloUses (e.g., based on
1039   // the opcode) to be placed on the alternate memory.
1040   MemorySpaceAssignment::IsUseAllowedInAlternateMemoryFunction
1041       is_use_allowed_in_alternate_mem_fn = [](const HloUse&) { return true; };
1042 
1043   // This function returns the amount of scoped memory in bytes that should be
1044   // reserved during the execution of this instruction.
1045   MemorySpaceAssignment::ReservedScopedMemoryFunction
1046       reserved_scoped_memory_fn = [](const HloInstruction*) { return 0; };
1047 
1048   // If true, we allocate the reserved scoped memory at the same offset. This
1049   // is useful to enable more deduplication between HLOs that have reserved
1050   // scoped memories, but may result in less efficient memory packing.
1051   bool allocate_reserved_scoped_memory_at_same_offset = true;
1052 
1053   // Specifies the upper bound for number of outstanding prefetches and
1054   // evictions, -1 for unlimited.
1055   int64_t max_outstanding_prefetches = -1;
1056   int64_t max_outstanding_evictions = -1;
1057 
1058   // Extra outstanding prefetch limit for while uses (in addition to
1059   // max_outstanding_prefetches).
1060   int64_t while_use_extra_outstanding_prefetch_limit = 0;
1061 
1062   // Specifies the maximum number of retries that will be performed for each
1063   // value in case prefetching failed due to running out of asynchronous
1064   // copies or asynchronous copy resource.
1065   int64_t max_retries = 1;
1066 
1067   // The maximum number of repacks that we are willing to perform in case we
1068   // can't allocate a buffer due to running out of memory. If this value is
1069   // greater than 0, repacker must be non-nullptr.
1070   int64_t max_repacks = 0;
1071 
1072   // This variable is used by the cost analysis in estimating how many times
1073   // each while loop will execute. Nested loops will be assumed to have
1074   // executed pow(while_execution_count, nesting_level) times.
1075   uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL;
1076 
1077   float async_copy_bandwidth_bytes_per_second = 0.0f;
1078 
1079   float alternate_mem_bandwidth_bytes_per_second = 0.0f;
1080 
1081   // The repacking algorithm to reduce fragmentation. Must be non-null if
1082   // max_repacks is greater than 0.
1083   MemorySpaceAssignmentRepacker* repacker = nullptr;
1084 
1085   // This is only useful for testing, repack after every allocation.
1086   bool repack_after_every_allocation = false;
1087 
1088   // If true, tries allocating buffers across (e.g., before and inside a while
1089   // loop body) sequential calls (kWhile, kCall, and kConditional).
1090   bool allocate_across_sequential_calls = false;
1091 
1092   // If true, verifies the memory space assignment against overlapping
1093   // buffers.
1094   bool verify = false;
1095 
1096   // If not nullptr, this function is called to dump debugging information.
1097   // The first argument is appended to the file name and the second argument
1098   // is the contents of the file.
1099   std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr;
1100 
1101   // Enable prefetching buffers into preferred memory across program
1102   // boundaries
1103   bool enable_cross_program_prefetch = true;
1104 
1105   // If true, use buffer_interval_compare to determine which buffers to
1106   // prefetch across program boundaries.
1107   bool default_cross_program_prefetch_heuristic = false;
1108 
1109   // Enable cross-program prefetch freeing optimization where the
1110   // cross-program-prefetched buffer can be reused.
1111   bool enable_cross_program_prefetch_freeing = true;
1112 
1113   // Enable redundant eviction optimization in/around while loops. If enabled,
1114   // this optimization would keep a copy of the buffer in the default memory in
1115   // addition to alternate memory to eliminate redundant evictions.
1116   bool enable_while_redundant_eviction_elimination = true;
1117 
1118   // An optional memory space assignment autotuning config, which is used
1119   // to sort allocated buffers.
1120   std::optional<std::vector<uint64_t>> autotuning_config = std::nullopt;
1121 };
1122 
1123 // A struct representing an asynchronous copy with its logical start and end
1124 // time (time that copy done is scheduled), the resource this copy would use,
1125 // its destination memory space, and a unique ID.
1126 struct AsynchronousCopy {
1127   int64_t start_time;
1128   int64_t end_time;
1129   float resource;
1130   MemorySpaceAssignment::MemorySpace destination;
1131   int64_t id;
1132 
1133   std::tuple<int64_t, int64_t, float, MemorySpaceAssignment::MemorySpace,
1134              int64_t>
AsTupleAsynchronousCopy1135   AsTuple() const {
1136     return std::make_tuple(start_time, end_time, resource, destination, id);
1137   }
1138 };
1139 
1140 // Compare asynchronous copies such that an earlier start time has the same or
1141 // earlier end time and an earlier end time has the same or earlier start time.
1142 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b);
1143 
1144 bool operator==(const AsynchronousCopy& a, const AsynchronousCopy& b);
1145 bool operator!=(const AsynchronousCopy& a, const AsynchronousCopy& b);
1146 
1147 // Helper class to enforce asynchronous copy resources by keeping track of
1148 // available copy bandwidth and elapsed times of overlapped operations. It
1149 // maintains a list of initial resources that correspond to the elapsed times of
1150 // overlapped operations. As asynchronous copies are added, the available
1151 // resource is subtracted to keep track of the current state.
1152 class AsynchronousCopyResource {
1153  public:
1154   AsynchronousCopyResource() = default;
1155 
1156   // The constructor needs the initial resources.
AsynchronousCopyResource(absl::Span<const float> initial_resources)1157   explicit AsynchronousCopyResource(absl::Span<const float> initial_resources)
1158       : initial_resources_(initial_resources.begin(), initial_resources.end()),
1159         delay_(initial_resources.size(), 0) {}
1160 
1161   // Adds the given asynchronous copy and updates the current resources. CHECK
1162   // fails if there aren't enough resources to satisfy this copy (the caller
1163   // should use HasEnoughResource first to ensure there is enough resource).
1164   void AddCopy(const AsynchronousCopy& copy);
1165 
1166   // Removes the given copy and frees the resource.
1167   void RemoveCopy(const AsynchronousCopy& copy);
1168 
1169   // Returns true if a copy with the given start and end times and resource can
1170   // be satisfied.
1171   bool HasEnoughResource(int64_t start_time, int64_t end_time, float resource);
1172 
1173   // This is only used for debugging and testing purposes, it returns the
1174   // currently available resource at each logical time.
GetCurrentResources()1175   std::vector<float> GetCurrentResources() const {
1176     std::vector<float> current_resources(initial_resources_.begin(),
1177                                          initial_resources_.end());
1178     for (int i = 0; i < current_resources.size(); ++i) {
1179       current_resources[i] -= std::min(current_resources[i], delay_[i]);
1180     }
1181     return current_resources;
1182   }
1183 
1184  private:
1185   // Internal helper method to implement adding/removing/checking resources.
1186   // Only updates the current resources if update_current_resource is true. The
1187   // current_copy points to an iterator in async_copies_ and this
1188   bool ConsumeResource(
1189       int64_t start_time, int64_t end_time, float resource,
1190       bool update_current_resource,
1191       const std::list<AsynchronousCopy>::iterator* current_copy = nullptr,
1192       float resource_to_free = 0.0);
1193 
1194   // We maintain a linked list of asynchronous copies sorted by the start times.
1195   // This allows us to efficiently find the copy that starts right after another
1196   // one because adding a copy might push a copy further into the future.
1197   std::list<AsynchronousCopy> async_copies_;
1198 // To make the lookups into async_copies_ more efficient, we also maintain a
1199 // binary tree that is indexed by the start time, containing iterators into
1200 // async_copies_.
1201 // TODO(b/210891274): Use btree_map after build issue in Windows is resolved.
1202 #if defined(__GNUC__) || defined(__clang__)
1203   absl::btree_map<int64_t, std::list<AsynchronousCopy>::iterator>
1204       async_copy_time_map_;
1205 #else
1206   std::map<int64_t, std::list<AsynchronousCopy>::iterator> async_copy_time_map_;
1207 #endif
1208   std::vector<float> initial_resources_;
1209   std::vector<float> delay_;
1210 };
1211 
1212 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of
1213 // maximum size.
1214 class AlternateMemoryBestFitHeap
1215     : public GlobalDecreasingSizeBestFitHeap<HloValue> {
1216  public:
1217   using MemorySpace = MemorySpaceAssignment::MemorySpace;
1218   using AllocationValue = MemorySpaceAssignment::AllocationValue;
1219 
1220   AlternateMemoryBestFitHeap(
1221       MemorySpaceAssignment::AllocationSequence* allocations,
1222       const Options& options, const HloAliasAnalysis& alias_analysis,
1223       const HloLiveRange& hlo_live_range);
1224 
1225   // Allocates a buffer in preferred memory with whole program lifetime and
1226   // enables prefetching prefetch_candidate from default memory across program
1227   // boundaries.
1228   void AllocateCrossProgramPrefetchBuffer(
1229       HloModule* module, std::optional<BufferInterval> prefetch_candidate);
1230 
1231   HeapSimulator::Result<HloValue> Finish() override;
1232 
1233  protected:
1234   // Given a buffer interval, returns the colocated intervals. Unlike the
1235   // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it
1236   // returns the colocated intervals sorted by scheduled time.
1237   std::vector<const BufferInterval*> GetSortedColocatedIntervals(
1238       const BufferInterval& interval) const;
1239 
1240   // Given a BufferInterval, creates AllocationValue objects and corresponding
1241   // AllocationSequences and appends them into allocation_sequence_list_.
1242   void CreateAllocationValues(
1243       const BufferInterval& buffer_interval,
1244       std::vector<AllocationValue>& allocation_values) const;
1245 
1246   // Given colocated intervals, populates allocation_values with the
1247   // corresponding AllocationValue objects.
1248   virtual void CreateAllocationValuesFromColocatedIntervals(
1249       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1250           colocated_intervals,
1251       std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values);
1252 
1253   // Go through all the uses in the AllocationValues and find the aliasing
1254   // positions.
1255   void FindAliases(std::vector<AllocationValue>* allocation_values) const;
1256 
allocations()1257   MemorySpaceAssignment::AllocationSequence* allocations() {
1258     return allocations_;
1259   }
options()1260   const Options& options() const { return options_; }
alias_analysis()1261   const HloAliasAnalysis& alias_analysis() { return alias_analysis_; }
hlo_live_range()1262   const HloLiveRange& hlo_live_range() { return hlo_live_range_; }
1263 
1264  private:
1265   // We inherit AllocationBlock struct to attach the Allocation information to
1266   // make importing repacked offsets easier.
1267   struct RepackAllocationBlock
1268       : MemorySpaceAssignmentRepacker::AllocationBlock {
1269     MemorySpaceAssignment::Allocation* allocation;
1270   };
1271 
1272   // A data structure we use to associate Allocation objects that are aliased
1273   // and must get the same offset.
1274   struct AliasedOffset {
1275     int64_t offset;
1276     absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations;
1277   };
1278 
1279   // An allocation request for a use segment. A use segment is the time segment
1280   // between the definition and the first use, and the time segment between the
1281   // uses of a buffer. For example, the time between the definition and Use1, is
1282   // the first segment, and the time between Use1 and Use2 is the second segment
1283   // and so on:
1284   //
1285   //        +------+----------+-------+
1286   //       /        \          \       \
1287   //      /          v          v       v
1288   //    Def         Use1       Use2    Use3
1289   //     <----------> <--------> <----->
1290   //        Segment    Segment   Segment
1291   //
1292   // start_time and end_time are the start and end logical times of the segment.
1293   // use_times is a sorted sequence of the times of all uses.
1294   // latest_prefetch_time is the latest time we can schedule the CopyDone for a
1295   // prefetch.
1296   // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced.
1297   // If earliest_prefetch_time is set, prefetches cannot start before this
1298   // value.
1299   struct AllocationRequest {
1300     int64_t start_time;
1301     int64_t end_time;
1302     int64_t latest_prefetch_time;
1303     int64_t size;
1304     bool allow_no_copy_alternate_mem_allocation;
1305     std::optional<int64_t> earliest_prefetch_time;
1306     AliasedOffset* preferred_offset;
1307     const MemorySpaceAssignment::AllocationValue::Use* use;
1308     MemorySpaceAssignment::AllocationValue* allocation_value;
1309     absl::Span<const int64_t> all_use_times;
1310   };
1311 
1312   // This struct contains mandatory memory assignments at a given time. E.g., an
1313   // input's required memory assignment time would correspond to the definition
1314   // time of the parameter instruction, and an output's time would correspond to
1315   // the time of last use.
1316   struct RequiredMemoryAssignment {
1317     MemorySpaceAssignment::MemorySpace memory_space;
1318     int64_t time;
1319     AliasedOffset* offset;
1320 
equals_ignoring_timeRequiredMemoryAssignment1321     bool equals_ignoring_time(const RequiredMemoryAssignment& other) const {
1322       return memory_space == other.memory_space && offset == other.offset;
1323     }
1324 
1325     bool operator==(const RequiredMemoryAssignment& other) const {
1326       return memory_space == other.memory_space && time == other.time &&
1327              offset == other.offset;
1328     }
1329 
1330     bool operator!=(const RequiredMemoryAssignment& other) const {
1331       return !(*this == other);
1332     }
1333   };
1334 
1335   // Result of an allocation, prefetch, eviction etc. request.  The result is
1336   // either kSuccess or a bitwise OR of one or more failures. The values are
1337   // unique powers of two. To check if a result contains a particular failure,
1338   // use the result_is method. To add a new failure to a result, use the
1339   // result_mark method.
1340   enum class Result {
1341     // Successful allocation.
1342     kSuccess = 0,
1343     // Allocation failed because we ran out of alternate memory.
1344     kFailOutOfMemory = 1,
1345     // A no-copy allocation couldn't be performed because the previous
1346     // allocation wasn't in the alternate memory space.
1347     kFailPrevAllocationNotInAlternateMem = 2,
1348     // A no-copy allocation couldn't be performed because the live range was too
1349     // long.
1350     kFailLiveRangeTooLong = 4,
1351     // A prefetching couldn't be performed because the live range was too short.
1352     kFailLiveRangeTooShort = 8,
1353     // Ran out of outstanding asynchronous copy limit either during prefetching
1354     // or eviction.
1355     kFailOutOfAsyncCopies = 16,
1356     // A prefetching couldn't be performed because the asynchronous copy
1357     // resource was violated.
1358     kFailViolatesAsyncCopyResource = 32,
1359     // An allocation failure happened that requires uncommitting all the pending
1360     // allocations. Usually this is due to a situation requiring an eviction but
1361     // the eviction couldn't be performed.
1362     kFailRequiresUncommit = 64
1363   };
1364 
1365   // Return true if the result belongs to a failure.
result_is(Result result,Result failure)1366   static bool result_is(Result result, Result failure) {
1367     return static_cast<int>(result) & static_cast<int>(failure);
1368   }
1369 
1370   // Mark (bitwise OR) a failure to the result.
result_mark(Result failure,Result & result)1371   static Result result_mark(Result failure, Result& result) {
1372     result = static_cast<Result>(static_cast<int>(result) |
1373                                  static_cast<int>(failure));
1374     return result;
1375   }
1376 
1377   // Return true if the result is a failure that requires us to uncommit pending
1378   // chunks.
result_requires_uncommit(Result result)1379   static bool result_requires_uncommit(Result result) {
1380     return result_is(result, Result::kFailRequiresUncommit);
1381   }
1382 
1383   // Return true if the result is a failure either due to running out of
1384   // outstanding asynchronous copies or due to violating asynchronous copy
1385   // ordering.
result_failed_because_of_async_copy(Result result)1386   static bool result_failed_because_of_async_copy(Result result) {
1387     return result_is(result, Result::kFailOutOfAsyncCopies) ||
1388            result_is(result, Result::kFailViolatesAsyncCopyResource);
1389   }
1390 
1391   // Allocates buffers for instructions that need reserved scoped allocations in
1392   // the alternate memory space.
1393   void AllocateReservedScopedAllocations();
1394 
1395   // Returns the AliasedOffset object associated with the allocation.
1396   AliasedOffset* GetAliasedOffset(
1397       const MemorySpaceAssignment::Allocation& allocation);
1398 
1399   // If aliased_offset is non-null, this method adds the allocation to
1400   // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds
1401   // the allocation to this new AliasedOffset.
1402   void CreateOrAddToAliasedOffset(
1403       const MemorySpaceAssignment::Allocation& allocation,
1404       AliasedOffset* aliased_offset);
1405 
1406   // Given an allocation sequence, returns the live allocation at time with a
1407   // preference towards allocations in alternate memory. Returns nullptr if no
1408   // allocation is alive at that time.
1409   static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
1410       const MemorySpaceAssignment::AllocationSequence& allocations,
1411       int64_t time);
1412 
1413   // Returns true if the use is allowed in the alternate memory.
1414   bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
1415                                      const HloUse& use) const;
1416 
1417   // Finds allocations for allocation values generated from colocated intervals.
1418   // All of the allocation values have a must-alias relationship with each
1419   // other. Returns either kSuccess if all of the sites could be placed in the
1420   // alternate memory or a bitwise OR of failure reasons why they couldn't
1421   Result AllocateAllocationValues(
1422       absl::Span<AllocationValue> allocation_values);
1423 
1424   // Finds an allocation for an allocation request for a segment (see the
1425   // documentation for AllocationRequest above how a segment is defined).
1426   //
1427   // It performs three things in the following order:
1428   //  1- Allocate the allocation request entirely in the alternate memory, if
1429   //     there is enough space and if the prefetch interval picker allows.
1430   //  2- If (1) was unsuccessful, and the only allocation for
1431   //     this buffer was in the alternate memory, we try to perform a prefetch.
1432   //  3- If (1) was unsuccessful, prefetch the buffer into the alternate memory,
1433   //     if there is enough space and if the prefetch interval picker allows.
1434   //
1435   // If an eviction (2) was requested and was unsuccessful, this method returns
1436   // Result::kFailRequiresUncommit. This means we could not find a suitable
1437   // allocation, so all previous allocations for this buffer must be removed and
1438   // allocated in the default memory. Otherwise, this method may return
1439   // Result::kSuccess if the buffer could be placed in alternate memory or some
1440   // other Result with an OR of reasons why the buffer couldn't be placed in
1441   // alternate memory.
1442   Result AllocateSegment(const AllocationRequest& request);
1443 
1444   // Try allocating in alternate memory without any copies.
1445   Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request);
1446 
1447   // Try evicting to default memory space.
1448   Result Evict(const AllocationRequest& request);
1449 
1450   // Returns the time a copy done of a prefetch should be scheduled.
1451   int64_t FindPrefetchEndTime(const AllocationRequest& request,
1452                               int64_t earliest_prefetch_time) const;
1453 
1454   // Try prefetching to alternate memory space.
1455   Result Prefetch(
1456       const AllocationRequest& request,
1457       const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem);
1458 
1459   // Find the best possible chunk candidate, where it has the longest possible
1460   // availability if no preferred offset is given, or at the preferred_offset if
1461   // it is given.
1462   std::optional<Chunk> FindBestChunkCandidate(
1463       const AllocationRequest& request, const AliasedOffset* preferred_offset,
1464       BufferInterval* alternate_mem_interval) const;
1465 
1466   // Returns the required assignment at a particular time, if available.
1467   std::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt(
1468       const HloValue* buffer, int64_t time) const;
1469 
1470   // Searches for aliases in the use for a required assignment, and returns it
1471   // if found.
1472   std::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse(
1473       const AllocationValue::Use& use) const;
1474 
1475   // Goes through the colocated intervals and adds any required assignment.
1476   void AddRequiredAssignmentsForColocatedIntervals(
1477       absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1478           colocated_intervals);
1479 
1480   // Propagates aliased required assignment for a given position.
1481   void AddAliasedRequiredAssignment(
1482       const HloInstruction* instruction, ShapeIndex index,
1483       const MemorySpaceAssignment::Allocation* aliased_allocation);
1484 
1485   // This sets a required assignment. CHECK fails if there is a conflicting
1486   // required assignment at the same time.
1487   void AddRequiredAssignment(const HloValue* value,
1488                              const HloInstruction* instruction,
1489                              MemorySpace memory_space, int64_t time,
1490                              AliasedOffset* offset = nullptr);
1491   void AddRequiredAssignment(const HloInstruction* instruction,
1492                              ShapeIndex index, MemorySpace memory_space,
1493                              AliasedOffset* offset = nullptr);
1494 
1495   // Adds input and outputs as required assignments.
1496   void AddInputAndOutputRequiredAssignments();
1497 
1498   // Returns true if the colocated intervals in the argument are in a parameter
1499   // or root instruction of the entry computation and are reserved by the user
1500   // to be in the alternate memory space.
1501   bool AreIntervalsReservedInAlternateMemory(
1502       absl::Span<const BufferInterval* const> colocated_intervals) const;
1503 
1504   // Since the allocations are recorded to the AllocationSequence, we don't
1505   // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap
1506   // to avoid unnecessarily adding the chunk to the chunk map.
AddToChunkMap(const HloValue * buffer,Chunk chunk)1507   void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {}
1508 
1509   // Returns true if the addition of an asynchronous copy in the given time
1510   // interval would violate the maximum number of asynchronous copies. An extra
1511   // async copy limit can be provided to increase the limit of asynchronous
1512   // copies for this instance.
1513   bool ViolatesMaximumOutstandingAsyncCopies(
1514       int64_t start_time, int64_t end_time, bool is_prefetch,
1515       int64_t extra_async_copy_limit = 0) const;
1516 
1517   // Exports the allocations for repacking and puts them into the vector in the
1518   // parameter.
1519   void ExportAllocationsForRepacking(
1520       std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>&
1521           allocations);
1522 
1523   // Imports repacked allocations and updates the internal data structures
1524   // consistent with the new packing.
1525   void ImportRepackedAllocations();
1526 
1527   // Adds an asynchronous copy to the allocations.
1528   void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation,
1529                     MemorySpace memory_space, std::optional<Chunk> chunk,
1530                     int64_t start_time, int64_t end_time,
1531                     int64_t copy_done_schedule_before_time,
1532                     MemorySpaceAssignment::AllocationSequence* allocations,
1533                     AliasedOffset* aliased_offset, float resource,
1534                     bool is_cross_program_prefetch = false);
1535 
1536   // This method is used for committing the chunk candidate but adding it to
1537   // pending_chunks_ so that we can "uncommit" them in case we need to roll back
1538   // this allocation sequence.
1539   void AddToPendingChunks(const BufferInterval& buffer_interval,
1540                           const Chunk& chunk);
1541   // If we need to remove the allocations for this allocation sequence, this
1542   // removes pending chunks and asynchronous copies in the respective pending
1543   // buffers from the interval trees. If an allocation request returns
1544   // kFailRequiresUncommit, this method must be called.
1545   void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values);
1546 
1547   // Finalizes the allocations where they can no longer be uncommitted.
1548   void FinalizeAllocations(absl::Span<AllocationValue> allocation_values);
1549 
1550   // Clears all pending chunks and asynchronous copies.
1551   void ClearPendingChunks();
1552 
1553   // Append buffer and allocation infos for debugging and dump it into a file,
1554   // if enabled.
1555   void AppendBufferInfoDebugString(const BufferInterval& interval,
1556                                    std::string* debug_str) const;
1557   void AppendScopedAllocationBufferInfoDebugString(
1558       const HloInstruction* instruction, int64_t time, int64_t size,
1559       std::string& debug_str) const;
1560   void AppendAllocationInfoDebugString(
1561       const MemorySpaceAssignment::Allocation& allocation,
1562       std::string& debug_str) const;
1563   void DumpDebugStringsIfEnabled() const;
1564 
1565   // Returns the available heap size in the alternate memory.
available_heap_size()1566   int64_t available_heap_size() const {
1567     return options_.max_size_in_bytes - reserved_in_bytes_;
1568   }
1569 
1570   // Returns the earliest time in the [start_time, end_time] range that a new
1571   // allocation with the given size would fit in the alternate memory. If it
1572   // doesn't fit, it returns nullopt.
1573   std::optional<int> FindEarliestTimeToSatisfyPeakMemory(int start_time,
1574                                                          int end_time,
1575                                                          int64_t size) const;
1576 
1577   // Creates and returns a RepackAllocationBlock.
MakeRepackAllocationBlock(int64_t start_time,int64_t end_time,int64_t size,int64_t initial_offset,int64_t id,MemorySpaceAssignment::Allocation * allocation)1578   static RepackAllocationBlock MakeRepackAllocationBlock(
1579       int64_t start_time, int64_t end_time, int64_t size,
1580       int64_t initial_offset, int64_t id,
1581       MemorySpaceAssignment::Allocation* allocation) {
1582     RepackAllocationBlock allocation_block;
1583     allocation_block.start_time = start_time;
1584     allocation_block.end_time = end_time;
1585     allocation_block.size = size;
1586     allocation_block.offset = -1;
1587     allocation_block.initial_offset = initial_offset;
1588     allocation_block.id = id;
1589     allocation_block.colocations = {};
1590     allocation_block.allocation = allocation;
1591     return allocation_block;
1592   }
1593 
1594   MemorySpaceAssignment::AllocationSequence* allocations_;
1595   const Options& options_;
1596   const HloAliasAnalysis& alias_analysis_;
1597   const HloLiveRange& hlo_live_range_;
1598   // We use a interval tree to keep track of the number of outstanding
1599   // prefetches and evictions.
1600   BufferIntervalTree prefetch_interval_tree_;
1601   BufferIntervalTree eviction_interval_tree_;
1602   AsynchronousCopyResource prefetch_async_copy_resource_;
1603   AsynchronousCopyResource eviction_async_copy_resource_;
1604   // A list of RepackAllocationBlock objects that mirrors allocation sequences,
1605   // used for repacking. We use a list here because we need pointer stability
1606   // for aliased allocations.
1607   std::list<RepackAllocationBlock> repack_allocation_blocks_;
1608   int64_t num_repacks_ = 0;
1609   std::vector<std::pair<BufferInterval, Chunk>> pending_chunks_;
1610   std::vector<AsynchronousCopy> pending_async_copies_;
1611   std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>>
1612       pending_required_assignments_;
1613   // A cache to keep the peak memory usage at each point in the graph. We use
1614   // this to see if the proposed allocation in the alternate memory would fit
1615   // ignoring fragmentation, and if not, we can skip the more expensive lookup
1616   // in the BufferIntervalTree, which also considers fragmentation.
1617   std::vector<int64_t> peak_memory_usage_;
1618   // The data structure that contains AliasedOffset objects and Allocation to
1619   // AliasedOffset map for efficient lookup.
1620   std::list<AliasedOffset> aliased_offsets_;
1621   absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*>
1622       aliased_offset_map_;
1623   // This map contains required memory assignments for HloValues (e.g., input
1624   // and outputs).
1625   absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
1626       required_assignments_;
1627   // Number of bytes reserved in alternate memory space.
1628   int64_t reserved_in_bytes_ = 0;
1629   // A rough measure of the memory pressure of the model, in bytes. Note that
1630   // this is pressure for memory capacity (and not accessed bytes), and for
1631   // alternate memory (not default memory).
1632   int64_t memory_pressure_ = 0;
1633   int64_t next_async_copy_id_ = 0;
1634   // Debug strings.
1635   std::string buffer_info_str_;
1636   std::string allocation_info_str_;
1637 };
1638 }  // namespace memory_space_assignment
1639 }  // namespace xla
1640 
1641 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_
1642