xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/loop_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/allocator.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils/frame.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/tensor_coding.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/saved_tensor_slice_util.h"
50 
51 using tensorflow::strings::StrCat;
52 
53 namespace tensorflow {
54 namespace grappler {
55 namespace {
56 
57 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
58 
59 class LoopInvariantNodeMotionOptimizer {
60  public:
LoopInvariantNodeMotionOptimizer(GraphDef * optimized_graph)61   explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
62       : optimized_graph_(optimized_graph) {}
63   virtual ~LoopInvariantNodeMotionOptimizer() = default;
64   Status Optimize();
65 
66  private:
67   Status FindInvariantNodes(NodeDef* node);
68   Status RevertInvariantNodes();
69   Status MoveInvariantNodes(const int frame_id);
70   Status HandleInvariantNode(NodeDef* node, const int num_outputs,
71                              const int frame_id);
72   Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
73   Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
74 
75   GraphDef* optimized_graph_;  // Not owned.
76   std::unique_ptr<NodeMap> node_map_;
77   std::map<NodeDef*, int> invariant_nodes_;
78   std::set<int> empty_set_;
79   std::vector<std::set<int>> frame_children_;
80   std::vector<int> frame_parent_;
81   std::map<int, const NodeDef*> loop_cond_;
82   std::map<int, std::vector<NodeDef*>> invariant_enters_;
83   int new_enter_id_;
84 };
85 
HandleInvariantEnter(NodeDef * node,const int num_outputs)86 Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
87     NodeDef* node, const int num_outputs) {
88   auto consumers = node_map_->GetOutputs(node->name());
89   std::vector<string> enter_control_inputs;
90   string enter_input;
91   for (auto& input : node->input()) {
92     if (IsControlInput(input)) {
93       enter_control_inputs.push_back(input);
94     } else {
95       enter_input = input;
96     }
97   }
98   for (auto* consumer : consumers) {
99     if (invariant_nodes_.count(consumer)) {
100       for (int i = 0; i < consumer->input_size(); ++i) {
101         if (NodeName(consumer->input(i)) == node->name()) {
102           consumer->set_input(i, enter_input);
103           node_map_->AddOutput(NodeName(enter_input), consumer->name());
104           node_map_->RemoveOutput(node->name(), consumer->name());
105         }
106       }
107       for (auto& control_input : enter_control_inputs) {
108         consumer->add_input(control_input);
109         node_map_->AddOutput(NodeName(control_input), consumer->name());
110       }
111     }
112   }
113   return OkStatus();
114 }
115 
HandleConst(NodeDef * node,const int num_outputs,const int frame_id)116 Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
117                                                      const int num_outputs,
118                                                      const int frame_id) {
119   NodeDef* const_node = nullptr;
120   if (num_outputs == 0) {
121     // all successor nodes are invariant
122     // Remove the control inputs from this frame to the const node,
123     // when moving it out of the frame (in parent frame)
124     const_node = node;
125     node_map_->RemoveInputs(node->name());
126     node->clear_input();
127   } else {
128     // some successor nodes are variant
129     // Have to keep the const node in the frame,
130     // so create a new one outside the frame (in parent frame)
131     const string const_node_name =
132         AddPrefixToNodeName(node->name(), kLoopOptimizer);
133     const_node = node_map_->GetNode(const_node_name);
134     if (const_node == nullptr) {
135       const_node = optimized_graph_->add_node();
136       const_node->set_name(const_node_name);
137       const_node->set_op("Const");
138       const_node->set_device(node->device());
139       *const_node->mutable_attr() = node->attr();
140       node_map_->AddNode(const_node->name(), const_node);
141     }
142     auto consumers = node_map_->GetOutputs(node->name());
143     for (auto* consumer : consumers) {
144       if (invariant_nodes_.count(consumer)) {
145         for (int i = 0; i < consumer->input_size(); ++i) {
146           if (NodeName(consumer->input(i)) == node->name()) {
147             if (IsControlInput(consumer->input(i))) {
148               *consumer->mutable_input(i) = AsControlDependency(*const_node);
149             } else {
150               *consumer->mutable_input(i) = const_node->name();
151             }
152             node_map_->AddOutput(const_node->name(), consumer->name());
153             node_map_->RemoveOutput(node->name(), consumer->name());
154           }
155         }
156       }
157     }
158   }
159   // add a control input from the parent frame
160   if (frame_parent_[frame_id] != -1) {
161     int parent_id = frame_parent_[frame_id];
162     auto loop_cond_it = loop_cond_.find(parent_id);
163     if (loop_cond_it == loop_cond_.end()) {
164       return errors::InvalidArgument("Frame ", frame_id,
165                                      " doesn't have a LoopCond node");
166     }
167     auto& loop_cond_name = loop_cond_it->second->name();
168     NodeDef* switch_node = nullptr;
169     for (auto* node : node_map_->GetOutputs(loop_cond_name)) {
170       if (node->op() == "Switch") {
171         switch_node = node;
172         break;
173       }
174     }
175     if (!switch_node) {
176       return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
177                                      " doesn't connect to any Switch node");
178     }
179     string switch_output = StrCat(switch_node->name(), ":1");
180     const string ctrl_dep = ConstantFolding::AddControlDependency(
181         switch_output, optimized_graph_, node_map_.get());
182     const_node->add_input(ctrl_dep);
183     node_map_->AddOutput(NodeName(ctrl_dep), const_node->name());
184   }
185   return OkStatus();
186 }
187 
HandleInvariantNode(NodeDef * node,const int num_outputs,const int frame_id)188 Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
189     NodeDef* node, const int num_outputs, const int frame_id) {
190   // have to remove control inputs to the invariant node from the same frame
191   // when moving this node out of this frame
192   for (int i = 0; i < node->input_size(); ++i) {
193     if (IsControlInput(node->input(i))) {
194       node->mutable_input()->SwapElements(i, node->input_size() - 1);
195       node->mutable_input()->RemoveLast();
196     }
197   }
198   if (num_outputs == 0) {
199     return OkStatus();
200   }
201 
202   DataTypeVector input_types;
203   DataTypeVector output_types;
204   OpRegistryInterface* op_registry = OpRegistry::Global();
205   const OpRegistrationData* op_reg_data = nullptr;
206   TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
207   TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
208                                        &output_types));
209 
210   auto consumers = node_map_->GetOutputs(node->name());
211   string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
212   int piterations =
213       invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
214   for (auto* consumer : consumers) {
215     if (!invariant_nodes_.count(consumer)) {
216       for (int i = 0; i < consumer->input_size(); ++i) {
217         int port;
218         string node_name = ParseNodeName(consumer->input(i), &port);
219         if (node_name != node->name()) {
220           continue;
221         }
222         if (port < 0) {
223           return errors::InvalidArgument(
224               "Invariant node should not have control outputs "
225               "to variant node");
226         }
227         DataType output_type = output_types[port];
228         NodeDef* new_enter = optimized_graph_->add_node();
229         new_enter->set_op("Enter");
230         new_enter->set_device(node->device());
231         new_enter->set_name(AddPrefixToNodeName(
232             StrCat(fname, "_enter_", new_enter_id_++), kLoopOptimizer));
233         AttrValue data_type;
234         data_type.set_type(output_type);
235         new_enter->mutable_attr()->insert({"T", data_type});
236         AttrValue frame_name;
237         frame_name.set_s(fname);
238         new_enter->mutable_attr()->insert({"frame_name", frame_name});
239         AttrValue is_const;
240         is_const.set_b(true);
241         new_enter->mutable_attr()->insert({"is_constant", is_const});
242         AttrValue parallel_iterations;
243         parallel_iterations.set_i(piterations);
244         new_enter->mutable_attr()->insert(
245             {"parallel_iterations", parallel_iterations});
246         new_enter->add_input(consumer->input(i));
247         *consumer->mutable_input(i) = new_enter->name();
248         node_map_->AddNode(new_enter->name(), new_enter);
249         node_map_->AddOutput(node->name(), new_enter->name());
250         node_map_->AddOutput(new_enter->name(), consumer->name());
251       }
252     }
253   }
254   return OkStatus();
255 }
256 
MoveInvariantNodes(const int frame_id)257 Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
258     const int frame_id) {
259   for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
260        ++iter) {
261     auto* invariant_node = iter->first;
262     const int num_outputs = iter->second;
263     if (IsEnter(*invariant_node)) {
264       TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
265     } else if (IsConstant(*invariant_node)) {
266       TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
267     } else {
268       TF_RETURN_IF_ERROR(
269           HandleInvariantNode(invariant_node, num_outputs, frame_id));
270     }
271   }
272   return OkStatus();
273 }
274 
RevertInvariantNodes()275 Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
276   std::deque<const NodeDef*> reverted_nodes;
277   for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
278     bool erased = false;
279     const auto* node = iter->first;
280     if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
281       auto& consumers = node_map_->GetOutputs(node->name());
282       for (auto* consumer : consumers) {
283         if (!invariant_nodes_.count(consumer)) {
284           for (const auto& input : consumer->input()) {
285             if (IsControlInput(input) && NodeName(input) == node->name()) {
286               reverted_nodes.push_back(node);
287               invariant_nodes_.erase(iter++);
288               erased = true;
289               break;
290             }
291           }
292           if (erased) break;
293         }
294       }
295     }
296     if (!erased) ++iter;
297   }
298   while (!reverted_nodes.empty()) {
299     const auto* node = reverted_nodes.front();
300     reverted_nodes.pop_front();
301     std::set<NodeDef*> producers;
302     for (const auto& input : node->input()) {
303       auto* producer = node_map_->GetNode(input);
304       auto iter = invariant_nodes_.find(producer);
305       if (iter != invariant_nodes_.end()) {
306         if (IsControlInput(input) && !IsConstant(*producer) &&
307             !IsEnter(*producer)) {
308           reverted_nodes.push_back(producer);
309           invariant_nodes_.erase(iter);
310         } else {
311           producers.insert(producer);
312         }
313       }
314     }
315     for (auto* producer : producers) {
316       auto iter = invariant_nodes_.find(producer);
317       if (iter != invariant_nodes_.end()) {
318         ++iter->second;
319       }
320     }
321     for (auto* consumer : node_map_->GetOutputs(node->name())) {
322       auto iter = invariant_nodes_.find(consumer);
323       if (iter != invariant_nodes_.end()) {
324         reverted_nodes.push_back(consumer);
325         invariant_nodes_.erase(iter);
326       }
327     }
328   }
329   return OkStatus();
330 }
331 
FindInvariantNodes(NodeDef * start_node)332 Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(
333     NodeDef* start_node) {
334   std::vector<NodeDef*> stack;
335   stack.reserve(32);
336   stack.push_back(start_node);
337   while (!stack.empty()) {
338     NodeDef* node = stack.back();
339     stack.pop_back();
340     auto consumers = node_map_->GetOutputs(node->name());
341     invariant_nodes_.emplace(node, consumers.size());
342     for (auto* consumer : consumers) {
343       if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
344         continue;
345       }
346       bool is_invariant = true;
347       for (const auto& input : consumer->input()) {
348         if (!IsControlInput(input)) {
349           const string name = NodeName(input);
350           auto* producer = node_map_->GetNode(name);
351           if (!invariant_nodes_.count(producer)) {
352             if (IsConstant(*producer)) {
353               invariant_nodes_.insert(
354                   std::make_pair(producer, node_map_->GetOutputs(name).size()));
355             } else {
356               is_invariant = false;
357               break;
358             }
359           }
360         }
361       }
362       if (is_invariant) {
363         std::set<NodeDef*> producers;
364         for (const auto& input : consumer->input()) {
365           auto* producer = node_map_->GetNode(input);
366           producers.insert(producer);
367         }
368         for (auto* producer : producers) {
369           auto iter = invariant_nodes_.find(producer);
370           if (iter != invariant_nodes_.end()) {
371             --iter->second;
372           }
373         }
374         stack.push_back(consumer);
375       }
376     }
377   }
378   return OkStatus();
379 }
380 
Optimize()381 Status LoopInvariantNodeMotionOptimizer::Optimize() {
382   node_map_.reset(new NodeMap(optimized_graph_));
383   FrameView frame_view;
384   // TODO(ezhulenev): Use GraphView when migrated from NodeMap.
385   TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*optimized_graph_));
386 
387   frame_parent_.resize(frame_view.num_frames(), -1);
388   frame_children_.resize(frame_view.num_frames());
389   std::deque<int> worklist;
390   for (const NodeDef& node : optimized_graph_->node()) {
391     const std::vector<int>& frame_ids = frame_view.Frames(node);
392 
393     if (frame_ids.size() >= 3) {
394       for (unsigned int i = 1; i < frame_ids.size() - 1; ++i) {
395         frame_parent_[frame_ids[i]] = frame_ids[i - 1];
396         frame_children_[frame_ids[i]].insert(frame_ids[i + 1]);
397       }
398     }
399     if (frame_ids.size() >= 2) {
400       frame_children_[frame_ids[0]].insert(frame_ids[1]);
401       frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
402     }
403     if (!frame_ids.empty()) {
404       frame_children_[frame_ids.back()] = empty_set_;
405       if (node.op() == "LoopCond") {
406         if (loop_cond_.count(frame_ids.back())) {
407           return errors::InvalidArgument(
408               "Loop ", frame_ids.back(),
409               " has more than one LoopCond node: ", node.name(), " and ",
410               loop_cond_[frame_ids.back()]->name());
411         }
412         loop_cond_[frame_ids.back()] = &node;
413       }
414       if (IsEnter(node) && node.attr().at("is_constant").b()) {
415         invariant_enters_[frame_ids.back()].push_back(
416             const_cast<NodeDef*>(&node));
417       }
418     }
419   }
420 
421   for (size_t i = 0; i < frame_children_.size(); i++) {
422     if (frame_children_[i].empty()) {
423       worklist.push_back(i);
424     }
425   }
426 
427   while (!worklist.empty()) {
428     int frame_id = worklist.front();
429     new_enter_id_ = 0;
430     worklist.pop_front();
431     if (frame_parent_[frame_id] != -1) {
432       int parent_id = frame_parent_[frame_id];
433       frame_children_[parent_id].erase(frame_id);
434       if (frame_children_[parent_id].empty()) {
435         worklist.push_back(parent_id);
436       }
437     }
438 
439     if (invariant_enters_[frame_id].empty()) {
440       continue;
441     }
442     invariant_nodes_.clear();
443     for (auto* enter : invariant_enters_[frame_id]) {
444       TF_RETURN_IF_ERROR(FindInvariantNodes(enter));
445     }
446 
447     // revert invariant nodes that have control outputs to variant nodes
448     TF_RETURN_IF_ERROR(RevertInvariantNodes());
449 
450     TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id));
451   }
452   return OkStatus();
453 }
454 
GetStackPushNodesToConvert(const GraphTopologyView & graph_view,const std::unordered_set<string> & nodes_to_preserve,int stack_node_idx)455 std::vector<int> GetStackPushNodesToConvert(
456     const GraphTopologyView& graph_view,
457     const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
458   VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
459 
460   const std::unordered_set<string> op_types_to_traverse(
461       {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
462        "_SwitchN", "Identity", "RefIdentity"});
463   const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
464     return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
465   };
466 
467   std::vector<int> nodes_to_convert;
468   std::vector<int> fanouts;
469 
470   DfsTraversal(graph_view, {graph_view.GetNode(stack_node_idx)},
471                TraversalDirection::kFollowOutputs,
472                DfsPredicates::Advance(is_op_to_traverse),
473                DfsCallbacks::PreOrder([&](const NodeDef* node) {
474                  const absl::optional<int> idx = graph_view.GetNodeIndex(*node);
475                  fanouts.push_back(idx.value());
476                }));
477 
478   for (int fanout_idx : fanouts) {
479     const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
480     VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
481     if (IsStackPushOp(fanout_node)) {
482       // Check that the stack itself is not a node we want to preserve. This can
483       // happen when the graph we have contains only the forward pass for a loop
484       // (as when the forward and backward passes are split across different
485       // functions).
486       if (graph_view.HasNode(fanout_node.input(0))) {
487         const NodeDef* stack_node = graph_view.GetNode(fanout_node.input(0));
488         while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
489                stack_node->input_size() > 0 &&
490                graph_view.HasNode(stack_node->input(0))) {
491           stack_node = graph_view.GetNode(stack_node->input(0));
492         }
493         if (nodes_to_preserve.find(stack_node->name()) ==
494             nodes_to_preserve.end()) {
495           nodes_to_convert.push_back(fanout_idx);
496         }
497       } else {
498         nodes_to_convert.push_back(fanout_idx);
499       }
500     } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
501                op_types_to_traverse.find(fanout_node.op()) !=
502                    op_types_to_traverse.end()) {
503       continue;
504     } else if (!IsStackPopOp(fanout_node) ||
505                (!graph_view.GetFanout(fanout_idx).empty() ||
506                 nodes_to_preserve.find(fanout_node.name()) !=
507                     nodes_to_preserve.end())) {
508       // The node is either a stack pop with consumers or something unexpected
509       // so we leave the graph alone.
510       nodes_to_convert.clear();
511       break;
512     }
513   }
514 
515   return nodes_to_convert;
516 }
517 
RemoveStackOps(const std::unordered_set<string> & nodes_to_preserve,GraphDef * optimized_graph)518 Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
519                       GraphDef* optimized_graph) {
520   NodeMap node_map(optimized_graph);
521   GraphTopologyView graph_view;
522   TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(*optimized_graph));
523 
524   for (int node_idx = 0; node_idx < optimized_graph->node_size(); ++node_idx) {
525     if (IsStackOp(optimized_graph->node(node_idx))) {
526       for (int push_node_idx : GetStackPushNodesToConvert(
527                graph_view, nodes_to_preserve, node_idx)) {
528         // We found push nodes without corresponding pops. Convert them to
529         // Identity passing the data through and add a control dependency from
530         // the op supplying the stack handle.
531         NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
532         VLOG(1) << "Converting " << push_node_idx << " : "
533                 << push_node->DebugString();
534         if (push_node->attr().count("swap_memory") != 0) {
535           push_node->mutable_attr()->erase("swap_memory");
536         }
537         push_node->set_op("Identity");
538         push_node->mutable_input()->SwapElements(0, 1);
539         const string ctrl_dep = ConstantFolding::AddControlDependency(
540             push_node->input(1), optimized_graph, &node_map);
541         push_node->set_input(1, ctrl_dep);
542         VLOG(1) << "After converting: " << push_node->DebugString();
543       }
544     }
545   }
546   return OkStatus();
547 }
548 
IsSimpleBinaryOperator(const NodeDef & node)549 bool IsSimpleBinaryOperator(const NodeDef& node) {
550   return (IsLess(node) || IsLessEqual(node) || IsGreater(node) ||
551           IsGreaterEqual(node) || IsEqual(node));
552 }
553 
EvaluateBoolOpForConstantOperands(const NodeDef & op_node,const NodeDef & constant_operand_0,const NodeDef & constant_operand_1,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * value)554 Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
555                                          const NodeDef& constant_operand_0,
556                                          const NodeDef& constant_operand_1,
557                                          DeviceBase* cpu_device,
558                                          ResourceMgr* resource_mgr,
559                                          bool* value) {
560   VLOG(4) << "Evaluate bool op: op_node=" << op_node.name()
561           << " input0=" << constant_operand_0.name()
562           << " input1=" << constant_operand_1.name();
563   TensorVector inputs;
564 
565   const TensorProto& raw_val_0 = constant_operand_0.attr().at("value").tensor();
566   Tensor value_0(raw_val_0.dtype(), raw_val_0.tensor_shape());
567   CHECK(value_0.FromProto(raw_val_0));
568   inputs.emplace_back(&value_0);
569   const TensorProto& raw_val_1 = constant_operand_1.attr().at("value").tensor();
570   Tensor value_1(raw_val_1.dtype(), raw_val_1.tensor_shape());
571   CHECK(value_1.FromProto(raw_val_1));
572   inputs.emplace_back(&value_1);
573 
574   TensorVector outputs;
575   TF_RETURN_IF_ERROR(
576       EvaluateNode(op_node, inputs, cpu_device, resource_mgr, &outputs));
577 
578   if (outputs.size() != 1 || outputs[0].tensor == nullptr) {
579     return Status(error::INVALID_ARGUMENT, "Expected one output.");
580   }
581   *value = outputs[0].tensor->scalar<bool>()();
582   delete outputs[0].tensor;
583 
584   return OkStatus();
585 }
586 
587 // TODO(lyandy): Consolidate with ConstantFolding implementation.
IsReallyConstant(const NodeDef & node,const absl::flat_hash_set<string> & feed_nodes)588 bool IsReallyConstant(const NodeDef& node,
589                       const absl::flat_hash_set<string>& feed_nodes) {
590   if (!IsConstant(node)) {
591     return false;
592   }
593   // If the node is fed it's not constant anymore.
594   return feed_nodes.find(node.name()) == feed_nodes.end();
595 }
596 
CheckForDeadFanout(const MutableGraphView & view,const NodeDef & switch_node,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * has_dead_fanout,int * dead_fanout)597 Status CheckForDeadFanout(const MutableGraphView& view,
598                           const NodeDef& switch_node, const NodeMap& node_map,
599                           const absl::flat_hash_set<string>& feed_nodes,
600                           DeviceBase* cpu_device, ResourceMgr* resource_mgr,
601                           bool* has_dead_fanout, int* dead_fanout) {
602   *has_dead_fanout = false;
603   GraphView::InputPort switch_loopcond_port(&switch_node, 1);
604   const NodeDef* switch_predicate =
605       view.GetRegularFanin(switch_loopcond_port).node;
606 
607   // CASE 1: Control is a constant.
608   if (IsReallyConstant(*switch_predicate, feed_nodes)) {
609     VLOG(3) << "Found switch node with constant predicate:"
610             << " switch_node=" << switch_node.name()
611             << " switch_predicate=" << switch_predicate->name();
612     Tensor selector;
613     CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
614     *has_dead_fanout = true;
615     *dead_fanout = selector.scalar<bool>()() ? 0 : 1;
616     return OkStatus();
617   }
618 
619   GraphView::InputPort switch_input_port(&switch_node, 0);
620   const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
621 
622   // CASE 2: Zero-iteration while loop.
623   // We check if its a while loop such that the condition is a simple binary
624   // operator which returns false for the initialization value.
625   // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
626   if (!IsMerge(*switch_input) || !IsLoopCond(*switch_predicate)) {
627     return OkStatus();
628   }
629 
630   VLOG(4) << "Try to find a zero iteration while loop:"
631           << " switch_node=" << switch_node.name();
632 
633   // Find the boolean predicate from a LoopCond node (e.g. Greater).
634   NodeDef* switch_ctrl_node = view.GetRegularFanin({switch_predicate, 0}).node;
635   if (!switch_ctrl_node || !IsSimpleBinaryOperator(*switch_ctrl_node)) {
636     return OkStatus();
637   }
638 
639   // Find the Merge node & the Constant Operand to the condition node, if
640   // available.
641   NodeDef* merge_node = nullptr;
642   NodeDef* constant_ctrl_input = nullptr;
643   int constant_index = 0;
644   for (int i = 0; i < switch_ctrl_node->input().size(); ++i) {
645     const string& input = switch_ctrl_node->input(i);
646     if (IsControlInput(input)) continue;
647 
648     NodeDef* node = view.GetNode(switch_ctrl_node->input(i));
649     if (IsMerge(*node)) {
650       merge_node = node;
651     }
652     if (IsReallyConstant(*node, feed_nodes)) {
653       constant_ctrl_input = node;
654       constant_index = i;
655     }
656   }
657   if (merge_node == nullptr || constant_ctrl_input == nullptr) {
658     return OkStatus();
659   }
660 
661   // Find the initialization constant (via Enter, if one exists).
662   NodeDef* enter_node = nullptr;
663   NodeDef* constant_init_node = nullptr;
664   for (const auto& input : merge_node->input()) {
665     NodeDef* node = node_map.GetNode(input);
666     if (IsEnter(*node)) {
667       enter_node = node;
668     }
669     if (IsReallyConstant(*node, feed_nodes)) {
670       constant_init_node = node;
671     }
672   }
673   if (enter_node != nullptr) {
674     if (constant_init_node != nullptr) return OkStatus();
675     for (const auto& input : enter_node->input()) {
676       NodeDef* node = node_map.GetNode(input);
677       if (IsReallyConstant(*node, feed_nodes)) {
678         constant_init_node = node;
679       }
680     }
681   }
682   if (constant_init_node == nullptr) {
683     return OkStatus();
684   }
685 
686   VLOG(4) << "Check if loop will be 0 iterations:"
687           << "\n|  switch_node        : " << switch_node.name()
688           << "\n|  switch_ctrl_node   : " << switch_ctrl_node->name()
689           << "\n|  merge_node         : " << merge_node->name()
690           << "\n|  constant_ctrl_input: " << constant_ctrl_input->name()
691           << "\n|  enter_node         : "
692           << (enter_node ? enter_node->name() : "<n/a>")
693           << "\n|  constant_init_node : " << constant_init_node->name();
694 
695   // Check if there will be 0 iterations. This will only happen if the condition
696   // evaluates to false with respect to the initialization value.
697   NodeDef* operand_0 =
698       constant_index ? constant_init_node : constant_ctrl_input;
699   NodeDef* operand_1 =
700       constant_index ? constant_ctrl_input : constant_init_node;
701   bool constant_switch_value;
702   TF_RETURN_IF_ERROR(EvaluateBoolOpForConstantOperands(
703       *switch_ctrl_node, *operand_0, *operand_1, cpu_device, resource_mgr,
704       &constant_switch_value));
705 
706   if (constant_switch_value == false) {
707     VLOG(3) << "Remove 0 iteration while loop:"
708             << " switch_node=" << switch_node.name();
709     *has_dead_fanout = true;
710     *dead_fanout = 1;
711   } else {
712     VLOG(4) << "Was not able to prove that loop has 0 iterations.";
713   }
714   return OkStatus();
715 }
716 
717 }  // namespace
718 
LoopOptimizer()719 LoopOptimizer::LoopOptimizer()
720     : opt_level_(RewriterConfig::ON),
721       cpu_device_(nullptr),
722       options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
723 
LoopOptimizer(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)724 LoopOptimizer::LoopOptimizer(RewriterConfig::Toggle opt_level,
725                              DeviceBase* cpu_device)
726     : opt_level_(opt_level),
727       cpu_device_(cpu_device),
728       options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {
729   resource_mgr_.reset(new ResourceMgr());
730 }
731 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)732 Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
733                                GraphDef* optimized_graph) {
734   if (!options_.enable_loop_invariant_node_motion &&
735       !options_.enable_stack_push_removal &&
736       !options_.enable_dead_branch_removal) {
737     return errors::Aborted("Nothing to do.");
738   }
739   *optimized_graph = item.graph;
740   // Set up helper data structures.
741   if (options_.enable_loop_invariant_node_motion) {
742     LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
743     TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
744   }
745   if (options_.enable_stack_push_removal) {
746     TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
747   }
748   if (options_.enable_dead_branch_removal) {
749     NodeMap node_map(optimized_graph);
750     absl::flat_hash_set<string> feed_nodes;
751     for (const auto& feed : item.feed) {
752       feed_nodes.insert(NodeName(feed.first));
753     }
754     TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
755                                           feed_nodes, optimized_graph));
756   }
757 
758   return OkStatus();
759 }
760 
RemoveDeadBranches(const std::unordered_set<string> & nodes_to_preserve,NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,GraphDef * optimized_graph)761 Status LoopOptimizer::RemoveDeadBranches(
762     const std::unordered_set<string>& nodes_to_preserve, NodeMap& node_map,
763     const absl::flat_hash_set<string>& feed_nodes, GraphDef* optimized_graph) {
764   std::unordered_set<const NodeDef*> dead_nodes;
765   std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
766   absl::flat_hash_set<GraphView::OutputPort> identity_switches;
767 
768   MutableGraphView view(optimized_graph);
769   for (const NodeDef& node : optimized_graph->node()) {
770     if (!IsSwitch(node)) {
771       continue;
772     }
773     if (node.op() == "_SwitchN") {  // _SwitchN not used in loop control flow.
774       continue;
775     }
776     if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
777       continue;
778     }
779 
780     int dead_fanout;
781     bool has_dead_fanout;
782     TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
783                                           cpu_device_, resource_mgr_.get(),
784                                           &has_dead_fanout, &dead_fanout));
785     if (!has_dead_fanout) {
786       continue;
787     }
788     GraphView::OutputPort dead(&node, dead_fanout);
789 
790     SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
791         zombie_inputs;
792     for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
793       if (dead_nodes.find(port.node) == dead_nodes.end()) {
794         zombie_inputs.PushBack(port);
795       }
796     }
797     // If we encounter a single node that must be preserved in the fanout of the
798     // switch node we need to preserve the entire switch fanout: we therefore
799     // work on a local copy that only gets committed to the master copy once the
800     // whole fanout has been explored.
801     std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
802     std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
803         dead_merge_inputs;
804     bool found_node_to_preserve = false;
805     while (!found_node_to_preserve && !zombie_inputs.Empty()) {
806       MutableGraphView::InputPort dead = zombie_inputs.PopBack();
807       if (nodes_to_preserve.find(dead.node->name()) !=
808           nodes_to_preserve.end()) {
809         found_node_to_preserve = true;
810         break;
811       }
812 
813       if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
814         continue;
815       }
816 
817       if (IsMerge(*dead.node)) {
818         const int num_data_inputs = dead.node->attr().at("N").i();
819         if (num_data_inputs > 2) {
820           // This can happen with _SwitchN/Merge (Case lowering). We skip these
821           // to simplify the code for now.
822           found_node_to_preserve = true;
823           break;
824         }
825         MutableGraphView::OutputPort value_index(dead.node, 1);
826         const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
827             view.GetFanout(value_index);
828         if (!index_fanout.empty()) {
829           // The 2nd output (that indicates which input is propagated) is
830           // connected. This never happens in practice, so we'll just skip this
831           // case to simplify the code for now.
832           found_node_to_preserve = true;
833           break;
834         }
835 
836         bool fully_dead = false;
837         // Merge node can become real dead only if all data inputs are dead.
838         // Merge always waits for all control edges, but they do not
839         // change the node deadness.
840         if (dead.port_id >= 0) {
841           local_dead_merge_inputs[dead.node].insert(dead.port_id);
842           if (local_dead_merge_inputs[dead.node].size() == num_data_inputs) {
843             fully_dead = true;
844           }
845         } else {
846           // Keep track of all Merge nodes, even if they do not have dead data
847           // inputs. We'll need to cleanup dead control edges for them later.
848           local_dead_merge_inputs.insert({dead.node, {}});
849         }
850         if (fully_dead) {
851           local_dead_merge_inputs.erase(dead.node);
852           local_dead_nodes.insert(dead.node);
853           for (const MutableGraphView::InputPort& port :
854                view.GetFanouts(*dead.node, true)) {
855             zombie_inputs.PushBack(port);
856           }
857         }
858       } else if (dead.node->op() == "ControlTrigger") {
859         // Control trigger have different semantic, so don't touch them
860         found_node_to_preserve = true;
861         break;
862       } else {
863         if (local_dead_nodes.insert(dead.node).second) {
864           for (const MutableGraphView::InputPort& dead_fanout :
865                view.GetFanouts(*dead.node, true)) {
866             zombie_inputs.PushBack(dead_fanout);
867           }
868         }
869       }
870     }
871     if (!found_node_to_preserve) {
872       std::swap(dead_nodes, local_dead_nodes);
873       std::swap(dead_merge_inputs, local_dead_merge_inputs);
874       // Found no nodes to preserve in fanout of this switch node. This switch
875       // node can be replaced with Identity node, collect here to process later
876       identity_switches.insert(dead);
877       VLOG(3) << "Found no nodes to preserve in fanout of switch node: "
878               << node.name() << ", fanout port: " << dead_fanout;
879     }
880   }
881 
882   std::vector<int> nodes_idx_to_delete;
883   nodes_idx_to_delete.reserve(dead_nodes.size());
884   for (int i = 0; i < optimized_graph->node_size(); ++i) {
885     if (dead_nodes.count(&optimized_graph->node(i)))
886       nodes_idx_to_delete.push_back(i);
887   }
888 
889   // Names of the nodes that were removed from the graph.
890   absl::flat_hash_set<absl::string_view> dead_node_names;
891   dead_node_names.reserve(dead_nodes.size());
892   for (const NodeDef* dead_node : dead_nodes) {
893     dead_node_names.insert(dead_node->name());
894   }
895 
896   // Check that the merge nodes are valid.
897   for (const auto& itr : dead_merge_inputs) {
898     NodeDef* merge_node = itr.first;
899     if (dead_nodes.find(merge_node) != dead_nodes.end()) {
900       // The node will be pruned since all its inputs are dead.
901       continue;
902     }
903     // Remove dead data input.
904     const std::set<int>& dead_inputs = itr.second;
905     const int num_data_inputs = merge_node->attr().at("N").i();
906     if (merge_node->input_size() != num_data_inputs) {
907       LOG(WARNING)
908           << "Skipping loop optimization for Merge node with control input: "
909           << merge_node->name();
910       return OkStatus();
911     } else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
912       LOG(WARNING) << "Skipping loop optimization for Merge node ("
913                    << merge_node->name()
914                    << ") with unexpected dead_inputs.size() ("
915                    << dead_inputs.size() << " or  num_data_inputs"
916                    << num_data_inputs;
917       return OkStatus();
918     }
919   }
920 
921   // Remove dead inputs from Merge nodes that will not be not
922   // pruned from the graph.
923   for (const auto& itr : dead_merge_inputs) {
924     NodeDef* merge_node = itr.first;
925     if (dead_nodes.find(merge_node) != dead_nodes.end()) {
926       // The node will be pruned since all its inputs are dead.
927       continue;
928     }
929     VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
930     // Remove dead data input.
931     const std::set<int>& dead_inputs = itr.second;
932     int index = *dead_inputs.begin();
933     auto* inputs = merge_node->mutable_input();
934     inputs->SwapElements(1, index);
935     inputs->SwapElements(1, merge_node->input_size() - 1);
936     inputs->RemoveLast();
937     merge_node->set_op("Identity");
938     merge_node->mutable_attr()->erase("N");
939 
940     VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
941   }
942 
943   for (const auto& id_switch : identity_switches) {
944     NodeDef* sw_node = const_cast<NodeDef*>((id_switch.node));
945     int dead_port_id = id_switch.port_id;
946 
947     // Switch node where pred is not a constant, is not optimized.
948     // TODO(intel-tf): For that case, enable optimization only if safe.
949     // TODO(intel-tf): Need to check for RefSwitch and replace RefSwitch with
950     // RefIdentity
951     NodeDef* pred = node_map.GetNode(sw_node->input(1));
952     if (IsReallyConstant(*pred, feed_nodes) && sw_node->op() == "Switch") {
953       // From the dead_port_id, get the live port id, so we can correct
954       // input names of consumers. When switch will be replaced with Identity,
955       // it will have only 1 output versus 2 outputs of a Switch node
956       int live_port_id = (dead_port_id + 1) % 2;
957       string live_output_name = sw_node->name();
958       if (live_port_id == 1) {
959         live_output_name = StrCat(sw_node->name(), ":1");
960       }
961 
962       // Get consumers of live port and update the input names
963       auto consumers = node_map.GetOutputs(sw_node->name());
964       for (auto* consumer : consumers) {
965         for (int i = 0; i < consumer->input_size(); ++i) {
966           if (consumer->input(i) == live_output_name) {
967             consumer->set_input(i, sw_node->name());
968             node_map.UpdateInput(consumer->name(), live_output_name,
969                                  sw_node->name());
970           }
971         }
972       }
973 
974       VLOG(3) << "Switch node before cleanup: " << sw_node->DebugString();
975 
976       // Change node from Switch to Identity and add a control dependency to
977       // this Identity op.
978       const string ctrl_dep = ConstantFolding::AddControlDependency(
979           pred->name(), optimized_graph, &node_map);
980       node_map.UpdateInput(sw_node->name(), pred->name(), ctrl_dep);
981       sw_node->set_input(1, ctrl_dep);
982       sw_node->set_op("Identity");
983       VLOG(3) << "Switch node after cleanup: " << sw_node->DebugString();
984     }
985   }
986   EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
987 
988   return OkStatus();
989 }
990 
991 }  // end namespace grappler
992 }  // end namespace tensorflow
993