xref: /aosp_15_r20/external/pytorch/torch/package/_digraph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import deque
3from typing import List, Set
4
5
6class DiGraph:
7    """Really simple unweighted directed graph data structure to track dependencies.
8
9    The API is pretty much the same as networkx so if you add something just
10    copy their API.
11    """
12
13    def __init__(self):
14        # Dict of node -> dict of arbitrary attributes
15        self._node = {}
16        # Nested dict of node -> successor node -> nothing.
17        # (didn't implement edge data)
18        self._succ = {}
19        # Nested dict of node -> predecessor node -> nothing.
20        self._pred = {}
21
22        # Keep track of the order in which nodes are added to
23        # the graph.
24        self._node_order = {}
25        self._insertion_idx = 0
26
27    def add_node(self, n, **kwargs):
28        """Add a node to the graph.
29
30        Args:
31            n: the node. Can we any object that is a valid dict key.
32            **kwargs: any attributes you want to attach to the node.
33        """
34        if n not in self._node:
35            self._node[n] = kwargs
36            self._succ[n] = {}
37            self._pred[n] = {}
38            self._node_order[n] = self._insertion_idx
39            self._insertion_idx += 1
40        else:
41            self._node[n].update(kwargs)
42
43    def add_edge(self, u, v):
44        """Add an edge to graph between nodes ``u`` and ``v``
45
46        ``u`` and ``v`` will be created if they do not already exist.
47        """
48        # add nodes
49        self.add_node(u)
50        self.add_node(v)
51
52        # add the edge
53        self._succ[u][v] = True
54        self._pred[v][u] = True
55
56    def successors(self, n):
57        """Returns an iterator over successor nodes of n."""
58        try:
59            return iter(self._succ[n])
60        except KeyError as e:
61            raise ValueError(f"The node {n} is not in the digraph.") from e
62
63    def predecessors(self, n):
64        """Returns an iterator over predecessors nodes of n."""
65        try:
66            return iter(self._pred[n])
67        except KeyError as e:
68            raise ValueError(f"The node {n} is not in the digraph.") from e
69
70    @property
71    def edges(self):
72        """Returns an iterator over all edges (u, v) in the graph"""
73        for n, successors in self._succ.items():
74            for succ in successors:
75                yield n, succ
76
77    @property
78    def nodes(self):
79        """Returns a dictionary of all nodes to their attributes."""
80        return self._node
81
82    def __iter__(self):
83        """Iterate over the nodes."""
84        return iter(self._node)
85
86    def __contains__(self, n):
87        """Returns True if ``n`` is a node in the graph, False otherwise."""
88        try:
89            return n in self._node
90        except TypeError:
91            return False
92
93    def forward_transitive_closure(self, src: str) -> Set[str]:
94        """Returns a set of nodes that are reachable from src"""
95
96        result = set(src)
97        working_set = deque(src)
98        while len(working_set) > 0:
99            cur = working_set.popleft()
100            for n in self.successors(cur):
101                if n not in result:
102                    result.add(n)
103                    working_set.append(n)
104        return result
105
106    def backward_transitive_closure(self, src: str) -> Set[str]:
107        """Returns a set of nodes that are reachable from src in reverse direction"""
108
109        result = set(src)
110        working_set = deque(src)
111        while len(working_set) > 0:
112            cur = working_set.popleft()
113            for n in self.predecessors(cur):
114                if n not in result:
115                    result.add(n)
116                    working_set.append(n)
117        return result
118
119    def all_paths(self, src: str, dst: str):
120        """Returns a subgraph rooted at src that shows all the paths to dst."""
121
122        result_graph = DiGraph()
123        # First compute forward transitive closure of src (all things reachable from src).
124        forward_reachable_from_src = self.forward_transitive_closure(src)
125
126        if dst not in forward_reachable_from_src:
127            return result_graph
128
129        # Second walk the reverse dependencies of dst, adding each node to
130        # the output graph iff it is also present in forward_reachable_from_src.
131        # we don't use backward_transitive_closures for optimization purposes
132        working_set = deque(dst)
133        while len(working_set) > 0:
134            cur = working_set.popleft()
135            for n in self.predecessors(cur):
136                if n in forward_reachable_from_src:
137                    result_graph.add_edge(n, cur)
138                    # only explore further if its reachable from src
139                    working_set.append(n)
140
141        return result_graph.to_dot()
142
143    def first_path(self, dst: str) -> List[str]:
144        """Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
145        path = []
146
147        while dst:
148            path.append(dst)
149            candidates = self._pred[dst].keys()
150            dst, min_idx = "", None
151            for candidate in candidates:
152                idx = self._node_order.get(candidate, None)
153                if idx is None:
154                    break
155                if min_idx is None or idx < min_idx:
156                    min_idx = idx
157                    dst = candidate
158
159        return list(reversed(path))
160
161    def to_dot(self) -> str:
162        """Returns the dot representation of the graph.
163
164        Returns:
165            A dot representation of the graph.
166        """
167        edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
168        return f"""\
169digraph G {{
170rankdir = LR;
171node [shape=box];
172{edges}
173}}
174"""
175