xref: /aosp_15_r20/external/pytorch/torch/fx/passes/infra/partitioner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
3import collections
4import itertools
5import logging
6
7from copy import copy
8from typing import Dict, Iterable, List, Optional, Sequence, Set
9
10from torch.fx.graph_module import GraphModule
11from torch.fx.node import Node, _get_qualified_name
12from torch.fx.passes.operator_support import OperatorSupportBase
13
14
15logger = logging.getLogger(__name__)
16logger.setLevel(logging.WARNING)
17
18class Partition:
19    def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
20        self.id = id
21        self.nodes = dict.fromkeys(nodes) if nodes is not None else {}
22
23    def __repr__(self) -> str:
24        return str(self.nodes)
25
26    def add_node(self, node: Node):
27        self.nodes.update({node: None})
28
29    def remove_node(self, node: Node):
30        del self.nodes[node]
31
32    def size(self):
33        return len(self.nodes)
34
35class _DependencyViewer:
36    def __init__(self, graph_module: GraphModule):
37        self.upstreams = collections.defaultdict(set)
38        self.downstreams = collections.defaultdict(set)
39
40        for node in graph_module.graph.nodes:
41            for input_node in node.all_input_nodes:
42                # add input_node and input_node's upstream dependency
43                self.upstreams[node].add(input_node)
44                self.upstreams[node].update(self.upstreams[input_node])
45
46        for node in reversed(graph_module.graph.nodes):
47            for output_node in node.users:
48                # add output_node and output_node's downstream dependency
49                self.downstreams[node].add(output_node)
50                self.downstreams[node].update(self.downstreams[output_node])
51
52    def downstreams_of(self, node: Node) -> Set[Node]:
53        return self.downstreams[node]
54
55    def upstreams_of(self, node: Node) -> Set[Node]:
56        return self.upstreams[node]
57
58class CapabilityBasedPartitioner:
59
60    def __init__(self,
61                 graph_module: GraphModule,
62                 operator_support: OperatorSupportBase,
63                 allows_single_node_partition: bool = False,
64                 non_compute_ops: Optional[Sequence[str]] = None,
65                 allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
66                 ) -> None:
67        self.graph_module = graph_module
68        self.operator_support = operator_support
69        self.allows_single_node_partition = allows_single_node_partition
70        self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
71        self.allowed_single_node_partition_ops = (
72            allowed_single_node_partition_ops
73            if allowed_single_node_partition_ops is not None
74            else []
75        )
76        self.dependency_viewer = _DependencyViewer(graph_module)
77
78    def __is_node_supported(self, node: Node) -> bool:
79        return (
80            self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
81        )
82
83    def propose_partitions(self) -> List[Partition]:
84        # partition_map is a mapping from partition id to a set of partition id's.
85        # The value set contains all the partition ids that can be reached by doing a
86        # DFS starting from the partition id in the key.
87        partition_map : Dict[int, Set] = collections.defaultdict(set)
88
89        # assumptions: nodes in candidate list is sorted in topological order
90        assignment: Dict[Node, int] = {}   # mapping from node to partition_id
91        partitions_by_id: Dict[int, Partition] = {}  # mapping from partition_id to partition
92        new_partition_id = itertools.count()
93
94        # try to merge partition other_id into partition self_id
95        # merge only happens if the end graph doesn't contain cyclic dependency
96        # returns `True` when merge happens, `False` otherwise.
97        def maybe_merge_partition(self_id: int, other_id: int):
98            # merged_nodes is the union of nodes in two partition to-be-merged
99            merged_nodes = copy(partitions_by_id[self_id].nodes)
100            merged_nodes.update(partitions_by_id[other_id].nodes)
101
102            def dfs_iter_find_cycle(all_user_nodes: Set[Node]):
103                for user_node in all_user_nodes:
104                    visited_partition_ids = set()
105
106                    for path_node in self.dependency_viewer.downstreams_of(user_node):
107                        # If any of the nodes in the dfs path of this node are in the merged_nodes
108                        # list then there is a cycle in the graph.
109                        if path_node in merged_nodes:
110                            return True
111
112                        # If any of the nodes in the dfs path of this node are in the assignment
113                        # map then we have to make sure that the partitions that these nodes belong
114                        # to do not form a cycle with the current partitions being merged. This means
115                        # iterating through all the nodes in all the parititons that are traversed in
116                        # the dfs path and checking if they are in the merged_nodes list.
117                        if path_node in assignment:
118                            partition_id = assignment[path_node]
119                            # If the partition id has already been visited then we know that it doesn't
120                            # form a cycle with the current partitions being merged.
121                            if partition_id in visited_partition_ids:
122                                continue
123                            p_map = partition_map[partition_id]
124                            if self_id in p_map or other_id in p_map:
125                                return True
126
127                            visited_partition_ids.add(partition_id)
128
129                return False
130
131            # check if merge would create cyclic dependency.
132            all_user_nodes = set()
133            for node in merged_nodes:
134                for user_node in node.users:
135                    if user_node not in merged_nodes:
136                        all_user_nodes.add(user_node)
137
138            if dfs_iter_find_cycle(all_user_nodes):
139                # return false indicating cyclic dependency found and
140                # merge is aborted
141                return False
142
143            # no cyclic dependency found, move forward with the merge
144            # updating partition nodes
145            partitions_by_id[self_id].nodes = merged_nodes
146            # updating assignment map
147            for node in partitions_by_id[other_id].nodes:
148                assignment[node] = self_id
149            # delete other partition
150            del partitions_by_id[other_id]
151
152            partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
153            del partition_map[other_id]
154
155            return True
156
157        def merge_single_node(node: Node, id: Optional[int]):
158            def _update_partition_map(node: Node, id: int):
159                # Iterate through all the downstream nodes of this node and update the partition map
160                # to indicate that there is a path from the partition id of this node to the target
161                # partition id.
162                downstream_nodes = self.dependency_viewer.downstreams_of(node)
163                for curr_node in downstream_nodes:
164                    target_id = assignment.get(curr_node, None)
165                    if target_id is not None:
166                        partition_map[id].add(target_id)
167
168                # Iterate through all the upstream nodes of this node and update the partition map
169                # to indicate that there is a path from the partition id of the upstream node to the
170                # current node's partition id.
171                upstream_nodes = self.dependency_viewer.upstreams_of(node)
172                for curr_node in upstream_nodes:
173                    source_id = assignment.get(curr_node, None)
174                    if source_id is not None:
175                        partition_map[source_id].add(id)
176
177            if node in assignment:
178                partitions_by_id[assignment[node]].remove_node(node)
179
180            if id is None:
181                assignment.pop(node)
182            elif id not in partitions_by_id:
183                assignment[node] = id
184                partitions_by_id[id] = Partition(id=id, nodes=[node])
185                _update_partition_map(node, id)
186            else:
187                assignment[node] = id
188                partitions_by_id[id].add_node(node)
189                _update_partition_map(node, id)
190
191        logger.debug("Proposing partitions...")
192
193        for node in reversed(self.graph_module.graph.nodes):
194            # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
195            merge_candidates: Dict[int, None] = {}
196
197            # Note a limited horizontal fusion is enabled:
198            #   when `node` is not supported, the code below attempts to fuse consumer of `node`.
199            #
200            # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
201            # the fusion by adding an `else` block here to skip horizontal fusion.
202            if self.__is_node_supported(node) and node not in assignment:
203                partition_id = next(new_partition_id)
204                merge_single_node(node, partition_id)
205                merge_candidates[partition_id] = None
206
207            # merge all possible partitions
208            for node in assignment:
209                merge_candidates[assignment[node]] = None
210
211            merge_candidates_list = list(merge_candidates.keys())
212            if len(merge_candidates_list) > 1:
213                self_id = merge_candidates_list[0]
214                for other_id in merge_candidates_list[1:]:
215                    # note: merge partition `other_id` into partition `self_id` if
216                    # it doesn't create cyclic dependency in the graph, otherwise,
217                    # this is a no-op
218                    maybe_merge_partition(self_id, other_id)
219
220        # post processing to re-assign "getitem" nodes into upstream partition
221        logger.debug("Reassigning getitem nodes to its producer node's partition...")
222        nodes_reassignment: Dict[Node, int] = {}
223        for node in self.graph_module.graph.nodes:
224            is_tuple_output = True
225            for user in node.users:
226                if user.op != "call_function" or \
227                   _get_qualified_name(user.target) != "_operator.getitem":     # type: ignore[arg-type]
228                    is_tuple_output = False
229                    break
230
231            # node has tuple outputs, re-assign all following getitem node into node's partition
232            if is_tuple_output:
233                id = assignment.get(node, None)     # type: ignore[arg-type]
234                for user in node.users:
235                    if assignment.get(user, None) != id:    # type: ignore[arg-type]
236                        nodes_reassignment[user] = id  # type: ignore[assignment]
237        for node, id in nodes_reassignment.items():
238            merge_single_node(node, id)
239
240        # filter out single node partitions
241        if not self.allows_single_node_partition:
242            logger.debug("Filtering out single node partitions...")
243            default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
244            non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
245            partitions_to_remove: List[int] = []
246            for id, partition in partitions_by_id.items():
247                compute_node_count = 0
248                for node in partition.nodes:
249                    if node.op == "call_function":
250                        assert callable(node.target)
251                        if _get_qualified_name(node.target) not in non_compute_ops:
252                            compute_node_count += 1
253                        if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
254                            compute_node_count += 1
255                if compute_node_count <= 1:
256                    partitions_to_remove.append(id)
257            for id in partitions_to_remove:
258                del partitions_by_id[id]
259
260        logger.debug("Partitions proposed:")
261        for id, partition in partitions_by_id.items():
262            logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
263
264        return [partition for partition in partitions_by_id.values() if partition.size() > 0]
265
266    def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule:
267        logger.debug("Fusing partitions...")
268        # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
269        return fuse_by_partitions(
270            self.graph_module,
271            [list(partition.nodes) for partition in partitions],
272            prefix=prefix,
273        )
274
275    # remove non-compute-ops that sits at the boundary of a partition.
276    def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
277        non_compute_ops = set(self.non_compute_ops)
278
279        def is_non_compute_node(node: Node):
280            return node.op == "call_function" and \
281                _get_qualified_name(node.target) in non_compute_ops  # type: ignore[arg-type]
282
283        # cache transparent nodes
284        transparent_input_nodes: Dict[Node, bool] = {}
285        transparent_output_nodes: Dict[Node, bool] = {}
286
287        def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
288            if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
289                return True
290            if node in transparent_input_nodes:
291                return transparent_input_nodes[node]
292            if is_non_compute_node(node):
293                for input_n in node.all_input_nodes:
294                    if not is_transparent_input_node(input_n, partition, removed_nodes):
295                        transparent_input_nodes[node] = False
296                        return False
297                transparent_input_nodes[node] = True
298                return True
299            transparent_input_nodes[node] = False
300            return False
301
302        def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
303            if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
304                return True
305            if node in transparent_output_nodes:
306                return transparent_output_nodes[node]
307            if is_non_compute_node(node):
308                for output_n in node.users:
309                    if not is_transparent_output_node(output_n, partition, removed_nodes):
310                        transparent_output_nodes[node] = False
311                        return False
312                transparent_output_nodes[node] = True
313                return True
314            transparent_output_nodes[node] = False
315            return False
316
317        for partition in partitions:
318            # Note it's ok to use `set` here, since we are only query if a node
319            # has been removed. We are NEVER going to iterate on nodes inside
320            # the set.
321            remove_node: Set[Node] = set()
322            for node in partition.nodes:
323                if is_non_compute_node(node) and \
324                    (is_transparent_input_node(node, set(partition.nodes), remove_node) or
325                     is_transparent_output_node(node, set(partition.nodes), remove_node)):
326                    remove_node.add(node)
327
328            if len(remove_node) != 0:
329                for node in remove_node:
330                    partition.nodes.pop(node, None)
331
332    def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule:
333        partitions = self.propose_partitions()
334        fused_gm = self.fuse_partitions(partitions, prefix=prefix)
335        return fused_gm
336