xref: /aosp_15_r20/external/pytorch/c10/test/util/NetworkFlow_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/test/util/Macros.h>
2 #include <c10/util/NetworkFlow.h>
3 #include <gtest/gtest.h>
4 #include <cstdlib>
5 
6 namespace {
7 
8 template <typename T>
vector_contains(const std::vector<T> & vec,const T & element)9 bool vector_contains(const std::vector<T>& vec, const T& element) {
10   for (const auto& e : vec) {
11     if (e == element) {
12       return true;
13     }
14   }
15   return false;
16 }
17 
18 template <typename T>
expect_vector_contains_subset(const std::vector<T> & vec,const std::vector<T> & subset)19 void expect_vector_contains_subset(
20     const std::vector<T>& vec,
21     const std::vector<T>& subset) {
22   for (auto& element : subset) {
23     if (!vector_contains(vec, element)) {
24       std::stringstream ss;
25       ss << "Failed: checking whether {";
26       for (auto& e : subset) {
27         ss << e << ", ";
28       }
29       ss << "} is a subset of {";
30       for (auto& e : vec) {
31         ss << e << ", ";
32       }
33       ss << "}, but couldn't find " << element;
34       FAIL() << ss.str();
35     }
36   }
37 }
38 
39 namespace test_network_flow {
40 
TEST(NetworkFlowTest,basic)41 TEST(NetworkFlowTest, basic) {
42   /*
43    *     3    1       2
44    *      -->b--  ->e--
45    *     /  1|  \/     \
46    *    / 2  v 2/\   2  \
47    *   a---->c-/  ->f---->h
48    *    \      2\/      /
49    *     \3    1/\    2/
50    *      -->d--  ->g--
51    *
52    * Consider these augmenting paths that constitute a blocking flow:
53    * a -> d -> f -> h (capacity 1), saturates d->f
54    * a -> c -> g -> h (capacity 2), saturates a->c, c->g, g->h
55    * a -> b -> c -> e -> h (capacity 1), saturates b->c
56    * a -> b -> f -> h (capacity 1), saturates b->f, f->h
57    */
58   c10::NetworkFlowGraph g;
59   g.add_edge("a", "b", 3); // flow: 2
60   g.add_edge("a", "c", 2); // flow: 2
61   g.add_edge("a", "d", 3); // flow: 1
62   g.add_edge("b", "f", 1); // flow: 1
63   g.add_edge("c", "e", 2); // flow: 1
64   g.add_edge("c", "g", 2); // flow: 2
65   g.add_edge("d", "f", 1); // flow: 1
66   g.add_edge("b", "c", 1); // flow: 1
67   g.add_edge("e", "h", 2); // flow: 1
68   g.add_edge("f", "h", 2); // flow: 2
69   g.add_edge("g", "h", 2); // flow: 2
70   auto res = g.minimum_cut("a", "h");
71   EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
72   EXPECT_EQ(res.max_flow, 5);
73 
74   // how we "reach" these vertices from "h":
75   // h -> e: we see the e->h edge has residual capacity
76   // e -> c: we see the c->e edge has residual capacity
77   // c -> g: the c->g edge has flow, therefore the g->c edge has residual
78   // capacity
79   expect_vector_contains_subset(res.unreachable, {"h", "e", "c", "g"});
80   expect_vector_contains_subset(res.reachable, {"a", "b", "d", "f"});
81 }
82 
TEST(NetworkFlowTest,loop)83 TEST(NetworkFlowTest, loop) {
84   /*                         1
85    *                 -------------------
86    *                /                   \
87    *       1       /    1          1     \    1
88    *  a --------> b --------> c -------> d --------> e
89    */
90   c10::NetworkFlowGraph g;
91   g.add_edge("a", "b", 1); // flow: 1
92   g.add_edge("b", "c", 1); // flow: 1
93   g.add_edge("c", "d", 1); // flow: 1
94   g.add_edge("d", "e", 1); // flow: 1
95   g.add_edge("d", "b", 1); // flow: 0
96   auto res = g.minimum_cut("a", "e");
97   EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
98   EXPECT_EQ(res.max_flow, 1);
99 
100   expect_vector_contains_subset(res.unreachable, {"e"});
101   expect_vector_contains_subset(res.reachable, {"a", "b", "c", "d"});
102 }
103 
TEST(NetworkFlowTest,disconnected_vertices)104 TEST(NetworkFlowTest, disconnected_vertices) {
105   /*
106    *        1
107    *  c --------> d
108    *
109    *       1
110    *  a --------> b
111    */
112   c10::NetworkFlowGraph g;
113   g.add_edge("a", "b", 1); // flow: 1
114   g.add_edge("c", "d", 1); // flow: 0
115   auto res = g.minimum_cut("a", "b");
116   EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
117   EXPECT_EQ(res.max_flow, 1);
118 
119   expect_vector_contains_subset(res.unreachable, {"b"});
120   // unintuitively, "c" and "d" get marked as reachable; this mirrors networkx
121   // behavior.
122   expect_vector_contains_subset(res.reachable, {"a", "c", "d"});
123 }
124 
TEST(NetworkFlowTest,invalid_endpoints)125 TEST(NetworkFlowTest, invalid_endpoints) {
126   c10::NetworkFlowGraph g;
127   g.add_edge("a", "b", 1);
128   auto res = g.minimum_cut("a", "c");
129   EXPECT_EQ(res.status, c10::MinCutStatus::INVALID);
130 
131   res = g.minimum_cut("c", "b");
132   EXPECT_EQ(res.status, c10::MinCutStatus::INVALID);
133 }
134 
TEST(NetworkFlowTest,unbounded)135 TEST(NetworkFlowTest, unbounded) {
136   c10::NetworkFlowGraph g;
137   g.add_edge("a", "b", c10::NetworkFlowGraph::INF);
138   auto res = g.minimum_cut("a", "b");
139   EXPECT_EQ(res.status, c10::MinCutStatus::UNBOUNDED);
140 }
141 
TEST(NetworkFlowTest,overflow)142 TEST(NetworkFlowTest, overflow) {
143   c10::NetworkFlowGraph g;
144   auto flow1 = c10::NetworkFlowGraph::INF / 2;
145   auto flow2 = c10::NetworkFlowGraph::INF - flow1;
146   g.add_edge("a", "b", flow1);
147   g.add_edge("a", "b", flow2);
148   auto res = g.minimum_cut("a", "b");
149   EXPECT_EQ(res.status, c10::MinCutStatus::OVERFLOW_INF);
150 }
151 
TEST(NetworkFlowTest,reverse_edge)152 TEST(NetworkFlowTest, reverse_edge) {
153   /*
154    *                    100
155    *                  --------
156    *                 /        \
157    *        1       <    1     \
158    *  a ---------> b ---------> c
159    *
160    */
161   c10::NetworkFlowGraph g;
162   g.add_edge("a", "b", 1);
163   g.add_edge("b", "c", 1);
164   g.add_edge("c", "a", 100);
165   auto res = g.minimum_cut("a", "c");
166   EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
167   EXPECT_EQ(res.max_flow, 1);
168 
169   expect_vector_contains_subset(res.unreachable, {"c"});
170   expect_vector_contains_subset(res.reachable, {"a", "b"});
171 }
172 
173 } // namespace test_network_flow
174 
175 } // namespace
176