xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/derived_timeline.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/utils/derived_timeline.h"
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/profiler/convert/xla_op_utils.h"
31 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
32 #include "tensorflow/core/profiler/utils/gpu_event_stats.h"
33 #include "tensorflow/core/profiler/utils/group_events.h"
34 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
35 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
36 #include "tensorflow/core/profiler/utils/timespan.h"
37 #include "tensorflow/core/profiler/utils/tpu_xplane_utils.h"
38 #include "tensorflow/core/profiler/utils/trace_utils.h"
39 #include "tensorflow/core/profiler/utils/xplane_builder.h"
40 #include "tensorflow/core/profiler/utils/xplane_schema.h"
41 #include "tensorflow/core/profiler/utils/xplane_utils.h"
42 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
43 #include "tensorflow/core/util/stats_calculator.h"
44 
45 namespace tensorflow {
46 namespace profiler {
47 namespace {
48 
HloModuleEventName(const GpuEventStats & stats)49 inline std::string HloModuleEventName(const GpuEventStats& stats) {
50   return stats.program_id ? HloModuleNameWithProgramId(stats.hlo_module_name,
51                                                        *stats.program_id)
52                           : std::string(stats.hlo_module_name);
53 }
54 
55 // Returns a prefix that uniquely identifies the HLO module.
HloOpEventPrefix(const GpuEventStats & stats)56 inline std::string HloOpEventPrefix(const GpuEventStats& stats) {
57   return stats.program_id ? absl::StrCat(*stats.program_id, "/")
58                           : absl::StrCat(stats.hlo_module_name, "/");
59 }
60 
GetOrCreateHloOpEventsMetadata(XPlaneBuilder & plane_builder,const GpuEventStats & stats)61 std::vector<XEventMetadata*> GetOrCreateHloOpEventsMetadata(
62     XPlaneBuilder& plane_builder, const GpuEventStats& stats) {
63   DCHECK(stats.IsXlaOp());
64   DCHECK(!stats.hlo_module_name.empty());
65   std::vector<XEventMetadata*> hlo_op_events_metadata;
66   hlo_op_events_metadata.reserve(stats.hlo_op_names.size());
67   // Prepend an HLO module identifier so HLO operators with the same name but in
68   // different modules have different metadata.
69   std::string hlo_op_event_prefix = HloOpEventPrefix(stats);
70   for (absl::string_view hlo_op_name : stats.hlo_op_names) {
71     XEventMetadata* hlo_op_event_metadata =
72         plane_builder.GetOrCreateEventMetadata(
73             absl::StrCat(hlo_op_event_prefix, hlo_op_name));
74     // Display the HLO name without the module name in tools.
75     if (hlo_op_event_metadata->display_name().empty()) {
76       hlo_op_event_metadata->set_display_name(std::string(hlo_op_name));
77     }
78     hlo_op_events_metadata.push_back(hlo_op_event_metadata);
79   }
80   return hlo_op_events_metadata;
81 }
82 
83 }  // namespace
84 
ProcessTfOpEvent(absl::string_view tf_op_full_name,Timespan event_span,std::optional<int64_t> group_id,XPlaneBuilder & plane_builder,DerivedXLineBuilder & tf_name_scope_line_builder,DerivedXLineBuilder & tf_op_line_builder)85 void ProcessTfOpEvent(absl::string_view tf_op_full_name, Timespan event_span,
86                       std::optional<int64_t> group_id,
87                       XPlaneBuilder& plane_builder,
88                       DerivedXLineBuilder& tf_name_scope_line_builder,
89                       DerivedXLineBuilder& tf_op_line_builder) {
90   TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
91   Category category = tf_op.category;
92   if (category == Category::kTensorFlow || category == Category::kJax) {
93     tf_name_scope_line_builder.ExpandOrAddEvents(
94         plane_builder.GetOrCreateEventsMetadata(ParseTfNameScopes(tf_op)),
95         event_span, group_id);
96   }
97   XEventMetadata* tf_op_event_metadata =
98       plane_builder.GetOrCreateEventMetadata(tf_op_full_name);
99   // Set the display name to op_type so that the events of the same op_type have
100   // the same color in the trace viewer.
101   if (tf_op_event_metadata->display_name().empty()) {
102     tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
103   }
104   tf_op_line_builder.ExpandOrAddEvent(*tf_op_event_metadata, event_span,
105                                       group_id);
106 }
107 
DerivedXEventBuilder(XEventBuilder event,std::optional<int64_t> group_id)108 DerivedXEventBuilder::DerivedXEventBuilder(XEventBuilder event,
109                                            std::optional<int64_t> group_id)
110     : event_(std::move(event)), group_id_(group_id) {}
111 
ShouldExpand(const XEventMetadata & event_metadata,std::optional<int64_t> group_id) const112 bool DerivedXEventBuilder::ShouldExpand(const XEventMetadata& event_metadata,
113                                         std::optional<int64_t> group_id) const {
114   return event_.MetadataId() == event_metadata.id() && group_id_ == group_id;
115 }
116 
Expand(Timespan event_span)117 void DerivedXEventBuilder::Expand(Timespan event_span) {
118   Timespan timespan = event_.GetTimespan();
119   DCHECK_LE(timespan.begin_ps(), event_span.begin_ps());
120   timespan.ExpandToInclude(event_span);
121   event_.SetTimespan(timespan);
122 }
123 
DerivedXLineBuilder(XPlaneBuilder * plane,int64_t line_id,absl::string_view name,int64_t timestamp_ns,std::vector<DerivedXLineBuilder * > dependent_lines)124 DerivedXLineBuilder::DerivedXLineBuilder(
125     XPlaneBuilder* plane, int64_t line_id, absl::string_view name,
126     int64_t timestamp_ns, std::vector<DerivedXLineBuilder*> dependent_lines)
127     : group_id_stat_metadata_(
128           plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))),
129       level_stat_metadata_(plane->GetOrCreateStatMetadata("l")),
130       line_(plane->GetOrCreateLine(line_id)),
131       dependent_lines_(std::move(dependent_lines)) {
132   line_.SetName(name);
133   line_.SetTimestampNs(timestamp_ns);
134 }
135 
ExpandOrAddEvent(const XEventMetadata & event_metadata,Timespan event_span,std::optional<int64_t> group_id)136 void DerivedXLineBuilder::ExpandOrAddEvent(const XEventMetadata& event_metadata,
137                                            Timespan event_span,
138                                            std::optional<int64_t> group_id) {
139   ExpandOrAddLevelEvent(event_metadata, event_span, group_id,
140                         /*level=*/0);
141 }
142 
ExpandOrAddEvents(const std::vector<XEventMetadata * > & events_metadata_per_level,Timespan event_span,std::optional<int64_t> group_id)143 void DerivedXLineBuilder::ExpandOrAddEvents(
144     const std::vector<XEventMetadata*>& events_metadata_per_level,
145     Timespan event_span, std::optional<int64_t> group_id) {
146   if (events_metadata_per_level.empty()) return;
147   size_t current_nested_level = events_metadata_per_level.size();
148   for (size_t level = 0; level < current_nested_level; ++level) {
149     ExpandOrAddLevelEvent(*events_metadata_per_level[level], event_span,
150                           group_id, level);
151   }
152   ResetLastEvents(current_nested_level);
153 }
154 
ExpandOrAddLevelEvent(const XEventMetadata & event_metadata,Timespan event_span,std::optional<int64_t> group_id,int level)155 void DerivedXLineBuilder::ExpandOrAddLevelEvent(
156     const XEventMetadata& event_metadata, Timespan event_span,
157     std::optional<int64_t> group_id, int level) {
158   auto& last_event = last_event_by_level_[level];
159   if (last_event && last_event->ShouldExpand(event_metadata, group_id)) {
160     // Expand the last event to cover the given event.
161     last_event->Expand(event_span);
162   } else {
163     // Otherwise, reset the last events lower than or equal to the given level.
164     ResetLastEvents(level);
165     // And create a new event for the given level.
166     XEventBuilder event = line_.AddEvent(event_metadata);
167     event.SetTimespan(event_span);
168     if (group_id.has_value()) {
169       event.AddStatValue(*group_id_stat_metadata_, *group_id);
170     }
171     event.AddStatValue(*level_stat_metadata_, level);
172     last_event.emplace(std::move(event), group_id);
173   }
174 }
175 
ResetLastEvents(int level)176 void DerivedXLineBuilder::ResetLastEvents(int level) {
177   for (int i = level, end = last_event_by_level_.size(); i < end; ++i) {
178     last_event_by_level_[i].reset();
179   }
180   if (level == 0) {
181     for (DerivedXLineBuilder* line : dependent_lines_) {
182       line->ResetLastEvents(0);
183     }
184   }
185 }
186 
AddGroupMetadataToStepEvents(const GroupMetadataMap & group_metadata_map,XLineBuilder & line)187 void AddGroupMetadataToStepEvents(const GroupMetadataMap& group_metadata_map,
188                                   XLineBuilder& line) {
189   if (group_metadata_map.empty()) return;
190   XPlaneBuilder* plane = line.Plane();
191   const XStatMetadata* group_id_stat_metadata =
192       plane->GetStatMetadata(GetStatTypeStr(StatType::kGroupId));
193   if (group_id_stat_metadata == nullptr) return;
194   const XStatMetadata* step_name_stat_metadata =
195       plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
196   line.ForEachEvent([&](XEventBuilder event) {
197     const XStat* group_id_stat = event.GetStat(*group_id_stat_metadata);
198     if (group_id_stat != nullptr) {
199       int64_t group_id = group_id_stat->int64_value();
200       if (const GroupMetadata* group_metadata =
201               gtl::FindOrNull(group_metadata_map, group_id)) {
202         // TODO(b/160255693): Change the event name directly.
203         event.AddStatValue(*step_name_stat_metadata, group_metadata->name);
204       }
205     }
206   });
207 }
208 
DeriveStepEventsFromGroups(const GroupMetadataMap & group_metadata_map,XPlane * device_trace)209 void DeriveStepEventsFromGroups(const GroupMetadataMap& group_metadata_map,
210                                 XPlane* device_trace) {
211   XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
212   const XStatMetadata* group_id_stat_metadata =
213       plane_visitor.GetStatMetadataByType(StatType::kGroupId);
214   if (group_id_stat_metadata == nullptr) return;
215   XPlaneBuilder plane_builder(device_trace);
216   int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
217   DerivedXLineBuilder steps(&plane_builder, kThreadIdStepInfo, kStepLineName,
218                             start_timestamp_ns, {});
219   for (const XEventVisitor& event_visitor :
220        GetSortedEvents<XEventVisitor>(plane_visitor)) {
221     std::optional<XStatVisitor> group_id_stat =
222         event_visitor.GetStat(StatType::kGroupId, *group_id_stat_metadata);
223     if (group_id_stat.has_value()) {
224       int64_t group_id = group_id_stat->IntValue();
225       steps.ExpandOrAddEvent(
226           *plane_builder.GetOrCreateEventMetadata(absl::StrCat(group_id)),
227           event_visitor.GetTimespan(), group_id);
228     }
229   }
230   AddGroupMetadataToStepEvents(group_metadata_map, steps.Line());
231 }
232 
DeriveEventsFromAnnotations(const SymbolResolver & symbol_resolver,XPlane * device_trace)233 void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
234                                  XPlane* device_trace) {
235   XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
236   XPlaneBuilder plane_builder(device_trace);
237   int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
238   DerivedXLineBuilder tf_ops(&plane_builder, kThreadIdTfOp,
239                              kTensorFlowOpLineName, start_timestamp_ns, {});
240   DerivedXLineBuilder tf_name_scope(&plane_builder, kThreadIdTfNameScope,
241                                     kTensorFlowNameScopeLineName,
242                                     start_timestamp_ns, {&tf_ops});
243   DerivedXLineBuilder hlo_ops(&plane_builder, kThreadIdHloOp, kXlaOpLineName,
244                               start_timestamp_ns, {});
245   DerivedXLineBuilder hlo_modules(&plane_builder, kThreadIdHloModule,
246                                   kXlaModuleLineName, start_timestamp_ns,
247                                   {&tf_name_scope, &hlo_ops});
248   DerivedXLineBuilder source(&plane_builder, kThreadIdSource, kSourceLineName,
249                              start_timestamp_ns, {});
250 
251   for (const XEventVisitor& event :
252        GetSortedEvents<XEventVisitor>(plane_visitor)) {
253     GpuEventStats stats(&event);
254     // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or
255     // allocation events).
256     if (!stats.IsKernel()) continue;
257     Timespan event_span = event.GetTimespan();
258 
259     if (!stats.hlo_module_name.empty()) {
260       hlo_modules.ExpandOrAddEvent(
261           *plane_builder.GetOrCreateEventMetadata(HloModuleEventName(stats)),
262           event_span, stats.group_id);
263     }
264 
265     if (stats.IsXlaOp()) {
266       hlo_ops.ExpandOrAddEvents(
267           GetOrCreateHloOpEventsMetadata(plane_builder, stats), event_span,
268           stats.group_id);
269       auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name,
270                                     stats.hlo_op_names.back());
271       if (!symbol.tf_op_name.empty()) {
272         ProcessTfOpEvent(symbol.tf_op_name,
273                          event_span, stats.group_id, plane_builder,
274                          tf_name_scope, tf_ops);
275       }
276       if (!symbol.source_info.empty()) {
277         source.ExpandOrAddEvent(
278             *plane_builder.GetOrCreateEventMetadata(symbol.source_info),
279             event_span, stats.group_id);
280       }
281     } else if (stats.IsTfOp()) {
282       ProcessTfOpEvent(stats.tf_op_fullname,
283                        event_span, stats.group_id, plane_builder, tf_name_scope,
284                        tf_ops);
285     }
286   }
287   RemoveEmptyLines(device_trace);
288 }
289 
DeriveEventsFromHostTrace(const XPlane * host_trace,const GroupMetadataMap & group_metadata_map,std::vector<XPlane * > device_traces)290 void DeriveEventsFromHostTrace(const XPlane* host_trace,
291                                const GroupMetadataMap& group_metadata_map,
292                                std::vector<XPlane*> device_traces) {
293   struct GroupLaunchInfo {  // "Group" normally means step.
294     Timespan timespan;
295     Stat<uint64_t> stat;
296 
297     void AddEventTimespan(Timespan event_span) {
298       if (stat.count() == 0) {
299         timespan = event_span;
300       } else {
301         timespan.ExpandToInclude(event_span);
302       }
303       stat.UpdateStat(event_span.duration_ps());
304     }
305   };
306   using DeviceLaunchInfo =
307       absl::flat_hash_map<int64_t /*group_id*/, GroupLaunchInfo>;
308 
309   const int num_devices = device_traces.size();
310   std::vector<DeviceLaunchInfo> per_device_launch_info(num_devices);
311 
312   XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host_trace);
313   host_plane.ForEachLine([&](const XLineVisitor& line) {
314     if (IsDerivedThreadId(line.Id())) return;
315     line.ForEachEvent([&](const XEventVisitor& event) {
316       // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize
317       // etc for now. TODO: find a better way to filter out only the memcpy and
318       // kernel launch events.
319       if (absl::StartsWith(event.Name(), "cu")) return;
320       LaunchEventStats stats(&event);
321       if (stats.group_id.has_value() && stats.IsLaunch() &&
322           0 <= *stats.device_id && *stats.device_id < num_devices) {
323         // This is a launch event on a known device.
324         GroupLaunchInfo& group_launch_info =
325             per_device_launch_info[*stats.device_id][*stats.group_id];
326         group_launch_info.AddEventTimespan(event.GetTimespan());
327       }
328     });
329   });
330 
331   int64_t host_plane_start = GetStartTimestampNs(*host_trace);
332   for (int i = 0; i < num_devices; ++i) {
333     if (per_device_launch_info[i].empty()) continue;
334     int64_t device_plane_start = GetStartTimestampNs(*device_traces[i]);
335 
336     XPlaneBuilder device_plane(device_traces[i]);
337     const XStatMetadata& group_id_stat_metadata =
338         *device_plane.GetOrCreateStatMetadata(
339             GetStatTypeStr(StatType::kGroupId));
340     const XStatMetadata& num_launches_stat_metadata =
341         *device_plane.GetOrCreateStatMetadata("num_launches");
342     const XStatMetadata& max_launch_time_us_stat_metadata =
343         *device_plane.GetOrCreateStatMetadata("max_launch_time_us");
344     const XStatMetadata& avg_launch_time_us_stat_metadata =
345         *device_plane.GetOrCreateStatMetadata("avg_launch_time_us");
346 
347     XLineBuilder launch_line =
348         device_plane.GetOrCreateLine(kThreadIdKernelLaunch);
349     launch_line.SetName(kKernelLaunchLineName);
350     launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start));
351     for (const auto& kv : per_device_launch_info[i]) {
352       int64_t group_id = kv.first;
353       const GroupLaunchInfo& group_info = kv.second;
354       if (const GroupMetadata* group_metadata =
355               gtl::FindOrNull(group_metadata_map, group_id)) {
356         XEventBuilder device_event =
357             launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata(
358                 absl::StrCat("Launch Stats for ", group_metadata->name)));
359         device_event.SetTimespan(group_info.timespan);
360         device_event.AddStatValue(group_id_stat_metadata, group_id);
361         device_event.AddStatValue(num_launches_stat_metadata,
362                                   group_info.stat.count());
363         device_event.AddStatValue(max_launch_time_us_stat_metadata,
364                                   PicoToMicro(group_info.stat.max()));
365         device_event.AddStatValue(avg_launch_time_us_stat_metadata,
366                                   PicoToMicro(group_info.stat.avg()));
367       }
368     }
369   }
370 }
371 
GenerateDerivedTimeLines(const GroupMetadataMap & group_metadata_map,XSpace * space)372 void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map,
373                               XSpace* space) {
374   // TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that
375   // to look up tensorflow op name from hlo_module/hlo_op.
376   auto dummy_symbol_resolver =
377       [](absl::optional<uint64_t> program_id, absl::string_view hlo_module,
378          absl::string_view hlo_op) { return Symbol(); };
379   for (XPlane* plane : FindMutablePlanesWithPrefix(space, kGpuPlanePrefix)) {
380     DeriveStepEventsFromGroups(group_metadata_map, plane);
381     DeriveEventsFromAnnotations(dummy_symbol_resolver, plane);
382   }
383   for (XPlane* plane : FindMutableTensorCorePlanes(space)) {
384     DeriveLinesFromStats(plane);
385     SortXPlane(plane);
386   }
387 }
388 
DeriveLinesFromStats(XPlane * device_trace)389 void DeriveLinesFromStats(XPlane* device_trace) {
390   XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
391   XPlaneBuilder plane_builder(device_trace);
392   int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
393   DerivedXLineBuilder tf_ops(
394       &plane_builder, tensorflow::profiler::kThreadIdTfOp,
395       tensorflow::profiler::kTensorFlowOpLineName, start_timestamp_ns, {});
396   DerivedXLineBuilder tf_name_scope(
397       &plane_builder, tensorflow::profiler::kThreadIdTfNameScope,
398       tensorflow::profiler::kTensorFlowNameScopeLineName, start_timestamp_ns,
399       {&tf_ops});
400   DerivedXLineBuilder source(
401       &plane_builder, tensorflow::profiler::kThreadIdSource,
402       tensorflow::profiler::kSourceLineName, start_timestamp_ns, {});
403 
404   for (const XEventVisitor& event :
405        GetSortedEvents<XEventVisitor>(plane_visitor, true)) {
406     Timespan event_span = event.GetTimespan();
407     std::optional<absl::string_view> tf_op_name;
408     std::optional<absl::string_view> source_info;
409     std::optional<uint64_t> group_id;
410     auto for_each_stat = [&](const XStatVisitor& stat) {
411       if (stat.Type() == StatType::kTfOp) {
412         tf_op_name = stat.StrOrRefValue();
413       } else if (stat.Type() == StatType::kGroupId) {
414         group_id = stat.IntOrUintValue();
415       } else if (stat.Type() == StatType::kSourceInfo) {
416         source_info = stat.StrOrRefValue();
417       }
418     };
419     event.Metadata().ForEachStat(for_each_stat);
420     event.ForEachStat(for_each_stat);
421 
422     if (tf_op_name && !tf_op_name->empty()) {
423       ProcessTfOpEvent(*tf_op_name, event_span, group_id, plane_builder,
424                        tf_name_scope, tf_ops);
425     }
426     if (source_info && !source_info->empty()) {
427       source.ExpandOrAddEvent(
428           *plane_builder.GetOrCreateEventMetadata(*source_info), event_span,
429           group_id);
430     }
431   }
432 
433   RemoveEmptyLines(device_trace);
434 }
435 
436 }  // namespace profiler
437 }  // namespace tensorflow
438