xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/resource_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/resource_util.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/tf2xla/resource_operation_table.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 
32 namespace tensorflow {
33 namespace {
34 
35 using stream_executor::port::StatusOr;
36 
37 const char kIdentityNOp[] = "IdentityN";
38 const char kIfOp[] = "If";
39 const char kWhileOp[] = "While";
40 const char kArgOp[] = "_Arg";
41 const char kRetvalOp[] = "_Retval";
42 
43 const int kMaxCallDepth = 100;
44 
45 Status AnalyzeResourceUsage(
46     const Graph* graph, const std::optional<std::string>& function_name,
47     const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
48     FunctionLibraryRuntime* lib_runtime,
49     absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
50                         absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
51         source_to_path);
52 
IsControlFlowV1Node(const Node * n)53 bool IsControlFlowV1Node(const Node* n) {
54   return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
55           n->IsNextIteration());
56 }
57 
58 // TODO(ycao): Add this as Tensorflow Node method.
OutputEdgesByIndex(const Node & n,int idx)59 StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node& n,
60                                                                  int idx) {
61   absl::InlinedVector<const Edge*, 1> res;
62   if (idx >= n.num_outputs()) {
63     return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
64                                    n.name(), " only has ", n.num_outputs(),
65                                    " outputs.");
66   }
67 
68   for (const Edge* o : n.out_edges()) {
69     if (o->src_output() == idx) res.emplace_back(o);
70   }
71   return res;
72 }
73 
IsStackOrTensorArraySource(const Node & n)74 bool IsStackOrTensorArraySource(const Node& n) {
75   const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
76 
77   if (!op_info) return false;
78   if (op_info->resource_kind() != XlaResourceKind::kStack &&
79       op_info->resource_kind() != XlaResourceKind::kTensorArray)
80     return false;
81   return n.num_outputs() > 0 && n.output_type(0) == DataType::DT_RESOURCE;
82 }
83 
PropagateFromStackOrTensorArraySourceOp(const Node & n,const std::optional<std::string> & function_name,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)84 void PropagateFromStackOrTensorArraySourceOp(
85     const Node& n, const std::optional<std::string>& function_name,
86     absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
87         user_to_source) {
88   ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
89                                                 n.type_string());
90   for (const Edge* o : n.out_edges()) {
91     if (o->IsControlEdge()) continue;
92     if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
93       continue;
94     }
95     (*user_to_source)[o] = src_node_info;
96   }
97 }
98 
PropagateFromArgOp(const Node & n,const std::optional<std::string> & function_name,const absl::flat_hash_set<int> & resource_arg_indices,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)99 Status PropagateFromArgOp(
100     const Node& n, const std::optional<std::string>& function_name,
101     const absl::flat_hash_set<int>& resource_arg_indices,
102     absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
103         user_to_source) {
104   TF_RET_CHECK(n.type_string() == kArgOp);
105 
106   int index;
107   TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index));
108   if (!resource_arg_indices.contains(index)) return OkStatus();
109 
110   TF_RET_CHECK(function_name.has_value())
111       << "ResourceUsageAnalysis does not support analyzing _Arg nodes "
112          "carrying Stack/TensorArray resource in given graph unless they "
113          "are in function calls.";
114 
115   const ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
116                                                       n.type_string());
117 
118   for (const Edge* o : n.out_edges()) {
119     if (o->IsControlEdge()) continue;
120     if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
121       continue;
122     }
123     (*user_to_source)[o] = src_node_info;
124   }
125 
126   return OkStatus();
127 }
128 
UpdateResourceUsageFromFunctionBodyAnalysis(const Node & call_node,const std::optional<absl::string_view> & caller_function_name,const FunctionBody & fbody,const absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> & called_function_source_to_path,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * caller_source_to_path)129 Status UpdateResourceUsageFromFunctionBodyAnalysis(
130     const Node& call_node,
131     const std::optional<absl::string_view>& caller_function_name,
132     const FunctionBody& fbody,
133     const absl::flat_hash_map<
134         ResourceUsageAnalysis::NodeInfo,
135         absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>&
136         called_function_source_to_path,
137     absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
138         user_to_source,
139     absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
140                         absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
141         caller_source_to_path) {
142   std::unordered_map<std::string, Node*> node_name_index =
143       fbody.graph->BuildNodeNameIndex();
144   for (const auto& it : called_function_source_to_path) {
145     ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
146 
147     // If source is an _Arg, then the true source is actually corresponding
148     // edge that feeds into function call node with the same index.
149     if (src_node_info.op_ == kArgOp) {
150       const Node* arg_src = node_name_index[src_node_info.node_name_];
151       int index;
152       TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
153 
154       const Edge* e;
155       // TODO(ycao): Allow overriding input_edge to _Arg index mapping. This is
156       // needed for cond function of while nodes.
157       TF_RETURN_IF_ERROR(call_node.input_edge(index, &e));
158       src_node_info = (*user_to_source)[e];
159     }
160 
161     for (const auto& dst_node_info : it.second) {
162       // If user is an _Retval, then the true user is actually corresponding
163       // edge of that _Retval.
164       if (dst_node_info.op_ == kRetvalOp) {
165         const Node* ret_user = node_name_index[dst_node_info.node_name_];
166         int index;
167         TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
168 
169         absl::InlinedVector<const Edge*, 1> outs;
170         // TODO(ycao): Allow overriding _Retval index to call node output edge
171         // mapping. This is needed for cond function of while nodes.
172         TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(call_node, index));
173         for (const Edge* o : outs) (*user_to_source)[o] = src_node_info;
174       } else {
175         (*caller_source_to_path)[src_node_info].emplace(dst_node_info);
176       }
177     }
178   }
179 
180   return OkStatus();
181 }
182 
PropagateThroughCallOp(const Node & n,const std::optional<std::string> & function_name,const int call_depth,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * source_to_path)183 Status PropagateThroughCallOp(
184     const Node& n, const std::optional<std::string>& function_name,
185     const int call_depth, FunctionLibraryRuntime* lib_runtime,
186     absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
187         user_to_source,
188     absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
189                         absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
190         source_to_path) {
191   if (call_depth > kMaxCallDepth) {
192     return errors::InvalidArgument(
193         "Function call stack in given graph is too deep, last function ",
194         "name is: ", function_name.value());
195   }
196   // resource_arg_indices contains all indices of the input
197   // arguments that carry Stack/TensorArray resource handles.
198   absl::flat_hash_set<int> resource_arg_indices;
199   for (const Edge* e : n.in_edges()) {
200     if (user_to_source->contains(e)) {
201       resource_arg_indices.emplace(e->dst_input());
202     }
203   }
204 
205   // Instantiate associated function to get function body.
206   FunctionLibraryRuntime::Handle handle;
207   TF_RETURN_IF_ERROR(InstantiateFunctionCall(n.def(), lib_runtime, &handle));
208   auto release_handle_on_return = gtl::MakeCleanup(
209       [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
210   const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
211 
212   // Recursively analyze called function for resource sources and users.
213   absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
214                       absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
215       called_function_source_to_path;
216   TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
217       fbody->graph, n.type_string(), call_depth + 1, resource_arg_indices,
218       lib_runtime, &called_function_source_to_path));
219 
220   TF_RETURN_IF_ERROR(UpdateResourceUsageFromFunctionBodyAnalysis(
221       n, function_name, *fbody, called_function_source_to_path, user_to_source,
222       source_to_path));
223   return OkStatus();
224 }
225 
226 // Analyzes pass through values for Identity and IdentityN ops.
PropagateThroughIdentityOp(const Node & n,absl::flat_hash_map<const Edge *,ResourceUsageAnalysis::NodeInfo> * user_to_source)227 Status PropagateThroughIdentityOp(
228     const Node& n,
229     absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
230         user_to_source) {
231   TF_RET_CHECK(n.IsIdentity() || n.type_string() == kIdentityNOp);
232   if (n.IsIdentity()) {
233     for (const Edge* o : n.out_edges()) {
234       if (o->IsControlEdge()) continue;
235       const Edge* in;
236       TF_RETURN_IF_ERROR(n.input_edge(0, &in));
237       if (!user_to_source->contains(in)) continue;
238       user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
239     }
240   } else {
241     for (const Edge* o : n.out_edges()) {
242       if (o->IsControlEdge()) continue;
243       const Edge* in;
244       TF_RETURN_IF_ERROR(n.input_edge(o->src_output(), &in));
245       if (!user_to_source->contains(in)) continue;
246       user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
247     }
248   }
249 
250   return OkStatus();
251 }
252 
AnalyzeResourceUsage(const Graph * graph,const std::optional<std::string> & function_name,const int call_depth,const absl::flat_hash_set<int> & resource_arg_indices,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>> * source_to_path)253 Status AnalyzeResourceUsage(
254     const Graph* graph, const std::optional<std::string>& function_name,
255     const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
256     FunctionLibraryRuntime* lib_runtime,
257     absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
258                         absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
259         source_to_path) {
260   source_to_path->clear();
261 
262   std::vector<Node*> reverse_post_order;
263   GetReversePostOrder(*graph, &reverse_post_order, NodeComparatorName{});
264 
265   // user_to_source maps from an edge carrying a Stack or TensorArray resource
266   // to the node that created this resource.
267   absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>
268       user_to_source;
269   for (const Node* n : reverse_post_order) {
270     if (IsControlFlowV1Node(n)) {
271       return errors::InvalidArgument(
272           "AnalyzeResourceUsage does not support control flow v1 node: ",
273           n->DebugString());
274     }
275 
276     // TODO(ycao): Support pass-through functional while/if nodes.
277     if (n->type_string() == kIfOp || n->type_string() == kWhileOp) {
278       return errors::InvalidArgument(
279           "AnalyzeResourceUsage does not yet support control flow v2 "
280           "node: ",
281           n->DebugString());
282     }
283 
284     // Record a resource source edge.
285     if (IsStackOrTensorArraySource(*n)) {
286       PropagateFromStackOrTensorArraySourceOp(*n, function_name,
287                                               &user_to_source);
288       continue;
289     }
290 
291     // Arguments that are listed in resource_arg_indices are also considered as
292     // resource sources.
293     if (n->IsArg()) {
294       TF_RETURN_IF_ERROR(PropagateFromArgOp(
295           *n, function_name, resource_arg_indices, &user_to_source));
296       continue;
297     }
298 
299     // Recursively analyze function call ops.
300     if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
301       TF_RETURN_IF_ERROR(PropagateThroughCallOp(*n, function_name, call_depth,
302                                                 lib_runtime, &user_to_source,
303                                                 source_to_path));
304       continue;
305     }
306 
307     if (n->IsIdentity() || n->type_string() == kIdentityNOp) {
308       TF_RETURN_IF_ERROR(PropagateThroughIdentityOp(*n, &user_to_source));
309     }
310   }
311 
312   for (const auto& it : user_to_source) {
313     (*source_to_path)[it.second].emplace(function_name, it.first->dst()->name(),
314                                          it.first->dst()->type_string());
315   }
316 
317   return OkStatus();
318 }
319 
320 }  // anonymous namespace
321 
Analyze(const Graph * graph,FunctionLibraryRuntime * lib_runtime,absl::flat_hash_map<NodeInfo,absl::flat_hash_set<NodeInfo>> * source_to_path)322 /*Static*/ Status ResourceUsageAnalysis::Analyze(
323     const Graph* graph, FunctionLibraryRuntime* lib_runtime,
324     absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
325         source_to_path) {
326   return AnalyzeResourceUsage(
327       graph, /*function_name=*/{}, /*call_depth=*/0,
328       /*resource_arg_indices=*/absl::flat_hash_set<int>(), lib_runtime,
329       source_to_path);
330 }
331 
332 }  // namespace tensorflow
333