1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/resource_util.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31
32 namespace tensorflow {
33 namespace {
34
35 using stream_executor::port::StatusOr;
36
37 const char kIdentityNOp[] = "IdentityN";
38 const char kIfOp[] = "If";
39 const char kWhileOp[] = "While";
40 const char kArgOp[] = "_Arg";
41 const char kRetvalOp[] = "_Retval";
42
43 const int kMaxCallDepth = 100;
44
45 Status AnalyzeResourceUsage(
46 const Graph* graph, const std::optional<std::string>& function_name,
47 const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
48 FunctionLibraryRuntime* lib_runtime,
49 absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
50 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
51 source_to_path);
52
IsControlFlowV1Node(const Node * n)53 bool IsControlFlowV1Node(const Node* n) {
54 return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
55 n->IsNextIteration());
56 }
57
58 // TODO(ycao): Add this as Tensorflow Node method.
OutputEdgesByIndex(const Node & n,int idx)59 StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node& n,
60 int idx) {
61 absl::InlinedVector<const Edge*, 1> res;
62 if (idx >= n.num_outputs()) {
63 return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
64 n.name(), " only has ", n.num_outputs(),
65 " outputs.");
66 }
67
68 for (const Edge* o : n.out_edges()) {
69 if (o->src_output() == idx) res.emplace_back(o);
70 }
71 return res;
72 }
73
IsStackOrTensorArraySource(const Node & n)74 bool IsStackOrTensorArraySource(const Node& n) {
75 const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
76
77 if (!op_info) return false;
78 if (op_info->resource_kind() != XlaResourceKind::kStack &&
79 op_info->resource_kind() != XlaResourceKind::kTensorArray)
80 return false;
81 return n.num_outputs() > 0 && n.output_type(0) == DataType::DT_RESOURCE;
82 }
83
PropagateFromStackOrTensorArraySourceOp(const Node & n,const std::optional<std::string> & function_name,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)84 void PropagateFromStackOrTensorArraySourceOp(
85 const Node& n, const std::optional<std::string>& function_name,
86 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
87 user_to_source) {
88 ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
89 n.type_string());
90 for (const Edge* o : n.out_edges()) {
91 if (o->IsControlEdge()) continue;
92 if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
93 continue;
94 }
95 (*user_to_source)[o] = src_node_info;
96 }
97 }
98
PropagateFromArgOp(const Node & n,const std::optional<std::string> & function_name,const absl::flat_hash_set<int> & resource_arg_indices,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)99 Status PropagateFromArgOp(
100 const Node& n, const std::optional<std::string>& function_name,
101 const absl::flat_hash_set<int>& resource_arg_indices,
102 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
103 user_to_source) {
104 TF_RET_CHECK(n.type_string() == kArgOp);
105
106 int index;
107 TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index));
108 if (!resource_arg_indices.contains(index)) return OkStatus();
109
110 TF_RET_CHECK(function_name.has_value())
111 << "ResourceUsageAnalysis does not support analyzing _Arg nodes "
112 "carrying Stack/TensorArray resource in given graph unless they "
113 "are in function calls.";
114
115 const ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
116 n.type_string());
117
118 for (const Edge* o : n.out_edges()) {
119 if (o->IsControlEdge()) continue;
120 if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
121 continue;
122 }
123 (*user_to_source)[o] = src_node_info;
124 }
125
126 return OkStatus();
127 }
128
UpdateResourceUsageFromFunctionBodyAnalysis(const Node & call_node,const std::optional<absl::string_view> & caller_function_name,const FunctionBody & fbody,const absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> & called_function_source_to_path,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * caller_source_to_path)129 Status UpdateResourceUsageFromFunctionBodyAnalysis(
130 const Node& call_node,
131 const std::optional<absl::string_view>& caller_function_name,
132 const FunctionBody& fbody,
133 const absl::flat_hash_map<
134 ResourceUsageAnalysis::NodeInfo,
135 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>&
136 called_function_source_to_path,
137 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
138 user_to_source,
139 absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
140 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
141 caller_source_to_path) {
142 std::unordered_map<std::string, Node*> node_name_index =
143 fbody.graph->BuildNodeNameIndex();
144 for (const auto& it : called_function_source_to_path) {
145 ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
146
147 // If source is an _Arg, then the true source is actually corresponding
148 // edge that feeds into function call node with the same index.
149 if (src_node_info.op_ == kArgOp) {
150 const Node* arg_src = node_name_index[src_node_info.node_name_];
151 int index;
152 TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
153
154 const Edge* e;
155 // TODO(ycao): Allow overriding input_edge to _Arg index mapping. This is
156 // needed for cond function of while nodes.
157 TF_RETURN_IF_ERROR(call_node.input_edge(index, &e));
158 src_node_info = (*user_to_source)[e];
159 }
160
161 for (const auto& dst_node_info : it.second) {
162 // If user is an _Retval, then the true user is actually corresponding
163 // edge of that _Retval.
164 if (dst_node_info.op_ == kRetvalOp) {
165 const Node* ret_user = node_name_index[dst_node_info.node_name_];
166 int index;
167 TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
168
169 absl::InlinedVector<const Edge*, 1> outs;
170 // TODO(ycao): Allow overriding _Retval index to call node output edge
171 // mapping. This is needed for cond function of while nodes.
172 TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(call_node, index));
173 for (const Edge* o : outs) (*user_to_source)[o] = src_node_info;
174 } else {
175 (*caller_source_to_path)[src_node_info].emplace(dst_node_info);
176 }
177 }
178 }
179
180 return OkStatus();
181 }
182
PropagateThroughCallOp(const Node & n,const std::optional<std::string> & function_name,const int call_depth,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * source_to_path)183 Status PropagateThroughCallOp(
184 const Node& n, const std::optional<std::string>& function_name,
185 const int call_depth, FunctionLibraryRuntime* lib_runtime,
186 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
187 user_to_source,
188 absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
189 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
190 source_to_path) {
191 if (call_depth > kMaxCallDepth) {
192 return errors::InvalidArgument(
193 "Function call stack in given graph is too deep, last function ",
194 "name is: ", function_name.value());
195 }
196 // resource_arg_indices contains all indices of the input
197 // arguments that carry Stack/TensorArray resource handles.
198 absl::flat_hash_set<int> resource_arg_indices;
199 for (const Edge* e : n.in_edges()) {
200 if (user_to_source->contains(e)) {
201 resource_arg_indices.emplace(e->dst_input());
202 }
203 }
204
205 // Instantiate associated function to get function body.
206 FunctionLibraryRuntime::Handle handle;
207 TF_RETURN_IF_ERROR(InstantiateFunctionCall(n.def(), lib_runtime, &handle));
208 auto release_handle_on_return = gtl::MakeCleanup(
209 [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
210 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
211
212 // Recursively analyze called function for resource sources and users.
213 absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
214 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
215 called_function_source_to_path;
216 TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
217 fbody->graph, n.type_string(), call_depth + 1, resource_arg_indices,
218 lib_runtime, &called_function_source_to_path));
219
220 TF_RETURN_IF_ERROR(UpdateResourceUsageFromFunctionBodyAnalysis(
221 n, function_name, *fbody, called_function_source_to_path, user_to_source,
222 source_to_path));
223 return OkStatus();
224 }
225
226 // Analyzes pass through values for Identity and IdentityN ops.
PropagateThroughIdentityOp(const Node & n,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)227 Status PropagateThroughIdentityOp(
228 const Node& n,
229 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
230 user_to_source) {
231 TF_RET_CHECK(n.IsIdentity() || n.type_string() == kIdentityNOp);
232 if (n.IsIdentity()) {
233 for (const Edge* o : n.out_edges()) {
234 if (o->IsControlEdge()) continue;
235 const Edge* in;
236 TF_RETURN_IF_ERROR(n.input_edge(0, &in));
237 if (!user_to_source->contains(in)) continue;
238 user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
239 }
240 } else {
241 for (const Edge* o : n.out_edges()) {
242 if (o->IsControlEdge()) continue;
243 const Edge* in;
244 TF_RETURN_IF_ERROR(n.input_edge(o->src_output(), &in));
245 if (!user_to_source->contains(in)) continue;
246 user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
247 }
248 }
249
250 return OkStatus();
251 }
252
AnalyzeResourceUsage(const Graph * graph,const std::optional<std::string> & function_name,const int call_depth,const absl::flat_hash_set<int> & resource_arg_indices,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * source_to_path)253 Status AnalyzeResourceUsage(
254 const Graph* graph, const std::optional<std::string>& function_name,
255 const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
256 FunctionLibraryRuntime* lib_runtime,
257 absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
258 absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
259 source_to_path) {
260 source_to_path->clear();
261
262 std::vector<Node*> reverse_post_order;
263 GetReversePostOrder(*graph, &reverse_post_order, NodeComparatorName{});
264
265 // user_to_source maps from an edge carrying a Stack or TensorArray resource
266 // to the node that created this resource.
267 absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>
268 user_to_source;
269 for (const Node* n : reverse_post_order) {
270 if (IsControlFlowV1Node(n)) {
271 return errors::InvalidArgument(
272 "AnalyzeResourceUsage does not support control flow v1 node: ",
273 n->DebugString());
274 }
275
276 // TODO(ycao): Support pass-through functional while/if nodes.
277 if (n->type_string() == kIfOp || n->type_string() == kWhileOp) {
278 return errors::InvalidArgument(
279 "AnalyzeResourceUsage does not yet support control flow v2 "
280 "node: ",
281 n->DebugString());
282 }
283
284 // Record a resource source edge.
285 if (IsStackOrTensorArraySource(*n)) {
286 PropagateFromStackOrTensorArraySourceOp(*n, function_name,
287 &user_to_source);
288 continue;
289 }
290
291 // Arguments that are listed in resource_arg_indices are also considered as
292 // resource sources.
293 if (n->IsArg()) {
294 TF_RETURN_IF_ERROR(PropagateFromArgOp(
295 *n, function_name, resource_arg_indices, &user_to_source));
296 continue;
297 }
298
299 // Recursively analyze function call ops.
300 if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
301 TF_RETURN_IF_ERROR(PropagateThroughCallOp(*n, function_name, call_depth,
302 lib_runtime, &user_to_source,
303 source_to_path));
304 continue;
305 }
306
307 if (n->IsIdentity() || n->type_string() == kIdentityNOp) {
308 TF_RETURN_IF_ERROR(PropagateThroughIdentityOp(*n, &user_to_source));
309 }
310 }
311
312 for (const auto& it : user_to_source) {
313 (*source_to_path)[it.second].emplace(function_name, it.first->dst()->name(),
314 it.first->dst()->type_string());
315 }
316
317 return OkStatus();
318 }
319
320 } // anonymous namespace
321
Analyze(const Graph * graph,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<NodeInfo,absl::flat_hash_set<NodeInfo>> * source_to_path)322 /*Static*/ Status ResourceUsageAnalysis::Analyze(
323 const Graph* graph, FunctionLibraryRuntime* lib_runtime,
324 absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
325 source_to_path) {
326 return AnalyzeResourceUsage(
327 graph, /*function_name=*/{}, /*call_depth=*/0,
328 /*resource_arg_indices=*/absl::flat_hash_set<int>(), lib_runtime,
329 source_to_path);
330 }
331
332 } // namespace tensorflow
333