xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/graphcycles/graphcycles_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A test for the GraphCycles interface.
17 
18 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
19 
20 #include <optional>
21 #include <random>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/platform/test_benchmark.h"
28 
29 // We emulate a GraphCycles object with a node vector and an edge vector.
30 // We then compare the two implementations.
31 
32 typedef std::vector<int> Nodes;
33 struct Edge {
34   int from;
35   int to;
36 };
37 typedef std::vector<Edge> Edges;
38 
39 // Return whether "to" is reachable from "from".
IsReachable(Edges * edges,int from,int to,absl::flat_hash_set<int> * seen)40 static bool IsReachable(Edges *edges, int from, int to,
41                         absl::flat_hash_set<int> *seen) {
42   seen->insert(from);  // we are investigating "from"; don't do it again
43   if (from == to) return true;
44   for (int i = 0; i != edges->size(); i++) {
45     Edge *edge = &(*edges)[i];
46     if (edge->from == from) {
47       if (edge->to == to) {  // success via edge directly
48         return true;
49       } else if (seen->find(edge->to) == seen->end() &&  // success via edge
50                  IsReachable(edges, edge->to, to, seen)) {
51         return true;
52       }
53     }
54   }
55   return false;
56 }
57 
PrintNodes(Nodes * nodes)58 static void PrintNodes(Nodes *nodes) {
59   LOG(INFO) << "NODES (" << nodes->size() << ")";
60   for (int i = 0; i != nodes->size(); i++) {
61     LOG(INFO) << (*nodes)[i];
62   }
63 }
64 
PrintEdges(Edges * edges)65 static void PrintEdges(Edges *edges) {
66   LOG(INFO) << "EDGES (" << edges->size() << ")";
67   for (int i = 0; i != edges->size(); i++) {
68     int a = (*edges)[i].from;
69     int b = (*edges)[i].to;
70     LOG(INFO) << a << " " << b;
71   }
72   LOG(INFO) << "---";
73 }
74 
PrintGCEdges(Nodes * nodes,tensorflow::GraphCycles * gc)75 static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) {
76   LOG(INFO) << "GC EDGES";
77   for (int i = 0; i != nodes->size(); i++) {
78     for (int j = 0; j != nodes->size(); j++) {
79       int a = (*nodes)[i];
80       int b = (*nodes)[j];
81       if (gc->HasEdge(a, b)) {
82         LOG(INFO) << a << " " << b;
83       }
84     }
85   }
86   LOG(INFO) << "---";
87 }
88 
PrintTransitiveClosure(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)89 static void PrintTransitiveClosure(Nodes *nodes, Edges *edges,
90                                    tensorflow::GraphCycles *gc) {
91   LOG(INFO) << "Transitive closure";
92   for (int i = 0; i != nodes->size(); i++) {
93     for (int j = 0; j != nodes->size(); j++) {
94       int a = (*nodes)[i];
95       int b = (*nodes)[j];
96       absl::flat_hash_set<int> seen;
97       if (IsReachable(edges, a, b, &seen)) {
98         LOG(INFO) << a << " " << b;
99       }
100     }
101   }
102   LOG(INFO) << "---";
103 }
104 
PrintGCTransitiveClosure(Nodes * nodes,tensorflow::GraphCycles * gc)105 static void PrintGCTransitiveClosure(Nodes *nodes,
106                                      tensorflow::GraphCycles *gc) {
107   LOG(INFO) << "GC Transitive closure";
108   for (int i = 0; i != nodes->size(); i++) {
109     for (int j = 0; j != nodes->size(); j++) {
110       int a = (*nodes)[i];
111       int b = (*nodes)[j];
112       if (gc->IsReachable(a, b)) {
113         LOG(INFO) << a << " " << b;
114       }
115     }
116   }
117   LOG(INFO) << "---";
118 }
119 
CheckTransitiveClosure(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)120 static void CheckTransitiveClosure(Nodes *nodes, Edges *edges,
121                                    tensorflow::GraphCycles *gc) {
122   absl::flat_hash_set<int> seen;
123   for (int i = 0; i != nodes->size(); i++) {
124     for (int j = 0; j != nodes->size(); j++) {
125       seen.clear();
126       int a = (*nodes)[i];
127       int b = (*nodes)[j];
128       bool gc_reachable = gc->IsReachable(a, b);
129       CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b));
130       bool reachable = IsReachable(edges, a, b, &seen);
131       if (gc_reachable != reachable) {
132         PrintEdges(edges);
133         PrintGCEdges(nodes, gc);
134         PrintTransitiveClosure(nodes, edges, gc);
135         PrintGCTransitiveClosure(nodes, gc);
136         LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable "
137                    << reachable << " a " << a << " b " << b;
138       }
139     }
140   }
141 }
142 
CheckEdges(Nodes * nodes,Edges * edges,tensorflow::GraphCycles * gc)143 static void CheckEdges(Nodes *nodes, Edges *edges,
144                        tensorflow::GraphCycles *gc) {
145   int count = 0;
146   for (int i = 0; i != edges->size(); i++) {
147     int a = (*edges)[i].from;
148     int b = (*edges)[i].to;
149     if (!gc->HasEdge(a, b)) {
150       PrintEdges(edges);
151       PrintGCEdges(nodes, gc);
152       LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")";
153     }
154   }
155   for (int i = 0; i != nodes->size(); i++) {
156     for (int j = 0; j != nodes->size(); j++) {
157       int a = (*nodes)[i];
158       int b = (*nodes)[j];
159       if (gc->HasEdge(a, b)) {
160         count++;
161       }
162     }
163   }
164   if (count != edges->size()) {
165     PrintEdges(edges);
166     PrintGCEdges(nodes, gc);
167     LOG(FATAL) << "edges->size() " << edges->size() << "  count " << count;
168   }
169 }
170 
171 // Returns the index of a randomly chosen node in *nodes.
172 // Requires *nodes be non-empty.
RandomNode(std::mt19937 * rnd,Nodes * nodes)173 static int RandomNode(std::mt19937 *rnd, Nodes *nodes) {
174   std::uniform_int_distribution<int> distribution(0, nodes->size() - 1);
175   return distribution(*rnd);
176 }
177 
178 // Returns the index of a randomly chosen edge in *edges.
179 // Requires *edges be non-empty.
RandomEdge(std::mt19937 * rnd,Edges * edges)180 static int RandomEdge(std::mt19937 *rnd, Edges *edges) {
181   std::uniform_int_distribution<int> distribution(0, edges->size() - 1);
182   return distribution(*rnd);
183 }
184 
185 // Returns the index of edge (from, to) in *edges or -1 if it is not in *edges.
EdgeIndex(Edges * edges,int from,int to)186 static int EdgeIndex(Edges *edges, int from, int to) {
187   int i = 0;
188   while (i != edges->size() &&
189          ((*edges)[i].from != from || (*edges)[i].to != to)) {
190     i++;
191   }
192   return i == edges->size() ? -1 : i;
193 }
194 
TEST(GraphCycles,RandomizedTest)195 TEST(GraphCycles, RandomizedTest) {
196   Nodes nodes;
197   Edges edges;  // from, to
198   tensorflow::GraphCycles graph_cycles;
199   static const int kMaxNodes = 7;     // use <= 7 nodes to keep test short
200   static const int kDataOffset = 17;  // an offset to the node-specific data
201   int n = 100000;
202   int op = 0;
203   std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1);
204 
205   for (int iter = 0; iter != n; iter++) {
206     if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n;
207 
208     if (VLOG_IS_ON(3)) {
209       LOG(INFO) << "===============";
210       LOG(INFO) << "last op " << op;
211       PrintNodes(&nodes);
212       PrintEdges(&edges);
213       PrintGCEdges(&nodes, &graph_cycles);
214     }
215     for (int i = 0; i != nodes.size(); i++) {
216       ASSERT_EQ(reinterpret_cast<intptr_t>(graph_cycles.GetNodeData(i)),
217                 i + kDataOffset)
218           << " node " << i;
219     }
220     CheckEdges(&nodes, &edges, &graph_cycles);
221     CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
222     std::uniform_int_distribution<int> distribution(0, 5);
223     op = distribution(rnd);
224     switch (op) {
225       case 0:  // Add a node
226         if (nodes.size() < kMaxNodes) {
227           int new_node = graph_cycles.NewNode();
228           ASSERT_NE(-1, new_node);
229           VLOG(1) << "adding node " << new_node;
230           ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
231           graph_cycles.SetNodeData(
232               new_node, reinterpret_cast<void *>(
233                             static_cast<intptr_t>(new_node + kDataOffset)));
234           ASSERT_GE(new_node, 0);
235           for (int i = 0; i != nodes.size(); i++) {
236             ASSERT_NE(nodes[i], new_node);
237           }
238           nodes.push_back(new_node);
239         }
240         break;
241 
242       case 1:  // Remove a node
243         if (!nodes.empty()) {
244           int node_index = RandomNode(&rnd, &nodes);
245           int node = nodes[node_index];
246           nodes[node_index] = nodes.back();
247           nodes.pop_back();
248           VLOG(1) << "removing node " << node;
249           graph_cycles.RemoveNode(node);
250           int i = 0;
251           while (i != edges.size()) {
252             if (edges[i].from == node || edges[i].to == node) {
253               edges[i] = edges.back();
254               edges.pop_back();
255             } else {
256               i++;
257             }
258           }
259         }
260         break;
261 
262       case 2:  // Add an edge
263         if (!nodes.empty()) {
264           int from = RandomNode(&rnd, &nodes);
265           int to = RandomNode(&rnd, &nodes);
266           if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
267             if (graph_cycles.InsertEdge(nodes[from], nodes[to])) {
268               Edge new_edge;
269               new_edge.from = nodes[from];
270               new_edge.to = nodes[to];
271               edges.push_back(new_edge);
272             } else {
273               absl::flat_hash_set<int> seen;
274               ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen))
275                   << "Edge " << nodes[to] << "->" << nodes[from];
276             }
277           }
278         }
279         break;
280 
281       case 3:  // Remove an edge
282         if (!edges.empty()) {
283           int i = RandomEdge(&rnd, &edges);
284           int from = edges[i].from;
285           int to = edges[i].to;
286           ASSERT_EQ(i, EdgeIndex(&edges, from, to));
287           edges[i] = edges.back();
288           edges.pop_back();
289           ASSERT_EQ(-1, EdgeIndex(&edges, from, to));
290           VLOG(1) << "removing edge " << from << " " << to;
291           graph_cycles.RemoveEdge(from, to);
292         }
293         break;
294 
295       case 4:  // Check a path
296         if (!nodes.empty()) {
297           int from = RandomNode(&rnd, &nodes);
298           int to = RandomNode(&rnd, &nodes);
299           int32_t path[2 * kMaxNodes];
300           int path_len = graph_cycles.FindPath(nodes[from], nodes[to],
301                                                2 * kMaxNodes, path);
302           absl::flat_hash_set<int> seen;
303           bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen);
304           bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]);
305           ASSERT_EQ(gc_reachable,
306                     graph_cycles.IsReachableNonConst(nodes[from], nodes[to]));
307           ASSERT_EQ(path_len != 0, reachable);
308           ASSERT_EQ(path_len != 0, gc_reachable);
309           // In the following line, we add one because a node can appear
310           // twice, if the path is from that node to itself, perhaps via
311           // every other node.
312           ASSERT_LE(path_len, kMaxNodes + 1);
313           if (path_len != 0) {
314             ASSERT_EQ(nodes[from], path[0]);
315             ASSERT_EQ(nodes[to], path[path_len - 1]);
316             for (int i = 1; i < path_len; i++) {
317               ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i]));
318               ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i]));
319             }
320           }
321         }
322         break;
323 
324       case 5:  // Check invariants
325         CHECK(graph_cycles.CheckInvariants());
326         break;
327 
328       default:
329         LOG(FATAL);
330     }
331 
332     // Very rarely, test graph expansion by adding then removing many nodes.
333     std::bernoulli_distribution rarely(1.0 / 1024.0);
334     if (rarely(rnd)) {
335       VLOG(3) << "Graph expansion";
336       CheckEdges(&nodes, &edges, &graph_cycles);
337       CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
338       for (int i = 0; i != 256; i++) {
339         int new_node = graph_cycles.NewNode();
340         ASSERT_NE(-1, new_node);
341         VLOG(1) << "adding node " << new_node;
342         ASSERT_GE(new_node, 0);
343         ASSERT_EQ(nullptr, graph_cycles.GetNodeData(new_node));
344         graph_cycles.SetNodeData(
345             new_node, reinterpret_cast<void *>(
346                           static_cast<intptr_t>(new_node + kDataOffset)));
347         for (int j = 0; j != nodes.size(); j++) {
348           ASSERT_NE(nodes[j], new_node);
349         }
350         nodes.push_back(new_node);
351       }
352       for (int i = 0; i != 256; i++) {
353         ASSERT_GT(nodes.size(), 0);
354         int node_index = RandomNode(&rnd, &nodes);
355         int node = nodes[node_index];
356         nodes[node_index] = nodes.back();
357         nodes.pop_back();
358         VLOG(1) << "removing node " << node;
359         graph_cycles.RemoveNode(node);
360         int j = 0;
361         while (j != edges.size()) {
362           if (edges[j].from == node || edges[j].to == node) {
363             edges[j] = edges.back();
364             edges.pop_back();
365           } else {
366             j++;
367           }
368         }
369       }
370       CHECK(graph_cycles.CheckInvariants());
371     }
372   }
373 }
374 
375 class GraphCyclesTest : public ::testing::Test {
376  public:
377   tensorflow::GraphCycles g_;
378 
379   // Test relies on ith NewNode() call returning Node numbered i
GraphCyclesTest()380   GraphCyclesTest() {
381     for (int i = 0; i < 100; i++) {
382       CHECK_EQ(i, g_.NewNode());
383     }
384     CHECK(g_.CheckInvariants());
385   }
386 
AddEdge(int x,int y)387   bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
388 
AddMultiples()389   void AddMultiples() {
390     // For every node x > 0: add edge to 2*x, 3*x
391     for (int x = 1; x < 25; x++) {
392       EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
393       EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
394     }
395     CHECK(g_.CheckInvariants());
396   }
397 
Path(int x,int y)398   std::string Path(int x, int y) {
399     static const int kPathSize = 5;
400     int32_t path[kPathSize];
401     int np = g_.FindPath(x, y, kPathSize, path);
402     std::string result;
403     for (int i = 0; i < np; i++) {
404       if (i >= kPathSize) {
405         result += " ...";
406         break;
407       }
408       if (!result.empty()) result.push_back(' ');
409       char buf[20];
410       snprintf(buf, sizeof(buf), "%d", path[i]);
411       result += buf;
412     }
413     return result;
414   }
415 };
416 
TEST_F(GraphCyclesTest,NoCycle)417 TEST_F(GraphCyclesTest, NoCycle) {
418   AddMultiples();
419   CHECK(g_.CheckInvariants());
420 }
421 
TEST_F(GraphCyclesTest,SimpleCycle)422 TEST_F(GraphCyclesTest, SimpleCycle) {
423   AddMultiples();
424   EXPECT_FALSE(AddEdge(8, 4));
425   EXPECT_EQ("4 8", Path(4, 8));
426   CHECK(g_.CheckInvariants());
427 }
428 
TEST_F(GraphCyclesTest,IndirectCycle)429 TEST_F(GraphCyclesTest, IndirectCycle) {
430   AddMultiples();
431   EXPECT_TRUE(AddEdge(16, 9));
432   CHECK(g_.CheckInvariants());
433   EXPECT_FALSE(AddEdge(9, 2));
434   EXPECT_EQ("2 4 8 16 9", Path(2, 9));
435   CHECK(g_.CheckInvariants());
436 }
437 
TEST_F(GraphCyclesTest,LongPath)438 TEST_F(GraphCyclesTest, LongPath) {
439   ASSERT_TRUE(AddEdge(2, 4));
440   ASSERT_TRUE(AddEdge(4, 6));
441   ASSERT_TRUE(AddEdge(6, 8));
442   ASSERT_TRUE(AddEdge(8, 10));
443   ASSERT_TRUE(AddEdge(10, 12));
444   ASSERT_FALSE(AddEdge(12, 2));
445   EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12));
446   CHECK(g_.CheckInvariants());
447 }
448 
TEST_F(GraphCyclesTest,RemoveNode)449 TEST_F(GraphCyclesTest, RemoveNode) {
450   ASSERT_TRUE(AddEdge(1, 2));
451   ASSERT_TRUE(AddEdge(2, 3));
452   ASSERT_TRUE(AddEdge(3, 4));
453   ASSERT_TRUE(AddEdge(4, 5));
454   g_.RemoveNode(3);
455   ASSERT_TRUE(AddEdge(5, 1));
456 }
457 
TEST_F(GraphCyclesTest,ManyEdges)458 TEST_F(GraphCyclesTest, ManyEdges) {
459   const int N = 50;
460   for (int i = 0; i < N; i++) {
461     for (int j = 1; j < N; j++) {
462       ASSERT_TRUE(AddEdge(i, i + j));
463     }
464   }
465   CHECK(g_.CheckInvariants());
466   ASSERT_TRUE(AddEdge(2 * N - 1, 0));
467   CHECK(g_.CheckInvariants());
468   ASSERT_FALSE(AddEdge(10, 9));
469   CHECK(g_.CheckInvariants());
470 }
471 
TEST_F(GraphCyclesTest,ContractEdge)472 TEST_F(GraphCyclesTest, ContractEdge) {
473   ASSERT_TRUE(AddEdge(1, 2));
474   ASSERT_TRUE(AddEdge(1, 3));
475   ASSERT_TRUE(AddEdge(2, 3));
476   ASSERT_TRUE(AddEdge(2, 4));
477   ASSERT_TRUE(AddEdge(3, 4));
478 
479   EXPECT_FALSE(g_.ContractEdge(1, 3).has_value());
480   CHECK(g_.CheckInvariants());
481   EXPECT_TRUE(g_.HasEdge(1, 3));
482 
483   // Node (2) has more edges.
484   EXPECT_EQ(g_.ContractEdge(1, 2).value(), 2);
485   CHECK(g_.CheckInvariants());
486   EXPECT_TRUE(g_.HasEdge(2, 3));
487   EXPECT_TRUE(g_.HasEdge(2, 4));
488   EXPECT_TRUE(g_.HasEdge(3, 4));
489 
490   // Node (2) has more edges.
491   EXPECT_EQ(g_.ContractEdge(2, 3).value(), 2);
492   CHECK(g_.CheckInvariants());
493   EXPECT_TRUE(g_.HasEdge(2, 4));
494 }
495 
TEST_F(GraphCyclesTest,CanContractEdge)496 TEST_F(GraphCyclesTest, CanContractEdge) {
497   ASSERT_TRUE(AddEdge(1, 2));
498   ASSERT_TRUE(AddEdge(1, 3));
499   ASSERT_TRUE(AddEdge(2, 3));
500   ASSERT_TRUE(AddEdge(2, 4));
501   ASSERT_TRUE(AddEdge(3, 4));
502 
503   EXPECT_FALSE(g_.CanContractEdge(1, 3));
504   EXPECT_FALSE(g_.CanContractEdge(2, 4));
505   EXPECT_TRUE(g_.CanContractEdge(1, 2));
506   EXPECT_TRUE(g_.CanContractEdge(2, 3));
507   EXPECT_TRUE(g_.CanContractEdge(3, 4));
508 }
509 
BM_StressTest(::testing::benchmark::State & state)510 static void BM_StressTest(::testing::benchmark::State &state) {
511   const int num_nodes = state.range(0);
512 
513   for (auto s : state) {
514     tensorflow::GraphCycles g;
515     int32_t *nodes = new int32_t[num_nodes];
516     for (int i = 0; i < num_nodes; i++) {
517       nodes[i] = g.NewNode();
518     }
519     for (int i = 0; i < num_nodes; i++) {
520       int end = std::min(num_nodes, i + 5);
521       for (int j = i + 1; j < end; j++) {
522         if (nodes[i] >= 0 && nodes[j] >= 0) {
523           CHECK(g.InsertEdge(nodes[i], nodes[j]));
524         }
525       }
526     }
527     delete[] nodes;
528   }
529 }
530 BENCHMARK(BM_StressTest)->Range(2048, 1048576);
531 
BM_ContractEdge(::testing::benchmark::State & state)532 static void BM_ContractEdge(::testing::benchmark::State &state) {
533   const int num_nodes = state.range(0);
534 
535   for (auto s : state) {
536     state.PauseTiming();
537     tensorflow::GraphCycles g;
538     std::vector<int32_t> nodes;
539     nodes.reserve(num_nodes);
540     for (int i = 0; i < num_nodes; i++) {
541       nodes.push_back(g.NewNode());
542     }
543     // All edges point toward the last one.
544     for (int i = 0; i < num_nodes - 1; ++i) {
545       g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
546     }
547 
548     state.ResumeTiming();
549     int node = num_nodes - 1;
550     for (int i = 0; i < num_nodes - 1; ++i) {
551       node = g.ContractEdge(nodes[i], node).value();
552     }
553   }
554 }
555 BENCHMARK(BM_ContractEdge)->Arg(1000)->Arg(10000);
556