xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/convert/xplane_to_step_events.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_step_events.h"
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/match.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/profiler/protobuf/steps_db.pb.h"
27 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
28 #include "tensorflow/core/profiler/utils/event_span.h"
29 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
30 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
31 #include "tensorflow/core/profiler/utils/timespan.h"
32 #include "tensorflow/core/profiler/utils/trace_utils.h"
33 #include "tensorflow/core/profiler/utils/xplane_schema.h"
34 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
35 
36 namespace tensorflow {
37 namespace profiler {
38 namespace {
39 
IsExplicitHostStepMarker(absl::string_view event_name)40 inline bool IsExplicitHostStepMarker(absl::string_view event_name) {
41   return (absl::StartsWith(event_name, "train") ||
42           absl::StartsWith(event_name, "test") ||
43           absl::StartsWith(event_name, "TraceContext")) &&
44          !absl::StrContains(event_name, "/");
45 }
46 
47 // Returns true if the given event_name should be considered as real computation
48 // on CPU.
IsRealCpuCompute(absl::string_view event_name)49 inline bool IsRealCpuCompute(absl::string_view event_name) {
50   bool not_real = absl::StartsWith(event_name, "EagerExecute") ||
51                   absl::StartsWith(event_name, "EagerLocalExecute") ||
52                   absl::StartsWith(event_name, "EagerKernelExecute") ||
53                   absl::StartsWith(event_name, "FunctionRun") ||
54                   IsExplicitHostStepMarker(event_name);
55   return !not_real;
56 }
57 
ParseNumBytesFromMemcpyDetail(absl::string_view memcpy_detail)58 uint64 ParseNumBytesFromMemcpyDetail(absl::string_view memcpy_detail) {
59   const std::vector<absl::string_view> params =
60       absl::StrSplit(memcpy_detail, absl::ByAnyChar(":\n"));
61 
62   // Processes value pairs.
63   for (uint32 ii = 0; ii < params.size(); ii += 2) {
64     if (params[ii] != "num_bytes") continue;
65     uint64 value = 0;
66     if (absl::SimpleAtoi(params[ii + 1], &value)) return value;
67     break;
68   }
69   return 0ULL;
70 }
71 
ClassifyGpuCompute(absl::string_view event_name,absl::string_view tensor_shapes)72 EventType ClassifyGpuCompute(absl::string_view event_name,
73                              absl::string_view tensor_shapes) {
74   if (tensor_shapes.empty()) {
75     // Deduces the precision from the name.
76     return (absl::StrContains(event_name, "half") ||
77             absl::StrContains(event_name, "fp16"))
78                ? DEVICE_COMPUTE_16
79                : DEVICE_COMPUTE_32;
80   } else {
81     // Deduces the precision from the shapes.
82     return (absl::StrContains(tensor_shapes, "half")) ? DEVICE_COMPUTE_16
83                                                       : DEVICE_COMPUTE_32;
84   }
85 }
86 
ClassifyGpuEvent(absl::string_view event_name,absl::string_view tensor_shapes)87 EventType ClassifyGpuEvent(absl::string_view event_name,
88                            absl::string_view tensor_shapes) {
89   TfOp tf_op = ParseTfOpFullname(event_name);
90   if (IsMemcpyHToDOp(tf_op)) {
91     return HOST_TO_DEVICE;
92   } else if (IsMemcpyDToHOp(tf_op)) {
93     return DEVICE_TO_HOST;
94   } else if (IsMemcpyDToDOp(tf_op)) {
95     return DEVICE_TO_DEVICE;
96   } else if (absl::StartsWithIgnoreCase(event_name, "nccl")) {
97     return DEVICE_COLLECTIVES;
98   } else {
99     return ClassifyGpuCompute(event_name, tensor_shapes);
100   }
101 }
102 
ClassifyCpuEvent(absl::string_view event_name,bool has_device,bool has_correlation_id)103 EventType ClassifyCpuEvent(absl::string_view event_name, bool has_device,
104                            bool has_correlation_id) {
105   TfOp tf_op = ParseTfOpFullname(event_name);
106   if (IsInfeedEnqueueOp(tf_op) || IsMemcpyHToDOp(tf_op)) {
107     return HOST_TO_DEVICE;
108   } else if (IsMemcpyHToHOp(tf_op)) {
109     return HOST_TO_HOST;
110   } else if (has_device && (has_correlation_id ||
111                             absl::StartsWithIgnoreCase(
112                                 event_name, "ExecutorState::Process"))) {
113     // TODO(b/150420972): Separate runtime overhead from actual compute for
114     // CPU-only.
115     return HOST_PREPARE;
116   } else if (absl::StartsWithIgnoreCase(event_name, "IteratorGetNext")) {
117     return HOST_WAIT_INPUT;
118   } else {
119     return HOST_COMPUTE;
120   }
121 }
122 
123 }  // namespace
124 
ConvertHostThreadsXLineToStepEvents(const XLineVisitor & line,const StepEvents * device_step_events)125 StepEvents ConvertHostThreadsXLineToStepEvents(
126     const XLineVisitor& line, const StepEvents* device_step_events) {
127   StepEvents result;
128   line.ForEachEvent([&](const XEventVisitor& event) {
129     int64_t correlation_id = -1;
130     int64_t group_id = -1;
131     absl::string_view step_name;
132     event.ForEachStat([&](const XStatVisitor& stat) {
133       if (!stat.Type().has_value()) return;
134       switch (stat.Type().value()) {
135         case StatType::kCorrelationId:
136           correlation_id = stat.IntValue();
137           break;
138         case StatType::kGroupId:
139           group_id = stat.IntValue();
140           break;
141         case StatType::kStepName:
142           step_name = stat.StrOrRefValue();
143           break;
144       }
145     });
146     if (group_id < 0) return;
147     // Don't add CPU events when (1) it includes device step events and (2) it
148     // doesn't have a device and that the group_id (i.e. step number) already
149     // appears on the device. This will filter out all cpu events that do not
150     // correspond to any steps executed on the device.
151     bool has_device = (device_step_events != nullptr);
152     if (has_device && !device_step_events->contains(group_id)) return;
153     if (IsExplicitHostStepMarker(event.Name())) {
154       result[group_id].AddMarker(
155           StepMarker(StepMarkerType::kExplicitHostStepMarker, event.Name(),
156                      event.GetTimespan()));
157     } else if (!step_name.empty()) {
158       // Grouping adds a step_name stat to implicit host step markers.
159       result[group_id].AddMarker(
160           StepMarker(StepMarkerType::kImplicitHostStepMarker, event.Name(),
161                      event.GetTimespan()));
162     } else if (IsRealCpuCompute(event.Name())) {
163       result[group_id].AddEvent(EventTypeSpan(
164           ClassifyCpuEvent(event.Name(), has_device, correlation_id >= 0),
165           event.GetTimespan()));
166     }
167     if (!step_name.empty()) {
168       result[group_id].SetStepName(std::string(step_name));
169     }
170   });
171   return result;
172 }
173 
ConvertHostThreadsXPlaneToStepEvents(const XPlane & host_trace,const StepEvents * device_step_events)174 StepEvents ConvertHostThreadsXPlaneToStepEvents(
175     const XPlane& host_trace, const StepEvents* device_step_events) {
176   StepEvents host_step_events;
177   XPlaneVisitor plane = CreateTfXPlaneVisitor(&host_trace);
178   plane.ForEachLine([&](const XLineVisitor& line) {
179     StepEvents thread_step_events =
180         ConvertHostThreadsXLineToStepEvents(line, device_step_events);
181     CombineStepEvents(thread_step_events, &host_step_events);
182   });
183   return host_step_events;
184 }
185 
ConvertDeviceStepInfoToStepMarkers(const XLineVisitor & line)186 StepEvents ConvertDeviceStepInfoToStepMarkers(const XLineVisitor& line) {
187   StepEvents result;
188   line.ForEachEvent([&](const XEventVisitor& event) {
189     if (absl::optional<XStatVisitor> stat = event.GetStat(StatType::kGroupId)) {
190       result[stat->IntValue()].AddMarker(
191           StepMarker(StepMarkerType::kDeviceStepMarker, event.Name(),
192                      event.GetTimespan()));
193     }
194   });
195   return result;
196 }
197 
ConvertDeviceTraceXLineToStepEvents(const uint64 device_id,const XLineVisitor & line)198 StepEvents ConvertDeviceTraceXLineToStepEvents(const uint64 device_id,
199                                                const XLineVisitor& line) {
200   StepEvents result;
201   line.ForEachEvent([&](const XEventVisitor& event) {
202     int64_t correlation_id = -1;
203     int64_t group_id = -1;
204     absl::string_view tensor_shapes;
205     absl::string_view memcpy_details;
206     event.ForEachStat([&](const XStatVisitor& stat) {
207       if (!stat.Type().has_value()) return;
208       switch (stat.Type().value()) {
209         case StatType::kCorrelationId:
210           correlation_id = stat.IntValue();
211           break;
212         case StatType::kGroupId:
213           group_id = stat.IntValue();
214           break;
215         case StatType::kTensorShapes:
216           tensor_shapes = stat.StrOrRefValue();
217           break;
218         case StatType::kMemcpyDetails:
219           memcpy_details = stat.StrOrRefValue();
220           break;
221       }
222     });
223 
224     if (correlation_id >= 0 && group_id >= 0) {
225       EventType event_type = ClassifyGpuEvent(event.Name(), tensor_shapes);
226       EventTypeSpan event_type_span(event_type, event.GetTimespan());
227       result[group_id].AddEvent(event_type_span);
228       switch (event_type) {
229         case DEVICE_COLLECTIVES: {
230           AllReduceInfo collective_ops;
231           collective_ops.set_start_time_ps(event.TimestampPs());
232           collective_ops.set_end_time_ps(event.EndOffsetPs());
233           // TODO(jiesun): figure out how to get size info etc.
234           result[group_id].AddCollectiveOpEvent(device_id, collective_ops);
235           break;
236         }
237         case HOST_TO_DEVICE:
238         case DEVICE_TO_DEVICE:
239         case DEVICE_TO_HOST: {
240           // TODO(jiesun): not all memcpy events are grouped, figure out a
241           // better way to attribute them to steps.
242           uint64 bytes_transferred =
243               ParseNumBytesFromMemcpyDetail(memcpy_details);
244           result[group_id].AddDeviceMemoryTransferEvent(
245               event_type, event.GetTimespan(), bytes_transferred);
246           break;
247         }
248         default:
249           return;
250       }
251     }
252   });
253   return result;
254 }
255 
ConvertDeviceTraceXPlaneToStepEvents(const XPlane & device_trace)256 StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) {
257   StepEvents device_step_events;
258   XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
259   plane.ForEachLine([&](const XLineVisitor& line) {
260     int64_t line_id = line.Id();
261     if (line_id == kThreadIdStepInfo) {
262       StepEvents step_marker_events = ConvertDeviceStepInfoToStepMarkers(line);
263       CombineStepEvents(step_marker_events, &device_step_events);
264     } else if (IsDerivedThreadId(line_id)) {
265       return;
266     } else {
267       StepEvents stream_step_events =
268           ConvertDeviceTraceXLineToStepEvents(plane.Id(), line);
269       CombineStepEvents(stream_step_events, &device_step_events);
270     }
271   });
272   return device_step_events;
273 }
274 
275 }  // namespace profiler
276 }  // namespace tensorflow
277