1# mypy: allow-untyped-defs 2from collections import deque 3from typing import List, Set 4 5 6class DiGraph: 7 """Really simple unweighted directed graph data structure to track dependencies. 8 9 The API is pretty much the same as networkx so if you add something just 10 copy their API. 11 """ 12 13 def __init__(self): 14 # Dict of node -> dict of arbitrary attributes 15 self._node = {} 16 # Nested dict of node -> successor node -> nothing. 17 # (didn't implement edge data) 18 self._succ = {} 19 # Nested dict of node -> predecessor node -> nothing. 20 self._pred = {} 21 22 # Keep track of the order in which nodes are added to 23 # the graph. 24 self._node_order = {} 25 self._insertion_idx = 0 26 27 def add_node(self, n, **kwargs): 28 """Add a node to the graph. 29 30 Args: 31 n: the node. Can we any object that is a valid dict key. 32 **kwargs: any attributes you want to attach to the node. 33 """ 34 if n not in self._node: 35 self._node[n] = kwargs 36 self._succ[n] = {} 37 self._pred[n] = {} 38 self._node_order[n] = self._insertion_idx 39 self._insertion_idx += 1 40 else: 41 self._node[n].update(kwargs) 42 43 def add_edge(self, u, v): 44 """Add an edge to graph between nodes ``u`` and ``v`` 45 46 ``u`` and ``v`` will be created if they do not already exist. 47 """ 48 # add nodes 49 self.add_node(u) 50 self.add_node(v) 51 52 # add the edge 53 self._succ[u][v] = True 54 self._pred[v][u] = True 55 56 def successors(self, n): 57 """Returns an iterator over successor nodes of n.""" 58 try: 59 return iter(self._succ[n]) 60 except KeyError as e: 61 raise ValueError(f"The node {n} is not in the digraph.") from e 62 63 def predecessors(self, n): 64 """Returns an iterator over predecessors nodes of n.""" 65 try: 66 return iter(self._pred[n]) 67 except KeyError as e: 68 raise ValueError(f"The node {n} is not in the digraph.") from e 69 70 @property 71 def edges(self): 72 """Returns an iterator over all edges (u, v) in the graph""" 73 for n, successors in self._succ.items(): 74 for succ in successors: 75 yield n, succ 76 77 @property 78 def nodes(self): 79 """Returns a dictionary of all nodes to their attributes.""" 80 return self._node 81 82 def __iter__(self): 83 """Iterate over the nodes.""" 84 return iter(self._node) 85 86 def __contains__(self, n): 87 """Returns True if ``n`` is a node in the graph, False otherwise.""" 88 try: 89 return n in self._node 90 except TypeError: 91 return False 92 93 def forward_transitive_closure(self, src: str) -> Set[str]: 94 """Returns a set of nodes that are reachable from src""" 95 96 result = set(src) 97 working_set = deque(src) 98 while len(working_set) > 0: 99 cur = working_set.popleft() 100 for n in self.successors(cur): 101 if n not in result: 102 result.add(n) 103 working_set.append(n) 104 return result 105 106 def backward_transitive_closure(self, src: str) -> Set[str]: 107 """Returns a set of nodes that are reachable from src in reverse direction""" 108 109 result = set(src) 110 working_set = deque(src) 111 while len(working_set) > 0: 112 cur = working_set.popleft() 113 for n in self.predecessors(cur): 114 if n not in result: 115 result.add(n) 116 working_set.append(n) 117 return result 118 119 def all_paths(self, src: str, dst: str): 120 """Returns a subgraph rooted at src that shows all the paths to dst.""" 121 122 result_graph = DiGraph() 123 # First compute forward transitive closure of src (all things reachable from src). 124 forward_reachable_from_src = self.forward_transitive_closure(src) 125 126 if dst not in forward_reachable_from_src: 127 return result_graph 128 129 # Second walk the reverse dependencies of dst, adding each node to 130 # the output graph iff it is also present in forward_reachable_from_src. 131 # we don't use backward_transitive_closures for optimization purposes 132 working_set = deque(dst) 133 while len(working_set) > 0: 134 cur = working_set.popleft() 135 for n in self.predecessors(cur): 136 if n in forward_reachable_from_src: 137 result_graph.add_edge(n, cur) 138 # only explore further if its reachable from src 139 working_set.append(n) 140 141 return result_graph.to_dot() 142 143 def first_path(self, dst: str) -> List[str]: 144 """Returns a list of nodes that show the first path that resulted in dst being added to the graph.""" 145 path = [] 146 147 while dst: 148 path.append(dst) 149 candidates = self._pred[dst].keys() 150 dst, min_idx = "", None 151 for candidate in candidates: 152 idx = self._node_order.get(candidate, None) 153 if idx is None: 154 break 155 if min_idx is None or idx < min_idx: 156 min_idx = idx 157 dst = candidate 158 159 return list(reversed(path)) 160 161 def to_dot(self) -> str: 162 """Returns the dot representation of the graph. 163 164 Returns: 165 A dot representation of the graph. 166 """ 167 edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges) 168 return f"""\ 169digraph G {{ 170rankdir = LR; 171node [shape=box]; 172{edges} 173}} 174""" 175