1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 #include <utility>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/node_hash_map.h"
26 #include "absl/container/node_hash_set.h"
27 #include "absl/status/status.h"
28 #include "absl/status/statusor.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/types/optional.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/service/hlo.pb.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/core/lib/math/math_util.h"
39 #include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h"
40
41 namespace tensorflow {
42 namespace profiler {
43 namespace {
44
45 using absl::StrFormat;
46 using ::xla::BufferAllocationProto;
47 using ::xla::HloInstructionProto;
48 using ::xla::HloProto;
49 using ::xla::LayoutUtil;
50 using ::xla::LogicalBufferProto;
51 using ::xla::Shape;
52 using ::xla::ShapeUtil;
53
BytesToMiB(int64_t bytes)54 double BytesToMiB(int64_t bytes) {
55 return static_cast<double>(bytes) / tensorflow::MathUtil::IPow(2, 20);
56 }
57
58 // Get buffer allocation property.
GetAllocationGroupName(const BufferAllocationProto * buffer_allocation)59 std::string GetAllocationGroupName(
60 const BufferAllocationProto* buffer_allocation) {
61 if (buffer_allocation == nullptr) {
62 return "";
63 }
64 if (buffer_allocation->is_entry_computation_parameter()) {
65 return "Parameter";
66 } else if (buffer_allocation->maybe_live_out()) {
67 return "Output";
68 } else if (buffer_allocation->is_thread_local()) {
69 return "Thread-local";
70 } else {
71 return "Temporary";
72 }
73 }
74
75 // Get the instruction associated with the logical buffer.
GetInstructionName(const LogicalBufferProto * logical_buffer)76 std::string GetInstructionName(const LogicalBufferProto* logical_buffer) {
77 if (logical_buffer == nullptr) {
78 return "";
79 }
80 if (logical_buffer->defined_at().shape_index().empty()) {
81 return logical_buffer->defined_at().instruction_name();
82 } else {
83 return absl::StrCat(
84 logical_buffer->defined_at().instruction_name(), "{",
85 absl::StrJoin(logical_buffer->defined_at().shape_index(), ""), "}");
86 }
87 }
88
MakeHeapObjectCommon(std::string label,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)89 HeapObject MakeHeapObjectCommon(std::string label, int logical_buffer_id,
90 int64_t logical_buffer_size_bytes,
91 int64_t unpadded_shape_bytes) {
92 HeapObject result;
93 result.set_label(std::move(label));
94 result.set_logical_buffer_id(logical_buffer_id);
95 result.set_logical_buffer_size_mib(BytesToMiB(logical_buffer_size_bytes));
96 result.set_unpadded_shape_mib(BytesToMiB(unpadded_shape_bytes));
97 return result;
98 }
99
MakeHeapObject(const std::string & tf_op_name,const std::string & shape_string,const std::string & op_code,std::string instruction_name,std::string group_name,std::string label,int color,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)100 HeapObject MakeHeapObject(const std::string& tf_op_name,
101 const std::string& shape_string,
102 const std::string& op_code,
103 std::string instruction_name, std::string group_name,
104 std::string label, int color, int logical_buffer_id,
105 int64_t logical_buffer_size_bytes,
106 int64_t unpadded_shape_bytes) {
107 HeapObject result =
108 MakeHeapObjectCommon(std::move(label), logical_buffer_id,
109 logical_buffer_size_bytes, unpadded_shape_bytes);
110 result.set_numbered(color);
111 result.set_instruction_name(std::move(instruction_name));
112 result.set_group_name(std::move(group_name));
113 result.set_tf_op_name(tf_op_name);
114 result.set_shape_string(shape_string);
115 result.set_op_code(op_code);
116 return result;
117 }
118
MakeHeapObject(std::string color,std::string label,int logical_buffer_id,int64_t logical_buffer_size_bytes,int64_t unpadded_shape_bytes)119 HeapObject MakeHeapObject(std::string color, std::string label,
120 int logical_buffer_id,
121 int64_t logical_buffer_size_bytes,
122 int64_t unpadded_shape_bytes) {
123 HeapObject result =
124 MakeHeapObjectCommon(std::move(label), logical_buffer_id,
125 logical_buffer_size_bytes, unpadded_shape_bytes);
126 result.set_named(std::move(color));
127 return result;
128 }
129
MakeBufferSpan(int32 start,int32 limit)130 BufferSpan MakeBufferSpan(int32 start, int32 limit) {
131 BufferSpan result;
132 result.set_start(start);
133 result.set_limit(limit);
134 return result;
135 }
136
ResolveShapeIndex(const Shape * shape,absl::Span<const int64_t> shape_index)137 const Shape* ResolveShapeIndex(const Shape* shape,
138 absl::Span<const int64_t> shape_index) {
139 for (int64_t value : shape_index) {
140 shape = &shape->tuple_shapes(value);
141 }
142 return shape;
143 }
144
145 // A wrapper around ShapeUtil::ByteSizeOf that clears out the layout/padding,
146 // since that is considered in the ByteSizeOf calculation.
UnpaddedSize(Shape shape)147 int64_t UnpaddedSize(Shape shape) {
148 // Ensure the layout has no padding by making it the default layout.
149 LayoutUtil::SetToDefaultLayout(&shape);
150 // Note: we make a simplifying assumption here that a "minimal" size for a
151 // tuple member would be the size of a `void*` -- there may be even fancier
152 // ways of doing things, but this should give a good enough approximation of
153 // what a minimal tuple size is.
154 return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
155 }
156
Convert(const xla::BufferAllocationProto_Assigned & assigned,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,LogicalBuffer * result)157 void Convert(const xla::BufferAllocationProto_Assigned& assigned,
158 const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
159 id_to_logical_buffer,
160 const absl::node_hash_map<std::string, const HloInstructionProto*>&
161 name_to_hlo,
162 LogicalBuffer* result) {
163 result->set_id(assigned.logical_buffer_id()),
164 result->set_size_mib(BytesToMiB(assigned.size()));
165 const LogicalBufferProto* proto =
166 id_to_logical_buffer.at(assigned.logical_buffer_id());
167 const std::string& instruction_name = proto->defined_at().instruction_name();
168 result->set_hlo_name(instruction_name);
169 result->mutable_shape_index()->CopyFrom(proto->defined_at().shape_index());
170 const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
171 const Shape* shape =
172 ResolveShapeIndex(&top_level_shape, proto->defined_at().shape_index());
173 result->set_shape(ShapeUtil::HumanStringWithLayout(*shape));
174 }
175
IsReusable(const BufferAllocationProto & buffer_allocation)176 bool IsReusable(const BufferAllocationProto& buffer_allocation) {
177 return !buffer_allocation.is_thread_local() && !buffer_allocation.is_tuple();
178 }
179
Convert(const BufferAllocationProto & proto,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,BufferAllocation * result)180 void Convert(const BufferAllocationProto& proto,
181 const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
182 id_to_logical_buffer,
183 const absl::node_hash_map<std::string, const HloInstructionProto*>&
184 name_to_hlo,
185 BufferAllocation* result) {
186 result->set_id(proto.index());
187 result->set_size_mib(BytesToMiB(proto.size()));
188 if (proto.is_entry_computation_parameter()) {
189 result->add_attributes("entry computation parameter");
190 }
191 if (proto.maybe_live_out()) {
192 result->add_attributes("may-be live out");
193 }
194 if (IsReusable(proto)) {
195 result->add_attributes("reusable");
196 }
197 for (const auto& assigned : proto.assigned()) {
198 Convert(assigned, id_to_logical_buffer, name_to_hlo,
199 result->add_logical_buffers());
200 }
201 // Check whether all logical buffers for this buffer allocation have a common
202 // shape.
203 if (!result->logical_buffers().empty()) {
204 std::string common_shape = result->logical_buffers(0).shape();
205 for (int64_t i = 1; i < result->logical_buffers_size(); ++i) {
206 if (result->logical_buffers(i).shape() != common_shape) {
207 common_shape = "";
208 break;
209 }
210 }
211 if (!common_shape.empty()) {
212 result->set_common_shape(common_shape);
213 }
214 }
215 }
216
NoteSpecialAllocations(const absl::flat_hash_set<const BufferAllocationProto * > & all_buffer_allocations,const absl::flat_hash_map<int64_t,const LogicalBufferProto * > & id_to_logical_buffer,const absl::node_hash_map<std::string,const HloInstructionProto * > & name_to_hlo,int64_t small_buffer_size,PreprocessResult * result)217 void NoteSpecialAllocations(
218 const absl::flat_hash_set<const BufferAllocationProto*>&
219 all_buffer_allocations,
220 const absl::flat_hash_map<int64_t, const LogicalBufferProto*>&
221 id_to_logical_buffer,
222
223 const absl::node_hash_map<std::string, const HloInstructionProto*>&
224 name_to_hlo,
225 int64_t small_buffer_size, PreprocessResult* result) {
226 int64_t entry_parameters_bytes = 0;
227 int64_t non_reusable_bytes = 0;
228 int64_t maybe_live_out_bytes = 0;
229 for (const BufferAllocationProto* buffer_allocation :
230 all_buffer_allocations) {
231 if (buffer_allocation->is_entry_computation_parameter()) {
232 entry_parameters_bytes += buffer_allocation->size();
233 }
234 if (!IsReusable(*buffer_allocation)) {
235 non_reusable_bytes += buffer_allocation->size();
236 }
237 if (buffer_allocation->maybe_live_out()) {
238 if (buffer_allocation->size() > small_buffer_size) {
239 VLOG(1) << "Maybe live out buffer allocation: "
240 << buffer_allocation->size()
241 << " bytes :: " << buffer_allocation->ShortDebugString();
242 }
243 maybe_live_out_bytes += buffer_allocation->size();
244 }
245 Convert(*buffer_allocation, id_to_logical_buffer, name_to_hlo,
246 result->add_indefinite_lifetimes());
247 }
248
249 result->set_entry_computation_parameters_mib(
250 BytesToMiB(entry_parameters_bytes));
251 result->set_non_reusable_mib(BytesToMiB(non_reusable_bytes));
252 result->set_maybe_live_out_mib(BytesToMiB(maybe_live_out_bytes));
253 }
254
255 } // namespace
256
ConvertHloProtoToPreprocessResult(const HloProto & hlo_proto,int64_t small_buffer_size,int64_t heap_simulator_trace_id,int64_t memory_color)257 absl::StatusOr<PreprocessResult> ConvertHloProtoToPreprocessResult(
258 const HloProto& hlo_proto, int64_t small_buffer_size,
259 int64_t heap_simulator_trace_id, int64_t memory_color) {
260 // Construct a mapping from name to HLO proto.
261 absl::node_hash_map<std::string, const HloInstructionProto*> name_to_hlo;
262 for (const auto& computation : hlo_proto.hlo_module().computations()) {
263 for (const auto& instruction : computation.instructions()) {
264 name_to_hlo[instruction.name()] = &instruction;
265 VLOG(1) << "HLO: " << instruction.ShortDebugString();
266 }
267 }
268
269 // Mapping from logical buffer ID to logical buffer, and set of all logical
270 // buffer protos.
271 absl::flat_hash_map<int64_t, const LogicalBufferProto*> id_to_logical_buffer;
272 absl::flat_hash_set<const LogicalBufferProto*> all_logical_buffers;
273 for (const auto& logical_buffer :
274 hlo_proto.buffer_assignment().logical_buffers()) {
275 VLOG(1) << "Logical buffer: " << logical_buffer.ShortDebugString();
276 id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
277 all_logical_buffers.insert(&logical_buffer);
278 }
279
280 // Mapping from logocal buffer proto to the buffer allocation that it exists
281 // inside (there must be only one).
282 //
283 // Also a reverse mapping from buffer allocation proto to the set of logical
284 // buffer protos that exist inside of it.
285 absl::flat_hash_map<const LogicalBufferProto*, const BufferAllocationProto*>
286 logical_buffer_to_buffer_allocation;
287 absl::node_hash_map<const BufferAllocationProto*,
288 absl::flat_hash_set<const LogicalBufferProto*>>
289 buffer_allocation_to_logical_buffers;
290 absl::flat_hash_set<const BufferAllocationProto*> all_buffer_allocations;
291 for (const BufferAllocationProto& buffer_allocation :
292 hlo_proto.buffer_assignment().buffer_allocations()) {
293 all_buffer_allocations.insert(&buffer_allocation);
294 for (const xla::BufferAllocationProto_Assigned& assigned :
295 buffer_allocation.assigned()) {
296 const LogicalBufferProto* logical_buffer =
297 id_to_logical_buffer.at(assigned.logical_buffer_id());
298 buffer_allocation_to_logical_buffers[&buffer_allocation].insert(
299 logical_buffer);
300 auto insert_result = logical_buffer_to_buffer_allocation.insert(
301 {logical_buffer, &buffer_allocation});
302 if (!insert_result.second) {
303 return absl::InvalidArgumentError(
304 "A logical buffer appears to be associated with multiple buffer "
305 "allocations.");
306 }
307 }
308 }
309
310 std::vector<int64_t> logical_buffers;
311 std::vector<int64_t> peak_logical_buffers;
312
313 int64_t heap_size_bytes = 0;
314 int64_t unpadded_heap_size_bytes = 0;
315
316 int64_t peak_heap_size_bytes = 0;
317 int64_t unpadded_peak_heap_size_bytes = 0; // Unpadded size at peak.
318 int64_t peak_heap_size_position = 0;
319 std::vector<double> heap_sizes;
320 std::vector<double> unpadded_heap_sizes;
321
322 absl::node_hash_map<int64_t, std::pair<int64_t, absl::optional<int64_t>>>
323 logical_buffer_spans;
324 absl::flat_hash_set<const LogicalBufferProto*> seen;
325 absl::flat_hash_set<const BufferAllocationProto*> seen_buffer_allocations;
326
327 // Run through all the simulator events in the given trace, and simulate the
328 // heap in order to find the point of peak memory usage and record its
329 // associated metadata.
330 if (heap_simulator_trace_id >= 0 &&
331 heap_simulator_trace_id <
332 hlo_proto.buffer_assignment().heap_simulator_traces_size()) {
333 const auto& simulator_events =
334 hlo_proto.buffer_assignment()
335 .heap_simulator_traces(heap_simulator_trace_id)
336 .events();
337 for (const auto& event : simulator_events) {
338 heap_sizes.push_back(BytesToMiB(heap_size_bytes));
339 unpadded_heap_sizes.push_back(BytesToMiB(unpadded_heap_size_bytes));
340 const auto* logical_buffer = id_to_logical_buffer.at(event.buffer_id());
341 seen.insert(logical_buffer);
342 seen_buffer_allocations.insert(
343 logical_buffer_to_buffer_allocation.at(logical_buffer));
344 const auto& instruction_name =
345 logical_buffer->defined_at().instruction_name();
346 const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
347 const Shape* shape = ResolveShapeIndex(
348 &top_level_shape, logical_buffer->defined_at().shape_index());
349 if (event.kind() == xla::HeapSimulatorTrace_Event::ALLOC ||
350 event.kind() == xla::HeapSimulatorTrace_Event::SHARE_WITH) {
351 logical_buffers.push_back(event.buffer_id());
352 heap_size_bytes += logical_buffer->size();
353 unpadded_heap_size_bytes += UnpaddedSize(*shape);
354 // Initialize the buffer span from the current event to the last event.
355 logical_buffer_spans[event.buffer_id()] = {heap_sizes.size() - 1,
356 simulator_events.size() - 1};
357 int64_t prior_peak_heap_size_bytes = peak_heap_size_bytes;
358 peak_heap_size_bytes = std::max(peak_heap_size_bytes, heap_size_bytes);
359 if (prior_peak_heap_size_bytes != peak_heap_size_bytes) {
360 peak_heap_size_position = heap_sizes.size() - 1;
361 unpadded_peak_heap_size_bytes = unpadded_heap_size_bytes;
362 VLOG(1) << StrFormat("New peak heap size on %d: %s :: %d bytes",
363 peak_heap_size_position, instruction_name,
364 peak_heap_size_bytes);
365 peak_logical_buffers = logical_buffers;
366 }
367 } else if (event.kind() == xla::HeapSimulatorTrace_Event::FREE) {
368 logical_buffers.erase(
369 std::remove(logical_buffers.begin(), logical_buffers.end(),
370 event.buffer_id()),
371 logical_buffers.end());
372 heap_size_bytes -= logical_buffer->size();
373 unpadded_heap_size_bytes -= UnpaddedSize(*shape);
374 logical_buffer_spans[event.buffer_id()].second = heap_sizes.size() - 1;
375 if (heap_size_bytes < 0) {
376 return absl::InvalidArgumentError(absl::StrCat(
377 "heap_size_bytes should be non-negative: ", heap_size_bytes));
378 }
379 } else {
380 return absl::InvalidArgumentError(
381 absl::StrCat("Unhandled event kind: ", event.kind()));
382 }
383 }
384
385 // Add the final heap size after simulating the entire heap trace.
386 heap_sizes.push_back(BytesToMiB(heap_size_bytes));
387 unpadded_heap_sizes.push_back(BytesToMiB(unpadded_heap_size_bytes));
388
389 if (seen_buffer_allocations.size() != 1) {
390 return absl::InvalidArgumentError(
391 absl::StrCat("All heap simulation should work out of a single buffer "
392 "allocation, actual seen_buffer_allocations.size():",
393 seen_buffer_allocations.size()));
394 }
395 }
396
397 VLOG(1) << "Found " << peak_logical_buffers.size()
398 << " logical buffers alive at point of peak heap usage.";
399
400 VLOG(1) << "Peak logical buffers: ["
401 << absl::StrJoin(peak_logical_buffers, ",") << "]";
402
403 int64_t indefinite_memory_usage_bytes = 0;
404 std::vector<HeapObject> max_heap;
405 int colorno = 0;
406 int64_t rest = 0;
407
408 // Helper lambda that adds the logical buffer as an element in the "max heap"
409 // view with constitutent logical buffers.
410 auto add_heap_object = [&](const LogicalBufferProto* logical_buffer,
411 const BufferAllocationProto* buffer_allocation) {
412 if (logical_buffer->size() <= small_buffer_size) {
413 rest += logical_buffer->size();
414 return;
415 }
416 const std::string& instruction_name =
417 logical_buffer->defined_at().instruction_name();
418 const Shape top_level_shape(name_to_hlo.at(instruction_name)->shape());
419 const Shape* shape = ResolveShapeIndex(
420 &top_level_shape, logical_buffer->defined_at().shape_index());
421 std::string shape_string = ShapeUtil::HumanStringWithLayout(*shape);
422 int64 unpadded_shape_bytes = UnpaddedSize(*shape);
423 const HloInstructionProto* hlo_instruction =
424 name_to_hlo.at(instruction_name);
425 std::string label = StrFormat("%s: %s # %s", instruction_name, shape_string,
426 hlo_instruction->metadata().op_name());
427 max_heap.push_back(MakeHeapObject(
428 hlo_instruction->metadata().op_name(), shape_string,
429 hlo_instruction->opcode(), GetInstructionName(logical_buffer),
430 GetAllocationGroupName(buffer_allocation), std::move(label), colorno++,
431 logical_buffer->id(), logical_buffer->size(), unpadded_shape_bytes));
432 };
433
434 // Now look for all logical buffers which have not been seen, and assume they
435 // have indefinite lifetime if they are not in thread-local buffer
436 // allocations.
437 absl::flat_hash_set<const LogicalBufferProto*> unseen;
438 for (const LogicalBufferProto* logical_buffer : all_logical_buffers) {
439 if (!seen.contains(logical_buffer)) {
440 unseen.insert(logical_buffer);
441 }
442 }
443 for (const LogicalBufferProto* logical_buffer : unseen) {
444 const BufferAllocationProto* buffer_allocation =
445 logical_buffer_to_buffer_allocation.at(logical_buffer);
446 if (buffer_allocation->is_thread_local()) {
447 continue;
448 }
449 if (logical_buffer->color() != memory_color) {
450 continue;
451 }
452 // Clear out the assigned logical buffers when stringifying the buffer
453 // allocation, as it can be a long list.
454 auto to_string = [](const BufferAllocationProto* p) {
455 BufferAllocationProto copy = *p;
456 copy.mutable_assigned()->Clear();
457 return copy.ShortDebugString();
458 };
459 if (seen_buffer_allocations.insert(buffer_allocation).second) {
460 indefinite_memory_usage_bytes += buffer_allocation->size();
461 const auto& logical_buffers =
462 buffer_allocation_to_logical_buffers.at(buffer_allocation);
463 if (logical_buffers.size() == 1) {
464 add_heap_object(*logical_buffers.begin(), buffer_allocation);
465 } else {
466 VLOG(1) << "Indefinite lifetime, no heap object shown due to "
467 "multiple logical buffers in buffer allocation: "
468 << logical_buffer->ShortDebugString()
469 << " :: " << to_string(buffer_allocation) << std::endl;
470 }
471 if (buffer_allocation->size() > small_buffer_size) {
472 VLOG(1) << "Indefinite memory usage now: "
473 << indefinite_memory_usage_bytes << " bytes (+"
474 << buffer_allocation->size() << " bytes)";
475 }
476 }
477 }
478
479 // For the buffers that have indefinite lifetime (that is, lifetime not
480 // reflected by the heap simulation) add it to the peak values and the vectors
481 // of heap sizes.
482 peak_heap_size_bytes += indefinite_memory_usage_bytes;
483 unpadded_peak_heap_size_bytes += indefinite_memory_usage_bytes;
484 double addend = BytesToMiB(indefinite_memory_usage_bytes);
485 for (int i = 0; i < heap_sizes.size(); ++i) {
486 heap_sizes[i] += addend;
487 unpadded_heap_sizes[i] += addend;
488 }
489
490 // Accumulate data for use in a stacked bar plot.
491 //
492 // We accumulate it in "program order" -- the order in which it was placed
493 // into the logical_buffers sequence above was program order, and we iterate
494 // that order to create data points.
495 for (int logical_buffer_id : peak_logical_buffers) {
496 const auto* logical_buffer = id_to_logical_buffer.at(logical_buffer_id);
497 const auto* buffer_allocation =
498 logical_buffer_to_buffer_allocation.at(logical_buffer);
499 add_heap_object(logical_buffer, buffer_allocation);
500 }
501 if (rest != 0) {
502 max_heap.push_back(MakeHeapObject(
503 "gray", StrFormat("small (<%d bytes)", small_buffer_size), -1, rest,
504 0));
505 }
506
507 std::vector<const HeapObject*> max_heap_by_size;
508 max_heap_by_size.reserve(max_heap.size());
509 for (const auto& object : max_heap) {
510 max_heap_by_size.push_back(&object);
511 }
512 std::sort(max_heap_by_size.begin(), max_heap_by_size.end(),
513 [](const HeapObject* a, const HeapObject* b) {
514 return a->logical_buffer_size_mib() >
515 b->logical_buffer_size_mib();
516 });
517
518 std::vector<int> max_heap_to_by_size;
519 max_heap_to_by_size.reserve(max_heap.size());
520 for (const auto& object : max_heap) {
521 auto it =
522 std::find(max_heap_by_size.begin(), max_heap_by_size.end(), &object);
523 int index = std::distance(max_heap_by_size.begin(), it);
524 max_heap_to_by_size.push_back(index);
525 }
526
527 std::vector<int> by_size_to_max_heap;
528 for (const auto* object : max_heap_by_size) {
529 int index = object - &max_heap[0];
530 by_size_to_max_heap.push_back(index);
531 }
532
533 PreprocessResult result;
534 result.set_module_name(hlo_proto.hlo_module().name());
535 result.set_entry_computation_name(
536 hlo_proto.hlo_module().entry_computation_name());
537 *result.mutable_heap_sizes() = {heap_sizes.begin(), heap_sizes.end()};
538 *result.mutable_unpadded_heap_sizes() = {unpadded_heap_sizes.begin(),
539 unpadded_heap_sizes.end()};
540 *result.mutable_max_heap() = {max_heap.begin(), max_heap.end()};
541 for (const HeapObject* o : max_heap_by_size) {
542 *result.add_max_heap_by_size() = *o;
543 }
544 *result.mutable_max_heap_to_by_size() = {max_heap_to_by_size.begin(),
545 max_heap_to_by_size.end()};
546 *result.mutable_by_size_to_max_heap() = {by_size_to_max_heap.begin(),
547 by_size_to_max_heap.end()};
548 result.set_peak_heap_mib(BytesToMiB(peak_heap_size_bytes));
549 result.set_peak_unpadded_heap_mib(BytesToMiB(unpadded_peak_heap_size_bytes));
550 result.set_peak_heap_size_position(peak_heap_size_position);
551
552 for (const auto& item : logical_buffer_spans) {
553 (*result.mutable_logical_buffer_spans())[item.first] =
554 MakeBufferSpan(item.second.first, item.second.second.value());
555 }
556
557 NoteSpecialAllocations(all_buffer_allocations, id_to_logical_buffer,
558 name_to_hlo, small_buffer_size, &result);
559 return result;
560 }
561
562 // From a list of heap simulator traces, identify the one that has the largest
563 // number of memory events with color <memory_color>.
564 // If unable to find the heap simulator trace, return -1, and
565 // ConvertHloProtoToPreprocessResult will not consider heap_simulator_traces
566 // during preprocess.
GetHeapSimulatorTraceIdFromEvents(const HloProto & proto,int64_t memory_color)567 int64_t GetHeapSimulatorTraceIdFromEvents(const HloProto& proto,
568 int64_t memory_color) {
569 absl::flat_hash_map<int64_t, const xla::LogicalBufferProto*>
570 id_to_logical_buffer;
571 for (const auto& logical_buffer :
572 proto.buffer_assignment().logical_buffers()) {
573 id_to_logical_buffer[logical_buffer.id()] = &logical_buffer;
574 }
575 int64_t best_index = -1;
576 int64_t best_event_count = 0;
577 for (int64_t i = 0;
578 i < proto.buffer_assignment().heap_simulator_traces_size(); i++) {
579 const auto& heap_simulator_trace =
580 proto.buffer_assignment().heap_simulator_traces(i);
581 int64_t event_count = 0;
582 for (const auto& event : heap_simulator_trace.events()) {
583 const auto iter = id_to_logical_buffer.find(event.buffer_id());
584 if (iter == id_to_logical_buffer.end()) {
585 continue;
586 }
587 if (iter->second->color() == memory_color) {
588 event_count++;
589 }
590 }
591 if (event_count > best_event_count) {
592 best_index = i;
593 best_event_count = event_count;
594 }
595 }
596
597 return best_index;
598 }
599
600 // Tries to get the correct heap simulator trace based on
601 // buffer_allocation_index.
GetHeapSimulatorTraceIdFromBufferAllocationIndex(const HloProto & proto,int64_t memory_color)602 int64_t GetHeapSimulatorTraceIdFromBufferAllocationIndex(const HloProto& proto,
603 int64_t memory_color) {
604 absl::flat_hash_map<int64_t, const xla::BufferAllocationProto*>
605 id_to_buffer_allocation;
606 for (const auto& buffer_allocation :
607 proto.buffer_assignment().buffer_allocations()) {
608 id_to_buffer_allocation[buffer_allocation.index()] = &buffer_allocation;
609 }
610 for (int64_t i = 0;
611 i < proto.buffer_assignment().heap_simulator_traces_size(); ++i) {
612 int64_t buffer_allocation_index = proto.buffer_assignment()
613 .heap_simulator_traces(i)
614 .buffer_allocation_index();
615 const auto iter = id_to_buffer_allocation.find(buffer_allocation_index);
616 if (buffer_allocation_index && iter != id_to_buffer_allocation.end()) {
617 // Find the heap simulator trace that corresponds to the HLO temporaries
618 // buffer allocation, where is_thread_local,
619 // is_entry_computation_parameter, is_constant, and maybe_live_out will
620 // all be false.
621 const auto* buffer_allocation = iter->second;
622 if (buffer_allocation->color() == memory_color &&
623 !buffer_allocation->is_thread_local() &&
624 !buffer_allocation->is_entry_computation_parameter() &&
625 !buffer_allocation->is_constant() &&
626 !buffer_allocation->maybe_live_out()) {
627 return i;
628 }
629 }
630 }
631 return -1;
632 }
633
GetHeapSimulatorTraceId(const HloProto & proto,int64_t memory_color)634 int64_t GetHeapSimulatorTraceId(const HloProto& proto, int64_t memory_color) {
635 int64_t id =
636 GetHeapSimulatorTraceIdFromBufferAllocationIndex(proto, memory_color);
637 if (id != -1) {
638 return id;
639 }
640 return GetHeapSimulatorTraceIdFromEvents(proto, memory_color);
641 }
642
643 } // namespace profiler
644 } // namespace tensorflow
645