xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/xnnpack_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import itertools
8
9import logging
10from typing import List, Optional, Type, Union
11
12from executorch.backends.xnnpack.partition.config import ALL_PARTITIONER_CONFIGS
13from executorch.backends.xnnpack.partition.config.xnnpack_config import (
14    ConfigPrecisionType,
15    XNNPartitionerConfig,
16)
17
18from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
19from executorch.exir.backend.backend_details import ExportedProgram
20from executorch.exir.backend.canonical_partitioners.config_partitioner import (
21    ConfigerationBasedPartitioner,
22)
23from executorch.exir.backend.partitioner import DelegationSpec
24from torch.fx.passes.infra.partitioner import Partition
25
26logging.basicConfig(level=logging.WARNING)
27logger = logging.getLogger(__name__)
28
29
30class XnnpackPartitioner(ConfigerationBasedPartitioner):
31    def __init__(
32        self,
33        configs: Optional[List[Type[XNNPartitionerConfig]]] = None,
34        config_precisions: Optional[
35            Union[ConfigPrecisionType, List[ConfigPrecisionType]]
36        ] = None,
37        per_op_mode=False,
38        verbose: bool = False,
39        **kwargs,
40    ):
41        """
42        @verbose: if True, print out more information about the partitioner.
43            Default level is WARNING. If verbose is True, level is set to DEBUG.
44        """
45        if verbose:
46            logger.setLevel(logging.DEBUG)
47            logger.debug("Verbose logging enabled for XNNPACK partitioner.")
48
49        delegation_spec = DelegationSpec(XnnpackBackend.__name__, [])
50        configs_to_use = configs or ALL_PARTITIONER_CONFIGS
51        # Can do logic and have extra args to filter/delete/select
52        # Certain configs based on user specification
53        initialized_configs = []
54        if isinstance(config_precisions, ConfigPrecisionType):
55            config_precisions = [config_precisions]
56
57        for config in configs_to_use:
58            # Config Classes given to XnnpackPartitioner should no longer be abstract
59            initialized = config(**kwargs)  #  pyre-ignore
60            initialized.set_enabled_precision_types(config_precisions)
61            initialized_configs.append(initialized)
62
63        # per_op_mode takes the first match from a partitioner config, any
64        # subsequent matches that overlap with the first match are not partitioned
65        self.per_op_mode = per_op_mode
66        super().__init__(delegation_spec, initialized_configs)
67
68    def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
69        """
70        generate_partitions is different if partitioner is set to per_op_mode
71        for per_op_mode we only need to generate unmerged partitions instead
72        of using the default generate_partitions method.
73        """
74        if self.per_op_mode:
75            return self.generate_per_op_partitions(ep)
76        else:
77            return super().generate_partitions(ep)
78
79    def generate_per_op_partitions(self, ep: ExportedProgram) -> List[Partition]:
80        """
81        Uses configs to generate per_op_partitions. That is no partitions are
82        merged together. All partitions (node + deps) returned by PartitionerConfigs
83        are put into their own partition.
84        """
85        partitions = []
86        matched_nodes = self.get_matched_nodes_from_configs(ep)
87        partition_id = itertools.count()
88        nodes_seen = set()
89        for match in matched_nodes:
90            match_set = set(match)
91            # We only create partitions from the first PartitionerConfig match
92            # if a subsequent partitioner match contains the same node, we do
93            # not create a partition for it
94            if match_set.isdisjoint(nodes_seen):
95                partitions.append(
96                    Partition(
97                        id=next(partition_id),
98                        nodes=match_set,
99                    )
100                )
101                nodes_seen.update(match_set)
102        return partitions
103
104
105class XnnpackDynamicallyQuantizedPartitioner(XnnpackPartitioner):
106    def __init__(self):
107        super().__init__(
108            config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, per_op_mode=True
109        )
110
111
112class XnnpackFloatingPointPartitioner(XnnpackPartitioner):
113    def __init__(self):
114        super().__init__(config_precisions=ConfigPrecisionType.FP32)
115
116
117class XnnpackQuantizedPartitioner(XnnpackPartitioner):
118    def __init__(self):
119        super().__init__(config_precisions=ConfigPrecisionType.STATIC_QUANT)
120