1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import logging 7from typing import Any, cast, Dict, List, Union 8 9import torch 10from executorch.backends.apple.mps import MPSBackend 11from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors 12from executorch.backends.transforms import get_shape 13from executorch.exir.backend.backend_details import CompileSpec 14from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 15 generate_partitions_from_list_of_nodes, 16) 17from executorch.exir.backend.partitioner import ( 18 DelegationSpec, 19 Partitioner, 20 PartitionResult, 21) 22from executorch.exir.backend.utils import tag_constant_data 23from executorch.exir.dialects._ops import ops as exir_ops 24from torch.export.exported_program import ExportedProgram 25from torch.fx.passes.infra.partitioner import Partition 26from torch.fx.passes.operator_support import OperatorSupportBase 27 28FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 29logging.basicConfig(level=logging.INFO, format=FORMAT) 30 31# ops implemented as Metal kernels. 32METAL_KERNELS = [ 33 exir_ops.edge.aten.index.Tensor, 34 exir_ops.edge.aten.index_put.default, 35] 36 37 38class MPSOperatorSupport(OperatorSupportBase): 39 def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): 40 self.node_visitors = get_node_visitors(edge_program) 41 self.edge_program = edge_program 42 43 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 44 if node.op != "call_function": 45 return False 46 47 if node.target.__name__ not in self.node_visitors: 48 logging.debug(f"[UNSUPPORTED] Node {node.target.__name__} not supported") 49 return False 50 51 return True 52 53 54class MPSPartitioner(Partitioner): 55 def __init__(self, compile_specs: List[CompileSpec]) -> None: 56 self.compile_specs = compile_specs 57 self.delegation_spec = DelegationSpec(MPSBackend.__name__, compile_specs) 58 self.partition_tags: Dict[str, DelegationSpec] = {} 59 60 def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: 61 self.supported_ops = MPSOperatorSupport( 62 edge_program=edge_program, compiler_specs=self.delegation_spec.compile_specs 63 ) 64 return generate_partitions_from_list_of_nodes( 65 edge_program.graph_module, 66 op_support=self.supported_ops, 67 ) 68 69 def mps_graph_advanced_indexing_support(self, node: torch.fx.Node): 70 num_indices = 0 71 tensors = cast(List[torch.fx.Node], node.args[1]) 72 input = cast(torch.fx.Node, node.args[0]) 73 for t in tensors: 74 if t is not None: 75 num_indices += 1 76 # Can dispatch to MPSGraph if the length of the slices is equal 77 # to the number of dimensions of the sliced tensors, or only one 78 # slice is present. All other cases will fallback to a Metal kernel. 79 if num_indices == len(get_shape(input)) or num_indices == 1: 80 return True 81 82 return False 83 84 def use_metal_kernel(self, node: torch.fx.Node): 85 if node.target in METAL_KERNELS: 86 if ( 87 node.target == exir_ops.edge.aten.index.Tensor 88 or node.target == exir_ops.edge.aten.index_put.default 89 ): 90 if not self.mps_graph_advanced_indexing_support(node): 91 return True 92 return False 93 94 def tag_nodes(self, partitions: List[Partition]) -> None: 95 for partition in partitions: 96 crt_partition_counter = 0 97 for node in partition.nodes: 98 delegation_tag = f"mps_{partition.id}" 99 if self.use_metal_kernel(node): 100 logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!") 101 # Partition the Metal kernel into a separate partition 102 crt_partition_counter += 1 103 delegation_tag = ( 104 f"{delegation_tag}_metal_kernel_{crt_partition_counter}" 105 ) 106 crt_partition_counter += 1 107 else: 108 delegation_tag = f"{delegation_tag}_{crt_partition_counter}" 109 110 node.meta["delegation_tag"] = delegation_tag 111 self.partition_tags[delegation_tag] = self.delegation_spec 112 113 @staticmethod 114 def check_partitions(partitions: Union[dict, list]) -> bool: 115 pl = len(partitions) 116 if pl == 0: 117 logging.warning("Nothing can be partitioned!") 118 else: 119 logging.info(f"Found {pl} subgraphs to be partitioned.") 120 return pl != 0 121 122 # override 123 def partition(self, edge_program: ExportedProgram) -> PartitionResult: 124 partitions = self.generate_partitions(edge_program=edge_program) 125 if self.check_partitions(partitions): 126 self.tag_nodes(partitions) 127 # Tag constant data that are used by the supported ops in MPS backend. 128 tag_constant_data(edge_program) 129 x = PartitionResult( 130 tagged_exported_program=edge_program, partition_tags=self.partition_tags 131 ) 132 133 return x 134