1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <iterator>
21 #include <limits>
22 #include <string>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/service/memory_space_assignment_tuning_utils.h"
30 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32 #include "tensorflow/core/lib/math/math_util.h"
33 namespace xla {
34
35 namespace memory_space_assignment {
36
37 namespace {
38 // Define a dummy chunk for chunks that will be allocated in the default memory
39 // space and for keeping track of number of asynchronous copies.
40 const HeapSimulator::Chunk kDummyChunk{-1, -1};
41 // For cross-program prefetched buffer, we only perform the freeing optimization
42 // if the buffer occupies less of the execution time ratio than this value.
43 const float kCrossProgramPrefetchOccupyFreeingLimit = 0.6;
44 // Each time we retry compilation, increase the preferred eviction end time by
45 // this amount multiplied by preferred overlap to async copy ratio.
46 const float kEvictionRetryMultiplier = 2.0;
47
LooksLikeAnActivation(const HloInstruction * inst)48 bool LooksLikeAnActivation(const HloInstruction* inst) {
49 for (HloInstruction* user : inst->users()) {
50 switch (user->opcode()) {
51 case HloOpcode::kConvolution:
52 case HloOpcode::kDot:
53 if (user->operand(0) == inst) {
54 return true;
55 }
56 break;
57 case HloOpcode::kGather:
58 if (user->operand(1) == inst) {
59 return true;
60 }
61 break;
62 case HloOpcode::kFusion:
63 for (int i = 0; i < user->operand_count(); ++i) {
64 if (user->operand(i) == inst &&
65 LooksLikeAnActivation(user->fused_parameter(i))) {
66 return true;
67 }
68 }
69 break;
70 case HloOpcode::kBitcast:
71 case HloOpcode::kBroadcast:
72 case HloOpcode::kTranspose:
73 if (LooksLikeAnActivation(user)) {
74 return true;
75 }
76 break;
77 case HloOpcode::kDynamicUpdateSlice:
78 case HloOpcode::kDynamicSlice:
79 if (std::find(user->operands().begin() + 1, user->operands().end(),
80 inst) != user->operands().end()) {
81 return true;
82 }
83 if (LooksLikeAnActivation(user)) {
84 return true;
85 }
86 break;
87 case HloOpcode::kReduce:
88 // Check init operands.
89 if (std::find(user->operands().begin() + user->operand_count() / 2,
90 user->operands().end(), inst) != user->operands().end()) {
91 return true;
92 }
93 if (LooksLikeAnActivation(user)) {
94 return true;
95 }
96 break;
97 default:
98 return true;
99 }
100 }
101 return false;
102 }
103
IsCrossProgramPrefetchCandidate(const HloValue & value,const Options & options)104 bool IsCrossProgramPrefetchCandidate(const HloValue& value,
105 const Options& options) {
106 return value.defining_instruction()->parent() ==
107 value.defining_instruction()->GetModule()->entry_computation() &&
108 value.defining_instruction()->opcode() == HloOpcode::kParameter &&
109 (!value.shape().has_layout() ||
110 value.shape().layout().memory_space() !=
111 options.alternate_memory_space) &&
112 value.index().size() <= 1 && value.shape().IsArray() &&
113 !value.GetUses().empty() &&
114 options.size_fn(value) <= options.max_size_in_bytes &&
115 absl::c_all_of(value.GetUses(), [&](const HloUse& use) {
116 const HloInstruction* inst =
117 use.instruction->operand(use.operand_number);
118
119 // Skip the LooksLikeAnActivation test since we're testing the
120 // parent GTE/parameter and its children below.
121 if (inst->opcode() == HloOpcode::kBitcast &&
122 ((inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
123 inst->operand(0)->operand(0)->opcode() ==
124 HloOpcode::kParameter) ||
125 inst->operand(0)->opcode() == HloOpcode::kParameter)) {
126 return true;
127 }
128
129 return (inst->opcode() == HloOpcode::kGetTupleElement ||
130 inst->opcode() == HloOpcode::kParameter) &&
131 !LooksLikeAnActivation(inst);
132 });
133 }
134
135 std::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const Options & options)136 FindCrossProgramPrefetchCandidate(const HloAliasAnalysis& alias_analysis,
137 const HloLiveRange& hlo_live_range,
138 const Options& options) {
139 std::vector<MemorySpaceAssignment::BufferInterval> candidates;
140 for (const HloBuffer& buffer : alias_analysis.buffers()) {
141 CHECK_GE(buffer.values().size(), 1);
142 const HloValue* value = buffer.values().at(0);
143 if (IsCrossProgramPrefetchCandidate(*value, options)) {
144 MemorySpaceAssignment::BufferInterval interval;
145 interval.buffer = value;
146 interval.size = options.size_fn(*value);
147 interval.start = 0;
148 interval.end = hlo_live_range.schedule_end_time();
149 interval.need_allocation = true;
150 interval.colocations = {++buffer.values().begin(), buffer.values().end()};
151 candidates.emplace_back(interval);
152 }
153 }
154
155 // The BufferIntervalCompare function used to sort buffers implements the
156 // greater-than operator so that the most beneficial buffers are allocated
157 // first. The size_compare function below hence uses the greater-than operator
158 // to pick the largest buffer.
159 auto size_compare = [](const auto& x, const auto& y) {
160 if (x.size == y.size) {
161 // When both buffers are of same size, we prefer the one that is used to
162 // produce larger tensors in its consumer instructions.
163 auto get_use_size =
164 [](const MemorySpaceAssignment::BufferInterval& bi) -> int64_t {
165 int64_t use_size = 0;
166 for (const auto& use : bi.buffer->GetUses()) {
167 use_size += ShapeUtil::ElementsInRecursive(use.instruction->shape());
168 }
169 return use_size;
170 };
171 return get_use_size(x) > get_use_size(y);
172 }
173 return x.size > y.size;
174 };
175 auto& compare = options.default_cross_program_prefetch_heuristic &&
176 options.buffer_interval_compare
177 ? *options.buffer_interval_compare
178 : size_compare;
179
180 auto best_candidate = absl::c_min_element(candidates, compare);
181 if (best_candidate == candidates.end()) {
182 return std::nullopt;
183 }
184 VLOG(3) << "Cross-program prefetch candidate picked: "
185 << best_candidate->buffer->ToString();
186 return *best_candidate;
187 }
188
189 Status InsertInstructionAndEnsureOperandsInserted(
190 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
191 absl::flat_hash_set<HloInstruction*>* inserted_instructions);
192
193 // Insert an instruction to the schedule, and make sure its dependencies
194 // (operands) are already in the schedule. If not, insert these operands
195 // before the instruction.
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions)196 Status EnsureInstructionAndOperandsInserted(
197 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
198 absl::flat_hash_set<HloInstruction*>* inserted_instructions) {
199 if (inserted_instructions->contains(new_instruction)) {
200 return OkStatus();
201 }
202 return InsertInstructionAndEnsureOperandsInserted(
203 new_instruction, new_sequence, inserted_instructions);
204 }
205
206 // Same as above, but does not check if instruction is already inserted. This is
207 // used when the caller already knows the instruction isn't inserted yet, to
208 // speed up compilation.
InsertInstructionAndEnsureOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions)209 Status InsertInstructionAndEnsureOperandsInserted(
210 HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
211 absl::flat_hash_set<HloInstruction*>* inserted_instructions) {
212 for (HloInstruction* operand : new_instruction->operands()) {
213 // CopyStart/CopyDone dependencies should always be already inserted; it is
214 // a red flag when they haven't already been inserted.
215 if (operand->opcode() == HloOpcode::kCopyStart ||
216 operand->opcode() == HloOpcode::kCopyDone) {
217 TF_RET_CHECK(inserted_instructions->contains(operand))
218 << "Inserted instruction " << new_instruction->ToString()
219 << " has un-inserted dependency: " << operand->ToString();
220 continue;
221 }
222 TF_RETURN_IF_ERROR(EnsureInstructionAndOperandsInserted(
223 operand, new_sequence, inserted_instructions));
224 }
225 VLOG(4) << "inserting: " << new_instruction->ToShortString();
226 new_sequence->push_back(new_instruction);
227 TF_RET_CHECK(inserted_instructions->insert(new_instruction).second);
228 return OkStatus();
229 }
230
UsesToString(const std::vector<HloUse> & uses)231 std::string UsesToString(const std::vector<HloUse>& uses) {
232 if (uses.empty()) {
233 return "none";
234 }
235 std::vector<std::string> uses_str;
236 uses_str.reserve(uses.size());
237 for (const auto& use : uses) {
238 uses_str.push_back(use.ToString());
239 }
240 return absl::StrJoin(uses_str, ",");
241 }
242
243 } // namespace
244
245 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const Options & options,const HloModule & module)246 MemorySpaceAssignmentCostAnalysis::Create(const HloCostAnalysis& cost_analysis,
247 const Options& options,
248 const HloModule& module) {
249 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
250 TF_ASSIGN_OR_RETURN(auto hlo_live_range,
251 HloLiveRange::Run(module.schedule(), *alias_analysis,
252 module.entry_computation()));
253 auto call_graph = CallGraph::Build(&module);
254 return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
255 cost_analysis, options, std::move(alias_analysis),
256 std::move(hlo_live_range), std::move(call_graph)));
257 }
258
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const259 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
260 const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
261 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
262 float elapsed_time_due_to_compute =
263 GetInstructionElapsedDueToCompute(instruction);
264 float elapsed_time_due_to_memory =
265 GetInstructionElapsedDueToMemory(instruction);
266 if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
267 // Memory bound, return how much alternate memory is better.
268 float while_nest_multiplier;
269 if (cache) {
270 // If there is a cache provided, memoize the while nest multiplier.
271 auto it = cache->while_nest_multiplier.find(&instruction);
272 if (it != cache->while_nest_multiplier.end()) {
273 while_nest_multiplier = it->second;
274 } else {
275 while_nest_multiplier = IPow<float>(
276 options_.xla_tpu_memory_space_assignment_while_execution_count,
277 CalculateComputationNestLevel(&instruction,
278 /*while_only=*/true));
279 cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
280 }
281 } else {
282 while_nest_multiplier = IPow<float>(
283 options_.xla_tpu_memory_space_assignment_while_execution_count,
284 CalculateComputationNestLevel(&instruction,
285 /*while_only=*/true));
286 }
287 return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
288 while_nest_multiplier;
289 } else {
290 // Compute bound, return how far off are we to memory boundedness.
291 return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
292 }
293 }
294
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const295 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
296 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
297 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
298 float alternate_mem_benefit =
299 GetAlternateMemoryBenefit(interval.buffer->defining_position(), cache);
300
301 for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
302 interval.buffer->defining_position().instruction,
303 interval.buffer->defining_position().index)) {
304 for (const HloValue* value : buffer->values()) {
305 for (const HloUse& use : value->GetUses()) {
306 // We look inside the called computations of while and conditional, so
307 // don't use the benefit of while and conditional directly.
308 if (use.instruction->opcode() == HloOpcode::kWhile ||
309 use.instruction->opcode() == HloOpcode::kConditional) {
310 continue;
311 }
312 float use_alternate_mem_benefit = GetAlternateMemoryBenefit(use, cache);
313 // If the benefit is positive (memory bound), add it to this buffer's
314 // benefit. If the benefit is negative (compute bound), calculate the
315 // maximum.
316 if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
317 alternate_mem_benefit += use_alternate_mem_benefit;
318 } else {
319 alternate_mem_benefit =
320 std::max(alternate_mem_benefit, use_alternate_mem_benefit);
321 }
322 }
323 }
324 }
325
326 // Penalize larger buffers by dividing the benefit by the square root of the
327 // size. Empirically, we observed this resulted in better performance compared
328 // to dividing by the size.
329 return alternate_mem_benefit / std::sqrt(interval.size);
330 }
331
GetAlternateMemoryBenefit(const HloPosition & position,MemorySpaceAssignmentCostAnalysis::Cache * cache) const332 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
333 const HloPosition& position,
334 MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
335 return GetAlternateMemoryBenefit(
336 *position.instruction,
337 GetInstructionElapsedDueToMemory(
338 *position.instruction,
339 /*operands_in_alternate_mem=*/{},
340 /*outputs_in_alternate_mem=*/{position.index}),
341 cache);
342 }
343
GetAlternateMemoryBenefit(const HloUse & use,MemorySpaceAssignmentCostAnalysis::Cache * cache) const344 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
345 const HloUse& use, MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
346 return GetAlternateMemoryBenefit(
347 *use.instruction,
348 GetInstructionElapsedDueToMemory(
349 *use.instruction,
350 /*operands_in_alternate_mem=*/{std::make_pair(use.operand_number,
351 use.operand_index)}),
352 cache);
353 }
354
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const355 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
356 const HloInstruction* instruction, bool while_only) const {
357 int nest_level = 0;
358 const HloComputation* computation = instruction->parent();
359 while (!computation->IsEntryComputation()) {
360 auto& node = call_graph_->GetNode(computation);
361 auto callsites = node.caller_callsites();
362 CHECK(node.computation()->IsAsyncComputation() || callsites.size() == 1)
363 << "The module is not flattened!";
364 auto& callsite = callsites[0];
365 if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
366 ++nest_level;
367 }
368 computation = callsite.instruction()->parent();
369 }
370 return nest_level;
371 }
372
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const373 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
374 const HloInstruction& instruction) const {
375 return std::max(
376 cost_analysis_.flop_count(instruction) /
377 cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
378 cost_analysis_.transcendental_count(instruction) /
379 cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
380 }
381
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const382 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
383 const HloInstruction& instruction,
384 absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
385 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
386 float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
387 float bytes_accessed_from_alternate_mem = 0.0;
388 for (auto& operand : operands_in_alternate_mem) {
389 float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
390 instruction, operand.first, operand.second);
391 bytes_accessed_from_alternate_mem += operand_bytes_accessed;
392 }
393
394 for (auto& shape_idx : outputs_in_alternate_mem) {
395 float output_bytes_accessed =
396 cost_analysis_.output_bytes_accessed(instruction, shape_idx);
397 bytes_accessed_from_alternate_mem += output_bytes_accessed;
398 }
399 float elapsed_due_to_alternate_mem =
400 bytes_accessed_from_alternate_mem /
401 options().alternate_mem_bandwidth_bytes_per_second;
402 float elapsed_due_to_default_mem =
403 (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
404 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
405 return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
406 }
407
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,IsInAlternateMemoryFun is_in_alternate_mem) const408 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
409 const HloInstruction& instruction,
410 IsInAlternateMemoryFun is_in_alternate_mem) const {
411 float total_bytes_accessed = cost_analysis_.bytes_accessed(instruction);
412 float bytes_accessed_from_alternate_mem = 0.0;
413 for (int operand_num = 0; operand_num < instruction.operand_count();
414 ++operand_num) {
415 ShapeUtil::ForEachSubshape(
416 instruction.operand(operand_num)->shape(),
417 [&](const Shape& subshape, const ShapeIndex& index) {
418 if (!subshape.IsArray()) {
419 return;
420 }
421 if (is_in_alternate_mem(operand_num, index, subshape)) {
422 bytes_accessed_from_alternate_mem +=
423 cost_analysis_.operand_bytes_accessed(instruction, operand_num,
424 index);
425 }
426 });
427 }
428 ShapeUtil::ForEachSubshape(instruction.shape(), [&](const Shape& subshape,
429 const ShapeIndex& index) {
430 if (!subshape.IsArray()) {
431 return;
432 }
433 if (is_in_alternate_mem(/*operand_num=*/std::nullopt, index, subshape)) {
434 bytes_accessed_from_alternate_mem +=
435 cost_analysis_.output_bytes_accessed(instruction, index);
436 }
437 });
438 float elapsed_due_to_alternate_mem =
439 bytes_accessed_from_alternate_mem /
440 options().alternate_mem_bandwidth_bytes_per_second;
441 float elapsed_due_to_default_mem =
442 (total_bytes_accessed - bytes_accessed_from_alternate_mem) /
443 cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
444 return elapsed_due_to_alternate_mem + elapsed_due_to_default_mem;
445 }
446
GetInstructionElapsed(const HloInstruction & instruction) const447 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
448 const HloInstruction& instruction) const {
449 return std::max(GetInstructionElapsedDueToCompute(instruction),
450 GetInstructionElapsedDueToMemory(instruction));
451 }
452
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::Span<const std::pair<int64_t,ShapeIndex>> operands_in_alternate_mem,absl::Span<const ShapeIndex> outputs_in_alternate_mem) const453 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
454 const HloInstruction& instruction,
455 absl::Span<const std::pair<int64_t, ShapeIndex>> operands_in_alternate_mem,
456 absl::Span<const ShapeIndex> outputs_in_alternate_mem) const {
457 return std::max(
458 GetInstructionElapsedDueToCompute(instruction),
459 GetInstructionElapsedDueToMemory(instruction, operands_in_alternate_mem,
460 outputs_in_alternate_mem));
461 }
462
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,IsInAlternateMemoryFun is_in_alternate_mem) const463 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
464 const HloInstruction& instruction,
465 IsInAlternateMemoryFun is_in_alternate_mem) const {
466 return std::max(
467 GetInstructionElapsedDueToCompute(instruction),
468 GetInstructionElapsedDueToMemory(instruction, is_in_alternate_mem));
469 }
470
GetAsyncCopyElapsed(const Shape & shape) const471 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
472 const Shape& shape) const {
473 int64_t size_in_bytes = cost_analysis_.GetShapeSize(shape);
474 return static_cast<float>(size_in_bytes) /
475 options().async_copy_bandwidth_bytes_per_second;
476 }
477
GetScheduleEndTime() const478 int64_t MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
479 return hlo_live_range_->schedule_end_time();
480 }
481
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const482 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
483 const Shape& shape, int64_t start_time, int64_t end_time) const {
484 return end_time - start_time <= max_overlap_count_;
485 }
486
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const487 int64_t InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
488 const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
489 return std::min(start_time + min_overlap_count_, latest_end_time);
490 }
491
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const492 int64_t InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
493 const Shape& shape, int64_t start_time, int64_t end_time,
494 const HloUse* use) const {
495 return end_time - min_overlap_count_;
496 }
497
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const498 int64_t InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
499 const Shape& shape, int64_t earliest_prefetch_start_time,
500 int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
501 return std::max(earliest_prefetch_start_time,
502 prefetch_end_time - max_overlap_count_);
503 }
504
EstimatedPrefetchEndTime(const Shape & shape,int64_t start_time,int64_t end_time) const505 int64_t InstructionCountPrefetchIntervalPicker::EstimatedPrefetchEndTime(
506 const Shape& shape, int64_t start_time, int64_t end_time) const {
507 // For testing, assume the end time is the estimated prefetch end time.
508 return end_time;
509 }
510
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const511 float InstructionCountPrefetchIntervalPicker::GetLogicalIntervalElapsed(
512 int64_t start_time, int64_t end_time) const {
513 // For testing, just assume every HLO takes 1 second.
514 return static_cast<float>(end_time - start_time - 1);
515 }
516
Begin(const HloUse & use,int64_t start_time,int64_t end_time)517 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
518 int64_t start_time,
519 int64_t end_time) {
520 end_time_ = end_time;
521 const Shape& shape = ShapeUtil::GetSubshape(
522 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
523 current_prefetch_time_ =
524 PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
525 }
526
Next()527 int64_t InstructionCountPrefetchIntervalPicker::Next() {
528 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
529 "Done() is false";
530 return current_prefetch_time_++;
531 }
532
Done() const533 bool InstructionCountPrefetchIntervalPicker::Done() const {
534 return end_time_ - current_prefetch_time_ <= min_overlap_count_;
535 }
536
latest_time() const537 int64_t InstructionCountPrefetchIntervalPicker::latest_time() const {
538 return end_time_ - min_overlap_count_ - 1;
539 }
540
ToDebugString() const541 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
542 return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
543 }
544
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const545 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
546 const Shape& shape, int64_t start_time, int64_t end_time) const {
547 return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
548 }
549
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_overlap_to_async_copy_ratio,float preferred_overlap_to_async_copy_ratio,float max_overlap_to_mem_size_async_copy_ratio,int64_t mem_size_bytes)550 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
551 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
552 float min_overlap_to_async_copy_ratio,
553 float preferred_overlap_to_async_copy_ratio,
554 float max_overlap_to_mem_size_async_copy_ratio, int64_t mem_size_bytes)
555 : while_nest_level_(
556 cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
557 computation_nest_level_(
558 cost_analysis.hlo_live_range().instruction_schedule().size() + 1, 0),
559 cost_analysis_(cost_analysis),
560 min_overlap_to_async_copy_ratio_(min_overlap_to_async_copy_ratio),
561 preferred_overlap_to_async_copy_ratio_(
562 preferred_overlap_to_async_copy_ratio),
563 max_async_copy_elapsed_(
564 cost_analysis_.GetAsyncCopyElapsed(
565 ShapeUtil::MakeShape(S32, {mem_size_bytes / 4})) *
566 max_overlap_to_mem_size_async_copy_ratio) {
567 instruction_schedule_ =
568 &cost_analysis_.hlo_live_range().instruction_schedule();
569
570 // Create a vector of elapsed times and while nesting levels of HLO
571 // instructions. The elapsed times are multiplied by
572 // pow(while_execution_count, nest_level) to account for executing the HLOs
573 // multiple times in while loops.
574 std::vector<float> instructions_elapsed_time(
575 instruction_schedule_->size() + 1, 0.0);
576 int max_while_nest_level = 0;
577 for (const auto& instruction_and_logical_time : *instruction_schedule_) {
578 // To avoid double counting, don't include the elapsed time of while and
579 // conditional HLOs.
580 const HloInstruction* instruction = instruction_and_logical_time.first;
581 int64_t logical_time = instruction_and_logical_time.second;
582 if (logical_time >= instructions_elapsed_time.size()) {
583 instructions_elapsed_time.resize(logical_time + 1, 0.0);
584 while_nest_level_.resize(logical_time + 1, 0);
585 }
586 int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
587 instruction_and_logical_time.first, /*while_only=*/true);
588 while_nest_level_[logical_time] = while_nest_level;
589 max_while_nest_level = std::max(max_while_nest_level, while_nest_level);
590 int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
591 instruction_and_logical_time.first, /*while_only=*/false);
592 computation_nest_level_[logical_time] = computation_nest_level;
593 if (instruction->opcode() == HloOpcode::kWhile ||
594 instruction->opcode() == HloOpcode::kConditional) {
595 continue;
596 }
597 float elapsed_time = cost_analysis_.GetInstructionElapsed(
598 *instruction_and_logical_time.first);
599 instructions_elapsed_time[logical_time] =
600 elapsed_time *
601 IPow<float>(cost_analysis_.options()
602 .xla_tpu_memory_space_assignment_while_execution_count,
603 while_nest_level);
604 }
605 // As an optimization, create a cumulative sum vector of elapsed time.
606 float cumsum = 0.0;
607 elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
608 for (float elapsed_time : instructions_elapsed_time) {
609 cumsum += elapsed_time;
610 elapsed_time_cumsum_.push_back(cumsum);
611 }
612 // To be able to accurately determine the minimum nest level between a start
613 // time and an end time efficiently, populate a data structure that stores the
614 // closest 'smaller' nest level change index.
615 const int64_t size = instructions_elapsed_time.size();
616 CHECK_EQ(size, while_nest_level_.size());
617 std::vector<int> most_recent_by_level(while_nest_level_.size(), -1);
618 int prev_nest_level = 0;
619 int change_idx = -1;
620 while_nest_level_change_.reserve(size);
621 for (int i = 0; i < size; ++i) {
622 int nest_level = while_nest_level_[i];
623 if (nest_level != prev_nest_level) {
624 prev_nest_level = nest_level;
625 // Compute last change index by choosing the most recent instruction index
626 // with smaller nesting level. Note that it may happen that even though
627 // there were few different regions with other nest levels before, all of
628 // then are same or bigger than this one, in which case we'll end up with
629 // -1, e.g. if you got nest level 0 no need checking anything else.
630 change_idx = -1;
631 for (int smaller_level = 0; smaller_level < nest_level; smaller_level++) {
632 change_idx = std::max(change_idx, most_recent_by_level[smaller_level]);
633 }
634 }
635 most_recent_by_level[nest_level] = i;
636 while_nest_level_change_.push_back(change_idx);
637 }
638 for (int i = 0; i <= max_while_nest_level; ++i) {
639 while_execution_counts_.push_back(
640 IPow<float>(cost_analysis_.options()
641 .xla_tpu_memory_space_assignment_while_execution_count,
642 i));
643 }
644 }
645
GetMaxElapsedInAlternateMemory(float async_copy_elapsed) const646 float CostAnalysisPrefetchIntervalPicker::GetMaxElapsedInAlternateMemory(
647 float async_copy_elapsed) const {
648 return max_async_copy_elapsed_;
649 }
650
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64_t start_time,int64_t end_time) const651 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
652 const Shape& shape, int64_t start_time, int64_t end_time) const {
653 // Even though this method returns if we allow the buffer in alternate memory
654 // _without_ asynchronous copies, calculate how long it would have taken to
655 // copy it and compare it to the elapsed time in the logical interval.
656 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
657 float logical_interval_elapsed =
658 GetLogicalIntervalElapsed(start_time, end_time);
659 return GetMaxElapsedInAlternateMemory(async_copy_elapsed) >
660 logical_interval_elapsed;
661 }
662
PreferredEvictionEndTime(const Shape & shape,int64_t start_time,int64_t latest_end_time) const663 int64_t CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
664 const Shape& shape, int64_t start_time, int64_t latest_end_time) const {
665 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
666 int64_t end_time;
667 for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
668 float logical_interval_elapsed =
669 GetLogicalIntervalElapsed(start_time, end_time);
670 if (logical_interval_elapsed >=
671 (1 + kEvictionRetryMultiplier * retry_number_) *
672 preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed) {
673 break;
674 }
675 }
676 return end_time;
677 }
678
LatestPrefetchStartTime(const Shape & shape,int64_t start_time,int64_t end_time,const HloUse * use) const679 int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
680 const Shape& shape, int64_t start_time, int64_t end_time,
681 const HloUse* use) const {
682 // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_.
683 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
684 // If there is a use, estimate the time we would save by having this op in
685 // alternate memory.
686 float inst_elapsed_reduction = 0.0f;
687 if (use) {
688 float elapsed_time =
689 cost_analysis_.GetInstructionElapsed(*use->instruction);
690 float elapsed_time_in_alternate_mem =
691 cost_analysis_.GetInstructionElapsedInAlternateMemory(
692 *use->instruction,
693 /*operands_in_alternate_mem=*/
694 {std::make_pair(use->operand_number, use->operand_index)},
695 /*outputs_in_alternate_mem=*/{});
696 inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
697 }
698 int end_nest_level = computation_nest_level_[end_time];
699
700 // Find the latest time we're allowed to start prefetching.
701 float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed;
702 int latest_prefetch_time;
703 for (latest_prefetch_time = end_time - 1;
704 latest_prefetch_time >= start_time &&
705 (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
706 min_interval >
707 GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
708 inst_elapsed_reduction);
709 --latest_prefetch_time) {
710 }
711
712 return latest_prefetch_time;
713 }
714
PreferredPrefetchStartTime(const Shape & shape,int64_t earliest_prefetch_start_time,int64_t latest_prefetch_start_time,int64_t prefetch_end_time) const715 int64_t CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
716 const Shape& shape, int64_t earliest_prefetch_start_time,
717 int64_t latest_prefetch_start_time, int64_t prefetch_end_time) const {
718 // Between the earliest and latest prefetch interval, find the interval
719 // closest to the preferred interval and start iterating from there.
720 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
721 int64_t preferred_prefetch_start_time = earliest_prefetch_start_time;
722 float preferred_interval =
723 preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed;
724 float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
725 prefetch_end_time);
726 int end_nest_level = computation_nest_level_[prefetch_end_time];
727 for (int64_t prefetch_start_time = earliest_prefetch_start_time + 1;
728 prefetch_start_time <= latest_prefetch_start_time;
729 ++prefetch_start_time) {
730 float interval =
731 GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
732 if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
733 std::abs(preferred_interval - interval) <
734 std::abs(preferred_interval - best_interval)) {
735 best_interval = interval;
736 preferred_prefetch_start_time = prefetch_start_time;
737 }
738 }
739 return preferred_prefetch_start_time;
740 }
741
LatestPrefetchEndTime(int64_t original_prefetch_end_time,int64_t proposed_prefetch_end_time) const742 int64_t CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
743 int64_t original_prefetch_end_time,
744 int64_t proposed_prefetch_end_time) const {
745 // Iterate towards the beginning until we find a suitable end time that is the
746 // same while nest level as the original prefetch end time.
747 int64_t original_nest_level =
748 computation_nest_level_[original_prefetch_end_time];
749 int64_t new_prefetch_end_time;
750 for (new_prefetch_end_time = proposed_prefetch_end_time;
751 computation_nest_level_[new_prefetch_end_time] != original_nest_level;
752 --new_prefetch_end_time) {
753 }
754 return new_prefetch_end_time;
755 }
756
EstimatedPrefetchEndTime(const Shape & shape,int64_t start_time,int64_t end_time) const757 int64_t CostAnalysisPrefetchIntervalPicker::EstimatedPrefetchEndTime(
758 const Shape& shape, int64_t start_time, int64_t end_time) const {
759 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
760 int64_t estimated_end_time;
761 for (estimated_end_time = start_time + 1; estimated_end_time < end_time;
762 ++estimated_end_time) {
763 float interval = GetLogicalIntervalElapsed(start_time, estimated_end_time);
764 if (interval >= async_copy_elapsed) {
765 break;
766 }
767 }
768 return estimated_end_time;
769 }
770
Begin(const HloUse & use,int64_t start_time,int64_t end_time)771 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
772 int64_t start_time,
773 int64_t end_time) {
774 const Shape& shape = ShapeUtil::GetSubshape(
775 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
776 // Find the earliest time that satisfies max_overlap_to_async_copy_ratio_.
777 async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
778 // Estimate the time we would save by having this op in alternate memory.
779 float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
780 float elapsed_time_in_alternate_mem =
781 cost_analysis_.GetInstructionElapsedInAlternateMemory(
782 *use.instruction, /*operands_in_alternate_mem=*/
783 {std::make_pair(use.operand_number, use.operand_index)},
784 /*outputs_in_alternate_mem=*/{});
785 inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
786 end_logical_time_ = end_time;
787 int end_nest_level = computation_nest_level_[end_logical_time_];
788
789 // Find the latest time we're allowed to start prefetching.
790 float min_interval = min_overlap_to_async_copy_ratio_ * async_copy_elapsed_;
791 latest_prefetch_time_ =
792 LatestPrefetchStartTime(shape, start_time, end_time, &use);
793
794 // Find the earliest time we're allowed to start prefetching.
795 float max_interval = GetMaxElapsedInAlternateMemory(async_copy_elapsed_);
796 for (earliest_prefetch_time_ = start_time;
797 earliest_prefetch_time_ < latest_prefetch_time_ &&
798 (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
799 max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
800 end_logical_time_));
801 ++earliest_prefetch_time_) {
802 }
803 if (earliest_prefetch_time_ > latest_prefetch_time_) {
804 // There is no available prefetch interval for the given start and end
805 // times. Set the iterators accordingly to ensure Done() returns true.
806 increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
807 decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
808 CHECK(Done());
809 return;
810 }
811
812 int64_t starting_prefetch_time = PreferredPrefetchStartTime(
813 shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
814 float preferred_interval =
815 preferred_overlap_to_async_copy_ratio_ * async_copy_elapsed_;
816 VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
817 << max_interval << " " << preferred_interval
818 << " prefetch time earliest/latest/starting = "
819 << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
820 << starting_prefetch_time;
821
822 increasing_prefetch_time_iterator_ = starting_prefetch_time;
823 decreasing_prefetch_time_iterator_ = starting_prefetch_time;
824 using_increasing_prefetch_time_iterator_ = true;
825 // Since both iterators start at the same position, call Next() once to
826 // advance one of the iterators.
827 Next();
828 }
829
Next()830 int64_t CostAnalysisPrefetchIntervalPicker::Next() {
831 CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
832 "Done() is false";
833 if (using_increasing_prefetch_time_iterator_) {
834 int64_t prefetch_time = increasing_prefetch_time_iterator_++;
835 while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
836 computation_nest_level_[increasing_prefetch_time_iterator_] !=
837 computation_nest_level_[end_logical_time_]) {
838 ++increasing_prefetch_time_iterator_;
839 }
840 if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
841 using_increasing_prefetch_time_iterator_ = false;
842 }
843 return prefetch_time;
844 } else {
845 int64_t prefetch_time = decreasing_prefetch_time_iterator_--;
846 while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
847 computation_nest_level_[decreasing_prefetch_time_iterator_] !=
848 computation_nest_level_[end_logical_time_]) {
849 --decreasing_prefetch_time_iterator_;
850 }
851 if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
852 using_increasing_prefetch_time_iterator_ = true;
853 }
854 return prefetch_time;
855 }
856 }
857
Done() const858 bool CostAnalysisPrefetchIntervalPicker::Done() const {
859 return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
860 decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
861 }
862
latest_time() const863 int64_t CostAnalysisPrefetchIntervalPicker::latest_time() const {
864 return latest_prefetch_time_;
865 }
866
SetRetryNumber(int retry_number)867 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
868 retry_number_ = retry_number;
869 }
870
GetMinWhileNestLevel(int64_t start_time,int64_t end_time) const871 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
872 int64_t start_time, int64_t end_time) const {
873 int min_nest_level =
874 std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
875 int change_idx = while_nest_level_change_[end_time];
876 while (change_idx >= start_time) {
877 min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
878 change_idx = while_nest_level_change_[change_idx];
879 }
880 return min_nest_level;
881 }
882
GetLogicalIntervalElapsed(int64_t start_time,int64_t end_time) const883 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
884 int64_t start_time, int64_t end_time) const {
885 CHECK_LE(start_time, end_time);
886 if (start_time == end_time) {
887 return 0.0;
888 }
889 if (start_time < 0) {
890 start_time = 0;
891 }
892 // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
893 // level, normalize the elapsed time by dividing with the nesting factor of
894 // the interval (start and end times).
895 int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
896 return (elapsed_time_cumsum_[end_time - 1] -
897 elapsed_time_cumsum_[start_time]) /
898 while_execution_counts_[interval_while_nest_level];
899 }
900
ToDebugString() const901 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
902 int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
903 ? increasing_prefetch_time_iterator_
904 : decreasing_prefetch_time_iterator_;
905 float logical_interval_elapsed = GetLogicalIntervalElapsed(
906 current_logical_prefetch_time, end_logical_time_);
907 return absl::StrCat(
908 "Async copy elapsed (s) = ", async_copy_elapsed_,
909 ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
910 ", logical interval elapsed (s) = ", logical_interval_elapsed,
911 ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
912 ")");
913 }
914
ToNoCopyDebugString(const Shape & shape,int64_t start_time,int64_t end_time) const915 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
916 const Shape& shape, int64_t start_time, int64_t end_time) const {
917 float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
918 float logical_interval_elapsed =
919 GetLogicalIntervalElapsed(start_time, end_time);
920 return absl::StrCat(
921 "Async copy elapsed (s) = ", async_copy_elapsed,
922 ", logical interval elapsed (s) = ", logical_interval_elapsed);
923 }
924
925 std::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const926 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
927 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
928 const {
929 return cost_analysis_.GetMemoryBoundedness(interval);
930 }
931
operator ==(const MemorySpaceAssignment::Allocation & other) const932 bool MemorySpaceAssignment::Allocation::operator==(
933 const MemorySpaceAssignment::Allocation& other) const {
934 return defining_position() == other.defining_position() &&
935 uses() == other.uses() && memory_space() == other.memory_space() &&
936 chunk() == other.chunk() && start_time() == other.start_time() &&
937 end_time() == other.end_time() &&
938 earliest_available_time() == other.earliest_available_time() &&
939 is_copy_allocation() == other.is_copy_allocation() &&
940 is_scoped_allocation() == other.is_scoped_allocation();
941 }
942
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const943 bool MemorySpaceAssignment::CopyAllocation::operator==(
944 const MemorySpaceAssignment::CopyAllocation& other) const {
945 return static_cast<const Allocation&>(*this) ==
946 static_cast<const Allocation&>(other) &&
947 copy_done_schedule_before() == other.copy_done_schedule_before() &&
948 copy_start_schedule_after() == other.copy_start_schedule_after() &&
949 copy_start() == other.copy_start() && copy_done() == other.copy_done();
950 }
951
ToString() const952 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
953 std::string out = absl::StrCat("computation = ", computation()->name());
954 absl::StrAppend(&out,
955 (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
956 absl::StrAppend(&out, "\n position:\n");
957 absl::StrAppend(&out, " ", defining_position_.ToString(), "\n");
958 absl::StrAppend(&out, " uses:\n");
959 for (const Use& use : uses_) {
960 absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n");
961 }
962 return out;
963 }
964
ToShortString() const965 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
966 return absl::StrCat("computation = ", computation()->name(),
967 ", position = ", defining_position_.ToString(),
968 ", value = ", value_->ToShortString(),
969 (requires_contiguous_allocation_ ? " (cont alloc)" : ""));
970 }
971
AlternateMemoryBestFitHeap(MemorySpaceAssignment::AllocationSequence * allocations,const Options & options,const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range)972 AlternateMemoryBestFitHeap::AlternateMemoryBestFitHeap(
973 MemorySpaceAssignment::AllocationSequence* allocations,
974 const Options& options, const HloAliasAnalysis& alias_analysis,
975 const HloLiveRange& hlo_live_range)
976 : GlobalDecreasingSizeBestFitHeap(options.alignment_in_bytes),
977 allocations_(allocations),
978 options_(options),
979 alias_analysis_(alias_analysis),
980 hlo_live_range_(hlo_live_range),
981 peak_memory_usage_(hlo_live_range.schedule_end_time() + 1) {
982 // Override buffer interval compare if provided.
983 if (options.buffer_interval_compare) {
984 buffer_interval_compare_ = *options.buffer_interval_compare;
985 }
986
987 std::vector<float> initial_resources(hlo_live_range.schedule_end_time(), 1.0);
988 if (options.cost_analysis) {
989 const std::vector<HloInstruction*>& flattened_instructions =
990 hlo_live_range.flattened_instruction_sequence().instructions();
991 for (int i = 0; i < flattened_instructions.size(); ++i) {
992 const HloInstruction* inst = flattened_instructions[i];
993 if (inst->opcode() == HloOpcode::kWhile ||
994 inst->opcode() == HloOpcode::kConditional) {
995 initial_resources[i] = 0;
996 } else {
997 initial_resources[i] =
998 options.cost_analysis->GetInstructionElapsed(*inst);
999 }
1000 VLOG(2) << "Initial resource[" << i << "] = " << initial_resources[i]
1001 << " (" << inst->name() << ")";
1002 }
1003 }
1004 prefetch_async_copy_resource_ = AsynchronousCopyResource(initial_resources);
1005 eviction_async_copy_resource_ = AsynchronousCopyResource(initial_resources);
1006 }
1007
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const1008 void AlternateMemoryBestFitHeap::CreateAllocationValues(
1009 const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
1010 std::vector<AllocationValue>& allocation_values) const {
1011 const HloValue* value = buffer_interval.buffer;
1012 VLOG(3) << "Creating AllocationValues for: " << value->ToString();
1013
1014 // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
1015 // positions. We create an AllocationValue object for each non-trivial
1016 // position. And for each AllocationValue object, we create an
1017 // AllocationSequence consisting of one or more Allocation objects.The reason
1018 // why we exclude the trivial positions from AllocationValue is because
1019 // Allocation objects have special support for tuples and bitcasts.
1020 const absl::flat_hash_map<const HloInstruction*, int64_t>&
1021 instruction_schedule = hlo_live_range_.instruction_schedule();
1022 std::vector<HloPosition> positions;
1023 for (const HloPosition& position : value->positions()) {
1024 const HloInstruction* instruction = position.instruction;
1025 if (instruction->opcode() != HloOpcode::kGetTupleElement &&
1026 instruction->opcode() != HloOpcode::kTuple &&
1027 instruction->opcode() != HloOpcode::kBitcast) {
1028 positions.push_back(position);
1029 }
1030 }
1031 absl::c_stable_sort(positions,
1032 [&](const HloPosition& pos1, const HloPosition& pos2) {
1033 return instruction_schedule.at(pos1.instruction) <
1034 instruction_schedule.at(pos2.instruction);
1035 });
1036
1037 // Create an AllocationValue for each non-trivial position.
1038 absl::flat_hash_set<const HloComputation*> computations;
1039 int beginning_idx = allocation_values.size();
1040 for (int i = 0; i < positions.size(); ++i) {
1041 const HloPosition& position = positions.at(i);
1042 allocation_values.emplace_back(value, position, buffer_interval.size);
1043 }
1044
1045 std::vector<HloUse> uses(value->GetUses().begin(), value->GetUses().end());
1046 absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
1047 return instruction_schedule.at(use1.instruction) <
1048 instruction_schedule.at(use2.instruction);
1049 });
1050
1051 // Associate each use with an AllocationValue. Each AllocationValue contains a
1052 // position and uses in the same computation. Furthermore, if the original
1053 // HloValue had multiple non-trivial positions in the same computation, those
1054 // will get their own AllocationValue as well. We split these HloValues so
1055 // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
1056 // point to the latest position. We then replace the operand of the use with
1057 // CopyStart/CopyDone with an operand of the latest position.
1058 for (const HloUse& use : uses) {
1059 int64_t use_time = instruction_schedule.at(use.instruction);
1060 HloComputation* use_computation = use.instruction->parent();
1061
1062 AllocationValue* last_allocation_value = nullptr;
1063 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
1064 AllocationValue* allocation_value = &allocation_values.at(i);
1065 if (HloDataflowAnalysis::IsAsynchronousOperationDone(
1066 use.instruction->opcode())) {
1067 if (allocation_value->defining_instruction() ==
1068 use.instruction->operand(0)) {
1069 last_allocation_value = allocation_value;
1070 }
1071 } else if (!HloDataflowAnalysis::IsAsynchronousOperationStart(
1072 allocation_value->defining_instruction()->opcode()) &&
1073 allocation_value->computation() == use_computation &&
1074 instruction_schedule.at(
1075 allocation_value->defining_position().instruction) <
1076 use_time) {
1077 last_allocation_value = allocation_value;
1078 }
1079 }
1080 CHECK(last_allocation_value != nullptr);
1081 last_allocation_value->AddUse(use, use_time);
1082 }
1083
1084 for (int i = beginning_idx; i < allocation_values.size(); ++i) {
1085 AllocationValue& allocation_value = allocation_values.at(i);
1086 if (HloDataflowAnalysis::IsAsynchronousOperationStart(
1087 allocation_value.defining_instruction()->opcode())) {
1088 CHECK_EQ(allocation_value.uses().size(), 1);
1089 CHECK(HloDataflowAnalysis::IsAsynchronousOperationDone(
1090 allocation_value.uses().at(0).hlo_use.instruction->opcode()));
1091 VLOG(3) << "Mark " << allocation_value.ToShortString()
1092 << " to require contiguous allocation.";
1093 allocation_value.set_requires_contiguous_allocation(true);
1094 }
1095 VLOG(3) << "Created allocation value: "
1096 << allocation_values.at(i).ToString();
1097 }
1098 }
1099
FindAliases(std::vector<AllocationValue> * allocation_values) const1100 void AlternateMemoryBestFitHeap::FindAliases(
1101 std::vector<AllocationValue>* allocation_values) const {
1102 absl::flat_hash_map<const HloInstruction*,
1103 std::vector<const AllocationValue*>>
1104 values_by_defining_inst;
1105 for (AllocationValue& value : *allocation_values) {
1106 values_by_defining_inst[value.defining_instruction()].push_back(&value);
1107 }
1108 auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
1109 AllocationValue::Use* use) {
1110 auto aliased_values_it = values_by_defining_inst.find(instruction);
1111 if (aliased_values_it != values_by_defining_inst.end()) {
1112 for (const AllocationValue* aliased_value : aliased_values_it->second) {
1113 VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString()
1114 << " to " << aliased_value->ToShortString();
1115 use->aliases.push_back(aliased_value->defining_position());
1116 }
1117 }
1118 };
1119
1120 for (AllocationValue& value : *allocation_values) {
1121 for (AllocationValue::Use& use : value.uses()) {
1122 // Find any aliases with the instruction itself (operand and output must
1123 // alias).
1124 maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
1125
1126 // Find any aliases with the parameters of called computations.
1127 for (const HloComputation* called_computation :
1128 use.hlo_use.instruction->called_computations()) {
1129 for (const HloInstruction* parameter_instruction :
1130 called_computation->parameter_instructions()) {
1131 maybe_add_alias_with_instruction(parameter_instruction, &use);
1132 }
1133 }
1134
1135 // Special case for kWhile: the root of the body computation must alias as
1136 // well.
1137 if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1138 HloPosition root_alias{
1139 use.hlo_use.instruction->while_body()->root_instruction(),
1140 use.hlo_use.operand_index};
1141 VLOG(3) << "Adding while body root aliasing for use "
1142 << use.hlo_use.ToString() << " to " << root_alias;
1143 use.aliases.push_back(root_alias);
1144 }
1145 }
1146 }
1147 }
1148
1149 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const1150 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
1151 const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
1152 std::vector<const BufferInterval*> colocated_intervals;
1153 std::vector<const BufferInterval*> worklist = {&interval};
1154 while (!worklist.empty()) {
1155 const BufferInterval* item = worklist.back();
1156 worklist.pop_back();
1157 colocated_intervals.push_back(item);
1158 for (const HloValue* buffer_colocated : item->colocations) {
1159 worklist.push_back(&buffer_intervals_.at(buffer_colocated));
1160 }
1161 }
1162
1163 absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
1164 const BufferInterval* y) {
1165 return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
1166 });
1167 return colocated_intervals;
1168 }
1169
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const1170 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
1171 const AllocationValue& value, const HloUse& use) const {
1172 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1173 if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
1174 return false;
1175 }
1176 if (use.instruction->opcode() == HloOpcode::kWhile) {
1177 HloComputation* while_body = use.instruction->while_body();
1178
1179 // We don't want to allocate this buffer in alternate memory if it will be
1180 // evicted anyway. Find out if it has an early use or a late definition that
1181 // would make sense to keep it in the alternate memory.
1182 HloValue* parameter_value =
1183 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1184 while_body->parameter_instruction(0), use.operand_index);
1185 int64_t parameter_time =
1186 instruction_schedule.at(while_body->parameter_instruction(0));
1187 int64_t root_time = instruction_schedule.at(while_body->root_instruction());
1188 int64_t min_use_time = root_time;
1189 for (const HloUse& parameter_use : parameter_value->GetUses()) {
1190 int64_t use_time = instruction_schedule.at(parameter_use.instruction);
1191 if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
1192 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1193 parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
1194 use_time > parameter_time) {
1195 min_use_time = std::min(min_use_time, use_time);
1196 }
1197 }
1198 // If there is no use of this buffer inside the while loop, there is no need
1199 // to allocate it in the loop.
1200 if (min_use_time == root_time) {
1201 VLOG(4) << "While allocation not allowed in alternate memory. "
1202 << "use time = " << min_use_time << ", root time = " << root_time;
1203 return false;
1204 }
1205 const Shape& shape = parameter_value->shape();
1206 // Allow the buffer in alternate memory if the buffer has a short live range
1207 // either at the beginning or end of the while loop body.
1208 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1209 shape, parameter_time, min_use_time)) {
1210 VLOG(4) << "While allocation not allowed in alternate memory. "
1211 << "use time = " << min_use_time << ", root time = " << root_time;
1212 return false;
1213 }
1214 // Check if there is a required assignment for the while loop output.
1215 HloValue* while_value =
1216 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1217 use.instruction, use.operand_index);
1218 int64_t while_time = instruction_schedule.at(use.instruction);
1219 auto existing_required_assignment =
1220 RequiredMemoryAssignmentAt(while_value, while_time);
1221 if (existing_required_assignment &&
1222 existing_required_assignment->memory_space == MemorySpace::kDefault) {
1223 VLOG(4) << "While allocation not allowed in alternate memory because "
1224 "there is a required default memory assignment.";
1225 return false;
1226 }
1227 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
1228 // For any use of this conditional (the same value might be passed into
1229 // multiple called computations), determine if the parameter->first use
1230 // dependency is short.
1231 int64_t conditional_time = instruction_schedule.at(use.instruction);
1232 for (const AllocationValue::Use& other_use : value.uses()) {
1233 if (other_use.hlo_use.instruction != use.instruction) {
1234 continue;
1235 }
1236 HloComputation* called_computation =
1237 use.instruction->called_computations().at(
1238 other_use.hlo_use.operand_number - 1);
1239 const HloInstruction* parameter_instruction =
1240 called_computation->parameter_instruction(0);
1241 HloValue* parameter_value =
1242 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1243 parameter_instruction, other_use.hlo_use.operand_index);
1244 int64_t parameter_time = instruction_schedule.at(parameter_instruction);
1245 int64_t min_use_time = conditional_time;
1246 for (const HloUse& parameter_use : parameter_value->GetUses()) {
1247 if (parameter_use.instruction->parent() == called_computation &&
1248 parameter_use.instruction->opcode() !=
1249 HloOpcode::kGetTupleElement &&
1250 parameter_use.instruction->opcode() != HloOpcode::kTuple &&
1251 parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
1252 min_use_time = std::min(
1253 min_use_time, instruction_schedule.at(parameter_use.instruction));
1254 }
1255 }
1256 if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
1257 parameter_value->shape(), parameter_time, min_use_time)) {
1258 VLOG(4) << "Conditional allocation allowed in alternate memory for "
1259 "computation = "
1260 << called_computation->name()
1261 << ", parameter time = " << parameter_time
1262 << ", min use time = " << min_use_time;
1263 return true;
1264 } else {
1265 VLOG(4) << "Conditional allocation not allowed in alternate memory for "
1266 "computation = "
1267 << called_computation->name()
1268 << ", parameter time = " << parameter_time
1269 << ", min use time = " << min_use_time;
1270 }
1271 }
1272 return false;
1273 }
1274
1275 return true;
1276 }
1277
1278 namespace {
1279 // Columns in buffer information:
1280 // buffer_id: int. This value can be used to match the allocation in
1281 // allocation information.
1282 // buffer_name: string.
1283 // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1284 // thought it would be beneficial to put this in the alternate memory. The
1285 // higher the value, the more it is memory bound.
1286 // size: int. In bytes.
1287 // definition_time: int. Logical time this value was defined in the schedule.
1288 // use_times: string. This is a semicolon-separated list of integers for all
1289 // the use times.
1290 // use_names: string. This is a semicolon-separated list of string
1291 // representation of uses.
1292 // is_scoped: int. A value of 1 indicates that the buffer is a scoped
1293 // allocation.
1294 constexpr absl::string_view kBufferInfoColumnNames =
1295 "buffer_id,buffer_name,alt_mem_benefit,size,definition_time,use_times,use_"
1296 "names,is_scoped";
1297 } // namespace
1298
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const1299 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
1300 const AlternateMemoryBestFitHeap::BufferInterval& interval,
1301 std::string* debug_str) const {
1302 if (debug_str->empty()) {
1303 // Append the column names.
1304 absl::StrAppend(debug_str, kBufferInfoColumnNames, "\n");
1305 }
1306 const HloBuffer& buffer =
1307 alias_analysis_.GetBufferContainingValue(*interval.buffer);
1308 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1309 int64_t definition_time =
1310 instruction_schedule.at(interval.buffer->defining_position().instruction);
1311 std::vector<std::pair<int64_t, std::string>> uses;
1312 for (const HloValue* value : buffer.values()) {
1313 for (const HloUse& use : value->GetUses()) {
1314 uses.push_back(
1315 {instruction_schedule.at(use.instruction), use.ToString()});
1316 }
1317 }
1318 absl::c_sort(uses);
1319 std::vector<int64_t> use_times;
1320 std::vector<std::string> use_names;
1321 use_times.reserve(uses.size());
1322 use_names.reserve(uses.size());
1323 for (const auto& use : uses) {
1324 use_times.push_back(use.first);
1325 use_names.push_back(use.second);
1326 }
1327
1328 absl::StrAppend(debug_str, buffer.id(), ",");
1329 absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1330 auto alternate_memory_benefit =
1331 options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1332 interval);
1333 absl::StrAppend(
1334 debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1335 absl::StrAppend(debug_str, interval.size, ",");
1336 absl::StrAppend(debug_str, definition_time, ",");
1337 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1338 absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\",");
1339 absl::StrAppend(debug_str, "0"); // is_scoped
1340 absl::StrAppend(debug_str, "\n");
1341 }
1342
AppendScopedAllocationBufferInfoDebugString(const HloInstruction * instruction,int64_t time,int64_t size,std::string & debug_str) const1343 void AlternateMemoryBestFitHeap::AppendScopedAllocationBufferInfoDebugString(
1344 const HloInstruction* instruction, int64_t time, int64_t size,
1345 std::string& debug_str) const {
1346 if (debug_str.empty()) {
1347 // Append the column names.
1348 absl::StrAppend(&debug_str, kBufferInfoColumnNames, "\n");
1349 }
1350 const HloBuffer& buffer = alias_analysis_.GetUniqueBufferAt(instruction);
1351
1352 // As a convention, we use negative values for scoped allocations.
1353 absl::StrAppend(&debug_str, -buffer.id(), ",");
1354 absl::StrAppend(&debug_str, "\"scoped allocation for ", instruction->name(),
1355 "\",");
1356 absl::StrAppend(&debug_str, 0, ","); // alt_mem_benefit
1357 absl::StrAppend(&debug_str, size, ",");
1358 absl::StrAppend(&debug_str, time, ",");
1359 absl::StrAppend(&debug_str, "\"\","); // use_times
1360 absl::StrAppend(&debug_str, "\"\","); // use_names
1361 absl::StrAppend(&debug_str, "1"); // is_scoped
1362 absl::StrAppend(&debug_str, "\n");
1363 }
1364
AppendAllocationInfoDebugString(const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1365 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1366 const MemorySpaceAssignment::Allocation& allocation,
1367 std::string& debug_str) const {
1368 // Columns in allocation information:
1369 // buffer_id: int. This value can be used the match with buffer info.
1370 // size: int. In bytes.
1371 // offset: int. In bytes.
1372 // start_time: int. Logical start time of the allocation.
1373 // end_time: int. Logical end time of the allocation.
1374 if (debug_str.empty()) {
1375 // Append the column names.
1376 absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1377 }
1378 if (allocation.memory_space() == MemorySpace::kAlternate) {
1379 const HloPosition& position = allocation.defining_position();
1380 const HloBuffer& buffer =
1381 alias_analysis_.GetUniqueBufferAt(position.instruction, position.index);
1382 // As a convention, we use negative values for scoped allocations.
1383 absl::StrAppend(
1384 &debug_str,
1385 allocation.is_scoped_allocation() ? -buffer.id() : buffer.id(), ",");
1386 absl::StrAppend(&debug_str, allocation.chunk().size, ",");
1387 absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1388 absl::StrAppend(&debug_str, allocation.start_time(), ",");
1389 absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1390 }
1391 }
1392
DumpDebugStringsIfEnabled() const1393 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1394 if (!options_.dump_fn) {
1395 return;
1396 }
1397 options_.dump_fn("bufferinfo", buffer_info_str_);
1398 options_.dump_fn("allocinfo", allocation_info_str_);
1399 }
1400
Finish()1401 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1402 if (options_.autotuning_config.has_value()) {
1403 CHECK_EQ((*options_.autotuning_config).size(), buffer_intervals_.size());
1404 }
1405
1406 AllocateReservedScopedAllocations();
1407 std::vector<BufferInterval> sorted_buffer_intervals =
1408 GetSortedBufferIntervals();
1409 memory_space_assignment::CustomizeSortedBufferInterval(
1410 options_.autotuning_config, sorted_buffer_intervals);
1411
1412 // Calculate the memory pressure for the buffers that can be assigned in the
1413 // alternate memory.
1414 memory_pressure_ = 0;
1415 for (auto& interval : sorted_buffer_intervals) {
1416 if (!interval.need_allocation ||
1417 !MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1418 interval) ||
1419 interval.size > available_heap_size()) {
1420 continue;
1421 }
1422 memory_pressure_ += interval.size;
1423 }
1424 VLOG(1) << "Memory pressure = " << memory_pressure_;
1425
1426 if (options_.enable_cross_program_prefetch) {
1427 std::optional<AlternateMemoryBestFitHeap::BufferInterval>
1428 prefetch_candidate = FindCrossProgramPrefetchCandidate(
1429 alias_analysis_, hlo_live_range_, options_);
1430 if (prefetch_candidate) {
1431 HloModule* module =
1432 prefetch_candidate->buffer->instruction()->GetModule();
1433 AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1434 }
1435 }
1436
1437 VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1438 << options_.max_size_in_bytes;
1439
1440 AddInputAndOutputRequiredAssignments();
1441
1442 if (VLOG_IS_ON(3)) {
1443 VLOG(3) << "Flattened instruction sequence:";
1444 const auto& instruction_sequence =
1445 hlo_live_range_.flattened_instruction_sequence().instructions();
1446 for (int i = 0; i < instruction_sequence.size(); ++i) {
1447 VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1448 << " " << instruction_sequence[i]->name();
1449 }
1450 }
1451
1452 for (const auto& interval : sorted_buffer_intervals) {
1453 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1454 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1455 // Increment the reserved part of alternate memory so that it is not
1456 // available for other buffers.
1457 reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1458 }
1459 }
1460 VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1461
1462 for (auto& interval : sorted_buffer_intervals) {
1463 if (!interval.need_allocation) {
1464 continue;
1465 }
1466
1467 if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1468 interval)) {
1469 continue;
1470 }
1471
1472 HloInstruction* inst = interval.buffer->instruction();
1473 HloModule* module = inst->GetModule();
1474
1475 // Don't intra-program prefetch a cross program prefetch
1476 if (inst->opcode() == HloOpcode::kParameter &&
1477 absl::c_count(module->CrossProgramPrefetches(),
1478 std::make_pair(inst->parameter_number(),
1479 interval.buffer->index())) > 0) {
1480 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1481 << " because it is cross-program prefetched.";
1482 continue;
1483 }
1484
1485 if (interval.size > available_heap_size()) {
1486 VLOG(3) << "Skip " << interval.buffer->ToShortString()
1487 << " because the buffer is larger than the heap size.";
1488 continue;
1489 }
1490
1491 auto colocated_intervals = GetSortedColocatedIntervals(interval);
1492
1493 if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1494 VLOG(3) << "Interval " << interval.buffer->ToShortString()
1495 << " is reserved in the alternate memory.";
1496 for (const BufferInterval* colocated_interval : colocated_intervals) {
1497 const HloValue* value = colocated_interval->buffer;
1498 // Color all of the aliased reserved buffers here because reserved
1499 // alternate memory allocations will not have an entry in preset
1500 // allocations that is normally used for coloring.
1501 for (auto& position : value->positions()) {
1502 VLOG(4) << "Coloring " << position.ToString();
1503 Shape* shape = ShapeUtil::GetMutableSubshape(
1504 position.instruction->mutable_shape(), position.index);
1505 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1506 << position.ToString();
1507 shape->mutable_layout()->set_memory_space(
1508 options_.alternate_memory_space);
1509 }
1510 }
1511 continue;
1512 }
1513
1514 if (colocated_intervals.size() > 1 &&
1515 !options_.allocate_across_sequential_calls) {
1516 VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1517 << " because it aliases with another interval and "
1518 << " allocate_across_sequential_calls is false.";
1519 continue;
1520 }
1521
1522 if (!ConsumeFuel("memory_space_assignment", [&] {
1523 return absl::StrCat("Ran out of fuel at buffer: ",
1524 colocated_intervals[0]->buffer->ToShortString());
1525 })) {
1526 continue;
1527 }
1528
1529 if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
1530 // Only fill buffer_info_str_ if needed.
1531 AppendBufferInfoDebugString(interval, &buffer_info_str_);
1532 }
1533
1534 std::vector<AllocationValue> allocation_values;
1535 CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1536 allocation_values);
1537
1538 // Retry allocating this value with larger limits if allocation fails.
1539 bool repacked = false;
1540 for (int retry_number = 0; retry_number < options_.max_retries;
1541 retry_number++) {
1542 AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1543 options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1544 Result result =
1545 AllocateAllocationValues(absl::MakeSpan(allocation_values));
1546 VLOG(2) << "Allocation result = "
1547 << absl::StrFormat("%x", static_cast<int>(result));
1548 if (result_requires_uncommit(result)) {
1549 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1550 VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1551 } else if ((result_is(result, Result::kFailOutOfMemory) ||
1552 options_.repack_after_every_allocation) &&
1553 num_repacks_ < options_.max_repacks && !repacked) {
1554 UncommitPendingChunks(absl::MakeSpan(allocation_values));
1555 ++num_repacks_;
1556 repacked = true;
1557 CHECK_NE(options_.repacker, nullptr);
1558 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1559 repack_allocation_blocks;
1560 ExportAllocationsForRepacking(repack_allocation_blocks);
1561 VLOG(2) << "Repacking.";
1562 auto repack_status =
1563 options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1564 CHECK_EQ(repack_status.status(), OkStatus());
1565 VLOG(2) << "Repack complete. Modified = " << *repack_status;
1566 if (*repack_status) {
1567 ImportRepackedAllocations();
1568 --retry_number;
1569 }
1570 } else {
1571 FinalizeAllocations(absl::MakeSpan(allocation_values));
1572 break;
1573 }
1574 }
1575 }
1576
1577 if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
1578 for (auto& allocation : *allocations_) {
1579 // Only fill allocation_info_str_ if needed.
1580 AppendAllocationInfoDebugString(*allocation, allocation_info_str_);
1581 }
1582 }
1583
1584 VLOG(3) << "Debug buffer info: ";
1585 XLA_VLOG_LINES(3, buffer_info_str_);
1586 VLOG(3) << "Debug allocation info: ";
1587 XLA_VLOG_LINES(3, allocation_info_str_);
1588 DumpDebugStringsIfEnabled();
1589
1590 HeapSimulator::Result<HloValue> result;
1591 result.heap_size = result_.heap_size;
1592 result.heap_results.emplace_back(std::move(result_));
1593 return result;
1594 }
1595
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1596 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1597 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1598 colocated_intervals) {
1599 // TODO(berkin): For now, place the phi values due to conditionals in
1600 // default memory.
1601 for (const BufferInterval* colocated_interval : colocated_intervals) {
1602 const HloValue* value = colocated_interval->buffer;
1603 for (const auto& position : value->positions()) {
1604 if (position.instruction->opcode() == HloOpcode::kConditional) {
1605 VLOG(3) << "Adding required assignment for condition output: "
1606 << value->ToShortString();
1607 AddRequiredAssignment(position.instruction, position.index,
1608 MemorySpace::kDefault);
1609 for (const HloComputation* called_computation :
1610 position.instruction->called_computations()) {
1611 AddRequiredAssignment(called_computation->root_instruction(),
1612 position.index, MemorySpace::kDefault);
1613 }
1614 }
1615 }
1616 }
1617 }
1618
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1619 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1620 absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1621 colocated_intervals,
1622 std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1623 // Create AllocationValues for all the colocated intervals.
1624 for (const auto& colocated_interval : colocated_intervals) {
1625 CreateAllocationValues(*colocated_interval, allocation_values);
1626 }
1627 // Go through the AllocationValues and delete the ones that have the identical
1628 // defining instruction and use instructions. This is useful for async
1629 // operations that can read and write to the same buffer, e.g., in-place
1630 // asynchronous collective permute. The AllocationValues that corresponds to
1631 // collective-permute-start{0} (the input) and collective-permute-start{1}
1632 // (the output) refer to the same buffer by definition (since they are created
1633 // from colocated intervals). If we don't delete one of these buffers, then
1634 // when we try to allocate the AllocationValue, we would think they overlap.
1635 auto create_instruction_vector = [](const AllocationValue& allocation_value) {
1636 std::vector<const HloInstruction*> instruction_vector;
1637 instruction_vector.push_back(allocation_value.defining_instruction());
1638 for (const AllocationValue::Use& use : allocation_value.uses()) {
1639 instruction_vector.push_back(use.hlo_use.instruction);
1640 }
1641 return instruction_vector;
1642 };
1643 for (int i = 0; i < allocation_values.size() - 1; ++i) {
1644 for (int j = i + 1; j < allocation_values.size(); ++j) {
1645 const AllocationValue& allocation_value_1 = allocation_values[i];
1646 const AllocationValue& allocation_value_2 = allocation_values[j];
1647 if (create_instruction_vector(allocation_value_1) ==
1648 create_instruction_vector(allocation_value_2)) {
1649 VLOG(3) << "Allocation values " << allocation_value_1.ToShortString()
1650 << " and " << allocation_value_2.ToShortString()
1651 << " are equivalent, deleting the second one.";
1652 allocation_values.erase(allocation_values.begin() + j);
1653 --j;
1654 }
1655 }
1656 }
1657
1658 FindAliases(&allocation_values);
1659 }
1660
1661 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1662 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1663 absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1664 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1665
1666 // Find the use times across all of the related AllocationValues and sort
1667 // them. We use these to find allocations that are available throughout the
1668 // entire live range of all the AllocationValues.
1669 std::vector<int64_t> all_use_times;
1670 for (const AllocationValue& allocation_value : allocation_values) {
1671 absl::c_transform(allocation_value.uses(),
1672 std::back_inserter(all_use_times),
1673 [](const AllocationValue::Use& use) { return use.time; });
1674 }
1675 absl::c_sort(all_use_times);
1676
1677 // Data structure to contain the preferred offset for a given computation.
1678 // We ensure that the same offset will be allocated outside the while loop
1679 // as well as inside the while loop.
1680 absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1681 preferred_offset_for_computation;
1682
1683 Result result = Result::kSuccess;
1684 for (AllocationValue& allocation_value : allocation_values) {
1685 int64_t definition_time =
1686 instruction_schedule.at(allocation_value.defining_instruction());
1687
1688 AliasedOffset* preferred_offset = nullptr;
1689 auto preferred_offset_it =
1690 preferred_offset_for_computation.find(allocation_value.computation());
1691 if (preferred_offset_it != preferred_offset_for_computation.end()) {
1692 preferred_offset = preferred_offset_it->second;
1693 }
1694
1695 // Iterate over the uses.
1696 for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1697 const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1698 const HloUse hlo_use = use.hlo_use;
1699 int64_t use_time = instruction_schedule.at(hlo_use.instruction);
1700 int64_t latest_prefetch_time = use_time;
1701 bool allow_no_copy_alternate_mem_allocation = true;
1702 std::optional<int64_t> earliest_prefetch_time = std::nullopt;
1703
1704 // Control flow calls include kWhile, kCall, and kConditional opcodes.
1705 bool is_sequential_call =
1706 (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1707 CallContext::kControlFlow);
1708 if (is_sequential_call) {
1709 for (const HloComputation* called_computation :
1710 hlo_use.instruction->called_computations()) {
1711 const HloLiveRange::TimeBound& computation_span =
1712 hlo_live_range_.computation_span_times().at(called_computation);
1713 latest_prefetch_time =
1714 std::min(computation_span.start - 1, latest_prefetch_time);
1715 }
1716 if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1717 // Given an example while loop and flattened schedule (logical times
1718 // shown on the left):
1719 //
1720 // 0: a = ...
1721 // 1: ...
1722 // cond {
1723 // 2: p = param(0)
1724 // 3: ...
1725 // }
1726 // body {
1727 // 4: p = param(0)
1728 // 5: ...
1729 // 6: ROOT ...
1730 // }
1731 // 7: w = while(a), body=body, cond=cond
1732 //
1733 // When processing "a" (time 0) and its while use (time 7), we update
1734 // the interval to time 0-4. This is so that the remaining interval
1735 // (5-6) can be allocated separately and this buffer doesn't waste
1736 // alternate memory space within the while loop body.
1737 HloComputation* while_body = hlo_use.instruction->while_body();
1738 // We require while body ROOTs to be the last in the schedule.
1739 CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1740 instruction_schedule.at(hlo_use.instruction))
1741 << "While body ROOTs need to be the last in the schedule! "
1742 "Please run RootInstructionSinker.";
1743 // Replace the use time with the parameter time so that we can decide
1744 // on alternate memory allocations within the while loop body when we
1745 // look at uses within the while loop body.
1746 use_time =
1747 instruction_schedule.at(while_body->parameter_instruction(0));
1748 } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1749 // Replace the use time with the earliest parameter of called
1750 // computations.
1751 for (const HloComputation* called_computation :
1752 hlo_use.instruction->called_computations()) {
1753 use_time = std::min(
1754 use_time, instruction_schedule.at(
1755 called_computation->parameter_instruction(0)));
1756 }
1757 }
1758 }
1759
1760 // Add a required assignment in default memory if the use not allowed in
1761 // alternate memory.
1762 if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1763 AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1764 MemorySpace::kDefault, use_time);
1765 } else if (use_idx > 0) {
1766 // We allow buffers in alternate memory that are passed into
1767 // conditionals to give up their alternate memory allocation inside the
1768 // called computation. This means that if a conditional operator has an
1769 // alternate memory allocation, subsequent uses cannot use the same
1770 // alternate memory allocation in order not to clobber data. So we force
1771 // default memory allocation for these subsequent uses.
1772 const AllocationValue::Use& previous_use =
1773 allocation_value.uses().at(use_idx - 1);
1774 if (previous_use.hlo_use.instruction->opcode() ==
1775 HloOpcode::kConditional &&
1776 previous_use.hlo_use.instruction != hlo_use.instruction) {
1777 allow_no_copy_alternate_mem_allocation = false;
1778 earliest_prefetch_time =
1779 instruction_schedule.at(previous_use.hlo_use.instruction);
1780 VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1781 << ") of use (" << hlo_use.ToString()
1782 << ") is a conditional, so this use will need to evict. "
1783 << "Earliest prefetch time = " << *earliest_prefetch_time;
1784 }
1785 }
1786
1787 // Bitcasts don't define buffers and don't directly consume buffers. Skip
1788 // allocating buffers for bitcast uses (unless they are the root
1789 // instruction). The uses that feed from bitcasts will be handled
1790 // specially.
1791 if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1792 hlo_use.instruction ==
1793 hlo_use.instruction->parent()->root_instruction()) {
1794 AllocationRequest request;
1795 // Rarely, (e.g., when conditional true and false parameters are the
1796 // same), definition time can be the time of the conditional and use
1797 // time is the parameter use, which is less.
1798 request.start_time = std::min(definition_time, use_time);
1799 request.end_time = use_time;
1800 request.latest_prefetch_time = latest_prefetch_time;
1801 request.size = allocation_value.size();
1802 request.allow_no_copy_alternate_mem_allocation =
1803 allow_no_copy_alternate_mem_allocation;
1804 request.earliest_prefetch_time = earliest_prefetch_time;
1805 request.preferred_offset = preferred_offset;
1806 request.use = &use;
1807 request.allocation_value = &allocation_value;
1808 request.all_use_times = all_use_times;
1809 result_mark(AllocateSegment(request), result);
1810 if (result_requires_uncommit(result)) {
1811 // If the allocation finding failed (e.g., due to running out of
1812 // asynchronous copies), then fall back to allocating the buffer
1813 // entirely in the default memory.
1814 return result;
1815 }
1816
1817 // If there are multiple uses, they can try using the memory allocation
1818 // already at the alternate memory.
1819 definition_time = instruction_schedule.at(hlo_use.instruction);
1820 }
1821
1822 // Propagate the allocation to any aliases this use might have had.
1823 MemorySpaceAssignment::Allocation* aliased_allocation =
1824 GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1825 use_time);
1826 for (const HloPosition& aliased_position : use.aliases) {
1827 AddAliasedRequiredAssignment(aliased_position.instruction,
1828 aliased_position.index,
1829 aliased_allocation);
1830 }
1831
1832 if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1833 aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1834 // For while uses that are allocated in the alternate memory space, if
1835 // they also have an allocation in the default memory space in their
1836 // allocation sequence, create a "parent" allocation that mirrors this
1837 // default memory space allocation. When we process the parent
1838 // allocation, we add an additional parameter to the while that is a
1839 // reference to the buffer in the default memory space. With parent
1840 // allocations, we don't need to unnecessarily evict buffers since they
1841 // already have a copy in the default memory space. We search backwards
1842 // (latest to earliest in execution time) for a suitable allocation in
1843 // order to find the most recent one.
1844 if (options_.enable_while_redundant_eviction_elimination &&
1845 absl::c_find_if(allocation_value.value()->positions(),
1846 [&hlo_use](const HloPosition& position) {
1847 return position.instruction ==
1848 hlo_use.instruction &&
1849 position.index == hlo_use.operand_index;
1850 }) != allocation_value.value()->positions().end()) {
1851 auto allocation_sequence = allocation_value.allocation_sequence();
1852 auto prev_allocation_in_default_mem_it = std::find_if(
1853 allocation_sequence->rbegin(), allocation_sequence->rend(),
1854 [&](const auto& allocation) {
1855 return allocation->memory_space() == MemorySpace::kDefault &&
1856 allocation->defining_position() ==
1857 allocation_value.defining_position();
1858 });
1859 if (prev_allocation_in_default_mem_it !=
1860 allocation_sequence->rend()) {
1861 VLOG(3) << "Found a prev allocation in default mem for while use: "
1862 << (*prev_allocation_in_default_mem_it)->ToString();
1863 auto body_allocation_value_it = absl::c_find_if(
1864 allocation_values, [&](const AllocationValue& value) {
1865 return value.computation() ==
1866 hlo_use.instruction->while_body() &&
1867 value.defining_instruction()->opcode() ==
1868 HloOpcode::kParameter;
1869 });
1870 CHECK_NE(body_allocation_value_it, allocation_values.end());
1871 VLOG(3) << "Body allocation value: "
1872 << body_allocation_value_it->ToShortString();
1873 int64_t body_parameter_time = instruction_schedule.at(
1874 body_allocation_value_it->defining_instruction());
1875 body_allocation_value_it->allocation_sequence()->push_back(
1876 std::make_unique<MemorySpaceAssignment::ParentAllocation>(
1877 **prev_allocation_in_default_mem_it, hlo_use.instruction,
1878 body_allocation_value_it->defining_position(),
1879 body_parameter_time));
1880 VLOG(3) << "Created: "
1881 << body_allocation_value_it->allocation_sequence()
1882 ->back()
1883 ->ToString();
1884
1885 auto after_while_allocation_value_it = absl::c_find_if(
1886 allocation_values, [&](const AllocationValue& value) {
1887 return value.defining_instruction() == hlo_use.instruction;
1888 });
1889 CHECK_NE(after_while_allocation_value_it, allocation_values.end());
1890 VLOG(3) << "After while allocation value: "
1891 << after_while_allocation_value_it->ToShortString();
1892 int64_t while_time = instruction_schedule.at(hlo_use.instruction);
1893 after_while_allocation_value_it->allocation_sequence()->push_back(
1894 std::make_unique<MemorySpaceAssignment::MirroredAllocation>(
1895 **prev_allocation_in_default_mem_it, while_time));
1896 VLOG(3) << "Created: "
1897 << after_while_allocation_value_it->allocation_sequence()
1898 ->back()
1899 ->ToString();
1900 }
1901 }
1902 // Special case for while loops since the root offset must agree with
1903 // other offsets: remember the preferred offset for the while loop body.
1904 preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1905 GetAliasedOffset(*aliased_allocation);
1906 }
1907 }
1908 }
1909 return result;
1910 }
1911
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1912 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1913 return a.AsTuple() < b.AsTuple();
1914 }
1915
operator ==(const AsynchronousCopy & a,const AsynchronousCopy & b)1916 bool operator==(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1917 return a.AsTuple() == b.AsTuple();
1918 }
1919
operator !=(const AsynchronousCopy & a,const AsynchronousCopy & b)1920 bool operator!=(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1921 return a.AsTuple() != b.AsTuple();
1922 }
1923
ConsumeResource(int64_t start_time,int64_t end_time,float resource,bool update_current_resource,const std::list<AsynchronousCopy>::iterator * current_copy,float resource_to_free)1924 bool AsynchronousCopyResource::ConsumeResource(
1925 int64_t start_time, int64_t end_time, float resource,
1926 bool update_current_resource,
1927 const std::list<AsynchronousCopy>::iterator* current_copy,
1928 float resource_to_free) {
1929 VLOG(3) << "Consume resource: " << start_time << ", " << end_time << ", "
1930 << resource << ", delay: " << delay_[start_time + 1]
1931 << ", free: " << resource_to_free;
1932
1933 // Nothing to do if we're not adding or removing any resources.
1934 if (resource == 0.0 && resource_to_free == 0.0) {
1935 return true;
1936 }
1937
1938 // For the async copy we're adding, check the delay_ array to see how much
1939 // this copy would have to be delayed because of an earlier copy that wasn't
1940 // finished when this copy starts.
1941 if (current_copy == nullptr) {
1942 resource += delay_[start_time + 1];
1943 }
1944
1945 // Find the copy that is right after this one. If there are leftover resources
1946 // by the time the next copy starts, the next copy will be pushed further
1947 // later in time.
1948 auto next_copy = async_copies_.end();
1949 if (current_copy != nullptr) {
1950 next_copy = std::next(*current_copy);
1951 } else {
1952 auto async_copy_time_it = async_copy_time_map_.upper_bound(start_time);
1953 if (async_copy_time_it != async_copy_time_map_.end()) {
1954 next_copy = async_copy_time_it->second;
1955 }
1956 }
1957
1958 // Check if this copy will push the next copy later in time (or if removing
1959 // the resource, check if the removal of this copy move the next copy earlier
1960 // in time).
1961 std::optional<float> delay_for_next_copy = std::nullopt;
1962 float resource_freed = 0.0;
1963 for (int64_t time = start_time + 1; time < end_time && resource != 0;
1964 ++time) {
1965 // Iterate over the logical times that this copy spans. Note that the start
1966 // and end time ranges are exclusive.
1967 float used_resource = std::min(resource, initial_resources_[time]);
1968 if (next_copy != async_copies_.end() && next_copy->start_time == time - 1) {
1969 // This is the time where the next copy begins. If the resource is
1970 // non-zero at this point, the copy didn't finish by the time the next
1971 // copy started, so the next copy would need to be pushed later in time.
1972 delay_for_next_copy = resource;
1973 resource_to_free -= resource_freed;
1974 }
1975 if (update_current_resource && !delay_for_next_copy.has_value()) {
1976 // Update the delay_ vector and resource_freed variable with the amount
1977 // that was freed when removing the copy.
1978 float old_resource =
1979 std::max(0.0f, initial_resources_[time] - delay_[time]);
1980 delay_[time] = std::max(0.0f, resource - resource_to_free);
1981 float new_resource =
1982 std::max(0.0f, initial_resources_[time] - delay_[time]);
1983 resource_freed += std::max(0.0f, new_resource - old_resource);
1984 }
1985 // Update the resource with the used amount in this logical time.
1986 resource -= used_resource;
1987 }
1988
1989 // If resource isn't satisfied by the end, we didn't have enough resources.
1990 if (resource > 0) {
1991 VLOG(3) << "Doesn't have enough resource; leftover resource = " << resource;
1992 return false;
1993 }
1994
1995 // If this copy overlapped with another one, we recursively call
1996 // ConsumeResource with the amount of resource that needs to be added or
1997 // removed.
1998 if (delay_for_next_copy.has_value()) {
1999 return ConsumeResource(next_copy->start_time, next_copy->end_time,
2000 *delay_for_next_copy + next_copy->resource,
2001 update_current_resource, &next_copy,
2002 resource_to_free);
2003 }
2004 return true;
2005 }
2006
AddCopy(const AsynchronousCopy & copy)2007 void AsynchronousCopyResource::AddCopy(const AsynchronousCopy& copy) {
2008 CHECK(ConsumeResource(copy.start_time, copy.end_time, copy.resource,
2009 /*update_current_resource=*/true));
2010 // Find the iterator for the copy that would be right after this copy and put
2011 // this copy right before it in async_copies_.
2012 auto async_copy_time_it = async_copy_time_map_.upper_bound(copy.start_time);
2013 auto insertion_it = (async_copy_time_it == async_copy_time_map_.end())
2014 ? async_copies_.end()
2015 : async_copy_time_it->second;
2016 auto inserted_it = async_copies_.insert(insertion_it, copy);
2017 // If this copy is the first copy we have seen with the start time, add the
2018 // inserted iterator into async_copy_time_map_ for fast lookups. Note that
2019 // async_copy_time_map_ always points to the very first copy with the same
2020 // start index. If there are multiple asynchronous copies that have the same
2021 // start time, the memory space assignment algorithm schedules them in the
2022 // same order that AddCopy was called.
2023 if (async_copy_time_map_.find(copy.start_time) ==
2024 async_copy_time_map_.end()) {
2025 async_copy_time_map_[copy.start_time] = inserted_it;
2026 }
2027 }
2028
RemoveCopy(const AsynchronousCopy & copy)2029 void AsynchronousCopyResource::RemoveCopy(const AsynchronousCopy& copy) {
2030 CHECK(ConsumeResource(copy.start_time, copy.end_time, /*resource=*/0,
2031 /*update_current_resource=*/true,
2032 /*current_copy=*/nullptr,
2033 /*resource_to_free=*/copy.resource));
2034 // Using async_copy_time_map_, find this copy to be removed. Note that the
2035 // iterator in async_copy_time_map_ points to the first-seen copy with the
2036 // given start time, so the copy to be removed might be later than the first
2037 // one.
2038 auto async_copy_time_it = async_copy_time_map_.find(copy.start_time);
2039 CHECK(async_copy_time_it != async_copy_time_map_.end());
2040 auto it = async_copy_time_it->second;
2041 for (; it != async_copies_.end() && *it != copy; ++it) {
2042 }
2043 CHECK(it != async_copies_.end());
2044 // If the copy to be removed is the value pointed by async_copy_time_map_, we
2045 // make the next copy with the same start time to be pointed by
2046 // async_copy_time_map_. If there are no such copies, we remove the key for
2047 // this copy start time.
2048 if (it == async_copy_time_it->second) {
2049 if (std::next(it) != async_copies_.end() &&
2050 std::next(it)->start_time == copy.start_time) {
2051 async_copy_time_it->second = std::next(it);
2052 } else {
2053 async_copy_time_map_.erase(async_copy_time_it);
2054 }
2055 }
2056 async_copies_.erase(it);
2057 }
2058
HasEnoughResource(int64_t start_time,int64_t end_time,float resource)2059 bool AsynchronousCopyResource::HasEnoughResource(int64_t start_time,
2060 int64_t end_time,
2061 float resource) {
2062 return ConsumeResource(start_time, end_time, resource,
2063 /*update_current_resource=*/false);
2064 }
2065
2066 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)2067 AlternateMemoryBestFitHeap::GetAliasedOffset(
2068 const MemorySpaceAssignment::Allocation& allocation) {
2069 auto aliased_offset_it = aliased_offset_map_.find(&allocation);
2070 CHECK(aliased_offset_it != aliased_offset_map_.end());
2071 return aliased_offset_it->second;
2072 }
2073
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)2074 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
2075 const MemorySpaceAssignment::Allocation& allocation,
2076 AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
2077 CHECK(allocation.memory_space() == MemorySpace::kAlternate);
2078 CHECK(!aliased_offset_map_.contains(&allocation));
2079 if (!aliased_offset) {
2080 aliased_offsets_.push_back({allocation.chunk().offset});
2081 aliased_offset = &aliased_offsets_.back();
2082 }
2083 CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
2084 CHECK(aliased_offset->allocations.insert(&allocation).second);
2085 aliased_offset_map_[&allocation] = aliased_offset;
2086 }
2087
2088 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64_t time)2089 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
2090 const MemorySpaceAssignment::AllocationSequence& allocations,
2091 int64_t time) {
2092 for (auto allocation_it = allocations.rbegin();
2093 allocation_it != allocations.rend(); ++allocation_it) {
2094 if ((*allocation_it)->start_time() <= time &&
2095 (*allocation_it)->end_time() >= time) {
2096 return allocation_it->get();
2097 }
2098 }
2099 return nullptr;
2100 }
2101
AllocateCrossProgramPrefetchBuffer(HloModule * module,std::optional<BufferInterval> prefetch_candidate)2102 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
2103 HloModule* module, std::optional<BufferInterval> prefetch_candidate) {
2104 if (!prefetch_candidate) {
2105 return;
2106 }
2107
2108 Chunk chunk_candidate = FindChunkCandidate(*prefetch_candidate);
2109 if (chunk_candidate.chunk_end() > available_heap_size()) {
2110 LOG(WARNING)
2111 << "Could not allocate preferred memory for cross program prefetch";
2112 return;
2113 }
2114
2115 const HloValue* buffer = prefetch_candidate->buffer;
2116 int64_t parameter = buffer->instruction()->parameter_number();
2117 module->AddCrossProgramPrefetch(parameter, buffer->index());
2118
2119 MemorySpaceAssignment::AllocationSequence allocations;
2120 allocations.push_back(std::make_unique<MemorySpaceAssignment::Allocation>(
2121 buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
2122 prefetch_candidate->start, prefetch_candidate->end,
2123 /*is_scoped_allocation=*/false));
2124
2125 // Find the earliest use.
2126 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
2127 auto uses = buffer->GetUses();
2128 auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
2129 return instruction_schedule.at(lhs.instruction) <
2130 instruction_schedule.at(rhs.instruction);
2131 };
2132 auto first_use = absl::c_min_element(uses, use_schedule_compare);
2133 int64_t latest_prefetch_time =
2134 instruction_schedule.at(first_use->instruction);
2135
2136 // Find the latest use time.
2137 int64_t last_use_time = instruction_schedule.at(
2138 absl::c_max_element(uses, use_schedule_compare)->instruction);
2139 for (const HloValue* colocation : prefetch_candidate->colocations) {
2140 auto colocation_uses = colocation->GetUses();
2141 if (!colocation_uses.empty()) {
2142 last_use_time = std::max(
2143 last_use_time,
2144 instruction_schedule.at(
2145 absl::c_max_element(colocation_uses, use_schedule_compare)
2146 ->instruction));
2147 }
2148 }
2149
2150 int64_t end_of_program_prefetch_end_time = instruction_schedule.size();
2151 int64_t end_of_program_prefetch_latest_start_time =
2152 options_.prefetch_interval_picker->LatestPrefetchStartTime(
2153 buffer->defining_position().shape(), last_use_time,
2154 end_of_program_prefetch_end_time, nullptr);
2155 int64_t end_of_program_prefetch_start_time =
2156 options_.prefetch_interval_picker->PreferredPrefetchStartTime(
2157 buffer->defining_position().shape(), last_use_time,
2158 end_of_program_prefetch_latest_start_time,
2159 end_of_program_prefetch_end_time);
2160 VLOG(2) << "last use time = " << last_use_time
2161 << ", end-of-program prefetch start time = "
2162 << end_of_program_prefetch_start_time;
2163 float total_execution_time =
2164 options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2165 0, instruction_schedule.size());
2166 float buffer_occupied_time =
2167 options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2168 0, last_use_time) +
2169 options_.prefetch_interval_picker->GetLogicalIntervalElapsed(
2170 end_of_program_prefetch_start_time, end_of_program_prefetch_end_time);
2171 float buffer_occupied_ratio = buffer_occupied_time / total_execution_time;
2172 VLOG(2) << "Total execution time = " << total_execution_time
2173 << ", buffer occupied time = " << buffer_occupied_time
2174 << ", buffer occupied ratio = " << buffer_occupied_ratio;
2175 // Freeing buffer only makes sense if the buffer will be free for a
2176 // substantial time. Only perform this optimization if the ratio is below the
2177 // limit, and if the memory pressure is above the alternate memory size.
2178 bool free_buffer =
2179 (options_.enable_cross_program_prefetch_freeing &&
2180 memory_pressure_ > options_.max_size_in_bytes &&
2181 buffer_occupied_ratio < kCrossProgramPrefetchOccupyFreeingLimit &&
2182 end_of_program_prefetch_start_time > last_use_time &&
2183 end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
2184 int64_t cross_program_prefetch_end_time =
2185 free_buffer ? last_use_time : prefetch_candidate->end;
2186
2187 AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate,
2188 prefetch_candidate->start, cross_program_prefetch_end_time,
2189 latest_prefetch_time, &allocations, /*aliased_offset=*/nullptr,
2190 /*resource=*/0.0,
2191 /*is_cross_program_prefetch=*/true);
2192
2193 HloInstruction* root_instruction =
2194 module->entry_computation()->root_instruction();
2195 absl::c_for_each(uses, [&](auto& use) {
2196 if (use.instruction != root_instruction) {
2197 allocations.back()->AddUse(use);
2198 }
2199 });
2200 AliasedOffset* cross_program_prefetch_offset =
2201 GetAliasedOffset(*allocations.back());
2202
2203 if (free_buffer) {
2204 VLOG(2) << "Adding an end-of-program prefetch for freed "
2205 "cross-program-prefetched buffer.";
2206 AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, chunk_candidate,
2207 end_of_program_prefetch_start_time,
2208 end_of_program_prefetch_end_time,
2209 end_of_program_prefetch_end_time, &allocations,
2210 cross_program_prefetch_offset,
2211 /*resource=*/0.0);
2212 CHECK_EQ(cross_program_prefetch_offset->offset,
2213 allocations.back()->chunk().offset);
2214 }
2215
2216 const int allocations_initial_size = allocations_->size();
2217 for (auto& allocation : allocations) {
2218 if (allocation->memory_space() == MemorySpace::kAlternate) {
2219 BufferInterval buffer_interval;
2220 buffer_interval.start = allocation->start_time();
2221 buffer_interval.end = allocation->end_time();
2222 buffer_interval.size = allocation->chunk().size;
2223 buffer_interval.buffer = prefetch_candidate->buffer;
2224 AddToPendingChunks(buffer_interval, chunk_candidate);
2225 }
2226 allocations_->push_back(std::move(allocation));
2227 }
2228
2229 // Add a repack allocation block for the Allocation objects in alternate
2230 // memory.
2231 for (int i = allocations_initial_size; i < allocations_->size(); ++i) {
2232 const auto& allocation = allocations_->at(i);
2233 if (allocation->memory_space() == MemorySpace::kAlternate) {
2234 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2235 allocation->start_time(), allocation->end_time(),
2236 allocation->chunk().size, allocation->chunk().offset,
2237 static_cast<int64_t>(repack_allocation_blocks_.size()),
2238 allocation.get()));
2239 RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
2240 for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
2241 colocation.colocations.push_back(inserted);
2242 if (&colocation != inserted) {
2243 inserted->colocations.push_back(&colocation);
2244 }
2245 }
2246 }
2247 }
2248
2249 ClearPendingChunks();
2250 }
2251
AllocateReservedScopedAllocations()2252 void AlternateMemoryBestFitHeap::AllocateReservedScopedAllocations() {
2253 const auto& instruction_sequence =
2254 hlo_live_range_.flattened_instruction_sequence().instructions();
2255 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2256 for (int i = 0; i < instruction_sequence.size(); ++i) {
2257 const HloInstruction* instruction = instruction_sequence[i];
2258 int64_t reserved_scoped_memory =
2259 options_.reserved_scoped_memory_fn(instruction);
2260 if (reserved_scoped_memory != 0) {
2261 VLOG(1) << "Allocate reserved scoped memory at " << i << " ("
2262 << instruction->name() << "): " << reserved_scoped_memory;
2263 MemorySpaceAssignment::BufferInterval interval;
2264 interval.buffer = nullptr;
2265 interval.size = reserved_scoped_memory;
2266 interval.start = i;
2267 interval.end = i;
2268 interval.need_allocation = true;
2269 interval.colocations = {};
2270 Chunk chunk_candidate =
2271 FindChunkCandidate(interval, /*preferred_offset=*/0);
2272 CHECK_EQ(chunk_candidate.offset, 0);
2273 AddToPendingChunks(interval, chunk_candidate);
2274
2275 if (options_.dump_fn != nullptr || VLOG_IS_ON(3)) {
2276 AppendScopedAllocationBufferInfoDebugString(
2277 instruction, i, reserved_scoped_memory, buffer_info_str_);
2278 }
2279
2280 allocations_->push_back(
2281 std::make_unique<MemorySpaceAssignment::Allocation>(
2282 HloPosition{instruction_sequence[i], {}}, MemorySpace::kAlternate,
2283 chunk_candidate, i, i, /*is_scoped_allocation=*/true));
2284
2285 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2286 i, i, reserved_scoped_memory,
2287 /*initial_offset=*/0,
2288 static_cast<int64_t>(repack_allocation_blocks_.size()),
2289 allocations_->back().get()));
2290 colocations.push_back(&repack_allocation_blocks_.back());
2291 }
2292 }
2293 // If requested, make all scoped allocations to colocate with each other so
2294 // that when we repack, all scoped allocations get the same offsets. Since
2295 // they will all have the same scoped memory addresses, this increases the
2296 // opportunity to deduplicate different ops. However, this may hurt the
2297 // memory packing efficiency.
2298 if (options_.allocate_reserved_scoped_memory_at_same_offset) {
2299 for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2300 colocations) {
2301 repack_block->colocations = colocations;
2302 }
2303 }
2304 ClearPendingChunks();
2305 }
2306
2307 std::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64_t time) const2308 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
2309 int64_t time) const {
2310 auto required_assignment_it = required_assignments_.find(buffer);
2311 std::optional<RequiredMemoryAssignment> required_assignment_at_time;
2312 if (required_assignment_it != required_assignments_.end()) {
2313 for (const RequiredMemoryAssignment& required_assignment :
2314 required_assignment_it->second) {
2315 if (required_assignment.time == time) {
2316 // Sanity check that there is only one required at time.
2317 CHECK(!required_assignment_at_time);
2318 required_assignment_at_time = required_assignment;
2319 }
2320 }
2321 }
2322 return required_assignment_at_time;
2323 }
2324
2325 std::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const2326 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
2327 const AllocationValue::Use& use) const {
2328 std::optional<RequiredMemoryAssignment> required_assignment;
2329 for (const HloPosition& position : use.aliases) {
2330 const HloValue* value =
2331 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
2332 position.instruction, position.index);
2333 int64_t time =
2334 hlo_live_range_.instruction_schedule().at(position.instruction);
2335 std::optional<RequiredMemoryAssignment> required_assignment_for_alias =
2336 RequiredMemoryAssignmentAt(value, time);
2337 if (required_assignment == std::nullopt) {
2338 required_assignment = required_assignment_for_alias;
2339 } else {
2340 CHECK(required_assignment_for_alias == std::nullopt ||
2341 required_assignment->equals_ignoring_time(
2342 *required_assignment_for_alias));
2343 }
2344 }
2345 return required_assignment;
2346 }
2347
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)2348 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
2349 const HloInstruction* instruction, ShapeIndex index,
2350 const MemorySpaceAssignment::Allocation* aliased_allocation) {
2351 AliasedOffset* offset = nullptr;
2352 if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
2353 offset = GetAliasedOffset(*aliased_allocation);
2354 }
2355 AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
2356 offset);
2357 }
2358
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64_t time,AliasedOffset * offset)2359 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
2360 const HloValue* value, const HloInstruction* instruction,
2361 MemorySpaceAssignment::MemorySpace memory_space, int64_t time,
2362 AliasedOffset* offset) {
2363 // Check for existing required assignment at this time and make sure it is the
2364 // same as this if there is one.
2365 auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
2366 if (existing_required_assignment) {
2367 CHECK(memory_space == existing_required_assignment->memory_space)
2368 << "inst = " << instruction->ToString() << " at " << time;
2369 CHECK((!offset && !existing_required_assignment->offset) ||
2370 offset == existing_required_assignment->offset);
2371 VLOG(3) << "Not adding required assignment because there is one already: "
2372 << value->ToShortString() << " at " << time << " at "
2373 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2374 } else {
2375 VLOG(3) << "Adding required assignment: " << value->ToShortString()
2376 << " at " << time << " at "
2377 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2378 RequiredMemoryAssignment required_assignment{memory_space, time, offset};
2379 required_assignments_[value].push_back(required_assignment);
2380 pending_required_assignments_.push_back({value, required_assignment});
2381 }
2382 }
2383
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)2384 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
2385 const HloInstruction* instruction, ShapeIndex index,
2386 MemorySpace memory_space, AliasedOffset* offset) {
2387 const HloValue* value =
2388 &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
2389 int64_t instruction_time =
2390 hlo_live_range_.instruction_schedule().at(instruction);
2391 AddRequiredAssignment(value, instruction, memory_space, instruction_time,
2392 offset);
2393 }
2394
AddInputAndOutputRequiredAssignments()2395 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
2396 // Go through the parameters, outputs, and constants and pin them to the
2397 // corresponding memory by adding a required assignment.
2398 const HloModule& module = alias_analysis_.dataflow_analysis().module();
2399 const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
2400 HloComputation* entry_computation = module.entry_computation();
2401 for (HloInstruction* parameter_instruction :
2402 entry_computation->parameter_instructions()) {
2403 int64_t parameter_instruction_time =
2404 instruction_schedule.at(parameter_instruction);
2405 ShapeUtil::ForEachSubshape(
2406 parameter_instruction->shape(),
2407 [&](const Shape& subshape, const ShapeIndex& index) {
2408 MemorySpace memory_space = MemorySpace::kDefault;
2409 if (subshape.has_layout() && subshape.layout().memory_space() ==
2410 options_.alternate_memory_space) {
2411 memory_space = MemorySpace::kAlternate;
2412 }
2413 for (const HloBuffer* buffer :
2414 alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
2415 for (const HloValue* value : buffer->values()) {
2416 VLOG(3) << "Adding required assignment for parameter value = "
2417 << value->ToShortString()
2418 << " time = " << parameter_instruction_time << " space = "
2419 << (memory_space == MemorySpace::kDefault ? "def"
2420 : "alt");
2421 required_assignments_[value].push_back(
2422 {memory_space, /*time=*/parameter_instruction_time});
2423 }
2424 }
2425 });
2426 }
2427 HloInstruction* root_instruction = entry_computation->root_instruction();
2428 int64_t root_instruction_time = instruction_schedule.at(root_instruction);
2429 ShapeUtil::ForEachSubshape(
2430 root_instruction->shape(),
2431 [&](const Shape& subshape, const ShapeIndex& index) {
2432 MemorySpace memory_space = MemorySpace::kDefault;
2433 if (subshape.has_layout() && subshape.layout().memory_space() ==
2434 options_.alternate_memory_space) {
2435 memory_space = MemorySpace::kAlternate;
2436 }
2437 for (const HloBuffer* buffer :
2438 alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
2439 for (const HloValue* value : buffer->values()) {
2440 VLOG(3) << "Adding required assignment for output value = "
2441 << value->ToShortString()
2442 << " time = " << root_instruction_time << " space = "
2443 << (memory_space == MemorySpace::kDefault ? "def" : "alt");
2444 required_assignments_[value].push_back(
2445 {memory_space, /*time=*/root_instruction_time});
2446 }
2447 }
2448 });
2449
2450 for (const HloComputation* computation : module.MakeNonfusionComputations()) {
2451 for (HloInstruction* instruction : computation->instructions()) {
2452 if (instruction->opcode() == HloOpcode::kConstant) {
2453 auto constant_instruction_it = instruction_schedule.find(instruction);
2454 if (constant_instruction_it == instruction_schedule.end()) {
2455 continue;
2456 }
2457 int64_t constant_instruction_time = constant_instruction_it->second;
2458 for (const auto& indexed_shape :
2459 ShapeUtil::GetLeafShapes(instruction->shape())) {
2460 const ShapeIndex& index = indexed_shape.index;
2461 for (const HloBuffer* buffer :
2462 alias_analysis_.ComputeBuffersAt(instruction, index)) {
2463 for (const HloValue* value : buffer->values()) {
2464 VLOG(3) << "Adding required assignment for constant value = "
2465 << value->ToShortString()
2466 << " time = " << constant_instruction_time
2467 << " space = def";
2468 required_assignments_[value].push_back(
2469 {MemorySpace::kDefault, /*time=*/constant_instruction_time});
2470 }
2471 }
2472 }
2473 }
2474 }
2475 }
2476
2477 // Go through all of the values and pin them to the default memory if they are
2478 // not allowed on the alternate memory.
2479 for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) {
2480 if (!options_.is_allowed_in_alternate_mem_fn(*value)) {
2481 // We won't find the instruction in the schedule if it's inside a fusion.
2482 // If so, just skip.
2483 auto instruction_time_it =
2484 instruction_schedule.find(value->instruction());
2485 if (instruction_time_it == instruction_schedule.end()) {
2486 continue;
2487 }
2488 int64_t instruction_time = instruction_time_it->second;
2489 auto& required_assignments = required_assignments_[value];
2490 // Check if there is an existing matching required assignment (e.g.
2491 // inserted by the logic above) and if so ensure it requires a default
2492 // memory allocation.
2493 auto matching_assignment = absl::c_find_if(
2494 required_assignments,
2495 [&](const RequiredMemoryAssignment& required_assignment) {
2496 return required_assignment.time == instruction_time;
2497 });
2498 if (matching_assignment != required_assignments.end()) {
2499 CHECK(matching_assignment->memory_space == MemorySpace::kDefault)
2500 << "Mismatch in required assignments at time " << instruction_time
2501 << " value: " << value->ToString();
2502 } else {
2503 required_assignments.push_back(
2504 {MemorySpace::kDefault, instruction_time});
2505 }
2506 }
2507 }
2508 }
2509
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const2510 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
2511 absl::Span<const BufferInterval* const> colocated_intervals) const {
2512 auto is_position_in_alternate_memory = [&](const HloPosition& position) {
2513 const Shape& shape = position.shape();
2514 return shape.has_layout() &&
2515 shape.layout().memory_space() == options_.alternate_memory_space;
2516 };
2517
2518 const HloModule& module = alias_analysis_.dataflow_analysis().module();
2519 const HloComputation* entry_computation = module.entry_computation();
2520 const HloInstruction* root_instruction =
2521 entry_computation->root_instruction();
2522 for (const BufferInterval* colocated_interval : colocated_intervals) {
2523 const HloValue* value = colocated_interval->buffer;
2524 if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
2525 value->defining_instruction()->parent() == entry_computation &&
2526 is_position_in_alternate_memory(value->defining_position())) {
2527 return true;
2528 }
2529
2530 for (const HloPosition& position : value->positions()) {
2531 if (position.instruction == root_instruction &&
2532 is_position_in_alternate_memory(position)) {
2533 return true;
2534 }
2535 }
2536 }
2537 return false;
2538 }
2539
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)2540 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
2541 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
2542 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2543 allocations.push_back(&allocation_block);
2544 }
2545 }
2546
ImportRepackedAllocations()2547 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
2548 interval_tree_ = {};
2549 for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
2550 MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
2551 VLOG(3) << "Moved " << allocation->ToString() << ", size "
2552 << allocation->chunk().size << ", (" << allocation_block.start_time
2553 << ", " << allocation_block.end_time << ") from "
2554 << allocation_block.initial_offset << " to "
2555 << allocation_block.offset;
2556 allocation_block.allocation->mutable_chunk()->offset =
2557 allocation_block.offset;
2558 interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
2559 {allocation_block.offset, allocation_block.size});
2560 allocation_block.initial_offset = allocation_block.offset;
2561 allocation_block.offset = -1;
2562 }
2563 }
2564
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)2565 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
2566 absl::Span<AllocationValue> allocation_values) {
2567 // Clear the allocation sequence of the allocation values so that in case we
2568 // retry allocation after uncommitting.
2569 for (AllocationValue& allocation_value : allocation_values) {
2570 allocation_value.allocation_sequence()->clear();
2571 }
2572 for (const auto& interval_and_chunk : pending_chunks_) {
2573 const BufferInterval& interval = interval_and_chunk.first;
2574 const Chunk& chunk = interval_and_chunk.second;
2575 VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
2576 << ") off = " << chunk.offset << " size = " << chunk.size;
2577 for (int i = interval.start; i <= interval.end; ++i) {
2578 peak_memory_usage_[i] -= chunk.size;
2579 CHECK_GE(peak_memory_usage_[i], 0)
2580 << "Peak memory usage at " << i
2581 << " is below zero after uncommitting. " << interval.start << "-"
2582 << interval.end << " : [" << chunk.offset << ", " << chunk.size
2583 << "]";
2584 }
2585 interval_tree_.Remove(interval.start, interval.end, chunk);
2586 }
2587 for (const auto& interval : pending_async_copies_) {
2588 if (interval.destination == MemorySpace::kAlternate) {
2589 prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
2590 kDummyChunk);
2591 prefetch_async_copy_resource_.RemoveCopy(interval);
2592 } else {
2593 eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
2594 kDummyChunk);
2595 eviction_async_copy_resource_.RemoveCopy(interval);
2596 }
2597 }
2598 for (const auto& value_and_required_assignment :
2599 pending_required_assignments_) {
2600 auto& required_assignment_vector =
2601 required_assignments_[value_and_required_assignment.first];
2602 const RequiredMemoryAssignment& required_assignment =
2603 value_and_required_assignment.second;
2604 VLOG(3) << "Removing required assignment: "
2605 << (required_assignment.memory_space == MemorySpace::kDefault
2606 ? "def"
2607 : "alt")
2608 << " time = " << required_assignment.time << " off = "
2609 << (required_assignment.offset ? required_assignment.offset->offset
2610 : -1);
2611 for (auto it = required_assignment_vector.begin();
2612 it != required_assignment_vector.end(); ++it) {
2613 if (*it == value_and_required_assignment.second) {
2614 required_assignment_vector.erase(it);
2615 break;
2616 }
2617 }
2618 }
2619 ClearPendingChunks();
2620 }
2621
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)2622 void AlternateMemoryBestFitHeap::FinalizeAllocations(
2623 absl::Span<AllocationValue> allocation_values) {
2624 absl::flat_hash_map<const AliasedOffset*,
2625 std::vector<MemorySpaceAssignment::Allocation*>>
2626 colocation_map;
2627 for (AllocationValue& allocation_value : allocation_values) {
2628 for (auto& allocation : *allocation_value.allocation_sequence()) {
2629 allocations_->push_back(std::move(allocation));
2630 MemorySpaceAssignment::Allocation* inserted_allocation =
2631 allocations_->back().get();
2632 if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
2633 colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
2634 inserted_allocation);
2635 }
2636 }
2637 }
2638 // The allocations that have the same AliasedOffset need to be colocated.
2639 // Export these to repack_allocation_blocks_ so that we can repack them to
2640 // reduce fragmentation.
2641 for (auto& colocation : colocation_map) {
2642 std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
2643 for (MemorySpaceAssignment::Allocation* colocated_allocation :
2644 colocation.second) {
2645 repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
2646 colocated_allocation->start_time(), colocated_allocation->end_time(),
2647 colocated_allocation->chunk().size,
2648 colocated_allocation->chunk().offset,
2649 static_cast<int64_t>(repack_allocation_blocks_.size()),
2650 colocated_allocation));
2651 colocations.push_back(&repack_allocation_blocks_.back());
2652 }
2653 for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
2654 colocations) {
2655 repack_block->colocations = colocations;
2656 }
2657 }
2658 ClearPendingChunks();
2659 }
2660
ClearPendingChunks()2661 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
2662 pending_chunks_.clear();
2663 pending_async_copies_.clear();
2664 pending_required_assignments_.clear();
2665 aliased_offset_map_.clear();
2666 aliased_offsets_.clear();
2667 }
2668
AddToPendingChunks(const BufferInterval & buffer_interval,const Chunk & chunk_candidate)2669 void AlternateMemoryBestFitHeap::AddToPendingChunks(
2670 const BufferInterval& buffer_interval, const Chunk& chunk_candidate) {
2671 VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
2672 << buffer_interval.end << " : [" << chunk_candidate.offset << ", "
2673 << chunk_candidate.size << "]";
2674 pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
2675 for (int i = buffer_interval.start; i <= buffer_interval.end; ++i) {
2676 peak_memory_usage_[i] += chunk_candidate.size;
2677 CHECK_LE(peak_memory_usage_[i], options_.max_size_in_bytes)
2678 << "Peak memory usage at " << i
2679 << " exceeds the max size of alternate memory. "
2680 << buffer_interval.start << "-" << buffer_interval.end << " : ["
2681 << chunk_candidate.offset << ", " << chunk_candidate.size << "]";
2682 }
2683 CommitChunk(buffer_interval, chunk_candidate);
2684 }
2685
2686 std::optional<int>
FindEarliestTimeToSatisfyPeakMemory(int start_time,int end_time,int64_t size) const2687 AlternateMemoryBestFitHeap::FindEarliestTimeToSatisfyPeakMemory(
2688 int start_time, int end_time, int64_t size) const {
2689 int earliest_time;
2690 for (earliest_time = end_time;
2691 earliest_time >= start_time &&
2692 peak_memory_usage_[earliest_time] + size <= options_.max_size_in_bytes;
2693 --earliest_time) {
2694 }
2695 if (earliest_time == end_time) {
2696 return std::nullopt;
2697 }
2698 return earliest_time + 1;
2699 }
2700
AllocateSegment(const AllocationRequest & request)2701 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
2702 const AllocationRequest& request) {
2703 auto allocation_sequence = request.allocation_value->allocation_sequence();
2704 // start_time == end_time is a special case where the value is consumed
2705 // multiple times by the same instruction. We can just find the previous
2706 // allocation and use that allocation.
2707 if (request.start_time == request.end_time) {
2708 MemorySpaceAssignment::Allocation* allocation =
2709 GetLiveAllocationAt(*allocation_sequence, request.end_time);
2710 CHECK_NE(allocation, nullptr);
2711 allocation->AddUse(request.use->hlo_use);
2712 return Result::kSuccess;
2713 }
2714
2715 const HloPosition& defining_position =
2716 request.allocation_value->defining_position();
2717 VLOG(2) << "Finding allocation for "
2718 << request.allocation_value->ToShortString() << " ("
2719 << request.start_time << ", " << request.end_time
2720 << ") latest prefetch = " << request.latest_prefetch_time
2721 << " last use = " << request.allocation_value->uses().back().time
2722 << " use = " << request.use->hlo_use.ToString()
2723 << ". Size = " << request.size
2724 << ", def pos = " << defining_position.ToString();
2725 CHECK_LE(request.start_time, request.end_time);
2726 if (VLOG_IS_ON(3) && options_.cost_analysis) {
2727 VLOG(3) << "Definition benefit = "
2728 << options_.cost_analysis->GetAlternateMemoryBenefit(
2729 request.allocation_value->defining_position())
2730 << " use benefit = "
2731 << options_.cost_analysis->GetAlternateMemoryBenefit(
2732 request.use->hlo_use);
2733 }
2734
2735 // There could be a requirement to pin this buffer to default memory either
2736 // because it is a parameter or an output. If the buffer is a parameter, then
2737 // we're allowed to prefetch. If the use expects the output to be in default
2738 // memory, we cannot prefetch it because if we did, it would be in alternate
2739 // memory instead.
2740 auto required_assignment_at_start = RequiredMemoryAssignmentAt(
2741 request.allocation_value->value(), request.start_time);
2742 std::optional<MemorySpace> required_memory_space_at_start;
2743 if (required_assignment_at_start) {
2744 required_memory_space_at_start = required_assignment_at_start->memory_space;
2745 }
2746 // Find required assignment both for the use and its aliases. If they are both
2747 // non-nullopt, then make sure they require the same assignment.
2748 auto required_assignment_at_end = RequiredMemoryAssignmentAt(
2749 request.allocation_value->value(), request.end_time);
2750 auto aliased_required_assignment_at_end =
2751 AliasedRequiredAssignmentForUse(*request.use);
2752 if (required_assignment_at_end != aliased_required_assignment_at_end) {
2753 if (required_assignment_at_end == std::nullopt) {
2754 required_assignment_at_end = aliased_required_assignment_at_end;
2755 } else {
2756 CHECK(aliased_required_assignment_at_end == std::nullopt ||
2757 aliased_required_assignment_at_end->equals_ignoring_time(
2758 *required_assignment_at_end));
2759 }
2760 }
2761 std::optional<MemorySpace> required_memory_space_at_end;
2762 if (required_assignment_at_end) {
2763 required_memory_space_at_end = required_assignment_at_end->memory_space;
2764 }
2765
2766 if (required_assignment_at_start) {
2767 bool needs_required_allocation = true;
2768 if (!allocation_sequence->empty()) {
2769 auto prev_allocation_it = std::find_if(
2770 allocation_sequence->rbegin(), allocation_sequence->rend(),
2771 [&](const auto& allocation) {
2772 return allocation->memory_space() ==
2773 required_memory_space_at_start &&
2774 allocation->defining_position() == defining_position;
2775 });
2776 if (prev_allocation_it != allocation_sequence->rend()) {
2777 (*prev_allocation_it)->Extend(request.start_time);
2778 needs_required_allocation = false;
2779 }
2780 }
2781 if (needs_required_allocation) {
2782 std::optional<Chunk> aliased_chunk = std::nullopt;
2783 if (required_assignment_at_start->memory_space ==
2784 MemorySpace::kAlternate) {
2785 aliased_chunk =
2786 Chunk{required_assignment_at_start->offset->offset, request.size};
2787 }
2788 allocation_sequence->push_back(
2789 std::make_unique<MemorySpaceAssignment::Allocation>(
2790 defining_position, required_assignment_at_start->memory_space,
2791 aliased_chunk, request.start_time, request.start_time,
2792 /*is_scoped_allocation=*/false));
2793 if (required_assignment_at_start->memory_space ==
2794 MemorySpace::kAlternate) {
2795 CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2796 required_assignment_at_start->offset);
2797 }
2798 }
2799 }
2800
2801 Result allocation_result = Result::kSuccess;
2802 // First try keeping the allocation entirely in the alternate memory.
2803 if (required_memory_space_at_start != MemorySpace::kDefault &&
2804 required_memory_space_at_end != MemorySpace::kDefault &&
2805 request.allow_no_copy_alternate_mem_allocation) {
2806 allocation_result = AllocateInAlternateMemoryNoCopy(request);
2807 if (allocation_result == Result::kSuccess) {
2808 return Result::kSuccess;
2809 }
2810 }
2811
2812 auto prev_allocation_it = allocation_sequence->rbegin();
2813 // Find a previous allocation that is in the default memory space (not
2814 // necessarily the very last allocation).
2815 auto prev_allocation_in_default_mem_it =
2816 std::find_if(allocation_sequence->rbegin(), allocation_sequence->rend(),
2817 [&](const auto& allocation) {
2818 return allocation->memory_space() == MemorySpace::kDefault;
2819 });
2820
2821 if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2822 prev_allocation_it != allocation_sequence->rend() &&
2823 (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2824 (*prev_allocation_it)->defining_position() == defining_position &&
2825 !request.allocation_value->requires_contiguous_allocation()) {
2826 // If there was an allocation for this HloValue that was in the alternate
2827 // memory space, we also need to perform an eviction.
2828 Result eviction_result = Evict(request);
2829 if (eviction_result != Result::kSuccess) {
2830 // A non-success eviction requires us to uncommit previous allocations.
2831 return result_mark(Result::kFailRequiresUncommit, eviction_result);
2832 }
2833 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2834 } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2835 allocation_sequence->push_back(
2836 std::make_unique<MemorySpaceAssignment::Allocation>(
2837 defining_position, MemorySpace::kDefault, /*chunk=*/std::nullopt,
2838 request.start_time, request.end_time,
2839 /*is_scoped_allocation=*/false));
2840 prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2841 }
2842
2843 CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2844 CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2845 MemorySpace::kDefault);
2846
2847 // If the buffer must be in default memory at the end_time, don't prefetch.
2848 if (required_memory_space_at_end == MemorySpace::kDefault) {
2849 VLOG(3)
2850 << "Not trying to prefetch because use requires buffer in default mem.";
2851 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2852 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2853 return Result::kSuccess;
2854 }
2855
2856 // Finally, try to prefetch the buffer into alternate memory.
2857 if (!request.allocation_value->requires_contiguous_allocation()) {
2858 Result prefetch_result =
2859 Prefetch(request, **prev_allocation_in_default_mem_it);
2860 if (prefetch_result == Result::kSuccess) {
2861 return Result::kSuccess;
2862 }
2863 result_mark(prefetch_result, allocation_result);
2864 }
2865
2866 // If the end assignment was required to be in alternate memory but that
2867 // wasn't possible, then this allocation is invalid.
2868 if (required_memory_space_at_end == MemorySpace::kAlternate) {
2869 return result_mark(Result::kFailRequiresUncommit, allocation_result);
2870 }
2871
2872 // If the start assignment was required to be in alternate memory and the
2873 // buffer needs a contiguous assignment, we couldn't satisfy this requirement
2874 // and must abort.
2875 if (required_memory_space_at_start == MemorySpace::kAlternate &&
2876 request.allocation_value->requires_contiguous_allocation()) {
2877 return result_mark(Result::kFailRequiresUncommit, allocation_result);
2878 }
2879
2880 // If a copy wasn't inserted, then add this use to the latest allocation in
2881 // default memory.
2882 (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2883 (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2884 return allocation_result;
2885 }
2886
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,std::optional<Chunk> chunk,int64_t start_time,int64_t end_time,int64_t copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,float resource,bool is_cross_program_prefetch)2887 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2888 const MemorySpaceAssignment::Allocation& prev_allocation,
2889 MemorySpace memory_space, std::optional<Chunk> chunk, int64_t start_time,
2890 int64_t end_time, int64_t copy_done_schedule_before_time,
2891 MemorySpaceAssignment::AllocationSequence* allocations,
2892 AliasedOffset* aliased_offset, float resource,
2893 bool is_cross_program_prefetch) {
2894 VLOG(3) << "Copy to "
2895 << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2896 ? "default"
2897 : "alternate")
2898 << " memory between " << start_time << " and "
2899 << copy_done_schedule_before_time << " keeping until " << end_time
2900 << ", estimated copy resource is " << resource;
2901 CHECK_LT(start_time, copy_done_schedule_before_time);
2902
2903 allocations->push_back(
2904 std::make_unique<MemorySpaceAssignment::CopyAllocation>(
2905 prev_allocation, memory_space, chunk, start_time, end_time,
2906 copy_done_schedule_before_time, is_cross_program_prefetch));
2907
2908 // Register the additional async copy with the interval tree to keep track of
2909 // the limit at any given time.
2910 pending_async_copies_.push_back({start_time, copy_done_schedule_before_time,
2911 resource, memory_space,
2912 next_async_copy_id_++});
2913 if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2914 prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2915 kDummyChunk);
2916 prefetch_async_copy_resource_.AddCopy(pending_async_copies_.back());
2917 CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2918 } else {
2919 eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2920 kDummyChunk);
2921 eviction_async_copy_resource_.AddCopy(pending_async_copies_.back());
2922 }
2923 }
2924
ViolatesMaximumOutstandingAsyncCopies(int64_t start_time,int64_t end_time,bool is_prefetch,int64_t extra_async_copy_limit) const2925 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2926 int64_t start_time, int64_t end_time, bool is_prefetch,
2927 int64_t extra_async_copy_limit) const {
2928 if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2929 return false;
2930 }
2931 if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2932 return false;
2933 }
2934
2935 // Count the prefetches/evictions in the interval tree for the given interval.
2936 if (is_prefetch) {
2937 int64_t num_prefetches =
2938 prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2939 .size();
2940 return num_prefetches >=
2941 options_.max_outstanding_prefetches + extra_async_copy_limit;
2942 } else {
2943 int64_t num_evictions =
2944 eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2945 .size();
2946 return num_evictions >=
2947 options_.max_outstanding_evictions + extra_async_copy_limit;
2948 }
2949 }
2950
2951 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2952 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2953 const AllocationRequest& request) {
2954 MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2955 bool can_eliminate_copy = false;
2956 if (request.allocation_value->allocation_sequence()->empty()) {
2957 // There hasn't been any allocations for this interval so far. We can
2958 // eliminate copy if the value can be placed in the alternate memory.
2959 can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2960 *request.allocation_value->value());
2961 } else {
2962 // If there has been a previous allocation, we can eliminate the copy if the
2963 // previous allocation was also in the alternate memory.
2964 prev_allocation =
2965 request.allocation_value->allocation_sequence()->back().get();
2966 can_eliminate_copy =
2967 (prev_allocation->memory_space() == MemorySpace::kAlternate);
2968 }
2969
2970 if (!can_eliminate_copy) {
2971 return Result::kFailPrevAllocationNotInAlternateMem;
2972 }
2973
2974 const HloPosition& defining_position =
2975 request.allocation_value->defining_position();
2976 if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2977 defining_position.shape(), request.start_time + 1,
2978 request.end_time)) {
2979 return Result::kFailLiveRangeTooLong;
2980 }
2981
2982 BufferInterval alternate_mem_interval;
2983 alternate_mem_interval.buffer = request.allocation_value->value();
2984 alternate_mem_interval.size = request.size;
2985 alternate_mem_interval.end = request.end_time;
2986 alternate_mem_interval.start = request.start_time;
2987
2988 // Prefer the offset that was previously used for the previous allocation.
2989 AliasedOffset* preferred_offset = nullptr;
2990 if (prev_allocation != nullptr) {
2991 preferred_offset = GetAliasedOffset(*prev_allocation);
2992 // If there is a previous allocation, set the start time one after the end
2993 // of the previous allocation's end.
2994 alternate_mem_interval.start = prev_allocation->end_time() + 1;
2995 }
2996
2997 if (request.preferred_offset) {
2998 // Sanity check that if there is a preferred offset provided in the request,
2999 // it matches with the previous allocation.
3000 CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
3001 << "preferred_offset = " << preferred_offset->offset
3002 << ", request.preferred_offset = " << request.preferred_offset->offset;
3003 preferred_offset = request.preferred_offset;
3004 }
3005
3006 VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
3007 << (preferred_offset ? preferred_offset->offset : -1);
3008 // In case there are additional uses after this use, we rely on the last use
3009 // time to try to reserve a chunk in the heap simulator. This is to prevent
3010 // the following scenario:
3011 //
3012 // +-------+
3013 // / \
3014 // Producer--->Use1 +-->Use2
3015 // +---------+---------+
3016 // New buffer: | | |
3017 // +---------+---------+
3018 //
3019 // +-----------+
3020 // Current heap: | offset: 0 |
3021 // --------------------------+-----------+------
3022 //
3023 // Because we allocate buffers greedily, Producer to Use1 segment first, and
3024 // then Use1 to Use2 segment, it is possible to allocate the first segment at
3025 // an offset that is available for the first segment (e.g. offset 0) but not
3026 // for the entire live range. This can result in unnecessary copies. By using
3027 // the last use time, we try to find an allocation that is available for the
3028 // entire Producer to Use2 range.
3029 std::optional<Chunk> chunk_candidate = FindBestChunkCandidate(
3030 request, preferred_offset, &alternate_mem_interval);
3031 // Check if the new heap size fits within limits. Also ensure if a
3032 // preferred offset was provided, that offset was used.
3033 if (chunk_candidate) {
3034 VLOG(3) << "Keep the buffer in alternate memory. Offset = "
3035 << chunk_candidate->offset << ", size = " << chunk_candidate->size
3036 << ", heap_size = " << result_.UpdatedHeapSize(*chunk_candidate)
3037 << ", prefetch picker = "
3038 << options_.prefetch_interval_picker->ToNoCopyDebugString(
3039 defining_position.shape(), request.start_time,
3040 request.end_time);
3041 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
3042
3043 // If there was a previous allocation, the buffer location is the
3044 // same as the previous. Otherwise, it is the operand.
3045 if (prev_allocation != nullptr &&
3046 (prev_allocation->is_copy_allocation() ||
3047 prev_allocation->defining_position() == defining_position)) {
3048 prev_allocation->Extend(request.end_time);
3049 } else {
3050 request.allocation_value->allocation_sequence()->push_back(
3051 std::make_unique<MemorySpaceAssignment::Allocation>(
3052 defining_position, MemorySpace::kAlternate, chunk_candidate,
3053 request.start_time, request.end_time,
3054 /*is_scoped_allocation=*/false));
3055 CreateOrAddToAliasedOffset(
3056 *request.allocation_value->allocation_sequence()->back(),
3057 preferred_offset);
3058 }
3059 request.allocation_value->allocation_sequence()->back()->AddUse(
3060 request.use->hlo_use);
3061 return Result::kSuccess;
3062 }
3063 return Result::kFailOutOfMemory;
3064 }
3065
Evict(const AllocationRequest & request)3066 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
3067 const AllocationRequest& request) {
3068 CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
3069 MemorySpaceAssignment::Allocation* prev_allocation =
3070 request.allocation_value->allocation_sequence()->back().get();
3071 int64_t eviction_start_time = prev_allocation->start_time();
3072 int64_t eviction_end_time = prev_allocation->end_time();
3073 CHECK(eviction_start_time <= eviction_end_time);
3074
3075 int64_t preferred_eviction_end_time =
3076 std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
3077 request.allocation_value->defining_position().shape(),
3078 eviction_start_time, request.end_time),
3079 eviction_end_time);
3080 // Evictions must complete by the time of this use.
3081 preferred_eviction_end_time =
3082 std::min(preferred_eviction_end_time, request.latest_prefetch_time);
3083
3084 BufferInterval eviction_mem_interval;
3085 eviction_mem_interval.buffer = request.allocation_value->value();
3086 eviction_mem_interval.size = request.size;
3087 // Try to reserve a buffer from the end of the previous allocation to the
3088 // preferred eviction end time.
3089 eviction_mem_interval.start = eviction_end_time + 1;
3090 eviction_mem_interval.end = preferred_eviction_end_time;
3091 int64_t preferred_offset = prev_allocation->chunk().offset;
3092 VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
3093 << ") preferred end time = " << eviction_mem_interval.end;
3094
3095 for (; eviction_mem_interval.end > eviction_end_time;
3096 --eviction_mem_interval.end) {
3097 Chunk chunk_candidate =
3098 FindChunkCandidate(eviction_mem_interval, preferred_offset);
3099 if (chunk_candidate.offset == preferred_offset) {
3100 AddToPendingChunks(eviction_mem_interval, chunk_candidate);
3101 break;
3102 }
3103 }
3104 eviction_end_time = eviction_mem_interval.end;
3105
3106 VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
3107 << eviction_start_time << ", " << eviction_end_time << ")";
3108
3109 float eviction_resource =
3110 options_.cost_analysis
3111 ? options_.cost_analysis->GetAsyncCopyElapsed(
3112 request.allocation_value->defining_position().shape())
3113 : 0.1;
3114
3115 bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
3116 bool eviction_violates_resource =
3117 !eviction_async_copy_resource_.HasEnoughResource(
3118 eviction_start_time, eviction_end_time, eviction_resource);
3119 if (eviction_violates_resource) {
3120 // If we're in the last retry, set resource to 0.
3121 if (options_.prefetch_interval_picker->retry_number() ==
3122 options_.max_retries - 1) {
3123 VLOG(3) << "Violates resource in last retry, setting resource = 0";
3124 eviction_resource = 0;
3125 }
3126 eviction_violates_resource =
3127 !eviction_async_copy_resource_.HasEnoughResource(
3128 eviction_start_time, eviction_end_time, eviction_resource);
3129 }
3130 bool eviction_violates_outstanding_copies =
3131 ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
3132 eviction_end_time,
3133 /*is_prefetch=*/false);
3134
3135 // See if this interval would violate the asynchronous copy limit.
3136 if (!eviction_interval_too_short && !eviction_violates_outstanding_copies &&
3137 !eviction_violates_resource) {
3138 prev_allocation->Extend(eviction_end_time);
3139 AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
3140 /*chunk=*/std::nullopt, eviction_start_time,
3141 prev_allocation->end_time(), eviction_end_time,
3142 request.allocation_value->allocation_sequence(),
3143 /*aliased_offset=*/nullptr, eviction_resource);
3144 } else {
3145 if (eviction_violates_outstanding_copies) {
3146 VLOG(3) << "This violates the maximum async copies.";
3147 } else if (eviction_violates_resource) {
3148 VLOG(3) << "This violates resource.";
3149 } else {
3150 VLOG(3) << "Eviction interval is too short (" << eviction_start_time
3151 << ", " << eviction_end_time << ").";
3152 }
3153 // If the original interval violated the limit, try sub-intervals within
3154 // this interval.
3155 bool eviction_scheduled = false;
3156
3157 if (!eviction_scheduled) {
3158 // If the eviction couldn't be scheduled, then fail. This buffer will be
3159 // kept in the default memory.
3160 VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
3161 << " because we hit the limit of maximum asynchronous copies "
3162 << "between "
3163 << hlo_live_range_.flattened_instruction_sequence()
3164 .instructions()[eviction_start_time]
3165 << " and "
3166 << hlo_live_range_.flattened_instruction_sequence()
3167 .instructions()[eviction_end_time];
3168 // return false;
3169 return Result::kFailOutOfAsyncCopies;
3170 }
3171 }
3172 // return true;
3173 return Result::kSuccess;
3174 }
3175
FindPrefetchEndTime(const AllocationRequest & request,int64_t earliest_prefetch_time) const3176 int64_t AlternateMemoryBestFitHeap::FindPrefetchEndTime(
3177 const AllocationRequest& request, int64_t earliest_prefetch_time) const {
3178 return request.latest_prefetch_time;
3179 }
3180
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)3181 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
3182 const AllocationRequest& request,
3183 const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
3184 // Try partially placing the buffer in the alternate space. The time that is
3185 // overlapped will be used to asynchronously copy the buffer from the
3186 // default memory to the alternate memory.
3187 //
3188 // start end
3189 // time time
3190 // X---------------------X
3191 // Alternate: +------+
3192 // Default: +---------------------+
3193 // ^ ^
3194 // Copy Copy
3195 // Start Done
3196 int64_t earliest_prefetch_time =
3197 prev_allocation_in_default_mem.earliest_available_time();
3198 if (request.earliest_prefetch_time) {
3199 earliest_prefetch_time =
3200 std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
3201 }
3202 int64_t prefetch_end_time =
3203 FindPrefetchEndTime(request, earliest_prefetch_time);
3204
3205 // As a compile time optimization, use the peak memory usage to filter out
3206 // allocation times that would push us to OOM.
3207 std::optional<int> earliest_non_oom_prefetch_time =
3208 FindEarliestTimeToSatisfyPeakMemory(earliest_prefetch_time,
3209 prefetch_end_time, request.size);
3210 Result result = Result::kSuccess;
3211 if (!earliest_non_oom_prefetch_time) {
3212 VLOG(3) << "Any prefetch in range (" << earliest_prefetch_time << ", "
3213 << prefetch_end_time << ") for size " << request.size
3214 << " would go out of memory.";
3215 result_mark(Result::kFailOutOfMemory, result);
3216 return result;
3217 }
3218 VLOG(4) << "After peak memory check, prefetch range is ("
3219 << *earliest_non_oom_prefetch_time << ", " << prefetch_end_time
3220 << "). Original earliest prefetch time is " << earliest_prefetch_time;
3221 earliest_prefetch_time = *earliest_non_oom_prefetch_time;
3222 options_.prefetch_interval_picker->Begin(
3223 request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
3224 VLOG(3) << "Trying prefetch picker = "
3225 << options_.prefetch_interval_picker->ToDebugString();
3226
3227 // Create an alternate memory interval that starts at the earliest
3228 // possible position, given by max_prefetch_interval.
3229 BufferInterval alternate_mem_interval;
3230 alternate_mem_interval.buffer = request.allocation_value->value();
3231 alternate_mem_interval.size = request.size;
3232 // As a compile time optimization, try a prefetch allocation that is as late
3233 // as possible. If this is not able to find a chunk candidate, none of the
3234 // earlier tries will succeed either.
3235 alternate_mem_interval.start =
3236 options_.prefetch_interval_picker->latest_time();
3237 auto chunk_candidate = FindBestChunkCandidate(
3238 request, request.preferred_offset, &alternate_mem_interval);
3239 if (!chunk_candidate) {
3240 VLOG(3) << "The latest prefetch (" << alternate_mem_interval.start << ", "
3241 << request.end_time << ") cannot find a valid chunk. Giving up.";
3242 result_mark(Result::kFailOutOfMemory, result);
3243 return result;
3244 }
3245 const HloUse& use = request.use->hlo_use;
3246 const Shape& shape = ShapeUtil::GetSubshape(
3247 use.instruction->operand(use.operand_number)->shape(), use.operand_index);
3248 // While uses might be allowed to have additional outstanding prefetches.
3249 int64_t extra_async_copy_limit =
3250 request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
3251 ? options_.while_use_extra_outstanding_prefetch_limit
3252 : 0;
3253 // As a compilation time optimization, store the prefetch start time where we
3254 // have first seen out of memory. There is no point of exploring prefetch
3255 // start times earlier than this point.
3256 std::optional<int64_t> out_of_mem_start;
3257 while (!options_.prefetch_interval_picker->Done()) {
3258 alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
3259 CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
3260 if (out_of_mem_start.has_value() &&
3261 alternate_mem_interval.start <= *out_of_mem_start) {
3262 VLOG(4) << "This would OOM (cached).";
3263 result_mark(Result::kFailOutOfMemory, result);
3264 continue;
3265 }
3266 int64_t estimated_prefetch_end_time =
3267 options_.prefetch_interval_picker->EstimatedPrefetchEndTime(
3268 shape, alternate_mem_interval.start, prefetch_end_time);
3269 VLOG(4) << "Trying alternate memory allocation ("
3270 << alternate_mem_interval.start << ", " << request.end_time
3271 << "), estimated prefetch end time = "
3272 << estimated_prefetch_end_time;
3273 float prefetch_resource =
3274 options_.cost_analysis
3275 ? options_.cost_analysis->GetAsyncCopyElapsed(shape)
3276 : 0.1;
3277 if (!prefetch_async_copy_resource_.HasEnoughResource(
3278 alternate_mem_interval.start, prefetch_end_time,
3279 prefetch_resource)) {
3280 VLOG(4) << "This would violate asynchronous copy resource = "
3281 << prefetch_resource;
3282 result_mark(Result::kFailViolatesAsyncCopyResource, result);
3283 continue;
3284 }
3285 if (ViolatesMaximumOutstandingAsyncCopies(
3286 alternate_mem_interval.start, prefetch_end_time,
3287 /*is_prefetch=*/true, extra_async_copy_limit)) {
3288 VLOG(4) << "This would violate the outstanding async copy limit.";
3289 result_mark(Result::kFailOutOfAsyncCopies, result);
3290 continue;
3291 }
3292
3293 auto chunk_candidate = FindBestChunkCandidate(
3294 request, request.preferred_offset, &alternate_mem_interval);
3295 // Check if we could find a suitable chunk.
3296 if (chunk_candidate) {
3297 VLOG(3) << "Move the buffer to alternate memory at "
3298 << alternate_mem_interval.start
3299 << ". Offset = " << chunk_candidate->offset
3300 << ", size = " << chunk_candidate->size
3301 << ", heap_size = " << result_.UpdatedHeapSize(*chunk_candidate)
3302 << ", prefetch picker = "
3303 << options_.prefetch_interval_picker->ToDebugString();
3304 AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
3305
3306 AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
3307 chunk_candidate, alternate_mem_interval.start,
3308 request.end_time, prefetch_end_time,
3309 request.allocation_value->allocation_sequence(),
3310 request.preferred_offset, prefetch_resource);
3311
3312 request.allocation_value->allocation_sequence()->back()->AddUse(
3313 request.use->hlo_use);
3314 return Result::kSuccess;
3315 } else {
3316 // Mark the out of memory start with the prefetch start time so that we
3317 // don't explore prefetch start times earlier than this point.
3318 out_of_mem_start =
3319 std::max(out_of_mem_start.has_value() ? *out_of_mem_start : -1,
3320 alternate_mem_interval.start);
3321 }
3322 result_mark(Result::kFailOutOfMemory, result);
3323 }
3324 // If we didn't consider any prefetch intervals, then the live range was too
3325 // short.
3326 if (result == Result::kSuccess) {
3327 return Result::kFailLiveRangeTooShort;
3328 } else {
3329 return result;
3330 }
3331 }
3332
3333 std::optional<AlternateMemoryBestFitHeap::Chunk>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const3334 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
3335 const AllocationRequest& request, const AliasedOffset* preferred_offset,
3336 BufferInterval* alternate_mem_interval) const {
3337 int64_t end_time = request.end_time;
3338 if (!preferred_offset) {
3339 // First find the earliest use that is the same or later than the end time.
3340 const auto& use_times = request.all_use_times;
3341 auto use_time_it = absl::c_lower_bound(use_times, end_time);
3342 CHECK(use_time_it != use_times.end());
3343 int64_t earliest_use = *use_time_it;
3344 auto earliest_use_it = use_time_it;
3345
3346 // Then find the latest use that can be allocated contiguously without
3347 // copies.
3348 const Shape& shape = request.allocation_value->defining_position().shape();
3349 for (;
3350 (use_time_it + 1) != use_times.end() &&
3351 options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
3352 shape, *use_time_it, *(use_time_it + 1));
3353 ++use_time_it) {
3354 }
3355 CHECK(use_time_it != use_times.end());
3356 int64_t latest_contiguous_use_time = *use_time_it;
3357
3358 // Find a chunk that's as long living as possible.
3359 std::optional<Chunk> last_chunk_candidate;
3360 int64_t latest_matching_use = std::numeric_limits<int64_t>::min();
3361 std::lower_bound(
3362 earliest_use_it, std::next(use_time_it), -1, [&](int64_t use, int64_t) {
3363 alternate_mem_interval->end = use;
3364 Chunk chunk_candidate = FindChunkCandidate(*alternate_mem_interval);
3365 if (chunk_candidate.chunk_end() <= available_heap_size()) {
3366 if (use > latest_matching_use) {
3367 last_chunk_candidate = chunk_candidate;
3368 latest_matching_use = use;
3369 }
3370 return true;
3371 }
3372 return false;
3373 });
3374 if (last_chunk_candidate.has_value()) {
3375 VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
3376 << ", latest contiguous use = " << latest_contiguous_use_time
3377 << ", use with available mem = " << latest_matching_use
3378 << ", offset = " << last_chunk_candidate->offset;
3379 }
3380 alternate_mem_interval->end = end_time;
3381 return last_chunk_candidate;
3382 }
3383 // If a preferred offset is given, try to find an allocation at that offset
3384 // only.
3385 alternate_mem_interval->end = end_time;
3386 Chunk chunk_candidate =
3387 FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
3388 if (chunk_candidate.offset == preferred_offset->offset) {
3389 return chunk_candidate;
3390 }
3391 return std::nullopt;
3392 }
3393
3394 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const3395 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
3396 AsyncCopyStats stats;
3397 stats.max_outstanding_async_copies = 0;
3398 stats.num_prefetches = 0;
3399 stats.prefetch_bytes = 0;
3400 stats.num_evictions = 0;
3401 stats.eviction_bytes = 0;
3402 int64_t current_copies = 0;
3403 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
3404 HloDataflowAnalysis::Run(*module_));
3405 for (const HloComputation* computation :
3406 module_->MakeNonfusionComputations()) {
3407 for (HloInstruction* instruction : computation->instructions()) {
3408 if (instruction->opcode() == HloOpcode::kCopyStart) {
3409 current_copies++;
3410 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
3411 current_copies--;
3412 int64_t size =
3413 options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
3414 if (instruction->shape().layout().memory_space() ==
3415 options_.alternate_memory_space) {
3416 ++stats.num_prefetches;
3417 stats.prefetch_bytes += size;
3418 } else {
3419 ++stats.num_evictions;
3420 stats.eviction_bytes += size;
3421 }
3422 }
3423 stats.max_outstanding_async_copies =
3424 std::max(stats.max_outstanding_async_copies, current_copies);
3425 }
3426 }
3427 return stats;
3428 }
3429
3430 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)3431 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
3432 const MemorySpaceAssignmentCostAnalysis& cost_analysis,
3433 MemorySpaceAssignmentCostAnalysis::Cache* cache) {
3434 return [&cost_analysis, cache](const BufferInterval& x,
3435 const BufferInterval& y) {
3436 float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
3437 float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
3438 if (x_memory_boundedness != y_memory_boundedness) {
3439 return x_memory_boundedness > y_memory_boundedness;
3440 }
3441 // Tie-break if the memory boundedness is the same.
3442 return GlobalDecreasingSizeBestFitHeap<
3443 HloValue>::GetSpatialBufferIntervalCompare()(x, y);
3444 };
3445 }
3446
3447 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)3448 MemorySpaceAssignment::Run(HloModule* module,
3449 const HloLiveRange& hlo_live_range,
3450 const HloAliasAnalysis& alias_analysis,
3451 const Options& options) {
3452 CHECK(module->has_schedule());
3453 VLOG(3) << "Module before memory space assignment: ";
3454 XLA_VLOG_LINES(3, module->ToString());
3455 VLOG(3) << "Schedule: " << module->schedule().ToString();
3456 MemorySpaceAssignment memory_space_assignment(module, options,
3457 hlo_live_range);
3458
3459 return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
3460 alias_analysis);
3461 }
3462
3463 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)3464 MemorySpaceAssignment::RunMemorySpaceAssignment(
3465 const HloLiveRange& hlo_live_range,
3466 const HloAliasAnalysis& alias_analysis) {
3467 TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
3468
3469 if (options_.cost_analysis) {
3470 float estimated_time =
3471 ComputeEstimatedElapsedTime(hlo_live_range, allocations_);
3472 VLOG(1) << "Estimated elapsed time (sec): " << estimated_time;
3473 }
3474
3475 TF_RETURN_IF_ERROR(Process());
3476 ScheduleAsynchronousCopies();
3477 TF_RETURN_IF_ERROR(SimplifyGraph());
3478 TF_RETURN_IF_ERROR(FixSchedule());
3479 TF_RETURN_IF_ERROR(ExportAndColorBuffers());
3480
3481 VLOG(3) << "Module after memory space assignment: ";
3482 XLA_VLOG_LINES(3, module_->ToString());
3483 TF_CHECK_OK(module_->schedule().Verify());
3484 TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
3485 VLOG(1) << "Maximum number of outstanding async copies: "
3486 << stats.max_outstanding_async_copies;
3487 VLOG(1) << "Number of prefetches: " << stats.num_prefetches
3488 << ", in bytes: " << stats.prefetch_bytes;
3489 VLOG(1) << "Number of evictions: " << stats.num_evictions
3490 << ", in bytes: " << stats.eviction_bytes;
3491
3492 TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
3493
3494 return std::move(preset_assignments_);
3495 }
3496
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)3497 Status MemorySpaceAssignment::FindAllocationSequence(
3498 const HloLiveRange& hlo_live_range,
3499 const HloAliasAnalysis& alias_analysis) {
3500 auto algorithm = std::make_unique<AlternateMemoryBestFitHeap>(
3501 &allocations_, options_, alias_analysis, hlo_live_range);
3502
3503 HeapSimulator::Options heap_simulator_options;
3504 heap_simulator_options.may_reuse_operand_buffers = false;
3505 heap_simulator_options.alloc_constants = true;
3506 TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
3507 module_->schedule(), alias_analysis,
3508 options_.size_fn,
3509 heap_simulator_options)
3510 .status());
3511 return OkStatus();
3512 }
3513
AddUse(HloUse use)3514 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
3515 HloInstruction* operand =
3516 use.instruction->mutable_operand(use.operand_number);
3517 // If the use is a tuple, look inside the tuple to find the actual use.
3518 for (int64_t index : use.operand_index) {
3519 if (operand->opcode() != HloOpcode::kTuple) {
3520 break;
3521 }
3522 operand = operand->mutable_operand(index);
3523 }
3524
3525 // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
3526 std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
3527 get_simplified_operand = [&](HloInstruction* instruction) {
3528 while (instruction->opcode() == HloOpcode::kGetTupleElement) {
3529 HloInstruction* operand =
3530 get_simplified_operand(instruction->mutable_operand(0));
3531 if (operand->opcode() == HloOpcode::kTuple) {
3532 instruction = operand->mutable_operand(instruction->tuple_index());
3533 } else {
3534 return instruction;
3535 }
3536 }
3537 return instruction;
3538 };
3539 operand = get_simplified_operand(operand);
3540
3541 uses_.push_back(use);
3542 }
3543
ComputeEstimatedElapsedTime(const HloLiveRange & hlo_live_range,const AllocationSequence & allocations)3544 float MemorySpaceAssignment::ComputeEstimatedElapsedTime(
3545 const HloLiveRange& hlo_live_range, const AllocationSequence& allocations) {
3546 absl::flat_hash_map<const HloInstruction*, std::vector<ShapeIndex>>
3547 outputs_in_alternate_memory_map;
3548 absl::flat_hash_map<const HloInstruction*,
3549 std::vector<std::pair<int64_t, ShapeIndex>>>
3550 operands_in_alternate_memory_map;
3551
3552 for (auto& allocation : allocations) {
3553 if (!allocation->is_copy_allocation()) {
3554 if (allocation->memory_space() == MemorySpace::kAlternate) {
3555 const HloInstruction* defining_instruction =
3556 allocation->defining_position().instruction;
3557 outputs_in_alternate_memory_map[defining_instruction].push_back(
3558 allocation->defining_position().index);
3559 }
3560 }
3561 for (auto& hlo_use : allocation->uses()) {
3562 const HloInstruction* use_instruction = hlo_use.instruction;
3563 operands_in_alternate_memory_map[use_instruction].push_back(
3564 std::make_pair(hlo_use.operand_number, hlo_use.operand_index));
3565 }
3566 }
3567
3568 const auto& instruction_sequence =
3569 hlo_live_range.flattened_instruction_sequence().instructions();
3570 float total_elapsed = 0.0;
3571 for (const HloInstruction* instruction : instruction_sequence) {
3572 std::vector<ShapeIndex> outputs_in_alternate_memory;
3573 auto output_it = outputs_in_alternate_memory_map.find(instruction);
3574 if (output_it != outputs_in_alternate_memory_map.end()) {
3575 outputs_in_alternate_memory = output_it->second;
3576 }
3577 std::vector<std::pair<int64_t, ShapeIndex>> operands_in_alternate_memory;
3578 auto operand_it = operands_in_alternate_memory_map.find(instruction);
3579 if (operand_it != operands_in_alternate_memory_map.end()) {
3580 operands_in_alternate_memory = operand_it->second;
3581 }
3582 float instruction_elapsed =
3583 options_.cost_analysis->GetInstructionElapsedInAlternateMemory(
3584 *instruction, operands_in_alternate_memory,
3585 outputs_in_alternate_memory);
3586 float while_nest_multiplier = IPow<float>(
3587 options_.xla_tpu_memory_space_assignment_while_execution_count,
3588 options_.cost_analysis->CalculateComputationNestLevel(
3589 instruction,
3590 /*while_only=*/true));
3591 total_elapsed += while_nest_multiplier * instruction_elapsed;
3592 }
3593 return total_elapsed;
3594 }
3595
Process()3596 Status MemorySpaceAssignment::Allocation::Process() {
3597 if (is_scoped_allocation()) {
3598 // Nothing to do here for scoped allocations.
3599 return OkStatus();
3600 }
3601 HloInstruction* producing_instruction = AddGetTupleElements();
3602 HloComputation* computation = producing_instruction->parent();
3603 for (const HloUse& use : uses_) {
3604 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3605 HloInstruction* replacement_instruction = producing_instruction;
3606 if (operand_shape.IsTuple()) {
3607 TF_ASSIGN_OR_RETURN(
3608 replacement_instruction,
3609 TupleUtil::ReplaceTupleWith(
3610 producing_instruction,
3611 use.instruction->mutable_operand(use.operand_number),
3612 use.operand_index));
3613 } else if (operand_shape != producing_instruction->shape()) {
3614 VLOG(4) << "Old shape = " << operand_shape.ToString()
3615 << ", new shape = " << producing_instruction->shape().ToString()
3616 << "; inserting a bitcast.";
3617 replacement_instruction = computation->AddInstruction(
3618 HloInstruction::CreateBitcast(operand_shape, producing_instruction));
3619 }
3620 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3621 use.operand_number, replacement_instruction));
3622 }
3623 return OkStatus();
3624 }
3625
AddGetTupleElements() const3626 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() const {
3627 CHECK_NE(defining_position().instruction, nullptr);
3628
3629 Shape shape = defining_position().shape();
3630 CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
3631 << shape.ToString()
3632 << " position = " << defining_position().shape();
3633 return TupleUtil::AddGetTupleElements(defining_position());
3634 }
3635
ToString() const3636 std::string MemorySpaceAssignment::Allocation::ToString() const {
3637 std::string memory_space_str = "def";
3638 if (memory_space_ == MemorySpace::kAlternate) {
3639 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3640 }
3641 return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""),
3642 "Allocation in ", memory_space_str, " defined at ",
3643 defining_position_.ToString(),
3644 ", start_time:", start_time(), ", end_time:", end_time(),
3645 ", uses: ", UsesToString(uses()));
3646 }
3647
ToString() const3648 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
3649 std::string memory_space_str = "def";
3650 if (memory_space_ == MemorySpace::kAlternate) {
3651 memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
3652 }
3653 return absl::StrCat("Copy Allocation in ", memory_space_str,
3654 ", start_time:", start_time(), ", end_time:", end_time(),
3655 ", copy_start_after_time: ", copy_start_schedule_after(),
3656 ", copy_done_before_time: ", copy_done_schedule_before(),
3657 ", uses: ", UsesToString(uses()), ", from ",
3658 prev_allocation_.ToString());
3659 }
3660
ToString() const3661 std::string MemorySpaceAssignment::MirroredAllocation::ToString() const {
3662 return absl::StrCat("Mirrored Allocation for ",
3663 original_allocation_.ToString());
3664 }
3665
ToString() const3666 std::string MemorySpaceAssignment::ParentAllocation::ToString() const {
3667 return absl::StrCat("Parent Allocation mirrored at ",
3668 defining_position_.ToString(), ", originally ",
3669 original_allocation_.ToString());
3670 }
3671
Process()3672 Status MemorySpaceAssignment::CopyAllocation::Process() {
3673 // Copy allocations need to insert asynchronous copy nodes.
3674 Shape shape = defining_position().shape();
3675 HloInstruction* producing_instruction = AddGetTupleElements();
3676 HloComputation* computation = producing_instruction->parent();
3677 copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
3678 ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
3679 producing_instruction, is_cross_program_prefetch_));
3680 copy_done_ = computation->AddInstruction(
3681 HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
3682 VLOG(4) << "Created " << copy_start_->name()
3683 << " for copy allocation: " << ToString();
3684 // Update the allocation position with the copy done instruction so that if
3685 // there are further copies from it, it can find the correct position.
3686 defining_position_ = HloPosition{copy_done_, {}};
3687
3688 // Replace all the uses with the new copy instruction.
3689 for (HloUse use : uses_) {
3690 // If the operand is a tuple, we need to descend to the actual instruction
3691 // we want to replace.
3692 HloInstruction* replacement_instruction;
3693 Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
3694 if (operand_shape.IsTuple()) {
3695 TF_ASSIGN_OR_RETURN(
3696 replacement_instruction,
3697 TupleUtil::ReplaceTupleWith(
3698 copy_done_, use.instruction->mutable_operand(use.operand_number),
3699 use.operand_index));
3700 } else if (operand_shape != copy_done_->shape()) {
3701 VLOG(4) << "Old shape = " << operand_shape.ToString()
3702 << ", new shape = " << copy_done_->shape().ToString()
3703 << "; inserting a bitcast.";
3704 replacement_instruction = computation->AddInstruction(
3705 HloInstruction::CreateBitcast(operand_shape, copy_done_));
3706 } else {
3707 replacement_instruction = copy_done_;
3708 }
3709 TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
3710 use.operand_number, replacement_instruction));
3711 }
3712
3713 return OkStatus();
3714 }
3715
Process()3716 Status MemorySpaceAssignment::MirroredAllocation::Process() {
3717 defining_position_ = original_allocation_.defining_position();
3718 return Allocation::Process();
3719 }
3720
Process()3721 Status MemorySpaceAssignment::ParentAllocation::Process() {
3722 // Add an additional parameter to the while HLO with a reference to the buffer
3723 // in the default memory space.
3724 HloInstruction* producing_instruction =
3725 original_allocation_.AddGetTupleElements();
3726 int new_tuple_index = calling_instruction_->shape().tuple_shapes_size();
3727
3728 TF_ASSIGN_OR_RETURN(
3729 HloInstruction * new_while_operand,
3730 TupleUtil::ReplaceTupleWith(producing_instruction,
3731 calling_instruction_->mutable_operand(0),
3732 {new_tuple_index}));
3733 TF_RETURN_IF_ERROR(calling_instruction_->ReplaceOperandWithDifferentShape(
3734 0, new_while_operand));
3735 *calling_instruction_->mutable_shape() = new_while_operand->shape();
3736 *calling_instruction_->while_condition()
3737 ->parameter_instruction(0)
3738 ->mutable_shape() = new_while_operand->shape();
3739 *calling_instruction_->while_body()
3740 ->parameter_instruction(0)
3741 ->mutable_shape() = new_while_operand->shape();
3742 defining_position_.index = {new_tuple_index};
3743 // Also replace the while op with a tuple that has the old shape. Note that we
3744 // need to first take a snapshot of the users before calling ExtractPrefix
3745 // since ExtractPrefix introduces additional gte users.
3746 std::vector<HloInstruction*> while_users = calling_instruction_->users();
3747 HloInstruction* tuple_with_old_shape =
3748 TupleUtil::ExtractPrefix(calling_instruction_, new_tuple_index);
3749 TF_RETURN_IF_ERROR(calling_instruction_->ReplaceAllUsesWithDifferentShape(
3750 while_users, tuple_with_old_shape));
3751 return Allocation::Process();
3752 }
3753
PostProcess()3754 Status MemorySpaceAssignment::ParentAllocation::PostProcess() {
3755 // Update the root of the while body with the new parameter. The reason why we
3756 // need a separate post-process for this is because other allocations may have
3757 // while body root as a use, so they would update the old root instead of the
3758 // new root. Doing the post-process step later ensures the root has been
3759 // updated with other changes, and we can safely add the additional parameter.
3760 HloComputation* while_body = calling_instruction_->while_body();
3761 TF_ASSIGN_OR_RETURN(HloInstruction * new_while_body_root,
3762 TupleUtil::ReplaceTupleWith(
3763 AddGetTupleElements(), while_body->root_instruction(),
3764 defining_position_.index));
3765 while_body->set_root_instruction(new_while_body_root,
3766 /*accept_different_shape=*/true);
3767 return OkStatus();
3768 }
3769
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3770 void MemorySpaceAssignment::Allocation::MarkIfNeeded(
3771 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3772 MarkNeeded(needed_allocations);
3773 }
3774
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3775 void MemorySpaceAssignment::Allocation::MarkNeeded(
3776 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3777 needed_allocations.insert(this);
3778 }
3779
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3780 void MemorySpaceAssignment::CopyAllocation::MarkNeeded(
3781 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3782 needed_allocations.insert(this);
3783 prev_allocation_.MarkNeeded(needed_allocations);
3784 }
3785
MarkIfNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3786 void MemorySpaceAssignment::ParentAllocation::MarkIfNeeded(
3787 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3788 // Parent allocations are only needed if they have any uses or if there is a
3789 // copy allocation that copies this value (in that case, the copy allocation
3790 // will call this allocation's MarkNeeded function).
3791 if (!uses_.empty()) {
3792 MarkNeeded(needed_allocations);
3793 }
3794 }
3795
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3796 void MemorySpaceAssignment::ParentAllocation::MarkNeeded(
3797 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3798 needed_allocations.insert(this);
3799 original_allocation_.MarkNeeded(needed_allocations);
3800 }
3801
MarkNeeded(absl::flat_hash_set<const Allocation * > & needed_allocations) const3802 void MemorySpaceAssignment::MirroredAllocation::MarkNeeded(
3803 absl::flat_hash_set<const Allocation*>& needed_allocations) const {
3804 needed_allocations.insert(this);
3805 original_allocation_.MarkNeeded(needed_allocations);
3806 }
3807
Process()3808 Status MemorySpaceAssignment::Process() {
3809 VLOG(1) << "Processing assigned buffers...";
3810 // Since some parent allocations may not be needed (e.g. when they don't have
3811 // any uses and if there is no other (non-parent) allocation that depends on
3812 // it, before we process the allocations, mark all allocations that are
3813 // needed.
3814 absl::flat_hash_set<const Allocation*> needed_allocations;
3815 for (auto& allocation : allocations_) {
3816 allocation->MarkIfNeeded(needed_allocations);
3817 }
3818 // Insert CopyStart/CopyDone pairs.
3819 for (auto& allocation : allocations_) {
3820 VLOG(3) << "Processing: " << allocation->ToString();
3821 if (!needed_allocations.contains(allocation.get())) {
3822 VLOG(3) << "Allocation not needed.";
3823 continue;
3824 }
3825 TF_RETURN_IF_ERROR(allocation->Process());
3826 // Add the offset and size of the allocation in the alternate memory to
3827 // the output map.
3828 if (allocation->is_scoped_allocation()) {
3829 CHECK(allocation->memory_space() == MemorySpace::kAlternate);
3830 scoped_memory_assignments_.emplace_back(
3831 allocation->defining_position().instruction, allocation->chunk());
3832 alternate_memory_size_ =
3833 std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3834 } else if (allocation->memory_space() == MemorySpace::kAlternate) {
3835 alternate_memory_assignments_.emplace_back(
3836 allocation->defining_position(), allocation->chunk());
3837 alternate_memory_size_ =
3838 std::max(alternate_memory_size_, allocation->chunk().chunk_end());
3839 }
3840 }
3841 // Post-process allocations. This is only used for parent allocations where we
3842 // update the body root with a reference to the buffer in default memory
3843 // space.
3844 for (auto& allocation : allocations_) {
3845 if (needed_allocations.contains(allocation.get())) {
3846 VLOG(3) << "Post-Processing: " << allocation->ToString();
3847 TF_RETURN_IF_ERROR(allocation->PostProcess());
3848 }
3849 }
3850 return OkStatus();
3851 }
3852
ExportAndColorBuffers()3853 Status MemorySpaceAssignment::ExportAndColorBuffers() {
3854 VLOG(1) << "Exporting buffers...";
3855 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
3856 absl::flat_hash_map<int64_t, int64_t> seen_buffer_offsets;
3857 VLOG(3) << "Exported alternate memory allocations:";
3858 for (const auto& position_and_chunk : alternate_memory_assignments_) {
3859 const HloPosition& defining_position = position_and_chunk.first;
3860 const Chunk& chunk = position_and_chunk.second;
3861 const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
3862 defining_position.instruction, defining_position.index);
3863 auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
3864 if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
3865 CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
3866 << "Mismatch in offset for positions that map to the same value: "
3867 << buffer.ToString() << ", pos: " << defining_position.ToString();
3868 } else {
3869 VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3870 << "] : " << defining_position.ToString() << " ("
3871 << buffer.ToString() << ")";
3872 preset_assignments_->add_chunk(defining_position, chunk);
3873 seen_buffer_offsets[buffer.id()] = chunk.offset;
3874 }
3875 }
3876
3877 VLOG(3) << "Exported scoped allocations in alternate memory:";
3878 for (const auto& instruction_and_chunk : scoped_memory_assignments_) {
3879 HloInstruction* instruction = instruction_and_chunk.first;
3880 const Chunk& chunk = instruction_and_chunk.second;
3881 VLOG(3) << " [" << chunk.offset << ", " << chunk.size
3882 << "] : " << instruction->name();
3883 preset_assignments_->add_scoped_allocation_chunk(instruction, chunk);
3884 }
3885
3886 if (!preset_assignments_->chunks().empty() ||
3887 !preset_assignments_->scoped_allocation_chunks().empty()) {
3888 preset_assignments_
3889 ->assignment_information_for_space(options_.alternate_memory_space)
3890 ->size = alternate_memory_size_;
3891 }
3892
3893 VLOG(3) << "Exported alternate memory sizes:";
3894 for (auto& pair : preset_assignments_->assignment_informations()) {
3895 VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size;
3896 }
3897
3898 VLOG(1) << "Coloring buffers...";
3899 // Color the pending positions and all of their aliased buffers.
3900 for (const auto& defining_position_and_chunk :
3901 preset_assignments_->chunks()) {
3902 const HloPosition& defining_position = defining_position_and_chunk.first;
3903 for (auto& buffer : alias_analysis->ComputeBuffersAt(
3904 defining_position.instruction, defining_position.index)) {
3905 for (auto& value : buffer->values()) {
3906 for (auto& position : value->positions()) {
3907 VLOG(4) << "Coloring " << position.ToString();
3908 Shape* shape = ShapeUtil::GetMutableSubshape(
3909 position.instruction->mutable_shape(), position.index);
3910 CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
3911 << position.ToString();
3912 shape->mutable_layout()->set_memory_space(
3913 options_.alternate_memory_space);
3914 }
3915 }
3916 }
3917 }
3918 return OkStatus();
3919 }
3920
RemoveAssignmentForInstruction(const HloInstruction * instruction)3921 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
3922 const HloInstruction* instruction) {
3923 auto it = alternate_memory_assignments_.begin();
3924 auto end = alternate_memory_assignments_.end();
3925 while (it != end) {
3926 const HloPosition& position = it->first;
3927 if (position.instruction == instruction) {
3928 VLOG(3) << "Removing instruction from alternate memory assignments.";
3929 if (std::next(it) == end) {
3930 alternate_memory_assignments_.pop_back();
3931 break;
3932 } else {
3933 // Swap the removed position and chunk with the back and pop back.
3934 *it = alternate_memory_assignments_.back();
3935 alternate_memory_assignments_.pop_back();
3936 end = alternate_memory_assignments_.end();
3937 }
3938 } else {
3939 ++it;
3940 }
3941 }
3942 }
3943
SimplifyGraph()3944 Status MemorySpaceAssignment::SimplifyGraph() {
3945 VLOG(1) << "Simplifying graph...";
3946 for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3947 // Parallel computations aren't in the schedule and don't need to be
3948 // modified.
3949 if (!computations_in_schedule_.contains(computation)) {
3950 VLOG(4) << "Not simplifying " << computation->name()
3951 << " because it's not in the schedule.";
3952 continue;
3953 }
3954 // Drop control dependencies. Since the computation is already scheduled, we
3955 // don't need control dependencies anymore, and having control
3956 // predecessors/successors prevents us from removing instructions without
3957 // users (HloComputation::IsSafelyRemovable returns false if there are
3958 // control dependencies).
3959 for (HloInstruction* instruction :
3960 computation->MakeInstructionPostOrder()) {
3961 TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3962 }
3963 // We perform limited DCE and forward the tuple operand in patterns like
3964 // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3965 // assignment is ran late in compilation (after DCE and arithmetic
3966 // simplification passes) and we don't want to generate redundant code. Run
3967 // to fixed point.
3968 bool computation_modified = true;
3969 while (computation_modified) {
3970 computation_modified = false;
3971 VLOG(4) << "Running simplify graph loop over " << computation->name();
3972 for (HloInstruction* instruction :
3973 computation->MakeInstructionPostOrder()) {
3974 if (computation->IsSafelyRemovable(instruction) &&
3975 instruction->IsDead() && !instruction->HasSideEffect() &&
3976 instruction->opcode() != HloOpcode::kCopyStart &&
3977 instruction->opcode() != HloOpcode::kCopyDone) {
3978 VLOG(4) << "Instruction removed: " << instruction->ToString();
3979 // Ensure the alternate memory assignments don't contain a reference
3980 // to the removed instruction.
3981 RemoveAssignmentForInstruction(instruction);
3982 // Instead of deleting the instruction from the schedule, replace it
3983 // with a nullptr. This is needed because FixSchedule relies on the
3984 // logical time that is the index into flattened_instructions_ for
3985 // scheduling asynchronous copies.
3986 auto instruction_it =
3987 absl::c_find(flattened_instructions_, instruction);
3988 if (instruction_it != flattened_instructions_.end()) {
3989 *instruction_it = nullptr;
3990 }
3991 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3992 computation_modified = true;
3993 } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3994 HloInstruction* operand = instruction->mutable_operand(0);
3995 if (operand->opcode() == HloOpcode::kTuple) {
3996 HloInstruction* forwarded_instruction =
3997 operand->mutable_operand(instruction->tuple_index());
3998 VLOG(4) << "Replacing uses of " << instruction->ToString()
3999 << " with " << forwarded_instruction->ToString();
4000 TF_RETURN_IF_ERROR(
4001 instruction->ReplaceAllUsesWith(forwarded_instruction));
4002 computation_modified = true;
4003 }
4004 } else if (instruction->opcode() == HloOpcode::kTuple) {
4005 // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
4006 // with x.
4007 bool can_replace =
4008 instruction->operand_count() > 0 &&
4009 instruction->operand(0)->opcode() ==
4010 HloOpcode::kGetTupleElement &&
4011 instruction->operand(0)
4012 ->operand(0)
4013 ->shape()
4014 .tuple_shapes_size() == instruction->operand_count();
4015 for (int operand_number = 0;
4016 operand_number < instruction->operand_count();
4017 ++operand_number) {
4018 const HloInstruction* operand =
4019 instruction->operand(operand_number);
4020 if (operand->opcode() != HloOpcode::kGetTupleElement ||
4021 operand->tuple_index() != operand_number ||
4022 operand->operand(0) != instruction->operand(0)->operand(0)) {
4023 can_replace = false;
4024 break;
4025 }
4026 }
4027 if (can_replace) {
4028 HloInstruction* forwarded_instruction =
4029 instruction->mutable_operand(0)->mutable_operand(0);
4030 VLOG(4) << "Replacing uses of " << instruction->ToString()
4031 << " with " << forwarded_instruction->ToString();
4032 TF_RETURN_IF_ERROR(
4033 instruction->ReplaceAllUsesWith(forwarded_instruction));
4034 computation_modified = true;
4035 }
4036 }
4037 }
4038 }
4039 }
4040
4041 return OkStatus();
4042 }
4043
ScheduleAsynchronousCopies()4044 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
4045 VLOG(1) << "Scheduling asynchronous copies...";
4046 for (MemorySpace memory_space :
4047 {MemorySpace::kDefault, MemorySpace::kAlternate}) {
4048 std::vector<CopyAllocation*> copy_allocations;
4049 for (auto& allocation : allocations_) {
4050 if (allocation->is_copy_allocation()) {
4051 auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
4052 if (copy_allocation->memory_space() == memory_space) {
4053 copy_allocations.push_back(copy_allocation);
4054 }
4055 }
4056 }
4057
4058 absl::c_stable_sort(
4059 copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
4060 return std::forward_as_tuple(first->copy_done_schedule_before(),
4061 first->copy_start_schedule_after()) <
4062 std::forward_as_tuple(second->copy_done_schedule_before(),
4063 second->copy_start_schedule_after());
4064 });
4065 for (CopyAllocation* copy_allocation : copy_allocations) {
4066 // If the copy start doesn't happen to be scheduled at the correct
4067 // computation, delay it until the correct computation starts.
4068 int64_t copy_start_schedule_after =
4069 copy_allocation->copy_start_schedule_after();
4070 // Accessing flattened_instructions_ here without checking if it is
4071 // nullptr is safe because this method is called before SimplifyGraph.
4072 while (copy_allocation->defining_position().instruction->parent() !=
4073 flattened_instructions_[copy_start_schedule_after]->parent()) {
4074 VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
4075 << (copy_start_schedule_after + 1) << ") for "
4076 << copy_allocation->copy_start()->ToString()
4077 << " because it is not in the correct computation.";
4078 copy_allocation->set_copy_start_schedule_after(
4079 ++copy_start_schedule_after);
4080 }
4081
4082 schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
4083 copy_allocation->copy_start());
4084 schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
4085 copy_allocation->copy_done());
4086 }
4087 }
4088 }
4089
FixSchedule()4090 Status MemorySpaceAssignment::FixSchedule() {
4091 VLOG(1) << "Fixing schedule...";
4092 TF_RET_CHECK(module_->has_schedule());
4093 HloSchedule& schedule = module_->schedule();
4094 for (const HloComputation* computation :
4095 module_->MakeNonfusionComputations()) {
4096 // Parallel computations aren't in the schedule and don't need to be
4097 // modified.
4098 if (!computations_in_schedule_.contains(computation)) {
4099 VLOG(4) << "Not scheduling " << computation->name()
4100 << " because it's not in the schedule.";
4101 continue;
4102 }
4103 TF_RET_CHECK(schedule.is_computation_scheduled(computation));
4104 HloInstructionSequence new_sequence;
4105
4106 absl::flat_hash_set<HloInstruction*> inserted_instructions;
4107
4108 VLOG(4) << "Scheduling: " << computation->ToString();
4109
4110 for (int64_t instruction_index = 0;; ++instruction_index) {
4111 auto insts_before_iter = schedule_before_.find(instruction_index);
4112 if (insts_before_iter != schedule_before_.end()) {
4113 for (HloInstruction* new_instruction : insts_before_iter->second) {
4114 if (new_instruction->parent() == computation) {
4115 VLOG(4) << "before " << instruction_index << ": "
4116 << new_instruction->name();
4117 TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4118 new_instruction, &new_sequence, &inserted_instructions));
4119 }
4120 }
4121 }
4122 // We allow scheduling copy dones past the root instruction (for
4123 // end-of-program cross-program prefetch). So the loop exit condition is
4124 // actually here.
4125 if (instruction_index >= flattened_instructions_.size()) {
4126 break;
4127 }
4128 HloInstruction* instruction = flattened_instructions_[instruction_index];
4129 // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
4130 // it was deleted) and not previously inserted. Also bitcasts and tuples
4131 // are treated specially and only inserted as a result of operand
4132 // dependencies.
4133 if (instruction != nullptr && instruction->parent() == computation &&
4134 instruction->opcode() != HloOpcode::kBitcast &&
4135 instruction->opcode() != HloOpcode::kTuple &&
4136 !inserted_instructions.contains(instruction)) {
4137 VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
4138 TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4139 instruction, &new_sequence, &inserted_instructions));
4140 }
4141 auto insts_after_iter = schedule_after_.find(instruction_index);
4142 if (insts_after_iter != schedule_after_.end()) {
4143 for (HloInstruction* new_instruction : insts_after_iter->second) {
4144 if (new_instruction->parent() == computation) {
4145 VLOG(4) << "after " << instruction_index << ": "
4146 << new_instruction->name();
4147 TF_RETURN_IF_ERROR(InsertInstructionAndEnsureOperandsInserted(
4148 new_instruction, &new_sequence, &inserted_instructions));
4149 }
4150 }
4151 }
4152 }
4153 // For rare cases where the original sequence is empty, ensure the root
4154 // instruction and its dependencies are scheduled.
4155 TF_RETURN_IF_ERROR(EnsureInstructionAndOperandsInserted(
4156 computation->root_instruction(), &new_sequence,
4157 &inserted_instructions));
4158 CHECK_EQ(new_sequence.size(), computation->instruction_count())
4159 << "New sequence for computation " << computation->name() << " has "
4160 << new_sequence.size() << " instructions, expects "
4161 << computation->instruction_count() << ".";
4162 schedule.set_sequence(computation, new_sequence);
4163 }
4164
4165 return OkStatus();
4166 }
4167
VerifyAndExportHeapSimulatorTrace()4168 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
4169 VLOG(1) << "Verifying...";
4170 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
4171 HloAliasAnalysis::Run(module_));
4172 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
4173 HloLiveRange::Run(module_->schedule(), *alias_analysis,
4174 module_->entry_computation()));
4175
4176 BufferIntervalTree interval_tree;
4177 absl::flat_hash_set<int64_t> seen_buffers;
4178 // The key for events is: time, is_free, value_id. This is so that the events
4179 // are sorted first by time, then within the same time, allocations are sorted
4180 // earlier than frees, and finally the value id as a tie breaker.
4181 std::map<std::tuple<int64_t, bool, int64_t>,
4182 std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
4183 events;
4184
4185 auto add_allocation_and_verify = [&](int64_t start_time, int64_t end_time,
4186 const Chunk& chunk,
4187 const HloValue* value) {
4188 events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
4189 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
4190 events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
4191 std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
4192
4193 // Get the chunks overlapping in time and search if they overlap in space
4194 // as well.
4195 // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
4196 // really should check against end_time (inclusive) for cases where the
4197 // operand can't share buffer with user (see
4198 // HloDataflowAnalysis::CanShareOperandBufferWithUser).
4199 for (const Chunk& overlapping_chunk :
4200 interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
4201 if (chunk.OverlapsWith(overlapping_chunk)) {
4202 return InternalError(
4203 ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
4204 " off: %d size: %d"),
4205 value->ToShortString(), start_time, end_time, chunk.offset,
4206 chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
4207 }
4208 }
4209 interval_tree.Add(start_time, end_time - 1, chunk);
4210 return OkStatus();
4211 };
4212
4213 // Go through all instructions in the module to ensure CopyStart/CopyDone
4214 // instructions copy between alternate memory and default memory.
4215 for (const HloComputation* computation :
4216 module_->MakeNonfusionComputations()) {
4217 for (const HloInstruction* instruction : computation->instructions()) {
4218 if (instruction->opcode() == HloOpcode::kCopyStart) {
4219 int64_t from_memory_space =
4220 ShapeUtil::GetSubshape(instruction->shape(), {1})
4221 .layout()
4222 .memory_space();
4223 int64_t to_memory_space =
4224 ShapeUtil::GetSubshape(instruction->shape(), {0})
4225 .layout()
4226 .memory_space();
4227 CHECK_NE(from_memory_space, to_memory_space)
4228 << "Asynchronous copy to the same memory space: "
4229 << instruction->ToString();
4230 }
4231 }
4232 }
4233
4234 for (const auto& position_and_chunk : preset_assignments_->chunks()) {
4235 const HloPosition& position = position_and_chunk.first;
4236 const Chunk& chunk = position_and_chunk.second;
4237 const HloBuffer& buffer =
4238 alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
4239 CHECK(!seen_buffers.contains(buffer.id()))
4240 << "Multiple preset assignments for the same buffer: "
4241 << buffer.ToString() << ", pos: " << position.ToString()
4242 << ", off: " << chunk.offset << ", size: " << chunk.size;
4243 seen_buffers.insert(buffer.id());
4244
4245 for (const HloValue* value : buffer.values()) {
4246 const HloLiveRange::TimeBound& time_bound =
4247 hlo_live_range->buffer_live_ranges().at(value);
4248 const HloInstruction* last_use_instruction = nullptr;
4249 int64_t last_use_time = time_bound.start;
4250 for (const HloUse& use : value->GetUses()) {
4251 int64_t use_time =
4252 hlo_live_range->instruction_schedule().at(use.instruction);
4253 if (use_time > last_use_time) {
4254 last_use_time = use_time;
4255 last_use_instruction = use.instruction;
4256 }
4257 }
4258
4259 std::function<Status(const HloInstruction*, int64_t, int64_t,
4260 absl::string_view)>
4261 split_conditional_buffer;
4262 split_conditional_buffer = [&](const HloInstruction* use_instruction,
4263 int64_t start_time, int64_t end_time,
4264 absl::string_view indent_string) {
4265 // Special case when verifying conditional: we internally split the use
4266 // of alternate memory in conditionals, so fish them out from the
4267 // conditionals.
4268 VLOG(3) << indent_string
4269 << "Splitting conditional buffer: " << buffer.ToString()
4270 << " value: " << value->ToShortString() << ": (" << start_time
4271 << ", " << end_time << ") off: " << chunk.offset
4272 << ", size: " << chunk.size;
4273 int64_t earliest_computation_start_time = end_time;
4274 for (const HloComputation* called_computation :
4275 use_instruction->called_computations()) {
4276 int64_t computation_start_time =
4277 hlo_live_range->computation_span_times()
4278 .at(called_computation)
4279 .start;
4280 earliest_computation_start_time =
4281 std::min(earliest_computation_start_time, computation_start_time);
4282 int64_t last_use_time = -1;
4283 const HloInstruction* last_use_instruction = nullptr;
4284 for (const HloUse& use : value->GetUses()) {
4285 int64_t use_time =
4286 hlo_live_range->instruction_schedule().at(use.instruction);
4287 if (use.instruction->parent() == called_computation &&
4288 use_time > last_use_time) {
4289 last_use_time = use_time;
4290 last_use_instruction = use.instruction;
4291 }
4292 }
4293 if (last_use_time != -1) {
4294 VLOG(3) << indent_string
4295 << " computation: " << called_computation->name() << ": ("
4296 << computation_start_time << ", " << last_use_time << ")";
4297 CHECK(last_use_instruction);
4298 if (last_use_instruction->opcode() == HloOpcode::kConditional) {
4299 // The last use is another (nested) conditional. Call this
4300 // function recursively.
4301 TF_RETURN_IF_ERROR(split_conditional_buffer(
4302 last_use_instruction, computation_start_time, last_use_time,
4303 absl::StrCat(indent_string, " ")));
4304 } else {
4305 last_use_time = std::min(last_use_time, end_time);
4306 TF_RETURN_IF_ERROR(add_allocation_and_verify(
4307 computation_start_time, last_use_time, chunk, value));
4308 }
4309 }
4310 }
4311 VLOG(3) << indent_string << " from beginning until first computation: ("
4312 << start_time << ", " << (earliest_computation_start_time - 1)
4313 << ")";
4314 TF_RETURN_IF_ERROR(add_allocation_and_verify(
4315 start_time, earliest_computation_start_time - 1, chunk, value));
4316 return OkStatus();
4317 };
4318
4319 if (last_use_instruction &&
4320 last_use_instruction->opcode() == HloOpcode::kConditional) {
4321 TF_RETURN_IF_ERROR(split_conditional_buffer(
4322 last_use_instruction, time_bound.start, time_bound.end, " "));
4323 } else if (!value->GetUses().empty()) {
4324 last_use_time = std::min(last_use_time, time_bound.end);
4325 VLOG(3) << " buffer: " << buffer.ToString()
4326 << " value: " << value->ToShortString() << ": ("
4327 << time_bound.start << ", " << last_use_time
4328 << ") off: " << chunk.offset << ", size: " << chunk.size;
4329 TF_RETURN_IF_ERROR(add_allocation_and_verify(
4330 time_bound.start, last_use_time, chunk, value));
4331 }
4332 }
4333 }
4334
4335 HeapSimulatorTrace* heap_trace =
4336 &preset_assignments_
4337 ->assignment_information_for_space(options_.alternate_memory_space)
4338 ->heap_simulator_trace;
4339 int64_t memory_usage = 0;
4340 int64_t max_memory_usage = 0;
4341 for (const auto& event : events) {
4342 int64_t time;
4343 bool is_free;
4344 int64_t buffer_id;
4345 std::tie(time, is_free, buffer_id) = event.first;
4346 const HloValue* value;
4347 Chunk chunk;
4348 HeapSimulatorTrace::Event::Kind kind;
4349 std::tie(value, chunk, kind) = event.second;
4350 HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
4351 heap_trace_event->set_kind(kind);
4352 heap_trace_event->set_buffer_id(buffer_id);
4353 heap_trace_event->set_instruction_name(value->instruction()->name());
4354 heap_trace_event->set_computation_name(
4355 value->instruction()->parent()->name());
4356
4357 if (kind == HeapSimulatorTrace::Event::ALLOC) {
4358 memory_usage += chunk.size;
4359 } else {
4360 CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
4361 memory_usage -= chunk.size;
4362 }
4363 max_memory_usage = std::max(max_memory_usage, memory_usage);
4364 VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
4365 }
4366 VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
4367
4368 return OkStatus();
4369 }
4370 } // namespace memory_space_assignment
4371 } // namespace xla
4372