xref: /aosp_15_r20/external/executorch/exir/backend/partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9from types import MappingProxyType
10from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
11
12import torch
13
14from executorch.exir.backend.backend_details import enforcedmethod
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from torch.export import ExportedProgram
17
18
19class DelegationSpec(NamedTuple):
20    backend_id: str
21    compile_specs: List[CompileSpec]
22
23
24@dataclass
25class PartitionResult:
26    """
27    tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
28    partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
29    in the node.meta.
30    """
31
32    tagged_exported_program: ExportedProgram
33    partition_tags: Dict[str, DelegationSpec]
34
35
36class Partitioner(ABC):
37    """
38    Defines a callable interface for partitioning an exported program for
39    backend delegation.
40    A partitioner implementation would receive an exported program, determine what portions of
41    the it can be delegated to certain backend (though a partitioner can target multiple
42    backends as well), and return the PartitionResult including:
43    - the same input module with specific nodes in the input graph tagged for delegation
44    - the "partition_tags" to indicate how the tag is mapped to Delegation Spec.
45
46    The nodes that intend to be delegated must be tagged (by setting
47    node.meta["delegation_tag"]) and this tag must be provided in the
48    `partition_tags` dictionary mapping to an instance of
49    DelegationSpec(backend_id, method_compilation_spec). Each tag must represent
50    a distinct submodule that we intend on lowering and should be fully contained.
51
52    For details on method_compilation_spec see the to_backend API, as these objects follow
53    the same format.
54
55    Args:
56        exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
57    """
58
59    def __init__(
60        self,
61        spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}),
62    ):
63        self._spec = spec
64
65    def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
66        return self.partition(exported_program)
67
68    @property
69    def spec(self) -> Mapping[Union[str, int, float, bool], object]:
70        return self._spec
71
72    @enforcedmethod
73    @abstractmethod
74    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
75        """
76        Returns the input exported program with newly created sub-Modules encapsulating
77        specific portions of the input "tagged" for delegation.
78
79        The specific implementation is free to decide how existing computation in the
80        input exported program should be delegated to one or even more than one specific
81        backends.
82
83        The contract is stringent in that:
84        * Each node that is intended to be delegated must be tagged
85        * No change in the original input exported program (ExportedProgram) representation can take
86        place other than adding sub-Modules for encapsulating existing portions of the
87        input exported program and the associated metadata for tagging.
88
89        Args:
90            exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
91
92        Returns:
93            PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
94        """
95        pass
96
97    def ops_to_not_decompose(
98        self,
99        ep: ExportedProgram,
100    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
101        """
102        Returns a list of operator names that should not be decomposed. When these ops are
103        registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
104        guaranteed that the program that the backend receives will not have any of these ops
105        decomposed.
106
107        Returns:
108            List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
109            Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
110            which will be called for each node in the graph that users can use as a filter for certain
111            nodes that should be continued to be decomposed even though the op they correspond to is
112            in the list returned by ops_to_not_decompose.
113        """
114        return ([], None)
115