xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/utils/graph_view_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/graph_view.h"
17 
18 #include <type_traits>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/benchmark_testlib.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 namespace utils {
36 namespace {
37 
38 using ::tensorflow::test::function::GDef;
39 using ::tensorflow::test::function::NDef;
40 
41 constexpr char kNoOp[] = "NoOp";
42 
SimpleTestGraph()43 GraphDef SimpleTestGraph() {
44   return GDef({NDef("a", kNoOp, {"b:2", "d:3", "b:2", "d:3", "^c"}),
45                NDef("b", kNoOp, {"d:2", "c:5", "^c"}),
46                NDef("c", kNoOp, {"^d", "^d"}), NDef("d", kNoOp, {})},
47               /*funcs=*/{});
48 }
49 
50 template <typename T>
GetGraphViewTypeAsString()51 const string GetGraphViewTypeAsString() {
52   return std::is_same<T, class GraphView>::value ? "GraphView"
53                                                  : "MutableGraphView";
54 }
55 
56 using GraphViewTypes = ::testing::Types<GraphView, MutableGraphView>;
57 
58 template <typename T>
59 class TypedGraphViewTest : public ::testing::Test {};
60 TYPED_TEST_SUITE(TypedGraphViewTest, GraphViewTypes);
61 
TYPED_TEST(TypedGraphViewTest,GraphWithDuplicateNodeNames)62 TYPED_TEST(TypedGraphViewTest, GraphWithDuplicateNodeNames) {
63   GraphDef graph =
64       GDef({NDef("a", kNoOp, {}), NDef("a", kNoOp, {})}, /*funcs=*/{});
65 
66   Status s;
67   TypeParam graph_view(&graph, &s);
68   EXPECT_FALSE(s.ok());
69   EXPECT_EQ(s.error_message(),
70             absl::Substitute(
71                 "$0::$0 error: graph has multiple nodes with the name 'a'.",
72                 GetGraphViewTypeAsString<TypeParam>()));
73 }
74 
TYPED_TEST(TypedGraphViewTest,GraphWithMissingFanins)75 TYPED_TEST(TypedGraphViewTest, GraphWithMissingFanins) {
76   GraphDef graph = GDef({NDef("a", kNoOp, {"b:3"})}, /*funcs=*/{});
77 
78   Status s;
79   TypeParam graph_view(&graph, &s);
80   EXPECT_FALSE(s.ok());
81   EXPECT_EQ(s.error_message(),
82             absl::Substitute("$0::$0 error: node 'a' has missing fanin 'b:3'.",
83                              GetGraphViewTypeAsString<TypeParam>()));
84 }
85 
TYPED_TEST(TypedGraphViewTest,GraphWithSelfCycles)86 TYPED_TEST(TypedGraphViewTest, GraphWithSelfCycles) {
87   GraphDef graph = GDef({NDef("a", kNoOp, {"a:4"})}, /*funcs=*/{});
88 
89   Status s;
90   TypeParam graph_view(&graph, &s);
91   EXPECT_FALSE(s.ok());
92   EXPECT_EQ(
93       s.error_message(),
94       absl::Substitute("$0::$0 error: node 'a' has self cycle fanin 'a:4'.",
95                        GetGraphViewTypeAsString<TypeParam>()));
96 }
97 
TYPED_TEST(TypedGraphViewTest,GraphWithMisorderedFanins)98 TYPED_TEST(TypedGraphViewTest, GraphWithMisorderedFanins) {
99   GraphDef graph = GDef({NDef("a", kNoOp, {"^b", "b:4"}), NDef("b", kNoOp, {})},
100                         /*funcs=*/{});
101 
102   Status s;
103   TypeParam graph_view(&graph, &s);
104   EXPECT_FALSE(s.ok());
105   EXPECT_EQ(s.error_message(),
106             absl::Substitute("$0::$0 error: node 'a' has regular fanin 'b:4' "
107                              "after controlling fanins.",
108                              GetGraphViewTypeAsString<TypeParam>()));
109 }
110 
TYPED_TEST(TypedGraphViewTest,GetNodeWithIndex)111 TYPED_TEST(TypedGraphViewTest, GetNodeWithIndex) {
112   GraphDef graph = SimpleTestGraph();
113 
114   Status s;
115   TypeParam graph_view(&graph, &s);
116   TF_ASSERT_OK(s);
117 
118   const int num_nodes = graph_view.NumNodes();
119   ASSERT_EQ(graph_view.NumNodes(), graph.node_size());
120   for (int i = 0; i < num_nodes; ++i) {
121     const auto* node = graph_view.GetNode(i);
122     ASSERT_NE(node, nullptr);
123     EXPECT_EQ(node->node(), graph.mutable_node(i));
124   }
125 
126   const auto* bad_node = graph_view.GetNode(-1);
127   ASSERT_EQ(bad_node, nullptr);
128   bad_node = graph_view.GetNode(num_nodes);
129   ASSERT_EQ(bad_node, nullptr);
130 }
131 
TYPED_TEST(TypedGraphViewTest,GetNodeWithName)132 TYPED_TEST(TypedGraphViewTest, GetNodeWithName) {
133   GraphDef graph = SimpleTestGraph();
134 
135   Status s;
136   TypeParam graph_view(&graph, &s);
137   TF_ASSERT_OK(s);
138 
139   std::vector<string> node_names = {"a", "b", "c", "d"};
140   for (int i = 0; i < node_names.size(); ++i) {
141     const string& node_name = node_names[i];
142     const auto* node = graph_view.GetNode(node_name);
143     ASSERT_NE(node, nullptr);
144     EXPECT_EQ(node->node(), graph.mutable_node(i));
145   }
146 
147   // Missing node.
148   const auto* bad_node = graph_view.GetNode("e");
149   ASSERT_EQ(bad_node, nullptr);
150 }
151 
TYPED_TEST(TypedGraphViewTest,GetNodes)152 TYPED_TEST(TypedGraphViewTest, GetNodes) {
153   GraphDef graph = SimpleTestGraph();
154 
155   Status s;
156   TypeParam graph_view(&graph, &s);
157   TF_ASSERT_OK(s);
158 
159   const auto& nodes = graph_view.GetNodes();
160   const int num_nodes = nodes.size();
161   EXPECT_EQ(num_nodes, 4);
162 
163   ASSERT_EQ(num_nodes, graph.node_size());
164   for (int i = 0; i < num_nodes; ++i) {
165     EXPECT_EQ(nodes[i].node(), graph.mutable_node(i));
166   }
167 }
168 
TYPED_TEST(TypedGraphViewTest,HasNode)169 TYPED_TEST(TypedGraphViewTest, HasNode) {
170   GraphDef graph = SimpleTestGraph();
171 
172   Status s;
173   TypeParam graph_view(&graph, &s);
174   TF_ASSERT_OK(s);
175 
176   for (const string& node_name : {"a", "b", "c", "d"}) {
177     EXPECT_TRUE(graph_view.HasNode(node_name));
178   }
179 
180   // Missing node.
181   EXPECT_FALSE(graph_view.HasNode("e"));
182 }
183 
TYPED_TEST(TypedGraphViewTest,NumNodes)184 TYPED_TEST(TypedGraphViewTest, NumNodes) {
185   GraphDef graph = SimpleTestGraph();
186 
187   Status s;
188   TypeParam graph_view(&graph, &s);
189   TF_ASSERT_OK(s);
190   EXPECT_EQ(graph_view.NumNodes(), 4);
191 }
192 
TYPED_TEST(TypedGraphViewTest,NumNodesEmptyGraph)193 TYPED_TEST(TypedGraphViewTest, NumNodesEmptyGraph) {
194   GraphDef graph;
195 
196   Status s;
197   TypeParam graph_view(&graph, &s);
198   TF_ASSERT_OK(s);
199   EXPECT_EQ(graph_view.NumNodes(), 0);
200 }
201 
TEST(MutableGraphViewTest,DedupControlDependencies)202 TEST(MutableGraphViewTest, DedupControlDependencies) {
203   GraphDef graph = GDef(
204       {NDef("a", kNoOp, {}), NDef("b", kNoOp, {}), NDef("c", kNoOp, {}),
205        NDef("d", kNoOp, {"a:2", "b:1", "^c", "^c", "^a", "^a", "^b", "^c"})},
206       /*funcs=*/{});
207 
208   Status s;
209   MutableGraphView graph_view(&graph, &s);
210   TF_ASSERT_OK(s);
211   EXPECT_EQ(graph_view.NumNodes(), 4);
212 
213   const auto* a_node = graph_view.GetNode("a");
214   ASSERT_NE(a_node, nullptr);
215   const auto* b_node = graph_view.GetNode("b");
216   ASSERT_NE(b_node, nullptr);
217   const auto* c_node = graph_view.GetNode("c");
218   ASSERT_NE(c_node, nullptr);
219   const auto* d_node = graph_view.GetNode("d");
220   ASSERT_NE(d_node, nullptr);
221 
222   EXPECT_EQ(d_node->NumRegularFanins(), 2);
223   ASSERT_NE(d_node->node(), nullptr);
224   ASSERT_EQ(d_node->node()->input_size(), 5);
225   EXPECT_EQ(d_node->node()->input(0), "a:2");
226   EXPECT_EQ(d_node->node()->input(1), "b:1");
227   EXPECT_EQ(d_node->node()->input(2), "^c");
228   EXPECT_EQ(d_node->node()->input(3), "^b");
229   EXPECT_EQ(d_node->node()->input(4), "^a");
230   ASSERT_EQ(d_node->NumControllingFanins(), 3);
231   const auto& d_control_fanins = d_node->GetControllingFanins();
232   ASSERT_EQ(d_control_fanins.size(), 3);
233   ASSERT_NE(d_control_fanins[0].node_view(), nullptr);
234   EXPECT_EQ(d_control_fanins[0].node_view()->GetName(), "c");
235   ASSERT_NE(d_control_fanins[1].node_view(), nullptr);
236   EXPECT_EQ(d_control_fanins[1].node_view()->GetName(), "b");
237   ASSERT_NE(d_control_fanins[2].node_view(), nullptr);
238   EXPECT_EQ(d_control_fanins[2].node_view()->GetName(), "a");
239 }
240 
241 template <typename T>
242 class TypedNodeViewTest : public ::testing::Test {};
243 TYPED_TEST_SUITE(TypedNodeViewTest, GraphViewTypes);
244 
TYPED_TEST(TypedNodeViewTest,GetName)245 TYPED_TEST(TypedNodeViewTest, GetName) {
246   GraphDef graph = SimpleTestGraph();
247 
248   Status s;
249   TypeParam graph_view(&graph, &s);
250   TF_ASSERT_OK(s);
251 
252   for (const NodeDef& node : graph.node()) {
253     const auto* node_view = graph_view.GetNode(node.name());
254     ASSERT_NE(node_view, nullptr);
255     EXPECT_EQ(node_view->GetName(), node.name());
256     EXPECT_EQ(node_view->GetName(), node_view->node()->name());
257   }
258 }
259 
TYPED_TEST(TypedNodeViewTest,GetOp)260 TYPED_TEST(TypedNodeViewTest, GetOp) {
261   GraphDef graph = GDef({NDef("a", "op_a", {}), NDef("b", "op_b", {}),
262                          NDef("c", "op_c", {}), NDef("d", "op_d", {})},
263                         /*funcs=*/{});
264 
265   Status s;
266   TypeParam graph_view(&graph, &s);
267   TF_ASSERT_OK(s);
268 
269   const auto* a_node = graph_view.GetNode("a");
270   ASSERT_NE(a_node, nullptr);
271   EXPECT_EQ(a_node->GetOp(), "op_a");
272   EXPECT_EQ(a_node->node()->op(), "op_a");
273   const auto* b_node = graph_view.GetNode("b");
274   ASSERT_NE(b_node, nullptr);
275   EXPECT_EQ(b_node->GetOp(), "op_b");
276   EXPECT_EQ(b_node->node()->op(), "op_b");
277   const auto* c_node = graph_view.GetNode("c");
278   ASSERT_NE(c_node, nullptr);
279   EXPECT_EQ(c_node->GetOp(), "op_c");
280   EXPECT_EQ(c_node->node()->op(), "op_c");
281   const auto* d_node = graph_view.GetNode("d");
282   ASSERT_NE(d_node, nullptr);
283   EXPECT_EQ(d_node->GetOp(), "op_d");
284   EXPECT_EQ(d_node->node()->op(), "op_d");
285 }
286 
TYPED_TEST(TypedNodeViewTest,GetDevice)287 TYPED_TEST(TypedNodeViewTest, GetDevice) {
288   GraphDef graph = GDef(
289       {NDef("a", "", {}, {}, "device_a"), NDef("b", "", {}, {}, "device_b"),
290        NDef("c", "", {}, {}, "device_c"), NDef("d", "", {}, {})},
291       /*funcs=*/{});
292 
293   Status s;
294   TypeParam graph_view(&graph, &s);
295   TF_ASSERT_OK(s);
296 
297   const auto* a_node = graph_view.GetNode("a");
298   ASSERT_NE(a_node, nullptr);
299   EXPECT_EQ(a_node->GetDevice(), "device_a");
300   EXPECT_EQ(a_node->node()->device(), "device_a");
301   const auto* b_node = graph_view.GetNode("b");
302   ASSERT_NE(b_node, nullptr);
303   EXPECT_EQ(b_node->GetDevice(), "device_b");
304   EXPECT_EQ(b_node->node()->device(), "device_b");
305   const auto* c_node = graph_view.GetNode("c");
306   ASSERT_NE(c_node, nullptr);
307   EXPECT_EQ(c_node->GetDevice(), "device_c");
308   EXPECT_EQ(c_node->node()->device(), "device_c");
309   const auto* d_node = graph_view.GetNode("d");
310   ASSERT_NE(d_node, nullptr);
311   EXPECT_EQ(d_node->GetDevice(), "");
312   EXPECT_EQ(d_node->node()->device(), "");
313 }
314 
315 template <typename T>
316 class TypedFaninTest : public ::testing::Test {};
317 using FaninTypes =
318     ::testing::Types<std::pair<FanoutView, GraphView>,
319                      std::pair<MutableFanoutView, MutableGraphView>>;
320 TYPED_TEST_SUITE(TypedFaninTest, FaninTypes);
321 
TYPED_TEST(TypedFaninTest,GetRegularFanins)322 TYPED_TEST(TypedFaninTest, GetRegularFanins) {
323   using FanoutViewType = typename TypeParam::first_type;
324   using GraphViewType = typename TypeParam::second_type;
325 
326   GraphDef graph = SimpleTestGraph();
327 
328   Status s;
329   GraphViewType graph_view(&graph, &s);
330   TF_ASSERT_OK(s);
331 
332   auto* a_node = graph_view.GetNode("a");
333   ASSERT_NE(a_node, nullptr);
334   auto* b_node = graph_view.GetNode("b");
335   ASSERT_NE(b_node, nullptr);
336   auto* d_node = graph_view.GetNode("d");
337   ASSERT_NE(d_node, nullptr);
338 
339   const auto& a_fanins = a_node->GetRegularFanins();
340   ASSERT_EQ(a_fanins.size(), 4);
341   EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, b_node->node_index(), 2));
342   EXPECT_EQ(a_fanins[1], FanoutViewType(&graph_view, d_node->node_index(), 3));
343   EXPECT_EQ(a_fanins[2], FanoutViewType(&graph_view, b_node->node_index(), 2));
344   EXPECT_EQ(a_fanins[3], FanoutViewType(&graph_view, d_node->node_index(), 3));
345 
346   const auto& d_fanins = d_node->GetRegularFanins();
347   EXPECT_EQ(d_fanins.size(), 0);
348 }
349 
TYPED_TEST(TypedFaninTest,GetRegularFanin)350 TYPED_TEST(TypedFaninTest, GetRegularFanin) {
351   using FanoutViewType = typename TypeParam::first_type;
352   using GraphViewType = typename TypeParam::second_type;
353 
354   GraphDef graph = SimpleTestGraph();
355 
356   Status s;
357   GraphViewType graph_view(&graph, &s);
358   TF_ASSERT_OK(s);
359 
360   auto* a_node = graph_view.GetNode("a");
361   ASSERT_NE(a_node, nullptr);
362   auto* b_node = graph_view.GetNode("b");
363   ASSERT_NE(b_node, nullptr);
364   auto* d_node = graph_view.GetNode("d");
365   ASSERT_NE(d_node, nullptr);
366 
367   const auto& a_fanin_0 = a_node->GetRegularFanin(0);
368   EXPECT_EQ(a_fanin_0, FanoutViewType(&graph_view, b_node->node_index(), 2));
369   const auto& a_fanin_1 = a_node->GetRegularFanin(1);
370   EXPECT_EQ(a_fanin_1, FanoutViewType(&graph_view, d_node->node_index(), 3));
371   const auto& a_fanin_2 = a_node->GetRegularFanin(2);
372   EXPECT_EQ(a_fanin_2, FanoutViewType(&graph_view, b_node->node_index(), 2));
373   const auto& a_fanin_3 = a_node->GetRegularFanin(3);
374   EXPECT_EQ(a_fanin_3, FanoutViewType(&graph_view, d_node->node_index(), 3));
375 
376   // Out of bounds.
377   const FanoutViewType missing_fanin;
378   EXPECT_EQ(missing_fanin, FanoutViewType(nullptr, -1, -2));
379   EXPECT_EQ(missing_fanin.node_view(), nullptr);
380   const auto& a_fanin_4 = a_node->GetRegularFanin(4);
381   EXPECT_EQ(a_fanin_4, missing_fanin);
382   const auto& a_fanin_5 = a_node->GetRegularFanin(5);
383   EXPECT_EQ(a_fanin_5, missing_fanin);
384   const auto& a_fanin_control = a_node->GetRegularFanin(Graph::kControlSlot);
385   EXPECT_EQ(a_fanin_control, missing_fanin);
386   const auto& a_fanin_bad = a_node->GetRegularFanin(-2);
387   EXPECT_EQ(a_fanin_bad, missing_fanin);
388 }
389 
TYPED_TEST(TypedFaninTest,GetControllingFanins)390 TYPED_TEST(TypedFaninTest, GetControllingFanins) {
391   using FanoutViewType = typename TypeParam::first_type;
392   using GraphViewType = typename TypeParam::second_type;
393 
394   GraphDef graph = SimpleTestGraph();
395 
396   Status s;
397   GraphViewType graph_view(&graph, &s);
398   TF_ASSERT_OK(s);
399 
400   auto* a_node = graph_view.GetNode("a");
401   ASSERT_NE(a_node, nullptr);
402   auto* c_node = graph_view.GetNode("c");
403   ASSERT_NE(c_node, nullptr);
404   auto* d_node = graph_view.GetNode("d");
405   ASSERT_NE(d_node, nullptr);
406 
407   const auto& a_fanins = a_node->GetControllingFanins();
408   ASSERT_EQ(a_fanins.size(), 1);
409   EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, c_node->node_index(),
410                                         Graph::kControlSlot));
411 
412   const auto& c_fanins = c_node->GetControllingFanins();
413   FanoutViewType d_control_fanin(&graph_view, d_node->node_index(),
414                                  Graph::kControlSlot);
415   if (std::is_same<GraphViewType, GraphView>::value) {
416     ASSERT_EQ(c_fanins.size(), 2);
417     EXPECT_EQ(c_fanins[0], d_control_fanin);
418     EXPECT_EQ(c_fanins[1], d_control_fanin);
419   } else {  // MutableGraphView will dedup control dependency.
420     ASSERT_EQ(c_fanins.size(), 1);
421     EXPECT_EQ(c_fanins[0], d_control_fanin);
422   }
423 
424   const auto& d_fanins = d_node->GetControllingFanins();
425   EXPECT_EQ(d_fanins.size(), 0);
426 }
427 
428 template <typename T>
429 class TypedFanoutTest : public ::testing::Test {};
430 using FanoutTypes =
431     ::testing::Types<std::pair<FaninView, GraphView>,
432                      std::pair<MutableFaninView, MutableGraphView>>;
433 TYPED_TEST_SUITE(TypedFanoutTest, FanoutTypes);
434 
TYPED_TEST(TypedFanoutTest,GetRegularFanouts)435 TYPED_TEST(TypedFanoutTest, GetRegularFanouts) {
436   using FaninViewType = typename TypeParam::first_type;
437   using GraphViewType = typename TypeParam::second_type;
438 
439   GraphDef graph = SimpleTestGraph();
440 
441   Status s;
442   GraphViewType graph_view(&graph, &s);
443   TF_ASSERT_OK(s);
444 
445   auto* a_node = graph_view.GetNode("a");
446   ASSERT_NE(a_node, nullptr);
447   auto* b_node = graph_view.GetNode("b");
448   ASSERT_NE(b_node, nullptr);
449   auto* d_node = graph_view.GetNode("d");
450   ASSERT_NE(d_node, nullptr);
451 
452   const auto& d_fanouts = d_node->GetRegularFanouts();
453   ASSERT_EQ(d_fanouts.size(), 4);
454   for (int i = 0; i < d_fanouts.size(); ++i) {
455     if (i == 2) {
456       ASSERT_EQ(d_fanouts[i].size(), 1);
457       EXPECT_EQ(d_fanouts[i][0],
458                 FaninViewType(&graph_view, b_node->node_index(), 0));
459     } else if (i == 3) {
460       ASSERT_EQ(d_fanouts[i].size(), 2);
461       absl::flat_hash_set<FaninViewType> fanouts(d_fanouts[i].begin(),
462                                                  d_fanouts[i].end());
463       EXPECT_TRUE(fanouts.contains(
464           FaninViewType(&graph_view, a_node->node_index(), 1)));
465       EXPECT_TRUE(fanouts.contains(
466           FaninViewType(&graph_view, a_node->node_index(), 3)));
467     } else {
468       EXPECT_EQ(d_fanouts[i].size(), 0);
469     }
470   }
471 
472   const auto& a_fanouts = a_node->GetRegularFanouts();
473   EXPECT_EQ(a_fanouts.size(), 0);
474 }
475 
TYPED_TEST(TypedFanoutTest,GetRegularFanout)476 TYPED_TEST(TypedFanoutTest, GetRegularFanout) {
477   using FaninViewType = typename TypeParam::first_type;
478   using GraphViewType = typename TypeParam::second_type;
479 
480   GraphDef graph = SimpleTestGraph();
481 
482   Status s;
483   GraphViewType graph_view(&graph, &s);
484   TF_ASSERT_OK(s);
485 
486   auto* a_node = graph_view.GetNode("a");
487   ASSERT_NE(a_node, nullptr);
488   auto* b_node = graph_view.GetNode("b");
489   ASSERT_NE(b_node, nullptr);
490   auto* d_node = graph_view.GetNode("d");
491   ASSERT_NE(d_node, nullptr);
492 
493   const auto& d_fanouts_2 = d_node->GetRegularFanout(2);
494   ASSERT_EQ(d_fanouts_2.size(), 1);
495   EXPECT_EQ(d_fanouts_2.at(0),
496             FaninViewType(&graph_view, b_node->node_index(), 0));
497 
498   const auto& d_fanouts_3 = d_node->GetRegularFanout(3);
499   EXPECT_EQ(d_fanouts_3.size(), 2);
500   absl::flat_hash_set<FaninViewType> d_fanouts_3_set(d_fanouts_3.begin(),
501                                                      d_fanouts_3.end());
502   EXPECT_TRUE(d_fanouts_3_set.contains(
503       FaninViewType(&graph_view, a_node->node_index(), 1)));
504   EXPECT_TRUE(d_fanouts_3_set.contains(
505       FaninViewType(&graph_view, a_node->node_index(), 3)));
506 
507   // Invalid or empty.
508   const std::vector<FaninViewType> no_fanouts;
509   EXPECT_EQ(d_node->GetRegularFanout(-2), no_fanouts);
510   EXPECT_EQ(d_node->GetRegularFanout(Graph::kControlSlot), no_fanouts);
511   EXPECT_EQ(d_node->GetRegularFanout(0), no_fanouts);
512   EXPECT_EQ(d_node->GetRegularFanout(1), no_fanouts);
513   EXPECT_EQ(d_node->GetRegularFanout(4), no_fanouts);
514   EXPECT_EQ(d_node->GetRegularFanout(5), no_fanouts);
515 }
516 
TYPED_TEST(TypedFanoutTest,GetControlledFanouts)517 TYPED_TEST(TypedFanoutTest, GetControlledFanouts) {
518   using FaninViewType = typename TypeParam::first_type;
519   using GraphViewType = typename TypeParam::second_type;
520 
521   GraphDef graph = SimpleTestGraph();
522 
523   Status s;
524   GraphViewType graph_view(&graph, &s);
525   TF_ASSERT_OK(s);
526 
527   auto* a_node = graph_view.GetNode("a");
528   ASSERT_NE(a_node, nullptr);
529   auto* b_node = graph_view.GetNode("b");
530   ASSERT_NE(b_node, nullptr);
531   auto* c_node = graph_view.GetNode("c");
532   ASSERT_NE(c_node, nullptr);
533   auto* d_node = graph_view.GetNode("d");
534   ASSERT_NE(d_node, nullptr);
535 
536   const auto& c_fanouts = c_node->GetControlledFanouts();
537   EXPECT_EQ(c_fanouts.size(), 2);
538   absl::flat_hash_set<FaninViewType> c_fanouts_set(c_fanouts.begin(),
539                                                    c_fanouts.end());
540   EXPECT_TRUE(c_fanouts_set.contains(
541       FaninViewType(&graph_view, b_node->node_index(), Graph::kControlSlot)));
542   EXPECT_TRUE(c_fanouts_set.contains(
543       FaninViewType(&graph_view, a_node->node_index(), Graph::kControlSlot)));
544 
545   const auto& d_fanouts = d_node->GetControlledFanouts();
546   FaninViewType c_control_fanout(&graph_view, c_node->node_index(),
547                                  Graph::kControlSlot);
548   if (std::is_same<GraphViewType, GraphView>::value) {
549     ASSERT_EQ(d_fanouts.size(), 2);
550     EXPECT_EQ(d_fanouts[0], c_control_fanout);
551     EXPECT_EQ(d_fanouts[1], c_control_fanout);
552   } else {  // MutableGraphView will dedup control dependency.
553     ASSERT_EQ(d_fanouts.size(), 1);
554     EXPECT_EQ(d_fanouts[0], c_control_fanout);
555   }
556 
557   const auto& a_fanouts = a_node->GetControlledFanouts();
558   EXPECT_EQ(a_fanouts.size(), 0);
559 }
560 
TYPED_TEST(TypedNodeViewTest,NumRegularFanins)561 TYPED_TEST(TypedNodeViewTest, NumRegularFanins) {
562   GraphDef graph = SimpleTestGraph();
563 
564   Status s;
565   TypeParam graph_view(&graph, &s);
566   TF_ASSERT_OK(s);
567 
568   auto* a_node = graph_view.GetNode("a");
569   ASSERT_NE(a_node, nullptr);
570   auto* b_node = graph_view.GetNode("b");
571   ASSERT_NE(b_node, nullptr);
572   auto* c_node = graph_view.GetNode("c");
573   ASSERT_NE(c_node, nullptr);
574   auto* d_node = graph_view.GetNode("d");
575   ASSERT_NE(d_node, nullptr);
576 
577   EXPECT_EQ(a_node->NumRegularFanins(), 4);
578   EXPECT_EQ(b_node->NumRegularFanins(), 2);
579   EXPECT_EQ(c_node->NumRegularFanins(), 0);
580   EXPECT_EQ(d_node->NumRegularFanins(), 0);
581 }
582 
TYPED_TEST(TypedNodeViewTest,NumControllingFanins)583 TYPED_TEST(TypedNodeViewTest, NumControllingFanins) {
584   GraphDef graph = SimpleTestGraph();
585 
586   Status s;
587   TypeParam graph_view(&graph, &s);
588   TF_ASSERT_OK(s);
589 
590   auto* a_node = graph_view.GetNode("a");
591   ASSERT_NE(a_node, nullptr);
592   auto* b_node = graph_view.GetNode("b");
593   ASSERT_NE(b_node, nullptr);
594   auto* c_node = graph_view.GetNode("c");
595   ASSERT_NE(c_node, nullptr);
596   auto* d_node = graph_view.GetNode("d");
597   ASSERT_NE(d_node, nullptr);
598 
599   EXPECT_EQ(a_node->NumControllingFanins(), 1);
600   EXPECT_EQ(b_node->NumControllingFanins(), 1);
601   if (std::is_same<TypeParam, GraphView>::value) {
602     EXPECT_EQ(c_node->NumControllingFanins(), 2);
603   } else {
604     EXPECT_EQ(c_node->NumControllingFanins(), 1);
605   }
606   EXPECT_EQ(d_node->NumControllingFanins(), 0);
607 }
608 
TYPED_TEST(TypedNodeViewTest,NumRegularFanouts)609 TYPED_TEST(TypedNodeViewTest, NumRegularFanouts) {
610   GraphDef graph = SimpleTestGraph();
611 
612   Status s;
613   TypeParam graph_view(&graph, &s);
614   TF_ASSERT_OK(s);
615 
616   auto* a_node = graph_view.GetNode("a");
617   ASSERT_NE(a_node, nullptr);
618   auto* b_node = graph_view.GetNode("b");
619   ASSERT_NE(b_node, nullptr);
620   auto* c_node = graph_view.GetNode("c");
621   ASSERT_NE(c_node, nullptr);
622   auto* d_node = graph_view.GetNode("d");
623   ASSERT_NE(d_node, nullptr);
624 
625   EXPECT_EQ(a_node->NumRegularFanouts(), 0);
626   EXPECT_EQ(b_node->NumRegularFanouts(), 2);
627   EXPECT_EQ(c_node->NumRegularFanouts(), 1);
628   EXPECT_EQ(d_node->NumRegularFanouts(), 3);
629 }
630 
TYPED_TEST(TypedNodeViewTest,NumControlledFanouts)631 TYPED_TEST(TypedNodeViewTest, NumControlledFanouts) {
632   GraphDef graph = SimpleTestGraph();
633 
634   Status s;
635   TypeParam graph_view(&graph, &s);
636   TF_ASSERT_OK(s);
637 
638   auto* a_node = graph_view.GetNode("a");
639   ASSERT_NE(a_node, nullptr);
640   auto* b_node = graph_view.GetNode("b");
641   ASSERT_NE(b_node, nullptr);
642   auto* c_node = graph_view.GetNode("c");
643   ASSERT_NE(c_node, nullptr);
644   auto* d_node = graph_view.GetNode("d");
645   ASSERT_NE(d_node, nullptr);
646 
647   EXPECT_EQ(a_node->NumControlledFanouts(), 0);
648   EXPECT_EQ(b_node->NumControlledFanouts(), 0);
649   EXPECT_EQ(c_node->NumControlledFanouts(), 2);
650   if (std::is_same<TypeParam, GraphView>::value) {
651     EXPECT_EQ(d_node->NumControlledFanouts(), 2);
652   } else {
653     EXPECT_EQ(d_node->NumControlledFanouts(), 1);
654   }
655 }
656 
TYPED_TEST(TypedNodeViewTest,HasFanin)657 TYPED_TEST(TypedNodeViewTest, HasFanin) {
658   GraphDef graph = SimpleTestGraph();
659 
660   Status s;
661   TypeParam graph_view(&graph, &s);
662   TF_ASSERT_OK(s);
663 
664   auto* a_node = graph_view.GetNode("a");
665   ASSERT_NE(a_node, nullptr);
666   auto* b_node = graph_view.GetNode("b");
667   ASSERT_NE(b_node, nullptr);
668   auto* c_node = graph_view.GetNode("c");
669   ASSERT_NE(c_node, nullptr);
670 
671   // Existing regular fanin.
672   EXPECT_TRUE(a_node->HasFanin({&graph_view, b_node->node_index(), 2}));
673   // Missing regular fanin.
674   EXPECT_FALSE(a_node->HasFanin({&graph_view, c_node->node_index(), 4}));
675   // Existing controlling fanin.
676   EXPECT_TRUE(a_node->HasFanin(
677       {&graph_view, c_node->node_index(), Graph::kControlSlot}));
678   // Missing controlling fanin.
679   EXPECT_FALSE(a_node->HasFanin(
680       {&graph_view, b_node->node_index(), Graph::kControlSlot}));
681   // Bad fanins.
682   EXPECT_FALSE(a_node->HasFanin({&graph_view, a_node->node_index(), 0}));
683   EXPECT_FALSE(a_node->HasFanin(
684       {&graph_view, b_node->node_index(), internal::kMissingSlot}));
685 }
686 
TYPED_TEST(TypedNodeViewTest,HasFanout)687 TYPED_TEST(TypedNodeViewTest, HasFanout) {
688   GraphDef graph = SimpleTestGraph();
689 
690   Status s;
691   TypeParam graph_view(&graph, &s);
692   TF_ASSERT_OK(s);
693 
694   auto* a_node = graph_view.GetNode("a");
695   ASSERT_NE(a_node, nullptr);
696   auto* b_node = graph_view.GetNode("b");
697   ASSERT_NE(b_node, nullptr);
698   auto* c_node = graph_view.GetNode("c");
699   ASSERT_NE(c_node, nullptr);
700   auto* d_node = graph_view.GetNode("d");
701   ASSERT_NE(d_node, nullptr);
702 
703   // Existing regular fanout.
704   EXPECT_TRUE(b_node->HasFanout({&graph_view, a_node->node_index(), 2}));
705   // Missing regular fanout.
706   EXPECT_FALSE(b_node->HasFanout({&graph_view, a_node->node_index(), 1}));
707   // Existing controlled fanout.
708   EXPECT_TRUE(d_node->HasFanout(
709       {&graph_view, c_node->node_index(), Graph::kControlSlot}));
710   // Missing controlled fanout.
711   EXPECT_FALSE(d_node->HasFanout(
712       {&graph_view, a_node->node_index(), Graph::kControlSlot}));
713   // Bad fanouts.
714   EXPECT_FALSE(d_node->HasFanout({&graph_view, d_node->node_index(), 0}));
715   EXPECT_FALSE(a_node->HasFanout({&graph_view, b_node->node_index(), 0}));
716   EXPECT_FALSE(a_node->HasFanout({&graph_view, 4, 0}));
717   EXPECT_FALSE(d_node->HasFanout(
718       {&graph_view, b_node->node_index(), internal::kMissingSlot}));
719 }
720 
SimpleAttrTestGraph()721 GraphDef SimpleAttrTestGraph() {
722   return GDef({NDef("a", kNoOp, {}), NDef("b", kNoOp, {}, {{"attr", 1}}),
723                NDef("c", kNoOp, {}, {{"attr_1", "a"}, {"attr_2", 2.0f}})},
724               /*funcs=*/{});
725 }
726 
TYPED_TEST(TypedNodeViewTest,GetAttr)727 TYPED_TEST(TypedNodeViewTest, GetAttr) {
728   GraphDef graph = SimpleAttrTestGraph();
729 
730   Status s;
731   TypeParam graph_view(&graph, &s);
732   TF_ASSERT_OK(s);
733 
734   auto* c_node = graph_view.GetNode("c");
735   ASSERT_NE(c_node, nullptr);
736 
737   EXPECT_EQ(c_node->GetAttr("attr_1")->s(), "a");
738 }
739 
TYPED_TEST(TypedNodeViewTest,GetAttrs)740 TYPED_TEST(TypedNodeViewTest, GetAttrs) {
741   GraphDef graph = SimpleAttrTestGraph();
742 
743   Status s;
744   TypeParam graph_view(&graph, &s);
745   TF_ASSERT_OK(s);
746 
747   auto* c_node = graph_view.GetNode("c");
748   ASSERT_NE(c_node, nullptr);
749 
750   const auto& actual_attrs = c_node->GetAttrs();
751   EXPECT_EQ(actual_attrs.size(), 2);
752   const auto* attr_1 = actual_attrs.Find("attr_1");
753   EXPECT_NE(attr_1, nullptr);
754   EXPECT_EQ(attr_1->s(), "a");
755   const auto* attr_2 = actual_attrs.Find("attr_2");
756   EXPECT_NE(attr_2, nullptr);
757   EXPECT_EQ(attr_2->f(), 2.0f);
758 }
759 
TYPED_TEST(TypedNodeViewTest,NumAttrs)760 TYPED_TEST(TypedNodeViewTest, NumAttrs) {
761   GraphDef graph = SimpleAttrTestGraph();
762 
763   Status s;
764   TypeParam graph_view(&graph, &s);
765   TF_ASSERT_OK(s);
766 
767   auto* a_node = graph_view.GetNode("a");
768   ASSERT_NE(a_node, nullptr);
769   auto* b_node = graph_view.GetNode("b");
770   ASSERT_NE(b_node, nullptr);
771   auto* c_node = graph_view.GetNode("c");
772   ASSERT_NE(c_node, nullptr);
773 
774   EXPECT_EQ(a_node->NumAttrs(), 0);
775   EXPECT_EQ(b_node->NumAttrs(), 1);
776   EXPECT_EQ(c_node->NumAttrs(), 2);
777 }
778 
TYPED_TEST(TypedNodeViewTest,HasAttr)779 TYPED_TEST(TypedNodeViewTest, HasAttr) {
780   GraphDef graph = SimpleAttrTestGraph();
781 
782   Status s;
783   TypeParam graph_view(&graph, &s);
784   TF_ASSERT_OK(s);
785 
786   auto* c_node = graph_view.GetNode("c");
787   ASSERT_NE(c_node, nullptr);
788 
789   EXPECT_TRUE(c_node->HasAttr("attr_1"));
790   EXPECT_FALSE(c_node->HasAttr("attr"));
791 }
792 
793 class CompareGraphTest : public GrapplerTest {
794  public:
CompareGraphViewWithGraph(MutableGraphView * graph_view,const GraphDef & expected_graph)795   void CompareGraphViewWithGraph(MutableGraphView* graph_view,
796                                  const GraphDef& expected_graph) {
797     Status s;
798     GraphView expected_graph_view(&expected_graph, &s);
799     TF_ASSERT_OK(s);
800 
801     EXPECT_EQ(graph_view->NumNodes(), expected_graph_view.NumNodes());
802 
803     for (const NodeView& expected_node_view : expected_graph_view.GetNodes()) {
804       const string& node_name = expected_node_view.GetName();
805       MutableNodeView* node_view = graph_view->GetNode(node_name);
806       ASSERT_NE(node_view, nullptr);
807 
808       EXPECT_EQ(node_view->GetName(), expected_node_view.GetName());
809 
810       EXPECT_EQ(node_view->GetOp(), expected_node_view.GetOp());
811 
812       EXPECT_EQ(node_view->GetDevice(), expected_node_view.GetDevice());
813 
814       const int actual_num_fanins = node_view->node()->input_size();
815       EXPECT_EQ(actual_num_fanins, expected_node_view.node()->input_size());
816 
817       const int expected_num_regular_fanins =
818           expected_node_view.NumRegularFanins();
819       bool same_num_regular_fanins =
820           node_view->NumRegularFanins() == expected_num_regular_fanins;
821       EXPECT_TRUE(same_num_regular_fanins);
822       for (int i = 0; i < expected_num_regular_fanins; ++i) {
823         const auto& expected_fanin = expected_node_view.GetRegularFanin(i);
824 
825         auto* actual_fanin_node =
826             graph_view->GetNode(expected_fanin.node_view()->GetName());
827         ASSERT_NE(actual_fanin_node, nullptr);
828         EXPECT_TRUE(
829             node_view->HasFanin({actual_fanin_node, expected_fanin.index()}));
830         if (i < node_view->NumRegularFanins()) {
831           auto& actual_fanin = node_view->GetRegularFanin(i);
832           EXPECT_EQ(actual_fanin, MutableFanoutView(actual_fanin_node,
833                                                     expected_fanin.index()));
834           EXPECT_EQ(actual_fanin.node_index(),
835                     actual_fanin.node_view()->node_index());
836         }
837       }
838 
839       if (same_num_regular_fanins) {
840         for (int i = 0; i < expected_num_regular_fanins; ++i) {
841           const auto& fanin = node_view->GetRegularFanin(i);
842           EXPECT_EQ(ParseTensorName(node_view->node()->input(i)),
843                     TensorId(fanin.node_view()->GetName(), fanin.index()));
844         }
845       }
846 
847       const int expected_num_controlling_fanins =
848           expected_node_view.NumControllingFanins();
849       bool same_num_controlling_fanins =
850           node_view->NumControllingFanins() == expected_num_controlling_fanins;
851       EXPECT_TRUE(same_num_controlling_fanins);
852       for (int i = 0; i < expected_num_controlling_fanins; ++i) {
853         auto& expected_fanin = expected_node_view.GetControllingFanins()[i];
854 
855         auto* actual_fanin_node =
856             graph_view->GetNode(expected_fanin.node_view()->GetName());
857         ASSERT_NE(actual_fanin_node, nullptr);
858         MutableFanoutView actual_fanin(actual_fanin_node,
859                                        expected_fanin.index());
860         EXPECT_TRUE(node_view->HasFanin(actual_fanin));
861 
862         int found = 0;
863         for (const auto& actual_fanin : node_view->GetControllingFanins()) {
864           if (actual_fanin.index() == expected_fanin.index() &&
865               actual_fanin.node_view()->GetName() ==
866                   expected_fanin.node_view()->GetName()) {
867             EXPECT_EQ(actual_fanin.node_index(),
868                       actual_fanin.node_view()->node_index());
869             ++found;
870           }
871         }
872         EXPECT_EQ(found, 1);
873       }
874 
875       if (same_num_controlling_fanins && same_num_regular_fanins) {
876         for (int i = 0; i < expected_num_controlling_fanins; ++i) {
877           const auto& fanin = node_view->GetControllingFanins()[i];
878           EXPECT_EQ(ParseTensorName(node_view->node()->input(
879                         i + expected_num_regular_fanins)),
880                     TensorId(fanin.node_view()->GetName(), fanin.index()));
881         }
882       }
883 
884       EXPECT_EQ(node_view->NumRegularFanouts(),
885                 expected_node_view.NumRegularFanouts());
886       const int num_output_ports =
887           expected_node_view.GetRegularFanouts().size();
888       ASSERT_EQ(node_view->GetRegularFanouts().size(), num_output_ports);
889       for (int i = 0; i < num_output_ports; ++i) {
890         auto& expected_fanouts_at_port_i = node_view->GetRegularFanouts()[i];
891         const int num_fanouts_at_port = expected_fanouts_at_port_i.size();
892 
893         auto& actual_fanouts_at_port_i = node_view->GetRegularFanouts()[i];
894         EXPECT_EQ(actual_fanouts_at_port_i.size(), num_fanouts_at_port);
895 
896         for (int j = 0; j < num_fanouts_at_port; ++j) {
897           auto& expected_fanout = expected_fanouts_at_port_i[j];
898 
899           auto* actual_fanout_node =
900               graph_view->GetNode(expected_fanout.node_view()->GetName());
901 
902           ASSERT_NE(actual_fanout_node, nullptr);
903           MutableFaninView actual_fanout(actual_fanout_node,
904                                          expected_fanout.index());
905           EXPECT_TRUE(node_view->HasFanout(actual_fanout));
906 
907           int found = 0;
908           for (const auto& fanout : actual_fanouts_at_port_i) {
909             if (fanout.index() == expected_fanout.index() &&
910                 fanout.node_view()->GetName() ==
911                     expected_fanout.node_view()->GetName()) {
912               EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index());
913               ++found;
914             }
915           }
916           EXPECT_EQ(found, 1);
917         }
918       }
919 
920       const int num_controlled_fanouts =
921           expected_node_view.NumControlledFanouts();
922       EXPECT_EQ(node_view->NumControlledFanouts(), num_controlled_fanouts);
923       for (int i = 0; i < num_controlled_fanouts; ++i) {
924         const auto& expected_fanout =
925             expected_node_view.GetControlledFanouts()[i];
926 
927         auto* actual_fanout_node =
928             graph_view->GetNode(expected_fanout.node_view()->GetName());
929         ASSERT_NE(actual_fanout_node, nullptr);
930         MutableFaninView actual_fanout(actual_fanout_node,
931                                        expected_fanout.index());
932         EXPECT_TRUE(node_view->HasFanout(actual_fanout));
933 
934         int found = 0;
935         for (const auto& fanout : node_view->GetControlledFanouts()) {
936           if (fanout.index() == expected_fanout.index() &&
937               fanout.node_view()->GetName() ==
938                   expected_fanout.node_view()->GetName()) {
939             EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index());
940             ++found;
941           }
942         }
943         EXPECT_EQ(found, 1);
944       }
945 
946       EXPECT_EQ(node_view->NumAttrs(), expected_node_view.NumAttrs());
947       for (const auto& expected_attr : expected_node_view.GetAttrs()) {
948         auto* attr = node_view->GetAttr(expected_attr.first);
949         EXPECT_TRUE(AreAttrValuesEqual(*attr, expected_attr.second));
950       }
951     }
952     CompareGraphs(*graph_view->graph(), expected_graph);
953   }
954 };
955 
956 class MutationTest : public CompareGraphTest {};
957 
958 constexpr char kDeviceCPU0[] = "/device:CPU:0";
959 constexpr char kDeviceGPU0[] = "/device:GPU:0";
960 
SimpleTestGraphForMutation()961 GraphDef SimpleTestGraphForMutation() {
962   return GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
963                NDef("b", kNoOp, {}, {}, kDeviceCPU0),
964                NDef("c", kNoOp, {}, {}, kDeviceCPU0),
965                NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
966                     {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
967               /*funcs=*/{});
968 }
969 
TEST_F(MutationTest,AddNewNode)970 TEST_F(MutationTest, AddNewNode) {
971   GraphDef graph = SimpleTestGraphForMutation();
972 
973   Status s;
974   MutableGraphView graph_view(&graph, &s);
975   TF_ASSERT_OK(s);
976 
977   Mutation* mutation = graph_view.GetMutationBuilder();
978 
979   NodeDef empty_node;
980   mutation->AddNode(std::move(empty_node), &s);
981   TF_EXPECT_OK(s);
982   s = errors::Internal("error");
983 
984   NodeDef valid_node =
985       NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo");
986   mutation->AddNode(std::move(valid_node), &s);
987   TF_EXPECT_OK(s);
988 
989   NodeDef bad_node_1 =
990       NDef("bad", "IdentityN", {"^b", "a:1"}, {{"N", 1}}, "foo");
991   mutation->AddNode(std::move(bad_node_1), &s);
992   EXPECT_FALSE(s.ok());
993   EXPECT_EQ(s.error_message(),
994             "Mutation::AddNode error: node 'bad' has regular fanin 'a:1' after "
995             "controlling fanins.");
996 
997   NodeDef bad_node_2 = NDef("bad", "IdentityN", {"bad:1"}, {}, "foo");
998   mutation->AddNode(std::move(bad_node_2), &s);
999   EXPECT_FALSE(s.ok());
1000   EXPECT_EQ(s.error_message(),
1001             "Mutation::AddNode error: node 'bad' has self cycle fanin "
1002             "'bad:1'.");
1003 
1004   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1005 }
1006 
TEST_F(MutationTest,NewNodeBadFaninsAfterAdd)1007 TEST_F(MutationTest, NewNodeBadFaninsAfterAdd) {
1008   GraphDef graph = SimpleTestGraphForMutation();
1009 
1010   Status s;
1011   MutableGraphView graph_view(&graph, &s);
1012   TF_ASSERT_OK(s);
1013 
1014   Mutation* mutation = graph_view.GetMutationBuilder();
1015 
1016   NodeDef valid_node =
1017       NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo");
1018   MutationNewNode new_node = mutation->AddNode(std::move(valid_node), &s);
1019 
1020   mutation->AddOrUpdateRegularFanin(new_node, 1, {"valid", 2});
1021   s = mutation->Apply();
1022   EXPECT_FALSE(s.ok());
1023   string expected_error_msg =
1024       "Mutation::Apply error: new node 'valid' is ill-formed.";
1025   EXPECT_EQ(s.error_message(), expected_error_msg);
1026   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1027 }
1028 
TEST_F(MutationTest,NewNodesConflictingNames)1029 TEST_F(MutationTest, NewNodesConflictingNames) {
1030   GraphDef graph = SimpleTestGraphForMutation();
1031 
1032   Status s;
1033   MutableGraphView graph_view(&graph, &s);
1034   TF_ASSERT_OK(s);
1035 
1036   Mutation* mutation = graph_view.GetMutationBuilder();
1037 
1038   NodeDef new_node_1 = NDef("a", "", {});
1039   mutation->AddNode(std::move(new_node_1), &s);
1040   TF_EXPECT_OK(s);
1041 
1042   NodeDef new_node_2 = NDef("a", "", {});
1043   mutation->AddNode(std::move(new_node_2), &s);
1044   TF_EXPECT_OK(s);
1045 
1046   s = mutation->Apply();
1047   EXPECT_FALSE(s.ok());
1048   string expected_error_msg =
1049       "Mutation::Apply error: multiple nodes with the name: 'a' exists in "
1050       "Mutation.";
1051   EXPECT_EQ(s.error_message(), expected_error_msg);
1052   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1053 }
1054 
TEST_F(MutationTest,UpdateNodeAndAddSelfLoop)1055 TEST_F(MutationTest, UpdateNodeAndAddSelfLoop) {
1056   GraphDef graph = SimpleTestGraphForMutation();
1057 
1058   Status s;
1059   MutableGraphView graph_view(&graph, &s);
1060   TF_ASSERT_OK(s);
1061 
1062   Mutation* mutation = graph_view.GetMutationBuilder();
1063 
1064   MutableNodeView* d_node = graph_view.GetNode("d");
1065   ASSERT_NE(d_node, nullptr);
1066   mutation->AddControllingFanin(d_node, "d");
1067 
1068   s = mutation->Apply();
1069   EXPECT_FALSE(s.ok());
1070   string expected_error_msg =
1071       "Mutation::Apply error: inplace updated node 'd' is ill-formed.";
1072   EXPECT_EQ(s.error_message(), expected_error_msg);
1073   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1074 }
1075 
TEST_F(MutationTest,RenameNodeAndAddSelfLoop)1076 TEST_F(MutationTest, RenameNodeAndAddSelfLoop) {
1077   GraphDef graph = SimpleTestGraphForMutation();
1078 
1079   Status s;
1080   MutableGraphView graph_view(&graph, &s);
1081   TF_ASSERT_OK(s);
1082 
1083   Mutation* mutation = graph_view.GetMutationBuilder();
1084 
1085   MutableNodeView* d_node = graph_view.GetNode("d");
1086   ASSERT_NE(d_node, nullptr);
1087   mutation->UpdateNodeName(d_node, "e");
1088   mutation->AddControllingFanin(d_node, "e");
1089 
1090   s = mutation->Apply();
1091   EXPECT_FALSE(s.ok());
1092   string expected_error_msg =
1093       "Mutation::Apply error: renamed updated node 'e' ('d') is ill-formed.";
1094   EXPECT_EQ(s.error_message(), expected_error_msg);
1095   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1096 }
1097 
TEST_F(MutationTest,ExistingNodesConflictingNames)1098 TEST_F(MutationTest, ExistingNodesConflictingNames) {
1099   GraphDef graph = SimpleTestGraphForMutation();
1100 
1101   Status s;
1102   MutableGraphView graph_view(&graph, &s);
1103   TF_ASSERT_OK(s);
1104 
1105   Mutation* mutation = graph_view.GetMutationBuilder();
1106 
1107   MutableNodeView* a_node = graph_view.GetNode("a");
1108   ASSERT_NE(a_node, nullptr);
1109   mutation->UpdateNodeName(a_node, "b");
1110 
1111   MutableNodeView* b_node = graph_view.GetNode("b");
1112   ASSERT_NE(b_node, nullptr);
1113   mutation->UpdateNodeOp(b_node, "Identity");
1114 
1115   s = mutation->Apply();
1116   EXPECT_FALSE(s.ok());
1117   string expected_error_msg =
1118       "Mutation::Apply error: multiple nodes with the name: 'b' exists in "
1119       "Mutation.";
1120   EXPECT_EQ(s.error_message(), expected_error_msg);
1121   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1122 }
1123 
TEST_F(MutationTest,NewAndExistingNodesConflictingNames)1124 TEST_F(MutationTest, NewAndExistingNodesConflictingNames) {
1125   GraphDef graph = SimpleTestGraphForMutation();
1126 
1127   Status s;
1128   MutableGraphView graph_view(&graph, &s);
1129   TF_ASSERT_OK(s);
1130 
1131   Mutation* mutation = graph_view.GetMutationBuilder();
1132 
1133   NodeDef new_node = NDef("a", "", {});
1134   mutation->AddNode(std::move(new_node), &s);
1135   TF_EXPECT_OK(s);
1136 
1137   MutableNodeView* a_node = graph_view.GetNode("a");
1138   ASSERT_NE(a_node, nullptr);
1139   mutation->UpdateNodeDevice(a_node, "foo");
1140 
1141   s = mutation->Apply();
1142   EXPECT_FALSE(s.ok());
1143   string expected_error_msg =
1144       "Mutation::Apply error: multiple nodes with the name: 'a' exists in "
1145       "Mutation.";
1146   EXPECT_EQ(s.error_message(), expected_error_msg);
1147   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1148 }
1149 
TEST_F(MutationTest,NewAndExistingRenamedNodesConflictingNames)1150 TEST_F(MutationTest, NewAndExistingRenamedNodesConflictingNames) {
1151   GraphDef graph = SimpleTestGraphForMutation();
1152 
1153   Status s;
1154   MutableGraphView graph_view(&graph, &s);
1155   TF_ASSERT_OK(s);
1156 
1157   Mutation* mutation = graph_view.GetMutationBuilder();
1158 
1159   NodeDef new_node = NDef("e", "", {});
1160   mutation->AddNode(std::move(new_node), &s);
1161   TF_EXPECT_OK(s);
1162 
1163   MutableNodeView* d_node = graph_view.GetNode("d");
1164   ASSERT_NE(d_node, nullptr);
1165   mutation->UpdateNodeName(d_node, "e");
1166 
1167   s = mutation->Apply();
1168   EXPECT_FALSE(s.ok());
1169   string expected_error_msg =
1170       "Mutation::Apply error: multiple nodes with the name: 'e' exists in "
1171       "Mutation.";
1172   EXPECT_EQ(s.error_message(), expected_error_msg);
1173   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1174 }
1175 
TEST_F(MutationTest,RemoveNodesWithFanouts)1176 TEST_F(MutationTest, RemoveNodesWithFanouts) {
1177   GraphDef graph = SimpleTestGraphForMutation();
1178 
1179   Status s;
1180   MutableGraphView graph_view(&graph, &s);
1181   TF_ASSERT_OK(s);
1182 
1183   Mutation* mutation = graph_view.GetMutationBuilder();
1184 
1185   MutableNodeView* b_node = graph_view.GetNode("b");
1186   ASSERT_NE(b_node, nullptr);
1187   mutation->RemoveNode(b_node);
1188 
1189   s = mutation->Apply();
1190   EXPECT_FALSE(s.ok());
1191   string expected_error_msg =
1192       "Mutation::Apply error: fanout 'd' exist for missing node 'b'.";
1193   EXPECT_EQ(s.error_message(), expected_error_msg);
1194   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1195 
1196   MutableNodeView* d_node = graph_view.GetNode("d");
1197   ASSERT_NE(d_node, nullptr);
1198   mutation->RemoveNode(d_node);
1199 
1200   TF_EXPECT_OK(mutation->Apply());
1201   GraphDef expected_graph = GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1202                                   NDef("c", kNoOp, {}, {}, kDeviceCPU0)},
1203                                  /*funcs=*/{});
1204   CompareGraphViewWithGraph(&graph_view, expected_graph);
1205 }
1206 
TEST_F(MutationTest,SwapNodeNamesWithCycle)1207 TEST_F(MutationTest, SwapNodeNamesWithCycle) {
1208   GraphDef graph = SimpleTestGraphForMutation();
1209 
1210   Status s;
1211   MutableGraphView graph_view(&graph, &s);
1212   TF_ASSERT_OK(s);
1213 
1214   Mutation* mutation = graph_view.GetMutationBuilder();
1215 
1216   MutableNodeView* d_node = graph_view.GetNode("d");
1217   ASSERT_NE(d_node, nullptr);
1218   mutation->UpdateNodeName(d_node, "b");
1219   MutableNodeView* b_node = graph_view.GetNode("b");
1220   ASSERT_NE(b_node, nullptr);
1221   mutation->UpdateNodeName(b_node, "d");
1222 
1223   s = mutation->Apply();
1224   EXPECT_FALSE(s.ok());
1225   string expected_error_msg =
1226       "Mutation::Apply error: renamed updated node 'b' ('d') is ill-formed.";
1227   EXPECT_EQ(s.error_message(), expected_error_msg);
1228   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1229 
1230   mutation->AddOrUpdateRegularFanin(d_node, 1, {"d", 3});
1231   mutation->RemoveControllingFanin(d_node, "b");
1232 
1233   TF_EXPECT_OK(mutation->Apply());
1234   GraphDef expected_graph =
1235       GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1236             NDef("d", kNoOp, {}, {}, kDeviceCPU0),
1237             NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1238             NDef("b", kNoOp, {"a:2", "d:3", "a:4", "^c"},
1239                  {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1240            /*funcs=*/{});
1241   CompareGraphViewWithGraph(&graph_view, expected_graph);
1242 }
1243 
TEST_F(MutationTest,RenamedNodeWithFanouts)1244 TEST_F(MutationTest, RenamedNodeWithFanouts) {
1245   GraphDef graph = SimpleTestGraphForMutation();
1246 
1247   Status s;
1248   MutableGraphView graph_view(&graph, &s);
1249   TF_ASSERT_OK(s);
1250 
1251   Mutation* mutation = graph_view.GetMutationBuilder();
1252 
1253   MutableNodeView* a_node = graph_view.GetNode("a");
1254   ASSERT_NE(a_node, nullptr);
1255   mutation->UpdateNodeName(a_node, "b");
1256 
1257   s = mutation->Apply();
1258   EXPECT_FALSE(s.ok());
1259   string expected_error_msg =
1260       "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1261   EXPECT_EQ(s.error_message(), expected_error_msg);
1262   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1263 
1264   mutation->UpdateNodeName(a_node, "a");
1265 
1266   MutableNodeView* b_node = graph_view.GetNode("b");
1267   ASSERT_NE(b_node, nullptr);
1268   mutation->UpdateNodeName(b_node, "e");
1269 
1270   s = mutation->Apply();
1271   EXPECT_FALSE(s.ok());
1272   expected_error_msg =
1273       "Mutation::Apply error: fanout 'd' exist for missing "
1274       "node 'b'.";
1275   EXPECT_EQ(s.error_message(), expected_error_msg);
1276   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1277 }
1278 
TEST_F(MutationTest,RemoveExistingNodeAndReplaceWithNewNode)1279 TEST_F(MutationTest, RemoveExistingNodeAndReplaceWithNewNode) {
1280   GraphDef graph = SimpleTestGraphForMutation();
1281 
1282   Status s;
1283   MutableGraphView graph_view(&graph, &s);
1284   TF_ASSERT_OK(s);
1285 
1286   Mutation* mutation = graph_view.GetMutationBuilder();
1287 
1288   MutableNodeView* d_node = graph_view.GetNode("d");
1289   ASSERT_NE(d_node, nullptr);
1290   mutation->RemoveNode(d_node);
1291 
1292   NodeDef new_node = NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0);
1293   mutation->AddNode(std::move(new_node), &s);
1294   TF_EXPECT_OK(s);
1295 
1296   TF_EXPECT_OK(mutation->Apply());
1297   GraphDef expected_graph =
1298       GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1299             NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1300             NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1301             NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0)},
1302            /*funcs=*/{});
1303   CompareGraphViewWithGraph(&graph_view, expected_graph);
1304 }
1305 
TEST_F(MutationTest,UpdateNodeNameAndRemoveFanins)1306 TEST_F(MutationTest, UpdateNodeNameAndRemoveFanins) {
1307   GraphDef graph = SimpleTestGraphForMutation();
1308 
1309   Status s;
1310   MutableGraphView graph_view(&graph, &s);
1311   TF_ASSERT_OK(s);
1312 
1313   Mutation* mutation = graph_view.GetMutationBuilder();
1314 
1315   MutableNodeView* d_node = graph_view.GetNode("d");
1316   ASSERT_NE(d_node, nullptr);
1317   mutation->UpdateNodeName(d_node, "e");
1318   mutation->RemoveRegularFanin(d_node, 1);
1319   mutation->RemoveRegularFanin(d_node, 2);
1320 
1321   TF_EXPECT_OK(mutation->Apply());
1322   GraphDef expected_graph =
1323       GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1324             NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1325             NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1326             NDef("e", kNoOp, {"a:2", "^c", "^b"},
1327                  {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1328            /*funcs=*/{});
1329   CompareGraphViewWithGraph(&graph_view, expected_graph);
1330 }
1331 
TEST_F(MutationTest,UpdateNodeNameAndRemoveRegularFanout)1332 TEST_F(MutationTest, UpdateNodeNameAndRemoveRegularFanout) {
1333   GraphDef graph = SimpleTestGraphForMutation();
1334 
1335   Status s;
1336   MutableGraphView graph_view(&graph, &s);
1337   TF_ASSERT_OK(s);
1338 
1339   Mutation* mutation = graph_view.GetMutationBuilder();
1340 
1341   MutableNodeView* a_node = graph_view.GetNode("a");
1342   ASSERT_NE(a_node, nullptr);
1343   mutation->UpdateNodeName(a_node, "e");
1344 
1345   s = mutation->Apply();
1346   EXPECT_FALSE(s.ok());
1347   string expected_error_msg =
1348       "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1349   EXPECT_EQ(s.error_message(), expected_error_msg);
1350   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1351 
1352   MutableNodeView* d_node = graph_view.GetNode("d");
1353   ASSERT_NE(d_node, nullptr);
1354   mutation->RemoveRegularFanin(d_node, 2);
1355 
1356   s = mutation->Apply();
1357   EXPECT_FALSE(s.ok());
1358   expected_error_msg =
1359       "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1360   EXPECT_EQ(s.error_message(), expected_error_msg);
1361   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1362 
1363   mutation->AddOrUpdateRegularFanin(d_node, 0, {"b", 1});
1364 
1365   TF_EXPECT_OK(mutation->Apply());
1366   GraphDef expected_graph =
1367       GDef({NDef("e", kNoOp, {}, {}, kDeviceCPU0),
1368             NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1369             NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1370             NDef("d", kNoOp, {"b:1", "b:3", "^c", "^b"},
1371                  {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1372            /*funcs=*/{});
1373   CompareGraphViewWithGraph(&graph_view, expected_graph);
1374 }
1375 
TEST_F(MutationTest,UpdateNodeNameAndRemoveControlledFanout)1376 TEST_F(MutationTest, UpdateNodeNameAndRemoveControlledFanout) {
1377   GraphDef graph = SimpleTestGraphForMutation();
1378 
1379   Status s;
1380   MutableGraphView graph_view(&graph, &s);
1381   TF_ASSERT_OK(s);
1382 
1383   Mutation* mutation = graph_view.GetMutationBuilder();
1384 
1385   MutableNodeView* c_node = graph_view.GetNode("c");
1386   ASSERT_NE(c_node, nullptr);
1387   mutation->UpdateNodeName(c_node, "e");
1388 
1389   s = mutation->Apply();
1390   EXPECT_FALSE(s.ok());
1391   string expected_error_msg =
1392       "Mutation::Apply error: fanout 'd' exist for missing node 'c'.";
1393   EXPECT_EQ(s.error_message(), expected_error_msg);
1394   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1395 
1396   MutableNodeView* d_node = graph_view.GetNode("d");
1397   ASSERT_NE(d_node, nullptr);
1398   mutation->UpdateNodeDevice(d_node, kDeviceGPU0);
1399 
1400   s = mutation->Apply();
1401   EXPECT_FALSE(s.ok());
1402   expected_error_msg =
1403       "Mutation::Apply error: fanout 'd' exist for missing node 'c'.";
1404   EXPECT_EQ(s.error_message(), expected_error_msg);
1405   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1406 
1407   mutation->RemoveControllingFanin(d_node, "c");
1408 
1409   TF_EXPECT_OK(mutation->Apply());
1410   GraphDef expected_graph =
1411       GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1412             NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1413             NDef("e", kNoOp, {}, {}, kDeviceCPU0),
1414             NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^b"},
1415                  {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceGPU0)},
1416            /*funcs=*/{});
1417   CompareGraphViewWithGraph(&graph_view, expected_graph);
1418 }
1419 
TEST_F(MutationTest,EmptyMutation)1420 TEST_F(MutationTest, EmptyMutation) {
1421   GraphDef graph = SimpleTestGraphForMutation();
1422 
1423   Status s;
1424   MutableGraphView graph_view(&graph, &s);
1425   TF_ASSERT_OK(s);
1426 
1427   Mutation* mutation = graph_view.GetMutationBuilder();
1428 
1429   TF_EXPECT_OK(mutation->Apply());
1430   CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1431 }
1432 
1433 constexpr char kIdentity[] = "Identity";
1434 constexpr char kDeviceCPU1[] = "/device:CPU:1";
1435 constexpr char kDeviceGPU1[] = "/device:GPU:1";
1436 
TestGraphForMutation()1437 GraphDef TestGraphForMutation() {
1438   return GDef(
1439       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1440        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1441        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1442        NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1443             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)},
1444       /*funcs=*/{});
1445 }
1446 
TEST_F(MutationTest,SwapNodeNamesWithNoCycle)1447 TEST_F(MutationTest, SwapNodeNamesWithNoCycle) {
1448   GraphDef graph = TestGraphForMutation();
1449 
1450   Status s;
1451   MutableGraphView graph_view(&graph, &s);
1452   TF_ASSERT_OK(s);
1453 
1454   Mutation* mutation = graph_view.GetMutationBuilder();
1455 
1456   MutableNodeView* b_node = graph_view.GetNode("b");
1457   ASSERT_NE(b_node, nullptr);
1458   MutableNodeView* c_node = graph_view.GetNode("c");
1459   ASSERT_NE(c_node, nullptr);
1460 
1461   mutation->UpdateNodeName(b_node, "c");
1462   mutation->UpdateNodeName(c_node, "b");
1463 
1464   TF_EXPECT_OK(mutation->Apply());
1465   GraphDef expected_graph = GDef(
1466       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1467        NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1468        NDef("b", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1469        NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1470             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)},
1471       /*funcs=*/{});
1472   CompareGraphViewWithGraph(&graph_view, expected_graph);
1473 }
1474 
TEST_F(MutationTest,RemoveMultipleDependentNodes)1475 TEST_F(MutationTest, RemoveMultipleDependentNodes) {
1476   GraphDef graph = TestGraphForMutation();
1477 
1478   Status s;
1479   MutableGraphView graph_view(&graph, &s);
1480   TF_ASSERT_OK(s);
1481 
1482   Mutation* mutation = graph_view.GetMutationBuilder();
1483 
1484   MutableNodeView* c_node = graph_view.GetNode("c");
1485   ASSERT_NE(c_node, nullptr);
1486   MutableNodeView* d_node = graph_view.GetNode("d");
1487   ASSERT_NE(d_node, nullptr);
1488 
1489   mutation->RemoveNode(c_node);
1490   mutation->RemoveNode(d_node);
1491 
1492   TF_EXPECT_OK(mutation->Apply());
1493   GraphDef expected_graph = GDef(
1494       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1495        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0)},
1496       /*funcs=*/{});
1497   CompareGraphViewWithGraph(&graph_view, expected_graph);
1498 }
1499 
1500 constexpr char kDeviceGPU2[] = "/device:GPU:2";
1501 
TEST_F(MutationTest,AddSimpleNewNode)1502 TEST_F(MutationTest, AddSimpleNewNode) {
1503   GraphDef graph = TestGraphForMutation();
1504 
1505   Status s;
1506   MutableGraphView graph_view(&graph, &s);
1507   TF_ASSERT_OK(s);
1508 
1509   Mutation* mutation = graph_view.GetMutationBuilder();
1510 
1511   NodeDef new_node =
1512       NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2);
1513   mutation->AddNode(std::move(new_node), &s);
1514   TF_EXPECT_OK(s);
1515 
1516   TF_EXPECT_OK(mutation->Apply());
1517   GraphDef expected_graph = GDef(
1518       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1519        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1520        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1521        NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1522             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1523        NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2)},
1524       /*funcs=*/{});
1525   CompareGraphViewWithGraph(&graph_view, expected_graph);
1526 }
1527 
1528 constexpr char kDeviceGPU3[] = "/device:GPU:3";
1529 
TEST_F(MutationTest,AddAndUpdateNodesWithFanins)1530 TEST_F(MutationTest, AddAndUpdateNodesWithFanins) {
1531   GraphDef graph = TestGraphForMutation();
1532 
1533   Status s;
1534   MutableGraphView graph_view(&graph, &s);
1535   TF_ASSERT_OK(s);
1536 
1537   Mutation* mutation = graph_view.GetMutationBuilder();
1538 
1539   NodeDef new_node_1 = NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"},
1540                             {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2);
1541   mutation->AddNode(std::move(new_node_1), &s);
1542   TF_EXPECT_OK(s);
1543 
1544   NodeDef new_node_2 =
1545       NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"},
1546            {{"new_node_2_attr_1", 9}}, kDeviceGPU3);
1547   mutation->AddNode(std::move(new_node_2), &s);
1548   TF_EXPECT_OK(s);
1549 
1550   MutableNodeView* d_node = graph_view.GetNode("d");
1551   ASSERT_NE(d_node, nullptr);
1552   mutation->AddOrUpdateRegularFanin(d_node, 3, {"c", 6});
1553   mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 5});
1554   mutation->AddControllingFanin(d_node, "new_node_2");
1555 
1556   TF_EXPECT_OK(mutation->Apply());
1557   GraphDef expected_graph = GDef(
1558       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1559        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1560        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1561        NDef("d", kNoOp,
1562             {"a:2", "new_node_1:5", "a:4", "c:6", "^c", "^b", "^new_node_2"},
1563             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1564        NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"},
1565             {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2),
1566        NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"},
1567             {{"new_node_2_attr_1", 9}}, kDeviceGPU3)},
1568       /*funcs=*/{});
1569   CompareGraphViewWithGraph(&graph_view, expected_graph);
1570 }
1571 
TEST_F(MutationTest,UpdateNodeNameToReplaceExistingNode)1572 TEST_F(MutationTest, UpdateNodeNameToReplaceExistingNode) {
1573   auto test_graph = []() {
1574     return GDef(
1575         {NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0),
1576          NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1),
1577          NDef("c", kNoOp, {"b:4", "^a"}, {{"attr_c", "test"}}, kDeviceGPU2),
1578          NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"},
1579               {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)},
1580         /*funcs=*/{});
1581   };
1582 
1583   GraphDef graph = test_graph();
1584 
1585   Status s;
1586   MutableGraphView graph_view(&graph, &s);
1587   TF_ASSERT_OK(s);
1588 
1589   Mutation* mutation = graph_view.GetMutationBuilder();
1590 
1591   MutableNodeView* b_node = graph_view.GetNode("b");
1592   ASSERT_NE(b_node, nullptr);
1593 
1594   mutation->UpdateNodeName(b_node, "c");
1595 
1596   TF_EXPECT_OK(mutation->Apply());
1597   GraphDef expected_graph =
1598       GDef({NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0),
1599             NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1),
1600             NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"},
1601                  {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)},
1602            /*funcs=*/{});
1603   CompareGraphViewWithGraph(&graph_view, expected_graph);
1604 }
1605 
TEST_F(MutationTest,NewNodeWithMutations)1606 TEST_F(MutationTest, NewNodeWithMutations) {
1607   GraphDef graph = TestGraphForMutation();
1608 
1609   Status s;
1610   MutableGraphView graph_view(&graph, &s);
1611   TF_ASSERT_OK(s);
1612 
1613   Mutation* mutation = graph_view.GetMutationBuilder();
1614 
1615   NodeDef new_node_def = NDef("node", kNoOp, {"a:2", "b:3", "^c"},
1616                               {{"attr_1", 1}, {"attr_2", 2.0f}}, kDeviceGPU3);
1617   MutationNewNode new_node = mutation->AddNode(std::move(new_node_def), &s);
1618   TF_EXPECT_OK(s);
1619 
1620   mutation->AddControllingFanin(new_node, "a");
1621   mutation->RemoveControllingFanin(new_node, "c");
1622   mutation->AddOrUpdateRegularFanin(new_node, 0, {"b", 6});
1623   mutation->RemoveRegularFanin(new_node, 1);
1624   mutation->UpdateNodeName(new_node, "new_node");
1625   mutation->UpdateNodeOp(new_node, kIdentity);
1626   mutation->UpdateNodeDevice(new_node, kDeviceGPU2);
1627   AttrValue attr_3;
1628   attr_3.set_s("new_node_attr");
1629   mutation->AddOrUpdateNodeAttr(new_node, "attr_3", attr_3);
1630   AttrValue attr_1;
1631   attr_1.set_b(true);
1632   mutation->AddOrUpdateNodeAttr(new_node, "attr_1", attr_1);
1633   mutation->RemoveNodeAttr(new_node, "attr_2");
1634   AttrValue attr_4;
1635   attr_4.set_type(DT_FLOAT);
1636   mutation->AddOrUpdateNodeAttr(new_node, "T", attr_4);
1637 
1638   TF_EXPECT_OK(mutation->Apply());
1639   GraphDef expected_graph = GDef(
1640       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1641        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1642        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1643        NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1644             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1645        NDef("new_node", kIdentity, {"b:6", "^a"},
1646             {{"attr_1", true}, {"attr_3", "new_node_attr"}, {"T", DT_FLOAT}},
1647             kDeviceGPU2)},
1648       /*funcs=*/{});
1649   CompareGraphViewWithGraph(&graph_view, expected_graph);
1650 }
1651 
TEST_F(MutationTest,UpdatedNodeWithNonFaninMutations)1652 TEST_F(MutationTest, UpdatedNodeWithNonFaninMutations) {
1653   GraphDef graph = TestGraphForMutation();
1654 
1655   Status s;
1656   MutableGraphView graph_view(&graph, &s);
1657   TF_ASSERT_OK(s);
1658 
1659   MutableNodeView* d_node = graph_view.GetNode("d");
1660   ASSERT_NE(d_node, nullptr);
1661 
1662   Mutation* mutation = graph_view.GetMutationBuilder();
1663 
1664   mutation->UpdateNodeName(d_node, "e");
1665   mutation->UpdateNodeOp(d_node, kIdentity);
1666   mutation->UpdateNodeDevice(d_node, kDeviceGPU2);
1667   AttrValue attr_d_1;
1668   attr_d_1.set_b(false);
1669   mutation->AddOrUpdateNodeAttr(d_node, "attr_d_1", attr_d_1);
1670   AttrValue attr_e_3;
1671   attr_e_3.set_s("test_string");
1672   mutation->AddOrUpdateNodeAttr(d_node, "attr_e_3", attr_e_3);
1673   mutation->RemoveNodeAttr(d_node, "attr_d_2");
1674   AttrValue attr_e_4;
1675   attr_e_4.set_type(DT_INT64);
1676   mutation->AddOrUpdateNodeAttr(d_node, "T", attr_e_4);
1677 
1678   TF_EXPECT_OK(mutation->Apply());
1679   GraphDef expected_graph = GDef(
1680       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1681        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1682        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1683        NDef("e", kIdentity, {"a:2", "b:3", "a:4", "^c", "^b"},
1684             {{"attr_d_1", false}, {"attr_e_3", "test_string"}, {"T", DT_INT64}},
1685             kDeviceGPU2)},
1686       /*funcs=*/{});
1687   CompareGraphViewWithGraph(&graph_view, expected_graph);
1688 }
1689 
TEST_F(MutationTest,Reset)1690 TEST_F(MutationTest, Reset) {
1691   GraphDef graph = TestGraphForMutation();
1692 
1693   Status s;
1694   MutableGraphView graph_view(&graph, &s);
1695   TF_ASSERT_OK(s);
1696 
1697   MutableNodeView* a_node = graph_view.GetNode("a");
1698   ASSERT_NE(a_node, nullptr);
1699 
1700   Mutation* mutation = graph_view.GetMutationBuilder();
1701 
1702   mutation->UpdateNodeName(a_node, "e");
1703   mutation->AddNode({}, &s);
1704   TF_EXPECT_OK(s);
1705 
1706   s = mutation->Apply();
1707   EXPECT_FALSE(s.ok());
1708   string expected_error_msg =
1709       "Mutation::Apply error: fanout 'b' exist for missing node 'a'.";
1710   EXPECT_EQ(s.error_message(), expected_error_msg);
1711   CompareGraphViewWithGraph(&graph_view, TestGraphForMutation());
1712 
1713   mutation->Reset();
1714   TF_EXPECT_OK(mutation->Apply());
1715   CompareGraphViewWithGraph(&graph_view, TestGraphForMutation());
1716 }
1717 
TEST_F(MutationTest,RenameNodeAndAddNewNodeWithRenamedNodeOldName)1718 TEST_F(MutationTest, RenameNodeAndAddNewNodeWithRenamedNodeOldName) {
1719   GraphDef graph = TestGraphForMutation();
1720 
1721   Status s;
1722   MutableGraphView graph_view(&graph, &s);
1723   TF_ASSERT_OK(s);
1724 
1725   MutableNodeView* b_node = graph_view.GetNode("b");
1726   ASSERT_NE(b_node, nullptr);
1727 
1728   Mutation* mutation = graph_view.GetMutationBuilder();
1729 
1730   mutation->UpdateNodeName(b_node, "e");
1731 
1732   NodeDef new_node =
1733       NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3);
1734   mutation->AddNode(std::move(new_node), &s);
1735   TF_EXPECT_OK(s);
1736 
1737   TF_EXPECT_OK(mutation->Apply());
1738   GraphDef expected_graph = GDef(
1739       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1740        NDef("e", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1741        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1742        NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1743             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1744        NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3)},
1745       /*funcs=*/{});
1746   CompareGraphViewWithGraph(&graph_view, expected_graph);
1747 }
1748 
TEST_F(MutationTest,ShiftNodesWithFanouts)1749 TEST_F(MutationTest, ShiftNodesWithFanouts) {
1750   auto test_graph = []() {
1751     return GDef({NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^c", "^b"},
1752                       {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1753                  NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1754                  NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1755                  NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}},
1756                       kDeviceGPU0)},
1757                 /*funcs=*/{});
1758   };
1759 
1760   GraphDef graph = test_graph();
1761 
1762   Status s;
1763   MutableGraphView graph_view(&graph, &s);
1764   TF_ASSERT_OK(s);
1765 
1766   MutableNodeView* c_node = graph_view.GetNode("c");
1767   ASSERT_NE(c_node, nullptr);
1768   MutableNodeView* d_node = graph_view.GetNode("d");
1769   ASSERT_NE(d_node, nullptr);
1770 
1771   Mutation* mutation = graph_view.GetMutationBuilder();
1772 
1773   mutation->RemoveControllingFanin(d_node, "c");
1774   mutation->RemoveNode(c_node);
1775 
1776   TF_EXPECT_OK(mutation->Apply());
1777   GraphDef expected_graph = GDef(
1778       {NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^b"},
1779             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1780        NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1781        NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0)},
1782       /*funcs=*/{});
1783   CompareGraphViewWithGraph(&graph_view, expected_graph);
1784 }
1785 
TEST_F(MutationTest,RemoveFaninFanoutAndShiftFanout)1786 TEST_F(MutationTest, RemoveFaninFanoutAndShiftFanout) {
1787   auto test_graph = []() {
1788     return GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0),
1789                  NDef("b", kNoOp, {"a:2", "a:1"}, {}, kDeviceGPU1),
1790                  NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)},
1791                 /*funcs=*/{});
1792   };
1793 
1794   GraphDef graph = test_graph();
1795 
1796   Status s;
1797   MutableGraphView graph_view(&graph, &s);
1798   TF_ASSERT_OK(s);
1799 
1800   MutableNodeView* b_node = graph_view.GetNode("b");
1801   ASSERT_NE(b_node, nullptr);
1802 
1803   Mutation* mutation = graph_view.GetMutationBuilder();
1804 
1805   mutation->RemoveRegularFanin(b_node, 1);
1806 
1807   TF_EXPECT_OK(mutation->Apply());
1808   GraphDef expected_graph =
1809       GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0),
1810             NDef("b", kNoOp, {"a:2"}, {}, kDeviceGPU1),
1811             NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)},
1812            /*funcs=*/{});
1813   CompareGraphViewWithGraph(&graph_view, expected_graph);
1814 }
1815 
TEST_F(MutationTest,ConsecutiveMutations)1816 TEST_F(MutationTest, ConsecutiveMutations) {
1817   GraphDef graph = TestGraphForMutation();
1818 
1819   Status s;
1820   MutableGraphView graph_view(&graph, &s);
1821   TF_ASSERT_OK(s);
1822 
1823   MutableNodeView* b_node = graph_view.GetNode("b");
1824   ASSERT_NE(b_node, nullptr);
1825   MutableNodeView* d_node = graph_view.GetNode("d");
1826   ASSERT_NE(d_node, nullptr);
1827 
1828   Mutation* mutation = graph_view.GetMutationBuilder();
1829 
1830   mutation->RemoveNode(b_node);
1831   mutation->AddOrUpdateRegularFanin(d_node, 1, {"c", 5});
1832   mutation->RemoveControllingFanin(d_node, "b");
1833 
1834   NodeDef new_node_1 = NDef("new_node_1", kIdentity, {"a:3", "d:5", "^d"},
1835                             {{"T", DT_FLOAT}}, kDeviceGPU2);
1836   MutationNewNode new_node_1_node =
1837       mutation->AddNode(std::move(new_node_1), &s);
1838   TF_EXPECT_OK(s);
1839 
1840   mutation->AddOrUpdateRegularFanin(new_node_1_node, 0, {"c", 5});
1841   mutation->RemoveRegularFanin(new_node_1_node, 1);
1842   mutation->AddOrUpdateRegularFanin(new_node_1_node, 1, {"a", 6});
1843   mutation->AddControllingFanin(new_node_1_node, "a");
1844   mutation->RemoveControllingFanin(new_node_1_node, "d");
1845 
1846   TF_EXPECT_OK(mutation->Apply());
1847   GraphDef expected_graph = GDef(
1848       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1849        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1850        NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^c"},
1851             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1852        NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}},
1853             kDeviceGPU2)},
1854       /*funcs=*/{});
1855   CompareGraphViewWithGraph(&graph_view, expected_graph);
1856 
1857   d_node = graph_view.GetNode("d");
1858   ASSERT_NE(d_node, nullptr);
1859 
1860   mutation->AddOrUpdateRegularFanin(d_node, 3, {"new_node_2", 6});
1861   mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 8});
1862   mutation->AddControllingFanin(d_node, "new_node_2");
1863   mutation->AddControllingFanin(d_node, "a");
1864   mutation->RemoveControllingFanin(d_node, "c");
1865 
1866   NodeDef new_node_2 =
1867       NDef("new_node_2", kNoOp, {"c:4", "new_node_1:5", "^d", "^c"});
1868   MutationNewNode new_node_2_node =
1869       mutation->AddNode(std::move(new_node_2), &s);
1870   TF_EXPECT_OK(s);
1871 
1872   mutation->UpdateNodeDevice(new_node_2_node, kDeviceGPU3);
1873   mutation->AddOrUpdateRegularFanin(new_node_2_node, 0, {"new_node_1", 4});
1874   mutation->RemoveRegularFanin(new_node_2_node, 1);
1875   mutation->RemoveControllingFanin(new_node_2_node, "c");
1876   mutation->AddControllingFanin(new_node_2_node, "a");
1877   mutation->AddControllingFanin(new_node_2_node, "new_node_1");
1878 
1879   TF_EXPECT_OK(mutation->Apply());
1880   expected_graph = GDef(
1881       {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1882        NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1883        NDef("d", kNoOp,
1884             {"a:2", "new_node_1:8", "a:4", "new_node_2:6", "^new_node_2", "^a"},
1885             {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1886        NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}},
1887             kDeviceGPU2),
1888        NDef("new_node_2", kNoOp, {"new_node_1:4", "^d", "^a", "^new_node_1"},
1889             {}, kDeviceGPU3)},
1890       /*funcs=*/{});
1891   CompareGraphViewWithGraph(&graph_view, expected_graph);
1892 }
1893 
1894 constexpr char kMatchingFiles[] = "MatchingFiles";
1895 
TEST_F(MutationTest,OpWithUnsupportedDevice)1896 TEST_F(MutationTest, OpWithUnsupportedDevice) {
1897   GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
1898   auto test_graph = []() {
1899     return GDef({NDef("a", kMatchingFiles, {}, {}, kDeviceCPU0)},
1900                 /*funcs=*/{});
1901   };
1902 
1903   GraphDef graph = test_graph();
1904 
1905   Status s;
1906   MutableGraphView graph_view(&graph, &s);
1907   TF_ASSERT_OK(s);
1908 
1909   MutableNodeView* a_node = graph_view.GetNode("a");
1910   ASSERT_NE(a_node, nullptr);
1911 
1912   Mutation* mutation = graph_view.GetMutationBuilder();
1913 
1914   // Unsupported device.
1915   mutation->UpdateNodeDevice(a_node, kDeviceGPU1);
1916 
1917   s = mutation->Apply();
1918   EXPECT_FALSE(s.ok());
1919   CompareGraphViewWithGraph(&graph_view, test_graph());
1920 
1921   mutation->Reset();
1922 
1923   // New node with unsupported device.
1924   NodeDef new_node = NDef("new_node", kMatchingFiles, {}, {}, kDeviceGPU2);
1925   mutation->AddNode(std::move(new_node), &s);
1926   TF_EXPECT_OK(s);
1927 
1928   s = mutation->Apply();
1929   EXPECT_FALSE(s.ok());
1930   CompareGraphViewWithGraph(&graph_view, test_graph());
1931 }
1932 
TEST_F(MutationTest,OpMissingAttribute)1933 TEST_F(MutationTest, OpMissingAttribute) {
1934   GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
1935   auto test_graph = []() {
1936     return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)},
1937                 /*funcs=*/{});
1938   };
1939 
1940   GraphDef graph = test_graph();
1941 
1942   Status s;
1943   MutableGraphView graph_view(&graph, &s);
1944   TF_ASSERT_OK(s);
1945 
1946   MutableNodeView* a_node = graph_view.GetNode("a");
1947   ASSERT_NE(a_node, nullptr);
1948 
1949   Mutation* mutation = graph_view.GetMutationBuilder();
1950 
1951   // Remove necessary attribute.
1952   mutation->RemoveNodeAttr(a_node, "T");
1953 
1954   s = mutation->Apply();
1955   EXPECT_FALSE(s.ok());
1956   CompareGraphViewWithGraph(&graph_view, test_graph());
1957 
1958   mutation->Reset();
1959 
1960   // New node without necessary attribute.
1961   NodeDef new_node = NDef("new_node", kIdentity, {}, {}, kDeviceGPU2);
1962   mutation->AddNode(std::move(new_node), &s);
1963   TF_EXPECT_OK(s);
1964 
1965   s = mutation->Apply();
1966   EXPECT_FALSE(s.ok());
1967   CompareGraphViewWithGraph(&graph_view, test_graph());
1968 }
1969 
TEST_F(MutationTest,EmptyMutationUpdateIndexPersisting)1970 TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) {
1971   auto test_graph = []() {
1972     return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)},
1973                 /*funcs=*/{});
1974   };
1975 
1976   GraphDef graph = test_graph();
1977 
1978   Status s;
1979   MutableGraphView graph_view(&graph, &s);
1980   TF_ASSERT_OK(s);
1981 
1982   MutableNodeView* a_node = graph_view.GetNode("a");
1983   ASSERT_NE(a_node, nullptr);
1984 
1985   Mutation* mutation = graph_view.GetMutationBuilder();
1986 
1987   // Empty MutableNodeViewDiff.
1988   mutation->UpdateNodeName(a_node, "a");
1989 
1990   TF_EXPECT_OK(mutation->Apply());
1991   CompareGraphViewWithGraph(&graph_view, test_graph());
1992 
1993   mutation->Reset();
1994 
1995   // Empty MutableNodeViewDiff, `update_index_` should not persist.
1996   mutation->UpdateNodeName(a_node, "a");
1997 
1998   TF_EXPECT_OK(mutation->Apply());
1999   CompareGraphViewWithGraph(&graph_view, test_graph());
2000 }
2001 
2002 class TopologicalSortTest : public CompareGraphTest {
2003  protected:
CompareGraphOrder(const MutableGraphView & graph_view,absl::Span<const string> node_names)2004   void CompareGraphOrder(const MutableGraphView& graph_view,
2005                          absl::Span<const string> node_names) {
2006     const int num_nodes = graph_view.NumNodes();
2007     ASSERT_EQ(num_nodes, node_names.size());
2008     for (int i = 0; i < num_nodes; ++i) {
2009       EXPECT_EQ(graph_view.GetNode(i)->GetName(), node_names[i]);
2010     }
2011   }
2012 
CompareGraphNodePrecedences(const MutableGraphView & graph_view,absl::Span<const std::pair<string,string>> node_precedences)2013   void CompareGraphNodePrecedences(
2014       const MutableGraphView& graph_view,
2015       absl::Span<const std::pair<string, string>> node_precedences) {
2016     for (const auto& node_precedence : node_precedences) {
2017       auto* parent_node = graph_view.GetNode(node_precedence.first);
2018       ASSERT_NE(parent_node, nullptr);
2019       auto* child_node = graph_view.GetNode(node_precedence.second);
2020       ASSERT_NE(child_node, nullptr);
2021       EXPECT_TRUE(parent_node->node_index() < child_node->node_index());
2022     }
2023   }
2024 };
2025 
TEST_F(TopologicalSortTest,ActiveMutationSort)2026 TEST_F(TopologicalSortTest, ActiveMutationSort) {
2027   auto test_graph = []() {
2028     return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2029                  NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2030                 /*funcs=*/{});
2031   };
2032 
2033   GraphDef graph = test_graph();
2034   Status status;
2035   MutableGraphView graph_view(&graph, &status);
2036   TF_ASSERT_OK(status);
2037 
2038   Mutation* mutation = graph_view.GetMutationBuilder();
2039   mutation->AddNode({}, &status);
2040   TF_ASSERT_OK(status);
2041 
2042   for (bool ignore_cycles : {false, true}) {
2043     status = graph_view.SortTopologically(ignore_cycles, {});
2044     EXPECT_FALSE(status.ok());
2045     EXPECT_EQ(
2046         status.error_message(),
2047         "MutableGraphView::SortTopologically error: active mutation exists.");
2048     CompareGraphViewWithGraph(&graph_view, test_graph());
2049     CompareGraphOrder(graph_view, {"a", "b"});
2050   }
2051 }
2052 
TEST_F(TopologicalSortTest,BadExtraDependenciesSort)2053 TEST_F(TopologicalSortTest, BadExtraDependenciesSort) {
2054   auto test_graph = []() {
2055     return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2056                  NDef("b", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2057                 /*funcs=*/{});
2058   };
2059 
2060   GraphDef graph_1 = test_graph();
2061   Status status;
2062   MutableGraphView graph_view_1(&graph_1, &status);
2063   TF_ASSERT_OK(status);
2064   MutableNodeView* a_node_1 = graph_view_1.GetNode("a");
2065 
2066   GraphDef graph_2 = test_graph();
2067   MutableGraphView graph_view_2(&graph_2, &status);
2068   TF_ASSERT_OK(status);
2069   MutableNodeView* b_node_2 = graph_view_2.GetNode("b");
2070 
2071   for (bool ignore_cycles : {false, true}) {
2072     status =
2073         graph_view_2.SortTopologically(ignore_cycles, {{a_node_1, b_node_2}});
2074     EXPECT_FALSE(status.ok());
2075     EXPECT_EQ(status.error_message(),
2076               "MutableGraphView::SortTopologically error: invalid extra "
2077               "dependencies.");
2078     CompareGraphViewWithGraph(&graph_view_2, test_graph());
2079     CompareGraphOrder(graph_view_2, {"a", "b"});
2080   }
2081 }
2082 
TEST_F(TopologicalSortTest,NoCyclesAllowed)2083 TEST_F(TopologicalSortTest, NoCyclesAllowed) {
2084   auto test_graph = []() {
2085     return GDef(
2086         {NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2087          NDef("b", kIdentity, {"a", "c"}, {{"T", DT_FLOAT}}, kDeviceGPU1),
2088          NDef("c", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2089         /*funcs=*/{});
2090   };
2091 
2092   GraphDef graph = test_graph();
2093   Status status;
2094   MutableGraphView graph_view(&graph, &status);
2095   TF_ASSERT_OK(status);
2096 
2097   status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2098   EXPECT_FALSE(status.ok());
2099   EXPECT_EQ(status.error_message(),
2100             "MutableGraphView::SortTopologically error: detected edge(s) "
2101             "creating cycle(s) {'c' -> 'b'}.");
2102   CompareGraphViewWithGraph(&graph_view, test_graph());
2103   CompareGraphOrder(graph_view, {"a", "b", "c"});
2104 
2105   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2106   CompareGraphViewWithGraph(&graph_view, test_graph());
2107   CompareGraphNodePrecedences(graph_view, {{"a", "b"}, {"a", "c"}});
2108 }
2109 
TEST_F(TopologicalSortTest,NoNodesWithZeroFanins)2110 TEST_F(TopologicalSortTest, NoNodesWithZeroFanins) {
2111   auto test_graph = []() {
2112     return GDef({NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2113                  NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2114                 /*funcs=*/{});
2115   };
2116 
2117   GraphDef graph = test_graph();
2118   Status status;
2119   MutableGraphView graph_view(&graph, &status);
2120   TF_ASSERT_OK(status);
2121 
2122   status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2123   EXPECT_FALSE(status.ok());
2124   EXPECT_EQ(status.error_message(),
2125             "MutableGraphView::SortTopologically error: was not able to sort "
2126             "all nodes topologically.");
2127   CompareGraphViewWithGraph(&graph_view, test_graph());
2128   CompareGraphOrder(graph_view, {"a", "b"});
2129 
2130   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2131   CompareGraphViewWithGraph(&graph_view, test_graph());
2132 }
2133 
TEST_F(TopologicalSortTest,DidNotReachAllNodes)2134 TEST_F(TopologicalSortTest, DidNotReachAllNodes) {
2135   auto test_graph = []() {
2136     return GDef({NDef("c", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU2),
2137                  NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2138                  NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2139                 /*funcs=*/{});
2140   };
2141 
2142   GraphDef graph = test_graph();
2143   Status status;
2144   MutableGraphView graph_view(&graph, &status);
2145   TF_ASSERT_OK(status);
2146 
2147   status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2148   EXPECT_FALSE(status.ok());
2149   EXPECT_EQ(status.error_message(),
2150             "MutableGraphView::SortTopologically error: was not able to sort "
2151             "all nodes topologically.");
2152   CompareGraphViewWithGraph(&graph_view, test_graph());
2153   CompareGraphOrder(graph_view, {"c", "a", "b"});
2154 
2155   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2156   CompareGraphViewWithGraph(&graph_view, test_graph());
2157   CompareGraphOrder(graph_view, {"a", "b", "c"});
2158 }
2159 
TEST_F(TopologicalSortTest,NoLoopGraph)2160 TEST_F(TopologicalSortTest, NoLoopGraph) {
2161   auto test_graph = []() {
2162     return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
2163                  NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
2164                  NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
2165                 /*funcs=*/{});
2166   };
2167 
2168   GraphDef graph = test_graph();
2169   Status status;
2170   MutableGraphView graph_view(&graph, &status);
2171   TF_ASSERT_OK(status);
2172 
2173   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2174   CompareGraphViewWithGraph(&graph_view, test_graph());
2175   CompareGraphNodePrecedences(
2176       graph_view,
2177       {{"f", "a"}, {"f", "c"}, {"e", "a"}, {"e", "b"}, {"c", "d"}, {"d", "b"}});
2178 }
2179 
TEST_F(TopologicalSortTest,ValidLoopGraph)2180 TEST_F(TopologicalSortTest, ValidLoopGraph) {
2181   // Control flow loop.
2182   auto test_graph = []() {
2183     return GDef(
2184         {NDef("while/Const_1", "Const", {}),
2185          NDef("while/Enter_2", "Enter", {"while/Const_1"},
2186               {{"frame_name", "while/while_context"}}),
2187          NDef("while/Const", "Const", {}),
2188          NDef("while/Enter_1", "Enter", {"while/Const"},
2189               {{"frame_name", "while/while_context"}}),
2190          NDef("while/iteration_counter", "Const", {}),
2191          NDef("while/Enter", "Enter", {"while/iteration_counter"},
2192               {{"frame_name", "while/while_context"}}),
2193          NDef("while/maximum_iterations", "Const", {}),
2194          NDef("while/Less/Enter", "Enter", {"while/maximum_iterations"},
2195               {{"frame_name", "while/while_context"}}),
2196          NDef("while/Less", "Less", {"while/Merge", "while/Less/Enter"}),
2197          NDef("while/LogicalAnd", "LogicalAnd",
2198               {"while/Less", "while/cond/Merge"}),
2199          NDef("while/LoopCond", "LoopCond", {"while/LogicalAnd"}),
2200          NDef("while/Switch", "Switch", {"while/Merge", "while/LoopCond"},
2201               {{"_class", "loc:@while/Merge"}}),
2202          NDef("while/Identity", "Identity", {"while/Switch:1"}),
2203          NDef("while/add", "Add", {"while/Identity", "while/add/y"}),
2204          NDef("while/NextIteration", "NextIteration", {"while/add"}),
2205          NDef("while/Merge", "Merge", {"while/Enter", "while/NextIteration"}),
2206          NDef("while/Less_1/y", "Const", {"^while/Merge"}),
2207          NDef("while/add/y", "Const", {"^while/Identity"}),
2208          NDef("while/mul/y", "Const", {"^while/Identity"}),
2209          NDef("while/add_2/y", "Const", {"^while/Identity"}),
2210          NDef("while/Switch_1", "Switch", {"while/Merge_1", "while/LoopCond"},
2211               {{"_class", "loc:@while/Merge_1"}}),
2212          NDef("while/Identity_1", "Identity", {"while/Switch_1:1"}),
2213          NDef("while/add_2", "Add", {"while/Identity_1", "while/add_2/y"}),
2214          NDef("while/NextIteration_1", "NextIteration", {"while/add_2"}),
2215          NDef("while/Merge_1", "Merge",
2216               {"while/Enter_1", "while/NextIteration_1"}),
2217          NDef("while/Less_1", "Less", {"while/Merge_1", "while/Less_1/y"}),
2218          NDef("while/cond/Switch", "Switch", {"while/Less_1", "while/Less_1"}),
2219          NDef("while/cond/switch_f", "Identity", {"while/cond/Switch"}),
2220          NDef("while/cond/Const_1", "Const", {"^while/cond/switch_f"}),
2221          NDef("while/cond/switch_t", "Identity", {"while/cond/Switch:1"}),
2222          NDef("while/cond/Const", "Const", {"^while/cond/switch_t"}),
2223          NDef("while/cond/Merge", "Merge",
2224               {"while/cond/Const_1", "while/cond/Const"}),
2225          NDef("TensorArrayUnstack/range/delta", "Const", {}),
2226          NDef("TensorArrayUnstack/range/start", "Const", {}),
2227          NDef("TensorArrayUnstack/strided_slice/stack_2", "Const", {}),
2228          NDef("TensorArrayUnstack/strided_slice/stack_1", "Const", {}),
2229          NDef("TensorArrayUnstack/strided_slice/stack", "Const", {}),
2230          NDef("TensorArrayUnstack/Shape", "Const", {}),
2231          NDef("TensorArrayUnstack/strided_slice", "StridedSlice",
2232               {"TensorArrayUnstack/Shape",
2233                "TensorArrayUnstack/strided_slice/stack",
2234                "TensorArrayUnstack/strided_slice/stack_1",
2235                "TensorArrayUnstack/strided_slice/stack_2"}),
2236          NDef("TensorArrayUnstack/range", "Range",
2237               {"TensorArrayUnstack/range/start",
2238                "TensorArrayUnstack/strided_slice",
2239                "TensorArrayUnstack/range/delta"}),
2240          NDef("TensorArray/size", "Const", {}),
2241          NDef("TensorArray", "TensorArrayV3", {"TensorArray/size"}),
2242          NDef("while/TensorArrayReadV3/Enter", "Enter", {"TensorArray"},
2243               {{"frame_name", "while/while_context"}}),
2244          NDef("Const", "Const", {}),
2245          NDef("TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3",
2246               "TensorArrayScatterV3",
2247               {"TensorArray", "TensorArrayUnstack/range", "Const",
2248                "TensorArray:1"},
2249               {{"_class", "loc@Const"}}),
2250          NDef("while/TensorArrayReadV3/Enter_1", "Enter",
2251               {"TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3"},
2252               {{"frame_name", "while/while_context"}}),
2253          NDef("while/TensorArrayReadV3", "TensorArrayReadV3",
2254               {"while/TensorArrayReadV3/Enter", "while/Identity_1",
2255                "while/TensorArrayReadV3/Enter_1"}),
2256          NDef("while/add_1", "Add", {"while/mul", "while/TensorArrayReadV3"}),
2257          NDef("while/NextIteration_2", "NextIteration", {"while/add_1"}),
2258          NDef("while/Merge_2", "Merge",
2259               {"while/Enter_2", "while/NextIteration_2"}),
2260          NDef("while/Switch_2", "Switch", {"while/Merge_2", "while/LoopCond"},
2261               {{"_class", "loc@while/Merge_2"}}),
2262          NDef("while/Exit_2", "Exit", {"while/Switch_2"}),
2263          NDef("while/Identity_2", "Identity", {"while/Switch_2:1"}),
2264          NDef("while/mul", "Mul", {"while/Identity_2", "while/mul/y"})},
2265         /*funcs=*/{});
2266   };
2267 
2268   GraphDef graph = test_graph();
2269   Status status;
2270   MutableGraphView graph_view(&graph, &status);
2271   TF_ASSERT_OK(status);
2272 
2273   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2274   CompareGraphViewWithGraph(&graph_view, test_graph());
2275 }
2276 
TEST_F(TopologicalSortTest,DuplicateFanins)2277 TEST_F(TopologicalSortTest, DuplicateFanins) {
2278   auto test_graph = []() {
2279     return GDef(
2280         {NDef("b", kIdentity, {"a", "a", "^a"}), NDef("a", "Const", {})},
2281         /*funcs=*/{});
2282   };
2283 
2284   GraphDef graph = test_graph();
2285   Status status;
2286   MutableGraphView graph_view(&graph, &status);
2287   TF_ASSERT_OK(status);
2288 
2289   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2290   CompareGraphViewWithGraph(&graph_view, test_graph());
2291   CompareGraphOrder(graph_view, {"a", "b"});
2292 }
2293 
TEST_F(TopologicalSortTest,DiamondDependencyNotACycle)2294 TEST_F(TopologicalSortTest, DiamondDependencyNotACycle) {
2295   auto test_graph = []() {
2296     return GDef({NDef("e", kIdentity, {"b", "c", "d"}),
2297                  NDef("b", kIdentity, {"a"}), NDef("a", "Const", {}),
2298                  NDef("d", kIdentity, {"a"}), NDef("c", kIdentity, {"a"})},
2299                 /*funcs=*/{});
2300   };
2301 
2302   GraphDef graph = test_graph();
2303   Status status;
2304   MutableGraphView graph_view(&graph, &status);
2305   TF_ASSERT_OK(status);
2306 
2307   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2308   CompareGraphViewWithGraph(&graph_view, test_graph());
2309   CompareGraphNodePrecedences(
2310       graph_view,
2311       {{"a", "b"}, {"a", "c"}, {"a", "d"}, {"b", "e"}, {"c", "e"}, {"d", "e"}});
2312 }
2313 
TEST_F(TopologicalSortTest,ExtraDependencies)2314 TEST_F(TopologicalSortTest, ExtraDependencies) {
2315   auto test_graph = []() {
2316     return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
2317                  NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
2318                  NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
2319                 /*funcs=*/{});
2320   };
2321 
2322   GraphDef graph = test_graph();
2323   Status status;
2324   MutableGraphView graph_view(&graph, &status);
2325   TF_ASSERT_OK(status);
2326 
2327   auto* e_node = graph_view.GetNode("e");
2328   ASSERT_NE(e_node, nullptr);
2329   auto* f_node = graph_view.GetNode("f");
2330   ASSERT_NE(f_node, nullptr);
2331 
2332   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false,
2333                                             {{e_node, f_node}}));
2334   CompareGraphViewWithGraph(&graph_view, test_graph());
2335   CompareGraphNodePrecedences(graph_view, {{"f", "a"},
2336                                            {"f", "c"},
2337                                            {"e", "a"},
2338                                            {"e", "b"},
2339                                            {"c", "d"},
2340                                            {"d", "b"},
2341                                            {"e", "f"}});
2342 }
2343 
TEST_F(TopologicalSortTest,PushVisitedNodes)2344 TEST_F(TopologicalSortTest, PushVisitedNodes) {
2345   auto test_graph = []() {
2346     return GDef({NDef("d", kIdentity, {"c"}), NDef("c", kIdentity, {"b", "a"}),
2347                  NDef("b", kIdentity, {"a"}), NDef("a", kIdentity, {})},
2348                 /*funcs=*/{});
2349   };
2350 
2351   GraphDef graph = test_graph();
2352   Status status;
2353   MutableGraphView graph_view(&graph, &status);
2354   TF_ASSERT_OK(status);
2355 
2356   TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2357   CompareGraphViewWithGraph(&graph_view, test_graph());
2358   CompareGraphNodePrecedences(graph_view,
2359                               {{"a", "b"}, {"a", "c"}, {"b", "c"}, {"c", "d"}});
2360 }
2361 
2362 #define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \
2363   BENCHMARK(name)                             \
2364       ->ArgPair(10, 2)                        \
2365       ->ArgPair(100, 2)                       \
2366       ->ArgPair(1000, 2)                      \
2367       ->ArgPair(10000, 2)                     \
2368       ->ArgPair(25000, 2)                     \
2369       ->ArgPair(50000, 2)                     \
2370       ->ArgPair(100000, 2)                    \
2371       ->ArgPair(10, 4)                        \
2372       ->ArgPair(100, 4)                       \
2373       ->ArgPair(1000, 4)                      \
2374       ->ArgPair(10000, 4)                     \
2375       ->ArgPair(25000, 4)                     \
2376       ->ArgPair(50000, 4)                     \
2377       ->ArgPair(100000, 4)                    \
2378       ->ArgPair(10, 8)                        \
2379       ->ArgPair(100, 8)                       \
2380       ->ArgPair(1000, 8)                      \
2381       ->ArgPair(10000, 8)                     \
2382       ->ArgPair(25000, 8)                     \
2383       ->ArgPair(50000, 8)                     \
2384       ->ArgPair(100000, 8)                    \
2385       ->ArgPair(10, 16)                       \
2386       ->ArgPair(100, 16)                      \
2387       ->ArgPair(1000, 16)                     \
2388       ->ArgPair(10000, 16)                    \
2389       ->ArgPair(25000, 16)                    \
2390       ->ArgPair(50000, 16)                    \
2391       ->ArgPair(100000, 16);
2392 
2393 template <typename GraphViewT>
BM_GraphViewTConstruction(::testing::benchmark::State & state)2394 void BM_GraphViewTConstruction(::testing::benchmark::State& state) {
2395   const int num_nodes = state.range(0);
2396   const int num_edges_per_node = state.range(1);
2397 
2398   GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
2399 
2400   for (auto i : state) {
2401     Status s;
2402     GraphViewT graph_view(&graph_def, &s);
2403   }
2404 }
2405 
BM_GraphViewConstruction(::testing::benchmark::State & state)2406 void BM_GraphViewConstruction(::testing::benchmark::State& state) {
2407   BM_GraphViewTConstruction<GraphView>(state);
2408 }
2409 
BM_MutableGraphViewConstruction(::testing::benchmark::State & state)2410 void BM_MutableGraphViewConstruction(::testing::benchmark::State& state) {
2411   BM_GraphViewTConstruction<MutableGraphView>(state);
2412 }
2413 
BM_MutableGraphViewClearAttrs(::testing::benchmark::State & state)2414 void BM_MutableGraphViewClearAttrs(::testing::benchmark::State& state) {
2415   const int num_nodes = state.range(0);
2416   const int num_edges_per_node = state.range(1);
2417 
2418   GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
2419 
2420   Status s;
2421   MutableGraphView graph_view(&graph_def, &s);
2422 
2423   for (auto i : state) {
2424     utils::Mutation* mutation = graph_view.GetMutationBuilder();
2425     for (int j = 0; j < num_nodes; ++j) {
2426       mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
2427     }
2428     s = mutation->Apply();
2429   }
2430 }
2431 
2432 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
2433 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
2434 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
2435 
2436 #define RUN_NUM_NODE_BENCHMARK(name) \
2437   BENCHMARK(name)                    \
2438       ->Arg(10)                      \
2439       ->Arg(100)                     \
2440       ->Arg(1000)                    \
2441       ->Arg(10000)                   \
2442       ->Arg(25000)                   \
2443       ->Arg(50000)                   \
2444       ->Arg(100000);
2445 
2446 template <typename GraphViewT>
BM_GraphViewTConstructionWithControlDependencies(::testing::benchmark::State & state)2447 void BM_GraphViewTConstructionWithControlDependencies(
2448     ::testing::benchmark::State& state) {
2449   const int num_fanins_fanouts = state.range(0);
2450 
2451   GraphDef graph_def =
2452       test::CreateFaninFanoutNodeGraph(num_fanins_fanouts, num_fanins_fanouts,
2453                                        num_fanins_fanouts, num_fanins_fanouts,
2454                                        /*fanout_unique_index=*/true);
2455 
2456   for (auto i : state) {
2457     Status s;
2458     GraphViewT graph_view(&graph_def, &s);
2459   }
2460 }
2461 
BM_GraphViewConstructionWithControlDependencies(::testing::benchmark::State & state)2462 void BM_GraphViewConstructionWithControlDependencies(
2463     ::testing::benchmark::State& state) {
2464   BM_GraphViewTConstructionWithControlDependencies<GraphView>(state);
2465 }
2466 
BM_MutableGraphViewConstructionWithControlDependencies(::testing::benchmark::State & state)2467 void BM_MutableGraphViewConstructionWithControlDependencies(
2468     ::testing::benchmark::State& state) {
2469   BM_GraphViewTConstructionWithControlDependencies<MutableGraphView>(state);
2470 }
2471 
2472 RUN_NUM_NODE_BENCHMARK(BM_GraphViewConstructionWithControlDependencies);
2473 RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewConstructionWithControlDependencies);
2474 
2475 template <typename GraphViewT>
BM_GraphViewTGetNode(::testing::benchmark::State & state)2476 void BM_GraphViewTGetNode(::testing::benchmark::State& state) {
2477   const int num_nodes = state.range(0);
2478 
2479   GraphDef graph_def =
2480       test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
2481   Status s;
2482   GraphViewT graph_view(&graph_def, &s);
2483 
2484   for (auto i : state) {
2485     graph_view.GetNode("out");
2486   }
2487 }
2488 
BM_GraphViewGetNode(::testing::benchmark::State & state)2489 void BM_GraphViewGetNode(::testing::benchmark::State& state) {
2490   BM_GraphViewTGetNode<GraphView>(state);
2491 }
2492 
BM_MutableGraphViewGetNode(::testing::benchmark::State & state)2493 void BM_MutableGraphViewGetNode(::testing::benchmark::State& state) {
2494   BM_GraphViewTGetNode<MutableGraphView>(state);
2495 }
2496 
2497 RUN_NUM_NODE_BENCHMARK(BM_GraphViewGetNode);
2498 RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewGetNode);
2499 
2500 #define RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(name) \
2501   BENCHMARK(name)                                \
2502       ->ArgPair(10, 10)                          \
2503       ->ArgPair(10, 100)                         \
2504       ->ArgPair(10, 1000)                        \
2505       ->ArgPair(10, 10000)                       \
2506       ->ArgPair(10, 100000)                      \
2507       ->ArgPair(100, 10)                         \
2508       ->ArgPair(100, 100)                        \
2509       ->ArgPair(100, 1000)                       \
2510       ->ArgPair(100, 10000)                      \
2511       ->ArgPair(100, 100000)                     \
2512       ->ArgPair(1000, 10)                        \
2513       ->ArgPair(1000, 100)                       \
2514       ->ArgPair(1000, 1000)                      \
2515       ->ArgPair(1000, 10000)                     \
2516       ->ArgPair(1000, 100000)                    \
2517       ->ArgPair(10000, 10)                       \
2518       ->ArgPair(10000, 100)                      \
2519       ->ArgPair(10000, 1000)                     \
2520       ->ArgPair(10000, 10000)                    \
2521       ->ArgPair(10000, 100000)                   \
2522       ->ArgPair(100000, 10)                      \
2523       ->ArgPair(100000, 100)                     \
2524       ->ArgPair(100000, 1000)                    \
2525       ->ArgPair(100000, 10000)                   \
2526       ->ArgPair(100000, 100000);
2527 
2528 template <typename GraphViewT>
BM_GraphViewTGetRegularFanin(::testing::benchmark::State & state)2529 void BM_GraphViewTGetRegularFanin(::testing::benchmark::State& state) {
2530   const int num_fanins = state.range(0);
2531   const int num_fanouts = state.range(1);
2532 
2533   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2534       num_fanins, num_fanouts, num_fanins, num_fanouts,
2535       /*fanout_unique_index=*/true);
2536   Status s;
2537   GraphViewT graph_view(&graph_def, &s);
2538 
2539   for (auto i : state) {
2540     auto* node = graph_view.GetNode("node");
2541     node->GetRegularFanin(0);
2542   }
2543 }
2544 
BM_GraphViewGetRegularFanin(::testing::benchmark::State & state)2545 void BM_GraphViewGetRegularFanin(::testing::benchmark::State& state) {
2546   BM_GraphViewTGetRegularFanin<GraphView>(state);
2547 }
2548 
BM_MutableGraphViewGetRegularFanin(::testing::benchmark::State & state)2549 void BM_MutableGraphViewGetRegularFanin(::testing::benchmark::State& state) {
2550   BM_GraphViewTGetRegularFanin<MutableGraphView>(state);
2551 }
2552 
2553 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
2554 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanin);
2555 
2556 template <typename GraphViewT>
BM_GraphViewTGetRegularFanout(::testing::benchmark::State & state)2557 void BM_GraphViewTGetRegularFanout(::testing::benchmark::State& state) {
2558   const int num_fanins = state.range(0);
2559   const int num_fanouts = state.range(1);
2560 
2561   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2562       num_fanins, num_fanouts, num_fanins, num_fanouts,
2563       /*fanout_unique_index=*/true);
2564   Status s;
2565   GraphViewT graph_view(&graph_def, &s);
2566 
2567   for (auto i : state) {
2568     auto* node = graph_view.GetNode("node");
2569     node->GetRegularFanout(0);
2570   }
2571 }
2572 
BM_GraphViewGetRegularFanout(::testing::benchmark::State & state)2573 void BM_GraphViewGetRegularFanout(::testing::benchmark::State& state) {
2574   BM_GraphViewTGetRegularFanout<GraphView>(state);
2575 }
2576 
BM_MutableGraphViewGetRegularFanout(::testing::benchmark::State & state)2577 void BM_MutableGraphViewGetRegularFanout(::testing::benchmark::State& state) {
2578   BM_GraphViewTGetRegularFanout<MutableGraphView>(state);
2579 }
2580 
2581 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanout);
2582 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanout);
2583 
2584 template <typename GraphViewT>
BM_GraphViewTGetRegularFanins(::testing::benchmark::State & state)2585 void BM_GraphViewTGetRegularFanins(::testing::benchmark::State& state) {
2586   const int num_fanins = state.range(0);
2587   const int num_fanouts = state.range(1);
2588 
2589   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2590       num_fanins, num_fanouts, num_fanins, num_fanouts,
2591       /*fanout_unique_index=*/true);
2592   Status s;
2593   GraphViewT graph_view(&graph_def, &s);
2594 
2595   for (auto i : state) {
2596     auto* node = graph_view.GetNode("node");
2597     node->GetRegularFanins();
2598   }
2599 }
2600 
BM_GraphViewGetRegularFanins(::testing::benchmark::State & state)2601 void BM_GraphViewGetRegularFanins(::testing::benchmark::State& state) {
2602   BM_GraphViewTGetRegularFanins<GraphView>(state);
2603 }
2604 
BM_MutableGraphViewGetRegularFanins(::testing::benchmark::State & state)2605 void BM_MutableGraphViewGetRegularFanins(::testing::benchmark::State& state) {
2606   BM_GraphViewTGetRegularFanins<MutableGraphView>(state);
2607 }
2608 
2609 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanins);
2610 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanins);
2611 
2612 template <typename GraphViewT>
BM_GraphViewTGetRegularFanouts(::testing::benchmark::State & state)2613 void BM_GraphViewTGetRegularFanouts(::testing::benchmark::State& state) {
2614   const int num_fanins = state.range(0);
2615   const int num_fanouts = state.range(1);
2616 
2617   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2618       num_fanins, num_fanouts, num_fanins, num_fanouts,
2619       /*fanout_unique_index=*/true);
2620   Status s;
2621   GraphViewT graph_view(&graph_def, &s);
2622 
2623   for (auto i : state) {
2624     auto* node = graph_view.GetNode("node");
2625     node->GetRegularFanouts();
2626   }
2627 }
2628 
BM_GraphViewGetRegularFanouts(::testing::benchmark::State & state)2629 void BM_GraphViewGetRegularFanouts(::testing::benchmark::State& state) {
2630   BM_GraphViewTGetRegularFanouts<GraphView>(state);
2631 }
2632 
BM_MutableGraphViewGetRegularFanouts(::testing::benchmark::State & state)2633 void BM_MutableGraphViewGetRegularFanouts(::testing::benchmark::State& state) {
2634   BM_GraphViewTGetRegularFanouts<MutableGraphView>(state);
2635 }
2636 
2637 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanouts);
2638 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanouts);
2639 
2640 template <typename GraphViewT>
BM_GraphViewTGetControllingFanins(::testing::benchmark::State & state)2641 void BM_GraphViewTGetControllingFanins(::testing::benchmark::State& state) {
2642   const int num_fanins = state.range(0);
2643   const int num_fanouts = state.range(1);
2644 
2645   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2646       num_fanins, num_fanouts, num_fanins, num_fanouts,
2647       /*fanout_unique_index=*/true);
2648   Status s;
2649   GraphViewT graph_view(&graph_def, &s);
2650 
2651   for (auto i : state) {
2652     auto* node = graph_view.GetNode("node");
2653     node->GetControllingFanins();
2654   }
2655 }
2656 
BM_GraphViewGetControllingFanins(::testing::benchmark::State & state)2657 void BM_GraphViewGetControllingFanins(::testing::benchmark::State& state) {
2658   BM_GraphViewTGetControllingFanins<GraphView>(state);
2659 }
2660 
BM_MutableGraphViewGetControllingFanins(::testing::benchmark::State & state)2661 void BM_MutableGraphViewGetControllingFanins(
2662     ::testing::benchmark::State& state) {
2663   BM_GraphViewTGetControllingFanins<MutableGraphView>(state);
2664 }
2665 
2666 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControllingFanins);
2667 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControllingFanins);
2668 
2669 template <typename GraphViewT>
BM_GraphViewTGetControlledFanouts(::testing::benchmark::State & state)2670 void BM_GraphViewTGetControlledFanouts(::testing::benchmark::State& state) {
2671   const int num_fanins = state.range(0);
2672   const int num_fanouts = state.range(1);
2673 
2674   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2675       num_fanins, num_fanouts, num_fanins, num_fanouts,
2676       /*fanout_unique_index=*/true);
2677   Status s;
2678   GraphViewT graph_view(&graph_def, &s);
2679 
2680   for (auto i : state) {
2681     auto* node = graph_view.GetNode("node");
2682     node->GetControlledFanouts();
2683   }
2684 }
2685 
BM_GraphViewGetControlledFanouts(::testing::benchmark::State & state)2686 void BM_GraphViewGetControlledFanouts(::testing::benchmark::State& state) {
2687   BM_GraphViewTGetControlledFanouts<GraphView>(state);
2688 }
2689 
BM_MutableGraphViewGetControlledFanouts(::testing::benchmark::State & state)2690 void BM_MutableGraphViewGetControlledFanouts(
2691     ::testing::benchmark::State& state) {
2692   BM_GraphViewTGetControlledFanouts<MutableGraphView>(state);
2693 }
2694 
2695 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControlledFanouts);
2696 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControlledFanouts);
2697 
2698 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasRegularFanin(::testing::benchmark::State & state)2699 inline void BM_GraphViewTHasRegularFanin(::testing::benchmark::State& state) {
2700   const int num_fanins = state.range(0);
2701   const int num_fanouts = state.range(1);
2702 
2703   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2704       num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
2705       /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
2706   Status s;
2707   GraphViewT graph_view(&graph_def, &s);
2708   const int index = IsLast ? num_fanouts - 1 : 0;
2709   auto* node = graph_view.GetNode(absl::StrFormat("out%05d", index));
2710   auto* fanin = graph_view.GetNode("node");
2711 
2712   for (auto i : state) {
2713     node->HasFanin({&graph_view, fanin->node_index(), 0});
2714   }
2715 }
2716 
BM_GraphViewHasRegularFaninFirst(::testing::benchmark::State & state)2717 void BM_GraphViewHasRegularFaninFirst(::testing::benchmark::State& state) {
2718   BM_GraphViewTHasRegularFanin<GraphView, false>(state);
2719 }
2720 
BM_GraphViewHasRegularFaninLast(::testing::benchmark::State & state)2721 void BM_GraphViewHasRegularFaninLast(::testing::benchmark::State& state) {
2722   BM_GraphViewTHasRegularFanin<GraphView, true>(state);
2723 }
2724 
BM_MutableGraphViewHasRegularFaninFirst(::testing::benchmark::State & state)2725 void BM_MutableGraphViewHasRegularFaninFirst(
2726     ::testing::benchmark::State& state) {
2727   BM_GraphViewTHasRegularFanin<MutableGraphView, false>(state);
2728 }
2729 
BM_MutableGraphViewHasRegularFaninLast(::testing::benchmark::State & state)2730 void BM_MutableGraphViewHasRegularFaninLast(
2731     ::testing::benchmark::State& state) {
2732   BM_GraphViewTHasRegularFanin<MutableGraphView, true>(state);
2733 }
2734 
2735 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninFirst);
2736 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninLast);
2737 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninFirst);
2738 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninLast);
2739 
2740 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasControllingFanin(::testing::benchmark::State & state)2741 inline void BM_GraphViewTHasControllingFanin(
2742     ::testing::benchmark::State& state) {
2743   const int num_fanins = state.range(0);
2744   const int num_fanouts = state.range(1);
2745 
2746   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2747       num_fanins, num_fanouts, num_fanins, num_fanouts,
2748       /*fanout_unique_index=*/true);
2749   Status s;
2750   GraphViewT graph_view(&graph_def, &s);
2751   const int index = IsLast ? num_fanouts - 1 : 0;
2752   auto* node = graph_view.GetNode(absl::StrFormat("control_out%05d", index));
2753   auto* fanin = graph_view.GetNode("node");
2754 
2755   for (auto i : state) {
2756     node->HasFanin({&graph_view, fanin->node_index(), Graph::kControlSlot});
2757   }
2758 }
2759 
BM_GraphViewHasControllingFaninFirst(::testing::benchmark::State & state)2760 void BM_GraphViewHasControllingFaninFirst(::testing::benchmark::State& state) {
2761   BM_GraphViewTHasControllingFanin<GraphView, false>(state);
2762 }
2763 
BM_GraphViewHasControllingFaninLast(::testing::benchmark::State & state)2764 void BM_GraphViewHasControllingFaninLast(::testing::benchmark::State& state) {
2765   BM_GraphViewTHasControllingFanin<GraphView, true>(state);
2766 }
2767 
BM_MutableGraphViewHasControllingFaninFirst(::testing::benchmark::State & state)2768 void BM_MutableGraphViewHasControllingFaninFirst(
2769     ::testing::benchmark::State& state) {
2770   BM_GraphViewTHasControllingFanin<MutableGraphView, false>(state);
2771 }
2772 
BM_MutableGraphViewHasControllingFaninLast(::testing::benchmark::State & state)2773 void BM_MutableGraphViewHasControllingFaninLast(
2774     ::testing::benchmark::State& state) {
2775   BM_GraphViewTHasControllingFanin<MutableGraphView, true>(state);
2776 }
2777 
2778 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninFirst);
2779 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninLast);
2780 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninFirst);
2781 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninLast);
2782 
2783 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasRegularFanout(::testing::benchmark::State & state)2784 inline void BM_GraphViewTHasRegularFanout(::testing::benchmark::State& state) {
2785   const int num_fanins = state.range(0);
2786   const int num_fanouts = state.range(1);
2787 
2788   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2789       num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
2790       /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
2791   Status s;
2792   GraphViewT graph_view(&graph_def, &s);
2793   const int index = IsLast ? num_fanins - 1 : 0;
2794   auto* node = graph_view.GetNode(absl::StrFormat("in%05d", index));
2795   auto* fanout = graph_view.GetNode("node");
2796 
2797   for (auto i : state) {
2798     node->HasFanout({&graph_view, fanout->node_index(), index});
2799   }
2800 }
2801 
BM_GraphViewHasRegularFanoutFirst(::testing::benchmark::State & state)2802 void BM_GraphViewHasRegularFanoutFirst(::testing::benchmark::State& state) {
2803   BM_GraphViewTHasRegularFanout<GraphView, false>(state);
2804 }
2805 
BM_GraphViewHasRegularFanoutLast(::testing::benchmark::State & state)2806 void BM_GraphViewHasRegularFanoutLast(::testing::benchmark::State& state) {
2807   BM_GraphViewTHasRegularFanout<GraphView, true>(state);
2808 }
2809 
BM_MutableGraphViewHasRegularFanoutFirst(::testing::benchmark::State & state)2810 void BM_MutableGraphViewHasRegularFanoutFirst(
2811     ::testing::benchmark::State& state) {
2812   BM_GraphViewTHasRegularFanout<MutableGraphView, false>(state);
2813 }
2814 
BM_MutableGraphViewHasRegularFanoutLast(::testing::benchmark::State & state)2815 void BM_MutableGraphViewHasRegularFanoutLast(
2816     ::testing::benchmark::State& state) {
2817   BM_GraphViewTHasRegularFanout<MutableGraphView, true>(state);
2818 }
2819 
2820 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutFirst);
2821 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutLast);
2822 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutFirst);
2823 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutLast);
2824 
2825 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasControlledFanout(::testing::benchmark::State & state)2826 inline void BM_GraphViewTHasControlledFanout(
2827     ::testing::benchmark::State& state) {
2828   const int num_fanins = state.range(0);
2829   const int num_fanouts = state.range(1);
2830 
2831   GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2832       num_fanins, num_fanouts, num_fanins, num_fanouts,
2833       /*fanout_unique_index=*/false);
2834   Status s;
2835   GraphViewT graph_view(&graph_def, &s);
2836   const int index = IsLast ? num_fanins - 1 : 0;
2837   auto* node = graph_view.GetNode(absl::StrFormat("control_in%05d", index));
2838   auto* fanout = graph_view.GetNode("node");
2839 
2840   for (auto i : state) {
2841     node->HasFanout({&graph_view, fanout->node_index(), Graph::kControlSlot});
2842   }
2843 }
2844 
BM_GraphViewHasControlledFanoutFirst(::testing::benchmark::State & state)2845 void BM_GraphViewHasControlledFanoutFirst(::testing::benchmark::State& state) {
2846   BM_GraphViewTHasControlledFanout<GraphView, false>(state);
2847 }
2848 
BM_GraphViewHasControlledFanoutLast(::testing::benchmark::State & state)2849 void BM_GraphViewHasControlledFanoutLast(::testing::benchmark::State& state) {
2850   BM_GraphViewTHasControlledFanout<GraphView, true>(state);
2851 }
2852 
BM_MutableGraphViewHasControlledFanoutFirst(::testing::benchmark::State & state)2853 void BM_MutableGraphViewHasControlledFanoutFirst(
2854     ::testing::benchmark::State& state) {
2855   BM_GraphViewTHasControlledFanout<MutableGraphView, false>(state);
2856 }
2857 
BM_MutableGraphViewHasControlledFanoutLast(::testing::benchmark::State & state)2858 void BM_MutableGraphViewHasControlledFanoutLast(
2859     ::testing::benchmark::State& state) {
2860   BM_GraphViewTHasControlledFanout<MutableGraphView, true>(state);
2861 }
2862 
2863 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutFirst);
2864 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast);
2865 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst);
2866 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast);
2867 
BM_SortTopologically(::testing::benchmark::State & state)2868 void BM_SortTopologically(::testing::benchmark::State& state) {
2869   const int size = state.range(0);
2870 
2871   GraphDef graph = test::CreateRandomGraph(size);
2872   Status status;
2873   MutableGraphView graph_view(&graph, &status);
2874   TF_ASSERT_OK(status);
2875 
2876   for (auto i : state) {
2877     TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2878   }
2879 }
2880 
2881 RUN_NUM_NODE_BENCHMARK(BM_SortTopologically);
2882 
2883 }  // namespace
2884 }  // namespace utils
2885 }  // namespace grappler
2886 }  // namespace tensorflow
2887