1# mypy: allow-untyped-defs 2from enum import Enum 3from typing import NamedTuple, Dict, List, Set 4 5from torch.fx.node import Node, map_arg 6 7 8class Partition: 9 """Partition class contains all the information about an individual partition. 10 It also provides necessary methods for manipulation the partition. 11 """ 12 13 def __init__(self, partition_id: int) -> None: 14 self.nodes: Set[Node] = set() 15 self.partition_id = partition_id 16 self.parents: Set[Partition] = set() 17 self.children: Set[Partition] = set() 18 self.bfs_level: int = -1 19 self.used_mem_bytes: int = 0 20 self.logical_device_ids: List[int] = [] 21 22 def __str__(self): 23 return str(self.partition_id) 24 25 def recalculate_mem_size(self): 26 self.used_mem_bytes = 0 27 for node in self.nodes: 28 self.used_mem_bytes += get_extra_size_of(node, self.nodes) 29 30 def add_node(self, node): 31 input_nodes: Dict[Node, None] = {} 32 map_arg(node.args, input_nodes.setdefault) 33 map_arg(node.kwargs, input_nodes.setdefault) 34 # Add current node's input nodes if they are placeholder or constants 35 for n in input_nodes: 36 if n.op in {"placeholder", "get_attr"}: 37 self.nodes.add(n) 38 self.nodes.add(node) 39 self.recalculate_mem_size() 40 41 def remove_node(self, node): 42 # Remove a node only if the node is in the partition 43 if node in self.nodes: 44 self.nodes.remove(node) 45 # Collect the node's input nodes 46 input_nodes: Dict[Node, None] = {} 47 map_arg(node.args, input_nodes.setdefault) 48 map_arg(node.kwargs, input_nodes.setdefault) 49 # Check if an input node is a placeholder or get_attr, 50 # and this input node is not used by some other nodes in this partition, 51 # the remove this input node 52 for input_node in input_nodes: 53 if all( 54 n not in self.nodes for n in input_node.users 55 ) and input_node.op in {"placeholder", "get_attr"}: 56 self.nodes.remove(input_node) 57 self.recalculate_mem_size() 58 59 60class Device(NamedTuple): 61 name: str 62 available_mem_bytes: int 63 logical_id: int 64 65 66class NodeLatency(NamedTuple): 67 # Latency due to the memory bandwidth 68 mem_latency_sec: float 69 # Latency due to the computation 70 computer_latency_sec: float 71 72 73class PartitionLatency(NamedTuple): 74 # Sum of all nodes' memory latency on the critical path 75 mem_latency_sec: float 76 # Sum of all nodes' compute latency on the critical path 77 computer_latency_sec: float 78 # Latency of the critical path 79 overall_latency_sec: float 80 81 82class PartitionMode(Enum): 83 size_based = 0 84 sparse_nn = 1 85 cost_aware = 2 86 kl_based = 3 87 aot_based = 4 88 89 90class PartitionerConfig(NamedTuple): 91 devices: List[Device] 92 mode: PartitionMode = PartitionMode.size_based 93 transfer_rate_bytes_per_sec: float = 0.0 94 node_to_latency_mapping: Dict[Node, NodeLatency] = {} 95 node_to_partition_mapping: Dict[Node, int] = {} 96 partition_to_logical_device_mapping: Dict[int, List[int]] = {} 97 # Saturate host by replicating partitions to the remaining idle devices. 98 saturate_host: bool = False 99 100 101def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: 102 """Given a node and a set of nodes, 103 this function return the extra size that needed 104 if this node is included in this set. 105 """ 106 # Find all its input nodes 107 input_nodes: Dict[Node, None] = {} 108 map_arg(node.args, input_nodes.setdefault) 109 map_arg(node.kwargs, input_nodes.setdefault) 110 # Calculate total size of related nodes 111 total_size_of_input_nodes = 0 112 for n in input_nodes: 113 # Make sure this node hasn't been in this set yet 114 if n not in nodes: 115 size_bytes = getattr(n, "size_bytes", None) 116 if size_bytes: 117 total_size_of_input_nodes += size_bytes.output_size 118 else: 119 raise RuntimeError("node has no size_bytes attr") 120 # Don't forget the op node itself 121 size_bytes = getattr(node, "size_bytes", None) 122 if size_bytes: 123 total_size_of_input_nodes += size_bytes.total_size 124 else: 125 raise RuntimeError("node has no size_bytes attr") 126 return total_size_of_input_nodes 127 128 129def get_latency_of_one_partition( 130 partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] 131) -> PartitionLatency: 132 """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" 133 134 def get_top_nodes(partition: Partition) -> List[Node]: 135 """Given a partition, return a list of nodes on the top bfs level""" 136 top_nodes: List[Node] = [] 137 for node in partition.nodes: 138 # Skip placeholder and get_attr nodes 139 if node.op in {"placeholder", "get_attr"}: 140 continue 141 input_nodes: Dict[Node, None] = {} 142 map_arg(node.args, input_nodes.setdefault) 143 map_arg(node.kwargs, input_nodes.setdefault) 144 # If a node has no input nodes in this partition, 145 # or its input nodes in this partition are placeholders and get_attrs 146 # this node is on the top bfs level in this partition 147 if not any( 148 n in partition.nodes and n.op not in {"placeholder", "get_attr"} 149 for n in input_nodes 150 ): 151 top_nodes.append(node) 152 return top_nodes 153 154 def dfs_helper(node: Node, partition_latency) -> PartitionLatency: 155 """Given a top node of a partition, this function returns 156 the latency of the critical path in the partition 157 """ 158 node_latency = node_to_latency_mapping[node] 159 # Calculate the current overall latency of the partition 160 overall_latency_sec = partition_latency.overall_latency_sec + max( 161 node_latency.computer_latency_sec, node_latency.mem_latency_sec 162 ) 163 # Update the mem latency of this path 164 mem_latency_sec = ( 165 partition_latency.mem_latency_sec + node_latency.mem_latency_sec 166 ) 167 # Update the compute latency of this path 168 computer_latency_sec = ( 169 partition_latency.computer_latency_sec + node_latency.computer_latency_sec 170 ) 171 # Get all users of this node that are in this partition 172 users = set(node.users).intersection(partition.nodes) 173 if users: 174 max_latency = PartitionLatency( 175 mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 176 ) 177 for n in users: 178 # Get new partition latency recursively 179 new_partition_latency = dfs_helper( 180 n, 181 PartitionLatency( 182 mem_latency_sec, computer_latency_sec, overall_latency_sec 183 ), 184 ) 185 if ( 186 new_partition_latency.overall_latency_sec 187 > max_latency.overall_latency_sec 188 ): 189 max_latency = new_partition_latency 190 return max_latency 191 # If there is no user, the node is at bottom of the partition 192 return PartitionLatency( 193 mem_latency_sec, computer_latency_sec, overall_latency_sec 194 ) 195 196 # Main part starts 197 # Get all top level nodes of this partition 198 top_nodes = get_top_nodes(partition) 199 critical_path_latency = PartitionLatency( 200 mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 201 ) 202 # Go through all top nodes and find the largest latency (critical pass latency) 203 for node in top_nodes: 204 partition_latency = dfs_helper( 205 node, 206 PartitionLatency( 207 mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 208 ), 209 ) 210 if ( 211 partition_latency.overall_latency_sec 212 > critical_path_latency.overall_latency_sec 213 ): 214 critical_path_latency = partition_latency 215 return critical_path_latency 216 217 218def get_partition_to_latency_mapping( 219 partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] 220) -> Dict[Partition, PartitionLatency]: 221 """Given all the partitions and node_to_latency_mapping dictionary, 222 return a mapping dictionary of each partition to its overall latency 223 """ 224 partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} 225 # Go through each partition and get its latency 226 for partition in partitions: 227 partition_latency = get_latency_of_one_partition( 228 partition, node_to_latency_mapping 229 ) 230 partition_to_latency_mapping[partition] = partition_latency 231 return partition_to_latency_mapping 232 233 234def get_comm_latency_between( 235 parent_partition: Partition, 236 child_partition: Partition, 237 transfer_rate_bytes_per_sec: float, 238): 239 """Given two partitions (parent and child), 240 calculate the communication latency between the two. 241 """ 242 # If two partitions are on the same device, the comm latency is 0. 243 if ( 244 parent_partition.logical_device_ids != [] 245 and child_partition.logical_device_ids != [] 246 and parent_partition.logical_device_ids == child_partition.logical_device_ids 247 ): 248 return 0.0 249 # Keep tracking the communication size between parent and child 250 comm_size = 0 251 # Keep tracking all the counted node 252 visited_nodes = set() 253 # Go through all nodes in the child partition 254 # If a node has input nodes from the parent partition, 255 # the output size of those input nodes will be counted 256 # and added to comm_size 257 for node in child_partition.nodes: 258 input_nodes: Dict[Node, None] = {} 259 map_arg(node.args, input_nodes.setdefault) 260 map_arg(node.kwargs, input_nodes.setdefault) 261 for n in input_nodes: 262 if n in parent_partition.nodes and n not in visited_nodes: 263 size_bytes = getattr(n, "size_bytes", None) 264 if size_bytes is not None: 265 comm_size += size_bytes.output_size 266 visited_nodes.add(n) 267 return comm_size / transfer_rate_bytes_per_sec 268 269 270def get_latency_of_partitioned_graph( 271 partitions: List[Partition], 272 partition_to_latency_mapping: Dict[Partition, PartitionLatency], 273 transfer_rate_bytes_per_sec: float, 274): 275 """Given all partitions in a graph, find the critical path among all partitions 276 and return its latency as the latency of the whole graph 277 """ 278 279 def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: 280 """This function helps to recursively get the latency of a path of partitions""" 281 # Update latency by adding current partition's latency 282 latency_so_far_sec += partition_to_latency_mapping[ 283 partition 284 ].overall_latency_sec 285 children = partition.children 286 if partition.children: 287 max_latency_sec = 0.0 288 for child in partition.children: 289 # Calculate latency between 290 comm_latency_sec = get_comm_latency_between( 291 partition, child, transfer_rate_bytes_per_sec 292 ) 293 new_latency_sec = dfs_helper( 294 child, latency_so_far_sec + comm_latency_sec 295 ) 296 if new_latency_sec > max_latency_sec: 297 max_latency_sec = new_latency_sec 298 return max_latency_sec 299 return latency_so_far_sec 300 301 def get_top_partitions(partitions: List[Partition]) -> List[Partition]: 302 """This function is to return all the partitions without parents 303 as the starting points of all the paths 304 """ 305 top_partitions = [] 306 for partition in partitions: 307 # If a partition has no parents, then it is a top partition 308 if len(partition.parents) == 0: 309 top_partitions.append(partition) 310 return top_partitions 311 312 top_partitions = get_top_partitions(partitions) 313 critical_path_latency_sec = 0.0 314 for partition in top_partitions: 315 latency_sec = dfs_helper(partition, 0.0) 316 if latency_sec > critical_path_latency_sec: 317 critical_path_latency_sec = latency_sec 318 return critical_path_latency_sec 319