xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/memory_space_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <limits>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/service/memory_space_assignment_tuning_utils.h"
30 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33 namespace xla {
34 
35 namespace memory_space_assignment {
36 
37 namespace {
38 // Define a dummy chunk for chunks that will be allocated in the default memory
39 // space and for keeping track of number of asynchronous copies.
40 const HeapSimulator::Chunk kDummyChunk{-1, -1};
41 // For cross-program prefetched buffer, we only perform the freeing optimization
42 // if the buffer occupies less of the execution time ratio than this value.
43 const float kCrossProgramPrefetchOccupyFreeingLimit = 0.6;
44 // Each time we retry compilation, increase the preferred eviction end time by
45 // this amount multiplied by preferred overlap to async copy ratio.
46 const float kEvictionRetryMultiplier = 2.0;
47 
LooksLikeAnActivation(const HloInstruction * inst)48 bool LooksLikeAnActivation(const HloInstruction* inst) {
49   for (HloInstruction* user : inst->users()) {
50     switch (user->opcode()) {
51       case HloOpcode::kConvolution:
52       case HloOpcode::kDot:
53         if (user->operand(0) == inst) {
54           return true;
55         }
56         break;
57       case HloOpcode::kGather:
58         if (user->operand(1) == inst) {
59           return true;
60         }
61         break;
62       case HloOpcode::kFusion:
63         for (int i = 0; i < user->operand_count(); ++i) {
64           if (user->operand(i) == inst &&
65               LooksLikeAnActivation(user->fused_parameter(i))) {
66             return true;
67           }
68         }
69         break;
70       case HloOpcode::kBitcast:
71       case HloOpcode::kBroadcast:
72       case HloOpcode::kTranspose:
73         if (LooksLikeAnActivation(user)) {
74           return true;
75         }
76         break;
77       case HloOpcode::kDynamicUpdateSlice:
78       case HloOpcode::kDynamicSlice:
79         if (std::find(user->operands().begin() + 1, user->operands().end(),
80                       inst) != user->operands().end()) {
81           return true;
82         }
83         if (LooksLikeAnActivation(user)) {
84           return true;
85         }
86         break;
87       case HloOpcode::kReduce:
88         // Check init operands.
89         if (std::find(user->operands().begin() + user->operand_count() / 2,
90                       user->operands().end(), inst) != user->operands().end()) {
91           return true;
92         }
93         if (LooksLikeAnActivation(user)) {
94           return true;
95         }
96         break;
97       default:
98         return true;
99     }
100   }
101   return false;
102 }
103 
IsCrossProgramPrefetchCandidate(const HloValue & value,const Options & options)104 bool IsCrossProgramPrefetchCandidate(const HloValue& value,
105                                      const Options& options) {
106   return value.defining_instruction()->parent() ==
107              value.defining_instruction()->GetModule()->entry_computation() &&
108          value.defining_instruction()->opcode() == HloOpcode::kParameter &&
109          (!value.shape().has_layout() ||
110           value.shape().layout().memory_space() !=
111               options.alternate_memory_space) &&
112          value.index().size() <= 1 && value.shape().IsArray() &&
113          !value.GetUses().empty() &&
114          options.size_fn(value) <= options.max_size_in_bytes &&
115          absl::c_all_of(value.GetUses(), [&](const HloUse& use) {
116            const HloInstruction* inst =
117                use.instruction->operand(use.operand_number);
118 
119            // Skip the LooksLikeAnActivation test since we're testing the
120            // parent GTE/parameter and its children below.
121            if (inst->opcode() == HloOpcode::kBitcast &&
122                ((inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
123                  inst->operand(0)->operand(0)->opcode() ==
124                      HloOpcode::kParameter) ||
125                 inst->operand(0)->opcode() == HloOpcode::kParameter)) {
126              return true;
127            }
128 
129            return (inst->opcode() == HloOpcode::kGetTupleElement ||
130                    inst->opcode() == HloOpcode::kParameter) &&
131                   !LooksLikeAnActivation(inst);
132          });
133 }
134 
135 std::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const Options & options)136 FindCrossProgramPrefetchCandidate(const HloAliasAnalysis& alias_analysis,
137                                   const HloLiveRange& hlo_live_range,
138                                   const Options& options) {
139   std::vector<MemorySpaceAssignment::BufferInterval> candidates;
140   for (const HloBuffer& buffer : alias_analysis.buffers()) {
141     CHECK_GE(buffer.values().size(), 1);
142     const HloValue* value = buffer.values().at(0);
143     if (IsCrossProgramPrefetchCandidate(*value, options)) {
144       MemorySpaceAssignment::BufferInterval interval;
145       interval.buffer = value;
146       interval.size = options.size_fn(*value);
147       interval.start = 0;
148       interval.end = hlo_live_range.schedule_end_time();
149       interval.need_allocation = true;
150       interval.colocations = {++buffer.values().begin(), buffer.values().end()};
151       candidates.emplace_back(interval);
152     }
153   }
154 
155   // The BufferIntervalCompare function used to sort buffers implements the
156   // greater-than operator so that the most beneficial buffers are allocated
157   // first. The size_compare function below hence uses the greater-than operator
158   // to pick the largest buffer.
159   auto size_compare = [](const auto& x, const auto& y) {
160     if (x.size == y.size) {
161       // When both buffers are of same size, we prefer the one that is used to
162       // produce larger tensors in its consumer instructions.
163       auto get_use_size =
164           [](const MemorySpaceAssignment::BufferInterval& bi) -> int64_t {
165         int64_t use_size = 0;
166         for (const auto& use : bi.buffer->GetUses()) {
167           use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape());
168         }
169         return use_size;
170       };
171       return get_use_size(x) > get_use_size(y);
172     }
173     return x.size > y.size;
174   };
175   auto& compare = options.default_cross_program_prefetch_heuristic &&
176                           options.buffer_interval_compare
177                       ? *options.buffer_interval_compare
178                       : size_compare;
179 
180   auto best_candidate = absl::c_min_element(candidates, compare);
181   if (best_candidate == candidates.end()) {
182     return std::nullopt;
183   }
184   VLOG(3) << "Cross-program prefetch candidate picked: "
185           << best_candidate->buffer->ToString();
186   return *best_candidate;
187 }
188 
189 Status InsertInstructionAndEnsureOperandsInserted(
190     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
191     absl::flat_hash_set<HloInstruction*>* inserted_instructions);
192 
193 // Insert an instruction to the schedule, and make sure its dependencies
194 // (operands) are already in the schedule. If not, insert these operands
195 // before the instruction.
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions)196 Status EnsureInstructionAndOperandsInserted(
197     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
198     absl::flat_hash_set<HloInstruction*>* inserted_instructions) {
199   if (inserted_instructions->contains(new_instruction)) {
200     return OkStatus();
201   }
202   return InsertInstructionAndEnsureOperandsInserted(
203       new_instruction, new_sequence, inserted_instructions);
204 }
205 
206 // Same as above, but does not check if instruction is already inserted. This is
207 // used when the caller already knows the instruction isn't inserted yet, to
208 // speed up compilation.
InsertInstructionAndEnsureOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions)209 Status InsertInstructionAndEnsureOperandsInserted(
210     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
211     absl::flat_hash_set<HloInstruction*>* inserted_instructions) {
212   for (HloInstruction* operand : new_instruction->operands()) {
213     // CopyStart/CopyDone dependencies should always be already inserted; it is
214     // a red flag when they haven't already been inserted.
215     if (operand->opcode() == HloOpcode::kCopyStart ||
216         operand->opcode() == HloOpcode::kCopyDone) {
217       TF_RET_CHECK(inserted_instructions->contains(operand))
218           << "Inserted instruction " << new_instruction->ToString()
219           << " has un-inserted dependency: " << operand->ToString();
220       continue;
221     }
222     TF_RETURN_IF_ERROR(EnsureInstructionAndOperandsInserted(
223         operand, new_sequence, inserted_instructions));
224   }
225   VLOG(4) << "inserting: " << new_instruction->ToShortString();
226   new_sequence->push_back(new_instruction);
227   TF_RET_CHECK(inserted_instructions->insert(new_instruction).second);
228   return OkStatus();
229 }
230 
UsesToString(const std::vector<HloUse> & uses)231 std::string UsesToString(const std::vector<HloUse>& uses) {
232   if (uses.empty()) {
233     return "none";
234   }
235   std::vector<std::string> uses_str;
236   uses_str.reserve(uses.size());
237   for (const auto& use : uses) {
238     uses_str.push_back(use.ToString());
239   }
240   return absl::StrJoin(uses_str, ",");
241 }
242 
243 }  // namespace
244 
245 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const Options & options,const HloModule & module)246 MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis,
247                                           const Options& options,
248                                           const HloModule& module) {
249   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
250   TF_ASSIGN_OR_RETURN(auto hlo_live_range,
251                       HloLiveRange::Run(module.schedule(), *alias_analysis,
252                                         module.entry_computation()));
253   auto call_graph = CallGraph::Build(&module);
254   return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
255       cost_analysis, options, std::move(alias_analysis),
256       std::move(hlo_live_range), std::move(call_graph)));
257 }
258 
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const259 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
260     const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
261     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
262   float elapsed_time_due_to_compute =
263       GetInstructionElapsedDueToCompute(instruction);
264   float elapsed_time_due_to_memory =
265       GetInstructionElapsedDueToMemory(instruction);
266   if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
267     // Memory bound, return how much alternate memory is better.
268     float while_nest_multiplier;
269     if (cache) {
270       // If there is a cache provided, memoize the while nest multiplier.
271       auto it = cache->while_nest_multiplier.find(&instruction);
272       if (it != cache->while_nest_multiplier.end()) {
273         while_nest_multiplier = it->second;
274       } else {
275         while_nest_multiplier = IPow<float>(
276             options_.xla_tpu_memory_space_assignment_while_execution_count,
277             CalculateComputationNestLevel(&instruction,
278                                           /*while_only=*/true));
279         cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
280       }
281     } else {
282       while_nest_multiplier = IPow<float>(
283           options_.xla_tpu_memory_space_assignment_while_execution_count,
284           CalculateComputationNestLevel(&instruction,
285                                         /*while_only=*/true));
286     }
287     return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
288            while_nest_multiplier;
289   } else {
290     // Compute bound, return how far off are we to memory boundedness.
291     return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
292   }
293 }
294 
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const295 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
296     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
297     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
298   float alternate_mem_benefit =
299       GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache);
300 
301   for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
302            interval.buffer->defining_position().instruction,
303            interval.buffer->defining_position().index)) {
304     for (const HloValue* value : buffer->values()) {
305       for (const HloUse& use : value->GetUses()) {
306         // We look inside the called computations of while and conditional, so
307         // don't use the benefit of while and conditional directly.
308         if (use.instruction->opcode() == HloOpcode::kWhile ||
309             use.instruction->opcode() == HloOpcode::kConditional) {
310           continue;
311         }
312         float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache);
313         // If the benefit is positive (memory bound), add it to this buffer's
314         // benefit. If the benefit is negative (compute bound), calculate the
315         // maximum.
316         if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
317           alternate_mem_benefit += use_alternate_mem_benefit;
318         } else {
319           alternate_mem_benefit =
320               std::max(alternate_mem_benefit, use_alternate_mem_benefit);
321         }
322       }
323     }
324   }
325 
326   // Penalize larger buffers by dividing the benefit by the square root of the
327   // size. Empirically, we observed this resulted in better performance compared
328   // to dividing by the size.
329   return alternate_mem_benefit / std::sqrt(interval.size);
330 }
331 
GetAlternateMemoryBenefit(const HloPosition & position,MemorySpaceAssignmentCostAnalysis::Cache * cache) const332 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
333     const HloPosition& position,
334     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
335   return GetAlternateMemoryBenefit(
336       *position.instruction,
337       GetInstructionElapsedDueToMemory(
338           *position.instruction,
339           /*operands_in_alternate_mem=*/{},
340           /*outputs_in_alternate_mem=*/{position.index}),
341       cache);
342 }
343 
GetAlternateMemoryBenefit(const HloUse & use,MemorySpaceAssignmentCostAnalysis::Cache * cache) const344 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
345     const HloUse& use, MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
346   return GetAlternateMemoryBenefit(
347       *use.instruction,
348       GetInstructionElapsedDueToMemory(
349           *use.instruction,
350           /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number,
351                                                         use.operand_index)}),
352       cache);
353 }
354 
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const355 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
356     const HloInstruction* instruction, bool while_only) const {
357   int nest_level = 0;
358   const HloComputation* computation = instruction->parent();
359   while (!computation->IsEntryComputation()) {
360     auto& node = call_graph_->GetNode(computation);
361     auto callsites = node.caller_callsites();
362     CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1)
363         << "The module is not flattened!";
364     auto& callsite = callsites[0];
365     if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
366       ++nest_level;
367     }
368     computation = callsite.instruction()->parent();
369   }
370   return nest_level;
371 }
372 
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const373 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
374     const HloInstruction& instruction) const {
375   return std::max(
376       cost_analysis_.flop_count(instruction) /
377           cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
378       cost_analysis_.transcendental_count(instruction) /
379           cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
380 }
381 
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const382 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
383     const HloInstruction& instruction,
384     absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
385     absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
386   float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
387   float bytes_accessed_from_alternate_mem = 0.0;
388   for (auto& operand : operands_in_alternate_mem) {
389     float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
390         instruction, operand.first, operand.second);
391     bytes_accessed_from_alternate_mem += operand_bytes_accessed;
392   }
393 
394   for (auto& shape_idx : outputs_in_alternate_mem) {
395     float output_bytes_accessed =
396         cost_analysis_.output_bytes_accessed(instruction, shape_idx);
397     bytes_accessed_from_alternate_mem += output_bytes_accessed;
398   }
399   float elapsed_due_to_alternate_mem =
400       bytes_accessed_from_alternate_mem /
401       options().alternate_mem_bandwidth_bytes_per_second;
402   float elapsed_due_to_default_mem =
403       (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
404       cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
405   return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
406 }
407 
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,IsInAlternateMemoryFun is_in_alternate_mem) const408 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
409     const HloInstruction& instruction,
410     IsInAlternateMemoryFun is_in_alternate_mem) const {
411   float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
412   float bytes_accessed_from_alternate_mem = 0.0;
413   for (int operand_num = 0; operand_num < instruction.operand_count();
414        ++operand_num) {
415     ShapeUtil::ForEachSubshape(
416         instruction.operand(operand_num)->shape(),
417         [&](const Shape& subshape, const ShapeIndex& index) {
418           if (!subshape.IsArray()) {
419             return;
420           }
421           if (is_in_alternate_mem(operand_num, index, subshape)) {
422             bytes_accessed_from_alternate_mem +=
423                 cost_analysis_.operand_bytes_accessed(instruction, operand_num,
424                                                       index);
425           }
426         });
427   }
428   ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape,
429                                                       const ShapeIndex& index) {
430     if (!subshape.IsArray()) {
431       return;
432     }
433     if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) {
434       bytes_accessed_from_alternate_mem +=
435           cost_analysis_.output_bytes_accessed(instruction, index);
436     }
437   });
438   float elapsed_due_to_alternate_mem =
439       bytes_accessed_from_alternate_mem /
440       options().alternate_mem_bandwidth_bytes_per_second;
441   float elapsed_due_to_default_mem =
442       (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
443       cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
444   return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
445 }
446 
GetInstructionElapsed(const HloInstruction & instruction) const447 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
448     const HloInstruction& instruction) const {
449   return std::max(GetInstructionElapsedDueToCompute(instruction),
450                   GetInstructionElapsedDueToMemory(instruction));
451 }
452 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const453 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
454     const HloInstruction& instruction,
455     absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
456     absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
457   return std::max(
458       GetInstructionElapsedDueToCompute(instruction),
459       GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem,
460                                        outputs_in_alternate_mem));
461 }
462 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,IsInAlternateMemoryFun is_in_alternate_mem) const463 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
464     const HloInstruction& instruction,
465     IsInAlternateMemoryFun is_in_alternate_mem) const {
466   return std::max(
467       GetInstructionElapsedDueToCompute(instruction),
468       GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem));
469 }
470 
GetAsyncCopyElapsed(const Shape & shape) const471 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
472     const Shape& shape) const {
473   int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape);
474   return static_cast<float>(size_in_bytes) /
475          options().async_copy_bandwidth_bytes_per_second;
476 }
477 
GetScheduleEndTime() const478 int64_t MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
479   return hlo_live_range_->schedule_end_time();
480 }
481 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const482 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
483     const Shape& shape, int64_t start_time, int64_t end_time) const {
484   return end_time - start_time <= max_overlap_count_;
485 }
486 
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const487 int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
488     const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
489   return std::min(start_time + min_overlap_count_, latest_end_time);
490 }
491 
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const492 int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
493     const Shape& shape, int64_t start_time, int64_t end_time,
494     const HloUse* use) const {
495   return end_time - min_overlap_count_;
496 }
497 
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const498 int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
499     const Shape& shape, int64_t earliest_prefetch_start_time,
500     int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
501   return std::max(earliest_prefetch_start_time,
502                   prefetch_end_time - max_overlap_count_);
503 }
504 
EstimatedPrefetchEndTime(const Shape & shape,int64_t start_time,int64_t end_time) const505 int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime(
506     const Shape& shape, int64_t start_time, int64_t end_time) const {
507   // For testing, assume the end time is the estimated prefetch end time.
508   return end_time;
509 }
510 
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const511 float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed(
512     int64_t start_time, int64_t end_time) const {
513   // For testing, just assume every HLO takes 1 second.
514   return static_cast<float>(end_time - start_time - 1);
515 }
516 
Begin(const HloUse & use,int64_t start_time,int64_t end_time)517 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
518                                                    int64_t start_time,
519                                                    int64_t end_time) {
520   end_time_ = end_time;
521   const Shape& shape = ShapeUtil::GetSubshape(
522       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
523   current_prefetch_time_ =
524       PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
525 }
526 
Next()527 int64_t InstructionCountPrefetchIntervalPicker::Next() {
528   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
529                     "Done() is false";
530   return current_prefetch_time_++;
531 }
532 
Done() const533 bool InstructionCountPrefetchIntervalPicker::Done() const {
534   return end_time_ - current_prefetch_time_ <= min_overlap_count_;
535 }
536 
latest_time() const537 int64_t InstructionCountPrefetchIntervalPicker::latest_time() const {
538   return end_time_ - min_overlap_count_ - 1;
539 }
540 
ToDebugString() const541 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
542   return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
543 }
544 
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const545 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
546     const Shape& shape, int64_t start_time, int64_t end_time) const {
547   return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
548 }
549 
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_overlap_to_async_copy_ratio,float preferred_overlap_to_async_copy_ratio,float max_overlap_to_mem_size_async_copy_ratio,int64_t mem_size_bytes)550 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
551     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
552     float min_overlap_to_async_copy_ratio,
553     float preferred_overlap_to_async_copy_ratio,
554     float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes)
555     : while_nest_level_(
556           cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
557       computation_nest_level_(
558           cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
559       cost_analysis_(cost_analysis),
560       min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio),
561       preferred_overlap_to_async_copy_ratio_(
562           preferred_overlap_to_async_copy_ratio),
563       max_async_copy_elapsed_(
564           cost_analysis_.GetAsyncCopyElapsed(
565               ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) *
566           max_overlap_to_mem_size_async_copy_ratio) {
567   instruction_schedule_ =
568       &cost_analysis_.hlo_live_range().instruction_schedule();
569 
570   // Create a vector of elapsed times and while nesting levels of HLO
571   // instructions. The elapsed times are multiplied by
572   // pow(while_execution_count, nest_level) to account for executing the HLOs
573   // multiple times in while loops.
574   std::vector<float> instructions_elapsed_time(
575       instruction_schedule_->size() + 1, 0.0);
576   int max_while_nest_level = 0;
577   for (const auto& instruction_and_logical_time : *instruction_schedule_) {
578     // To avoid double counting, don't include the elapsed time of while and
579     // conditional HLOs.
580     const HloInstruction* instruction = instruction_and_logical_time.first;
581     int64_t logical_time = instruction_and_logical_time.second;
582     if (logical_time >= instructions_elapsed_time.size()) {
583       instructions_elapsed_time.resize(logical_time + 1, 0.0);
584       while_nest_level_.resize(logical_time + 1, 0);
585     }
586     int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
587         instruction_and_logical_time.first, /*while_only=*/true);
588     while_nest_level_[logical_time] = while_nest_level;
589     max_while_nest_level = std::max(max_while_nest_level, while_nest_level);
590     int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
591         instruction_and_logical_time.first, /*while_only=*/false);
592     computation_nest_level_[logical_time] = computation_nest_level;
593     if (instruction->opcode() == HloOpcode::kWhile ||
594         instruction->opcode() == HloOpcode::kConditional) {
595       continue;
596     }
597     float elapsed_time = cost_analysis_.GetInstructionElapsed(
598         *instruction_and_logical_time.first);
599     instructions_elapsed_time[logical_time] =
600         elapsed_time *
601         IPow<float>(cost_analysis_.options()
602                         .xla_tpu_memory_space_assignment_while_execution_count,
603                     while_nest_level);
604   }
605   // As an optimization, create a cumulative sum vector of elapsed time.
606   float cumsum = 0.0;
607   elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
608   for (float elapsed_time : instructions_elapsed_time) {
609     cumsum += elapsed_time;
610     elapsed_time_cumsum_.push_back(cumsum);
611   }
612   // To be able to accurately determine the minimum nest level between a start
613   // time and an end time efficiently, populate a data structure that stores the
614   // closest 'smaller' nest level change index.
615   const int64_t size = instructions_elapsed_time.size();
616   CHECK_EQ(size, while_nest_level_.size());
617   std::vector<int> most_recent_by_level(while_nest_level_.size(), -1);
618   int prev_nest_level = 0;
619   int change_idx = -1;
620   while_nest_level_change_.reserve(size);
621   for (int i = 0; i < size; ++i) {
622     int nest_level = while_nest_level_[i];
623     if (nest_level != prev_nest_level) {
624       prev_nest_level = nest_level;
625       // Compute last change index by choosing the most recent instruction index
626       // with smaller nesting level. Note that it may happen that even though
627       // there were few different regions with other nest levels before, all of
628       // then are same or bigger than this one, in which case we'll end up with
629       // -1, e.g. if you got nest level 0 no need checking anything else.
630       change_idx = -1;
631       for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) {
632         change_idx = std::max(change_idx, most_recent_by_level[smaller_level]);
633       }
634     }
635     most_recent_by_level[nest_level] = i;
636     while_nest_level_change_.push_back(change_idx);
637   }
638   for (int i = 0; i <= max_while_nest_level; ++i) {
639     while_execution_counts_.push_back(
640         IPow<float>(cost_analysis_.options()
641                         .xla_tpu_memory_space_assignment_while_execution_count,
642                     i));
643   }
644 }
645 
GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const646 float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory(
647     float async_copy_elapsed) const {
648   return max_async_copy_elapsed_;
649 }
650 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const651 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
652     const Shape& shape, int64_t start_time, int64_t end_time) const {
653   // Even though this method returns if we allow the buffer in alternate memory
654   // _without_ asynchronous copies, calculate how long it would have taken to
655   // copy it and compare it to the elapsed time in the logical interval.
656   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
657   float logical_interval_elapsed =
658       GetLogicalIntervalElapsed(start_time, end_time);
659   return GetMaxElapsedInAlternateMemory(async_copy_elapsed) >
660          logical_interval_elapsed;
661 }
662 
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const663 int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
664     const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
665   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
666   int64_t end_time;
667   for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
668     float logical_interval_elapsed =
669         GetLogicalIntervalElapsed(start_time, end_time);
670     if (logical_interval_elapsed >=
671         (1 + kEvictionRetryMultiplier * retry_number_) *
672             preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) {
673       break;
674     }
675   }
676   return end_time;
677 }
678 
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const679 int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
680     const Shape& shape, int64_t start_time, int64_t end_time,
681     const HloUse* use) const {
682   // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_.
683   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
684   // If there is a use, estimate the time we would save by having this op in
685   // alternate memory.
686   float inst_elapsed_reduction = 0.0f;
687   if (use) {
688     float elapsed_time =
689         cost_analysis_.GetInstructionElapsed(*use->instruction);
690     float elapsed_time_in_alternate_mem =
691         cost_analysis_.GetInstructionElapsedInAlternateMemory(
692             *use->instruction,
693             /*operands_in_alternate_mem=*/
694             {std::make_pair(use->operand_number, use->operand_index)},
695             /*outputs_in_alternate_mem=*/{});
696     inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
697   }
698   int end_nest_level = computation_nest_level_[end_time];
699 
700   // Find the latest time we're allowed to start prefetching.
701   float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed;
702   int latest_prefetch_time;
703   for (latest_prefetch_time = end_time - 1;
704        latest_prefetch_time >= start_time &&
705        (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
706         min_interval >
707             GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
708                 inst_elapsed_reduction);
709        --latest_prefetch_time) {
710   }
711 
712   return latest_prefetch_time;
713 }
714 
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const715 int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
716     const Shape& shape, int64_t earliest_prefetch_start_time,
717     int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
718   // Between the earliest and latest prefetch interval, find the interval
719   // closest to the preferred interval and start iterating from there.
720   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
721   int64_t preferred_prefetch_start_time = earliest_prefetch_start_time;
722   float preferred_interval =
723       preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed;
724   float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
725                                                   prefetch_end_time);
726   int end_nest_level = computation_nest_level_[prefetch_end_time];
727   for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1;
728        prefetch_start_time <= latest_prefetch_start_time;
729        ++prefetch_start_time) {
730     float interval =
731         GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
732     if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
733         std::abs(preferred_interval - interval) <
734             std::abs(preferred_interval - best_interval)) {
735       best_interval = interval;
736       preferred_prefetch_start_time = prefetch_start_time;
737     }
738   }
739   return preferred_prefetch_start_time;
740 }
741 
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time) const742 int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
743     int64_t original_prefetch_end_time,
744     int64_t proposed_prefetch_end_time) const {
745   // Iterate towards the beginning until we find a suitable end time that is the
746   // same while nest level as the original prefetch end time.
747   int64_t original_nest_level =
748       computation_nest_level_[original_prefetch_end_time];
749   int64_t new_prefetch_end_time;
750   for (new_prefetch_end_time = proposed_prefetch_end_time;
751        computation_nest_level_[new_prefetch_end_time] != original_nest_level;
752        --new_prefetch_end_time) {
753   }
754   return new_prefetch_end_time;
755 }
756 
EstimatedPrefetchEndTime(const Shape & shape,int64_t start_time,int64_t end_time) const757 int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime(
758     const Shape& shape, int64_t start_time, int64_t end_time) const {
759   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
760   int64_t estimated_end_time;
761   for (estimated_end_time = start_time + 1; estimated_end_time < end_time;
762        ++estimated_end_time) {
763     float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time);
764     if (interval >= async_copy_elapsed) {
765       break;
766     }
767   }
768   return estimated_end_time;
769 }
770 
Begin(const HloUse & use,int64_t start_time,int64_t end_time)771 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
772                                                int64_t start_time,
773                                                int64_t end_time) {
774   const Shape& shape = ShapeUtil::GetSubshape(
775       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
776   // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_.
777   async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
778   // Estimate the time we would save by having this op in alternate memory.
779   float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
780   float elapsed_time_in_alternate_mem =
781       cost_analysis_.GetInstructionElapsedInAlternateMemory(
782           *use.instruction, /*operands_in_alternate_mem=*/
783           {std::make_pair(use.operand_number, use.operand_index)},
784           /*outputs_in_alternate_mem=*/{});
785   inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
786   end_logical_time_ = end_time;
787   int end_nest_level = computation_nest_level_[end_logical_time_];
788 
789   // Find the latest time we're allowed to start prefetching.
790   float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_;
791   latest_prefetch_time_ =
792       LatestPrefetchStartTime(shape, start_time, end_time, &use);
793 
794   // Find the earliest time we're allowed to start prefetching.
795   float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_);
796   for (earliest_prefetch_time_ = start_time;
797        earliest_prefetch_time_ < latest_prefetch_time_ &&
798        (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
799         max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
800                                                  end_logical_time_));
801        ++earliest_prefetch_time_) {
802   }
803   if (earliest_prefetch_time_ > latest_prefetch_time_) {
804     // There is no available prefetch interval for the given start and end
805     // times. Set the iterators accordingly to ensure Done() returns true.
806     increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
807     decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
808     CHECK(Done());
809     return;
810   }
811 
812   int64_t starting_prefetch_time = PreferredPrefetchStartTime(
813       shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
814   float preferred_interval =
815       preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_;
816   VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
817           << max_interval << " " << preferred_interval
818           << " prefetch time earliest/latest/starting = "
819           << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
820           << starting_prefetch_time;
821 
822   increasing_prefetch_time_iterator_ = starting_prefetch_time;
823   decreasing_prefetch_time_iterator_ = starting_prefetch_time;
824   using_increasing_prefetch_time_iterator_ = true;
825   // Since both iterators start at the same position, call Next() once to
826   // advance one of the iterators.
827   Next();
828 }
829 
Next()830 int64_t CostAnalysisPrefetchIntervalPicker::Next() {
831   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
832                     "Done() is false";
833   if (using_increasing_prefetch_time_iterator_) {
834     int64_t prefetch_time = increasing_prefetch_time_iterator_++;
835     while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
836            computation_nest_level_[increasing_prefetch_time_iterator_] !=
837                computation_nest_level_[end_logical_time_]) {
838       ++increasing_prefetch_time_iterator_;
839     }
840     if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
841       using_increasing_prefetch_time_iterator_ = false;
842     }
843     return prefetch_time;
844   } else {
845     int64_t prefetch_time = decreasing_prefetch_time_iterator_--;
846     while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
847            computation_nest_level_[decreasing_prefetch_time_iterator_] !=
848                computation_nest_level_[end_logical_time_]) {
849       --decreasing_prefetch_time_iterator_;
850     }
851     if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
852       using_increasing_prefetch_time_iterator_ = true;
853     }
854     return prefetch_time;
855   }
856 }
857 
Done() const858 bool CostAnalysisPrefetchIntervalPicker::Done() const {
859   return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
860          decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
861 }
862 
latest_time() const863 int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const {
864   return latest_prefetch_time_;
865 }
866 
SetRetryNumber(int retry_number)867 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
868   retry_number_ = retry_number;
869 }
870 
GetMinWhileNestLevel(int64_t start_time,int64_t end_time) const871 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
872     int64_t start_time, int64_t end_time) const {
873   int min_nest_level =
874       std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
875   int change_idx = while_nest_level_change_[end_time];
876   while (change_idx >= start_time) {
877     min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
878     change_idx = while_nest_level_change_[change_idx];
879   }
880   return min_nest_level;
881 }
882 
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const883 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
884     int64_t start_time, int64_t end_time) const {
885   CHECK_LE(start_time, end_time);
886   if (start_time == end_time) {
887     return 0.0;
888   }
889   if (start_time < 0) {
890     start_time = 0;
891   }
892   // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
893   // level, normalize the elapsed time by dividing with the nesting factor of
894   // the interval (start and end times).
895   int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
896   return (elapsed_time_cumsum_[end_time - 1] -
897           elapsed_time_cumsum_[start_time]) /
898          while_execution_counts_[interval_while_nest_level];
899 }
900 
ToDebugString() const901 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
902   int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
903                                           ? increasing_prefetch_time_iterator_
904                                           : decreasing_prefetch_time_iterator_;
905   float logical_interval_elapsed = GetLogicalIntervalElapsed(
906       current_logical_prefetch_time, end_logical_time_);
907   return absl::StrCat(
908       "Async copy elapsed (s) = ", async_copy_elapsed_,
909       ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
910       ", logical interval elapsed (s) = ", logical_interval_elapsed,
911       ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
912       ")");
913 }
914 
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const915 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
916     const Shape& shape, int64_t start_time, int64_t end_time) const {
917   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
918   float logical_interval_elapsed =
919       GetLogicalIntervalElapsed(start_time, end_time);
920   return absl::StrCat(
921       "Async copy elapsed (s) = ", async_copy_elapsed,
922       ", logical interval elapsed (s) = ", logical_interval_elapsed);
923 }
924 
925 std::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const926 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
927     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
928     const {
929   return cost_analysis_.GetMemoryBoundedness(interval);
930 }
931 
operator ==(const MemorySpaceAssignment::Allocation & other) const932 bool MemorySpaceAssignment::Allocation::operator==(
933     const MemorySpaceAssignment::Allocation& other) const {
934   return defining_position() == other.defining_position() &&
935          uses() == other.uses() && memory_space() == other.memory_space() &&
936          chunk() == other.chunk() && start_time() == other.start_time() &&
937          end_time() == other.end_time() &&
938          earliest_available_time() == other.earliest_available_time() &&
939          is_copy_allocation() == other.is_copy_allocation() &&
940          is_scoped_allocation() == other.is_scoped_allocation();
941 }
942 
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const943 bool MemorySpaceAssignment::CopyAllocation::operator==(
944     const MemorySpaceAssignment::CopyAllocation& other) const {
945   return static_cast<const Allocation&>(*this) ==
946              static_cast<const Allocation&>(other) &&
947          copy_done_schedule_before() == other.copy_done_schedule_before() &&
948          copy_start_schedule_after() == other.copy_start_schedule_after() &&
949          copy_start() == other.copy_start() && copy_done() == other.copy_done();
950 }
951 
ToString() const952 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
953   std::string out = absl::StrCat("computation = ", computation()->name());
954   absl::StrAppend(&out,
955                   (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
956   absl::StrAppend(&out, "\n position:\n");
957   absl::StrAppend(&out, "  ", defining_position_.ToString(), "\n");
958   absl::StrAppend(&out, " uses:\n");
959   for (const Use& use : uses_) {
960     absl::StrAppend(&out, "  ", use.hlo_use.ToString(), "\n");
961   }
962   return out;
963 }
964 
ToShortString() const965 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
966   return absl::StrCat("computation = ", computation()->name(),
967                       ", position = ", defining_position_.ToString(),
968                       ", value = ", value_->ToShortString(),
969                       (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
970 }
971 
AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)972 AlternateMemoryBestFitHeap::AlternateMemoryBestFitHeap(
973     MemorySpaceAssignment::AllocationSequence* allocations,
974     const Options& options, const HloAliasAnalysis& alias_analysis,
975     const HloLiveRange& hlo_live_range)
976     : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
977       allocations_(allocations),
978       options_(options),
979       alias_analysis_(alias_analysis),
980       hlo_live_range_(hlo_live_range),
981       peak_memory_usage_(hlo_live_range.schedule_end_time() + 1) {
982   // Override buffer interval compare if provided.
983   if (options.buffer_interval_compare) {
984     buffer_interval_compare_ = *options.buffer_interval_compare;
985   }
986 
987   std::vector<float> initial_resources(hlo_live_range.schedule_end_time(), 1.0);
988   if (options.cost_analysis) {
989     const std::vector<HloInstruction*>& flattened_instructions =
990         hlo_live_range.flattened_instruction_sequence().instructions();
991     for (int i = 0; i < flattened_instructions.size(); ++i) {
992       const HloInstruction* inst = flattened_instructions[i];
993       if (inst->opcode() == HloOpcode::kWhile ||
994           inst->opcode() == HloOpcode::kConditional) {
995         initial_resources[i] = 0;
996       } else {
997         initial_resources[i] =
998             options.cost_analysis->GetInstructionElapsed(*inst);
999       }
1000       VLOG(2) << "Initial resource[" << i << "] = " << initial_resources[i]
1001               << " (" << inst->name() << ")";
1002     }
1003   }
1004   prefetch_async_copy_resource_ = AsynchronousCopyResource(initial_resources);
1005   eviction_async_copy_resource_ = AsynchronousCopyResource(initial_resources);
1006 }
1007 
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const1008 void AlternateMemoryBestFitHeap::CreateAllocationValues(
1009     const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
1010     std::vector<AllocationValue>& allocation_values) const {
1011   const HloValue* value = buffer_interval.buffer;
1012   VLOG(3) << "Creating AllocationValues for: " << value->ToString();
1013 
1014   // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
1015   // positions. We create an AllocationValue object for each non-trivial
1016   // position. And for each AllocationValue object, we create an
1017   // AllocationSequence consisting of one or more Allocation objects.The reason
1018   // why we exclude the trivial positions from AllocationValue is because
1019   // Allocation objects have special support for tuples and bitcasts.
1020   const absl::flat_hash_map<const HloInstruction*, int64_t>&
1021       instruction_schedule = hlo_live_range_.instruction_schedule();
1022   std::vector<HloPosition> positions;
1023   for (const HloPosition& position : value->positions()) {
1024     const HloInstruction* instruction = position.instruction;
1025     if (instruction->opcode() != HloOpcode::kGetTupleElement &&
1026         instruction->opcode() != HloOpcode::kTuple &&
1027         instruction->opcode() != HloOpcode::kBitcast) {
1028       positions.push_back(position);
1029     }
1030   }
1031   absl::c_stable_sort(positions,
1032                       [&](const HloPosition& pos1, const HloPosition& pos2) {
1033                         return instruction_schedule.at(pos1.instruction) <
1034                                instruction_schedule.at(pos2.instruction);
1035                       });
1036 
1037   // Create an AllocationValue for each non-trivial position.
1038   absl::flat_hash_set<const HloComputation*> computations;
1039   int beginning_idx = allocation_values.size();
1040   for (int i = 0; i < positions.size(); ++i) {
1041     const HloPosition& position = positions.at(i);
1042     allocation_values.emplace_back(value, position, buffer_interval.size);
1043   }
1044 
1045   std::vector<HloUse> uses(value->GetUses().begin(), value->GetUses().end());
1046   absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
1047     return instruction_schedule.at(use1.instruction) <
1048            instruction_schedule.at(use2.instruction);
1049   });
1050 
1051   // Associate each use with an AllocationValue. Each AllocationValue contains a
1052   // position and uses in the same computation. Furthermore, if the original
1053   // HloValue had multiple non-trivial positions in the same computation, those
1054   // will get their own AllocationValue as well. We split these HloValues so
1055   // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
1056   // point to the latest position. We then replace the operand of the use with
1057   // CopyStart/CopyDone with an operand of the latest position.
1058   for (const HloUse& use : uses) {
1059     int64_t use_time = instruction_schedule.at(use.instruction);
1060     HloComputation* use_computation = use.instruction->parent();
1061 
1062     AllocationValue* last_allocation_value = nullptr;
1063     for (int i = beginning_idx; i < allocation_values.size(); ++i) {
1064       AllocationValue* allocation_value = &allocation_values.at(i);
1065       if (HloDataflowAnalysis::IsAsynchronousOperationDone(
1066               use.instruction->opcode())) {
1067         if (allocation_value->defining_instruction() ==
1068             use.instruction->operand(0)) {
1069           last_allocation_value = allocation_value;
1070         }
1071       } else if (!HloDataflowAnalysis::IsAsynchronousOperationStart(
1072                      allocation_value->defining_instruction()->opcode()) &&
1073                  allocation_value->computation() == use_computation &&
1074                  instruction_schedule.at(
1075                      allocation_value->defining_position().instruction) <
1076                      use_time) {
1077         last_allocation_value = allocation_value;
1078       }
1079     }
1080     CHECK(last_allocation_value != nullptr);
1081     last_allocation_value->AddUse(use, use_time);
1082   }
1083 
1084   for (int i = beginning_idx; i < allocation_values.size(); ++i) {
1085     AllocationValue& allocation_value = allocation_values.at(i);
1086     if (HloDataflowAnalysis::IsAsynchronousOperationStart(
1087             allocation_value.defining_instruction()->opcode())) {
1088       CHECK_EQ(allocation_value.uses().size(), 1);
1089       CHECK(HloDataflowAnalysis::IsAsynchronousOperationDone(
1090           allocation_value.uses().at(0).hlo_use.instruction->opcode()));
1091       VLOG(3) << "Mark " << allocation_value.ToShortString()
1092               << " to require contiguous allocation.";
1093       allocation_value.set_requires_contiguous_allocation(true);
1094     }
1095     VLOG(3) << "Created allocation value: "
1096             << allocation_values.at(i).ToString();
1097   }
1098 }
1099 
FindAliases(std::vector<AllocationValue> * allocation_values) const1100 void AlternateMemoryBestFitHeap::FindAliases(
1101     std::vector<AllocationValue>* allocation_values) const {
1102   absl::flat_hash_map<const HloInstruction*,
1103                       std::vector<const AllocationValue*>>
1104       values_by_defining_inst;
1105   for (AllocationValue& value : *allocation_values) {
1106     values_by_defining_inst[value.defining_instruction()].push_back(&value);
1107   }
1108   auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
1109                                               AllocationValue::Use* use) {
1110     auto aliased_values_it = values_by_defining_inst.find(instruction);
1111     if (aliased_values_it != values_by_defining_inst.end()) {
1112       for (const AllocationValue* aliased_value : aliased_values_it->second) {
1113         VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString()
1114                 << " to " << aliased_value->ToShortString();
1115         use->aliases.push_back(aliased_value->defining_position());
1116       }
1117     }
1118   };
1119 
1120   for (AllocationValue& value : *allocation_values) {
1121     for (AllocationValue::Use& use : value.uses()) {
1122       // Find any aliases with the instruction itself (operand and output must
1123       // alias).
1124       maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
1125 
1126       // Find any aliases with the parameters of called computations.
1127       for (const HloComputation* called_computation :
1128            use.hlo_use.instruction->called_computations()) {
1129         for (const HloInstruction* parameter_instruction :
1130              called_computation->parameter_instructions()) {
1131           maybe_add_alias_with_instruction(parameter_instruction, &use);
1132         }
1133       }
1134 
1135       // Special case for kWhile: the root of the body computation must alias as
1136       // well.
1137       if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1138         HloPosition root_alias{
1139             use.hlo_use.instruction->while_body()->root_instruction(),
1140             use.hlo_use.operand_index};
1141         VLOG(3) << "Adding while body root aliasing for use "
1142                 << use.hlo_use.ToString() << " to " << root_alias;
1143         use.aliases.push_back(root_alias);
1144       }
1145     }
1146   }
1147 }
1148 
1149 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const1150 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
1151     const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
1152   std::vector<const BufferInterval*> colocated_intervals;
1153   std::vector<const BufferInterval*> worklist = {&interval};
1154   while (!worklist.empty()) {
1155     const BufferInterval* item = worklist.back();
1156     worklist.pop_back();
1157     colocated_intervals.push_back(item);
1158     for (const HloValue* buffer_colocated : item->colocations) {
1159       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
1160     }
1161   }
1162 
1163   absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
1164                                                const BufferInterval* y) {
1165     return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
1166   });
1167   return colocated_intervals;
1168 }
1169 
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const1170 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
1171     const AllocationValue& value, const HloUse& use) const {
1172   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1173   if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
1174     return false;
1175   }
1176   if (use.instruction->opcode() == HloOpcode::kWhile) {
1177     HloComputation* while_body = use.instruction->while_body();
1178 
1179     // We don't want to allocate this buffer in alternate memory if it will be
1180     // evicted anyway. Find out if it has an early use or a late definition that
1181     // would make sense to keep it in the alternate memory.
1182     HloValue* parameter_value =
1183         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1184             while_body->parameter_instruction(0), use.operand_index);
1185     int64_t parameter_time =
1186         instruction_schedule.at(while_body->parameter_instruction(0));
1187     int64_t root_time = instruction_schedule.at(while_body->root_instruction());
1188     int64_t min_use_time = root_time;
1189     for (const HloUse& parameter_use : parameter_value->GetUses()) {
1190       int64_t use_time = instruction_schedule.at(parameter_use.instruction);
1191       if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
1192           parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1193           parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
1194           use_time > parameter_time) {
1195         min_use_time = std::min(min_use_time, use_time);
1196       }
1197     }
1198     // If there is no use of this buffer inside the while loop, there is no need
1199     // to allocate it in the loop.
1200     if (min_use_time == root_time) {
1201       VLOG(4) << "While allocation not allowed in alternate memory. "
1202               << "use time = " << min_use_time << ", root time = " << root_time;
1203       return false;
1204     }
1205     const Shape& shape = parameter_value->shape();
1206     // Allow the buffer in alternate memory if the buffer has a short live range
1207     // either at the beginning or end of the while loop body.
1208     if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1209             shape, parameter_time, min_use_time)) {
1210       VLOG(4) << "While allocation not allowed in alternate memory. "
1211               << "use time = " << min_use_time << ", root time = " << root_time;
1212       return false;
1213     }
1214     // Check if there is a required assignment for the while loop output.
1215     HloValue* while_value =
1216         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1217             use.instruction, use.operand_index);
1218     int64_t while_time = instruction_schedule.at(use.instruction);
1219     auto existing_required_assignment =
1220         RequiredMemoryAssignmentAt(while_value, while_time);
1221     if (existing_required_assignment &&
1222         existing_required_assignment->memory_space == MemorySpace::kDefault) {
1223       VLOG(4) << "While allocation not allowed in alternate memory because "
1224                  "there is a required default memory assignment.";
1225       return false;
1226     }
1227   } else if (use.instruction->opcode() == HloOpcode::kConditional) {
1228     // For any use of this conditional (the same value might be passed into
1229     // multiple called computations), determine if the parameter->first use
1230     // dependency is short.
1231     int64_t conditional_time = instruction_schedule.at(use.instruction);
1232     for (const AllocationValue::Use& other_use : value.uses()) {
1233       if (other_use.hlo_use.instruction != use.instruction) {
1234         continue;
1235       }
1236       HloComputation* called_computation =
1237           use.instruction->called_computations().at(
1238               other_use.hlo_use.operand_number - 1);
1239       const HloInstruction* parameter_instruction =
1240           called_computation->parameter_instruction(0);
1241       HloValue* parameter_value =
1242           &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1243               parameter_instruction, other_use.hlo_use.operand_index);
1244       int64_t parameter_time = instruction_schedule.at(parameter_instruction);
1245       int64_t min_use_time = conditional_time;
1246       for (const HloUse& parameter_use : parameter_value->GetUses()) {
1247         if (parameter_use.instruction->parent() == called_computation &&
1248             parameter_use.instruction->opcode() !=
1249                 HloOpcode::kGetTupleElement &&
1250             parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1251             parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
1252           min_use_time = std::min(
1253               min_use_time, instruction_schedule.at(parameter_use.instruction));
1254         }
1255       }
1256       if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1257               parameter_value->shape(), parameter_time, min_use_time)) {
1258         VLOG(4) << "Conditional allocation allowed in alternate memory for "
1259                    "computation = "
1260                 << called_computation->name()
1261                 << ", parameter time = " << parameter_time
1262                 << ", min use time = " << min_use_time;
1263         return true;
1264       } else {
1265         VLOG(4) << "Conditional allocation not allowed in alternate memory for "
1266                    "computation = "
1267                 << called_computation->name()
1268                 << ", parameter time = " << parameter_time
1269                 << ", min use time = " << min_use_time;
1270       }
1271     }
1272     return false;
1273   }
1274 
1275   return true;
1276 }
1277 
1278 namespace {
1279 // Columns in buffer information:
1280 // buffer_id: int. This value can be used to match the allocation in
1281 // allocation information.
1282 // buffer_name: string.
1283 // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1284 // thought it would be beneficial to put this in the alternate memory. The
1285 // higher the value, the more it is memory bound.
1286 // size: int. In bytes.
1287 // definition_time: int. Logical time this value was defined in the schedule.
1288 // use_times: string. This is a semicolon-separated list of integers for all
1289 // the use times.
1290 // use_names: string. This is a semicolon-separated list of string
1291 // representation of uses.
1292 // is_scoped: int. A value of 1 indicates that the buffer is a scoped
1293 // allocation.
1294 constexpr absl::string_view kBufferInfoColumnNames =
1295     "buffer_id,buffer_name,alt_mem_benefit,size,definition_time,use_times,use_"
1296     "names,is_scoped";
1297 }  // namespace
1298 
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const1299 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
1300     const AlternateMemoryBestFitHeap::BufferInterval& interval,
1301     std::string* debug_str) const {
1302   if (debug_str->empty()) {
1303     // Append the column names.
1304     absl::StrAppend(debug_str, kBufferInfoColumnNames, "\n");
1305   }
1306   const HloBuffer& buffer =
1307       alias_analysis_.GetBufferContainingValue(*interval.buffer);
1308   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1309   int64_t definition_time =
1310       instruction_schedule.at(interval.buffer->defining_position().instruction);
1311   std::vector<std::pair<int64_t, std::string>> uses;
1312   for (const HloValue* value : buffer.values()) {
1313     for (const HloUse& use : value->GetUses()) {
1314       uses.push_back(
1315           {instruction_schedule.at(use.instruction), use.ToString()});
1316     }
1317   }
1318   absl::c_sort(uses);
1319   std::vector<int64_t> use_times;
1320   std::vector<std::string> use_names;
1321   use_times.reserve(uses.size());
1322   use_names.reserve(uses.size());
1323   for (const auto& use : uses) {
1324     use_times.push_back(use.first);
1325     use_names.push_back(use.second);
1326   }
1327 
1328   absl::StrAppend(debug_str, buffer.id(), ",");
1329   absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1330   auto alternate_memory_benefit =
1331       options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1332           interval);
1333   absl::StrAppend(
1334       debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1335   absl::StrAppend(debug_str, interval.size, ",");
1336   absl::StrAppend(debug_str, definition_time, ",");
1337   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1338   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\",");
1339   absl::StrAppend(debug_str, "0");  // is_scoped
1340   absl::StrAppend(debug_str, "\n");
1341 }
1342 
AppendScopedAllocationBufferInfoDebugString(const HloInstruction * instruction,int64_t time,int64_t size,std::string & debug_str) const1343 void AlternateMemoryBestFitHeap::AppendScopedAllocationBufferInfoDebugString(
1344     const HloInstruction* instruction, int64_t time, int64_t size,
1345     std::string& debug_str) const {
1346   if (debug_str.empty()) {
1347     // Append the column names.
1348     absl::StrAppend(&debug_str, kBufferInfoColumnNames, "\n");
1349   }
1350   const HloBuffer& buffer = alias_analysis_.GetUniqueBufferAt(instruction);
1351 
1352   // As a convention, we use negative values for scoped allocations.
1353   absl::StrAppend(&debug_str, -buffer.id(), ",");
1354   absl::StrAppend(&debug_str, "\"scoped allocation for ", instruction->name(),
1355                   "\",");
1356   absl::StrAppend(&debug_str, 0, ",");  // alt_mem_benefit
1357   absl::StrAppend(&debug_str, size, ",");
1358   absl::StrAppend(&debug_str, time, ",");
1359   absl::StrAppend(&debug_str, "\"\",");  // use_times
1360   absl::StrAppend(&debug_str, "\"\",");  // use_names
1361   absl::StrAppend(&debug_str, "1");      // is_scoped
1362   absl::StrAppend(&debug_str, "\n");
1363 }
1364 
AppendAllocationInfoDebugString(const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1365 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1366     const MemorySpaceAssignment::Allocation& allocation,
1367     std::string& debug_str) const {
1368   // Columns in allocation information:
1369   // buffer_id: int. This value can be used the match with buffer info.
1370   // size: int. In bytes.
1371   // offset: int. In bytes.
1372   // start_time: int. Logical start time of the allocation.
1373   // end_time: int. Logical end time of the allocation.
1374   if (debug_str.empty()) {
1375     // Append the column names.
1376     absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1377   }
1378   if (allocation.memory_space() == MemorySpace::kAlternate) {
1379     const HloPosition& position = allocation.defining_position();
1380     const HloBuffer& buffer =
1381         alias_analysis_.GetUniqueBufferAt(position.instruction, position.index);
1382     // As a convention, we use negative values for scoped allocations.
1383     absl::StrAppend(
1384         &debug_str,
1385         allocation.is_scoped_allocation() ? -buffer.id() : buffer.id(), ",");
1386     absl::StrAppend(&debug_str, allocation.chunk().size, ",");
1387     absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1388     absl::StrAppend(&debug_str, allocation.start_time(), ",");
1389     absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1390   }
1391 }
1392 
DumpDebugStringsIfEnabled() const1393 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1394   if (!options_.dump_fn) {
1395     return;
1396   }
1397   options_.dump_fn("bufferinfo", buffer_info_str_);
1398   options_.dump_fn("allocinfo", allocation_info_str_);
1399 }
1400 
Finish()1401 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1402   if (options_.autotuning_config.has_value()) {
1403     CHECK_EQ((*options_.autotuning_config).size(), buffer_intervals_.size());
1404   }
1405 
1406   AllocateReservedScopedAllocations();
1407   std::vector<BufferInterval> sorted_buffer_intervals =
1408       GetSortedBufferIntervals();
1409   memory_space_assignment::CustomizeSortedBufferInterval(
1410       options_.autotuning_config, sorted_buffer_intervals);
1411 
1412   // Calculate the memory pressure for the buffers that can be assigned in the
1413   // alternate memory.
1414   memory_pressure_ = 0;
1415   for (auto& interval : sorted_buffer_intervals) {
1416     if (!interval.need_allocation ||
1417         !MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1418             interval) ||
1419         interval.size > available_heap_size()) {
1420       continue;
1421     }
1422     memory_pressure_ += interval.size;
1423   }
1424   VLOG(1) << "Memory pressure = " << memory_pressure_;
1425 
1426   if (options_.enable_cross_program_prefetch) {
1427     std::optional<AlternateMemoryBestFitHeap::BufferInterval>
1428         prefetch_candidate = FindCrossProgramPrefetchCandidate(
1429             alias_analysis_, hlo_live_range_, options_);
1430     if (prefetch_candidate) {
1431       HloModule* module =
1432           prefetch_candidate->buffer->instruction()->GetModule();
1433       AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1434     }
1435   }
1436 
1437   VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1438           << options_.max_size_in_bytes;
1439 
1440   AddInputAndOutputRequiredAssignments();
1441 
1442   if (VLOG_IS_ON(3)) {
1443     VLOG(3) << "Flattened instruction sequence:";
1444     const auto& instruction_sequence =
1445         hlo_live_range_.flattened_instruction_sequence().instructions();
1446     for (int i = 0; i < instruction_sequence.size(); ++i) {
1447       VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1448               << " " << instruction_sequence[i]->name();
1449     }
1450   }
1451 
1452   for (const auto& interval : sorted_buffer_intervals) {
1453     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1454     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1455       // Increment the reserved part of alternate memory so that it is not
1456       // available for other buffers.
1457       reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1458     }
1459   }
1460   VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1461 
1462   for (auto& interval : sorted_buffer_intervals) {
1463     if (!interval.need_allocation) {
1464       continue;
1465     }
1466 
1467     if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1468             interval)) {
1469       continue;
1470     }
1471 
1472     HloInstruction* inst = interval.buffer->instruction();
1473     HloModule* module = inst->GetModule();
1474 
1475     // Don't intra-program prefetch a cross program prefetch
1476     if (inst->opcode() == HloOpcode::kParameter &&
1477         absl::c_count(module->CrossProgramPrefetches(),
1478                       std::make_pair(inst->parameter_number(),
1479                                      interval.buffer->index())) > 0) {
1480       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1481               << " because it is cross-program prefetched.";
1482       continue;
1483     }
1484 
1485     if (interval.size > available_heap_size()) {
1486       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1487               << " because the buffer is larger than the heap size.";
1488       continue;
1489     }
1490 
1491     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1492 
1493     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1494       VLOG(3) << "Interval " << interval.buffer->ToShortString()
1495               << " is reserved in the alternate memory.";
1496       for (const BufferInterval* colocated_interval : colocated_intervals) {
1497         const HloValue* value = colocated_interval->buffer;
1498         // Color all of the aliased reserved buffers here because reserved
1499         // alternate memory allocations will not have an entry in preset
1500         // allocations that is normally used for coloring.
1501         for (auto& position : value->positions()) {
1502           VLOG(4) << "Coloring " << position.ToString();
1503           Shape* shape = ShapeUtil::GetMutableSubshape(
1504               position.instruction->mutable_shape(), position.index);
1505           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1506                                   << position.ToString();
1507           shape->mutable_layout()->set_memory_space(
1508               options_.alternate_memory_space);
1509         }
1510       }
1511       continue;
1512     }
1513 
1514     if (colocated_intervals.size() > 1 &&
1515         !options_.allocate_across_sequential_calls) {
1516       VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1517               << " because it aliases with another interval and "
1518               << " allocate_across_sequential_calls is false.";
1519       continue;
1520     }
1521 
1522     if (!ConsumeFuel("memory_space_assignment", [&] {
1523           return absl::StrCat("Ran out of fuel at buffer: ",
1524                               colocated_intervals[0]->buffer->ToShortString());
1525         })) {
1526       continue;
1527     }
1528 
1529     if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
1530       // Only fill buffer_info_str_ if needed.
1531       AppendBufferInfoDebugString(interval, &buffer_info_str_);
1532     }
1533 
1534     std::vector<AllocationValue> allocation_values;
1535     CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1536                                                  allocation_values);
1537 
1538     // Retry allocating this value with larger limits if allocation fails.
1539     bool repacked = false;
1540     for (int retry_number = 0; retry_number < options_.max_retries;
1541          retry_number++) {
1542       AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1543       options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1544       Result result =
1545           AllocateAllocationValues(absl::MakeSpan(allocation_values));
1546       VLOG(2) << "Allocation result = "
1547               << absl::StrFormat("%x", static_cast<int>(result));
1548       if (result_requires_uncommit(result)) {
1549         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1550         VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1551       } else if ((result_is(result, Result::kFailOutOfMemory) ||
1552                   options_.repack_after_every_allocation) &&
1553                  num_repacks_ < options_.max_repacks && !repacked) {
1554         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1555         ++num_repacks_;
1556         repacked = true;
1557         CHECK_NE(options_.repacker, nullptr);
1558         std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1559             repack_allocation_blocks;
1560         ExportAllocationsForRepacking(repack_allocation_blocks);
1561         VLOG(2) << "Repacking.";
1562         auto repack_status =
1563             options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1564         CHECK_EQ(repack_status.status(), OkStatus());
1565         VLOG(2) << "Repack complete. Modified = " << *repack_status;
1566         if (*repack_status) {
1567           ImportRepackedAllocations();
1568           --retry_number;
1569         }
1570       } else {
1571         FinalizeAllocations(absl::MakeSpan(allocation_values));
1572         break;
1573       }
1574     }
1575   }
1576 
1577   if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
1578     for (auto& allocation : *allocations_) {
1579       // Only fill allocation_info_str_ if needed.
1580       AppendAllocationInfoDebugString(*allocation, allocation_info_str_);
1581     }
1582   }
1583 
1584   VLOG(3) << "Debug buffer info: ";
1585   XLA_VLOG_LINES(3, buffer_info_str_);
1586   VLOG(3) << "Debug allocation info: ";
1587   XLA_VLOG_LINES(3, allocation_info_str_);
1588   DumpDebugStringsIfEnabled();
1589 
1590   HeapSimulator::Result<HloValue> result;
1591   result.heap_size = result_.heap_size;
1592   result.heap_results.emplace_back(std::move(result_));
1593   return result;
1594 }
1595 
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1596 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1597     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1598         colocated_intervals) {
1599   // TODO(berkin): For now, place the phi values due to conditionals in
1600   // default memory.
1601   for (const BufferInterval* colocated_interval : colocated_intervals) {
1602     const HloValue* value = colocated_interval->buffer;
1603     for (const auto& position : value->positions()) {
1604       if (position.instruction->opcode() == HloOpcode::kConditional) {
1605         VLOG(3) << "Adding required assignment for condition output: "
1606                 << value->ToShortString();
1607         AddRequiredAssignment(position.instruction, position.index,
1608                               MemorySpace::kDefault);
1609         for (const HloComputation* called_computation :
1610              position.instruction->called_computations()) {
1611           AddRequiredAssignment(called_computation->root_instruction(),
1612                                 position.index, MemorySpace::kDefault);
1613         }
1614       }
1615     }
1616   }
1617 }
1618 
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1619 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1620     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1621         colocated_intervals,
1622     std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1623   // Create AllocationValues for all the colocated intervals.
1624   for (const auto& colocated_interval : colocated_intervals) {
1625     CreateAllocationValues(*colocated_interval, allocation_values);
1626   }
1627   // Go through the AllocationValues and delete the ones that have the identical
1628   // defining instruction and use instructions. This is useful for async
1629   // operations that can read and write to the same buffer, e.g., in-place
1630   // asynchronous collective permute. The AllocationValues that corresponds to
1631   // collective-permute-start{0} (the input) and collective-permute-start{1}
1632   // (the output) refer to the same buffer by definition (since they are created
1633   // from colocated intervals). If we don't delete one of these buffers, then
1634   // when we try to allocate the AllocationValue, we would think they overlap.
1635   auto create_instruction_vector = [](const AllocationValue& allocation_value) {
1636     std::vector<const HloInstruction*> instruction_vector;
1637     instruction_vector.push_back(allocation_value.defining_instruction());
1638     for (const AllocationValue::Use& use : allocation_value.uses()) {
1639       instruction_vector.push_back(use.hlo_use.instruction);
1640     }
1641     return instruction_vector;
1642   };
1643   for (int i = 0; i < allocation_values.size() - 1; ++i) {
1644     for (int j = i + 1; j < allocation_values.size(); ++j) {
1645       const AllocationValue& allocation_value_1 = allocation_values[i];
1646       const AllocationValue& allocation_value_2 = allocation_values[j];
1647       if (create_instruction_vector(allocation_value_1) ==
1648           create_instruction_vector(allocation_value_2)) {
1649         VLOG(3) << "Allocation values " << allocation_value_1.ToShortString()
1650                 << " and " << allocation_value_2.ToShortString()
1651                 << " are equivalent, deleting the second one.";
1652         allocation_values.erase(allocation_values.begin() + j);
1653         --j;
1654       }
1655     }
1656   }
1657 
1658   FindAliases(&allocation_values);
1659 }
1660 
1661 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1662 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1663     absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1664   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1665 
1666   // Find the use times across all of the related AllocationValues and sort
1667   // them. We use these to find allocations that are available throughout the
1668   // entire live range of all the AllocationValues.
1669   std::vector<int64_t> all_use_times;
1670   for (const AllocationValue& allocation_value : allocation_values) {
1671     absl::c_transform(allocation_value.uses(),
1672                       std::back_inserter(all_use_times),
1673                       [](const AllocationValue::Use& use) { return use.time; });
1674   }
1675   absl::c_sort(all_use_times);
1676 
1677   // Data structure to contain the preferred offset for a given computation.
1678   // We ensure that the same offset will be allocated outside the while loop
1679   // as well as inside the while loop.
1680   absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1681       preferred_offset_for_computation;
1682 
1683   Result result = Result::kSuccess;
1684   for (AllocationValue& allocation_value : allocation_values) {
1685     int64_t definition_time =
1686         instruction_schedule.at(allocation_value.defining_instruction());
1687 
1688     AliasedOffset* preferred_offset = nullptr;
1689     auto preferred_offset_it =
1690         preferred_offset_for_computation.find(allocation_value.computation());
1691     if (preferred_offset_it != preferred_offset_for_computation.end()) {
1692       preferred_offset = preferred_offset_it->second;
1693     }
1694 
1695     // Iterate over the uses.
1696     for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1697       const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1698       const HloUse hlo_use = use.hlo_use;
1699       int64_t use_time = instruction_schedule.at(hlo_use.instruction);
1700       int64_t latest_prefetch_time = use_time;
1701       bool allow_no_copy_alternate_mem_allocation = true;
1702       std::optional<int64_t> earliest_prefetch_time = std::nullopt;
1703 
1704       // Control flow  calls include kWhile, kCall, and kConditional opcodes.
1705       bool is_sequential_call =
1706           (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1707            CallContext::kControlFlow);
1708       if (is_sequential_call) {
1709         for (const HloComputation* called_computation :
1710              hlo_use.instruction->called_computations()) {
1711           const HloLiveRange::TimeBound& computation_span =
1712               hlo_live_range_.computation_span_times().at(called_computation);
1713           latest_prefetch_time =
1714               std::min(computation_span.start - 1, latest_prefetch_time);
1715         }
1716         if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1717           // Given an example while loop and flattened schedule (logical times
1718           // shown on the left):
1719           //
1720           // 0:  a = ...
1721           // 1:  ...
1722           //     cond {
1723           // 2:   p = param(0)
1724           // 3:   ...
1725           //     }
1726           //     body {
1727           // 4:   p = param(0)
1728           // 5:   ...
1729           // 6:   ROOT ...
1730           //     }
1731           // 7:  w = while(a), body=body, cond=cond
1732           //
1733           // When processing "a" (time 0) and its while use (time 7), we update
1734           // the interval to time 0-4. This is so that the remaining interval
1735           // (5-6) can be allocated separately and this buffer doesn't waste
1736           // alternate memory space within the while loop body.
1737           HloComputation* while_body = hlo_use.instruction->while_body();
1738           // We require while body ROOTs to be the last in the schedule.
1739           CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1740                    instruction_schedule.at(hlo_use.instruction))
1741               << "While body ROOTs need to be the last in the schedule!  "
1742                  "Please run RootInstructionSinker.";
1743           // Replace the use time with the parameter time so that we can decide
1744           // on alternate memory allocations within the while loop body when we
1745           // look at uses within the while loop body.
1746           use_time =
1747               instruction_schedule.at(while_body->parameter_instruction(0));
1748         } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1749           // Replace the use time with the earliest parameter of called
1750           // computations.
1751           for (const HloComputation* called_computation :
1752                hlo_use.instruction->called_computations()) {
1753             use_time = std::min(
1754                 use_time, instruction_schedule.at(
1755                               called_computation->parameter_instruction(0)));
1756           }
1757         }
1758       }
1759 
1760       // Add a required assignment in default memory if the use not allowed in
1761       // alternate memory.
1762       if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1763         AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1764                               MemorySpace::kDefault, use_time);
1765       } else if (use_idx > 0) {
1766         // We allow buffers in alternate memory that are passed into
1767         // conditionals to give up their alternate memory allocation inside the
1768         // called computation. This means that if a conditional operator has an
1769         // alternate memory allocation, subsequent uses cannot use the same
1770         // alternate memory allocation in order not to clobber data. So we force
1771         // default memory allocation for these subsequent uses.
1772         const AllocationValue::Use& previous_use =
1773             allocation_value.uses().at(use_idx - 1);
1774         if (previous_use.hlo_use.instruction->opcode() ==
1775                 HloOpcode::kConditional &&
1776             previous_use.hlo_use.instruction != hlo_use.instruction) {
1777           allow_no_copy_alternate_mem_allocation = false;
1778           earliest_prefetch_time =
1779               instruction_schedule.at(previous_use.hlo_use.instruction);
1780           VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1781                   << ") of use (" << hlo_use.ToString()
1782                   << ") is a conditional, so this use will need to evict. "
1783                   << "Earliest prefetch time = " << *earliest_prefetch_time;
1784         }
1785       }
1786 
1787       // Bitcasts don't define buffers and don't directly consume buffers. Skip
1788       // allocating buffers for bitcast uses (unless they are the root
1789       // instruction). The uses that feed from bitcasts will be handled
1790       // specially.
1791       if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1792           hlo_use.instruction ==
1793               hlo_use.instruction->parent()->root_instruction()) {
1794         AllocationRequest request;
1795         // Rarely, (e.g., when conditional true and false parameters are the
1796         // same), definition time can be the time of the conditional and use
1797         // time is the parameter use, which is less.
1798         request.start_time = std::min(definition_time, use_time);
1799         request.end_time = use_time;
1800         request.latest_prefetch_time = latest_prefetch_time;
1801         request.size = allocation_value.size();
1802         request.allow_no_copy_alternate_mem_allocation =
1803             allow_no_copy_alternate_mem_allocation;
1804         request.earliest_prefetch_time = earliest_prefetch_time;
1805         request.preferred_offset = preferred_offset;
1806         request.use = &use;
1807         request.allocation_value = &allocation_value;
1808         request.all_use_times = all_use_times;
1809         result_mark(AllocateSegment(request), result);
1810         if (result_requires_uncommit(result)) {
1811           // If the allocation finding failed (e.g., due to running out of
1812           // asynchronous copies), then fall back to allocating the buffer
1813           // entirely in the default memory.
1814           return result;
1815         }
1816 
1817         // If there are multiple uses, they can try using the memory allocation
1818         // already at the alternate memory.
1819         definition_time = instruction_schedule.at(hlo_use.instruction);
1820       }
1821 
1822       // Propagate the allocation to any aliases this use might have had.
1823       MemorySpaceAssignment::Allocation* aliased_allocation =
1824           GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1825                               use_time);
1826       for (const HloPosition& aliased_position : use.aliases) {
1827         AddAliasedRequiredAssignment(aliased_position.instruction,
1828                                      aliased_position.index,
1829                                      aliased_allocation);
1830       }
1831 
1832       if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1833           aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1834         // For while uses that are allocated in the alternate memory space, if
1835         // they also have an allocation in the default memory space in their
1836         // allocation sequence, create a "parent" allocation that mirrors this
1837         // default memory space allocation. When we process the parent
1838         // allocation, we add an additional parameter to the while that is a
1839         // reference to the buffer in the default memory space. With parent
1840         // allocations, we don't need to unnecessarily evict buffers since they
1841         // already have a copy in the default memory space. We search backwards
1842         // (latest to earliest in execution time) for a suitable allocation in
1843         // order to find the most recent one.
1844         if (options_.enable_while_redundant_eviction_elimination &&
1845             absl::c_find_if(allocation_value.value()->positions(),
1846                             [&hlo_use](const HloPosition& position) {
1847                               return position.instruction ==
1848                                          hlo_use.instruction &&
1849                                      position.index == hlo_use.operand_index;
1850                             }) != allocation_value.value()->positions().end()) {
1851           auto allocation_sequence = allocation_value.allocation_sequence();
1852           auto prev_allocation_in_default_mem_it = std::find_if(
1853               allocation_sequence->rbegin(), allocation_sequence->rend(),
1854               [&](const auto& allocation) {
1855                 return allocation->memory_space() == MemorySpace::kDefault &&
1856                        allocation->defining_position() ==
1857                            allocation_value.defining_position();
1858               });
1859           if (prev_allocation_in_default_mem_it !=
1860               allocation_sequence->rend()) {
1861             VLOG(3) << "Found a prev allocation in default mem for while use: "
1862                     << (*prev_allocation_in_default_mem_it)->ToString();
1863             auto body_allocation_value_it = absl::c_find_if(
1864                 allocation_values, [&](const AllocationValue& value) {
1865                   return value.computation() ==
1866                              hlo_use.instruction->while_body() &&
1867                          value.defining_instruction()->opcode() ==
1868                              HloOpcode::kParameter;
1869                 });
1870             CHECK_NE(body_allocation_value_it, allocation_values.end());
1871             VLOG(3) << "Body allocation value: "
1872                     << body_allocation_value_it->ToShortString();
1873             int64_t body_parameter_time = instruction_schedule.at(
1874                 body_allocation_value_it->defining_instruction());
1875             body_allocation_value_it->allocation_sequence()->push_back(
1876                 std::make_unique<MemorySpaceAssignment::ParentAllocation>(
1877                     **prev_allocation_in_default_mem_it, hlo_use.instruction,
1878                     body_allocation_value_it->defining_position(),
1879                     body_parameter_time));
1880             VLOG(3) << "Created: "
1881                     << body_allocation_value_it->allocation_sequence()
1882                            ->back()
1883                            ->ToString();
1884 
1885             auto after_while_allocation_value_it = absl::c_find_if(
1886                 allocation_values, [&](const AllocationValue& value) {
1887                   return value.defining_instruction() == hlo_use.instruction;
1888                 });
1889             CHECK_NE(after_while_allocation_value_it, allocation_values.end());
1890             VLOG(3) << "After while allocation value: "
1891                     << after_while_allocation_value_it->ToShortString();
1892             int64_t while_time = instruction_schedule.at(hlo_use.instruction);
1893             after_while_allocation_value_it->allocation_sequence()->push_back(
1894                 std::make_unique<MemorySpaceAssignment::MirroredAllocation>(
1895                     **prev_allocation_in_default_mem_it, while_time));
1896             VLOG(3) << "Created: "
1897                     << after_while_allocation_value_it->allocation_sequence()
1898                            ->back()
1899                            ->ToString();
1900           }
1901         }
1902         // Special case for while loops since the root offset must agree with
1903         // other offsets: remember the preferred offset for the while loop body.
1904         preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1905             GetAliasedOffset(*aliased_allocation);
1906       }
1907     }
1908   }
1909   return result;
1910 }
1911 
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1912 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1913   return a.AsTuple() < b.AsTuple();
1914 }
1915 
operator ==(const AsynchronousCopy & a,const AsynchronousCopy & b)1916 bool operator==(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1917   return a.AsTuple() == b.AsTuple();
1918 }
1919 
operator !=(const AsynchronousCopy & a,const AsynchronousCopy & b)1920 bool operator!=(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1921   return a.AsTuple() != b.AsTuple();
1922 }
1923 
ConsumeResource(int64_t start_time,int64_t end_time,float resource,bool update_current_resource,const std::list<AsynchronousCopy>::iterator * current_copy,float resource_to_free)1924 bool AsynchronousCopyResource::ConsumeResource(
1925     int64_t start_time, int64_t end_time, float resource,
1926     bool update_current_resource,
1927     const std::list<AsynchronousCopy>::iterator* current_copy,
1928     float resource_to_free) {
1929   VLOG(3) << "Consume resource: " << start_time << ", " << end_time << ", "
1930           << resource << ", delay: " << delay_[start_time + 1]
1931           << ", free: " << resource_to_free;
1932 
1933   // Nothing to do if we're not adding or removing any resources.
1934   if (resource == 0.0 && resource_to_free == 0.0) {
1935     return true;
1936   }
1937 
1938   // For the async copy we're adding, check the delay_ array to see how much
1939   // this copy would have to be delayed because of an earlier copy that wasn't
1940   // finished when this copy starts.
1941   if (current_copy == nullptr) {
1942     resource += delay_[start_time + 1];
1943   }
1944 
1945   // Find the copy that is right after this one. If there are leftover resources
1946   // by the time the next copy starts, the next copy will be pushed further
1947   // later in time.
1948   auto next_copy = async_copies_.end();
1949   if (current_copy != nullptr) {
1950     next_copy = std::next(*current_copy);
1951   } else {
1952     auto async_copy_time_it = async_copy_time_map_.upper_bound(start_time);
1953     if (async_copy_time_it != async_copy_time_map_.end()) {
1954       next_copy = async_copy_time_it->second;
1955     }
1956   }
1957 
1958   // Check if this copy will push the next copy later in time (or if removing
1959   // the resource, check if the removal of this copy move the next copy earlier
1960   // in time).
1961   std::optional<float> delay_for_next_copy = std::nullopt;
1962   float resource_freed = 0.0;
1963   for (int64_t time = start_time + 1; time < end_time && resource != 0;
1964        ++time) {
1965     // Iterate over the logical times that this copy spans. Note that the start
1966     // and end time ranges are exclusive.
1967     float used_resource = std::min(resource, initial_resources_[time]);
1968     if (next_copy != async_copies_.end() && next_copy->start_time == time - 1) {
1969       // This is the time where the next copy begins. If the resource is
1970       // non-zero at this point, the copy didn't finish by the time the next
1971       // copy started, so the next copy would need to be pushed later in time.
1972       delay_for_next_copy = resource;
1973       resource_to_free -= resource_freed;
1974     }
1975     if (update_current_resource && !delay_for_next_copy.has_value()) {
1976       // Update the delay_ vector and resource_freed variable with the amount
1977       // that was freed when removing the copy.
1978       float old_resource =
1979           std::max(0.0f, initial_resources_[time] - delay_[time]);
1980       delay_[time] = std::max(0.0f, resource - resource_to_free);
1981       float new_resource =
1982           std::max(0.0f, initial_resources_[time] - delay_[time]);
1983       resource_freed += std::max(0.0f, new_resource - old_resource);
1984     }
1985     // Update the resource with the used amount in this logical time.
1986     resource -= used_resource;
1987   }
1988 
1989   // If resource isn't satisfied by the end, we didn't have enough resources.
1990   if (resource > 0) {
1991     VLOG(3) << "Doesn't have enough resource; leftover resource = " << resource;
1992     return false;
1993   }
1994 
1995   // If this copy overlapped with another one, we recursively call
1996   // ConsumeResource with the amount of resource that needs to be added or
1997   // removed.
1998   if (delay_for_next_copy.has_value()) {
1999     return ConsumeResource(next_copy->start_time, next_copy->end_time,
2000                            *delay_for_next_copy + next_copy->resource,
2001                            update_current_resource, &next_copy,
2002                            resource_to_free);
2003   }
2004   return true;
2005 }
2006 
AddCopy(const AsynchronousCopy & copy)2007 void AsynchronousCopyResource::AddCopy(const AsynchronousCopy& copy) {
2008   CHECK(ConsumeResource(copy.start_time, copy.end_time, copy.resource,
2009                         /*update_current_resource=*/true));
2010   // Find the iterator for the copy that would be right after this copy and put
2011   // this copy right before it in async_copies_.
2012   auto async_copy_time_it = async_copy_time_map_.upper_bound(copy.start_time);
2013   auto insertion_it = (async_copy_time_it == async_copy_time_map_.end())
2014                           ? async_copies_.end()
2015                           : async_copy_time_it->second;
2016   auto inserted_it = async_copies_.insert(insertion_it, copy);
2017   // If this copy is the first copy we have seen with the start time, add the
2018   // inserted iterator into async_copy_time_map_ for fast lookups. Note that
2019   // async_copy_time_map_ always points to the very first copy with the same
2020   // start index. If there are multiple asynchronous copies that have the same
2021   // start time, the memory space assignment algorithm schedules them in the
2022   // same order that AddCopy was called.
2023   if (async_copy_time_map_.find(copy.start_time) ==
2024       async_copy_time_map_.end()) {
2025     async_copy_time_map_[copy.start_time] = inserted_it;
2026   }
2027 }
2028 
RemoveCopy(const AsynchronousCopy & copy)2029 void AsynchronousCopyResource::RemoveCopy(const AsynchronousCopy& copy) {
2030   CHECK(ConsumeResource(copy.start_time, copy.end_time, /*resource=*/0,
2031                         /*update_current_resource=*/true,
2032                         /*current_copy=*/nullptr,
2033                         /*resource_to_free=*/copy.resource));
2034   // Using async_copy_time_map_, find this copy to be removed. Note that the
2035   // iterator in async_copy_time_map_ points to the first-seen copy with the
2036   // given start time, so the copy to be removed might be later than the first
2037   // one.
2038   auto async_copy_time_it = async_copy_time_map_.find(copy.start_time);
2039   CHECK(async_copy_time_it != async_copy_time_map_.end());
2040   auto it = async_copy_time_it->second;
2041   for (; it != async_copies_.end() && *it != copy; ++it) {
2042   }
2043   CHECK(it != async_copies_.end());
2044   // If the copy to be removed is the value pointed by async_copy_time_map_, we
2045   // make the next copy with the same start time to be pointed by
2046   // async_copy_time_map_. If there are no such copies, we remove the key for
2047   // this copy start time.
2048   if (it == async_copy_time_it->second) {
2049     if (std::next(it) != async_copies_.end() &&
2050         std::next(it)->start_time == copy.start_time) {
2051       async_copy_time_it->second = std::next(it);
2052     } else {
2053       async_copy_time_map_.erase(async_copy_time_it);
2054     }
2055   }
2056   async_copies_.erase(it);
2057 }
2058 
HasEnoughResource(int64_t start_time,int64_t end_time,float resource)2059 bool AsynchronousCopyResource::HasEnoughResource(int64_t start_time,
2060                                                  int64_t end_time,
2061                                                  float resource) {
2062   return ConsumeResource(start_time, end_time, resource,
2063                          /*update_current_resource=*/false);
2064 }
2065 
2066 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)2067 AlternateMemoryBestFitHeap::GetAliasedOffset(
2068     const MemorySpaceAssignment::Allocation& allocation) {
2069   auto aliased_offset_it = aliased_offset_map_.find(&allocation);
2070   CHECK(aliased_offset_it != aliased_offset_map_.end());
2071   return aliased_offset_it->second;
2072 }
2073 
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)2074 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
2075     const MemorySpaceAssignment::Allocation& allocation,
2076     AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
2077   CHECK(allocation.memory_space() == MemorySpace::kAlternate);
2078   CHECK(!aliased_offset_map_.contains(&allocation));
2079   if (!aliased_offset) {
2080     aliased_offsets_.push_back({allocation.chunk().offset});
2081     aliased_offset = &aliased_offsets_.back();
2082   }
2083   CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
2084   CHECK(aliased_offset->allocations.insert(&allocation).second);
2085   aliased_offset_map_[&allocation] = aliased_offset;
2086 }
2087 
2088 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64_t time)2089 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
2090     const MemorySpaceAssignment::AllocationSequence& allocations,
2091     int64_t time) {
2092   for (auto allocation_it = allocations.rbegin();
2093        allocation_it != allocations.rend(); ++allocation_it) {
2094     if ((*allocation_it)->start_time() <= time &&
2095         (*allocation_it)->end_time() >= time) {
2096       return allocation_it->get();
2097     }
2098   }
2099   return nullptr;
2100 }
2101 
AllocateCrossProgramPrefetchBuffer(HloModule * module,std::optional<BufferInterval> prefetch_candidate)2102 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
2103     HloModule* module, std::optional<BufferInterval> prefetch_candidate) {
2104   if (!prefetch_candidate) {
2105     return;
2106   }
2107 
2108   Chunk chunk_candidate = FindChunkCandidate(*prefetch_candidate);
2109   if (chunk_candidate.chunk_end() > available_heap_size()) {
2110     LOG(WARNING)
2111         << "Could not allocate preferred memory for cross program prefetch";
2112     return;
2113   }
2114 
2115   const HloValue* buffer = prefetch_candidate->buffer;
2116   int64_t parameter = buffer->instruction()->parameter_number();
2117   module->AddCrossProgramPrefetch(parameter, buffer->index());
2118 
2119   MemorySpaceAssignment::AllocationSequence allocations;
2120   allocations.push_back(std::make_unique<MemorySpaceAssignment::Allocation>(
2121       buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
2122       prefetch_candidate->start, prefetch_candidate->end,
2123       /*is_scoped_allocation=*/false));
2124 
2125   // Find the earliest use.
2126   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
2127   auto uses = buffer->GetUses();
2128   auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
2129     return instruction_schedule.at(lhs.instruction) <
2130            instruction_schedule.at(rhs.instruction);
2131   };
2132   auto first_use = absl::c_min_element(uses, use_schedule_compare);
2133   int64_t latest_prefetch_time =
2134       instruction_schedule.at(first_use->instruction);
2135 
2136   // Find the latest use time.
2137   int64_t last_use_time = instruction_schedule.at(
2138       absl::c_max_element(uses, use_schedule_compare)->instruction);
2139   for (const HloValue* colocation : prefetch_candidate->colocations) {
2140     auto colocation_uses = colocation->GetUses();
2141     if (!colocation_uses.empty()) {
2142       last_use_time = std::max(
2143           last_use_time,
2144           instruction_schedule.at(
2145               absl::c_max_element(colocation_uses, use_schedule_compare)
2146                   ->instruction));
2147     }
2148   }
2149 
2150   int64_t end_of_program_prefetch_end_time = instruction_schedule.size();
2151   int64_t end_of_program_prefetch_latest_start_time =
2152       options_.prefetch_interval_picker->LatestPrefetchStartTime(
2153           buffer->defining_position().shape(), last_use_time,
2154           end_of_program_prefetch_end_time, nullptr);
2155   int64_t end_of_program_prefetch_start_time =
2156       options_.prefetch_interval_picker->PreferredPrefetchStartTime(
2157           buffer->defining_position().shape(), last_use_time,
2158           end_of_program_prefetch_latest_start_time,
2159           end_of_program_prefetch_end_time);
2160   VLOG(2) << "last use time = " << last_use_time
2161           << ", end-of-program prefetch start time = "
2162           << end_of_program_prefetch_start_time;
2163   float total_execution_time =
2164       options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2165           0, instruction_schedule.size());
2166   float buffer_occupied_time =
2167       options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2168           0, last_use_time) +
2169       options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2170           end_of_program_prefetch_start_time, end_of_program_prefetch_end_time);
2171   float buffer_occupied_ratio = buffer_occupied_time / total_execution_time;
2172   VLOG(2) << "Total execution time = " << total_execution_time
2173           << ", buffer occupied time = " << buffer_occupied_time
2174           << ", buffer occupied ratio = " << buffer_occupied_ratio;
2175   // Freeing buffer only makes sense if the buffer will be free for a
2176   // substantial time. Only perform this optimization if the ratio is below the
2177   // limit, and if the memory pressure is above the alternate memory size.
2178   bool free_buffer =
2179       (options_.enable_cross_program_prefetch_freeing &&
2180        memory_pressure_ > options_.max_size_in_bytes &&
2181        buffer_occupied_ratio < kCrossProgramPrefetchOccupyFreeingLimit &&
2182        end_of_program_prefetch_start_time > last_use_time &&
2183        end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
2184   int64_t cross_program_prefetch_end_time =
2185       free_buffer ? last_use_time : prefetch_candidate->end;
2186 
2187   AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate,
2188                prefetch_candidate->start, cross_program_prefetch_end_time,
2189                latest_prefetch_time, &allocations, /*aliased_offset=*/nullptr,
2190                /*resource=*/0.0,
2191                /*is_cross_program_prefetch=*/true);
2192 
2193   HloInstruction* root_instruction =
2194       module->entry_computation()->root_instruction();
2195   absl::c_for_each(uses, [&](auto& use) {
2196     if (use.instruction != root_instruction) {
2197       allocations.back()->AddUse(use);
2198     }
2199   });
2200   AliasedOffset* cross_program_prefetch_offset =
2201       GetAliasedOffset(*allocations.back());
2202 
2203   if (free_buffer) {
2204     VLOG(2) << "Adding an end-of-program prefetch for freed "
2205                "cross-program-prefetched buffer.";
2206     AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, chunk_candidate,
2207                  end_of_program_prefetch_start_time,
2208                  end_of_program_prefetch_end_time,
2209                  end_of_program_prefetch_end_time, &allocations,
2210                  cross_program_prefetch_offset,
2211                  /*resource=*/0.0);
2212     CHECK_EQ(cross_program_prefetch_offset->offset,
2213              allocations.back()->chunk().offset);
2214   }
2215 
2216   const int allocations_initial_size = allocations_->size();
2217   for (auto& allocation : allocations) {
2218     if (allocation->memory_space() == MemorySpace::kAlternate) {
2219       BufferInterval buffer_interval;
2220       buffer_interval.start = allocation->start_time();
2221       buffer_interval.end = allocation->end_time();
2222       buffer_interval.size = allocation->chunk().size;
2223       buffer_interval.buffer = prefetch_candidate->buffer;
2224       AddToPendingChunks(buffer_interval, chunk_candidate);
2225     }
2226     allocations_->push_back(std::move(allocation));
2227   }
2228 
2229   // Add a repack allocation block for the Allocation objects in alternate
2230   // memory.
2231   for (int i = allocations_initial_size; i < allocations_->size(); ++i) {
2232     const auto& allocation = allocations_->at(i);
2233     if (allocation->memory_space() == MemorySpace::kAlternate) {
2234       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2235           allocation->start_time(), allocation->end_time(),
2236           allocation->chunk().size, allocation->chunk().offset,
2237           static_cast<int64_t>(repack_allocation_blocks_.size()),
2238           allocation.get()));
2239       RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
2240       for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
2241         colocation.colocations.push_back(inserted);
2242         if (&colocation != inserted) {
2243           inserted->colocations.push_back(&colocation);
2244         }
2245       }
2246     }
2247   }
2248 
2249   ClearPendingChunks();
2250 }
2251 
AllocateReservedScopedAllocations()2252 void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() {
2253   const auto& instruction_sequence =
2254       hlo_live_range_.flattened_instruction_sequence().instructions();
2255   std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2256   for (int i = 0; i < instruction_sequence.size(); ++i) {
2257     const HloInstruction* instruction = instruction_sequence[i];
2258     int64_t reserved_scoped_memory =
2259         options_.reserved_scoped_memory_fn(instruction);
2260     if (reserved_scoped_memory != 0) {
2261       VLOG(1) << "Allocate reserved scoped memory at " << i << " ("
2262               << instruction->name() << "): " << reserved_scoped_memory;
2263       MemorySpaceAssignment::BufferInterval interval;
2264       interval.buffer = nullptr;
2265       interval.size = reserved_scoped_memory;
2266       interval.start = i;
2267       interval.end = i;
2268       interval.need_allocation = true;
2269       interval.colocations = {};
2270       Chunk chunk_candidate =
2271           FindChunkCandidate(interval, /*preferred_offset=*/0);
2272       CHECK_EQ(chunk_candidate.offset, 0);
2273       AddToPendingChunks(interval, chunk_candidate);
2274 
2275       if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
2276         AppendScopedAllocationBufferInfoDebugString(
2277             instruction, i, reserved_scoped_memory, buffer_info_str_);
2278       }
2279 
2280       allocations_->push_back(
2281           std::make_unique<MemorySpaceAssignment::Allocation>(
2282               HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate,
2283               chunk_candidate, i, i, /*is_scoped_allocation=*/true));
2284 
2285       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2286           i, i, reserved_scoped_memory,
2287           /*initial_offset=*/0,
2288           static_cast<int64_t>(repack_allocation_blocks_.size()),
2289           allocations_->back().get()));
2290       colocations.push_back(&repack_allocation_blocks_.back());
2291     }
2292   }
2293   // If requested, make all scoped allocations to colocate with each other so
2294   // that when we repack, all scoped allocations get the same offsets. Since
2295   // they will all have the same scoped memory addresses, this increases the
2296   // opportunity to deduplicate different ops.  However, this may hurt the
2297   // memory packing efficiency.
2298   if (options_.allocate_reserved_scoped_memory_at_same_offset) {
2299     for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2300          colocations) {
2301       repack_block->colocations = colocations;
2302     }
2303   }
2304   ClearPendingChunks();
2305 }
2306 
2307 std::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64_t time) const2308 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
2309                                                        int64_t time) const {
2310   auto required_assignment_it = required_assignments_.find(buffer);
2311   std::optional<RequiredMemoryAssignment> required_assignment_at_time;
2312   if (required_assignment_it != required_assignments_.end()) {
2313     for (const RequiredMemoryAssignment& required_assignment :
2314          required_assignment_it->second) {
2315       if (required_assignment.time == time) {
2316         // Sanity check that there is only one required at time.
2317         CHECK(!required_assignment_at_time);
2318         required_assignment_at_time = required_assignment;
2319       }
2320     }
2321   }
2322   return required_assignment_at_time;
2323 }
2324 
2325 std::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const2326 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
2327     const AllocationValue::Use& use) const {
2328   std::optional<RequiredMemoryAssignment> required_assignment;
2329   for (const HloPosition& position : use.aliases) {
2330     const HloValue* value =
2331         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
2332             position.instruction, position.index);
2333     int64_t time =
2334         hlo_live_range_.instruction_schedule().at(position.instruction);
2335     std::optional<RequiredMemoryAssignment> required_assignment_for_alias =
2336         RequiredMemoryAssignmentAt(value, time);
2337     if (required_assignment == std::nullopt) {
2338       required_assignment = required_assignment_for_alias;
2339     } else {
2340       CHECK(required_assignment_for_alias == std::nullopt ||
2341             required_assignment->equals_ignoring_time(
2342                 *required_assignment_for_alias));
2343     }
2344   }
2345   return required_assignment;
2346 }
2347 
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)2348 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
2349     const HloInstruction* instruction, ShapeIndex index,
2350     const MemorySpaceAssignment::Allocation* aliased_allocation) {
2351   AliasedOffset* offset = nullptr;
2352   if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
2353     offset = GetAliasedOffset(*aliased_allocation);
2354   }
2355   AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
2356                         offset);
2357 }
2358 
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64_t time,AliasedOffset * offset)2359 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
2360     const HloValue* value, const HloInstruction* instruction,
2361     MemorySpaceAssignment::MemorySpace memory_space, int64_t time,
2362     AliasedOffset* offset) {
2363   // Check for existing required assignment at this time and make sure it is the
2364   // same as this if there is one.
2365   auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
2366   if (existing_required_assignment) {
2367     CHECK(memory_space == existing_required_assignment->memory_space)
2368         << "inst = " << instruction->ToString() << " at " << time;
2369     CHECK((!offset && !existing_required_assignment->offset) ||
2370           offset == existing_required_assignment->offset);
2371     VLOG(3) << "Not adding required assignment because there is one already: "
2372             << value->ToShortString() << " at " << time << " at "
2373             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2374   } else {
2375     VLOG(3) << "Adding required assignment: " << value->ToShortString()
2376             << " at " << time << " at "
2377             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2378     RequiredMemoryAssignment required_assignment{memory_space, time, offset};
2379     required_assignments_[value].push_back(required_assignment);
2380     pending_required_assignments_.push_back({value, required_assignment});
2381   }
2382 }
2383 
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)2384 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
2385     const HloInstruction* instruction, ShapeIndex index,
2386     MemorySpace memory_space, AliasedOffset* offset) {
2387   const HloValue* value =
2388       &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
2389   int64_t instruction_time =
2390       hlo_live_range_.instruction_schedule().at(instruction);
2391   AddRequiredAssignment(value, instruction, memory_space, instruction_time,
2392                         offset);
2393 }
2394 
AddInputAndOutputRequiredAssignments()2395 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
2396   // Go through the parameters, outputs, and constants and pin them to the
2397   // corresponding memory by adding a required assignment.
2398   const HloModule& module = alias_analysis_.dataflow_analysis().module();
2399   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
2400   HloComputation* entry_computation = module.entry_computation();
2401   for (HloInstruction* parameter_instruction :
2402        entry_computation->parameter_instructions()) {
2403     int64_t parameter_instruction_time =
2404         instruction_schedule.at(parameter_instruction);
2405     ShapeUtil::ForEachSubshape(
2406         parameter_instruction->shape(),
2407         [&](const Shape& subshape, const ShapeIndex& index) {
2408           MemorySpace memory_space = MemorySpace::kDefault;
2409           if (subshape.has_layout() && subshape.layout().memory_space() ==
2410                                            options_.alternate_memory_space) {
2411             memory_space = MemorySpace::kAlternate;
2412           }
2413           for (const HloBuffer* buffer :
2414                alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
2415             for (const HloValue* value : buffer->values()) {
2416               VLOG(3) << "Adding required assignment for parameter value = "
2417                       << value->ToShortString()
2418                       << " time = " << parameter_instruction_time << " space = "
2419                       << (memory_space == MemorySpace::kDefault ? "def"
2420                                                                 : "alt");
2421               required_assignments_[value].push_back(
2422                   {memory_space, /*time=*/parameter_instruction_time});
2423             }
2424           }
2425         });
2426   }
2427   HloInstruction* root_instruction = entry_computation->root_instruction();
2428   int64_t root_instruction_time = instruction_schedule.at(root_instruction);
2429   ShapeUtil::ForEachSubshape(
2430       root_instruction->shape(),
2431       [&](const Shape& subshape, const ShapeIndex& index) {
2432         MemorySpace memory_space = MemorySpace::kDefault;
2433         if (subshape.has_layout() && subshape.layout().memory_space() ==
2434                                          options_.alternate_memory_space) {
2435           memory_space = MemorySpace::kAlternate;
2436         }
2437         for (const HloBuffer* buffer :
2438              alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
2439           for (const HloValue* value : buffer->values()) {
2440             VLOG(3) << "Adding required assignment for output value = "
2441                     << value->ToShortString()
2442                     << " time = " << root_instruction_time << " space = "
2443                     << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2444             required_assignments_[value].push_back(
2445                 {memory_space, /*time=*/root_instruction_time});
2446           }
2447         }
2448       });
2449 
2450   for (const HloComputation* computation : module.MakeNonfusionComputations()) {
2451     for (HloInstruction* instruction : computation->instructions()) {
2452       if (instruction->opcode() == HloOpcode::kConstant) {
2453         auto constant_instruction_it = instruction_schedule.find(instruction);
2454         if (constant_instruction_it == instruction_schedule.end()) {
2455           continue;
2456         }
2457         int64_t constant_instruction_time = constant_instruction_it->second;
2458         for (const auto& indexed_shape :
2459              ShapeUtil::GetLeafShapes(instruction->shape())) {
2460           const ShapeIndex& index = indexed_shape.index;
2461           for (const HloBuffer* buffer :
2462                alias_analysis_.ComputeBuffersAt(instruction, index)) {
2463             for (const HloValue* value : buffer->values()) {
2464               VLOG(3) << "Adding required assignment for constant value = "
2465                       << value->ToShortString()
2466                       << " time = " << constant_instruction_time
2467                       << " space = def";
2468               required_assignments_[value].push_back(
2469                   {MemorySpace::kDefault, /*time=*/constant_instruction_time});
2470             }
2471           }
2472         }
2473       }
2474     }
2475   }
2476 
2477   // Go through all of the values and pin them to the default memory if they are
2478   // not allowed on the alternate memory.
2479   for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
2480     if (!options_.is_allowed_in_alternate_mem_fn(*value)) {
2481       // We won't find the instruction in the schedule if it's inside a fusion.
2482       // If so, just skip.
2483       auto instruction_time_it =
2484           instruction_schedule.find(value->instruction());
2485       if (instruction_time_it == instruction_schedule.end()) {
2486         continue;
2487       }
2488       int64_t instruction_time = instruction_time_it->second;
2489       auto& required_assignments = required_assignments_[value];
2490       // Check if there is an existing matching required assignment (e.g.
2491       // inserted by the logic above) and if so ensure it requires a default
2492       // memory allocation.
2493       auto matching_assignment = absl::c_find_if(
2494           required_assignments,
2495           [&](const RequiredMemoryAssignment& required_assignment) {
2496             return required_assignment.time == instruction_time;
2497           });
2498       if (matching_assignment != required_assignments.end()) {
2499         CHECK(matching_assignment->memory_space == MemorySpace::kDefault)
2500             << "Mismatch in required assignments at time " << instruction_time
2501             << " value: " << value->ToString();
2502       } else {
2503         required_assignments.push_back(
2504             {MemorySpace::kDefault, instruction_time});
2505       }
2506     }
2507   }
2508 }
2509 
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const2510 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
2511     absl::Span<const BufferInterval* const> colocated_intervals) const {
2512   auto is_position_in_alternate_memory = [&](const HloPosition& position) {
2513     const Shape& shape = position.shape();
2514     return shape.has_layout() &&
2515            shape.layout().memory_space() == options_.alternate_memory_space;
2516   };
2517 
2518   const HloModule& module = alias_analysis_.dataflow_analysis().module();
2519   const HloComputation* entry_computation = module.entry_computation();
2520   const HloInstruction* root_instruction =
2521       entry_computation->root_instruction();
2522   for (const BufferInterval* colocated_interval : colocated_intervals) {
2523     const HloValue* value = colocated_interval->buffer;
2524     if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
2525         value->defining_instruction()->parent() == entry_computation &&
2526         is_position_in_alternate_memory(value->defining_position())) {
2527       return true;
2528     }
2529 
2530     for (const HloPosition& position : value->positions()) {
2531       if (position.instruction == root_instruction &&
2532           is_position_in_alternate_memory(position)) {
2533         return true;
2534       }
2535     }
2536   }
2537   return false;
2538 }
2539 
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)2540 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
2541     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
2542   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2543     allocations.push_back(&allocation_block);
2544   }
2545 }
2546 
ImportRepackedAllocations()2547 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
2548   interval_tree_ = {};
2549   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2550     MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
2551     VLOG(3) << "Moved " << allocation->ToString() << ", size "
2552             << allocation->chunk().size << ", (" << allocation_block.start_time
2553             << ", " << allocation_block.end_time << ") from "
2554             << allocation_block.initial_offset << " to "
2555             << allocation_block.offset;
2556     allocation_block.allocation->mutable_chunk()->offset =
2557         allocation_block.offset;
2558     interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
2559                        {allocation_block.offset, allocation_block.size});
2560     allocation_block.initial_offset = allocation_block.offset;
2561     allocation_block.offset = -1;
2562   }
2563 }
2564 
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)2565 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
2566     absl::Span<AllocationValue> allocation_values) {
2567   // Clear the allocation sequence of the allocation values so that in case we
2568   // retry allocation after uncommitting.
2569   for (AllocationValue& allocation_value : allocation_values) {
2570     allocation_value.allocation_sequence()->clear();
2571   }
2572   for (const auto& interval_and_chunk : pending_chunks_) {
2573     const BufferInterval& interval = interval_and_chunk.first;
2574     const Chunk& chunk = interval_and_chunk.second;
2575     VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
2576             << ") off = " << chunk.offset << " size = " << chunk.size;
2577     for (int i = interval.start; i <= interval.end; ++i) {
2578       peak_memory_usage_[i] -= chunk.size;
2579       CHECK_GE(peak_memory_usage_[i], 0)
2580           << "Peak memory usage at " << i
2581           << " is below zero after uncommitting. " << interval.start << "-"
2582           << interval.end << " : [" << chunk.offset << ", " << chunk.size
2583           << "]";
2584     }
2585     interval_tree_.Remove(interval.start, interval.end, chunk);
2586   }
2587   for (const auto& interval : pending_async_copies_) {
2588     if (interval.destination == MemorySpace::kAlternate) {
2589       prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
2590                                      kDummyChunk);
2591       prefetch_async_copy_resource_.RemoveCopy(interval);
2592     } else {
2593       eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
2594                                      kDummyChunk);
2595       eviction_async_copy_resource_.RemoveCopy(interval);
2596     }
2597   }
2598   for (const auto& value_and_required_assignment :
2599        pending_required_assignments_) {
2600     auto& required_assignment_vector =
2601         required_assignments_[value_and_required_assignment.first];
2602     const RequiredMemoryAssignment& required_assignment =
2603         value_and_required_assignment.second;
2604     VLOG(3) << "Removing required assignment: "
2605             << (required_assignment.memory_space == MemorySpace::kDefault
2606                     ? "def"
2607                     : "alt")
2608             << " time = " << required_assignment.time << " off = "
2609             << (required_assignment.offset ? required_assignment.offset->offset
2610                                            : -1);
2611     for (auto it = required_assignment_vector.begin();
2612          it != required_assignment_vector.end(); ++it) {
2613       if (*it == value_and_required_assignment.second) {
2614         required_assignment_vector.erase(it);
2615         break;
2616       }
2617     }
2618   }
2619   ClearPendingChunks();
2620 }
2621 
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)2622 void AlternateMemoryBestFitHeap::FinalizeAllocations(
2623     absl::Span<AllocationValue> allocation_values) {
2624   absl::flat_hash_map<const AliasedOffset*,
2625                       std::vector<MemorySpaceAssignment::Allocation*>>
2626       colocation_map;
2627   for (AllocationValue& allocation_value : allocation_values) {
2628     for (auto& allocation : *allocation_value.allocation_sequence()) {
2629       allocations_->push_back(std::move(allocation));
2630       MemorySpaceAssignment::Allocation* inserted_allocation =
2631           allocations_->back().get();
2632       if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
2633         colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
2634             inserted_allocation);
2635       }
2636     }
2637   }
2638   // The allocations that have the same AliasedOffset need to be colocated.
2639   // Export these to repack_allocation_blocks_ so that we can repack them to
2640   // reduce fragmentation.
2641   for (auto& colocation : colocation_map) {
2642     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2643     for (MemorySpaceAssignment::Allocation* colocated_allocation :
2644          colocation.second) {
2645       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2646           colocated_allocation->start_time(), colocated_allocation->end_time(),
2647           colocated_allocation->chunk().size,
2648           colocated_allocation->chunk().offset,
2649           static_cast<int64_t>(repack_allocation_blocks_.size()),
2650           colocated_allocation));
2651       colocations.push_back(&repack_allocation_blocks_.back());
2652     }
2653     for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2654          colocations) {
2655       repack_block->colocations = colocations;
2656     }
2657   }
2658   ClearPendingChunks();
2659 }
2660 
ClearPendingChunks()2661 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
2662   pending_chunks_.clear();
2663   pending_async_copies_.clear();
2664   pending_required_assignments_.clear();
2665   aliased_offset_map_.clear();
2666   aliased_offsets_.clear();
2667 }
2668 
AddToPendingChunks(const BufferInterval & buffer_interval,const Chunk & chunk_candidate)2669 void AlternateMemoryBestFitHeap::AddToPendingChunks(
2670     const BufferInterval& buffer_interval, const Chunk& chunk_candidate) {
2671   VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
2672           << buffer_interval.end << " : [" << chunk_candidate.offset << ", "
2673           << chunk_candidate.size << "]";
2674   pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
2675   for (int i = buffer_interval.start; i <= buffer_interval.end; ++i) {
2676     peak_memory_usage_[i] += chunk_candidate.size;
2677     CHECK_LE(peak_memory_usage_[i], options_.max_size_in_bytes)
2678         << "Peak memory usage at " << i
2679         << " exceeds the max size of alternate memory. "
2680         << buffer_interval.start << "-" << buffer_interval.end << " : ["
2681         << chunk_candidate.offset << ", " << chunk_candidate.size << "]";
2682   }
2683   CommitChunk(buffer_interval, chunk_candidate);
2684 }
2685 
2686 std::optional<int>
FindEarliestTimeToSatisfyPeakMemory(int start_time,int end_time,int64_t size) const2687 AlternateMemoryBestFitHeap::FindEarliestTimeToSatisfyPeakMemory(
2688     int start_time, int end_time, int64_t size) const {
2689   int earliest_time;
2690   for (earliest_time = end_time;
2691        earliest_time >= start_time &&
2692        peak_memory_usage_[earliest_time] + size <= options_.max_size_in_bytes;
2693        --earliest_time) {
2694   }
2695   if (earliest_time == end_time) {
2696     return std::nullopt;
2697   }
2698   return earliest_time + 1;
2699 }
2700 
AllocateSegment(const AllocationRequest & request)2701 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
2702     const AllocationRequest& request) {
2703   auto allocation_sequence = request.allocation_value->allocation_sequence();
2704   // start_time == end_time is a special case where the value is consumed
2705   // multiple times by the same instruction. We can just find the previous
2706   // allocation and use that allocation.
2707   if (request.start_time == request.end_time) {
2708     MemorySpaceAssignment::Allocation* allocation =
2709         GetLiveAllocationAt(*allocation_sequence, request.end_time);
2710     CHECK_NE(allocation, nullptr);
2711     allocation->AddUse(request.use->hlo_use);
2712     return Result::kSuccess;
2713   }
2714 
2715   const HloPosition& defining_position =
2716       request.allocation_value->defining_position();
2717   VLOG(2) << "Finding allocation for "
2718           << request.allocation_value->ToShortString() << " ("
2719           << request.start_time << ", " << request.end_time
2720           << ") latest prefetch = " << request.latest_prefetch_time
2721           << " last use = " << request.allocation_value->uses().back().time
2722           << " use = " << request.use->hlo_use.ToString()
2723           << ". Size = " << request.size
2724           << ", def pos = " << defining_position.ToString();
2725   CHECK_LE(request.start_time, request.end_time);
2726   if (VLOG_IS_ON(3) && options_.cost_analysis) {
2727     VLOG(3) << "Definition benefit = "
2728             << options_.cost_analysis->GetAlternateMemoryBenefit(
2729                    request.allocation_value->defining_position())
2730             << " use benefit = "
2731             << options_.cost_analysis->GetAlternateMemoryBenefit(
2732                    request.use->hlo_use);
2733   }
2734 
2735   // There could be a requirement to pin this buffer to default memory either
2736   // because it is a parameter or an output.  If the buffer is a parameter, then
2737   // we're allowed to prefetch. If the use expects the output to be in default
2738   // memory, we cannot prefetch it because if we did, it would be in alternate
2739   // memory instead.
2740   auto required_assignment_at_start = RequiredMemoryAssignmentAt(
2741       request.allocation_value->value(), request.start_time);
2742   std::optional<MemorySpace> required_memory_space_at_start;
2743   if (required_assignment_at_start) {
2744     required_memory_space_at_start = required_assignment_at_start->memory_space;
2745   }
2746   // Find required assignment both for the use and its aliases. If they are both
2747   // non-nullopt, then make sure they require the same assignment.
2748   auto required_assignment_at_end = RequiredMemoryAssignmentAt(
2749       request.allocation_value->value(), request.end_time);
2750   auto aliased_required_assignment_at_end =
2751       AliasedRequiredAssignmentForUse(*request.use);
2752   if (required_assignment_at_end != aliased_required_assignment_at_end) {
2753     if (required_assignment_at_end == std::nullopt) {
2754       required_assignment_at_end = aliased_required_assignment_at_end;
2755     } else {
2756       CHECK(aliased_required_assignment_at_end == std::nullopt ||
2757             aliased_required_assignment_at_end->equals_ignoring_time(
2758                 *required_assignment_at_end));
2759     }
2760   }
2761   std::optional<MemorySpace> required_memory_space_at_end;
2762   if (required_assignment_at_end) {
2763     required_memory_space_at_end = required_assignment_at_end->memory_space;
2764   }
2765 
2766   if (required_assignment_at_start) {
2767     bool needs_required_allocation = true;
2768     if (!allocation_sequence->empty()) {
2769       auto prev_allocation_it = std::find_if(
2770           allocation_sequence->rbegin(), allocation_sequence->rend(),
2771           [&](const auto& allocation) {
2772             return allocation->memory_space() ==
2773                        required_memory_space_at_start &&
2774                    allocation->defining_position() == defining_position;
2775           });
2776       if (prev_allocation_it != allocation_sequence->rend()) {
2777         (*prev_allocation_it)->Extend(request.start_time);
2778         needs_required_allocation = false;
2779       }
2780     }
2781     if (needs_required_allocation) {
2782       std::optional<Chunk> aliased_chunk = std::nullopt;
2783       if (required_assignment_at_start->memory_space ==
2784           MemorySpace::kAlternate) {
2785         aliased_chunk =
2786             Chunk{required_assignment_at_start->offset->offset, request.size};
2787       }
2788       allocation_sequence->push_back(
2789           std::make_unique<MemorySpaceAssignment::Allocation>(
2790               defining_position, required_assignment_at_start->memory_space,
2791               aliased_chunk, request.start_time, request.start_time,
2792               /*is_scoped_allocation=*/false));
2793       if (required_assignment_at_start->memory_space ==
2794           MemorySpace::kAlternate) {
2795         CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2796                                    required_assignment_at_start->offset);
2797       }
2798     }
2799   }
2800 
2801   Result allocation_result = Result::kSuccess;
2802   // First try keeping the allocation entirely in the alternate memory.
2803   if (required_memory_space_at_start != MemorySpace::kDefault &&
2804       required_memory_space_at_end != MemorySpace::kDefault &&
2805       request.allow_no_copy_alternate_mem_allocation) {
2806     allocation_result = AllocateInAlternateMemoryNoCopy(request);
2807     if (allocation_result == Result::kSuccess) {
2808       return Result::kSuccess;
2809     }
2810   }
2811 
2812   auto prev_allocation_it = allocation_sequence->rbegin();
2813   // Find a previous allocation that is in the default memory space (not
2814   // necessarily the very last allocation).
2815   auto prev_allocation_in_default_mem_it =
2816       std::find_if(allocation_sequence->rbegin(), allocation_sequence->rend(),
2817                    [&](const auto& allocation) {
2818                      return allocation->memory_space() == MemorySpace::kDefault;
2819                    });
2820 
2821   if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2822       prev_allocation_it != allocation_sequence->rend() &&
2823       (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2824       (*prev_allocation_it)->defining_position() == defining_position &&
2825       !request.allocation_value->requires_contiguous_allocation()) {
2826     // If there was an allocation for this HloValue that was in the alternate
2827     // memory space, we also need to perform an eviction.
2828     Result eviction_result = Evict(request);
2829     if (eviction_result != Result::kSuccess) {
2830       // A non-success eviction requires us to uncommit previous allocations.
2831       return result_mark(Result::kFailRequiresUncommit, eviction_result);
2832     }
2833     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2834   } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2835     allocation_sequence->push_back(
2836         std::make_unique<MemorySpaceAssignment::Allocation>(
2837             defining_position, MemorySpace::kDefault, /*chunk=*/std::nullopt,
2838             request.start_time, request.end_time,
2839             /*is_scoped_allocation=*/false));
2840     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2841   }
2842 
2843   CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2844   CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2845         MemorySpace::kDefault);
2846 
2847   // If the buffer must be in default memory at the end_time, don't prefetch.
2848   if (required_memory_space_at_end == MemorySpace::kDefault) {
2849     VLOG(3)
2850         << "Not trying to prefetch because use requires buffer in default mem.";
2851     (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2852     (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2853     return Result::kSuccess;
2854   }
2855 
2856   // Finally, try to prefetch the buffer into alternate memory.
2857   if (!request.allocation_value->requires_contiguous_allocation()) {
2858     Result prefetch_result =
2859         Prefetch(request, **prev_allocation_in_default_mem_it);
2860     if (prefetch_result == Result::kSuccess) {
2861       return Result::kSuccess;
2862     }
2863     result_mark(prefetch_result, allocation_result);
2864   }
2865 
2866   // If the end assignment was required to be in alternate memory but that
2867   // wasn't possible, then this allocation is invalid.
2868   if (required_memory_space_at_end == MemorySpace::kAlternate) {
2869     return result_mark(Result::kFailRequiresUncommit, allocation_result);
2870   }
2871 
2872   // If the start assignment was required to be in alternate memory and the
2873   // buffer needs a contiguous assignment, we couldn't satisfy this requirement
2874   // and must abort.
2875   if (required_memory_space_at_start == MemorySpace::kAlternate &&
2876       request.allocation_value->requires_contiguous_allocation()) {
2877     return result_mark(Result::kFailRequiresUncommit, allocation_result);
2878   }
2879 
2880   // If a copy wasn't inserted, then add this use to the latest allocation in
2881   // default memory.
2882   (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2883   (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2884   return allocation_result;
2885 }
2886 
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,std::optional<Chunk> chunk,int64_t start_time,int64_t end_time,int64_t copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,float resource,bool is_cross_program_prefetch)2887 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2888     const MemorySpaceAssignment::Allocation& prev_allocation,
2889     MemorySpace memory_space, std::optional<Chunk> chunk, int64_t start_time,
2890     int64_t end_time, int64_t copy_done_schedule_before_time,
2891     MemorySpaceAssignment::AllocationSequence* allocations,
2892     AliasedOffset* aliased_offset, float resource,
2893     bool is_cross_program_prefetch) {
2894   VLOG(3) << "Copy to "
2895           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2896                   ? "default"
2897                   : "alternate")
2898           << " memory between " << start_time << " and "
2899           << copy_done_schedule_before_time << " keeping until " << end_time
2900           << ", estimated copy resource is " << resource;
2901   CHECK_LT(start_time, copy_done_schedule_before_time);
2902 
2903   allocations->push_back(
2904       std::make_unique<MemorySpaceAssignment::CopyAllocation>(
2905           prev_allocation, memory_space, chunk, start_time, end_time,
2906           copy_done_schedule_before_time, is_cross_program_prefetch));
2907 
2908   // Register the additional async copy with the interval tree to keep track of
2909   // the limit at any given time.
2910   pending_async_copies_.push_back({start_time, copy_done_schedule_before_time,
2911                                    resource, memory_space,
2912                                    next_async_copy_id_++});
2913   if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2914     prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2915                                 kDummyChunk);
2916     prefetch_async_copy_resource_.AddCopy(pending_async_copies_.back());
2917     CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2918   } else {
2919     eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2920                                 kDummyChunk);
2921     eviction_async_copy_resource_.AddCopy(pending_async_copies_.back());
2922   }
2923 }
2924 
ViolatesMaximumOutstandingAsyncCopies(int64_t start_time,int64_t end_time,bool is_prefetch,int64_t extra_async_copy_limit) const2925 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2926     int64_t start_time, int64_t end_time, bool is_prefetch,
2927     int64_t extra_async_copy_limit) const {
2928   if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2929     return false;
2930   }
2931   if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2932     return false;
2933   }
2934 
2935   // Count the prefetches/evictions in the interval tree for the given interval.
2936   if (is_prefetch) {
2937     int64_t num_prefetches =
2938         prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2939             .size();
2940     return num_prefetches >=
2941            options_.max_outstanding_prefetches + extra_async_copy_limit;
2942   } else {
2943     int64_t num_evictions =
2944         eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2945             .size();
2946     return num_evictions >=
2947            options_.max_outstanding_evictions + extra_async_copy_limit;
2948   }
2949 }
2950 
2951 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2952 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2953     const AllocationRequest& request) {
2954   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2955   bool can_eliminate_copy = false;
2956   if (request.allocation_value->allocation_sequence()->empty()) {
2957     // There hasn't been any allocations for this interval so far. We can
2958     // eliminate copy if the value can be placed in the alternate memory.
2959     can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2960         *request.allocation_value->value());
2961   } else {
2962     // If there has been a previous allocation, we can eliminate the copy if the
2963     // previous allocation was also in the alternate memory.
2964     prev_allocation =
2965         request.allocation_value->allocation_sequence()->back().get();
2966     can_eliminate_copy =
2967         (prev_allocation->memory_space() == MemorySpace::kAlternate);
2968   }
2969 
2970   if (!can_eliminate_copy) {
2971     return Result::kFailPrevAllocationNotInAlternateMem;
2972   }
2973 
2974   const HloPosition& defining_position =
2975       request.allocation_value->defining_position();
2976   if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2977           defining_position.shape(), request.start_time + 1,
2978           request.end_time)) {
2979     return Result::kFailLiveRangeTooLong;
2980   }
2981 
2982   BufferInterval alternate_mem_interval;
2983   alternate_mem_interval.buffer = request.allocation_value->value();
2984   alternate_mem_interval.size = request.size;
2985   alternate_mem_interval.end = request.end_time;
2986   alternate_mem_interval.start = request.start_time;
2987 
2988   // Prefer the offset that was previously used for the previous allocation.
2989   AliasedOffset* preferred_offset = nullptr;
2990   if (prev_allocation != nullptr) {
2991     preferred_offset = GetAliasedOffset(*prev_allocation);
2992     // If there is a previous allocation, set the start time one after the end
2993     // of the previous allocation's end.
2994     alternate_mem_interval.start = prev_allocation->end_time() + 1;
2995   }
2996 
2997   if (request.preferred_offset) {
2998     // Sanity check that if there is a preferred offset provided in the request,
2999     // it matches with the previous allocation.
3000     CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
3001         << "preferred_offset = " << preferred_offset->offset
3002         << ", request.preferred_offset = " << request.preferred_offset->offset;
3003     preferred_offset = request.preferred_offset;
3004   }
3005 
3006   VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
3007           << (preferred_offset ? preferred_offset->offset : -1);
3008   // In case there are additional uses after this use, we rely on the last use
3009   // time to try to reserve a chunk in the heap simulator. This is to prevent
3010   // the following scenario:
3011   //
3012   //                            +-------+
3013   //                           /         \
3014   //                   Producer--->Use1   +-->Use2
3015   //                       +---------+---------+
3016   // New buffer:           |         |         |
3017   //                       +---------+---------+
3018   //
3019   //                                     +-----------+
3020   // Current heap:                       | offset: 0 |
3021   //           --------------------------+-----------+------
3022   //
3023   // Because we allocate buffers greedily, Producer to Use1 segment first, and
3024   // then Use1 to Use2 segment, it is possible to allocate the first segment at
3025   // an offset that is available for the first segment (e.g. offset 0) but not
3026   // for the entire live range. This can result in unnecessary copies. By using
3027   // the last use time, we try to find an allocation that is available for the
3028   // entire Producer to Use2 range.
3029   std::optional<Chunk> chunk_candidate = FindBestChunkCandidate(
3030       request, preferred_offset, &alternate_mem_interval);
3031   // Check if the new heap size fits within limits. Also ensure if a
3032   // preferred offset was provided, that offset was used.
3033   if (chunk_candidate) {
3034     VLOG(3) << "Keep the buffer in alternate memory. Offset = "
3035             << chunk_candidate->offset << ", size = " << chunk_candidate->size
3036             << ", heap_size = " << result_.UpdatedHeapSize(*chunk_candidate)
3037             << ", prefetch picker = "
3038             << options_.prefetch_interval_picker->ToNoCopyDebugString(
3039                    defining_position.shape(), request.start_time,
3040                    request.end_time);
3041     AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
3042 
3043     // If there was a previous allocation, the buffer location is the
3044     // same as the previous. Otherwise, it is the operand.
3045     if (prev_allocation != nullptr &&
3046         (prev_allocation->is_copy_allocation() ||
3047          prev_allocation->defining_position() == defining_position)) {
3048       prev_allocation->Extend(request.end_time);
3049     } else {
3050       request.allocation_value->allocation_sequence()->push_back(
3051           std::make_unique<MemorySpaceAssignment::Allocation>(
3052               defining_position, MemorySpace::kAlternate, chunk_candidate,
3053               request.start_time, request.end_time,
3054               /*is_scoped_allocation=*/false));
3055       CreateOrAddToAliasedOffset(
3056           *request.allocation_value->allocation_sequence()->back(),
3057           preferred_offset);
3058     }
3059     request.allocation_value->allocation_sequence()->back()->AddUse(
3060         request.use->hlo_use);
3061     return Result::kSuccess;
3062   }
3063   return Result::kFailOutOfMemory;
3064 }
3065 
Evict(const AllocationRequest & request)3066 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
3067     const AllocationRequest& request) {
3068   CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
3069   MemorySpaceAssignment::Allocation* prev_allocation =
3070       request.allocation_value->allocation_sequence()->back().get();
3071   int64_t eviction_start_time = prev_allocation->start_time();
3072   int64_t eviction_end_time = prev_allocation->end_time();
3073   CHECK(eviction_start_time <= eviction_end_time);
3074 
3075   int64_t preferred_eviction_end_time =
3076       std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
3077                    request.allocation_value->defining_position().shape(),
3078                    eviction_start_time, request.end_time),
3079                eviction_end_time);
3080   // Evictions must complete by the time of this use.
3081   preferred_eviction_end_time =
3082       std::min(preferred_eviction_end_time, request.latest_prefetch_time);
3083 
3084   BufferInterval eviction_mem_interval;
3085   eviction_mem_interval.buffer = request.allocation_value->value();
3086   eviction_mem_interval.size = request.size;
3087   // Try to reserve a buffer from the end of the previous allocation to the
3088   // preferred eviction end time.
3089   eviction_mem_interval.start = eviction_end_time + 1;
3090   eviction_mem_interval.end = preferred_eviction_end_time;
3091   int64_t preferred_offset = prev_allocation->chunk().offset;
3092   VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
3093           << ") preferred end time = " << eviction_mem_interval.end;
3094 
3095   for (; eviction_mem_interval.end > eviction_end_time;
3096        --eviction_mem_interval.end) {
3097     Chunk chunk_candidate =
3098         FindChunkCandidate(eviction_mem_interval, preferred_offset);
3099     if (chunk_candidate.offset == preferred_offset) {
3100       AddToPendingChunks(eviction_mem_interval, chunk_candidate);
3101       break;
3102     }
3103   }
3104   eviction_end_time = eviction_mem_interval.end;
3105 
3106   VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
3107           << eviction_start_time << ", " << eviction_end_time << ")";
3108 
3109   float eviction_resource =
3110       options_.cost_analysis
3111           ? options_.cost_analysis->GetAsyncCopyElapsed(
3112                 request.allocation_value->defining_position().shape())
3113           : 0.1;
3114 
3115   bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
3116   bool eviction_violates_resource =
3117       !eviction_async_copy_resource_.HasEnoughResource(
3118           eviction_start_time, eviction_end_time, eviction_resource);
3119   if (eviction_violates_resource) {
3120     // If we're in the last retry, set resource to 0.
3121     if (options_.prefetch_interval_picker->retry_number() ==
3122         options_.max_retries - 1) {
3123       VLOG(3) << "Violates resource in last retry, setting resource = 0";
3124       eviction_resource = 0;
3125     }
3126     eviction_violates_resource =
3127         !eviction_async_copy_resource_.HasEnoughResource(
3128             eviction_start_time, eviction_end_time, eviction_resource);
3129   }
3130   bool eviction_violates_outstanding_copies =
3131       ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
3132                                             eviction_end_time,
3133                                             /*is_prefetch=*/false);
3134 
3135   // See if this interval would violate the asynchronous copy limit.
3136   if (!eviction_interval_too_short && !eviction_violates_outstanding_copies &&
3137       !eviction_violates_resource) {
3138     prev_allocation->Extend(eviction_end_time);
3139     AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
3140                  /*chunk=*/std::nullopt, eviction_start_time,
3141                  prev_allocation->end_time(), eviction_end_time,
3142                  request.allocation_value->allocation_sequence(),
3143                  /*aliased_offset=*/nullptr, eviction_resource);
3144   } else {
3145     if (eviction_violates_outstanding_copies) {
3146       VLOG(3) << "This violates the maximum async copies.";
3147     } else if (eviction_violates_resource) {
3148       VLOG(3) << "This violates resource.";
3149     } else {
3150       VLOG(3) << "Eviction interval is too short (" << eviction_start_time
3151               << ", " << eviction_end_time << ").";
3152     }
3153     // If the original interval violated the limit, try sub-intervals within
3154     // this interval.
3155     bool eviction_scheduled = false;
3156 
3157     if (!eviction_scheduled) {
3158       // If the eviction couldn't be scheduled, then fail. This buffer will be
3159       // kept in the default memory.
3160       VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
3161               << " because we hit the limit of maximum asynchronous copies "
3162               << "between "
3163               << hlo_live_range_.flattened_instruction_sequence()
3164                      .instructions()[eviction_start_time]
3165               << " and "
3166               << hlo_live_range_.flattened_instruction_sequence()
3167                      .instructions()[eviction_end_time];
3168       // return false;
3169       return Result::kFailOutOfAsyncCopies;
3170     }
3171   }
3172   // return true;
3173   return Result::kSuccess;
3174 }
3175 
FindPrefetchEndTime(const AllocationRequest & request,int64_t earliest_prefetch_time) const3176 int64_t AlternateMemoryBestFitHeap::FindPrefetchEndTime(
3177     const AllocationRequest& request, int64_t earliest_prefetch_time) const {
3178   return request.latest_prefetch_time;
3179 }
3180 
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)3181 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
3182     const AllocationRequest& request,
3183     const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
3184   // Try partially placing the buffer in the alternate space. The time that is
3185   // overlapped will be used to asynchronously copy the buffer from the
3186   // default memory to the alternate memory.
3187   //
3188   //                      start                 end
3189   //                      time                  time
3190   //                      X---------------------X
3191   // Alternate:                          +------+
3192   // Default:             +---------------------+
3193   //                                     ^      ^
3194   //                                   Copy    Copy
3195   //                                   Start   Done
3196   int64_t earliest_prefetch_time =
3197       prev_allocation_in_default_mem.earliest_available_time();
3198   if (request.earliest_prefetch_time) {
3199     earliest_prefetch_time =
3200         std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
3201   }
3202   int64_t prefetch_end_time =
3203       FindPrefetchEndTime(request, earliest_prefetch_time);
3204 
3205   // As a compile time optimization, use the peak memory usage to filter out
3206   // allocation times that would push us to OOM.
3207   std::optional<int> earliest_non_oom_prefetch_time =
3208       FindEarliestTimeToSatisfyPeakMemory(earliest_prefetch_time,
3209                                           prefetch_end_time, request.size);
3210   Result result = Result::kSuccess;
3211   if (!earliest_non_oom_prefetch_time) {
3212     VLOG(3) << "Any prefetch in range (" << earliest_prefetch_time << ", "
3213             << prefetch_end_time << ") for size " << request.size
3214             << " would go out of memory.";
3215     result_mark(Result::kFailOutOfMemory, result);
3216     return result;
3217   }
3218   VLOG(4) << "After peak memory check, prefetch range is ("
3219           << *earliest_non_oom_prefetch_time << ", " << prefetch_end_time
3220           << "). Original earliest prefetch time is " << earliest_prefetch_time;
3221   earliest_prefetch_time = *earliest_non_oom_prefetch_time;
3222   options_.prefetch_interval_picker->Begin(
3223       request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
3224   VLOG(3) << "Trying prefetch picker = "
3225           << options_.prefetch_interval_picker->ToDebugString();
3226 
3227   // Create an alternate memory interval that starts at the earliest
3228   // possible position, given by max_prefetch_interval.
3229   BufferInterval alternate_mem_interval;
3230   alternate_mem_interval.buffer = request.allocation_value->value();
3231   alternate_mem_interval.size = request.size;
3232   // As a compile time optimization, try a prefetch allocation that is as late
3233   // as possible. If this is not able to find a chunk candidate, none of the
3234   // earlier tries will succeed either.
3235   alternate_mem_interval.start =
3236       options_.prefetch_interval_picker->latest_time();
3237   auto chunk_candidate = FindBestChunkCandidate(
3238       request, request.preferred_offset, &alternate_mem_interval);
3239   if (!chunk_candidate) {
3240     VLOG(3) << "The latest prefetch (" << alternate_mem_interval.start << ", "
3241             << request.end_time << ") cannot find a valid chunk. Giving up.";
3242     result_mark(Result::kFailOutOfMemory, result);
3243     return result;
3244   }
3245   const HloUse& use = request.use->hlo_use;
3246   const Shape& shape = ShapeUtil::GetSubshape(
3247       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
3248   // While uses might be allowed to have additional outstanding prefetches.
3249   int64_t extra_async_copy_limit =
3250       request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
3251           ? options_.while_use_extra_outstanding_prefetch_limit
3252           : 0;
3253   // As a compilation time optimization, store the prefetch start time where we
3254   // have first seen out of memory. There is no point of exploring prefetch
3255   // start times earlier than this point.
3256   std::optional<int64_t> out_of_mem_start;
3257   while (!options_.prefetch_interval_picker->Done()) {
3258     alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
3259     CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
3260     if (out_of_mem_start.has_value() &&
3261         alternate_mem_interval.start <= *out_of_mem_start) {
3262       VLOG(4) << "This would OOM (cached).";
3263       result_mark(Result::kFailOutOfMemory, result);
3264       continue;
3265     }
3266     int64_t estimated_prefetch_end_time =
3267         options_.prefetch_interval_picker->EstimatedPrefetchEndTime(
3268             shape, alternate_mem_interval.start, prefetch_end_time);
3269     VLOG(4) << "Trying alternate memory allocation ("
3270             << alternate_mem_interval.start << ", " << request.end_time
3271             << "), estimated prefetch end time = "
3272             << estimated_prefetch_end_time;
3273     float prefetch_resource =
3274         options_.cost_analysis
3275             ? options_.cost_analysis->GetAsyncCopyElapsed(shape)
3276             : 0.1;
3277     if (!prefetch_async_copy_resource_.HasEnoughResource(
3278             alternate_mem_interval.start, prefetch_end_time,
3279             prefetch_resource)) {
3280       VLOG(4) << "This would violate asynchronous copy resource = "
3281               << prefetch_resource;
3282       result_mark(Result::kFailViolatesAsyncCopyResource, result);
3283       continue;
3284     }
3285     if (ViolatesMaximumOutstandingAsyncCopies(
3286             alternate_mem_interval.start, prefetch_end_time,
3287             /*is_prefetch=*/true, extra_async_copy_limit)) {
3288       VLOG(4) << "This would violate the outstanding async copy limit.";
3289       result_mark(Result::kFailOutOfAsyncCopies, result);
3290       continue;
3291     }
3292 
3293     auto chunk_candidate = FindBestChunkCandidate(
3294         request, request.preferred_offset, &alternate_mem_interval);
3295     // Check if we could find a suitable chunk.
3296     if (chunk_candidate) {
3297       VLOG(3) << "Move the buffer to alternate memory at "
3298               << alternate_mem_interval.start
3299               << ". Offset = " << chunk_candidate->offset
3300               << ", size = " << chunk_candidate->size
3301               << ", heap_size = " << result_.UpdatedHeapSize(*chunk_candidate)
3302               << ", prefetch picker = "
3303               << options_.prefetch_interval_picker->ToDebugString();
3304       AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
3305 
3306       AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
3307                    chunk_candidate, alternate_mem_interval.start,
3308                    request.end_time, prefetch_end_time,
3309                    request.allocation_value->allocation_sequence(),
3310                    request.preferred_offset, prefetch_resource);
3311 
3312       request.allocation_value->allocation_sequence()->back()->AddUse(
3313           request.use->hlo_use);
3314       return Result::kSuccess;
3315     } else {
3316       // Mark the out of memory start with the prefetch start time so that we
3317       // don't explore prefetch start times earlier than this point.
3318       out_of_mem_start =
3319           std::max(out_of_mem_start.has_value() ? *out_of_mem_start : -1,
3320                    alternate_mem_interval.start);
3321     }
3322     result_mark(Result::kFailOutOfMemory, result);
3323   }
3324   // If we didn't consider any prefetch intervals, then the live range was too
3325   // short.
3326   if (result == Result::kSuccess) {
3327     return Result::kFailLiveRangeTooShort;
3328   } else {
3329     return result;
3330   }
3331 }
3332 
3333 std::optional<AlternateMemoryBestFitHeap::Chunk>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const3334 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
3335     const AllocationRequest& request, const AliasedOffset* preferred_offset,
3336     BufferInterval* alternate_mem_interval) const {
3337   int64_t end_time = request.end_time;
3338   if (!preferred_offset) {
3339     // First find the earliest use that is the same or later than the end time.
3340     const auto& use_times = request.all_use_times;
3341     auto use_time_it = absl::c_lower_bound(use_times, end_time);
3342     CHECK(use_time_it != use_times.end());
3343     int64_t earliest_use = *use_time_it;
3344     auto earliest_use_it = use_time_it;
3345 
3346     // Then find the latest use that can be allocated contiguously without
3347     // copies.
3348     const Shape& shape = request.allocation_value->defining_position().shape();
3349     for (;
3350          (use_time_it + 1) != use_times.end() &&
3351          options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
3352              shape, *use_time_it, *(use_time_it + 1));
3353          ++use_time_it) {
3354     }
3355     CHECK(use_time_it != use_times.end());
3356     int64_t latest_contiguous_use_time = *use_time_it;
3357 
3358     // Find a chunk that's as long living as possible.
3359     std::optional<Chunk> last_chunk_candidate;
3360     int64_t latest_matching_use = std::numeric_limits<int64_t>::min();
3361     std::lower_bound(
3362         earliest_use_it, std::next(use_time_it), -1, [&](int64_t use, int64_t) {
3363           alternate_mem_interval->end = use;
3364           Chunk chunk_candidate = FindChunkCandidate(*alternate_mem_interval);
3365           if (chunk_candidate.chunk_end() <= available_heap_size()) {
3366             if (use > latest_matching_use) {
3367               last_chunk_candidate = chunk_candidate;
3368               latest_matching_use = use;
3369             }
3370             return true;
3371           }
3372           return false;
3373         });
3374     if (last_chunk_candidate.has_value()) {
3375       VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
3376               << ", latest contiguous use = " << latest_contiguous_use_time
3377               << ", use with available mem = " << latest_matching_use
3378               << ", offset = " << last_chunk_candidate->offset;
3379     }
3380     alternate_mem_interval->end = end_time;
3381     return last_chunk_candidate;
3382   }
3383   // If a preferred offset is given, try to find an allocation at that offset
3384   // only.
3385   alternate_mem_interval->end = end_time;
3386   Chunk chunk_candidate =
3387       FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
3388   if (chunk_candidate.offset == preferred_offset->offset) {
3389     return chunk_candidate;
3390   }
3391   return std::nullopt;
3392 }
3393 
3394 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const3395 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
3396   AsyncCopyStats stats;
3397   stats.max_outstanding_async_copies = 0;
3398   stats.num_prefetches = 0;
3399   stats.prefetch_bytes = 0;
3400   stats.num_evictions = 0;
3401   stats.eviction_bytes = 0;
3402   int64_t current_copies = 0;
3403   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
3404                       HloDataflowAnalysis::Run(*module_));
3405   for (const HloComputation* computation :
3406        module_->MakeNonfusionComputations()) {
3407     for (HloInstruction* instruction : computation->instructions()) {
3408       if (instruction->opcode() == HloOpcode::kCopyStart) {
3409         current_copies++;
3410       } else if (instruction->opcode() == HloOpcode::kCopyDone) {
3411         current_copies--;
3412         int64_t size =
3413             options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
3414         if (instruction->shape().layout().memory_space() ==
3415             options_.alternate_memory_space) {
3416           ++stats.num_prefetches;
3417           stats.prefetch_bytes += size;
3418         } else {
3419           ++stats.num_evictions;
3420           stats.eviction_bytes += size;
3421         }
3422       }
3423       stats.max_outstanding_async_copies =
3424           std::max(stats.max_outstanding_async_copies, current_copies);
3425     }
3426   }
3427   return stats;
3428 }
3429 
3430 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)3431 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
3432     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
3433     MemorySpaceAssignmentCostAnalysis::Cache* cache) {
3434   return [&cost_analysis, cache](const BufferInterval& x,
3435                                  const BufferInterval& y) {
3436     float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
3437     float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
3438     if (x_memory_boundedness != y_memory_boundedness) {
3439       return x_memory_boundedness > y_memory_boundedness;
3440     }
3441     // Tie-break if the memory boundedness is the same.
3442     return GlobalDecreasingSizeBestFitHeap<
3443         HloValue>::GetSpatialBufferIntervalCompare()(x, y);
3444   };
3445 }
3446 
3447 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)3448 MemorySpaceAssignment::Run(HloModule* module,
3449                            const HloLiveRange& hlo_live_range,
3450                            const HloAliasAnalysis& alias_analysis,
3451                            const Options& options) {
3452   CHECK(module->has_schedule());
3453   VLOG(3) << "Module before memory space assignment: ";
3454   XLA_VLOG_LINES(3, module->ToString());
3455   VLOG(3) << "Schedule: " << module->schedule().ToString();
3456   MemorySpaceAssignment memory_space_assignment(module, options,
3457                                                 hlo_live_range);
3458 
3459   return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
3460                                                           alias_analysis);
3461 }
3462 
3463 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)3464 MemorySpaceAssignment::RunMemorySpaceAssignment(
3465     const HloLiveRange& hlo_live_range,
3466     const HloAliasAnalysis& alias_analysis) {
3467   TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
3468 
3469   if (options_.cost_analysis) {
3470     float estimated_time =
3471         ComputeEstimatedElapsedTime(hlo_live_range, allocations_);
3472     VLOG(1) << "Estimated elapsed time (sec): " << estimated_time;
3473   }
3474 
3475   TF_RETURN_IF_ERROR(Process());
3476   ScheduleAsynchronousCopies();
3477   TF_RETURN_IF_ERROR(SimplifyGraph());
3478   TF_RETURN_IF_ERROR(FixSchedule());
3479   TF_RETURN_IF_ERROR(ExportAndColorBuffers());
3480 
3481   VLOG(3) << "Module after memory space assignment: ";
3482   XLA_VLOG_LINES(3, module_->ToString());
3483   TF_CHECK_OK(module_->schedule().Verify());
3484   TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
3485   VLOG(1) << "Maximum number of outstanding async copies: "
3486           << stats.max_outstanding_async_copies;
3487   VLOG(1) << "Number of prefetches: " << stats.num_prefetches
3488           << ", in bytes: " << stats.prefetch_bytes;
3489   VLOG(1) << "Number of evictions: " << stats.num_evictions
3490           << ", in bytes: " << stats.eviction_bytes;
3491 
3492   TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
3493 
3494   return std::move(preset_assignments_);
3495 }
3496 
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)3497 Status MemorySpaceAssignment::FindAllocationSequence(
3498     const HloLiveRange& hlo_live_range,
3499     const HloAliasAnalysis& alias_analysis) {
3500   auto algorithm = std::make_unique<AlternateMemoryBestFitHeap>(
3501       &allocations_, options_, alias_analysis, hlo_live_range);
3502 
3503   HeapSimulator::Options heap_simulator_options;
3504   heap_simulator_options.may_reuse_operand_buffers = false;
3505   heap_simulator_options.alloc_constants = true;
3506   TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
3507                                         module_->schedule(), alias_analysis,
3508                                         options_.size_fn,
3509                                         heap_simulator_options)
3510                          .status());
3511   return OkStatus();
3512 }
3513 
AddUse(HloUse use)3514 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
3515   HloInstruction* operand =
3516       use.instruction->mutable_operand(use.operand_number);
3517   // If the use is a tuple, look inside the tuple to find the actual use.
3518   for (int64_t index : use.operand_index) {
3519     if (operand->opcode() != HloOpcode::kTuple) {
3520       break;
3521     }
3522     operand = operand->mutable_operand(index);
3523   }
3524 
3525   // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
3526   std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
3527   get_simplified_operand = [&](HloInstruction* instruction) {
3528     while (instruction->opcode() == HloOpcode::kGetTupleElement) {
3529       HloInstruction* operand =
3530           get_simplified_operand(instruction->mutable_operand(0));
3531       if (operand->opcode() == HloOpcode::kTuple) {
3532         instruction = operand->mutable_operand(instruction->tuple_index());
3533       } else {
3534         return instruction;
3535       }
3536     }
3537     return instruction;
3538   };
3539   operand = get_simplified_operand(operand);
3540 
3541   uses_.push_back(use);
3542 }
3543 
ComputeEstimatedElapsedTime(const HloLiveRange & hlo_live_range,const AllocationSequence & allocations)3544 float MemorySpaceAssignment::ComputeEstimatedElapsedTime(
3545     const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) {
3546   absl::flat_hash_map<const HloInstruction*, std::vector<ShapeIndex>>
3547       outputs_in_alternate_memory_map;
3548   absl::flat_hash_map<const HloInstruction*,
3549                       std::vector<std::pair<int64_t, ShapeIndex>>>
3550       operands_in_alternate_memory_map;
3551 
3552   for (auto& allocation : allocations) {
3553     if (!allocation->is_copy_allocation()) {
3554       if (allocation->memory_space() == MemorySpace::kAlternate) {
3555         const HloInstruction* defining_instruction =
3556             allocation->defining_position().instruction;
3557         outputs_in_alternate_memory_map[defining_instruction].push_back(
3558             allocation->defining_position().index);
3559       }
3560     }
3561     for (auto& hlo_use : allocation->uses()) {
3562       const HloInstruction* use_instruction = hlo_use.instruction;
3563       operands_in_alternate_memory_map[use_instruction].push_back(
3564           std::make_pair(hlo_use.operand_number, hlo_use.operand_index));
3565     }
3566   }
3567 
3568   const auto& instruction_sequence =
3569       hlo_live_range.flattened_instruction_sequence().instructions();
3570   float total_elapsed = 0.0;
3571   for (const HloInstruction* instruction : instruction_sequence) {
3572     std::vector<ShapeIndex> outputs_in_alternate_memory;
3573     auto output_it = outputs_in_alternate_memory_map.find(instruction);
3574     if (output_it != outputs_in_alternate_memory_map.end()) {
3575       outputs_in_alternate_memory = output_it->second;
3576     }
3577     std::vector<std::pair<int64_t, ShapeIndex>> operands_in_alternate_memory;
3578     auto operand_it = operands_in_alternate_memory_map.find(instruction);
3579     if (operand_it != operands_in_alternate_memory_map.end()) {
3580       operands_in_alternate_memory = operand_it->second;
3581     }
3582     float instruction_elapsed =
3583         options_.cost_analysis->GetInstructionElapsedInAlternateMemory(
3584             *instruction, operands_in_alternate_memory,
3585             outputs_in_alternate_memory);
3586     float while_nest_multiplier = IPow<float>(
3587         options_.xla_tpu_memory_space_assignment_while_execution_count,
3588         options_.cost_analysis->CalculateComputationNestLevel(
3589             instruction,
3590             /*while_only=*/true));
3591     total_elapsed += while_nest_multiplier * instruction_elapsed;
3592   }
3593   return total_elapsed;
3594 }
3595 
Process()3596 Status MemorySpaceAssignment::Allocation::Process() {
3597   if (is_scoped_allocation()) {
3598     // Nothing to do here for scoped allocations.
3599     return OkStatus();
3600   }
3601   HloInstruction* producing_instruction = AddGetTupleElements();
3602   HloComputation* computation = producing_instruction->parent();
3603   for (const HloUse& use : uses_) {
3604     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3605     HloInstruction* replacement_instruction = producing_instruction;
3606     if (operand_shape.IsTuple()) {
3607       TF_ASSIGN_OR_RETURN(
3608           replacement_instruction,
3609           TupleUtil::ReplaceTupleWith(
3610               producing_instruction,
3611               use.instruction->mutable_operand(use.operand_number),
3612               use.operand_index));
3613     } else if (operand_shape != producing_instruction->shape()) {
3614       VLOG(4) << "Old shape = " << operand_shape.ToString()
3615               << ", new shape = " << producing_instruction->shape().ToString()
3616               << "; inserting a bitcast.";
3617       replacement_instruction = computation->AddInstruction(
3618           HloInstruction::CreateBitcast(operand_shape, producing_instruction));
3619     }
3620     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3621         use.operand_number, replacement_instruction));
3622   }
3623   return OkStatus();
3624 }
3625 
AddGetTupleElements() const3626 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const {
3627   CHECK_NE(defining_position().instruction, nullptr);
3628 
3629   Shape shape = defining_position().shape();
3630   CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
3631                          << shape.ToString()
3632                          << " position = " << defining_position().shape();
3633   return TupleUtil::AddGetTupleElements(defining_position());
3634 }
3635 
ToString() const3636 std::string MemorySpaceAssignment::Allocation::ToString() const {
3637   std::string memory_space_str = "def";
3638   if (memory_space_ == MemorySpace::kAlternate) {
3639     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3640   }
3641   return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""),
3642                       "Allocation in ", memory_space_str, " defined at ",
3643                       defining_position_.ToString(),
3644                       ", start_time:", start_time(), ", end_time:", end_time(),
3645                       ", uses: ", UsesToString(uses()));
3646 }
3647 
ToString() const3648 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
3649   std::string memory_space_str = "def";
3650   if (memory_space_ == MemorySpace::kAlternate) {
3651     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3652   }
3653   return absl::StrCat("Copy Allocation in ", memory_space_str,
3654                       ", start_time:", start_time(), ", end_time:", end_time(),
3655                       ", copy_start_after_time: ", copy_start_schedule_after(),
3656                       ", copy_done_before_time: ", copy_done_schedule_before(),
3657                       ", uses: ", UsesToString(uses()), ", from ",
3658                       prev_allocation_.ToString());
3659 }
3660 
ToString() const3661 std::string MemorySpaceAssignment::MirroredAllocation::ToString() const {
3662   return absl::StrCat("Mirrored Allocation for ",
3663                       original_allocation_.ToString());
3664 }
3665 
ToString() const3666 std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
3667   return absl::StrCat("Parent Allocation mirrored at ",
3668                       defining_position_.ToString(), ", originally ",
3669                       original_allocation_.ToString());
3670 }
3671 
Process()3672 Status MemorySpaceAssignment::CopyAllocation::Process() {
3673   // Copy allocations need to insert asynchronous copy nodes.
3674   Shape shape = defining_position().shape();
3675   HloInstruction* producing_instruction = AddGetTupleElements();
3676   HloComputation* computation = producing_instruction->parent();
3677   copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
3678       ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
3679       producing_instruction, is_cross_program_prefetch_));
3680   copy_done_ = computation->AddInstruction(
3681       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
3682   VLOG(4) << "Created " << copy_start_->name()
3683           << " for copy allocation: " << ToString();
3684   // Update the allocation position with the copy done instruction so that if
3685   // there are further copies from it, it can find the correct position.
3686   defining_position_ = HloPosition{copy_done_, {}};
3687 
3688   // Replace all the uses with the new copy instruction.
3689   for (HloUse use : uses_) {
3690     // If the operand is a tuple, we need to descend to the actual instruction
3691     // we want to replace.
3692     HloInstruction* replacement_instruction;
3693     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3694     if (operand_shape.IsTuple()) {
3695       TF_ASSIGN_OR_RETURN(
3696           replacement_instruction,
3697           TupleUtil::ReplaceTupleWith(
3698               copy_done_, use.instruction->mutable_operand(use.operand_number),
3699               use.operand_index));
3700     } else if (operand_shape != copy_done_->shape()) {
3701       VLOG(4) << "Old shape = " << operand_shape.ToString()
3702               << ", new shape = " << copy_done_->shape().ToString()
3703               << "; inserting a bitcast.";
3704       replacement_instruction = computation->AddInstruction(
3705           HloInstruction::CreateBitcast(operand_shape, copy_done_));
3706     } else {
3707       replacement_instruction = copy_done_;
3708     }
3709     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3710         use.operand_number, replacement_instruction));
3711   }
3712 
3713   return OkStatus();
3714 }
3715 
Process()3716 Status MemorySpaceAssignment::MirroredAllocation::Process() {
3717   defining_position_ = original_allocation_.defining_position();
3718   return Allocation::Process();
3719 }
3720 
Process()3721 Status MemorySpaceAssignment::ParentAllocation::Process() {
3722   // Add an additional parameter to the while HLO with a reference to the buffer
3723   // in the default memory space.
3724   HloInstruction* producing_instruction =
3725       original_allocation_.AddGetTupleElements();
3726   int new_tuple_index = calling_instruction_->shape().tuple_shapes_size();
3727 
3728   TF_ASSIGN_OR_RETURN(
3729       HloInstruction * new_while_operand,
3730       TupleUtil::ReplaceTupleWith(producing_instruction,
3731                                   calling_instruction_->mutable_operand(0),
3732                                   {new_tuple_index}));
3733   TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape(
3734       0, new_while_operand));
3735   *calling_instruction_->mutable_shape() = new_while_operand->shape();
3736   *calling_instruction_->while_condition()
3737        ->parameter_instruction(0)
3738        ->mutable_shape() = new_while_operand->shape();
3739   *calling_instruction_->while_body()
3740        ->parameter_instruction(0)
3741        ->mutable_shape() = new_while_operand->shape();
3742   defining_position_.index = {new_tuple_index};
3743   // Also replace the while op with a tuple that has the old shape. Note that we
3744   // need to first take a snapshot of the users before calling ExtractPrefix
3745   // since ExtractPrefix introduces additional gte users.
3746   std::vector<HloInstruction*> while_users = calling_instruction_->users();
3747   HloInstruction* tuple_with_old_shape =
3748       TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index);
3749   TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape(
3750       while_users, tuple_with_old_shape));
3751   return Allocation::Process();
3752 }
3753 
PostProcess()3754 Status MemorySpaceAssignment::ParentAllocation::PostProcess() {
3755   // Update the root of the while body with the new parameter. The reason why we
3756   // need a separate post-process for this is because other allocations may have
3757   // while body root as a use, so they would update the old root instead of the
3758   // new root. Doing the post-process step later ensures the root has been
3759   // updated with other changes, and we can safely add the additional parameter.
3760   HloComputation* while_body = calling_instruction_->while_body();
3761   TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root,
3762                       TupleUtil::ReplaceTupleWith(
3763                           AddGetTupleElements(), while_body->root_instruction(),
3764                           defining_position_.index));
3765   while_body->set_root_instruction(new_while_body_root,
3766                                    /*accept_different_shape=*/true);
3767   return OkStatus();
3768 }
3769 
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3770 void MemorySpaceAssignment::Allocation::MarkIfNeeded(
3771     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3772   MarkNeeded(needed_allocations);
3773 }
3774 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3775 void MemorySpaceAssignment::Allocation::MarkNeeded(
3776     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3777   needed_allocations.insert(this);
3778 }
3779 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3780 void MemorySpaceAssignment::CopyAllocation::MarkNeeded(
3781     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3782   needed_allocations.insert(this);
3783   prev_allocation_.MarkNeeded(needed_allocations);
3784 }
3785 
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3786 void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded(
3787     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3788   // Parent allocations are only needed if they have any uses or if there is a
3789   // copy allocation that copies this value (in that case, the copy allocation
3790   // will call this allocation's MarkNeeded function).
3791   if (!uses_.empty()) {
3792     MarkNeeded(needed_allocations);
3793   }
3794 }
3795 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3796 void MemorySpaceAssignment::ParentAllocation::MarkNeeded(
3797     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3798   needed_allocations.insert(this);
3799   original_allocation_.MarkNeeded(needed_allocations);
3800 }
3801 
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3802 void MemorySpaceAssignment::MirroredAllocation::MarkNeeded(
3803     absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3804   needed_allocations.insert(this);
3805   original_allocation_.MarkNeeded(needed_allocations);
3806 }
3807 
Process()3808 Status MemorySpaceAssignment::Process() {
3809   VLOG(1) << "Processing assigned buffers...";
3810   // Since some parent allocations may not be needed (e.g. when they don't have
3811   // any uses and if there is no other (non-parent) allocation that depends on
3812   // it, before we process the allocations, mark all allocations that are
3813   // needed.
3814   absl::flat_hash_set<const Allocation*> needed_allocations;
3815   for (auto& allocation : allocations_) {
3816     allocation->MarkIfNeeded(needed_allocations);
3817   }
3818   // Insert CopyStart/CopyDone pairs.
3819   for (auto& allocation : allocations_) {
3820     VLOG(3) << "Processing: " << allocation->ToString();
3821     if (!needed_allocations.contains(allocation.get())) {
3822       VLOG(3) << "Allocation not needed.";
3823       continue;
3824     }
3825     TF_RETURN_IF_ERROR(allocation->Process());
3826     // Add the offset and size of the allocation in the alternate memory to
3827     // the output map.
3828     if (allocation->is_scoped_allocation()) {
3829       CHECK(allocation->memory_space() == MemorySpace::kAlternate);
3830       scoped_memory_assignments_.emplace_back(
3831           allocation->defining_position().instruction, allocation->chunk());
3832       alternate_memory_size_ =
3833           std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3834     } else if (allocation->memory_space() == MemorySpace::kAlternate) {
3835       alternate_memory_assignments_.emplace_back(
3836           allocation->defining_position(), allocation->chunk());
3837       alternate_memory_size_ =
3838           std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3839     }
3840   }
3841   // Post-process allocations. This is only used for parent allocations where we
3842   // update the body root with a reference to the buffer in default memory
3843   // space.
3844   for (auto& allocation : allocations_) {
3845     if (needed_allocations.contains(allocation.get())) {
3846       VLOG(3) << "Post-Processing: " << allocation->ToString();
3847       TF_RETURN_IF_ERROR(allocation->PostProcess());
3848     }
3849   }
3850   return OkStatus();
3851 }
3852 
ExportAndColorBuffers()3853 Status MemorySpaceAssignment::ExportAndColorBuffers() {
3854   VLOG(1) << "Exporting buffers...";
3855   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
3856   absl::flat_hash_map<int64_t, int64_t> seen_buffer_offsets;
3857   VLOG(3) << "Exported alternate memory allocations:";
3858   for (const auto& position_and_chunk : alternate_memory_assignments_) {
3859     const HloPosition& defining_position = position_and_chunk.first;
3860     const Chunk& chunk = position_and_chunk.second;
3861     const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
3862         defining_position.instruction, defining_position.index);
3863     auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
3864     if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
3865       CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
3866           << "Mismatch in offset for positions that map to the same value: "
3867           << buffer.ToString() << ", pos: " << defining_position.ToString();
3868     } else {
3869       VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3870               << "] : " << defining_position.ToString() << " ("
3871               << buffer.ToString() << ")";
3872       preset_assignments_->add_chunk(defining_position, chunk);
3873       seen_buffer_offsets[buffer.id()] = chunk.offset;
3874     }
3875   }
3876 
3877   VLOG(3) << "Exported scoped allocations in alternate memory:";
3878   for (const auto& instruction_and_chunk : scoped_memory_assignments_) {
3879     HloInstruction* instruction = instruction_and_chunk.first;
3880     const Chunk& chunk = instruction_and_chunk.second;
3881     VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3882             << "] : " << instruction->name();
3883     preset_assignments_->add_scoped_allocation_chunk(instruction, chunk);
3884   }
3885 
3886   if (!preset_assignments_->chunks().empty() ||
3887       !preset_assignments_->scoped_allocation_chunks().empty()) {
3888     preset_assignments_
3889         ->assignment_information_for_space(options_.alternate_memory_space)
3890         ->size = alternate_memory_size_;
3891   }
3892 
3893   VLOG(3) << "Exported alternate memory sizes:";
3894   for (auto& pair : preset_assignments_->assignment_informations()) {
3895     VLOG(3) << "  space: " << pair.first << ", size: " << pair.second.size;
3896   }
3897 
3898   VLOG(1) << "Coloring buffers...";
3899   // Color the pending positions and all of their aliased buffers.
3900   for (const auto& defining_position_and_chunk :
3901        preset_assignments_->chunks()) {
3902     const HloPosition& defining_position = defining_position_and_chunk.first;
3903     for (auto& buffer : alias_analysis->ComputeBuffersAt(
3904              defining_position.instruction, defining_position.index)) {
3905       for (auto& value : buffer->values()) {
3906         for (auto& position : value->positions()) {
3907           VLOG(4) << "Coloring " << position.ToString();
3908           Shape* shape = ShapeUtil::GetMutableSubshape(
3909               position.instruction->mutable_shape(), position.index);
3910           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
3911                                   << position.ToString();
3912           shape->mutable_layout()->set_memory_space(
3913               options_.alternate_memory_space);
3914         }
3915       }
3916     }
3917   }
3918   return OkStatus();
3919 }
3920 
RemoveAssignmentForInstruction(const HloInstruction * instruction)3921 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
3922     const HloInstruction* instruction) {
3923   auto it = alternate_memory_assignments_.begin();
3924   auto end = alternate_memory_assignments_.end();
3925   while (it != end) {
3926     const HloPosition& position = it->first;
3927     if (position.instruction == instruction) {
3928       VLOG(3) << "Removing instruction from alternate memory assignments.";
3929       if (std::next(it) == end) {
3930         alternate_memory_assignments_.pop_back();
3931         break;
3932       } else {
3933         // Swap the removed position and chunk with the back and pop back.
3934         *it = alternate_memory_assignments_.back();
3935         alternate_memory_assignments_.pop_back();
3936         end = alternate_memory_assignments_.end();
3937       }
3938     } else {
3939       ++it;
3940     }
3941   }
3942 }
3943 
SimplifyGraph()3944 Status MemorySpaceAssignment::SimplifyGraph() {
3945   VLOG(1) << "Simplifying graph...";
3946   for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3947     // Parallel computations aren't in the schedule and don't need to be
3948     // modified.
3949     if (!computations_in_schedule_.contains(computation)) {
3950       VLOG(4) << "Not simplifying " << computation->name()
3951               << " because it's not in the schedule.";
3952       continue;
3953     }
3954     // Drop control dependencies. Since the computation is already scheduled, we
3955     // don't need control dependencies anymore, and having control
3956     // predecessors/successors prevents us from removing instructions without
3957     // users (HloComputation::IsSafelyRemovable returns false if there are
3958     // control dependencies).
3959     for (HloInstruction* instruction :
3960          computation->MakeInstructionPostOrder()) {
3961       TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3962     }
3963     // We perform limited DCE and forward the tuple operand in patterns like
3964     // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3965     // assignment is ran late in compilation (after DCE and arithmetic
3966     // simplification passes) and we don't want to generate redundant code.  Run
3967     // to fixed point.
3968     bool computation_modified = true;
3969     while (computation_modified) {
3970       computation_modified = false;
3971       VLOG(4) << "Running simplify graph loop over " << computation->name();
3972       for (HloInstruction* instruction :
3973            computation->MakeInstructionPostOrder()) {
3974         if (computation->IsSafelyRemovable(instruction) &&
3975             instruction->IsDead() && !instruction->HasSideEffect() &&
3976             instruction->opcode() != HloOpcode::kCopyStart &&
3977             instruction->opcode() != HloOpcode::kCopyDone) {
3978           VLOG(4) << "Instruction removed: " << instruction->ToString();
3979           // Ensure the alternate memory assignments don't contain a reference
3980           // to the removed instruction.
3981           RemoveAssignmentForInstruction(instruction);
3982           // Instead of deleting the instruction from the schedule, replace it
3983           // with a nullptr. This is needed because FixSchedule relies on the
3984           // logical time that is the index into flattened_instructions_ for
3985           // scheduling asynchronous copies.
3986           auto instruction_it =
3987               absl::c_find(flattened_instructions_, instruction);
3988           if (instruction_it != flattened_instructions_.end()) {
3989             *instruction_it = nullptr;
3990           }
3991           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3992           computation_modified = true;
3993         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3994           HloInstruction* operand = instruction->mutable_operand(0);
3995           if (operand->opcode() == HloOpcode::kTuple) {
3996             HloInstruction* forwarded_instruction =
3997                 operand->mutable_operand(instruction->tuple_index());
3998             VLOG(4) << "Replacing uses of " << instruction->ToString()
3999                     << " with " << forwarded_instruction->ToString();
4000             TF_RETURN_IF_ERROR(
4001                 instruction->ReplaceAllUsesWith(forwarded_instruction));
4002             computation_modified = true;
4003           }
4004         } else if (instruction->opcode() == HloOpcode::kTuple) {
4005           // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
4006           // with x.
4007           bool can_replace =
4008               instruction->operand_count() > 0 &&
4009               instruction->operand(0)->opcode() ==
4010                   HloOpcode::kGetTupleElement &&
4011               instruction->operand(0)
4012                       ->operand(0)
4013                       ->shape()
4014                       .tuple_shapes_size() == instruction->operand_count();
4015           for (int operand_number = 0;
4016                operand_number < instruction->operand_count();
4017                ++operand_number) {
4018             const HloInstruction* operand =
4019                 instruction->operand(operand_number);
4020             if (operand->opcode() != HloOpcode::kGetTupleElement ||
4021                 operand->tuple_index() != operand_number ||
4022                 operand->operand(0) != instruction->operand(0)->operand(0)) {
4023               can_replace = false;
4024               break;
4025             }
4026           }
4027           if (can_replace) {
4028             HloInstruction* forwarded_instruction =
4029                 instruction->mutable_operand(0)->mutable_operand(0);
4030             VLOG(4) << "Replacing uses of " << instruction->ToString()
4031                     << " with " << forwarded_instruction->ToString();
4032             TF_RETURN_IF_ERROR(
4033                 instruction->ReplaceAllUsesWith(forwarded_instruction));
4034             computation_modified = true;
4035           }
4036         }
4037       }
4038     }
4039   }
4040 
4041   return OkStatus();
4042 }
4043 
ScheduleAsynchronousCopies()4044 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
4045   VLOG(1) << "Scheduling asynchronous copies...";
4046   for (MemorySpace memory_space :
4047        {MemorySpace::kDefault, MemorySpace::kAlternate}) {
4048     std::vector<CopyAllocation*> copy_allocations;
4049     for (auto& allocation : allocations_) {
4050       if (allocation->is_copy_allocation()) {
4051         auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
4052         if (copy_allocation->memory_space() == memory_space) {
4053           copy_allocations.push_back(copy_allocation);
4054         }
4055       }
4056     }
4057 
4058     absl::c_stable_sort(
4059         copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
4060           return std::forward_as_tuple(first->copy_done_schedule_before(),
4061                                        first->copy_start_schedule_after()) <
4062                  std::forward_as_tuple(second->copy_done_schedule_before(),
4063                                        second->copy_start_schedule_after());
4064         });
4065     for (CopyAllocation* copy_allocation : copy_allocations) {
4066       // If the copy start doesn't happen to be scheduled at the correct
4067       // computation, delay it until the correct computation starts.
4068       int64_t copy_start_schedule_after =
4069           copy_allocation->copy_start_schedule_after();
4070       // Accessing flattened_instructions_ here without checking if it is
4071       // nullptr is safe because this method is called before SimplifyGraph.
4072       while (copy_allocation->defining_position().instruction->parent() !=
4073              flattened_instructions_[copy_start_schedule_after]->parent()) {
4074         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
4075                 << (copy_start_schedule_after + 1) << ") for "
4076                 << copy_allocation->copy_start()->ToString()
4077                 << " because it is not in the correct computation.";
4078         copy_allocation->set_copy_start_schedule_after(
4079             ++copy_start_schedule_after);
4080       }
4081 
4082       schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
4083           copy_allocation->copy_start());
4084       schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
4085           copy_allocation->copy_done());
4086     }
4087   }
4088 }
4089 
FixSchedule()4090 Status MemorySpaceAssignment::FixSchedule() {
4091   VLOG(1) << "Fixing schedule...";
4092   TF_RET_CHECK(module_->has_schedule());
4093   HloSchedule& schedule = module_->schedule();
4094   for (const HloComputation* computation :
4095        module_->MakeNonfusionComputations()) {
4096     // Parallel computations aren't in the schedule and don't need to be
4097     // modified.
4098     if (!computations_in_schedule_.contains(computation)) {
4099       VLOG(4) << "Not scheduling " << computation->name()
4100               << " because it's not in the schedule.";
4101       continue;
4102     }
4103     TF_RET_CHECK(schedule.is_computation_scheduled(computation));
4104     HloInstructionSequence new_sequence;
4105 
4106     absl::flat_hash_set<HloInstruction*> inserted_instructions;
4107 
4108     VLOG(4) << "Scheduling: " << computation->ToString();
4109 
4110     for (int64_t instruction_index = 0;; ++instruction_index) {
4111       auto insts_before_iter = schedule_before_.find(instruction_index);
4112       if (insts_before_iter != schedule_before_.end()) {
4113         for (HloInstruction* new_instruction : insts_before_iter->second) {
4114           if (new_instruction->parent() == computation) {
4115             VLOG(4) << "before " << instruction_index << ": "
4116                     << new_instruction->name();
4117             TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4118                 new_instruction, &new_sequence, &inserted_instructions));
4119           }
4120         }
4121       }
4122       // We allow scheduling copy dones past the root instruction (for
4123       // end-of-program cross-program prefetch). So the loop exit condition is
4124       // actually here.
4125       if (instruction_index >= flattened_instructions_.size()) {
4126         break;
4127       }
4128       HloInstruction* instruction = flattened_instructions_[instruction_index];
4129       // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
4130       // it was deleted) and not previously inserted. Also bitcasts and tuples
4131       // are treated specially and only inserted as a result of operand
4132       // dependencies.
4133       if (instruction != nullptr && instruction->parent() == computation &&
4134           instruction->opcode() != HloOpcode::kBitcast &&
4135           instruction->opcode() != HloOpcode::kTuple &&
4136           !inserted_instructions.contains(instruction)) {
4137         VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
4138         TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4139             instruction, &new_sequence, &inserted_instructions));
4140       }
4141       auto insts_after_iter = schedule_after_.find(instruction_index);
4142       if (insts_after_iter != schedule_after_.end()) {
4143         for (HloInstruction* new_instruction : insts_after_iter->second) {
4144           if (new_instruction->parent() == computation) {
4145             VLOG(4) << "after " << instruction_index << ": "
4146                     << new_instruction->name();
4147             TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4148                 new_instruction, &new_sequence, &inserted_instructions));
4149           }
4150         }
4151       }
4152     }
4153     // For rare cases where the original sequence is empty, ensure the root
4154     // instruction and its dependencies are scheduled.
4155     TF_RETURN_IF_ERROR(EnsureInstructionAndOperandsInserted(
4156         computation->root_instruction(), &new_sequence,
4157         &inserted_instructions));
4158     CHECK_EQ(new_sequence.size(), computation->instruction_count())
4159         << "New sequence for computation " << computation->name() << " has "
4160         << new_sequence.size() << " instructions, expects "
4161         << computation->instruction_count() << ".";
4162     schedule.set_sequence(computation, new_sequence);
4163   }
4164 
4165   return OkStatus();
4166 }
4167 
VerifyAndExportHeapSimulatorTrace()4168 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
4169   VLOG(1) << "Verifying...";
4170   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
4171                       HloAliasAnalysis::Run(module_));
4172   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
4173                       HloLiveRange::Run(module_->schedule(), *alias_analysis,
4174                                         module_->entry_computation()));
4175 
4176   BufferIntervalTree interval_tree;
4177   absl::flat_hash_set<int64_t> seen_buffers;
4178   // The key for events is: time, is_free, value_id. This is so that the events
4179   // are sorted first by time, then within the same time, allocations are sorted
4180   // earlier than frees, and finally the value id as a tie breaker.
4181   std::map<std::tuple<int64_t, bool, int64_t>,
4182            std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
4183       events;
4184 
4185   auto add_allocation_and_verify = [&](int64_t start_time, int64_t end_time,
4186                                        const Chunk& chunk,
4187                                        const HloValue* value) {
4188     events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
4189         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
4190     events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
4191         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
4192 
4193     // Get the chunks overlapping in time and search if they overlap in space
4194     // as well.
4195     // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
4196     // really should check against end_time (inclusive) for cases where the
4197     // operand can't share buffer with user (see
4198     // HloDataflowAnalysis::CanShareOperandBufferWithUser).
4199     for (const Chunk& overlapping_chunk :
4200          interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
4201       if (chunk.OverlapsWith(overlapping_chunk)) {
4202         return InternalError(
4203             ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
4204              " off: %d size: %d"),
4205             value->ToShortString(), start_time, end_time, chunk.offset,
4206             chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
4207       }
4208     }
4209     interval_tree.Add(start_time, end_time - 1, chunk);
4210     return OkStatus();
4211   };
4212 
4213   // Go through all instructions in the module to ensure CopyStart/CopyDone
4214   // instructions copy between alternate memory and default memory.
4215   for (const HloComputation* computation :
4216        module_->MakeNonfusionComputations()) {
4217     for (const HloInstruction* instruction : computation->instructions()) {
4218       if (instruction->opcode() == HloOpcode::kCopyStart) {
4219         int64_t from_memory_space =
4220             ShapeUtil::GetSubshape(instruction->shape(), {1})
4221                 .layout()
4222                 .memory_space();
4223         int64_t to_memory_space =
4224             ShapeUtil::GetSubshape(instruction->shape(), {0})
4225                 .layout()
4226                 .memory_space();
4227         CHECK_NE(from_memory_space, to_memory_space)
4228             << "Asynchronous copy to the same memory space: "
4229             << instruction->ToString();
4230       }
4231     }
4232   }
4233 
4234   for (const auto& position_and_chunk : preset_assignments_->chunks()) {
4235     const HloPosition& position = position_and_chunk.first;
4236     const Chunk& chunk = position_and_chunk.second;
4237     const HloBuffer& buffer =
4238         alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
4239     CHECK(!seen_buffers.contains(buffer.id()))
4240         << "Multiple preset assignments for the same buffer: "
4241         << buffer.ToString() << ", pos: " << position.ToString()
4242         << ", off: " << chunk.offset << ", size: " << chunk.size;
4243     seen_buffers.insert(buffer.id());
4244 
4245     for (const HloValue* value : buffer.values()) {
4246       const HloLiveRange::TimeBound& time_bound =
4247           hlo_live_range->buffer_live_ranges().at(value);
4248       const HloInstruction* last_use_instruction = nullptr;
4249       int64_t last_use_time = time_bound.start;
4250       for (const HloUse& use : value->GetUses()) {
4251         int64_t use_time =
4252             hlo_live_range->instruction_schedule().at(use.instruction);
4253         if (use_time > last_use_time) {
4254           last_use_time = use_time;
4255           last_use_instruction = use.instruction;
4256         }
4257       }
4258 
4259       std::function<Status(const HloInstruction*, int64_t, int64_t,
4260                            absl::string_view)>
4261           split_conditional_buffer;
4262       split_conditional_buffer = [&](const HloInstruction* use_instruction,
4263                                      int64_t start_time, int64_t end_time,
4264                                      absl::string_view indent_string) {
4265         // Special case when verifying conditional: we internally split the use
4266         // of alternate memory in conditionals, so fish them out from the
4267         // conditionals.
4268         VLOG(3) << indent_string
4269                 << "Splitting conditional buffer: " << buffer.ToString()
4270                 << " value: " << value->ToShortString() << ": (" << start_time
4271                 << ", " << end_time << ") off: " << chunk.offset
4272                 << ", size: " << chunk.size;
4273         int64_t earliest_computation_start_time = end_time;
4274         for (const HloComputation* called_computation :
4275              use_instruction->called_computations()) {
4276           int64_t computation_start_time =
4277               hlo_live_range->computation_span_times()
4278                   .at(called_computation)
4279                   .start;
4280           earliest_computation_start_time =
4281               std::min(earliest_computation_start_time, computation_start_time);
4282           int64_t last_use_time = -1;
4283           const HloInstruction* last_use_instruction = nullptr;
4284           for (const HloUse& use : value->GetUses()) {
4285             int64_t use_time =
4286                 hlo_live_range->instruction_schedule().at(use.instruction);
4287             if (use.instruction->parent() == called_computation &&
4288                 use_time > last_use_time) {
4289               last_use_time = use_time;
4290               last_use_instruction = use.instruction;
4291             }
4292           }
4293           if (last_use_time != -1) {
4294             VLOG(3) << indent_string
4295                     << " computation: " << called_computation->name() << ": ("
4296                     << computation_start_time << ", " << last_use_time << ")";
4297             CHECK(last_use_instruction);
4298             if (last_use_instruction->opcode() == HloOpcode::kConditional) {
4299               // The last use is another (nested) conditional. Call this
4300               // function recursively.
4301               TF_RETURN_IF_ERROR(split_conditional_buffer(
4302                   last_use_instruction, computation_start_time, last_use_time,
4303                   absl::StrCat(indent_string, "  ")));
4304             } else {
4305               last_use_time = std::min(last_use_time, end_time);
4306               TF_RETURN_IF_ERROR(add_allocation_and_verify(
4307                   computation_start_time, last_use_time, chunk, value));
4308             }
4309           }
4310         }
4311         VLOG(3) << indent_string << " from beginning until first computation: ("
4312                 << start_time << ", " << (earliest_computation_start_time - 1)
4313                 << ")";
4314         TF_RETURN_IF_ERROR(add_allocation_and_verify(
4315             start_time, earliest_computation_start_time - 1, chunk, value));
4316         return OkStatus();
4317       };
4318 
4319       if (last_use_instruction &&
4320           last_use_instruction->opcode() == HloOpcode::kConditional) {
4321         TF_RETURN_IF_ERROR(split_conditional_buffer(
4322             last_use_instruction, time_bound.start, time_bound.end, " "));
4323       } else if (!value->GetUses().empty()) {
4324         last_use_time = std::min(last_use_time, time_bound.end);
4325         VLOG(3) << " buffer: " << buffer.ToString()
4326                 << " value: " << value->ToShortString() << ": ("
4327                 << time_bound.start << ", " << last_use_time
4328                 << ") off: " << chunk.offset << ", size: " << chunk.size;
4329         TF_RETURN_IF_ERROR(add_allocation_and_verify(
4330             time_bound.start, last_use_time, chunk, value));
4331       }
4332     }
4333   }
4334 
4335   HeapSimulatorTrace* heap_trace =
4336       &preset_assignments_
4337            ->assignment_information_for_space(options_.alternate_memory_space)
4338            ->heap_simulator_trace;
4339   int64_t memory_usage = 0;
4340   int64_t max_memory_usage = 0;
4341   for (const auto& event : events) {
4342     int64_t time;
4343     bool is_free;
4344     int64_t buffer_id;
4345     std::tie(time, is_free, buffer_id) = event.first;
4346     const HloValue* value;
4347     Chunk chunk;
4348     HeapSimulatorTrace::Event::Kind kind;
4349     std::tie(value, chunk, kind) = event.second;
4350     HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
4351     heap_trace_event->set_kind(kind);
4352     heap_trace_event->set_buffer_id(buffer_id);
4353     heap_trace_event->set_instruction_name(value->instruction()->name());
4354     heap_trace_event->set_computation_name(
4355         value->instruction()->parent()->name());
4356 
4357     if (kind == HeapSimulatorTrace::Event::ALLOC) {
4358       memory_usage += chunk.size;
4359     } else {
4360       CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
4361       memory_usage -= chunk.size;
4362     }
4363     max_memory_usage = std::max(max_memory_usage, memory_usage);
4364     VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
4365   }
4366   VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
4367 
4368   return OkStatus();
4369 }
4370 }  // namespace memory_space_assignment
4371 }  // namespace xla
4372