xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/functionalize_while.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
45 // starting at nodes in vector `stack`.
46 // `node_map` is a vector indexed by source node ID to dest nodes.
47 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
48 // before the traversal clients can cut the graph. If a frame is provided (frame
49 // != nullptr), then this functions will return an error if the
50 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
51 // cut the graph and prevent the traversal from escaping.
52 //
53 // `squash_src_outputs` contains a bool for each source node ID. If true, then
54 // the source output on that node will be replaced by zero when copied. This is
55 // used when replacing a Switch node with an _Arg node. The output we are
56 // taking from the Switch node was not necessarily the first output, but _Arg
57 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
58 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)59 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
60                     std::vector<Node*> stack,
61                     const std::vector<bool>& squash_src_outputs,
62                     std::vector<Node*>* node_map, Graph* output) {
63   VLOG(3) << "Stack: " << NodesToString(stack);
64   std::vector<bool> visited(graph.num_node_ids(), false);
65   while (!stack.empty()) {
66     Node* n = stack.back();
67     stack.pop_back();
68 
69     VLOG(5) << "Copying node " << n->name();
70 
71     if (visited[n->id()]) continue;
72     visited[n->id()] = true;
73 
74     // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
75     // order.
76     std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
77                                           n->in_edges().end());
78     std::sort(sorted_edges.begin(), sorted_edges.end(),
79               [](const Edge* a, const Edge* b) {
80                 int a_src_output = a->src_output(),
81                     b_src_output = b->src_output();
82                 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
83                 return std::tie(a_src_output, a_name) <
84                        std::tie(b_src_output, b_name);
85               });
86     for (const Edge* e : sorted_edges) {
87       Node* src = e->src();
88       if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
89         // We traversed out of the loop frame, without encountering a cut node.
90         return errors::Internal("Graph traversal of loop frame ", frame->name,
91                                 " escaped frame at ", src->name(),
92                                 " without encountering an argument node.");
93       }
94       if ((*node_map)[src->id()] == nullptr) {
95         (*node_map)[src->id()] = output->CopyNode(src);
96         stack.push_back(src);
97       }
98       Node* src_copy = (*node_map)[e->src()->id()];
99       int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
100                            ? 0
101                            : e->src_output();
102       Node* dst_copy = (*node_map)[e->dst()->id()];
103       output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
104     }
105   }
106   return OkStatus();
107 }
108 
BuildArgNode(Graph * graph,DataType type,int index)109 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
110   const char* const kArgOp = "_Arg";
111   NodeDef arg_def;
112   NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
113   builder.Attr("T", type);
114   builder.Attr("index", index);
115   TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
116   return graph->AddNode(arg_def);
117 }
118 
119 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)120 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
121                           std::unique_ptr<Graph>* cond_output) {
122   VLOG(2) << "Building loop condition for " << frame->name;
123   *cond_output = std::make_unique<Graph>(graph.op_registry());
124   Graph* output = cond_output->get();
125 
126   // Map from nodes in the original graph to the condition graph.
127   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
128   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
129 
130   // Build one _Arg node for each Enter node.
131   for (int i = 0, end = frame->args.size(); i < end; ++i) {
132     const WhileLoopArg& arg = frame->args[i];
133 
134     TF_ASSIGN_OR_RETURN(Node * arg_node,
135                         BuildArgNode(output, arg.enter->input_type(0), i));
136     if (arg.is_loop_invariant) {
137       node_map[arg.enter->id()] = arg_node;
138     } else {
139       node_map[arg.merge->id()] = arg_node;
140     }
141   }
142 
143   // Build a Retval node for the loop condition. The LoopCond nodes are always
144   // boolean because of the type constraints on the LoopCond op.
145   TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
146                       BuildRetvalNode(output, DT_BOOL, 0));
147 
148   // Performs a reverse DFS, copying nodes and edges to the output graph.
149   // The _Arg and _Retval nodes were added unconditionally above, so we are
150   // guaranteed to get the correct function signature.
151   return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
152                       &node_map, output);
153 }
154 
155 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)156 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
157                      DataTypeVector* arg_types,
158                      std::unique_ptr<Graph>* body_output) {
159   VLOG(2) << "Building loop body for " << frame->name;
160   *body_output = std::make_unique<Graph>(graph.op_registry());
161   Graph* output = body_output->get();
162 
163   // Map from nodes in the original graph to the body graph.
164   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
165   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
166 
167   // Build one _Arg node for each Enter node.
168   std::vector<Node*> next_iterations;
169   next_iterations.reserve(frame->args.size());
170   arg_types->reserve(frame->args.size());
171   for (int i = 0, end = frame->args.size(); i < end; ++i) {
172     const WhileLoopArg& arg = frame->args[i];
173 
174     DataType dtype = arg.enter->input_type(0);
175     arg_types->push_back(dtype);
176 
177     TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
178     TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
179     if (arg.is_loop_invariant) {
180       // Argument is loop-invariant. Forward it from the Arg to the Retval.
181       node_map[arg.enter->id()] = arg_node;
182       output->AddEdge(arg_node, 0, retval_node, 0);
183     } else {
184       // Argument is loop-varying.
185       if (dtype == DT_RESOURCE) {
186         // DT_RESOURCE arguments should always be loop-invariant in the graphs
187         // generated from TF.
188         return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
189                                      arg.enter->name(), " is currently not",
190                                      " supported.");
191       }
192       node_map[arg.switch_node->id()] = arg_node;
193       // The Switch node has two outputs, but _Arg only has one. This tells
194       // the CopySubgraph function to rewrite the output number of edges from
195       // the _Arg node to be 0 rather than copying the output number from the
196       // Switch node.
197       squash_src_outputs[arg.switch_node->id()] = true;
198       node_map[arg.next_iteration->id()] = retval_node;
199       next_iterations.push_back(arg.next_iteration);
200     }
201   }
202 
203   // Performs a reverse DFS, copying nodes and edges to the output graph.
204   // The _Arg and _Retval nodes were added unconditionally above, so we are
205   // guaranteed to get the correct function signature.
206   TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
207                                   squash_src_outputs, &node_map, output));
208 
209   return OkStatus();
210 }
211 
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)212 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
213                          FunctionLibraryDefinition* library,
214                          const NodeFilter& node_filter) {
215   if (node_filter && !frame->should_be_functionalized) {
216     VLOG(2) << "Skipping functionalization for frame " << frame->name
217             << " because it has control flow nodes that are filtered out by "
218                "the specified node filter.";
219     return OkStatus();
220   }
221   VLOG(2) << "Frame " << frame->name << " before: "
222           << DumpGraphToFile("functionalize_before", *graph, library);
223 
224   // Split loop-varying Enter nodes with multiple successors. If the same
225   // Tensor is fed as input to multiple loop arguments, we may end up with a
226   // shared Enter node. We clone Enter nodes with multiple successors to
227   // maintain the invariant of a unique Enter node per argument of the final
228   // loop.
229   std::vector<WhileLoopArg> args;
230   args.reserve(frame->args.size());
231   for (const WhileLoopArg& arg : frame->args) {
232     if (arg.is_loop_invariant) {
233       args.push_back(arg);
234     } else {
235       std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
236                                      arg.enter->out_edges().end());
237       for (int i = 0, end = edges.size(); i < end; ++i) {
238         if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
239           continue;
240         }
241         TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
242         WhileLoopArg new_arg;
243         new_arg.is_loop_invariant = false;
244         if (i == 0) {
245           new_arg.enter = arg.enter;
246         } else {
247           new_arg.enter = graph->CopyNode(arg.enter);
248           frame->nodes.insert(new_arg.enter);
249           for (Edge const* e : arg.enter->in_edges()) {
250             graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
251                            e->IsControlEdge() ? Graph::kControlSlot : 0);
252           }
253           Node* dst = edges[i]->dst();
254           int dst_input = edges[i]->dst_input();
255           graph->RemoveEdge(edges[i]);
256           graph->AddEdge(new_arg.enter, 0, dst, dst_input);
257         }
258         args.push_back(new_arg);
259       }
260     }
261   }
262   frame->args = std::move(args);
263 
264   std::sort(frame->args.begin(), frame->args.end(),
265             [](const WhileLoopArg& a, const WhileLoopArg& b) {
266               return NodeCmpByNameResourcesLast()(a.enter, b.enter);
267             });
268 
269   if (frame->loop_cond == nullptr) {
270     return errors::InvalidArgument("Loop ", frame->name,
271                                    " has no LoopCond node");
272   }
273 
274   // Find the set of Switch nodes that are successors of the LoopCond.
275   std::unordered_set<Node*> switches;
276   for (const Edge* edge : frame->loop_cond->out_edges()) {
277     if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
278         edge->dst_input() == 1) {
279       switches.insert(edge->dst());
280     }
281   }
282 
283   // For each non-constant argument, looks for the following pattern of nodes:
284   // Enter ----> Merge  -------->  Switch  --> Exit
285   //               ^                  ^
286   //               |                  |
287   //         NextIteration         LoopCond
288   //               ^                  ^
289   //               |                  |
290   //              ...                ...
291   for (WhileLoopArg& arg : frame->args) {
292     if (!arg.is_loop_invariant) {
293       // Follow the edge from the Enter to Merge.
294       const Edge* enter_merge = nullptr;
295       for (const Edge* e : arg.enter->out_edges()) {
296         // Ignore control-edges to the sink node. These are allowed by the
297         // graph invariants, although probably they should have been stripped
298         // off earlier.
299         if (e->IsControlEdge() && e->dst()->IsSink()) {
300           continue;
301         }
302         if (enter_merge != nullptr) {
303           return errors::Internal("Enter node for loop-varying argument ",
304                                   FormatNodeForError(*arg.enter),
305                                   " has multiple successors: ",
306                                   FormatNodeForError(*enter_merge->dst()),
307                                   " and ", FormatNodeForError(*e->dst()));
308         }
309         enter_merge = e;
310       }
311       if (enter_merge == nullptr) {
312         return errors::Internal("Enter node for loop-varying argument ",
313                                 FormatNodeForError(*arg.enter),
314                                 " has zero successors");
315       }
316       arg.merge = enter_merge->dst();
317       if (!IsMerge(arg.merge)) {
318         return errors::InvalidArgument(
319             "Successor of Enter node for loop-varying argument ",
320             FormatNodeForError(*arg.merge),
321             " is not a Merge node; got: ", arg.merge->type_string());
322       }
323 
324       // Find the NextIteration from the merge. There should be two inputs to
325       // the Merge and the NextIteration should be the other input.
326       if (arg.merge->input_types().size() != 2) {
327         return errors::InvalidArgument(
328             "Unexpected number of inputs to Merge node for loop-varying "
329             "argument ",
330             FormatNodeForError(*arg.merge), "; expected 2, got ",
331             arg.merge->input_types().size());
332       }
333       TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
334                                                &arg.next_iteration));
335       if (!IsNextIteration(arg.next_iteration)) {
336         return errors::InvalidArgument(
337             "Expected NextIteration node as input to Merge node; got node ",
338             FormatNodeForError(*arg.next_iteration), " with kind ",
339             arg.next_iteration->type_string());
340       }
341 
342       // Find the Switch successor of the Merge. There should be exactly one
343       // Switch node that is a successor of both the Merge and the LoopCond.
344       for (const Edge* edge : arg.merge->out_edges()) {
345         if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
346             switches.find(edge->dst()) != switches.end()) {
347           if (arg.switch_node != nullptr) {
348             return errors::InvalidArgument("Duplicate Switch successors to ",
349                                            FormatNodeForError(*arg.merge));
350           }
351           arg.switch_node = edge->dst();
352         }
353       }
354       if (arg.switch_node == nullptr) {
355         return errors::InvalidArgument("Missing Switch successor to ",
356                                        FormatNodeForError(*arg.merge));
357       }
358       // Loop over the switch node's output to:
359       // - Find the Exit successor.
360       // - Set the sharding on all Identity outputs of the switch. These
361       //   identity nodes are values used by the loop body or condition.
362       //   The Identity node may have the wrong device so copy the device from
363       //   one of its outputs instead.
364       std::deque<const Edge*> possible_exit;
365       for (const Edge* edge : arg.switch_node->out_edges()) {
366         if (edge->src_output() == 0) {
367           possible_exit.push_back(edge);
368         }
369         if (IsIdentity(edge->dst())) {
370           TF_RETURN_IF_ERROR(
371               SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
372         }
373       }
374       // TODO(b/67425339): Allow general graph between switch and exit.
375       while (!possible_exit.empty()) {
376         const Edge* edge = possible_exit.front();
377         possible_exit.pop_front();
378         if (IsExit(edge->dst())) {
379           if (arg.exit != nullptr) {
380             return errors::InvalidArgument(
381                 "Duplicate Exit successors to ",
382                 FormatNodeForError(*arg.switch_node));
383           }
384           arg.exit = edge->dst();
385         } else {
386           if (!IsIdentity(edge->dst())) {
387             return errors::Unimplemented("General graph between switch (",
388                                          FormatNodeForError(*arg.switch_node),
389                                          ") and exit node of frame ",
390                                          frame->name, " not supported yet.");
391           }
392           for (const Edge* out : edge->dst()->out_edges()) {
393             possible_exit.push_back(out);
394           }
395         }
396       }
397     }
398   }
399 
400   // Builds the condition and body functions. Notice that we call
401   // FunctionalizeCond() on cond_graph and body_graph because we might have
402   // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
403   // before they are encapsulated in FunctionDef.
404   std::unique_ptr<Graph> cond_graph;
405   TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
406   FixupSourceAndSinkEdges(cond_graph.get());
407   TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
408   DataTypeVector arg_types;
409   std::unique_ptr<Graph> body_graph;
410   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
411   FixupSourceAndSinkEdges(body_graph.get());
412   TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
413 
414   VLOG(2) << "Frame " << frame->name << " condition: "
415           << DumpGraphToFile("loop_condition", *cond_graph, library)
416           << " body: " << DumpGraphToFile("loop_body", *body_graph);
417 
418   NameAttrList cond_name;
419   cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
420   NameAttrList body_name;
421   body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
422   FunctionDef cond_fdef;
423   TF_RETURN_IF_ERROR(
424       GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
425   FunctionDef body_fdef;
426   TF_RETURN_IF_ERROR(
427       GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
428 
429   TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
430   TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
431 
432   // Builds a While operator.
433   NodeDef while_def;
434   NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
435   builder.Attr("T", arg_types);
436   builder.Attr("cond", cond_name);
437   builder.Attr("body", body_name);
438   // Add some internal attributes which need to be propagated.
439   for (absl::string_view attr_name : kAttrsToPropagate) {
440     string attr_val;
441     if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
442       builder.Attr(attr_name, attr_val);
443     }
444   }
445   std::vector<NodeDefBuilder::NodeOut> inputs;
446   for (int i = 0, end = frame->args.size(); i < end; ++i) {
447     const WhileLoopArg& arg = frame->args[i];
448     const Edge* in_edge;
449     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
450     if (in_edge->IsControlEdge()) {
451       builder.ControlInput(in_edge->src()->name());
452     } else {
453       inputs.push_back(NodeDefBuilder::NodeOut(
454           in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
455     }
456   }
457   builder.Input(inputs);
458   TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
459   TF_ASSIGN_OR_RETURN(Node * while_node, graph->AddNode(while_def));
460 
461   // Copies edges to the Enter nodes and from the Exit nodes onto the While.
462   for (int i = 0, end = frame->args.size(); i < end; ++i) {
463     const WhileLoopArg& arg = frame->args[i];
464     const Edge* in_edge;
465     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
466     if (in_edge->IsControlEdge()) {
467       graph->AddControlEdge(in_edge->src(), while_node);
468     } else {
469       graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
470     }
471 
472     if (!arg.is_loop_invariant) {
473       // Add output edges if the output of the loop is consumed.
474       if (arg.exit != nullptr) {
475         std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
476                                        arg.exit->out_edges().end());
477         for (const Edge* edge : edges) {
478           Node* dst = edge->dst();
479           int dst_input = edge->dst_input();
480           graph->RemoveEdge(edge);
481 
482           if (dst_input == Graph::kControlSlot) {
483             graph->AddControlEdge(while_node, dst);
484           } else {
485             graph->AddEdge(while_node, i, dst, dst_input);
486           }
487         }
488       }
489     }
490   }
491 
492   // Remove the old nodes from the graph, and add the while node to the parent
493   // frame.
494   for (Node* node : frame->nodes) {
495     VLOG(2) << "Removing obsolete node " << node->name();
496     graph->RemoveNode(node);
497   }
498   frame->nodes.clear();
499   frame->parent->nodes.insert(while_node);
500 
501   VLOG(2) << "Frame " << frame->name << " after: "
502           << DumpGraphToFile("functionalize_after", *graph, library);
503 
504   return OkStatus();
505 }
506 }  // namespace
507 
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)508 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
509                               const NodeFilter& node_filter) {
510   // Note: BuildControlFlowInfo() requires that the graph's source node is
511   // connected to all source nodes in the graph. Many graphs violate this
512   // invariant.
513   std::vector<ControlFlowInfo> cf_info;
514   std::vector<string> unreachable_nodes;
515   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
516   if (!unreachable_nodes.empty()) {
517     return errors::InvalidArgument(
518         "The following nodes are unreachable from the source in the graph: ",
519         errors::FormatNodeNamesForError(unreachable_nodes));
520   }
521 
522   // Builds Frames, indexed by name.
523   std::unordered_map<string, WhileLoopFrame> frames;
524   TF_RETURN_IF_ERROR(
525       ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
526 
527   // Adds frames with no children (i.e., the innermost frames) to a worklist.
528   std::deque<WhileLoopFrame*> worklist;
529   for (auto& frame : frames) {
530     if (frame.second.num_children == 0) {
531       worklist.push_back(&frame.second);
532     }
533   }
534 
535   // Eliminate loops from innermost to outermost. Note that the precondition for
536   // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
537   // works.
538   while (!worklist.empty()) {
539     WhileLoopFrame* frame = worklist.front();
540     worklist.pop_front();
541     if (frame->parent == frame) {
542       // Skip the root frame.
543       continue;
544     }
545 
546     TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
547 
548     // If the parent has no remaining children, add it to the worklist.
549     --frame->parent->num_children;
550     if (frame->parent->num_children == 0) {
551       worklist.push_back(frame->parent);
552     }
553   }
554 
555   if (!node_filter) {
556     // There should be no cycle at this point, since while loops have been
557     // removed from graph. Check that the newly added While nodes don't feed
558     // into themselves.
559     for (const Node* node : graph->op_nodes()) {
560       if (node->def().op() == "While") {
561         TF_RETURN_WITH_CONTEXT_IF_ERROR(
562             CheckNodeNotInCycle(node, graph->num_node_ids()),
563             "Functionalizing loop failed.");
564       }
565     }
566   }
567 
568   return OkStatus();
569 }
570 
571 }  // namespace tensorflow
572