xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/convert/op_metrics_db_combiner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
24 
25 namespace tensorflow {
26 namespace profiler {
27 namespace {
28 
29 using OperationType = OpMetrics::MemoryAccessed::OperationType;
30 
CombinePrecisionStats(const PrecisionStats & src,PrecisionStats * dst)31 void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) {
32   dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps());
33   dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps());
34 }
35 
36 }  // namespace
37 
CopyOpMetricsMetadata(const OpMetrics & src,OpMetrics * dst)38 void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) {
39   DCHECK(dst != nullptr);
40   DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id());
41   DCHECK_EQ(src.name(), dst->name());
42   if (dst->long_name().empty()) {
43     dst->set_long_name(src.long_name());
44   }
45   if (dst->category().empty()) {
46     dst->set_category(src.category());
47   }
48   if (dst->provenance().empty()) {
49     dst->set_provenance(src.provenance());
50   }
51   if (dst->deduplicated_name().empty()) {
52     dst->set_deduplicated_name(src.deduplicated_name());
53   }
54   if (!dst->has_layout() && src.has_layout()) {
55     *dst->mutable_layout() = src.layout();
56   }
57   if (!dst->has_children() && src.has_children()) {
58     *dst->mutable_children() = src.children();
59   }
60 }
61 
CombineOpMetrics(const OpMetrics & src,OpMetrics * dst,bool update_num_cores)62 void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst,
63                       bool update_num_cores) {
64   DCHECK(dst != nullptr);
65   if (dst->occurrences() == 0) {
66     dst->set_min_time_ps(src.min_time_ps());
67   } else {
68     dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps()));
69   }
70   dst->set_is_eager(dst->is_eager() || src.is_eager());
71   dst->set_occurrences(src.occurrences() + dst->occurrences());
72   dst->set_time_ps(src.time_ps() + dst->time_ps());
73   dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps());
74   dst->set_flops(src.flops() + dst->flops());
75   dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed());
76   dst->set_autotuned(dst->autotuned() || src.autotuned());
77   if (update_num_cores) {
78     dst->set_num_cores(src.num_cores() + dst->num_cores());
79   }
80   CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(),
81                                  dst->mutable_memory_accessed_breakdown());
82   dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps());
83 }
84 
CombineMemoryAccessedBreakdown(const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> & src,protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> * dst)85 void CombineMemoryAccessedBreakdown(
86     const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
87     protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst) {
88   if (src.empty()) return;
89   absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
90                       OpMetrics_MemoryAccessed*>
91       dst_memory_accessed_map;
92   for (auto& dst_memory_accessed : *dst) {
93     dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
94                              dst_memory_accessed.operation_type()}] =
95         &dst_memory_accessed;
96   }
97   for (const auto& src_memory_accessed : src) {
98     uint64 memory_space = src_memory_accessed.memory_space();
99     OperationType operation_type = src_memory_accessed.operation_type();
100     auto*& dst_memory_accessed =
101         dst_memory_accessed_map[{memory_space, operation_type}];
102     if (dst_memory_accessed == nullptr) {
103       dst_memory_accessed = dst->Add();
104       dst_memory_accessed->set_memory_space(memory_space);
105       dst_memory_accessed->set_operation_type(operation_type);
106     }
107     dst_memory_accessed->set_bytes_accessed(
108         src_memory_accessed.bytes_accessed() +
109         dst_memory_accessed->bytes_accessed());
110   }
111 }
112 
Combine(const OpMetricsDb & src,bool update_num_cores)113 void OpMetricsDbCombiner::Combine(const OpMetricsDb& src,
114                                   bool update_num_cores) {
115   OpMetricsDb* dst = db();
116   dst->set_total_host_infeed_enq_duration_ps(
117       src.total_host_infeed_enq_duration_ps() +
118       dst->total_host_infeed_enq_duration_ps());
119   dst->set_total_host_infeed_enq_start_timestamp_ps_diff(
120       src.total_host_infeed_enq_start_timestamp_ps_diff() +
121       dst->total_host_infeed_enq_start_timestamp_ps_diff());
122   dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps());
123   dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps());
124   CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats());
125 
126   for (const auto& src_metrics : src.metrics_db()) {
127     auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(),
128                                                    src_metrics.name());
129     CopyOpMetricsMetadata(src_metrics, dst_metrics);
130     CombineOpMetrics(src_metrics, dst_metrics, update_num_cores);
131   }
132 }
133 
134 }  // namespace profiler
135 }  // namespace tensorflow
136