xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/xplane_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/utils/xplane_utils.h"
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/platform/fingerprint.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/profiler/lib/context_types.h"
31 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
32 #include "tensorflow/core/profiler/utils/math_utils.h"
33 #include "tensorflow/core/profiler/utils/timespan.h"
34 #include "tensorflow/core/profiler/utils/xplane_builder.h"
35 #include "tensorflow/core/profiler/utils/xplane_schema.h"
36 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
37 #include "tensorflow/core/util/stats_calculator.h"
38 
39 namespace tensorflow {
40 namespace profiler {
41 namespace {
42 
43 // Returns the index of the first element in array for which pred is true.
44 // Returns -1 if no such element is found.
45 template <typename T, typename Pred>
Find(const protobuf::RepeatedPtrField<T> & array,const Pred & pred)46 int Find(const protobuf::RepeatedPtrField<T>& array, const Pred& pred) {
47   for (int i = 0; i < array.size(); ++i) {
48     if (pred(&array.Get(i))) return i;
49   }
50   return -1;
51 }
52 
53 // Returns the indices of all elements in array for which pred is true.
54 template <typename T, typename Pred>
FindAll(const protobuf::RepeatedPtrField<T> & array,const Pred & pred)55 std::vector<int> FindAll(const protobuf::RepeatedPtrField<T>& array,
56                          const Pred& pred) {
57   std::vector<int> indices;
58   for (int i = 0; i < array.size(); ++i) {
59     if (pred(&array.Get(i))) indices.push_back(i);
60   }
61   return indices;
62 }
63 
64 template <typename T>
RemoveAt(protobuf::RepeatedPtrField<T> * array,const std::vector<int> & indices)65 void RemoveAt(protobuf::RepeatedPtrField<T>* array,
66               const std::vector<int>& indices) {
67   if (indices.empty()) return;
68   if (array->size() == indices.size()) {
69     // Assumes that 'indices' consists of [0 ... N-1].
70     array->Clear();
71     return;
72   }
73   auto remove_iter = indices.begin();
74   int i = *(remove_iter++);
75   for (int j = i + 1; j < array->size(); ++j) {
76     if (remove_iter != indices.end() && *remove_iter == j) {
77       ++remove_iter;
78     } else {
79       array->SwapElements(j, i++);
80     }
81   }
82   array->DeleteSubrange(i, array->size() - i);
83 }
84 
85 // Removes the given element from array.
86 template <typename T>
Remove(protobuf::RepeatedPtrField<T> * array,const T * elem)87 void Remove(protobuf::RepeatedPtrField<T>* array, const T* elem) {
88   int i = Find(*array, [elem](const T* e) { return elem == e; });
89   RemoveAt(array, {i});
90 }
91 
92 template <typename T, typename Pred>
RemoveIf(protobuf::RepeatedPtrField<T> * array,Pred && pred)93 void RemoveIf(protobuf::RepeatedPtrField<T>* array, Pred&& pred) {
94   std::vector<int> indices = FindAll(*array, pred);
95   RemoveAt(array, indices);
96 }
97 
98 // Copy XEventMetadata from source to destination. Also copies the associated
99 // XStats.
CopyEventMetadata(const XEventMetadata & src_event_metadata,const XPlaneVisitor & src_plane,XEventMetadata & dst_event_metadata,XPlaneBuilder & dst_plane)100 void CopyEventMetadata(const XEventMetadata& src_event_metadata,
101                        const XPlaneVisitor& src_plane,
102                        XEventMetadata& dst_event_metadata,
103                        XPlaneBuilder& dst_plane) {
104   if (dst_event_metadata.display_name().empty() &&
105       !src_event_metadata.display_name().empty()) {
106     dst_event_metadata.set_display_name(src_event_metadata.display_name());
107   }
108   if (dst_event_metadata.name().empty() && !src_event_metadata.name().empty()) {
109     dst_event_metadata.set_name(src_event_metadata.name());
110   }
111   if (dst_event_metadata.metadata().empty() &&
112       !src_event_metadata.metadata().empty()) {
113     dst_event_metadata.set_metadata(src_event_metadata.metadata());
114   }
115   XEventMetadataVisitor src_event_metadata_visitor(&src_plane,
116                                                    &src_event_metadata);
117   src_event_metadata_visitor.ForEachStat([&](const XStatVisitor& stat) {
118     XStatMetadata& metadata = *dst_plane.GetOrCreateStatMetadata(stat.Name());
119     XStat dst_stat = stat.RawStat();
120     if (stat.ValueCase() == XStat::kRefValue) {
121       XStatMetadata& value_metadata =
122           *dst_plane.GetOrCreateStatMetadata(stat.StrOrRefValue());
123       dst_stat.set_ref_value(value_metadata.id());
124     }
125     dst_stat.set_metadata_id(metadata.id());
126     *dst_event_metadata.add_stats() = std::move(dst_stat);
127   });
128 }
129 
IsOpLineName(absl::string_view line_name)130 bool IsOpLineName(absl::string_view line_name) {
131   return line_name == kXlaOpLineName || line_name == kTensorFlowOpLineName;
132 }
133 
134 }  // namespace
135 
FindPlaneWithName(const XSpace & space,absl::string_view name)136 const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) {
137   int i = Find(space.planes(),
138                [name](const XPlane* plane) { return plane->name() == name; });
139   return (i != -1) ? &space.planes(i) : nullptr;
140 }
141 
FindPlanesWithNames(const XSpace & space,const std::vector<absl::string_view> & names)142 std::vector<const XPlane*> FindPlanesWithNames(
143     const XSpace& space, const std::vector<absl::string_view>& names) {
144   absl::flat_hash_set<absl::string_view> names_set(names.begin(), names.end());
145   std::vector<int> indices =
146       FindAll(space.planes(), [&names_set](const XPlane* plane) {
147         return names_set.contains(plane->name());
148       });
149   std::vector<const XPlane*> planes;
150   planes.reserve(indices.size());
151   for (int i : indices) {
152     planes.push_back(&space.planes(i));
153   }
154   return planes;
155 }
156 
FindMutablePlaneWithName(XSpace * space,absl::string_view name)157 XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name) {
158   int i = Find(space->planes(),
159                [name](const XPlane* plane) { return plane->name() == name; });
160   return (i != -1) ? space->mutable_planes(i) : nullptr;
161 }
162 
FindOrAddMutablePlaneWithName(XSpace * space,absl::string_view name)163 XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name) {
164   XPlane* plane = FindMutablePlaneWithName(space, name);
165   if (plane == nullptr) {
166     plane = space->add_planes();
167     plane->set_name(name.data(), name.size());
168   }
169   return plane;
170 }
171 
FindPlanesWithPrefix(const XSpace & space,absl::string_view prefix)172 std::vector<const XPlane*> FindPlanesWithPrefix(const XSpace& space,
173                                                 absl::string_view prefix) {
174   return FindPlanes(space, [&](const XPlane& plane) {
175     return absl::StartsWith(plane.name(), prefix);
176   });
177 }
178 
FindMutablePlanesWithPrefix(XSpace * space,absl::string_view prefix)179 std::vector<XPlane*> FindMutablePlanesWithPrefix(XSpace* space,
180                                                  absl::string_view prefix) {
181   return FindMutablePlanes(space, [&](XPlane& plane) {
182     return absl::StartsWith(plane.name(), prefix);
183   });
184 }
185 
FindLineWithId(const XPlane & plane,int64_t id)186 const XLine* FindLineWithId(const XPlane& plane, int64_t id) {
187   int i =
188       Find(plane.lines(), [id](const XLine* line) { return line->id() == id; });
189   return (i != -1) ? &plane.lines(i) : nullptr;
190 }
191 
FindLineWithName(const XPlane & plane,absl::string_view name)192 const XLine* FindLineWithName(const XPlane& plane, absl::string_view name) {
193   int i = Find(plane.lines(),
194                [name](const XLine* line) { return line->name() == name; });
195   return (i != -1) ? &plane.lines(i) : nullptr;
196 }
197 
FindOrAddMutableStat(const XStatMetadata & stat_metadata,XEvent * event)198 XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event) {
199   for (auto& stat : *event->mutable_stats()) {
200     if (stat.metadata_id() == stat_metadata.id()) {
201       return &stat;
202     }
203   }
204   XStat* stat = event->add_stats();
205   stat->set_metadata_id(stat_metadata.id());
206   return stat;
207 }
208 
RemovePlane(XSpace * space,const XPlane * plane)209 void RemovePlane(XSpace* space, const XPlane* plane) {
210   DCHECK(plane != nullptr);
211   Remove(space->mutable_planes(), plane);
212 }
213 
RemovePlanes(XSpace * space,const std::vector<const XPlane * > & planes)214 void RemovePlanes(XSpace* space, const std::vector<const XPlane*>& planes) {
215   absl::flat_hash_set<const XPlane*> planes_set(planes.begin(), planes.end());
216   RemoveIf(space->mutable_planes(), [&planes_set](const XPlane* plane) {
217     return planes_set.contains(plane);
218   });
219 }
220 
RemoveLine(XPlane * plane,const XLine * line)221 void RemoveLine(XPlane* plane, const XLine* line) {
222   DCHECK(line != nullptr);
223   Remove(plane->mutable_lines(), line);
224 }
225 
RemoveEvents(XLine * line,const absl::flat_hash_set<const XEvent * > & events)226 void RemoveEvents(XLine* line,
227                   const absl::flat_hash_set<const XEvent*>& events) {
228   RemoveIf(line->mutable_events(),
229            [&](const XEvent* event) { return events.contains(event); });
230 }
231 
RemoveEmptyPlanes(XSpace * space)232 void RemoveEmptyPlanes(XSpace* space) {
233   RemoveIf(space->mutable_planes(),
234            [&](const XPlane* plane) { return plane->lines().empty(); });
235 }
236 
RemoveEmptyLines(XPlane * plane)237 void RemoveEmptyLines(XPlane* plane) {
238   RemoveIf(plane->mutable_lines(),
239            [&](const XLine* line) { return line->events().empty(); });
240 }
241 
operator ()(const XEvent * a,const XEvent * b) const242 bool XEventsComparator::operator()(const XEvent* a, const XEvent* b) const {
243   return XEventTimespan(*a) < XEventTimespan(*b);
244 }
245 
SortXPlane(XPlane * plane)246 void SortXPlane(XPlane* plane) {
247   for (XLine& line : *plane->mutable_lines()) {
248     auto& events = *line.mutable_events();
249     std::sort(events.pointer_begin(), events.pointer_end(),
250               XEventsComparator());
251   }
252 }
253 
SortXSpace(XSpace * space)254 void SortXSpace(XSpace* space) {
255   for (XPlane& plane : *space->mutable_planes()) SortXPlane(&plane);
256 }
257 
258 // Normalize the line's timestamp in this XPlane.
259 // NOTE: This can be called multiple times on the same plane. Only the first
260 // call will do the normalization, subsequent calls will do nothing.
261 // The assumption is that both line's timestamp_ns and start_time_ns are
262 // nano-seconds from epoch time, the different of these values is much
263 // smaller than these value.
NormalizeTimestamps(XPlane * plane,uint64 start_time_ns)264 void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns) {
265   for (XLine& line : *plane->mutable_lines()) {
266     if (line.timestamp_ns() >= static_cast<int64_t>(start_time_ns)) {
267       line.set_timestamp_ns(line.timestamp_ns() - start_time_ns);
268     }
269   }
270 }
271 
NormalizeTimestamps(XSpace * space,uint64 start_time_ns)272 void NormalizeTimestamps(XSpace* space, uint64 start_time_ns) {
273   for (XPlane& plane : *space->mutable_planes()) {
274     NormalizeTimestamps(&plane, start_time_ns);
275   }
276 }
277 
MergePlanes(const XPlane & src_plane,XPlane * dst_plane)278 void MergePlanes(const XPlane& src_plane, XPlane* dst_plane) {
279   RemoveEmptyLines(dst_plane);
280   XPlaneVisitor src(&src_plane);
281   XPlaneBuilder dst(dst_plane);
282   src.ForEachStat([&](const tensorflow::profiler::XStatVisitor& stat) {
283     XStatMetadata* stat_metadata = dst.GetOrCreateStatMetadata(stat.Name());
284     // Use SetOrAddStat to avoid duplicating stats in dst_plane.
285     dst.SetOrAddStat(*stat_metadata, stat.RawStat(), src_plane);
286   });
287   src.ForEachLine([&](const XLineVisitor& line) {
288     XLineBuilder dst_line = dst.GetOrCreateLine(line.Id());
289     int64_t time_offset_ps = 0LL;
290     if (dst_line.NumEvents() == 0) {
291       // Since we RemoveEmptyLines above, this could only mean that current
292       // line only exist in src plane.
293       dst_line.SetTimestampNs(line.TimestampNs());
294       dst_line.SetName(line.Name());
295       dst_line.SetDisplayNameIfEmpty(line.DisplayName());
296     } else {
297       if (line.TimestampNs() <= dst_line.TimestampNs()) {
298         dst_line.SetTimestampNsAndAdjustEventOffsets(line.TimestampNs());
299       } else {
300         time_offset_ps =
301             NanoToPico(line.TimestampNs() - dst_line.TimestampNs());
302       }
303       dst_line.SetNameIfEmpty(line.Name());
304       // Don't override dst_line's display name because if both lines have name,
305       // but no display name, line's name will became display name of dst_line.
306     }
307 
308     line.ForEachEvent([&](const XEventVisitor& event) {
309       XEventMetadata* dst_event_metadata =
310           dst.GetOrCreateEventMetadata(event.Name());
311       CopyEventMetadata(*event.metadata(), src, *dst_event_metadata, dst);
312       XEventBuilder dst_event = dst_line.AddEvent(*dst_event_metadata);
313       dst_event.SetOffsetPs(event.OffsetPs() + time_offset_ps);
314       dst_event.SetDurationPs(event.DurationPs());
315       if (event.NumOccurrences()) {
316         dst_event.SetNumOccurrences(event.NumOccurrences());
317       }
318       event.ForEachStat([&](const XStatVisitor& stat) {
319         // Here we can call AddStat instead of SetOrAddStat because dst_event
320         // was just added.
321         dst_event.AddStat(*dst.GetOrCreateStatMetadata(stat.Name()),
322                           stat.RawStat(), src_plane);
323       });
324     });
325   });
326 }
327 
MergePlanes(const std::vector<const XPlane * > & src_planes,XPlane * dst_plane)328 void MergePlanes(const std::vector<const XPlane*>& src_planes,
329                  XPlane* dst_plane) {
330   for (const XPlane* src_plane : src_planes) {
331     MergePlanes(*src_plane, dst_plane);
332   }
333 }
334 
GetStartTimestampNs(const XPlane & plane)335 int64_t GetStartTimestampNs(const XPlane& plane) {
336   int64_t plane_timestamp = 0;
337   for (const auto& line : plane.lines()) {
338     plane_timestamp = std::min(plane_timestamp, line.timestamp_ns());
339   }
340   return plane_timestamp;
341 }
342 
IsEmpty(const XSpace & space)343 bool IsEmpty(const XSpace& space) {
344   for (const auto& plane : space.planes()) {
345     for (const auto& line : plane.lines()) {
346       if (!line.events().empty()) {
347         return false;
348       }
349     }
350   }
351   return true;
352 }
353 
AddFlowsToXplane(int32_t host_id,bool is_host_plane,bool connect_traceme,XPlane * xplane)354 void AddFlowsToXplane(int32_t host_id, bool is_host_plane, bool connect_traceme,
355                       XPlane* xplane) {
356   if (!xplane) return;
357   XPlaneBuilder plane(xplane);
358   XStatMetadata* correlation_id_stats_metadata =
359       plane.GetStatMetadata(GetStatTypeStr(StatType::kCorrelationId));
360   XStatMetadata* producer_type_stats_metadata =
361       plane.GetStatMetadata(GetStatTypeStr(StatType::kProducerType));
362   XStatMetadata* consumer_type_stats_metadata =
363       plane.GetStatMetadata(GetStatTypeStr(StatType::kConsumerType));
364   XStatMetadata* producer_id_stats_metadata =
365       plane.GetStatMetadata(GetStatTypeStr(StatType::kProducerId));
366   XStatMetadata* consumer_id_stats_metadata =
367       plane.GetStatMetadata(GetStatTypeStr(StatType::kConsumerId));
368   XStatMetadata* flow_stats_metadata =
369       plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kFlow));
370   XFlow::FlowDirection direction = is_host_plane
371                                        ? XFlow::FlowDirection::kFlowOut
372                                        : XFlow::FlowDirection::kFlowIn;
373 
374   plane.ForEachLine([&](XLineBuilder line) {
375     line.ForEachEvent([&](XEventBuilder event) {
376       absl::optional<uint64_t> correlation_id;
377       absl::optional<uint64_t> producer_type;
378       absl::optional<uint64_t> consumer_type;
379       absl::optional<uint64_t> producer_id;
380       absl::optional<uint64_t> consumer_id;
381       event.ForEachStat([&](XStat* stat) {
382         if (correlation_id_stats_metadata &&
383             stat->metadata_id() == correlation_id_stats_metadata->id()) {
384           correlation_id = stat->uint64_value();
385         } else if (connect_traceme) {
386           if (producer_type_stats_metadata &&
387               stat->metadata_id() == producer_type_stats_metadata->id()) {
388             producer_type = XStatsBuilder<XPlane>::IntOrUintValue(*stat);
389           } else if (consumer_type_stats_metadata &&
390                      stat->metadata_id() ==
391                          consumer_type_stats_metadata->id()) {
392             consumer_type = XStatsBuilder<XPlane>::IntOrUintValue(*stat);
393           } else if (producer_id_stats_metadata &&
394                      stat->metadata_id() == producer_id_stats_metadata->id()) {
395             producer_id = XStatsBuilder<XPlane>::IntOrUintValue(*stat);
396           } else if (consumer_id_stats_metadata &&
397                      stat->metadata_id() == consumer_id_stats_metadata->id()) {
398             consumer_id = XStatsBuilder<XPlane>::IntOrUintValue(*stat);
399           }
400         }
401       });
402       if (correlation_id) {
403         XFlow flow(XFlow::GetFlowId(host_id, *correlation_id), direction,
404                    ContextType::kGpuLaunch);
405         event.AddStatValue(*flow_stats_metadata, flow.ToStatValue());
406       }
407       if (connect_traceme) {
408         if (producer_type && producer_id) {
409           auto context_type = GetSafeContextType(*producer_type);
410           XFlow flow(XFlow::GetFlowId(host_id, *producer_id, context_type),
411                      XFlow::FlowDirection::kFlowOut, context_type);
412           event.AddStatValue(*flow_stats_metadata, flow.ToStatValue());
413         }
414         if (consumer_type && consumer_id) {
415           auto context_type = GetSafeContextType(*consumer_type);
416           XFlow flow(XFlow::GetFlowId(host_id, *consumer_id, context_type),
417                      XFlow::FlowDirection::kFlowIn, context_type);
418           event.AddStatValue(*flow_stats_metadata, flow.ToStatValue());
419         }
420       }
421     });
422   });
423 }
424 
GetDevicePlaneFingerprint(const XPlane & plane)425 uint64_t GetDevicePlaneFingerprint(const XPlane& plane) {
426   const XLine* xla_module_line = FindLineWithName(plane, kXlaModuleLineName);
427   if (!xla_module_line) return 0ULL;
428 
429   XPlaneVisitor xplane(&plane);
430   XLineVisitor xline(&xplane, xla_module_line);
431   std::set<uint64_t> ordered_module_fps;
432   xline.ForEachEvent([&](const XEventVisitor& xevent) {
433     ordered_module_fps.insert(Fingerprint64(xevent.Name()));
434   });
435   if (ordered_module_fps.empty()) return 0ULL;
436   uint64_t output = 0ULL;
437   for (const auto& fp : ordered_module_fps) {
438     output = FingerprintCat64(output, fp);
439   }
440   return output;
441 }
442 
GetContainingEvent(const Timespan & event)443 std::optional<XEventVisitor> XEventContextTracker::GetContainingEvent(
444     const Timespan& event) {
445   if (!line_) return std::nullopt;
446   if (current_index_ != -1) {
447     XEventVisitor current_event(plane_, line_, &line_->events(current_index_));
448     if (current_event.GetTimespan().Includes(event)) {
449       return current_event;
450     }
451   }
452   for (int i = current_index_ + 1; i < line_->events_size(); ++i) {
453     XEventVisitor current_event(plane_, line_, &line_->events(i));
454     if (current_event.TimestampPs() > event.end_ps()) break;
455     if (current_event.EndTimestampPs() < event.begin_ps()) continue;
456     current_index_ = i;
457     if (current_event.GetTimespan().Includes(event)) {
458       return current_event;
459     }
460     break;  // overlapping
461   }
462   return std::nullopt;
463 }
464 
GetOverlappingEvent(const Timespan & event)465 std::optional<XEventVisitor> XEventContextTracker::GetOverlappingEvent(
466     const Timespan& event) {
467   if (!line_) return std::nullopt;
468   if (current_index_ != -1) {
469     XEventVisitor current_event(plane_, line_, &line_->events(current_index_));
470     if (current_event.GetTimespan().Overlaps(event)) {
471       return current_event;
472     }
473   }
474   for (int i = current_index_ + 1; i < line_->events_size(); ++i) {
475     XEventVisitor current_event(plane_, line_, &line_->events(i));
476     if (current_event.TimestampPs() > event.end_ps()) break;
477     if (current_event.EndTimestampPs() < event.begin_ps()) continue;
478     current_index_ = i;
479     if (current_event.GetTimespan().Overlaps(event)) {
480       return current_event;
481     }
482     break;  // overlapping
483   }
484   return std::nullopt;
485 }
486 
AggregateXPlane(const XPlane & full_trace,XPlane & aggregated_trace)487 void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) {
488   struct EventStat {
489     Stat<int64_t> stat;
490     int64_t children_duration;
491   };
492   using StatByEvent = absl::flat_hash_map<int64_t /*event_id*/, EventStat>;
493 
494   absl::flat_hash_map<int64_t /*line_id*/, StatByEvent> stats;
495 
496   XPlaneVisitor plane(&full_trace);
497   XPlaneBuilder aggregated_plane(&aggregated_trace);
498 
499   plane.ForEachLine([&](const XLineVisitor& line) {
500     if (!IsOpLineName(line.Name())) return;
501     XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line.Id());
502     aggregated_line.SetName(line.Name());
503     std::vector<XEventVisitor> event_stack;
504     line.ForEachEvent([&](XEventVisitor event) {
505       StatByEvent& line_stats = stats[line.Id()];
506       line_stats[event.Id()].stat.UpdateStat(event.DurationPs());
507       DCHECK(event_stack.empty() || !(event < event_stack.back()));
508       while (!event_stack.empty() &&
509              !event_stack.back().GetTimespan().Includes(event.GetTimespan())) {
510         event_stack.pop_back();
511       }
512       if (!event_stack.empty()) {
513         line_stats[event_stack.back().Id()].children_duration +=
514             event.DurationPs();
515       }
516       event_stack.push_back(std::move(event));
517     });
518   });
519 
520   // TODO(b/238349654): Remove when XPlane better XPlane Comparison mechanism
521   // exists.
522   aggregated_plane.GetOrCreateStatMetadata(
523       GetStatTypeStr(StatType::kMinDurationPs));
524   aggregated_plane.GetOrCreateStatMetadata(
525       GetStatTypeStr(StatType::kSelfDurationPs));
526 
527   for (const auto& [line_id, stat_by_event] : stats) {
528     XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line_id);
529     for (const auto& [event_id, event_stat] : stat_by_event) {
530       XEventMetadata& event_metadata =
531           *aggregated_plane.GetOrCreateEventMetadata(event_id);
532       CopyEventMetadata(*plane.GetEventMetadata(event_id), plane,
533                         event_metadata, aggregated_plane);
534       XEventBuilder aggregated_event = aggregated_line.AddEvent(event_metadata);
535       aggregated_event.SetNumOccurrences(event_stat.stat.count());
536       aggregated_event.SetDurationPs(event_stat.stat.sum());
537       if (event_stat.stat.count() > 1) {
538         aggregated_event.AddStatValue(
539             *aggregated_plane.GetOrCreateStatMetadata(
540                 GetStatTypeStr(StatType::kMinDurationPs)),
541             event_stat.stat.min());
542       }
543       if (event_stat.children_duration != 0) {
544         aggregated_event.AddStatValue(
545             *aggregated_plane.GetOrCreateStatMetadata(
546                 GetStatTypeStr(StatType::kSelfDurationPs)),
547             event_stat.stat.sum() - event_stat.children_duration);
548       }
549     }
550   }
551 }
552 
553 }  // namespace profiler
554 }  // namespace tensorflow
555