1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/container/node_hash_set.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/optional.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/service/hlo.pb.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39 #include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h"
40 
41 namespace tensorflow {
42 namespace profiler {
43 namespace {
44 
45 using absl::StrFormat;
46 using ::xla::BufferAllocationProto;
47 using ::xla::HloInstructionProto;
48 using ::xla::HloProto;
49 using ::xla::LayoutUtil;
50 using ::xla::LogicalBufferProto;
51 using ::xla::Shape;
52 using ::xla::ShapeUtil;
53 
BytesToMiB(int64_t bytes)54 double BytesToMiB(int64_t bytes) {
55   return static_cast<double>(bytes) / tensorflow::MathUtil::IPow(2, 20);
56 }
57 
58 // Get buffer allocation property.
GetAllocationGroupName(const BufferAllocationProto * buffer_allocation)59 std::string GetAllocationGroupName(
60     const BufferAllocationProto* buffer_allocation) {
61   if (buffer_allocation == nullptr) {
62     return "";
63   }
64   if (buffer_allocation->is_entry_computation_parameter()) {
65     return "Parameter";
66   } else if (buffer_allocation->maybe_live_out()) {
67     return "Output";
68   } else if (buffer_allocation->is_thread_local()) {
69     return "Thread-local";
70   } else {
71     return "Temporary";
72   }
73 }
74 
75 // Get the instruction associated with the logical buffer.
GetInstructionName(const LogicalBufferProto * logical_buffer)76 std::string GetInstructionName(const LogicalBufferProto* logical_buffer) {
77   if (logical_buffer == nullptr) {
78     return "";
79   }
80   if (logical_buffer->defined_at().shape_index().empty()) {
81     return logical_buffer->defined_at().instruction_name();
82   } else {
83     return absl::StrCat(
84         logical_buffer->defined_at().instruction_name(), "{",
85         absl::StrJoin(logical_buffer->defined_at().shape_index(), ""), "}");
86   }
87 }
88 
MakeHeapObjectCommon(std::string label,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)89 HeapObject MakeHeapObjectCommon(std::string label, int logical_buffer_id,
90                                 int64_t logical_buffer_size_bytes,
91                                 int64_t unpadded_shape_bytes) {
92   HeapObject result;
93   result.set_label(std::move(label));
94   result.set_logical_buffer_id(logical_buffer_id);
95   result.set_logical_buffer_size_mib(BytesToMiB(logical_buffer_size_bytes));
96   result.set_unpadded_shape_mib(BytesToMiB(unpadded_shape_bytes));
97   return result;
98 }
99 
MakeHeapObject(const std::string & tf_op_name,const std::string & shape_string,const std::string & op_code,std::string instruction_name,std::string group_name,std::string label,int color,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)100 HeapObject MakeHeapObject(const std::string& tf_op_name,
101                           const std::string& shape_string,
102                           const std::string& op_code,
103                           std::string instruction_name, std::string group_name,
104                           std::string label, int color, int logical_buffer_id,
105                           int64_t logical_buffer_size_bytes,
106                           int64_t unpadded_shape_bytes) {
107   HeapObject result =
108       MakeHeapObjectCommon(std::move(label), logical_buffer_id,
109                            logical_buffer_size_bytes, unpadded_shape_bytes);
110   result.set_numbered(color);
111   result.set_instruction_name(std::move(instruction_name));
112   result.set_group_name(std::move(group_name));
113   result.set_tf_op_name(tf_op_name);
114   result.set_shape_string(shape_string);
115   result.set_op_code(op_code);
116   return result;
117 }
118 
MakeHeapObject(std::string color,std::string label,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)119 HeapObject MakeHeapObject(std::string color, std::string label,
120                           int logical_buffer_id,
121                           int64_t logical_buffer_size_bytes,
122                           int64_t unpadded_shape_bytes) {
123   HeapObject result =
124       MakeHeapObjectCommon(std::move(label), logical_buffer_id,
125                            logical_buffer_size_bytes, unpadded_shape_bytes);
126   result.set_named(std::move(color));
127   return result;
128 }
129 
MakeBufferSpan(int32 start,int32 limit)130 BufferSpan MakeBufferSpan(int32 start, int32 limit) {
131   BufferSpan result;
132   result.set_start(start);
133   result.set_limit(limit);
134   return result;
135 }
136 
ResolveShapeIndex(const Shape * shape,absl::Span<const int64_t> shape_index)137 const Shape* ResolveShapeIndex(const Shape* shape,
138                                absl::Span<const int64_t> shape_index) {
139   for (int64_t value : shape_index) {
140     shape = &shape->tuple_shapes(value);
141   }
142   return shape;
143 }
144 
145 // A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding,
146 // since that is considered in the ByteSizeOf calculation.
UnpaddedSize(Shape shape)147 int64_t UnpaddedSize(Shape shape) {
148   // Ensure the layout has no padding by making it the default layout.
149   LayoutUtil::SetToDefaultLayout(&shape);
150   // Note: we make a simplifying assumption here that a "minimal" size for a
151   // tuple member would be the size of a `void*` -- there may be even fancier
152   // ways of doing things, but this should give a good enough approximation of
153   // what a minimal tuple size is.
154   return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
155 }
156 
Convert(const xla::BufferAllocationProto_Assigned & assigned,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,LogicalBuffer * result)157 void Convert(const xla::BufferAllocationProto_Assigned& assigned,
158              const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
159                  id_to_logical_buffer,
160              const absl::node_hash_map<std::string, const HloInstructionProto*>&
161                  name_to_hlo,
162              LogicalBuffer* result) {
163   result->set_id(assigned.logical_buffer_id()),
164       result->set_size_mib(BytesToMiB(assigned.size()));
165   const LogicalBufferProto* proto =
166       id_to_logical_buffer.at(assigned.logical_buffer_id());
167   const std::string& instruction_name = proto->defined_at().instruction_name();
168   result->set_hlo_name(instruction_name);
169   result->mutable_shape_index()->CopyFrom(proto->defined_at().shape_index());
170   const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
171   const Shape* shape =
172       ResolveShapeIndex(&top_level_shape, proto->defined_at().shape_index());
173   result->set_shape(ShapeUtil::HumanStringWithLayout(*shape));
174 }
175 
IsReusable(const BufferAllocationProto & buffer_allocation)176 bool IsReusable(const BufferAllocationProto& buffer_allocation) {
177   return !buffer_allocation.is_thread_local() && !buffer_allocation.is_tuple();
178 }
179 
Convert(const BufferAllocationProto & proto,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,BufferAllocation * result)180 void Convert(const BufferAllocationProto& proto,
181              const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
182                  id_to_logical_buffer,
183              const absl::node_hash_map<std::string, const HloInstructionProto*>&
184                  name_to_hlo,
185              BufferAllocation* result) {
186   result->set_id(proto.index());
187   result->set_size_mib(BytesToMiB(proto.size()));
188   if (proto.is_entry_computation_parameter()) {
189     result->add_attributes("entry computation parameter");
190   }
191   if (proto.maybe_live_out()) {
192     result->add_attributes("may-be live out");
193   }
194   if (IsReusable(proto)) {
195     result->add_attributes("reusable");
196   }
197   for (const auto& assigned : proto.assigned()) {
198     Convert(assigned, id_to_logical_buffer, name_to_hlo,
199             result->add_logical_buffers());
200   }
201   // Check whether all logical buffers for this buffer allocation have a common
202   // shape.
203   if (!result->logical_buffers().empty()) {
204     std::string common_shape = result->logical_buffers(0).shape();
205     for (int64_t i = 1; i < result->logical_buffers_size(); ++i) {
206       if (result->logical_buffers(i).shape() != common_shape) {
207         common_shape = "";
208         break;
209       }
210     }
211     if (!common_shape.empty()) {
212       result->set_common_shape(common_shape);
213     }
214   }
215 }
216 
NoteSpecialAllocations(const absl::flat_hash_set<const BufferAllocationProto * > & all_buffer_allocations,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,int64_t small_buffer_size,PreprocessResult * result)217 void NoteSpecialAllocations(
218     const absl::flat_hash_set<const BufferAllocationProto*>&
219         all_buffer_allocations,
220     const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
221         id_to_logical_buffer,
222 
223     const absl::node_hash_map<std::string, const HloInstructionProto*>&
224         name_to_hlo,
225     int64_t small_buffer_size, PreprocessResult* result) {
226   int64_t entry_parameters_bytes = 0;
227   int64_t non_reusable_bytes = 0;
228   int64_t maybe_live_out_bytes = 0;
229   for (const BufferAllocationProto* buffer_allocation :
230        all_buffer_allocations) {
231     if (buffer_allocation->is_entry_computation_parameter()) {
232       entry_parameters_bytes += buffer_allocation->size();
233     }
234     if (!IsReusable(*buffer_allocation)) {
235       non_reusable_bytes += buffer_allocation->size();
236     }
237     if (buffer_allocation->maybe_live_out()) {
238       if (buffer_allocation->size() > small_buffer_size) {
239         VLOG(1) << "Maybe live out buffer allocation: "
240                 << buffer_allocation->size()
241                 << " bytes :: " << buffer_allocation->ShortDebugString();
242       }
243       maybe_live_out_bytes += buffer_allocation->size();
244     }
245     Convert(*buffer_allocation, id_to_logical_buffer, name_to_hlo,
246             result->add_indefinite_lifetimes());
247   }
248 
249   result->set_entry_computation_parameters_mib(
250       BytesToMiB(entry_parameters_bytes));
251   result->set_non_reusable_mib(BytesToMiB(non_reusable_bytes));
252   result->set_maybe_live_out_mib(BytesToMiB(maybe_live_out_bytes));
253 }
254 
255 }  // namespace
256 
ConvertHloProtoToPreprocessResult(const HloProto & hlo_proto,int64_t small_buffer_size,int64_t heap_simulator_trace_id,int64_t memory_color)257 absl::StatusOr<PreprocessResult> ConvertHloProtoToPreprocessResult(
258     const HloProto& hlo_proto, int64_t small_buffer_size,
259     int64_t heap_simulator_trace_id, int64_t memory_color) {
260   // Construct a mapping from name to HLO proto.
261   absl::node_hash_map<std::string, const HloInstructionProto*> name_to_hlo;
262   for (const auto& computation : hlo_proto.hlo_module().computations()) {
263     for (const auto& instruction : computation.instructions()) {
264       name_to_hlo[instruction.name()] = &instruction;
265       VLOG(1) << "HLO: " << instruction.ShortDebugString();
266     }
267   }
268 
269   // Mapping from logical buffer ID to logical buffer, and set of all logical
270   // buffer protos.
271   absl::flat_hash_map<int64_t, const LogicalBufferProto*> id_to_logical_buffer;
272   absl::flat_hash_set<const LogicalBufferProto*> all_logical_buffers;
273   for (const auto& logical_buffer :
274        hlo_proto.buffer_assignment().logical_buffers()) {
275     VLOG(1) << "Logical buffer: " << logical_buffer.ShortDebugString();
276     id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
277     all_logical_buffers.insert(&logical_buffer);
278   }
279 
280   // Mapping from logocal buffer proto to the buffer allocation that it exists
281   // inside (there must be only one).
282   //
283   // Also a reverse mapping from buffer allocation proto to the set of logical
284   // buffer protos that exist inside of it.
285   absl::flat_hash_map<const LogicalBufferProto*, const BufferAllocationProto*>
286       logical_buffer_to_buffer_allocation;
287   absl::node_hash_map<const BufferAllocationProto*,
288                       absl::flat_hash_set<const LogicalBufferProto*>>
289       buffer_allocation_to_logical_buffers;
290   absl::flat_hash_set<const BufferAllocationProto*> all_buffer_allocations;
291   for (const BufferAllocationProto& buffer_allocation :
292        hlo_proto.buffer_assignment().buffer_allocations()) {
293     all_buffer_allocations.insert(&buffer_allocation);
294     for (const xla::BufferAllocationProto_Assigned& assigned :
295          buffer_allocation.assigned()) {
296       const LogicalBufferProto* logical_buffer =
297           id_to_logical_buffer.at(assigned.logical_buffer_id());
298       buffer_allocation_to_logical_buffers[&buffer_allocation].insert(
299           logical_buffer);
300       auto insert_result = logical_buffer_to_buffer_allocation.insert(
301           {logical_buffer, &buffer_allocation});
302       if (!insert_result.second) {
303         return absl::InvalidArgumentError(
304             "A logical buffer appears to be associated with multiple buffer "
305             "allocations.");
306       }
307     }
308   }
309 
310   std::vector<int64_t> logical_buffers;
311   std::vector<int64_t> peak_logical_buffers;
312 
313   int64_t heap_size_bytes = 0;
314   int64_t unpadded_heap_size_bytes = 0;
315 
316   int64_t peak_heap_size_bytes = 0;
317   int64_t unpadded_peak_heap_size_bytes = 0;  // Unpadded size at peak.
318   int64_t peak_heap_size_position = 0;
319   std::vector<double> heap_sizes;
320   std::vector<double> unpadded_heap_sizes;
321 
322   absl::node_hash_map<int64_t, std::pair<int64_t, absl::optional<int64_t>>>
323       logical_buffer_spans;
324   absl::flat_hash_set<const LogicalBufferProto*> seen;
325   absl::flat_hash_set<const BufferAllocationProto*> seen_buffer_allocations;
326 
327   // Run through all the simulator events in the given trace, and simulate the
328   // heap in order to find the point of peak memory usage and record its
329   // associated metadata.
330   if (heap_simulator_trace_id >= 0 &&
331       heap_simulator_trace_id <
332           hlo_proto.buffer_assignment().heap_simulator_traces_size()) {
333     const auto& simulator_events =
334         hlo_proto.buffer_assignment()
335             .heap_simulator_traces(heap_simulator_trace_id)
336             .events();
337     for (const auto& event : simulator_events) {
338       heap_sizes.push_back(BytesToMiB(heap_size_bytes));
339       unpadded_heap_sizes.push_back(BytesToMiB(unpadded_heap_size_bytes));
340       const auto* logical_buffer = id_to_logical_buffer.at(event.buffer_id());
341       seen.insert(logical_buffer);
342       seen_buffer_allocations.insert(
343           logical_buffer_to_buffer_allocation.at(logical_buffer));
344       const auto& instruction_name =
345           logical_buffer->defined_at().instruction_name();
346       const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
347       const Shape* shape = ResolveShapeIndex(
348           &top_level_shape, logical_buffer->defined_at().shape_index());
349       if (event.kind() == xla::HeapSimulatorTrace_Event::ALLOC ||
350           event.kind() == xla::HeapSimulatorTrace_Event::SHARE_WITH) {
351         logical_buffers.push_back(event.buffer_id());
352         heap_size_bytes += logical_buffer->size();
353         unpadded_heap_size_bytes += UnpaddedSize(*shape);
354         // Initialize the buffer span from the current event to the last event.
355         logical_buffer_spans[event.buffer_id()] = {heap_sizes.size() - 1,
356                                                    simulator_events.size() - 1};
357         int64_t prior_peak_heap_size_bytes = peak_heap_size_bytes;
358         peak_heap_size_bytes = std::max(peak_heap_size_bytes, heap_size_bytes);
359         if (prior_peak_heap_size_bytes != peak_heap_size_bytes) {
360           peak_heap_size_position = heap_sizes.size() - 1;
361           unpadded_peak_heap_size_bytes = unpadded_heap_size_bytes;
362           VLOG(1) << StrFormat("New peak heap size on %d: %s :: %d bytes",
363                                peak_heap_size_position, instruction_name,
364                                peak_heap_size_bytes);
365           peak_logical_buffers = logical_buffers;
366         }
367       } else if (event.kind() == xla::HeapSimulatorTrace_Event::FREE) {
368         logical_buffers.erase(
369             std::remove(logical_buffers.begin(), logical_buffers.end(),
370                         event.buffer_id()),
371             logical_buffers.end());
372         heap_size_bytes -= logical_buffer->size();
373         unpadded_heap_size_bytes -= UnpaddedSize(*shape);
374         logical_buffer_spans[event.buffer_id()].second = heap_sizes.size() - 1;
375         if (heap_size_bytes < 0) {
376           return absl::InvalidArgumentError(absl::StrCat(
377               "heap_size_bytes should be non-negative: ", heap_size_bytes));
378         }
379       } else {
380         return absl::InvalidArgumentError(
381             absl::StrCat("Unhandled event kind: ", event.kind()));
382       }
383     }
384 
385     // Add the final heap size after simulating the entire heap trace.
386     heap_sizes.push_back(BytesToMiB(heap_size_bytes));
387     unpadded_heap_sizes.push_back(BytesToMiB(unpadded_heap_size_bytes));
388 
389     if (seen_buffer_allocations.size() != 1) {
390       return absl::InvalidArgumentError(
391           absl::StrCat("All heap simulation should work out of a single buffer "
392                        "allocation, actual seen_buffer_allocations.size():",
393                        seen_buffer_allocations.size()));
394     }
395   }
396 
397   VLOG(1) << "Found " << peak_logical_buffers.size()
398           << " logical buffers alive at point of peak heap usage.";
399 
400   VLOG(1) << "Peak logical buffers: ["
401           << absl::StrJoin(peak_logical_buffers, ",") << "]";
402 
403   int64_t indefinite_memory_usage_bytes = 0;
404   std::vector<HeapObject> max_heap;
405   int colorno = 0;
406   int64_t rest = 0;
407 
408   // Helper lambda that adds the logical buffer as an element in the "max heap"
409   // view with constitutent logical buffers.
410   auto add_heap_object = [&](const LogicalBufferProto* logical_buffer,
411                              const BufferAllocationProto* buffer_allocation) {
412     if (logical_buffer->size() <= small_buffer_size) {
413       rest += logical_buffer->size();
414       return;
415     }
416     const std::string& instruction_name =
417         logical_buffer->defined_at().instruction_name();
418     const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
419     const Shape* shape = ResolveShapeIndex(
420         &top_level_shape, logical_buffer->defined_at().shape_index());
421     std::string shape_string = ShapeUtil::HumanStringWithLayout(*shape);
422     int64 unpadded_shape_bytes = UnpaddedSize(*shape);
423     const HloInstructionProto* hlo_instruction =
424         name_to_hlo.at(instruction_name);
425     std::string label = StrFormat("%s: %s # %s", instruction_name, shape_string,
426                                   hlo_instruction->metadata().op_name());
427     max_heap.push_back(MakeHeapObject(
428         hlo_instruction->metadata().op_name(), shape_string,
429         hlo_instruction->opcode(), GetInstructionName(logical_buffer),
430         GetAllocationGroupName(buffer_allocation), std::move(label), colorno++,
431         logical_buffer->id(), logical_buffer->size(), unpadded_shape_bytes));
432   };
433 
434   // Now look for all logical buffers which have not been seen, and assume they
435   // have indefinite lifetime if they are not in thread-local buffer
436   // allocations.
437   absl::flat_hash_set<const LogicalBufferProto*> unseen;
438   for (const LogicalBufferProto* logical_buffer : all_logical_buffers) {
439     if (!seen.contains(logical_buffer)) {
440       unseen.insert(logical_buffer);
441     }
442   }
443   for (const LogicalBufferProto* logical_buffer : unseen) {
444     const BufferAllocationProto* buffer_allocation =
445         logical_buffer_to_buffer_allocation.at(logical_buffer);
446     if (buffer_allocation->is_thread_local()) {
447       continue;
448     }
449     if (logical_buffer->color() != memory_color) {
450       continue;
451     }
452     // Clear out the assigned logical buffers when stringifying the buffer
453     // allocation, as it can be a long list.
454     auto to_string = [](const BufferAllocationProto* p) {
455       BufferAllocationProto copy = *p;
456       copy.mutable_assigned()->Clear();
457       return copy.ShortDebugString();
458     };
459     if (seen_buffer_allocations.insert(buffer_allocation).second) {
460       indefinite_memory_usage_bytes += buffer_allocation->size();
461       const auto& logical_buffers =
462           buffer_allocation_to_logical_buffers.at(buffer_allocation);
463       if (logical_buffers.size() == 1) {
464         add_heap_object(*logical_buffers.begin(), buffer_allocation);
465       } else {
466         VLOG(1) << "Indefinite lifetime, no heap object shown due to "
467                    "multiple logical buffers in buffer allocation: "
468                 << logical_buffer->ShortDebugString()
469                 << " :: " << to_string(buffer_allocation) << std::endl;
470       }
471       if (buffer_allocation->size() > small_buffer_size) {
472         VLOG(1) << "Indefinite memory usage now: "
473                 << indefinite_memory_usage_bytes << " bytes (+"
474                 << buffer_allocation->size() << " bytes)";
475       }
476     }
477   }
478 
479   // For the buffers that have indefinite lifetime (that is, lifetime not
480   // reflected by the heap simulation) add it to the peak values and the vectors
481   // of heap sizes.
482   peak_heap_size_bytes += indefinite_memory_usage_bytes;
483   unpadded_peak_heap_size_bytes += indefinite_memory_usage_bytes;
484   double addend = BytesToMiB(indefinite_memory_usage_bytes);
485   for (int i = 0; i < heap_sizes.size(); ++i) {
486     heap_sizes[i] += addend;
487     unpadded_heap_sizes[i] += addend;
488   }
489 
490   // Accumulate data for use in a stacked bar plot.
491   //
492   // We accumulate it in "program order" -- the order in which it was placed
493   // into the logical_buffers sequence above was program order, and we iterate
494   // that order to create data points.
495   for (int logical_buffer_id : peak_logical_buffers) {
496     const auto* logical_buffer = id_to_logical_buffer.at(logical_buffer_id);
497     const auto* buffer_allocation =
498         logical_buffer_to_buffer_allocation.at(logical_buffer);
499     add_heap_object(logical_buffer, buffer_allocation);
500   }
501   if (rest != 0) {
502     max_heap.push_back(MakeHeapObject(
503         "gray", StrFormat("small (<%d bytes)", small_buffer_size), -1, rest,
504         0));
505   }
506 
507   std::vector<const HeapObject*> max_heap_by_size;
508   max_heap_by_size.reserve(max_heap.size());
509   for (const auto& object : max_heap) {
510     max_heap_by_size.push_back(&object);
511   }
512   std::sort(max_heap_by_size.begin(), max_heap_by_size.end(),
513             [](const HeapObject* a, const HeapObject* b) {
514               return a->logical_buffer_size_mib() >
515                      b->logical_buffer_size_mib();
516             });
517 
518   std::vector<int> max_heap_to_by_size;
519   max_heap_to_by_size.reserve(max_heap.size());
520   for (const auto& object : max_heap) {
521     auto it =
522         std::find(max_heap_by_size.begin(), max_heap_by_size.end(), &object);
523     int index = std::distance(max_heap_by_size.begin(), it);
524     max_heap_to_by_size.push_back(index);
525   }
526 
527   std::vector<int> by_size_to_max_heap;
528   for (const auto* object : max_heap_by_size) {
529     int index = object - &max_heap[0];
530     by_size_to_max_heap.push_back(index);
531   }
532 
533   PreprocessResult result;
534   result.set_module_name(hlo_proto.hlo_module().name());
535   result.set_entry_computation_name(
536       hlo_proto.hlo_module().entry_computation_name());
537   *result.mutable_heap_sizes() = {heap_sizes.begin(), heap_sizes.end()};
538   *result.mutable_unpadded_heap_sizes() = {unpadded_heap_sizes.begin(),
539                                            unpadded_heap_sizes.end()};
540   *result.mutable_max_heap() = {max_heap.begin(), max_heap.end()};
541   for (const HeapObject* o : max_heap_by_size) {
542     *result.add_max_heap_by_size() = *o;
543   }
544   *result.mutable_max_heap_to_by_size() = {max_heap_to_by_size.begin(),
545                                            max_heap_to_by_size.end()};
546   *result.mutable_by_size_to_max_heap() = {by_size_to_max_heap.begin(),
547                                            by_size_to_max_heap.end()};
548   result.set_peak_heap_mib(BytesToMiB(peak_heap_size_bytes));
549   result.set_peak_unpadded_heap_mib(BytesToMiB(unpadded_peak_heap_size_bytes));
550   result.set_peak_heap_size_position(peak_heap_size_position);
551 
552   for (const auto& item : logical_buffer_spans) {
553     (*result.mutable_logical_buffer_spans())[item.first] =
554         MakeBufferSpan(item.second.first, item.second.second.value());
555   }
556 
557   NoteSpecialAllocations(all_buffer_allocations, id_to_logical_buffer,
558                          name_to_hlo, small_buffer_size, &result);
559   return result;
560 }
561 
562 // From a list of heap simulator traces, identify the one that has the largest
563 // number of memory events with color <memory_color>.
564 // If unable to find the heap simulator trace, return -1, and
565 // ConvertHloProtoToPreprocessResult will not consider heap_simulator_traces
566 // during preprocess.
GetHeapSimulatorTraceIdFromEvents(const HloProto & proto,int64_t memory_color)567 int64_t GetHeapSimulatorTraceIdFromEvents(const HloProto& proto,
568                                           int64_t memory_color) {
569   absl::flat_hash_map<int64_t, const xla::LogicalBufferProto*>
570       id_to_logical_buffer;
571   for (const auto& logical_buffer :
572        proto.buffer_assignment().logical_buffers()) {
573     id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
574   }
575   int64_t best_index = -1;
576   int64_t best_event_count = 0;
577   for (int64_t i = 0;
578        i < proto.buffer_assignment().heap_simulator_traces_size(); i++) {
579     const auto& heap_simulator_trace =
580         proto.buffer_assignment().heap_simulator_traces(i);
581     int64_t event_count = 0;
582     for (const auto& event : heap_simulator_trace.events()) {
583       const auto iter = id_to_logical_buffer.find(event.buffer_id());
584       if (iter == id_to_logical_buffer.end()) {
585         continue;
586       }
587       if (iter->second->color() == memory_color) {
588         event_count++;
589       }
590     }
591     if (event_count > best_event_count) {
592       best_index = i;
593       best_event_count = event_count;
594     }
595   }
596 
597   return best_index;
598 }
599 
600 // Tries to get the correct heap simulator trace based on
601 // buffer_allocation_index.
GetHeapSimulatorTraceIdFromBufferAllocationIndex(const HloProto & proto,int64_t memory_color)602 int64_t GetHeapSimulatorTraceIdFromBufferAllocationIndex(const HloProto& proto,
603                                                          int64_t memory_color) {
604   absl::flat_hash_map<int64_t, const xla::BufferAllocationProto*>
605       id_to_buffer_allocation;
606   for (const auto& buffer_allocation :
607        proto.buffer_assignment().buffer_allocations()) {
608     id_to_buffer_allocation[buffer_allocation.index()] = &buffer_allocation;
609   }
610   for (int64_t i = 0;
611        i < proto.buffer_assignment().heap_simulator_traces_size(); ++i) {
612     int64_t buffer_allocation_index = proto.buffer_assignment()
613                                           .heap_simulator_traces(i)
614                                           .buffer_allocation_index();
615     const auto iter = id_to_buffer_allocation.find(buffer_allocation_index);
616     if (buffer_allocation_index && iter != id_to_buffer_allocation.end()) {
617       // Find the heap simulator trace that corresponds to the HLO temporaries
618       // buffer allocation, where is_thread_local,
619       // is_entry_computation_parameter, is_constant, and maybe_live_out will
620       // all be false.
621       const auto* buffer_allocation = iter->second;
622       if (buffer_allocation->color() == memory_color &&
623           !buffer_allocation->is_thread_local() &&
624           !buffer_allocation->is_entry_computation_parameter() &&
625           !buffer_allocation->is_constant() &&
626           !buffer_allocation->maybe_live_out()) {
627         return i;
628       }
629     }
630   }
631   return -1;
632 }
633 
GetHeapSimulatorTraceId(const HloProto & proto,int64_t memory_color)634 int64_t GetHeapSimulatorTraceId(const HloProto& proto, int64_t memory_color) {
635   int64_t id =
636       GetHeapSimulatorTraceIdFromBufferAllocationIndex(proto, memory_color);
637   if (id != -1) {
638     return id;
639   }
640   return GetHeapSimulatorTraceIdFromEvents(proto, memory_color);
641 }
642 
643 }  // namespace profiler
644 }  // namespace tensorflow
645