xref: /aosp_15_r20/external/executorch/backends/apple/coreml/partition/coreml_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright © 2023 Apple Inc. All rights reserved.
2#
3# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4
5import logging
6from typing import List, Optional
7
8import coremltools as ct
9
10import torch
11
12from executorch.backends.apple.coreml.compiler import CoreMLBackend
13from executorch.exir.backend.compile_spec_schema import CompileSpec
14
15from executorch.exir.backend.partitioner import (
16    DelegationSpec,
17    Partitioner,
18    PartitionResult,
19)
20from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
21from torch.export.exported_program import ExportedProgram
22from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
23from torch.fx.passes.operator_support import OperatorSupportBase
24
25logger = logging.getLogger(__name__)
26logger.setLevel(logging.WARNING)
27
28
29class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
30    def __init__(
31        self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
32    ) -> None:
33        if skip_ops_for_coreml_delegation is None:
34            skip_ops_for_coreml_delegation = []
35        super().__init__()
36        self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
37
38    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
39        # get_attr node can always be supported on any backend
40        if node.op == "get_attr":
41            return True
42        # check if the PyTorch op get called is supported in Core ML
43        elif node.op == "call_function":
44            # skip ops if specified by user
45            node_target_name = getattr(node.target, "__name__", "").lower()
46            if node_target_name in (self.skip_ops_for_coreml_delegation or []):
47                return False
48            # query coremltools to see if node is supported
49            return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
50        # cowardly refuse to support all other types of node:
51        # 1. placeholder / output nodes should not be tagged
52        #    reference: https://github.com/pytorch/executorch/pull/1398
53        # 2. call_module / call_method should have been replaced with call_function?
54        else:
55            return False
56
57
58class CoreMLPartitioner(Partitioner):
59
60    def __init__(
61        self,
62        skip_ops_for_coreml_delegation: Optional[List[str]] = None,
63        compile_specs: Optional[List[CompileSpec]] = None,
64        take_over_mutable_buffer: Optional[bool] = True,
65    ) -> None:
66        if skip_ops_for_coreml_delegation is None:
67            skip_ops_for_coreml_delegation = []
68        self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
69        self.delegation_spec = DelegationSpec(
70            backend_id=CoreMLBackend.__name__,
71            compile_specs=compile_specs if compile_specs is not None else [],
72        )
73        self.take_over_mutable_buffer = take_over_mutable_buffer
74
75    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
76        # Run the CapabilityBasedPartitioner to return the largest possible
77        # subgraphs containing the nodes with the tags
78        logger.info("CoreMLPartitioner::partition")
79        partition_tags = {}
80
81        capability_partitioner = CapabilityBasedPartitioner(
82            exported_program.graph_module,
83            OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation),
84            allows_single_node_partition=True,
85        )
86        partition_list = capability_partitioner.propose_partitions()
87        for partition in partition_list:
88            for node in partition.nodes:
89                tag = f"tag{partition.id}"
90                node.meta["delegation_tag"] = tag
91                partition_tags[tag] = self.delegation_spec
92
93        tag_constant_data(exported_program)
94        if self.take_over_mutable_buffer:
95            logger.info(
96                "Core ML partitioner will take over torch mutable buffer as Core ML state, "
97                "so if your model contains mutable buffer, "
98                "then you will need MacOS15+/iOS18+ to execute. "
99                "If you want your mutable buffer model to be compatible with older OS, "
100                "then please set `take_over_mutable_buffer=False`"
101            )
102            tag_mutated_buffer(exported_program)
103
104        return PartitionResult(
105            tagged_exported_program=exported_program, partition_tags=partition_tags
106        )
107