xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/tf2xla_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/function_body.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/graph_def_util.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_def_builder.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/errors.h"
45 
46 namespace tensorflow {
47 
48 namespace {
49 
ValidateTensorId(const tf2xla::TensorId & id)50 Status ValidateTensorId(const tf2xla::TensorId& id) {
51   if (id.node_name().empty()) {
52     return errors::InvalidArgument("TensorId node_name must be non-empty");
53   }
54   if (id.output_index() < 0) {
55     return errors::InvalidArgument("TensorId output_index must be positive");
56   }
57   return OkStatus();
58 }
59 
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)60 Status CheckNameDuplicates(const string& kind, const string& name,
61                            std::set<string>* names) {
62   if (!name.empty()) {
63     if (!names->insert(name).second) {
64       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
65     }
66   }
67   return OkStatus();
68 }
69 
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)70 Status CheckFeedFetchNameConflicts(const string& kind,
71                                    const std::set<string>& names) {
72   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
73   // since that will cause a collision in codegen symbols.
74   for (const string& name : names) {
75     const string name_data(name + "_data");
76     if (names.find(name_data) != names.end()) {
77       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
78                                      " and ", name_data);
79     }
80   }
81   return OkStatus();
82 }
83 
84 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
85 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)86 Status CopyAssociatedFunctions(Graph* g,
87                                const FunctionLibraryDefinition* lookup_fld,
88                                FunctionLibraryDefinition* fld) {
89   for (Node* n : g->op_nodes()) {
90     for (const auto& associated_function :
91          GetAssociatedFunctions(*n, lookup_fld)) {
92       switch (associated_function.type()) {
93         case AssociatedFunctionInfo::kFunctionCallNode: {
94           const FunctionDef* fdef =
95               lookup_fld->Find(associated_function.func_name());
96           if (!fdef) {
97             return errors::Internal(
98                 "Cannot find function ", associated_function.func_name(),
99                 " for function call node ", n->DebugString());
100           }
101           TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
102           break;
103         }
104         case AssociatedFunctionInfo::kSymbolicGradient:
105         case AssociatedFunctionInfo::kFunctionAttr:
106           break;
107       }
108     }
109   }
110   return OkStatus();
111 }
112 
113 // Replaces the single edge feeding into {dst,dst_input} with a new
114 // src/src_output specified by {with,with_output}.
ReplaceEdge(Graph * g,Node * dst,int dst_input,Node * with,int with_output)115 StatusOr<Node*> ReplaceEdge(Graph* g, Node* dst, int dst_input, Node* with,
116                             int with_output) {
117   NodeDef replace_def = dst->def();
118   *replace_def.mutable_input(dst_input) = with->name();
119   TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, replace_def));
120   const Edge* usage_edge;
121   TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &usage_edge));
122   g->RemoveEdge(usage_edge);
123   g->AddEdge(with, with_output, replace_node, dst_input);
124   return replace_node;
125 }
126 
127 // Replaces usages of the given `src_output` index of the given `src` node with
128 // the given `replacement` node (assumes the :0 output of `replacement`).
ReplaceSrcOutputUsageWithNode(Graph * g,Node * src,int src_output,Node * replacement)129 Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output,
130                                      Node* replacement) {
131   VLOG(1) << "Replace usages of output " << src_output << " of node "
132           << (VLOG_IS_ON(3) ? src->DebugString() : src->name()) << " with "
133           << (VLOG_IS_ON(3) ? replacement->DebugString() : replacement->name());
134   // Collect all usages of the specified src output (src->out_edges() iterator
135   // will not be stable under modifications).
136   struct OutEdgeInfo {
137     int dst_node_id, dst_input;
138   };
139   std::vector<OutEdgeInfo> usages;
140   for (const Edge* e : src->out_edges()) {
141     if (e->IsControlEdge() || e->src_output() != src_output) {
142       continue;
143     }
144     usages.push_back({e->dst()->id(), e->dst_input()});
145   }
146 
147   // Now, replace each usage.
148   for (int i = 0, end = usages.size(); i < end; i++) {
149     // Make a copy of `usage_node`, and change its input to const node.
150     Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
151     VLOG(2) << "  Replace usage by " << usage_node->DebugString();
152     // Note: Replacement output index is presumed to be 0.
153     TF_ASSIGN_OR_RETURN(
154         Node * replace_node,
155         ReplaceEdge(g, usage_node, usages[i].dst_input, replacement, 0));
156     // Later entries in `usages` might have `usage_node` as dst node, but
157     // `usage_node` is removed. Replace such entries with `replace_node`.
158     for (int j = i + 1, end = usages.size(); j < end; j++) {
159       if (usages[j].dst_node_id == usages[i].dst_node_id) {
160         usages[j].dst_node_id = replace_node->id();
161       }
162     }
163   }
164   return OkStatus();
165 }
166 
167 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
168 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)169 Status ReplaceArgUsageWithConstNode(
170     Graph* g,
171     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
172   // Collect all _Arg nodes.
173   absl::flat_hash_map<int, Node*> arg_nodes;
174   for (Node* n : g->op_nodes()) {
175     if (n->IsArg()) {
176       int index;
177       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
178       arg_nodes[index] = n;
179     }
180   }
181 
182   for (const auto& iter : const_input_index_to_node) {
183     int arg_index = iter.first;
184     VLOG(2) << "Replace usages of _Arg " << arg_index;
185     NodeDef const_def = iter.second->def();
186     const_def.set_name(g->NewName(const_def.name()));
187     TF_ASSIGN_OR_RETURN(Node * const_node, g->AddNode(const_def));
188     Node* arg_node = arg_nodes[arg_index];
189     TF_RETURN_IF_ERROR(
190         ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node));
191   }
192   return OkStatus();
193 }
194 
195 // Replaces the single input to _Retval nodes with an index in the keys of
196 // const_input_index_to_node with the single output of the corresponding _Arg
197 // node.
ReplaceRetvalInputWithArg(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)198 Status ReplaceRetvalInputWithArg(
199     Graph* g,
200     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
201   absl::flat_hash_map<int, Node*> arg_nodes;
202   absl::flat_hash_map<int, Node*> ret_nodes;
203   for (Node* n : g->op_nodes()) {
204     if (n->IsRetval() || n->IsArg()) {
205       int index;
206       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
207       if (n->IsRetval()) {
208         ret_nodes[index] = n;
209       } else {
210         arg_nodes[index] = n;
211       }
212     }
213   }
214 
215   for (const auto& iter : const_input_index_to_node) {
216     int arg_index = iter.first;
217     VLOG(2) << "Bind _Retval " << arg_index << " to _Arg " << arg_index;
218     TF_RETURN_IF_ERROR(
219         ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0)
220             .status());
221   }
222   return OkStatus();
223 }
224 
225 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
226 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
227 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld,bool passthrough_arg_to_retval=false)228 Status PropagateConstIntoFuncAttr(
229     Node* n, const string& attr_name,
230     const absl::flat_hash_map<int, const Node*>& const_input_index_to_node,
231     const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld,
232     bool passthrough_arg_to_retval = false) {
233   VLOG(1) << "Propagate const into " << attr_name << " of node " << n->name();
234   // Instantiate the function.
235   NameAttrList func_attr;
236   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
237   const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
238   if (!fdef) {
239     return errors::Internal("Cannot find function ", func_attr.name(),
240                             " for node ", n->name());
241   }
242   std::unique_ptr<FunctionBody> fbody;
243   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
244       *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
245 
246   // Rewrite _Arg usages with Const node.
247   Graph* func_graph = fbody->graph;
248   TF_RETURN_IF_ERROR(
249       ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
250   if (passthrough_arg_to_retval) {
251     TF_RETURN_IF_ERROR(
252         ReplaceRetvalInputWithArg(func_graph, const_input_index_to_node));
253   }
254 
255   // Save rewritten function.
256   FunctionDef replace_fdef;
257   string new_func_name =
258       fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
259   TF_RETURN_IF_ERROR(
260       GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
261   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
262       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
263 
264   VLOG(1) << "replace func " << func_attr.name() << " with " << new_func_name;
265   // Change the node to use rewritten function.
266   func_attr.set_name(new_func_name);
267   n->ClearAttr(attr_name);
268   n->AddAttr(attr_name, func_attr);
269 
270   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
271       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
272 
273   // Copy associated functions.
274   TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
275 
276   return OkStatus();
277 }
278 
279 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
280 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)281 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
282                                 const FunctionLibraryDefinition* lookup_fld,
283                                 FunctionLibraryDefinition* fld) {
284   // Notice that first input for If node is predicate; other inputs are function
285   // inputs.
286   absl::flat_hash_map<int, const Node*> const_input_index_to_node;
287   for (int i = 1; i < if_node->num_inputs(); i++) {
288     const Node* input_node;
289     TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
290     if (input_node->type_string() == "Const") {
291       const_input_index_to_node[i - 1] = input_node;
292     }
293   }
294   if (const_input_index_to_node.empty()) {
295     return OkStatus();
296   }
297 
298   // Rewrite "then_branch" and "else_branch" function, replace usage of those
299   // _Arg nodes with corresponding const node.
300   for (const auto& attr_name :
301        std::vector<string>{"then_branch", "else_branch"}) {
302     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
303         if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
304   }
305 
306   return OkStatus();
307 }
308 
309 using GraphCache = absl::flat_hash_map<string, std::unique_ptr<FunctionBody>>;
310 
FindOrInsert(GraphCache * cache,const NameAttrList & body_attr,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld)311 StatusOr<FunctionBody*> FindOrInsert(
312     GraphCache* cache, const NameAttrList& body_attr,
313     const FunctionLibraryDefinition* lookup_fld,
314     const FunctionLibraryDefinition* fallback_fld) {
315   const string name = body_attr.name();
316   std::unique_ptr<FunctionBody>& value = (*cache)[name];
317   if (!value) {
318     const FunctionDef* body_func = lookup_fld->Find(name);
319     if (!body_func && fallback_fld != nullptr) {
320       body_func = fallback_fld->Find(name);
321     }
322     if (!body_func) {
323       return errors::Internal("Traverse: Cannot find body function ", name);
324     }
325     std::unique_ptr<FunctionBody> fbody;
326     Status s = FunctionDefToBodyHelper(*body_func, AttrSlice(&body_attr.attr()),
327                                        lookup_fld, &fbody);
328     if (!s.ok() && fallback_fld != nullptr) {
329       TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
330           *body_func, AttrSlice(&body_attr.attr()), fallback_fld, &fbody));
331     }
332     value = std::move(fbody);
333   }
334   return value.get();
335 }
336 // Determines whether a loop body is invariant for the given argument index.
337 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
338                                const FunctionLibraryDefinition* lookup_fld,
339                                const FunctionLibraryDefinition* fallback_fld,
340                                GraphCache* cache);
341 
342 // Traces backward through non-modifying ops such as Identity and loop-invariant
343 // While, to find a preceding source edge.
TraverseUnmodifiedPathBackward(const Edge * src,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)344 StatusOr<const Edge*> TraverseUnmodifiedPathBackward(
345     const Edge* src, const FunctionLibraryDefinition* lookup_fld,
346     const FunctionLibraryDefinition* fallback_fld, GraphCache* cache) {
347   const Edge* e = src;
348   VLOG(2) << "Traverse: Begin at " << e->DebugString();
349   // TODO(b/184727356): Also traverse If/Case nodes.
350   // Begin walking back from the output node.
351   while (IsConstTraversableOpType(e->src())) {
352     VLOG(3) << e->DebugString();
353 
354     if (e->src()->IsWhileNode()) {
355       NameAttrList body_attr;
356       TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "body", &body_attr));
357       TF_ASSIGN_OR_RETURN(
358           FunctionBody * fbody,
359           FindOrInsert(cache, body_attr, lookup_fld, fallback_fld));
360       TF_ASSIGN_OR_RETURN(bool is_loop_invariant,
361                           IsLoopInvariant(fbody, e->src_output(), lookup_fld,
362                                           fallback_fld, cache));
363       if (!is_loop_invariant) {
364         VLOG(2) << "Non-loop-invariant: index " << e->src_output() << " of "
365                 << body_attr.name();
366         break;
367       }
368     }  // if While|StatelessWhile
369     // Proceed backward to the src's input corresponding with the output index.
370     TF_RETURN_IF_ERROR(e->src()->input_edge(e->src_output(), &e));
371   }
372   VLOG(2) << "Traverse: Finish at " << e->DebugString();
373 
374   return e;
375 }
376 
377 // Determines whether a loop body is invariant for the given argument index.
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)378 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
379                                const FunctionLibraryDefinition* lookup_fld,
380                                const FunctionLibraryDefinition* fallback_fld,
381                                GraphCache* cache) {
382   const Edge* e;
383   TF_RETURN_IF_ERROR(loop_body->ret_nodes[index]->input_edge(0, &e));
384   TF_ASSIGN_OR_RETURN(
385       const Edge* reachable,
386       TraverseUnmodifiedPathBackward(e, lookup_fld, fallback_fld, cache));
387   if (reachable->src()->id() == loop_body->arg_nodes[index]->id()) {
388     VLOG(2) << "Index " << index << " is loop invariant.";
389     return true;
390   }
391   VLOG(2) << "Index " << index << " not loop invariant: "
392           << "walk backward from " << e->src()->DebugString() << " to "
393           << reachable->src()->DebugString() << " did not reach "
394           << loop_body->arg_nodes[index]->DebugString();
395   return false;
396 }
397 
398 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
399 // cond/body function to replace _Arg nodes with those Const inputs. Then,
400 // propagate these Const to consumers of the relevant outputs of the while loop.
PropagateConstIntoAndAroundWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)401 Status PropagateConstIntoAndAroundWhileNode(
402     Graph* g, Node* while_node, const FunctionLibraryDefinition* lookup_fld,
403     FunctionLibraryDefinition* fld) {
404   VLOG(1) << "Propagate const into " << while_node->name();
405 
406   // For "While" node, we should only replace _Arg nodes which are loop
407   // invariants. For such _Arg nodes, the return value's input will come
408   // directly from the corresponding arg.
409   absl::flat_hash_map<int, const Node*> const_input_index_to_node;
410   absl::flat_hash_map<int, Node*> const_input_index_to_mutable_node;
411   NameAttrList body_attr;
412   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
413   const string fn_name = body_attr.name();
414   const FunctionDef* body_func = lookup_fld->Find(fn_name);
415   if (!body_func) {
416     return errors::Internal("Propagate: Cannot find body function ", fn_name,
417                             " for While node ", while_node->name());
418   }
419   std::unique_ptr<FunctionBody> fbody;
420   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
421       *body_func, AttrSlice(&body_attr.attr()), lookup_fld, &fbody));
422   GraphCache cache;
423   for (int i = 0; i < while_node->num_inputs(); i++) {
424     // Check if i-th retval's input comes from i-th arg directly.
425     // For resource variable input of While nodes, TF2XLA convention is to place
426     // them at the end of all inputs (after all data inputs), and *not* return
427     // them. So number of While node inputs might be larger than number of its
428     // outputs.
429     if (i >= body_func->signature().output_arg_size()) {
430       break;
431     }
432 
433     const Edge* input_edge;
434     TF_RETURN_IF_ERROR(while_node->input_edge(i, &input_edge));
435     TF_ASSIGN_OR_RETURN(input_edge, TraverseUnmodifiedPathBackward(
436                                         input_edge, lookup_fld, fld, &cache));
437     if (!input_edge->src()->IsConstant()) {
438       VLOG(2) << "Input " << i << " is not Const; is "
439               << input_edge->src()->type_string();
440       continue;
441     }
442 
443     TF_ASSIGN_OR_RETURN(
444         bool is_loop_invariant,
445         IsLoopInvariant(fbody.get(), i, lookup_fld, fld, &cache));
446     if (!is_loop_invariant) {
447       VLOG(2) << "While state not loop-invariant; not propagating Const " << i;
448       continue;
449     }
450     VLOG(2) << "While state is loop-invariant; propagating Const " << i;
451 
452     const_input_index_to_mutable_node[i] = input_edge->src();
453     const_input_index_to_node[i] = input_edge->src();
454   }
455   if (const_input_index_to_node.empty()) {
456     return OkStatus();
457   }
458 
459   // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
460   // corresponding const node.
461   for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
462     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
463         while_node, attr_name, const_input_index_to_node, lookup_fld, fld,
464         /*passthrough_arg_to_retval=*/attr_name == "body"));
465   }
466 
467   // Rewrite usages of the output edges corresponding to loop-invariant const
468   // inputs to refer instead to the Const node.
469   for (const auto& it : const_input_index_to_mutable_node) {
470     TF_RETURN_IF_ERROR(
471         ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second));
472   }
473   return OkStatus();
474 }
475 
476 }  // namespace
477 
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld)478 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
479                                const FunctionLibraryDefinition* lookup_fld) {
480   GraphCache cache;
481   return IsLoopInvariant(loop_body, index, lookup_fld,
482                          /*fallback_fld=*/nullptr, &cache);
483 }
484 
ValidateConfig(const tf2xla::Config & config)485 Status ValidateConfig(const tf2xla::Config& config) {
486   std::set<string> names;
487   for (const tf2xla::Feed& feed : config.feed()) {
488     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
489     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
490     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
491   }
492   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
493   names.clear();
494   for (const tf2xla::Fetch& fetch : config.fetch()) {
495     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
496     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
497   }
498   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
499   if (config.fetch().empty()) {
500     return errors::InvalidArgument("fetches must be specified");
501   }
502   return OkStatus();
503 }
504 
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)505 Status AddPlaceholdersForFeeds(
506     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
507     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
508   struct PlaceholderInfo {
509     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
510     string placeholder_name;
511     DataType data_type = DT_INVALID;
512   };
513 
514   // Put each fed tensor into a map by name:port. A map is used for determinism
515   // when creating placeholders (genrules want deterministic output).
516   std::map<string, PlaceholderInfo> placeholder_info;
517   for (int i = 0; i < config.feed_size(); ++i) {
518     const tf2xla::Feed* feed = &config.feed(i);
519     const string name_port = TensorIdToString(feed->id());
520     PlaceholderInfo& info = placeholder_info[name_port];
521     info.feed = feed;
522     info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
523                                          "/", feed->id().node_name());
524     (*feed_remapping)[name_port] = info.placeholder_name;
525   }
526 
527   // Verify node exists and determine data type.
528   std::unordered_map<string, const NodeDef*> name_to_node;
529   for (int i = 0; i < graph_def->node_size(); ++i) {
530     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
531   }
532   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
533     PlaceholderInfo& info = it->second;
534     const tf2xla::TensorId& feed_id = info.feed->id();
535 
536     // Find the existing node and determine data type.
537     auto node_it = name_to_node.find(feed_id.node_name());
538     if (node_it == name_to_node.end()) {
539       return errors::NotFound("Can't find feed node: ",
540                               TensorIdToString(feed_id));
541     }
542     const NodeDef* existing = node_it->second;
543 
544     if (info.feed->type() != DT_INVALID) {
545       info.data_type = info.feed->type();
546     } else {
547       // Build the node in order to infer its type.
548 
549       // Must first add default attrs as well, so do this in a copied GraphDef.
550       GraphDef gd;
551       *gd.mutable_versions() = graph_def->versions();
552       *gd.add_node() = *existing;
553       MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
554       TF_RETURN_IF_ERROR(
555           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
556 
557       // Now build the node from the copied node def.
558       Graph g(op_registry);
559       g.set_versions(graph_def->versions());
560       TF_ASSIGN_OR_RETURN(Node * feed_node, g.AddNode(gd.node(0)));
561 
562       if (info.feed->id().output_index() < feed_node->num_outputs()) {
563         info.data_type =
564             BaseType(feed_node->output_type(info.feed->id().output_index()));
565       } else {
566         return errors::InvalidArgument(
567             "Invalid output_index ", info.feed->id().output_index(),
568             " for feed node ", info.feed->id().node_name());
569       }
570     }
571   }
572 
573   // Create placeholders. Note that we could avoid creating a placeholder for
574   // feeds which are already placeholders, but we omit that to avoid more cases
575   // in this code.
576   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
577     const PlaceholderInfo& info = it->second;
578     // TODO(shikharagarwal): Add original node information.
579     NodeDef* d = graph_def->add_node();
580     d->set_name(info.placeholder_name);
581     d->set_op("Placeholder");
582     auto& attr_map = *d->mutable_attr();
583     attr_map["dtype"].set_type(info.data_type);
584     *attr_map["shape"].mutable_shape() = info.feed->shape();
585   }
586 
587   // Rewrite references to the fed tensors to refer to the placeholder.
588   for (int i = 0; i < graph_def->node_size(); ++i) {
589     NodeDef* node_def = graph_def->mutable_node(i);
590     for (int j = 0; j < node_def->input_size(); ++j) {
591       auto id = ParseTensorName(node_def->input(j));
592       auto it = placeholder_info.find(id.ToString());
593       if (it != placeholder_info.end()) {
594         node_def->set_input(j, it->second.placeholder_name);
595       }
596     }
597   }
598 
599   return OkStatus();
600 }
601 
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)602 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
603                          GraphDef* out) {
604   *out = in;
605   out->clear_node();
606 
607   // Tensors needed for feeding.
608   std::set<std::pair<string, int>> feed_tensors;
609   for (const tf2xla::Feed& feed : config.feed()) {
610     feed_tensors.insert(
611         std::make_pair(feed.id().node_name(), feed.id().output_index()));
612   }
613 
614   // Maps node name to reachability.
615   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
616   for (const NodeDef& node : in.node()) {
617     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
618   }
619 
620   // Traverse.
621   std::queue<string> name_queue;
622   for (int i = 0; i < config.fetch_size(); ++i) {
623     name_queue.push(config.fetch(i).id().node_name());
624   }
625   while (!name_queue.empty()) {
626     const string name = name_queue.front();
627     name_queue.pop();
628 
629     auto find_it = node_by_name.find(name);
630     if (find_it == node_by_name.end()) {
631       return errors::InvalidArgument("While pruning graph, node ", name,
632                                      " needed but not found in the graph.");
633     }
634     auto& map_entry = find_it->second;
635     if (map_entry.first) {
636       continue;
637     }
638     map_entry.first = true;
639 
640     // Push input nodes of the currently visited node to name_queue.
641     for (const string& in_edge : map_entry.second->input()) {
642       auto id = ParseTensorName(in_edge);
643       const string node_name = string(id.first);
644       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
645           feed_tensors.end()) {
646         name_queue.push(node_name);
647       } else {
648         // The input tensor is from an edge that is being fed. Therefore,
649         // we skip recursing down that edge, to avoid requiring nodes that
650         // may not be needed (note that the input node may still be added
651         // to name_queue later if one of its output edges is not being fed).
652       }
653     }
654   }
655 
656   // Copy over, preserving order of original and only nodes that are reachable
657   // from the fetches.
658   out->mutable_node()->Reserve(in.node_size());
659   for (const NodeDef& node : in.node()) {
660     if (node_by_name[node.name()].first) {
661       *out->add_node() = node;
662     }
663   }
664   return OkStatus();
665 }
666 
TensorIdToString(const tf2xla::TensorId & id)667 string TensorIdToString(const tf2xla::TensorId& id) {
668   return absl::StrCat(id.node_name(), ":", id.output_index());
669 }
670 
SetNodeShardingFromNeighbors(Node * n,bool out_edges)671 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
672   int core = -1;
673   const Node* matching_node = nullptr;
674   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
675     if (edge->IsControlEdge()) continue;
676     const Node* possible_match = out_edges ? edge->dst() : edge->src();
677     TF_ASSIGN_OR_RETURN(
678         std::optional<xla::OpSharding> sharding,
679         ParseShardingFromDevice(
680             *possible_match,
681             /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
682             /*add_metadata=*/false));
683     if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
684       const int core_annotation = sharding.value().tile_assignment_devices(0);
685       if (core == -1 || core > core_annotation) {
686         core = core_annotation;
687         matching_node = possible_match;
688       }
689     }
690   }
691   if (matching_node != nullptr) {
692     n->set_assigned_device_name(matching_node->assigned_device_name());
693     n->set_requested_device(matching_node->requested_device());
694   }
695   return OkStatus();
696 }
697 
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)698 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
699                                    KernelDef* kdef) {
700   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
701     if (constraint.name() == name) {
702       constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
703     }
704   }
705 }
706 
707 namespace {
InitialRandomSeed()708 uint32 InitialRandomSeed() {
709   // Support plumbing the TF seed through to XLA is being worked on.
710   // If a user wants deterministic behavior, their best option
711   // is to start with a known checkpoint. This also handles issues when
712   // multiple random calls can be invoked in any order by TF executor.
713   // Another option is to use stateless random ops. They have much cleaner
714   // semantics.
715   // If a user really wants to set a deterministic seed for XLA-based
716   // devices, this is the place to do it.
717   std::random_device rd;
718   // Make the starting value odd.
719   return rd() | 1;
720 }
721 }  // namespace
722 
GetXLARandomSeed()723 uint32 GetXLARandomSeed() {
724   // We initialize counter with an odd number and increment it by two
725   // everytime. This ensures that it will never be zero, even
726   // after an overflow. When seeded with zero, some XLA backends
727   // can return all zeros instead of random numbers.
728   static std::atomic<uint32> counter(InitialRandomSeed());
729   uint32 seed = counter.fetch_add(2);
730   std::srand(seed);
731   return std::rand() | 1;
732 }
733 
734 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)735 bool HasAssociatedFunction(const NodeDef& node_def,
736                            const FunctionLibraryDefinition* fld) {
737   if (fld->Contains(node_def.op())) {
738     return true;
739   }
740 
741   if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
742     // Gradient op has "f" attr, which is set to the function we are getting
743     // gradient for. We need to functionalize the gradient function.
744     return true;
745   }
746 
747   if (node_def.op() == "XlaHostCompute") {
748     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
749     // related to graph execution.
750     return false;
751   }
752 
753   for (const auto& iter : node_def.attr()) {
754     if (iter.second.has_func()) {
755       return true;
756     }
757   }
758 
759   return false;
760 }
761 
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)762 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
763     const Node& node, const FunctionLibraryDefinition* fld) {
764   std::vector<AssociatedFunctionInfo> results;
765   const string& op = node.type_string();
766   if (fld->Contains(op)) {
767     // This is a function call node.
768     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
769     results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
770   } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
771     // This is a SymbolicGradient op.
772     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
773     results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
774   } else if (node.type_string() == "XlaHostCompute") {
775     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
776     // related to graph execution.
777   } else {
778     // Collect all function attrs for the node.
779     for (auto& iter : node.attrs()) {
780       if (iter.second.has_func()) {
781         VLOG(2) << "Found function attr for node " << node.name() << ": "
782                 << iter.first << " = " << iter.second.func().name();
783         results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
784             iter.second.func().name(), iter.second.func().attr(), iter.first));
785       }
786     }
787   }
788   return results;
789 }
790 
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)791 Status RewriteAssociatedFunction(
792     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
793     const AssociatedFunctionInfo& associated_function,
794     const string& rewritten_function_name) {
795   switch (associated_function.type()) {
796     case AssociatedFunctionInfo::kFunctionCallNode: {
797       // Change this node to call the new function.
798       NodeDebugInfo debug_info(*node);
799       NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
800                              &debug_info);
801       for (const auto& attr : node->attrs()) {
802         builder.Attr(attr.first, attr.second);
803       }
804       for (int i = 0; i < node->num_inputs(); i++) {
805         Node* input_node;
806         TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
807         builder.Input(input_node->name(), i, node->input_type(i));
808       }
809       builder.Device(node->assigned_device_name().empty()
810                          ? node->requested_device()
811                          : node->assigned_device_name());
812       NodeDef node_def;
813       TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
814       TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(node_def));
815       for (auto edge : node->in_edges()) {
816         graph->AddEdge(edge->src(), edge->src_output(), new_node,
817                        edge->dst_input());
818       }
819       for (auto edge : node->out_edges()) {
820         graph->AddEdge(new_node, edge->src_output(), edge->dst(),
821                        edge->dst_input());
822       }
823       graph->RemoveNode(node);
824       break;
825     }
826     case AssociatedFunctionInfo::kSymbolicGradient: {
827       NameAttrList func;
828       TF_RETURN_IF_ERROR(GetNodeAttr(
829           node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
830       GradientDef gradient_def;
831       gradient_def.set_function_name(func.name());
832       gradient_def.set_gradient_func(rewritten_function_name);
833       string original_grad_func = fld->FindGradient(func.name());
834       if (original_grad_func.empty()) {
835         TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
836       } else if (original_grad_func != rewritten_function_name) {
837         TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
838       }
839       break;
840     }
841     case AssociatedFunctionInfo::kFunctionAttr: {
842       // Change function attr to rewritten functions.
843       NameAttrList func;
844       TF_RETURN_IF_ERROR(
845           GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
846       // Save the original function name in case TFRT fallbacks to use
847       // TPUPartitionedCall op in the runtime.
848       if (node->type_string() == "TPUPartitionedCall") {
849         node->AddAttr("_orig_f", func.name());
850       }
851       node->ClearAttr(associated_function.attr_name());
852       func.set_name(rewritten_function_name);
853       node->AddAttr(associated_function.attr_name(), func);
854       break;
855     }
856   }
857 
858   return OkStatus();
859 }
860 
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)861 Status CachedFunctionHandles::GetOrInstantiate(
862     const string& func_name, AttrSlice attrs,
863     FunctionLibraryRuntime::Handle* handle) {
864   string canonicalized_name = Canonicalize(func_name, attrs);
865   auto iter = handles_.find(canonicalized_name);
866   if (iter != handles_.end()) {
867     *handle = iter->second;
868     return OkStatus();
869   }
870 
871   TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
872   handles_[canonicalized_name] = *handle;
873   return OkStatus();
874 }
875 
ReleaseAllHandles()876 Status CachedFunctionHandles::ReleaseAllHandles() {
877   Status result;
878   for (const auto& iter : handles_) {
879     result.Update(flr_->ReleaseHandle(iter.second));
880   }
881   handles_.clear();
882   return result;
883 }
884 
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)885 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
886   // Create the replacement node.
887   TF_ASSIGN_OR_RETURN(Node * new_node, g->AddNode(node_def));
888 
889   // Record original node's output edges and remove them first. This is to avoid
890   // multiple producers for dst nodes' input.
891   std::vector<OutEdgeInfo> out_edge_info;
892   std::vector<const Edge*> out_edges;
893   for (const Edge* edge : n->out_edges()) {
894     out_edges.push_back(edge);
895     out_edge_info.push_back(
896         {edge->dst(), edge->src_output(), edge->dst_input()});
897   }
898   for (const Edge* edge : out_edges) {
899     g->RemoveEdge(edge);
900   }
901 
902   // Add original node's input and output edges to the replacement node.
903   for (const Edge* in_edge : n->in_edges()) {
904     g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
905                in_edge->dst_input());
906   }
907   for (const OutEdgeInfo& out_edge : out_edge_info) {
908     g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
909   }
910 
911   // Remove the original node.
912   g->RemoveNode(n);
913 
914   return new_node;
915 }
916 
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,std::optional<string> requested_device)917 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
918                                   DataType dtype, const Node* input,
919                                   std::optional<string> requested_device) {
920   // Create identity node.
921   NodeDef ndef;
922   ndef.set_name(node_name);
923   ndef.set_op("Identity");
924   if (input) {
925     ndef.add_input(input->name());
926   }
927   if (requested_device) {
928     ndef.set_device(*requested_device);
929   }
930   AddNodeAttr("T", dtype, &ndef);
931   TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
932   return id_node;
933 }
934 
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)935 Status PropagateConstIntoFunctionalNodes(
936     Graph* g, const FunctionLibraryDefinition* lookup_fld,
937     FunctionLibraryDefinition* fld) {
938   absl::flat_hash_set<int> done_node_ids;
939 
940   // Because we may propagate Const around a while node as well as into it,
941   // we restart the op_nodes() iterator after each pass and keep track of which
942   // nodes we've already dealt with.
943   bool should_continue = true;
944   while (should_continue) {
945     should_continue = false;
946     for (Node* n : g->op_nodes()) {
947       if (!done_node_ids.contains(n->id())) {
948         if (n->IsIfNode()) {
949           VLOG(1) << "PropagateConstIntoIfNode: " << n->name();
950           TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
951           done_node_ids.emplace(n->id());
952           VLOG(1) << "Done PropagateConstIntoIfNode: " << n->name();
953         } else if (n->IsWhileNode()) {
954           VLOG(1) << "PropagateConstIntoWhileNode: " << n->name();
955           TF_RETURN_IF_ERROR(
956               PropagateConstIntoAndAroundWhileNode(g, n, lookup_fld, fld));
957           done_node_ids.emplace(n->id());
958           should_continue = true;
959           VLOG(1) << "Done PropagateConstIntoWhileNode: " << n->name();
960           break;
961         }
962       }
963     }
964   }
965   return OkStatus();
966 }
967 
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)968 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
969                                           FunctionLibraryDefinition* fld) {
970   GraphDef graph_def;
971   g.ToGraphDef(&graph_def);
972   FunctionLibraryDefinition reachable_functions =
973       fld->ReachableDefinitions(graph_def);
974   for (const string& func_name : fld->ListFunctionNames()) {
975     if (!reachable_functions.Find(func_name)) {
976       TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
977     }
978   }
979   return OkStatus();
980 }
981 
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)982 Status RewriteTensorListWithConstElement(Graph* g,
983                                          FunctionLibraryDefinition* fld) {
984   for (Node* n : g->nodes()) {
985     if (n->type_string() != "EmptyTensorList") {
986       continue;
987     }
988 
989     // Find the forward While op.
990     std::vector<const Edge*> fwd_while_edges;
991     for (const Edge* e : n->out_edges()) {
992       if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
993         fwd_while_edges.push_back(e);
994       }
995     }
996     if (fwd_while_edges.size() != 1) {
997       // No forward While op found, or multiple forward While ops.
998       continue;
999     }
1000 
1001     // Find the backward While op.
1002     Node* fwd_while = fwd_while_edges[0]->dst();
1003     int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
1004     std::vector<const Edge*> bwd_while_edges;
1005     for (const Edge* e : fwd_while->out_edges()) {
1006       if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
1007         bwd_while_edges.push_back(e);
1008       }
1009     }
1010     if (bwd_while_edges.size() != 1) {
1011       // No backward While op found, or multiple backward While ops.
1012       continue;
1013     }
1014 
1015     Node* bwd_while = bwd_while_edges[0]->dst();
1016     int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
1017 
1018     // Look into forward While body function and check if TensorListPushBack op
1019     // has a Const input.
1020     NameAttrList fwd_body_attr;
1021     TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
1022     const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
1023     if (!fwd_body) {
1024       return errors::InvalidArgument("Cannot find function ",
1025                                      fwd_body_attr.name(), " for While node ",
1026                                      fwd_while->DebugString());
1027     }
1028     std::unique_ptr<FunctionBody> fwd_fbody;
1029     TF_CHECK_OK(FunctionDefToBodyHelper(
1030         *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
1031 
1032     // Find the TensorListPushBack node; it's one of fwd_arg's successors.
1033     Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
1034     std::vector<Node*> tl_push_nodes;
1035     for (const Edge* out_edge : fwd_arg->out_edges()) {
1036       if (out_edge->dst()->type_string() == "TensorListPushBack") {
1037         tl_push_nodes.push_back(out_edge->dst());
1038       }
1039     }
1040     if (tl_push_nodes.size() != 1) {
1041       // No TensorListPushBack found, or multiple TensorListPushBack.
1042       continue;
1043     }
1044 
1045     // Get input for the TensorListPushBack node.
1046     Node* input_node;
1047     TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
1048     if (input_node->type_string() != "Const") {
1049       // Input for the TensorList is not Const node.
1050       continue;
1051     }
1052 
1053     NodeDef const_input_nodedef = input_node->def();
1054 
1055     // Rewrite backward While body function, replace usages of
1056     // TensorListPopBack with a Const node.
1057     NameAttrList bwd_body_attr;
1058     TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
1059     const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
1060     if (!bwd_body) {
1061       return errors::InvalidArgument("Cannot find function ",
1062                                      bwd_body_attr.name(), " for While node ",
1063                                      bwd_while->DebugString());
1064     }
1065     std::unique_ptr<FunctionBody> bwd_fbody;
1066     TF_CHECK_OK(FunctionDefToBodyHelper(
1067         *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
1068 
1069     // Find the TensorListPopBack node; it's one of bwd_arg's successors.
1070     Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
1071     std::vector<Node*> tl_pop_nodes;
1072     for (const Edge* out_edge : bwd_arg->out_edges()) {
1073       if (out_edge->dst()->type_string() == "TensorListPopBack") {
1074         tl_pop_nodes.push_back(out_edge->dst());
1075       }
1076     }
1077     if (tl_pop_nodes.size() != 1) {
1078       // No TensorListPopBack found, or multiple TensorListPopBack.
1079       continue;
1080     }
1081 
1082     // Replace TensorListPopBack usages with Const node.
1083     std::vector<const Edge*> edges_to_replace;
1084     for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
1085       if (e->src_output() == 1) {
1086         edges_to_replace.push_back(e);
1087       }
1088     }
1089     if (edges_to_replace.empty()) {
1090       continue;
1091     }
1092     const_input_nodedef.set_name(
1093         bwd_fbody->graph->NewName(const_input_nodedef.name()));
1094     TF_ASSIGN_OR_RETURN(Node * const_node,
1095                         bwd_fbody->graph->AddNode(const_input_nodedef));
1096     for (const Edge* e : edges_to_replace) {
1097       Node* dst = e->dst();
1098       int dst_input = e->dst_input();
1099       bwd_fbody->graph->RemoveEdge(e);
1100       bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
1101     }
1102 
1103     // Add rewritten backward While body function.
1104     FunctionDef new_fdef;
1105     string new_name = fld->UniqueFunctionName(
1106         absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
1107     TF_RETURN_IF_ERROR(
1108         GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
1109     TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
1110 
1111     // Change backward While op to use the new body function.
1112     bwd_body_attr.set_name(new_name);
1113     bwd_while->ClearAttr("body");
1114     bwd_while->AddAttr("body", bwd_body_attr);
1115   }
1116   return OkStatus();
1117 }
1118 
1119 }  // namespace tensorflow
1120