xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/xplane_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_
16 #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <optional>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
29 #include "tensorflow/core/profiler/utils/timespan.h"
30 #include "tensorflow/core/profiler/utils/trace_utils.h"
31 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
32 
33 namespace tensorflow {
34 namespace profiler {
35 
36 // Returns a Timespan from an XEvent.
37 // WARNING: This should only be used when comparing events from the same XLine.
XEventTimespan(const XEvent & event)38 inline Timespan XEventTimespan(const XEvent& event) {
39   return Timespan(event.offset_ps(), event.duration_ps());
40 }
41 
42 // Returns the planes with the given predicate.
43 template <typename F>
FindPlanes(const XSpace & space,const F & predicate)44 std::vector<const XPlane*> FindPlanes(const XSpace& space, const F& predicate) {
45   std::vector<const XPlane*> result;
46   for (const XPlane& plane : space.planes()) {
47     if (predicate(plane)) {
48       result.push_back(&plane);
49     }
50   }
51   return result;
52 }
53 
54 // Returns mutable planes with the given predicate.
55 template <typename F>
FindMutablePlanes(XSpace * space,const F & predicate)56 std::vector<XPlane*> FindMutablePlanes(XSpace* space, const F& predicate) {
57   std::vector<XPlane*> result;
58   for (XPlane& plane : *space->mutable_planes()) {
59     if (predicate(plane)) {
60       result.push_back(&plane);
61     }
62   }
63   return result;
64 }
65 
66 // Returns the plane with the given name or nullptr if not found.
67 const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name);
68 XPlane* FindMutablePlaneWithName(XSpace* space, absl::string_view name);
69 
70 // Returns the planes with the given names, if found.
71 std::vector<const XPlane*> FindPlanesWithNames(
72     const XSpace& space, const std::vector<absl::string_view>& names);
73 
74 // Returns the plane with the given name in the container. If necessary, adds a
75 // new plane to the container.
76 XPlane* FindOrAddMutablePlaneWithName(XSpace* space, absl::string_view name);
77 
78 // Returns all the planes with a given prefix.
79 std::vector<const XPlane*> FindPlanesWithPrefix(const XSpace& space,
80                                                 absl::string_view prefix);
81 std::vector<XPlane*> FindMutablePlanesWithPrefix(XSpace* space,
82                                                  absl::string_view prefix);
83 
84 // Returns the plane with the given id/name or nullptr if not found.
85 const XLine* FindLineWithId(const XPlane& plane, int64_t id);
86 const XLine* FindLineWithName(const XPlane& plane, absl::string_view name);
87 
88 XStat* FindOrAddMutableStat(const XStatMetadata& stat_metadata, XEvent* event);
89 
90 void RemovePlane(XSpace* space, const XPlane* plane);
91 void RemovePlanes(XSpace* space, const std::vector<const XPlane*>& planes);
92 void RemoveLine(XPlane* plane, const XLine* line);
93 void RemoveEvents(XLine* line,
94                   const absl::flat_hash_set<const XEvent*>& events);
95 
96 void RemoveEmptyPlanes(XSpace* space);
97 void RemoveEmptyLines(XPlane* plane);
98 
99 // Sort lines in plane with a provided comparator.
100 template <class Compare>
SortXLinesBy(XPlane * plane,Compare comp)101 void SortXLinesBy(XPlane* plane, Compare comp) {
102   std::sort(plane->mutable_lines()->pointer_begin(),
103             plane->mutable_lines()->pointer_end(), comp);
104 }
105 
106 class XLinesComparatorByName {
107  public:
operator()108   bool operator()(const XLine* a, const XLine* b) const {
109     auto& line_a = a->display_name().empty() ? a->name() : a->display_name();
110     auto& line_b = b->display_name().empty() ? b->name() : b->display_name();
111     return line_a < line_b;
112   }
113 };
114 
115 // Sorts each XLine's XEvents by offset_ps (ascending) and duration_ps
116 // (descending) so nested events are sorted from outer to innermost.
117 void SortXPlane(XPlane* plane);
118 // Sorts each plane of the XSpace.
119 void SortXSpace(XSpace* space);
120 
121 // Functor that compares XEvents for sorting by timespan.
122 struct XEventsComparator {
123   bool operator()(const XEvent* a, const XEvent* b) const;
124 };
125 
126 // Returns a sorted vector of all XEvents in the given XPlane.
127 // This template can be used with either XPlaneVisitor or XPlaneBuilder.
128 template <typename Event, typename Plane>
129 inline std::vector<Event> GetSortedEvents(Plane& plane,
130                                           bool include_derived_events = false) {
131   std::vector<Event> events;
132   plane.ForEachLine([&events, include_derived_events](auto line) {
133     if (!include_derived_events && IsDerivedThreadId(line.Id())) return;
134     line.ForEachEvent(
135         [&events](auto event) { events.emplace_back(std::move(event)); });
136   });
137   absl::c_sort(events);
138   return events;
139 }
140 
141 // Normalize timestamps by time-shifting to start_time_ns_ as origin.
142 void NormalizeTimestamps(XPlane* plane, uint64 start_time_ns);
143 void NormalizeTimestamps(XSpace* space, uint64 start_time_ns);
144 
145 // Merges src_plane into dst_plane. Both plane level stats, lines, events and
146 // event level stats are merged. If src_plane and dst_plane both have the same
147 // line, which have different start timestamps, we will normalize the events
148 // offset timestamp correspondingly.
149 void MergePlanes(const XPlane& src_plane, XPlane* dst_plane);
150 
151 // Merges each plane with a src_planes, into the dst_plane.
152 void MergePlanes(const std::vector<const XPlane*>& src_planes,
153                  XPlane* dst_plane);
154 
155 // Plane's start timestamp is defined as the minimum of all lines' start
156 // timestamps. If zero line exists, return 0;
157 int64_t GetStartTimestampNs(const XPlane& plane);
158 
159 // Returns true if there are no XEvents.
160 bool IsEmpty(const XSpace& space);
161 
162 // Mutate the XPlane by adding predefined XFlow. e.g. GPU kernel launches =>
163 // GPU kernel events.
164 void AddFlowsToXplane(int32_t host_id, bool is_host_plane, bool connect_traceme,
165                       XPlane* plane);
166 
167 // Get a fingerprint of device plane for deduplicating derived lines in similar
168 // device planes. The fingerprint is a hash of sorted HLO modules name which
169 // were appeared on current plane.
170 // Returns 0 when such "Xla Modules" line don't exist.
171 uint64_t GetDevicePlaneFingerprint(const XPlane& plane);
172 template <typename XPlanePointerIterator>
SortPlanesById(XPlanePointerIterator begin,XPlanePointerIterator end)173 void SortPlanesById(XPlanePointerIterator begin, XPlanePointerIterator end) {
174   std::sort(begin, end, [&](const XPlane* a, const XPlane* b) {
175     return a->id() < b->id();  // ascending order of device xplane id.
176   });
177 }
178 
179 // When certain event context only exists from event from other line, which
180 // "encloses" current event in timeline, we need to find out quickly which
181 // enclosing event is (or if there is one).
182 // To Avoid O(N) search overhead, assume the event are processed in the order
183 // of "XLine default sorting order".
184 class XEventContextTracker {
185  public:
186   // The events on line need to be sorted and disjointed.
XEventContextTracker(const XPlaneVisitor * plane,const XLine * line)187   XEventContextTracker(const XPlaneVisitor* plane, const XLine* line)
188       : plane_(plane), line_(line) {}
189 
190   // Returns the event that encloses/contains the specified input event.
191   // Expects called with events with start timestamps sorted incrementingly.
192   std::optional<XEventVisitor> GetContainingEvent(const Timespan& event);
193 
194   // Returns the event that overlaps the specified input event.
195   // Expects called with events with start timestamps sorted incrementingly.
196   std::optional<XEventVisitor> GetOverlappingEvent(const Timespan& event);
197 
198  private:
199   const XPlaneVisitor* plane_;
200   const XLine* line_;
201   int64_t current_index_ = -1;
202 };
203 
204 // Aggregate traces on full_trace xplane and add them onto the aggregated_trace
205 // xplane.
206 void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace);
207 
208 }  // namespace profiler
209 }  // namespace tensorflow
210 
211 #endif  // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_
212