1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/graph_view.h"
17
18 #include <type_traits>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/substitute.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/core/framework/function_testlib.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/benchmark_testlib.h"
27 #include "tensorflow/core/grappler/utils/grappler_test.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/platform/test_benchmark.h"
32
33 namespace tensorflow {
34 namespace grappler {
35 namespace utils {
36 namespace {
37
38 using ::tensorflow::test::function::GDef;
39 using ::tensorflow::test::function::NDef;
40
41 constexpr char kNoOp[] = "NoOp";
42
SimpleTestGraph()43 GraphDef SimpleTestGraph() {
44 return GDef({NDef("a", kNoOp, {"b:2", "d:3", "b:2", "d:3", "^c"}),
45 NDef("b", kNoOp, {"d:2", "c:5", "^c"}),
46 NDef("c", kNoOp, {"^d", "^d"}), NDef("d", kNoOp, {})},
47 /*funcs=*/{});
48 }
49
50 template <typename T>
GetGraphViewTypeAsString()51 const string GetGraphViewTypeAsString() {
52 return std::is_same<T, class GraphView>::value ? "GraphView"
53 : "MutableGraphView";
54 }
55
56 using GraphViewTypes = ::testing::Types<GraphView, MutableGraphView>;
57
58 template <typename T>
59 class TypedGraphViewTest : public ::testing::Test {};
60 TYPED_TEST_SUITE(TypedGraphViewTest, GraphViewTypes);
61
TYPED_TEST(TypedGraphViewTest,GraphWithDuplicateNodeNames)62 TYPED_TEST(TypedGraphViewTest, GraphWithDuplicateNodeNames) {
63 GraphDef graph =
64 GDef({NDef("a", kNoOp, {}), NDef("a", kNoOp, {})}, /*funcs=*/{});
65
66 Status s;
67 TypeParam graph_view(&graph, &s);
68 EXPECT_FALSE(s.ok());
69 EXPECT_EQ(s.error_message(),
70 absl::Substitute(
71 "$0::$0 error: graph has multiple nodes with the name 'a'.",
72 GetGraphViewTypeAsString<TypeParam>()));
73 }
74
TYPED_TEST(TypedGraphViewTest,GraphWithMissingFanins)75 TYPED_TEST(TypedGraphViewTest, GraphWithMissingFanins) {
76 GraphDef graph = GDef({NDef("a", kNoOp, {"b:3"})}, /*funcs=*/{});
77
78 Status s;
79 TypeParam graph_view(&graph, &s);
80 EXPECT_FALSE(s.ok());
81 EXPECT_EQ(s.error_message(),
82 absl::Substitute("$0::$0 error: node 'a' has missing fanin 'b:3'.",
83 GetGraphViewTypeAsString<TypeParam>()));
84 }
85
TYPED_TEST(TypedGraphViewTest,GraphWithSelfCycles)86 TYPED_TEST(TypedGraphViewTest, GraphWithSelfCycles) {
87 GraphDef graph = GDef({NDef("a", kNoOp, {"a:4"})}, /*funcs=*/{});
88
89 Status s;
90 TypeParam graph_view(&graph, &s);
91 EXPECT_FALSE(s.ok());
92 EXPECT_EQ(
93 s.error_message(),
94 absl::Substitute("$0::$0 error: node 'a' has self cycle fanin 'a:4'.",
95 GetGraphViewTypeAsString<TypeParam>()));
96 }
97
TYPED_TEST(TypedGraphViewTest,GraphWithMisorderedFanins)98 TYPED_TEST(TypedGraphViewTest, GraphWithMisorderedFanins) {
99 GraphDef graph = GDef({NDef("a", kNoOp, {"^b", "b:4"}), NDef("b", kNoOp, {})},
100 /*funcs=*/{});
101
102 Status s;
103 TypeParam graph_view(&graph, &s);
104 EXPECT_FALSE(s.ok());
105 EXPECT_EQ(s.error_message(),
106 absl::Substitute("$0::$0 error: node 'a' has regular fanin 'b:4' "
107 "after controlling fanins.",
108 GetGraphViewTypeAsString<TypeParam>()));
109 }
110
TYPED_TEST(TypedGraphViewTest,GetNodeWithIndex)111 TYPED_TEST(TypedGraphViewTest, GetNodeWithIndex) {
112 GraphDef graph = SimpleTestGraph();
113
114 Status s;
115 TypeParam graph_view(&graph, &s);
116 TF_ASSERT_OK(s);
117
118 const int num_nodes = graph_view.NumNodes();
119 ASSERT_EQ(graph_view.NumNodes(), graph.node_size());
120 for (int i = 0; i < num_nodes; ++i) {
121 const auto* node = graph_view.GetNode(i);
122 ASSERT_NE(node, nullptr);
123 EXPECT_EQ(node->node(), graph.mutable_node(i));
124 }
125
126 const auto* bad_node = graph_view.GetNode(-1);
127 ASSERT_EQ(bad_node, nullptr);
128 bad_node = graph_view.GetNode(num_nodes);
129 ASSERT_EQ(bad_node, nullptr);
130 }
131
TYPED_TEST(TypedGraphViewTest,GetNodeWithName)132 TYPED_TEST(TypedGraphViewTest, GetNodeWithName) {
133 GraphDef graph = SimpleTestGraph();
134
135 Status s;
136 TypeParam graph_view(&graph, &s);
137 TF_ASSERT_OK(s);
138
139 std::vector<string> node_names = {"a", "b", "c", "d"};
140 for (int i = 0; i < node_names.size(); ++i) {
141 const string& node_name = node_names[i];
142 const auto* node = graph_view.GetNode(node_name);
143 ASSERT_NE(node, nullptr);
144 EXPECT_EQ(node->node(), graph.mutable_node(i));
145 }
146
147 // Missing node.
148 const auto* bad_node = graph_view.GetNode("e");
149 ASSERT_EQ(bad_node, nullptr);
150 }
151
TYPED_TEST(TypedGraphViewTest,GetNodes)152 TYPED_TEST(TypedGraphViewTest, GetNodes) {
153 GraphDef graph = SimpleTestGraph();
154
155 Status s;
156 TypeParam graph_view(&graph, &s);
157 TF_ASSERT_OK(s);
158
159 const auto& nodes = graph_view.GetNodes();
160 const int num_nodes = nodes.size();
161 EXPECT_EQ(num_nodes, 4);
162
163 ASSERT_EQ(num_nodes, graph.node_size());
164 for (int i = 0; i < num_nodes; ++i) {
165 EXPECT_EQ(nodes[i].node(), graph.mutable_node(i));
166 }
167 }
168
TYPED_TEST(TypedGraphViewTest,HasNode)169 TYPED_TEST(TypedGraphViewTest, HasNode) {
170 GraphDef graph = SimpleTestGraph();
171
172 Status s;
173 TypeParam graph_view(&graph, &s);
174 TF_ASSERT_OK(s);
175
176 for (const string& node_name : {"a", "b", "c", "d"}) {
177 EXPECT_TRUE(graph_view.HasNode(node_name));
178 }
179
180 // Missing node.
181 EXPECT_FALSE(graph_view.HasNode("e"));
182 }
183
TYPED_TEST(TypedGraphViewTest,NumNodes)184 TYPED_TEST(TypedGraphViewTest, NumNodes) {
185 GraphDef graph = SimpleTestGraph();
186
187 Status s;
188 TypeParam graph_view(&graph, &s);
189 TF_ASSERT_OK(s);
190 EXPECT_EQ(graph_view.NumNodes(), 4);
191 }
192
TYPED_TEST(TypedGraphViewTest,NumNodesEmptyGraph)193 TYPED_TEST(TypedGraphViewTest, NumNodesEmptyGraph) {
194 GraphDef graph;
195
196 Status s;
197 TypeParam graph_view(&graph, &s);
198 TF_ASSERT_OK(s);
199 EXPECT_EQ(graph_view.NumNodes(), 0);
200 }
201
TEST(MutableGraphViewTest,DedupControlDependencies)202 TEST(MutableGraphViewTest, DedupControlDependencies) {
203 GraphDef graph = GDef(
204 {NDef("a", kNoOp, {}), NDef("b", kNoOp, {}), NDef("c", kNoOp, {}),
205 NDef("d", kNoOp, {"a:2", "b:1", "^c", "^c", "^a", "^a", "^b", "^c"})},
206 /*funcs=*/{});
207
208 Status s;
209 MutableGraphView graph_view(&graph, &s);
210 TF_ASSERT_OK(s);
211 EXPECT_EQ(graph_view.NumNodes(), 4);
212
213 const auto* a_node = graph_view.GetNode("a");
214 ASSERT_NE(a_node, nullptr);
215 const auto* b_node = graph_view.GetNode("b");
216 ASSERT_NE(b_node, nullptr);
217 const auto* c_node = graph_view.GetNode("c");
218 ASSERT_NE(c_node, nullptr);
219 const auto* d_node = graph_view.GetNode("d");
220 ASSERT_NE(d_node, nullptr);
221
222 EXPECT_EQ(d_node->NumRegularFanins(), 2);
223 ASSERT_NE(d_node->node(), nullptr);
224 ASSERT_EQ(d_node->node()->input_size(), 5);
225 EXPECT_EQ(d_node->node()->input(0), "a:2");
226 EXPECT_EQ(d_node->node()->input(1), "b:1");
227 EXPECT_EQ(d_node->node()->input(2), "^c");
228 EXPECT_EQ(d_node->node()->input(3), "^b");
229 EXPECT_EQ(d_node->node()->input(4), "^a");
230 ASSERT_EQ(d_node->NumControllingFanins(), 3);
231 const auto& d_control_fanins = d_node->GetControllingFanins();
232 ASSERT_EQ(d_control_fanins.size(), 3);
233 ASSERT_NE(d_control_fanins[0].node_view(), nullptr);
234 EXPECT_EQ(d_control_fanins[0].node_view()->GetName(), "c");
235 ASSERT_NE(d_control_fanins[1].node_view(), nullptr);
236 EXPECT_EQ(d_control_fanins[1].node_view()->GetName(), "b");
237 ASSERT_NE(d_control_fanins[2].node_view(), nullptr);
238 EXPECT_EQ(d_control_fanins[2].node_view()->GetName(), "a");
239 }
240
241 template <typename T>
242 class TypedNodeViewTest : public ::testing::Test {};
243 TYPED_TEST_SUITE(TypedNodeViewTest, GraphViewTypes);
244
TYPED_TEST(TypedNodeViewTest,GetName)245 TYPED_TEST(TypedNodeViewTest, GetName) {
246 GraphDef graph = SimpleTestGraph();
247
248 Status s;
249 TypeParam graph_view(&graph, &s);
250 TF_ASSERT_OK(s);
251
252 for (const NodeDef& node : graph.node()) {
253 const auto* node_view = graph_view.GetNode(node.name());
254 ASSERT_NE(node_view, nullptr);
255 EXPECT_EQ(node_view->GetName(), node.name());
256 EXPECT_EQ(node_view->GetName(), node_view->node()->name());
257 }
258 }
259
TYPED_TEST(TypedNodeViewTest,GetOp)260 TYPED_TEST(TypedNodeViewTest, GetOp) {
261 GraphDef graph = GDef({NDef("a", "op_a", {}), NDef("b", "op_b", {}),
262 NDef("c", "op_c", {}), NDef("d", "op_d", {})},
263 /*funcs=*/{});
264
265 Status s;
266 TypeParam graph_view(&graph, &s);
267 TF_ASSERT_OK(s);
268
269 const auto* a_node = graph_view.GetNode("a");
270 ASSERT_NE(a_node, nullptr);
271 EXPECT_EQ(a_node->GetOp(), "op_a");
272 EXPECT_EQ(a_node->node()->op(), "op_a");
273 const auto* b_node = graph_view.GetNode("b");
274 ASSERT_NE(b_node, nullptr);
275 EXPECT_EQ(b_node->GetOp(), "op_b");
276 EXPECT_EQ(b_node->node()->op(), "op_b");
277 const auto* c_node = graph_view.GetNode("c");
278 ASSERT_NE(c_node, nullptr);
279 EXPECT_EQ(c_node->GetOp(), "op_c");
280 EXPECT_EQ(c_node->node()->op(), "op_c");
281 const auto* d_node = graph_view.GetNode("d");
282 ASSERT_NE(d_node, nullptr);
283 EXPECT_EQ(d_node->GetOp(), "op_d");
284 EXPECT_EQ(d_node->node()->op(), "op_d");
285 }
286
TYPED_TEST(TypedNodeViewTest,GetDevice)287 TYPED_TEST(TypedNodeViewTest, GetDevice) {
288 GraphDef graph = GDef(
289 {NDef("a", "", {}, {}, "device_a"), NDef("b", "", {}, {}, "device_b"),
290 NDef("c", "", {}, {}, "device_c"), NDef("d", "", {}, {})},
291 /*funcs=*/{});
292
293 Status s;
294 TypeParam graph_view(&graph, &s);
295 TF_ASSERT_OK(s);
296
297 const auto* a_node = graph_view.GetNode("a");
298 ASSERT_NE(a_node, nullptr);
299 EXPECT_EQ(a_node->GetDevice(), "device_a");
300 EXPECT_EQ(a_node->node()->device(), "device_a");
301 const auto* b_node = graph_view.GetNode("b");
302 ASSERT_NE(b_node, nullptr);
303 EXPECT_EQ(b_node->GetDevice(), "device_b");
304 EXPECT_EQ(b_node->node()->device(), "device_b");
305 const auto* c_node = graph_view.GetNode("c");
306 ASSERT_NE(c_node, nullptr);
307 EXPECT_EQ(c_node->GetDevice(), "device_c");
308 EXPECT_EQ(c_node->node()->device(), "device_c");
309 const auto* d_node = graph_view.GetNode("d");
310 ASSERT_NE(d_node, nullptr);
311 EXPECT_EQ(d_node->GetDevice(), "");
312 EXPECT_EQ(d_node->node()->device(), "");
313 }
314
315 template <typename T>
316 class TypedFaninTest : public ::testing::Test {};
317 using FaninTypes =
318 ::testing::Types<std::pair<FanoutView, GraphView>,
319 std::pair<MutableFanoutView, MutableGraphView>>;
320 TYPED_TEST_SUITE(TypedFaninTest, FaninTypes);
321
TYPED_TEST(TypedFaninTest,GetRegularFanins)322 TYPED_TEST(TypedFaninTest, GetRegularFanins) {
323 using FanoutViewType = typename TypeParam::first_type;
324 using GraphViewType = typename TypeParam::second_type;
325
326 GraphDef graph = SimpleTestGraph();
327
328 Status s;
329 GraphViewType graph_view(&graph, &s);
330 TF_ASSERT_OK(s);
331
332 auto* a_node = graph_view.GetNode("a");
333 ASSERT_NE(a_node, nullptr);
334 auto* b_node = graph_view.GetNode("b");
335 ASSERT_NE(b_node, nullptr);
336 auto* d_node = graph_view.GetNode("d");
337 ASSERT_NE(d_node, nullptr);
338
339 const auto& a_fanins = a_node->GetRegularFanins();
340 ASSERT_EQ(a_fanins.size(), 4);
341 EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, b_node->node_index(), 2));
342 EXPECT_EQ(a_fanins[1], FanoutViewType(&graph_view, d_node->node_index(), 3));
343 EXPECT_EQ(a_fanins[2], FanoutViewType(&graph_view, b_node->node_index(), 2));
344 EXPECT_EQ(a_fanins[3], FanoutViewType(&graph_view, d_node->node_index(), 3));
345
346 const auto& d_fanins = d_node->GetRegularFanins();
347 EXPECT_EQ(d_fanins.size(), 0);
348 }
349
TYPED_TEST(TypedFaninTest,GetRegularFanin)350 TYPED_TEST(TypedFaninTest, GetRegularFanin) {
351 using FanoutViewType = typename TypeParam::first_type;
352 using GraphViewType = typename TypeParam::second_type;
353
354 GraphDef graph = SimpleTestGraph();
355
356 Status s;
357 GraphViewType graph_view(&graph, &s);
358 TF_ASSERT_OK(s);
359
360 auto* a_node = graph_view.GetNode("a");
361 ASSERT_NE(a_node, nullptr);
362 auto* b_node = graph_view.GetNode("b");
363 ASSERT_NE(b_node, nullptr);
364 auto* d_node = graph_view.GetNode("d");
365 ASSERT_NE(d_node, nullptr);
366
367 const auto& a_fanin_0 = a_node->GetRegularFanin(0);
368 EXPECT_EQ(a_fanin_0, FanoutViewType(&graph_view, b_node->node_index(), 2));
369 const auto& a_fanin_1 = a_node->GetRegularFanin(1);
370 EXPECT_EQ(a_fanin_1, FanoutViewType(&graph_view, d_node->node_index(), 3));
371 const auto& a_fanin_2 = a_node->GetRegularFanin(2);
372 EXPECT_EQ(a_fanin_2, FanoutViewType(&graph_view, b_node->node_index(), 2));
373 const auto& a_fanin_3 = a_node->GetRegularFanin(3);
374 EXPECT_EQ(a_fanin_3, FanoutViewType(&graph_view, d_node->node_index(), 3));
375
376 // Out of bounds.
377 const FanoutViewType missing_fanin;
378 EXPECT_EQ(missing_fanin, FanoutViewType(nullptr, -1, -2));
379 EXPECT_EQ(missing_fanin.node_view(), nullptr);
380 const auto& a_fanin_4 = a_node->GetRegularFanin(4);
381 EXPECT_EQ(a_fanin_4, missing_fanin);
382 const auto& a_fanin_5 = a_node->GetRegularFanin(5);
383 EXPECT_EQ(a_fanin_5, missing_fanin);
384 const auto& a_fanin_control = a_node->GetRegularFanin(Graph::kControlSlot);
385 EXPECT_EQ(a_fanin_control, missing_fanin);
386 const auto& a_fanin_bad = a_node->GetRegularFanin(-2);
387 EXPECT_EQ(a_fanin_bad, missing_fanin);
388 }
389
TYPED_TEST(TypedFaninTest,GetControllingFanins)390 TYPED_TEST(TypedFaninTest, GetControllingFanins) {
391 using FanoutViewType = typename TypeParam::first_type;
392 using GraphViewType = typename TypeParam::second_type;
393
394 GraphDef graph = SimpleTestGraph();
395
396 Status s;
397 GraphViewType graph_view(&graph, &s);
398 TF_ASSERT_OK(s);
399
400 auto* a_node = graph_view.GetNode("a");
401 ASSERT_NE(a_node, nullptr);
402 auto* c_node = graph_view.GetNode("c");
403 ASSERT_NE(c_node, nullptr);
404 auto* d_node = graph_view.GetNode("d");
405 ASSERT_NE(d_node, nullptr);
406
407 const auto& a_fanins = a_node->GetControllingFanins();
408 ASSERT_EQ(a_fanins.size(), 1);
409 EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, c_node->node_index(),
410 Graph::kControlSlot));
411
412 const auto& c_fanins = c_node->GetControllingFanins();
413 FanoutViewType d_control_fanin(&graph_view, d_node->node_index(),
414 Graph::kControlSlot);
415 if (std::is_same<GraphViewType, GraphView>::value) {
416 ASSERT_EQ(c_fanins.size(), 2);
417 EXPECT_EQ(c_fanins[0], d_control_fanin);
418 EXPECT_EQ(c_fanins[1], d_control_fanin);
419 } else { // MutableGraphView will dedup control dependency.
420 ASSERT_EQ(c_fanins.size(), 1);
421 EXPECT_EQ(c_fanins[0], d_control_fanin);
422 }
423
424 const auto& d_fanins = d_node->GetControllingFanins();
425 EXPECT_EQ(d_fanins.size(), 0);
426 }
427
428 template <typename T>
429 class TypedFanoutTest : public ::testing::Test {};
430 using FanoutTypes =
431 ::testing::Types<std::pair<FaninView, GraphView>,
432 std::pair<MutableFaninView, MutableGraphView>>;
433 TYPED_TEST_SUITE(TypedFanoutTest, FanoutTypes);
434
TYPED_TEST(TypedFanoutTest,GetRegularFanouts)435 TYPED_TEST(TypedFanoutTest, GetRegularFanouts) {
436 using FaninViewType = typename TypeParam::first_type;
437 using GraphViewType = typename TypeParam::second_type;
438
439 GraphDef graph = SimpleTestGraph();
440
441 Status s;
442 GraphViewType graph_view(&graph, &s);
443 TF_ASSERT_OK(s);
444
445 auto* a_node = graph_view.GetNode("a");
446 ASSERT_NE(a_node, nullptr);
447 auto* b_node = graph_view.GetNode("b");
448 ASSERT_NE(b_node, nullptr);
449 auto* d_node = graph_view.GetNode("d");
450 ASSERT_NE(d_node, nullptr);
451
452 const auto& d_fanouts = d_node->GetRegularFanouts();
453 ASSERT_EQ(d_fanouts.size(), 4);
454 for (int i = 0; i < d_fanouts.size(); ++i) {
455 if (i == 2) {
456 ASSERT_EQ(d_fanouts[i].size(), 1);
457 EXPECT_EQ(d_fanouts[i][0],
458 FaninViewType(&graph_view, b_node->node_index(), 0));
459 } else if (i == 3) {
460 ASSERT_EQ(d_fanouts[i].size(), 2);
461 absl::flat_hash_set<FaninViewType> fanouts(d_fanouts[i].begin(),
462 d_fanouts[i].end());
463 EXPECT_TRUE(fanouts.contains(
464 FaninViewType(&graph_view, a_node->node_index(), 1)));
465 EXPECT_TRUE(fanouts.contains(
466 FaninViewType(&graph_view, a_node->node_index(), 3)));
467 } else {
468 EXPECT_EQ(d_fanouts[i].size(), 0);
469 }
470 }
471
472 const auto& a_fanouts = a_node->GetRegularFanouts();
473 EXPECT_EQ(a_fanouts.size(), 0);
474 }
475
TYPED_TEST(TypedFanoutTest,GetRegularFanout)476 TYPED_TEST(TypedFanoutTest, GetRegularFanout) {
477 using FaninViewType = typename TypeParam::first_type;
478 using GraphViewType = typename TypeParam::second_type;
479
480 GraphDef graph = SimpleTestGraph();
481
482 Status s;
483 GraphViewType graph_view(&graph, &s);
484 TF_ASSERT_OK(s);
485
486 auto* a_node = graph_view.GetNode("a");
487 ASSERT_NE(a_node, nullptr);
488 auto* b_node = graph_view.GetNode("b");
489 ASSERT_NE(b_node, nullptr);
490 auto* d_node = graph_view.GetNode("d");
491 ASSERT_NE(d_node, nullptr);
492
493 const auto& d_fanouts_2 = d_node->GetRegularFanout(2);
494 ASSERT_EQ(d_fanouts_2.size(), 1);
495 EXPECT_EQ(d_fanouts_2.at(0),
496 FaninViewType(&graph_view, b_node->node_index(), 0));
497
498 const auto& d_fanouts_3 = d_node->GetRegularFanout(3);
499 EXPECT_EQ(d_fanouts_3.size(), 2);
500 absl::flat_hash_set<FaninViewType> d_fanouts_3_set(d_fanouts_3.begin(),
501 d_fanouts_3.end());
502 EXPECT_TRUE(d_fanouts_3_set.contains(
503 FaninViewType(&graph_view, a_node->node_index(), 1)));
504 EXPECT_TRUE(d_fanouts_3_set.contains(
505 FaninViewType(&graph_view, a_node->node_index(), 3)));
506
507 // Invalid or empty.
508 const std::vector<FaninViewType> no_fanouts;
509 EXPECT_EQ(d_node->GetRegularFanout(-2), no_fanouts);
510 EXPECT_EQ(d_node->GetRegularFanout(Graph::kControlSlot), no_fanouts);
511 EXPECT_EQ(d_node->GetRegularFanout(0), no_fanouts);
512 EXPECT_EQ(d_node->GetRegularFanout(1), no_fanouts);
513 EXPECT_EQ(d_node->GetRegularFanout(4), no_fanouts);
514 EXPECT_EQ(d_node->GetRegularFanout(5), no_fanouts);
515 }
516
TYPED_TEST(TypedFanoutTest,GetControlledFanouts)517 TYPED_TEST(TypedFanoutTest, GetControlledFanouts) {
518 using FaninViewType = typename TypeParam::first_type;
519 using GraphViewType = typename TypeParam::second_type;
520
521 GraphDef graph = SimpleTestGraph();
522
523 Status s;
524 GraphViewType graph_view(&graph, &s);
525 TF_ASSERT_OK(s);
526
527 auto* a_node = graph_view.GetNode("a");
528 ASSERT_NE(a_node, nullptr);
529 auto* b_node = graph_view.GetNode("b");
530 ASSERT_NE(b_node, nullptr);
531 auto* c_node = graph_view.GetNode("c");
532 ASSERT_NE(c_node, nullptr);
533 auto* d_node = graph_view.GetNode("d");
534 ASSERT_NE(d_node, nullptr);
535
536 const auto& c_fanouts = c_node->GetControlledFanouts();
537 EXPECT_EQ(c_fanouts.size(), 2);
538 absl::flat_hash_set<FaninViewType> c_fanouts_set(c_fanouts.begin(),
539 c_fanouts.end());
540 EXPECT_TRUE(c_fanouts_set.contains(
541 FaninViewType(&graph_view, b_node->node_index(), Graph::kControlSlot)));
542 EXPECT_TRUE(c_fanouts_set.contains(
543 FaninViewType(&graph_view, a_node->node_index(), Graph::kControlSlot)));
544
545 const auto& d_fanouts = d_node->GetControlledFanouts();
546 FaninViewType c_control_fanout(&graph_view, c_node->node_index(),
547 Graph::kControlSlot);
548 if (std::is_same<GraphViewType, GraphView>::value) {
549 ASSERT_EQ(d_fanouts.size(), 2);
550 EXPECT_EQ(d_fanouts[0], c_control_fanout);
551 EXPECT_EQ(d_fanouts[1], c_control_fanout);
552 } else { // MutableGraphView will dedup control dependency.
553 ASSERT_EQ(d_fanouts.size(), 1);
554 EXPECT_EQ(d_fanouts[0], c_control_fanout);
555 }
556
557 const auto& a_fanouts = a_node->GetControlledFanouts();
558 EXPECT_EQ(a_fanouts.size(), 0);
559 }
560
TYPED_TEST(TypedNodeViewTest,NumRegularFanins)561 TYPED_TEST(TypedNodeViewTest, NumRegularFanins) {
562 GraphDef graph = SimpleTestGraph();
563
564 Status s;
565 TypeParam graph_view(&graph, &s);
566 TF_ASSERT_OK(s);
567
568 auto* a_node = graph_view.GetNode("a");
569 ASSERT_NE(a_node, nullptr);
570 auto* b_node = graph_view.GetNode("b");
571 ASSERT_NE(b_node, nullptr);
572 auto* c_node = graph_view.GetNode("c");
573 ASSERT_NE(c_node, nullptr);
574 auto* d_node = graph_view.GetNode("d");
575 ASSERT_NE(d_node, nullptr);
576
577 EXPECT_EQ(a_node->NumRegularFanins(), 4);
578 EXPECT_EQ(b_node->NumRegularFanins(), 2);
579 EXPECT_EQ(c_node->NumRegularFanins(), 0);
580 EXPECT_EQ(d_node->NumRegularFanins(), 0);
581 }
582
TYPED_TEST(TypedNodeViewTest,NumControllingFanins)583 TYPED_TEST(TypedNodeViewTest, NumControllingFanins) {
584 GraphDef graph = SimpleTestGraph();
585
586 Status s;
587 TypeParam graph_view(&graph, &s);
588 TF_ASSERT_OK(s);
589
590 auto* a_node = graph_view.GetNode("a");
591 ASSERT_NE(a_node, nullptr);
592 auto* b_node = graph_view.GetNode("b");
593 ASSERT_NE(b_node, nullptr);
594 auto* c_node = graph_view.GetNode("c");
595 ASSERT_NE(c_node, nullptr);
596 auto* d_node = graph_view.GetNode("d");
597 ASSERT_NE(d_node, nullptr);
598
599 EXPECT_EQ(a_node->NumControllingFanins(), 1);
600 EXPECT_EQ(b_node->NumControllingFanins(), 1);
601 if (std::is_same<TypeParam, GraphView>::value) {
602 EXPECT_EQ(c_node->NumControllingFanins(), 2);
603 } else {
604 EXPECT_EQ(c_node->NumControllingFanins(), 1);
605 }
606 EXPECT_EQ(d_node->NumControllingFanins(), 0);
607 }
608
TYPED_TEST(TypedNodeViewTest,NumRegularFanouts)609 TYPED_TEST(TypedNodeViewTest, NumRegularFanouts) {
610 GraphDef graph = SimpleTestGraph();
611
612 Status s;
613 TypeParam graph_view(&graph, &s);
614 TF_ASSERT_OK(s);
615
616 auto* a_node = graph_view.GetNode("a");
617 ASSERT_NE(a_node, nullptr);
618 auto* b_node = graph_view.GetNode("b");
619 ASSERT_NE(b_node, nullptr);
620 auto* c_node = graph_view.GetNode("c");
621 ASSERT_NE(c_node, nullptr);
622 auto* d_node = graph_view.GetNode("d");
623 ASSERT_NE(d_node, nullptr);
624
625 EXPECT_EQ(a_node->NumRegularFanouts(), 0);
626 EXPECT_EQ(b_node->NumRegularFanouts(), 2);
627 EXPECT_EQ(c_node->NumRegularFanouts(), 1);
628 EXPECT_EQ(d_node->NumRegularFanouts(), 3);
629 }
630
TYPED_TEST(TypedNodeViewTest,NumControlledFanouts)631 TYPED_TEST(TypedNodeViewTest, NumControlledFanouts) {
632 GraphDef graph = SimpleTestGraph();
633
634 Status s;
635 TypeParam graph_view(&graph, &s);
636 TF_ASSERT_OK(s);
637
638 auto* a_node = graph_view.GetNode("a");
639 ASSERT_NE(a_node, nullptr);
640 auto* b_node = graph_view.GetNode("b");
641 ASSERT_NE(b_node, nullptr);
642 auto* c_node = graph_view.GetNode("c");
643 ASSERT_NE(c_node, nullptr);
644 auto* d_node = graph_view.GetNode("d");
645 ASSERT_NE(d_node, nullptr);
646
647 EXPECT_EQ(a_node->NumControlledFanouts(), 0);
648 EXPECT_EQ(b_node->NumControlledFanouts(), 0);
649 EXPECT_EQ(c_node->NumControlledFanouts(), 2);
650 if (std::is_same<TypeParam, GraphView>::value) {
651 EXPECT_EQ(d_node->NumControlledFanouts(), 2);
652 } else {
653 EXPECT_EQ(d_node->NumControlledFanouts(), 1);
654 }
655 }
656
TYPED_TEST(TypedNodeViewTest,HasFanin)657 TYPED_TEST(TypedNodeViewTest, HasFanin) {
658 GraphDef graph = SimpleTestGraph();
659
660 Status s;
661 TypeParam graph_view(&graph, &s);
662 TF_ASSERT_OK(s);
663
664 auto* a_node = graph_view.GetNode("a");
665 ASSERT_NE(a_node, nullptr);
666 auto* b_node = graph_view.GetNode("b");
667 ASSERT_NE(b_node, nullptr);
668 auto* c_node = graph_view.GetNode("c");
669 ASSERT_NE(c_node, nullptr);
670
671 // Existing regular fanin.
672 EXPECT_TRUE(a_node->HasFanin({&graph_view, b_node->node_index(), 2}));
673 // Missing regular fanin.
674 EXPECT_FALSE(a_node->HasFanin({&graph_view, c_node->node_index(), 4}));
675 // Existing controlling fanin.
676 EXPECT_TRUE(a_node->HasFanin(
677 {&graph_view, c_node->node_index(), Graph::kControlSlot}));
678 // Missing controlling fanin.
679 EXPECT_FALSE(a_node->HasFanin(
680 {&graph_view, b_node->node_index(), Graph::kControlSlot}));
681 // Bad fanins.
682 EXPECT_FALSE(a_node->HasFanin({&graph_view, a_node->node_index(), 0}));
683 EXPECT_FALSE(a_node->HasFanin(
684 {&graph_view, b_node->node_index(), internal::kMissingSlot}));
685 }
686
TYPED_TEST(TypedNodeViewTest,HasFanout)687 TYPED_TEST(TypedNodeViewTest, HasFanout) {
688 GraphDef graph = SimpleTestGraph();
689
690 Status s;
691 TypeParam graph_view(&graph, &s);
692 TF_ASSERT_OK(s);
693
694 auto* a_node = graph_view.GetNode("a");
695 ASSERT_NE(a_node, nullptr);
696 auto* b_node = graph_view.GetNode("b");
697 ASSERT_NE(b_node, nullptr);
698 auto* c_node = graph_view.GetNode("c");
699 ASSERT_NE(c_node, nullptr);
700 auto* d_node = graph_view.GetNode("d");
701 ASSERT_NE(d_node, nullptr);
702
703 // Existing regular fanout.
704 EXPECT_TRUE(b_node->HasFanout({&graph_view, a_node->node_index(), 2}));
705 // Missing regular fanout.
706 EXPECT_FALSE(b_node->HasFanout({&graph_view, a_node->node_index(), 1}));
707 // Existing controlled fanout.
708 EXPECT_TRUE(d_node->HasFanout(
709 {&graph_view, c_node->node_index(), Graph::kControlSlot}));
710 // Missing controlled fanout.
711 EXPECT_FALSE(d_node->HasFanout(
712 {&graph_view, a_node->node_index(), Graph::kControlSlot}));
713 // Bad fanouts.
714 EXPECT_FALSE(d_node->HasFanout({&graph_view, d_node->node_index(), 0}));
715 EXPECT_FALSE(a_node->HasFanout({&graph_view, b_node->node_index(), 0}));
716 EXPECT_FALSE(a_node->HasFanout({&graph_view, 4, 0}));
717 EXPECT_FALSE(d_node->HasFanout(
718 {&graph_view, b_node->node_index(), internal::kMissingSlot}));
719 }
720
SimpleAttrTestGraph()721 GraphDef SimpleAttrTestGraph() {
722 return GDef({NDef("a", kNoOp, {}), NDef("b", kNoOp, {}, {{"attr", 1}}),
723 NDef("c", kNoOp, {}, {{"attr_1", "a"}, {"attr_2", 2.0f}})},
724 /*funcs=*/{});
725 }
726
TYPED_TEST(TypedNodeViewTest,GetAttr)727 TYPED_TEST(TypedNodeViewTest, GetAttr) {
728 GraphDef graph = SimpleAttrTestGraph();
729
730 Status s;
731 TypeParam graph_view(&graph, &s);
732 TF_ASSERT_OK(s);
733
734 auto* c_node = graph_view.GetNode("c");
735 ASSERT_NE(c_node, nullptr);
736
737 EXPECT_EQ(c_node->GetAttr("attr_1")->s(), "a");
738 }
739
TYPED_TEST(TypedNodeViewTest,GetAttrs)740 TYPED_TEST(TypedNodeViewTest, GetAttrs) {
741 GraphDef graph = SimpleAttrTestGraph();
742
743 Status s;
744 TypeParam graph_view(&graph, &s);
745 TF_ASSERT_OK(s);
746
747 auto* c_node = graph_view.GetNode("c");
748 ASSERT_NE(c_node, nullptr);
749
750 const auto& actual_attrs = c_node->GetAttrs();
751 EXPECT_EQ(actual_attrs.size(), 2);
752 const auto* attr_1 = actual_attrs.Find("attr_1");
753 EXPECT_NE(attr_1, nullptr);
754 EXPECT_EQ(attr_1->s(), "a");
755 const auto* attr_2 = actual_attrs.Find("attr_2");
756 EXPECT_NE(attr_2, nullptr);
757 EXPECT_EQ(attr_2->f(), 2.0f);
758 }
759
TYPED_TEST(TypedNodeViewTest,NumAttrs)760 TYPED_TEST(TypedNodeViewTest, NumAttrs) {
761 GraphDef graph = SimpleAttrTestGraph();
762
763 Status s;
764 TypeParam graph_view(&graph, &s);
765 TF_ASSERT_OK(s);
766
767 auto* a_node = graph_view.GetNode("a");
768 ASSERT_NE(a_node, nullptr);
769 auto* b_node = graph_view.GetNode("b");
770 ASSERT_NE(b_node, nullptr);
771 auto* c_node = graph_view.GetNode("c");
772 ASSERT_NE(c_node, nullptr);
773
774 EXPECT_EQ(a_node->NumAttrs(), 0);
775 EXPECT_EQ(b_node->NumAttrs(), 1);
776 EXPECT_EQ(c_node->NumAttrs(), 2);
777 }
778
TYPED_TEST(TypedNodeViewTest,HasAttr)779 TYPED_TEST(TypedNodeViewTest, HasAttr) {
780 GraphDef graph = SimpleAttrTestGraph();
781
782 Status s;
783 TypeParam graph_view(&graph, &s);
784 TF_ASSERT_OK(s);
785
786 auto* c_node = graph_view.GetNode("c");
787 ASSERT_NE(c_node, nullptr);
788
789 EXPECT_TRUE(c_node->HasAttr("attr_1"));
790 EXPECT_FALSE(c_node->HasAttr("attr"));
791 }
792
793 class CompareGraphTest : public GrapplerTest {
794 public:
CompareGraphViewWithGraph(MutableGraphView * graph_view,const GraphDef & expected_graph)795 void CompareGraphViewWithGraph(MutableGraphView* graph_view,
796 const GraphDef& expected_graph) {
797 Status s;
798 GraphView expected_graph_view(&expected_graph, &s);
799 TF_ASSERT_OK(s);
800
801 EXPECT_EQ(graph_view->NumNodes(), expected_graph_view.NumNodes());
802
803 for (const NodeView& expected_node_view : expected_graph_view.GetNodes()) {
804 const string& node_name = expected_node_view.GetName();
805 MutableNodeView* node_view = graph_view->GetNode(node_name);
806 ASSERT_NE(node_view, nullptr);
807
808 EXPECT_EQ(node_view->GetName(), expected_node_view.GetName());
809
810 EXPECT_EQ(node_view->GetOp(), expected_node_view.GetOp());
811
812 EXPECT_EQ(node_view->GetDevice(), expected_node_view.GetDevice());
813
814 const int actual_num_fanins = node_view->node()->input_size();
815 EXPECT_EQ(actual_num_fanins, expected_node_view.node()->input_size());
816
817 const int expected_num_regular_fanins =
818 expected_node_view.NumRegularFanins();
819 bool same_num_regular_fanins =
820 node_view->NumRegularFanins() == expected_num_regular_fanins;
821 EXPECT_TRUE(same_num_regular_fanins);
822 for (int i = 0; i < expected_num_regular_fanins; ++i) {
823 const auto& expected_fanin = expected_node_view.GetRegularFanin(i);
824
825 auto* actual_fanin_node =
826 graph_view->GetNode(expected_fanin.node_view()->GetName());
827 ASSERT_NE(actual_fanin_node, nullptr);
828 EXPECT_TRUE(
829 node_view->HasFanin({actual_fanin_node, expected_fanin.index()}));
830 if (i < node_view->NumRegularFanins()) {
831 auto& actual_fanin = node_view->GetRegularFanin(i);
832 EXPECT_EQ(actual_fanin, MutableFanoutView(actual_fanin_node,
833 expected_fanin.index()));
834 EXPECT_EQ(actual_fanin.node_index(),
835 actual_fanin.node_view()->node_index());
836 }
837 }
838
839 if (same_num_regular_fanins) {
840 for (int i = 0; i < expected_num_regular_fanins; ++i) {
841 const auto& fanin = node_view->GetRegularFanin(i);
842 EXPECT_EQ(ParseTensorName(node_view->node()->input(i)),
843 TensorId(fanin.node_view()->GetName(), fanin.index()));
844 }
845 }
846
847 const int expected_num_controlling_fanins =
848 expected_node_view.NumControllingFanins();
849 bool same_num_controlling_fanins =
850 node_view->NumControllingFanins() == expected_num_controlling_fanins;
851 EXPECT_TRUE(same_num_controlling_fanins);
852 for (int i = 0; i < expected_num_controlling_fanins; ++i) {
853 auto& expected_fanin = expected_node_view.GetControllingFanins()[i];
854
855 auto* actual_fanin_node =
856 graph_view->GetNode(expected_fanin.node_view()->GetName());
857 ASSERT_NE(actual_fanin_node, nullptr);
858 MutableFanoutView actual_fanin(actual_fanin_node,
859 expected_fanin.index());
860 EXPECT_TRUE(node_view->HasFanin(actual_fanin));
861
862 int found = 0;
863 for (const auto& actual_fanin : node_view->GetControllingFanins()) {
864 if (actual_fanin.index() == expected_fanin.index() &&
865 actual_fanin.node_view()->GetName() ==
866 expected_fanin.node_view()->GetName()) {
867 EXPECT_EQ(actual_fanin.node_index(),
868 actual_fanin.node_view()->node_index());
869 ++found;
870 }
871 }
872 EXPECT_EQ(found, 1);
873 }
874
875 if (same_num_controlling_fanins && same_num_regular_fanins) {
876 for (int i = 0; i < expected_num_controlling_fanins; ++i) {
877 const auto& fanin = node_view->GetControllingFanins()[i];
878 EXPECT_EQ(ParseTensorName(node_view->node()->input(
879 i + expected_num_regular_fanins)),
880 TensorId(fanin.node_view()->GetName(), fanin.index()));
881 }
882 }
883
884 EXPECT_EQ(node_view->NumRegularFanouts(),
885 expected_node_view.NumRegularFanouts());
886 const int num_output_ports =
887 expected_node_view.GetRegularFanouts().size();
888 ASSERT_EQ(node_view->GetRegularFanouts().size(), num_output_ports);
889 for (int i = 0; i < num_output_ports; ++i) {
890 auto& expected_fanouts_at_port_i = node_view->GetRegularFanouts()[i];
891 const int num_fanouts_at_port = expected_fanouts_at_port_i.size();
892
893 auto& actual_fanouts_at_port_i = node_view->GetRegularFanouts()[i];
894 EXPECT_EQ(actual_fanouts_at_port_i.size(), num_fanouts_at_port);
895
896 for (int j = 0; j < num_fanouts_at_port; ++j) {
897 auto& expected_fanout = expected_fanouts_at_port_i[j];
898
899 auto* actual_fanout_node =
900 graph_view->GetNode(expected_fanout.node_view()->GetName());
901
902 ASSERT_NE(actual_fanout_node, nullptr);
903 MutableFaninView actual_fanout(actual_fanout_node,
904 expected_fanout.index());
905 EXPECT_TRUE(node_view->HasFanout(actual_fanout));
906
907 int found = 0;
908 for (const auto& fanout : actual_fanouts_at_port_i) {
909 if (fanout.index() == expected_fanout.index() &&
910 fanout.node_view()->GetName() ==
911 expected_fanout.node_view()->GetName()) {
912 EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index());
913 ++found;
914 }
915 }
916 EXPECT_EQ(found, 1);
917 }
918 }
919
920 const int num_controlled_fanouts =
921 expected_node_view.NumControlledFanouts();
922 EXPECT_EQ(node_view->NumControlledFanouts(), num_controlled_fanouts);
923 for (int i = 0; i < num_controlled_fanouts; ++i) {
924 const auto& expected_fanout =
925 expected_node_view.GetControlledFanouts()[i];
926
927 auto* actual_fanout_node =
928 graph_view->GetNode(expected_fanout.node_view()->GetName());
929 ASSERT_NE(actual_fanout_node, nullptr);
930 MutableFaninView actual_fanout(actual_fanout_node,
931 expected_fanout.index());
932 EXPECT_TRUE(node_view->HasFanout(actual_fanout));
933
934 int found = 0;
935 for (const auto& fanout : node_view->GetControlledFanouts()) {
936 if (fanout.index() == expected_fanout.index() &&
937 fanout.node_view()->GetName() ==
938 expected_fanout.node_view()->GetName()) {
939 EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index());
940 ++found;
941 }
942 }
943 EXPECT_EQ(found, 1);
944 }
945
946 EXPECT_EQ(node_view->NumAttrs(), expected_node_view.NumAttrs());
947 for (const auto& expected_attr : expected_node_view.GetAttrs()) {
948 auto* attr = node_view->GetAttr(expected_attr.first);
949 EXPECT_TRUE(AreAttrValuesEqual(*attr, expected_attr.second));
950 }
951 }
952 CompareGraphs(*graph_view->graph(), expected_graph);
953 }
954 };
955
956 class MutationTest : public CompareGraphTest {};
957
958 constexpr char kDeviceCPU0[] = "/device:CPU:0";
959 constexpr char kDeviceGPU0[] = "/device:GPU:0";
960
SimpleTestGraphForMutation()961 GraphDef SimpleTestGraphForMutation() {
962 return GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
963 NDef("b", kNoOp, {}, {}, kDeviceCPU0),
964 NDef("c", kNoOp, {}, {}, kDeviceCPU0),
965 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
966 {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
967 /*funcs=*/{});
968 }
969
TEST_F(MutationTest,AddNewNode)970 TEST_F(MutationTest, AddNewNode) {
971 GraphDef graph = SimpleTestGraphForMutation();
972
973 Status s;
974 MutableGraphView graph_view(&graph, &s);
975 TF_ASSERT_OK(s);
976
977 Mutation* mutation = graph_view.GetMutationBuilder();
978
979 NodeDef empty_node;
980 mutation->AddNode(std::move(empty_node), &s);
981 TF_EXPECT_OK(s);
982 s = errors::Internal("error");
983
984 NodeDef valid_node =
985 NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo");
986 mutation->AddNode(std::move(valid_node), &s);
987 TF_EXPECT_OK(s);
988
989 NodeDef bad_node_1 =
990 NDef("bad", "IdentityN", {"^b", "a:1"}, {{"N", 1}}, "foo");
991 mutation->AddNode(std::move(bad_node_1), &s);
992 EXPECT_FALSE(s.ok());
993 EXPECT_EQ(s.error_message(),
994 "Mutation::AddNode error: node 'bad' has regular fanin 'a:1' after "
995 "controlling fanins.");
996
997 NodeDef bad_node_2 = NDef("bad", "IdentityN", {"bad:1"}, {}, "foo");
998 mutation->AddNode(std::move(bad_node_2), &s);
999 EXPECT_FALSE(s.ok());
1000 EXPECT_EQ(s.error_message(),
1001 "Mutation::AddNode error: node 'bad' has self cycle fanin "
1002 "'bad:1'.");
1003
1004 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1005 }
1006
TEST_F(MutationTest,NewNodeBadFaninsAfterAdd)1007 TEST_F(MutationTest, NewNodeBadFaninsAfterAdd) {
1008 GraphDef graph = SimpleTestGraphForMutation();
1009
1010 Status s;
1011 MutableGraphView graph_view(&graph, &s);
1012 TF_ASSERT_OK(s);
1013
1014 Mutation* mutation = graph_view.GetMutationBuilder();
1015
1016 NodeDef valid_node =
1017 NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo");
1018 MutationNewNode new_node = mutation->AddNode(std::move(valid_node), &s);
1019
1020 mutation->AddOrUpdateRegularFanin(new_node, 1, {"valid", 2});
1021 s = mutation->Apply();
1022 EXPECT_FALSE(s.ok());
1023 string expected_error_msg =
1024 "Mutation::Apply error: new node 'valid' is ill-formed.";
1025 EXPECT_EQ(s.error_message(), expected_error_msg);
1026 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1027 }
1028
TEST_F(MutationTest,NewNodesConflictingNames)1029 TEST_F(MutationTest, NewNodesConflictingNames) {
1030 GraphDef graph = SimpleTestGraphForMutation();
1031
1032 Status s;
1033 MutableGraphView graph_view(&graph, &s);
1034 TF_ASSERT_OK(s);
1035
1036 Mutation* mutation = graph_view.GetMutationBuilder();
1037
1038 NodeDef new_node_1 = NDef("a", "", {});
1039 mutation->AddNode(std::move(new_node_1), &s);
1040 TF_EXPECT_OK(s);
1041
1042 NodeDef new_node_2 = NDef("a", "", {});
1043 mutation->AddNode(std::move(new_node_2), &s);
1044 TF_EXPECT_OK(s);
1045
1046 s = mutation->Apply();
1047 EXPECT_FALSE(s.ok());
1048 string expected_error_msg =
1049 "Mutation::Apply error: multiple nodes with the name: 'a' exists in "
1050 "Mutation.";
1051 EXPECT_EQ(s.error_message(), expected_error_msg);
1052 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1053 }
1054
TEST_F(MutationTest,UpdateNodeAndAddSelfLoop)1055 TEST_F(MutationTest, UpdateNodeAndAddSelfLoop) {
1056 GraphDef graph = SimpleTestGraphForMutation();
1057
1058 Status s;
1059 MutableGraphView graph_view(&graph, &s);
1060 TF_ASSERT_OK(s);
1061
1062 Mutation* mutation = graph_view.GetMutationBuilder();
1063
1064 MutableNodeView* d_node = graph_view.GetNode("d");
1065 ASSERT_NE(d_node, nullptr);
1066 mutation->AddControllingFanin(d_node, "d");
1067
1068 s = mutation->Apply();
1069 EXPECT_FALSE(s.ok());
1070 string expected_error_msg =
1071 "Mutation::Apply error: inplace updated node 'd' is ill-formed.";
1072 EXPECT_EQ(s.error_message(), expected_error_msg);
1073 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1074 }
1075
TEST_F(MutationTest,RenameNodeAndAddSelfLoop)1076 TEST_F(MutationTest, RenameNodeAndAddSelfLoop) {
1077 GraphDef graph = SimpleTestGraphForMutation();
1078
1079 Status s;
1080 MutableGraphView graph_view(&graph, &s);
1081 TF_ASSERT_OK(s);
1082
1083 Mutation* mutation = graph_view.GetMutationBuilder();
1084
1085 MutableNodeView* d_node = graph_view.GetNode("d");
1086 ASSERT_NE(d_node, nullptr);
1087 mutation->UpdateNodeName(d_node, "e");
1088 mutation->AddControllingFanin(d_node, "e");
1089
1090 s = mutation->Apply();
1091 EXPECT_FALSE(s.ok());
1092 string expected_error_msg =
1093 "Mutation::Apply error: renamed updated node 'e' ('d') is ill-formed.";
1094 EXPECT_EQ(s.error_message(), expected_error_msg);
1095 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1096 }
1097
TEST_F(MutationTest,ExistingNodesConflictingNames)1098 TEST_F(MutationTest, ExistingNodesConflictingNames) {
1099 GraphDef graph = SimpleTestGraphForMutation();
1100
1101 Status s;
1102 MutableGraphView graph_view(&graph, &s);
1103 TF_ASSERT_OK(s);
1104
1105 Mutation* mutation = graph_view.GetMutationBuilder();
1106
1107 MutableNodeView* a_node = graph_view.GetNode("a");
1108 ASSERT_NE(a_node, nullptr);
1109 mutation->UpdateNodeName(a_node, "b");
1110
1111 MutableNodeView* b_node = graph_view.GetNode("b");
1112 ASSERT_NE(b_node, nullptr);
1113 mutation->UpdateNodeOp(b_node, "Identity");
1114
1115 s = mutation->Apply();
1116 EXPECT_FALSE(s.ok());
1117 string expected_error_msg =
1118 "Mutation::Apply error: multiple nodes with the name: 'b' exists in "
1119 "Mutation.";
1120 EXPECT_EQ(s.error_message(), expected_error_msg);
1121 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1122 }
1123
TEST_F(MutationTest,NewAndExistingNodesConflictingNames)1124 TEST_F(MutationTest, NewAndExistingNodesConflictingNames) {
1125 GraphDef graph = SimpleTestGraphForMutation();
1126
1127 Status s;
1128 MutableGraphView graph_view(&graph, &s);
1129 TF_ASSERT_OK(s);
1130
1131 Mutation* mutation = graph_view.GetMutationBuilder();
1132
1133 NodeDef new_node = NDef("a", "", {});
1134 mutation->AddNode(std::move(new_node), &s);
1135 TF_EXPECT_OK(s);
1136
1137 MutableNodeView* a_node = graph_view.GetNode("a");
1138 ASSERT_NE(a_node, nullptr);
1139 mutation->UpdateNodeDevice(a_node, "foo");
1140
1141 s = mutation->Apply();
1142 EXPECT_FALSE(s.ok());
1143 string expected_error_msg =
1144 "Mutation::Apply error: multiple nodes with the name: 'a' exists in "
1145 "Mutation.";
1146 EXPECT_EQ(s.error_message(), expected_error_msg);
1147 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1148 }
1149
TEST_F(MutationTest,NewAndExistingRenamedNodesConflictingNames)1150 TEST_F(MutationTest, NewAndExistingRenamedNodesConflictingNames) {
1151 GraphDef graph = SimpleTestGraphForMutation();
1152
1153 Status s;
1154 MutableGraphView graph_view(&graph, &s);
1155 TF_ASSERT_OK(s);
1156
1157 Mutation* mutation = graph_view.GetMutationBuilder();
1158
1159 NodeDef new_node = NDef("e", "", {});
1160 mutation->AddNode(std::move(new_node), &s);
1161 TF_EXPECT_OK(s);
1162
1163 MutableNodeView* d_node = graph_view.GetNode("d");
1164 ASSERT_NE(d_node, nullptr);
1165 mutation->UpdateNodeName(d_node, "e");
1166
1167 s = mutation->Apply();
1168 EXPECT_FALSE(s.ok());
1169 string expected_error_msg =
1170 "Mutation::Apply error: multiple nodes with the name: 'e' exists in "
1171 "Mutation.";
1172 EXPECT_EQ(s.error_message(), expected_error_msg);
1173 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1174 }
1175
TEST_F(MutationTest,RemoveNodesWithFanouts)1176 TEST_F(MutationTest, RemoveNodesWithFanouts) {
1177 GraphDef graph = SimpleTestGraphForMutation();
1178
1179 Status s;
1180 MutableGraphView graph_view(&graph, &s);
1181 TF_ASSERT_OK(s);
1182
1183 Mutation* mutation = graph_view.GetMutationBuilder();
1184
1185 MutableNodeView* b_node = graph_view.GetNode("b");
1186 ASSERT_NE(b_node, nullptr);
1187 mutation->RemoveNode(b_node);
1188
1189 s = mutation->Apply();
1190 EXPECT_FALSE(s.ok());
1191 string expected_error_msg =
1192 "Mutation::Apply error: fanout 'd' exist for missing node 'b'.";
1193 EXPECT_EQ(s.error_message(), expected_error_msg);
1194 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1195
1196 MutableNodeView* d_node = graph_view.GetNode("d");
1197 ASSERT_NE(d_node, nullptr);
1198 mutation->RemoveNode(d_node);
1199
1200 TF_EXPECT_OK(mutation->Apply());
1201 GraphDef expected_graph = GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1202 NDef("c", kNoOp, {}, {}, kDeviceCPU0)},
1203 /*funcs=*/{});
1204 CompareGraphViewWithGraph(&graph_view, expected_graph);
1205 }
1206
TEST_F(MutationTest,SwapNodeNamesWithCycle)1207 TEST_F(MutationTest, SwapNodeNamesWithCycle) {
1208 GraphDef graph = SimpleTestGraphForMutation();
1209
1210 Status s;
1211 MutableGraphView graph_view(&graph, &s);
1212 TF_ASSERT_OK(s);
1213
1214 Mutation* mutation = graph_view.GetMutationBuilder();
1215
1216 MutableNodeView* d_node = graph_view.GetNode("d");
1217 ASSERT_NE(d_node, nullptr);
1218 mutation->UpdateNodeName(d_node, "b");
1219 MutableNodeView* b_node = graph_view.GetNode("b");
1220 ASSERT_NE(b_node, nullptr);
1221 mutation->UpdateNodeName(b_node, "d");
1222
1223 s = mutation->Apply();
1224 EXPECT_FALSE(s.ok());
1225 string expected_error_msg =
1226 "Mutation::Apply error: renamed updated node 'b' ('d') is ill-formed.";
1227 EXPECT_EQ(s.error_message(), expected_error_msg);
1228 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1229
1230 mutation->AddOrUpdateRegularFanin(d_node, 1, {"d", 3});
1231 mutation->RemoveControllingFanin(d_node, "b");
1232
1233 TF_EXPECT_OK(mutation->Apply());
1234 GraphDef expected_graph =
1235 GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1236 NDef("d", kNoOp, {}, {}, kDeviceCPU0),
1237 NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1238 NDef("b", kNoOp, {"a:2", "d:3", "a:4", "^c"},
1239 {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1240 /*funcs=*/{});
1241 CompareGraphViewWithGraph(&graph_view, expected_graph);
1242 }
1243
TEST_F(MutationTest,RenamedNodeWithFanouts)1244 TEST_F(MutationTest, RenamedNodeWithFanouts) {
1245 GraphDef graph = SimpleTestGraphForMutation();
1246
1247 Status s;
1248 MutableGraphView graph_view(&graph, &s);
1249 TF_ASSERT_OK(s);
1250
1251 Mutation* mutation = graph_view.GetMutationBuilder();
1252
1253 MutableNodeView* a_node = graph_view.GetNode("a");
1254 ASSERT_NE(a_node, nullptr);
1255 mutation->UpdateNodeName(a_node, "b");
1256
1257 s = mutation->Apply();
1258 EXPECT_FALSE(s.ok());
1259 string expected_error_msg =
1260 "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1261 EXPECT_EQ(s.error_message(), expected_error_msg);
1262 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1263
1264 mutation->UpdateNodeName(a_node, "a");
1265
1266 MutableNodeView* b_node = graph_view.GetNode("b");
1267 ASSERT_NE(b_node, nullptr);
1268 mutation->UpdateNodeName(b_node, "e");
1269
1270 s = mutation->Apply();
1271 EXPECT_FALSE(s.ok());
1272 expected_error_msg =
1273 "Mutation::Apply error: fanout 'd' exist for missing "
1274 "node 'b'.";
1275 EXPECT_EQ(s.error_message(), expected_error_msg);
1276 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1277 }
1278
TEST_F(MutationTest,RemoveExistingNodeAndReplaceWithNewNode)1279 TEST_F(MutationTest, RemoveExistingNodeAndReplaceWithNewNode) {
1280 GraphDef graph = SimpleTestGraphForMutation();
1281
1282 Status s;
1283 MutableGraphView graph_view(&graph, &s);
1284 TF_ASSERT_OK(s);
1285
1286 Mutation* mutation = graph_view.GetMutationBuilder();
1287
1288 MutableNodeView* d_node = graph_view.GetNode("d");
1289 ASSERT_NE(d_node, nullptr);
1290 mutation->RemoveNode(d_node);
1291
1292 NodeDef new_node = NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0);
1293 mutation->AddNode(std::move(new_node), &s);
1294 TF_EXPECT_OK(s);
1295
1296 TF_EXPECT_OK(mutation->Apply());
1297 GraphDef expected_graph =
1298 GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1299 NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1300 NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1301 NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0)},
1302 /*funcs=*/{});
1303 CompareGraphViewWithGraph(&graph_view, expected_graph);
1304 }
1305
TEST_F(MutationTest,UpdateNodeNameAndRemoveFanins)1306 TEST_F(MutationTest, UpdateNodeNameAndRemoveFanins) {
1307 GraphDef graph = SimpleTestGraphForMutation();
1308
1309 Status s;
1310 MutableGraphView graph_view(&graph, &s);
1311 TF_ASSERT_OK(s);
1312
1313 Mutation* mutation = graph_view.GetMutationBuilder();
1314
1315 MutableNodeView* d_node = graph_view.GetNode("d");
1316 ASSERT_NE(d_node, nullptr);
1317 mutation->UpdateNodeName(d_node, "e");
1318 mutation->RemoveRegularFanin(d_node, 1);
1319 mutation->RemoveRegularFanin(d_node, 2);
1320
1321 TF_EXPECT_OK(mutation->Apply());
1322 GraphDef expected_graph =
1323 GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1324 NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1325 NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1326 NDef("e", kNoOp, {"a:2", "^c", "^b"},
1327 {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1328 /*funcs=*/{});
1329 CompareGraphViewWithGraph(&graph_view, expected_graph);
1330 }
1331
TEST_F(MutationTest,UpdateNodeNameAndRemoveRegularFanout)1332 TEST_F(MutationTest, UpdateNodeNameAndRemoveRegularFanout) {
1333 GraphDef graph = SimpleTestGraphForMutation();
1334
1335 Status s;
1336 MutableGraphView graph_view(&graph, &s);
1337 TF_ASSERT_OK(s);
1338
1339 Mutation* mutation = graph_view.GetMutationBuilder();
1340
1341 MutableNodeView* a_node = graph_view.GetNode("a");
1342 ASSERT_NE(a_node, nullptr);
1343 mutation->UpdateNodeName(a_node, "e");
1344
1345 s = mutation->Apply();
1346 EXPECT_FALSE(s.ok());
1347 string expected_error_msg =
1348 "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1349 EXPECT_EQ(s.error_message(), expected_error_msg);
1350 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1351
1352 MutableNodeView* d_node = graph_view.GetNode("d");
1353 ASSERT_NE(d_node, nullptr);
1354 mutation->RemoveRegularFanin(d_node, 2);
1355
1356 s = mutation->Apply();
1357 EXPECT_FALSE(s.ok());
1358 expected_error_msg =
1359 "Mutation::Apply error: fanout 'd' exist for missing node 'a'.";
1360 EXPECT_EQ(s.error_message(), expected_error_msg);
1361 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1362
1363 mutation->AddOrUpdateRegularFanin(d_node, 0, {"b", 1});
1364
1365 TF_EXPECT_OK(mutation->Apply());
1366 GraphDef expected_graph =
1367 GDef({NDef("e", kNoOp, {}, {}, kDeviceCPU0),
1368 NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1369 NDef("c", kNoOp, {}, {}, kDeviceCPU0),
1370 NDef("d", kNoOp, {"b:1", "b:3", "^c", "^b"},
1371 {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)},
1372 /*funcs=*/{});
1373 CompareGraphViewWithGraph(&graph_view, expected_graph);
1374 }
1375
TEST_F(MutationTest,UpdateNodeNameAndRemoveControlledFanout)1376 TEST_F(MutationTest, UpdateNodeNameAndRemoveControlledFanout) {
1377 GraphDef graph = SimpleTestGraphForMutation();
1378
1379 Status s;
1380 MutableGraphView graph_view(&graph, &s);
1381 TF_ASSERT_OK(s);
1382
1383 Mutation* mutation = graph_view.GetMutationBuilder();
1384
1385 MutableNodeView* c_node = graph_view.GetNode("c");
1386 ASSERT_NE(c_node, nullptr);
1387 mutation->UpdateNodeName(c_node, "e");
1388
1389 s = mutation->Apply();
1390 EXPECT_FALSE(s.ok());
1391 string expected_error_msg =
1392 "Mutation::Apply error: fanout 'd' exist for missing node 'c'.";
1393 EXPECT_EQ(s.error_message(), expected_error_msg);
1394 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1395
1396 MutableNodeView* d_node = graph_view.GetNode("d");
1397 ASSERT_NE(d_node, nullptr);
1398 mutation->UpdateNodeDevice(d_node, kDeviceGPU0);
1399
1400 s = mutation->Apply();
1401 EXPECT_FALSE(s.ok());
1402 expected_error_msg =
1403 "Mutation::Apply error: fanout 'd' exist for missing node 'c'.";
1404 EXPECT_EQ(s.error_message(), expected_error_msg);
1405 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1406
1407 mutation->RemoveControllingFanin(d_node, "c");
1408
1409 TF_EXPECT_OK(mutation->Apply());
1410 GraphDef expected_graph =
1411 GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0),
1412 NDef("b", kNoOp, {}, {}, kDeviceCPU0),
1413 NDef("e", kNoOp, {}, {}, kDeviceCPU0),
1414 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^b"},
1415 {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceGPU0)},
1416 /*funcs=*/{});
1417 CompareGraphViewWithGraph(&graph_view, expected_graph);
1418 }
1419
TEST_F(MutationTest,EmptyMutation)1420 TEST_F(MutationTest, EmptyMutation) {
1421 GraphDef graph = SimpleTestGraphForMutation();
1422
1423 Status s;
1424 MutableGraphView graph_view(&graph, &s);
1425 TF_ASSERT_OK(s);
1426
1427 Mutation* mutation = graph_view.GetMutationBuilder();
1428
1429 TF_EXPECT_OK(mutation->Apply());
1430 CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation());
1431 }
1432
1433 constexpr char kIdentity[] = "Identity";
1434 constexpr char kDeviceCPU1[] = "/device:CPU:1";
1435 constexpr char kDeviceGPU1[] = "/device:GPU:1";
1436
TestGraphForMutation()1437 GraphDef TestGraphForMutation() {
1438 return GDef(
1439 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1440 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1441 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1442 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1443 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)},
1444 /*funcs=*/{});
1445 }
1446
TEST_F(MutationTest,SwapNodeNamesWithNoCycle)1447 TEST_F(MutationTest, SwapNodeNamesWithNoCycle) {
1448 GraphDef graph = TestGraphForMutation();
1449
1450 Status s;
1451 MutableGraphView graph_view(&graph, &s);
1452 TF_ASSERT_OK(s);
1453
1454 Mutation* mutation = graph_view.GetMutationBuilder();
1455
1456 MutableNodeView* b_node = graph_view.GetNode("b");
1457 ASSERT_NE(b_node, nullptr);
1458 MutableNodeView* c_node = graph_view.GetNode("c");
1459 ASSERT_NE(c_node, nullptr);
1460
1461 mutation->UpdateNodeName(b_node, "c");
1462 mutation->UpdateNodeName(c_node, "b");
1463
1464 TF_EXPECT_OK(mutation->Apply());
1465 GraphDef expected_graph = GDef(
1466 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1467 NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1468 NDef("b", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1469 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1470 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)},
1471 /*funcs=*/{});
1472 CompareGraphViewWithGraph(&graph_view, expected_graph);
1473 }
1474
TEST_F(MutationTest,RemoveMultipleDependentNodes)1475 TEST_F(MutationTest, RemoveMultipleDependentNodes) {
1476 GraphDef graph = TestGraphForMutation();
1477
1478 Status s;
1479 MutableGraphView graph_view(&graph, &s);
1480 TF_ASSERT_OK(s);
1481
1482 Mutation* mutation = graph_view.GetMutationBuilder();
1483
1484 MutableNodeView* c_node = graph_view.GetNode("c");
1485 ASSERT_NE(c_node, nullptr);
1486 MutableNodeView* d_node = graph_view.GetNode("d");
1487 ASSERT_NE(d_node, nullptr);
1488
1489 mutation->RemoveNode(c_node);
1490 mutation->RemoveNode(d_node);
1491
1492 TF_EXPECT_OK(mutation->Apply());
1493 GraphDef expected_graph = GDef(
1494 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1495 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0)},
1496 /*funcs=*/{});
1497 CompareGraphViewWithGraph(&graph_view, expected_graph);
1498 }
1499
1500 constexpr char kDeviceGPU2[] = "/device:GPU:2";
1501
TEST_F(MutationTest,AddSimpleNewNode)1502 TEST_F(MutationTest, AddSimpleNewNode) {
1503 GraphDef graph = TestGraphForMutation();
1504
1505 Status s;
1506 MutableGraphView graph_view(&graph, &s);
1507 TF_ASSERT_OK(s);
1508
1509 Mutation* mutation = graph_view.GetMutationBuilder();
1510
1511 NodeDef new_node =
1512 NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2);
1513 mutation->AddNode(std::move(new_node), &s);
1514 TF_EXPECT_OK(s);
1515
1516 TF_EXPECT_OK(mutation->Apply());
1517 GraphDef expected_graph = GDef(
1518 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1519 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1520 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1521 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1522 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1523 NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2)},
1524 /*funcs=*/{});
1525 CompareGraphViewWithGraph(&graph_view, expected_graph);
1526 }
1527
1528 constexpr char kDeviceGPU3[] = "/device:GPU:3";
1529
TEST_F(MutationTest,AddAndUpdateNodesWithFanins)1530 TEST_F(MutationTest, AddAndUpdateNodesWithFanins) {
1531 GraphDef graph = TestGraphForMutation();
1532
1533 Status s;
1534 MutableGraphView graph_view(&graph, &s);
1535 TF_ASSERT_OK(s);
1536
1537 Mutation* mutation = graph_view.GetMutationBuilder();
1538
1539 NodeDef new_node_1 = NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"},
1540 {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2);
1541 mutation->AddNode(std::move(new_node_1), &s);
1542 TF_EXPECT_OK(s);
1543
1544 NodeDef new_node_2 =
1545 NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"},
1546 {{"new_node_2_attr_1", 9}}, kDeviceGPU3);
1547 mutation->AddNode(std::move(new_node_2), &s);
1548 TF_EXPECT_OK(s);
1549
1550 MutableNodeView* d_node = graph_view.GetNode("d");
1551 ASSERT_NE(d_node, nullptr);
1552 mutation->AddOrUpdateRegularFanin(d_node, 3, {"c", 6});
1553 mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 5});
1554 mutation->AddControllingFanin(d_node, "new_node_2");
1555
1556 TF_EXPECT_OK(mutation->Apply());
1557 GraphDef expected_graph = GDef(
1558 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1559 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1560 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1561 NDef("d", kNoOp,
1562 {"a:2", "new_node_1:5", "a:4", "c:6", "^c", "^b", "^new_node_2"},
1563 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1564 NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"},
1565 {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2),
1566 NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"},
1567 {{"new_node_2_attr_1", 9}}, kDeviceGPU3)},
1568 /*funcs=*/{});
1569 CompareGraphViewWithGraph(&graph_view, expected_graph);
1570 }
1571
TEST_F(MutationTest,UpdateNodeNameToReplaceExistingNode)1572 TEST_F(MutationTest, UpdateNodeNameToReplaceExistingNode) {
1573 auto test_graph = []() {
1574 return GDef(
1575 {NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0),
1576 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1),
1577 NDef("c", kNoOp, {"b:4", "^a"}, {{"attr_c", "test"}}, kDeviceGPU2),
1578 NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"},
1579 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)},
1580 /*funcs=*/{});
1581 };
1582
1583 GraphDef graph = test_graph();
1584
1585 Status s;
1586 MutableGraphView graph_view(&graph, &s);
1587 TF_ASSERT_OK(s);
1588
1589 Mutation* mutation = graph_view.GetMutationBuilder();
1590
1591 MutableNodeView* b_node = graph_view.GetNode("b");
1592 ASSERT_NE(b_node, nullptr);
1593
1594 mutation->UpdateNodeName(b_node, "c");
1595
1596 TF_EXPECT_OK(mutation->Apply());
1597 GraphDef expected_graph =
1598 GDef({NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0),
1599 NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1),
1600 NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"},
1601 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)},
1602 /*funcs=*/{});
1603 CompareGraphViewWithGraph(&graph_view, expected_graph);
1604 }
1605
TEST_F(MutationTest,NewNodeWithMutations)1606 TEST_F(MutationTest, NewNodeWithMutations) {
1607 GraphDef graph = TestGraphForMutation();
1608
1609 Status s;
1610 MutableGraphView graph_view(&graph, &s);
1611 TF_ASSERT_OK(s);
1612
1613 Mutation* mutation = graph_view.GetMutationBuilder();
1614
1615 NodeDef new_node_def = NDef("node", kNoOp, {"a:2", "b:3", "^c"},
1616 {{"attr_1", 1}, {"attr_2", 2.0f}}, kDeviceGPU3);
1617 MutationNewNode new_node = mutation->AddNode(std::move(new_node_def), &s);
1618 TF_EXPECT_OK(s);
1619
1620 mutation->AddControllingFanin(new_node, "a");
1621 mutation->RemoveControllingFanin(new_node, "c");
1622 mutation->AddOrUpdateRegularFanin(new_node, 0, {"b", 6});
1623 mutation->RemoveRegularFanin(new_node, 1);
1624 mutation->UpdateNodeName(new_node, "new_node");
1625 mutation->UpdateNodeOp(new_node, kIdentity);
1626 mutation->UpdateNodeDevice(new_node, kDeviceGPU2);
1627 AttrValue attr_3;
1628 attr_3.set_s("new_node_attr");
1629 mutation->AddOrUpdateNodeAttr(new_node, "attr_3", attr_3);
1630 AttrValue attr_1;
1631 attr_1.set_b(true);
1632 mutation->AddOrUpdateNodeAttr(new_node, "attr_1", attr_1);
1633 mutation->RemoveNodeAttr(new_node, "attr_2");
1634 AttrValue attr_4;
1635 attr_4.set_type(DT_FLOAT);
1636 mutation->AddOrUpdateNodeAttr(new_node, "T", attr_4);
1637
1638 TF_EXPECT_OK(mutation->Apply());
1639 GraphDef expected_graph = GDef(
1640 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1641 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1642 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1643 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1644 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1645 NDef("new_node", kIdentity, {"b:6", "^a"},
1646 {{"attr_1", true}, {"attr_3", "new_node_attr"}, {"T", DT_FLOAT}},
1647 kDeviceGPU2)},
1648 /*funcs=*/{});
1649 CompareGraphViewWithGraph(&graph_view, expected_graph);
1650 }
1651
TEST_F(MutationTest,UpdatedNodeWithNonFaninMutations)1652 TEST_F(MutationTest, UpdatedNodeWithNonFaninMutations) {
1653 GraphDef graph = TestGraphForMutation();
1654
1655 Status s;
1656 MutableGraphView graph_view(&graph, &s);
1657 TF_ASSERT_OK(s);
1658
1659 MutableNodeView* d_node = graph_view.GetNode("d");
1660 ASSERT_NE(d_node, nullptr);
1661
1662 Mutation* mutation = graph_view.GetMutationBuilder();
1663
1664 mutation->UpdateNodeName(d_node, "e");
1665 mutation->UpdateNodeOp(d_node, kIdentity);
1666 mutation->UpdateNodeDevice(d_node, kDeviceGPU2);
1667 AttrValue attr_d_1;
1668 attr_d_1.set_b(false);
1669 mutation->AddOrUpdateNodeAttr(d_node, "attr_d_1", attr_d_1);
1670 AttrValue attr_e_3;
1671 attr_e_3.set_s("test_string");
1672 mutation->AddOrUpdateNodeAttr(d_node, "attr_e_3", attr_e_3);
1673 mutation->RemoveNodeAttr(d_node, "attr_d_2");
1674 AttrValue attr_e_4;
1675 attr_e_4.set_type(DT_INT64);
1676 mutation->AddOrUpdateNodeAttr(d_node, "T", attr_e_4);
1677
1678 TF_EXPECT_OK(mutation->Apply());
1679 GraphDef expected_graph = GDef(
1680 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1681 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1682 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1683 NDef("e", kIdentity, {"a:2", "b:3", "a:4", "^c", "^b"},
1684 {{"attr_d_1", false}, {"attr_e_3", "test_string"}, {"T", DT_INT64}},
1685 kDeviceGPU2)},
1686 /*funcs=*/{});
1687 CompareGraphViewWithGraph(&graph_view, expected_graph);
1688 }
1689
TEST_F(MutationTest,Reset)1690 TEST_F(MutationTest, Reset) {
1691 GraphDef graph = TestGraphForMutation();
1692
1693 Status s;
1694 MutableGraphView graph_view(&graph, &s);
1695 TF_ASSERT_OK(s);
1696
1697 MutableNodeView* a_node = graph_view.GetNode("a");
1698 ASSERT_NE(a_node, nullptr);
1699
1700 Mutation* mutation = graph_view.GetMutationBuilder();
1701
1702 mutation->UpdateNodeName(a_node, "e");
1703 mutation->AddNode({}, &s);
1704 TF_EXPECT_OK(s);
1705
1706 s = mutation->Apply();
1707 EXPECT_FALSE(s.ok());
1708 string expected_error_msg =
1709 "Mutation::Apply error: fanout 'b' exist for missing node 'a'.";
1710 EXPECT_EQ(s.error_message(), expected_error_msg);
1711 CompareGraphViewWithGraph(&graph_view, TestGraphForMutation());
1712
1713 mutation->Reset();
1714 TF_EXPECT_OK(mutation->Apply());
1715 CompareGraphViewWithGraph(&graph_view, TestGraphForMutation());
1716 }
1717
TEST_F(MutationTest,RenameNodeAndAddNewNodeWithRenamedNodeOldName)1718 TEST_F(MutationTest, RenameNodeAndAddNewNodeWithRenamedNodeOldName) {
1719 GraphDef graph = TestGraphForMutation();
1720
1721 Status s;
1722 MutableGraphView graph_view(&graph, &s);
1723 TF_ASSERT_OK(s);
1724
1725 MutableNodeView* b_node = graph_view.GetNode("b");
1726 ASSERT_NE(b_node, nullptr);
1727
1728 Mutation* mutation = graph_view.GetMutationBuilder();
1729
1730 mutation->UpdateNodeName(b_node, "e");
1731
1732 NodeDef new_node =
1733 NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3);
1734 mutation->AddNode(std::move(new_node), &s);
1735 TF_EXPECT_OK(s);
1736
1737 TF_EXPECT_OK(mutation->Apply());
1738 GraphDef expected_graph = GDef(
1739 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1740 NDef("e", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1741 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1742 NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"},
1743 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1744 NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3)},
1745 /*funcs=*/{});
1746 CompareGraphViewWithGraph(&graph_view, expected_graph);
1747 }
1748
TEST_F(MutationTest,ShiftNodesWithFanouts)1749 TEST_F(MutationTest, ShiftNodesWithFanouts) {
1750 auto test_graph = []() {
1751 return GDef({NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^c", "^b"},
1752 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1753 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1754 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1755 NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}},
1756 kDeviceGPU0)},
1757 /*funcs=*/{});
1758 };
1759
1760 GraphDef graph = test_graph();
1761
1762 Status s;
1763 MutableGraphView graph_view(&graph, &s);
1764 TF_ASSERT_OK(s);
1765
1766 MutableNodeView* c_node = graph_view.GetNode("c");
1767 ASSERT_NE(c_node, nullptr);
1768 MutableNodeView* d_node = graph_view.GetNode("d");
1769 ASSERT_NE(d_node, nullptr);
1770
1771 Mutation* mutation = graph_view.GetMutationBuilder();
1772
1773 mutation->RemoveControllingFanin(d_node, "c");
1774 mutation->RemoveNode(c_node);
1775
1776 TF_EXPECT_OK(mutation->Apply());
1777 GraphDef expected_graph = GDef(
1778 {NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^b"},
1779 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1780 NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0),
1781 NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0)},
1782 /*funcs=*/{});
1783 CompareGraphViewWithGraph(&graph_view, expected_graph);
1784 }
1785
TEST_F(MutationTest,RemoveFaninFanoutAndShiftFanout)1786 TEST_F(MutationTest, RemoveFaninFanoutAndShiftFanout) {
1787 auto test_graph = []() {
1788 return GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0),
1789 NDef("b", kNoOp, {"a:2", "a:1"}, {}, kDeviceGPU1),
1790 NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)},
1791 /*funcs=*/{});
1792 };
1793
1794 GraphDef graph = test_graph();
1795
1796 Status s;
1797 MutableGraphView graph_view(&graph, &s);
1798 TF_ASSERT_OK(s);
1799
1800 MutableNodeView* b_node = graph_view.GetNode("b");
1801 ASSERT_NE(b_node, nullptr);
1802
1803 Mutation* mutation = graph_view.GetMutationBuilder();
1804
1805 mutation->RemoveRegularFanin(b_node, 1);
1806
1807 TF_EXPECT_OK(mutation->Apply());
1808 GraphDef expected_graph =
1809 GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0),
1810 NDef("b", kNoOp, {"a:2"}, {}, kDeviceGPU1),
1811 NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)},
1812 /*funcs=*/{});
1813 CompareGraphViewWithGraph(&graph_view, expected_graph);
1814 }
1815
TEST_F(MutationTest,ConsecutiveMutations)1816 TEST_F(MutationTest, ConsecutiveMutations) {
1817 GraphDef graph = TestGraphForMutation();
1818
1819 Status s;
1820 MutableGraphView graph_view(&graph, &s);
1821 TF_ASSERT_OK(s);
1822
1823 MutableNodeView* b_node = graph_view.GetNode("b");
1824 ASSERT_NE(b_node, nullptr);
1825 MutableNodeView* d_node = graph_view.GetNode("d");
1826 ASSERT_NE(d_node, nullptr);
1827
1828 Mutation* mutation = graph_view.GetMutationBuilder();
1829
1830 mutation->RemoveNode(b_node);
1831 mutation->AddOrUpdateRegularFanin(d_node, 1, {"c", 5});
1832 mutation->RemoveControllingFanin(d_node, "b");
1833
1834 NodeDef new_node_1 = NDef("new_node_1", kIdentity, {"a:3", "d:5", "^d"},
1835 {{"T", DT_FLOAT}}, kDeviceGPU2);
1836 MutationNewNode new_node_1_node =
1837 mutation->AddNode(std::move(new_node_1), &s);
1838 TF_EXPECT_OK(s);
1839
1840 mutation->AddOrUpdateRegularFanin(new_node_1_node, 0, {"c", 5});
1841 mutation->RemoveRegularFanin(new_node_1_node, 1);
1842 mutation->AddOrUpdateRegularFanin(new_node_1_node, 1, {"a", 6});
1843 mutation->AddControllingFanin(new_node_1_node, "a");
1844 mutation->RemoveControllingFanin(new_node_1_node, "d");
1845
1846 TF_EXPECT_OK(mutation->Apply());
1847 GraphDef expected_graph = GDef(
1848 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1849 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1850 NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^c"},
1851 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1852 NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}},
1853 kDeviceGPU2)},
1854 /*funcs=*/{});
1855 CompareGraphViewWithGraph(&graph_view, expected_graph);
1856
1857 d_node = graph_view.GetNode("d");
1858 ASSERT_NE(d_node, nullptr);
1859
1860 mutation->AddOrUpdateRegularFanin(d_node, 3, {"new_node_2", 6});
1861 mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 8});
1862 mutation->AddControllingFanin(d_node, "new_node_2");
1863 mutation->AddControllingFanin(d_node, "a");
1864 mutation->RemoveControllingFanin(d_node, "c");
1865
1866 NodeDef new_node_2 =
1867 NDef("new_node_2", kNoOp, {"c:4", "new_node_1:5", "^d", "^c"});
1868 MutationNewNode new_node_2_node =
1869 mutation->AddNode(std::move(new_node_2), &s);
1870 TF_EXPECT_OK(s);
1871
1872 mutation->UpdateNodeDevice(new_node_2_node, kDeviceGPU3);
1873 mutation->AddOrUpdateRegularFanin(new_node_2_node, 0, {"new_node_1", 4});
1874 mutation->RemoveRegularFanin(new_node_2_node, 1);
1875 mutation->RemoveControllingFanin(new_node_2_node, "c");
1876 mutation->AddControllingFanin(new_node_2_node, "a");
1877 mutation->AddControllingFanin(new_node_2_node, "new_node_1");
1878
1879 TF_EXPECT_OK(mutation->Apply());
1880 expected_graph = GDef(
1881 {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0),
1882 NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1),
1883 NDef("d", kNoOp,
1884 {"a:2", "new_node_1:8", "a:4", "new_node_2:6", "^new_node_2", "^a"},
1885 {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1),
1886 NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}},
1887 kDeviceGPU2),
1888 NDef("new_node_2", kNoOp, {"new_node_1:4", "^d", "^a", "^new_node_1"},
1889 {}, kDeviceGPU3)},
1890 /*funcs=*/{});
1891 CompareGraphViewWithGraph(&graph_view, expected_graph);
1892 }
1893
1894 constexpr char kMatchingFiles[] = "MatchingFiles";
1895
TEST_F(MutationTest,OpWithUnsupportedDevice)1896 TEST_F(MutationTest, OpWithUnsupportedDevice) {
1897 GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
1898 auto test_graph = []() {
1899 return GDef({NDef("a", kMatchingFiles, {}, {}, kDeviceCPU0)},
1900 /*funcs=*/{});
1901 };
1902
1903 GraphDef graph = test_graph();
1904
1905 Status s;
1906 MutableGraphView graph_view(&graph, &s);
1907 TF_ASSERT_OK(s);
1908
1909 MutableNodeView* a_node = graph_view.GetNode("a");
1910 ASSERT_NE(a_node, nullptr);
1911
1912 Mutation* mutation = graph_view.GetMutationBuilder();
1913
1914 // Unsupported device.
1915 mutation->UpdateNodeDevice(a_node, kDeviceGPU1);
1916
1917 s = mutation->Apply();
1918 EXPECT_FALSE(s.ok());
1919 CompareGraphViewWithGraph(&graph_view, test_graph());
1920
1921 mutation->Reset();
1922
1923 // New node with unsupported device.
1924 NodeDef new_node = NDef("new_node", kMatchingFiles, {}, {}, kDeviceGPU2);
1925 mutation->AddNode(std::move(new_node), &s);
1926 TF_EXPECT_OK(s);
1927
1928 s = mutation->Apply();
1929 EXPECT_FALSE(s.ok());
1930 CompareGraphViewWithGraph(&graph_view, test_graph());
1931 }
1932
TEST_F(MutationTest,OpMissingAttribute)1933 TEST_F(MutationTest, OpMissingAttribute) {
1934 GTEST_SKIP() << "Reenable once offline optimization tests enable CUDA.";
1935 auto test_graph = []() {
1936 return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)},
1937 /*funcs=*/{});
1938 };
1939
1940 GraphDef graph = test_graph();
1941
1942 Status s;
1943 MutableGraphView graph_view(&graph, &s);
1944 TF_ASSERT_OK(s);
1945
1946 MutableNodeView* a_node = graph_view.GetNode("a");
1947 ASSERT_NE(a_node, nullptr);
1948
1949 Mutation* mutation = graph_view.GetMutationBuilder();
1950
1951 // Remove necessary attribute.
1952 mutation->RemoveNodeAttr(a_node, "T");
1953
1954 s = mutation->Apply();
1955 EXPECT_FALSE(s.ok());
1956 CompareGraphViewWithGraph(&graph_view, test_graph());
1957
1958 mutation->Reset();
1959
1960 // New node without necessary attribute.
1961 NodeDef new_node = NDef("new_node", kIdentity, {}, {}, kDeviceGPU2);
1962 mutation->AddNode(std::move(new_node), &s);
1963 TF_EXPECT_OK(s);
1964
1965 s = mutation->Apply();
1966 EXPECT_FALSE(s.ok());
1967 CompareGraphViewWithGraph(&graph_view, test_graph());
1968 }
1969
TEST_F(MutationTest,EmptyMutationUpdateIndexPersisting)1970 TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) {
1971 auto test_graph = []() {
1972 return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)},
1973 /*funcs=*/{});
1974 };
1975
1976 GraphDef graph = test_graph();
1977
1978 Status s;
1979 MutableGraphView graph_view(&graph, &s);
1980 TF_ASSERT_OK(s);
1981
1982 MutableNodeView* a_node = graph_view.GetNode("a");
1983 ASSERT_NE(a_node, nullptr);
1984
1985 Mutation* mutation = graph_view.GetMutationBuilder();
1986
1987 // Empty MutableNodeViewDiff.
1988 mutation->UpdateNodeName(a_node, "a");
1989
1990 TF_EXPECT_OK(mutation->Apply());
1991 CompareGraphViewWithGraph(&graph_view, test_graph());
1992
1993 mutation->Reset();
1994
1995 // Empty MutableNodeViewDiff, `update_index_` should not persist.
1996 mutation->UpdateNodeName(a_node, "a");
1997
1998 TF_EXPECT_OK(mutation->Apply());
1999 CompareGraphViewWithGraph(&graph_view, test_graph());
2000 }
2001
2002 class TopologicalSortTest : public CompareGraphTest {
2003 protected:
CompareGraphOrder(const MutableGraphView & graph_view,absl::Span<const string> node_names)2004 void CompareGraphOrder(const MutableGraphView& graph_view,
2005 absl::Span<const string> node_names) {
2006 const int num_nodes = graph_view.NumNodes();
2007 ASSERT_EQ(num_nodes, node_names.size());
2008 for (int i = 0; i < num_nodes; ++i) {
2009 EXPECT_EQ(graph_view.GetNode(i)->GetName(), node_names[i]);
2010 }
2011 }
2012
CompareGraphNodePrecedences(const MutableGraphView & graph_view,absl::Span<const std::pair<string,string>> node_precedences)2013 void CompareGraphNodePrecedences(
2014 const MutableGraphView& graph_view,
2015 absl::Span<const std::pair<string, string>> node_precedences) {
2016 for (const auto& node_precedence : node_precedences) {
2017 auto* parent_node = graph_view.GetNode(node_precedence.first);
2018 ASSERT_NE(parent_node, nullptr);
2019 auto* child_node = graph_view.GetNode(node_precedence.second);
2020 ASSERT_NE(child_node, nullptr);
2021 EXPECT_TRUE(parent_node->node_index() < child_node->node_index());
2022 }
2023 }
2024 };
2025
TEST_F(TopologicalSortTest,ActiveMutationSort)2026 TEST_F(TopologicalSortTest, ActiveMutationSort) {
2027 auto test_graph = []() {
2028 return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2029 NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2030 /*funcs=*/{});
2031 };
2032
2033 GraphDef graph = test_graph();
2034 Status status;
2035 MutableGraphView graph_view(&graph, &status);
2036 TF_ASSERT_OK(status);
2037
2038 Mutation* mutation = graph_view.GetMutationBuilder();
2039 mutation->AddNode({}, &status);
2040 TF_ASSERT_OK(status);
2041
2042 for (bool ignore_cycles : {false, true}) {
2043 status = graph_view.SortTopologically(ignore_cycles, {});
2044 EXPECT_FALSE(status.ok());
2045 EXPECT_EQ(
2046 status.error_message(),
2047 "MutableGraphView::SortTopologically error: active mutation exists.");
2048 CompareGraphViewWithGraph(&graph_view, test_graph());
2049 CompareGraphOrder(graph_view, {"a", "b"});
2050 }
2051 }
2052
TEST_F(TopologicalSortTest,BadExtraDependenciesSort)2053 TEST_F(TopologicalSortTest, BadExtraDependenciesSort) {
2054 auto test_graph = []() {
2055 return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2056 NDef("b", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2057 /*funcs=*/{});
2058 };
2059
2060 GraphDef graph_1 = test_graph();
2061 Status status;
2062 MutableGraphView graph_view_1(&graph_1, &status);
2063 TF_ASSERT_OK(status);
2064 MutableNodeView* a_node_1 = graph_view_1.GetNode("a");
2065
2066 GraphDef graph_2 = test_graph();
2067 MutableGraphView graph_view_2(&graph_2, &status);
2068 TF_ASSERT_OK(status);
2069 MutableNodeView* b_node_2 = graph_view_2.GetNode("b");
2070
2071 for (bool ignore_cycles : {false, true}) {
2072 status =
2073 graph_view_2.SortTopologically(ignore_cycles, {{a_node_1, b_node_2}});
2074 EXPECT_FALSE(status.ok());
2075 EXPECT_EQ(status.error_message(),
2076 "MutableGraphView::SortTopologically error: invalid extra "
2077 "dependencies.");
2078 CompareGraphViewWithGraph(&graph_view_2, test_graph());
2079 CompareGraphOrder(graph_view_2, {"a", "b"});
2080 }
2081 }
2082
TEST_F(TopologicalSortTest,NoCyclesAllowed)2083 TEST_F(TopologicalSortTest, NoCyclesAllowed) {
2084 auto test_graph = []() {
2085 return GDef(
2086 {NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2087 NDef("b", kIdentity, {"a", "c"}, {{"T", DT_FLOAT}}, kDeviceGPU1),
2088 NDef("c", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2089 /*funcs=*/{});
2090 };
2091
2092 GraphDef graph = test_graph();
2093 Status status;
2094 MutableGraphView graph_view(&graph, &status);
2095 TF_ASSERT_OK(status);
2096
2097 status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2098 EXPECT_FALSE(status.ok());
2099 EXPECT_EQ(status.error_message(),
2100 "MutableGraphView::SortTopologically error: detected edge(s) "
2101 "creating cycle(s) {'c' -> 'b'}.");
2102 CompareGraphViewWithGraph(&graph_view, test_graph());
2103 CompareGraphOrder(graph_view, {"a", "b", "c"});
2104
2105 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2106 CompareGraphViewWithGraph(&graph_view, test_graph());
2107 CompareGraphNodePrecedences(graph_view, {{"a", "b"}, {"a", "c"}});
2108 }
2109
TEST_F(TopologicalSortTest,NoNodesWithZeroFanins)2110 TEST_F(TopologicalSortTest, NoNodesWithZeroFanins) {
2111 auto test_graph = []() {
2112 return GDef({NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2113 NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2114 /*funcs=*/{});
2115 };
2116
2117 GraphDef graph = test_graph();
2118 Status status;
2119 MutableGraphView graph_view(&graph, &status);
2120 TF_ASSERT_OK(status);
2121
2122 status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2123 EXPECT_FALSE(status.ok());
2124 EXPECT_EQ(status.error_message(),
2125 "MutableGraphView::SortTopologically error: was not able to sort "
2126 "all nodes topologically.");
2127 CompareGraphViewWithGraph(&graph_view, test_graph());
2128 CompareGraphOrder(graph_view, {"a", "b"});
2129
2130 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2131 CompareGraphViewWithGraph(&graph_view, test_graph());
2132 }
2133
TEST_F(TopologicalSortTest,DidNotReachAllNodes)2134 TEST_F(TopologicalSortTest, DidNotReachAllNodes) {
2135 auto test_graph = []() {
2136 return GDef({NDef("c", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU2),
2137 NDef("a", kIdentity, {"b"}, {{"T", DT_FLOAT}}, kDeviceGPU0),
2138 NDef("b", kIdentity, {"a"}, {{"T", DT_FLOAT}}, kDeviceGPU1)},
2139 /*funcs=*/{});
2140 };
2141
2142 GraphDef graph = test_graph();
2143 Status status;
2144 MutableGraphView graph_view(&graph, &status);
2145 TF_ASSERT_OK(status);
2146
2147 status = graph_view.SortTopologically(/*ignore_cycles=*/false, {});
2148 EXPECT_FALSE(status.ok());
2149 EXPECT_EQ(status.error_message(),
2150 "MutableGraphView::SortTopologically error: was not able to sort "
2151 "all nodes topologically.");
2152 CompareGraphViewWithGraph(&graph_view, test_graph());
2153 CompareGraphOrder(graph_view, {"c", "a", "b"});
2154
2155 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/true, {}));
2156 CompareGraphViewWithGraph(&graph_view, test_graph());
2157 CompareGraphOrder(graph_view, {"a", "b", "c"});
2158 }
2159
TEST_F(TopologicalSortTest,NoLoopGraph)2160 TEST_F(TopologicalSortTest, NoLoopGraph) {
2161 auto test_graph = []() {
2162 return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
2163 NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
2164 NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
2165 /*funcs=*/{});
2166 };
2167
2168 GraphDef graph = test_graph();
2169 Status status;
2170 MutableGraphView graph_view(&graph, &status);
2171 TF_ASSERT_OK(status);
2172
2173 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2174 CompareGraphViewWithGraph(&graph_view, test_graph());
2175 CompareGraphNodePrecedences(
2176 graph_view,
2177 {{"f", "a"}, {"f", "c"}, {"e", "a"}, {"e", "b"}, {"c", "d"}, {"d", "b"}});
2178 }
2179
TEST_F(TopologicalSortTest,ValidLoopGraph)2180 TEST_F(TopologicalSortTest, ValidLoopGraph) {
2181 // Control flow loop.
2182 auto test_graph = []() {
2183 return GDef(
2184 {NDef("while/Const_1", "Const", {}),
2185 NDef("while/Enter_2", "Enter", {"while/Const_1"},
2186 {{"frame_name", "while/while_context"}}),
2187 NDef("while/Const", "Const", {}),
2188 NDef("while/Enter_1", "Enter", {"while/Const"},
2189 {{"frame_name", "while/while_context"}}),
2190 NDef("while/iteration_counter", "Const", {}),
2191 NDef("while/Enter", "Enter", {"while/iteration_counter"},
2192 {{"frame_name", "while/while_context"}}),
2193 NDef("while/maximum_iterations", "Const", {}),
2194 NDef("while/Less/Enter", "Enter", {"while/maximum_iterations"},
2195 {{"frame_name", "while/while_context"}}),
2196 NDef("while/Less", "Less", {"while/Merge", "while/Less/Enter"}),
2197 NDef("while/LogicalAnd", "LogicalAnd",
2198 {"while/Less", "while/cond/Merge"}),
2199 NDef("while/LoopCond", "LoopCond", {"while/LogicalAnd"}),
2200 NDef("while/Switch", "Switch", {"while/Merge", "while/LoopCond"},
2201 {{"_class", "loc:@while/Merge"}}),
2202 NDef("while/Identity", "Identity", {"while/Switch:1"}),
2203 NDef("while/add", "Add", {"while/Identity", "while/add/y"}),
2204 NDef("while/NextIteration", "NextIteration", {"while/add"}),
2205 NDef("while/Merge", "Merge", {"while/Enter", "while/NextIteration"}),
2206 NDef("while/Less_1/y", "Const", {"^while/Merge"}),
2207 NDef("while/add/y", "Const", {"^while/Identity"}),
2208 NDef("while/mul/y", "Const", {"^while/Identity"}),
2209 NDef("while/add_2/y", "Const", {"^while/Identity"}),
2210 NDef("while/Switch_1", "Switch", {"while/Merge_1", "while/LoopCond"},
2211 {{"_class", "loc:@while/Merge_1"}}),
2212 NDef("while/Identity_1", "Identity", {"while/Switch_1:1"}),
2213 NDef("while/add_2", "Add", {"while/Identity_1", "while/add_2/y"}),
2214 NDef("while/NextIteration_1", "NextIteration", {"while/add_2"}),
2215 NDef("while/Merge_1", "Merge",
2216 {"while/Enter_1", "while/NextIteration_1"}),
2217 NDef("while/Less_1", "Less", {"while/Merge_1", "while/Less_1/y"}),
2218 NDef("while/cond/Switch", "Switch", {"while/Less_1", "while/Less_1"}),
2219 NDef("while/cond/switch_f", "Identity", {"while/cond/Switch"}),
2220 NDef("while/cond/Const_1", "Const", {"^while/cond/switch_f"}),
2221 NDef("while/cond/switch_t", "Identity", {"while/cond/Switch:1"}),
2222 NDef("while/cond/Const", "Const", {"^while/cond/switch_t"}),
2223 NDef("while/cond/Merge", "Merge",
2224 {"while/cond/Const_1", "while/cond/Const"}),
2225 NDef("TensorArrayUnstack/range/delta", "Const", {}),
2226 NDef("TensorArrayUnstack/range/start", "Const", {}),
2227 NDef("TensorArrayUnstack/strided_slice/stack_2", "Const", {}),
2228 NDef("TensorArrayUnstack/strided_slice/stack_1", "Const", {}),
2229 NDef("TensorArrayUnstack/strided_slice/stack", "Const", {}),
2230 NDef("TensorArrayUnstack/Shape", "Const", {}),
2231 NDef("TensorArrayUnstack/strided_slice", "StridedSlice",
2232 {"TensorArrayUnstack/Shape",
2233 "TensorArrayUnstack/strided_slice/stack",
2234 "TensorArrayUnstack/strided_slice/stack_1",
2235 "TensorArrayUnstack/strided_slice/stack_2"}),
2236 NDef("TensorArrayUnstack/range", "Range",
2237 {"TensorArrayUnstack/range/start",
2238 "TensorArrayUnstack/strided_slice",
2239 "TensorArrayUnstack/range/delta"}),
2240 NDef("TensorArray/size", "Const", {}),
2241 NDef("TensorArray", "TensorArrayV3", {"TensorArray/size"}),
2242 NDef("while/TensorArrayReadV3/Enter", "Enter", {"TensorArray"},
2243 {{"frame_name", "while/while_context"}}),
2244 NDef("Const", "Const", {}),
2245 NDef("TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3",
2246 "TensorArrayScatterV3",
2247 {"TensorArray", "TensorArrayUnstack/range", "Const",
2248 "TensorArray:1"},
2249 {{"_class", "loc@Const"}}),
2250 NDef("while/TensorArrayReadV3/Enter_1", "Enter",
2251 {"TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3"},
2252 {{"frame_name", "while/while_context"}}),
2253 NDef("while/TensorArrayReadV3", "TensorArrayReadV3",
2254 {"while/TensorArrayReadV3/Enter", "while/Identity_1",
2255 "while/TensorArrayReadV3/Enter_1"}),
2256 NDef("while/add_1", "Add", {"while/mul", "while/TensorArrayReadV3"}),
2257 NDef("while/NextIteration_2", "NextIteration", {"while/add_1"}),
2258 NDef("while/Merge_2", "Merge",
2259 {"while/Enter_2", "while/NextIteration_2"}),
2260 NDef("while/Switch_2", "Switch", {"while/Merge_2", "while/LoopCond"},
2261 {{"_class", "loc@while/Merge_2"}}),
2262 NDef("while/Exit_2", "Exit", {"while/Switch_2"}),
2263 NDef("while/Identity_2", "Identity", {"while/Switch_2:1"}),
2264 NDef("while/mul", "Mul", {"while/Identity_2", "while/mul/y"})},
2265 /*funcs=*/{});
2266 };
2267
2268 GraphDef graph = test_graph();
2269 Status status;
2270 MutableGraphView graph_view(&graph, &status);
2271 TF_ASSERT_OK(status);
2272
2273 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2274 CompareGraphViewWithGraph(&graph_view, test_graph());
2275 }
2276
TEST_F(TopologicalSortTest,DuplicateFanins)2277 TEST_F(TopologicalSortTest, DuplicateFanins) {
2278 auto test_graph = []() {
2279 return GDef(
2280 {NDef("b", kIdentity, {"a", "a", "^a"}), NDef("a", "Const", {})},
2281 /*funcs=*/{});
2282 };
2283
2284 GraphDef graph = test_graph();
2285 Status status;
2286 MutableGraphView graph_view(&graph, &status);
2287 TF_ASSERT_OK(status);
2288
2289 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2290 CompareGraphViewWithGraph(&graph_view, test_graph());
2291 CompareGraphOrder(graph_view, {"a", "b"});
2292 }
2293
TEST_F(TopologicalSortTest,DiamondDependencyNotACycle)2294 TEST_F(TopologicalSortTest, DiamondDependencyNotACycle) {
2295 auto test_graph = []() {
2296 return GDef({NDef("e", kIdentity, {"b", "c", "d"}),
2297 NDef("b", kIdentity, {"a"}), NDef("a", "Const", {}),
2298 NDef("d", kIdentity, {"a"}), NDef("c", kIdentity, {"a"})},
2299 /*funcs=*/{});
2300 };
2301
2302 GraphDef graph = test_graph();
2303 Status status;
2304 MutableGraphView graph_view(&graph, &status);
2305 TF_ASSERT_OK(status);
2306
2307 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2308 CompareGraphViewWithGraph(&graph_view, test_graph());
2309 CompareGraphNodePrecedences(
2310 graph_view,
2311 {{"a", "b"}, {"a", "c"}, {"a", "d"}, {"b", "e"}, {"c", "e"}, {"d", "e"}});
2312 }
2313
TEST_F(TopologicalSortTest,ExtraDependencies)2314 TEST_F(TopologicalSortTest, ExtraDependencies) {
2315 auto test_graph = []() {
2316 return GDef({NDef("c", kIdentity, {"f"}), NDef("a", kIdentity, {"f", "e"}),
2317 NDef("b", kIdentity, {"e", "d"}), NDef("d", kIdentity, {"c"}),
2318 NDef("f", kIdentity, {}), NDef("e", kIdentity, {})},
2319 /*funcs=*/{});
2320 };
2321
2322 GraphDef graph = test_graph();
2323 Status status;
2324 MutableGraphView graph_view(&graph, &status);
2325 TF_ASSERT_OK(status);
2326
2327 auto* e_node = graph_view.GetNode("e");
2328 ASSERT_NE(e_node, nullptr);
2329 auto* f_node = graph_view.GetNode("f");
2330 ASSERT_NE(f_node, nullptr);
2331
2332 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false,
2333 {{e_node, f_node}}));
2334 CompareGraphViewWithGraph(&graph_view, test_graph());
2335 CompareGraphNodePrecedences(graph_view, {{"f", "a"},
2336 {"f", "c"},
2337 {"e", "a"},
2338 {"e", "b"},
2339 {"c", "d"},
2340 {"d", "b"},
2341 {"e", "f"}});
2342 }
2343
TEST_F(TopologicalSortTest,PushVisitedNodes)2344 TEST_F(TopologicalSortTest, PushVisitedNodes) {
2345 auto test_graph = []() {
2346 return GDef({NDef("d", kIdentity, {"c"}), NDef("c", kIdentity, {"b", "a"}),
2347 NDef("b", kIdentity, {"a"}), NDef("a", kIdentity, {})},
2348 /*funcs=*/{});
2349 };
2350
2351 GraphDef graph = test_graph();
2352 Status status;
2353 MutableGraphView graph_view(&graph, &status);
2354 TF_ASSERT_OK(status);
2355
2356 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2357 CompareGraphViewWithGraph(&graph_view, test_graph());
2358 CompareGraphNodePrecedences(graph_view,
2359 {{"a", "b"}, {"a", "c"}, {"b", "c"}, {"c", "d"}});
2360 }
2361
2362 #define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \
2363 BENCHMARK(name) \
2364 ->ArgPair(10, 2) \
2365 ->ArgPair(100, 2) \
2366 ->ArgPair(1000, 2) \
2367 ->ArgPair(10000, 2) \
2368 ->ArgPair(25000, 2) \
2369 ->ArgPair(50000, 2) \
2370 ->ArgPair(100000, 2) \
2371 ->ArgPair(10, 4) \
2372 ->ArgPair(100, 4) \
2373 ->ArgPair(1000, 4) \
2374 ->ArgPair(10000, 4) \
2375 ->ArgPair(25000, 4) \
2376 ->ArgPair(50000, 4) \
2377 ->ArgPair(100000, 4) \
2378 ->ArgPair(10, 8) \
2379 ->ArgPair(100, 8) \
2380 ->ArgPair(1000, 8) \
2381 ->ArgPair(10000, 8) \
2382 ->ArgPair(25000, 8) \
2383 ->ArgPair(50000, 8) \
2384 ->ArgPair(100000, 8) \
2385 ->ArgPair(10, 16) \
2386 ->ArgPair(100, 16) \
2387 ->ArgPair(1000, 16) \
2388 ->ArgPair(10000, 16) \
2389 ->ArgPair(25000, 16) \
2390 ->ArgPair(50000, 16) \
2391 ->ArgPair(100000, 16);
2392
2393 template <typename GraphViewT>
BM_GraphViewTConstruction(::testing::benchmark::State & state)2394 void BM_GraphViewTConstruction(::testing::benchmark::State& state) {
2395 const int num_nodes = state.range(0);
2396 const int num_edges_per_node = state.range(1);
2397
2398 GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
2399
2400 for (auto i : state) {
2401 Status s;
2402 GraphViewT graph_view(&graph_def, &s);
2403 }
2404 }
2405
BM_GraphViewConstruction(::testing::benchmark::State & state)2406 void BM_GraphViewConstruction(::testing::benchmark::State& state) {
2407 BM_GraphViewTConstruction<GraphView>(state);
2408 }
2409
BM_MutableGraphViewConstruction(::testing::benchmark::State & state)2410 void BM_MutableGraphViewConstruction(::testing::benchmark::State& state) {
2411 BM_GraphViewTConstruction<MutableGraphView>(state);
2412 }
2413
BM_MutableGraphViewClearAttrs(::testing::benchmark::State & state)2414 void BM_MutableGraphViewClearAttrs(::testing::benchmark::State& state) {
2415 const int num_nodes = state.range(0);
2416 const int num_edges_per_node = state.range(1);
2417
2418 GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
2419
2420 Status s;
2421 MutableGraphView graph_view(&graph_def, &s);
2422
2423 for (auto i : state) {
2424 utils::Mutation* mutation = graph_view.GetMutationBuilder();
2425 for (int j = 0; j < num_nodes; ++j) {
2426 mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
2427 }
2428 s = mutation->Apply();
2429 }
2430 }
2431
2432 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
2433 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
2434 RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
2435
2436 #define RUN_NUM_NODE_BENCHMARK(name) \
2437 BENCHMARK(name) \
2438 ->Arg(10) \
2439 ->Arg(100) \
2440 ->Arg(1000) \
2441 ->Arg(10000) \
2442 ->Arg(25000) \
2443 ->Arg(50000) \
2444 ->Arg(100000);
2445
2446 template <typename GraphViewT>
BM_GraphViewTConstructionWithControlDependencies(::testing::benchmark::State & state)2447 void BM_GraphViewTConstructionWithControlDependencies(
2448 ::testing::benchmark::State& state) {
2449 const int num_fanins_fanouts = state.range(0);
2450
2451 GraphDef graph_def =
2452 test::CreateFaninFanoutNodeGraph(num_fanins_fanouts, num_fanins_fanouts,
2453 num_fanins_fanouts, num_fanins_fanouts,
2454 /*fanout_unique_index=*/true);
2455
2456 for (auto i : state) {
2457 Status s;
2458 GraphViewT graph_view(&graph_def, &s);
2459 }
2460 }
2461
BM_GraphViewConstructionWithControlDependencies(::testing::benchmark::State & state)2462 void BM_GraphViewConstructionWithControlDependencies(
2463 ::testing::benchmark::State& state) {
2464 BM_GraphViewTConstructionWithControlDependencies<GraphView>(state);
2465 }
2466
BM_MutableGraphViewConstructionWithControlDependencies(::testing::benchmark::State & state)2467 void BM_MutableGraphViewConstructionWithControlDependencies(
2468 ::testing::benchmark::State& state) {
2469 BM_GraphViewTConstructionWithControlDependencies<MutableGraphView>(state);
2470 }
2471
2472 RUN_NUM_NODE_BENCHMARK(BM_GraphViewConstructionWithControlDependencies);
2473 RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewConstructionWithControlDependencies);
2474
2475 template <typename GraphViewT>
BM_GraphViewTGetNode(::testing::benchmark::State & state)2476 void BM_GraphViewTGetNode(::testing::benchmark::State& state) {
2477 const int num_nodes = state.range(0);
2478
2479 GraphDef graph_def =
2480 test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
2481 Status s;
2482 GraphViewT graph_view(&graph_def, &s);
2483
2484 for (auto i : state) {
2485 graph_view.GetNode("out");
2486 }
2487 }
2488
BM_GraphViewGetNode(::testing::benchmark::State & state)2489 void BM_GraphViewGetNode(::testing::benchmark::State& state) {
2490 BM_GraphViewTGetNode<GraphView>(state);
2491 }
2492
BM_MutableGraphViewGetNode(::testing::benchmark::State & state)2493 void BM_MutableGraphViewGetNode(::testing::benchmark::State& state) {
2494 BM_GraphViewTGetNode<MutableGraphView>(state);
2495 }
2496
2497 RUN_NUM_NODE_BENCHMARK(BM_GraphViewGetNode);
2498 RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewGetNode);
2499
2500 #define RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(name) \
2501 BENCHMARK(name) \
2502 ->ArgPair(10, 10) \
2503 ->ArgPair(10, 100) \
2504 ->ArgPair(10, 1000) \
2505 ->ArgPair(10, 10000) \
2506 ->ArgPair(10, 100000) \
2507 ->ArgPair(100, 10) \
2508 ->ArgPair(100, 100) \
2509 ->ArgPair(100, 1000) \
2510 ->ArgPair(100, 10000) \
2511 ->ArgPair(100, 100000) \
2512 ->ArgPair(1000, 10) \
2513 ->ArgPair(1000, 100) \
2514 ->ArgPair(1000, 1000) \
2515 ->ArgPair(1000, 10000) \
2516 ->ArgPair(1000, 100000) \
2517 ->ArgPair(10000, 10) \
2518 ->ArgPair(10000, 100) \
2519 ->ArgPair(10000, 1000) \
2520 ->ArgPair(10000, 10000) \
2521 ->ArgPair(10000, 100000) \
2522 ->ArgPair(100000, 10) \
2523 ->ArgPair(100000, 100) \
2524 ->ArgPair(100000, 1000) \
2525 ->ArgPair(100000, 10000) \
2526 ->ArgPair(100000, 100000);
2527
2528 template <typename GraphViewT>
BM_GraphViewTGetRegularFanin(::testing::benchmark::State & state)2529 void BM_GraphViewTGetRegularFanin(::testing::benchmark::State& state) {
2530 const int num_fanins = state.range(0);
2531 const int num_fanouts = state.range(1);
2532
2533 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2534 num_fanins, num_fanouts, num_fanins, num_fanouts,
2535 /*fanout_unique_index=*/true);
2536 Status s;
2537 GraphViewT graph_view(&graph_def, &s);
2538
2539 for (auto i : state) {
2540 auto* node = graph_view.GetNode("node");
2541 node->GetRegularFanin(0);
2542 }
2543 }
2544
BM_GraphViewGetRegularFanin(::testing::benchmark::State & state)2545 void BM_GraphViewGetRegularFanin(::testing::benchmark::State& state) {
2546 BM_GraphViewTGetRegularFanin<GraphView>(state);
2547 }
2548
BM_MutableGraphViewGetRegularFanin(::testing::benchmark::State & state)2549 void BM_MutableGraphViewGetRegularFanin(::testing::benchmark::State& state) {
2550 BM_GraphViewTGetRegularFanin<MutableGraphView>(state);
2551 }
2552
2553 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
2554 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanin);
2555
2556 template <typename GraphViewT>
BM_GraphViewTGetRegularFanout(::testing::benchmark::State & state)2557 void BM_GraphViewTGetRegularFanout(::testing::benchmark::State& state) {
2558 const int num_fanins = state.range(0);
2559 const int num_fanouts = state.range(1);
2560
2561 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2562 num_fanins, num_fanouts, num_fanins, num_fanouts,
2563 /*fanout_unique_index=*/true);
2564 Status s;
2565 GraphViewT graph_view(&graph_def, &s);
2566
2567 for (auto i : state) {
2568 auto* node = graph_view.GetNode("node");
2569 node->GetRegularFanout(0);
2570 }
2571 }
2572
BM_GraphViewGetRegularFanout(::testing::benchmark::State & state)2573 void BM_GraphViewGetRegularFanout(::testing::benchmark::State& state) {
2574 BM_GraphViewTGetRegularFanout<GraphView>(state);
2575 }
2576
BM_MutableGraphViewGetRegularFanout(::testing::benchmark::State & state)2577 void BM_MutableGraphViewGetRegularFanout(::testing::benchmark::State& state) {
2578 BM_GraphViewTGetRegularFanout<MutableGraphView>(state);
2579 }
2580
2581 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanout);
2582 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanout);
2583
2584 template <typename GraphViewT>
BM_GraphViewTGetRegularFanins(::testing::benchmark::State & state)2585 void BM_GraphViewTGetRegularFanins(::testing::benchmark::State& state) {
2586 const int num_fanins = state.range(0);
2587 const int num_fanouts = state.range(1);
2588
2589 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2590 num_fanins, num_fanouts, num_fanins, num_fanouts,
2591 /*fanout_unique_index=*/true);
2592 Status s;
2593 GraphViewT graph_view(&graph_def, &s);
2594
2595 for (auto i : state) {
2596 auto* node = graph_view.GetNode("node");
2597 node->GetRegularFanins();
2598 }
2599 }
2600
BM_GraphViewGetRegularFanins(::testing::benchmark::State & state)2601 void BM_GraphViewGetRegularFanins(::testing::benchmark::State& state) {
2602 BM_GraphViewTGetRegularFanins<GraphView>(state);
2603 }
2604
BM_MutableGraphViewGetRegularFanins(::testing::benchmark::State & state)2605 void BM_MutableGraphViewGetRegularFanins(::testing::benchmark::State& state) {
2606 BM_GraphViewTGetRegularFanins<MutableGraphView>(state);
2607 }
2608
2609 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanins);
2610 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanins);
2611
2612 template <typename GraphViewT>
BM_GraphViewTGetRegularFanouts(::testing::benchmark::State & state)2613 void BM_GraphViewTGetRegularFanouts(::testing::benchmark::State& state) {
2614 const int num_fanins = state.range(0);
2615 const int num_fanouts = state.range(1);
2616
2617 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2618 num_fanins, num_fanouts, num_fanins, num_fanouts,
2619 /*fanout_unique_index=*/true);
2620 Status s;
2621 GraphViewT graph_view(&graph_def, &s);
2622
2623 for (auto i : state) {
2624 auto* node = graph_view.GetNode("node");
2625 node->GetRegularFanouts();
2626 }
2627 }
2628
BM_GraphViewGetRegularFanouts(::testing::benchmark::State & state)2629 void BM_GraphViewGetRegularFanouts(::testing::benchmark::State& state) {
2630 BM_GraphViewTGetRegularFanouts<GraphView>(state);
2631 }
2632
BM_MutableGraphViewGetRegularFanouts(::testing::benchmark::State & state)2633 void BM_MutableGraphViewGetRegularFanouts(::testing::benchmark::State& state) {
2634 BM_GraphViewTGetRegularFanouts<MutableGraphView>(state);
2635 }
2636
2637 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanouts);
2638 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanouts);
2639
2640 template <typename GraphViewT>
BM_GraphViewTGetControllingFanins(::testing::benchmark::State & state)2641 void BM_GraphViewTGetControllingFanins(::testing::benchmark::State& state) {
2642 const int num_fanins = state.range(0);
2643 const int num_fanouts = state.range(1);
2644
2645 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2646 num_fanins, num_fanouts, num_fanins, num_fanouts,
2647 /*fanout_unique_index=*/true);
2648 Status s;
2649 GraphViewT graph_view(&graph_def, &s);
2650
2651 for (auto i : state) {
2652 auto* node = graph_view.GetNode("node");
2653 node->GetControllingFanins();
2654 }
2655 }
2656
BM_GraphViewGetControllingFanins(::testing::benchmark::State & state)2657 void BM_GraphViewGetControllingFanins(::testing::benchmark::State& state) {
2658 BM_GraphViewTGetControllingFanins<GraphView>(state);
2659 }
2660
BM_MutableGraphViewGetControllingFanins(::testing::benchmark::State & state)2661 void BM_MutableGraphViewGetControllingFanins(
2662 ::testing::benchmark::State& state) {
2663 BM_GraphViewTGetControllingFanins<MutableGraphView>(state);
2664 }
2665
2666 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControllingFanins);
2667 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControllingFanins);
2668
2669 template <typename GraphViewT>
BM_GraphViewTGetControlledFanouts(::testing::benchmark::State & state)2670 void BM_GraphViewTGetControlledFanouts(::testing::benchmark::State& state) {
2671 const int num_fanins = state.range(0);
2672 const int num_fanouts = state.range(1);
2673
2674 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2675 num_fanins, num_fanouts, num_fanins, num_fanouts,
2676 /*fanout_unique_index=*/true);
2677 Status s;
2678 GraphViewT graph_view(&graph_def, &s);
2679
2680 for (auto i : state) {
2681 auto* node = graph_view.GetNode("node");
2682 node->GetControlledFanouts();
2683 }
2684 }
2685
BM_GraphViewGetControlledFanouts(::testing::benchmark::State & state)2686 void BM_GraphViewGetControlledFanouts(::testing::benchmark::State& state) {
2687 BM_GraphViewTGetControlledFanouts<GraphView>(state);
2688 }
2689
BM_MutableGraphViewGetControlledFanouts(::testing::benchmark::State & state)2690 void BM_MutableGraphViewGetControlledFanouts(
2691 ::testing::benchmark::State& state) {
2692 BM_GraphViewTGetControlledFanouts<MutableGraphView>(state);
2693 }
2694
2695 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControlledFanouts);
2696 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControlledFanouts);
2697
2698 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasRegularFanin(::testing::benchmark::State & state)2699 inline void BM_GraphViewTHasRegularFanin(::testing::benchmark::State& state) {
2700 const int num_fanins = state.range(0);
2701 const int num_fanouts = state.range(1);
2702
2703 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2704 num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
2705 /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
2706 Status s;
2707 GraphViewT graph_view(&graph_def, &s);
2708 const int index = IsLast ? num_fanouts - 1 : 0;
2709 auto* node = graph_view.GetNode(absl::StrFormat("out%05d", index));
2710 auto* fanin = graph_view.GetNode("node");
2711
2712 for (auto i : state) {
2713 node->HasFanin({&graph_view, fanin->node_index(), 0});
2714 }
2715 }
2716
BM_GraphViewHasRegularFaninFirst(::testing::benchmark::State & state)2717 void BM_GraphViewHasRegularFaninFirst(::testing::benchmark::State& state) {
2718 BM_GraphViewTHasRegularFanin<GraphView, false>(state);
2719 }
2720
BM_GraphViewHasRegularFaninLast(::testing::benchmark::State & state)2721 void BM_GraphViewHasRegularFaninLast(::testing::benchmark::State& state) {
2722 BM_GraphViewTHasRegularFanin<GraphView, true>(state);
2723 }
2724
BM_MutableGraphViewHasRegularFaninFirst(::testing::benchmark::State & state)2725 void BM_MutableGraphViewHasRegularFaninFirst(
2726 ::testing::benchmark::State& state) {
2727 BM_GraphViewTHasRegularFanin<MutableGraphView, false>(state);
2728 }
2729
BM_MutableGraphViewHasRegularFaninLast(::testing::benchmark::State & state)2730 void BM_MutableGraphViewHasRegularFaninLast(
2731 ::testing::benchmark::State& state) {
2732 BM_GraphViewTHasRegularFanin<MutableGraphView, true>(state);
2733 }
2734
2735 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninFirst);
2736 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninLast);
2737 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninFirst);
2738 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninLast);
2739
2740 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasControllingFanin(::testing::benchmark::State & state)2741 inline void BM_GraphViewTHasControllingFanin(
2742 ::testing::benchmark::State& state) {
2743 const int num_fanins = state.range(0);
2744 const int num_fanouts = state.range(1);
2745
2746 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2747 num_fanins, num_fanouts, num_fanins, num_fanouts,
2748 /*fanout_unique_index=*/true);
2749 Status s;
2750 GraphViewT graph_view(&graph_def, &s);
2751 const int index = IsLast ? num_fanouts - 1 : 0;
2752 auto* node = graph_view.GetNode(absl::StrFormat("control_out%05d", index));
2753 auto* fanin = graph_view.GetNode("node");
2754
2755 for (auto i : state) {
2756 node->HasFanin({&graph_view, fanin->node_index(), Graph::kControlSlot});
2757 }
2758 }
2759
BM_GraphViewHasControllingFaninFirst(::testing::benchmark::State & state)2760 void BM_GraphViewHasControllingFaninFirst(::testing::benchmark::State& state) {
2761 BM_GraphViewTHasControllingFanin<GraphView, false>(state);
2762 }
2763
BM_GraphViewHasControllingFaninLast(::testing::benchmark::State & state)2764 void BM_GraphViewHasControllingFaninLast(::testing::benchmark::State& state) {
2765 BM_GraphViewTHasControllingFanin<GraphView, true>(state);
2766 }
2767
BM_MutableGraphViewHasControllingFaninFirst(::testing::benchmark::State & state)2768 void BM_MutableGraphViewHasControllingFaninFirst(
2769 ::testing::benchmark::State& state) {
2770 BM_GraphViewTHasControllingFanin<MutableGraphView, false>(state);
2771 }
2772
BM_MutableGraphViewHasControllingFaninLast(::testing::benchmark::State & state)2773 void BM_MutableGraphViewHasControllingFaninLast(
2774 ::testing::benchmark::State& state) {
2775 BM_GraphViewTHasControllingFanin<MutableGraphView, true>(state);
2776 }
2777
2778 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninFirst);
2779 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninLast);
2780 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninFirst);
2781 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninLast);
2782
2783 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasRegularFanout(::testing::benchmark::State & state)2784 inline void BM_GraphViewTHasRegularFanout(::testing::benchmark::State& state) {
2785 const int num_fanins = state.range(0);
2786 const int num_fanouts = state.range(1);
2787
2788 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2789 num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
2790 /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
2791 Status s;
2792 GraphViewT graph_view(&graph_def, &s);
2793 const int index = IsLast ? num_fanins - 1 : 0;
2794 auto* node = graph_view.GetNode(absl::StrFormat("in%05d", index));
2795 auto* fanout = graph_view.GetNode("node");
2796
2797 for (auto i : state) {
2798 node->HasFanout({&graph_view, fanout->node_index(), index});
2799 }
2800 }
2801
BM_GraphViewHasRegularFanoutFirst(::testing::benchmark::State & state)2802 void BM_GraphViewHasRegularFanoutFirst(::testing::benchmark::State& state) {
2803 BM_GraphViewTHasRegularFanout<GraphView, false>(state);
2804 }
2805
BM_GraphViewHasRegularFanoutLast(::testing::benchmark::State & state)2806 void BM_GraphViewHasRegularFanoutLast(::testing::benchmark::State& state) {
2807 BM_GraphViewTHasRegularFanout<GraphView, true>(state);
2808 }
2809
BM_MutableGraphViewHasRegularFanoutFirst(::testing::benchmark::State & state)2810 void BM_MutableGraphViewHasRegularFanoutFirst(
2811 ::testing::benchmark::State& state) {
2812 BM_GraphViewTHasRegularFanout<MutableGraphView, false>(state);
2813 }
2814
BM_MutableGraphViewHasRegularFanoutLast(::testing::benchmark::State & state)2815 void BM_MutableGraphViewHasRegularFanoutLast(
2816 ::testing::benchmark::State& state) {
2817 BM_GraphViewTHasRegularFanout<MutableGraphView, true>(state);
2818 }
2819
2820 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutFirst);
2821 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutLast);
2822 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutFirst);
2823 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutLast);
2824
2825 template <typename GraphViewT, bool IsLast>
BM_GraphViewTHasControlledFanout(::testing::benchmark::State & state)2826 inline void BM_GraphViewTHasControlledFanout(
2827 ::testing::benchmark::State& state) {
2828 const int num_fanins = state.range(0);
2829 const int num_fanouts = state.range(1);
2830
2831 GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
2832 num_fanins, num_fanouts, num_fanins, num_fanouts,
2833 /*fanout_unique_index=*/false);
2834 Status s;
2835 GraphViewT graph_view(&graph_def, &s);
2836 const int index = IsLast ? num_fanins - 1 : 0;
2837 auto* node = graph_view.GetNode(absl::StrFormat("control_in%05d", index));
2838 auto* fanout = graph_view.GetNode("node");
2839
2840 for (auto i : state) {
2841 node->HasFanout({&graph_view, fanout->node_index(), Graph::kControlSlot});
2842 }
2843 }
2844
BM_GraphViewHasControlledFanoutFirst(::testing::benchmark::State & state)2845 void BM_GraphViewHasControlledFanoutFirst(::testing::benchmark::State& state) {
2846 BM_GraphViewTHasControlledFanout<GraphView, false>(state);
2847 }
2848
BM_GraphViewHasControlledFanoutLast(::testing::benchmark::State & state)2849 void BM_GraphViewHasControlledFanoutLast(::testing::benchmark::State& state) {
2850 BM_GraphViewTHasControlledFanout<GraphView, true>(state);
2851 }
2852
BM_MutableGraphViewHasControlledFanoutFirst(::testing::benchmark::State & state)2853 void BM_MutableGraphViewHasControlledFanoutFirst(
2854 ::testing::benchmark::State& state) {
2855 BM_GraphViewTHasControlledFanout<MutableGraphView, false>(state);
2856 }
2857
BM_MutableGraphViewHasControlledFanoutLast(::testing::benchmark::State & state)2858 void BM_MutableGraphViewHasControlledFanoutLast(
2859 ::testing::benchmark::State& state) {
2860 BM_GraphViewTHasControlledFanout<MutableGraphView, true>(state);
2861 }
2862
2863 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutFirst);
2864 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast);
2865 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst);
2866 RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast);
2867
BM_SortTopologically(::testing::benchmark::State & state)2868 void BM_SortTopologically(::testing::benchmark::State& state) {
2869 const int size = state.range(0);
2870
2871 GraphDef graph = test::CreateRandomGraph(size);
2872 Status status;
2873 MutableGraphView graph_view(&graph, &status);
2874 TF_ASSERT_OK(status);
2875
2876 for (auto i : state) {
2877 TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
2878 }
2879 }
2880
2881 RUN_NUM_NODE_BENCHMARK(BM_SortTopologically);
2882
2883 } // namespace
2884 } // namespace utils
2885 } // namespace grappler
2886 } // namespace tensorflow
2887