xref: /aosp_15_r20/external/executorch/exir/backend/canonical_partitioners/pattern_op_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from typing import List, Optional
9
10import torch
11from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
12from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
13from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
14
15
16def generate_partitions_from_list_of_nodes(
17    graph_module: torch.fx.GraphModule,
18    pattern_list: Optional[List[List[torch.fx.Node]]] = None,
19    op_support: Optional[OperatorSupportBase] = None,
20) -> List[Partition]:
21    final_op_support: Optional[OperatorSupportBase] = op_support
22
23    if pattern_list is not None:
24        # Tag all the nodes in these patterns
25        for node_list in pattern_list:
26            for node in node_list:
27                node.meta["match"] = True
28
29        class MatchTag(OperatorSupportBase):
30            def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
31                return node.meta.get("match", False)
32
33        final_op_support = (
34            MatchTag()
35            if final_op_support is None
36            else any_chain(final_op_support, MatchTag())
37        )
38
39    assert (
40        final_op_support is not None
41    ), "Did not give a pattern or OperatorSupportBase instance to partition with"
42
43    # Run the CapabilityBasedPartitioner to return the largest possible
44    # subgraphs containing the nodes with the tags
45    capability_partitioner = CapabilityBasedPartitioner(
46        graph_module,
47        final_op_support,
48        allows_single_node_partition=True,
49    )
50    partition_list = capability_partitioner.propose_partitions()
51
52    # Remove the metadata field we added
53    for partition in partition_list:
54        for node in partition.nodes:
55            node.meta.pop("match", False)
56    return partition_list
57
58
59def generate_pattern_op_partitions(
60    graph_module: torch.fx.GraphModule,
61    patterns: Optional[List[torch.fx.Graph]] = None,
62    partitions_list: Optional[List[List[torch.fx.Node]]] = None,
63    op_support: Optional[OperatorSupportBase] = None,
64    ignore_literals: bool = False,
65) -> List[Partition]:
66    """
67    Args:
68        graph_module: Module that we want to partition
69        patterns: A list of patterns in the form of torch.fx.Graph. These graphs
70            can be obtained through the `graph` field from a GraphModule obtained by
71            exir.capture (recommended) or symbolic tracing (which might not result
72            in an accurate edge dialect graph), or by manual crafting a graph
73            module.
74        partitions_list: A list of node lists whose nodes are intended to be tagged
75            along with the nodes detected by the pattern matcher.
76        op_support: A OperatorSupportBase that can be created in the following ways:
77            - Subclassing it directly and implementing is_node_supported()
78            - Getting the result of create_op_support()
79            - Getting the result of create_pattern_support()
80            - Multiple OperatorSupportBase classes chained together with chain()
81
82    Returns
83        A list of partitions (largest possible subgraphs) containing nodes are
84        supported by the given OperatorSupportBase object
85    """
86    final_op_support: Optional[OperatorSupportBase] = op_support
87
88    if patterns is not None:
89        # Find all patterns in the graph (even if they're invalid)
90        matches = []
91        for pattern in patterns:
92            logging.debug(f"Finding matches for pattern: {pattern}")
93            subgraph_matcher = SubgraphMatcher(pattern, ignore_literals=ignore_literals)
94            matches.extend(subgraph_matcher.match(graph_module.graph))
95
96        # Tag all the nodes in these patterns
97        for match in matches:
98            for node_in_pattern, node_in_graph in match.nodes_map.items():
99                if (
100                    node_in_pattern.op != "placeholder"
101                    and node_in_graph.op != "placeholder"
102                ):
103                    node_in_graph.meta["match"] = True
104
105    if partitions_list:
106        for node_list in partitions_list:
107            for node in node_list:
108                node.meta["match"] = True
109
110    class MatchTag(OperatorSupportBase):
111        def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
112            return node.meta.get("match", False)
113
114    final_op_support = (
115        MatchTag()
116        if final_op_support is None
117        else any_chain(final_op_support, MatchTag())
118    )
119
120    assert (
121        final_op_support is not None
122    ), "Did not give a pattern or OperatorSupportBase instance to partition with"
123
124    # Run the CapabilityBasedPartitioner to return the largest possible
125    # subgraphs containing the nodes with the tags
126    capability_partitioner = CapabilityBasedPartitioner(
127        graph_module,
128        final_op_support,
129        allows_single_node_partition=True,
130    )
131    partition_list = capability_partitioner.propose_partitions()
132
133    # Remove the metadata field we added
134    for partition in partition_list:
135        for node in partition.nodes:
136            node.meta.pop("match", False)
137    return partition_list
138