1 #include <c10/test/util/Macros.h>
2 #include <c10/util/NetworkFlow.h>
3 #include <gtest/gtest.h>
4 #include <cstdlib>
5
6 namespace {
7
8 template <typename T>
vector_contains(const std::vector<T> & vec,const T & element)9 bool vector_contains(const std::vector<T>& vec, const T& element) {
10 for (const auto& e : vec) {
11 if (e == element) {
12 return true;
13 }
14 }
15 return false;
16 }
17
18 template <typename T>
expect_vector_contains_subset(const std::vector<T> & vec,const std::vector<T> & subset)19 void expect_vector_contains_subset(
20 const std::vector<T>& vec,
21 const std::vector<T>& subset) {
22 for (auto& element : subset) {
23 if (!vector_contains(vec, element)) {
24 std::stringstream ss;
25 ss << "Failed: checking whether {";
26 for (auto& e : subset) {
27 ss << e << ", ";
28 }
29 ss << "} is a subset of {";
30 for (auto& e : vec) {
31 ss << e << ", ";
32 }
33 ss << "}, but couldn't find " << element;
34 FAIL() << ss.str();
35 }
36 }
37 }
38
39 namespace test_network_flow {
40
TEST(NetworkFlowTest,basic)41 TEST(NetworkFlowTest, basic) {
42 /*
43 * 3 1 2
44 * -->b-- ->e--
45 * / 1| \/ \
46 * / 2 v 2/\ 2 \
47 * a---->c-/ ->f---->h
48 * \ 2\/ /
49 * \3 1/\ 2/
50 * -->d-- ->g--
51 *
52 * Consider these augmenting paths that constitute a blocking flow:
53 * a -> d -> f -> h (capacity 1), saturates d->f
54 * a -> c -> g -> h (capacity 2), saturates a->c, c->g, g->h
55 * a -> b -> c -> e -> h (capacity 1), saturates b->c
56 * a -> b -> f -> h (capacity 1), saturates b->f, f->h
57 */
58 c10::NetworkFlowGraph g;
59 g.add_edge("a", "b", 3); // flow: 2
60 g.add_edge("a", "c", 2); // flow: 2
61 g.add_edge("a", "d", 3); // flow: 1
62 g.add_edge("b", "f", 1); // flow: 1
63 g.add_edge("c", "e", 2); // flow: 1
64 g.add_edge("c", "g", 2); // flow: 2
65 g.add_edge("d", "f", 1); // flow: 1
66 g.add_edge("b", "c", 1); // flow: 1
67 g.add_edge("e", "h", 2); // flow: 1
68 g.add_edge("f", "h", 2); // flow: 2
69 g.add_edge("g", "h", 2); // flow: 2
70 auto res = g.minimum_cut("a", "h");
71 EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
72 EXPECT_EQ(res.max_flow, 5);
73
74 // how we "reach" these vertices from "h":
75 // h -> e: we see the e->h edge has residual capacity
76 // e -> c: we see the c->e edge has residual capacity
77 // c -> g: the c->g edge has flow, therefore the g->c edge has residual
78 // capacity
79 expect_vector_contains_subset(res.unreachable, {"h", "e", "c", "g"});
80 expect_vector_contains_subset(res.reachable, {"a", "b", "d", "f"});
81 }
82
TEST(NetworkFlowTest,loop)83 TEST(NetworkFlowTest, loop) {
84 /* 1
85 * -------------------
86 * / \
87 * 1 / 1 1 \ 1
88 * a --------> b --------> c -------> d --------> e
89 */
90 c10::NetworkFlowGraph g;
91 g.add_edge("a", "b", 1); // flow: 1
92 g.add_edge("b", "c", 1); // flow: 1
93 g.add_edge("c", "d", 1); // flow: 1
94 g.add_edge("d", "e", 1); // flow: 1
95 g.add_edge("d", "b", 1); // flow: 0
96 auto res = g.minimum_cut("a", "e");
97 EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
98 EXPECT_EQ(res.max_flow, 1);
99
100 expect_vector_contains_subset(res.unreachable, {"e"});
101 expect_vector_contains_subset(res.reachable, {"a", "b", "c", "d"});
102 }
103
TEST(NetworkFlowTest,disconnected_vertices)104 TEST(NetworkFlowTest, disconnected_vertices) {
105 /*
106 * 1
107 * c --------> d
108 *
109 * 1
110 * a --------> b
111 */
112 c10::NetworkFlowGraph g;
113 g.add_edge("a", "b", 1); // flow: 1
114 g.add_edge("c", "d", 1); // flow: 0
115 auto res = g.minimum_cut("a", "b");
116 EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
117 EXPECT_EQ(res.max_flow, 1);
118
119 expect_vector_contains_subset(res.unreachable, {"b"});
120 // unintuitively, "c" and "d" get marked as reachable; this mirrors networkx
121 // behavior.
122 expect_vector_contains_subset(res.reachable, {"a", "c", "d"});
123 }
124
TEST(NetworkFlowTest,invalid_endpoints)125 TEST(NetworkFlowTest, invalid_endpoints) {
126 c10::NetworkFlowGraph g;
127 g.add_edge("a", "b", 1);
128 auto res = g.minimum_cut("a", "c");
129 EXPECT_EQ(res.status, c10::MinCutStatus::INVALID);
130
131 res = g.minimum_cut("c", "b");
132 EXPECT_EQ(res.status, c10::MinCutStatus::INVALID);
133 }
134
TEST(NetworkFlowTest,unbounded)135 TEST(NetworkFlowTest, unbounded) {
136 c10::NetworkFlowGraph g;
137 g.add_edge("a", "b", c10::NetworkFlowGraph::INF);
138 auto res = g.minimum_cut("a", "b");
139 EXPECT_EQ(res.status, c10::MinCutStatus::UNBOUNDED);
140 }
141
TEST(NetworkFlowTest,overflow)142 TEST(NetworkFlowTest, overflow) {
143 c10::NetworkFlowGraph g;
144 auto flow1 = c10::NetworkFlowGraph::INF / 2;
145 auto flow2 = c10::NetworkFlowGraph::INF - flow1;
146 g.add_edge("a", "b", flow1);
147 g.add_edge("a", "b", flow2);
148 auto res = g.minimum_cut("a", "b");
149 EXPECT_EQ(res.status, c10::MinCutStatus::OVERFLOW_INF);
150 }
151
TEST(NetworkFlowTest,reverse_edge)152 TEST(NetworkFlowTest, reverse_edge) {
153 /*
154 * 100
155 * --------
156 * / \
157 * 1 < 1 \
158 * a ---------> b ---------> c
159 *
160 */
161 c10::NetworkFlowGraph g;
162 g.add_edge("a", "b", 1);
163 g.add_edge("b", "c", 1);
164 g.add_edge("c", "a", 100);
165 auto res = g.minimum_cut("a", "c");
166 EXPECT_EQ(res.status, c10::MinCutStatus::SUCCESS);
167 EXPECT_EQ(res.max_flow, 1);
168
169 expect_vector_contains_subset(res.unreachable, {"c"});
170 expect_vector_contains_subset(res.reachable, {"a", "b"});
171 }
172
173 } // namespace test_network_flow
174
175 } // namespace
176