1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/allocator.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils/frame.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/tensor_coding.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/saved_tensor_slice_util.h"
50
51 using tensorflow::strings::StrCat;
52
53 namespace tensorflow {
54 namespace grappler {
55 namespace {
56
57 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
58
59 class LoopInvariantNodeMotionOptimizer {
60 public:
LoopInvariantNodeMotionOptimizer(GraphDef * optimized_graph)61 explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
62 : optimized_graph_(optimized_graph) {}
63 virtual ~LoopInvariantNodeMotionOptimizer() = default;
64 Status Optimize();
65
66 private:
67 Status FindInvariantNodes(NodeDef* node);
68 Status RevertInvariantNodes();
69 Status MoveInvariantNodes(const int frame_id);
70 Status HandleInvariantNode(NodeDef* node, const int num_outputs,
71 const int frame_id);
72 Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
73 Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
74
75 GraphDef* optimized_graph_; // Not owned.
76 std::unique_ptr<NodeMap> node_map_;
77 std::map<NodeDef*, int> invariant_nodes_;
78 std::set<int> empty_set_;
79 std::vector<std::set<int>> frame_children_;
80 std::vector<int> frame_parent_;
81 std::map<int, const NodeDef*> loop_cond_;
82 std::map<int, std::vector<NodeDef*>> invariant_enters_;
83 int new_enter_id_;
84 };
85
HandleInvariantEnter(NodeDef * node,const int num_outputs)86 Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
87 NodeDef* node, const int num_outputs) {
88 auto consumers = node_map_->GetOutputs(node->name());
89 std::vector<string> enter_control_inputs;
90 string enter_input;
91 for (auto& input : node->input()) {
92 if (IsControlInput(input)) {
93 enter_control_inputs.push_back(input);
94 } else {
95 enter_input = input;
96 }
97 }
98 for (auto* consumer : consumers) {
99 if (invariant_nodes_.count(consumer)) {
100 for (int i = 0; i < consumer->input_size(); ++i) {
101 if (NodeName(consumer->input(i)) == node->name()) {
102 consumer->set_input(i, enter_input);
103 node_map_->AddOutput(NodeName(enter_input), consumer->name());
104 node_map_->RemoveOutput(node->name(), consumer->name());
105 }
106 }
107 for (auto& control_input : enter_control_inputs) {
108 consumer->add_input(control_input);
109 node_map_->AddOutput(NodeName(control_input), consumer->name());
110 }
111 }
112 }
113 return OkStatus();
114 }
115
HandleConst(NodeDef * node,const int num_outputs,const int frame_id)116 Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
117 const int num_outputs,
118 const int frame_id) {
119 NodeDef* const_node = nullptr;
120 if (num_outputs == 0) {
121 // all successor nodes are invariant
122 // Remove the control inputs from this frame to the const node,
123 // when moving it out of the frame (in parent frame)
124 const_node = node;
125 node_map_->RemoveInputs(node->name());
126 node->clear_input();
127 } else {
128 // some successor nodes are variant
129 // Have to keep the const node in the frame,
130 // so create a new one outside the frame (in parent frame)
131 const string const_node_name =
132 AddPrefixToNodeName(node->name(), kLoopOptimizer);
133 const_node = node_map_->GetNode(const_node_name);
134 if (const_node == nullptr) {
135 const_node = optimized_graph_->add_node();
136 const_node->set_name(const_node_name);
137 const_node->set_op("Const");
138 const_node->set_device(node->device());
139 *const_node->mutable_attr() = node->attr();
140 node_map_->AddNode(const_node->name(), const_node);
141 }
142 auto consumers = node_map_->GetOutputs(node->name());
143 for (auto* consumer : consumers) {
144 if (invariant_nodes_.count(consumer)) {
145 for (int i = 0; i < consumer->input_size(); ++i) {
146 if (NodeName(consumer->input(i)) == node->name()) {
147 if (IsControlInput(consumer->input(i))) {
148 *consumer->mutable_input(i) = AsControlDependency(*const_node);
149 } else {
150 *consumer->mutable_input(i) = const_node->name();
151 }
152 node_map_->AddOutput(const_node->name(), consumer->name());
153 node_map_->RemoveOutput(node->name(), consumer->name());
154 }
155 }
156 }
157 }
158 }
159 // add a control input from the parent frame
160 if (frame_parent_[frame_id] != -1) {
161 int parent_id = frame_parent_[frame_id];
162 auto loop_cond_it = loop_cond_.find(parent_id);
163 if (loop_cond_it == loop_cond_.end()) {
164 return errors::InvalidArgument("Frame ", frame_id,
165 " doesn't have a LoopCond node");
166 }
167 auto& loop_cond_name = loop_cond_it->second->name();
168 NodeDef* switch_node = nullptr;
169 for (auto* node : node_map_->GetOutputs(loop_cond_name)) {
170 if (node->op() == "Switch") {
171 switch_node = node;
172 break;
173 }
174 }
175 if (!switch_node) {
176 return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
177 " doesn't connect to any Switch node");
178 }
179 string switch_output = StrCat(switch_node->name(), ":1");
180 const string ctrl_dep = ConstantFolding::AddControlDependency(
181 switch_output, optimized_graph_, node_map_.get());
182 const_node->add_input(ctrl_dep);
183 node_map_->AddOutput(NodeName(ctrl_dep), const_node->name());
184 }
185 return OkStatus();
186 }
187
HandleInvariantNode(NodeDef * node,const int num_outputs,const int frame_id)188 Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
189 NodeDef* node, const int num_outputs, const int frame_id) {
190 // have to remove control inputs to the invariant node from the same frame
191 // when moving this node out of this frame
192 for (int i = 0; i < node->input_size(); ++i) {
193 if (IsControlInput(node->input(i))) {
194 node->mutable_input()->SwapElements(i, node->input_size() - 1);
195 node->mutable_input()->RemoveLast();
196 }
197 }
198 if (num_outputs == 0) {
199 return OkStatus();
200 }
201
202 DataTypeVector input_types;
203 DataTypeVector output_types;
204 OpRegistryInterface* op_registry = OpRegistry::Global();
205 const OpRegistrationData* op_reg_data = nullptr;
206 TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
207 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
208 &output_types));
209
210 auto consumers = node_map_->GetOutputs(node->name());
211 string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
212 int piterations =
213 invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
214 for (auto* consumer : consumers) {
215 if (!invariant_nodes_.count(consumer)) {
216 for (int i = 0; i < consumer->input_size(); ++i) {
217 int port;
218 string node_name = ParseNodeName(consumer->input(i), &port);
219 if (node_name != node->name()) {
220 continue;
221 }
222 if (port < 0) {
223 return errors::InvalidArgument(
224 "Invariant node should not have control outputs "
225 "to variant node");
226 }
227 DataType output_type = output_types[port];
228 NodeDef* new_enter = optimized_graph_->add_node();
229 new_enter->set_op("Enter");
230 new_enter->set_device(node->device());
231 new_enter->set_name(AddPrefixToNodeName(
232 StrCat(fname, "_enter_", new_enter_id_++), kLoopOptimizer));
233 AttrValue data_type;
234 data_type.set_type(output_type);
235 new_enter->mutable_attr()->insert({"T", data_type});
236 AttrValue frame_name;
237 frame_name.set_s(fname);
238 new_enter->mutable_attr()->insert({"frame_name", frame_name});
239 AttrValue is_const;
240 is_const.set_b(true);
241 new_enter->mutable_attr()->insert({"is_constant", is_const});
242 AttrValue parallel_iterations;
243 parallel_iterations.set_i(piterations);
244 new_enter->mutable_attr()->insert(
245 {"parallel_iterations", parallel_iterations});
246 new_enter->add_input(consumer->input(i));
247 *consumer->mutable_input(i) = new_enter->name();
248 node_map_->AddNode(new_enter->name(), new_enter);
249 node_map_->AddOutput(node->name(), new_enter->name());
250 node_map_->AddOutput(new_enter->name(), consumer->name());
251 }
252 }
253 }
254 return OkStatus();
255 }
256
MoveInvariantNodes(const int frame_id)257 Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
258 const int frame_id) {
259 for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
260 ++iter) {
261 auto* invariant_node = iter->first;
262 const int num_outputs = iter->second;
263 if (IsEnter(*invariant_node)) {
264 TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
265 } else if (IsConstant(*invariant_node)) {
266 TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
267 } else {
268 TF_RETURN_IF_ERROR(
269 HandleInvariantNode(invariant_node, num_outputs, frame_id));
270 }
271 }
272 return OkStatus();
273 }
274
RevertInvariantNodes()275 Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
276 std::deque<const NodeDef*> reverted_nodes;
277 for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
278 bool erased = false;
279 const auto* node = iter->first;
280 if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
281 auto& consumers = node_map_->GetOutputs(node->name());
282 for (auto* consumer : consumers) {
283 if (!invariant_nodes_.count(consumer)) {
284 for (const auto& input : consumer->input()) {
285 if (IsControlInput(input) && NodeName(input) == node->name()) {
286 reverted_nodes.push_back(node);
287 invariant_nodes_.erase(iter++);
288 erased = true;
289 break;
290 }
291 }
292 if (erased) break;
293 }
294 }
295 }
296 if (!erased) ++iter;
297 }
298 while (!reverted_nodes.empty()) {
299 const auto* node = reverted_nodes.front();
300 reverted_nodes.pop_front();
301 std::set<NodeDef*> producers;
302 for (const auto& input : node->input()) {
303 auto* producer = node_map_->GetNode(input);
304 auto iter = invariant_nodes_.find(producer);
305 if (iter != invariant_nodes_.end()) {
306 if (IsControlInput(input) && !IsConstant(*producer) &&
307 !IsEnter(*producer)) {
308 reverted_nodes.push_back(producer);
309 invariant_nodes_.erase(iter);
310 } else {
311 producers.insert(producer);
312 }
313 }
314 }
315 for (auto* producer : producers) {
316 auto iter = invariant_nodes_.find(producer);
317 if (iter != invariant_nodes_.end()) {
318 ++iter->second;
319 }
320 }
321 for (auto* consumer : node_map_->GetOutputs(node->name())) {
322 auto iter = invariant_nodes_.find(consumer);
323 if (iter != invariant_nodes_.end()) {
324 reverted_nodes.push_back(consumer);
325 invariant_nodes_.erase(iter);
326 }
327 }
328 }
329 return OkStatus();
330 }
331
FindInvariantNodes(NodeDef * start_node)332 Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(
333 NodeDef* start_node) {
334 std::vector<NodeDef*> stack;
335 stack.reserve(32);
336 stack.push_back(start_node);
337 while (!stack.empty()) {
338 NodeDef* node = stack.back();
339 stack.pop_back();
340 auto consumers = node_map_->GetOutputs(node->name());
341 invariant_nodes_.emplace(node, consumers.size());
342 for (auto* consumer : consumers) {
343 if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
344 continue;
345 }
346 bool is_invariant = true;
347 for (const auto& input : consumer->input()) {
348 if (!IsControlInput(input)) {
349 const string name = NodeName(input);
350 auto* producer = node_map_->GetNode(name);
351 if (!invariant_nodes_.count(producer)) {
352 if (IsConstant(*producer)) {
353 invariant_nodes_.insert(
354 std::make_pair(producer, node_map_->GetOutputs(name).size()));
355 } else {
356 is_invariant = false;
357 break;
358 }
359 }
360 }
361 }
362 if (is_invariant) {
363 std::set<NodeDef*> producers;
364 for (const auto& input : consumer->input()) {
365 auto* producer = node_map_->GetNode(input);
366 producers.insert(producer);
367 }
368 for (auto* producer : producers) {
369 auto iter = invariant_nodes_.find(producer);
370 if (iter != invariant_nodes_.end()) {
371 --iter->second;
372 }
373 }
374 stack.push_back(consumer);
375 }
376 }
377 }
378 return OkStatus();
379 }
380
Optimize()381 Status LoopInvariantNodeMotionOptimizer::Optimize() {
382 node_map_.reset(new NodeMap(optimized_graph_));
383 FrameView frame_view;
384 // TODO(ezhulenev): Use GraphView when migrated from NodeMap.
385 TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*optimized_graph_));
386
387 frame_parent_.resize(frame_view.num_frames(), -1);
388 frame_children_.resize(frame_view.num_frames());
389 std::deque<int> worklist;
390 for (const NodeDef& node : optimized_graph_->node()) {
391 const std::vector<int>& frame_ids = frame_view.Frames(node);
392
393 if (frame_ids.size() >= 3) {
394 for (unsigned int i = 1; i < frame_ids.size() - 1; ++i) {
395 frame_parent_[frame_ids[i]] = frame_ids[i - 1];
396 frame_children_[frame_ids[i]].insert(frame_ids[i + 1]);
397 }
398 }
399 if (frame_ids.size() >= 2) {
400 frame_children_[frame_ids[0]].insert(frame_ids[1]);
401 frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
402 }
403 if (!frame_ids.empty()) {
404 frame_children_[frame_ids.back()] = empty_set_;
405 if (node.op() == "LoopCond") {
406 if (loop_cond_.count(frame_ids.back())) {
407 return errors::InvalidArgument(
408 "Loop ", frame_ids.back(),
409 " has more than one LoopCond node: ", node.name(), " and ",
410 loop_cond_[frame_ids.back()]->name());
411 }
412 loop_cond_[frame_ids.back()] = &node;
413 }
414 if (IsEnter(node) && node.attr().at("is_constant").b()) {
415 invariant_enters_[frame_ids.back()].push_back(
416 const_cast<NodeDef*>(&node));
417 }
418 }
419 }
420
421 for (size_t i = 0; i < frame_children_.size(); i++) {
422 if (frame_children_[i].empty()) {
423 worklist.push_back(i);
424 }
425 }
426
427 while (!worklist.empty()) {
428 int frame_id = worklist.front();
429 new_enter_id_ = 0;
430 worklist.pop_front();
431 if (frame_parent_[frame_id] != -1) {
432 int parent_id = frame_parent_[frame_id];
433 frame_children_[parent_id].erase(frame_id);
434 if (frame_children_[parent_id].empty()) {
435 worklist.push_back(parent_id);
436 }
437 }
438
439 if (invariant_enters_[frame_id].empty()) {
440 continue;
441 }
442 invariant_nodes_.clear();
443 for (auto* enter : invariant_enters_[frame_id]) {
444 TF_RETURN_IF_ERROR(FindInvariantNodes(enter));
445 }
446
447 // revert invariant nodes that have control outputs to variant nodes
448 TF_RETURN_IF_ERROR(RevertInvariantNodes());
449
450 TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id));
451 }
452 return OkStatus();
453 }
454
GetStackPushNodesToConvert(const GraphTopologyView & graph_view,const std::unordered_set<string> & nodes_to_preserve,int stack_node_idx)455 std::vector<int> GetStackPushNodesToConvert(
456 const GraphTopologyView& graph_view,
457 const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
458 VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
459
460 const std::unordered_set<string> op_types_to_traverse(
461 {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
462 "_SwitchN", "Identity", "RefIdentity"});
463 const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
464 return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
465 };
466
467 std::vector<int> nodes_to_convert;
468 std::vector<int> fanouts;
469
470 DfsTraversal(graph_view, {graph_view.GetNode(stack_node_idx)},
471 TraversalDirection::kFollowOutputs,
472 DfsPredicates::Advance(is_op_to_traverse),
473 DfsCallbacks::PreOrder([&](const NodeDef* node) {
474 const absl::optional<int> idx = graph_view.GetNodeIndex(*node);
475 fanouts.push_back(idx.value());
476 }));
477
478 for (int fanout_idx : fanouts) {
479 const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
480 VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
481 if (IsStackPushOp(fanout_node)) {
482 // Check that the stack itself is not a node we want to preserve. This can
483 // happen when the graph we have contains only the forward pass for a loop
484 // (as when the forward and backward passes are split across different
485 // functions).
486 if (graph_view.HasNode(fanout_node.input(0))) {
487 const NodeDef* stack_node = graph_view.GetNode(fanout_node.input(0));
488 while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
489 stack_node->input_size() > 0 &&
490 graph_view.HasNode(stack_node->input(0))) {
491 stack_node = graph_view.GetNode(stack_node->input(0));
492 }
493 if (nodes_to_preserve.find(stack_node->name()) ==
494 nodes_to_preserve.end()) {
495 nodes_to_convert.push_back(fanout_idx);
496 }
497 } else {
498 nodes_to_convert.push_back(fanout_idx);
499 }
500 } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
501 op_types_to_traverse.find(fanout_node.op()) !=
502 op_types_to_traverse.end()) {
503 continue;
504 } else if (!IsStackPopOp(fanout_node) ||
505 (!graph_view.GetFanout(fanout_idx).empty() ||
506 nodes_to_preserve.find(fanout_node.name()) !=
507 nodes_to_preserve.end())) {
508 // The node is either a stack pop with consumers or something unexpected
509 // so we leave the graph alone.
510 nodes_to_convert.clear();
511 break;
512 }
513 }
514
515 return nodes_to_convert;
516 }
517
RemoveStackOps(const std::unordered_set<string> & nodes_to_preserve,GraphDef * optimized_graph)518 Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
519 GraphDef* optimized_graph) {
520 NodeMap node_map(optimized_graph);
521 GraphTopologyView graph_view;
522 TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(*optimized_graph));
523
524 for (int node_idx = 0; node_idx < optimized_graph->node_size(); ++node_idx) {
525 if (IsStackOp(optimized_graph->node(node_idx))) {
526 for (int push_node_idx : GetStackPushNodesToConvert(
527 graph_view, nodes_to_preserve, node_idx)) {
528 // We found push nodes without corresponding pops. Convert them to
529 // Identity passing the data through and add a control dependency from
530 // the op supplying the stack handle.
531 NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
532 VLOG(1) << "Converting " << push_node_idx << " : "
533 << push_node->DebugString();
534 if (push_node->attr().count("swap_memory") != 0) {
535 push_node->mutable_attr()->erase("swap_memory");
536 }
537 push_node->set_op("Identity");
538 push_node->mutable_input()->SwapElements(0, 1);
539 const string ctrl_dep = ConstantFolding::AddControlDependency(
540 push_node->input(1), optimized_graph, &node_map);
541 push_node->set_input(1, ctrl_dep);
542 VLOG(1) << "After converting: " << push_node->DebugString();
543 }
544 }
545 }
546 return OkStatus();
547 }
548
IsSimpleBinaryOperator(const NodeDef & node)549 bool IsSimpleBinaryOperator(const NodeDef& node) {
550 return (IsLess(node) || IsLessEqual(node) || IsGreater(node) ||
551 IsGreaterEqual(node) || IsEqual(node));
552 }
553
EvaluateBoolOpForConstantOperands(const NodeDef & op_node,const NodeDef & constant_operand_0,const NodeDef & constant_operand_1,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * value)554 Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
555 const NodeDef& constant_operand_0,
556 const NodeDef& constant_operand_1,
557 DeviceBase* cpu_device,
558 ResourceMgr* resource_mgr,
559 bool* value) {
560 VLOG(4) << "Evaluate bool op: op_node=" << op_node.name()
561 << " input0=" << constant_operand_0.name()
562 << " input1=" << constant_operand_1.name();
563 TensorVector inputs;
564
565 const TensorProto& raw_val_0 = constant_operand_0.attr().at("value").tensor();
566 Tensor value_0(raw_val_0.dtype(), raw_val_0.tensor_shape());
567 CHECK(value_0.FromProto(raw_val_0));
568 inputs.emplace_back(&value_0);
569 const TensorProto& raw_val_1 = constant_operand_1.attr().at("value").tensor();
570 Tensor value_1(raw_val_1.dtype(), raw_val_1.tensor_shape());
571 CHECK(value_1.FromProto(raw_val_1));
572 inputs.emplace_back(&value_1);
573
574 TensorVector outputs;
575 TF_RETURN_IF_ERROR(
576 EvaluateNode(op_node, inputs, cpu_device, resource_mgr, &outputs));
577
578 if (outputs.size() != 1 || outputs[0].tensor == nullptr) {
579 return Status(error::INVALID_ARGUMENT, "Expected one output.");
580 }
581 *value = outputs[0].tensor->scalar<bool>()();
582 delete outputs[0].tensor;
583
584 return OkStatus();
585 }
586
587 // TODO(lyandy): Consolidate with ConstantFolding implementation.
IsReallyConstant(const NodeDef & node,const absl::flat_hash_set<string> & feed_nodes)588 bool IsReallyConstant(const NodeDef& node,
589 const absl::flat_hash_set<string>& feed_nodes) {
590 if (!IsConstant(node)) {
591 return false;
592 }
593 // If the node is fed it's not constant anymore.
594 return feed_nodes.find(node.name()) == feed_nodes.end();
595 }
596
CheckForDeadFanout(const MutableGraphView & view,const NodeDef & switch_node,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * has_dead_fanout,int * dead_fanout)597 Status CheckForDeadFanout(const MutableGraphView& view,
598 const NodeDef& switch_node, const NodeMap& node_map,
599 const absl::flat_hash_set<string>& feed_nodes,
600 DeviceBase* cpu_device, ResourceMgr* resource_mgr,
601 bool* has_dead_fanout, int* dead_fanout) {
602 *has_dead_fanout = false;
603 GraphView::InputPort switch_loopcond_port(&switch_node, 1);
604 const NodeDef* switch_predicate =
605 view.GetRegularFanin(switch_loopcond_port).node;
606
607 // CASE 1: Control is a constant.
608 if (IsReallyConstant(*switch_predicate, feed_nodes)) {
609 VLOG(3) << "Found switch node with constant predicate:"
610 << " switch_node=" << switch_node.name()
611 << " switch_predicate=" << switch_predicate->name();
612 Tensor selector;
613 CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
614 *has_dead_fanout = true;
615 *dead_fanout = selector.scalar<bool>()() ? 0 : 1;
616 return OkStatus();
617 }
618
619 GraphView::InputPort switch_input_port(&switch_node, 0);
620 const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
621
622 // CASE 2: Zero-iteration while loop.
623 // We check if its a while loop such that the condition is a simple binary
624 // operator which returns false for the initialization value.
625 // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
626 if (!IsMerge(*switch_input) || !IsLoopCond(*switch_predicate)) {
627 return OkStatus();
628 }
629
630 VLOG(4) << "Try to find a zero iteration while loop:"
631 << " switch_node=" << switch_node.name();
632
633 // Find the boolean predicate from a LoopCond node (e.g. Greater).
634 NodeDef* switch_ctrl_node = view.GetRegularFanin({switch_predicate, 0}).node;
635 if (!switch_ctrl_node || !IsSimpleBinaryOperator(*switch_ctrl_node)) {
636 return OkStatus();
637 }
638
639 // Find the Merge node & the Constant Operand to the condition node, if
640 // available.
641 NodeDef* merge_node = nullptr;
642 NodeDef* constant_ctrl_input = nullptr;
643 int constant_index = 0;
644 for (int i = 0; i < switch_ctrl_node->input().size(); ++i) {
645 const string& input = switch_ctrl_node->input(i);
646 if (IsControlInput(input)) continue;
647
648 NodeDef* node = view.GetNode(switch_ctrl_node->input(i));
649 if (IsMerge(*node)) {
650 merge_node = node;
651 }
652 if (IsReallyConstant(*node, feed_nodes)) {
653 constant_ctrl_input = node;
654 constant_index = i;
655 }
656 }
657 if (merge_node == nullptr || constant_ctrl_input == nullptr) {
658 return OkStatus();
659 }
660
661 // Find the initialization constant (via Enter, if one exists).
662 NodeDef* enter_node = nullptr;
663 NodeDef* constant_init_node = nullptr;
664 for (const auto& input : merge_node->input()) {
665 NodeDef* node = node_map.GetNode(input);
666 if (IsEnter(*node)) {
667 enter_node = node;
668 }
669 if (IsReallyConstant(*node, feed_nodes)) {
670 constant_init_node = node;
671 }
672 }
673 if (enter_node != nullptr) {
674 if (constant_init_node != nullptr) return OkStatus();
675 for (const auto& input : enter_node->input()) {
676 NodeDef* node = node_map.GetNode(input);
677 if (IsReallyConstant(*node, feed_nodes)) {
678 constant_init_node = node;
679 }
680 }
681 }
682 if (constant_init_node == nullptr) {
683 return OkStatus();
684 }
685
686 VLOG(4) << "Check if loop will be 0 iterations:"
687 << "\n| switch_node : " << switch_node.name()
688 << "\n| switch_ctrl_node : " << switch_ctrl_node->name()
689 << "\n| merge_node : " << merge_node->name()
690 << "\n| constant_ctrl_input: " << constant_ctrl_input->name()
691 << "\n| enter_node : "
692 << (enter_node ? enter_node->name() : "<n/a>")
693 << "\n| constant_init_node : " << constant_init_node->name();
694
695 // Check if there will be 0 iterations. This will only happen if the condition
696 // evaluates to false with respect to the initialization value.
697 NodeDef* operand_0 =
698 constant_index ? constant_init_node : constant_ctrl_input;
699 NodeDef* operand_1 =
700 constant_index ? constant_ctrl_input : constant_init_node;
701 bool constant_switch_value;
702 TF_RETURN_IF_ERROR(EvaluateBoolOpForConstantOperands(
703 *switch_ctrl_node, *operand_0, *operand_1, cpu_device, resource_mgr,
704 &constant_switch_value));
705
706 if (constant_switch_value == false) {
707 VLOG(3) << "Remove 0 iteration while loop:"
708 << " switch_node=" << switch_node.name();
709 *has_dead_fanout = true;
710 *dead_fanout = 1;
711 } else {
712 VLOG(4) << "Was not able to prove that loop has 0 iterations.";
713 }
714 return OkStatus();
715 }
716
717 } // namespace
718
LoopOptimizer()719 LoopOptimizer::LoopOptimizer()
720 : opt_level_(RewriterConfig::ON),
721 cpu_device_(nullptr),
722 options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
723
LoopOptimizer(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)724 LoopOptimizer::LoopOptimizer(RewriterConfig::Toggle opt_level,
725 DeviceBase* cpu_device)
726 : opt_level_(opt_level),
727 cpu_device_(cpu_device),
728 options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {
729 resource_mgr_.reset(new ResourceMgr());
730 }
731
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)732 Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
733 GraphDef* optimized_graph) {
734 if (!options_.enable_loop_invariant_node_motion &&
735 !options_.enable_stack_push_removal &&
736 !options_.enable_dead_branch_removal) {
737 return errors::Aborted("Nothing to do.");
738 }
739 *optimized_graph = item.graph;
740 // Set up helper data structures.
741 if (options_.enable_loop_invariant_node_motion) {
742 LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
743 TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
744 }
745 if (options_.enable_stack_push_removal) {
746 TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
747 }
748 if (options_.enable_dead_branch_removal) {
749 NodeMap node_map(optimized_graph);
750 absl::flat_hash_set<string> feed_nodes;
751 for (const auto& feed : item.feed) {
752 feed_nodes.insert(NodeName(feed.first));
753 }
754 TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
755 feed_nodes, optimized_graph));
756 }
757
758 return OkStatus();
759 }
760
RemoveDeadBranches(const std::unordered_set<string> & nodes_to_preserve,NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,GraphDef * optimized_graph)761 Status LoopOptimizer::RemoveDeadBranches(
762 const std::unordered_set<string>& nodes_to_preserve, NodeMap& node_map,
763 const absl::flat_hash_set<string>& feed_nodes, GraphDef* optimized_graph) {
764 std::unordered_set<const NodeDef*> dead_nodes;
765 std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
766 absl::flat_hash_set<GraphView::OutputPort> identity_switches;
767
768 MutableGraphView view(optimized_graph);
769 for (const NodeDef& node : optimized_graph->node()) {
770 if (!IsSwitch(node)) {
771 continue;
772 }
773 if (node.op() == "_SwitchN") { // _SwitchN not used in loop control flow.
774 continue;
775 }
776 if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
777 continue;
778 }
779
780 int dead_fanout;
781 bool has_dead_fanout;
782 TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
783 cpu_device_, resource_mgr_.get(),
784 &has_dead_fanout, &dead_fanout));
785 if (!has_dead_fanout) {
786 continue;
787 }
788 GraphView::OutputPort dead(&node, dead_fanout);
789
790 SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
791 zombie_inputs;
792 for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
793 if (dead_nodes.find(port.node) == dead_nodes.end()) {
794 zombie_inputs.PushBack(port);
795 }
796 }
797 // If we encounter a single node that must be preserved in the fanout of the
798 // switch node we need to preserve the entire switch fanout: we therefore
799 // work on a local copy that only gets committed to the master copy once the
800 // whole fanout has been explored.
801 std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
802 std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
803 dead_merge_inputs;
804 bool found_node_to_preserve = false;
805 while (!found_node_to_preserve && !zombie_inputs.Empty()) {
806 MutableGraphView::InputPort dead = zombie_inputs.PopBack();
807 if (nodes_to_preserve.find(dead.node->name()) !=
808 nodes_to_preserve.end()) {
809 found_node_to_preserve = true;
810 break;
811 }
812
813 if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
814 continue;
815 }
816
817 if (IsMerge(*dead.node)) {
818 const int num_data_inputs = dead.node->attr().at("N").i();
819 if (num_data_inputs > 2) {
820 // This can happen with _SwitchN/Merge (Case lowering). We skip these
821 // to simplify the code for now.
822 found_node_to_preserve = true;
823 break;
824 }
825 MutableGraphView::OutputPort value_index(dead.node, 1);
826 const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
827 view.GetFanout(value_index);
828 if (!index_fanout.empty()) {
829 // The 2nd output (that indicates which input is propagated) is
830 // connected. This never happens in practice, so we'll just skip this
831 // case to simplify the code for now.
832 found_node_to_preserve = true;
833 break;
834 }
835
836 bool fully_dead = false;
837 // Merge node can become real dead only if all data inputs are dead.
838 // Merge always waits for all control edges, but they do not
839 // change the node deadness.
840 if (dead.port_id >= 0) {
841 local_dead_merge_inputs[dead.node].insert(dead.port_id);
842 if (local_dead_merge_inputs[dead.node].size() == num_data_inputs) {
843 fully_dead = true;
844 }
845 } else {
846 // Keep track of all Merge nodes, even if they do not have dead data
847 // inputs. We'll need to cleanup dead control edges for them later.
848 local_dead_merge_inputs.insert({dead.node, {}});
849 }
850 if (fully_dead) {
851 local_dead_merge_inputs.erase(dead.node);
852 local_dead_nodes.insert(dead.node);
853 for (const MutableGraphView::InputPort& port :
854 view.GetFanouts(*dead.node, true)) {
855 zombie_inputs.PushBack(port);
856 }
857 }
858 } else if (dead.node->op() == "ControlTrigger") {
859 // Control trigger have different semantic, so don't touch them
860 found_node_to_preserve = true;
861 break;
862 } else {
863 if (local_dead_nodes.insert(dead.node).second) {
864 for (const MutableGraphView::InputPort& dead_fanout :
865 view.GetFanouts(*dead.node, true)) {
866 zombie_inputs.PushBack(dead_fanout);
867 }
868 }
869 }
870 }
871 if (!found_node_to_preserve) {
872 std::swap(dead_nodes, local_dead_nodes);
873 std::swap(dead_merge_inputs, local_dead_merge_inputs);
874 // Found no nodes to preserve in fanout of this switch node. This switch
875 // node can be replaced with Identity node, collect here to process later
876 identity_switches.insert(dead);
877 VLOG(3) << "Found no nodes to preserve in fanout of switch node: "
878 << node.name() << ", fanout port: " << dead_fanout;
879 }
880 }
881
882 std::vector<int> nodes_idx_to_delete;
883 nodes_idx_to_delete.reserve(dead_nodes.size());
884 for (int i = 0; i < optimized_graph->node_size(); ++i) {
885 if (dead_nodes.count(&optimized_graph->node(i)))
886 nodes_idx_to_delete.push_back(i);
887 }
888
889 // Names of the nodes that were removed from the graph.
890 absl::flat_hash_set<absl::string_view> dead_node_names;
891 dead_node_names.reserve(dead_nodes.size());
892 for (const NodeDef* dead_node : dead_nodes) {
893 dead_node_names.insert(dead_node->name());
894 }
895
896 // Check that the merge nodes are valid.
897 for (const auto& itr : dead_merge_inputs) {
898 NodeDef* merge_node = itr.first;
899 if (dead_nodes.find(merge_node) != dead_nodes.end()) {
900 // The node will be pruned since all its inputs are dead.
901 continue;
902 }
903 // Remove dead data input.
904 const std::set<int>& dead_inputs = itr.second;
905 const int num_data_inputs = merge_node->attr().at("N").i();
906 if (merge_node->input_size() != num_data_inputs) {
907 LOG(WARNING)
908 << "Skipping loop optimization for Merge node with control input: "
909 << merge_node->name();
910 return OkStatus();
911 } else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
912 LOG(WARNING) << "Skipping loop optimization for Merge node ("
913 << merge_node->name()
914 << ") with unexpected dead_inputs.size() ("
915 << dead_inputs.size() << " or num_data_inputs"
916 << num_data_inputs;
917 return OkStatus();
918 }
919 }
920
921 // Remove dead inputs from Merge nodes that will not be not
922 // pruned from the graph.
923 for (const auto& itr : dead_merge_inputs) {
924 NodeDef* merge_node = itr.first;
925 if (dead_nodes.find(merge_node) != dead_nodes.end()) {
926 // The node will be pruned since all its inputs are dead.
927 continue;
928 }
929 VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
930 // Remove dead data input.
931 const std::set<int>& dead_inputs = itr.second;
932 int index = *dead_inputs.begin();
933 auto* inputs = merge_node->mutable_input();
934 inputs->SwapElements(1, index);
935 inputs->SwapElements(1, merge_node->input_size() - 1);
936 inputs->RemoveLast();
937 merge_node->set_op("Identity");
938 merge_node->mutable_attr()->erase("N");
939
940 VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
941 }
942
943 for (const auto& id_switch : identity_switches) {
944 NodeDef* sw_node = const_cast<NodeDef*>((id_switch.node));
945 int dead_port_id = id_switch.port_id;
946
947 // Switch node where pred is not a constant, is not optimized.
948 // TODO(intel-tf): For that case, enable optimization only if safe.
949 // TODO(intel-tf): Need to check for RefSwitch and replace RefSwitch with
950 // RefIdentity
951 NodeDef* pred = node_map.GetNode(sw_node->input(1));
952 if (IsReallyConstant(*pred, feed_nodes) && sw_node->op() == "Switch") {
953 // From the dead_port_id, get the live port id, so we can correct
954 // input names of consumers. When switch will be replaced with Identity,
955 // it will have only 1 output versus 2 outputs of a Switch node
956 int live_port_id = (dead_port_id + 1) % 2;
957 string live_output_name = sw_node->name();
958 if (live_port_id == 1) {
959 live_output_name = StrCat(sw_node->name(), ":1");
960 }
961
962 // Get consumers of live port and update the input names
963 auto consumers = node_map.GetOutputs(sw_node->name());
964 for (auto* consumer : consumers) {
965 for (int i = 0; i < consumer->input_size(); ++i) {
966 if (consumer->input(i) == live_output_name) {
967 consumer->set_input(i, sw_node->name());
968 node_map.UpdateInput(consumer->name(), live_output_name,
969 sw_node->name());
970 }
971 }
972 }
973
974 VLOG(3) << "Switch node before cleanup: " << sw_node->DebugString();
975
976 // Change node from Switch to Identity and add a control dependency to
977 // this Identity op.
978 const string ctrl_dep = ConstantFolding::AddControlDependency(
979 pred->name(), optimized_graph, &node_map);
980 node_map.UpdateInput(sw_node->name(), pred->name(), ctrl_dep);
981 sw_node->set_input(1, ctrl_dep);
982 sw_node->set_op("Identity");
983 VLOG(3) << "Switch node after cleanup: " << sw_node->DebugString();
984 }
985 }
986 EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
987
988 return OkStatus();
989 }
990
991 } // end namespace grappler
992 } // end namespace tensorflow
993