xref: /aosp_15_r20/external/executorch/exir/backend/canonical_partitioners/config_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from abc import ABC, abstractmethod
8from typing import Callable, Dict, Iterable, List, Optional, Tuple
9
10import torch
11from executorch.exir.backend.backend_details import ExportedProgram
12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13    generate_partitions_from_list_of_nodes,
14)
15from executorch.exir.backend.partitioner import (
16    DelegationSpec,
17    Partitioner,
18    PartitionResult,
19)
20from torch.fx.passes.infra.partitioner import Partition
21
22
23def format_target_name(target_name: str) -> str:
24    """
25    We remove the dialect name space from the target name. We generally
26    do not care for the op dialect specific name space ("aten.", "quantized_decomposed.")
27    but rather the op itself. Se remove the dialect-specific name space from the
28    name and return the op name itself
29    """
30    names = target_name.split(".")
31    if len(names) > 2:
32        names.pop(0)
33
34    return ".".join(names)
35
36
37class PartitionerConfig(ABC):
38    """
39    Class used to represent a PartitionerConfig.
40
41    PartitionerConfig is used by config-based partitioner to partition identify
42    nodes to be delegated. User overrides the methods:
43        - target_name
44        - check_constraints
45        - get_partition
46        - get_original_aten
47
48    The Config-Based Partitioner then uses these overridden methods to find nodes
49    which match target_name, check_constraints, and if true, returns the partition
50    (list of nodes) which represent the node and its dependencies. get_original_aten
51    is used to halt decomposition to edge_dialect if the node can be delegated by
52    the specified backend.
53    """
54
55    @classmethod
56    @property
57    @abstractmethod
58    def target_name(cls) -> str:
59        """
60        Target name for this partitioner config. When the Config-Based Partitioner
61        encounters a node with a matching target name, it uses this config's methods to
62        checks the constraints of this node and get all of its dependencies.
63        the target name is formatted to remove the dialect-specific name space.
64        i.e. linear.default
65        """
66        pass
67
68    @abstractmethod
69    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
70        """
71        Takes in a node and returns true if the node is partitionable.
72
73        Args:
74            node: Node to be partitioned
75            ep: Exported program of the graph module
76        Returns:
77            True or False whether this node is partitionable
78        """
79        pass
80
81    @abstractmethod
82    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
83        """
84        Returns the original aten dialect op, this is for to_edge_transform_and_lower
85        API, so that this config can be used to stop decomposition of this original
86        aten op
87        """
88        pass
89
90    @abstractmethod
91    def get_partition(
92        self, node: torch.fx.Node, ep: ExportedProgram
93    ) -> List[torch.fx.Node]:
94        """
95        Returns the partitioned nodes from get_node_and_deps, but also labels them
96        with the name of the PartitionerConfig class which return this set of nodes.
97
98        Returns an empty list of the node and deps do not satisfy the checked constraints
99        """
100        pass
101
102
103class ConfigerationBasedPartitioner(Partitioner):
104    def __init__(
105        self,
106        delegation_spec: DelegationSpec,
107        partitioner_configs: Iterable[PartitionerConfig],
108    ):
109        """
110        Configeration based partitioner. We supply the partitioner with a set of configerations
111        which describe the node type, constraints, and any dependencies required to be partitioned
112        with the node. We use the configerations to partition the graph module.
113        """
114        super().__init__()
115        # Initialize partitioner configs map {"target_name": PartitionerConfig}
116        self.target_partitioner_configs: Dict[str, PartitionerConfig] = {}
117        for config in partitioner_configs:
118            target_name = config.target_name
119            if target_name in self.target_partitioner_configs:
120                other_config = self.target_partitioner_configs[target_name]
121                raise RuntimeError(
122                    f"PartitionerConfig: {config} and {other_config} have the same target_name: {target_name}"
123                )
124            else:
125                self.target_partitioner_configs[target_name] = config
126
127        self.delegation_spec = delegation_spec
128
129    def ops_to_not_decompose(
130        self,
131        ep: ExportedProgram,
132    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
133        def filter_fn(node: torch.fx.Node) -> bool:
134            """
135            The partitioner configs we initialize with have check_constraints function,
136            to determine if this op is indeed partitionable. We grab the check_constraint
137            function of this op from the config and use it to filter.
138            """
139            if node.op != "call_function":
140                return False
141            target_name = format_target_name(node.target.__name__)  # pyre-ignore
142
143            if target_name in self.target_partitioner_configs:
144                config = self.target_partitioner_configs[target_name]
145                # only filter_fn if config has original_aten
146                if config.get_original_aten():
147                    return self.target_partitioner_configs[
148                        target_name
149                    ].check_constraints(node, ep)
150
151            return False
152
153        # Get list of original aten targets which we do not want to decomp
154        do_not_decomp = []
155        for node_config in self.target_partitioner_configs.values():
156            original_aten = node_config.get_original_aten()
157            if original_aten is not None:
158                do_not_decomp.append(original_aten)
159
160        return (do_not_decomp, filter_fn)
161
162    def get_matched_nodes_from_configs(
163        self, ep: ExportedProgram
164    ) -> List[List[torch.fx.Node]]:
165        # gather supported nodes
166        matched_nodes = []
167        gm = ep.graph_module
168        for node in gm.graph.nodes:
169            if node.op == "call_function":
170                target = format_target_name(node.target.__name__)
171                if target in self.target_partitioner_configs:
172                    node_config = self.target_partitioner_configs[target]
173                    if node_config.check_constraints(node, ep):
174                        matched_nodes.append(node_config.get_partition(node, ep))
175
176        return matched_nodes
177
178    def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
179        matched_nodes = self.get_matched_nodes_from_configs(ep)
180        # create partitions
181        partitions = generate_partitions_from_list_of_nodes(
182            ep.graph_module,
183            matched_nodes,
184        )
185        return partitions
186
187    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
188        partitions = self.generate_partitions(exported_program)
189
190        # tag nodes
191        partition_tags: Dict[str, DelegationSpec] = {}
192        for partition in partitions:
193            for node in partition.nodes:
194                delegation_tag = f"tag{partition.id}"
195                if "delegation_tag" in node.meta:
196                    raise RuntimeError(
197                        f"Partitioner Erro found node {node} in partition {node.meta['delegation_tag']} and partition {delegation_tag}"
198                    )
199                node.meta["delegation_tag"] = delegation_tag
200                partition_tags[delegation_tag] = self.delegation_spec
201
202        return PartitionResult(
203            tagged_exported_program=exported_program, partition_tags=partition_tags
204        )
205