1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40
41 namespace tensorflow {
42 namespace {
43
44 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
45 // starting at nodes in vector `stack`.
46 // `node_map` is a vector indexed by source node ID to dest nodes.
47 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
48 // before the traversal clients can cut the graph. If a frame is provided (frame
49 // != nullptr), then this functions will return an error if the
50 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
51 // cut the graph and prevent the traversal from escaping.
52 //
53 // `squash_src_outputs` contains a bool for each source node ID. If true, then
54 // the source output on that node will be replaced by zero when copied. This is
55 // used when replacing a Switch node with an _Arg node. The output we are
56 // taking from the Switch node was not necessarily the first output, but _Arg
57 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
58 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)59 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
60 std::vector<Node*> stack,
61 const std::vector<bool>& squash_src_outputs,
62 std::vector<Node*>* node_map, Graph* output) {
63 VLOG(3) << "Stack: " << NodesToString(stack);
64 std::vector<bool> visited(graph.num_node_ids(), false);
65 while (!stack.empty()) {
66 Node* n = stack.back();
67 stack.pop_back();
68
69 VLOG(5) << "Copying node " << n->name();
70
71 if (visited[n->id()]) continue;
72 visited[n->id()] = true;
73
74 // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
75 // order.
76 std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
77 n->in_edges().end());
78 std::sort(sorted_edges.begin(), sorted_edges.end(),
79 [](const Edge* a, const Edge* b) {
80 int a_src_output = a->src_output(),
81 b_src_output = b->src_output();
82 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
83 return std::tie(a_src_output, a_name) <
84 std::tie(b_src_output, b_name);
85 });
86 for (const Edge* e : sorted_edges) {
87 Node* src = e->src();
88 if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
89 // We traversed out of the loop frame, without encountering a cut node.
90 return errors::Internal("Graph traversal of loop frame ", frame->name,
91 " escaped frame at ", src->name(),
92 " without encountering an argument node.");
93 }
94 if ((*node_map)[src->id()] == nullptr) {
95 (*node_map)[src->id()] = output->CopyNode(src);
96 stack.push_back(src);
97 }
98 Node* src_copy = (*node_map)[e->src()->id()];
99 int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
100 ? 0
101 : e->src_output();
102 Node* dst_copy = (*node_map)[e->dst()->id()];
103 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
104 }
105 }
106 return OkStatus();
107 }
108
BuildArgNode(Graph * graph,DataType type,int index)109 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
110 const char* const kArgOp = "_Arg";
111 NodeDef arg_def;
112 NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
113 builder.Attr("T", type);
114 builder.Attr("index", index);
115 TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
116 return graph->AddNode(arg_def);
117 }
118
119 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)120 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
121 std::unique_ptr<Graph>* cond_output) {
122 VLOG(2) << "Building loop condition for " << frame->name;
123 *cond_output = std::make_unique<Graph>(graph.op_registry());
124 Graph* output = cond_output->get();
125
126 // Map from nodes in the original graph to the condition graph.
127 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
128 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
129
130 // Build one _Arg node for each Enter node.
131 for (int i = 0, end = frame->args.size(); i < end; ++i) {
132 const WhileLoopArg& arg = frame->args[i];
133
134 TF_ASSIGN_OR_RETURN(Node * arg_node,
135 BuildArgNode(output, arg.enter->input_type(0), i));
136 if (arg.is_loop_invariant) {
137 node_map[arg.enter->id()] = arg_node;
138 } else {
139 node_map[arg.merge->id()] = arg_node;
140 }
141 }
142
143 // Build a Retval node for the loop condition. The LoopCond nodes are always
144 // boolean because of the type constraints on the LoopCond op.
145 TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
146 BuildRetvalNode(output, DT_BOOL, 0));
147
148 // Performs a reverse DFS, copying nodes and edges to the output graph.
149 // The _Arg and _Retval nodes were added unconditionally above, so we are
150 // guaranteed to get the correct function signature.
151 return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
152 &node_map, output);
153 }
154
155 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)156 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
157 DataTypeVector* arg_types,
158 std::unique_ptr<Graph>* body_output) {
159 VLOG(2) << "Building loop body for " << frame->name;
160 *body_output = std::make_unique<Graph>(graph.op_registry());
161 Graph* output = body_output->get();
162
163 // Map from nodes in the original graph to the body graph.
164 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
165 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
166
167 // Build one _Arg node for each Enter node.
168 std::vector<Node*> next_iterations;
169 next_iterations.reserve(frame->args.size());
170 arg_types->reserve(frame->args.size());
171 for (int i = 0, end = frame->args.size(); i < end; ++i) {
172 const WhileLoopArg& arg = frame->args[i];
173
174 DataType dtype = arg.enter->input_type(0);
175 arg_types->push_back(dtype);
176
177 TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
178 TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
179 if (arg.is_loop_invariant) {
180 // Argument is loop-invariant. Forward it from the Arg to the Retval.
181 node_map[arg.enter->id()] = arg_node;
182 output->AddEdge(arg_node, 0, retval_node, 0);
183 } else {
184 // Argument is loop-varying.
185 if (dtype == DT_RESOURCE) {
186 // DT_RESOURCE arguments should always be loop-invariant in the graphs
187 // generated from TF.
188 return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
189 arg.enter->name(), " is currently not",
190 " supported.");
191 }
192 node_map[arg.switch_node->id()] = arg_node;
193 // The Switch node has two outputs, but _Arg only has one. This tells
194 // the CopySubgraph function to rewrite the output number of edges from
195 // the _Arg node to be 0 rather than copying the output number from the
196 // Switch node.
197 squash_src_outputs[arg.switch_node->id()] = true;
198 node_map[arg.next_iteration->id()] = retval_node;
199 next_iterations.push_back(arg.next_iteration);
200 }
201 }
202
203 // Performs a reverse DFS, copying nodes and edges to the output graph.
204 // The _Arg and _Retval nodes were added unconditionally above, so we are
205 // guaranteed to get the correct function signature.
206 TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
207 squash_src_outputs, &node_map, output));
208
209 return OkStatus();
210 }
211
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)212 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
213 FunctionLibraryDefinition* library,
214 const NodeFilter& node_filter) {
215 if (node_filter && !frame->should_be_functionalized) {
216 VLOG(2) << "Skipping functionalization for frame " << frame->name
217 << " because it has control flow nodes that are filtered out by "
218 "the specified node filter.";
219 return OkStatus();
220 }
221 VLOG(2) << "Frame " << frame->name << " before: "
222 << DumpGraphToFile("functionalize_before", *graph, library);
223
224 // Split loop-varying Enter nodes with multiple successors. If the same
225 // Tensor is fed as input to multiple loop arguments, we may end up with a
226 // shared Enter node. We clone Enter nodes with multiple successors to
227 // maintain the invariant of a unique Enter node per argument of the final
228 // loop.
229 std::vector<WhileLoopArg> args;
230 args.reserve(frame->args.size());
231 for (const WhileLoopArg& arg : frame->args) {
232 if (arg.is_loop_invariant) {
233 args.push_back(arg);
234 } else {
235 std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
236 arg.enter->out_edges().end());
237 for (int i = 0, end = edges.size(); i < end; ++i) {
238 if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
239 continue;
240 }
241 TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
242 WhileLoopArg new_arg;
243 new_arg.is_loop_invariant = false;
244 if (i == 0) {
245 new_arg.enter = arg.enter;
246 } else {
247 new_arg.enter = graph->CopyNode(arg.enter);
248 frame->nodes.insert(new_arg.enter);
249 for (Edge const* e : arg.enter->in_edges()) {
250 graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
251 e->IsControlEdge() ? Graph::kControlSlot : 0);
252 }
253 Node* dst = edges[i]->dst();
254 int dst_input = edges[i]->dst_input();
255 graph->RemoveEdge(edges[i]);
256 graph->AddEdge(new_arg.enter, 0, dst, dst_input);
257 }
258 args.push_back(new_arg);
259 }
260 }
261 }
262 frame->args = std::move(args);
263
264 std::sort(frame->args.begin(), frame->args.end(),
265 [](const WhileLoopArg& a, const WhileLoopArg& b) {
266 return NodeCmpByNameResourcesLast()(a.enter, b.enter);
267 });
268
269 if (frame->loop_cond == nullptr) {
270 return errors::InvalidArgument("Loop ", frame->name,
271 " has no LoopCond node");
272 }
273
274 // Find the set of Switch nodes that are successors of the LoopCond.
275 std::unordered_set<Node*> switches;
276 for (const Edge* edge : frame->loop_cond->out_edges()) {
277 if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
278 edge->dst_input() == 1) {
279 switches.insert(edge->dst());
280 }
281 }
282
283 // For each non-constant argument, looks for the following pattern of nodes:
284 // Enter ----> Merge --------> Switch --> Exit
285 // ^ ^
286 // | |
287 // NextIteration LoopCond
288 // ^ ^
289 // | |
290 // ... ...
291 for (WhileLoopArg& arg : frame->args) {
292 if (!arg.is_loop_invariant) {
293 // Follow the edge from the Enter to Merge.
294 const Edge* enter_merge = nullptr;
295 for (const Edge* e : arg.enter->out_edges()) {
296 // Ignore control-edges to the sink node. These are allowed by the
297 // graph invariants, although probably they should have been stripped
298 // off earlier.
299 if (e->IsControlEdge() && e->dst()->IsSink()) {
300 continue;
301 }
302 if (enter_merge != nullptr) {
303 return errors::Internal("Enter node for loop-varying argument ",
304 FormatNodeForError(*arg.enter),
305 " has multiple successors: ",
306 FormatNodeForError(*enter_merge->dst()),
307 " and ", FormatNodeForError(*e->dst()));
308 }
309 enter_merge = e;
310 }
311 if (enter_merge == nullptr) {
312 return errors::Internal("Enter node for loop-varying argument ",
313 FormatNodeForError(*arg.enter),
314 " has zero successors");
315 }
316 arg.merge = enter_merge->dst();
317 if (!IsMerge(arg.merge)) {
318 return errors::InvalidArgument(
319 "Successor of Enter node for loop-varying argument ",
320 FormatNodeForError(*arg.merge),
321 " is not a Merge node; got: ", arg.merge->type_string());
322 }
323
324 // Find the NextIteration from the merge. There should be two inputs to
325 // the Merge and the NextIteration should be the other input.
326 if (arg.merge->input_types().size() != 2) {
327 return errors::InvalidArgument(
328 "Unexpected number of inputs to Merge node for loop-varying "
329 "argument ",
330 FormatNodeForError(*arg.merge), "; expected 2, got ",
331 arg.merge->input_types().size());
332 }
333 TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
334 &arg.next_iteration));
335 if (!IsNextIteration(arg.next_iteration)) {
336 return errors::InvalidArgument(
337 "Expected NextIteration node as input to Merge node; got node ",
338 FormatNodeForError(*arg.next_iteration), " with kind ",
339 arg.next_iteration->type_string());
340 }
341
342 // Find the Switch successor of the Merge. There should be exactly one
343 // Switch node that is a successor of both the Merge and the LoopCond.
344 for (const Edge* edge : arg.merge->out_edges()) {
345 if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
346 switches.find(edge->dst()) != switches.end()) {
347 if (arg.switch_node != nullptr) {
348 return errors::InvalidArgument("Duplicate Switch successors to ",
349 FormatNodeForError(*arg.merge));
350 }
351 arg.switch_node = edge->dst();
352 }
353 }
354 if (arg.switch_node == nullptr) {
355 return errors::InvalidArgument("Missing Switch successor to ",
356 FormatNodeForError(*arg.merge));
357 }
358 // Loop over the switch node's output to:
359 // - Find the Exit successor.
360 // - Set the sharding on all Identity outputs of the switch. These
361 // identity nodes are values used by the loop body or condition.
362 // The Identity node may have the wrong device so copy the device from
363 // one of its outputs instead.
364 std::deque<const Edge*> possible_exit;
365 for (const Edge* edge : arg.switch_node->out_edges()) {
366 if (edge->src_output() == 0) {
367 possible_exit.push_back(edge);
368 }
369 if (IsIdentity(edge->dst())) {
370 TF_RETURN_IF_ERROR(
371 SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
372 }
373 }
374 // TODO(b/67425339): Allow general graph between switch and exit.
375 while (!possible_exit.empty()) {
376 const Edge* edge = possible_exit.front();
377 possible_exit.pop_front();
378 if (IsExit(edge->dst())) {
379 if (arg.exit != nullptr) {
380 return errors::InvalidArgument(
381 "Duplicate Exit successors to ",
382 FormatNodeForError(*arg.switch_node));
383 }
384 arg.exit = edge->dst();
385 } else {
386 if (!IsIdentity(edge->dst())) {
387 return errors::Unimplemented("General graph between switch (",
388 FormatNodeForError(*arg.switch_node),
389 ") and exit node of frame ",
390 frame->name, " not supported yet.");
391 }
392 for (const Edge* out : edge->dst()->out_edges()) {
393 possible_exit.push_back(out);
394 }
395 }
396 }
397 }
398 }
399
400 // Builds the condition and body functions. Notice that we call
401 // FunctionalizeCond() on cond_graph and body_graph because we might have
402 // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
403 // before they are encapsulated in FunctionDef.
404 std::unique_ptr<Graph> cond_graph;
405 TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
406 FixupSourceAndSinkEdges(cond_graph.get());
407 TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
408 DataTypeVector arg_types;
409 std::unique_ptr<Graph> body_graph;
410 TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
411 FixupSourceAndSinkEdges(body_graph.get());
412 TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
413
414 VLOG(2) << "Frame " << frame->name << " condition: "
415 << DumpGraphToFile("loop_condition", *cond_graph, library)
416 << " body: " << DumpGraphToFile("loop_body", *body_graph);
417
418 NameAttrList cond_name;
419 cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
420 NameAttrList body_name;
421 body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
422 FunctionDef cond_fdef;
423 TF_RETURN_IF_ERROR(
424 GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
425 FunctionDef body_fdef;
426 TF_RETURN_IF_ERROR(
427 GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
428
429 TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
430 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
431
432 // Builds a While operator.
433 NodeDef while_def;
434 NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
435 builder.Attr("T", arg_types);
436 builder.Attr("cond", cond_name);
437 builder.Attr("body", body_name);
438 // Add some internal attributes which need to be propagated.
439 for (absl::string_view attr_name : kAttrsToPropagate) {
440 string attr_val;
441 if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
442 builder.Attr(attr_name, attr_val);
443 }
444 }
445 std::vector<NodeDefBuilder::NodeOut> inputs;
446 for (int i = 0, end = frame->args.size(); i < end; ++i) {
447 const WhileLoopArg& arg = frame->args[i];
448 const Edge* in_edge;
449 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
450 if (in_edge->IsControlEdge()) {
451 builder.ControlInput(in_edge->src()->name());
452 } else {
453 inputs.push_back(NodeDefBuilder::NodeOut(
454 in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
455 }
456 }
457 builder.Input(inputs);
458 TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
459 TF_ASSIGN_OR_RETURN(Node * while_node, graph->AddNode(while_def));
460
461 // Copies edges to the Enter nodes and from the Exit nodes onto the While.
462 for (int i = 0, end = frame->args.size(); i < end; ++i) {
463 const WhileLoopArg& arg = frame->args[i];
464 const Edge* in_edge;
465 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
466 if (in_edge->IsControlEdge()) {
467 graph->AddControlEdge(in_edge->src(), while_node);
468 } else {
469 graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
470 }
471
472 if (!arg.is_loop_invariant) {
473 // Add output edges if the output of the loop is consumed.
474 if (arg.exit != nullptr) {
475 std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
476 arg.exit->out_edges().end());
477 for (const Edge* edge : edges) {
478 Node* dst = edge->dst();
479 int dst_input = edge->dst_input();
480 graph->RemoveEdge(edge);
481
482 if (dst_input == Graph::kControlSlot) {
483 graph->AddControlEdge(while_node, dst);
484 } else {
485 graph->AddEdge(while_node, i, dst, dst_input);
486 }
487 }
488 }
489 }
490 }
491
492 // Remove the old nodes from the graph, and add the while node to the parent
493 // frame.
494 for (Node* node : frame->nodes) {
495 VLOG(2) << "Removing obsolete node " << node->name();
496 graph->RemoveNode(node);
497 }
498 frame->nodes.clear();
499 frame->parent->nodes.insert(while_node);
500
501 VLOG(2) << "Frame " << frame->name << " after: "
502 << DumpGraphToFile("functionalize_after", *graph, library);
503
504 return OkStatus();
505 }
506 } // namespace
507
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)508 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
509 const NodeFilter& node_filter) {
510 // Note: BuildControlFlowInfo() requires that the graph's source node is
511 // connected to all source nodes in the graph. Many graphs violate this
512 // invariant.
513 std::vector<ControlFlowInfo> cf_info;
514 std::vector<string> unreachable_nodes;
515 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
516 if (!unreachable_nodes.empty()) {
517 return errors::InvalidArgument(
518 "The following nodes are unreachable from the source in the graph: ",
519 errors::FormatNodeNamesForError(unreachable_nodes));
520 }
521
522 // Builds Frames, indexed by name.
523 std::unordered_map<string, WhileLoopFrame> frames;
524 TF_RETURN_IF_ERROR(
525 ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
526
527 // Adds frames with no children (i.e., the innermost frames) to a worklist.
528 std::deque<WhileLoopFrame*> worklist;
529 for (auto& frame : frames) {
530 if (frame.second.num_children == 0) {
531 worklist.push_back(&frame.second);
532 }
533 }
534
535 // Eliminate loops from innermost to outermost. Note that the precondition for
536 // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
537 // works.
538 while (!worklist.empty()) {
539 WhileLoopFrame* frame = worklist.front();
540 worklist.pop_front();
541 if (frame->parent == frame) {
542 // Skip the root frame.
543 continue;
544 }
545
546 TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
547
548 // If the parent has no remaining children, add it to the worklist.
549 --frame->parent->num_children;
550 if (frame->parent->num_children == 0) {
551 worklist.push_back(frame->parent);
552 }
553 }
554
555 if (!node_filter) {
556 // There should be no cycle at this point, since while loops have been
557 // removed from graph. Check that the newly added While nodes don't feed
558 // into themselves.
559 for (const Node* node : graph->op_nodes()) {
560 if (node->def().op() == "While") {
561 TF_RETURN_WITH_CONTEXT_IF_ERROR(
562 CheckNodeNotInCycle(node, graph->num_node_ids()),
563 "Functionalizing loop failed.");
564 }
565 }
566 }
567
568 return OkStatus();
569 }
570
571 } // namespace tensorflow
572