xref: /aosp_15_r20/external/pytorch/c10/util/NetworkFlow.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <string>
6*da0073e9SAndroid Build Coastguard Worker #include <vector>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker /**
9*da0073e9SAndroid Build Coastguard Worker  * This file provides a network flow implementation.
10*da0073e9SAndroid Build Coastguard Worker  * https://en.wikipedia.org/wiki/Flow_network
11*da0073e9SAndroid Build Coastguard Worker  *
12*da0073e9SAndroid Build Coastguard Worker  * It aims to mirror some of the behavior of networkx, which is/was used by
13*da0073e9SAndroid Build Coastguard Worker  * functorch partitioners for splitting the graph into a forward and backward
14*da0073e9SAndroid Build Coastguard Worker  * graph.
15*da0073e9SAndroid Build Coastguard Worker  */
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker namespace c10 {
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker enum class C10_API_ENUM MinCutStatus {
20*da0073e9SAndroid Build Coastguard Worker   SUCCESS = 0,
21*da0073e9SAndroid Build Coastguard Worker   UNBOUNDED = 1,
22*da0073e9SAndroid Build Coastguard Worker   OVERFLOW_INF = 2,
23*da0073e9SAndroid Build Coastguard Worker   INVALID = 3,
24*da0073e9SAndroid Build Coastguard Worker };
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker struct MinCutResult {
27*da0073e9SAndroid Build Coastguard Worker   MinCutStatus status;
28*da0073e9SAndroid Build Coastguard Worker   int64_t max_flow;
29*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> reachable;
30*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> unreachable;
31*da0073e9SAndroid Build Coastguard Worker };
32*da0073e9SAndroid Build Coastguard Worker 
33*da0073e9SAndroid Build Coastguard Worker // Modeled after networkx implementation
34*da0073e9SAndroid Build Coastguard Worker class C10_API NetworkFlowGraph {
35*da0073e9SAndroid Build Coastguard Worker  public:
36*da0073e9SAndroid Build Coastguard Worker   // selected such that INF + INF is < INT64_MAX
37*da0073e9SAndroid Build Coastguard Worker   constexpr static int64_t INF = (1LL << 62) - 1;
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker   struct Edge {
40*da0073e9SAndroid Build Coastguard Worker     std::string source, dest;
41*da0073e9SAndroid Build Coastguard Worker     int64_t capacity;
42*da0073e9SAndroid Build Coastguard Worker   };
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker   MinCutStatus add_edge(
45*da0073e9SAndroid Build Coastguard Worker       const std::string& source,
46*da0073e9SAndroid Build Coastguard Worker       const std::string& dest,
47*da0073e9SAndroid Build Coastguard Worker       int64_t capacity = 1);
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   MinCutResult minimum_cut(const std::string& s, const std::string& t) const;
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   std::vector<Edge> edges;
52*da0073e9SAndroid Build Coastguard Worker };
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker } // namespace c10
55