xref: /aosp_15_r20/external/executorch/backends/example/example_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict, final
8
9import torch
10from executorch.backends.example.example_backend import ExampleBackend
11from executorch.backends.example.example_operators.ops import module_to_annotator
12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
13    generate_partitions_from_list_of_nodes,
14)
15from executorch.exir.backend.partitioner import (
16    DelegationSpec,
17    Partitioner,
18    PartitionResult,
19)
20from executorch.exir.dialects._ops import ops as exir_ops
21from executorch.exir.graph_module import get_control_flow_submodules
22from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
23from torch.export import ExportedProgram
24from torch.fx.passes.operator_support import OperatorSupportBase
25
26
27@final
28class ExamplePartitioner(Partitioner):
29    """
30    Partitions all add/mul nodes regardless of order
31    """
32
33    def __init__(self) -> None:
34        self.patterns = module_to_annotator.keys()
35        self.delegation_spec = DelegationSpec(ExampleBackend.__name__, [])
36
37        class DequantQuantOperatorSupport(OperatorSupportBase):
38            def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool:
39                return node.op == "call_function" and node.target in [
40                    exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
41                    exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
42                ]
43
44        self.dequant_quant_support = DequantQuantOperatorSupport()
45
46    def _partition_graph_module(
47        self, edge_graph_module: torch.fx.GraphModule
48    ) -> Dict[str, DelegationSpec]:
49        partition_tags: Dict[str, DelegationSpec] = {}
50        partition_nodes = []
51        for pattern in self.patterns:
52            fused_partitions = find_sequential_partitions(
53                edge_graph_module,
54                pattern,
55            )
56
57            for fused_partition in fused_partitions:
58                for partition in fused_partition:
59                    partition_nodes.append(partition.nodes)
60
61        partitions = generate_partitions_from_list_of_nodes(
62            edge_graph_module, partition_nodes, self.dequant_quant_support
63        )
64
65        for partition in partitions:
66            for node in partition.nodes:
67                delegation_tag = f"tag{partition.id}"
68                node.meta["delegation_tag"] = delegation_tag
69                if node.op == "call_function":
70                    for arg_node in node.args:
71                        if (
72                            isinstance(arg_node, torch.fx.Node)
73                            and arg_node.op == "get_attr"
74                        ):
75                            arg_node.meta["delegation_tag"] = delegation_tag
76                partition_tags[delegation_tag] = self.delegation_spec
77
78        for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
79            submodule_partition_tags = self._partition_graph_module(submodule)
80            partition_tags.update(submodule_partition_tags)
81
82        return partition_tags
83
84    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
85        partition_tag = self._partition_graph_module(exported_program.graph_module)
86        return PartitionResult(
87            tagged_exported_program=exported_program, partition_tags=partition_tag
88        )
89