xref: /aosp_15_r20/external/executorch/exir/backend/test/op_partitioner_demo.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8from typing import Callable, Dict, final, List, Optional, Tuple
9
10import torch
11from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
12    generate_pattern_op_partitions,
13)
14
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from executorch.exir.backend.partitioner import (
17    DelegationSpec,
18    Partitioner,
19    PartitionResult,
20)
21from executorch.exir.backend.test.backend_with_compiler_demo import (
22    BackendWithCompilerDemo,
23)
24from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
25    ExecutorBackend,
26)
27from executorch.exir.dialects._ops import ops as exir_ops
28from executorch.exir.graph_module import get_control_flow_submodules
29from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
30from torch.export import ExportedProgram
31from torch.fx.passes.infra.partitioner import Partition
32from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
33
34
35class AllOperatorSupport(OperatorSupportBase):
36    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
37        return node.op == "call_function"
38
39
40class AddOperatorSupport(OperatorSupportBase):
41    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
42        return node.op == "call_function" and node.target in [
43            exir_ops.edge.aten.add.Tensor,
44        ]
45
46
47class MatmulOperatorSupport(OperatorSupportBase):
48    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
49        return node.op == "call_function" and node.target in [
50            exir_ops.edge.aten.mm.default,
51        ]
52
53
54@final
55class AddMulPartitionerDemo(Partitioner):
56    """
57    Partitions all add/mul nodes regardless of order
58    """
59
60    def __init__(self) -> None:
61        self.op_support = any_chain(AddOperatorSupport(), MatmulOperatorSupport())
62        self.delegation_spec = DelegationSpec(
63            BackendWithCompilerDemo.__name__,
64            [CompileSpec("max_value", bytes([4]))],
65        )
66
67    def _partition_graph_module(
68        self,
69        graph_module: torch.fx.GraphModule,
70    ) -> Dict[str, DelegationSpec]:
71        partition_tags: Dict[str, DelegationSpec] = {}
72        partition_list = generate_pattern_op_partitions(
73            graph_module, op_support=self.op_support
74        )
75        for partition in partition_list:
76            for node in partition.nodes:
77                delegation_tag = f"tag{partition.id}"
78                node.meta["delegation_tag"] = delegation_tag
79                partition_tags[delegation_tag] = self.delegation_spec
80
81        for _, submodule, _ in get_control_flow_submodules(graph_module):
82            ret_partition_tags = self._partition_graph_module(submodule)
83            partition_tags.update(ret_partition_tags)
84
85        return partition_tags
86
87    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
88        partition_tags = self._partition_graph_module(exported_program.graph_module)
89        return PartitionResult(
90            tagged_exported_program=exported_program, partition_tags=partition_tags
91        )
92
93
94@final
95class AddAttributePartitionerDemo(Partitioner):
96    """
97    Partitions all add and get_attr nodes
98    """
99
100    def __init__(self) -> None:
101        self.op_support = AddOperatorSupport()
102
103        self.delegation_spec = DelegationSpec(BackendWithCompilerDemo.__name__, [])
104
105    def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
106        partition_tags = {}
107        partition_list = generate_pattern_op_partitions(
108            edge_exported_program.graph_module, op_support=self.op_support
109        )
110        for partition in partition_list:
111            for node in partition.nodes:
112                delegation_tag = f"tag{partition.id}"
113                partition_tags[delegation_tag] = self.delegation_spec
114
115                # Tag the add nodes
116                node.meta["delegation_tag"] = delegation_tag
117
118                for arg_node in node.args:
119                    if not isinstance(arg_node, torch.fx.Node):
120                        continue
121
122                    is_get_attr = arg_node.op == "get_attr"
123                    is_param_buffer = arg_node.op == "placeholder" and (
124                        is_param(edge_exported_program, arg_node)
125                        or is_buffer(edge_exported_program, arg_node)
126                        or is_lifted_tensor_constant(edge_exported_program, arg_node)
127                    )
128                    if is_get_attr or is_param_buffer:
129                        arg_node.meta["delegation_tag"] = delegation_tag
130                    # Add to the list of partitioned nodes.
131
132        return PartitionResult(
133            tagged_exported_program=edge_exported_program, partition_tags=partition_tags
134        )
135
136
137@final
138class AllNodesPartitionerDemo(Partitioner):
139    """
140    Partitions all nodes
141    """
142
143    def __init__(self) -> None:
144        self.op_support = AllOperatorSupport()
145        self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, [])
146
147    def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
148        partition_tags = {}
149        partition_list = generate_pattern_op_partitions(
150            edge_exported_program.graph_module, op_support=self.op_support
151        )
152        for partition in partition_list:
153            for node in partition.nodes:
154                delegation_tag = f"tag{partition.id}"
155                partition_tags[delegation_tag] = self.delegation_spec
156
157                # Tag the add nodes
158                node.meta["delegation_tag"] = delegation_tag
159
160                for arg_node in node.args:
161                    if not isinstance(arg_node, torch.fx.Node):
162                        continue
163
164                    is_get_attr = arg_node.op == "get_attr"
165                    is_param_buffer = arg_node.op == "placeholder" and (
166                        is_param(edge_exported_program, arg_node)
167                        or is_buffer(edge_exported_program, arg_node)
168                        or is_lifted_tensor_constant(edge_exported_program, arg_node)
169                    )
170                    if is_get_attr or is_param_buffer:
171                        arg_node.meta["delegation_tag"] = delegation_tag
172                    # Add to the list of partitioned nodes.
173
174        return PartitionResult(
175            tagged_exported_program=edge_exported_program, partition_tags=partition_tags
176        )
177
178
179ops_not_to_decompose = [
180    torch.ops.aten.linear.default,
181    torch.ops.aten.scaled_dot_product_attention.default,
182    torch.ops.aten.upsample_nearest2d.vec,
183]
184
185edge_ops_non_decomposed = [
186    exir_ops.edge.aten.linear.default,
187    exir_ops.edge.aten.scaled_dot_product_attention.default,
188    exir_ops.edge.aten.upsample_nearest2d.vec,
189]
190
191
192class OpsToNotDecomposeOperatorSupport(OperatorSupportBase):
193    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
194        return node.op == "call_function" and node.target in edge_ops_non_decomposed
195
196
197@final
198class NonDecompTestPartitioner(Partitioner):
199    """
200    Non Decomp Test Partitioner, preserves aten ops from decomposition for delegate
201    consumption. Ensures that non_decomposed_edge_ops are all within their own delegate
202    """
203
204    def __init__(self) -> None:
205        self.supported_non_decomposed_edge_ops = edge_ops_non_decomposed
206        self.op_support = any_chain(OpsToNotDecomposeOperatorSupport())
207        self.delegation_spec = DelegationSpec(
208            BackendWithCompilerDemo.__name__,
209            [CompileSpec("max_value", bytes([4]))],
210        )
211
212    def ops_to_not_decompose(
213        self, ep: ExportedProgram
214    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
215        def filter_ops(node: torch.fx.Node) -> bool:
216            if node.op == "call_function" and node.target in ops_not_to_decompose:
217                if len(node.args) == 3:
218                    # This means that linear has a bias which is the only linear we support in this
219                    # demo partitioner.
220                    return True
221                else:
222                    return False
223
224            return True
225
226        return (ops_not_to_decompose, filter_ops)
227
228    def _generate_single_node_partition(
229        self, gm: torch.fx.GraphModule
230    ) -> List[Partition]:
231        partitions = []
232        partition_id = itertools.count()
233        nodes_seen = set()
234        for node in gm.graph.nodes:
235            if (
236                node.op == "call_function"
237                and node.target in self.supported_non_decomposed_edge_ops
238                and node not in nodes_seen
239            ):
240                partitions.append(Partition(nodes=[node], id=next(partition_id)))
241                nodes_seen.add(node)
242
243        return partitions
244
245    def _partition_graph_module(
246        self,
247        graph_module: torch.fx.GraphModule,
248    ) -> Dict[str, DelegationSpec]:
249        partition_tags: Dict[str, DelegationSpec] = {}
250        partition_list = self._generate_single_node_partition(graph_module)
251        for partition in partition_list:
252            for node in partition.nodes:
253                delegation_tag = f"tag{partition.id}"
254                node.meta["delegation_tag"] = delegation_tag
255                partition_tags[delegation_tag] = self.delegation_spec
256
257        for _, submodule, _ in get_control_flow_submodules(graph_module):
258            ret_partition_tags = self._partition_graph_module(submodule)
259            partition_tags.update(ret_partition_tags)
260        return partition_tags
261
262    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
263        partition_tags = self._partition_graph_module(exported_program.graph_module)
264        return PartitionResult(
265            tagged_exported_program=exported_program, partition_tags=partition_tags
266        )
267