1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
17 #define TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
18
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/match.h"
23 #include "absl/strings/string_view.h"
24 #include "tensorflow/core/platform/macros.h"
25
26 namespace tensorflow {
27 namespace profiler {
28
29 // Special op types.
30 TF_CONST_INIT extern const absl::string_view kUnknownOp;
31 TF_CONST_INIT extern const absl::string_view kDatasetOp;
32 TF_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
33 TF_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
34 TF_CONST_INIT extern const absl::string_view kMemcpyDToDOp;
35 TF_CONST_INIT extern const absl::string_view kMemcpyHToHOp;
36
37 enum class Category {
38 kUnknown,
39 kTensorFlow,
40 kJax,
41 kTfData,
42 kMemcpyHToD,
43 kMemcpyDToH,
44 kMemcpyDToD,
45 kMemcpyHToH,
46 };
47
48 // Breaks a TensorFlow op fullname into name and type.
49 struct TfOp {
50 Category category = Category::kUnknown;
51 absl::string_view name;
52 absl::string_view type;
53 };
54 TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
55
56 // Returns a vector of TF name scopes extracted from a TF op name.
57 std::vector<absl::string_view> ParseTfNameScopes(absl::string_view tf_op_name);
58 std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
59
60 // Trace event name for TF ops is the op type so they have the same color in
61 // trace viewer.
62 std::string TfOpEventName(const TfOp& tf_op);
63 std::string TfOpEventName(absl::string_view tf_op_fullname);
64
65 // Trace event name for dataset ops.
66 std::string DatasetOpEventName(absl::string_view full_name);
67
68 // Returns the iterator name without prefix and parent iterator names.
69 std::string IteratorName(absl::string_view full_name);
70
71 // Returns true if the given name is a TensorFlow Dataset Op.
IsDatasetOp(absl::string_view tf_op_type)72 inline bool IsDatasetOp(absl::string_view tf_op_type) {
73 return tf_op_type == kDatasetOp;
74 }
IsDatasetOp(const TfOp & tf_op)75 inline bool IsDatasetOp(const TfOp& tf_op) {
76 return tf_op.category == Category::kTfData;
77 }
78
79 // Returns true if the given name is a TensorFlow Infeed Enqueue Op.
80 // See: tensorflow/core/tpu/kernels/infeed_ops.h
IsInfeedEnqueueOp(absl::string_view tf_op_type)81 inline bool IsInfeedEnqueueOp(absl::string_view tf_op_type) {
82 return absl::StartsWith(tf_op_type, "InfeedEnqueue");
83 }
IsInfeedEnqueueOp(const TfOp & tf_op)84 inline bool IsInfeedEnqueueOp(const TfOp& tf_op) {
85 return tf_op.category == Category::kTensorFlow &&
86 IsInfeedEnqueueOp(tf_op.type);
87 }
88
89 // Returns true if the given op has XlaSendToHost/XlaRecvFromHost in fullname.
IsOutsideCompilationOp(absl::string_view tf_op_fullname)90 inline bool IsOutsideCompilationOp(absl::string_view tf_op_fullname) {
91 if (absl::EndsWith(tf_op_fullname, ":XlaSendToHost")) return true;
92 if (absl::EndsWith(tf_op_fullname, ":XlaRecvFromHost")) return true;
93 return false;
94 }
95
96 // Returns true if the given op is for outside compilation.
IsOutsideCompilationOp(absl::string_view tf_op_fullname,absl::string_view hlo_expression)97 inline bool IsOutsideCompilationOp(absl::string_view tf_op_fullname,
98 absl::string_view hlo_expression) {
99 if (IsOutsideCompilationOp(tf_op_fullname)) return true;
100 if (absl::StrContains(hlo_expression, "send-done") &&
101 absl::StrContains(hlo_expression, "is_host_transfer=true"))
102 return true;
103 return false;
104 }
105
106 // Returns true if the given name is a TensorFlow embedding op.
IsEmbeddingOp(absl::string_view tf_op_fullname)107 inline bool IsEmbeddingOp(absl::string_view tf_op_fullname) {
108 return absl::StrContains(tf_op_fullname, "Embedding");
109 }
110
111 // Returns true if the given op is for copying data from host to device.
IsMemcpyHToDOp(absl::string_view tf_op_type)112 inline bool IsMemcpyHToDOp(absl::string_view tf_op_type) {
113 return tf_op_type == kMemcpyHToDOp;
114 }
IsMemcpyHToDOp(const TfOp & tf_op)115 inline bool IsMemcpyHToDOp(const TfOp& tf_op) {
116 return tf_op.category == Category::kMemcpyHToD;
117 }
118
119 // Returns true if the given op is for copying data from device to host.
IsMemcpyDToHOp(const TfOp & tf_op)120 inline bool IsMemcpyDToHOp(const TfOp& tf_op) {
121 return tf_op.category == Category::kMemcpyDToH;
122 }
123
124 // Returns true if the given op is for copying data from device to device.
IsMemcpyDToDOp(const TfOp & tf_op)125 inline bool IsMemcpyDToDOp(const TfOp& tf_op) {
126 return tf_op.category == Category::kMemcpyDToD;
127 }
128
129 // Returns true if the given op is for copying data from host to host.
IsMemcpyHToHOp(const TfOp & tf_op)130 inline bool IsMemcpyHToHOp(const TfOp& tf_op) {
131 return tf_op.category == Category::kMemcpyHToH;
132 }
133
134 // Splits a string of tensor shapes in "(shape1;shape2;...)" format, i.e.,
135 // delimited by '(' and ')' and separated by ';', into the individual shapes.
136 std::vector<absl::string_view> ParseTensorShapes(
137 absl::string_view tensor_shapes);
138
139 // Returns true if the given string matches OpDef.name pattern.
140 bool IsTfOpName(absl::string_view op_name);
141
142 // Returns true if the given string matches NodeDef.name pattern.
143 bool IsTfOpType(absl::string_view op_type);
144
145 // Returns true if the given string matches JAX pattern.
146 bool IsJaxOpType(absl::string_view op_type);
147
148 // Returns true if the given strings match JAX pattern.
149 bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type);
150
151 } // namespace profiler
152 } // namespace tensorflow
153
154 #endif // TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
155