xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/mutable_graph_view.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/mutable_graph_view.h"
17 
18 #include <algorithm>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/string_view.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/grappler/op_types.h"
32 #include "tensorflow/core/grappler/utils.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/stringpiece.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 namespace grappler {
41 
42 namespace {
43 
IsTensorIdPortValid(const TensorId & tensor_id)44 bool IsTensorIdPortValid(const TensorId& tensor_id) {
45   return tensor_id.index() >= Graph::kControlSlot;
46 }
47 
IsTensorIdRegular(const TensorId & tensor_id)48 bool IsTensorIdRegular(const TensorId& tensor_id) {
49   return tensor_id.index() > Graph::kControlSlot;
50 }
51 
IsTensorIdControlling(const TensorId & tensor_id)52 bool IsTensorIdControlling(const TensorId& tensor_id) {
53   return tensor_id.index() == Graph::kControlSlot;
54 }
55 
IsOutputPortControlling(const MutableGraphView::OutputPort & port)56 bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
57   return port.port_id == Graph::kControlSlot;
58 }
59 
60 // Determines if node is an Identity where it's first regular input is a Switch
61 // node.
IsIdentityConsumingSwitch(const MutableGraphView & graph,const NodeDef & node)62 bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
63                                const NodeDef& node) {
64   if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
65       node.input_size() > 0) {
66     TensorId tensor_id = ParseTensorName(node.input(0));
67     if (IsTensorIdControlling(tensor_id)) {
68       return false;
69     }
70 
71     NodeDef* input_node = graph.GetNode(tensor_id.node());
72     if (input_node == nullptr) {
73       return false;
74     }
75     return IsSwitch(*input_node);
76   }
77   return false;
78 }
79 
80 // Determines if node input can be deduped by regular inputs when used as a
81 // control dependency. Specifically, if a node is an Identity that leads to a
82 // Switch node, when used as a control dependency, that control dependency
83 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,const NodeDef & control_node)84 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
85                                      const NodeDef& control_node) {
86   return !IsIdentityConsumingSwitch(graph, control_node);
87 }
88 
89 // Determines if node input can be deduped by regular inputs when used as a
90 // control dependency. Specifically, if a node is an Identity that leads to a
91 // Switch node, when used as a control dependency, that control dependency
92 // should not be deduped even though the same node is used as a regular input.
CanDedupControlWithRegularInput(const MutableGraphView & graph,absl::string_view control_node_name)93 bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
94                                      absl::string_view control_node_name) {
95   NodeDef* control_node = graph.GetNode(control_node_name);
96   if (control_node == nullptr) {
97     return false;
98   }
99   return CanDedupControlWithRegularInput(graph, *control_node);
100 }
101 
HasRegularFaninNode(const MutableGraphView & graph,const NodeDef & node,absl::string_view fanin_node_name)102 bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
103                          absl::string_view fanin_node_name) {
104   const int num_regular_fanins =
105       graph.NumFanins(node, /*include_controlling_nodes=*/false);
106   for (int i = 0; i < num_regular_fanins; ++i) {
107     if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
108       return true;
109     }
110   }
111   return false;
112 }
113 
114 using FanoutsMap =
115     absl::flat_hash_map<MutableGraphView::OutputPort,
116                         absl::flat_hash_set<MutableGraphView::InputPort>>;
117 
SwapControlledFanoutInputs(const MutableGraphView & graph,const FanoutsMap::iterator & control_fanouts,absl::string_view to_node_name)118 void SwapControlledFanoutInputs(const MutableGraphView& graph,
119                                 const FanoutsMap::iterator& control_fanouts,
120                                 absl::string_view to_node_name) {
121   absl::string_view from_node_name(control_fanouts->first.node->name());
122   string control = TensorIdToString({to_node_name, Graph::kControlSlot});
123   for (const auto& control_fanout : control_fanouts->second) {
124     const int start = graph.NumFanins(*control_fanout.node,
125                                       /*include_controlling_nodes=*/false);
126     for (int i = start; i < control_fanout.node->input_size(); ++i) {
127       TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
128       if (tensor_id.node() == from_node_name) {
129         control_fanout.node->set_input(i, control);
130         break;
131       }
132     }
133   }
134 }
135 
SwapRegularFanoutInputs(FanoutsMap * fanouts,NodeDef * from_node,absl::string_view to_node_name,int max_port)136 void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
137                              absl::string_view to_node_name, int max_port) {
138   MutableGraphView::OutputPort port;
139   port.node = from_node;
140   for (int i = 0; i <= max_port; ++i) {
141     port.port_id = i;
142     auto it = fanouts->find(port);
143     if (it == fanouts->end()) {
144       continue;
145     }
146     string input = TensorIdToString({to_node_name, i});
147     for (const auto& fanout : it->second) {
148       fanout.node->set_input(fanout.port_id, input);
149     }
150   }
151 }
152 
153 using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
154 
SwapFanoutInputs(const MutableGraphView & graph,FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)155 void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
156                       MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
157                       NodeDef* to_node) {
158   auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
159   if (from_control_fanouts != fanouts->end()) {
160     SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
161   }
162   auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
163   if (to_control_fanouts != fanouts->end()) {
164     SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
165   }
166   auto from_max_port = max_output_ports->find(from_node);
167   if (from_max_port != max_output_ports->end()) {
168     SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
169                             from_max_port->second);
170   }
171   auto to_max_port = max_output_ports->find(to_node);
172   if (to_max_port != max_output_ports->end()) {
173     SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
174                             to_max_port->second);
175   }
176 }
177 
SwapFanoutsMapValues(FanoutsMap * fanouts,const MutableGraphView::OutputPort & from_port,const FanoutsMap::iterator & from_fanouts,const MutableGraphView::OutputPort & to_port,const FanoutsMap::iterator & to_fanouts)178 void SwapFanoutsMapValues(FanoutsMap* fanouts,
179                           const MutableGraphView::OutputPort& from_port,
180                           const FanoutsMap::iterator& from_fanouts,
181                           const MutableGraphView::OutputPort& to_port,
182                           const FanoutsMap::iterator& to_fanouts) {
183   const bool from_exists = from_fanouts != fanouts->end();
184   const bool to_exists = to_fanouts != fanouts->end();
185 
186   if (from_exists && to_exists) {
187     std::swap(from_fanouts->second, to_fanouts->second);
188   } else if (from_exists) {
189     fanouts->emplace(to_port, std::move(from_fanouts->second));
190     fanouts->erase(from_port);
191   } else if (to_exists) {
192     fanouts->emplace(from_port, std::move(to_fanouts->second));
193     fanouts->erase(to_port);
194   }
195 }
196 
SwapRegularFanoutsAndMaxPortValues(FanoutsMap * fanouts,MaxOutputPortsMap * max_output_ports,NodeDef * from_node,NodeDef * to_node)197 void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
198                                         MaxOutputPortsMap* max_output_ports,
199                                         NodeDef* from_node, NodeDef* to_node) {
200   auto from_max_port = max_output_ports->find(from_node);
201   auto to_max_port = max_output_ports->find(to_node);
202   bool from_exists = from_max_port != max_output_ports->end();
203   bool to_exists = to_max_port != max_output_ports->end();
204 
205   auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
206                                    int end) {
207     for (int i = start; i <= end; ++i) {
208       MutableGraphView::OutputPort from_port(from, i);
209       auto from_fanouts = fanouts->find(from_port);
210       if (from_fanouts != fanouts->end()) {
211         MutableGraphView::OutputPort to_port(to, i);
212         fanouts->emplace(to_port, std::move(from_fanouts->second));
213         fanouts->erase(from_port);
214       }
215     }
216   };
217 
218   if (from_exists && to_exists) {
219     const int from = from_max_port->second;
220     const int to = to_max_port->second;
221     const int shared = std::min(from, to);
222     for (int i = 0; i <= shared; ++i) {
223       MutableGraphView::OutputPort from_port(from_node, i);
224       auto from_fanouts = fanouts->find(from_port);
225       MutableGraphView::OutputPort to_port(to_node, i);
226       auto to_fanouts = fanouts->find(to_port);
227       SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
228                            to_fanouts);
229     }
230     if (to > from) {
231       forward_fanouts(to_node, from_node, shared + 1, to);
232     } else if (from > to) {
233       forward_fanouts(from_node, to_node, shared + 1, from);
234     }
235 
236     std::swap(from_max_port->second, to_max_port->second);
237   } else if (from_exists) {
238     forward_fanouts(from_node, to_node, 0, from_max_port->second);
239 
240     max_output_ports->emplace(to_node, from_max_port->second);
241     max_output_ports->erase(from_node);
242   } else if (to_exists) {
243     forward_fanouts(to_node, from_node, 0, to_max_port->second);
244 
245     max_output_ports->emplace(from_node, to_max_port->second);
246     max_output_ports->erase(to_node);
247   }
248 }
249 
HasFanoutValue(const FanoutsMap & fanouts,const FanoutsMap::iterator & it)250 bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
251   return it != fanouts.end() && !it->second.empty();
252 }
253 
MutationError(absl::string_view function_name,absl::string_view params,absl::string_view msg)254 Status MutationError(absl::string_view function_name, absl::string_view params,
255                      absl::string_view msg) {
256   return errors::InvalidArgument(absl::Substitute(
257       "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
258 }
259 
260 using ErrorHandler = std::function<Status(absl::string_view)>;
261 
UpdateFanoutsError(absl::string_view from_node_name,absl::string_view to_node_name)262 ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
263                                 absl::string_view to_node_name) {
264   return [from_node_name, to_node_name](absl::string_view msg) {
265     string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
266                                      from_node_name, to_node_name);
267     return MutationError("UpdateFanouts", params, msg);
268   };
269 }
270 
CheckFaninIsRegular(const TensorId & fanin,ErrorHandler handler)271 Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
272   if (!IsTensorIdRegular(fanin)) {
273     return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
274                                     fanin.ToString()));
275   }
276   return OkStatus();
277 }
278 
CheckFaninIsValid(const TensorId & fanin,ErrorHandler handler)279 Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
280   if (!IsTensorIdPortValid(fanin)) {
281     return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
282                                     fanin.ToString()));
283   }
284   return OkStatus();
285 }
286 
CheckAddingFaninToSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)287 Status CheckAddingFaninToSelf(absl::string_view node_name,
288                               const TensorId& fanin, ErrorHandler handler) {
289   if (node_name == fanin.node()) {
290     return handler(
291         absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
292   }
293   return OkStatus();
294 }
295 
CheckRemovingFaninFromSelf(absl::string_view node_name,const TensorId & fanin,ErrorHandler handler)296 Status CheckRemovingFaninFromSelf(absl::string_view node_name,
297                                   const TensorId& fanin, ErrorHandler handler) {
298   if (node_name == fanin.node()) {
299     return handler(absl::Substitute("can't remove fanin '$0' from self",
300                                     fanin.ToString()));
301   }
302   return OkStatus();
303 }
304 
NodeMissingErrorMsg(absl::string_view node_name)305 string NodeMissingErrorMsg(absl::string_view node_name) {
306   return absl::Substitute("node '$0' was not found", node_name);
307 }
308 
CheckNodeExists(absl::string_view node_name,NodeDef * node,ErrorHandler handler)309 Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
310                        ErrorHandler handler) {
311   if (node == nullptr) {
312     return handler(NodeMissingErrorMsg(node_name));
313   }
314   return OkStatus();
315 }
316 
CheckPortRange(int port,int min,int max,ErrorHandler handler)317 Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
318   if (port < min || port > max) {
319     if (max < min) {
320       return handler("no available ports as node has no regular fanins");
321     }
322     return handler(
323         absl::Substitute("port must be in range [$0, $1]", min, max));
324   }
325   return OkStatus();
326 }
327 
SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name)328 string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
329   return absl::Substitute(
330       "can't swap node name '$0' as it will become a Switch control dependency",
331       node_name);
332 }
333 
GeneratedNameForIdentityConsumingSwitch(const MutableGraphView::OutputPort & fanin)334 string GeneratedNameForIdentityConsumingSwitch(
335     const MutableGraphView::OutputPort& fanin) {
336   return AddPrefixToNodeName(
337       absl::StrCat(fanin.node->name(), "_", fanin.port_id),
338       kMutableGraphViewCtrl);
339 }
340 
PrintInTextFormat(const protobuf::MessageLite & message)341 string PrintInTextFormat(const protobuf::MessageLite& message) {
342   // Unfortunately proto2::TextFormat::Printer::PrintToString does not have
343   // a overload for MessageLite so here we have to use
344   // MessageLite::ShortDebugString.
345   return message.ShortDebugString();
346 }
347 
PrintInTextFormat(const protobuf::Message & message)348 string PrintInTextFormat(const protobuf::Message& message) {
349   string message_text;
350   ::tensorflow::protobuf::TextFormat::Printer printer;
351   printer.SetSingleLineMode(true);
352   printer.PrintToString(message, &message_text);
353   if (!message_text.empty() && message_text[message_text.size() - 1] == ' ') {
354     message_text.resize(message_text.size() - 1);
355   }
356   return message_text;
357 }
358 
359 }  // namespace
360 
AddAndDedupFanouts(NodeDef * node)361 void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
362   // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
363   // exist, and all regular fanins come before controlling fanins.
364   absl::flat_hash_set<absl::string_view> fanins;
365   absl::flat_hash_set<absl::string_view> controlling_fanins;
366   int max_input_port = -1;
367   int pos = 0;
368   const int last_idx = node->input_size() - 1;
369   int last_pos = last_idx;
370   while (pos <= last_pos) {
371     TensorId tensor_id = ParseTensorName(node->input(pos));
372     absl::string_view input_node_name = tensor_id.node();
373     bool is_control_input = IsTensorIdControlling(tensor_id);
374     bool can_dedup_control_with_regular_input =
375         CanDedupControlWithRegularInput(*this, input_node_name);
376     bool can_dedup_control =
377         is_control_input && (can_dedup_control_with_regular_input ||
378                              controlling_fanins.contains(input_node_name));
379     if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
380         can_dedup_control) {
381       node->mutable_input()->SwapElements(pos, last_pos);
382       --last_pos;
383     } else {
384       OutputPort output(nodes()[input_node_name], tensor_id.index());
385 
386       if (is_control_input) {
387         fanouts()[output].emplace(node, Graph::kControlSlot);
388       } else {
389         max_input_port = pos;
390         max_regular_output_port()[output.node] =
391             std::max(max_regular_output_port()[output.node], output.port_id);
392         fanouts()[output].emplace(node, pos);
393       }
394       ++pos;
395     }
396     if (is_control_input) {
397       controlling_fanins.insert(input_node_name);
398     }
399   }
400 
401   if (last_pos < last_idx) {
402     node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
403   }
404 
405   if (max_input_port > -1) {
406     max_regular_input_port()[node] = max_input_port;
407   }
408 }
409 
UpdateMaxRegularOutputPortForRemovedFanin(const OutputPort & fanin,const absl::flat_hash_set<InputPort> & fanin_fanouts)410 void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
411     const OutputPort& fanin,
412     const absl::flat_hash_set<InputPort>& fanin_fanouts) {
413   int max_port = max_regular_output_port()[fanin.node];
414   if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
415     return;
416   }
417   bool updated_max_port = false;
418   for (int i = fanin.port_id - 1; i >= 0; --i) {
419     OutputPort fanin_port(fanin.node, i);
420     if (!fanouts()[fanin_port].empty()) {
421       max_regular_output_port()[fanin.node] = i;
422       updated_max_port = true;
423       break;
424     }
425   }
426   if (!updated_max_port) {
427     max_regular_output_port().erase(fanin.node);
428   }
429 }
430 
UpdateMaxRegularOutputPortForAddedFanin(const OutputPort & fanin)431 void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
432     const OutputPort& fanin) {
433   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
434     max_regular_output_port()[fanin.node] = fanin.port_id;
435   }
436 }
437 
438 const absl::flat_hash_set<MutableGraphView::InputPort>&
GetFanout(const GraphView::OutputPort & port) const439 MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
440   return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
441                                                 port.port_id));
442 }
443 
GetFanin(const GraphView::InputPort & port) const444 absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
445     const GraphView::InputPort& port) const {
446   return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
447                                               port.port_id));
448 }
449 
GetRegularFanin(const GraphView::InputPort & port) const450 const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
451     const GraphView::InputPort& port) const {
452   return GetRegularFanin(MutableGraphView::InputPort(
453       const_cast<NodeDef*>(port.node), port.port_id));
454 }
455 
AddNode(NodeDef && node)456 NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
457   auto* node_in_graph = graph()->add_node();
458   *node_in_graph = std::move(node);
459 
460   AddUniqueNodeOrDie(node_in_graph);
461 
462   AddAndDedupFanouts(node_in_graph);
463   return node_in_graph;
464 }
465 
AddSubgraph(GraphDef && subgraph)466 Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
467   // 1. Add all new functions and check that functions with the same name
468   // have identical definition.
469   const int function_size = subgraph.library().function_size();
470   if (function_size > 0) {
471     absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
472     for (const FunctionDef& fdef : graph()->library().function()) {
473       graph_fdefs.emplace(fdef.signature().name(), &fdef);
474     }
475 
476     for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
477       const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
478 
479       if (graph_fdef == graph_fdefs.end()) {
480         VLOG(3) << "Add new function definition: " << fdef.signature().name();
481         graph()->mutable_library()->add_function()->Swap(&fdef);
482       } else {
483         if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
484           return MutationError(
485               "AddSubgraph",
486               absl::Substitute("function_size=$0", function_size),
487               absl::StrCat(
488                   "Found different function definition with the same name: ",
489                   fdef.signature().name()));
490         }
491       }
492     }
493   }
494 
495   // 2. Add all nodes to the underlying graph.
496   int node_size_before = graph()->node_size();
497 
498   for (NodeDef& node : *subgraph.mutable_node()) {
499     auto* node_in_graph = graph()->add_node();
500     node_in_graph->Swap(&node);
501     TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
502   }
503 
504   // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
505   // fanins actually exists in the graph, and there is already TODO for that.
506 
507   for (int i = node_size_before; i < graph()->node_size(); ++i) {
508     NodeDef* node = graph()->mutable_node(i);
509     AddAndDedupFanouts(node);
510   }
511 
512   return OkStatus();
513 }
514 
UpdateNode(absl::string_view node_name,absl::string_view op,absl::string_view device,absl::Span<const std::pair<string,AttrValue>> attrs)515 Status MutableGraphView::UpdateNode(
516     absl::string_view node_name, absl::string_view op, absl::string_view device,
517     absl::Span<const std::pair<string, AttrValue>> attrs) {
518   auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
519     std::vector<string> attr_strs;
520     attr_strs.reserve(attrs.size());
521     for (const auto& attr : attrs) {
522       string attr_str = absl::Substitute("('$0', $1)", attr.first,
523                                          PrintInTextFormat(attr.second));
524       attr_strs.push_back(attr_str);
525     }
526     string params =
527         absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
528                          node_name, op, device, absl::StrJoin(attr_strs, ", "));
529     return MutationError("UpdateNodeOp", params, msg);
530   };
531 
532   NodeDef* node = GetNode(node_name);
533   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
534 
535   MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
536   auto control_fanouts = GetFanout(control_port);
537   if (op == "Switch" && !control_fanouts.empty()) {
538     return error_status(
539         "can't change node op to Switch when node drives a control dependency "
540         "(alternatively, we could add the identity node needed, but it seems "
541         "like an unlikely event and probably a mistake)");
542   }
543 
544   if (node->device() != device) {
545     node->set_device(string(device));
546   }
547   node->mutable_attr()->clear();
548   for (const auto& attr : attrs) {
549     (*node->mutable_attr())[attr.first] = attr.second;
550   }
551 
552   if (node->op() == op) {
553     return OkStatus();
554   }
555 
556   node->set_op(string(op));
557 
558   if (CanDedupControlWithRegularInput(*this, *node)) {
559     for (const auto& control_fanout : control_fanouts) {
560       if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
561         RemoveControllingFaninInternal(control_fanout.node, node);
562       }
563     }
564   }
565 
566   return OkStatus();
567 }
568 
UpdateNodeName(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)569 Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
570                                         absl::string_view to_node_name,
571                                         bool update_fanouts) {
572   auto error_status = [from_node_name, to_node_name,
573                        update_fanouts](absl::string_view msg) {
574     string params = absl::Substitute(
575         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
576         from_node_name, to_node_name, update_fanouts);
577     return MutationError("UpdateNodeName", params, msg);
578   };
579 
580   NodeDef* node = GetNode(from_node_name);
581   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
582 
583   if (node->name() == to_node_name) {
584     return OkStatus();
585   }
586   if (HasNode(to_node_name)) {
587     return error_status(
588         "can't update node name because new node name is in use");
589   }
590   auto max_output_port = max_regular_output_port().find(node);
591   const bool has_max_output_port =
592       max_output_port != max_regular_output_port().end();
593   auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
594 
595   if (update_fanouts) {
596     SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
597     if (has_max_output_port) {
598       SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
599                               max_output_port->second);
600     }
601   } else if (has_max_output_port ||
602              HasFanoutValue(fanouts(), control_fanouts)) {
603     return error_status("can't update node name because node has fanouts");
604   }
605 
606   nodes().erase(node->name());
607   node->set_name(string(to_node_name));
608   nodes().emplace(node->name(), node);
609   return OkStatus();
610 }
611 
SwapNodeNames(absl::string_view from_node_name,absl::string_view to_node_name,bool update_fanouts)612 Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
613                                        absl::string_view to_node_name,
614                                        bool update_fanouts) {
615   auto error_status = [from_node_name, to_node_name,
616                        update_fanouts](absl::string_view msg) {
617     string params = absl::Substitute(
618         "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
619         from_node_name, to_node_name, update_fanouts);
620     return MutationError("SwapNodeNames", params, msg);
621   };
622 
623   NodeDef* from_node = GetNode(from_node_name);
624   TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
625   if (from_node_name == to_node_name) {
626     return OkStatus();
627   }
628   NodeDef* to_node = GetNode(to_node_name);
629   TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
630 
631   auto swap_names = [this, from_node, to_node]() {
632     nodes().erase(from_node->name());
633     nodes().erase(to_node->name());
634     std::swap(*from_node->mutable_name(), *to_node->mutable_name());
635     nodes().emplace(from_node->name(), from_node);
636     nodes().emplace(to_node->name(), to_node);
637   };
638 
639   if (update_fanouts) {
640     SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
641                      to_node);
642     swap_names();
643     return OkStatus();
644   }
645 
646   bool from_is_switch = IsSwitch(*from_node);
647   MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
648   auto to_control_fanouts = fanouts().find(to_control);
649   if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
650     return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
651   }
652 
653   bool to_is_switch = IsSwitch(*to_node);
654   MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
655   auto from_control_fanouts = fanouts().find(from_control);
656   if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
657     return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
658   }
659 
660   // Swap node names.
661   swap_names();
662 
663   // Swap controlling fanouts.
664   //
665   // Note: To and from control fanout iterators are still valid as no mutations
666   // has been performed on fanouts().
667   SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
668                        to_control, to_control_fanouts);
669 
670   // Swap regular fanouts.
671   SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
672                                      from_node, to_node);
673 
674   // Update fanins to remove self loops.
675   auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
676     for (int i = 0; i < node->input_size(); ++i) {
677       TensorId tensor_id = ParseTensorName(node->input(i));
678       if (tensor_id.node() == node->name()) {
679         const int idx = tensor_id.index();
680         const int node_idx =
681             IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
682 
683         MutableGraphView::OutputPort from_fanin(node, idx);
684         absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
685         from_fanouts->erase({node, node_idx});
686         UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
687 
688         MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
689         fanouts()[to_fanin].insert({node, node_idx});
690         UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
691         node->set_input(i, TensorIdToString({old_node_name, idx}));
692       }
693     }
694   };
695   update_fanins(from_node, to_node->name());
696   update_fanins(to_node, from_node->name());
697 
698   // Dedup control dependencies.
699   auto dedup_control_fanouts =
700       [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
701         if (CanDedupControlWithRegularInput(*this, *node) &&
702             control_fanouts != fanouts().end()) {
703           for (auto it = control_fanouts->second.begin();
704                it != control_fanouts->second.end();) {
705             // Advance `it` before invalidation from removal.
706             const auto& control_fanout = *it++;
707             if (HasRegularFaninNode(*this, *control_fanout.node,
708                                     node->name())) {
709               RemoveControllingFaninInternal(control_fanout.node, node);
710             }
711           }
712         }
713       };
714   auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
715     OutputPort port;
716     port.node = node;
717     const int max_port =
718         gtl::FindWithDefault(max_regular_output_port(), node, -1);
719     for (int i = 0; i <= max_port; ++i) {
720       port.port_id = i;
721       auto it = fanouts().find(port);
722       if (it == fanouts().end()) {
723         continue;
724       }
725       for (const auto& fanout : it->second) {
726         auto fanout_controls =
727             fanouts().find({fanout.node, Graph::kControlSlot});
728         dedup_control_fanouts(fanout.node, fanout_controls);
729       }
730     }
731   };
732 
733   if (!from_is_switch) {
734     if (to_is_switch) {
735       dedup_switch_control(from_node);
736     } else {
737       // Fetch iterator again as the original iterator might have been
738       // invalidated by container rehash triggered due to mutations.
739       auto from_control_fanouts = fanouts().find(from_control);
740       dedup_control_fanouts(from_node, from_control_fanouts);
741     }
742   }
743   if (!to_is_switch) {
744     if (from_is_switch) {
745       dedup_switch_control(to_node);
746     } else {
747       // Fetch iterator again as the original iterator might have been
748       // invalidated by container rehash triggered due to mutations.
749       auto to_control_fanouts = fanouts().find(to_control);
750       dedup_control_fanouts(to_node, to_control_fanouts);
751     }
752   }
753 
754   return OkStatus();
755 }
756 
UpdateFanouts(absl::string_view from_node_name,absl::string_view to_node_name)757 Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
758                                        absl::string_view to_node_name) {
759   NodeDef* from_node = GetNode(from_node_name);
760   TF_RETURN_IF_ERROR(
761       CheckNodeExists(from_node_name, from_node,
762                       UpdateFanoutsError(from_node_name, to_node_name)));
763   NodeDef* to_node = GetNode(to_node_name);
764   TF_RETURN_IF_ERROR(CheckNodeExists(
765       to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
766 
767   return UpdateFanoutsInternal(from_node, to_node);
768 }
769 
UpdateFanoutsInternal(NodeDef * from_node,NodeDef * to_node)770 Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
771                                                NodeDef* to_node) {
772   VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
773                               from_node->name(), to_node->name());
774   if (from_node == to_node) {
775     return OkStatus();
776   }
777 
778   // Update internal state with the new output_port->input_port edge.
779   const auto add_edge = [this](const OutputPort& output_port,
780                                const InputPort& input_port) {
781     fanouts()[output_port].insert(input_port);
782   };
783 
784   // Remove invalidated edge from the internal state.
785   const auto remove_edge = [this](const OutputPort& output_port,
786                                   const InputPort& input_port) {
787     fanouts()[output_port].erase(input_port);
788   };
789 
790   // For the control fanouts we do not know the input index in a NodeDef,
791   // so we have to traverse all control inputs.
792 
793   auto control_fanouts =
794       GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
795 
796   bool to_node_is_switch = IsSwitch(*to_node);
797   for (const InputPort& control_port : control_fanouts) {
798     // Node can't be control dependency of itself.
799     if (control_port.node == to_node) continue;
800 
801     // Can't add Switch node as a control dependency.
802     if (to_node_is_switch) {
803       // Trying to add a Switch as a control dependency, which if allowed will
804       // make the graph invalid.
805       return UpdateFanoutsError(from_node->name(), to_node->name())(
806           absl::Substitute("can't update fanouts to node '$0' as it will "
807                            "become a Switch control dependency",
808                            to_node->name()));
809     }
810 
811     NodeDef* node = control_port.node;
812     RemoveControllingFaninInternal(node, from_node);
813     AddFaninInternal(node, {to_node, Graph::kControlSlot});
814   }
815 
816   // First we update regular fanouts. For the regular fanouts
817   // `input_port:port_id` is the input index in NodeDef.
818 
819   auto regular_edges =
820       GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
821 
822   // Maximum index of the `from_node` output tensor that is still used as an
823   // input to some other node.
824   int keep_max_regular_output_port = -1;
825 
826   for (const Edge& edge : regular_edges) {
827     const OutputPort output_port = edge.src;
828     const InputPort input_port = edge.dst;
829 
830     // If the `to_node` reads from the `from_node`, skip this edge (see
831     // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
832     if (input_port.node == to_node) {
833       keep_max_regular_output_port =
834           std::max(keep_max_regular_output_port, output_port.port_id);
835       continue;
836     }
837 
838     // Update input at destination node.
839     input_port.node->set_input(
840         input_port.port_id,
841         TensorIdToString({to_node->name(), output_port.port_id}));
842 
843     // Remove old edge between the `from_node` and the fanout node.
844     remove_edge(output_port, input_port);
845     // Add an edge between the `to_node` and new fanout node.
846     add_edge(OutputPort(to_node, output_port.port_id), input_port);
847     // Dedup control dependency.
848     if (CanDedupControlWithRegularInput(*this, *to_node)) {
849       RemoveControllingFaninInternal(input_port.node, to_node);
850     }
851   }
852 
853   // Because we update all regular fanouts of `from_node`, we can just copy
854   // the value `num_regular_outputs`.
855   max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
856 
857   // Check if all fanouts were updated to read from the `to_node`.
858   if (keep_max_regular_output_port >= 0) {
859     max_regular_output_port()[from_node] = keep_max_regular_output_port;
860   } else {
861     max_regular_output_port().erase(from_node);
862   }
863 
864   return OkStatus();
865 }
866 
AddFaninInternal(NodeDef * node,const OutputPort & fanin)867 bool MutableGraphView::AddFaninInternal(NodeDef* node,
868                                         const OutputPort& fanin) {
869   int num_regular_fanins =
870       NumFanins(*node, /*include_controlling_nodes=*/false);
871   bool input_is_control = IsOutputPortControlling(fanin);
872   bool can_dedup_control_with_regular_input =
873       CanDedupControlWithRegularInput(*this, *fanin.node);
874   // Don't add duplicate control dependencies.
875   if (input_is_control) {
876     const int start =
877         can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
878     for (int i = start; i < node->input_size(); ++i) {
879       if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
880         return false;
881       }
882     }
883   }
884 
885   InputPort input;
886   input.node = node;
887   input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
888 
889   node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
890   if (!input_is_control) {
891     const int last_node_input = node->input_size() - 1;
892     // If there are control dependencies in node, move newly inserted fanin to
893     // be before such control dependencies.
894     if (num_regular_fanins < last_node_input) {
895       node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
896     }
897   }
898 
899   fanouts()[fanin].insert(input);
900   if (max_regular_output_port()[fanin.node] < fanin.port_id) {
901     max_regular_output_port()[fanin.node] = fanin.port_id;
902   }
903 
904   // Update max input port and dedup control dependencies.
905   if (!input_is_control) {
906     max_regular_input_port()[node] = num_regular_fanins;
907     if (can_dedup_control_with_regular_input) {
908       RemoveControllingFaninInternal(node, fanin.node);
909     }
910   }
911 
912   return true;
913 }
914 
AddRegularFanin(absl::string_view node_name,const TensorId & fanin)915 Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
916                                          const TensorId& fanin) {
917   auto error_status = [node_name, fanin](absl::string_view msg) {
918     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
919                                      fanin.ToString());
920     return MutationError("AddRegularFanin", params, msg);
921   };
922 
923   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
924   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
925   NodeDef* node = GetNode(node_name);
926   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
927   NodeDef* fanin_node = GetNode(fanin.node());
928   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
929 
930   AddFaninInternal(node, {fanin_node, fanin.index()});
931   return OkStatus();
932 }
933 
AddRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)934 Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
935                                                int port,
936                                                const TensorId& fanin) {
937   auto error_status = [node_name, port, fanin](absl::string_view msg) {
938     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
939                                      node_name, port, fanin.ToString());
940     return MutationError("AddRegularFaninByPort", params, msg);
941   };
942 
943   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
944   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
945   NodeDef* node = GetNode(node_name);
946   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
947   const int num_regular_fanins =
948       NumFanins(*node, /*include_controlling_nodes=*/false);
949   TF_RETURN_IF_ERROR(
950       CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
951   NodeDef* fanin_node = GetNode(fanin.node());
952   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
953 
954   const int last_node_input = node->input_size();
955   node->add_input(TensorIdToString(fanin));
956   node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
957   for (int i = num_regular_fanins - 1; i >= port; --i) {
958     TensorId tensor_id = ParseTensorName(node->input(i));
959     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
960     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
961     fanouts_set->erase({node, i});
962     fanouts_set->insert({node, i + 1});
963     node->mutable_input()->SwapElements(i, i + 1);
964   }
965 
966   OutputPort fanin_port(fanin_node, fanin.index());
967   fanouts()[fanin_port].insert({node, port});
968   UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
969 
970   max_regular_input_port()[node] = num_regular_fanins;
971   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
972     RemoveControllingFaninInternal(node, fanin_node);
973   }
974 
975   return OkStatus();
976 }
977 
GetControllingFaninToAdd(absl::string_view node_name,const OutputPort & fanin,string * error_msg)978 NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
979                                                     const OutputPort& fanin,
980                                                     string* error_msg) {
981   if (!IsSwitch(*fanin.node)) {
982     return fanin.node;
983   } else {
984     if (IsOutputPortControlling(fanin)) {
985       // Can't add a Switch node control dependency.
986       TensorId tensor_id(fanin.node->name(), fanin.port_id);
987       *error_msg = absl::Substitute(
988           "can't add fanin '$0' as it will become a Switch control dependency",
989           tensor_id.ToString());
990       return nullptr;
991     }
992     // We can't anchor control dependencies directly on the switch node: unlike
993     // other nodes only one of the outputs of the switch node will be generated
994     // when the switch node is executed, and we need to make sure the control
995     // dependency is only triggered when the corresponding output is triggered.
996     // We start by looking for an identity node connected to the output of the
997     // switch node, and use it to anchor the control dependency.
998     for (const auto& fanout : GetFanout(fanin)) {
999       if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
1000         if (fanout.node->name() == node_name) {
1001           *error_msg =
1002               absl::Substitute("can't add found fanin '$0' to self",
1003                                AsControlDependency(fanout.node->name()));
1004           return nullptr;
1005         }
1006         return fanout.node;
1007       }
1008     }
1009 
1010     // No node found, check if node to be created is itself.
1011     if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
1012       *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
1013                                     AsControlDependency(string(node_name)));
1014     }
1015   }
1016   return nullptr;
1017 }
1018 
GetOrCreateIdentityConsumingSwitch(const OutputPort & fanin)1019 NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
1020     const OutputPort& fanin) {
1021   // We haven't found an existing node where we can anchor the control
1022   // dependency: add a new identity node.
1023   string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
1024   NodeDef* identity_node = GetNode(identity_name);
1025   if (identity_node == nullptr) {
1026     NodeDef new_node;
1027     new_node.set_name(identity_name);
1028     new_node.set_op("Identity");
1029     new_node.set_device(fanin.node->device());
1030     (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1031     new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1032     identity_node = AddNode(std::move(new_node));
1033   }
1034   return identity_node;
1035 }
1036 
AddControllingFanin(absl::string_view node_name,const TensorId & fanin)1037 Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1038                                              const TensorId& fanin) {
1039   auto error_status = [node_name, fanin](absl::string_view msg) {
1040     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1041                                      fanin.ToString());
1042     return MutationError("AddControllingFanin", params, msg);
1043   };
1044 
1045   TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1046   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1047   NodeDef* node = GetNode(node_name);
1048   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1049   NodeDef* fanin_node = GetNode(fanin.node());
1050   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1051 
1052   OutputPort fanin_port(fanin_node, fanin.index());
1053 
1054   string error_msg = "";
1055   NodeDef* control_node = GetControllingFaninToAdd(
1056       node_name, {fanin_node, fanin.index()}, &error_msg);
1057   if (!error_msg.empty()) {
1058     return error_status(error_msg);
1059   }
1060   if (control_node == nullptr) {
1061     control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1062   }
1063   AddFaninInternal(node, {control_node, Graph::kControlSlot});
1064 
1065   return OkStatus();
1066 }
1067 
RemoveRegularFaninInternal(NodeDef * node,const OutputPort & fanin)1068 bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1069                                                   const OutputPort& fanin) {
1070   auto remove_input = [this, node](const OutputPort& fanin_port,
1071                                    int node_input_port, bool update_max_port) {
1072     InputPort input(node, node_input_port);
1073 
1074     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1075     fanouts_set->erase(input);
1076     if (update_max_port) {
1077       UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1078     }
1079     return fanouts_set;
1080   };
1081 
1082   auto mutable_inputs = node->mutable_input();
1083   bool modified = false;
1084   const int num_regular_fanins =
1085       NumFanins(*node, /*include_controlling_nodes=*/false);
1086   int i;
1087   int curr_pos = 0;
1088   for (i = 0; i < num_regular_fanins; ++i) {
1089     TensorId tensor_id = ParseTensorName(node->input(i));
1090     if (tensor_id.node() == fanin.node->name() &&
1091         tensor_id.index() == fanin.port_id) {
1092       remove_input(fanin, i, /*update_max_port=*/true);
1093       modified = true;
1094     } else if (modified) {
1095       // Regular inputs will need to have their ports updated.
1096       OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1097       auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1098       fanouts_set->insert({node, curr_pos});
1099       // Shift inputs to be retained.
1100       mutable_inputs->SwapElements(i, curr_pos);
1101       ++curr_pos;
1102     } else {
1103       // Skip inputs to be retained until first modification.
1104       ++curr_pos;
1105     }
1106   }
1107 
1108   if (modified) {
1109     const int last_regular_input_port = curr_pos - 1;
1110     if (last_regular_input_port < 0) {
1111       max_regular_input_port().erase(node);
1112     } else {
1113       max_regular_input_port()[node] = last_regular_input_port;
1114     }
1115     if (curr_pos < i) {
1116       // Remove fanins from node inputs.
1117       mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1118     }
1119   }
1120 
1121   return modified;
1122 }
1123 
RemoveRegularFanin(absl::string_view node_name,const TensorId & fanin)1124 Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1125                                             const TensorId& fanin) {
1126   auto error_status = [node_name, fanin](absl::string_view msg) {
1127     string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1128                                      fanin.ToString());
1129     return MutationError("RemoveRegularFanin", params, msg);
1130   };
1131 
1132   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1133   TF_RETURN_IF_ERROR(
1134       CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1135   NodeDef* node = GetNode(node_name);
1136   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1137   NodeDef* fanin_node = GetNode(fanin.node());
1138   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1139 
1140   RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1141   return OkStatus();
1142 }
1143 
RemoveRegularFaninByPort(absl::string_view node_name,int port)1144 Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1145                                                   int port) {
1146   auto error_status = [node_name, port](absl::string_view msg) {
1147     string params =
1148         absl::Substitute("node_name='$0', port=$1", node_name, port);
1149     return MutationError("RemoveRegularFaninByPort", params, msg);
1150   };
1151 
1152   NodeDef* node = GetNode(node_name);
1153   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1154   const int last_regular_fanin_port =
1155       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1156   TF_RETURN_IF_ERROR(
1157       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1158 
1159   TensorId tensor_id = ParseTensorName(node->input(port));
1160   OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1161   fanouts()[fanin_port].erase({node, port});
1162   auto mutable_inputs = node->mutable_input();
1163   for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1164     TensorId tensor_id = ParseTensorName(node->input(i));
1165     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1166     absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1167     fanouts_set->erase({node, i});
1168     fanouts_set->insert({node, i - 1});
1169     mutable_inputs->SwapElements(i - 1, i);
1170   }
1171   const int last_node_input = node->input_size() - 1;
1172   if (last_regular_fanin_port < last_node_input) {
1173     mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1174   }
1175   mutable_inputs->RemoveLast();
1176 
1177   const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1178   if (updated_last_regular_input_port < 0) {
1179     max_regular_input_port().erase(node);
1180   } else {
1181     max_regular_input_port()[node] = updated_last_regular_input_port;
1182   }
1183 
1184   return OkStatus();
1185 }
1186 
RemoveControllingFaninInternal(NodeDef * node,NodeDef * fanin_node)1187 bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1188                                                       NodeDef* fanin_node) {
1189   for (int i = node->input_size() - 1; i >= 0; --i) {
1190     TensorId tensor_id = ParseTensorName(node->input(i));
1191     if (tensor_id.index() > Graph::kControlSlot) {
1192       break;
1193     }
1194     if (tensor_id.node() == fanin_node->name()) {
1195       fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1196           {node, Graph::kControlSlot});
1197       node->mutable_input()->SwapElements(i, node->input_size() - 1);
1198       node->mutable_input()->RemoveLast();
1199       return true;
1200     }
1201   }
1202   return false;
1203 }
1204 
RemoveControllingFanin(absl::string_view node_name,absl::string_view fanin_node_name)1205 Status MutableGraphView::RemoveControllingFanin(
1206     absl::string_view node_name, absl::string_view fanin_node_name) {
1207   auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1208     string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1209                                      node_name, fanin_node_name);
1210     return MutationError("RemoveControllingFanin", params, msg);
1211   };
1212 
1213   TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1214       node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1215   NodeDef* node = GetNode(node_name);
1216   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1217   NodeDef* fanin_node = GetNode(fanin_node_name);
1218   TF_RETURN_IF_ERROR(
1219       CheckNodeExists(fanin_node_name, fanin_node, error_status));
1220 
1221   RemoveControllingFaninInternal(node, fanin_node);
1222   return OkStatus();
1223 }
1224 
RemoveAllFanins(absl::string_view node_name,bool keep_controlling_fanins)1225 Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1226                                          bool keep_controlling_fanins) {
1227   NodeDef* node = GetNode(node_name);
1228   if (node == nullptr) {
1229     string params =
1230         absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1231                          node_name, keep_controlling_fanins);
1232     return MutationError("RemoveAllFanins", params,
1233                          NodeMissingErrorMsg(node_name));
1234   }
1235 
1236   if (node->input().empty()) {
1237     return OkStatus();
1238   }
1239 
1240   const int num_regular_fanins =
1241       NumFanins(*node, /*include_controlling_nodes=*/false);
1242   RemoveFaninsInternal(node, keep_controlling_fanins);
1243   if (keep_controlling_fanins) {
1244     if (num_regular_fanins == 0) {
1245       return OkStatus();
1246     } else if (num_regular_fanins < node->input_size()) {
1247       node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1248     } else {
1249       node->clear_input();
1250     }
1251   } else {
1252     node->clear_input();
1253   }
1254   return OkStatus();
1255 }
1256 
UpdateFanin(absl::string_view node_name,const TensorId & from_fanin,const TensorId & to_fanin)1257 Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1258                                      const TensorId& from_fanin,
1259                                      const TensorId& to_fanin) {
1260   auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1261     string params =
1262         absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1263                          node_name, from_fanin.ToString(), to_fanin.ToString());
1264     return MutationError("UpdateFanin", params, msg);
1265   };
1266 
1267   TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1268   TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1269   NodeDef* node = GetNode(node_name);
1270   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1271   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1272   TF_RETURN_IF_ERROR(
1273       CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1274   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1275   TF_RETURN_IF_ERROR(
1276       CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1277 
1278   // When replacing a non control dependency fanin with a control dependency, or
1279   // vice versa, remove and add, so ports can be updated properly in fanout(s).
1280   bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1281   if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1282     // Can't add Switch node as a control dependency.
1283     return error_status(
1284         absl::Substitute("can't update to fanin '$0' as it will become a "
1285                          "Switch control dependency",
1286                          to_fanin.ToString()));
1287   }
1288   if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1289     return error_status("can't update fanin to or from self");
1290   }
1291 
1292   if (from_fanin == to_fanin) {
1293     return OkStatus();
1294   }
1295 
1296   bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1297   if (from_fanin_is_control || to_fanin_is_control) {
1298     bool modified = false;
1299     if (from_fanin_is_control) {
1300       modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1301     } else {
1302       modified |= RemoveRegularFaninInternal(
1303           node, {from_fanin_node, from_fanin.index()});
1304     }
1305     if (modified) {
1306       AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1307     }
1308     return OkStatus();
1309   }
1310 
1311   // In place mutation of regular fanins, requires no shifting of ports.
1312   string to_fanin_string = TensorIdToString(to_fanin);
1313   const int num_regular_fanins =
1314       NumFanins(*node, /*include_controlling_nodes=*/false);
1315   bool modified = false;
1316   for (int i = 0; i < num_regular_fanins; ++i) {
1317     if (ParseTensorName(node->input(i)) == from_fanin) {
1318       InputPort input(node, i);
1319 
1320       OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1321       fanouts()[from_fanin_port].erase(input);
1322 
1323       OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1324       fanouts()[to_fanin_port].insert(input);
1325 
1326       node->set_input(i, to_fanin_string);
1327       modified = true;
1328     }
1329   }
1330 
1331   // Dedup control dependencies and update max regular output ports.
1332   if (modified) {
1333     OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1334     UpdateMaxRegularOutputPortForRemovedFanin(
1335         {from_fanin_node, from_fanin.index()}, fanouts()[from_fanin_port]);
1336     if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1337       max_regular_output_port()[to_fanin_node] = to_fanin.index();
1338     }
1339     if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1340       RemoveControllingFaninInternal(node, to_fanin_node);
1341     }
1342   }
1343 
1344   return OkStatus();
1345 }
1346 
UpdateRegularFaninByPort(absl::string_view node_name,int port,const TensorId & fanin)1347 Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1348                                                   int port,
1349                                                   const TensorId& fanin) {
1350   auto error_status = [node_name, port, fanin](absl::string_view msg) {
1351     string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1352                                      node_name, port, fanin.ToString());
1353     return MutationError("UpdateRegularFaninByPort", params, msg);
1354   };
1355 
1356   TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1357   TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1358   NodeDef* node = GetNode(node_name);
1359   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1360   const int last_regular_fanin_port =
1361       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1362   TF_RETURN_IF_ERROR(
1363       CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1364   NodeDef* fanin_node = GetNode(fanin.node());
1365   TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1366 
1367   TensorId tensor_id = ParseTensorName(node->input(port));
1368   if (tensor_id == fanin) {
1369     return OkStatus();
1370   }
1371 
1372   InputPort input(node, port);
1373   OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1374   absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1375   from_fanouts->erase(input);
1376   UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1377 
1378   OutputPort to_fanin_port(fanin_node, fanin.index());
1379   fanouts()[to_fanin_port].insert(input);
1380   UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1381 
1382   node->set_input(port, TensorIdToString(fanin));
1383 
1384   if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1385     RemoveControllingFaninInternal(node, fanin_node);
1386   }
1387 
1388   return OkStatus();
1389 }
1390 
SwapRegularFaninsByPorts(absl::string_view node_name,int from_port,int to_port)1391 Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1392                                                   int from_port, int to_port) {
1393   auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1394     string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1395                                      node_name, from_port, to_port);
1396     return MutationError("SwapRegularFaninsByPorts", params, msg);
1397   };
1398 
1399   NodeDef* node = GetNode(node_name);
1400   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1401   const int last_regular_fanin_port =
1402       gtl::FindWithDefault(max_regular_input_port(), node, -1);
1403   TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1404                                     last_regular_fanin_port, error_status));
1405   TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1406                                     error_status));
1407 
1408   if (from_port == to_port) {
1409     return OkStatus();
1410   }
1411   TensorId from_fanin = ParseTensorName(node->input(from_port));
1412   TensorId to_fanin = ParseTensorName(node->input(to_port));
1413   if (from_fanin == to_fanin) {
1414     return OkStatus();
1415   }
1416 
1417   InputPort from_input(node, from_port);
1418   InputPort to_input(node, to_port);
1419   NodeDef* from_fanin_node = GetNode(from_fanin.node());
1420   absl::flat_hash_set<InputPort>* from_fanouts =
1421       &fanouts()[{from_fanin_node, from_fanin.index()}];
1422   from_fanouts->erase(from_input);
1423   from_fanouts->insert(to_input);
1424   NodeDef* to_fanin_node = GetNode(to_fanin.node());
1425   absl::flat_hash_set<InputPort>* to_fanouts =
1426       &fanouts()[{to_fanin_node, to_fanin.index()}];
1427   to_fanouts->erase(to_input);
1428   to_fanouts->insert(from_input);
1429 
1430   node->mutable_input()->SwapElements(from_port, to_port);
1431 
1432   return OkStatus();
1433 }
1434 
UpdateAllRegularFaninsToControlling(absl::string_view node_name)1435 Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1436     absl::string_view node_name) {
1437   auto error_status = [node_name](absl::string_view msg) {
1438     string params = absl::Substitute("node_name='$0'", node_name);
1439     return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1440   };
1441 
1442   NodeDef* node = GetNode(node_name);
1443   TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1444 
1445   const int num_regular_fanins =
1446       NumFanins(*node, /*include_controlling_nodes=*/false);
1447   std::vector<OutputPort> regular_fanins;
1448   regular_fanins.reserve(num_regular_fanins);
1449   std::vector<NodeDef*> controlling_fanins;
1450   controlling_fanins.reserve(num_regular_fanins);
1451 
1452   // Get all regular fanins and derive controlling fanins.
1453   for (int i = 0; i < num_regular_fanins; ++i) {
1454     TensorId tensor_id = ParseTensorName(node->input(i));
1455     OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1456 
1457     string error_msg = "";
1458     NodeDef* control_node =
1459         GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1460     if (!error_msg.empty()) {
1461       return error_status(error_msg);
1462     }
1463 
1464     regular_fanins.push_back(fanin_port);
1465     controlling_fanins.push_back(control_node);
1466   }
1467 
1468   // Replace regular fanins with controlling fanins and dedup.
1469   int pos = 0;
1470   InputPort input_port(node, Graph::kControlSlot);
1471   absl::flat_hash_set<absl::string_view> controls;
1472   for (int i = 0; i < num_regular_fanins; ++i) {
1473     OutputPort fanin_port = regular_fanins[i];
1474     NodeDef* control = controlling_fanins[i];
1475     if (control == nullptr) {
1476       control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1477     }
1478     fanouts()[fanin_port].erase({node, i});
1479     if (controls.contains(control->name())) {
1480       continue;
1481     }
1482     controls.insert(control->name());
1483     node->set_input(pos, AsControlDependency(control->name()));
1484     fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1485     ++pos;
1486   }
1487 
1488   // Shift existing controlling fanins and dedup.
1489   for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1490     TensorId tensor_id = ParseTensorName(node->input(i));
1491     if (controls.contains(tensor_id.node())) {
1492       continue;
1493     }
1494     controls.insert(tensor_id.node());
1495     node->mutable_input()->SwapElements(pos, i);
1496     ++pos;
1497   }
1498 
1499   // Remove duplicate controls and leftover regular fanins.
1500   node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1501   max_regular_input_port().erase(node);
1502 
1503   return OkStatus();
1504 }
1505 
CheckNodesCanBeDeleted(const absl::flat_hash_set<string> & nodes_to_delete)1506 Status MutableGraphView::CheckNodesCanBeDeleted(
1507     const absl::flat_hash_set<string>& nodes_to_delete) {
1508   std::vector<string> missing_nodes;
1509   std::vector<string> nodes_with_fanouts;
1510   for (const string& node_name_to_delete : nodes_to_delete) {
1511     NodeDef* node = GetNode(node_name_to_delete);
1512     if (node == nullptr) {
1513       // Can't delete missing node.
1514       missing_nodes.push_back(node_name_to_delete);
1515       continue;
1516     }
1517     const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1518                                               Graph::kControlSlot);
1519     for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1520       auto it = fanouts().find({node, i});
1521       bool has_retained_fanout = false;
1522       if (it != fanouts().end()) {
1523         for (const auto& fanout : it->second) {
1524           // Check if fanouts are of nodes to be deleted, and if so, they can be
1525           // ignored, as they will be removed also.
1526           if (!nodes_to_delete.contains(fanout.node->name())) {
1527             // Removing node will leave graph in an invalid state.
1528             has_retained_fanout = true;
1529             break;
1530           }
1531         }
1532       }
1533       if (has_retained_fanout) {
1534         nodes_with_fanouts.push_back(node_name_to_delete);
1535         break;
1536       }
1537     }
1538   }
1539 
1540   // Error message can get quite long, so we only show the first 5 node names.
1541   auto sort_and_sample = [](std::vector<string>* s) {
1542     constexpr int kMaxNodeNames = 5;
1543     std::sort(s->begin(), s->end());
1544     if (s->size() > kMaxNodeNames) {
1545       return absl::StrCat(
1546           absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1547     }
1548     return absl::StrJoin(*s, ", ");
1549   };
1550 
1551   if (!missing_nodes.empty()) {
1552     VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1553                                 sort_and_sample(&missing_nodes));
1554   }
1555   if (!nodes_with_fanouts.empty()) {
1556     std::vector<string> input_node_names(nodes_to_delete.begin(),
1557                                          nodes_to_delete.end());
1558     string params = absl::Substitute("nodes_to_delete={$0}",
1559                                      sort_and_sample(&input_node_names));
1560     string error_msg =
1561         absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1562                          sort_and_sample(&nodes_with_fanouts));
1563     return MutationError("DeleteNodes", params, error_msg);
1564   }
1565 
1566   return OkStatus();
1567 }
1568 
DeleteNodes(const absl::flat_hash_set<string> & nodes_to_delete)1569 Status MutableGraphView::DeleteNodes(
1570     const absl::flat_hash_set<string>& nodes_to_delete) {
1571   TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1572 
1573   // Find nodes in internal state and delete.
1574   for (const string& node_name_to_delete : nodes_to_delete) {
1575     NodeDef* node = GetNode(node_name_to_delete);
1576     if (node != nullptr) {
1577       RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1578       RemoveFanoutsInternal(node);
1579     }
1580   }
1581   for (const string& node_name_to_delete : nodes_to_delete) {
1582     nodes().erase(node_name_to_delete);
1583   }
1584 
1585   // Find nodes in graph and delete by partitioning into nodes to retain and
1586   // nodes to delete based on input set of nodes to delete by name.
1587   // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1588   // bottleneck.
1589   int pos = 0;
1590   const int last_idx = graph()->node_size() - 1;
1591   int last_pos = last_idx;
1592   while (pos <= last_pos) {
1593     if (nodes_to_delete.contains(graph()->node(pos).name())) {
1594       graph()->mutable_node()->SwapElements(pos, last_pos);
1595       --last_pos;
1596     } else {
1597       ++pos;
1598     }
1599   }
1600   if (last_pos < last_idx) {
1601     graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1602   }
1603 
1604   return OkStatus();
1605 }
1606 
RemoveFaninsInternal(NodeDef * deleted_node,bool keep_controlling_fanins)1607 void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1608                                             bool keep_controlling_fanins) {
1609   for (int i = 0; i < deleted_node->input_size(); ++i) {
1610     TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1611     bool is_control = IsTensorIdControlling(tensor_id);
1612     if (keep_controlling_fanins && is_control) {
1613       break;
1614     }
1615     OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1616 
1617     InputPort input;
1618     input.node = deleted_node;
1619     input.port_id = is_control ? Graph::kControlSlot : i;
1620 
1621     auto it = fanouts().find(fanin);
1622     if (it != fanouts().end()) {
1623       absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1624       fanouts_set->erase(input);
1625       UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1626     }
1627   }
1628   max_regular_input_port().erase(deleted_node);
1629 }
1630 
RemoveFanoutsInternal(NodeDef * deleted_node)1631 void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1632   const int max_port =
1633       gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1634   for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1635     fanouts().erase({deleted_node, i});
1636   }
1637   max_regular_output_port().erase(deleted_node);
1638 }
1639 
1640 }  // end namespace grappler
1641 }  // end namespace tensorflow
1642