xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/kernel_stats_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <tuple>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
31 
32 namespace tensorflow {
33 namespace profiler {
34 
35 namespace {
36 
37 // The maximum number of Kernels displayed on Kernel Stats page.
38 const int kMaxNumOfKernels = 1000;
39 
40 // A list of patterns to help determine if a kernel uses Tensor Core.
41 // A kernel uses Tensor Core if its kernel name contains any of these patterns.
42 // Some examples of kernel names: volta_h884gemm, turing_fp16_s1688cudnn_fp16
43 constexpr absl::string_view kTensorCoreKernelNamePatterns[] = {
44     "16816",
45     "c1688",
46     "conv1x1",
47     "conv2d_c1_k1",
48     "dgrad_1x1_stride_2x2",
49     "direct_group",
50     "first_layer_wgrad_kernel",
51     "h1688",
52     "h884",
53     "hmma",
54     "i16832",
55     "i8816",
56     "s884",
57     "s1688",
58     "xmma_gemm",
59     "xmma_implicit_gemm",
60     "xmma_sparse_conv",
61     "xmma_sparse_gemm",
62     "xmma_warp_specialized_implicit_gemm"};
63 
64 }  // namespace
65 
ParseKernelLaunchParams(absl::string_view xstat_kernel_details,KernelReport * kernel)66 void ParseKernelLaunchParams(absl::string_view xstat_kernel_details,
67                              KernelReport* kernel) {
68   const std::vector<absl::string_view> params =
69       absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(" \n"));
70 
71   constexpr uint32 kNumDimensions = 3;
72   for (uint32 dim = 0; dim < kNumDimensions; ++dim) {
73     kernel->add_block_dim(1);
74     kernel->add_grid_dim(1);
75   }
76 
77   // Process tokens.
78   for (const auto& param : params) {
79     const std::vector<absl::string_view> key_value = absl::StrSplit(param, ':');
80     if (key_value.size() != 2) {
81       // Unrecognized token.
82       continue;
83     }
84     absl::string_view key = key_value[0];
85     absl::string_view value_str = key_value[1];
86     uint32 value = 0;
87     double pct = 0.0;
88     // Cases that consume a pair of tokens "key:value".
89     if (key == "regs" && absl::SimpleAtoi(value_str, &value)) {
90       kernel->set_registers_per_thread(value);
91     } else if (key == "static_shared" && absl::SimpleAtoi(value_str, &value)) {
92       kernel->set_static_shmem_bytes(value);
93     } else if (key == "dynamic_shared" && absl::SimpleAtoi(value_str, &value)) {
94       kernel->set_dynamic_shmem_bytes(value);
95     } else if (key == "block") {
96       const std::vector<absl::string_view>& block =
97           absl::StrSplit(value_str, ',');
98       uint32 tmp[3];
99       if (block.size() == 3 && absl::SimpleAtoi(block[0], &tmp[0]) &&
100           absl::SimpleAtoi(block[1], &tmp[1]) &&
101           absl::SimpleAtoi(block[2], &tmp[2])) {
102         std::copy_n(tmp, 3, kernel->mutable_block_dim()->begin());
103       }
104     } else if (key == "grid") {
105       const std::vector<absl::string_view>& grid =
106           absl::StrSplit(value_str, ',');
107       uint32 tmp[3];
108       if (grid.size() == 3 && absl::SimpleAtoi(grid[0], &tmp[0]) &&
109           absl::SimpleAtoi(grid[1], &tmp[1]) &&
110           absl::SimpleAtoi(grid[2], &tmp[2])) {
111         std::copy_n(tmp, 3, kernel->mutable_grid_dim()->begin());
112       }
113     } else if (key == "occ_pct" && absl::SimpleAtod(value_str, &pct)) {
114       kernel->set_occupancy_pct(pct);
115     }
116   }
117 }
118 
IsKernelUsingTensorCore(absl::string_view kernel_name)119 bool IsKernelUsingTensorCore(absl::string_view kernel_name) {
120   VLOG(1) << "kernel name: " << kernel_name;
121   for (absl::string_view pattern : kTensorCoreKernelNamePatterns) {
122     if (absl::StrContains(kernel_name, pattern)) {
123       return true;
124     }
125   }
126   return false;
127 }
128 
129 // This list is not exhaustive.
IsOpTensorCoreEligible(absl::string_view tf_op_name)130 bool IsOpTensorCoreEligible(absl::string_view tf_op_name) {
131   // Disable formatting to keep inline comments vertically aligned.
132   // clang-format off
133   return false
134       // Using EndsWith to match Fused operations.
135       || absl::EndsWith(tf_op_name, "Conv2D")
136       || absl::EndsWith(tf_op_name, "Conv2DBackpropFilter")
137       || absl::EndsWith(tf_op_name, "Conv2DBackpropInput")
138       || absl::EndsWith(tf_op_name, "Conv3D")
139       || absl::EndsWith(tf_op_name, "DepthwiseConv2dNative")
140       || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropFilter")
141       || absl::EndsWith(tf_op_name, "DepthwiseConv2dNativeBackpropInput")
142       // Using Contains to match V2/V3 suffixes.
143       || absl::StrContains(tf_op_name, "BatchMatMul")
144       // MatMul requires exact matching.
145       || absl::EndsWith(tf_op_name, "/MatMul")
146       || absl::EndsWith(tf_op_name, "FusedMatMul")
147       // cuDNN operations.
148       || absl::EndsWith(tf_op_name, "/CudnnRNN")
149       || absl::StrContains(tf_op_name, "CudnnRNNV")
150       || absl::StrContains(tf_op_name, "CudnnRNNForward")
151       || absl::StrContains(tf_op_name, "CudnnRNNBackprop")
152       // Special cases.
153       || absl::EndsWith(tf_op_name, "XlaDot")
154       || absl::EndsWith(tf_op_name, "XlaDotV2");
155   // clang-format on
156 }
157 
IsEinsumTensorCoreEligible(absl::string_view equation)158 bool IsEinsumTensorCoreEligible(absl::string_view equation) {
159   if (equation.empty()) {
160     return false;
161   }
162   const std::vector<absl::string_view> input_output =
163       absl::StrSplit(equation, "->");
164   if (input_output.size() != 2) {
165     return false;
166   }
167   const std::vector<absl::string_view> lhs_rhs =
168       absl::StrSplit(input_output[0], ',');
169   return lhs_rhs.size() == 2;
170 }
171 
operator ()(const KernelReport & lhs,const KernelReport & rhs) const172 bool KernelReportLessThanComparator::operator()(const KernelReport& lhs,
173                                                 const KernelReport& rhs) const {
174   // Disable formatting to keep vertical alignment for better readability,
175   // and make it easier to reorder columns.
176   // clang-format off
177   auto lhs_tuple = std::make_tuple(
178       lhs.name(),
179       lhs.grid_dim(0),
180       lhs.grid_dim(1),
181       lhs.grid_dim(2),
182       lhs.block_dim(0),
183       lhs.block_dim(1),
184       lhs.block_dim(2),
185       lhs.registers_per_thread(),
186       lhs.static_shmem_bytes(),
187       lhs.dynamic_shmem_bytes(),
188       lhs.is_kernel_using_tensor_core(),
189       lhs.is_op_tensor_core_eligible(),
190       lhs.op_name());
191 
192   auto rhs_tuple = std::make_tuple(
193       rhs.name(),
194       rhs.grid_dim(0),
195       rhs.grid_dim(1),
196       rhs.grid_dim(2),
197       rhs.block_dim(0),
198       rhs.block_dim(1),
199       rhs.block_dim(2),
200       rhs.registers_per_thread(),
201       rhs.static_shmem_bytes(),
202       rhs.dynamic_shmem_bytes(),
203       rhs.is_kernel_using_tensor_core(),
204       rhs.is_op_tensor_core_eligible(),
205       rhs.op_name());
206   // clang-format on
207   return lhs_tuple < rhs_tuple;
208 }
209 
operator ()(const KernelReport & lhs,const KernelReport & rhs) const210 bool KernelReportEqualToComparator::operator()(const KernelReport& lhs,
211                                                const KernelReport& rhs) const {
212   // Disable formatting to keep vertical alignment for better readability,
213   // and make it easier to reorder columns.
214   // clang-format off
215   // Put the most expensive string comparisons last.
216   return (
217       lhs.is_kernel_using_tensor_core() == rhs.is_kernel_using_tensor_core() &&
218       lhs.is_op_tensor_core_eligible() == rhs.is_op_tensor_core_eligible() &&
219       lhs.block_dim(0) == rhs.block_dim(0) &&
220       lhs.block_dim(1) == rhs.block_dim(1) &&
221       lhs.block_dim(2) == rhs.block_dim(2) &&
222       lhs.grid_dim(0) == rhs.grid_dim(0) &&
223       lhs.grid_dim(1) == rhs.grid_dim(1) &&
224       lhs.grid_dim(2) == rhs.grid_dim(2) &&
225       lhs.registers_per_thread() == rhs.registers_per_thread() &&
226       lhs.static_shmem_bytes() == rhs.static_shmem_bytes() &&
227       lhs.dynamic_shmem_bytes() == rhs.dynamic_shmem_bytes() &&
228       lhs.name() == rhs.name() &&
229       lhs.op_name() == rhs.op_name());
230   // clang-format on
231 }
232 
SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb * kernel_stats_db)233 void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db) {
234   auto comp = [](const KernelReport& lhs, const KernelReport& rhs) {
235     return lhs.total_duration_ns() > rhs.total_duration_ns() ||
236            (lhs.total_duration_ns() == rhs.total_duration_ns() &&
237             KernelReportLessThanComparator()(lhs, rhs));
238   };
239 
240   // Sort and keep at most <kMaxNumOfKernels> kernel reports.
241   if (kernel_stats_db->reports_size() > kMaxNumOfKernels) {
242     std::partial_sort(
243         kernel_stats_db->mutable_reports()->begin(),
244         kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels,
245         kernel_stats_db->mutable_reports()->end(), comp);
246     kernel_stats_db->mutable_reports()->erase(
247         kernel_stats_db->mutable_reports()->begin() + kMaxNumOfKernels,
248         kernel_stats_db->mutable_reports()->end());
249   } else {
250     std::sort(kernel_stats_db->mutable_reports()->begin(),
251               kernel_stats_db->mutable_reports()->end(), comp);
252   }
253 }
254 
CopyTopKDurationKernelReportsToDb(const KernelReportMap & reports,KernelStatsDb * dst)255 void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports,
256                                        KernelStatsDb* dst) {
257   std::vector<std::pair<const KernelReport*, const KernelReportValue*>>
258       kernels_to_sort;
259   kernels_to_sort.reserve(reports.size());
260   for (const auto& report_value : reports) {
261     kernels_to_sort.push_back(
262         std::make_pair(&report_value.first, &report_value.second));
263   }
264 
265   auto comp =
266       [](const std::pair<const KernelReport*, const KernelReportValue*>& lhs,
267          const std::pair<const KernelReport*, const KernelReportValue*>& rhs) {
268         return lhs.second->total_duration_ns > rhs.second->total_duration_ns ||
269                (lhs.second->total_duration_ns ==
270                     rhs.second->total_duration_ns &&
271                 KernelReportLessThanComparator()(*lhs.first, *rhs.first));
272       };
273 
274   // Sort and copy at most <kMaxNumOfKernels> kernels to <dst>.
275   if (kernels_to_sort.size() > kMaxNumOfKernels) {
276     absl::c_partial_sort(kernels_to_sort,
277                          kernels_to_sort.begin() + kMaxNumOfKernels, comp);
278   } else {
279     absl::c_sort(kernels_to_sort, comp);
280   }
281 
282   int copy_size =
283       std::min(kMaxNumOfKernels, static_cast<int>(kernels_to_sort.size()));
284   for (int i = 0; i < copy_size; i++) {
285     KernelReport* report = dst->add_reports();
286     *report = *kernels_to_sort[i].first;
287     const KernelReportValue& kernel_value = *kernels_to_sort[i].second;
288     // Set value using KernelReportValue.
289     report->set_occurrences(kernel_value.occurrences);
290     report->set_min_duration_ns(kernel_value.min_duration_ns);
291     report->set_max_duration_ns(kernel_value.max_duration_ns);
292     report->set_total_duration_ns(kernel_value.total_duration_ns);
293   }
294 }
295 
InsertOrUpdateKernelReport(const KernelReport & kernel,const KernelReportValue & value,KernelReportMap * dst)296 void InsertOrUpdateKernelReport(const KernelReport& kernel,
297                                 const KernelReportValue& value,
298                                 KernelReportMap* dst) {
299   KernelReportValue& element = (*dst)[kernel];
300   if (element.occurrences == 0) {
301     element = value;
302   } else {
303     element.total_duration_ns += value.total_duration_ns;
304     element.min_duration_ns =
305         std::min(element.min_duration_ns, value.min_duration_ns);
306     element.max_duration_ns =
307         std::max(element.max_duration_ns, value.max_duration_ns);
308     element.occurrences += 1;
309   }
310 }
311 
MergeKernelReports(const KernelReportMap & reports,KernelReportMap * dst)312 void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst) {
313   for (auto& kernel_value : reports) {
314     InsertOrUpdateKernelReport(kernel_value.first, kernel_value.second, dst);
315   }
316 }
317 
GroupKernelReportsByOpName(const KernelStatsDb & kernel_stats_db)318 KernelStatsByOpName GroupKernelReportsByOpName(
319     const KernelStatsDb& kernel_stats_db) {
320   KernelStatsByOpName op_level_kernel_stats;
321   for (const KernelReport& kernel_report : kernel_stats_db.reports()) {
322     auto ret = op_level_kernel_stats.emplace(kernel_report.op_name(),
323                                              OpLevelKernelStats());
324     if (ret.second) {
325       // Inserted. Add a new op in <op_level_kernel_stats>.
326       OpLevelKernelStats& stats = ret.first->second;
327       stats.is_op_tensor_core_eligible =
328           kernel_report.is_op_tensor_core_eligible();
329       stats.total_duration_ns += kernel_report.total_duration_ns();
330       if (kernel_report.is_kernel_using_tensor_core()) {
331         stats.tensor_core_duration_ns += kernel_report.total_duration_ns();
332       }
333     } else {
334       // Not inserted. Aggregate kernel stats to op level.
335       OpLevelKernelStats& stats = ret.first->second;
336       // Verifies operations with the same name have the same TensorCore
337       // eligibility.
338       DCHECK_EQ(stats.is_op_tensor_core_eligible,
339                 kernel_report.is_op_tensor_core_eligible());
340       stats.total_duration_ns += kernel_report.total_duration_ns();
341       if (kernel_report.is_kernel_using_tensor_core()) {
342         stats.tensor_core_duration_ns += kernel_report.total_duration_ns();
343       }
344     }
345   }
346   return op_level_kernel_stats;
347 }
348 
349 }  // namespace profiler
350 }  // namespace tensorflow
351