1# Copyright © 2023 Apple Inc. All rights reserved. 2# 3# Please refer to the license found in the LICENSE file in the root directory of the source tree. 4 5import logging 6from typing import List, Optional 7 8import coremltools as ct 9 10import torch 11 12from executorch.backends.apple.coreml.compiler import CoreMLBackend 13from executorch.exir.backend.compile_spec_schema import CompileSpec 14 15from executorch.exir.backend.partitioner import ( 16 DelegationSpec, 17 Partitioner, 18 PartitionResult, 19) 20from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer 21from torch.export.exported_program import ExportedProgram 22from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 23from torch.fx.passes.operator_support import OperatorSupportBase 24 25logger = logging.getLogger(__name__) 26logger.setLevel(logging.WARNING) 27 28 29class OperatorsSupportedForCoreMLBackend(OperatorSupportBase): 30 def __init__( 31 self, skip_ops_for_coreml_delegation: Optional[List[str]] = None 32 ) -> None: 33 if skip_ops_for_coreml_delegation is None: 34 skip_ops_for_coreml_delegation = [] 35 super().__init__() 36 self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation 37 38 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 39 # get_attr node can always be supported on any backend 40 if node.op == "get_attr": 41 return True 42 # check if the PyTorch op get called is supported in Core ML 43 elif node.op == "call_function": 44 # skip ops if specified by user 45 node_target_name = getattr(node.target, "__name__", "").lower() 46 if node_target_name in (self.skip_ops_for_coreml_delegation or []): 47 return False 48 # query coremltools to see if node is supported 49 return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node) 50 # cowardly refuse to support all other types of node: 51 # 1. placeholder / output nodes should not be tagged 52 # reference: https://github.com/pytorch/executorch/pull/1398 53 # 2. call_module / call_method should have been replaced with call_function? 54 else: 55 return False 56 57 58class CoreMLPartitioner(Partitioner): 59 60 def __init__( 61 self, 62 skip_ops_for_coreml_delegation: Optional[List[str]] = None, 63 compile_specs: Optional[List[CompileSpec]] = None, 64 take_over_mutable_buffer: Optional[bool] = True, 65 ) -> None: 66 if skip_ops_for_coreml_delegation is None: 67 skip_ops_for_coreml_delegation = [] 68 self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation 69 self.delegation_spec = DelegationSpec( 70 backend_id=CoreMLBackend.__name__, 71 compile_specs=compile_specs if compile_specs is not None else [], 72 ) 73 self.take_over_mutable_buffer = take_over_mutable_buffer 74 75 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 76 # Run the CapabilityBasedPartitioner to return the largest possible 77 # subgraphs containing the nodes with the tags 78 logger.info("CoreMLPartitioner::partition") 79 partition_tags = {} 80 81 capability_partitioner = CapabilityBasedPartitioner( 82 exported_program.graph_module, 83 OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation), 84 allows_single_node_partition=True, 85 ) 86 partition_list = capability_partitioner.propose_partitions() 87 for partition in partition_list: 88 for node in partition.nodes: 89 tag = f"tag{partition.id}" 90 node.meta["delegation_tag"] = tag 91 partition_tags[tag] = self.delegation_spec 92 93 tag_constant_data(exported_program) 94 if self.take_over_mutable_buffer: 95 logger.info( 96 "Core ML partitioner will take over torch mutable buffer as Core ML state, " 97 "so if your model contains mutable buffer, " 98 "then you will need MacOS15+/iOS18+ to execute. " 99 "If you want your mutable buffer model to be compatible with older OS, " 100 "then please set `take_over_mutable_buffer=False`" 101 ) 102 tag_mutated_buffer(exported_program) 103 104 return PartitionResult( 105 tagged_exported_program=exported_program, partition_tags=partition_tags 106 ) 107