xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/string_view.h"
29 #include "absl/types/optional.h"
30 #include "tensorflow/core/lib/gtl/map_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
34 #include "tensorflow/core/profiler/convert/op_stack.h"
35 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
36 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
37 #include "tensorflow/core/profiler/utils/cost_utils.h"
38 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
39 #include "tensorflow/core/profiler/utils/op_utils.h"
40 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
41 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
42 #include "tensorflow/core/profiler/utils/timespan.h"
43 #include "tensorflow/core/profiler/utils/trace_utils.h"
44 #include "tensorflow/core/profiler/utils/xplane_schema.h"
45 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
46 
47 namespace tensorflow {
48 namespace profiler {
49 namespace {
50 
51 constexpr uint64_t kRootSymbolId = 0;
52 
53 // Type of a TensorFlow Op activity, which is either beginning or ending an Op.
54 enum TfActivityType { kTfOpBegin, kTfOpEnd };
55 
56 // Instant activity representing the begin or end of a host-side TF Op.
57 struct TfActivity {
58   // The timestamp in picoseconds when this activity happened.
59   uint64 timestamp_ps;
60   // The ID of this Op.
61   uint32 tf_op_id;
62   // Type of this activity.
63   TfActivityType activity_type;
64   // Full TF op name and type of this activity (backed by XEvent::name).
65   TfOp tf_op;
66   // Whether it is eagerly executed.
67   bool is_eager;
68 };
69 
70 // TF Op metrics stored as element in OpStack.
71 struct TfOpInfo {
TfOpInfotensorflow::profiler::__anon0461001c0111::TfOpInfo72   explicit TfOpInfo(uint64 ts) : start_timestamp_ps(ts) {}
73 
74   // Start timestamp in picoseconds.
75   uint64 start_timestamp_ps;
76   // Children duration in picoseconds.
77   uint64 children_duration_ps = 0;
78 };
79 
80 // Processes a TF-activity on particular core.
ProcessOneTfActivity(const TfActivity & activity,OpStack<TfOpInfo> * tf_op_stack,TfMetricsDbData * tf_metrics_data)81 void ProcessOneTfActivity(const TfActivity& activity,
82                           OpStack<TfOpInfo>* tf_op_stack,
83                           TfMetricsDbData* tf_metrics_data) {
84   uint32 tf_op_id = activity.tf_op_id;
85   switch (activity.activity_type) {
86     case kTfOpBegin: {
87       tf_op_stack->Push(tf_op_id,
88                         absl::make_unique<TfOpInfo>(activity.timestamp_ps));
89       break;
90     }
91     case kTfOpEnd: {
92       std::unique_ptr<TfOpInfo> info = tf_op_stack->Pop(tf_op_id);
93       if (info == nullptr) {
94         // This happens if TraceMes overlap.
95         VLOG(1) << "No begin event found for TF activity id=" << tf_op_id
96                 << " name=" << activity.tf_op.name
97                 << " type=" << activity.tf_op.type;
98         break;
99       }
100       Timespan tf_op_span =
101           PicoSpan(info->start_timestamp_ps, activity.timestamp_ps);
102       tf_metrics_data->tf_metrics_db_builder.EnterOp(
103           activity.tf_op.name, activity.tf_op.type, activity.is_eager,
104           tf_op_span.duration_ps(), info->children_duration_ps);
105       TfOpInfo* parent_info = tf_op_stack->Top();
106       if (parent_info != nullptr) {
107         parent_info->children_duration_ps += tf_op_span.duration_ps();
108       }
109       if (IsInfeedEnqueueOp(activity.tf_op.type)) {
110         tf_metrics_data->tf_metrics_db_builder.EnterHostInfeedEnqueue(
111             tf_op_span);
112       }
113       break;
114     }
115   }
116 }
117 
118 // Processes all TF-activities on the given core.
ProcessTfActivities(std::vector<TfActivity> * tf_activities,TfMetricsDbData * tf_metrics_db_data)119 void ProcessTfActivities(std::vector<TfActivity>* tf_activities,
120                          TfMetricsDbData* tf_metrics_db_data) {
121   if (tf_activities->empty()) return;
122   absl::c_stable_sort(*tf_activities,
123                       [](const TfActivity& a, const TfActivity& b) {
124                         return a.timestamp_ps < b.timestamp_ps;
125                       });
126   OpStack<TfOpInfo> tf_op_stack;
127   for (const auto& tf_activity : *tf_activities) {
128     ProcessOneTfActivity(tf_activity, &tf_op_stack, tf_metrics_db_data);
129   }
130   SetTotalTimePs(
131       tf_metrics_db_data->tf_metrics_db,
132       tf_activities->back().timestamp_ps - tf_activities->front().timestamp_ps);
133 }
134 
CollectTfActivities(const XLineVisitor & line,const absl::flat_hash_map<int64_t,TfOp> & tf_ops,std::vector<TfActivity> * tf_activities)135 void CollectTfActivities(const XLineVisitor& line,
136                          const absl::flat_hash_map<int64_t, TfOp>& tf_ops,
137                          std::vector<TfActivity>* tf_activities) {
138   uint32 tf_op_id = 0;
139   tf_activities->reserve(line.NumEvents() * 2);
140   line.ForEachEvent([&tf_ops, &tf_op_id,
141                      &tf_activities](const XEventVisitor& event) {
142     const TfOp* tf_op = gtl::FindOrNull(tf_ops, event.Id());
143     if (tf_op != nullptr) {
144       ++tf_op_id;
145       bool is_eager = false;
146       if (absl::optional<XStatVisitor> stat =
147               event.GetStat(StatType::kIsEager)) {
148         is_eager = stat->IntValue();
149       }
150       Timespan span = event.GetTimespan();
151       tf_activities->push_back(
152           {span.begin_ps(), tf_op_id, kTfOpBegin, *tf_op, is_eager});
153       tf_activities->push_back(
154           {span.end_ps(), tf_op_id, kTfOpEnd, *tf_op, is_eager});
155     }
156   });
157 }
158 
159 struct OpKey {
160   std::optional<uint64_t> program_id;
161   std::optional<uint64_t> symbol_id;
162 };
GetOpKeyFromHloEventMetadata(const XEventMetadataVisitor & hlo_event_metadata)163 OpKey GetOpKeyFromHloEventMetadata(
164     const XEventMetadataVisitor& hlo_event_metadata) {
165   OpKey op_key;
166   hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) {
167     if (stat.Type().has_value()) {
168       switch (static_cast<StatType>(*stat.Type())) {
169         case StatType::kProgramId:
170           op_key.program_id = stat.IntOrUintValue();
171           break;
172         case StatType::kSymbolId:
173           op_key.symbol_id = stat.IntOrUintValue();
174           break;
175         default:
176           break;
177       }
178     }
179   });
180   return op_key;
181 }
182 
SetOpMetadataFromHloEventMetadata(const XEventMetadataVisitor & hlo_event_metadata,OpMetrics * op_metrics)183 void SetOpMetadataFromHloEventMetadata(
184     const XEventMetadataVisitor& hlo_event_metadata, OpMetrics* op_metrics) {
185   if (hlo_event_metadata.HasDisplayName()) {
186     op_metrics->set_name(std::string(hlo_event_metadata.DisplayName()));
187     op_metrics->set_long_name(std::string(hlo_event_metadata.Name()));
188   } else {
189     op_metrics->set_name(std::string(hlo_event_metadata.Name()));
190   }
191   hlo_event_metadata.ForEachStat([&](const XStatVisitor& stat) {
192     if (stat.Type().has_value()) {
193       switch (static_cast<StatType>(*stat.Type())) {
194         case StatType::kHloCategory:
195           op_metrics->set_category(std::string(stat.StrOrRefValue()));
196           break;
197         case StatType::kTfOpName:
198           op_metrics->set_provenance(std::string(stat.StrOrRefValue()));
199           break;
200         case StatType::kFlops:
201           op_metrics->set_flops(stat.IntOrUintValue());
202           break;
203         case StatType::kBytesAccessed:
204           op_metrics->set_bytes_accessed(stat.IntOrUintValue());
205           break;
206         default:
207           break;
208       }
209     }
210   });
211   hlo_event_metadata.ForEachChild(
212       [&](const XEventMetadataVisitor& child_hlo_event_metadata) {
213         OpMetrics* child = op_metrics->mutable_children()->add_metrics_db();
214         child->set_occurrences(1);
215         SetOpMetadataFromHloEventMetadata(child_hlo_event_metadata, child);
216       });
217 }
218 
SetOpMetricsFromHloEvent(const XEventVisitor & hlo_event,OpMetrics * op_metrics)219 void SetOpMetricsFromHloEvent(const XEventVisitor& hlo_event,
220                               OpMetrics* op_metrics) {
221   uint64_t duration_ps = hlo_event.DurationPs();
222   uint64_t min_duration_ps = duration_ps;
223   uint64_t self_duration_ps = duration_ps;
224   uint64_t dma_stall_ps = 0;
225   hlo_event.ForEachStat([&](const XStatVisitor& stat) {
226     if (!stat.Type()) return;
227     switch (static_cast<StatType>(*stat.Type())) {
228       case StatType::kMinDurationPs:
229         min_duration_ps = stat.IntValue();
230         break;
231       case StatType::kSelfDurationPs:
232         self_duration_ps = stat.IntValue();
233         break;
234       case StatType::kDmaStallDurationPs:
235         dma_stall_ps = stat.IntValue();
236         break;
237       default:
238         break;
239     }
240   });
241   if (op_metrics->occurrences() == 0) {
242     SetOpMetadataFromHloEventMetadata(hlo_event.Metadata(), op_metrics);
243     op_metrics->set_occurrences(hlo_event.NumOccurrences());
244     op_metrics->set_time_ps(duration_ps);
245     op_metrics->set_min_time_ps(min_duration_ps);
246     op_metrics->set_self_time_ps(self_duration_ps);
247     op_metrics->set_dma_stall_ps(dma_stall_ps);
248   } else {
249     op_metrics->set_occurrences(op_metrics->occurrences() +
250                                 hlo_event.NumOccurrences());
251     op_metrics->set_time_ps(op_metrics->time_ps() + duration_ps);
252     op_metrics->set_min_time_ps(
253         std::min<uint64_t>(op_metrics->min_time_ps(), min_duration_ps));
254     op_metrics->set_self_time_ps(op_metrics->self_time_ps() + self_duration_ps);
255     op_metrics->set_dma_stall_ps(op_metrics->dma_stall_ps() + dma_stall_ps);
256   }
257 }
258 
259 }  // namespace
260 
CollectTfOpsFromHostThreadsXPlane(const XPlane & host_trace)261 absl::flat_hash_map<int64_t, TfOp> CollectTfOpsFromHostThreadsXPlane(
262     const XPlane& host_trace) {
263   absl::flat_hash_map<int64_t, TfOp> tf_ops;
264   for (const auto& id_metadata : host_trace.event_metadata()) {
265     const XEventMetadata& metadata = id_metadata.second;
266     // On the host, we have added some user-specified TraceMe's in addition to
267     // the TraceMe's added to every TensorFlow op by the system. These
268     // user-inserted TraceMe's have "unknown" type. We don't count them in
269     // Tf-stats.
270     TfOp tf_op = ParseTfOpFullname(metadata.name());
271     if (tf_op.category != Category::kUnknown) {
272       tf_ops.try_emplace(metadata.id(), tf_op);
273     }
274   }
275   return tf_ops;
276 }
277 
ConvertHostThreadsXLineToTfMetricsDbData(const XLineVisitor & line,const absl::flat_hash_map<int64_t,TfOp> & tf_ops)278 TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData(
279     const XLineVisitor& line,
280     const absl::flat_hash_map<int64_t, TfOp>& tf_ops) {
281   TfMetricsDbData tf_metrics_db_data;
282   if (!tf_ops.empty()) {
283     std::vector<TfActivity> tf_activities;
284     CollectTfActivities(line, tf_ops, &tf_activities);
285     ProcessTfActivities(&tf_activities, &tf_metrics_db_data);
286   }
287   return tf_metrics_db_data;
288 }
289 
ConsumeTfMetricsDbData(TfMetricsDbData src,OpMetricsDbCombiner * dst)290 void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst) {
291   AddIdleOp(src.tf_metrics_db);
292   // Host OpMetricsDb does not need to update the number of cores a certain op
293   // occurs.
294   dst->Combine(src.tf_metrics_db, /*update_num_cores=*/false);
295   src.tf_metrics_db.Clear();
296 }
297 
ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane & host_trace)298 OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace) {
299   absl::flat_hash_map<int64_t, TfOp> tf_ops =
300       CollectTfOpsFromHostThreadsXPlane(host_trace);
301   OpMetricsDb result;
302   OpMetricsDbCombiner combiner(&result);
303   XPlaneVisitor plane = CreateTfXPlaneVisitor(&host_trace);
304   plane.ForEachLine([&tf_ops, &combiner](const XLineVisitor& line) {
305     ConsumeTfMetricsDbData(
306         ConvertHostThreadsXLineToTfMetricsDbData(line, tf_ops), &combiner);
307   });
308   return result;
309 }
310 
ConvertTpuDeviceTraceXPlaneToOpMetricsDb(const XPlane & device_trace)311 OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb(
312     const XPlane& device_trace) {
313   OpMetricsDb result;
314   XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
315   using OpMetricBySymbol = absl::flat_hash_map<int64_t, OpMetrics>;
316   absl::flat_hash_map<int64_t, OpMetricBySymbol> flat_op_metric;
317   plane.ForEachLine([&](const XLineVisitor& line) {
318     line.ForEachEvent([&](const XEventVisitor& event) {
319       OpKey key = GetOpKeyFromHloEventMetadata(event.Metadata());
320       if (!key.program_id.has_value() || !key.symbol_id.has_value()) return;
321       OpMetricBySymbol& op_metric_by_symbol =
322           flat_op_metric[key.program_id.value()];
323       if (key.symbol_id != kRootSymbolId) {
324         OpMetrics& op_metrics = op_metric_by_symbol[key.symbol_id.value()];
325         SetOpMetricsFromHloEvent(event, &op_metrics);
326       }
327     });
328   });
329 
330   for (auto& [program_id, op_metric_by_symbol] : flat_op_metric) {
331     for (auto& [symbol_id, op_metrics] : op_metric_by_symbol) {
332       result.add_metrics_db()->Swap(&op_metrics);
333     }
334   }
335   AddIdleOp(result);
336   return result;
337 }
338 
ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane & device_trace)339 OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace) {
340   OpMetricsDb result;
341   DeviceOpMetricsDbBuilder device_op_metrics_db_builder(&result);
342 
343   int64_t first_op_offset_ps = kint64max;
344   int64_t last_op_offset_ps = 0;
345 
346   TfOpRoofLineCostEstimator op_level_cost_estimator;
347   XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
348   plane.ForEachLine([&](const XLineVisitor& line) {
349     if (IsDerivedThreadId(line.Id())) return;
350     line.ForEachEvent([&](const XEventVisitor& event) {
351       first_op_offset_ps = std::min(first_op_offset_ps, event.OffsetPs());
352       last_op_offset_ps = std::max(last_op_offset_ps, event.EndOffsetPs());
353 
354       absl::string_view tf_op_full_name;
355       bool is_eager = false;
356       event.ForEachStat([&](const XStatVisitor& stat) {
357         if (stat.Type() == StatType::kTfOp) {
358           tf_op_full_name = stat.StrOrRefValue();
359         } else if (stat.Type() == StatType::kIsEager) {
360           is_eager = stat.IntValue();
361         }
362       });
363       if (tf_op_full_name.empty()) return;
364       TfOp tf_op = ParseTfOpFullname(tf_op_full_name);
365       TfOpRoofLineCostEstimator::OpRoofLineStats costs;
366       if (tf_op.category != Category::kUnknown) {
367         costs = op_level_cost_estimator.Predict(event);
368       }
369       device_op_metrics_db_builder.EnterOp(
370           /*program_id=*/0, absl::StrCat(tf_op.name, "/", event.Name()),
371           tf_op.type, tf_op_full_name, is_eager,
372           /*occurrences=*/1, event.DurationPs(),
373           /*children_time_ps=*/0, costs.flops, costs.bytes_accessed);
374     });
375   });
376   SetTotalTimePs(
377       result, last_op_offset_ps ? last_op_offset_ps - first_op_offset_ps : 0);
378   AddIdleOp(result);
379   return result;
380 }
381 
382 }  // namespace profiler
383 }  // namespace tensorflow
384