xref: /aosp_15_r20/external/executorch/backends/mediatek/partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file
4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root
5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, final, List, Optional, Tuple
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.mediatek.preprocess import NeuropilotBackend
11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import CompileSpec
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import (
13*523fa7a6SAndroid Build Coastguard Worker    DelegationSpec,
14*523fa7a6SAndroid Build Coastguard Worker    Partitioner,
15*523fa7a6SAndroid Build Coastguard Worker    PartitionResult,
16*523fa7a6SAndroid Build Coastguard Worker)
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import tag_constant_data
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerfrom mtk_converter.python.converters.pytorch import importer_v2
20*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
21*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
22*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.operator_support import OperatorSupportBase
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotOperatorsSupport(OperatorSupportBase):
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker    def __init__(
28*523fa7a6SAndroid Build Coastguard Worker        self,
29*523fa7a6SAndroid Build Coastguard Worker        op_types_to_skip: Optional[set] = None,
30*523fa7a6SAndroid Build Coastguard Worker        op_names_to_skip: Optional[set] = None,
31*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
32*523fa7a6SAndroid Build Coastguard Worker        if op_types_to_skip is None:
33*523fa7a6SAndroid Build Coastguard Worker            op_types_to_skip = set()
34*523fa7a6SAndroid Build Coastguard Worker        if op_names_to_skip is None:
35*523fa7a6SAndroid Build Coastguard Worker            op_names_to_skip = set()
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker        self._op_types_to_skip = op_types_to_skip
38*523fa7a6SAndroid Build Coastguard Worker        self._op_names_to_skip = op_names_to_skip
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker    def is_node_supported(self, _, node: torch.fx.Node) -> bool:
41*523fa7a6SAndroid Build Coastguard Worker        # Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged.
42*523fa7a6SAndroid Build Coastguard Worker        # Ref: https://github.com/pytorch/executorch/pull/1398
43*523fa7a6SAndroid Build Coastguard Worker        if node.op != "call_function":
44*523fa7a6SAndroid Build Coastguard Worker            return False
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker        op_type = node.target.__name__
47*523fa7a6SAndroid Build Coastguard Worker        if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
48*523fa7a6SAndroid Build Coastguard Worker            print(
49*523fa7a6SAndroid Build Coastguard Worker                f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."
50*523fa7a6SAndroid Build Coastguard Worker            )
51*523fa7a6SAndroid Build Coastguard Worker            return False
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker        return importer_v2.is_fx_node_supported(node)
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker@final
57*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotPartitioner(Partitioner):
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker    def __init__(
60*523fa7a6SAndroid Build Coastguard Worker        self,
61*523fa7a6SAndroid Build Coastguard Worker        compile_spec: List[CompileSpec],
62*523fa7a6SAndroid Build Coastguard Worker        op_types_to_skip: Optional[set] = None,
63*523fa7a6SAndroid Build Coastguard Worker        op_names_to_skip: Optional[set] = None,
64*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
65*523fa7a6SAndroid Build Coastguard Worker        self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec)
66*523fa7a6SAndroid Build Coastguard Worker        self._op_types_to_skip = op_types_to_skip
67*523fa7a6SAndroid Build Coastguard Worker        self._op_names_to_skip = op_names_to_skip
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker    def ops_to_not_decompose(
70*523fa7a6SAndroid Build Coastguard Worker        self,
71*523fa7a6SAndroid Build Coastguard Worker        ep: ExportedProgram,
72*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
73*523fa7a6SAndroid Build Coastguard Worker        ops_not_decompose = [
74*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.pixel_shuffle.default,
75*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.upsample_bilinear2d.default,
76*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.upsample_bilinear2d.vec,
77*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.upsample_nearest2d.default,
78*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.upsample_nearest2d.vec,
79*523fa7a6SAndroid Build Coastguard Worker        ]
80*523fa7a6SAndroid Build Coastguard Worker        return (ops_not_decompose, None)
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
83*523fa7a6SAndroid Build Coastguard Worker        capability_partitioner = CapabilityBasedPartitioner(
84*523fa7a6SAndroid Build Coastguard Worker            exported_program.graph_module,
85*523fa7a6SAndroid Build Coastguard Worker            NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
86*523fa7a6SAndroid Build Coastguard Worker            allows_single_node_partition=True,
87*523fa7a6SAndroid Build Coastguard Worker        )
88*523fa7a6SAndroid Build Coastguard Worker        partition_list = capability_partitioner.propose_partitions()
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker        partition_tags = {}
91*523fa7a6SAndroid Build Coastguard Worker        for partition in partition_list:
92*523fa7a6SAndroid Build Coastguard Worker            for node in partition.nodes:
93*523fa7a6SAndroid Build Coastguard Worker                tag = f"tag{partition.id}"
94*523fa7a6SAndroid Build Coastguard Worker                node.meta["delegation_tag"] = tag
95*523fa7a6SAndroid Build Coastguard Worker                partition_tags[tag] = self.delegation_spec
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker        tag_constant_data(exported_program)
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker        return PartitionResult(
100*523fa7a6SAndroid Build Coastguard Worker            tagged_exported_program=exported_program, partition_tags=partition_tags
101*523fa7a6SAndroid Build Coastguard Worker        )
102