xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/utils/tf_op_utils.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/ascii.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/numbers.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/strings/strip.h"
28 #include "tensorflow/core/platform/regexp.h"
29 
30 namespace tensorflow {
31 namespace profiler {
32 namespace {
33 
34 const absl::string_view kIterator = "Iterator";
35 const absl::string_view kSeparator = "::";
36 constexpr char kNameScopeSeparator = '/';
37 constexpr char kOpNameSuffixSeparator = '_';
38 
IsInteger(absl::string_view str)39 bool IsInteger(absl::string_view str) {
40   int64_t unused;
41   return absl::SimpleAtoi(str, &unused);
42 }
43 
44 // Returns an op type derived from an op name.
DeriveOpType(absl::string_view full_op_name)45 absl::string_view DeriveOpType(absl::string_view full_op_name) {
46   // Use the op name without name scopes and suffix as an op type. A full op
47   // name consists of name scopes, an op type, and optionally a numeric suffix
48   // (e.g., model/layer/MatMul_1).
49   std::vector<absl::string_view> name_scopes_and_op_name =
50       absl::StrSplit(full_op_name, kNameScopeSeparator);
51   absl::string_view op_name = name_scopes_and_op_name.back();
52   std::vector<absl::string_view> op_type_and_maybe_suffix =
53       absl::StrSplit(op_name, kOpNameSuffixSeparator);
54   absl::string_view maybe_suffix = op_type_and_maybe_suffix.back();
55   absl::string_view op_type = op_name;
56   if (IsInteger(maybe_suffix)) {
57     // NOTE: assuming a numeric suffix is not part of an op type while
58     // technically it is allowed.
59     op_type = op_name.substr(0, op_name.size() - maybe_suffix.size() - 1);
60   }
61   return op_type;
62 }
63 
64 }  // namespace
65 
66 const absl::string_view kUnknownOp = "";  // op types are non-empty strings
67 const absl::string_view kDatasetOp = "Dataset";
68 const absl::string_view kMemcpyHToDOp = "MemcpyHToD";
69 const absl::string_view kMemcpyDToHOp = "MemcpyDToH";
70 const absl::string_view kMemcpyDToDOp = "MemcpyDToD";
71 const absl::string_view kMemcpyHToHOp = "MemcpyHToH";
72 
IsTfOpName(absl::string_view op_name)73 bool IsTfOpName(absl::string_view op_name) {
74   // TODO(b/177602927): Confirm the naming convention with the TF team.
75   static const LazyRE2 kTfOpNameRegEx = {"[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*"};
76   return RE2::FullMatch(op_name, *kTfOpNameRegEx);
77 }
78 
IsTfOpType(absl::string_view op_type)79 bool IsTfOpType(absl::string_view op_type) {
80   static const LazyRE2 kTfOpTypeRegEx = {"[A-Z_][a-zA-Z0-9_]*"};
81   return RE2::FullMatch(op_type, *kTfOpTypeRegEx);
82 }
83 
IsJaxOpType(absl::string_view op_type)84 bool IsJaxOpType(absl::string_view op_type) {
85   static const LazyRE2 kJaxOpTypeRegEx = {"[a-z_][a-z0-9_]*"};
86   return RE2::FullMatch(op_type, *kJaxOpTypeRegEx);
87 }
88 
IsJaxOpNameAndType(absl::string_view op_name,absl::string_view op_type)89 bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type) {
90   if (op_name.empty() || !IsJaxOpType(op_type)) return false;
91   std::vector<absl::string_view> split_result =
92       absl::StrSplit(op_name, kNameScopeSeparator);
93   return absl::StrContains(split_result.back(), op_type);
94 }
95 
ParseTfOpFullname(absl::string_view tf_op_fullname)96 TfOp ParseTfOpFullname(absl::string_view tf_op_fullname) {
97   // TF Op names have the format "name:type".
98   TfOp tf_op = {Category::kUnknown, tf_op_fullname, kUnknownOp};
99   std::vector<absl::string_view> parts =
100       absl::StrSplit(tf_op_fullname, absl::MaxSplits(':', 1));
101   if (parts.size() != 2) {
102     // GPU-related Ops that need to be tracked.
103     if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToD")) {
104       tf_op.category = Category::kMemcpyHToD;
105       tf_op.type = kMemcpyHToDOp;
106     } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToH")) {
107       tf_op.category = Category::kMemcpyDToH;
108       tf_op.type = kMemcpyDToHOp;
109     } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYDToD")) {
110       tf_op.category = Category::kMemcpyDToD;
111       tf_op.type = kMemcpyDToDOp;
112     } else if (absl::StartsWithIgnoreCase(tf_op_fullname, "MEMCPYHToH")) {
113       tf_op.category = Category::kMemcpyHToH;
114       tf_op.type = kMemcpyHToHOp;
115     }
116     // TODO(ckluk): Include the corresponding Ops on TPU.
117   } else if (parts[0] == kIterator) {
118     // Dataset Op names (e.g., Iterator::Batch::Map::TFRecord) do not follow the
119     // format of TF Op names. But we still want to capture them for
120     // input-pipeline analysis.
121     tf_op.category = Category::kTfData;
122     tf_op.type = kDatasetOp;
123   } else if (IsTfOpType(parts[1]) && IsTfOpName(parts[0])) {
124     tf_op = {Category::kTensorFlow, parts[0], parts[1]};
125   } else if (IsJaxOpType(parts[1])) {
126     tf_op = {Category::kJax, parts[0], parts[1]};
127   } else if (parts[1].empty()) {
128     tf_op = {Category::kTensorFlow, parts[0], DeriveOpType(parts[0])};
129   }
130   return tf_op;
131 }
132 
ParseTfNameScopes(absl::string_view tf_op_name)133 std::vector<absl::string_view> ParseTfNameScopes(absl::string_view tf_op_name) {
134   std::vector<absl::string_view> name_scopes =
135       absl::StrSplit(tf_op_name, kNameScopeSeparator);
136   // The last element is an op name not TF name scope.
137   if (!name_scopes.empty()) name_scopes.pop_back();
138   return name_scopes;
139 }
140 
ParseTfNameScopes(const TfOp & tf_op)141 std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op) {
142   return ParseTfNameScopes(tf_op.name);
143 }
144 
TfOpEventName(const TfOp & tf_op)145 std::string TfOpEventName(const TfOp& tf_op) {
146   std::string event_name;
147   if (tf_op.category == Category::kUnknown) {
148     // Some TraceMe names contain trailing whitespace, remove it.
149     event_name = std::string(absl::StripTrailingAsciiWhitespace(tf_op.name));
150   } else if (tf_op.category == Category::kTfData) {
151     event_name = DatasetOpEventName(tf_op.name);
152   } else {
153     event_name = std::string(tf_op.type);
154   }
155   return event_name;
156 }
157 
TfOpEventName(absl::string_view tf_op_fullname)158 std::string TfOpEventName(absl::string_view tf_op_fullname) {
159   return TfOpEventName(ParseTfOpFullname(tf_op_fullname));
160 }
161 
DatasetOpEventName(absl::string_view full_name)162 std::string DatasetOpEventName(absl::string_view full_name) {
163   std::vector<absl::string_view> split_result =
164       absl::StrSplit(full_name, kSeparator);
165   return absl::StrCat(kIterator, kSeparator, split_result.back());
166 }
167 
IteratorName(absl::string_view full_name)168 std::string IteratorName(absl::string_view full_name) {
169   std::vector<absl::string_view> split_result =
170       absl::StrSplit(full_name, kSeparator);
171   return std::string(split_result.back());
172 }
173 
ParseTensorShapes(absl::string_view tensor_shapes)174 std::vector<absl::string_view> ParseTensorShapes(
175     absl::string_view tensor_shapes) {
176   absl::ConsumePrefix(&tensor_shapes, "(");
177   absl::ConsumeSuffix(&tensor_shapes, ")");
178   return absl::StrSplit(tensor_shapes, ';');
179 }
180 
181 }  // namespace profiler
182 }  // namespace tensorflow
183