1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 23 // TODO(b/210891274): Use btree_map after build issue in Windows is resolved. 24 #if defined(__GNUC__) || defined(__clang__) 25 #include "absl/container/btree_map.h" 26 #else 27 #include <map> 28 #endif 29 #include "tensorflow/compiler/xla/service/heap_simulator.h" 30 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" 31 #include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" 32 33 namespace xla { 34 35 namespace memory_space_assignment { 36 // Forward Declaration of Options. 37 class Options; 38 39 // This class contains pre-set assignments determined by memory space 40 // assignment. It contains two data structures: (1) a chunks vector that maps a 41 // defining HloPosition to a Chunk (offset and size), and (2) an assignment_info 42 // vector that maps the memory space to information like its allocated size and 43 // heap memory trace. If there is only one alternate memory space like there is 44 // currently, there will be one entry in assignment_info. 45 class PresetAssignments { 46 public: 47 // Contains per-memory-space information like the allocated size and heap 48 // simulator trace. 49 struct AssignmentInformation { 50 int64_t size; 51 HeapSimulatorTrace heap_simulator_trace; 52 }; 53 54 PresetAssignments() = default; 55 add_chunk(const HloPosition & position,const HeapSimulator::Chunk & chunk)56 void add_chunk(const HloPosition& position, 57 const HeapSimulator::Chunk& chunk) { 58 chunks_.emplace_back(position, chunk); 59 } 60 add_scoped_allocation_chunk(HloInstruction * instruction,const HeapSimulator::Chunk & chunk)61 void add_scoped_allocation_chunk(HloInstruction* instruction, 62 const HeapSimulator::Chunk& chunk) { 63 scoped_allocation_chunks_.emplace_back(instruction, chunk); 64 } 65 assignment_information_for_space(int64_t memory_space)66 AssignmentInformation* assignment_information_for_space( 67 int64_t memory_space) { 68 for (auto& space_and_info : assignment_info_) { 69 if (space_and_info.first == memory_space) { 70 return &space_and_info.second; 71 } 72 } 73 assignment_info_.emplace_back(memory_space, AssignmentInformation()); 74 return &assignment_info_.back().second; 75 } 76 chunks()77 absl::Span<const std::pair<HloPosition, HeapSimulator::Chunk>> chunks() 78 const { 79 return chunks_; 80 } 81 82 absl::Span<const std::pair<HloInstruction*, HeapSimulator::Chunk>> scoped_allocation_chunks()83 scoped_allocation_chunks() const { 84 return scoped_allocation_chunks_; 85 } 86 87 absl::Span<const std::pair<int64_t, AssignmentInformation>> assignment_informations()88 assignment_informations() const { 89 return assignment_info_; 90 } 91 92 // Get debugging information. buffer_info_str()93 std::string buffer_info_str() const { return buffer_info_str_; } allocation_info_str()94 std::string allocation_info_str() const { return allocation_info_str_; } 95 96 private: 97 std::vector<std::pair<HloPosition, HeapSimulator::Chunk>> chunks_; 98 std::vector<std::pair<HloInstruction*, HeapSimulator::Chunk>> 99 scoped_allocation_chunks_; 100 std::vector<std::pair<int64_t, AssignmentInformation>> assignment_info_; 101 std::string buffer_info_str_; 102 std::string allocation_info_str_; 103 }; 104 105 // A wrapper class around HloCostAnalysis with additional knowledge about the 106 // bandwidths of different memory spaces. 107 class MemorySpaceAssignmentCostAnalysis { 108 public: 109 // An optional Cache object may be provided to some of the methods below to 110 // speed up the lookup. 111 struct Cache { 112 absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier; 113 }; 114 115 // Function type that can be used to indicate which input/output values are in 116 // the alternate memory. 117 using IsInAlternateMemoryFun = 118 std::function<bool(std::optional<int> /*operand_num*/, 119 const ShapeIndex& /*index*/, const Shape& /*shape*/)>; 120 121 virtual ~MemorySpaceAssignmentCostAnalysis() = default; 122 123 static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create( 124 const HloCostAnalysis& cost_analysis, const Options& options, 125 const HloModule& module); 126 cost_analysis()127 const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } 128 129 // Returns a heuristic value that captures how much putting this tensor to the 130 // alternate memory would help if the op is memory bound, or otherwise how far 131 // off is the op to memory boundedness. The larger this number, the higher 132 // priority it will be placed in the alternate memory. 133 float GetAlternateMemoryBenefit(const HloInstruction& instruction, 134 float elapsed_time_due_to_alternate_mem, 135 Cache* cache = nullptr) const; 136 // Like above, return the benefit of putting the output tensor in the 137 // alternate memory. 138 float GetAlternateMemoryBenefit(const HloPosition& position, 139 Cache* cache = nullptr) const; 140 // Like above, return the benefit of putting the input tensor in the alternate 141 // memory. 142 float GetAlternateMemoryBenefit(const HloUse& use, 143 Cache* cache = nullptr) const; 144 145 // Returns a heuristic value of memory boundedness for the given 146 // BufferInterval. The larger this number, the higher priority it will be 147 // placed in the alternate memory. 148 float GetMemoryBoundedness( 149 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval, 150 Cache* cache = nullptr) const; 151 152 // Returns the elapsed time in seconds due to compute only. 153 float GetInstructionElapsedDueToCompute( 154 const HloInstruction& instruction) const; 155 156 // Returns the elapsed time in seconds due to memory only. If 157 // operands_in_alternate_mem or outputs_in_alternate_mem is provided, it will 158 // assume that the corresponding operands or output will be in the alternate 159 // memory space. This is useful for calculating the benefit of placing the 160 // buffer in alternate memory. 161 float GetInstructionElapsedDueToMemory( 162 const HloInstruction& instruction, 163 absl::Span<const std::pair<int64_t, ShapeIndex>> 164 operands_in_alternate_mem = {}, 165 absl::Span<const ShapeIndex> outputs_in_alternate_mem = {}) const; 166 167 // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in 168 // the alternate memory. 169 float GetInstructionElapsedDueToMemory( 170 const HloInstruction& instruction, 171 IsInAlternateMemoryFun is_in_alternate_mem) const; 172 173 // Returns the estimated elapsed duration of the instruction in seconds. It 174 // assumes all operands and outputs of the instruction are in the default 175 // memory. 176 virtual float GetInstructionElapsed(const HloInstruction& instruction) const; 177 178 // Returns the estimated elapsed duration of the instruction in seconds. It 179 // assumes all operands and outputs of the instruction are in the default 180 // memory, except for the operands and outputs specified to be in the 181 // alternate memory. 182 virtual float GetInstructionElapsedInAlternateMemory( 183 const HloInstruction& instruction, 184 absl::Span<const std::pair<int64_t, ShapeIndex>> 185 operands_in_alternate_mem, 186 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const; 187 188 // Like above, only the inputs/outputs indicated by is_in_alternate_mem are in 189 // the alternate memory. 190 float GetInstructionElapsedInAlternateMemory( 191 const HloInstruction& instruction, 192 IsInAlternateMemoryFun is_in_alternate_mem) const; 193 194 // Returns the elapsed time it would take to asynchronously copy the shape 195 // from default to alternate memory space (or vice versa). 196 virtual float GetAsyncCopyElapsed(const Shape& shape) const; 197 198 int64_t GetScheduleEndTime() const; 199 200 // Returns the number of nested computation levels this instruction resides 201 // in. If while_only is true, it returns the while loop nest level and 0 202 // means the instruction is not in a while loop. 203 int CalculateComputationNestLevel(const HloInstruction* instruction, 204 bool while_only) const; 205 hlo_live_range()206 const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } options()207 const Options& options() const { return options_; } 208 209 protected: MemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,const Options & options,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)210 MemorySpaceAssignmentCostAnalysis( 211 const HloCostAnalysis& cost_analysis, const Options& options, 212 std::unique_ptr<HloAliasAnalysis> alias_analysis, 213 std::unique_ptr<HloLiveRange> hlo_live_range, 214 std::unique_ptr<CallGraph> call_graph) 215 : cost_analysis_(cost_analysis), 216 options_(options), 217 alias_analysis_(std::move(alias_analysis)), 218 hlo_live_range_(std::move(hlo_live_range)), 219 call_graph_(std::move(call_graph)) {} 220 221 private: 222 const HloCostAnalysis& cost_analysis_; 223 const Options& options_; 224 std::unique_ptr<HloAliasAnalysis> alias_analysis_; 225 std::unique_ptr<HloLiveRange> hlo_live_range_; 226 std::unique_ptr<CallGraph> call_graph_; 227 }; 228 229 // Abstract base class that memory space assignment uses to pick prefetch 230 // intervals. 231 class PrefetchIntervalPicker { 232 public: 233 PrefetchIntervalPicker() = default; 234 virtual ~PrefetchIntervalPicker() = default; 235 236 // Returns true if the buffer can be allocated in alternate memory space 237 // without any copies (prefetches). 238 virtual bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 239 int64_t start_time, 240 int64_t end_time) const = 0; 241 242 // Returns the preferred end time for an eviction that starts at a given time 243 // and must end by the given end time. 244 virtual int64_t PreferredEvictionEndTime(const Shape& shape, 245 int64_t start_time, 246 int64_t latest_end_time) const = 0; 247 248 // Returns the latest time that a prefetch can start. 249 virtual int64_t LatestPrefetchStartTime(const Shape& shape, 250 int64_t start_time, int64_t end_time, 251 const HloUse* use) const = 0; 252 253 // Returns the preferred time that a prefetch can start. 254 virtual int64_t PreferredPrefetchStartTime( 255 const Shape& shape, int64_t earliest_prefetch_start_time, 256 int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const = 0; 257 258 // Returns the latest time that a prefetch can end that is less than or equal 259 // to proposed_prefetch_end_time. LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time)260 virtual int64_t LatestPrefetchEndTime( 261 int64_t original_prefetch_end_time, 262 int64_t proposed_prefetch_end_time) const { 263 return proposed_prefetch_end_time; 264 } 265 266 // Returns the estimated end time of a prefetch that starts at the given time. 267 virtual int64_t EstimatedPrefetchEndTime(const Shape& shape, 268 int64_t start_time, 269 int64_t end_time) const = 0; 270 271 // Returns the elapsed time in seconds between the logical interval that 272 // corresponds to the instruction schedule. 273 virtual float GetLogicalIntervalElapsed(int64_t start_time, 274 int64_t end_time) const = 0; 275 276 // Begins the iterator for the first start time of the prefetch. 277 virtual void Begin(const HloUse& use, int64_t start_time, 278 int64_t end_time) = 0; 279 280 // Advances the start time of the prefetch and returns that value. 281 virtual int64_t Next() = 0; 282 283 // Returns true if the available prefetch intervals have been exhausted. 284 virtual bool Done() const = 0; 285 286 // Returns the latest time the prefetch interval picker will have pick. 287 virtual int64_t latest_time() const = 0; 288 289 // The retry number can be used to modify the interval picking policies. The 290 // first attempt will have a retry_number of 0, then 1, etc. SetRetryNumber(int retry_number)291 virtual void SetRetryNumber(int retry_number) { 292 retry_number_ = retry_number; 293 } retry_number()294 int retry_number() const { return retry_number_; } 295 296 // Returns a debug string for the current state of the prefetch interval 297 // picker. 298 virtual std::string ToDebugString() const = 0; 299 300 // Returns a debug string for no-copy allocation. 301 virtual std::string ToNoCopyDebugString(const Shape& shape, 302 int64_t start_time, 303 int64_t end_time) const = 0; 304 305 // Prefetch interval pickers may return a value corresponding to the benefit 306 // of placing the BufferInterval in the alternate memory. The larger value, 307 // the more beneficial. BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)308 virtual std::optional<float> BufferIntervalAlternateMemoryBenefit( 309 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 310 const { 311 return std::nullopt; 312 } 313 314 protected: 315 const absl::flat_hash_map<const HloInstruction*, int64_t>* 316 instruction_schedule_ = nullptr; 317 int retry_number_ = 0; 318 }; 319 320 // Prefetch interval picker that uses instruction count to overlap asynchronous 321 // copies with independent computation. The min and max overlap counts describe 322 // the number of independent HLOs overlapped while a value is being prefetched 323 // into the alternate memory (between CopyStart and CopyDone HLO instructions). 324 // max_overlap_count attempts to prevent bringing tensors into the alternate 325 // memory too eagerly and hence occupying the space for other tensors which 326 // might use it. min_overlap_count attempts to prevent cases where tensors are 327 // prefetched into the alternate memory without sufficient time for the copy to 328 // take place. In those cases, it's just better to keep the tensor in the 329 // default memory instead of hurting the critical path with this copy that 330 // likely won't finish in time. 331 class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { 332 public: InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count,int64_t max_overlap_count)333 InstructionCountPrefetchIntervalPicker(int64_t min_overlap_count, 334 int64_t max_overlap_count) 335 : min_overlap_count_(min_overlap_count), 336 max_overlap_count_(max_overlap_count) {} 337 338 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 339 int64_t start_time, 340 int64_t end_time) const override; 341 342 int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, 343 int64_t latest_end_time) const override; 344 345 int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, 346 int64_t end_time, 347 const HloUse* use) const override; 348 349 int64_t PreferredPrefetchStartTime(const Shape& shape, 350 int64_t earliest_prefetch_start_time, 351 int64_t latest_prefetch_start_time, 352 int64_t prefetch_end_time) const override; 353 354 int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, 355 int64_t end_time) const override; 356 float GetLogicalIntervalElapsed(int64_t start_time, 357 int64_t end_time) const override; 358 359 void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override; 360 361 int64_t Next() override; 362 bool Done() const override; 363 364 int64_t latest_time() const override; 365 366 std::string ToDebugString() const override; 367 std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, 368 int64_t end_time) const override; 369 370 private: 371 int64_t min_overlap_count_; 372 int64_t max_overlap_count_; 373 int64_t end_time_; 374 int64_t current_prefetch_time_; 375 }; 376 377 // Forward Declaration of MemorySpaceAssignmentCostAnalysis 378 class MemorySpaceAssignmentCostAnalysis; 379 // Prefetch interval picker that uses cost analysis to overlap asynchronous 380 // copies with independent computation. It uses min (independent computation 381 // duration) / (asynchronous copy duration) ratio to guide whether the prefetch 382 // is within the lower bound. For the upper bound, it restricts the maximum 383 // duration that a buffer may occupy the alternate memory space as a multiple of 384 // the time it would take to copy a buffer that is the size of the alternate 385 // memory. It starts with the preferred ratio in Begin() and works its way for 386 // alternately earlier and later prefetches until hitting min and max ratios. 387 // The value for buffer size for max async copy is a mechanism to prevent 388 // copying small buffers between the two memories unnecessarily. For calculating 389 // the max time that the buffer can reside in alternate memory, we use the 390 // larger of this value and the actual size of the buffer. 391 class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { 392 public: 393 CostAnalysisPrefetchIntervalPicker( 394 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 395 float min_overlap_to_async_copy_ratio, 396 float preferred_overlap_to_async_copy_ratio, 397 float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes); 398 399 bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, 400 int64_t start_time, 401 int64_t end_time) const override; 402 403 int64_t PreferredEvictionEndTime(const Shape& shape, int64_t start_time, 404 int64_t latest_end_time) const override; 405 406 int64_t LatestPrefetchEndTime( 407 int64_t original_prefetch_end_time, 408 int64_t proposed_prefetch_end_time) const override; 409 410 int64_t LatestPrefetchStartTime(const Shape& shape, int64_t start_time, 411 int64_t end_time, 412 const HloUse* use) const override; 413 414 int64_t PreferredPrefetchStartTime(const Shape& shape, 415 int64_t earliest_prefetch_start_time, 416 int64_t latest_prefetch_start_time, 417 int64_t prefetch_end_time) const override; 418 419 int64_t EstimatedPrefetchEndTime(const Shape& shape, int64_t start_time, 420 int64_t end_time) const override; 421 float GetLogicalIntervalElapsed(int64_t start_time, 422 int64_t end_time) const override; 423 424 void Begin(const HloUse& use, int64_t start_time, int64_t end_time) override; 425 426 int64_t Next() override; 427 bool Done() const override; 428 429 int64_t latest_time() const override; 430 431 void SetRetryNumber(int retry_number) override; 432 433 std::string ToDebugString() const override; 434 std::string ToNoCopyDebugString(const Shape& shape, int64_t start_time, 435 int64_t end_time) const override; 436 437 std::optional<float> BufferIntervalAlternateMemoryBenefit( 438 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) 439 const override; 440 441 private: 442 // Finds the minimum nest level in the given interval. 443 int GetMinWhileNestLevel(int64_t start_time, int64_t end_time) const; 444 445 // Given the elapsed time to copy this buffer to the alternate memory, returns 446 // the longest time that this buffer may reside in the alternate memory space. 447 float GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const; 448 449 // For each instruction in the flattened schedule, maintain their elapsed time 450 // (in cumulative sum) and while nesting level. 451 std::vector<float> elapsed_time_cumsum_; 452 std::vector<int> while_nest_level_; 453 std::vector<int> computation_nest_level_; 454 // Maintain the index of the most recent (before this instruction) nest level 455 // change in order to efficiently determine the minimum nest level in an 456 // interval. 457 std::vector<int> while_nest_level_change_; 458 459 const MemorySpaceAssignmentCostAnalysis& cost_analysis_; 460 float min_overlap_to_async_copy_ratio_; 461 float preferred_overlap_to_async_copy_ratio_; 462 float max_async_copy_elapsed_; 463 float max_overlap_multiplier_ = 1.0; 464 465 float async_copy_elapsed_; 466 float inst_elapsed_reduction_; 467 int64_t end_logical_time_; 468 int64_t earliest_prefetch_time_; 469 int64_t latest_prefetch_time_; 470 bool using_increasing_prefetch_time_iterator_ = true; 471 int64_t increasing_prefetch_time_iterator_; 472 int64_t decreasing_prefetch_time_iterator_; 473 474 std::vector<float> while_execution_counts_; 475 }; 476 477 // MemorySpaceAssignment assigns memory spaces (default or alternate) to each 478 // instruction in the module. It will greedily try placing as as many values in 479 // the alternate memory space as possible. It uses the heap simulator to 480 // determine the actual allocation offsets of values in the alternate memory 481 // space to account for fragmentation. The default memory space is assumed to be 482 // large enough to hold the values that could not be placed in the alternate 483 // memory space. 484 class MemorySpaceAssignment { 485 public: 486 using Chunk = HeapSimulator::Chunk; 487 using BufferInterval = 488 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval; 489 using BufferIntervalCompare = 490 GlobalDecreasingSizeBestFitHeap<HloValue>::BufferIntervalCompare; 491 using IsAllowedInAlternateMemoryFunction = 492 std::function<bool(const HloValue&)>; 493 using IsUseAllowedInAlternateMemoryFunction = 494 std::function<bool(const HloUse&)>; 495 using ReservedScopedMemoryFunction = 496 std::function<int64_t(const HloInstruction*)>; 497 498 // MemorySpaceAssignment uses a notion of a slow and large default memory 499 // space and a fast and small alternate memory space. 500 enum class MemorySpace { kDefault, kAlternate }; 501 502 // Forward declaration for Allocation. 503 class Allocation; 504 class ParentAllocation; 505 506 // This class represents an allocation that might either be in the default or 507 // alternate memory. An HloValue might live in multiple different allocations 508 // over its lifetime. The lifetimes of the allocations are defined using 509 // start_time and end_time, which corresponds to the instruction indexes in 510 // the flattened schedule. Each of these allocations might partially overlap 511 // with each other. CopyAllocation defined below represents asynchronous 512 // copies between Allocations. 513 // 514 // Consider an instruction Foo, and its users Bar and Baz, and the times given 515 // in terms of the flattened schedule of the entire module: 516 // 517 // Foo:10 518 // / \ 519 // Bar:14 \ 520 // Baz:25 521 // 522 // A valid memory space assignment could be like the following: 523 // 524 // Time: 10 ... 14 ... 25 525 // Foo Bar Baz 526 // Alternate +-------+ +-----+ 527 // Default +---------------------+ 528 // ^ ^ ^ ^ 529 // | | | | 530 // evict evict prefetch prefetch 531 // start end start end 532 // 533 // This would be represented with: 534 // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) 535 // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) 536 // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) 537 class Allocation { 538 friend class ParentAllocation; 539 540 public: Allocation(HloPosition defining_position,MemorySpace memory_space,std::optional<Chunk> chunk,int64_t start_time,int64_t end_time,bool is_scoped_allocation)541 Allocation(HloPosition defining_position, MemorySpace memory_space, 542 std::optional<Chunk> chunk, int64_t start_time, int64_t end_time, 543 bool is_scoped_allocation) 544 : defining_position_(defining_position), 545 memory_space_(memory_space), 546 chunk_(chunk), 547 start_time_(start_time), 548 end_time_(end_time), 549 is_scoped_allocation_(is_scoped_allocation) { 550 CHECK(!is_scoped_allocation || defining_position.index == ShapeIndex({})); 551 } 552 virtual ~Allocation() = default; 553 is_copy_allocation()554 virtual bool is_copy_allocation() const { return false; } 555 556 // Adds a use to this allocation. 557 void AddUse(HloUse use); 558 559 // Extends the end time of this allocation. Extend(int64_t end_time)560 void Extend(int64_t end_time) { end_time_ = end_time; } 561 562 // After all of the time ranges for the allocations have been assigned, 563 // Process morphs the instructions affected to assign the memory spaces and 564 // insert asynchronous copy instructions if necessary. 565 virtual Status Process(); 566 567 // An optional post-process step that will be called after all allocations 568 // have been processed. PostProcess()569 virtual Status PostProcess() { return OkStatus(); } 570 571 // Marks (adds this allocation to needed_allocations) if this allocation is 572 // needed. Allocation and CopyAllocations are always needed and 573 // ParentAllocations are needed if they have any uses or if other 574 // CopyAllocation or ParentAllocations depend on them. 575 virtual void MarkIfNeeded( 576 absl::flat_hash_set<const Allocation*>& needed_allocations) const; 577 578 // Marks this allocation as needed. 579 virtual void MarkNeeded( 580 absl::flat_hash_set<const Allocation*>& needed_allocations) const; 581 582 // Returns the defining position for this allocation. defining_position()583 virtual HloPosition defining_position() const { return defining_position_; } 584 585 // Returns the time the buffer is first available to be used. For 586 // Allocation, this is start_time. earliest_available_time()587 virtual int64_t earliest_available_time() const { return start_time_; } 588 uses()589 const std::vector<HloUse>& uses() const { return uses_; } memory_space()590 MemorySpace memory_space() const { return memory_space_; } 591 // Returns the associated chunk that may be a nullopt if the allocation is 592 // in the default memory space. maybe_chunk()593 std::optional<Chunk> maybe_chunk() const { return chunk_; } 594 // Returns the associated chunk. The caller should ensure that the chunk is 595 // defined (the allocation should be in the alternate memory space). chunk()596 Chunk chunk() const { 597 CHECK(chunk_.has_value()); 598 return *chunk_; 599 } mutable_chunk()600 Chunk* mutable_chunk() { return &*chunk_; } set_start_time(int64_t start_time)601 void set_start_time(int64_t start_time) { start_time_ = start_time; } start_time()602 int64_t start_time() const { return start_time_; } end_time()603 int64_t end_time() const { return end_time_; } is_scoped_allocation()604 bool is_scoped_allocation() const { return is_scoped_allocation_; } 605 606 bool operator==(const Allocation& other) const; 607 virtual std::string ToString() const; 608 609 protected: 610 // Recursively create kGetTupleElement instructions if the defining position 611 // shape is not an array. Returns the new instruction that has array shape. 612 HloInstruction* AddGetTupleElements() const; 613 614 HloPosition defining_position_; 615 std::vector<HloUse> uses_; 616 MemorySpace memory_space_; 617 std::optional<Chunk> chunk_; 618 int64_t start_time_; 619 int64_t end_time_; 620 const bool is_scoped_allocation_; 621 }; 622 623 // This class represents an allocation as a result of an asynchronous copy. 624 // Note: CopyStart instructions are inserted after `start_time` or later, 625 // while CopyDone instructions are inserted before 626 // `copy_done_schedule_before_time` or earlier. 627 class CopyAllocation : public Allocation { 628 public: 629 CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, 630 std::optional<Chunk> chunk, int64_t start_time, 631 int64_t end_time, int64_t copy_done_schedule_before_time, 632 bool is_cross_program_prefetch = false) 633 : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, 634 start_time, end_time, /*is_scoped_allocation=*/false), 635 prev_allocation_(prev_allocation), 636 copy_start_schedule_after_(start_time), 637 copy_done_schedule_before_(copy_done_schedule_before_time), 638 is_cross_program_prefetch_(is_cross_program_prefetch) {} 639 is_copy_allocation()640 bool is_copy_allocation() const override { return true; } 641 642 Status Process() override; 643 644 void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations) 645 const override; 646 defining_position()647 HloPosition defining_position() const override { 648 // Unless explicitly set, the defining position of a copy allocation in 649 // retrieved from the previous allocation. This is because we don't create 650 // new CopyStart/CopyDone instructions until later and the position should 651 // point to the previous (copy or otherwise) allocation's position for the 652 // original defining position. 653 if (defining_position_.instruction == nullptr) { 654 return prev_allocation_.defining_position(); 655 } 656 return defining_position_; 657 } 658 copy_start()659 HloInstruction* copy_start() const { return copy_start_; } copy_done()660 HloInstruction* copy_done() const { return copy_done_; } 661 662 // Returns the time the buffer is first available to be used. For 663 // CopyAllocation, this is when the copy ends, which is 664 // copy_done_schedule_before. earliest_available_time()665 int64_t earliest_available_time() const override { 666 return copy_done_schedule_before_; 667 } 668 copy_start_schedule_after()669 int64_t copy_start_schedule_after() const { 670 return copy_start_schedule_after_; 671 } copy_done_schedule_before()672 int64_t copy_done_schedule_before() const { 673 return copy_done_schedule_before_; 674 } 675 set_copy_start_schedule_after(int64_t copy_start_schedule_after)676 void set_copy_start_schedule_after(int64_t copy_start_schedule_after) { 677 copy_start_schedule_after_ = copy_start_schedule_after; 678 } 679 set_copy_done_schedule_before(int64_t copy_done_schedule_before)680 void set_copy_done_schedule_before(int64_t copy_done_schedule_before) { 681 copy_done_schedule_before_ = copy_done_schedule_before; 682 } 683 is_cross_program_prefetch()684 bool is_cross_program_prefetch() const { 685 return is_cross_program_prefetch_; 686 } 687 688 bool operator==(const CopyAllocation& other) const; 689 std::string ToString() const override; 690 691 private: 692 const Allocation& prev_allocation_; 693 // These variables define the scheduling boundaries where CopyStart and 694 // CopyDone can be scheduled. The earliest CopyStart can be scheduled is 695 // after copy_start_schedule_after_ and the latest CopyDone can be scheduled 696 // is before copy_done_schedule_before_. 697 int64_t copy_start_schedule_after_; 698 int64_t copy_done_schedule_before_; 699 bool is_cross_program_prefetch_; 700 HloInstruction* copy_start_; 701 HloInstruction* copy_done_; 702 }; 703 704 // An allocation in the default memory space that mirrors another Allocation 705 // object. This is useful to model an eviction that happens before a while op 706 // so that we don't need to redundantly evict the buffer after the while op as 707 // well. 708 class MirroredAllocation : public Allocation { 709 public: MirroredAllocation(const Allocation & original_allocation,int64_t time)710 MirroredAllocation(const Allocation& original_allocation, int64_t time) 711 : Allocation(original_allocation.defining_position(), 712 MemorySpace::kDefault, original_allocation.maybe_chunk(), 713 /*start_time=*/time, 714 /*end_time=*/time, /*is_scoped_allocation=*/false), 715 original_allocation_(original_allocation) {} 716 717 Status Process() override; 718 719 void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations) 720 const override; 721 722 std::string ToString() const override; 723 724 private: 725 const Allocation& original_allocation_; 726 }; 727 728 // An allocation in default memory space that is defined in the parent 729 // computation. If a value has a copy in the default memory space in the 730 // parent computation, we don't need to evict this buffer in a while loop. 731 class ParentAllocation : public Allocation { 732 public: ParentAllocation(const Allocation & original_allocation,HloInstruction * calling_instruction,HloPosition position,int64_t time)733 ParentAllocation(const Allocation& original_allocation, 734 HloInstruction* calling_instruction, HloPosition position, 735 int64_t time) 736 : Allocation(position, MemorySpace::kDefault, 737 original_allocation.maybe_chunk(), /*start_time=*/time, 738 /*end_time=*/time, /*is_scoped_allocation=*/false), 739 original_allocation_(original_allocation), 740 calling_instruction_(calling_instruction) {} 741 742 Status Process() override; 743 Status PostProcess() override; 744 745 void MarkIfNeeded(absl::flat_hash_set<const Allocation*>& 746 needed_allocations) const override; 747 void MarkNeeded(absl::flat_hash_set<const Allocation*>& needed_allocations) 748 const override; 749 750 std::string ToString() const override; 751 752 private: 753 const Allocation& original_allocation_; 754 HloInstruction* calling_instruction_; 755 }; 756 757 using AllocationSequence = std::vector<std::unique_ptr<Allocation>>; 758 // AllocationValue is used to break up HloValues for each non-trivial position 759 // (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An 760 // HloValue may include positions and uses that alias with each other across 761 // multiple computations. We use this class to break these HloValues such that 762 // every AllocationValue has one defining position (that may alias with other 763 // AllocationValues). The uses field of the AllocationValue contains only the 764 // direct uses of the AllocationValue's defining position. 765 // 766 // For example, consider the following HLO snippet: 767 // 768 // Body { 769 // body_param = (f32[4,3]{1,0}, f32[]) parameter(0) 770 // get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), 771 // index=0 772 // ... 773 // ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) 774 // } 775 // 776 // Cond { 777 // cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) 778 // ... 779 // } 780 // 781 // add.4 = f32[4,3]{1,0} add(...) 782 // tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) 783 // while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond 784 // get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 785 // add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) 786 // 787 // This contains an HloValue that looks like the following: 788 // positions: 789 // add.4 790 // body_param {0} 791 // get-tuple-element.3 792 // tuple {0} 793 // cond_param {0} 794 // tuple.1 {0} 795 // while {0} 796 // get-tuple-element.5 797 // uses: 798 // add.1, operand 0 799 // tuple, operand 0 800 // while, operand 0 {0} 801 // add.5, operand 0 802 // 803 // We break this HloValue up into the following AllocationValues for each 804 // non-trivial position: 805 // AllocationValue1: computation = Entry 806 // position: 807 // add.4 808 // uses: 809 // while, operand 0 {0} 810 // AllocationValue2: computation = Cond 811 // position: 812 // cond_param {0} 813 // uses: 814 // AllocationValue3: computation = Body 815 // position: 816 // body_param {0} 817 // uses: 818 // add.1, operand 0 819 // tuple, operand 0 820 // AllocationValue4: computation = Entry 821 // position: 822 // while {0} 823 // uses: 824 // add.5, operand 0 825 class AllocationValue { 826 public: 827 // This data structure wraps an HloUse and adds additional metadata that are 828 // useful for allocation. 829 struct Use { 830 // The wrapped HloUse object. 831 HloUse hlo_use; 832 // The logical time this use is scheduled. 833 int64_t time; 834 // All the positions where this use aliases with. The aliased positions 835 // must get the same allocation. 836 std::vector<HloPosition> aliases; 837 838 bool operator==(const Use& other) const { 839 return hlo_use == other.hlo_use && time == other.time && 840 aliases == other.aliases; 841 } 842 843 template <typename H> AbslHashValueUse844 friend H AbslHashValue(H h, const Use& s) { 845 return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); 846 } 847 }; 848 AllocationValue(const HloValue * value,const HloPosition & position,int64_t size)849 AllocationValue(const HloValue* value, const HloPosition& position, 850 int64_t size) 851 : value_(value), 852 defining_position_(position), 853 size_(size), 854 requires_contiguous_allocation_(false) {} 855 defining_position()856 const HloPosition& defining_position() const { return defining_position_; } defining_instruction()857 const HloInstruction* defining_instruction() const { 858 return defining_position().instruction; 859 } size()860 int64_t size() const { return size_; } uses()861 const std::vector<Use>& uses() const { return uses_; } uses()862 std::vector<Use>& uses() { return uses_; } value()863 const HloValue* value() const { return value_; } computation()864 const HloComputation* computation() const { 865 return defining_instruction()->parent(); 866 } allocation_sequence()867 AllocationSequence* allocation_sequence() { return &allocation_sequence_; } 868 869 // Sets/gets whether this AllocationValue requires allocating it 870 // contiguously throughout its live range (without any copies). requires_contiguous_allocation()871 bool requires_contiguous_allocation() const { 872 return requires_contiguous_allocation_; 873 } set_requires_contiguous_allocation(bool requires_contiguous_allocation)874 void set_requires_contiguous_allocation( 875 bool requires_contiguous_allocation) { 876 requires_contiguous_allocation_ = requires_contiguous_allocation; 877 } 878 AddUse(const HloUse & use,int64_t use_time)879 void AddUse(const HloUse& use, int64_t use_time) { 880 uses_.push_back({use, use_time, {}}); 881 } 882 883 std::string ToString() const; 884 std::string ToShortString() const; 885 886 private: 887 const HloValue* value_; 888 HloPosition defining_position_; 889 int64_t size_; 890 // If true, there must be a contiguous allocation for this buffer without 891 // any copies. 892 bool requires_contiguous_allocation_; 893 std::vector<Use> uses_; 894 AllocationSequence allocation_sequence_; 895 }; 896 897 // Statistics of asynchronous copies. 898 struct AsyncCopyStats { 899 int64_t max_outstanding_async_copies; 900 int64_t num_prefetches; 901 int64_t prefetch_bytes; 902 int64_t num_evictions; 903 int64_t eviction_bytes; 904 }; 905 906 virtual ~MemorySpaceAssignment() = default; 907 908 // Runs the MemorySpaceAssignment pass. 909 static StatusOr<std::unique_ptr<PresetAssignments>> Run( 910 HloModule* module, const HloLiveRange& hlo_live_range, 911 const HloAliasAnalysis& alias_analysis, const Options& options); 912 913 // Calculates asynchronous copy statistics. 914 StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const; 915 916 static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( 917 const MemorySpaceAssignmentCostAnalysis& cost_analysis, 918 MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr); 919 920 // Verify that the memory space assignment is free of overlapping buffers and 921 // export heap simulator trace to be used by buffer_assignment. 922 Status VerifyAndExportHeapSimulatorTrace(); 923 924 protected: 925 // Main driver of the memory space assignment pass. 926 virtual StatusOr<std::unique_ptr<PresetAssignments>> RunMemorySpaceAssignment( 927 const HloLiveRange& hlo_live_range, 928 const HloAliasAnalysis& alias_analysis); 929 930 // Finds an AllocationSequence for placing buffers in alternate memory using 931 // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is 932 // called. 933 virtual Status FindAllocationSequence(const HloLiveRange& hlo_live_range, 934 const HloAliasAnalysis& alias_analysis); 935 options()936 const Options& options() const { return options_; } 937 MemorySpaceAssignment(HloModule * module,const Options & options,const HloLiveRange & hlo_live_range)938 MemorySpaceAssignment(HloModule* module, const Options& options, 939 const HloLiveRange& hlo_live_range) 940 : module_(module), 941 options_(options), 942 flattened_instructions_(hlo_live_range.flattened_instruction_sequence() 943 .instructions() 944 .begin(), 945 hlo_live_range.flattened_instruction_sequence() 946 .instructions() 947 .end()), 948 computations_in_schedule_(), 949 preset_assignments_(std::make_unique<PresetAssignments>()) { 950 for (const auto& computation_and_bound : 951 hlo_live_range.computation_span_times()) { 952 computations_in_schedule_.insert(computation_and_bound.first); 953 } 954 } 955 956 AllocationSequence allocations_; 957 module()958 HloModule* module() { return module_; } 959 960 private: 961 // Process calls Process methods of the allocations after the allocations have 962 // been finalized. 963 Status Process(); 964 965 // Process() might have altered the computation graph by inserting kTuple and 966 // kGetTupleElement instructions. SimplifyGraph performs a simple DCE and 967 // tuple simplification operation (e.g., given GetTupleElement(Tuple(a, b), 968 // 1), simply forwards b). Runs to fixed point. 969 Status SimplifyGraph(); 970 971 // FixSchedule inserts asynchronous copies in the schedule. 972 Status FixSchedule(); 973 974 // Export the alternate memory assignments to the PresetAssignments and color 975 // the HLO graph with the determined memory spaces. 976 Status ExportAndColorBuffers(); 977 978 // Schedules asynchronous copies and ensures that the CopyStarts and their 979 // corresponding CopyDones follow the same order. 980 void ScheduleAsynchronousCopies(); 981 982 // Remove the positions and chunks associated with the instruction from 983 // alternate_memory_assignments_. 984 void RemoveAssignmentForInstruction(const HloInstruction* instruction); 985 986 // Returns the estimated elapsed duration of the hlo module in seconds. It 987 // uses the 'allocations' argument to determine the location (default memory 988 // or alternate memory) of each operand and output of an instruction. 989 float ComputeEstimatedElapsedTime(const HloLiveRange& hlo_live_range, 990 const AllocationSequence& allocations); 991 992 HloModule* module_; 993 const Options& options_; 994 std::vector<HloInstruction*> flattened_instructions_; 995 absl::flat_hash_set<const HloComputation*> computations_in_schedule_; 996 std::unique_ptr<PresetAssignments> preset_assignments_; 997 std::vector<std::pair<HloPosition, Chunk>> alternate_memory_assignments_; 998 std::vector<std::pair<HloInstruction*, Chunk>> scoped_memory_assignments_; 999 int64_t alternate_memory_size_ = 0; 1000 1001 // These maps hold vectors of new instructions that need to be scheduled after 1002 // (or before) the instruction index in the key. FixSchedule uses these maps 1003 // to modify and fix the schedule. 1004 absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> schedule_after_; 1005 absl::flat_hash_map<int64_t, std::vector<HloInstruction*>> schedule_before_; 1006 }; 1007 1008 // The different options to be passed to the Run() API. 1009 struct Options { 1010 // Backend-specific integer value that describes the alternate memory. 1011 int64_t alternate_memory_space = 0; 1012 1013 // Maximum size of the alternate memory space. 1014 int64_t max_size_in_bytes = 0; 1015 1016 // Memory alignment of the alternate memory space. 1017 int64_t alignment_in_bytes = 1; 1018 1019 // If provided, we sort the buffers using this comparison function 1020 // otherwise, we use GlobalDecreasingSizeBestFitHeap::kSpatial. 1021 std::optional<MemorySpaceAssignment::BufferIntervalCompare> 1022 buffer_interval_compare = std::nullopt; 1023 1024 // This object determines how early and how late prefetches can occur. 1025 PrefetchIntervalPicker* prefetch_interval_picker = nullptr; 1026 1027 // This object is used to determine the benefit of a particular allocation. 1028 MemorySpaceAssignmentCostAnalysis* cost_analysis = nullptr; 1029 1030 // Size function for buffer values. 1031 BufferValue::SizeFunction size_fn; 1032 1033 // This function can be used to prevent certain HloValues (e.g., based on 1034 // the opcode) to be placed on the alternate memory. 1035 MemorySpaceAssignment::IsAllowedInAlternateMemoryFunction 1036 is_allowed_in_alternate_mem_fn; 1037 1038 // This function can be used to prevent certain HloUses (e.g., based on 1039 // the opcode) to be placed on the alternate memory. 1040 MemorySpaceAssignment::IsUseAllowedInAlternateMemoryFunction 1041 is_use_allowed_in_alternate_mem_fn = [](const HloUse&) { return true; }; 1042 1043 // This function returns the amount of scoped memory in bytes that should be 1044 // reserved during the execution of this instruction. 1045 MemorySpaceAssignment::ReservedScopedMemoryFunction 1046 reserved_scoped_memory_fn = [](const HloInstruction*) { return 0; }; 1047 1048 // If true, we allocate the reserved scoped memory at the same offset. This 1049 // is useful to enable more deduplication between HLOs that have reserved 1050 // scoped memories, but may result in less efficient memory packing. 1051 bool allocate_reserved_scoped_memory_at_same_offset = true; 1052 1053 // Specifies the upper bound for number of outstanding prefetches and 1054 // evictions, -1 for unlimited. 1055 int64_t max_outstanding_prefetches = -1; 1056 int64_t max_outstanding_evictions = -1; 1057 1058 // Extra outstanding prefetch limit for while uses (in addition to 1059 // max_outstanding_prefetches). 1060 int64_t while_use_extra_outstanding_prefetch_limit = 0; 1061 1062 // Specifies the maximum number of retries that will be performed for each 1063 // value in case prefetching failed due to running out of asynchronous 1064 // copies or asynchronous copy resource. 1065 int64_t max_retries = 1; 1066 1067 // The maximum number of repacks that we are willing to perform in case we 1068 // can't allocate a buffer due to running out of memory. If this value is 1069 // greater than 0, repacker must be non-nullptr. 1070 int64_t max_repacks = 0; 1071 1072 // This variable is used by the cost analysis in estimating how many times 1073 // each while loop will execute. Nested loops will be assumed to have 1074 // executed pow(while_execution_count, nesting_level) times. 1075 uint64_t xla_tpu_memory_space_assignment_while_execution_count = 5ULL; 1076 1077 float async_copy_bandwidth_bytes_per_second = 0.0f; 1078 1079 float alternate_mem_bandwidth_bytes_per_second = 0.0f; 1080 1081 // The repacking algorithm to reduce fragmentation. Must be non-null if 1082 // max_repacks is greater than 0. 1083 MemorySpaceAssignmentRepacker* repacker = nullptr; 1084 1085 // This is only useful for testing, repack after every allocation. 1086 bool repack_after_every_allocation = false; 1087 1088 // If true, tries allocating buffers across (e.g., before and inside a while 1089 // loop body) sequential calls (kWhile, kCall, and kConditional). 1090 bool allocate_across_sequential_calls = false; 1091 1092 // If true, verifies the memory space assignment against overlapping 1093 // buffers. 1094 bool verify = false; 1095 1096 // If not nullptr, this function is called to dump debugging information. 1097 // The first argument is appended to the file name and the second argument 1098 // is the contents of the file. 1099 std::function<void(absl::string_view, absl::string_view)> dump_fn = nullptr; 1100 1101 // Enable prefetching buffers into preferred memory across program 1102 // boundaries 1103 bool enable_cross_program_prefetch = true; 1104 1105 // If true, use buffer_interval_compare to determine which buffers to 1106 // prefetch across program boundaries. 1107 bool default_cross_program_prefetch_heuristic = false; 1108 1109 // Enable cross-program prefetch freeing optimization where the 1110 // cross-program-prefetched buffer can be reused. 1111 bool enable_cross_program_prefetch_freeing = true; 1112 1113 // Enable redundant eviction optimization in/around while loops. If enabled, 1114 // this optimization would keep a copy of the buffer in the default memory in 1115 // addition to alternate memory to eliminate redundant evictions. 1116 bool enable_while_redundant_eviction_elimination = true; 1117 1118 // An optional memory space assignment autotuning config, which is used 1119 // to sort allocated buffers. 1120 std::optional<std::vector<uint64_t>> autotuning_config = std::nullopt; 1121 }; 1122 1123 // A struct representing an asynchronous copy with its logical start and end 1124 // time (time that copy done is scheduled), the resource this copy would use, 1125 // its destination memory space, and a unique ID. 1126 struct AsynchronousCopy { 1127 int64_t start_time; 1128 int64_t end_time; 1129 float resource; 1130 MemorySpaceAssignment::MemorySpace destination; 1131 int64_t id; 1132 1133 std::tuple<int64_t, int64_t, float, MemorySpaceAssignment::MemorySpace, 1134 int64_t> AsTupleAsynchronousCopy1135 AsTuple() const { 1136 return std::make_tuple(start_time, end_time, resource, destination, id); 1137 } 1138 }; 1139 1140 // Compare asynchronous copies such that an earlier start time has the same or 1141 // earlier end time and an earlier end time has the same or earlier start time. 1142 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); 1143 1144 bool operator==(const AsynchronousCopy& a, const AsynchronousCopy& b); 1145 bool operator!=(const AsynchronousCopy& a, const AsynchronousCopy& b); 1146 1147 // Helper class to enforce asynchronous copy resources by keeping track of 1148 // available copy bandwidth and elapsed times of overlapped operations. It 1149 // maintains a list of initial resources that correspond to the elapsed times of 1150 // overlapped operations. As asynchronous copies are added, the available 1151 // resource is subtracted to keep track of the current state. 1152 class AsynchronousCopyResource { 1153 public: 1154 AsynchronousCopyResource() = default; 1155 1156 // The constructor needs the initial resources. AsynchronousCopyResource(absl::Span<const float> initial_resources)1157 explicit AsynchronousCopyResource(absl::Span<const float> initial_resources) 1158 : initial_resources_(initial_resources.begin(), initial_resources.end()), 1159 delay_(initial_resources.size(), 0) {} 1160 1161 // Adds the given asynchronous copy and updates the current resources. CHECK 1162 // fails if there aren't enough resources to satisfy this copy (the caller 1163 // should use HasEnoughResource first to ensure there is enough resource). 1164 void AddCopy(const AsynchronousCopy& copy); 1165 1166 // Removes the given copy and frees the resource. 1167 void RemoveCopy(const AsynchronousCopy& copy); 1168 1169 // Returns true if a copy with the given start and end times and resource can 1170 // be satisfied. 1171 bool HasEnoughResource(int64_t start_time, int64_t end_time, float resource); 1172 1173 // This is only used for debugging and testing purposes, it returns the 1174 // currently available resource at each logical time. GetCurrentResources()1175 std::vector<float> GetCurrentResources() const { 1176 std::vector<float> current_resources(initial_resources_.begin(), 1177 initial_resources_.end()); 1178 for (int i = 0; i < current_resources.size(); ++i) { 1179 current_resources[i] -= std::min(current_resources[i], delay_[i]); 1180 } 1181 return current_resources; 1182 } 1183 1184 private: 1185 // Internal helper method to implement adding/removing/checking resources. 1186 // Only updates the current resources if update_current_resource is true. The 1187 // current_copy points to an iterator in async_copies_ and this 1188 bool ConsumeResource( 1189 int64_t start_time, int64_t end_time, float resource, 1190 bool update_current_resource, 1191 const std::list<AsynchronousCopy>::iterator* current_copy = nullptr, 1192 float resource_to_free = 0.0); 1193 1194 // We maintain a linked list of asynchronous copies sorted by the start times. 1195 // This allows us to efficiently find the copy that starts right after another 1196 // one because adding a copy might push a copy further into the future. 1197 std::list<AsynchronousCopy> async_copies_; 1198 // To make the lookups into async_copies_ more efficient, we also maintain a 1199 // binary tree that is indexed by the start time, containing iterators into 1200 // async_copies_. 1201 // TODO(b/210891274): Use btree_map after build issue in Windows is resolved. 1202 #if defined(__GNUC__) || defined(__clang__) 1203 absl::btree_map<int64_t, std::list<AsynchronousCopy>::iterator> 1204 async_copy_time_map_; 1205 #else 1206 std::map<int64_t, std::list<AsynchronousCopy>::iterator> async_copy_time_map_; 1207 #endif 1208 std::vector<float> initial_resources_; 1209 std::vector<float> delay_; 1210 }; 1211 1212 // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of 1213 // maximum size. 1214 class AlternateMemoryBestFitHeap 1215 : public GlobalDecreasingSizeBestFitHeap<HloValue> { 1216 public: 1217 using MemorySpace = MemorySpaceAssignment::MemorySpace; 1218 using AllocationValue = MemorySpaceAssignment::AllocationValue; 1219 1220 AlternateMemoryBestFitHeap( 1221 MemorySpaceAssignment::AllocationSequence* allocations, 1222 const Options& options, const HloAliasAnalysis& alias_analysis, 1223 const HloLiveRange& hlo_live_range); 1224 1225 // Allocates a buffer in preferred memory with whole program lifetime and 1226 // enables prefetching prefetch_candidate from default memory across program 1227 // boundaries. 1228 void AllocateCrossProgramPrefetchBuffer( 1229 HloModule* module, std::optional<BufferInterval> prefetch_candidate); 1230 1231 HeapSimulator::Result<HloValue> Finish() override; 1232 1233 protected: 1234 // Given a buffer interval, returns the colocated intervals. Unlike the 1235 // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it 1236 // returns the colocated intervals sorted by scheduled time. 1237 std::vector<const BufferInterval*> GetSortedColocatedIntervals( 1238 const BufferInterval& interval) const; 1239 1240 // Given a BufferInterval, creates AllocationValue objects and corresponding 1241 // AllocationSequences and appends them into allocation_sequence_list_. 1242 void CreateAllocationValues( 1243 const BufferInterval& buffer_interval, 1244 std::vector<AllocationValue>& allocation_values) const; 1245 1246 // Given colocated intervals, populates allocation_values with the 1247 // corresponding AllocationValue objects. 1248 virtual void CreateAllocationValuesFromColocatedIntervals( 1249 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 1250 colocated_intervals, 1251 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values); 1252 1253 // Go through all the uses in the AllocationValues and find the aliasing 1254 // positions. 1255 void FindAliases(std::vector<AllocationValue>* allocation_values) const; 1256 allocations()1257 MemorySpaceAssignment::AllocationSequence* allocations() { 1258 return allocations_; 1259 } options()1260 const Options& options() const { return options_; } alias_analysis()1261 const HloAliasAnalysis& alias_analysis() { return alias_analysis_; } hlo_live_range()1262 const HloLiveRange& hlo_live_range() { return hlo_live_range_; } 1263 1264 private: 1265 // We inherit AllocationBlock struct to attach the Allocation information to 1266 // make importing repacked offsets easier. 1267 struct RepackAllocationBlock 1268 : MemorySpaceAssignmentRepacker::AllocationBlock { 1269 MemorySpaceAssignment::Allocation* allocation; 1270 }; 1271 1272 // A data structure we use to associate Allocation objects that are aliased 1273 // and must get the same offset. 1274 struct AliasedOffset { 1275 int64_t offset; 1276 absl::flat_hash_set<const MemorySpaceAssignment::Allocation*> allocations; 1277 }; 1278 1279 // An allocation request for a use segment. A use segment is the time segment 1280 // between the definition and the first use, and the time segment between the 1281 // uses of a buffer. For example, the time between the definition and Use1, is 1282 // the first segment, and the time between Use1 and Use2 is the second segment 1283 // and so on: 1284 // 1285 // +------+----------+-------+ 1286 // / \ \ \ 1287 // / v v v 1288 // Def Use1 Use2 Use3 1289 // <----------> <--------> <-----> 1290 // Segment Segment Segment 1291 // 1292 // start_time and end_time are the start and end logical times of the segment. 1293 // use_times is a sorted sequence of the times of all uses. 1294 // latest_prefetch_time is the latest time we can schedule the CopyDone for a 1295 // prefetch. 1296 // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. 1297 // If earliest_prefetch_time is set, prefetches cannot start before this 1298 // value. 1299 struct AllocationRequest { 1300 int64_t start_time; 1301 int64_t end_time; 1302 int64_t latest_prefetch_time; 1303 int64_t size; 1304 bool allow_no_copy_alternate_mem_allocation; 1305 std::optional<int64_t> earliest_prefetch_time; 1306 AliasedOffset* preferred_offset; 1307 const MemorySpaceAssignment::AllocationValue::Use* use; 1308 MemorySpaceAssignment::AllocationValue* allocation_value; 1309 absl::Span<const int64_t> all_use_times; 1310 }; 1311 1312 // This struct contains mandatory memory assignments at a given time. E.g., an 1313 // input's required memory assignment time would correspond to the definition 1314 // time of the parameter instruction, and an output's time would correspond to 1315 // the time of last use. 1316 struct RequiredMemoryAssignment { 1317 MemorySpaceAssignment::MemorySpace memory_space; 1318 int64_t time; 1319 AliasedOffset* offset; 1320 equals_ignoring_timeRequiredMemoryAssignment1321 bool equals_ignoring_time(const RequiredMemoryAssignment& other) const { 1322 return memory_space == other.memory_space && offset == other.offset; 1323 } 1324 1325 bool operator==(const RequiredMemoryAssignment& other) const { 1326 return memory_space == other.memory_space && time == other.time && 1327 offset == other.offset; 1328 } 1329 1330 bool operator!=(const RequiredMemoryAssignment& other) const { 1331 return !(*this == other); 1332 } 1333 }; 1334 1335 // Result of an allocation, prefetch, eviction etc. request. The result is 1336 // either kSuccess or a bitwise OR of one or more failures. The values are 1337 // unique powers of two. To check if a result contains a particular failure, 1338 // use the result_is method. To add a new failure to a result, use the 1339 // result_mark method. 1340 enum class Result { 1341 // Successful allocation. 1342 kSuccess = 0, 1343 // Allocation failed because we ran out of alternate memory. 1344 kFailOutOfMemory = 1, 1345 // A no-copy allocation couldn't be performed because the previous 1346 // allocation wasn't in the alternate memory space. 1347 kFailPrevAllocationNotInAlternateMem = 2, 1348 // A no-copy allocation couldn't be performed because the live range was too 1349 // long. 1350 kFailLiveRangeTooLong = 4, 1351 // A prefetching couldn't be performed because the live range was too short. 1352 kFailLiveRangeTooShort = 8, 1353 // Ran out of outstanding asynchronous copy limit either during prefetching 1354 // or eviction. 1355 kFailOutOfAsyncCopies = 16, 1356 // A prefetching couldn't be performed because the asynchronous copy 1357 // resource was violated. 1358 kFailViolatesAsyncCopyResource = 32, 1359 // An allocation failure happened that requires uncommitting all the pending 1360 // allocations. Usually this is due to a situation requiring an eviction but 1361 // the eviction couldn't be performed. 1362 kFailRequiresUncommit = 64 1363 }; 1364 1365 // Return true if the result belongs to a failure. result_is(Result result,Result failure)1366 static bool result_is(Result result, Result failure) { 1367 return static_cast<int>(result) & static_cast<int>(failure); 1368 } 1369 1370 // Mark (bitwise OR) a failure to the result. result_mark(Result failure,Result & result)1371 static Result result_mark(Result failure, Result& result) { 1372 result = static_cast<Result>(static_cast<int>(result) | 1373 static_cast<int>(failure)); 1374 return result; 1375 } 1376 1377 // Return true if the result is a failure that requires us to uncommit pending 1378 // chunks. result_requires_uncommit(Result result)1379 static bool result_requires_uncommit(Result result) { 1380 return result_is(result, Result::kFailRequiresUncommit); 1381 } 1382 1383 // Return true if the result is a failure either due to running out of 1384 // outstanding asynchronous copies or due to violating asynchronous copy 1385 // ordering. result_failed_because_of_async_copy(Result result)1386 static bool result_failed_because_of_async_copy(Result result) { 1387 return result_is(result, Result::kFailOutOfAsyncCopies) || 1388 result_is(result, Result::kFailViolatesAsyncCopyResource); 1389 } 1390 1391 // Allocates buffers for instructions that need reserved scoped allocations in 1392 // the alternate memory space. 1393 void AllocateReservedScopedAllocations(); 1394 1395 // Returns the AliasedOffset object associated with the allocation. 1396 AliasedOffset* GetAliasedOffset( 1397 const MemorySpaceAssignment::Allocation& allocation); 1398 1399 // If aliased_offset is non-null, this method adds the allocation to 1400 // aliased_offset. Otherwise, it creates a new AliasedOffset object and adds 1401 // the allocation to this new AliasedOffset. 1402 void CreateOrAddToAliasedOffset( 1403 const MemorySpaceAssignment::Allocation& allocation, 1404 AliasedOffset* aliased_offset); 1405 1406 // Given an allocation sequence, returns the live allocation at time with a 1407 // preference towards allocations in alternate memory. Returns nullptr if no 1408 // allocation is alive at that time. 1409 static MemorySpaceAssignment::Allocation* GetLiveAllocationAt( 1410 const MemorySpaceAssignment::AllocationSequence& allocations, 1411 int64_t time); 1412 1413 // Returns true if the use is allowed in the alternate memory. 1414 bool IsUseAllowedInAlternateMemory(const AllocationValue& value, 1415 const HloUse& use) const; 1416 1417 // Finds allocations for allocation values generated from colocated intervals. 1418 // All of the allocation values have a must-alias relationship with each 1419 // other. Returns either kSuccess if all of the sites could be placed in the 1420 // alternate memory or a bitwise OR of failure reasons why they couldn't 1421 Result AllocateAllocationValues( 1422 absl::Span<AllocationValue> allocation_values); 1423 1424 // Finds an allocation for an allocation request for a segment (see the 1425 // documentation for AllocationRequest above how a segment is defined). 1426 // 1427 // It performs three things in the following order: 1428 // 1- Allocate the allocation request entirely in the alternate memory, if 1429 // there is enough space and if the prefetch interval picker allows. 1430 // 2- If (1) was unsuccessful, and the only allocation for 1431 // this buffer was in the alternate memory, we try to perform a prefetch. 1432 // 3- If (1) was unsuccessful, prefetch the buffer into the alternate memory, 1433 // if there is enough space and if the prefetch interval picker allows. 1434 // 1435 // If an eviction (2) was requested and was unsuccessful, this method returns 1436 // Result::kFailRequiresUncommit. This means we could not find a suitable 1437 // allocation, so all previous allocations for this buffer must be removed and 1438 // allocated in the default memory. Otherwise, this method may return 1439 // Result::kSuccess if the buffer could be placed in alternate memory or some 1440 // other Result with an OR of reasons why the buffer couldn't be placed in 1441 // alternate memory. 1442 Result AllocateSegment(const AllocationRequest& request); 1443 1444 // Try allocating in alternate memory without any copies. 1445 Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); 1446 1447 // Try evicting to default memory space. 1448 Result Evict(const AllocationRequest& request); 1449 1450 // Returns the time a copy done of a prefetch should be scheduled. 1451 int64_t FindPrefetchEndTime(const AllocationRequest& request, 1452 int64_t earliest_prefetch_time) const; 1453 1454 // Try prefetching to alternate memory space. 1455 Result Prefetch( 1456 const AllocationRequest& request, 1457 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); 1458 1459 // Find the best possible chunk candidate, where it has the longest possible 1460 // availability if no preferred offset is given, or at the preferred_offset if 1461 // it is given. 1462 std::optional<Chunk> FindBestChunkCandidate( 1463 const AllocationRequest& request, const AliasedOffset* preferred_offset, 1464 BufferInterval* alternate_mem_interval) const; 1465 1466 // Returns the required assignment at a particular time, if available. 1467 std::optional<RequiredMemoryAssignment> RequiredMemoryAssignmentAt( 1468 const HloValue* buffer, int64_t time) const; 1469 1470 // Searches for aliases in the use for a required assignment, and returns it 1471 // if found. 1472 std::optional<RequiredMemoryAssignment> AliasedRequiredAssignmentForUse( 1473 const AllocationValue::Use& use) const; 1474 1475 // Goes through the colocated intervals and adds any required assignment. 1476 void AddRequiredAssignmentsForColocatedIntervals( 1477 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const> 1478 colocated_intervals); 1479 1480 // Propagates aliased required assignment for a given position. 1481 void AddAliasedRequiredAssignment( 1482 const HloInstruction* instruction, ShapeIndex index, 1483 const MemorySpaceAssignment::Allocation* aliased_allocation); 1484 1485 // This sets a required assignment. CHECK fails if there is a conflicting 1486 // required assignment at the same time. 1487 void AddRequiredAssignment(const HloValue* value, 1488 const HloInstruction* instruction, 1489 MemorySpace memory_space, int64_t time, 1490 AliasedOffset* offset = nullptr); 1491 void AddRequiredAssignment(const HloInstruction* instruction, 1492 ShapeIndex index, MemorySpace memory_space, 1493 AliasedOffset* offset = nullptr); 1494 1495 // Adds input and outputs as required assignments. 1496 void AddInputAndOutputRequiredAssignments(); 1497 1498 // Returns true if the colocated intervals in the argument are in a parameter 1499 // or root instruction of the entry computation and are reserved by the user 1500 // to be in the alternate memory space. 1501 bool AreIntervalsReservedInAlternateMemory( 1502 absl::Span<const BufferInterval* const> colocated_intervals) const; 1503 1504 // Since the allocations are recorded to the AllocationSequence, we don't 1505 // maintain result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap 1506 // to avoid unnecessarily adding the chunk to the chunk map. AddToChunkMap(const HloValue * buffer,Chunk chunk)1507 void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} 1508 1509 // Returns true if the addition of an asynchronous copy in the given time 1510 // interval would violate the maximum number of asynchronous copies. An extra 1511 // async copy limit can be provided to increase the limit of asynchronous 1512 // copies for this instance. 1513 bool ViolatesMaximumOutstandingAsyncCopies( 1514 int64_t start_time, int64_t end_time, bool is_prefetch, 1515 int64_t extra_async_copy_limit = 0) const; 1516 1517 // Exports the allocations for repacking and puts them into the vector in the 1518 // parameter. 1519 void ExportAllocationsForRepacking( 1520 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& 1521 allocations); 1522 1523 // Imports repacked allocations and updates the internal data structures 1524 // consistent with the new packing. 1525 void ImportRepackedAllocations(); 1526 1527 // Adds an asynchronous copy to the allocations. 1528 void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, 1529 MemorySpace memory_space, std::optional<Chunk> chunk, 1530 int64_t start_time, int64_t end_time, 1531 int64_t copy_done_schedule_before_time, 1532 MemorySpaceAssignment::AllocationSequence* allocations, 1533 AliasedOffset* aliased_offset, float resource, 1534 bool is_cross_program_prefetch = false); 1535 1536 // This method is used for committing the chunk candidate but adding it to 1537 // pending_chunks_ so that we can "uncommit" them in case we need to roll back 1538 // this allocation sequence. 1539 void AddToPendingChunks(const BufferInterval& buffer_interval, 1540 const Chunk& chunk); 1541 // If we need to remove the allocations for this allocation sequence, this 1542 // removes pending chunks and asynchronous copies in the respective pending 1543 // buffers from the interval trees. If an allocation request returns 1544 // kFailRequiresUncommit, this method must be called. 1545 void UncommitPendingChunks(absl::Span<AllocationValue> allocation_values); 1546 1547 // Finalizes the allocations where they can no longer be uncommitted. 1548 void FinalizeAllocations(absl::Span<AllocationValue> allocation_values); 1549 1550 // Clears all pending chunks and asynchronous copies. 1551 void ClearPendingChunks(); 1552 1553 // Append buffer and allocation infos for debugging and dump it into a file, 1554 // if enabled. 1555 void AppendBufferInfoDebugString(const BufferInterval& interval, 1556 std::string* debug_str) const; 1557 void AppendScopedAllocationBufferInfoDebugString( 1558 const HloInstruction* instruction, int64_t time, int64_t size, 1559 std::string& debug_str) const; 1560 void AppendAllocationInfoDebugString( 1561 const MemorySpaceAssignment::Allocation& allocation, 1562 std::string& debug_str) const; 1563 void DumpDebugStringsIfEnabled() const; 1564 1565 // Returns the available heap size in the alternate memory. available_heap_size()1566 int64_t available_heap_size() const { 1567 return options_.max_size_in_bytes - reserved_in_bytes_; 1568 } 1569 1570 // Returns the earliest time in the [start_time, end_time] range that a new 1571 // allocation with the given size would fit in the alternate memory. If it 1572 // doesn't fit, it returns nullopt. 1573 std::optional<int> FindEarliestTimeToSatisfyPeakMemory(int start_time, 1574 int end_time, 1575 int64_t size) const; 1576 1577 // Creates and returns a RepackAllocationBlock. MakeRepackAllocationBlock(int64_t start_time,int64_t end_time,int64_t size,int64_t initial_offset,int64_t id,MemorySpaceAssignment::Allocation * allocation)1578 static RepackAllocationBlock MakeRepackAllocationBlock( 1579 int64_t start_time, int64_t end_time, int64_t size, 1580 int64_t initial_offset, int64_t id, 1581 MemorySpaceAssignment::Allocation* allocation) { 1582 RepackAllocationBlock allocation_block; 1583 allocation_block.start_time = start_time; 1584 allocation_block.end_time = end_time; 1585 allocation_block.size = size; 1586 allocation_block.offset = -1; 1587 allocation_block.initial_offset = initial_offset; 1588 allocation_block.id = id; 1589 allocation_block.colocations = {}; 1590 allocation_block.allocation = allocation; 1591 return allocation_block; 1592 } 1593 1594 MemorySpaceAssignment::AllocationSequence* allocations_; 1595 const Options& options_; 1596 const HloAliasAnalysis& alias_analysis_; 1597 const HloLiveRange& hlo_live_range_; 1598 // We use a interval tree to keep track of the number of outstanding 1599 // prefetches and evictions. 1600 BufferIntervalTree prefetch_interval_tree_; 1601 BufferIntervalTree eviction_interval_tree_; 1602 AsynchronousCopyResource prefetch_async_copy_resource_; 1603 AsynchronousCopyResource eviction_async_copy_resource_; 1604 // A list of RepackAllocationBlock objects that mirrors allocation sequences, 1605 // used for repacking. We use a list here because we need pointer stability 1606 // for aliased allocations. 1607 std::list<RepackAllocationBlock> repack_allocation_blocks_; 1608 int64_t num_repacks_ = 0; 1609 std::vector<std::pair<BufferInterval, Chunk>> pending_chunks_; 1610 std::vector<AsynchronousCopy> pending_async_copies_; 1611 std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>> 1612 pending_required_assignments_; 1613 // A cache to keep the peak memory usage at each point in the graph. We use 1614 // this to see if the proposed allocation in the alternate memory would fit 1615 // ignoring fragmentation, and if not, we can skip the more expensive lookup 1616 // in the BufferIntervalTree, which also considers fragmentation. 1617 std::vector<int64_t> peak_memory_usage_; 1618 // The data structure that contains AliasedOffset objects and Allocation to 1619 // AliasedOffset map for efficient lookup. 1620 std::list<AliasedOffset> aliased_offsets_; 1621 absl::flat_hash_map<const MemorySpaceAssignment::Allocation*, AliasedOffset*> 1622 aliased_offset_map_; 1623 // This map contains required memory assignments for HloValues (e.g., input 1624 // and outputs). 1625 absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>> 1626 required_assignments_; 1627 // Number of bytes reserved in alternate memory space. 1628 int64_t reserved_in_bytes_ = 0; 1629 // A rough measure of the memory pressure of the model, in bytes. Note that 1630 // this is pressure for memory capacity (and not accessed bytes), and for 1631 // alternate memory (not default memory). 1632 int64_t memory_pressure_ = 0; 1633 int64_t next_async_copy_id_ = 0; 1634 // Debug strings. 1635 std::string buffer_info_str_; 1636 std::string allocation_info_str_; 1637 }; 1638 } // namespace memory_space_assignment 1639 } // namespace xla 1640 1641 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ 1642