xref: /aosp_15_r20/external/stg/scc_test.cc (revision 9e3b08ae94a55201065475453d799e8b1378bea6)
1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 // -*- mode: C++ -*-
3 //
4 // Copyright 2020-2022 Google LLC
5 //
6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7 // "License"); you may not use this file except in compliance with the
8 // License.  You may obtain a copy of the License at
9 //
10 //     https://llvm.org/LICENSE.txt
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 // Author: Giuliano Procida
19 
20 #include "scc.h"
21 
22 #include <algorithm>
23 #include <cstddef>
24 #include <cstdint>
25 #include <iostream>
26 #include <map>
27 #include <optional>
28 #include <random>
29 #include <set>
30 #include <sstream>
31 #include <utility>
32 #include <vector>
33 
34 #include <catch2/catch.hpp>
35 
36 namespace Test {
37 
38 using stg::SCC;
39 
40 // Nodes are [0, N), the sets are the out-edges.
41 typedef std::vector<std::set<size_t>> Graph;
42 
43 template <typename G>
invent(size_t n,G & gen)44 Graph invent(size_t n, G& gen) {
45   Graph graph(n);
46   std::uniform_int_distribution<int> toss(0, 1);
47   for (auto& node : graph) {
48     for (size_t o = 0; o < n; ++o) {
49       if (toss(gen)) {
50         node.insert(o);
51       }
52     }
53   }
54   return graph;
55 }
56 
57 // Generate a graph g' where i -> j iff i and j are strongly connected in g.
symmetric_subset_of_reflexive_transitive_closure(Graph g)58 Graph symmetric_subset_of_reflexive_transitive_closure(Graph g) {
59   const size_t n = g.size();
60   // compute reflexive, transitive closure using a modified Floyd-Warshall
61   for (size_t o = 0; o < n; ++o) {
62     // 1. add edge o -> o, for each node o
63     g[o].insert(o);
64   }
65   for (size_t k = 0; k < n; ++k) {
66     // 2. for each node k check for paths of the form: i -> ... -> k -> ... -> j
67     // where no node after k appears in the ...
68     for (size_t i = 0; i < n; ++i) {
69       // since we scan the nodes k in order, it suffices to consider just paths:
70       // i -> k -> j
71       if (g[i].contains(k)) {
72         // we have i -> k
73         for (size_t j = 0; j < n; ++j) {
74           if (g[k].contains(j)) {
75             // and k -> j
76             g[i].insert(j);
77           }
78         }
79       }
80     }
81   }
82   // now have edge i -> j iff there is a path from i to j
83   for (size_t i = 0; i < n; ++i) {
84     for (size_t j = i + 1; j < n; ++j) {
85       // discard i -> j if not j -> i and vice versa
86       auto ij = g[i].contains(j);
87       auto ji = g[j].contains(i);
88       if (ij < ji) {
89         g[j].erase(i);
90       }
91       if (ji < ij) {
92         g[i].erase(j);
93       }
94     }
95   }
96   // now have edge i -> j iff there is a path from i to j and a path from j to i
97   return g;
98 }
99 
100 // Generate a graph where i -> j iff i and j are in the same SCC.
scc_strong_connectivity(const std::vector<std::set<size_t>> & sccs)101 Graph scc_strong_connectivity(const std::vector<std::set<size_t>>& sccs) {
102   size_t n = 0;
103   std::map<size_t, const std::set<size_t>*> edges;
104   for (const auto& scc : sccs) {
105     for (auto o : scc) {
106       if (o >= n) {
107         n = o + 1;
108       }
109       edges[o] = &scc;
110     }
111   }
112   Graph g(n);
113   for (size_t o = 0; o < n; ++o) {
114     g[o] = *edges[o];
115   }
116   return g;
117 }
118 
dfs(std::set<size_t> & visited,SCC<size_t> & scc,const Graph & g,size_t node,std::vector<std::set<size_t>> & sccs)119 void dfs(std::set<size_t>& visited, SCC<size_t>& scc, const Graph& g,
120          size_t node, std::vector<std::set<size_t>>& sccs) {
121   if (visited.contains(node)) {
122     return;
123   }
124   auto handle = scc.Open(node);
125   if (!handle) {
126     return;
127   }
128   for (auto o : g[node]) {
129     dfs(visited, scc, g, o, sccs);
130   }
131   auto nodes = scc.Close(*handle);
132   if (!nodes.empty()) {
133     std::set<size_t> scc_set;
134     for (auto o : nodes) {
135       CHECK(visited.insert(o).second);
136       CHECK(scc_set.insert(o).second);
137     }
138     sccs.push_back(scc_set);
139   }
140 }
141 
process(const Graph & g)142 void process(const Graph& g) {
143   const size_t n = g.size();
144 
145   // find SCCs
146   std::set<size_t> visited;
147   std::vector<std::set<size_t>> sccs;
148   for (size_t o = 0; o < n; ++o) {
149     // could reuse a single SCC finder but assert stronger invariants this way
150     SCC<size_t> scc;
151     dfs(visited, scc, g, o, sccs);
152   }
153 
154   // check partition and topological order properties
155   std::set<size_t> seen;
156   for (const auto& nodes : sccs) {
157     CHECK(!nodes.empty());
158     for (auto node : nodes) {
159       // value in range [0, n)
160       CHECK(node < n);
161       // value seen at most once
162       CHECK(seen.insert(node).second);
163     }
164     for (auto node : nodes) {
165       for (auto o : g[node]) {
166         // edges point to nodes in this or earlier SCCs
167         CHECK(seen.contains(o));
168       }
169     }
170   }
171   // exactly n values seen
172   CHECK(seen.size() == n);
173 
174   // check strong connectivity
175   auto g_scc_closure = scc_strong_connectivity(sccs);
176   auto g_closure = symmetric_subset_of_reflexive_transitive_closure(g);
177   CHECK(g_scc_closure == g_closure);
178 }
179 
180 TEST_CASE("randomly-generated graphs") {
181   std::ranlux48 gen;
182   auto seed = gen();
183   // NOTES:
184   //   Graphs of size 6 are plenty big enough to shake out bugs.
185   //   There are O(2^k^2) possible directed graphs of size k.
186   //   Testing costs are O(k^3) so we restrict accordingly.
187   const uint64_t budget = 10000;
188   for (size_t k = 0; k < 7; ++k) {
189     const uint64_t count = std::min(static_cast<uint64_t>(1) << (k * k),
190                                     budget / (k ? k * k * k : 1));
191     INFO("testing with " << count << " graphs of size " << k);
192     for (uint64_t n = 0; n < count; ++n, ++seed) {
193       gen.seed(seed);
194       const Graph g = invent(k, gen);
195       std::ostringstream os;
196       os << "a graph of " << k << " nodes generated using seed " << seed;
197       GIVEN(os.str()) {
198         process(g);
199       }
200     }
201   }
202 }
203 
204 }  // namespace Test
205