1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A test for the GraphCycles interface.
17
18 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
19
20 #include <optional>
21 #include <random>
22 #include <vector>
23
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28
29 // We emulate a GraphCycles object with a node vector and an edge vector.
30 // We then compare the two implementations.
31
32 typedef std::vector<int> Nodes;
33 struct Edge {
34 int from;
35 int to;
36 };
37 typedef std::vector<Edge> Edges;
38
39 // Return whether "to" is reachable from "from".
IsReachable(Edges * edges,int from,int to,absl::flat_hash_set<int> * seen)40 static bool IsReachable(Edges *edges, int from, int to,
41 absl::flat_hash_set<int> *seen) {
42 seen->insert(from); // we are investigating "from"; don't do it again
43 if (from == to) return true;
44 for (int i = 0; i != edges->size(); i++) {
45 Edge *edge = &(*edges)[i];
46 if (edge->from == from) {
47 if (edge->to == to) { // success via edge directly
48 return true;
49 } else if (seen->find(edge->to) == seen->end() && // success via edge
50 IsReachable(edges, edge->to, to, seen)) {
51 return true;
52 }
53 }
54 }
55 return false;
56 }
57
PrintNodes(Nodes * nodes)58 static void PrintNodes(Nodes *nodes) {
59 LOG(INFO) << "NODES (" << nodes->size() << ")";
60 for (int i = 0; i != nodes->size(); i++) {
61 LOG(INFO) << (*nodes)[i];
62 }
63 }
64
PrintEdges(Edges * edges)65 static void PrintEdges(Edges *edges) {
66 LOG(INFO) << "EDGES (" << edges->size() << ")";
67 for (int i = 0; i != edges->size(); i++) {
68 int a = (*edges)[i].from;
69 int b = (*edges)[i].to;
70 LOG(INFO) << a << " " << b;
71 }
72 LOG(INFO) << "---";
73 }
74
PrintGCEdges(Nodes * nodes,tensorflow::GraphCycles * gc)75 static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) {
76 LOG(INFO) << "GC EDGES";
77 for (int i = 0; i != nodes->size(); i++) {
78 for (int j = 0; j != nodes->size(); j++) {
79 int a = (*nodes)[i];
80 int b = (*nodes)[j];
81 if (gc->HasEdge(a, b)) {
82 LOG(INFO) << a << " " << b;
83 }
84 }
85 }
86 LOG(INFO) << "---";
87 }
88
PrintTransitiveClosure(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)89 static void PrintTransitiveClosure(Nodes *nodes, Edges *edges,
90 tensorflow::GraphCycles *gc) {
91 LOG(INFO) << "Transitive closure";
92 for (int i = 0; i != nodes->size(); i++) {
93 for (int j = 0; j != nodes->size(); j++) {
94 int a = (*nodes)[i];
95 int b = (*nodes)[j];
96 absl::flat_hash_set<int> seen;
97 if (IsReachable(edges, a, b, &seen)) {
98 LOG(INFO) << a << " " << b;
99 }
100 }
101 }
102 LOG(INFO) << "---";
103 }
104
PrintGCTransitiveClosure(Nodes * nodes,tensorflow::GraphCycles * gc)105 static void PrintGCTransitiveClosure(Nodes *nodes,
106 tensorflow::GraphCycles *gc) {
107 LOG(INFO) << "GC Transitive closure";
108 for (int i = 0; i != nodes->size(); i++) {
109 for (int j = 0; j != nodes->size(); j++) {
110 int a = (*nodes)[i];
111 int b = (*nodes)[j];
112 if (gc->IsReachable(a, b)) {
113 LOG(INFO) << a << " " << b;
114 }
115 }
116 }
117 LOG(INFO) << "---";
118 }
119
CheckTransitiveClosure(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)120 static void CheckTransitiveClosure(Nodes *nodes, Edges *edges,
121 tensorflow::GraphCycles *gc) {
122 absl::flat_hash_set<int> seen;
123 for (int i = 0; i != nodes->size(); i++) {
124 for (int j = 0; j != nodes->size(); j++) {
125 seen.clear();
126 int a = (*nodes)[i];
127 int b = (*nodes)[j];
128 bool gc_reachable = gc->IsReachable(a, b);
129 CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b));
130 bool reachable = IsReachable(edges, a, b, &seen);
131 if (gc_reachable != reachable) {
132 PrintEdges(edges);
133 PrintGCEdges(nodes, gc);
134 PrintTransitiveClosure(nodes, edges, gc);
135 PrintGCTransitiveClosure(nodes, gc);
136 LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable "
137 << reachable << " a " << a << " b " << b;
138 }
139 }
140 }
141 }
142
CheckEdges(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)143 static void CheckEdges(Nodes *nodes, Edges *edges,
144 tensorflow::GraphCycles *gc) {
145 int count = 0;
146 for (int i = 0; i != edges->size(); i++) {
147 int a = (*edges)[i].from;
148 int b = (*edges)[i].to;
149 if (!gc->HasEdge(a, b)) {
150 PrintEdges(edges);
151 PrintGCEdges(nodes, gc);
152 LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")";
153 }
154 }
155 for (int i = 0; i != nodes->size(); i++) {
156 for (int j = 0; j != nodes->size(); j++) {
157 int a = (*nodes)[i];
158 int b = (*nodes)[j];
159 if (gc->HasEdge(a, b)) {
160 count++;
161 }
162 }
163 }
164 if (count != edges->size()) {
165 PrintEdges(edges);
166 PrintGCEdges(nodes, gc);
167 LOG(FATAL) << "edges->size() " << edges->size() << " count " << count;
168 }
169 }
170
171 // Returns the index of a randomly chosen node in *nodes.
172 // Requires *nodes be non-empty.
RandomNode(std::mt19937 * rnd,Nodes * nodes)173 static int RandomNode(std::mt19937 *rnd, Nodes *nodes) {
174 std::uniform_int_distribution<int> distribution(0, nodes->size() - 1);
175 return distribution(*rnd);
176 }
177
178 // Returns the index of a randomly chosen edge in *edges.
179 // Requires *edges be non-empty.
RandomEdge(std::mt19937 * rnd,Edges * edges)180 static int RandomEdge(std::mt19937 *rnd, Edges *edges) {
181 std::uniform_int_distribution<int> distribution(0, edges->size() - 1);
182 return distribution(*rnd);
183 }
184
185 // Returns the index of edge (from, to) in *edges or -1 if it is not in *edges.
EdgeIndex(Edges * edges,int from,int to)186 static int EdgeIndex(Edges *edges, int from, int to) {
187 int i = 0;
188 while (i != edges->size() &&
189 ((*edges)[i].from != from || (*edges)[i].to != to)) {
190 i++;
191 }
192 return i == edges->size() ? -1 : i;
193 }
194
TEST(GraphCycles,RandomizedTest)195 TEST(GraphCycles, RandomizedTest) {
196 Nodes nodes;
197 Edges edges; // from, to
198 tensorflow::GraphCycles graph_cycles;
199 static const int kMaxNodes = 7; // use <= 7 nodes to keep test short
200 static const int kDataOffset = 17; // an offset to the node-specific data
201 int n = 100000;
202 int op = 0;
203 std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1);
204
205 for (int iter = 0; iter != n; iter++) {
206 if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n;
207
208 if (VLOG_IS_ON(3)) {
209 LOG(INFO) << "===============";
210 LOG(INFO) << "last op " << op;
211 PrintNodes(&nodes);
212 PrintEdges(&edges);
213 PrintGCEdges(&nodes, &graph_cycles);
214 }
215 for (int i = 0; i != nodes.size(); i++) {
216 ASSERT_EQ(reinterpret_cast<intptr_t>(graph_cycles.GetNodeData(i)),
217 i + kDataOffset)
218 << " node " << i;
219 }
220 CheckEdges(&nodes, &edges, &graph_cycles);
221 CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
222 std::uniform_int_distribution<int> distribution(0, 5);
223 op = distribution(rnd);
224 switch (op) {
225 case 0: // Add a node
226 if (nodes.size() < kMaxNodes) {
227 int new_node = graph_cycles.NewNode();
228 ASSERT_NE(-1, new_node);
229 VLOG(1) << "adding node " << new_node;
230 ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
231 graph_cycles.SetNodeData(
232 new_node, reinterpret_cast<void *>(
233 static_cast<intptr_t>(new_node + kDataOffset)));
234 ASSERT_GE(new_node, 0);
235 for (int i = 0; i != nodes.size(); i++) {
236 ASSERT_NE(nodes[i], new_node);
237 }
238 nodes.push_back(new_node);
239 }
240 break;
241
242 case 1: // Remove a node
243 if (!nodes.empty()) {
244 int node_index = RandomNode(&rnd, &nodes);
245 int node = nodes[node_index];
246 nodes[node_index] = nodes.back();
247 nodes.pop_back();
248 VLOG(1) << "removing node " << node;
249 graph_cycles.RemoveNode(node);
250 int i = 0;
251 while (i != edges.size()) {
252 if (edges[i].from == node || edges[i].to == node) {
253 edges[i] = edges.back();
254 edges.pop_back();
255 } else {
256 i++;
257 }
258 }
259 }
260 break;
261
262 case 2: // Add an edge
263 if (!nodes.empty()) {
264 int from = RandomNode(&rnd, &nodes);
265 int to = RandomNode(&rnd, &nodes);
266 if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
267 if (graph_cycles.InsertEdge(nodes[from], nodes[to])) {
268 Edge new_edge;
269 new_edge.from = nodes[from];
270 new_edge.to = nodes[to];
271 edges.push_back(new_edge);
272 } else {
273 absl::flat_hash_set<int> seen;
274 ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen))
275 << "Edge " << nodes[to] << "->" << nodes[from];
276 }
277 }
278 }
279 break;
280
281 case 3: // Remove an edge
282 if (!edges.empty()) {
283 int i = RandomEdge(&rnd, &edges);
284 int from = edges[i].from;
285 int to = edges[i].to;
286 ASSERT_EQ(i, EdgeIndex(&edges, from, to));
287 edges[i] = edges.back();
288 edges.pop_back();
289 ASSERT_EQ(-1, EdgeIndex(&edges, from, to));
290 VLOG(1) << "removing edge " << from << " " << to;
291 graph_cycles.RemoveEdge(from, to);
292 }
293 break;
294
295 case 4: // Check a path
296 if (!nodes.empty()) {
297 int from = RandomNode(&rnd, &nodes);
298 int to = RandomNode(&rnd, &nodes);
299 int32_t path[2 * kMaxNodes];
300 int path_len = graph_cycles.FindPath(nodes[from], nodes[to],
301 2 * kMaxNodes, path);
302 absl::flat_hash_set<int> seen;
303 bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen);
304 bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]);
305 ASSERT_EQ(gc_reachable,
306 graph_cycles.IsReachableNonConst(nodes[from], nodes[to]));
307 ASSERT_EQ(path_len != 0, reachable);
308 ASSERT_EQ(path_len != 0, gc_reachable);
309 // In the following line, we add one because a node can appear
310 // twice, if the path is from that node to itself, perhaps via
311 // every other node.
312 ASSERT_LE(path_len, kMaxNodes + 1);
313 if (path_len != 0) {
314 ASSERT_EQ(nodes[from], path[0]);
315 ASSERT_EQ(nodes[to], path[path_len - 1]);
316 for (int i = 1; i < path_len; i++) {
317 ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i]));
318 ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i]));
319 }
320 }
321 }
322 break;
323
324 case 5: // Check invariants
325 CHECK(graph_cycles.CheckInvariants());
326 break;
327
328 default:
329 LOG(FATAL);
330 }
331
332 // Very rarely, test graph expansion by adding then removing many nodes.
333 std::bernoulli_distribution rarely(1.0 / 1024.0);
334 if (rarely(rnd)) {
335 VLOG(3) << "Graph expansion";
336 CheckEdges(&nodes, &edges, &graph_cycles);
337 CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
338 for (int i = 0; i != 256; i++) {
339 int new_node = graph_cycles.NewNode();
340 ASSERT_NE(-1, new_node);
341 VLOG(1) << "adding node " << new_node;
342 ASSERT_GE(new_node, 0);
343 ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
344 graph_cycles.SetNodeData(
345 new_node, reinterpret_cast<void *>(
346 static_cast<intptr_t>(new_node + kDataOffset)));
347 for (int j = 0; j != nodes.size(); j++) {
348 ASSERT_NE(nodes[j], new_node);
349 }
350 nodes.push_back(new_node);
351 }
352 for (int i = 0; i != 256; i++) {
353 ASSERT_GT(nodes.size(), 0);
354 int node_index = RandomNode(&rnd, &nodes);
355 int node = nodes[node_index];
356 nodes[node_index] = nodes.back();
357 nodes.pop_back();
358 VLOG(1) << "removing node " << node;
359 graph_cycles.RemoveNode(node);
360 int j = 0;
361 while (j != edges.size()) {
362 if (edges[j].from == node || edges[j].to == node) {
363 edges[j] = edges.back();
364 edges.pop_back();
365 } else {
366 j++;
367 }
368 }
369 }
370 CHECK(graph_cycles.CheckInvariants());
371 }
372 }
373 }
374
375 class GraphCyclesTest : public ::testing::Test {
376 public:
377 tensorflow::GraphCycles g_;
378
379 // Test relies on ith NewNode() call returning Node numbered i
GraphCyclesTest()380 GraphCyclesTest() {
381 for (int i = 0; i < 100; i++) {
382 CHECK_EQ(i, g_.NewNode());
383 }
384 CHECK(g_.CheckInvariants());
385 }
386
AddEdge(int x,int y)387 bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
388
AddMultiples()389 void AddMultiples() {
390 // For every node x > 0: add edge to 2*x, 3*x
391 for (int x = 1; x < 25; x++) {
392 EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
393 EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
394 }
395 CHECK(g_.CheckInvariants());
396 }
397
Path(int x,int y)398 std::string Path(int x, int y) {
399 static const int kPathSize = 5;
400 int32_t path[kPathSize];
401 int np = g_.FindPath(x, y, kPathSize, path);
402 std::string result;
403 for (int i = 0; i < np; i++) {
404 if (i >= kPathSize) {
405 result += " ...";
406 break;
407 }
408 if (!result.empty()) result.push_back(' ');
409 char buf[20];
410 snprintf(buf, sizeof(buf), "%d", path[i]);
411 result += buf;
412 }
413 return result;
414 }
415 };
416
TEST_F(GraphCyclesTest,NoCycle)417 TEST_F(GraphCyclesTest, NoCycle) {
418 AddMultiples();
419 CHECK(g_.CheckInvariants());
420 }
421
TEST_F(GraphCyclesTest,SimpleCycle)422 TEST_F(GraphCyclesTest, SimpleCycle) {
423 AddMultiples();
424 EXPECT_FALSE(AddEdge(8, 4));
425 EXPECT_EQ("4 8", Path(4, 8));
426 CHECK(g_.CheckInvariants());
427 }
428
TEST_F(GraphCyclesTest,IndirectCycle)429 TEST_F(GraphCyclesTest, IndirectCycle) {
430 AddMultiples();
431 EXPECT_TRUE(AddEdge(16, 9));
432 CHECK(g_.CheckInvariants());
433 EXPECT_FALSE(AddEdge(9, 2));
434 EXPECT_EQ("2 4 8 16 9", Path(2, 9));
435 CHECK(g_.CheckInvariants());
436 }
437
TEST_F(GraphCyclesTest,LongPath)438 TEST_F(GraphCyclesTest, LongPath) {
439 ASSERT_TRUE(AddEdge(2, 4));
440 ASSERT_TRUE(AddEdge(4, 6));
441 ASSERT_TRUE(AddEdge(6, 8));
442 ASSERT_TRUE(AddEdge(8, 10));
443 ASSERT_TRUE(AddEdge(10, 12));
444 ASSERT_FALSE(AddEdge(12, 2));
445 EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12));
446 CHECK(g_.CheckInvariants());
447 }
448
TEST_F(GraphCyclesTest,RemoveNode)449 TEST_F(GraphCyclesTest, RemoveNode) {
450 ASSERT_TRUE(AddEdge(1, 2));
451 ASSERT_TRUE(AddEdge(2, 3));
452 ASSERT_TRUE(AddEdge(3, 4));
453 ASSERT_TRUE(AddEdge(4, 5));
454 g_.RemoveNode(3);
455 ASSERT_TRUE(AddEdge(5, 1));
456 }
457
TEST_F(GraphCyclesTest,ManyEdges)458 TEST_F(GraphCyclesTest, ManyEdges) {
459 const int N = 50;
460 for (int i = 0; i < N; i++) {
461 for (int j = 1; j < N; j++) {
462 ASSERT_TRUE(AddEdge(i, i + j));
463 }
464 }
465 CHECK(g_.CheckInvariants());
466 ASSERT_TRUE(AddEdge(2 * N - 1, 0));
467 CHECK(g_.CheckInvariants());
468 ASSERT_FALSE(AddEdge(10, 9));
469 CHECK(g_.CheckInvariants());
470 }
471
TEST_F(GraphCyclesTest,ContractEdge)472 TEST_F(GraphCyclesTest, ContractEdge) {
473 ASSERT_TRUE(AddEdge(1, 2));
474 ASSERT_TRUE(AddEdge(1, 3));
475 ASSERT_TRUE(AddEdge(2, 3));
476 ASSERT_TRUE(AddEdge(2, 4));
477 ASSERT_TRUE(AddEdge(3, 4));
478
479 EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
480 CHECK(g_.CheckInvariants());
481 EXPECT_TRUE(g_.HasEdge(1, 3));
482
483 // Node (2) has more edges.
484 EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
485 CHECK(g_.CheckInvariants());
486 EXPECT_TRUE(g_.HasEdge(2, 3));
487 EXPECT_TRUE(g_.HasEdge(2, 4));
488 EXPECT_TRUE(g_.HasEdge(3, 4));
489
490 // Node (2) has more edges.
491 EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
492 CHECK(g_.CheckInvariants());
493 EXPECT_TRUE(g_.HasEdge(2, 4));
494 }
495
TEST_F(GraphCyclesTest,CanContractEdge)496 TEST_F(GraphCyclesTest, CanContractEdge) {
497 ASSERT_TRUE(AddEdge(1, 2));
498 ASSERT_TRUE(AddEdge(1, 3));
499 ASSERT_TRUE(AddEdge(2, 3));
500 ASSERT_TRUE(AddEdge(2, 4));
501 ASSERT_TRUE(AddEdge(3, 4));
502
503 EXPECT_FALSE(g_.CanContractEdge(1, 3));
504 EXPECT_FALSE(g_.CanContractEdge(2, 4));
505 EXPECT_TRUE(g_.CanContractEdge(1, 2));
506 EXPECT_TRUE(g_.CanContractEdge(2, 3));
507 EXPECT_TRUE(g_.CanContractEdge(3, 4));
508 }
509
BM_StressTest(::testing::benchmark::State & state)510 static void BM_StressTest(::testing::benchmark::State &state) {
511 const int num_nodes = state.range(0);
512
513 for (auto s : state) {
514 tensorflow::GraphCycles g;
515 int32_t *nodes = new int32_t[num_nodes];
516 for (int i = 0; i < num_nodes; i++) {
517 nodes[i] = g.NewNode();
518 }
519 for (int i = 0; i < num_nodes; i++) {
520 int end = std::min(num_nodes, i + 5);
521 for (int j = i + 1; j < end; j++) {
522 if (nodes[i] >= 0 && nodes[j] >= 0) {
523 CHECK(g.InsertEdge(nodes[i], nodes[j]));
524 }
525 }
526 }
527 delete[] nodes;
528 }
529 }
530 BENCHMARK(BM_StressTest)->Range(2048, 1048576);
531
BM_ContractEdge(::testing::benchmark::State & state)532 static void BM_ContractEdge(::testing::benchmark::State &state) {
533 const int num_nodes = state.range(0);
534
535 for (auto s : state) {
536 state.PauseTiming();
537 tensorflow::GraphCycles g;
538 std::vector<int32_t> nodes;
539 nodes.reserve(num_nodes);
540 for (int i = 0; i < num_nodes; i++) {
541 nodes.push_back(g.NewNode());
542 }
543 // All edges point toward the last one.
544 for (int i = 0; i < num_nodes - 1; ++i) {
545 g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
546 }
547
548 state.ResumeTiming();
549 int node = num_nodes - 1;
550 for (int i = 0; i < num_nodes - 1; ++i) {
551 node = g.ContractEdge(nodes[i], node).value();
552 }
553 }
554 }
555 BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);
556