1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/sharding_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/function_body.h"
31 #include "tensorflow/core/framework/graph.pb.h"
32 #include "tensorflow/core/framework/graph_def_util.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def.pb.h"
35 #include "tensorflow/core/framework/node_def_builder.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_def_builder.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_shape.pb.h"
40 #include "tensorflow/core/framework/versions.pb.h"
41 #include "tensorflow/core/graph/tensor_id.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/platform/errors.h"
45
46 namespace tensorflow {
47
48 namespace {
49
ValidateTensorId(const tf2xla::TensorId & id)50 Status ValidateTensorId(const tf2xla::TensorId& id) {
51 if (id.node_name().empty()) {
52 return errors::InvalidArgument("TensorId node_name must be non-empty");
53 }
54 if (id.output_index() < 0) {
55 return errors::InvalidArgument("TensorId output_index must be positive");
56 }
57 return OkStatus();
58 }
59
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)60 Status CheckNameDuplicates(const string& kind, const string& name,
61 std::set<string>* names) {
62 if (!name.empty()) {
63 if (!names->insert(name).second) {
64 return errors::InvalidArgument("duplicate ", kind, " name: ", name);
65 }
66 }
67 return OkStatus();
68 }
69
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)70 Status CheckFeedFetchNameConflicts(const string& kind,
71 const std::set<string>& names) {
72 // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
73 // since that will cause a collision in codegen symbols.
74 for (const string& name : names) {
75 const string name_data(name + "_data");
76 if (names.find(name_data) != names.end()) {
77 return errors::InvalidArgument("conflicting ", kind, " name: ", name,
78 " and ", name_data);
79 }
80 }
81 return OkStatus();
82 }
83
84 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
85 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)86 Status CopyAssociatedFunctions(Graph* g,
87 const FunctionLibraryDefinition* lookup_fld,
88 FunctionLibraryDefinition* fld) {
89 for (Node* n : g->op_nodes()) {
90 for (const auto& associated_function :
91 GetAssociatedFunctions(*n, lookup_fld)) {
92 switch (associated_function.type()) {
93 case AssociatedFunctionInfo::kFunctionCallNode: {
94 const FunctionDef* fdef =
95 lookup_fld->Find(associated_function.func_name());
96 if (!fdef) {
97 return errors::Internal(
98 "Cannot find function ", associated_function.func_name(),
99 " for function call node ", n->DebugString());
100 }
101 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
102 break;
103 }
104 case AssociatedFunctionInfo::kSymbolicGradient:
105 case AssociatedFunctionInfo::kFunctionAttr:
106 break;
107 }
108 }
109 }
110 return OkStatus();
111 }
112
113 // Replaces the single edge feeding into {dst,dst_input} with a new
114 // src/src_output specified by {with,with_output}.
ReplaceEdge(Graph * g,Node * dst,int dst_input,Node * with,int with_output)115 StatusOr<Node*> ReplaceEdge(Graph* g, Node* dst, int dst_input, Node* with,
116 int with_output) {
117 NodeDef replace_def = dst->def();
118 *replace_def.mutable_input(dst_input) = with->name();
119 TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, replace_def));
120 const Edge* usage_edge;
121 TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &usage_edge));
122 g->RemoveEdge(usage_edge);
123 g->AddEdge(with, with_output, replace_node, dst_input);
124 return replace_node;
125 }
126
127 // Replaces usages of the given `src_output` index of the given `src` node with
128 // the given `replacement` node (assumes the :0 output of `replacement`).
ReplaceSrcOutputUsageWithNode(Graph * g,Node * src,int src_output,Node * replacement)129 Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output,
130 Node* replacement) {
131 VLOG(1) << "Replace usages of output " << src_output << " of node "
132 << (VLOG_IS_ON(3) ? src->DebugString() : src->name()) << " with "
133 << (VLOG_IS_ON(3) ? replacement->DebugString() : replacement->name());
134 // Collect all usages of the specified src output (src->out_edges() iterator
135 // will not be stable under modifications).
136 struct OutEdgeInfo {
137 int dst_node_id, dst_input;
138 };
139 std::vector<OutEdgeInfo> usages;
140 for (const Edge* e : src->out_edges()) {
141 if (e->IsControlEdge() || e->src_output() != src_output) {
142 continue;
143 }
144 usages.push_back({e->dst()->id(), e->dst_input()});
145 }
146
147 // Now, replace each usage.
148 for (int i = 0, end = usages.size(); i < end; i++) {
149 // Make a copy of `usage_node`, and change its input to const node.
150 Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
151 VLOG(2) << " Replace usage by " << usage_node->DebugString();
152 // Note: Replacement output index is presumed to be 0.
153 TF_ASSIGN_OR_RETURN(
154 Node * replace_node,
155 ReplaceEdge(g, usage_node, usages[i].dst_input, replacement, 0));
156 // Later entries in `usages` might have `usage_node` as dst node, but
157 // `usage_node` is removed. Replace such entries with `replace_node`.
158 for (int j = i + 1, end = usages.size(); j < end; j++) {
159 if (usages[j].dst_node_id == usages[i].dst_node_id) {
160 usages[j].dst_node_id = replace_node->id();
161 }
162 }
163 }
164 return OkStatus();
165 }
166
167 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
168 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)169 Status ReplaceArgUsageWithConstNode(
170 Graph* g,
171 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
172 // Collect all _Arg nodes.
173 absl::flat_hash_map<int, Node*> arg_nodes;
174 for (Node* n : g->op_nodes()) {
175 if (n->IsArg()) {
176 int index;
177 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
178 arg_nodes[index] = n;
179 }
180 }
181
182 for (const auto& iter : const_input_index_to_node) {
183 int arg_index = iter.first;
184 VLOG(2) << "Replace usages of _Arg " << arg_index;
185 NodeDef const_def = iter.second->def();
186 const_def.set_name(g->NewName(const_def.name()));
187 TF_ASSIGN_OR_RETURN(Node * const_node, g->AddNode(const_def));
188 Node* arg_node = arg_nodes[arg_index];
189 TF_RETURN_IF_ERROR(
190 ReplaceSrcOutputUsageWithNode(g, arg_node, 0, const_node));
191 }
192 return OkStatus();
193 }
194
195 // Replaces the single input to _Retval nodes with an index in the keys of
196 // const_input_index_to_node with the single output of the corresponding _Arg
197 // node.
ReplaceRetvalInputWithArg(Graph * g,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node)198 Status ReplaceRetvalInputWithArg(
199 Graph* g,
200 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node) {
201 absl::flat_hash_map<int, Node*> arg_nodes;
202 absl::flat_hash_map<int, Node*> ret_nodes;
203 for (Node* n : g->op_nodes()) {
204 if (n->IsRetval() || n->IsArg()) {
205 int index;
206 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
207 if (n->IsRetval()) {
208 ret_nodes[index] = n;
209 } else {
210 arg_nodes[index] = n;
211 }
212 }
213 }
214
215 for (const auto& iter : const_input_index_to_node) {
216 int arg_index = iter.first;
217 VLOG(2) << "Bind _Retval " << arg_index << " to _Arg " << arg_index;
218 TF_RETURN_IF_ERROR(
219 ReplaceEdge(g, ret_nodes[arg_index], 0, arg_nodes[arg_index], 0)
220 .status());
221 }
222 return OkStatus();
223 }
224
225 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
226 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
227 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const absl::flat_hash_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld,bool passthrough_arg_to_retval=false)228 Status PropagateConstIntoFuncAttr(
229 Node* n, const string& attr_name,
230 const absl::flat_hash_map<int, const Node*>& const_input_index_to_node,
231 const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld,
232 bool passthrough_arg_to_retval = false) {
233 VLOG(1) << "Propagate const into " << attr_name << " of node " << n->name();
234 // Instantiate the function.
235 NameAttrList func_attr;
236 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
237 const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
238 if (!fdef) {
239 return errors::Internal("Cannot find function ", func_attr.name(),
240 " for node ", n->name());
241 }
242 std::unique_ptr<FunctionBody> fbody;
243 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
244 *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
245
246 // Rewrite _Arg usages with Const node.
247 Graph* func_graph = fbody->graph;
248 TF_RETURN_IF_ERROR(
249 ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
250 if (passthrough_arg_to_retval) {
251 TF_RETURN_IF_ERROR(
252 ReplaceRetvalInputWithArg(func_graph, const_input_index_to_node));
253 }
254
255 // Save rewritten function.
256 FunctionDef replace_fdef;
257 string new_func_name =
258 fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
259 TF_RETURN_IF_ERROR(
260 GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
261 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
262 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
263
264 VLOG(1) << "replace func " << func_attr.name() << " with " << new_func_name;
265 // Change the node to use rewritten function.
266 func_attr.set_name(new_func_name);
267 n->ClearAttr(attr_name);
268 n->AddAttr(attr_name, func_attr);
269
270 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
271 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
272
273 // Copy associated functions.
274 TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
275
276 return OkStatus();
277 }
278
279 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
280 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)281 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
282 const FunctionLibraryDefinition* lookup_fld,
283 FunctionLibraryDefinition* fld) {
284 // Notice that first input for If node is predicate; other inputs are function
285 // inputs.
286 absl::flat_hash_map<int, const Node*> const_input_index_to_node;
287 for (int i = 1; i < if_node->num_inputs(); i++) {
288 const Node* input_node;
289 TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
290 if (input_node->type_string() == "Const") {
291 const_input_index_to_node[i - 1] = input_node;
292 }
293 }
294 if (const_input_index_to_node.empty()) {
295 return OkStatus();
296 }
297
298 // Rewrite "then_branch" and "else_branch" function, replace usage of those
299 // _Arg nodes with corresponding const node.
300 for (const auto& attr_name :
301 std::vector<string>{"then_branch", "else_branch"}) {
302 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
303 if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
304 }
305
306 return OkStatus();
307 }
308
309 using GraphCache = absl::flat_hash_map<string, std::unique_ptr<FunctionBody>>;
310
FindOrInsert(GraphCache * cache,const NameAttrList & body_attr,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld)311 StatusOr<FunctionBody*> FindOrInsert(
312 GraphCache* cache, const NameAttrList& body_attr,
313 const FunctionLibraryDefinition* lookup_fld,
314 const FunctionLibraryDefinition* fallback_fld) {
315 const string name = body_attr.name();
316 std::unique_ptr<FunctionBody>& value = (*cache)[name];
317 if (!value) {
318 const FunctionDef* body_func = lookup_fld->Find(name);
319 if (!body_func && fallback_fld != nullptr) {
320 body_func = fallback_fld->Find(name);
321 }
322 if (!body_func) {
323 return errors::Internal("Traverse: Cannot find body function ", name);
324 }
325 std::unique_ptr<FunctionBody> fbody;
326 Status s = FunctionDefToBodyHelper(*body_func, AttrSlice(&body_attr.attr()),
327 lookup_fld, &fbody);
328 if (!s.ok() && fallback_fld != nullptr) {
329 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
330 *body_func, AttrSlice(&body_attr.attr()), fallback_fld, &fbody));
331 }
332 value = std::move(fbody);
333 }
334 return value.get();
335 }
336 // Determines whether a loop body is invariant for the given argument index.
337 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
338 const FunctionLibraryDefinition* lookup_fld,
339 const FunctionLibraryDefinition* fallback_fld,
340 GraphCache* cache);
341
342 // Traces backward through non-modifying ops such as Identity and loop-invariant
343 // While, to find a preceding source edge.
TraverseUnmodifiedPathBackward(const Edge * src,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)344 StatusOr<const Edge*> TraverseUnmodifiedPathBackward(
345 const Edge* src, const FunctionLibraryDefinition* lookup_fld,
346 const FunctionLibraryDefinition* fallback_fld, GraphCache* cache) {
347 const Edge* e = src;
348 VLOG(2) << "Traverse: Begin at " << e->DebugString();
349 // TODO(b/184727356): Also traverse If/Case nodes.
350 // Begin walking back from the output node.
351 while (IsConstTraversableOpType(e->src())) {
352 VLOG(3) << e->DebugString();
353
354 if (e->src()->IsWhileNode()) {
355 NameAttrList body_attr;
356 TF_RETURN_IF_ERROR(GetNodeAttr(e->src()->def(), "body", &body_attr));
357 TF_ASSIGN_OR_RETURN(
358 FunctionBody * fbody,
359 FindOrInsert(cache, body_attr, lookup_fld, fallback_fld));
360 TF_ASSIGN_OR_RETURN(bool is_loop_invariant,
361 IsLoopInvariant(fbody, e->src_output(), lookup_fld,
362 fallback_fld, cache));
363 if (!is_loop_invariant) {
364 VLOG(2) << "Non-loop-invariant: index " << e->src_output() << " of "
365 << body_attr.name();
366 break;
367 }
368 } // if While|StatelessWhile
369 // Proceed backward to the src's input corresponding with the output index.
370 TF_RETURN_IF_ERROR(e->src()->input_edge(e->src_output(), &e));
371 }
372 VLOG(2) << "Traverse: Finish at " << e->DebugString();
373
374 return e;
375 }
376
377 // Determines whether a loop body is invariant for the given argument index.
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld,const FunctionLibraryDefinition * fallback_fld,GraphCache * cache)378 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
379 const FunctionLibraryDefinition* lookup_fld,
380 const FunctionLibraryDefinition* fallback_fld,
381 GraphCache* cache) {
382 const Edge* e;
383 TF_RETURN_IF_ERROR(loop_body->ret_nodes[index]->input_edge(0, &e));
384 TF_ASSIGN_OR_RETURN(
385 const Edge* reachable,
386 TraverseUnmodifiedPathBackward(e, lookup_fld, fallback_fld, cache));
387 if (reachable->src()->id() == loop_body->arg_nodes[index]->id()) {
388 VLOG(2) << "Index " << index << " is loop invariant.";
389 return true;
390 }
391 VLOG(2) << "Index " << index << " not loop invariant: "
392 << "walk backward from " << e->src()->DebugString() << " to "
393 << reachable->src()->DebugString() << " did not reach "
394 << loop_body->arg_nodes[index]->DebugString();
395 return false;
396 }
397
398 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
399 // cond/body function to replace _Arg nodes with those Const inputs. Then,
400 // propagate these Const to consumers of the relevant outputs of the while loop.
PropagateConstIntoAndAroundWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)401 Status PropagateConstIntoAndAroundWhileNode(
402 Graph* g, Node* while_node, const FunctionLibraryDefinition* lookup_fld,
403 FunctionLibraryDefinition* fld) {
404 VLOG(1) << "Propagate const into " << while_node->name();
405
406 // For "While" node, we should only replace _Arg nodes which are loop
407 // invariants. For such _Arg nodes, the return value's input will come
408 // directly from the corresponding arg.
409 absl::flat_hash_map<int, const Node*> const_input_index_to_node;
410 absl::flat_hash_map<int, Node*> const_input_index_to_mutable_node;
411 NameAttrList body_attr;
412 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
413 const string fn_name = body_attr.name();
414 const FunctionDef* body_func = lookup_fld->Find(fn_name);
415 if (!body_func) {
416 return errors::Internal("Propagate: Cannot find body function ", fn_name,
417 " for While node ", while_node->name());
418 }
419 std::unique_ptr<FunctionBody> fbody;
420 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
421 *body_func, AttrSlice(&body_attr.attr()), lookup_fld, &fbody));
422 GraphCache cache;
423 for (int i = 0; i < while_node->num_inputs(); i++) {
424 // Check if i-th retval's input comes from i-th arg directly.
425 // For resource variable input of While nodes, TF2XLA convention is to place
426 // them at the end of all inputs (after all data inputs), and *not* return
427 // them. So number of While node inputs might be larger than number of its
428 // outputs.
429 if (i >= body_func->signature().output_arg_size()) {
430 break;
431 }
432
433 const Edge* input_edge;
434 TF_RETURN_IF_ERROR(while_node->input_edge(i, &input_edge));
435 TF_ASSIGN_OR_RETURN(input_edge, TraverseUnmodifiedPathBackward(
436 input_edge, lookup_fld, fld, &cache));
437 if (!input_edge->src()->IsConstant()) {
438 VLOG(2) << "Input " << i << " is not Const; is "
439 << input_edge->src()->type_string();
440 continue;
441 }
442
443 TF_ASSIGN_OR_RETURN(
444 bool is_loop_invariant,
445 IsLoopInvariant(fbody.get(), i, lookup_fld, fld, &cache));
446 if (!is_loop_invariant) {
447 VLOG(2) << "While state not loop-invariant; not propagating Const " << i;
448 continue;
449 }
450 VLOG(2) << "While state is loop-invariant; propagating Const " << i;
451
452 const_input_index_to_mutable_node[i] = input_edge->src();
453 const_input_index_to_node[i] = input_edge->src();
454 }
455 if (const_input_index_to_node.empty()) {
456 return OkStatus();
457 }
458
459 // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
460 // corresponding const node.
461 for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
462 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
463 while_node, attr_name, const_input_index_to_node, lookup_fld, fld,
464 /*passthrough_arg_to_retval=*/attr_name == "body"));
465 }
466
467 // Rewrite usages of the output edges corresponding to loop-invariant const
468 // inputs to refer instead to the Const node.
469 for (const auto& it : const_input_index_to_mutable_node) {
470 TF_RETURN_IF_ERROR(
471 ReplaceSrcOutputUsageWithNode(g, while_node, it.first, it.second));
472 }
473 return OkStatus();
474 }
475
476 } // namespace
477
IsLoopInvariant(const FunctionBody * loop_body,int index,const FunctionLibraryDefinition * lookup_fld)478 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
479 const FunctionLibraryDefinition* lookup_fld) {
480 GraphCache cache;
481 return IsLoopInvariant(loop_body, index, lookup_fld,
482 /*fallback_fld=*/nullptr, &cache);
483 }
484
ValidateConfig(const tf2xla::Config & config)485 Status ValidateConfig(const tf2xla::Config& config) {
486 std::set<string> names;
487 for (const tf2xla::Feed& feed : config.feed()) {
488 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
489 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
490 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
491 }
492 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
493 names.clear();
494 for (const tf2xla::Fetch& fetch : config.fetch()) {
495 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
496 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
497 }
498 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
499 if (config.fetch().empty()) {
500 return errors::InvalidArgument("fetches must be specified");
501 }
502 return OkStatus();
503 }
504
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)505 Status AddPlaceholdersForFeeds(
506 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
507 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
508 struct PlaceholderInfo {
509 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
510 string placeholder_name;
511 DataType data_type = DT_INVALID;
512 };
513
514 // Put each fed tensor into a map by name:port. A map is used for determinism
515 // when creating placeholders (genrules want deterministic output).
516 std::map<string, PlaceholderInfo> placeholder_info;
517 for (int i = 0; i < config.feed_size(); ++i) {
518 const tf2xla::Feed* feed = &config.feed(i);
519 const string name_port = TensorIdToString(feed->id());
520 PlaceholderInfo& info = placeholder_info[name_port];
521 info.feed = feed;
522 info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
523 "/", feed->id().node_name());
524 (*feed_remapping)[name_port] = info.placeholder_name;
525 }
526
527 // Verify node exists and determine data type.
528 std::unordered_map<string, const NodeDef*> name_to_node;
529 for (int i = 0; i < graph_def->node_size(); ++i) {
530 name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
531 }
532 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
533 PlaceholderInfo& info = it->second;
534 const tf2xla::TensorId& feed_id = info.feed->id();
535
536 // Find the existing node and determine data type.
537 auto node_it = name_to_node.find(feed_id.node_name());
538 if (node_it == name_to_node.end()) {
539 return errors::NotFound("Can't find feed node: ",
540 TensorIdToString(feed_id));
541 }
542 const NodeDef* existing = node_it->second;
543
544 if (info.feed->type() != DT_INVALID) {
545 info.data_type = info.feed->type();
546 } else {
547 // Build the node in order to infer its type.
548
549 // Must first add default attrs as well, so do this in a copied GraphDef.
550 GraphDef gd;
551 *gd.mutable_versions() = graph_def->versions();
552 *gd.add_node() = *existing;
553 MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
554 TF_RETURN_IF_ERROR(
555 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
556
557 // Now build the node from the copied node def.
558 Graph g(op_registry);
559 g.set_versions(graph_def->versions());
560 TF_ASSIGN_OR_RETURN(Node * feed_node, g.AddNode(gd.node(0)));
561
562 if (info.feed->id().output_index() < feed_node->num_outputs()) {
563 info.data_type =
564 BaseType(feed_node->output_type(info.feed->id().output_index()));
565 } else {
566 return errors::InvalidArgument(
567 "Invalid output_index ", info.feed->id().output_index(),
568 " for feed node ", info.feed->id().node_name());
569 }
570 }
571 }
572
573 // Create placeholders. Note that we could avoid creating a placeholder for
574 // feeds which are already placeholders, but we omit that to avoid more cases
575 // in this code.
576 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
577 const PlaceholderInfo& info = it->second;
578 // TODO(shikharagarwal): Add original node information.
579 NodeDef* d = graph_def->add_node();
580 d->set_name(info.placeholder_name);
581 d->set_op("Placeholder");
582 auto& attr_map = *d->mutable_attr();
583 attr_map["dtype"].set_type(info.data_type);
584 *attr_map["shape"].mutable_shape() = info.feed->shape();
585 }
586
587 // Rewrite references to the fed tensors to refer to the placeholder.
588 for (int i = 0; i < graph_def->node_size(); ++i) {
589 NodeDef* node_def = graph_def->mutable_node(i);
590 for (int j = 0; j < node_def->input_size(); ++j) {
591 auto id = ParseTensorName(node_def->input(j));
592 auto it = placeholder_info.find(id.ToString());
593 if (it != placeholder_info.end()) {
594 node_def->set_input(j, it->second.placeholder_name);
595 }
596 }
597 }
598
599 return OkStatus();
600 }
601
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)602 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
603 GraphDef* out) {
604 *out = in;
605 out->clear_node();
606
607 // Tensors needed for feeding.
608 std::set<std::pair<string, int>> feed_tensors;
609 for (const tf2xla::Feed& feed : config.feed()) {
610 feed_tensors.insert(
611 std::make_pair(feed.id().node_name(), feed.id().output_index()));
612 }
613
614 // Maps node name to reachability.
615 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
616 for (const NodeDef& node : in.node()) {
617 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
618 }
619
620 // Traverse.
621 std::queue<string> name_queue;
622 for (int i = 0; i < config.fetch_size(); ++i) {
623 name_queue.push(config.fetch(i).id().node_name());
624 }
625 while (!name_queue.empty()) {
626 const string name = name_queue.front();
627 name_queue.pop();
628
629 auto find_it = node_by_name.find(name);
630 if (find_it == node_by_name.end()) {
631 return errors::InvalidArgument("While pruning graph, node ", name,
632 " needed but not found in the graph.");
633 }
634 auto& map_entry = find_it->second;
635 if (map_entry.first) {
636 continue;
637 }
638 map_entry.first = true;
639
640 // Push input nodes of the currently visited node to name_queue.
641 for (const string& in_edge : map_entry.second->input()) {
642 auto id = ParseTensorName(in_edge);
643 const string node_name = string(id.first);
644 if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
645 feed_tensors.end()) {
646 name_queue.push(node_name);
647 } else {
648 // The input tensor is from an edge that is being fed. Therefore,
649 // we skip recursing down that edge, to avoid requiring nodes that
650 // may not be needed (note that the input node may still be added
651 // to name_queue later if one of its output edges is not being fed).
652 }
653 }
654 }
655
656 // Copy over, preserving order of original and only nodes that are reachable
657 // from the fetches.
658 out->mutable_node()->Reserve(in.node_size());
659 for (const NodeDef& node : in.node()) {
660 if (node_by_name[node.name()].first) {
661 *out->add_node() = node;
662 }
663 }
664 return OkStatus();
665 }
666
TensorIdToString(const tf2xla::TensorId & id)667 string TensorIdToString(const tf2xla::TensorId& id) {
668 return absl::StrCat(id.node_name(), ":", id.output_index());
669 }
670
SetNodeShardingFromNeighbors(Node * n,bool out_edges)671 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
672 int core = -1;
673 const Node* matching_node = nullptr;
674 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
675 if (edge->IsControlEdge()) continue;
676 const Node* possible_match = out_edges ? edge->dst() : edge->src();
677 TF_ASSIGN_OR_RETURN(
678 std::optional<xla::OpSharding> sharding,
679 ParseShardingFromDevice(
680 *possible_match,
681 /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
682 /*add_metadata=*/false));
683 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
684 const int core_annotation = sharding.value().tile_assignment_devices(0);
685 if (core == -1 || core > core_annotation) {
686 core = core_annotation;
687 matching_node = possible_match;
688 }
689 }
690 }
691 if (matching_node != nullptr) {
692 n->set_assigned_device_name(matching_node->assigned_device_name());
693 n->set_requested_device(matching_node->requested_device());
694 }
695 return OkStatus();
696 }
697
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)698 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
699 KernelDef* kdef) {
700 for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
701 if (constraint.name() == name) {
702 constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
703 }
704 }
705 }
706
707 namespace {
InitialRandomSeed()708 uint32 InitialRandomSeed() {
709 // Support plumbing the TF seed through to XLA is being worked on.
710 // If a user wants deterministic behavior, their best option
711 // is to start with a known checkpoint. This also handles issues when
712 // multiple random calls can be invoked in any order by TF executor.
713 // Another option is to use stateless random ops. They have much cleaner
714 // semantics.
715 // If a user really wants to set a deterministic seed for XLA-based
716 // devices, this is the place to do it.
717 std::random_device rd;
718 // Make the starting value odd.
719 return rd() | 1;
720 }
721 } // namespace
722
GetXLARandomSeed()723 uint32 GetXLARandomSeed() {
724 // We initialize counter with an odd number and increment it by two
725 // everytime. This ensures that it will never be zero, even
726 // after an overflow. When seeded with zero, some XLA backends
727 // can return all zeros instead of random numbers.
728 static std::atomic<uint32> counter(InitialRandomSeed());
729 uint32 seed = counter.fetch_add(2);
730 std::srand(seed);
731 return std::rand() | 1;
732 }
733
734 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)735 bool HasAssociatedFunction(const NodeDef& node_def,
736 const FunctionLibraryDefinition* fld) {
737 if (fld->Contains(node_def.op())) {
738 return true;
739 }
740
741 if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
742 // Gradient op has "f" attr, which is set to the function we are getting
743 // gradient for. We need to functionalize the gradient function.
744 return true;
745 }
746
747 if (node_def.op() == "XlaHostCompute") {
748 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
749 // related to graph execution.
750 return false;
751 }
752
753 for (const auto& iter : node_def.attr()) {
754 if (iter.second.has_func()) {
755 return true;
756 }
757 }
758
759 return false;
760 }
761
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)762 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
763 const Node& node, const FunctionLibraryDefinition* fld) {
764 std::vector<AssociatedFunctionInfo> results;
765 const string& op = node.type_string();
766 if (fld->Contains(op)) {
767 // This is a function call node.
768 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
769 results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
770 } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
771 // This is a SymbolicGradient op.
772 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
773 results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
774 } else if (node.type_string() == "XlaHostCompute") {
775 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
776 // related to graph execution.
777 } else {
778 // Collect all function attrs for the node.
779 for (auto& iter : node.attrs()) {
780 if (iter.second.has_func()) {
781 VLOG(2) << "Found function attr for node " << node.name() << ": "
782 << iter.first << " = " << iter.second.func().name();
783 results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
784 iter.second.func().name(), iter.second.func().attr(), iter.first));
785 }
786 }
787 }
788 return results;
789 }
790
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)791 Status RewriteAssociatedFunction(
792 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
793 const AssociatedFunctionInfo& associated_function,
794 const string& rewritten_function_name) {
795 switch (associated_function.type()) {
796 case AssociatedFunctionInfo::kFunctionCallNode: {
797 // Change this node to call the new function.
798 NodeDebugInfo debug_info(*node);
799 NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
800 &debug_info);
801 for (const auto& attr : node->attrs()) {
802 builder.Attr(attr.first, attr.second);
803 }
804 for (int i = 0; i < node->num_inputs(); i++) {
805 Node* input_node;
806 TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
807 builder.Input(input_node->name(), i, node->input_type(i));
808 }
809 builder.Device(node->assigned_device_name().empty()
810 ? node->requested_device()
811 : node->assigned_device_name());
812 NodeDef node_def;
813 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
814 TF_ASSIGN_OR_RETURN(Node * new_node, graph->AddNode(node_def));
815 for (auto edge : node->in_edges()) {
816 graph->AddEdge(edge->src(), edge->src_output(), new_node,
817 edge->dst_input());
818 }
819 for (auto edge : node->out_edges()) {
820 graph->AddEdge(new_node, edge->src_output(), edge->dst(),
821 edge->dst_input());
822 }
823 graph->RemoveNode(node);
824 break;
825 }
826 case AssociatedFunctionInfo::kSymbolicGradient: {
827 NameAttrList func;
828 TF_RETURN_IF_ERROR(GetNodeAttr(
829 node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
830 GradientDef gradient_def;
831 gradient_def.set_function_name(func.name());
832 gradient_def.set_gradient_func(rewritten_function_name);
833 string original_grad_func = fld->FindGradient(func.name());
834 if (original_grad_func.empty()) {
835 TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
836 } else if (original_grad_func != rewritten_function_name) {
837 TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
838 }
839 break;
840 }
841 case AssociatedFunctionInfo::kFunctionAttr: {
842 // Change function attr to rewritten functions.
843 NameAttrList func;
844 TF_RETURN_IF_ERROR(
845 GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
846 // Save the original function name in case TFRT fallbacks to use
847 // TPUPartitionedCall op in the runtime.
848 if (node->type_string() == "TPUPartitionedCall") {
849 node->AddAttr("_orig_f", func.name());
850 }
851 node->ClearAttr(associated_function.attr_name());
852 func.set_name(rewritten_function_name);
853 node->AddAttr(associated_function.attr_name(), func);
854 break;
855 }
856 }
857
858 return OkStatus();
859 }
860
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)861 Status CachedFunctionHandles::GetOrInstantiate(
862 const string& func_name, AttrSlice attrs,
863 FunctionLibraryRuntime::Handle* handle) {
864 string canonicalized_name = Canonicalize(func_name, attrs);
865 auto iter = handles_.find(canonicalized_name);
866 if (iter != handles_.end()) {
867 *handle = iter->second;
868 return OkStatus();
869 }
870
871 TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
872 handles_[canonicalized_name] = *handle;
873 return OkStatus();
874 }
875
ReleaseAllHandles()876 Status CachedFunctionHandles::ReleaseAllHandles() {
877 Status result;
878 for (const auto& iter : handles_) {
879 result.Update(flr_->ReleaseHandle(iter.second));
880 }
881 handles_.clear();
882 return result;
883 }
884
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)885 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
886 // Create the replacement node.
887 TF_ASSIGN_OR_RETURN(Node * new_node, g->AddNode(node_def));
888
889 // Record original node's output edges and remove them first. This is to avoid
890 // multiple producers for dst nodes' input.
891 std::vector<OutEdgeInfo> out_edge_info;
892 std::vector<const Edge*> out_edges;
893 for (const Edge* edge : n->out_edges()) {
894 out_edges.push_back(edge);
895 out_edge_info.push_back(
896 {edge->dst(), edge->src_output(), edge->dst_input()});
897 }
898 for (const Edge* edge : out_edges) {
899 g->RemoveEdge(edge);
900 }
901
902 // Add original node's input and output edges to the replacement node.
903 for (const Edge* in_edge : n->in_edges()) {
904 g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
905 in_edge->dst_input());
906 }
907 for (const OutEdgeInfo& out_edge : out_edge_info) {
908 g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
909 }
910
911 // Remove the original node.
912 g->RemoveNode(n);
913
914 return new_node;
915 }
916
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,std::optional<string> requested_device)917 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
918 DataType dtype, const Node* input,
919 std::optional<string> requested_device) {
920 // Create identity node.
921 NodeDef ndef;
922 ndef.set_name(node_name);
923 ndef.set_op("Identity");
924 if (input) {
925 ndef.add_input(input->name());
926 }
927 if (requested_device) {
928 ndef.set_device(*requested_device);
929 }
930 AddNodeAttr("T", dtype, &ndef);
931 TF_ASSIGN_OR_RETURN(Node * id_node, graph->AddNode(ndef));
932 return id_node;
933 }
934
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)935 Status PropagateConstIntoFunctionalNodes(
936 Graph* g, const FunctionLibraryDefinition* lookup_fld,
937 FunctionLibraryDefinition* fld) {
938 absl::flat_hash_set<int> done_node_ids;
939
940 // Because we may propagate Const around a while node as well as into it,
941 // we restart the op_nodes() iterator after each pass and keep track of which
942 // nodes we've already dealt with.
943 bool should_continue = true;
944 while (should_continue) {
945 should_continue = false;
946 for (Node* n : g->op_nodes()) {
947 if (!done_node_ids.contains(n->id())) {
948 if (n->IsIfNode()) {
949 VLOG(1) << "PropagateConstIntoIfNode: " << n->name();
950 TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
951 done_node_ids.emplace(n->id());
952 VLOG(1) << "Done PropagateConstIntoIfNode: " << n->name();
953 } else if (n->IsWhileNode()) {
954 VLOG(1) << "PropagateConstIntoWhileNode: " << n->name();
955 TF_RETURN_IF_ERROR(
956 PropagateConstIntoAndAroundWhileNode(g, n, lookup_fld, fld));
957 done_node_ids.emplace(n->id());
958 should_continue = true;
959 VLOG(1) << "Done PropagateConstIntoWhileNode: " << n->name();
960 break;
961 }
962 }
963 }
964 }
965 return OkStatus();
966 }
967
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)968 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
969 FunctionLibraryDefinition* fld) {
970 GraphDef graph_def;
971 g.ToGraphDef(&graph_def);
972 FunctionLibraryDefinition reachable_functions =
973 fld->ReachableDefinitions(graph_def);
974 for (const string& func_name : fld->ListFunctionNames()) {
975 if (!reachable_functions.Find(func_name)) {
976 TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
977 }
978 }
979 return OkStatus();
980 }
981
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)982 Status RewriteTensorListWithConstElement(Graph* g,
983 FunctionLibraryDefinition* fld) {
984 for (Node* n : g->nodes()) {
985 if (n->type_string() != "EmptyTensorList") {
986 continue;
987 }
988
989 // Find the forward While op.
990 std::vector<const Edge*> fwd_while_edges;
991 for (const Edge* e : n->out_edges()) {
992 if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
993 fwd_while_edges.push_back(e);
994 }
995 }
996 if (fwd_while_edges.size() != 1) {
997 // No forward While op found, or multiple forward While ops.
998 continue;
999 }
1000
1001 // Find the backward While op.
1002 Node* fwd_while = fwd_while_edges[0]->dst();
1003 int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
1004 std::vector<const Edge*> bwd_while_edges;
1005 for (const Edge* e : fwd_while->out_edges()) {
1006 if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
1007 bwd_while_edges.push_back(e);
1008 }
1009 }
1010 if (bwd_while_edges.size() != 1) {
1011 // No backward While op found, or multiple backward While ops.
1012 continue;
1013 }
1014
1015 Node* bwd_while = bwd_while_edges[0]->dst();
1016 int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
1017
1018 // Look into forward While body function and check if TensorListPushBack op
1019 // has a Const input.
1020 NameAttrList fwd_body_attr;
1021 TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
1022 const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
1023 if (!fwd_body) {
1024 return errors::InvalidArgument("Cannot find function ",
1025 fwd_body_attr.name(), " for While node ",
1026 fwd_while->DebugString());
1027 }
1028 std::unique_ptr<FunctionBody> fwd_fbody;
1029 TF_CHECK_OK(FunctionDefToBodyHelper(
1030 *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
1031
1032 // Find the TensorListPushBack node; it's one of fwd_arg's successors.
1033 Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
1034 std::vector<Node*> tl_push_nodes;
1035 for (const Edge* out_edge : fwd_arg->out_edges()) {
1036 if (out_edge->dst()->type_string() == "TensorListPushBack") {
1037 tl_push_nodes.push_back(out_edge->dst());
1038 }
1039 }
1040 if (tl_push_nodes.size() != 1) {
1041 // No TensorListPushBack found, or multiple TensorListPushBack.
1042 continue;
1043 }
1044
1045 // Get input for the TensorListPushBack node.
1046 Node* input_node;
1047 TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
1048 if (input_node->type_string() != "Const") {
1049 // Input for the TensorList is not Const node.
1050 continue;
1051 }
1052
1053 NodeDef const_input_nodedef = input_node->def();
1054
1055 // Rewrite backward While body function, replace usages of
1056 // TensorListPopBack with a Const node.
1057 NameAttrList bwd_body_attr;
1058 TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
1059 const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
1060 if (!bwd_body) {
1061 return errors::InvalidArgument("Cannot find function ",
1062 bwd_body_attr.name(), " for While node ",
1063 bwd_while->DebugString());
1064 }
1065 std::unique_ptr<FunctionBody> bwd_fbody;
1066 TF_CHECK_OK(FunctionDefToBodyHelper(
1067 *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
1068
1069 // Find the TensorListPopBack node; it's one of bwd_arg's successors.
1070 Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
1071 std::vector<Node*> tl_pop_nodes;
1072 for (const Edge* out_edge : bwd_arg->out_edges()) {
1073 if (out_edge->dst()->type_string() == "TensorListPopBack") {
1074 tl_pop_nodes.push_back(out_edge->dst());
1075 }
1076 }
1077 if (tl_pop_nodes.size() != 1) {
1078 // No TensorListPopBack found, or multiple TensorListPopBack.
1079 continue;
1080 }
1081
1082 // Replace TensorListPopBack usages with Const node.
1083 std::vector<const Edge*> edges_to_replace;
1084 for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
1085 if (e->src_output() == 1) {
1086 edges_to_replace.push_back(e);
1087 }
1088 }
1089 if (edges_to_replace.empty()) {
1090 continue;
1091 }
1092 const_input_nodedef.set_name(
1093 bwd_fbody->graph->NewName(const_input_nodedef.name()));
1094 TF_ASSIGN_OR_RETURN(Node * const_node,
1095 bwd_fbody->graph->AddNode(const_input_nodedef));
1096 for (const Edge* e : edges_to_replace) {
1097 Node* dst = e->dst();
1098 int dst_input = e->dst_input();
1099 bwd_fbody->graph->RemoveEdge(e);
1100 bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
1101 }
1102
1103 // Add rewritten backward While body function.
1104 FunctionDef new_fdef;
1105 string new_name = fld->UniqueFunctionName(
1106 absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
1107 TF_RETURN_IF_ERROR(
1108 GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
1109 TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
1110
1111 // Change backward While op to use the new body function.
1112 bwd_body_attr.set_name(new_name);
1113 bwd_while->ClearAttr("body");
1114 bwd_while->AddAttr("body", bwd_body_attr);
1115 }
1116 return OkStatus();
1117 }
1118
1119 } // namespace tensorflow
1120