xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/op_metrics_db_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
27 #include "tensorflow/core/profiler/utils/math_utils.h"
28 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
29 
30 namespace tensorflow {
31 namespace profiler {
32 
33 const absl::string_view kIdle = "IDLE";
34 
35 namespace {
36 
37 class DeviceTfOpMetricsDbBuilder : public OpMetricsDbBuilder {
38  public:
DeviceTfOpMetricsDbBuilder(OpMetricsDb * db)39   explicit DeviceTfOpMetricsDbBuilder(OpMetricsDb* db)
40       : OpMetricsDbBuilder(db) {}
41 
UpdateTfOpMetricsWithDeviceOpMetrics(absl::string_view tf_op_name,absl::string_view tf_op_type,const OpMetrics & device_op_metrics)42   void UpdateTfOpMetricsWithDeviceOpMetrics(
43       absl::string_view tf_op_name, absl::string_view tf_op_type,
44       const OpMetrics& device_op_metrics) {
45     OpMetrics* tf_op_metrics = OpMetricsDbBuilder::LookupOrInsertNewOpMetrics(
46         /*hlo_module_id=*/0, tf_op_name);
47     if (tf_op_metrics->category().empty()) {
48       tf_op_metrics->set_category(
49           tf_op_type == kUnknownOp ? "Unknown" : std::string(tf_op_type));
50     }
51     tf_op_metrics->set_is_eager(device_op_metrics.is_eager());
52     // The occurrences of a TF-op is the maximum among the occurrences of all
53     // device ops that it contains.
54     tf_op_metrics->set_occurrences(std::max(tf_op_metrics->occurrences(),
55                                             device_op_metrics.occurrences()));
56     tf_op_metrics->set_time_ps(tf_op_metrics->time_ps() +
57                                device_op_metrics.time_ps());
58     tf_op_metrics->set_self_time_ps(tf_op_metrics->self_time_ps() +
59                                     device_op_metrics.self_time_ps());
60     tf_op_metrics->set_flops(tf_op_metrics->flops() +
61                              device_op_metrics.flops());
62     tf_op_metrics->set_bytes_accessed(tf_op_metrics->bytes_accessed() +
63                                       device_op_metrics.bytes_accessed());
64   }
65 };
66 
67 }  // namespace
68 
OpMetricsDbBuilder(OpMetricsDb * db)69 OpMetricsDbBuilder::OpMetricsDbBuilder(OpMetricsDb* db) : db_(db) {
70   DCHECK_NE(db_, nullptr);
71   DCHECK_EQ(db_->metrics_db_size(), 0);
72 }
73 
LookupOrInsertNewOpMetrics(uint64 hlo_module_id,absl::string_view name)74 OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics(
75     uint64 hlo_module_id, absl::string_view name) {
76   OpMetrics*& op_metrics = op_metrics_map_[hlo_module_id][name];
77   if (op_metrics == nullptr) {
78     op_metrics = db_->add_metrics_db();
79     op_metrics->set_hlo_module_id(hlo_module_id);
80     op_metrics->set_name(name.data(), name.size());
81   }
82   return op_metrics;
83 }
84 
IdleTimeRatio(const OpMetricsDb & db)85 double IdleTimeRatio(const OpMetricsDb& db) {
86   return 1.0 - SafeDivide(db.total_op_time_ps(), db.total_time_ps());
87 }
88 
IdleTimePs(const OpMetricsDb & db)89 uint64 IdleTimePs(const OpMetricsDb& db) {
90   DCHECK_GE(db.total_time_ps(), db.total_op_time_ps());
91   return db.total_time_ps() - db.total_op_time_ps();
92 }
93 
SetIdleOp(uint64_t idle_time_ps,OpMetrics & metrics)94 void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics) {
95   metrics.set_name(std::string(kIdle));
96   metrics.set_category(std::string(kIdle));
97   metrics.set_occurrences(0);
98   metrics.set_time_ps(idle_time_ps);
99   metrics.set_self_time_ps(idle_time_ps);
100 }
101 
AddIdleOp(OpMetricsDb & db)102 void AddIdleOp(OpMetricsDb& db) {
103   uint64 idle_time_ps = IdleTimePs(db);
104   SetIdleOp(idle_time_ps, *db.add_metrics_db());
105 }
106 
HostInfeedEnqueueRatio(const OpMetricsDb & db)107 absl::optional<double> HostInfeedEnqueueRatio(const OpMetricsDb& db) {
108   if (db.total_host_infeed_enq_start_timestamp_ps_diff() > 0) {
109     // We use total_host_infeed_enq_start_timestamp_ps_diff to approximate the
110     // total host time.
111     return SafeDivide(db.total_host_infeed_enq_duration_ps(),
112                       db.total_host_infeed_enq_start_timestamp_ps_diff());
113   }
114   return absl::nullopt;
115 }
116 
CreateTfMetricsDbFromDeviceOpMetricsDb(const OpMetricsDb & device_op_metrics_db,bool with_idle)117 OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb(
118     const OpMetricsDb& device_op_metrics_db, bool with_idle) {
119   OpMetricsDb tf_op_metrics_db;
120   DeviceTfOpMetricsDbBuilder builder(&tf_op_metrics_db);
121   for (const auto& device_op_metrics : device_op_metrics_db.metrics_db()) {
122     if (IsIdleOp(device_op_metrics)) {
123       if (with_idle) {
124         builder.UpdateTfOpMetricsWithDeviceOpMetrics(kIdle, kIdle,
125                                                      device_op_metrics);
126       }
127     } else if (device_op_metrics.provenance().empty()) {
128       builder.UpdateTfOpMetricsWithDeviceOpMetrics(
129           device_op_metrics.name(), kUnknownOp, device_op_metrics);
130     } else {
131       TfOp tf_op = ParseTfOpFullname(device_op_metrics.provenance());
132       builder.UpdateTfOpMetricsWithDeviceOpMetrics(tf_op.name, tf_op.type,
133                                                    device_op_metrics);
134     }
135   }
136   tf_op_metrics_db.set_total_op_time_ps(
137       device_op_metrics_db.total_op_time_ps());
138 
139   tf_op_metrics_db.set_total_time_ps(
140       with_idle ? device_op_metrics_db.total_time_ps()
141                 : device_op_metrics_db.total_op_time_ps());
142 
143   return tf_op_metrics_db;
144 }
145 
146 }  // namespace profiler
147 }  // namespace tensorflow
148