xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/lower_while_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/lower_while_op.h"
17 
18 #include "tensorflow/core/common_runtime/inline_function_utils.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 
28 using NodeOut = NodeBuilder::NodeOut;
29 
30 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32 
33 // Helper to convert a functional While op to its lowered form.
34 //
35 // Example:
36 //
37 // Input graph:
38 //
39 // loop_var -> WhileOp<cond_func, body_func> -> consumer
40 //
41 // Output graph(top to down flow):
42 //
43 //                   loop_var
44 //                      |
45 //                    Enter
46 //                      |
47 //  cond_func ---<--- Merge  ---<--- NextIteration
48 //      |               |                |
49 //      V               V                ^
50 //      |               |                |
51 //  LoopCond  --->--- Switch --->--- body_func
52 //                      |
53 //                     Exit
54 //                      |
55 //                   consumer
56 //
57 // DT_RESOURCE tensors are handled specially:
58 //
59 // resource_loop_var -> Enter[is_constant=True] -> cond_func and body_func
60 //      |
61 //      V
62 //   consumer
63 class LowerWhileHelper {
64  public:
Run(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,const FunctionLibraryDefinition * flib_def,bool keep_node_fetchable)65   static Status Run(Node* while_op, const NameAttrList& cond_fn,
66                     const NameAttrList& body_fn, int parallel_iterations,
67                     Graph* graph, const FunctionLibraryDefinition* flib_def,
68                     bool keep_node_fetchable) {
69     LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
70                             graph, flib_def, keep_node_fetchable);
71     return helper.RunInternal();
72   }
73 
74  private:
75   // Create a LowerWhileHelper to create the lowering of While op that has cond
76   // and body functions named `cond_fn_name` and `body_fn_name` respectively in
77   // the given graph.
78   LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
79                    const NameAttrList& body_fn, int parallel_iterations,
80                    Graph* graph, const FunctionLibraryDefinition* flib_def,
81                    bool keep_node_fetchable);
82 
83   Status RunInternal();
84 
85   void InitializeInputOutputToLoweredNodeMap();
86 
87   // Creates an Enter node for each `while_op_` input and adds them to
88   // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
89   // `src` node we add a control edge from `src` to each Enter node.
90   Status CreateEnterNodes();
91 
92   // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
93   // Initially now both inputs of a Merge node are the Enter node. Input at
94   // index 1 is later updated to the output of NextIteration node in
95   // `UpdateMergeNodes`.
96   Status CreateMergeNodes();
97 
98   // Creates the call node for cond func and stores in `cond_call_node_`.
99   Status CreateCondFuncCallNode();
100 
101   // Creates a Switch node for each loop var and adds to `switch_nodes_`.
102   // Output at index 1(true) of a Switch node is fed into the loop body.
103   // Output at index 0(false) of a Switch node is fed into the Exit nodes.
104   Status CreateSwitchNodes();
105 
106   // Creates the call node for body func and stores in `body_call_node_`.
107   Status CreateBodyFuncCallNode();
108 
109   // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
110   // are fed into the consumers of the `while_op_`.
111   Status CreateExitNodes();
112 
113   // Creates an NextIteration node for each loop var and adds to
114   // `next_iteration_nodes_`.
115   Status CreateNextIterationNodes();
116 
117   // Updates input at index 1 of each merge node created in `CreateMergeNodes`
118   // to use the output of NextIteration node created in
119   // `CreateNextIterationNodes` instead.
120   Status UpdateMergeNodes();
121 
122   // Updates consumers of the original `while_op_` to instead use the outputs
123   // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
124   // edges to depend on `lowered_while_executed_` instead.
125   Status UpdateConsumers();
126 
127   // Returns unique name containing the name of the While op being rewritten
128   // (name_), infix and a suffix to ensure it is unique within the graph.
129   string NewName(const string& infix);
130 
131   // Returns whether the While op's input/output at `index` is a `DT_RESOURCE`.
132   bool IsResource(int index);
133 
134   // The original While op.
135   Node* while_op_;
136   // The call node for the cond branch.
137   Node* cond_call_node_;
138   // The LoopCond node specifying the loop termination condition.
139   Node* loop_cond_node_;
140   // The call node for the body branch.
141   Node* body_call_node_;
142   // The node with the same name as the original While op:
143   //   (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'.
144   //   (b) NoOp node with control edge from 'lowered_while_executed_' otherwise.
145   Node* lowered_while_output_;
146   // The NoOp node with control edges from all Exit nodes. This node will be
147   // used as a source of outgoing control edges from lowered While node.
148   Node* lowered_while_executed_;
149   Graph* graph_;
150   const FunctionLibraryDefinition* flib_def_;
151   // Name of the `while_op_`.
152   string name_;
153   // Max number of parallel_iterations for the while loop.
154   const int parallel_iterations_;
155   bool keep_node_fetchable_;
156 
157   NodeDebugInfo debug_info_;
158   NodeBuilder cond_call_builder_;
159   NodeBuilder body_call_builder_;
160 
161   // `Enter` nodes, one per loop input/output.
162   // Note: `Enter` nodes with type `DT_RESOURCE` have attr `is_constant=True`.
163   std::vector<Node*> enter_nodes_;
164 
165   // Merge/Switch/NextIteration/Exit nodes, one per non-resource loop
166   // input/output.
167   std::vector<Node*> merge_nodes_;
168   std::vector<Node*> switch_nodes_;
169   std::vector<Node*> exit_nodes_;
170   std::vector<Node*> next_iterations_nodes_;
171   // Maps from the loop input/output indices to their corresponding
172   // Merge/Switch/NextIteration/Exit node indices. For inputs/outputs of
173   // `DT_RESOURCE` type there are no Merge/Switch/NextIteration/Exit nodes
174   // in which case the mapping contains -1.
175   std::vector<int> op_input_output_to_lowered_node_;
176 
177   size_t num_loop_inputs_;
178 };
179 
LowerWhileHelper(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,const FunctionLibraryDefinition * flib_def,bool keep_node_fetchable)180 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
181                                    const NameAttrList& body_fn,
182                                    int parallel_iterations, Graph* graph,
183                                    const FunctionLibraryDefinition* flib_def,
184                                    bool keep_node_fetchable)
185     : while_op_(while_op),
186       graph_(graph),
187       flib_def_(flib_def),
188       name_(while_op->name()),
189       parallel_iterations_(parallel_iterations),
190       keep_node_fetchable_(keep_node_fetchable),
191       debug_info_(*while_op_),
192       cond_call_builder_(NewName("cond"), cond_fn.name(), flib_def,
193                          &debug_info_),
194       body_call_builder_(NewName("body"), body_fn.name(), flib_def,
195                          &debug_info_),
196       num_loop_inputs_(while_op_->num_inputs()) {
197   cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
198   for (const auto& i : cond_fn.attr()) {
199     cond_call_builder_.Attr(i.first, i.second);
200   }
201   body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
202   for (const auto& i : body_fn.attr()) {
203     body_call_builder_.Attr(i.first, i.second);
204   }
205   // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
206   // because we need to set it's elements out of order in `CreateEnterNodes`.
207   enter_nodes_.resize(num_loop_inputs_);
208   merge_nodes_.reserve(num_loop_inputs_);
209   switch_nodes_.reserve(num_loop_inputs_);
210   exit_nodes_.reserve(num_loop_inputs_);
211   next_iterations_nodes_.reserve(num_loop_inputs_);
212   op_input_output_to_lowered_node_.resize(num_loop_inputs_, -1);
213 }
214 
RunInternal()215 Status LowerWhileHelper::RunInternal() {
216   InitializeInputOutputToLoweredNodeMap();
217   TF_RETURN_IF_ERROR(CreateEnterNodes());
218   TF_RETURN_IF_ERROR(CreateMergeNodes());
219   TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
220   TF_RETURN_IF_ERROR(CreateSwitchNodes());
221   TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
222   TF_RETURN_IF_ERROR(CreateExitNodes());
223   TF_RETURN_IF_ERROR(CreateNextIterationNodes());
224   TF_RETURN_IF_ERROR(UpdateMergeNodes());
225   TF_RETURN_IF_ERROR(UpdateConsumers());
226   return OkStatus();
227 }
228 
InitializeInputOutputToLoweredNodeMap()229 void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() {
230   int counter = 0;
231   for (int i = 0; i < num_loop_inputs_; i++) {
232     if (!IsResource(i)) {
233       op_input_output_to_lowered_node_[i] = counter++;
234     }
235   }
236 }
237 
CreateEnterNodes()238 Status LowerWhileHelper::CreateEnterNodes() {
239   // Note: `Node::input_edge` runs in  O(num_inputs) so we use
240   // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
241   // and not O(num_inputs^2).
242   std::vector<const Edge*> edges;
243   TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
244   for (const Edge* edge : edges) {
245     Node* enter_node;
246     NodeBuilder builder =
247         NodeBuilder(NewName("enter"), "Enter", flib_def_, &debug_info_)
248             .Input(NodeOut(edge->src(), edge->src_output()))
249             .Attr("frame_name", name_)
250             .Attr("parallel_iterations", parallel_iterations_)
251             .Device(edge->src()->requested_device())
252             .AssignedDevice(edge->src()->assigned_device_name());
253     if (IsResource(edge->dst_input())) {
254       builder.Attr("is_constant", true);
255     }
256     TF_RETURN_IF_ERROR(builder.Finalize(graph_, &enter_node));
257     enter_nodes_[edge->dst_input()] = enter_node;
258   }
259   // Create a NoOp node that takes incoming control inputs of the original While
260   // op as control inputs and use it as a control input for all Enter nodes.
261   std::vector<Node*> control_inputs;
262   for (const Edge* e : while_op_->in_edges()) {
263     if (e->IsControlEdge()) {
264       control_inputs.push_back(e->src());
265     }
266   }
267   if (!control_inputs.empty()) {
268     Node* incoming_control_node;
269     TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopControlInputs"), "NoOp",
270                                    flib_def_, &debug_info_)
271                            .ControlInputs(control_inputs)
272                            .Device(while_op_->requested_device())
273                            .Finalize(graph_, &incoming_control_node));
274     for (Node* n : enter_nodes_) {
275       graph_->AddControlEdge(incoming_control_node, n);
276     }
277   }
278   return OkStatus();
279 }
280 
CreateMergeNodes()281 Status LowerWhileHelper::CreateMergeNodes() {
282   for (Node* enter_node : enter_nodes_) {
283     if (enter_node->output_type(0) == DT_RESOURCE) {
284       continue;
285     }
286     Node* merge_node;
287     TF_RETURN_IF_ERROR(
288         NodeBuilder(NewName("merge"), "Merge", flib_def_, &debug_info_)
289             .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
290             .Device(enter_node->requested_device())
291             .AssignedDevice(enter_node->assigned_device_name())
292             .Finalize(graph_, &merge_node));
293     merge_nodes_.emplace_back(merge_node);
294   }
295   return OkStatus();
296 }
297 
CreateCondFuncCallNode()298 Status LowerWhileHelper::CreateCondFuncCallNode() {
299   for (int i = 0; i < num_loop_inputs_; i++) {
300     if (IsResource(i)) {
301       cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
302     } else {
303       cond_call_builder_.Input(
304           NodeOut(merge_nodes_[op_input_output_to_lowered_node_[i]], 0));
305     }
306   }
307   cond_call_builder_.Device(while_op_->requested_device());
308   TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
309   // Add a control edge to make sure the Const nodes in the cond function
310   // are in the same frame as the rest of the function, otherwise
311   // `BuildControlFlowInfo` throws an error.
312   graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
313   TF_RETURN_IF_ERROR(
314       NodeBuilder(NewName("LoopCond"), "LoopCond", flib_def_, &debug_info_)
315           .Input(NodeOut(cond_call_node_, 0))
316           .Device(while_op_->requested_device())
317           .Finalize(graph_, &loop_cond_node_));
318   return OkStatus();
319 }
320 
CreateSwitchNodes()321 Status LowerWhileHelper::CreateSwitchNodes() {
322   for (int i = 0; i < num_loop_inputs_; i++) {
323     if (IsResource(i)) {
324       continue;
325     }
326     string op_name;
327     {
328       const Node* input_node;
329       TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
330       op_name = strings::StrCat(input_node->name(), "_switch");
331     }
332     Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
333     Node* switch_node;
334     string op_type = "Switch";
335     if (IsRefType(merge_node->output_type(0))) {
336       op_type = "RefSwitch";
337     }
338     TF_RETURN_IF_ERROR(
339         NodeBuilder(NewName(op_name), op_type, flib_def_, &debug_info_)
340             .Input(NodeOut(merge_node, 0))
341             .Input(NodeOut(loop_cond_node_, 0))
342             .Device(merge_node->requested_device())
343             .AssignedDevice(merge_node->assigned_device_name())
344             .Finalize(graph_, &switch_node));
345     switch_nodes_.emplace_back(switch_node);
346   }
347   return OkStatus();
348 }
349 
CreateBodyFuncCallNode()350 Status LowerWhileHelper::CreateBodyFuncCallNode() {
351   for (int i = 0; i < num_loop_inputs_; i++) {
352     if (IsResource(i)) {
353       body_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
354     } else {
355       body_call_builder_.Input(
356           NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], 1));
357     }
358   }
359   body_call_builder_.Device(while_op_->requested_device());
360   TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
361   // Add a control edge to make sure the Const nodes in the body function
362   // are in the same frame as the rest of the function, otherwise
363   // `BuildControlFlowInfo` throws an error.
364   // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
365   // this is how tf.while_loop does it. Can this affect performance if the 0th
366   // node is not the first one to be ready? Can we speed that case up using some
367   // sort of multi-input Merge?
368   Node* body_control_node_;
369   string op_type = "Identity";
370   if (IsRefType(switch_nodes_[0]->output_type(1))) {
371     op_type = "RefIdentity";
372   }
373   TF_RETURN_IF_ERROR(NodeBuilder(NewName("loop_body_control"), op_type,
374                                  flib_def_, &debug_info_)
375                          .Input(NodeOut(switch_nodes_[0], 1))
376                          .Device(while_op_->requested_device())
377                          .Finalize(graph_, &body_control_node_));
378   graph_->AddControlEdge(body_control_node_, body_call_node_);
379   return OkStatus();
380 }
381 
CreateExitNodes()382 Status LowerWhileHelper::CreateExitNodes() {
383   std::vector<NodeOut> outputs;
384   outputs.reserve(num_loop_inputs_);
385   for (int i = 0; i < num_loop_inputs_; i++) {
386     if (IsResource(i)) {
387       // Note(srbs): A resource output of this While should never be used but we
388       // need this for the IdentityN node below.
389       OutputTensor resource_tensor;
390       TF_RETURN_IF_ERROR(enter_nodes_[i]->input_tensor(0, &resource_tensor));
391       outputs.emplace_back(resource_tensor);
392     } else {
393       Node* exit_node;
394       TF_RETURN_IF_ERROR(
395           NodeBuilder(NewName("exit"), "Exit", flib_def_, &debug_info_)
396               .Input(NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]],
397                              0))
398               .Device(switch_nodes_[op_input_output_to_lowered_node_[i]]
399                           ->requested_device())
400               .AssignedDevice(switch_nodes_[op_input_output_to_lowered_node_[i]]
401                                   ->assigned_device_name())
402               .Finalize(graph_, &exit_node));
403       exit_nodes_.emplace_back(exit_node);
404       outputs.emplace_back(NodeOut(exit_node, 0));
405     }
406   }
407 
408   // We split data and control outputs of lowered while op, because otherwise
409   // after lowering of multi-device loop body we might end up with DT_RESOURCE
410   // inputs from multiple devices coming into IdentityN.
411 
412   // Add a NoOp node that has control edges from all Exit nodes. This node is
413   // used for rewriting control edges with the original while op as src.
414   TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopExecuted"), "NoOp",
415                                  OpRegistry::Global(), &debug_info_)
416                          .ControlInputs(exit_nodes_)
417                          .Device(while_op_->requested_device())
418                          .Finalize(graph_, &lowered_while_executed_));
419 
420   if (keep_node_fetchable_) {
421     // Add an IdentityN node that has the same outputs and same name as the
422     // original functional While op. This is used for fetching the output of the
423     // While node by name in calls to sess.run.
424     TF_RETURN_IF_ERROR(
425         NodeBuilder(name_, "IdentityN", OpRegistry::Global(), &debug_info_)
426             .Input(outputs)
427             .Device(while_op_->requested_device())
428             .Finalize(graph_, &lowered_while_output_));
429   } else {
430     // Even if we don't plan to fetch tensors from the lowered While op, we must
431     // keep it a valid source of control edges, because it might be a part of
432     // function control output set.
433     TF_RETURN_IF_ERROR(
434         NodeBuilder(name_, "NoOp", OpRegistry::Global(), &debug_info_)
435             .ControlInput(lowered_while_executed_)
436             .Device(while_op_->requested_device())
437             .Finalize(graph_, &lowered_while_output_));
438   }
439 
440   return OkStatus();
441 }
442 
CreateNextIterationNodes()443 Status LowerWhileHelper::CreateNextIterationNodes() {
444   for (int i = 0; i < num_loop_inputs_; i++) {
445     Node* next_iteration;
446     if (IsResource(i)) {
447       continue;
448     }
449     Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
450     TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
451                                    flib_def_, &debug_info_)
452                            .Input(NodeOut(body_call_node_, i))
453                            .ControlInput(body_call_node_)
454                            .Device(merge_node->requested_device())
455                            .AssignedDevice(merge_node->assigned_device_name())
456                            .Finalize(graph_, &next_iteration));
457     next_iterations_nodes_.emplace_back(next_iteration);
458   }
459   return OkStatus();
460 }
461 
UpdateMergeNodes()462 Status LowerWhileHelper::UpdateMergeNodes() {
463   for (int i = 0; i < merge_nodes_.size(); i++) {
464     TF_RETURN_IF_ERROR(
465         graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
466   }
467   return OkStatus();
468 }
469 
UpdateConsumers()470 Status LowerWhileHelper::UpdateConsumers() {
471   for (const Edge* e : while_op_->out_edges()) {
472     if (e->IsControlEdge()) {
473       graph_->AddControlEdge(lowered_while_executed_, e->dst());
474     } else {
475       if (IsResource(e->src_output())) {
476         OutputTensor resource;
477         TF_RETURN_IF_ERROR(
478             enter_nodes_[e->src_output()]->input_tensor(0, &resource));
479         graph_->AddEdge(resource.node, resource.index, e->dst(),
480                         e->dst_input());
481       } else {
482         // Feed the outputs directly from the exit nodes so that downstream ops
483         // can start before all the outputs have been computed.
484         int exit_node_index = op_input_output_to_lowered_node_[e->src_output()];
485         if (exit_node_index < 0) {
486           return errors::Internal(
487               "Expecting an Exit node for a Resource tensor.");
488         }
489         graph_->AddEdge(exit_nodes_[exit_node_index], 0, e->dst(),
490                         e->dst_input());
491       }
492     }
493   }
494   return OkStatus();
495 }
496 
NewName(const string & infix)497 string LowerWhileHelper::NewName(const string& infix) {
498   return graph_->NewName(strings::StrCat(name_, "/", infix));
499 }
500 
IsResource(int index)501 bool LowerWhileHelper::IsResource(int index) {
502   return while_op_->input_type(index) == DT_RESOURCE;
503 }
504 
505 }  // namespace
506 
RewriteWhileNode(Node * n,Graph * g,const FunctionLibraryDefinition * flib_def,bool keep_node_fetchable)507 Status RewriteWhileNode(Node* n, Graph* g,
508                         const FunctionLibraryDefinition* flib_def,
509                         bool keep_node_fetchable) {
510   VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable
511           << "): " << SummarizeNode(*n);
512 
513   const AttrValue* cond_attr = n->attrs().Find("cond");
514   if (cond_attr == nullptr) {
515     return errors::InvalidArgument("While cond function missing");
516   }
517   const AttrValue* body_attr = n->attrs().Find("body");
518   if (body_attr == nullptr) {
519     return errors::InvalidArgument("While body function missing");
520   }
521   const AttrValue* parallel_iterations_attr =
522       n->attrs().Find("parallel_iterations");
523   if (parallel_iterations_attr == nullptr) {
524     return errors::InvalidArgument("parallel_iterations attr missing");
525   }
526   if (parallel_iterations_attr->i() < 1) {
527     return errors::InvalidArgument("parallel_iterations must be > 0");
528   }
529 
530   TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
531       n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
532       flib_def, keep_node_fetchable));
533   g->RemoveNode(n);
534 
535   return OkStatus();
536 }
537 
538 }  // namespace tensorflow
539