1# mypy: allow-untyped-defs 2import copy 3from queue import SimpleQueue 4from typing import List, Dict, Tuple 5 6import torch.fx 7from torch.fx.graph_module import GraphModule 8from torch.fx.graph import Graph 9from torch.fx.node import Node 10from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph 11from torch.fx.passes.utils import lift_subgraph_as_module 12from torch.fx._compatibility import compatibility 13 14@compatibility(is_backward_compatible=False) 15def topo_sort(nodes: NodeList) -> NodeList: 16 # sort nodes according to the topological order 17 indegree_map = dict.fromkeys(nodes, 0) 18 candidates: SimpleQueue = SimpleQueue() 19 20 for node in nodes: 21 for n in node.all_input_nodes: 22 if n in indegree_map: 23 indegree_map[node] += 1 24 if indegree_map[node] == 0: 25 candidates.put(node) 26 27 sorted_nodes: NodeList = [] 28 while not candidates.empty(): 29 node = candidates.get() 30 sorted_nodes.append(node) 31 32 for n in node.users: 33 if n in indegree_map: 34 indegree_map[n] -= 1 35 if indegree_map[n] == 0: 36 candidates.put(n) 37 38 assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" 39 40 return sorted_nodes 41 42 43@compatibility(is_backward_compatible=False) 44def validate_partition(partition: NodeList) -> bool: 45 # verify the partition does't form a dependency cycle in the original graph 46 # returns True for valid partition, False for invalid 47 48 partition_set = set(partition) 49 50 outputs: NodeList = [] 51 for node in partition_set: 52 for user_node in node.users: 53 if user_node not in partition_set: 54 # external user node, need to expose as an output 55 outputs.append(user_node) 56 57 # Perform BFS on the partition outputs. 58 # If it reaches a node within the partition, then it found a cycle. 59 # This function takes the ownership of `root_nodes` and may modify it. 60 def bfs_find_cycle(root_nodes: NodeList) -> bool: 61 # Set used to exclude nodes that have already been visited. 62 # If a node has been visited, that node and all its children have 63 # been checked for cycles. 64 visited: NodeSet = set() 65 66 # Start with `root_nodes` and traverse through (toward child nodes) 67 # their connected sub-graph. Nodes in `visited` won't be added 68 # to `queue` again. 69 queue: NodeList = root_nodes 70 while queue: 71 current = queue.pop() 72 visited.add(current) 73 if current in partition_set: 74 # Started from partition's `output` nodes, and reached 75 # another node in partition. Cycle! 76 return True 77 for user_node in current.users: 78 if user_node in visited: 79 continue 80 queue.append(user_node) 81 # `root_nodes` don't cause cycle. 82 return False 83 84 # Use all output nodes as roots to traverse 85 # the graph to check cycles. 86 if bfs_find_cycle(outputs): 87 return False 88 89 return True 90 91 92@compatibility(is_backward_compatible=False) 93def fuse_as_graphmodule(gm: GraphModule, 94 nodes: NodeList, 95 module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: 96 97 """ 98 Fuse nodes in graph_module into a GraphModule. 99 100 Args: 101 gm (GraphModule): target graph_module 102 103 nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted 104 105 module_name: class name for the fused GraphModule 106 107 Returns: 108 fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` 109 110 original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` 111 112 original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` 113 114 """ 115 116 # assumption: nodes are already sorted in topo order 117 118 for node in nodes: 119 assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" 120 assert not node._erased, f"{node} has been removed from owning graph" 121 assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" 122 123 # validates partition doesn't introduce dependency circles in the graph 124 assert validate_partition(nodes), "Invalid partition, found dependency cycles" 125 126 subgraph = Graph() 127 128 node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph 129 node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph 130 131 # handles inputs through graph.node_copy's arg_transform functions 132 def remap_inputs(x): 133 if x.op == "get_attr": 134 # TODO: do we really need copy the get_attr node into the graph? 135 # do something here 136 pass 137 138 if x in nodes: 139 # x is inside subgraph, return the copied node 140 # the node should have been copied aleady, as we are copying graph in the topological order 141 return node_map[x] 142 143 if x not in node_to_placeholder: 144 # x is not in subgraph, create a new placeholder for subgraph 145 placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) 146 # copy all meta fields, even if some fields might be irrelvant for the placeholder node 147 placeholder_node.meta = copy.copy(x.meta) 148 node_to_placeholder[x] = placeholder_node 149 150 return node_to_placeholder[x] 151 152 # copy nodes in topological order 153 for node in nodes: 154 new_node = subgraph.node_copy(node, remap_inputs) 155 node_map[node] = new_node 156 157 # handles outputs 158 output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs 159 160 for node in nodes: 161 for user_node in node.users: 162 if user_node not in nodes: 163 # external user node, need to expose as an output 164 output_mapping[node] = node_map[node] 165 166 # outs contain nodes in the new subgraph 167 outs = tuple(output_mapping.values()) 168 169 # Take care of the args of FX output node. If there's a single 170 # output then the output node args is like (output_single), else 171 # if there're multiple outputs then the output node args is like 172 # ((output_0, output_1, ...)). 173 subgraph.output(outs[0] if len(outs) == 1 else outs) 174 175 # lint to ensure correctness 176 subgraph.lint() 177 fused_gm: GraphModule 178 fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name) 179 180 # sub_gm's input nodes in the original module 181 original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) 182 183 # sub_gm's outputs node in the original module 184 original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys()) 185 186 return fused_gm, original_inputs, original_outputs 187 188 189@compatibility(is_backward_compatible=False) 190def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): 191 # add sub_gm into gm 192 submodule_name = sub_gm.__class__.__name__ 193 gm.add_submodule(submodule_name, sub_gm) 194 195 # Create a call_module node in main graph. 196 module_node = gm.graph.call_module( 197 submodule_name, 198 args=orig_inputs, 199 kwargs=None) 200 201 if len(orig_outputs) == 1: 202 # main_remapping[comp.orig_outputs[0]] = module_node 203 orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) 204 else: 205 for i, orig_output in enumerate(orig_outputs): 206 # Use Proxy to record getitem access. 207 proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] 208 orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) 209 210 module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs) 211 return gm 212 213@compatibility(is_backward_compatible=False) 214def erase_nodes(gm: GraphModule, nodes: NodeList): 215 216 # erase original nodes in inversed topological order 217 for node in reversed(nodes): 218 gm.graph.erase_node(node) 219 220 221@compatibility(is_backward_compatible=False) 222def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule: 223 for partition_id, nodes in enumerate(partitions): 224 sorted_nodes = topo_sort(nodes) 225 226 submodule_name = prefix + str(partition_id) 227 sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) 228 229 insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) 230 231 erase_nodes(gm, sorted_nodes) 232 233 # topological sort original gm with newly created sub_gm 234 legalize_graph(gm) 235 236 return gm 237