1# mypy: allow-untyped-defs 2from torch.fx.passes.utils.fuser_utils import fuse_by_partitions 3import collections 4import itertools 5import logging 6 7from copy import copy 8from typing import Dict, Iterable, List, Optional, Sequence, Set 9 10from torch.fx.graph_module import GraphModule 11from torch.fx.node import Node, _get_qualified_name 12from torch.fx.passes.operator_support import OperatorSupportBase 13 14 15logger = logging.getLogger(__name__) 16logger.setLevel(logging.WARNING) 17 18class Partition: 19 def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): 20 self.id = id 21 self.nodes = dict.fromkeys(nodes) if nodes is not None else {} 22 23 def __repr__(self) -> str: 24 return str(self.nodes) 25 26 def add_node(self, node: Node): 27 self.nodes.update({node: None}) 28 29 def remove_node(self, node: Node): 30 del self.nodes[node] 31 32 def size(self): 33 return len(self.nodes) 34 35class _DependencyViewer: 36 def __init__(self, graph_module: GraphModule): 37 self.upstreams = collections.defaultdict(set) 38 self.downstreams = collections.defaultdict(set) 39 40 for node in graph_module.graph.nodes: 41 for input_node in node.all_input_nodes: 42 # add input_node and input_node's upstream dependency 43 self.upstreams[node].add(input_node) 44 self.upstreams[node].update(self.upstreams[input_node]) 45 46 for node in reversed(graph_module.graph.nodes): 47 for output_node in node.users: 48 # add output_node and output_node's downstream dependency 49 self.downstreams[node].add(output_node) 50 self.downstreams[node].update(self.downstreams[output_node]) 51 52 def downstreams_of(self, node: Node) -> Set[Node]: 53 return self.downstreams[node] 54 55 def upstreams_of(self, node: Node) -> Set[Node]: 56 return self.upstreams[node] 57 58class CapabilityBasedPartitioner: 59 60 def __init__(self, 61 graph_module: GraphModule, 62 operator_support: OperatorSupportBase, 63 allows_single_node_partition: bool = False, 64 non_compute_ops: Optional[Sequence[str]] = None, 65 allowed_single_node_partition_ops: Optional[Sequence[str]] = None, 66 ) -> None: 67 self.graph_module = graph_module 68 self.operator_support = operator_support 69 self.allows_single_node_partition = allows_single_node_partition 70 self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] 71 self.allowed_single_node_partition_ops = ( 72 allowed_single_node_partition_ops 73 if allowed_single_node_partition_ops is not None 74 else [] 75 ) 76 self.dependency_viewer = _DependencyViewer(graph_module) 77 78 def __is_node_supported(self, node: Node) -> bool: 79 return ( 80 self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node) 81 ) 82 83 def propose_partitions(self) -> List[Partition]: 84 # partition_map is a mapping from partition id to a set of partition id's. 85 # The value set contains all the partition ids that can be reached by doing a 86 # DFS starting from the partition id in the key. 87 partition_map : Dict[int, Set] = collections.defaultdict(set) 88 89 # assumptions: nodes in candidate list is sorted in topological order 90 assignment: Dict[Node, int] = {} # mapping from node to partition_id 91 partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition 92 new_partition_id = itertools.count() 93 94 # try to merge partition other_id into partition self_id 95 # merge only happens if the end graph doesn't contain cyclic dependency 96 # returns `True` when merge happens, `False` otherwise. 97 def maybe_merge_partition(self_id: int, other_id: int): 98 # merged_nodes is the union of nodes in two partition to-be-merged 99 merged_nodes = copy(partitions_by_id[self_id].nodes) 100 merged_nodes.update(partitions_by_id[other_id].nodes) 101 102 def dfs_iter_find_cycle(all_user_nodes: Set[Node]): 103 for user_node in all_user_nodes: 104 visited_partition_ids = set() 105 106 for path_node in self.dependency_viewer.downstreams_of(user_node): 107 # If any of the nodes in the dfs path of this node are in the merged_nodes 108 # list then there is a cycle in the graph. 109 if path_node in merged_nodes: 110 return True 111 112 # If any of the nodes in the dfs path of this node are in the assignment 113 # map then we have to make sure that the partitions that these nodes belong 114 # to do not form a cycle with the current partitions being merged. This means 115 # iterating through all the nodes in all the parititons that are traversed in 116 # the dfs path and checking if they are in the merged_nodes list. 117 if path_node in assignment: 118 partition_id = assignment[path_node] 119 # If the partition id has already been visited then we know that it doesn't 120 # form a cycle with the current partitions being merged. 121 if partition_id in visited_partition_ids: 122 continue 123 p_map = partition_map[partition_id] 124 if self_id in p_map or other_id in p_map: 125 return True 126 127 visited_partition_ids.add(partition_id) 128 129 return False 130 131 # check if merge would create cyclic dependency. 132 all_user_nodes = set() 133 for node in merged_nodes: 134 for user_node in node.users: 135 if user_node not in merged_nodes: 136 all_user_nodes.add(user_node) 137 138 if dfs_iter_find_cycle(all_user_nodes): 139 # return false indicating cyclic dependency found and 140 # merge is aborted 141 return False 142 143 # no cyclic dependency found, move forward with the merge 144 # updating partition nodes 145 partitions_by_id[self_id].nodes = merged_nodes 146 # updating assignment map 147 for node in partitions_by_id[other_id].nodes: 148 assignment[node] = self_id 149 # delete other partition 150 del partitions_by_id[other_id] 151 152 partition_map[self_id] = partition_map[self_id].union(partition_map[other_id]) 153 del partition_map[other_id] 154 155 return True 156 157 def merge_single_node(node: Node, id: Optional[int]): 158 def _update_partition_map(node: Node, id: int): 159 # Iterate through all the downstream nodes of this node and update the partition map 160 # to indicate that there is a path from the partition id of this node to the target 161 # partition id. 162 downstream_nodes = self.dependency_viewer.downstreams_of(node) 163 for curr_node in downstream_nodes: 164 target_id = assignment.get(curr_node, None) 165 if target_id is not None: 166 partition_map[id].add(target_id) 167 168 # Iterate through all the upstream nodes of this node and update the partition map 169 # to indicate that there is a path from the partition id of the upstream node to the 170 # current node's partition id. 171 upstream_nodes = self.dependency_viewer.upstreams_of(node) 172 for curr_node in upstream_nodes: 173 source_id = assignment.get(curr_node, None) 174 if source_id is not None: 175 partition_map[source_id].add(id) 176 177 if node in assignment: 178 partitions_by_id[assignment[node]].remove_node(node) 179 180 if id is None: 181 assignment.pop(node) 182 elif id not in partitions_by_id: 183 assignment[node] = id 184 partitions_by_id[id] = Partition(id=id, nodes=[node]) 185 _update_partition_map(node, id) 186 else: 187 assignment[node] = id 188 partitions_by_id[id].add_node(node) 189 _update_partition_map(node, id) 190 191 logger.debug("Proposing partitions...") 192 193 for node in reversed(self.graph_module.graph.nodes): 194 # use Dict as an ordered set to ensure deterministic partitioning result, don't care value 195 merge_candidates: Dict[int, None] = {} 196 197 # Note a limited horizontal fusion is enabled: 198 # when `node` is not supported, the code below attempts to fuse consumer of `node`. 199 # 200 # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut 201 # the fusion by adding an `else` block here to skip horizontal fusion. 202 if self.__is_node_supported(node) and node not in assignment: 203 partition_id = next(new_partition_id) 204 merge_single_node(node, partition_id) 205 merge_candidates[partition_id] = None 206 207 # merge all possible partitions 208 for node in assignment: 209 merge_candidates[assignment[node]] = None 210 211 merge_candidates_list = list(merge_candidates.keys()) 212 if len(merge_candidates_list) > 1: 213 self_id = merge_candidates_list[0] 214 for other_id in merge_candidates_list[1:]: 215 # note: merge partition `other_id` into partition `self_id` if 216 # it doesn't create cyclic dependency in the graph, otherwise, 217 # this is a no-op 218 maybe_merge_partition(self_id, other_id) 219 220 # post processing to re-assign "getitem" nodes into upstream partition 221 logger.debug("Reassigning getitem nodes to its producer node's partition...") 222 nodes_reassignment: Dict[Node, int] = {} 223 for node in self.graph_module.graph.nodes: 224 is_tuple_output = True 225 for user in node.users: 226 if user.op != "call_function" or \ 227 _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] 228 is_tuple_output = False 229 break 230 231 # node has tuple outputs, re-assign all following getitem node into node's partition 232 if is_tuple_output: 233 id = assignment.get(node, None) # type: ignore[arg-type] 234 for user in node.users: 235 if assignment.get(user, None) != id: # type: ignore[arg-type] 236 nodes_reassignment[user] = id # type: ignore[assignment] 237 for node, id in nodes_reassignment.items(): 238 merge_single_node(node, id) 239 240 # filter out single node partitions 241 if not self.allows_single_node_partition: 242 logger.debug("Filtering out single node partitions...") 243 default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} 244 non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) 245 partitions_to_remove: List[int] = [] 246 for id, partition in partitions_by_id.items(): 247 compute_node_count = 0 248 for node in partition.nodes: 249 if node.op == "call_function": 250 assert callable(node.target) 251 if _get_qualified_name(node.target) not in non_compute_ops: 252 compute_node_count += 1 253 if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops: 254 compute_node_count += 1 255 if compute_node_count <= 1: 256 partitions_to_remove.append(id) 257 for id in partitions_to_remove: 258 del partitions_by_id[id] 259 260 logger.debug("Partitions proposed:") 261 for id, partition in partitions_by_id.items(): 262 logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes]) 263 264 return [partition for partition in partitions_by_id.values() if partition.size() > 0] 265 266 def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: 267 logger.debug("Fusing partitions...") 268 # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] 269 return fuse_by_partitions( 270 self.graph_module, 271 [list(partition.nodes) for partition in partitions], 272 prefix=prefix, 273 ) 274 275 # remove non-compute-ops that sits at the boundary of a partition. 276 def remove_bookend_non_compute_ops(self, partitions: List[Partition]): 277 non_compute_ops = set(self.non_compute_ops) 278 279 def is_non_compute_node(node: Node): 280 return node.op == "call_function" and \ 281 _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] 282 283 # cache transparent nodes 284 transparent_input_nodes: Dict[Node, bool] = {} 285 transparent_output_nodes: Dict[Node, bool] = {} 286 287 def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): 288 if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): 289 return True 290 if node in transparent_input_nodes: 291 return transparent_input_nodes[node] 292 if is_non_compute_node(node): 293 for input_n in node.all_input_nodes: 294 if not is_transparent_input_node(input_n, partition, removed_nodes): 295 transparent_input_nodes[node] = False 296 return False 297 transparent_input_nodes[node] = True 298 return True 299 transparent_input_nodes[node] = False 300 return False 301 302 def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): 303 if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): 304 return True 305 if node in transparent_output_nodes: 306 return transparent_output_nodes[node] 307 if is_non_compute_node(node): 308 for output_n in node.users: 309 if not is_transparent_output_node(output_n, partition, removed_nodes): 310 transparent_output_nodes[node] = False 311 return False 312 transparent_output_nodes[node] = True 313 return True 314 transparent_output_nodes[node] = False 315 return False 316 317 for partition in partitions: 318 # Note it's ok to use `set` here, since we are only query if a node 319 # has been removed. We are NEVER going to iterate on nodes inside 320 # the set. 321 remove_node: Set[Node] = set() 322 for node in partition.nodes: 323 if is_non_compute_node(node) and \ 324 (is_transparent_input_node(node, set(partition.nodes), remove_node) or 325 is_transparent_output_node(node, set(partition.nodes), remove_node)): 326 remove_node.add(node) 327 328 if len(remove_node) != 0: 329 for node in remove_node: 330 partition.nodes.pop(node, None) 331 332 def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: 333 partitions = self.propose_partitions() 334 fused_gm = self.fuse_partitions(partitions, prefix=prefix) 335 return fused_gm 336