1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h"
24
25 namespace tensorflow {
26 namespace profiler {
27 namespace {
28
29 using OperationType = OpMetrics::MemoryAccessed::OperationType;
30
CombinePrecisionStats(const PrecisionStats & src,PrecisionStats * dst)31 void CombinePrecisionStats(const PrecisionStats& src, PrecisionStats* dst) {
32 dst->set_compute_16bit_ps(src.compute_16bit_ps() + dst->compute_16bit_ps());
33 dst->set_compute_32bit_ps(src.compute_32bit_ps() + dst->compute_32bit_ps());
34 }
35
36 } // namespace
37
CopyOpMetricsMetadata(const OpMetrics & src,OpMetrics * dst)38 void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst) {
39 DCHECK(dst != nullptr);
40 DCHECK_EQ(src.hlo_module_id(), dst->hlo_module_id());
41 DCHECK_EQ(src.name(), dst->name());
42 if (dst->long_name().empty()) {
43 dst->set_long_name(src.long_name());
44 }
45 if (dst->category().empty()) {
46 dst->set_category(src.category());
47 }
48 if (dst->provenance().empty()) {
49 dst->set_provenance(src.provenance());
50 }
51 if (dst->deduplicated_name().empty()) {
52 dst->set_deduplicated_name(src.deduplicated_name());
53 }
54 if (!dst->has_layout() && src.has_layout()) {
55 *dst->mutable_layout() = src.layout();
56 }
57 if (!dst->has_children() && src.has_children()) {
58 *dst->mutable_children() = src.children();
59 }
60 }
61
CombineOpMetrics(const OpMetrics & src,OpMetrics * dst,bool update_num_cores)62 void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst,
63 bool update_num_cores) {
64 DCHECK(dst != nullptr);
65 if (dst->occurrences() == 0) {
66 dst->set_min_time_ps(src.min_time_ps());
67 } else {
68 dst->set_min_time_ps(std::min(src.min_time_ps(), dst->min_time_ps()));
69 }
70 dst->set_is_eager(dst->is_eager() || src.is_eager());
71 dst->set_occurrences(src.occurrences() + dst->occurrences());
72 dst->set_time_ps(src.time_ps() + dst->time_ps());
73 dst->set_self_time_ps(src.self_time_ps() + dst->self_time_ps());
74 dst->set_flops(src.flops() + dst->flops());
75 dst->set_bytes_accessed(src.bytes_accessed() + dst->bytes_accessed());
76 dst->set_autotuned(dst->autotuned() || src.autotuned());
77 if (update_num_cores) {
78 dst->set_num_cores(src.num_cores() + dst->num_cores());
79 }
80 CombineMemoryAccessedBreakdown(src.memory_accessed_breakdown(),
81 dst->mutable_memory_accessed_breakdown());
82 dst->set_dma_stall_ps(src.dma_stall_ps() + dst->dma_stall_ps());
83 }
84
CombineMemoryAccessedBreakdown(const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> & src,protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed> * dst)85 void CombineMemoryAccessedBreakdown(
86 const protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>& src,
87 protobuf::RepeatedPtrField<OpMetrics_MemoryAccessed>* dst) {
88 if (src.empty()) return;
89 absl::flat_hash_map<std::pair<uint64 /*memory_space*/, OperationType>,
90 OpMetrics_MemoryAccessed*>
91 dst_memory_accessed_map;
92 for (auto& dst_memory_accessed : *dst) {
93 dst_memory_accessed_map[{dst_memory_accessed.memory_space(),
94 dst_memory_accessed.operation_type()}] =
95 &dst_memory_accessed;
96 }
97 for (const auto& src_memory_accessed : src) {
98 uint64 memory_space = src_memory_accessed.memory_space();
99 OperationType operation_type = src_memory_accessed.operation_type();
100 auto*& dst_memory_accessed =
101 dst_memory_accessed_map[{memory_space, operation_type}];
102 if (dst_memory_accessed == nullptr) {
103 dst_memory_accessed = dst->Add();
104 dst_memory_accessed->set_memory_space(memory_space);
105 dst_memory_accessed->set_operation_type(operation_type);
106 }
107 dst_memory_accessed->set_bytes_accessed(
108 src_memory_accessed.bytes_accessed() +
109 dst_memory_accessed->bytes_accessed());
110 }
111 }
112
Combine(const OpMetricsDb & src,bool update_num_cores)113 void OpMetricsDbCombiner::Combine(const OpMetricsDb& src,
114 bool update_num_cores) {
115 OpMetricsDb* dst = db();
116 dst->set_total_host_infeed_enq_duration_ps(
117 src.total_host_infeed_enq_duration_ps() +
118 dst->total_host_infeed_enq_duration_ps());
119 dst->set_total_host_infeed_enq_start_timestamp_ps_diff(
120 src.total_host_infeed_enq_start_timestamp_ps_diff() +
121 dst->total_host_infeed_enq_start_timestamp_ps_diff());
122 dst->set_total_time_ps(src.total_time_ps() + dst->total_time_ps());
123 dst->set_total_op_time_ps(src.total_op_time_ps() + dst->total_op_time_ps());
124 CombinePrecisionStats(src.precision_stats(), dst->mutable_precision_stats());
125
126 for (const auto& src_metrics : src.metrics_db()) {
127 auto* dst_metrics = LookupOrInsertNewOpMetrics(src_metrics.hlo_module_id(),
128 src_metrics.name());
129 CopyOpMetricsMetadata(src_metrics, dst_metrics);
130 CombineOpMetrics(src_metrics, dst_metrics, update_num_cores);
131 }
132 }
133
134 } // namespace profiler
135 } // namespace tensorflow
136