xref: /aosp_15_r20/external/executorch/backends/apple/mps/partition/mps_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import logging
7from typing import Any, cast, Dict, List, Union
8
9import torch
10from executorch.backends.apple.mps import MPSBackend
11from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
12from executorch.backends.transforms import get_shape
13from executorch.exir.backend.backend_details import CompileSpec
14from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
15    generate_partitions_from_list_of_nodes,
16)
17from executorch.exir.backend.partitioner import (
18    DelegationSpec,
19    Partitioner,
20    PartitionResult,
21)
22from executorch.exir.backend.utils import tag_constant_data
23from executorch.exir.dialects._ops import ops as exir_ops
24from torch.export.exported_program import ExportedProgram
25from torch.fx.passes.infra.partitioner import Partition
26from torch.fx.passes.operator_support import OperatorSupportBase
27
28FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
29logging.basicConfig(level=logging.INFO, format=FORMAT)
30
31# ops implemented as Metal kernels.
32METAL_KERNELS = [
33    exir_ops.edge.aten.index.Tensor,
34    exir_ops.edge.aten.index_put.default,
35]
36
37
38class MPSOperatorSupport(OperatorSupportBase):
39    def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
40        self.node_visitors = get_node_visitors(edge_program)
41        self.edge_program = edge_program
42
43    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
44        if node.op != "call_function":
45            return False
46
47        if node.target.__name__ not in self.node_visitors:
48            logging.debug(f"[UNSUPPORTED] Node {node.target.__name__} not supported")
49            return False
50
51        return True
52
53
54class MPSPartitioner(Partitioner):
55    def __init__(self, compile_specs: List[CompileSpec]) -> None:
56        self.compile_specs = compile_specs
57        self.delegation_spec = DelegationSpec(MPSBackend.__name__, compile_specs)
58        self.partition_tags: Dict[str, DelegationSpec] = {}
59
60    def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
61        self.supported_ops = MPSOperatorSupport(
62            edge_program=edge_program, compiler_specs=self.delegation_spec.compile_specs
63        )
64        return generate_partitions_from_list_of_nodes(
65            edge_program.graph_module,
66            op_support=self.supported_ops,
67        )
68
69    def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
70        num_indices = 0
71        tensors = cast(List[torch.fx.Node], node.args[1])
72        input = cast(torch.fx.Node, node.args[0])
73        for t in tensors:
74            if t is not None:
75                num_indices += 1
76        # Can dispatch to MPSGraph if the length of the slices is equal
77        # to the number of dimensions of the sliced tensors, or only one
78        # slice is present. All other cases will fallback to a Metal kernel.
79        if num_indices == len(get_shape(input)) or num_indices == 1:
80            return True
81
82        return False
83
84    def use_metal_kernel(self, node: torch.fx.Node):
85        if node.target in METAL_KERNELS:
86            if (
87                node.target == exir_ops.edge.aten.index.Tensor
88                or node.target == exir_ops.edge.aten.index_put.default
89            ):
90                if not self.mps_graph_advanced_indexing_support(node):
91                    return True
92        return False
93
94    def tag_nodes(self, partitions: List[Partition]) -> None:
95        for partition in partitions:
96            crt_partition_counter = 0
97            for node in partition.nodes:
98                delegation_tag = f"mps_{partition.id}"
99                if self.use_metal_kernel(node):
100                    logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
101                    # Partition the Metal kernel into a separate partition
102                    crt_partition_counter += 1
103                    delegation_tag = (
104                        f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
105                    )
106                    crt_partition_counter += 1
107                else:
108                    delegation_tag = f"{delegation_tag}_{crt_partition_counter}"
109
110                node.meta["delegation_tag"] = delegation_tag
111                self.partition_tags[delegation_tag] = self.delegation_spec
112
113    @staticmethod
114    def check_partitions(partitions: Union[dict, list]) -> bool:
115        pl = len(partitions)
116        if pl == 0:
117            logging.warning("Nothing can be partitioned!")
118        else:
119            logging.info(f"Found {pl} subgraphs to be partitioned.")
120        return pl != 0
121
122    # override
123    def partition(self, edge_program: ExportedProgram) -> PartitionResult:
124        partitions = self.generate_partitions(edge_program=edge_program)
125        if self.check_partitions(partitions):
126            self.tag_nodes(partitions)
127            # Tag constant data that are used by the supported ops in MPS backend.
128            tag_constant_data(edge_program)
129        x = PartitionResult(
130            tagged_exported_program=edge_program, partition_tags=self.partition_tags
131        )
132
133        return x
134