xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/buffer_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the data returned by the XLA buffer assignment packages.
17 
18 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
19 
20 #include <algorithm>
21 #include <deque>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <utility>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/btree_map.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "tensorflow/compiler/xla/map_util.h"
34 #include "tensorflow/compiler/xla/service/buffer_value_containers.h"
35 #include "tensorflow/compiler/xla/service/heap_simulator.h"
36 #include "tensorflow/compiler/xla/service/hlo.pb.h"
37 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
38 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
39 #include "tensorflow/compiler/xla/service/hlo_live_range.h"
40 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
41 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
42 #include "tensorflow/compiler/xla/service/hlo_value.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/strings/numbers.h"
49 
50 namespace xla {
51 namespace {
52 
53 using absl::flat_hash_map;
54 using absl::flat_hash_set;
55 using absl::StrAppend;
56 using absl::StrAppendFormat;
57 using memory_space_assignment::PresetAssignments;
58 using ::tensorflow::strings::HumanReadableNumBytes;
59 
60 // Given the interference map of a graph (the list of interfering node indices
61 // for each node), perform graph coloring such that interfering nodes are
62 // assigned to different colors. Returns the assigned color of the nodes, where
63 // the colors are represented as integer values [0, color_count).
ColorInterferenceGraph(const std::vector<std::vector<int64_t>> & interference_map)64 std::vector<int64_t> ColorInterferenceGraph(
65     const std::vector<std::vector<int64_t>>& interference_map) {
66   const int64_t node_count = interference_map.size();
67 
68   // Sort the nodes such that we assign nodes with more interference first. This
69   // relies on the common heuristic of assigning the most constrained node
70   // first, but it would be good to investigate other ordering heuristics too.
71   std::vector<int64_t> nodes(node_count);
72   std::iota(nodes.begin(), nodes.end(), 0);
73   absl::c_sort(nodes, [&interference_map](const int64_t i, const int64_t j) {
74     return interference_map[i].size() > interference_map[j].size();
75   });
76 
77   const int64_t kColorUnassigned = -1;
78   std::vector<int64_t> assigned_colors(node_count, kColorUnassigned);
79   for (int64_t node : nodes) {
80     // Mark the colors that are already assigned to the neighbors.
81     std::vector<bool> available_colors(node_count, true);
82     for (int64_t neighbor : interference_map[node]) {
83       int64_t color = assigned_colors[neighbor];
84       if (color != kColorUnassigned) {
85         available_colors[color] = false;
86       }
87     }
88 
89     // Find the color that is not yet assigned to the neighbors.
90     int64_t color = kColorUnassigned;
91     for (color = 0; color < available_colors.size(); ++color) {
92       if (available_colors[color]) {
93         break;
94       }
95     }
96     CHECK_NE(color, kColorUnassigned);
97     assigned_colors[node] = color;
98   }
99   return assigned_colors;
100 }
101 
102 }  // namespace
103 
GatherComputationsByAllocationType(const HloModule * module,std::vector<const HloComputation * > * thread_local_computations,std::vector<const HloComputation * > * global_computations)104 Status GatherComputationsByAllocationType(
105     const HloModule* module,
106     std::vector<const HloComputation*>* thread_local_computations,
107     std::vector<const HloComputation*>* global_computations) {
108   // Create a worklist of computations paired with whether the allocation must
109   // be thread-local.
110   std::deque<std::pair<const HloComputation*, bool>> worklist;
111   worklist.push_back(std::make_pair(module->entry_computation(),
112                                     /*is_thread_local*/ false));
113 
114   // Sets for quickly checking membership. Computations are returned in vectors
115   // for stable iteration.
116   flat_hash_set<const HloComputation*> thread_local_set;
117   flat_hash_set<const HloComputation*> global_set;
118 
119   while (!worklist.empty()) {
120     auto worklist_front = worklist.front();
121     worklist.pop_front();
122     const HloComputation* computation = worklist_front.first;
123     bool is_thread_local = worklist_front.second;
124     bool in_thread_local_set = thread_local_set.contains(computation);
125     bool in_global_set = global_set.contains(computation);
126 
127     // If the computation has already been added to the respective set, then
128     // nothing to do.
129     if ((is_thread_local && in_thread_local_set) ||
130         (!is_thread_local && in_global_set)) {
131       continue;
132     }
133 
134     // If the computation has already been added to the other set this is an
135     // error condition because the global call to the computation (eg,
136     // while/call) may return a reference to one of the thread-local buffers to
137     // the calling computation which will become a dangling reference when the
138     // thread-local is deallocated with the call return.
139     if ((is_thread_local && in_global_set) ||
140         (!is_thread_local && in_thread_local_set)) {
141       return InvalidArgument(
142           "computation %s has conflicting allocation requirements (global "
143           "and thread-local)",
144           computation->name());
145     }
146 
147     if (is_thread_local) {
148       thread_local_set.insert(computation);
149     } else {
150       global_set.insert(computation);
151     }
152 
153     for (auto* instruction : computation->instructions()) {
154       for (HloComputation* subcomputation :
155            instruction->called_computations()) {
156         switch (instruction->opcode()) {
157           case HloOpcode::kCall:
158           case HloOpcode::kConditional:
159           case HloOpcode::kWhile:
160           case HloOpcode::kAsyncStart:
161           case HloOpcode::kAsyncUpdate:
162           case HloOpcode::kAsyncDone:
163             // Call, conditional, while, and async operations must be called
164             // from a computation with global allocations as they may return
165             // references to buffers inside the called computation which cannot
166             // be thread-local.
167             if (is_thread_local) {
168               return InvalidArgument(
169                   "computation %s cannot contain call/while op because it "
170                   "requires thread-local buffer allocations",
171                   computation->name());
172             }
173             worklist.push_back(std::make_pair(subcomputation,
174                                               false));  // Not thread local.
175             break;
176           case HloOpcode::kCustomCall:
177           case HloOpcode::kAllReduce:
178           case HloOpcode::kReduceScatter:
179           case HloOpcode::kAllReduceStart:
180           case HloOpcode::kMap:
181           case HloOpcode::kReduce:
182           case HloOpcode::kReduceWindow:
183           case HloOpcode::kScatter:
184           case HloOpcode::kSelectAndScatter:
185           case HloOpcode::kSort:
186           case HloOpcode::kFusion:
187             // Map/reduce etc computations are always thread-local.
188             worklist.push_back(std::make_pair(subcomputation,
189                                               true));  // Thread local.
190             break;
191           default:
192             return InternalError("Unexpected calling opcode: %s",
193                                  HloOpcodeString(instruction->opcode()));
194         }
195       }
196     }
197   }
198 
199   // Add the computations to the vectors in post order.
200   for (auto* computation : module->MakeComputationPostOrder()) {
201     if (thread_local_set.contains(computation)) {
202       thread_local_computations->push_back(computation);
203     } else if (global_set.contains(computation)) {
204       global_computations->push_back(computation);
205     }
206     // If the computation is not reachable from the entry computation, then it
207     // will not appear in either thread_local_set or global_set. We don't bother
208     // assigning buffers for these.
209   }
210   return OkStatus();
211 }
212 
ToString() const213 std::string BufferAllocation::Slice::ToString() const {
214   return absl::StrCat("{index:", index(), ", offset:", offset_,
215                       ", size:", size_, "}");
216 }
217 
GetSlice(const HloValue & buffer) const218 BufferAllocation::Slice BufferAllocation::GetSlice(
219     const HloValue& buffer) const {
220   const OffsetSize os = FindOrDie(assigned_buffers_, &buffer);
221   return Slice(this, os.offset, os.size);
222 }
223 
AddAssignment(const HloValue & buffer,int64_t offset,int64_t size)224 void BufferAllocation::AddAssignment(const HloValue& buffer, int64_t offset,
225                                      int64_t size) {
226   VLOG(4) << "Adding the following buffer to allocation #" << index()
227           << absl::StrFormat(" (size=%d, offset=%d) %s", size, offset,
228                              buffer.ToShortString());
229   CHECK(!assigned_buffers_.contains(&buffer))
230       << "LogicalBuffer " << buffer << " already assigned to allocation "
231       << index_;
232   CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
233                           << " offset out of range";
234   CHECK_LE(offset + size, size_)
235       << "LogicalBuffer " << buffer
236       << " size out of range at offset: " << offset << " with size: " << size;
237   CHECK_EQ(buffer.color(), color())
238       << "Buffer color " << buffer.color() << " for buffer " << buffer
239       << " does not match allocation color " << color() << ".";
240   OffsetSize offset_size;
241   offset_size.offset = offset;
242   offset_size.size = size;
243   assigned_buffers_.emplace(&buffer, offset_size);
244   // For debugging purposes, store the assigned memory space in the
245   // instruction's layout.
246   for (HloPosition position : buffer.positions()) {
247     Shape* shape = ShapeUtil::GetMutableSubshape(
248         position.instruction->mutable_shape(), position.index);
249     if (shape->has_layout()) {
250       shape->mutable_layout()->set_memory_space(buffer.color());
251     }
252   }
253 }
254 
ToProto() const255 BufferAllocationProto BufferAllocation::ToProto() const {
256   BufferAllocationProto proto;
257   proto.set_index(index_);
258   proto.set_size(size_);
259   proto.set_is_thread_local(is_thread_local_);
260   proto.set_is_tuple(is_tuple_);
261   proto.set_color(color_);
262   if (is_entry_computation_parameter_) {
263     proto.set_is_entry_computation_parameter(true);
264     for (int64_t idx : param_shape_index()) {
265       proto.add_parameter_shape_index(idx);
266     }
267     proto.set_parameter_number(parameter_number_);
268   }
269   proto.set_is_constant(is_constant_);
270   proto.set_maybe_live_out(maybe_live_out_);
271   for (const auto& buffer_offset_size : assigned_buffers_) {
272     BufferAllocationProto::Assigned* proto_assigned = proto.add_assigned();
273     proto_assigned->set_logical_buffer_id(buffer_offset_size.first->id());
274     proto_assigned->set_offset(buffer_offset_size.second.offset);
275     proto_assigned->set_size(buffer_offset_size.second.size);
276   }
277   absl::c_sort(*proto.mutable_assigned(),
278                [](const BufferAllocationProto::Assigned& assign1,
279                   const BufferAllocationProto::Assigned& assign2) {
280                  return assign1.logical_buffer_id() <
281                         assign2.logical_buffer_id();
282                });
283   return proto;
284 }
285 
CompareHloValuesById(const HloValue * a,const HloValue * b)286 static bool CompareHloValuesById(const HloValue* a, const HloValue* b) {
287   return a->id() < b->id();
288 }
289 
290 // Returns parameter instruction corresponding to the allocation or nullptr.
GetEntryParameterInstruction(const BufferAllocation & alloc)291 static const HloInstruction* GetEntryParameterInstruction(
292     const BufferAllocation& alloc) {
293   for (const auto& p : alloc.assigned_buffers()) {
294     const HloValue* value = p.first;
295     const HloInstruction* instr = value->instruction();
296     if (instr->opcode() == HloOpcode::kParameter &&
297         instr->parent() == instr->parent()->parent()->entry_computation()) {
298       return instr;
299     }
300   }
301   return nullptr;
302 }
303 
304 // Returns root module output instruction corresponding to the allocation or
305 // nullptr.
GetOutputInstruction(const BufferAllocation & alloc)306 static const HloInstruction* GetOutputInstruction(
307     const BufferAllocation& alloc) {
308   for (const auto& p : alloc.assigned_buffers()) {
309     const HloValue* value = p.first;
310     for (const HloPosition& position : value->positions()) {
311       const HloInstruction* instr = position.instruction;
312       if (position.index.empty() &&
313           instr->parent()->root_instruction() == instr &&
314           instr->parent()->IsEntryComputation()) {
315         return instr;
316       }
317     }
318   }
319   return nullptr;
320 }
321 
ToString() const322 std::string BufferAllocation::ToString() const {
323   std::string output;
324   StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size());
325   if (color() != 0) {
326     StrAppend(&output, ", color ", color());
327   }
328   if (is_entry_computation_parameter()) {
329     const HloInstruction* param = GetEntryParameterInstruction(*this);
330     StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
331               param ? param->shape().ToString(/*print_layout=*/false)
332                     : "<unknown shape>",
333               "| at ShapeIndex ", param_shape_index().ToString());
334   }
335   if (const HloInstruction* instr = GetOutputInstruction(*this)) {
336     StrAppend(&output, ", output shape is |",
337               instr->shape().ToString(/*print_layout=*/false), "|");
338   }
339   if (is_constant()) {
340     StrAppend(&output, ", constant");
341   }
342   if (is_thread_local()) {
343     StrAppend(&output, ", thread-local");
344   }
345   if (maybe_live_out()) {
346     StrAppend(&output, ", maybe-live-out");
347   }
348   if (IsPreallocatedTempBuffer()) {
349     StrAppend(&output, ", preallocated-temp");
350   }
351   StrAppend(&output, ":\n");
352   // Dump the assigned buffers ordered by id.
353   std::vector<const HloValue*> sorted_buffers;
354   for (const auto& buffer_offset_size : assigned_buffers_) {
355     sorted_buffers.push_back(buffer_offset_size.first);
356   }
357   absl::c_sort(sorted_buffers, &CompareHloValuesById);
358   for (const HloValue* buffer : sorted_buffers) {
359     const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
360     StrAppend(&output,
361               absl::StrFormat(
362                   " value: %s (size=%d,offset=%d): %s\n",
363                   buffer->ToShortString(), offset_size.size, offset_size.offset,
364                   ShapeUtil::HumanStringWithLayout(buffer->shape())));
365   }
366   return output;
367 }
368 
operator <<(std::ostream & out,const BufferAllocation & buffer)369 std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
370   out << buffer.ToString();
371   return out;
372 }
373 
operator <<(std::ostream & out,const BufferAllocation::Slice & s)374 std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s) {
375   out << s.ToString();
376   return out;
377 }
378 
HasAllocation(const HloValue & value) const379 bool BufferAssignment::HasAllocation(const HloValue& value) const {
380   return allocation_index_for_value_.contains(&value);
381 }
382 
HasAllocation(const HloBuffer & buffer) const383 bool BufferAssignment::HasAllocation(const HloBuffer& buffer) const {
384   return allocation_index_for_value_.contains(buffer.values()[0]);
385 }
386 
GetAssignedAllocation(const HloValue & value) const387 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
388     const HloValue& value) const {
389   CHECK(HasAllocation(value));
390   return GetAllocation(allocation_index_for_value_.at(&value));
391 }
392 
GetAssignedAllocation(const HloBuffer & hlo_buffer) const393 const BufferAllocation& BufferAssignment::GetAssignedAllocation(
394     const HloBuffer& hlo_buffer) const {
395   return GetAssignedAllocation(*hlo_buffer.values()[0]);
396 }
397 
GetMutableAssignedAllocation(const HloBuffer & buffer)398 BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
399     const HloBuffer& buffer) {
400   return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
401 }
402 
GetAllSlices(const HloInstruction * instruction,const ShapeIndex & index) const403 std::set<BufferAllocation::Slice> BufferAssignment::GetAllSlices(
404     const HloInstruction* instruction, const ShapeIndex& index) const {
405   std::set<BufferAllocation::Slice> result;
406   for (const HloValue* value :
407        dataflow_analysis().GetValueSet(instruction, index).values()) {
408     if (HasAllocation(*value)) {
409       result.insert(GetAssignedAllocation(*value).GetSlice(*value));
410     }
411   }
412   return result;
413 }
414 
GetAllocation(BufferAllocation::Index index) const415 const BufferAllocation& BufferAssignment::GetAllocation(
416     BufferAllocation::Index index) const {
417   CHECK_GE(index, 0);
418   CHECK_LT(index, allocations_.size());
419   return allocations_[index];
420 }
421 
GetInstructionAllocation(const HloInstruction * hlo,const ShapeIndex & shape_index) const422 const BufferAllocation* BufferAssignment::GetInstructionAllocation(
423     const HloInstruction* hlo, const ShapeIndex& shape_index) const {
424   const HloValue* value =
425       dataflow_analysis().GetValueSet(hlo, shape_index).values()[0];
426 
427   if (!HasAllocation(*value)) {
428     return nullptr;
429   }
430 
431   const BufferAllocation& instruction_allocation =
432       GetAssignedAllocation(*value);
433   return &instruction_allocation;
434 }
435 
GetMutableAllocation(BufferAllocation::Index index)436 BufferAllocation* BufferAssignment::GetMutableAllocation(
437     BufferAllocation::Index index) {
438   return const_cast<BufferAllocation*>(&GetAllocation(index));
439 }
440 
HasAllocationAt(const HloInstruction * instruction,const ShapeIndex & index) const441 bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
442                                        const ShapeIndex& index) const {
443   return absl::c_any_of(
444       dataflow_analysis().GetValueSet(instruction, index).values(),
445       IsKeyIn(allocation_index_for_value_));
446 }
447 
HasTopLevelAllocation(const HloInstruction * instruction) const448 bool BufferAssignment::HasTopLevelAllocation(
449     const HloInstruction* instruction) const {
450   return HasAllocationAt(instruction, /*index=*/{});
451 }
452 
GetUniqueSlice(const HloInstruction * instruction,const ShapeIndex & index) const453 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
454     const HloInstruction* instruction, const ShapeIndex& index) const {
455   VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
456           << index << "]";
457   BufferAllocation::Slice result;
458   for (const HloValue* value :
459        dataflow_analysis().GetValueSet(instruction, index).values()) {
460     VLOG(3) << "Examining value " << *value;
461     if (HasAllocation(*value)) {
462       VLOG(3) << "Has allocation";
463       const BufferAllocation::Slice slice =
464           GetAssignedAllocation(*value).GetSlice(*value);
465       if (result.allocation() == nullptr) {
466         result = slice;
467       } else if (result != slice) {
468         return FailedPrecondition(
469             "BufferAllocation::Slice for instruction %s at index %s cannot "
470             "be determined at compile-time.",
471             instruction->name(), index.ToString());
472       }
473     } else {
474       VLOG(3) << "No allocation";
475     }
476   }
477   if (result.allocation() == nullptr) {
478     return FailedPrecondition(
479         "BufferAllocation::Slice not assigned for instruction %s at index %s",
480         instruction->name(), index.ToString());
481   }
482   return result;
483 }
484 
GetUniqueTopLevelSlice(const HloInstruction * instruction) const485 StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
486     const HloInstruction* instruction) const {
487   return GetUniqueSlice(instruction, /*index=*/{});
488 }
489 
SharesSliceAtIndex(const HloInstruction * hlo_a,const ShapeIndex & shape_index_a,const HloInstruction * hlo_b,const ShapeIndex & shape_index_b) const490 bool BufferAssignment::SharesSliceAtIndex(
491     const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
492     const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
493   return GetUniqueSlice(hlo_a, shape_index_a).value() ==
494          GetUniqueSlice(hlo_b, shape_index_b).value();
495 }
496 
HaveDisjointSlices(const HloInstruction * hlo_a,const HloInstruction * hlo_b) const497 bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
498                                           const HloInstruction* hlo_b) const {
499   using SliceSet = flat_hash_set<BufferAllocation::Slice>;
500   // Gets the slices all of instr's subshapes.  If any subshape doesn't have an
501   // assigned slice, returns the empty set.
502   auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
503     SliceSet slices;
504     Status status = ShapeUtil::ForEachSubshapeWithStatus(
505         instr->shape(),
506         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
507           auto shape_slices = GetAllSlices(instr, index);
508           if (shape_slices.empty()) {
509             return InvalidArgument("No slices assigned to part of instr.");
510           }
511           slices.insert(shape_slices.begin(), shape_slices.end());
512           return OkStatus();
513         });
514     if (!status.ok()) {
515       return {};
516     }
517     return slices;
518   };
519 
520   SliceSet slices_a = collect_slices(hlo_a);
521   SliceSet slices_b = collect_slices(hlo_b);
522   // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e.
523   // didn't return the empty set) for both HLOs, and the two resulting sets of
524   // slices are disjoint.
525   return !slices_a.empty() && !slices_b.empty() &&
526          absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
527            return slices_b.contains(slice);
528          });
529 }
530 
531 StatusOr<BufferAllocation::Slice>
GetUniqueTopLevelOutputSlice() const532 BufferAssignment::GetUniqueTopLevelOutputSlice() const {
533   return GetUniqueTopLevelSlice(
534       module_->entry_computation()->root_instruction());
535 }
536 
NewEmptyAllocation(int64_t size,LogicalBuffer::Color color)537 BufferAllocation* BufferAssignment::NewEmptyAllocation(
538     int64_t size, LogicalBuffer::Color color) {
539   BufferAllocation::Index index = allocations_.size();
540   allocations_.emplace_back(index, size, color);
541   BufferAllocation* allocation = &allocations_.back();
542   return allocation;
543 }
544 
NewAllocation(const HloBuffer & buffer,int64_t size)545 BufferAllocation* BufferAssignment::NewAllocation(const HloBuffer& buffer,
546                                                   int64_t size) {
547   BufferAllocation* allocation = NewEmptyAllocation(size, buffer.color());
548   AddAssignment(allocation, buffer, /*offset=*/0, size);
549   allocation->peak_buffers_.push_back(buffer.values()[0]);
550   return allocation;
551 }
552 
AddAssignment(BufferAllocation * allocation,const HloBuffer & buffer,int64_t offset,int64_t size)553 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
554                                      const HloBuffer& buffer, int64_t offset,
555                                      int64_t size) {
556   CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
557       << "Non-reusable allocation already assigned a buffer: "
558       << allocation->ToString();
559 
560   for (const HloValue* buffer_value : buffer.values()) {
561     CHECK(!allocation_index_for_value_.contains(buffer_value))
562         << "BufferValue " << buffer_value << " already has an allocation.";
563     allocation->AddAssignment(*buffer_value, offset, size);
564     allocation_index_for_value_[buffer_value] = allocation->index();
565   }
566 
567   if (alias_analysis().BufferLivesOut(buffer)) {
568     VLOG(3) << "HloBuffer lives out: " << buffer.ToString();
569     VLOG(3) << "Set maybe live out: " << allocation->ToString();
570     allocation->set_maybe_live_out(true);
571   }
572 }
573 
AddAssignment(BufferAllocation * allocation,const HloValue & value,int64_t offset,int64_t size)574 void BufferAssignment::AddAssignment(BufferAllocation* allocation,
575                                      const HloValue& value, int64_t offset,
576                                      int64_t size) {
577   allocation->AddAssignment(value, offset, size);
578   allocation_index_for_value_[&value] = allocation->index();
579   const HloValue& hlo_value =
580       *CHECK_NOTNULL(dynamic_cast<const HloValue*>(&value));
581   if (alias_analysis().ValueLivesOut(hlo_value)) {
582     VLOG(3) << "HloValue lives out: " << hlo_value.ToString();
583     VLOG(3) << "Set maybe live out: " << allocation->ToString();
584     allocation->set_maybe_live_out(true);
585   }
586 }
587 
588 // Combines allocations of temporary buffers of the same color into one big
589 // BufferAllocation.
CombineTempAllocations()590 void BufferAssignment::CombineTempAllocations() {
591   VLOG(1) << "CombineTempAllocations()";
592   // Stores the combined allocations.
593   std::deque<BufferAllocation> combined_allocations;
594   // Holds the pointer to a combined allocation of each color, if any.
595   flat_hash_map<BufferValue::Color, BufferAllocation*> combined_allocation_map;
596 
597   // Move all temp allocations into a single run at the end of the allocations
598   // vector.
599   const auto first_temp_it =
600       std::partition(allocations_.begin(), allocations_.end(),
601                      [](const BufferAllocation& allocation) {
602                        return !allocation.IsPreallocatedTempBuffer();
603                      });
604 
605   // Walk over the run of temp allocations, collecting the allocations belonging
606   // to the same color.
607   if (first_temp_it != allocations_.end()) {
608     for (auto it = first_temp_it; it != allocations_.end(); ++it) {
609       BufferAllocation& temp_allocation = *it;
610       BufferValue::Color color = temp_allocation.color();
611       auto combined_it = combined_allocation_map.find(color);
612       if (combined_it == combined_allocation_map.end()) {
613         // We have found the first temp allocation of this color. Collect
614         // the other temp allocations of the same color into it subject to the
615         // size constraint.
616         VLOG(1) << "Combined temp allocation for color " << color
617                 << " is: " << temp_allocation;
618         combined_allocations.emplace_back(temp_allocation);
619         combined_allocation_map.emplace(color, &combined_allocations.back());
620         continue;
621       }
622       if (combined_it->second->size() + it->size() >=
623           multiheap_size_constraint_per_heap_) {
624         // We cannot put more into the current combined_it. So, appoint a new
625         // combined_it.
626         VLOG(1) << "Due to size constraint, reset temp allocation for color "
627                 << color << " to: " << temp_allocation;
628         combined_allocations.emplace_back(temp_allocation);
629         combined_allocation_map.emplace(color, &combined_allocations.back());
630         continue;
631       }
632 
633       BufferAllocation* combined_allocation = combined_it->second;
634       VLOG(1) << "Combined allocation absorbing temp allocation: "
635               << temp_allocation;
636 
637       // Each temp allocation is placed end-to-end, accounting for alignment.
638       // The offset of each buffer in the combined allocation is computed from
639       // the base offset of the allocation.
640       int64_t alignment = color_alignment_(color);
641       const int64_t base = RoundUpTo(combined_allocation->size(), alignment);
642       combined_allocation->set_size(base + temp_allocation.size());
643       for (const auto& buffer_offset_size : temp_allocation.assigned_buffers_) {
644         const HloValue* value = buffer_offset_size.first;
645         const int64_t offset = buffer_offset_size.second.offset;
646         const int64_t size = buffer_offset_size.second.size;
647         combined_allocation->AddAssignment(*value, base + offset, size);
648       }
649       if (!temp_allocation.HeapTraces().empty()) {
650         CHECK_EQ(temp_allocation.HeapTraces().size(), 1);
651         combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front());
652       }
653 
654       combined_allocation->peak_buffers_.insert(
655           combined_allocation->peak_buffers_.end(),
656           temp_allocation.peak_buffers_.begin(),
657           temp_allocation.peak_buffers_.end());
658     }
659     // Replace all existing temporary allocations with the new combined
660     // allocations.
661     allocations_.erase(first_temp_it, allocations_.end());
662     for (BufferAllocation& combined : combined_allocations) {
663       temp_allocation_total_size_ += combined.size();
664       allocations_.push_back(std::move(combined));
665     }
666   }
667 
668   // Update allocation indices to their new positions.
669   allocation_index_for_value_.erase(allocation_index_for_value_.begin(),
670                                     allocation_index_for_value_.end());
671   for (size_t index = 0; index < allocations_.size(); ++index) {
672     BufferAllocation* allocation = &allocations_[index];
673     allocation->set_index(index);
674     for (const auto& buffer_offset_size : allocation->assigned_buffers_) {
675       const HloValue* value = buffer_offset_size.first;
676       allocation_index_for_value_[value] = index;
677     }
678   }
679 }
680 
ComputeSummaryStats()681 Status BufferAssignment::ComputeSummaryStats() {
682   for (auto& allocation : Allocations()) {
683     if (allocation.is_entry_computation_parameter()) {
684       stats_.parameter_allocation_count++;
685       stats_.parameter_allocation_bytes += allocation.size();
686     }
687     if (allocation.is_constant()) {
688       stats_.constant_allocation_count++;
689       stats_.constant_allocation_bytes += allocation.size();
690     }
691     if (allocation.maybe_live_out()) {
692       stats_.maybe_live_out_allocation_count++;
693       stats_.maybe_live_out_allocation_bytes += allocation.size();
694     }
695     if (allocation.IsPreallocatedTempBuffer()) {
696       stats_.preallocated_temp_allocation_count++;
697       stats_.preallocated_temp_allocation_bytes += allocation.size();
698     }
699     stats_.total_allocation_count++;
700     stats_.total_allocation_bytes += allocation.size();
701   }
702 
703   // Only compute total fragmentation if all computations have schedules.
704   HloSchedule schedule(module_);
705   bool schedule_complete = true;
706   for (const auto& computation : module_->computations()) {
707     if (!computation->IsFusionComputation()) {
708       const HloInstructionSequence* sequence =
709           hlo_ordering().SequentialOrder(*computation);
710       if (sequence == nullptr) {
711         schedule_complete = false;
712       } else {
713         schedule.set_sequence(computation, *sequence);
714       }
715     }
716   }
717   if (schedule_complete) {
718     TF_RETURN_IF_ERROR(schedule.Verify());
719     TF_ASSIGN_OR_RETURN(
720         const int64_t min_size,
721         HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
722     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
723   }
724 
725   return OkStatus();
726 }
727 
ToString() const728 std::string BufferAssignment::Stats::ToString() const {
729   std::string s;
730   StrAppendFormat(&s, "BufferAssignment stats:\n");
731   StrAppendFormat(&s, "             parameter allocation: %10s\n",
732                   HumanReadableNumBytes(parameter_allocation_bytes));
733   StrAppendFormat(&s, "              constant allocation: %10s\n",
734                   HumanReadableNumBytes(constant_allocation_bytes));
735   StrAppendFormat(&s, "        maybe_live_out allocation: %10s\n",
736                   HumanReadableNumBytes(maybe_live_out_allocation_bytes));
737   StrAppendFormat(&s, "     preallocated temp allocation: %10s\n",
738                   HumanReadableNumBytes(preallocated_temp_allocation_bytes));
739   if (preallocated_temp_fragmentation_bytes >= 0) {
740     const double percent = 100. * preallocated_temp_fragmentation_bytes /
741                            preallocated_temp_allocation_bytes;
742     StrAppendFormat(
743         &s, "  preallocated temp fragmentation: %10s (%.2f%%)\n",
744         HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent);
745   }
746   StrAppendFormat(&s, "                 total allocation: %10s\n",
747                   HumanReadableNumBytes(total_allocation_bytes));
748   if (total_fragmentation_bytes >= 0) {
749     const double percent =
750         100. * total_fragmentation_bytes / total_allocation_bytes;
751     StrAppendFormat(&s, "              total fragmentation: %10s (%.2f%%)\n",
752                     HumanReadableNumBytes(total_fragmentation_bytes), percent);
753   }
754   return s;
755 }
756 
ToString() const757 std::string BufferAssignment::ToString() const {
758   std::string output;
759   absl::StrAppend(&output, "BufferAssignment:\n");
760   std::vector<const HloValue*> used_values;
761   int64_t total_size = 0;
762   for (auto& allocation : allocations_) {
763     total_size += allocation.size();
764     absl::StrAppend(&output, allocation.ToString());
765     for (const auto& p : allocation.assigned_buffers()) {
766       used_values.push_back(p.first);
767     }
768   }
769   absl::StrAppend(&output, "\nTotal bytes used: ", total_size, " (",
770                   HumanReadableNumBytes(total_size), ")\n");
771   absl::StrAppend(&output, "\nUsed values:\n");
772   absl::c_sort(used_values, &CompareHloValuesById);
773   for (const HloValue* value : used_values) {
774     absl::StrAppend(&output, value->ToString());
775   }
776   return output;
777 }
778 
779 // Returns the largest k buffers present at the point of peak memory usage
780 // across allocations as a vector of pairs with their corresponding sizes.
TopKPeakBuffers(uint64_t k,const std::vector<BufferAllocation> allocations)781 std::vector<std::pair<int64_t, const HloValue*>> TopKPeakBuffers(
782     uint64_t k, const std::vector<BufferAllocation> allocations) {
783   absl::btree_multimap<int64_t, const HloValue*> topk;
784   for (const BufferAllocation& allocation : allocations) {
785     for (const HloValue* value : allocation.PeakMemoryLogicalBuffers()) {
786       int64_t size = allocation.assigned_buffers().at(value).size;
787       if (topk.size() < k) {
788         topk.insert({size, value});
789       } else {
790         auto it = topk.begin();
791         if (size > it->first) {
792           topk.erase(it);
793           topk.insert({size, value});
794         }
795       }
796     }
797   }
798 
799   // map will iterate smallest first, so reverse it.
800   std::vector<std::pair<int64_t, const HloValue*>> topk_descending;
801   topk_descending.reserve(topk.size());
802   absl::c_reverse_copy(topk, std::back_inserter(topk_descending));
803   return topk_descending;
804 }
805 
ToVerboseString() const806 std::string BufferAssignment::ToVerboseString() const {
807   // TODO(loreno): make this tunable via flag.
808   const size_t kMaxBuffersToShow = 15;
809   std::string output =
810       absl::StrCat("BufferAssignment OOM Debugging.\n", stats_.ToString());
811 
812   std::vector<std::pair<int64_t, const HloValue*>> peak_buffers =
813       TopKPeakBuffers(kMaxBuffersToShow, allocations_);
814   std::vector<std::string> buf_strs;
815   for (size_t i = 0; i < std::min(kMaxBuffersToShow, peak_buffers.size());
816        ++i) {
817     const HloValue* value = peak_buffers[i].second;
818     const HloInstruction* instr = value->instruction();
819     int64_t size = peak_buffers[i].first;
820     buf_strs.push_back(absl::StrCat("\n\tBuffer ", i + 1, ":\n\t\tSize: ",
821                                     xla::HumanReadableNumBytes(size)));
822     if (!instr->metadata().op_name().empty()) {
823       buf_strs.push_back(absl::StrCat(
824           "\n\t\tOperator: ", xla::OpMetadataToString(instr->metadata())));
825     }
826     if (instr->opcode() == HloOpcode::kParameter &&
827         (instr->parent() == instr->parent()->parent()->entry_computation())) {
828       // Special case on entry parameters as they sometimes have hundreds of
829       // indices in their shapes, and overwhelm the output.
830       buf_strs.push_back(absl::StrCat(
831           "\n\t\tEntry Parameter Subshape: ",
832           ShapeUtil::GetSubshape(instr->shape(), value->index()).ToString()));
833     } else {
834       // TODO(loreno): change this to a truncated string of the instruction.
835       buf_strs.push_back(
836           absl::StrCat("\n\t\tXLA Label: ", HloOpcodeString(instr->opcode()),
837                        "\n\t\tShape: ", value->shape().ToString()));
838     }
839     buf_strs.push_back("\n\t\t==========================\n");
840   }
841   absl::StrAppend(&output, "Peak buffers:", absl::StrJoin(buf_strs, ""));
842   return output;
843 }
844 
BufferInfoString() const845 std::string BufferAssignment::BufferInfoString() const {
846   std::string binfo;
847   // Columns in buffer information:
848   // buffer_id: int. This value can be used to match the allocation in
849   // allocation information.
850   // buffer_name: string.
851   // offset: int. Starting position of the buffer in the memory space.
852   // size: int. Size of the buffer in bytes.
853   // definition_time: int. Position in the schedule where the buffer starts
854   // being live (inclusive).
855   // end_time: int. Position in the schedule where the buffer stops being live
856   // (exclusive).
857   // num_uses: int. Number of uses of the buffer.
858   // use_names: string. This is a semicolon-separated list of string
859   // representation of uses.
860   // Append the column names.
861   absl::StrAppend(&binfo,
862                   "buffer_id,buffer_name,offset,size,"
863                   "definition_time,end_time,num_uses,use_times,use_names\n");
864   const HloLiveRange& live_ranges = hlo_live_range();
865   const auto& instruction_schedule = live_ranges.instruction_schedule();
866   const auto& buffer_live_ranges = live_ranges.buffer_live_ranges();
867   // Sort the buffers by Id.
868   std::vector<std::pair<const HloValue*, BufferAllocation::OffsetSize>> buffers;
869   for (const BufferAllocation& allocation : allocations_) {
870     absl::c_copy(allocation.assigned_buffers(), std::back_inserter(buffers));
871   }
872   absl::c_sort(
873       buffers,
874       [](const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b1,
875          const std::pair<const HloValue*, BufferAllocation::OffsetSize>& b2) {
876         return b1.first->id() < b2.first->id();
877       });
878   for (const auto& buffer_pair : buffers) {
879     const HloValue& buffer = *buffer_pair.first;
880     const BufferAllocation::OffsetSize& offset_size = buffer_pair.second;
881     if (!buffer_live_ranges.contains(&buffer)) {
882       continue;
883     }
884     // Ordering uses by their use position.
885     std::vector<std::pair<int64_t, std::string>> uses;
886     uses.reserve(buffer.GetUses().size());
887     for (const HloUse& use : buffer.GetUses()) {
888       uses.emplace_back(instruction_schedule.at(use.instruction),
889                         use.ToString());
890     }
891     absl::c_sort(uses);
892     std::vector<int64_t> use_positions;
893     std::vector<std::string> use_names;
894     use_positions.reserve(uses.size());
895     use_names.reserve(uses.size());
896     for (const auto& use : uses) {
897       use_positions.push_back(use.first);
898       use_names.push_back(use.second);
899     }
900     const int64_t definition_time =
901         instruction_schedule.at(buffer.defining_position().instruction);
902     const int64_t end_t = buffer_live_ranges.at(&buffer).end;
903     absl::StrAppend(&binfo, buffer.id(), ",");
904     absl::StrAppend(&binfo, "\"", buffer.ToShortString(), "\",");
905     absl::StrAppend(&binfo, offset_size.offset, ",");
906     absl::StrAppend(&binfo, offset_size.size, ",");
907     absl::StrAppend(&binfo, definition_time, ",");
908     absl::StrAppend(&binfo, end_t, ",");
909     absl::StrAppend(&binfo, use_positions.size(), ",");
910     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_positions, ";"), "\",");
911     absl::StrAppend(&binfo, "\"", absl::StrJoin(use_names, ";"), "\"");
912     absl::StrAppend(&binfo, "\n");
913   }
914   return binfo;
915 }
916 
ToProto() const917 BufferAssignmentProto BufferAssignment::ToProto() const {
918   BufferAssignmentProto proto;
919   // NOTE: DataflowAnalysis state is serialized here in BufferAssignment,
920   // because we need to do the HasAllocation check for each buffer. Otherwise
921   // the buffer_size_ call might fail for some backends.
922   const HloDataflowAnalysis& dataflow = this->dataflow_analysis();
923   for (BufferValue::Id id = 0; id < dataflow.values().size(); id++) {
924     auto& value = dataflow.values().at(id);
925     if (HasAllocation(*value)) {
926       LogicalBufferProto proto_buffer = value->ToProto(buffer_size_);
927       proto.add_logical_buffers()->Swap(&proto_buffer);
928 
929       // Fill buffer aliases.
930       for (const HloValue* alias :
931            alias_analysis().GetBufferContainingValue(*value).values()) {
932         if (alias->instruction() == value->instruction() &&
933             alias->index() == value->index()) {
934           continue;  // skip self-aliases
935         }
936         BufferAssignmentProto::BufferAlias* proto_alias =
937             proto.add_buffer_aliases();
938         LogicalBufferProto::Location proto_alias_location =
939             BufferValue::ToLocationProto(*alias->instruction(), alias->index());
940         proto_alias->set_source_buffer_id(value->id());
941         proto_alias->mutable_location()->Swap(&proto_alias_location);
942       }
943     }
944   }
945   for (const BufferAllocation& allocation : Allocations()) {
946     BufferAllocationProto proto_allocation = allocation.ToProto();
947     proto.add_buffer_allocations()->Swap(&proto_allocation);
948     for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) {
949       *proto.add_heap_simulator_traces() = heap_trace;
950     }
951   }
952   return proto;
953 }
954 
955 /* static */
Run(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,bool allocate_buffers_for_constants,BufferAssigner::Colorer colorer,std::optional<BufferAssigner::MustNotLiveOut> must_not_live_out,HloDataflowAnalysis::CanShareBuffer can_share_buffer,std::unique_ptr<PresetAssignments> preset_assignments)956 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
957     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
958     BufferValue::SizeFunction buffer_size,
959     LogicalBuffer::AlignmentFunction color_alignment,
960     bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
961     std::optional<BufferAssigner::MustNotLiveOut> must_not_live_out,
962     HloDataflowAnalysis::CanShareBuffer can_share_buffer,
963     std::unique_ptr<PresetAssignments> preset_assignments) {
964   BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
965                           must_not_live_out, std::move(preset_assignments));
966   return assigner.CreateAssignment(
967       module, std::move(hlo_ordering), std::move(buffer_size),
968       std::move(color_alignment), std::move(can_share_buffer));
969 }
970 
LiveRangeInterferes(const HloValue * buffer1,const HloValue * buffer2,BufferAssignment * assignment)971 bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1,
972                                          const HloValue* buffer2,
973                                          BufferAssignment* assignment) {
974   CHECK((assignment->hlo_live_range().total_order_scheduled()));
975   const HloLiveRange& hlo_live_range = assignment->hlo_live_range();
976 
977   const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges();
978 
979   auto live_range_it1 = buffer_live_ranges.find(buffer1);
980   CHECK(live_range_it1 != buffer_live_ranges.end())
981       << "Buffer doesn't have a proper live range:" << buffer1->ToString();
982 
983   auto live_range_it2 = buffer_live_ranges.find(buffer2);
984   CHECK(live_range_it2 != buffer_live_ranges.end())
985       << "Buffer doesn't have a proper live range:" << buffer2->ToString();
986 
987   // Check if a user value can share the same buffer as its operand.
988   auto can_share_as_operand =
989       [&assignment](const HloValue* user_value, const HloValue* operand_value,
990                     const HloLiveRange::TimeBound& operand_live_range) {
991         // An hlo value can hold multiple instructions during its life time. We
992         // only look at the last instruction and check if it can be shared with
993         // the operand.
994         HloPosition operand_end_position = operand_live_range.end_position;
995         return user_value->instruction()->opcode() != HloOpcode::kCopy &&
996                user_value->instruction()->IsUserOf(
997                    operand_end_position.instruction) &&
998                assignment->dataflow_analysis().CanShareOperandBufferWithUser(
999                    operand_end_position.instruction, operand_end_position.index,
1000                    user_value->instruction(), user_value->index());
1001       };
1002 
1003   const auto& live_range_1 = live_range_it1->second;
1004   const auto& live_range_2 = live_range_it2->second;
1005 
1006   if (!(live_range_1.start > live_range_2.end ||
1007         live_range_2.start > live_range_1.end)) {
1008     if (live_range_1.end == live_range_2.start) {
1009       auto operand_value = buffer1;
1010       auto user_value = buffer2;
1011       if (!can_share_as_operand(user_value, operand_value, live_range_1)) {
1012         VLOG(4) << "End of live range of " << buffer1->ToShortString()
1013                 << " is equal to the start of live range of "
1014                 << buffer2->ToShortString() << ", buffer cannot be shared.";
1015         return true;
1016       }
1017     } else if (live_range_2.end == live_range_1.start) {
1018       auto operand_value = buffer2;
1019       auto user_value = buffer1;
1020       if (!can_share_as_operand(user_value, operand_value, live_range_2)) {
1021         VLOG(4) << "End of live range of " << buffer2->ToShortString()
1022                 << " is equal to the start of live range of "
1023                 << buffer1->ToShortString() << ", buffer cannot be shared.";
1024         return true;
1025       }
1026     } else {
1027       VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with "
1028               << *buffer2;
1029       VLOG(4) << "assigned_buffer.start: " << live_range_1.start;
1030       VLOG(4) << "assigned_buffer.end: " << live_range_1.end;
1031       VLOG(4) << "live_range_2.start" << live_range_2.start;
1032       VLOG(4) << "live_range_2.end" << live_range_2.end;
1033       return true;
1034     }
1035   }
1036   return false;
1037 }
1038 
MaybeAssignBuffer(BufferAllocation * allocation,const HloBuffer & hlo_buffer,BufferAssignment * assignment)1039 bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
1040                                        const HloBuffer& hlo_buffer,
1041                                        BufferAssignment* assignment) {
1042   CHECK(!assignment->HasAllocation(hlo_buffer))
1043       << "buffer " << hlo_buffer << " already has an allocation assigned.";
1044 
1045   VLOG(4) << "Trying to assign " << hlo_buffer << " size "
1046           << assignment->HloBufferSize(hlo_buffer)
1047           << " to allocation: " << *allocation;
1048 
1049   if (hlo_buffer.color() != allocation->color()) {
1050     VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color()
1051             << " and allocation has color " << allocation->color() << ".";
1052     return false;
1053   }
1054 
1055   if (assignment->HloBufferSize(hlo_buffer) > allocation->size()) {
1056     VLOG(4) << "Can't assign: buffer is larger than allocation ("
1057             << assignment->HloBufferSize(hlo_buffer) << " > "
1058             << allocation->size() << ")";
1059     return false;
1060   }
1061 
1062   if (allocation->is_readonly()) {
1063     VLOG(4) << "Can't assign: allocation is readonly";
1064     return false;
1065   }
1066 
1067   if (must_not_live_out_.has_value()) {
1068     if (allocation->maybe_live_out()) {
1069       // If a buffer maybe live out, the allocation cannot contain any node
1070       // where must_not_live_out_ returns true.
1071       for (const HloValue* value : hlo_buffer.values()) {
1072         if ((*must_not_live_out_)(value->instruction(), value->index())) {
1073           VLOG(4) << "Can't assign: " << value->instruction()->ToString()
1074                   << " cannot live out of the module";
1075           return false;
1076         }
1077       }
1078     }
1079     // The above check is not enough -- There could be the case where an
1080     // allocation can be not live out and contains an instruction where
1081     // must_not_live_out_ returns true, but assigning a live out buffer to
1082     // that allocation makes the allocation live out and also contain an
1083     // instruction where ust_not_live_out_ returns true.
1084     if (assignment->alias_analysis().BufferLivesOut(hlo_buffer)) {
1085       for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1086         const HloValue* value = buffer_offset_size.first;
1087         if ((*must_not_live_out_)(value->instruction(), value->index())) {
1088           VLOG(4) << "Can't assign: " << buffer_offset_size.first->instruction()
1089                   << " cannot live out of the module";
1090           return false;
1091         }
1092       }
1093     }
1094   }
1095 
1096   if (!allocation->is_reusable()) {
1097     VLOG(4) << "Can't assign: allocation is not reusable";
1098     return false;
1099   }
1100 
1101   for (const auto& buffer_offset_size : allocation->assigned_buffers()) {
1102     // Pairwise compare.
1103     const HloValue& assigned_buffer =
1104         *CHECK_NOTNULL(dynamic_cast<const HloValue*>(buffer_offset_size.first));
1105     for (const HloValue* new_value : hlo_buffer.values()) {
1106       if (assignment->hlo_live_range().total_order_scheduled()) {
1107         if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) {
1108           VLOG(4) << "Can't assign: assignee " << assigned_buffer
1109                   << " live range interferes with "
1110                   << new_value->ToShortString();
1111           return false;
1112         }
1113       } else if (assignment->hlo_ordering().MayInterfere(
1114                      assigned_buffer, *new_value,
1115                      assignment->dataflow_analysis())) {
1116         // Fallback to partial order based interference detection (slower) when
1117         // we don't have a total order scheduled module.
1118         VLOG(4) << "Can't assign: assignee " << assigned_buffer
1119                 << " may interfere with " << new_value->ToShortString();
1120         return false;
1121       }
1122 
1123       // Copy instruction don't share a buffer with their input operand.
1124       if (new_value->instruction()->opcode() == HloOpcode::kCopy) {
1125         for (const HloPosition& assigned_buffer_position :
1126              assigned_buffer.positions()) {
1127           if (new_value->instruction()->IsUserOf(
1128                   assigned_buffer_position.instruction)) {
1129             VLOG(4) << "Can't assign: assignee " << assigned_buffer
1130                     << " is used at copy instruction "
1131                     << new_value->ToShortString();
1132             return false;
1133           }
1134         }
1135       }
1136     }
1137   }
1138 
1139   // If the buffer is live out of the computation then it should only be
1140   // assigned a buffer which exactly fits the result to avoid wasting memory
1141   // (result buffers can have arbitrary lifetimes).
1142   if (assignment->alias_analysis().BufferLivesOut(hlo_buffer) &&
1143       allocation->size() != assignment->HloBufferSize(hlo_buffer)) {
1144     VLOG(4) << "Can't assign: buffer " << hlo_buffer
1145             << "is live out and size not the same as allocation";
1146     return false;
1147   }
1148 
1149   assignment->AddAssignment(allocation, hlo_buffer, /*offset=*/0,
1150                             assignment->HloBufferSize(hlo_buffer));
1151   return true;
1152 }  // namespace xla
1153 
AssignSingleHloBuffer(const HloBuffer * hlo_buffer,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,std::vector<BufferAllocation::Index> * allocation_indices,BufferAssignment * assignment)1154 Status BufferAssigner::AssignSingleHloBuffer(
1155     const HloBuffer* hlo_buffer, bool is_thread_local,
1156     absl::flat_hash_map<const HloComputation*,
1157                         absl::flat_hash_set<const HloValue*>>*
1158         buffers_to_assign_sequentially,
1159     std::vector<BufferAllocation::Index>* allocation_indices,
1160     BufferAssignment* assignment) {
1161   const int64_t buffer_size = assignment->HloBufferSize(*hlo_buffer);
1162   for (const HloValue* value : hlo_buffer->values()) {
1163     if (value->instruction()->opcode() == HloOpcode::kConstant) {
1164       if (allocate_buffers_for_constants_) {
1165         BufferAllocation* allocation =
1166             assignment->NewAllocation(*hlo_buffer, buffer_size);
1167         allocation->set_constant(true);
1168         VLOG(3) << "New allocation #" << allocation->index() << " for constant "
1169                 << *hlo_buffer << " value ptr: " << value;
1170       }
1171       VLOG(3) << "Not allocating buffer for constant";
1172       return OkStatus();
1173     }
1174 
1175     const HloInstruction* instruction = value->instruction();
1176     const bool is_entry_parameter =
1177         instruction->opcode() == HloOpcode::kParameter &&
1178         instruction->parent() ==
1179             instruction->parent()->parent()->entry_computation();
1180 
1181     if (is_entry_parameter) {
1182       bool parameter_has_alias =
1183           assignment->module().input_output_alias_config().ParameterHasAlias(
1184               instruction->parameter_number(), value->index());
1185       // If the hlo buffer is part of an external parameter, creates a new
1186       // allocation and sets its parameter number. Parameters of non-entry
1187       // computations do not need special allocations because they live inside
1188       // callers.
1189       BufferAllocation* allocation =
1190           assignment->NewAllocation(*hlo_buffer, buffer_size);
1191 
1192       allocation->set_entry_computation_parameter(
1193           instruction->parameter_number(), value->index(), parameter_has_alias);
1194       if (parameter_has_alias) {
1195         allocation_indices->push_back(allocation->index());
1196       }
1197       VLOG(3) << "New allocation #" << allocation->index()
1198               << " marked as entry computation parameter: " << *hlo_buffer;
1199       return OkStatus();
1200     }
1201   }
1202 
1203   if (is_thread_local) {
1204     BufferAllocation* allocation =
1205         assignment->NewAllocation(*hlo_buffer, buffer_size);
1206     allocation->set_is_thread_local(true);
1207     VLOG(3) << "New allocation #" << allocation->index()
1208             << " for thread-local: " << *hlo_buffer;
1209     return OkStatus();
1210   }
1211 
1212   for (const HloValue* value : hlo_buffer->values()) {
1213     if (value->shape().IsTuple()) {
1214       BufferAllocation* allocation =
1215           assignment->NewAllocation(*hlo_buffer, buffer_size);
1216       allocation->set_is_tuple(true);
1217       VLOG(3) << "New allocation #" << allocation->index()
1218               << " for tuple-shaped buffer: " << *hlo_buffer;
1219       return OkStatus();
1220     }
1221 
1222     if (value->IsTopLevel() && !value->IsTuple()) {
1223       const HloInstruction* instruction = value->instruction();
1224       for (auto* operand : instruction->operands()) {
1225         for (const auto& operand_slice :
1226              assignment->GetAllSlices(operand, /*index=*/{})) {
1227           BufferAllocation* allocation =
1228               assignment->GetMutableAllocation(operand_slice.index());
1229           if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1230             VLOG(3) << "Reusing (operand) allocation #" << allocation->index()
1231                     << " for: " << *hlo_buffer;
1232             return OkStatus();
1233           }
1234         }
1235       }
1236     }
1237   }
1238 
1239   // Find the smallest buffer which can be reused iterating from end of
1240   // allocation_indices (smallest) to beginning (largest).
1241   for (int allocation_index = allocation_indices->size() - 1;
1242        allocation_index >= 0; allocation_index--) {
1243     BufferAllocation* allocation = assignment->GetMutableAllocation(
1244         allocation_indices->at(allocation_index));
1245     if (MaybeAssignBuffer(allocation, *hlo_buffer, assignment)) {
1246       VLOG(3) << "Reusing allocation #" << allocation->index()
1247               << " for: " << *hlo_buffer;
1248       return OkStatus();
1249     }
1250   }
1251 
1252   if (!assignment->HasAllocation(*hlo_buffer) &&
1253       !assignment->alias_analysis().BufferLivesOut(*hlo_buffer)) {
1254     bool all_computations_have_sequential_order = true;
1255     for (const HloValue* hlo_value : hlo_buffer->values()) {
1256       HloComputation* computation = hlo_value->instruction()->parent();
1257       const bool has_sequential_order =
1258           assignment->hlo_ordering().SequentialOrder(*computation) != nullptr;
1259       all_computations_have_sequential_order &= has_sequential_order;
1260     }
1261 
1262     if (all_computations_have_sequential_order) {
1263       for (const HloValue* hlo_value : hlo_buffer->values()) {
1264         HloComputation* computation = hlo_value->instruction()->parent();
1265         // There is a sequential instruction ordering, so we delay assignment
1266         // of temp buffers until after the loop. We do this right before we
1267         // decide to create a new allocation, to ensure we've exhausted all
1268         // the buffer re-use cases above.
1269         //
1270         // Entry parameters and thread local buffers were already handled
1271         // earlier in this loop iteration.  See
1272         // BufferAllocation::IsPreallocatedTempBuffer for the definition of
1273         // temp buffers.
1274         (*buffers_to_assign_sequentially)[computation].insert(hlo_value);
1275         VLOG(3) << "Delaying assignment of temp buffer: " << *hlo_value;
1276       }
1277       return OkStatus();
1278     }
1279   }
1280 
1281   if (!assignment->HasAllocation(*hlo_buffer)) {
1282     BufferAllocation* allocation =
1283         assignment->NewAllocation(*hlo_buffer, buffer_size);
1284     allocation_indices->push_back(allocation->index());
1285     VLOG(3) << "New allocation #" << allocation->index()
1286             << " for: " << *hlo_buffer;
1287   }
1288 
1289   TF_RET_CHECK(assignment->HasAllocation(*hlo_buffer));
1290   return OkStatus();
1291 }
1292 
AssignBuffersForComputations(const std::vector<const HloComputation * > & computations,bool is_thread_local,absl::flat_hash_map<const HloComputation *,absl::flat_hash_set<const HloValue * >> * buffers_to_assign_sequentially,BufferAssignment * assignment)1293 Status BufferAssigner::AssignBuffersForComputations(
1294     const std::vector<const HloComputation*>& computations,
1295     bool is_thread_local,
1296     absl::flat_hash_map<const HloComputation*,
1297                         absl::flat_hash_set<const HloValue*>>*
1298         buffers_to_assign_sequentially,
1299     BufferAssignment* assignment) {
1300   if (computations.empty()) {
1301     return OkStatus();
1302   }
1303   std::vector<const HloBuffer*> sorted_buffers;
1304 
1305   // First assign the preset allocations.
1306   absl::flat_hash_set<const HloBuffer*> preset_assigned_buffers;
1307 
1308   TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment));
1309 
1310   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1311 
1312   for (const HloBuffer& buffer : alias_analysis.buffers()) {
1313     // Skip if the buffer is already assigned since it had a preset allocation.
1314     if (preset_assigned_buffers.find(&buffer) !=
1315         preset_assigned_buffers.end()) {
1316       VLOG(3) << "Skip allocation for buffer: " << buffer;
1317       continue;
1318     }
1319     TF_RET_CHECK(!buffer.values().empty());
1320     const HloComputation* comp = buffer.values()[0]->instruction()->parent();
1321     if (absl::c_linear_search(computations, comp)) {
1322       sorted_buffers.push_back(&buffer);
1323     }
1324   }
1325 
1326   // Generate a post order sort of instructions for sorting of the
1327   // HloBuffers.
1328   flat_hash_map<const HloInstruction*, int> post_order_position;
1329   int position = 0;
1330   std::vector<const HloComputation*> reverse_post_order_computations;
1331   std::unique_ptr<CallGraph> call_graph =
1332       CallGraph::Build(computations[0]->parent());
1333   TF_RETURN_IF_ERROR(call_graph->VisitNodes([&](const CallGraphNode& node) {
1334     if (absl::c_linear_search(computations, node.computation())) {
1335       reverse_post_order_computations.push_back(node.computation());
1336     }
1337     return OkStatus();
1338   }));
1339   absl::c_reverse(reverse_post_order_computations);
1340   for (auto* computation : reverse_post_order_computations) {
1341     for (auto* instruction : computation->MakeInstructionPostOrder()) {
1342       post_order_position.emplace(instruction, position);
1343       position++;
1344     }
1345   }
1346 
1347   HloSchedule schedule(&assignment->module());
1348 
1349   for (const HloComputation* computation : computations) {
1350     const HloInstructionSequence* instruction_sequence =
1351         assignment->hlo_ordering().SequentialOrder(*computation);
1352     const bool has_sequential_order = instruction_sequence != nullptr;
1353     if (has_sequential_order && buffers_to_assign_sequentially != nullptr) {
1354       // Every sequential computation must get an entry in the
1355       // buffers_to_assign_sequentially map, even if we end up with an empty
1356       // set of buffers. This ensures we can correctly determine whether to
1357       // run whole-module heap simulation.
1358       buffers_to_assign_sequentially->emplace(computation,
1359                                               flat_hash_set<const HloValue*>());
1360 
1361       schedule.set_sequence(computation, *instruction_sequence);
1362     }
1363   }
1364 
1365   absl::c_stable_sort(
1366       sorted_buffers, [&post_order_position, &alias_analysis, assignment](
1367                           const HloBuffer* a, const HloBuffer* b) {
1368         // Primary sort is by decreasing buffer size.
1369         const int64_t a_size = assignment->HloBufferSize(*a);
1370         const int64_t b_size = assignment->HloBufferSize(*b);
1371         if (a_size != b_size) {
1372           return a_size > b_size;  // use ">" for decreasing size.
1373         }
1374 
1375         const bool a_live_out = alias_analysis.BufferLivesOut(*a);
1376         const bool b_live_out = alias_analysis.BufferLivesOut(*b);
1377         if (a_live_out != b_live_out) {
1378           return a_live_out;
1379         }
1380         auto compare = [&post_order_position](const HloValue* value1,
1381                                               const HloValue* value2) {
1382           return post_order_position.at(value1->instruction()) <
1383                  post_order_position.at(value2->instruction());
1384         };
1385         const HloValue* a_min = *absl::c_min_element(a->values(), compare);
1386         const HloValue* b_min = *absl::c_min_element(b->values(), compare);
1387         return compare(a_min, b_min);
1388       });
1389 
1390   std::vector<BufferAllocation::Index> allocation_indices;
1391 
1392   for (const HloBuffer* buffer : sorted_buffers) {
1393     VLOG(3) << "=================================================";
1394     VLOG(3) << "Assigning buffer for " << *buffer;
1395     TF_RETURN_IF_ERROR(AssignSingleHloBuffer(buffer, is_thread_local,
1396                                              buffers_to_assign_sequentially,
1397                                              &allocation_indices, assignment));
1398   }
1399   return OkStatus();
1400 }
1401 
1402 flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>>
SplitBuffersByColor(const flat_hash_set<const HloValue * > & buffers)1403 BufferAssigner::SplitBuffersByColor(
1404     const flat_hash_set<const HloValue*>& buffers) {
1405   flat_hash_map<LogicalBuffer::Color, flat_hash_set<const HloValue*>> color_map;
1406   for (auto buffer : buffers) {
1407     color_map[buffer->color()].insert(buffer);
1408   }
1409   return color_map;
1410 }
1411 
AssignPresetBuffers(absl::flat_hash_set<const HloBuffer * > * assigned_buffers,BufferAssignment * assignment)1412 Status BufferAssigner::AssignPresetBuffers(
1413     absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
1414     BufferAssignment* assignment) {
1415   if (!preset_assignments_) {
1416     return OkStatus();
1417   }
1418 
1419   // Create an allocation for each preset color.
1420   absl::flat_hash_map<LogicalBuffer::Color, BufferAllocation*>
1421       preset_allocations;
1422   for (auto& color_and_info : preset_assignments_->assignment_informations()) {
1423     LogicalBuffer::Color color(color_and_info.first);
1424     auto inserted = preset_allocations.emplace(
1425         color,
1426         assignment->NewEmptyAllocation(color_and_info.second.size, color));
1427     BufferAllocation* inserted_allocation = inserted.first->second;
1428     inserted_allocation->AddHeapTrace(
1429         color_and_info.second.heap_simulator_trace);
1430     VLOG(3) << "Created preset buffer allocation "
1431             << inserted_allocation->index()
1432             << ", color: " << inserted_allocation->color()
1433             << ", size: " << inserted_allocation->size();
1434   }
1435 
1436   const HloAliasAnalysis& alias_analysis = assignment->alias_analysis();
1437 
1438   for (auto& position_and_chunk : preset_assignments_->chunks()) {
1439     const HloPosition& defining_position = position_and_chunk.first;
1440     const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
1441         defining_position.instruction, defining_position.index);
1442     for (const HloValue* value : buffer.values()) {
1443       VLOG(3) << "Preset allocation for value: " << value->ToShortString();
1444       const HeapSimulator::Chunk& chunk = position_and_chunk.second;
1445       auto preset_allocations_iter = preset_allocations.find(value->color());
1446       CHECK(preset_allocations_iter != preset_allocations.end())
1447           << "No preset value allocation for color " << value->color()
1448           << " for " << value->ToShortString() << " found.";
1449       preset_allocations_iter->second->AddAssignment(*value, chunk.offset,
1450                                                      chunk.size);
1451     }
1452 
1453     assigned_buffers->insert(&buffer);
1454   }
1455 
1456   // Upon consumption of the preset assignments, delete it so that if this
1457   // method is called again, it does not assign the same buffers multiple times.
1458   preset_assignments_ = {};
1459 
1460   return OkStatus();
1461 }
1462 
AssignBuffersWithSequentialOrdering(const flat_hash_map<const HloComputation *,flat_hash_set<const HloValue * >> & buffers_to_assign_sequentially,bool run_whole_module_heap_simulation,BufferAssignment * assignment)1463 Status BufferAssigner::AssignBuffersWithSequentialOrdering(
1464     const flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>&
1465         buffers_to_assign_sequentially,
1466     bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
1467   // Run the sequence of instructions through the heap simulator.  The
1468   // heuristic that seems to give the best results is lazy-best-fit, with all
1469   // runs of alloc / free calls sorted in decreasing size order.
1470   const HloOrdering& hlo_ordering = assignment->hlo_ordering();
1471 
1472   // Returns a heap algorithm that chooses the best result from several
1473   // algorithms.
1474   auto get_heap_algorithm = [&](int64_t alignment) {
1475     auto algorithms = std::make_unique<
1476         std::vector<std::unique_ptr<HeapAlgorithm<HloValue>>>>();
1477     algorithms->push_back(
1478         std::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1479             assignment->multiheap_size_constraint_per_heap(), alignment,
1480             GlobalDecreasingSizeBestFitHeap<HloValue>::kSpatial));
1481     algorithms->push_back(
1482         std::make_unique<ConstrainedGlobalDecreasingSizeBestFitHeap>(
1483             assignment->multiheap_size_constraint_per_heap(), alignment,
1484             GlobalDecreasingSizeBestFitHeap<HloValue>::kTemporal));
1485     return std::make_unique<ChooseBestHeapAlgorithm<HloValue>>(
1486         std::move(algorithms));
1487   };
1488 
1489   if (run_whole_module_heap_simulation) {
1490     // Run the heap simulation over the whole module. This reduces memory
1491     // usage, since buffers for kCall, kWhile, and kConditional
1492     // sub-computations are only live for the duration of their calling
1493     // instructions.
1494     VLOG(1) << "Running whole-module heap simulation";
1495     HloSchedule schedule(&assignment->module());
1496     flat_hash_set<const HloValue*> all_buffers_to_assign;
1497     for (const auto& pair : buffers_to_assign_sequentially) {
1498       const HloComputation* computation = pair.first;
1499       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1500       const HloInstructionSequence* instruction_sequence =
1501           hlo_ordering.SequentialOrder(*computation);
1502       CHECK(instruction_sequence != nullptr) << computation->name();
1503       schedule.set_sequence(computation, *instruction_sequence);
1504       all_buffers_to_assign.insert(buffers_to_assign.begin(),
1505                                    buffers_to_assign.end());
1506     }
1507     auto color_map = SplitBuffersByColor(all_buffers_to_assign);
1508     for (auto& single_colored_set : color_map) {
1509       auto color = single_colored_set.first;
1510       VLOG(2) << "Simulating heap for color " << color;
1511       int64_t alignment = assignment->color_alignment_(color);
1512       HeapSimulator::Options options;
1513       options.alloc_constants = allocate_buffers_for_constants_;
1514       options.buffers_to_assign = &single_colored_set.second;
1515 
1516       TF_ASSIGN_OR_RETURN(
1517           HeapSimulator::Result<HloValue> result,
1518           HeapSimulator::Run(
1519               get_heap_algorithm(alignment), assignment->module(), schedule,
1520               assignment->alias_analysis(), assignment->buffer_size_, options));
1521       AssignBuffersFromHeapSimulator(result, assignment,
1522                                      single_colored_set.first);
1523     }
1524   } else {
1525     // Run the heap-simulation on a per-computation basis. Buffers for
1526     // sub-computations are assigned disjoint BufferAllocations, assuming the
1527     // worst-case that they may all be live concurrently.
1528     VLOG(1) << "Running per-computation heap simulation";
1529     for (const auto& pair : buffers_to_assign_sequentially) {
1530       const HloComputation* computation = pair.first;
1531       const flat_hash_set<const HloValue*>& buffers_to_assign = pair.second;
1532       const HloInstructionSequence* instruction_sequence =
1533           hlo_ordering.SequentialOrder(*computation);
1534       CHECK(instruction_sequence != nullptr) << computation->name();
1535       auto color_map = SplitBuffersByColor(buffers_to_assign);
1536       for (auto& single_colored_set : color_map) {
1537         auto color = single_colored_set.first;
1538         VLOG(2) << "Simulating heap for color " << color;
1539         int64_t alignment = assignment->color_alignment_(color);
1540         HeapSimulator::Options options;
1541         options.buffers_to_assign = &single_colored_set.second;
1542         TF_ASSIGN_OR_RETURN(
1543             HeapSimulator::Result<HloValue> result,
1544             HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
1545                                *instruction_sequence,
1546                                assignment->alias_analysis(),
1547                                assignment->buffer_size_, options));
1548         AssignBuffersFromHeapSimulator(result, assignment,
1549                                        single_colored_set.first);
1550       }
1551     }
1552   }
1553   return OkStatus();
1554 }
1555 
1556 namespace {
1557 // Computes and returns the set of logical buffers live at the point of
1558 // maximal liveness in the given heap trace. LogicalBuffers are (stabily)
1559 // sorted by id.
ComputePeakMemoryLogicalBuffers(const BufferAllocation & allocation,const HeapSimulatorTrace & heap_trace)1560 std::vector<const HloValue*> ComputePeakMemoryLogicalBuffers(
1561     const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
1562   // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
1563   // buffers in this allocation.
1564   absl::flat_hash_map<BufferValue::Id, const HloValue*> id_to_value;
1565   absl::flat_hash_map<const HloValue*, int64_t> buffer_sizes;
1566   for (const auto& pair : allocation.assigned_buffers()) {
1567     const HloValue* value = pair.first;
1568     const BufferAllocation::OffsetSize& offset_size = pair.second;
1569     id_to_value[value->id()] = value;
1570     buffer_sizes[value] = offset_size.size;
1571   }
1572   VLOG(1) << "Compute peak memory logical buffers";
1573 
1574   // To properly account for shared buffers, we keep track of the number of
1575   // instances of the same shared buffer are currently live, their canonical ids
1576   // and the size we had returned when allocating the buffer so that we can
1577   // return the -size when freeing the buffer.
1578   absl::flat_hash_map<int64_t, int> num_outstanding_shared_buffers;
1579   absl::flat_hash_map<int64_t, int64_t> shared_canonical_ids;
1580   absl::flat_hash_map<int64_t, int64_t> allocated_sizes;
1581   // Returns how much the given event increases the total size of live
1582   // buffers. Can be negative.
1583   auto memory_delta = [&](const HeapSimulatorTrace::Event& event) -> int64_t {
1584     const HloValue* buffer = id_to_value.at(event.buffer_id());
1585     const int64_t buffer_size = buffer_sizes.at(buffer);
1586     if (event.kind() == HeapSimulatorTrace::Event::ALLOC) {
1587       num_outstanding_shared_buffers[event.buffer_id()] = 1;
1588       allocated_sizes[event.buffer_id()] = buffer_size;
1589       return buffer_size;
1590     } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) {
1591       shared_canonical_ids[event.buffer_id()] = event.share_with_canonical_id();
1592       if (++num_outstanding_shared_buffers[event.share_with_canonical_id()] ==
1593           1) {
1594         // This shared buffer is currently the only instance of the buffer with
1595         // the canonical id. So we return the buffer size.
1596         allocated_sizes[event.buffer_id()] = buffer_size;
1597         return buffer_size;
1598       }
1599       // There are multiple instances of this buffer, so return 0.
1600       allocated_sizes[event.buffer_id()] = 0;
1601       return 0;
1602     } else if (event.kind() == HeapSimulatorTrace::Event::FREE) {
1603       auto shared_canonical_id_it =
1604           shared_canonical_ids.find(event.buffer_id());
1605       // Decrement the outstanding instances of this buffer and return the
1606       // -size.
1607       int64_t buffer_id = (shared_canonical_id_it == shared_canonical_ids.end())
1608                               ? event.buffer_id()
1609                               : shared_canonical_id_it->second;
1610       --num_outstanding_shared_buffers[buffer_id];
1611       return -1 * allocated_sizes[event.buffer_id()];
1612     }
1613     LOG(FATAL) << "Unknown event kind: " << event.kind();
1614   };
1615 
1616   // First compute the size of the maximal live set.
1617   int64_t max_live_size = 0;
1618   int64_t live_size = 0;
1619   for (const auto& event : heap_trace.events()) {
1620     if (!id_to_value.contains(event.buffer_id())) {
1621       // Skip as the buffer associated with this trace event is not placed into
1622       // this allocation. This can happen when size constraints are given to the
1623       // heap simulator.
1624       continue;
1625     }
1626     live_size += memory_delta(event);
1627     if (max_live_size < live_size) {
1628       max_live_size = live_size;
1629     }
1630   }
1631 
1632   // Next gather the set of logical buffers live at the earliest point of
1633   // maximal live set size.
1634   absl::flat_hash_set<const HloValue*> live_values;
1635   live_size = 0;
1636   num_outstanding_shared_buffers.clear();
1637   for (const auto& event : heap_trace.events()) {
1638     if (!id_to_value.contains(event.buffer_id())) {
1639       // Skip as the buffer associated with this trace event is not placed into
1640       // this allocation. This can happen when size constraints are given to the
1641       // heap simulator.
1642       continue;
1643     }
1644     const HloValue* value = id_to_value.at(event.buffer_id());
1645     int64_t delta = memory_delta(event);
1646     // To avoid including buffers that are aliases of each other to the peak
1647     // buffers list, only add the buffers that memory_delta returns non-zero
1648     // positive sizes. memory_delta returns 0 as the size for the buffer already
1649     // has a live alias of itself.
1650     if (delta > 0) {
1651       InsertOrDie(&live_values, value);
1652     } else if (delta < 0) {
1653       CHECK(ContainsKey(live_values, value));
1654       live_values.erase(value);
1655     }
1656     live_size += delta;
1657 
1658     if (live_size == max_live_size) {
1659       break;
1660     }
1661   }
1662   CHECK_EQ(live_size, max_live_size);
1663 
1664   std::vector<const HloValue*> live_values_vector;
1665   live_values_vector.insert(live_values_vector.end(), live_values.begin(),
1666                             live_values.end());
1667 
1668   // Stabily sort the live buffers.
1669   absl::c_sort(live_values_vector, [](const HloValue* a, const HloValue* b) {
1670     return a->id() < b->id();
1671   });
1672   VLOG(4) << "Peak memory buffer:";
1673   for (auto value : live_values_vector) {
1674     VLOG(4) << "  " << value->ToString();
1675   }
1676   return live_values_vector;
1677 }
1678 
1679 }  // namespace
1680 
AssignBuffersFromHeapSimulator(const HeapSimulator::Result<HloValue> & result,BufferAssignment * assignment,BufferValue::Color color)1681 void BufferAssigner::AssignBuffersFromHeapSimulator(
1682     const HeapSimulator::Result<HloValue>& result, BufferAssignment* assignment,
1683     BufferValue::Color color) {
1684   if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) {
1685     assignment->stats_.preallocated_temp_fragmentation_bytes =
1686         result.fragmentation_size;
1687   } else {
1688     assignment->stats_.preallocated_temp_fragmentation_bytes +=
1689         result.fragmentation_size;
1690   }
1691   VLOG(1) << "Result size from heap simulator: " << result.heap_size;
1692 
1693   // Iterate through heap_results. For each heap_result, create a new allocation
1694   // in `assignment`.
1695   for (const HeapSimulator::HeapResult<HloValue>& heap_result :
1696        result.heap_results) {
1697     BufferAllocation* allocation =
1698         assignment->NewEmptyAllocation(heap_result.heap_size, color);
1699     for (const auto& buffer_chunk : heap_result.chunk_map) {
1700       const HloValue& value = *buffer_chunk.first;
1701       const HeapSimulator::Chunk& chunk = buffer_chunk.second;
1702       assignment->AddAssignment(allocation, value, chunk.offset, chunk.size);
1703     }
1704     allocation->peak_buffers_ =
1705         ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace);
1706 
1707     XLA_VLOG_LINES(2, allocation->ToString());
1708 
1709     allocation->AddHeapTrace(result.debug_trace);
1710   }
1711 }
1712 
CreateAssignment(const HloModule * module,std::unique_ptr<HloOrdering> hlo_ordering,BufferValue::SizeFunction buffer_size,LogicalBuffer::AlignmentFunction color_alignment,HloDataflowAnalysis::CanShareBuffer can_share_buffer)1713 StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
1714     const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
1715     BufferValue::SizeFunction buffer_size,
1716     LogicalBuffer::AlignmentFunction color_alignment,
1717     HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
1718   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1719                       HloAliasAnalysis::Run(module, can_share_buffer));
1720 
1721   // Set up a schedule for each computation.
1722   HloSchedule schedule(module);
1723   for (const HloComputation* computation : module->computations()) {
1724     const HloInstructionSequence* instruction_sequence =
1725         hlo_ordering->SequentialOrder(*computation);
1726     const bool has_sequential_order = instruction_sequence != nullptr;
1727     if (has_sequential_order) {
1728       schedule.set_sequence(computation, *instruction_sequence);
1729     }
1730   }
1731 
1732   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1733                       HloLiveRange::Run(schedule, *alias_analysis,
1734                                         module->entry_computation(), true));
1735 
1736   VLOG(1) << "Assigning buffers to module " << module->name();
1737   XLA_VLOG_LINES(3, module->ToString());
1738   XLA_VLOG_LINES(3, alias_analysis->ToString());
1739   XLA_VLOG_LINES(3, alias_analysis->dataflow_analysis().ToString());
1740   VLOG(1) << "Number of buffers to assign: "
1741           << alias_analysis->buffers().size();
1742 
1743   // Can't use std::make_unique because BufferAssignment constructor is
1744   // private.
1745   std::unique_ptr<BufferAssignment> assignment(new BufferAssignment(
1746       module, std::move(hlo_ordering), std::move(buffer_size),
1747       std::move(color_alignment), std::move(alias_analysis),
1748       std::move(hlo_live_range)));
1749 
1750   TF_RETURN_IF_ERROR(
1751       colorer_(&assignment->alias_analysis(), assignment->hlo_ordering()));
1752   VLOG(3) << "After coloring:";
1753   XLA_VLOG_LINES(3,
1754                  assignment->alias_analysis().dataflow_analysis().ToString());
1755 
1756   std::vector<const HloComputation*> thread_local_computations;
1757   std::vector<const HloComputation*> global_computations;
1758   TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
1759       module, &thread_local_computations, &global_computations));
1760 
1761   // First assign buffers for global computations. Temporary buffers for
1762   // sequential computations are collected in
1763   // 'buffers_to_assign_sequentially'.
1764   flat_hash_map<const HloComputation*, flat_hash_set<const HloValue*>>
1765       buffers_to_assign_sequentially;
1766   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1767       global_computations,
1768       /*is_thread_local=*/false, &buffers_to_assign_sequentially,
1769       assignment.get()));
1770   // Assign buffers with sequential ordering, if any. If all global
1771   // computations are sequential, we can run heap simulation on the whole
1772   // module, which reduces memory usage.
1773   const bool run_whole_module_heap_simulation =
1774       buffers_to_assign_sequentially.size() == global_computations.size();
1775   VLOG(2) << "Running whole module heap simulation: "
1776           << run_whole_module_heap_simulation;
1777   const int32_t multiheap_size_constraint_per_heap =
1778       module->config().debug_options().xla_multiheap_size_constraint_per_heap();
1779   VLOG(2) << "Multiheap per heap size limit: "
1780           << multiheap_size_constraint_per_heap;
1781   TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering(
1782       buffers_to_assign_sequentially, run_whole_module_heap_simulation,
1783       assignment.get()));
1784 
1785   std::vector<const HloComputation*> thread_local_computations_no_fusion;
1786   // Now assign buffers for thread-local computations. All LogicalBuffers get
1787   // their own BufferAllocation.
1788 
1789   for (auto* computation : thread_local_computations) {
1790     TF_RET_CHECK(computation != module->entry_computation());
1791     if (computation->IsFusionComputation()) {
1792       continue;
1793     }
1794     thread_local_computations_no_fusion.push_back(computation);
1795   }
1796 
1797   TF_RETURN_IF_ERROR(AssignBuffersForComputations(
1798       thread_local_computations_no_fusion,
1799       /*is_thread_local=*/true,
1800       /*buffers_to_assign_sequentially=*/nullptr, assignment.get()));
1801 
1802   // Mark all buffers which may be live out of the entry computation as
1803   // "liveout".
1804   for (const HloBuffer* buffer :
1805        assignment->alias_analysis().LiveOutBuffers()) {
1806     VLOG(3) << "maybe_live_out LogicalBuffer: " << *buffer;
1807     if (assignment->HasAllocation(*buffer)) {
1808       BufferAllocation* alloc =
1809           assignment->GetMutableAssignedAllocation(*buffer);
1810       alloc->set_maybe_live_out(true);
1811       VLOG(3) << "maybe_live_out BufferAllocation: " << *alloc;
1812     }
1813   }
1814 
1815   // Combines allocations of temporary buffers into big BufferAllocations
1816   // subject to the buffer allocation size constraint. This can only be
1817   // performed after all buffers have been assigned, and after maybe_live_out
1818   // is marked, since it is used to determine whether an allocation contains
1819   // temporary buffers or not.
1820   assignment->CombineTempAllocations();
1821 
1822   XLA_VLOG_LINES(2, assignment->ToString());
1823   TF_RETURN_IF_ERROR(assignment->ComputeSummaryStats());
1824   XLA_VLOG_LINES(1, assignment->GetStats().ToString());
1825   VLOG(1) << "Buffer assignment done.";
1826   return std::move(assignment);
1827 }
1828 
1829 }  // namespace xla
1830