1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
34 #include "tensorflow/core/profiler/convert/op_stack.h"
35 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
36 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
37 #include "tensorflow/core/profiler/utils/cost_utils.h"
38 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
39 #include "tensorflow/core/profiler/utils/op_utils.h"
40 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
41 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
42 #include "tensorflow/core/profiler/utils/timespan.h"
43 #include "tensorflow/core/profiler/utils/trace_utils.h"
44 #include "tensorflow/core/profiler/utils/xplane_schema.h"
45 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
46
47 namespace tensorflow {
48 namespace profiler {
49 namespace {
50
51 constexpr uint64_t kRootSymbolId = 0;
52
53 // Type of a TensorFlow Op activity, which is either beginning or ending an Op.
54 enum TfActivityType { kTfOpBegin, kTfOpEnd };
55
56 // Instant activity representing the begin or end of a host-side TF Op.
57 struct TfActivity {
58 // The timestamp in picoseconds when this activity happened.
59 uint64 timestamp_ps;
60 // The ID of this Op.
61 uint32 tf_op_id;
62 // Type of this activity.
63 TfActivityType activity_type;
64 // Full TF op name and type of this activity (backed by XEvent::name).
65 TfOp tf_op;
66 // Whether it is eagerly executed.
67 bool is_eager;
68 };
69
70 // TF Op metrics stored as element in OpStack.
71 struct TfOpInfo {
TfOpInfotensorflow::profiler::__anon0461001c0111::TfOpInfo72 explicit TfOpInfo(uint64 ts) : start_timestamp_ps(ts) {}
73
74 // Start timestamp in picoseconds.
75 uint64 start_timestamp_ps;
76 // Children duration in picoseconds.
77 uint64 children_duration_ps = 0;
78 };
79
80 // Processes a TF-activity on particular core.
ProcessOneTfActivity(const TfActivity & activity,OpStack<TfOpInfo> * tf_op_stack,TfMetricsDbData * tf_metrics_data)81 void ProcessOneTfActivity(const TfActivity& activity,
82 OpStack<TfOpInfo>* tf_op_stack,
83 TfMetricsDbData* tf_metrics_data) {
84 uint32 tf_op_id = activity.tf_op_id;
85 switch (activity.activity_type) {
86 case kTfOpBegin: {
87 tf_op_stack->Push(tf_op_id,
88 absl::make_unique<TfOpInfo>(activity.timestamp_ps));
89 break;
90 }
91 case kTfOpEnd: {
92 std::unique_ptr<TfOpInfo> info = tf_op_stack->Pop(tf_op_id);
93 if (info == nullptr) {
94 // This happens if TraceMes overlap.
95 VLOG(1) << "No begin event found for TF activity id=" << tf_op_id
96 << " name=" << activity.tf_op.name
97 << " type=" << activity.tf_op.type;
98 break;
99 }
100 Timespan tf_op_span =
101 PicoSpan(info->start_timestamp_ps, activity.timestamp_ps);
102 tf_metrics_data->tf_metrics_db_builder.EnterOp(
103 activity.tf_op.name, activity.tf_op.type, activity.is_eager,
104 tf_op_span.duration_ps(), info->children_duration_ps);
105 TfOpInfo* parent_info = tf_op_stack->Top();
106 if (parent_info != nullptr) {
107 parent_info->children_duration_ps += tf_op_span.duration_ps();
108 }
109 if (IsInfeedEnqueueOp(activity.tf_op.type)) {
110 tf_metrics_data->tf_metrics_db_builder.EnterHostInfeedEnqueue(
111 tf_op_span);
112 }
113 break;
114 }
115 }
116 }
117
118 // Processes all TF-activities on the given core.
ProcessTfActivities(std::vector<TfActivity> * tf_activities,TfMetricsDbData * tf_metrics_db_data)119 void ProcessTfActivities(std::vector<TfActivity>* tf_activities,
120 TfMetricsDbData* tf_metrics_db_data) {
121 if (tf_activities->empty()) return;
122 absl::c_stable_sort(*tf_activities,
123 [](const TfActivity& a, const TfActivity& b) {
124 return a.timestamp_ps < b.timestamp_ps;
125 });
126 OpStack<TfOpInfo> tf_op_stack;
127 for (const auto& tf_activity : *tf_activities) {
128 ProcessOneTfActivity(tf_activity, &tf_op_stack, tf_metrics_db_data);
129 }
130 SetTotalTimePs(
131 tf_metrics_db_data->tf_metrics_db,
132 tf_activities->back().timestamp_ps - tf_activities->front().timestamp_ps);
133 }
134
CollectTfActivities(const XLineVisitor & line,const absl::flat_hash_map<int64_t,TfOp> & tf_ops,std::vector<TfActivity> * tf_activities)135 void CollectTfActivities(const XLineVisitor& line,
136 const absl::flat_hash_map<int64_t, TfOp>& tf_ops,
137 std::vector<TfActivity>* tf_activities) {
138 uint32 tf_op_id = 0;
139 tf_activities->reserve(line.NumEvents() * 2);
140 line.ForEachEvent([&tf_ops, &tf_op_id,
141 &tf_activities](const XEventVisitor& event) {
142 const TfOp* tf_op = gtl::FindOrNull(tf_ops, event.Id());
143 if (tf_op != nullptr) {
144 ++tf_op_id;
145 bool is_eager = false;
146 if (absl::optional<XStatVisitor> stat =
147 event.GetStat(StatType::kIsEager)) {
148 is_eager = stat->IntValue();
149 }
150 Timespan span = event.GetTimespan();
151 tf_activities->push_back(
152 {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});
153 tf_activities->push_back(
154 {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager});
155 }
156 });
157 }
158
159 struct OpKey {
160 std::optional<uint64_t> program_id;
161 std::optional<uint64_t> symbol_id;
162 };
GetOpKeyFromHloEventMetadata(const XEventMetadataVisitor & hlo_event_metadata)163 OpKey GetOpKeyFromHloEventMetadata(
164 const XEventMetadataVisitor& hlo_event_metadata) {
165 OpKey op_key;
166 hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) {
167 if (stat.Type().has_value()) {
168 switch (static_cast<StatType>(*stat.Type())) {
169 case StatType::kProgramId:
170 op_key.program_id = stat.IntOrUintValue();
171 break;
172 case StatType::kSymbolId:
173 op_key.symbol_id = stat.IntOrUintValue();
174 break;
175 default:
176 break;
177 }
178 }
179 });
180 return op_key;
181 }
182
SetOpMetadataFromHloEventMetadata(const XEventMetadataVisitor & hlo_event_metadata,OpMetrics * op_metrics)183 void SetOpMetadataFromHloEventMetadata(
184 const XEventMetadataVisitor& hlo_event_metadata, OpMetrics* op_metrics) {
185 if (hlo_event_metadata.HasDisplayName()) {
186 op_metrics->set_name(std::string(hlo_event_metadata.DisplayName()));
187 op_metrics->set_long_name(std::string(hlo_event_metadata.Name()));
188 } else {
189 op_metrics->set_name(std::string(hlo_event_metadata.Name()));
190 }
191 hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) {
192 if (stat.Type().has_value()) {
193 switch (static_cast<StatType>(*stat.Type())) {
194 case StatType::kHloCategory:
195 op_metrics->set_category(std::string(stat.StrOrRefValue()));
196 break;
197 case StatType::kTfOpName:
198 op_metrics->set_provenance(std::string(stat.StrOrRefValue()));
199 break;
200 case StatType::kFlops:
201 op_metrics->set_flops(stat.IntOrUintValue());
202 break;
203 case StatType::kBytesAccessed:
204 op_metrics->set_bytes_accessed(stat.IntOrUintValue());
205 break;
206 default:
207 break;
208 }
209 }
210 });
211 hlo_event_metadata.ForEachChild(
212 [&](const XEventMetadataVisitor& child_hlo_event_metadata) {
213 OpMetrics* child = op_metrics->mutable_children()->add_metrics_db();
214 child->set_occurrences(1);
215 SetOpMetadataFromHloEventMetadata(child_hlo_event_metadata, child);
216 });
217 }
218
SetOpMetricsFromHloEvent(const XEventVisitor & hlo_event,OpMetrics * op_metrics)219 void SetOpMetricsFromHloEvent(const XEventVisitor& hlo_event,
220 OpMetrics* op_metrics) {
221 uint64_t duration_ps = hlo_event.DurationPs();
222 uint64_t min_duration_ps = duration_ps;
223 uint64_t self_duration_ps = duration_ps;
224 uint64_t dma_stall_ps = 0;
225 hlo_event.ForEachStat([&](const XStatVisitor& stat) {
226 if (!stat.Type()) return;
227 switch (static_cast<StatType>(*stat.Type())) {
228 case StatType::kMinDurationPs:
229 min_duration_ps = stat.IntValue();
230 break;
231 case StatType::kSelfDurationPs:
232 self_duration_ps = stat.IntValue();
233 break;
234 case StatType::kDmaStallDurationPs:
235 dma_stall_ps = stat.IntValue();
236 break;
237 default:
238 break;
239 }
240 });
241 if (op_metrics->occurrences() == 0) {
242 SetOpMetadataFromHloEventMetadata(hlo_event.Metadata(), op_metrics);
243 op_metrics->set_occurrences(hlo_event.NumOccurrences());
244 op_metrics->set_time_ps(duration_ps);
245 op_metrics->set_min_time_ps(min_duration_ps);
246 op_metrics->set_self_time_ps(self_duration_ps);
247 op_metrics->set_dma_stall_ps(dma_stall_ps);
248 } else {
249 op_metrics->set_occurrences(op_metrics->occurrences() +
250 hlo_event.NumOccurrences());
251 op_metrics->set_time_ps(op_metrics->time_ps() + duration_ps);
252 op_metrics->set_min_time_ps(
253 std::min<uint64_t>(op_metrics->min_time_ps(), min_duration_ps));
254 op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_duration_ps);
255 op_metrics->set_dma_stall_ps(op_metrics->dma_stall_ps() + dma_stall_ps);
256 }
257 }
258
259 } // namespace
260
CollectTfOpsFromHostThreadsXPlane(const XPlane & host_trace)261 absl::flat_hash_map<int64_t, TfOp> CollectTfOpsFromHostThreadsXPlane(
262 const XPlane& host_trace) {
263 absl::flat_hash_map<int64_t, TfOp> tf_ops;
264 for (const auto& id_metadata : host_trace.event_metadata()) {
265 const XEventMetadata& metadata = id_metadata.second;
266 // On the host, we have added some user-specified TraceMe's in addition to
267 // the TraceMe's added to every TensorFlow op by the system. These
268 // user-inserted TraceMe's have "unknown" type. We don't count them in
269 // Tf-stats.
270 TfOp tf_op = ParseTfOpFullname(metadata.name());
271 if (tf_op.category != Category::kUnknown) {
272 tf_ops.try_emplace(metadata.id(), tf_op);
273 }
274 }
275 return tf_ops;
276 }
277
ConvertHostThreadsXLineToTfMetricsDbData(const XLineVisitor & line,const absl::flat_hash_map<int64_t,TfOp> & tf_ops)278 TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData(
279 const XLineVisitor& line,
280 const absl::flat_hash_map<int64_t, TfOp>& tf_ops) {
281 TfMetricsDbData tf_metrics_db_data;
282 if (!tf_ops.empty()) {
283 std::vector<TfActivity> tf_activities;
284 CollectTfActivities(line, tf_ops, &tf_activities);
285 ProcessTfActivities(&tf_activities, &tf_metrics_db_data);
286 }
287 return tf_metrics_db_data;
288 }
289
ConsumeTfMetricsDbData(TfMetricsDbData src,OpMetricsDbCombiner * dst)290 void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) {
291 AddIdleOp(src.tf_metrics_db);
292 // Host OpMetricsDb does not need to update the number of cores a certain op
293 // occurs.
294 dst->Combine(src.tf_metrics_db, /*update_num_cores=*/false);
295 src.tf_metrics_db.Clear();
296 }
297
ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane & host_trace)298 OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) {
299 absl::flat_hash_map<int64_t, TfOp> tf_ops =
300 CollectTfOpsFromHostThreadsXPlane(host_trace);
301 OpMetricsDb result;
302 OpMetricsDbCombiner combiner(&result);
303 XPlaneVisitor plane = CreateTfXPlaneVisitor(&host_trace);
304 plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) {
305 ConsumeTfMetricsDbData(
306 ConvertHostThreadsXLineToTfMetricsDbData(line, tf_ops), &combiner);
307 });
308 return result;
309 }
310
ConvertTpuDeviceTraceXPlaneToOpMetricsDb(const XPlane & device_trace)311 OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb(
312 const XPlane& device_trace) {
313 OpMetricsDb result;
314 XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
315 using OpMetricBySymbol = absl::flat_hash_map<int64_t, OpMetrics>;
316 absl::flat_hash_map<int64_t, OpMetricBySymbol> flat_op_metric;
317 plane.ForEachLine([&](const XLineVisitor& line) {
318 line.ForEachEvent([&](const XEventVisitor& event) {
319 OpKey key = GetOpKeyFromHloEventMetadata(event.Metadata());
320 if (!key.program_id.has_value() || !key.symbol_id.has_value()) return;
321 OpMetricBySymbol& op_metric_by_symbol =
322 flat_op_metric[key.program_id.value()];
323 if (key.symbol_id != kRootSymbolId) {
324 OpMetrics& op_metrics = op_metric_by_symbol[key.symbol_id.value()];
325 SetOpMetricsFromHloEvent(event, &op_metrics);
326 }
327 });
328 });
329
330 for (auto& [program_id, op_metric_by_symbol] : flat_op_metric) {
331 for (auto& [symbol_id, op_metrics] : op_metric_by_symbol) {
332 result.add_metrics_db()->Swap(&op_metrics);
333 }
334 }
335 AddIdleOp(result);
336 return result;
337 }
338
ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane & device_trace)339 OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) {
340 OpMetricsDb result;
341 DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result);
342
343 int64_t first_op_offset_ps = kint64max;
344 int64_t last_op_offset_ps = 0;
345
346 TfOpRoofLineCostEstimator op_level_cost_estimator;
347 XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
348 plane.ForEachLine([&](const XLineVisitor& line) {
349 if (IsDerivedThreadId(line.Id())) return;
350 line.ForEachEvent([&](const XEventVisitor& event) {
351 first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs());
352 last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs());
353
354 absl::string_view tf_op_full_name;
355 bool is_eager = false;
356 event.ForEachStat([&](const XStatVisitor& stat) {
357 if (stat.Type() == StatType::kTfOp) {
358 tf_op_full_name = stat.StrOrRefValue();
359 } else if (stat.Type() == StatType::kIsEager) {
360 is_eager = stat.IntValue();
361 }
362 });
363 if (tf_op_full_name.empty()) return;
364 TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
365 TfOpRoofLineCostEstimator::OpRoofLineStats costs;
366 if (tf_op.category != Category::kUnknown) {
367 costs = op_level_cost_estimator.Predict(event);
368 }
369 device_op_metrics_db_builder.EnterOp(
370 /*program_id=*/0, absl::StrCat(tf_op.name, "/", event.Name()),
371 tf_op.type, tf_op_full_name, is_eager,
372 /*occurrences=*/1, event.DurationPs(),
373 /*children_time_ps=*/0, costs.flops, costs.bytes_accessed);
374 });
375 });
376 SetTotalTimePs(
377 result, last_op_offset_ps ? last_op_offset_ps - first_op_offset_ps : 0);
378 AddIdleOp(result);
379 return result;
380 }
381
382 } // namespace profiler
383 } // namespace tensorflow
384