1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/profiler/utils/derived_timeline.h"
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/match.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/profiler/convert/xla_op_utils.h"
31 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
32 #include "tensorflow/core/profiler/utils/gpu_event_stats.h"
33 #include "tensorflow/core/profiler/utils/group_events.h"
34 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
35 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
36 #include "tensorflow/core/profiler/utils/timespan.h"
37 #include "tensorflow/core/profiler/utils/tpu_xplane_utils.h"
38 #include "tensorflow/core/profiler/utils/trace_utils.h"
39 #include "tensorflow/core/profiler/utils/xplane_builder.h"
40 #include "tensorflow/core/profiler/utils/xplane_schema.h"
41 #include "tensorflow/core/profiler/utils/xplane_utils.h"
42 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
43 #include "tensorflow/core/util/stats_calculator.h"
44
45 namespace tensorflow {
46 namespace profiler {
47 namespace {
48
HloModuleEventName(const GpuEventStats & stats)49 inline std::string HloModuleEventName(const GpuEventStats& stats) {
50 return stats.program_id ? HloModuleNameWithProgramId(stats.hlo_module_name,
51 *stats.program_id)
52 : std::string(stats.hlo_module_name);
53 }
54
55 // Returns a prefix that uniquely identifies the HLO module.
HloOpEventPrefix(const GpuEventStats & stats)56 inline std::string HloOpEventPrefix(const GpuEventStats& stats) {
57 return stats.program_id ? absl::StrCat(*stats.program_id, "/")
58 : absl::StrCat(stats.hlo_module_name, "/");
59 }
60
GetOrCreateHloOpEventsMetadata(XPlaneBuilder & plane_builder,const GpuEventStats & stats)61 std::vector<XEventMetadata*> GetOrCreateHloOpEventsMetadata(
62 XPlaneBuilder& plane_builder, const GpuEventStats& stats) {
63 DCHECK(stats.IsXlaOp());
64 DCHECK(!stats.hlo_module_name.empty());
65 std::vector<XEventMetadata*> hlo_op_events_metadata;
66 hlo_op_events_metadata.reserve(stats.hlo_op_names.size());
67 // Prepend an HLO module identifier so HLO operators with the same name but in
68 // different modules have different metadata.
69 std::string hlo_op_event_prefix = HloOpEventPrefix(stats);
70 for (absl::string_view hlo_op_name : stats.hlo_op_names) {
71 XEventMetadata* hlo_op_event_metadata =
72 plane_builder.GetOrCreateEventMetadata(
73 absl::StrCat(hlo_op_event_prefix, hlo_op_name));
74 // Display the HLO name without the module name in tools.
75 if (hlo_op_event_metadata->display_name().empty()) {
76 hlo_op_event_metadata->set_display_name(std::string(hlo_op_name));
77 }
78 hlo_op_events_metadata.push_back(hlo_op_event_metadata);
79 }
80 return hlo_op_events_metadata;
81 }
82
83 } // namespace
84
ProcessTfOpEvent(absl::string_view tf_op_full_name,Timespan event_span,std::optional<int64_t> group_id,XPlaneBuilder & plane_builder,DerivedXLineBuilder & tf_name_scope_line_builder,DerivedXLineBuilder & tf_op_line_builder)85 void ProcessTfOpEvent(absl::string_view tf_op_full_name, Timespan event_span,
86 std::optional<int64_t> group_id,
87 XPlaneBuilder& plane_builder,
88 DerivedXLineBuilder& tf_name_scope_line_builder,
89 DerivedXLineBuilder& tf_op_line_builder) {
90 TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
91 Category category = tf_op.category;
92 if (category == Category::kTensorFlow || category == Category::kJax) {
93 tf_name_scope_line_builder.ExpandOrAddEvents(
94 plane_builder.GetOrCreateEventsMetadata(ParseTfNameScopes(tf_op)),
95 event_span, group_id);
96 }
97 XEventMetadata* tf_op_event_metadata =
98 plane_builder.GetOrCreateEventMetadata(tf_op_full_name);
99 // Set the display name to op_type so that the events of the same op_type have
100 // the same color in the trace viewer.
101 if (tf_op_event_metadata->display_name().empty()) {
102 tf_op_event_metadata->set_display_name(TfOpEventName(tf_op));
103 }
104 tf_op_line_builder.ExpandOrAddEvent(*tf_op_event_metadata, event_span,
105 group_id);
106 }
107
DerivedXEventBuilder(XEventBuilder event,std::optional<int64_t> group_id)108 DerivedXEventBuilder::DerivedXEventBuilder(XEventBuilder event,
109 std::optional<int64_t> group_id)
110 : event_(std::move(event)), group_id_(group_id) {}
111
ShouldExpand(const XEventMetadata & event_metadata,std::optional<int64_t> group_id) const112 bool DerivedXEventBuilder::ShouldExpand(const XEventMetadata& event_metadata,
113 std::optional<int64_t> group_id) const {
114 return event_.MetadataId() == event_metadata.id() && group_id_ == group_id;
115 }
116
Expand(Timespan event_span)117 void DerivedXEventBuilder::Expand(Timespan event_span) {
118 Timespan timespan = event_.GetTimespan();
119 DCHECK_LE(timespan.begin_ps(), event_span.begin_ps());
120 timespan.ExpandToInclude(event_span);
121 event_.SetTimespan(timespan);
122 }
123
DerivedXLineBuilder(XPlaneBuilder * plane,int64_t line_id,absl::string_view name,int64_t timestamp_ns,std::vector<DerivedXLineBuilder * > dependent_lines)124 DerivedXLineBuilder::DerivedXLineBuilder(
125 XPlaneBuilder* plane, int64_t line_id, absl::string_view name,
126 int64_t timestamp_ns, std::vector<DerivedXLineBuilder*> dependent_lines)
127 : group_id_stat_metadata_(
128 plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))),
129 level_stat_metadata_(plane->GetOrCreateStatMetadata("l")),
130 line_(plane->GetOrCreateLine(line_id)),
131 dependent_lines_(std::move(dependent_lines)) {
132 line_.SetName(name);
133 line_.SetTimestampNs(timestamp_ns);
134 }
135
ExpandOrAddEvent(const XEventMetadata & event_metadata,Timespan event_span,std::optional<int64_t> group_id)136 void DerivedXLineBuilder::ExpandOrAddEvent(const XEventMetadata& event_metadata,
137 Timespan event_span,
138 std::optional<int64_t> group_id) {
139 ExpandOrAddLevelEvent(event_metadata, event_span, group_id,
140 /*level=*/0);
141 }
142
ExpandOrAddEvents(const std::vector<XEventMetadata * > & events_metadata_per_level,Timespan event_span,std::optional<int64_t> group_id)143 void DerivedXLineBuilder::ExpandOrAddEvents(
144 const std::vector<XEventMetadata*>& events_metadata_per_level,
145 Timespan event_span, std::optional<int64_t> group_id) {
146 if (events_metadata_per_level.empty()) return;
147 size_t current_nested_level = events_metadata_per_level.size();
148 for (size_t level = 0; level < current_nested_level; ++level) {
149 ExpandOrAddLevelEvent(*events_metadata_per_level[level], event_span,
150 group_id, level);
151 }
152 ResetLastEvents(current_nested_level);
153 }
154
ExpandOrAddLevelEvent(const XEventMetadata & event_metadata,Timespan event_span,std::optional<int64_t> group_id,int level)155 void DerivedXLineBuilder::ExpandOrAddLevelEvent(
156 const XEventMetadata& event_metadata, Timespan event_span,
157 std::optional<int64_t> group_id, int level) {
158 auto& last_event = last_event_by_level_[level];
159 if (last_event && last_event->ShouldExpand(event_metadata, group_id)) {
160 // Expand the last event to cover the given event.
161 last_event->Expand(event_span);
162 } else {
163 // Otherwise, reset the last events lower than or equal to the given level.
164 ResetLastEvents(level);
165 // And create a new event for the given level.
166 XEventBuilder event = line_.AddEvent(event_metadata);
167 event.SetTimespan(event_span);
168 if (group_id.has_value()) {
169 event.AddStatValue(*group_id_stat_metadata_, *group_id);
170 }
171 event.AddStatValue(*level_stat_metadata_, level);
172 last_event.emplace(std::move(event), group_id);
173 }
174 }
175
ResetLastEvents(int level)176 void DerivedXLineBuilder::ResetLastEvents(int level) {
177 for (int i = level, end = last_event_by_level_.size(); i < end; ++i) {
178 last_event_by_level_[i].reset();
179 }
180 if (level == 0) {
181 for (DerivedXLineBuilder* line : dependent_lines_) {
182 line->ResetLastEvents(0);
183 }
184 }
185 }
186
AddGroupMetadataToStepEvents(const GroupMetadataMap & group_metadata_map,XLineBuilder & line)187 void AddGroupMetadataToStepEvents(const GroupMetadataMap& group_metadata_map,
188 XLineBuilder& line) {
189 if (group_metadata_map.empty()) return;
190 XPlaneBuilder* plane = line.Plane();
191 const XStatMetadata* group_id_stat_metadata =
192 plane->GetStatMetadata(GetStatTypeStr(StatType::kGroupId));
193 if (group_id_stat_metadata == nullptr) return;
194 const XStatMetadata* step_name_stat_metadata =
195 plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
196 line.ForEachEvent([&](XEventBuilder event) {
197 const XStat* group_id_stat = event.GetStat(*group_id_stat_metadata);
198 if (group_id_stat != nullptr) {
199 int64_t group_id = group_id_stat->int64_value();
200 if (const GroupMetadata* group_metadata =
201 gtl::FindOrNull(group_metadata_map, group_id)) {
202 // TODO(b/160255693): Change the event name directly.
203 event.AddStatValue(*step_name_stat_metadata, group_metadata->name);
204 }
205 }
206 });
207 }
208
DeriveStepEventsFromGroups(const GroupMetadataMap & group_metadata_map,XPlane * device_trace)209 void DeriveStepEventsFromGroups(const GroupMetadataMap& group_metadata_map,
210 XPlane* device_trace) {
211 XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
212 const XStatMetadata* group_id_stat_metadata =
213 plane_visitor.GetStatMetadataByType(StatType::kGroupId);
214 if (group_id_stat_metadata == nullptr) return;
215 XPlaneBuilder plane_builder(device_trace);
216 int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
217 DerivedXLineBuilder steps(&plane_builder, kThreadIdStepInfo, kStepLineName,
218 start_timestamp_ns, {});
219 for (const XEventVisitor& event_visitor :
220 GetSortedEvents<XEventVisitor>(plane_visitor)) {
221 std::optional<XStatVisitor> group_id_stat =
222 event_visitor.GetStat(StatType::kGroupId, *group_id_stat_metadata);
223 if (group_id_stat.has_value()) {
224 int64_t group_id = group_id_stat->IntValue();
225 steps.ExpandOrAddEvent(
226 *plane_builder.GetOrCreateEventMetadata(absl::StrCat(group_id)),
227 event_visitor.GetTimespan(), group_id);
228 }
229 }
230 AddGroupMetadataToStepEvents(group_metadata_map, steps.Line());
231 }
232
DeriveEventsFromAnnotations(const SymbolResolver & symbol_resolver,XPlane * device_trace)233 void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
234 XPlane* device_trace) {
235 XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
236 XPlaneBuilder plane_builder(device_trace);
237 int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
238 DerivedXLineBuilder tf_ops(&plane_builder, kThreadIdTfOp,
239 kTensorFlowOpLineName, start_timestamp_ns, {});
240 DerivedXLineBuilder tf_name_scope(&plane_builder, kThreadIdTfNameScope,
241 kTensorFlowNameScopeLineName,
242 start_timestamp_ns, {&tf_ops});
243 DerivedXLineBuilder hlo_ops(&plane_builder, kThreadIdHloOp, kXlaOpLineName,
244 start_timestamp_ns, {});
245 DerivedXLineBuilder hlo_modules(&plane_builder, kThreadIdHloModule,
246 kXlaModuleLineName, start_timestamp_ns,
247 {&tf_name_scope, &hlo_ops});
248 DerivedXLineBuilder source(&plane_builder, kThreadIdSource, kSourceLineName,
249 start_timestamp_ns, {});
250
251 for (const XEventVisitor& event :
252 GetSortedEvents<XEventVisitor>(plane_visitor)) {
253 GpuEventStats stats(&event);
254 // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or
255 // allocation events).
256 if (!stats.IsKernel()) continue;
257 Timespan event_span = event.GetTimespan();
258
259 if (!stats.hlo_module_name.empty()) {
260 hlo_modules.ExpandOrAddEvent(
261 *plane_builder.GetOrCreateEventMetadata(HloModuleEventName(stats)),
262 event_span, stats.group_id);
263 }
264
265 if (stats.IsXlaOp()) {
266 hlo_ops.ExpandOrAddEvents(
267 GetOrCreateHloOpEventsMetadata(plane_builder, stats), event_span,
268 stats.group_id);
269 auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name,
270 stats.hlo_op_names.back());
271 if (!symbol.tf_op_name.empty()) {
272 ProcessTfOpEvent(symbol.tf_op_name,
273 event_span, stats.group_id, plane_builder,
274 tf_name_scope, tf_ops);
275 }
276 if (!symbol.source_info.empty()) {
277 source.ExpandOrAddEvent(
278 *plane_builder.GetOrCreateEventMetadata(symbol.source_info),
279 event_span, stats.group_id);
280 }
281 } else if (stats.IsTfOp()) {
282 ProcessTfOpEvent(stats.tf_op_fullname,
283 event_span, stats.group_id, plane_builder, tf_name_scope,
284 tf_ops);
285 }
286 }
287 RemoveEmptyLines(device_trace);
288 }
289
DeriveEventsFromHostTrace(const XPlane * host_trace,const GroupMetadataMap & group_metadata_map,std::vector<XPlane * > device_traces)290 void DeriveEventsFromHostTrace(const XPlane* host_trace,
291 const GroupMetadataMap& group_metadata_map,
292 std::vector<XPlane*> device_traces) {
293 struct GroupLaunchInfo { // "Group" normally means step.
294 Timespan timespan;
295 Stat<uint64_t> stat;
296
297 void AddEventTimespan(Timespan event_span) {
298 if (stat.count() == 0) {
299 timespan = event_span;
300 } else {
301 timespan.ExpandToInclude(event_span);
302 }
303 stat.UpdateStat(event_span.duration_ps());
304 }
305 };
306 using DeviceLaunchInfo =
307 absl::flat_hash_map<int64_t /*group_id*/, GroupLaunchInfo>;
308
309 const int num_devices = device_traces.size();
310 std::vector<DeviceLaunchInfo> per_device_launch_info(num_devices);
311
312 XPlaneVisitor host_plane = CreateTfXPlaneVisitor(host_trace);
313 host_plane.ForEachLine([&](const XLineVisitor& line) {
314 if (IsDerivedThreadId(line.Id())) return;
315 line.ForEachEvent([&](const XEventVisitor& event) {
316 // Filter out API calls for cuEventRecord/cuEventQuery/cuCtxSynchronize
317 // etc for now. TODO: find a better way to filter out only the memcpy and
318 // kernel launch events.
319 if (absl::StartsWith(event.Name(), "cu")) return;
320 LaunchEventStats stats(&event);
321 if (stats.group_id.has_value() && stats.IsLaunch() &&
322 0 <= *stats.device_id && *stats.device_id < num_devices) {
323 // This is a launch event on a known device.
324 GroupLaunchInfo& group_launch_info =
325 per_device_launch_info[*stats.device_id][*stats.group_id];
326 group_launch_info.AddEventTimespan(event.GetTimespan());
327 }
328 });
329 });
330
331 int64_t host_plane_start = GetStartTimestampNs(*host_trace);
332 for (int i = 0; i < num_devices; ++i) {
333 if (per_device_launch_info[i].empty()) continue;
334 int64_t device_plane_start = GetStartTimestampNs(*device_traces[i]);
335
336 XPlaneBuilder device_plane(device_traces[i]);
337 const XStatMetadata& group_id_stat_metadata =
338 *device_plane.GetOrCreateStatMetadata(
339 GetStatTypeStr(StatType::kGroupId));
340 const XStatMetadata& num_launches_stat_metadata =
341 *device_plane.GetOrCreateStatMetadata("num_launches");
342 const XStatMetadata& max_launch_time_us_stat_metadata =
343 *device_plane.GetOrCreateStatMetadata("max_launch_time_us");
344 const XStatMetadata& avg_launch_time_us_stat_metadata =
345 *device_plane.GetOrCreateStatMetadata("avg_launch_time_us");
346
347 XLineBuilder launch_line =
348 device_plane.GetOrCreateLine(kThreadIdKernelLaunch);
349 launch_line.SetName(kKernelLaunchLineName);
350 launch_line.SetTimestampNs(std::min(device_plane_start, host_plane_start));
351 for (const auto& kv : per_device_launch_info[i]) {
352 int64_t group_id = kv.first;
353 const GroupLaunchInfo& group_info = kv.second;
354 if (const GroupMetadata* group_metadata =
355 gtl::FindOrNull(group_metadata_map, group_id)) {
356 XEventBuilder device_event =
357 launch_line.AddEvent(*device_plane.GetOrCreateEventMetadata(
358 absl::StrCat("Launch Stats for ", group_metadata->name)));
359 device_event.SetTimespan(group_info.timespan);
360 device_event.AddStatValue(group_id_stat_metadata, group_id);
361 device_event.AddStatValue(num_launches_stat_metadata,
362 group_info.stat.count());
363 device_event.AddStatValue(max_launch_time_us_stat_metadata,
364 PicoToMicro(group_info.stat.max()));
365 device_event.AddStatValue(avg_launch_time_us_stat_metadata,
366 PicoToMicro(group_info.stat.avg()));
367 }
368 }
369 }
370 }
371
GenerateDerivedTimeLines(const GroupMetadataMap & group_metadata_map,XSpace * space)372 void GenerateDerivedTimeLines(const GroupMetadataMap& group_metadata_map,
373 XSpace* space) {
374 // TODO(profiler): Once we capture HLO protos for xla/gpu, we should use that
375 // to look up tensorflow op name from hlo_module/hlo_op.
376 auto dummy_symbol_resolver =
377 [](absl::optional<uint64_t> program_id, absl::string_view hlo_module,
378 absl::string_view hlo_op) { return Symbol(); };
379 for (XPlane* plane : FindMutablePlanesWithPrefix(space, kGpuPlanePrefix)) {
380 DeriveStepEventsFromGroups(group_metadata_map, plane);
381 DeriveEventsFromAnnotations(dummy_symbol_resolver, plane);
382 }
383 for (XPlane* plane : FindMutableTensorCorePlanes(space)) {
384 DeriveLinesFromStats(plane);
385 SortXPlane(plane);
386 }
387 }
388
DeriveLinesFromStats(XPlane * device_trace)389 void DeriveLinesFromStats(XPlane* device_trace) {
390 XPlaneVisitor plane_visitor = CreateTfXPlaneVisitor(device_trace);
391 XPlaneBuilder plane_builder(device_trace);
392 int64_t start_timestamp_ns = GetStartTimestampNs(*device_trace);
393 DerivedXLineBuilder tf_ops(
394 &plane_builder, tensorflow::profiler::kThreadIdTfOp,
395 tensorflow::profiler::kTensorFlowOpLineName, start_timestamp_ns, {});
396 DerivedXLineBuilder tf_name_scope(
397 &plane_builder, tensorflow::profiler::kThreadIdTfNameScope,
398 tensorflow::profiler::kTensorFlowNameScopeLineName, start_timestamp_ns,
399 {&tf_ops});
400 DerivedXLineBuilder source(
401 &plane_builder, tensorflow::profiler::kThreadIdSource,
402 tensorflow::profiler::kSourceLineName, start_timestamp_ns, {});
403
404 for (const XEventVisitor& event :
405 GetSortedEvents<XEventVisitor>(plane_visitor, true)) {
406 Timespan event_span = event.GetTimespan();
407 std::optional<absl::string_view> tf_op_name;
408 std::optional<absl::string_view> source_info;
409 std::optional<uint64_t> group_id;
410 auto for_each_stat = [&](const XStatVisitor& stat) {
411 if (stat.Type() == StatType::kTfOp) {
412 tf_op_name = stat.StrOrRefValue();
413 } else if (stat.Type() == StatType::kGroupId) {
414 group_id = stat.IntOrUintValue();
415 } else if (stat.Type() == StatType::kSourceInfo) {
416 source_info = stat.StrOrRefValue();
417 }
418 };
419 event.Metadata().ForEachStat(for_each_stat);
420 event.ForEachStat(for_each_stat);
421
422 if (tf_op_name && !tf_op_name->empty()) {
423 ProcessTfOpEvent(*tf_op_name, event_span, group_id, plane_builder,
424 tf_name_scope, tf_ops);
425 }
426 if (source_info && !source_info->empty()) {
427 source.ExpandOrAddEvent(
428 *plane_builder.GetOrCreateEventMetadata(*source_info), event_span,
429 group_id);
430 }
431 }
432
433 RemoveEmptyLines(device_trace);
434 }
435
436 } // namespace profiler
437 } // namespace tensorflow
438