xref: /aosp_15_r20/external/pytorch/torch/fx/passes/utils/fuser_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from queue import SimpleQueue
4from typing import List, Dict, Tuple
5
6import torch.fx
7from torch.fx.graph_module import GraphModule
8from torch.fx.graph import Graph
9from torch.fx.node import Node
10from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
11from torch.fx.passes.utils import lift_subgraph_as_module
12from torch.fx._compatibility import compatibility
13
14@compatibility(is_backward_compatible=False)
15def topo_sort(nodes: NodeList) -> NodeList:
16    # sort nodes according to the topological order
17    indegree_map = dict.fromkeys(nodes, 0)
18    candidates: SimpleQueue = SimpleQueue()
19
20    for node in nodes:
21        for n in node.all_input_nodes:
22            if n in indegree_map:
23                indegree_map[node] += 1
24        if indegree_map[node] == 0:
25            candidates.put(node)
26
27    sorted_nodes: NodeList = []
28    while not candidates.empty():
29        node = candidates.get()
30        sorted_nodes.append(node)
31
32        for n in node.users:
33            if n in indegree_map:
34                indegree_map[n] -= 1
35                if indegree_map[n] == 0:
36                    candidates.put(n)
37
38    assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
39
40    return sorted_nodes
41
42
43@compatibility(is_backward_compatible=False)
44def validate_partition(partition: NodeList) -> bool:
45    # verify the partition does't form a dependency cycle in the original graph
46    # returns True for valid partition, False for invalid
47
48    partition_set = set(partition)
49
50    outputs: NodeList = []
51    for node in partition_set:
52        for user_node in node.users:
53            if user_node not in partition_set:
54                # external user node, need to expose as an output
55                outputs.append(user_node)
56
57    # Perform BFS on the partition outputs.
58    # If it reaches a node within the partition, then it found a cycle.
59    # This function takes the ownership of `root_nodes` and may modify it.
60    def bfs_find_cycle(root_nodes: NodeList) -> bool:
61        # Set used to exclude nodes that have already been visited.
62        # If a node has been visited, that node and all its children have
63        # been checked for cycles.
64        visited: NodeSet = set()
65
66        # Start with `root_nodes` and traverse through (toward child nodes)
67        # their connected sub-graph. Nodes in `visited` won't be added
68        # to `queue` again.
69        queue: NodeList = root_nodes
70        while queue:
71            current = queue.pop()
72            visited.add(current)
73            if current in partition_set:
74                # Started from partition's `output` nodes, and reached
75                # another node in partition. Cycle!
76                return True
77            for user_node in current.users:
78                if user_node in visited:
79                    continue
80                queue.append(user_node)
81        # `root_nodes` don't cause cycle.
82        return False
83
84    # Use all output nodes as roots to traverse
85    # the graph to check cycles.
86    if bfs_find_cycle(outputs):
87        return False
88
89    return True
90
91
92@compatibility(is_backward_compatible=False)
93def fuse_as_graphmodule(gm: GraphModule,
94                        nodes: NodeList,
95                        module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
96
97    """
98    Fuse nodes in graph_module into a GraphModule.
99
100    Args:
101        gm (GraphModule): target graph_module
102
103        nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
104
105        module_name: class name for the fused GraphModule
106
107    Returns:
108        fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
109
110        original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
111
112        original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
113
114    """
115
116    # assumption: nodes are already sorted in topo order
117
118    for node in nodes:
119        assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
120        assert not node._erased, f"{node} has been removed from owning graph"
121        assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
122
123    # validates partition doesn't introduce dependency circles in the graph
124    assert validate_partition(nodes), "Invalid partition, found dependency cycles"
125
126    subgraph = Graph()
127
128    node_to_placeholder: Dict[Node, Node] = {}  # mapping of nodes from old graph to placeholder in new graph
129    node_map: Dict[Node, Node] = {}       # mapping of nodes from old graph to new graph
130
131    # handles inputs through graph.node_copy's arg_transform functions
132    def remap_inputs(x):
133        if x.op == "get_attr":
134            # TODO: do we really need copy the get_attr node into the graph?
135            # do something here
136            pass
137
138        if x in nodes:
139            # x is inside subgraph, return the copied node
140            # the node should have been copied aleady, as we are copying graph in the topological order
141            return node_map[x]
142
143        if x not in node_to_placeholder:
144            # x is not in subgraph, create a new placeholder for subgraph
145            placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
146            # copy all meta fields, even if some fields might be irrelvant for the placeholder node
147            placeholder_node.meta = copy.copy(x.meta)
148            node_to_placeholder[x] = placeholder_node
149
150        return node_to_placeholder[x]
151
152    # copy nodes in topological order
153    for node in nodes:
154        new_node = subgraph.node_copy(node, remap_inputs)
155        node_map[node] = new_node
156
157    # handles outputs
158    output_mapping: Dict[Node, Node] = {}  # mapping from old output to new outputs
159
160    for node in nodes:
161        for user_node in node.users:
162            if user_node not in nodes:
163                # external user node, need to expose as an output
164                output_mapping[node] = node_map[node]
165
166    # outs contain nodes in the new subgraph
167    outs = tuple(output_mapping.values())
168
169    # Take care of the args of FX output node. If there's a single
170    # output then the output node args is like (output_single), else
171    # if there're multiple outputs then the output node args is like
172    # ((output_0, output_1, ...)).
173    subgraph.output(outs[0] if len(outs) == 1 else outs)
174
175    # lint to ensure correctness
176    subgraph.lint()
177    fused_gm: GraphModule
178    fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name)
179
180    # sub_gm's input nodes in the original module
181    original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
182
183    # sub_gm's outputs node in the original module
184    original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
185
186    return fused_gm, original_inputs, original_outputs
187
188
189@compatibility(is_backward_compatible=False)
190def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
191    # add sub_gm into gm
192    submodule_name = sub_gm.__class__.__name__
193    gm.add_submodule(submodule_name, sub_gm)
194
195    # Create a call_module node in main graph.
196    module_node = gm.graph.call_module(
197        submodule_name,
198        args=orig_inputs,
199        kwargs=None)
200
201    if len(orig_outputs) == 1:
202        # main_remapping[comp.orig_outputs[0]] = module_node
203        orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
204    else:
205        for i, orig_output in enumerate(orig_outputs):
206            # Use Proxy to record getitem access.
207            proxy_out = torch.fx.Proxy(module_node)[i].node  # type: ignore[index]
208            orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
209
210        module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs)
211    return gm
212
213@compatibility(is_backward_compatible=False)
214def erase_nodes(gm: GraphModule, nodes: NodeList):
215
216    # erase original nodes in inversed topological order
217    for node in reversed(nodes):
218        gm.graph.erase_node(node)
219
220
221@compatibility(is_backward_compatible=False)
222def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList], prefix: str = "fused_") -> GraphModule:
223    for partition_id, nodes in enumerate(partitions):
224        sorted_nodes = topo_sort(nodes)
225
226        submodule_name = prefix + str(partition_id)
227        sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
228
229        insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
230
231        erase_nodes(gm, sorted_nodes)
232
233    # topological sort original gm with newly created sub_gm
234    legalize_graph(gm)
235
236    return gm
237