xref: /aosp_15_r20/external/executorch/backends/arm/arm_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Workerimport logging
9*523fa7a6SAndroid Build Coastguard Workerimport os
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, final, List, Optional, Tuple
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.arm_backend import ArmBackend  # usort: skip
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.operator_support.tosa_supported_operators import (
16*523fa7a6SAndroid Build Coastguard Worker    TOSASupportedOperators,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_specification import TosaSpecification
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import (
21*523fa7a6SAndroid Build Coastguard Worker    DelegationSpec,
22*523fa7a6SAndroid Build Coastguard Worker    Partitioner,
23*523fa7a6SAndroid Build Coastguard Worker    PartitionResult,
24*523fa7a6SAndroid Build Coastguard Worker)
25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import tag_constant_data
26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import PassManager
27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
28*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
31*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING)
32*523fa7a6SAndroid Build Coastguard WorkerTOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
33*523fa7a6SAndroid Build Coastguard Workerif TOSA_DBG_VERBOSE:
34*523fa7a6SAndroid Build Coastguard Worker    logging.basicConfig(level=logging.INFO)
35*523fa7a6SAndroid Build Coastguard Worker    logger.setLevel(logging.INFO)
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Worker@final
39*523fa7a6SAndroid Build Coastguard Workerclass ArmPartitioner(Partitioner):
40*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, compile_spec: List[CompileSpec]) -> None:
41*523fa7a6SAndroid Build Coastguard Worker        self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec)
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
44*523fa7a6SAndroid Build Coastguard Worker        # Run the CapabilityBasedPartitioner to return the largest possible
45*523fa7a6SAndroid Build Coastguard Worker        # subgraphs containing the nodes with the tags
46*523fa7a6SAndroid Build Coastguard Worker        logger.info("ArmPartitioner::partition")
47*523fa7a6SAndroid Build Coastguard Worker        partition_tags = {}
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker        tosa_spec = TosaSpecification.create_from_compilespecs(
50*523fa7a6SAndroid Build Coastguard Worker            self.delegation_spec.compile_specs
51*523fa7a6SAndroid Build Coastguard Worker        )
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker        logger.info(f"Partitioning for {tosa_spec}")
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker        for spec in self.delegation_spec.compile_specs:
56*523fa7a6SAndroid Build Coastguard Worker            if spec.key == "quantize_io" and spec.value.decode() == "True":
57*523fa7a6SAndroid Build Coastguard Worker                # Exclude IO quantization from the partition
58*523fa7a6SAndroid Build Coastguard Worker                passes = PassManager(
59*523fa7a6SAndroid Build Coastguard Worker                    passes=[
60*523fa7a6SAndroid Build Coastguard Worker                        TagIOQuantPass(),
61*523fa7a6SAndroid Build Coastguard Worker                    ]
62*523fa7a6SAndroid Build Coastguard Worker                )
63*523fa7a6SAndroid Build Coastguard Worker                passes(exported_program.graph_module)
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker        capability_partitioner = CapabilityBasedPartitioner(
66*523fa7a6SAndroid Build Coastguard Worker            exported_program.graph_module,
67*523fa7a6SAndroid Build Coastguard Worker            TOSASupportedOperators(tosa_spec),
68*523fa7a6SAndroid Build Coastguard Worker            allows_single_node_partition=True,
69*523fa7a6SAndroid Build Coastguard Worker        )
70*523fa7a6SAndroid Build Coastguard Worker        partition_list = capability_partitioner.propose_partitions()
71*523fa7a6SAndroid Build Coastguard Worker        for partition in partition_list:
72*523fa7a6SAndroid Build Coastguard Worker            for node in partition.nodes:
73*523fa7a6SAndroid Build Coastguard Worker                tag = f"tag{partition.id}"
74*523fa7a6SAndroid Build Coastguard Worker                node.meta["delegation_tag"] = tag
75*523fa7a6SAndroid Build Coastguard Worker                partition_tags[tag] = self.delegation_spec
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker        tag_constant_data(exported_program)
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker        return PartitionResult(
80*523fa7a6SAndroid Build Coastguard Worker            tagged_exported_program=exported_program, partition_tags=partition_tags
81*523fa7a6SAndroid Build Coastguard Worker        )
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker    def ops_to_not_decompose(
84*523fa7a6SAndroid Build Coastguard Worker        self,
85*523fa7a6SAndroid Build Coastguard Worker        ep: ExportedProgram,
86*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
87*523fa7a6SAndroid Build Coastguard Worker        ops_to_not_decompose = [
88*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.linear.default,
89*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.upsample_nearest2d.vec,
90*523fa7a6SAndroid Build Coastguard Worker        ]
91*523fa7a6SAndroid Build Coastguard Worker        return (ops_to_not_decompose, None)
92