1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the data returned by the XLA buffer assignment packages.
17
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19
20 #include <algorithm>
21 #include <deque>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <utility>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/btree_map.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
35 #include "tensorflow/compiler/xla/service/heap_simulator.h"
36 #include "tensorflow/compiler/xla/service/hlo.pb.h"
37 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
39 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
40 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_value.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/strings/numbers.h"
49
50 namespace xla {
51 namespace {
52
53 using absl::flat_hash_map;
54 using absl::flat_hash_set;
55 using absl::StrAppend;
56 using absl::StrAppendFormat;
57 using memory_space_assignment::PresetAssignments;
58 using ::tensorflow::strings::HumanReadableNumBytes;
59
60 // Given the interference map of a graph (the list of interfering node indices
61 // for each node), perform graph coloring such that interfering nodes are
62 // assigned to different colors. Returns the assigned color of the nodes, where
63 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64_t>> & interference_map)64 std::vector<int64_t> ColorInterferenceGraph(
65 const std::vector<std::vector<int64_t>>& interference_map) {
66 const int64_t node_count = interference_map.size();
67
68 // Sort the nodes such that we assign nodes with more interference first. This
69 // relies on the common heuristic of assigning the most constrained node
70 // first, but it would be good to investigate other ordering heuristics too.
71 std::vector<int64_t> nodes(node_count);
72 std::iota(nodes.begin(), nodes.end(), 0);
73 absl::c_sort(nodes, [&interference_map](const int64_t i, const int64_t j) {
74 return interference_map[i].size() > interference_map[j].size();
75 });
76
77 const int64_t kColorUnassigned = -1;
78 std::vector<int64_t> assigned_colors(node_count, kColorUnassigned);
79 for (int64_t node : nodes) {
80 // Mark the colors that are already assigned to the neighbors.
81 std::vector<bool> available_colors(node_count, true);
82 for (int64_t neighbor : interference_map[node]) {
83 int64_t color = assigned_colors[neighbor];
84 if (color != kColorUnassigned) {
85 available_colors[color] = false;
86 }
87 }
88
89 // Find the color that is not yet assigned to the neighbors.
90 int64_t color = kColorUnassigned;
91 for (color = 0; color < available_colors.size(); ++color) {
92 if (available_colors[color]) {
93 break;
94 }
95 }
96 CHECK_NE(color, kColorUnassigned);
97 assigned_colors[node] = color;
98 }
99 return assigned_colors;
100 }
101
102 } // namespace
103
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)104 Status GatherComputationsByAllocationType(
105 const HloModule* module,
106 std::vector<const HloComputation*>* thread_local_computations,
107 std::vector<const HloComputation*>* global_computations) {
108 // Create a worklist of computations paired with whether the allocation must
109 // be thread-local.
110 std::deque<std::pair<const HloComputation*, bool>> worklist;
111 worklist.push_back(std::make_pair(module->entry_computation(),
112 /*is_thread_local*/ false));
113
114 // Sets for quickly checking membership. Computations are returned in vectors
115 // for stable iteration.
116 flat_hash_set<const HloComputation*> thread_local_set;
117 flat_hash_set<const HloComputation*> global_set;
118
119 while (!worklist.empty()) {
120 auto worklist_front = worklist.front();
121 worklist.pop_front();
122 const HloComputation* computation = worklist_front.first;
123 bool is_thread_local = worklist_front.second;
124 bool in_thread_local_set = thread_local_set.contains(computation);
125 bool in_global_set = global_set.contains(computation);
126
127 // If the computation has already been added to the respective set, then
128 // nothing to do.
129 if ((is_thread_local && in_thread_local_set) ||
130 (!is_thread_local && in_global_set)) {
131 continue;
132 }
133
134 // If the computation has already been added to the other set this is an
135 // error condition because the global call to the computation (eg,
136 // while/call) may return a reference to one of the thread-local buffers to
137 // the calling computation which will become a dangling reference when the
138 // thread-local is deallocated with the call return.
139 if ((is_thread_local && in_global_set) ||
140 (!is_thread_local && in_thread_local_set)) {
141 return InvalidArgument(
142 "computation %s has conflicting allocation requirements (global "
143 "and thread-local)",
144 computation->name());
145 }
146
147 if (is_thread_local) {
148 thread_local_set.insert(computation);
149 } else {
150 global_set.insert(computation);
151 }
152
153 for (auto* instruction : computation->instructions()) {
154 for (HloComputation* subcomputation :
155 instruction->called_computations()) {
156 switch (instruction->opcode()) {
157 case HloOpcode::kCall:
158 case HloOpcode::kConditional:
159 case HloOpcode::kWhile:
160 case HloOpcode::kAsyncStart:
161 case HloOpcode::kAsyncUpdate:
162 case HloOpcode::kAsyncDone:
163 // Call, conditional, while, and async operations must be called
164 // from a computation with global allocations as they may return
165 // references to buffers inside the called computation which cannot
166 // be thread-local.
167 if (is_thread_local) {
168 return InvalidArgument(
169 "computation %s cannot contain call/while op because it "
170 "requires thread-local buffer allocations",
171 computation->name());
172 }
173 worklist.push_back(std::make_pair(subcomputation,
174 false)); // Not thread local.
175 break;
176 case HloOpcode::kCustomCall:
177 case HloOpcode::kAllReduce:
178 case HloOpcode::kReduceScatter:
179 case HloOpcode::kAllReduceStart:
180 case HloOpcode::kMap:
181 case HloOpcode::kReduce:
182 case HloOpcode::kReduceWindow:
183 case HloOpcode::kScatter:
184 case HloOpcode::kSelectAndScatter:
185 case HloOpcode::kSort:
186 case HloOpcode::kFusion:
187 // Map/reduce etc computations are always thread-local.
188 worklist.push_back(std::make_pair(subcomputation,
189 true)); // Thread local.
190 break;
191 default:
192 return InternalError("Unexpected calling opcode: %s",
193 HloOpcodeString(instruction->opcode()));
194 }
195 }
196 }
197 }
198
199 // Add the computations to the vectors in post order.
200 for (auto* computation : module->MakeComputationPostOrder()) {
201 if (thread_local_set.contains(computation)) {
202 thread_local_computations->push_back(computation);
203 } else if (global_set.contains(computation)) {
204 global_computations->push_back(computation);
205 }
206 // If the computation is not reachable from the entry computation, then it
207 // will not appear in either thread_local_set or global_set. We don't bother
208 // assigning buffers for these.
209 }
210 return OkStatus();
211 }
212
ToString() const213 std::string BufferAllocation::Slice::ToString() const {
214 return absl::StrCat("{index:", index(), ", offset:", offset_,
215 ", size:", size_, "}");
216 }
217
GetSlice(const HloValue & buffer) const218 BufferAllocation::Slice BufferAllocation::GetSlice(
219 const HloValue& buffer) const {
220 const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
221 return Slice(this, os.offset, os.size);
222 }
223
AddAssignment(const HloValue & buffer,int64_t offset,int64_t size)224 void BufferAllocation::AddAssignment(const HloValue& buffer, int64_t offset,
225 int64_t size) {
226 VLOG(4) << "Adding the following buffer to allocation #" << index()
227 << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
228 buffer.ToShortString());
229 CHECK(!assigned_buffers_.contains(&buffer))
230 << "LogicalBuffer " << buffer << " already assigned to allocation "
231 << index_;
232 CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
233 << " offset out of range";
234 CHECK_LE(offset + size, size_)
235 << "LogicalBuffer " << buffer
236 << " size out of range at offset: " << offset << " with size: " << size;
237 CHECK_EQ(buffer.color(), color())
238 << "Buffer color " << buffer.color() << " for buffer " << buffer
239 << " does not match allocation color " << color() << ".";
240 OffsetSize offset_size;
241 offset_size.offset = offset;
242 offset_size.size = size;
243 assigned_buffers_.emplace(&buffer, offset_size);
244 // For debugging purposes, store the assigned memory space in the
245 // instruction's layout.
246 for (HloPosition position : buffer.positions()) {
247 Shape* shape = ShapeUtil::GetMutableSubshape(
248 position.instruction->mutable_shape(), position.index);
249 if (shape->has_layout()) {
250 shape->mutable_layout()->set_memory_space(buffer.color());
251 }
252 }
253 }
254
ToProto() const255 BufferAllocationProto BufferAllocation::ToProto() const {
256 BufferAllocationProto proto;
257 proto.set_index(index_);
258 proto.set_size(size_);
259 proto.set_is_thread_local(is_thread_local_);
260 proto.set_is_tuple(is_tuple_);
261 proto.set_color(color_);
262 if (is_entry_computation_parameter_) {
263 proto.set_is_entry_computation_parameter(true);
264 for (int64_t idx : param_shape_index()) {
265 proto.add_parameter_shape_index(idx);
266 }
267 proto.set_parameter_number(parameter_number_);
268 }
269 proto.set_is_constant(is_constant_);
270 proto.set_maybe_live_out(maybe_live_out_);
271 for (const auto& buffer_offset_size : assigned_buffers_) {
272 BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
273 proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
274 proto_assigned->set_offset(buffer_offset_size.second.offset);
275 proto_assigned->set_size(buffer_offset_size.second.size);
276 }
277 absl::c_sort(*proto.mutable_assigned(),
278 [](const BufferAllocationProto::Assigned& assign1,
279 const BufferAllocationProto::Assigned& assign2) {
280 return assign1.logical_buffer_id() <
281 assign2.logical_buffer_id();
282 });
283 return proto;
284 }
285
CompareHloValuesById(const HloValue * a,const HloValue * b)286 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
287 return a->id() < b->id();
288 }
289
290 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)291 static const HloInstruction* GetEntryParameterInstruction(
292 const BufferAllocation& alloc) {
293 for (const auto& p : alloc.assigned_buffers()) {
294 const HloValue* value = p.first;
295 const HloInstruction* instr = value->instruction();
296 if (instr->opcode() == HloOpcode::kParameter &&
297 instr->parent() == instr->parent()->parent()->entry_computation()) {
298 return instr;
299 }
300 }
301 return nullptr;
302 }
303
304 // Returns root module output instruction corresponding to the allocation or
305 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)306 static const HloInstruction* GetOutputInstruction(
307 const BufferAllocation& alloc) {
308 for (const auto& p : alloc.assigned_buffers()) {
309 const HloValue* value = p.first;
310 for (const HloPosition& position : value->positions()) {
311 const HloInstruction* instr = position.instruction;
312 if (position.index.empty() &&
313 instr->parent()->root_instruction() == instr &&
314 instr->parent()->IsEntryComputation()) {
315 return instr;
316 }
317 }
318 }
319 return nullptr;
320 }
321
ToString() const322 std::string BufferAllocation::ToString() const {
323 std::string output;
324 StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
325 if (color() != 0) {
326 StrAppend(&output, ", color ", color());
327 }
328 if (is_entry_computation_parameter()) {
329 const HloInstruction* param = GetEntryParameterInstruction(*this);
330 StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
331 param ? param->shape().ToString(/*print_layout=*/false)
332 : "<unknown shape>",
333 "| at ShapeIndex ", param_shape_index().ToString());
334 }
335 if (const HloInstruction* instr = GetOutputInstruction(*this)) {
336 StrAppend(&output, ", output shape is |",
337 instr->shape().ToString(/*print_layout=*/false), "|");
338 }
339 if (is_constant()) {
340 StrAppend(&output, ", constant");
341 }
342 if (is_thread_local()) {
343 StrAppend(&output, ", thread-local");
344 }
345 if (maybe_live_out()) {
346 StrAppend(&output, ", maybe-live-out");
347 }
348 if (IsPreallocatedTempBuffer()) {
349 StrAppend(&output, ", preallocated-temp");
350 }
351 StrAppend(&output, ":\n");
352 // Dump the assigned buffers ordered by id.
353 std::vector<const HloValue*> sorted_buffers;
354 for (const auto& buffer_offset_size : assigned_buffers_) {
355 sorted_buffers.push_back(buffer_offset_size.first);
356 }
357 absl::c_sort(sorted_buffers, &CompareHloValuesById);
358 for (const HloValue* buffer : sorted_buffers) {
359 const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
360 StrAppend(&output,
361 absl::StrFormat(
362 " value: %s (size=%d,offset=%d): %s\n",
363 buffer->ToShortString(), offset_size.size, offset_size.offset,
364 ShapeUtil::HumanStringWithLayout(buffer->shape())));
365 }
366 return output;
367 }
368
operator <<(std::ostream & out,const BufferAllocation & buffer)369 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
370 out << buffer.ToString();
371 return out;
372 }
373
operator <<(std::ostream & out,const BufferAllocation::Slice & s)374 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
375 out << s.ToString();
376 return out;
377 }
378
HasAllocation(const HloValue & value) const379 bool BufferAssignment::HasAllocation(const HloValue& value) const {
380 return allocation_index_for_value_.contains(&value);
381 }
382
HasAllocation(const HloBuffer & buffer) const383 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
384 return allocation_index_for_value_.contains(buffer.values()[0]);
385 }
386
GetAssignedAllocation(const HloValue & value) const387 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
388 const HloValue& value) const {
389 CHECK(HasAllocation(value));
390 return GetAllocation(allocation_index_for_value_.at(&value));
391 }
392
GetAssignedAllocation(const HloBuffer & hlo_buffer) const393 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
394 const HloBuffer& hlo_buffer) const {
395 return GetAssignedAllocation(*hlo_buffer.values()[0]);
396 }
397
GetMutableAssignedAllocation(const HloBuffer & buffer)398 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
399 const HloBuffer& buffer) {
400 return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
401 }
402
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const403 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
404 const HloInstruction* instruction, const ShapeIndex& index) const {
405 std::set<BufferAllocation::Slice> result;
406 for (const HloValue* value :
407 dataflow_analysis().GetValueSet(instruction, index).values()) {
408 if (HasAllocation(*value)) {
409 result.insert(GetAssignedAllocation(*value).GetSlice(*value));
410 }
411 }
412 return result;
413 }
414
GetAllocation(BufferAllocation::Index index) const415 const BufferAllocation& BufferAssignment::GetAllocation(
416 BufferAllocation::Index index) const {
417 CHECK_GE(index, 0);
418 CHECK_LT(index, allocations_.size());
419 return allocations_[index];
420 }
421
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const422 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
423 const HloInstruction* hlo, const ShapeIndex& shape_index) const {
424 const HloValue* value =
425 dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
426
427 if (!HasAllocation(*value)) {
428 return nullptr;
429 }
430
431 const BufferAllocation& instruction_allocation =
432 GetAssignedAllocation(*value);
433 return &instruction_allocation;
434 }
435
GetMutableAllocation(BufferAllocation::Index index)436 BufferAllocation* BufferAssignment::GetMutableAllocation(
437 BufferAllocation::Index index) {
438 return const_cast<BufferAllocation*>(&GetAllocation(index));
439 }
440
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const441 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
442 const ShapeIndex& index) const {
443 return absl::c_any_of(
444 dataflow_analysis().GetValueSet(instruction, index).values(),
445 IsKeyIn(allocation_index_for_value_));
446 }
447
HasTopLevelAllocation(const HloInstruction * instruction) const448 bool BufferAssignment::HasTopLevelAllocation(
449 const HloInstruction* instruction) const {
450 return HasAllocationAt(instruction, /*index=*/{});
451 }
452
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const453 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
454 const HloInstruction* instruction, const ShapeIndex& index) const {
455 VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
456 << index << "]";
457 BufferAllocation::Slice result;
458 for (const HloValue* value :
459 dataflow_analysis().GetValueSet(instruction, index).values()) {
460 VLOG(3) << "Examining value " << *value;
461 if (HasAllocation(*value)) {
462 VLOG(3) << "Has allocation";
463 const BufferAllocation::Slice slice =
464 GetAssignedAllocation(*value).GetSlice(*value);
465 if (result.allocation() == nullptr) {
466 result = slice;
467 } else if (result != slice) {
468 return FailedPrecondition(
469 "BufferAllocation::Slice for instruction %s at index %s cannot "
470 "be determined at compile-time.",
471 instruction->name(), index.ToString());
472 }
473 } else {
474 VLOG(3) << "No allocation";
475 }
476 }
477 if (result.allocation() == nullptr) {
478 return FailedPrecondition(
479 "BufferAllocation::Slice not assigned for instruction %s at index %s",
480 instruction->name(), index.ToString());
481 }
482 return result;
483 }
484
GetUniqueTopLevelSlice(const HloInstruction * instruction) const485 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
486 const HloInstruction* instruction) const {
487 return GetUniqueSlice(instruction, /*index=*/{});
488 }
489
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const490 bool BufferAssignment::SharesSliceAtIndex(
491 const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
492 const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
493 return GetUniqueSlice(hlo_a, shape_index_a).value() ==
494 GetUniqueSlice(hlo_b, shape_index_b).value();
495 }
496
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const497 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
498 const HloInstruction* hlo_b) const {
499 using SliceSet = flat_hash_set<BufferAllocation::Slice>;
500 // Gets the slices all of instr's subshapes. If any subshape doesn't have an
501 // assigned slice, returns the empty set.
502 auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
503 SliceSet slices;
504 Status status = ShapeUtil::ForEachSubshapeWithStatus(
505 instr->shape(),
506 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
507 auto shape_slices = GetAllSlices(instr, index);
508 if (shape_slices.empty()) {
509 return InvalidArgument("No slices assigned to part of instr.");
510 }
511 slices.insert(shape_slices.begin(), shape_slices.end());
512 return OkStatus();
513 });
514 if (!status.ok()) {
515 return {};
516 }
517 return slices;
518 };
519
520 SliceSet slices_a = collect_slices(hlo_a);
521 SliceSet slices_b = collect_slices(hlo_b);
522 // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
523 // didn't return the empty set) for both HLOs, and the two resulting sets of
524 // slices are disjoint.
525 return !slices_a.empty() && !slices_b.empty() &&
526 absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
527 return slices_b.contains(slice);
528 });
529 }
530
531 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const532 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
533 return GetUniqueTopLevelSlice(
534 module_->entry_computation()->root_instruction());
535 }
536
NewEmptyAllocation(int64_t size,LogicalBuffer::Color color)537 BufferAllocation* BufferAssignment::NewEmptyAllocation(
538 int64_t size, LogicalBuffer::Color color) {
539 BufferAllocation::Index index = allocations_.size();
540 allocations_.emplace_back(index, size, color);
541 BufferAllocation* allocation = &allocations_.back();
542 return allocation;
543 }
544
NewAllocation(const HloBuffer & buffer,int64_t size)545 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
546 int64_t size) {
547 BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
548 AddAssignment(allocation, buffer, /*offset=*/0, size);
549 allocation->peak_buffers_.push_back(buffer.values()[0]);
550 return allocation;
551 }
552
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64_t offset,int64_t size)553 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
554 const HloBuffer& buffer, int64_t offset,
555 int64_t size) {
556 CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
557 << "Non-reusable allocation already assigned a buffer: "
558 << allocation->ToString();
559
560 for (const HloValue* buffer_value : buffer.values()) {
561 CHECK(!allocation_index_for_value_.contains(buffer_value))
562 << "BufferValue " << buffer_value << " already has an allocation.";
563 allocation->AddAssignment(*buffer_value, offset, size);
564 allocation_index_for_value_[buffer_value] = allocation->index();
565 }
566
567 if (alias_analysis().BufferLivesOut(buffer)) {
568 VLOG(3) << "HloBuffer lives out: " << buffer.ToString();
569 VLOG(3) << "Set maybe live out: " << allocation->ToString();
570 allocation->set_maybe_live_out(true);
571 }
572 }
573
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64_t offset,int64_t size)574 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
575 const HloValue& value, int64_t offset,
576 int64_t size) {
577 allocation->AddAssignment(value, offset, size);
578 allocation_index_for_value_[&value] = allocation->index();
579 const HloValue& hlo_value =
580 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
581 if (alias_analysis().ValueLivesOut(hlo_value)) {
582 VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
583 VLOG(3) << "Set maybe live out: " << allocation->ToString();
584 allocation->set_maybe_live_out(true);
585 }
586 }
587
588 // Combines allocations of temporary buffers of the same color into one big
589 // BufferAllocation.
CombineTempAllocations()590 void BufferAssignment::CombineTempAllocations() {
591 VLOG(1) << "CombineTempAllocations()";
592 // Stores the combined allocations.
593 std::deque<BufferAllocation> combined_allocations;
594 // Holds the pointer to a combined allocation of each color, if any.
595 flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
596
597 // Move all temp allocations into a single run at the end of the allocations
598 // vector.
599 const auto first_temp_it =
600 std::partition(allocations_.begin(), allocations_.end(),
601 [](const BufferAllocation& allocation) {
602 return !allocation.IsPreallocatedTempBuffer();
603 });
604
605 // Walk over the run of temp allocations, collecting the allocations belonging
606 // to the same color.
607 if (first_temp_it != allocations_.end()) {
608 for (auto it = first_temp_it; it != allocations_.end(); ++it) {
609 BufferAllocation& temp_allocation = *it;
610 BufferValue::Color color = temp_allocation.color();
611 auto combined_it = combined_allocation_map.find(color);
612 if (combined_it == combined_allocation_map.end()) {
613 // We have found the first temp allocation of this color. Collect
614 // the other temp allocations of the same color into it subject to the
615 // size constraint.
616 VLOG(1) << "Combined temp allocation for color " << color
617 << " is: " << temp_allocation;
618 combined_allocations.emplace_back(temp_allocation);
619 combined_allocation_map.emplace(color, &combined_allocations.back());
620 continue;
621 }
622 if (combined_it->second->size() + it->size() >=
623 multiheap_size_constraint_per_heap_) {
624 // We cannot put more into the current combined_it. So, appoint a new
625 // combined_it.
626 VLOG(1) << "Due to size constraint, reset temp allocation for color "
627 << color << " to: " << temp_allocation;
628 combined_allocations.emplace_back(temp_allocation);
629 combined_allocation_map.emplace(color, &combined_allocations.back());
630 continue;
631 }
632
633 BufferAllocation* combined_allocation = combined_it->second;
634 VLOG(1) << "Combined allocation absorbing temp allocation: "
635 << temp_allocation;
636
637 // Each temp allocation is placed end-to-end, accounting for alignment.
638 // The offset of each buffer in the combined allocation is computed from
639 // the base offset of the allocation.
640 int64_t alignment = color_alignment_(color);
641 const int64_t base = RoundUpTo(combined_allocation->size(), alignment);
642 combined_allocation->set_size(base + temp_allocation.size());
643 for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
644 const HloValue* value = buffer_offset_size.first;
645 const int64_t offset = buffer_offset_size.second.offset;
646 const int64_t size = buffer_offset_size.second.size;
647 combined_allocation->AddAssignment(*value, base + offset, size);
648 }
649 if (!temp_allocation.HeapTraces().empty()) {
650 CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
651 combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
652 }
653
654 combined_allocation->peak_buffers_.insert(
655 combined_allocation->peak_buffers_.end(),
656 temp_allocation.peak_buffers_.begin(),
657 temp_allocation.peak_buffers_.end());
658 }
659 // Replace all existing temporary allocations with the new combined
660 // allocations.
661 allocations_.erase(first_temp_it, allocations_.end());
662 for (BufferAllocation& combined : combined_allocations) {
663 temp_allocation_total_size_ += combined.size();
664 allocations_.push_back(std::move(combined));
665 }
666 }
667
668 // Update allocation indices to their new positions.
669 allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
670 allocation_index_for_value_.end());
671 for (size_t index = 0; index < allocations_.size(); ++index) {
672 BufferAllocation* allocation = &allocations_[index];
673 allocation->set_index(index);
674 for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
675 const HloValue* value = buffer_offset_size.first;
676 allocation_index_for_value_[value] = index;
677 }
678 }
679 }
680
ComputeSummaryStats()681 Status BufferAssignment::ComputeSummaryStats() {
682 for (auto& allocation : Allocations()) {
683 if (allocation.is_entry_computation_parameter()) {
684 stats_.parameter_allocation_count++;
685 stats_.parameter_allocation_bytes += allocation.size();
686 }
687 if (allocation.is_constant()) {
688 stats_.constant_allocation_count++;
689 stats_.constant_allocation_bytes += allocation.size();
690 }
691 if (allocation.maybe_live_out()) {
692 stats_.maybe_live_out_allocation_count++;
693 stats_.maybe_live_out_allocation_bytes += allocation.size();
694 }
695 if (allocation.IsPreallocatedTempBuffer()) {
696 stats_.preallocated_temp_allocation_count++;
697 stats_.preallocated_temp_allocation_bytes += allocation.size();
698 }
699 stats_.total_allocation_count++;
700 stats_.total_allocation_bytes += allocation.size();
701 }
702
703 // Only compute total fragmentation if all computations have schedules.
704 HloSchedule schedule(module_);
705 bool schedule_complete = true;
706 for (const auto& computation : module_->computations()) {
707 if (!computation->IsFusionComputation()) {
708 const HloInstructionSequence* sequence =
709 hlo_ordering().SequentialOrder(*computation);
710 if (sequence == nullptr) {
711 schedule_complete = false;
712 } else {
713 schedule.set_sequence(computation, *sequence);
714 }
715 }
716 }
717 if (schedule_complete) {
718 TF_RETURN_IF_ERROR(schedule.Verify());
719 TF_ASSIGN_OR_RETURN(
720 const int64_t min_size,
721 HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
722 stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
723 }
724
725 return OkStatus();
726 }
727
ToString() const728 std::string BufferAssignment::Stats::ToString() const {
729 std::string s;
730 StrAppendFormat(&s, "BufferAssignment stats:\n");
731 StrAppendFormat(&s, " parameter allocation: %10s\n",
732 HumanReadableNumBytes(parameter_allocation_bytes));
733 StrAppendFormat(&s, " constant allocation: %10s\n",
734 HumanReadableNumBytes(constant_allocation_bytes));
735 StrAppendFormat(&s, " maybe_live_out allocation: %10s\n",
736 HumanReadableNumBytes(maybe_live_out_allocation_bytes));
737 StrAppendFormat(&s, " preallocated temp allocation: %10s\n",
738 HumanReadableNumBytes(preallocated_temp_allocation_bytes));
739 if (preallocated_temp_fragmentation_bytes >= 0) {
740 const double percent = 100. * preallocated_temp_fragmentation_bytes /
741 preallocated_temp_allocation_bytes;
742 StrAppendFormat(
743 &s, " preallocated temp fragmentation: %10s (%.2f%%)\n",
744 HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
745 }
746 StrAppendFormat(&s, " total allocation: %10s\n",
747 HumanReadableNumBytes(total_allocation_bytes));
748 if (total_fragmentation_bytes >= 0) {
749 const double percent =
750 100. * total_fragmentation_bytes / total_allocation_bytes;
751 StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n",
752 HumanReadableNumBytes(total_fragmentation_bytes), percent);
753 }
754 return s;
755 }
756
ToString() const757 std::string BufferAssignment::ToString() const {
758 std::string output;
759 absl::StrAppend(&output, "BufferAssignment:\n");
760 std::vector<const HloValue*> used_values;
761 int64_t total_size = 0;
762 for (auto& allocation : allocations_) {
763 total_size += allocation.size();
764 absl::StrAppend(&output, allocation.ToString());
765 for (const auto& p : allocation.assigned_buffers()) {
766 used_values.push_back(p.first);
767 }
768 }
769 absl::StrAppend(&output, "\nTotal bytes used: ", total_size, " (",
770 HumanReadableNumBytes(total_size), ")\n");
771 absl::StrAppend(&output, "\nUsed values:\n");
772 absl::c_sort(used_values, &CompareHloValuesById);
773 for (const HloValue* value : used_values) {
774 absl::StrAppend(&output, value->ToString());
775 }
776 return output;
777 }
778
779 // Returns the largest k buffers present at the point of peak memory usage
780 // across allocations as a vector of pairs with their corresponding sizes.
TopKPeakBuffers(uint64_t k,const std::vector<BufferAllocation> allocations)781 std::vector<std::pair<int64_t, const HloValue*>> TopKPeakBuffers(
782 uint64_t k, const std::vector<BufferAllocation> allocations) {
783 absl::btree_multimap<int64_t, const HloValue*> topk;
784 for (const BufferAllocation& allocation : allocations) {
785 for (const HloValue* value : allocation.PeakMemoryLogicalBuffers()) {
786 int64_t size = allocation.assigned_buffers().at(value).size;
787 if (topk.size() < k) {
788 topk.insert({size, value});
789 } else {
790 auto it = topk.begin();
791 if (size > it->first) {
792 topk.erase(it);
793 topk.insert({size, value});
794 }
795 }
796 }
797 }
798
799 // map will iterate smallest first, so reverse it.
800 std::vector<std::pair<int64_t, const HloValue*>> topk_descending;
801 topk_descending.reserve(topk.size());
802 absl::c_reverse_copy(topk, std::back_inserter(topk_descending));
803 return topk_descending;
804 }
805
ToVerboseString() const806 std::string BufferAssignment::ToVerboseString() const {
807 // TODO(loreno): make this tunable via flag.
808 const size_t kMaxBuffersToShow = 15;
809 std::string output =
810 absl::StrCat("BufferAssignment OOM Debugging.\n", stats_.ToString());
811
812 std::vector<std::pair<int64_t, const HloValue*>> peak_buffers =
813 TopKPeakBuffers(kMaxBuffersToShow, allocations_);
814 std::vector<std::string> buf_strs;
815 for (size_t i = 0; i < std::min(kMaxBuffersToShow, peak_buffers.size());
816 ++i) {
817 const HloValue* value = peak_buffers[i].second;
818 const HloInstruction* instr = value->instruction();
819 int64_t size = peak_buffers[i].first;
820 buf_strs.push_back(absl::StrCat("\n\tBuffer ", i + 1, ":\n\t\tSize: ",
821 xla::HumanReadableNumBytes(size)));
822 if (!instr->metadata().op_name().empty()) {
823 buf_strs.push_back(absl::StrCat(
824 "\n\t\tOperator: ", xla::OpMetadataToString(instr->metadata())));
825 }
826 if (instr->opcode() == HloOpcode::kParameter &&
827 (instr->parent() == instr->parent()->parent()->entry_computation())) {
828 // Special case on entry parameters as they sometimes have hundreds of
829 // indices in their shapes, and overwhelm the output.
830 buf_strs.push_back(absl::StrCat(
831 "\n\t\tEntry Parameter Subshape: ",
832 ShapeUtil::GetSubshape(instr->shape(), value->index()).ToString()));
833 } else {
834 // TODO(loreno): change this to a truncated string of the instruction.
835 buf_strs.push_back(
836 absl::StrCat("\n\t\tXLA Label: ", HloOpcodeString(instr->opcode()),
837 "\n\t\tShape: ", value->shape().ToString()));
838 }
839 buf_strs.push_back("\n\t\t==========================\n");
840 }
841 absl::StrAppend(&output, "Peak buffers:", absl::StrJoin(buf_strs, ""));
842 return output;
843 }
844
BufferInfoString() const845 std::string BufferAssignment::BufferInfoString() const {
846 std::string binfo;
847 // Columns in buffer information:
848 // buffer_id: int. This value can be used to match the allocation in
849 // allocation information.
850 // buffer_name: string.
851 // offset: int. Starting position of the buffer in the memory space.
852 // size: int. Size of the buffer in bytes.
853 // definition_time: int. Position in the schedule where the buffer starts
854 // being live (inclusive).
855 // end_time: int. Position in the schedule where the buffer stops being live
856 // (exclusive).
857 // num_uses: int. Number of uses of the buffer.
858 // use_names: string. This is a semicolon-separated list of string
859 // representation of uses.
860 // Append the column names.
861 absl::StrAppend(&binfo,
862 "buffer_id,buffer_name,offset,size,"
863 "definition_time,end_time,num_uses,use_times,use_names\n");
864 const HloLiveRange& live_ranges = hlo_live_range();
865 const auto& instruction_schedule = live_ranges.instruction_schedule();
866 const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
867 // Sort the buffers by Id.
868 std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
869 for (const BufferAllocation& allocation : allocations_) {
870 absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
871 }
872 absl::c_sort(
873 buffers,
874 [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
875 const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
876 return b1.first->id() < b2.first->id();
877 });
878 for (const auto& buffer_pair : buffers) {
879 const HloValue& buffer = *buffer_pair.first;
880 const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
881 if (!buffer_live_ranges.contains(&buffer)) {
882 continue;
883 }
884 // Ordering uses by their use position.
885 std::vector<std::pair<int64_t, std::string>> uses;
886 uses.reserve(buffer.GetUses().size());
887 for (const HloUse& use : buffer.GetUses()) {
888 uses.emplace_back(instruction_schedule.at(use.instruction),
889 use.ToString());
890 }
891 absl::c_sort(uses);
892 std::vector<int64_t> use_positions;
893 std::vector<std::string> use_names;
894 use_positions.reserve(uses.size());
895 use_names.reserve(uses.size());
896 for (const auto& use : uses) {
897 use_positions.push_back(use.first);
898 use_names.push_back(use.second);
899 }
900 const int64_t definition_time =
901 instruction_schedule.at(buffer.defining_position().instruction);
902 const int64_t end_t = buffer_live_ranges.at(&buffer).end;
903 absl::StrAppend(&binfo, buffer.id(), ",");
904 absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
905 absl::StrAppend(&binfo, offset_size.offset, ",");
906 absl::StrAppend(&binfo, offset_size.size, ",");
907 absl::StrAppend(&binfo, definition_time, ",");
908 absl::StrAppend(&binfo, end_t, ",");
909 absl::StrAppend(&binfo, use_positions.size(), ",");
910 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
911 absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
912 absl::StrAppend(&binfo, "\n");
913 }
914 return binfo;
915 }
916
ToProto() const917 BufferAssignmentProto BufferAssignment::ToProto() const {
918 BufferAssignmentProto proto;
919 // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
920 // because we need to do the HasAllocation check for each buffer. Otherwise
921 // the buffer_size_ call might fail for some backends.
922 const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
923 for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
924 auto& value = dataflow.values().at(id);
925 if (HasAllocation(*value)) {
926 LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
927 proto.add_logical_buffers()->Swap(&proto_buffer);
928
929 // Fill buffer aliases.
930 for (const HloValue* alias :
931 alias_analysis().GetBufferContainingValue(*value).values()) {
932 if (alias->instruction() == value->instruction() &&
933 alias->index() == value->index()) {
934 continue; // skip self-aliases
935 }
936 BufferAssignmentProto::BufferAlias* proto_alias =
937 proto.add_buffer_aliases();
938 LogicalBufferProto::Location proto_alias_location =
939 BufferValue::ToLocationProto(*alias->instruction(), alias->index());
940 proto_alias->set_source_buffer_id(value->id());
941 proto_alias->mutable_location()->Swap(&proto_alias_location);
942 }
943 }
944 }
945 for (const BufferAllocation& allocation : Allocations()) {
946 BufferAllocationProto proto_allocation = allocation.ToProto();
947 proto.add_buffer_allocations()->Swap(&proto_allocation);
948 for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
949 *proto.add_heap_simulator_traces() = heap_trace;
950 }
951 }
952 return proto;
953 }
954
955 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,std::optional<BufferAssigner::MustNotLiveOut> must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)956 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
957 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
958 BufferValue::SizeFunction buffer_size,
959 LogicalBuffer::AlignmentFunction color_alignment,
960 bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
961 std::optional<BufferAssigner::MustNotLiveOut> must_not_live_out,
962 HloDataflowAnalysis::CanShareBuffer can_share_buffer,
963 std::unique_ptr<PresetAssignments> preset_assignments) {
964 BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
965 must_not_live_out, std::move(preset_assignments));
966 return assigner.CreateAssignment(
967 module, std::move(hlo_ordering), std::move(buffer_size),
968 std::move(color_alignment), std::move(can_share_buffer));
969 }
970
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)971 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
972 const HloValue* buffer2,
973 BufferAssignment* assignment) {
974 CHECK((assignment->hlo_live_range().total_order_scheduled()));
975 const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
976
977 const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
978
979 auto live_range_it1 = buffer_live_ranges.find(buffer1);
980 CHECK(live_range_it1 != buffer_live_ranges.end())
981 << "Buffer doesn't have a proper live range:" << buffer1->ToString();
982
983 auto live_range_it2 = buffer_live_ranges.find(buffer2);
984 CHECK(live_range_it2 != buffer_live_ranges.end())
985 << "Buffer doesn't have a proper live range:" << buffer2->ToString();
986
987 // Check if a user value can share the same buffer as its operand.
988 auto can_share_as_operand =
989 [&assignment](const HloValue* user_value, const HloValue* operand_value,
990 const HloLiveRange::TimeBound& operand_live_range) {
991 // An hlo value can hold multiple instructions during its life time. We
992 // only look at the last instruction and check if it can be shared with
993 // the operand.
994 HloPosition operand_end_position = operand_live_range.end_position;
995 return user_value->instruction()->opcode() != HloOpcode::kCopy &&
996 user_value->instruction()->IsUserOf(
997 operand_end_position.instruction) &&
998 assignment->dataflow_analysis().CanShareOperandBufferWithUser(
999 operand_end_position.instruction, operand_end_position.index,
1000 user_value->instruction(), user_value->index());
1001 };
1002
1003 const auto& live_range_1 = live_range_it1->second;
1004 const auto& live_range_2 = live_range_it2->second;
1005
1006 if (!(live_range_1.start > live_range_2.end ||
1007 live_range_2.start > live_range_1.end)) {
1008 if (live_range_1.end == live_range_2.start) {
1009 auto operand_value = buffer1;
1010 auto user_value = buffer2;
1011 if (!can_share_as_operand(user_value, operand_value, live_range_1)) {
1012 VLOG(4) << "End of live range of " << buffer1->ToShortString()
1013 << " is equal to the start of live range of "
1014 << buffer2->ToShortString() << ", buffer cannot be shared.";
1015 return true;
1016 }
1017 } else if (live_range_2.end == live_range_1.start) {
1018 auto operand_value = buffer2;
1019 auto user_value = buffer1;
1020 if (!can_share_as_operand(user_value, operand_value, live_range_2)) {
1021 VLOG(4) << "End of live range of " << buffer2->ToShortString()
1022 << " is equal to the start of live range of "
1023 << buffer1->ToShortString() << ", buffer cannot be shared.";
1024 return true;
1025 }
1026 } else {
1027 VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
1028 << *buffer2;
1029 VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
1030 VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
1031 VLOG(4) << "live_range_2.start" << live_range_2.start;
1032 VLOG(4) << "live_range_2.end" << live_range_2.end;
1033 return true;
1034 }
1035 }
1036 return false;
1037 }
1038
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)1039 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
1040 const HloBuffer& hlo_buffer,
1041 BufferAssignment* assignment) {
1042 CHECK(!assignment->HasAllocation(hlo_buffer))
1043 << "buffer " << hlo_buffer << " already has an allocation assigned.";
1044
1045 VLOG(4) << "Trying to assign " << hlo_buffer << " size "
1046 << assignment->HloBufferSize(hlo_buffer)
1047 << " to allocation: " << *allocation;
1048
1049 if (hlo_buffer.color() != allocation->color()) {
1050 VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
1051 << " and allocation has color " << allocation->color() << ".";
1052 return false;
1053 }
1054
1055 if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1056 VLOG(4) << "Can't assign: buffer is larger than allocation ("
1057 << assignment->HloBufferSize(hlo_buffer) << " > "
1058 << allocation->size() << ")";
1059 return false;
1060 }
1061
1062 if (allocation->is_readonly()) {
1063 VLOG(4) << "Can't assign: allocation is readonly";
1064 return false;
1065 }
1066
1067 if (must_not_live_out_.has_value()) {
1068 if (allocation->maybe_live_out()) {
1069 // If a buffer maybe live out, the allocation cannot contain any node
1070 // where must_not_live_out_ returns true.
1071 for (const HloValue* value : hlo_buffer.values()) {
1072 if ((*must_not_live_out_)(value->instruction(), value->index())) {
1073 VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1074 << " cannot live out of the module";
1075 return false;
1076 }
1077 }
1078 }
1079 // The above check is not enough -- There could be the case where an
1080 // allocation can be not live out and contains an instruction where
1081 // must_not_live_out_ returns true, but assigning a live out buffer to
1082 // that allocation makes the allocation live out and also contain an
1083 // instruction where ust_not_live_out_ returns true.
1084 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1085 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1086 const HloValue* value = buffer_offset_size.first;
1087 if ((*must_not_live_out_)(value->instruction(), value->index())) {
1088 VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1089 << " cannot live out of the module";
1090 return false;
1091 }
1092 }
1093 }
1094 }
1095
1096 if (!allocation->is_reusable()) {
1097 VLOG(4) << "Can't assign: allocation is not reusable";
1098 return false;
1099 }
1100
1101 for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1102 // Pairwise compare.
1103 const HloValue& assigned_buffer =
1104 *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1105 for (const HloValue* new_value : hlo_buffer.values()) {
1106 if (assignment->hlo_live_range().total_order_scheduled()) {
1107 if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1108 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1109 << " live range interferes with "
1110 << new_value->ToShortString();
1111 return false;
1112 }
1113 } else if (assignment->hlo_ordering().MayInterfere(
1114 assigned_buffer, *new_value,
1115 assignment->dataflow_analysis())) {
1116 // Fallback to partial order based interference detection (slower) when
1117 // we don't have a total order scheduled module.
1118 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1119 << " may interfere with " << new_value->ToShortString();
1120 return false;
1121 }
1122
1123 // Copy instruction don't share a buffer with their input operand.
1124 if (new_value->instruction()->opcode() == HloOpcode::kCopy) {
1125 for (const HloPosition& assigned_buffer_position :
1126 assigned_buffer.positions()) {
1127 if (new_value->instruction()->IsUserOf(
1128 assigned_buffer_position.instruction)) {
1129 VLOG(4) << "Can't assign: assignee " << assigned_buffer
1130 << " is used at copy instruction "
1131 << new_value->ToShortString();
1132 return false;
1133 }
1134 }
1135 }
1136 }
1137 }
1138
1139 // If the buffer is live out of the computation then it should only be
1140 // assigned a buffer which exactly fits the result to avoid wasting memory
1141 // (result buffers can have arbitrary lifetimes).
1142 if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1143 allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1144 VLOG(4) << "Can't assign: buffer " << hlo_buffer
1145 << "is live out and size not the same as allocation";
1146 return false;
1147 }
1148
1149 assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1150 assignment->HloBufferSize(hlo_buffer));
1151 return true;
1152 } // namespace xla
1153
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1154 Status BufferAssigner::AssignSingleHloBuffer(
1155 const HloBuffer* hlo_buffer, bool is_thread_local,
1156 absl::flat_hash_map<const HloComputation*,
1157 absl::flat_hash_set<const HloValue*>>*
1158 buffers_to_assign_sequentially,
1159 std::vector<BufferAllocation::Index>* allocation_indices,
1160 BufferAssignment* assignment) {
1161 const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer);
1162 for (const HloValue* value : hlo_buffer->values()) {
1163 if (value->instruction()->opcode() == HloOpcode::kConstant) {
1164 if (allocate_buffers_for_constants_) {
1165 BufferAllocation* allocation =
1166 assignment->NewAllocation(*hlo_buffer, buffer_size);
1167 allocation->set_constant(true);
1168 VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1169 << *hlo_buffer << " value ptr: " << value;
1170 }
1171 VLOG(3) << "Not allocating buffer for constant";
1172 return OkStatus();
1173 }
1174
1175 const HloInstruction* instruction = value->instruction();
1176 const bool is_entry_parameter =
1177 instruction->opcode() == HloOpcode::kParameter &&
1178 instruction->parent() ==
1179 instruction->parent()->parent()->entry_computation();
1180
1181 if (is_entry_parameter) {
1182 bool parameter_has_alias =
1183 assignment->module().input_output_alias_config().ParameterHasAlias(
1184 instruction->parameter_number(), value->index());
1185 // If the hlo buffer is part of an external parameter, creates a new
1186 // allocation and sets its parameter number. Parameters of non-entry
1187 // computations do not need special allocations because they live inside
1188 // callers.
1189 BufferAllocation* allocation =
1190 assignment->NewAllocation(*hlo_buffer, buffer_size);
1191
1192 allocation->set_entry_computation_parameter(
1193 instruction->parameter_number(), value->index(), parameter_has_alias);
1194 if (parameter_has_alias) {
1195 allocation_indices->push_back(allocation->index());
1196 }
1197 VLOG(3) << "New allocation #" << allocation->index()
1198 << " marked as entry computation parameter: " << *hlo_buffer;
1199 return OkStatus();
1200 }
1201 }
1202
1203 if (is_thread_local) {
1204 BufferAllocation* allocation =
1205 assignment->NewAllocation(*hlo_buffer, buffer_size);
1206 allocation->set_is_thread_local(true);
1207 VLOG(3) << "New allocation #" << allocation->index()
1208 << " for thread-local: " << *hlo_buffer;
1209 return OkStatus();
1210 }
1211
1212 for (const HloValue* value : hlo_buffer->values()) {
1213 if (value->shape().IsTuple()) {
1214 BufferAllocation* allocation =
1215 assignment->NewAllocation(*hlo_buffer, buffer_size);
1216 allocation->set_is_tuple(true);
1217 VLOG(3) << "New allocation #" << allocation->index()
1218 << " for tuple-shaped buffer: " << *hlo_buffer;
1219 return OkStatus();
1220 }
1221
1222 if (value->IsTopLevel() && !value->IsTuple()) {
1223 const HloInstruction* instruction = value->instruction();
1224 for (auto* operand : instruction->operands()) {
1225 for (const auto& operand_slice :
1226 assignment->GetAllSlices(operand, /*index=*/{})) {
1227 BufferAllocation* allocation =
1228 assignment->GetMutableAllocation(operand_slice.index());
1229 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1230 VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1231 << " for: " << *hlo_buffer;
1232 return OkStatus();
1233 }
1234 }
1235 }
1236 }
1237 }
1238
1239 // Find the smallest buffer which can be reused iterating from end of
1240 // allocation_indices (smallest) to beginning (largest).
1241 for (int allocation_index = allocation_indices->size() - 1;
1242 allocation_index >= 0; allocation_index--) {
1243 BufferAllocation* allocation = assignment->GetMutableAllocation(
1244 allocation_indices->at(allocation_index));
1245 if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1246 VLOG(3) << "Reusing allocation #" << allocation->index()
1247 << " for: " << *hlo_buffer;
1248 return OkStatus();
1249 }
1250 }
1251
1252 if (!assignment->HasAllocation(*hlo_buffer) &&
1253 !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1254 bool all_computations_have_sequential_order = true;
1255 for (const HloValue* hlo_value : hlo_buffer->values()) {
1256 HloComputation* computation = hlo_value->instruction()->parent();
1257 const bool has_sequential_order =
1258 assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1259 all_computations_have_sequential_order &= has_sequential_order;
1260 }
1261
1262 if (all_computations_have_sequential_order) {
1263 for (const HloValue* hlo_value : hlo_buffer->values()) {
1264 HloComputation* computation = hlo_value->instruction()->parent();
1265 // There is a sequential instruction ordering, so we delay assignment
1266 // of temp buffers until after the loop. We do this right before we
1267 // decide to create a new allocation, to ensure we've exhausted all
1268 // the buffer re-use cases above.
1269 //
1270 // Entry parameters and thread local buffers were already handled
1271 // earlier in this loop iteration. See
1272 // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1273 // temp buffers.
1274 (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1275 VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1276 }
1277 return OkStatus();
1278 }
1279 }
1280
1281 if (!assignment->HasAllocation(*hlo_buffer)) {
1282 BufferAllocation* allocation =
1283 assignment->NewAllocation(*hlo_buffer, buffer_size);
1284 allocation_indices->push_back(allocation->index());
1285 VLOG(3) << "New allocation #" << allocation->index()
1286 << " for: " << *hlo_buffer;
1287 }
1288
1289 TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1290 return OkStatus();
1291 }
1292
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1293 Status BufferAssigner::AssignBuffersForComputations(
1294 const std::vector<const HloComputation*>& computations,
1295 bool is_thread_local,
1296 absl::flat_hash_map<const HloComputation*,
1297 absl::flat_hash_set<const HloValue*>>*
1298 buffers_to_assign_sequentially,
1299 BufferAssignment* assignment) {
1300 if (computations.empty()) {
1301 return OkStatus();
1302 }
1303 std::vector<const HloBuffer*> sorted_buffers;
1304
1305 // First assign the preset allocations.
1306 absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1307
1308 TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1309
1310 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1311
1312 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1313 // Skip if the buffer is already assigned since it had a preset allocation.
1314 if (preset_assigned_buffers.find(&buffer) !=
1315 preset_assigned_buffers.end()) {
1316 VLOG(3) << "Skip allocation for buffer: " << buffer;
1317 continue;
1318 }
1319 TF_RET_CHECK(!buffer.values().empty());
1320 const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1321 if (absl::c_linear_search(computations, comp)) {
1322 sorted_buffers.push_back(&buffer);
1323 }
1324 }
1325
1326 // Generate a post order sort of instructions for sorting of the
1327 // HloBuffers.
1328 flat_hash_map<const HloInstruction*, int> post_order_position;
1329 int position = 0;
1330 std::vector<const HloComputation*> reverse_post_order_computations;
1331 std::unique_ptr<CallGraph> call_graph =
1332 CallGraph::Build(computations[0]->parent());
1333 TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1334 if (absl::c_linear_search(computations, node.computation())) {
1335 reverse_post_order_computations.push_back(node.computation());
1336 }
1337 return OkStatus();
1338 }));
1339 absl::c_reverse(reverse_post_order_computations);
1340 for (auto* computation : reverse_post_order_computations) {
1341 for (auto* instruction : computation->MakeInstructionPostOrder()) {
1342 post_order_position.emplace(instruction, position);
1343 position++;
1344 }
1345 }
1346
1347 HloSchedule schedule(&assignment->module());
1348
1349 for (const HloComputation* computation : computations) {
1350 const HloInstructionSequence* instruction_sequence =
1351 assignment->hlo_ordering().SequentialOrder(*computation);
1352 const bool has_sequential_order = instruction_sequence != nullptr;
1353 if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1354 // Every sequential computation must get an entry in the
1355 // buffers_to_assign_sequentially map, even if we end up with an empty
1356 // set of buffers. This ensures we can correctly determine whether to
1357 // run whole-module heap simulation.
1358 buffers_to_assign_sequentially->emplace(computation,
1359 flat_hash_set<const HloValue*>());
1360
1361 schedule.set_sequence(computation, *instruction_sequence);
1362 }
1363 }
1364
1365 absl::c_stable_sort(
1366 sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1367 const HloBuffer* a, const HloBuffer* b) {
1368 // Primary sort is by decreasing buffer size.
1369 const int64_t a_size = assignment->HloBufferSize(*a);
1370 const int64_t b_size = assignment->HloBufferSize(*b);
1371 if (a_size != b_size) {
1372 return a_size > b_size; // use ">" for decreasing size.
1373 }
1374
1375 const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1376 const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1377 if (a_live_out != b_live_out) {
1378 return a_live_out;
1379 }
1380 auto compare = [&post_order_position](const HloValue* value1,
1381 const HloValue* value2) {
1382 return post_order_position.at(value1->instruction()) <
1383 post_order_position.at(value2->instruction());
1384 };
1385 const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1386 const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1387 return compare(a_min, b_min);
1388 });
1389
1390 std::vector<BufferAllocation::Index> allocation_indices;
1391
1392 for (const HloBuffer* buffer : sorted_buffers) {
1393 VLOG(3) << "=================================================";
1394 VLOG(3) << "Assigning buffer for " << *buffer;
1395 TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1396 buffers_to_assign_sequentially,
1397 &allocation_indices, assignment));
1398 }
1399 return OkStatus();
1400 }
1401
1402 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1403 BufferAssigner::SplitBuffersByColor(
1404 const flat_hash_set<const HloValue*>& buffers) {
1405 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1406 for (auto buffer : buffers) {
1407 color_map[buffer->color()].insert(buffer);
1408 }
1409 return color_map;
1410 }
1411
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1412 Status BufferAssigner::AssignPresetBuffers(
1413 absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1414 BufferAssignment* assignment) {
1415 if (!preset_assignments_) {
1416 return OkStatus();
1417 }
1418
1419 // Create an allocation for each preset color.
1420 absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1421 preset_allocations;
1422 for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1423 LogicalBuffer::Color color(color_and_info.first);
1424 auto inserted = preset_allocations.emplace(
1425 color,
1426 assignment->NewEmptyAllocation(color_and_info.second.size, color));
1427 BufferAllocation* inserted_allocation = inserted.first->second;
1428 inserted_allocation->AddHeapTrace(
1429 color_and_info.second.heap_simulator_trace);
1430 VLOG(3) << "Created preset buffer allocation "
1431 << inserted_allocation->index()
1432 << ", color: " << inserted_allocation->color()
1433 << ", size: " << inserted_allocation->size();
1434 }
1435
1436 const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1437
1438 for (auto& position_and_chunk : preset_assignments_->chunks()) {
1439 const HloPosition& defining_position = position_and_chunk.first;
1440 const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1441 defining_position.instruction, defining_position.index);
1442 for (const HloValue* value : buffer.values()) {
1443 VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1444 const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1445 auto preset_allocations_iter = preset_allocations.find(value->color());
1446 CHECK(preset_allocations_iter != preset_allocations.end())
1447 << "No preset value allocation for color " << value->color()
1448 << " for " << value->ToShortString() << " found.";
1449 preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1450 chunk.size);
1451 }
1452
1453 assigned_buffers->insert(&buffer);
1454 }
1455
1456 // Upon consumption of the preset assignments, delete it so that if this
1457 // method is called again, it does not assign the same buffers multiple times.
1458 preset_assignments_ = {};
1459
1460 return OkStatus();
1461 }
1462
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1463 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1464 const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1465 buffers_to_assign_sequentially,
1466 bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1467 // Run the sequence of instructions through the heap simulator. The
1468 // heuristic that seems to give the best results is lazy-best-fit, with all
1469 // runs of alloc / free calls sorted in decreasing size order.
1470 const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1471
1472 // Returns a heap algorithm that chooses the best result from several
1473 // algorithms.
1474 auto get_heap_algorithm = [&](int64_t alignment) {
1475 auto algorithms = std::make_unique<
1476 std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1477 algorithms->push_back(
1478 std::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1479 assignment->multiheap_size_constraint_per_heap(), alignment,
1480 GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1481 algorithms->push_back(
1482 std::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1483 assignment->multiheap_size_constraint_per_heap(), alignment,
1484 GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1485 return std::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1486 std::move(algorithms));
1487 };
1488
1489 if (run_whole_module_heap_simulation) {
1490 // Run the heap simulation over the whole module. This reduces memory
1491 // usage, since buffers for kCall, kWhile, and kConditional
1492 // sub-computations are only live for the duration of their calling
1493 // instructions.
1494 VLOG(1) << "Running whole-module heap simulation";
1495 HloSchedule schedule(&assignment->module());
1496 flat_hash_set<const HloValue*> all_buffers_to_assign;
1497 for (const auto& pair : buffers_to_assign_sequentially) {
1498 const HloComputation* computation = pair.first;
1499 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1500 const HloInstructionSequence* instruction_sequence =
1501 hlo_ordering.SequentialOrder(*computation);
1502 CHECK(instruction_sequence != nullptr) << computation->name();
1503 schedule.set_sequence(computation, *instruction_sequence);
1504 all_buffers_to_assign.insert(buffers_to_assign.begin(),
1505 buffers_to_assign.end());
1506 }
1507 auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1508 for (auto& single_colored_set : color_map) {
1509 auto color = single_colored_set.first;
1510 VLOG(2) << "Simulating heap for color " << color;
1511 int64_t alignment = assignment->color_alignment_(color);
1512 HeapSimulator::Options options;
1513 options.alloc_constants = allocate_buffers_for_constants_;
1514 options.buffers_to_assign = &single_colored_set.second;
1515
1516 TF_ASSIGN_OR_RETURN(
1517 HeapSimulator::Result<HloValue> result,
1518 HeapSimulator::Run(
1519 get_heap_algorithm(alignment), assignment->module(), schedule,
1520 assignment->alias_analysis(), assignment->buffer_size_, options));
1521 AssignBuffersFromHeapSimulator(result, assignment,
1522 single_colored_set.first);
1523 }
1524 } else {
1525 // Run the heap-simulation on a per-computation basis. Buffers for
1526 // sub-computations are assigned disjoint BufferAllocations, assuming the
1527 // worst-case that they may all be live concurrently.
1528 VLOG(1) << "Running per-computation heap simulation";
1529 for (const auto& pair : buffers_to_assign_sequentially) {
1530 const HloComputation* computation = pair.first;
1531 const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1532 const HloInstructionSequence* instruction_sequence =
1533 hlo_ordering.SequentialOrder(*computation);
1534 CHECK(instruction_sequence != nullptr) << computation->name();
1535 auto color_map = SplitBuffersByColor(buffers_to_assign);
1536 for (auto& single_colored_set : color_map) {
1537 auto color = single_colored_set.first;
1538 VLOG(2) << "Simulating heap for color " << color;
1539 int64_t alignment = assignment->color_alignment_(color);
1540 HeapSimulator::Options options;
1541 options.buffers_to_assign = &single_colored_set.second;
1542 TF_ASSIGN_OR_RETURN(
1543 HeapSimulator::Result<HloValue> result,
1544 HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1545 *instruction_sequence,
1546 assignment->alias_analysis(),
1547 assignment->buffer_size_, options));
1548 AssignBuffersFromHeapSimulator(result, assignment,
1549 single_colored_set.first);
1550 }
1551 }
1552 }
1553 return OkStatus();
1554 }
1555
1556 namespace {
1557 // Computes and returns the set of logical buffers live at the point of
1558 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1559 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1560 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1561 const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1562 // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1563 // buffers in this allocation.
1564 absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1565 absl::flat_hash_map<const HloValue*, int64_t> buffer_sizes;
1566 for (const auto& pair : allocation.assigned_buffers()) {
1567 const HloValue* value = pair.first;
1568 const BufferAllocation::OffsetSize& offset_size = pair.second;
1569 id_to_value[value->id()] = value;
1570 buffer_sizes[value] = offset_size.size;
1571 }
1572 VLOG(1) << "Compute peak memory logical buffers";
1573
1574 // To properly account for shared buffers, we keep track of the number of
1575 // instances of the same shared buffer are currently live, their canonical ids
1576 // and the size we had returned when allocating the buffer so that we can
1577 // return the -size when freeing the buffer.
1578 absl::flat_hash_map<int64_t, int> num_outstanding_shared_buffers;
1579 absl::flat_hash_map<int64_t, int64_t> shared_canonical_ids;
1580 absl::flat_hash_map<int64_t, int64_t> allocated_sizes;
1581 // Returns how much the given event increases the total size of live
1582 // buffers. Can be negative.
1583 auto memory_delta = [&](const HeapSimulatorTrace::Event& event) -> int64_t {
1584 const HloValue* buffer = id_to_value.at(event.buffer_id());
1585 const int64_t buffer_size = buffer_sizes.at(buffer);
1586 if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1587 num_outstanding_shared_buffers[event.buffer_id()] = 1;
1588 allocated_sizes[event.buffer_id()] = buffer_size;
1589 return buffer_size;
1590 } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1591 shared_canonical_ids[event.buffer_id()] = event.share_with_canonical_id();
1592 if (++num_outstanding_shared_buffers[event.share_with_canonical_id()] ==
1593 1) {
1594 // This shared buffer is currently the only instance of the buffer with
1595 // the canonical id. So we return the buffer size.
1596 allocated_sizes[event.buffer_id()] = buffer_size;
1597 return buffer_size;
1598 }
1599 // There are multiple instances of this buffer, so return 0.
1600 allocated_sizes[event.buffer_id()] = 0;
1601 return 0;
1602 } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1603 auto shared_canonical_id_it =
1604 shared_canonical_ids.find(event.buffer_id());
1605 // Decrement the outstanding instances of this buffer and return the
1606 // -size.
1607 int64_t buffer_id = (shared_canonical_id_it == shared_canonical_ids.end())
1608 ? event.buffer_id()
1609 : shared_canonical_id_it->second;
1610 --num_outstanding_shared_buffers[buffer_id];
1611 return -1 * allocated_sizes[event.buffer_id()];
1612 }
1613 LOG(FATAL) << "Unknown event kind: " << event.kind();
1614 };
1615
1616 // First compute the size of the maximal live set.
1617 int64_t max_live_size = 0;
1618 int64_t live_size = 0;
1619 for (const auto& event : heap_trace.events()) {
1620 if (!id_to_value.contains(event.buffer_id())) {
1621 // Skip as the buffer associated with this trace event is not placed into
1622 // this allocation. This can happen when size constraints are given to the
1623 // heap simulator.
1624 continue;
1625 }
1626 live_size += memory_delta(event);
1627 if (max_live_size < live_size) {
1628 max_live_size = live_size;
1629 }
1630 }
1631
1632 // Next gather the set of logical buffers live at the earliest point of
1633 // maximal live set size.
1634 absl::flat_hash_set<const HloValue*> live_values;
1635 live_size = 0;
1636 num_outstanding_shared_buffers.clear();
1637 for (const auto& event : heap_trace.events()) {
1638 if (!id_to_value.contains(event.buffer_id())) {
1639 // Skip as the buffer associated with this trace event is not placed into
1640 // this allocation. This can happen when size constraints are given to the
1641 // heap simulator.
1642 continue;
1643 }
1644 const HloValue* value = id_to_value.at(event.buffer_id());
1645 int64_t delta = memory_delta(event);
1646 // To avoid including buffers that are aliases of each other to the peak
1647 // buffers list, only add the buffers that memory_delta returns non-zero
1648 // positive sizes. memory_delta returns 0 as the size for the buffer already
1649 // has a live alias of itself.
1650 if (delta > 0) {
1651 InsertOrDie(&live_values, value);
1652 } else if (delta < 0) {
1653 CHECK(ContainsKey(live_values, value));
1654 live_values.erase(value);
1655 }
1656 live_size += delta;
1657
1658 if (live_size == max_live_size) {
1659 break;
1660 }
1661 }
1662 CHECK_EQ(live_size, max_live_size);
1663
1664 std::vector<const HloValue*> live_values_vector;
1665 live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1666 live_values.end());
1667
1668 // Stabily sort the live buffers.
1669 absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1670 return a->id() < b->id();
1671 });
1672 VLOG(4) << "Peak memory buffer:";
1673 for (auto value : live_values_vector) {
1674 VLOG(4) << " " << value->ToString();
1675 }
1676 return live_values_vector;
1677 }
1678
1679 } // namespace
1680
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1681 void BufferAssigner::AssignBuffersFromHeapSimulator(
1682 const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1683 BufferValue::Color color) {
1684 if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1685 assignment->stats_.preallocated_temp_fragmentation_bytes =
1686 result.fragmentation_size;
1687 } else {
1688 assignment->stats_.preallocated_temp_fragmentation_bytes +=
1689 result.fragmentation_size;
1690 }
1691 VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1692
1693 // Iterate through heap_results. For each heap_result, create a new allocation
1694 // in `assignment`.
1695 for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1696 result.heap_results) {
1697 BufferAllocation* allocation =
1698 assignment->NewEmptyAllocation(heap_result.heap_size, color);
1699 for (const auto& buffer_chunk : heap_result.chunk_map) {
1700 const HloValue& value = *buffer_chunk.first;
1701 const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1702 assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1703 }
1704 allocation->peak_buffers_ =
1705 ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1706
1707 XLA_VLOG_LINES(2, allocation->ToString());
1708
1709 allocation->AddHeapTrace(result.debug_trace);
1710 }
1711 }
1712
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1713 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1714 const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1715 BufferValue::SizeFunction buffer_size,
1716 LogicalBuffer::AlignmentFunction color_alignment,
1717 HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1718 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1719 HloAliasAnalysis::Run(module, can_share_buffer));
1720
1721 // Set up a schedule for each computation.
1722 HloSchedule schedule(module);
1723 for (const HloComputation* computation : module->computations()) {
1724 const HloInstructionSequence* instruction_sequence =
1725 hlo_ordering->SequentialOrder(*computation);
1726 const bool has_sequential_order = instruction_sequence != nullptr;
1727 if (has_sequential_order) {
1728 schedule.set_sequence(computation, *instruction_sequence);
1729 }
1730 }
1731
1732 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1733 HloLiveRange::Run(schedule, *alias_analysis,
1734 module->entry_computation(), true));
1735
1736 VLOG(1) << "Assigning buffers to module " << module->name();
1737 XLA_VLOG_LINES(3, module->ToString());
1738 XLA_VLOG_LINES(3, alias_analysis->ToString());
1739 XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1740 VLOG(1) << "Number of buffers to assign: "
1741 << alias_analysis->buffers().size();
1742
1743 // Can't use std::make_unique because BufferAssignment constructor is
1744 // private.
1745 std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1746 module, std::move(hlo_ordering), std::move(buffer_size),
1747 std::move(color_alignment), std::move(alias_analysis),
1748 std::move(hlo_live_range)));
1749
1750 TF_RETURN_IF_ERROR(
1751 colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1752 VLOG(3) << "After coloring:";
1753 XLA_VLOG_LINES(3,
1754 assignment->alias_analysis().dataflow_analysis().ToString());
1755
1756 std::vector<const HloComputation*> thread_local_computations;
1757 std::vector<const HloComputation*> global_computations;
1758 TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1759 module, &thread_local_computations, &global_computations));
1760
1761 // First assign buffers for global computations. Temporary buffers for
1762 // sequential computations are collected in
1763 // 'buffers_to_assign_sequentially'.
1764 flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1765 buffers_to_assign_sequentially;
1766 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1767 global_computations,
1768 /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1769 assignment.get()));
1770 // Assign buffers with sequential ordering, if any. If all global
1771 // computations are sequential, we can run heap simulation on the whole
1772 // module, which reduces memory usage.
1773 const bool run_whole_module_heap_simulation =
1774 buffers_to_assign_sequentially.size() == global_computations.size();
1775 VLOG(2) << "Running whole module heap simulation: "
1776 << run_whole_module_heap_simulation;
1777 const int32_t multiheap_size_constraint_per_heap =
1778 module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1779 VLOG(2) << "Multiheap per heap size limit: "
1780 << multiheap_size_constraint_per_heap;
1781 TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1782 buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1783 assignment.get()));
1784
1785 std::vector<const HloComputation*> thread_local_computations_no_fusion;
1786 // Now assign buffers for thread-local computations. All LogicalBuffers get
1787 // their own BufferAllocation.
1788
1789 for (auto* computation : thread_local_computations) {
1790 TF_RET_CHECK(computation != module->entry_computation());
1791 if (computation->IsFusionComputation()) {
1792 continue;
1793 }
1794 thread_local_computations_no_fusion.push_back(computation);
1795 }
1796
1797 TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1798 thread_local_computations_no_fusion,
1799 /*is_thread_local=*/true,
1800 /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1801
1802 // Mark all buffers which may be live out of the entry computation as
1803 // "liveout".
1804 for (const HloBuffer* buffer :
1805 assignment->alias_analysis().LiveOutBuffers()) {
1806 VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1807 if (assignment->HasAllocation(*buffer)) {
1808 BufferAllocation* alloc =
1809 assignment->GetMutableAssignedAllocation(*buffer);
1810 alloc->set_maybe_live_out(true);
1811 VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1812 }
1813 }
1814
1815 // Combines allocations of temporary buffers into big BufferAllocations
1816 // subject to the buffer allocation size constraint. This can only be
1817 // performed after all buffers have been assigned, and after maybe_live_out
1818 // is marked, since it is used to determine whether an allocation contains
1819 // temporary buffers or not.
1820 assignment->CombineTempAllocations();
1821
1822 XLA_VLOG_LINES(2, assignment->ToString());
1823 TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1824 XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1825 VLOG(1) << "Buffer assignment done.";
1826 return std::move(assignment);
1827 }
1828
1829 } // namespace xla
1830