xref: /aosp_15_r20/external/executorch/exir/backend/partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom abc import ABC, abstractmethod
8*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
9*523fa7a6SAndroid Build Coastguard Workerfrom types import MappingProxyType
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import enforcedmethod
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec
16*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerclass DelegationSpec(NamedTuple):
20*523fa7a6SAndroid Build Coastguard Worker    backend_id: str
21*523fa7a6SAndroid Build Coastguard Worker    compile_specs: List[CompileSpec]
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker@dataclass
25*523fa7a6SAndroid Build Coastguard Workerclass PartitionResult:
26*523fa7a6SAndroid Build Coastguard Worker    """
27*523fa7a6SAndroid Build Coastguard Worker    tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata
28*523fa7a6SAndroid Build Coastguard Worker    partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used
29*523fa7a6SAndroid Build Coastguard Worker    in the node.meta.
30*523fa7a6SAndroid Build Coastguard Worker    """
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker    tagged_exported_program: ExportedProgram
33*523fa7a6SAndroid Build Coastguard Worker    partition_tags: Dict[str, DelegationSpec]
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Workerclass Partitioner(ABC):
37*523fa7a6SAndroid Build Coastguard Worker    """
38*523fa7a6SAndroid Build Coastguard Worker    Defines a callable interface for partitioning an exported program for
39*523fa7a6SAndroid Build Coastguard Worker    backend delegation.
40*523fa7a6SAndroid Build Coastguard Worker    A partitioner implementation would receive an exported program, determine what portions of
41*523fa7a6SAndroid Build Coastguard Worker    the it can be delegated to certain backend (though a partitioner can target multiple
42*523fa7a6SAndroid Build Coastguard Worker    backends as well), and return the PartitionResult including:
43*523fa7a6SAndroid Build Coastguard Worker    - the same input module with specific nodes in the input graph tagged for delegation
44*523fa7a6SAndroid Build Coastguard Worker    - the "partition_tags" to indicate how the tag is mapped to Delegation Spec.
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    The nodes that intend to be delegated must be tagged (by setting
47*523fa7a6SAndroid Build Coastguard Worker    node.meta["delegation_tag"]) and this tag must be provided in the
48*523fa7a6SAndroid Build Coastguard Worker    `partition_tags` dictionary mapping to an instance of
49*523fa7a6SAndroid Build Coastguard Worker    DelegationSpec(backend_id, method_compilation_spec). Each tag must represent
50*523fa7a6SAndroid Build Coastguard Worker    a distinct submodule that we intend on lowering and should be fully contained.
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker    For details on method_compilation_spec see the to_backend API, as these objects follow
53*523fa7a6SAndroid Build Coastguard Worker    the same format.
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker    Args:
56*523fa7a6SAndroid Build Coastguard Worker        exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
57*523fa7a6SAndroid Build Coastguard Worker    """
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker    def __init__(
60*523fa7a6SAndroid Build Coastguard Worker        self,
61*523fa7a6SAndroid Build Coastguard Worker        spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}),
62*523fa7a6SAndroid Build Coastguard Worker    ):
63*523fa7a6SAndroid Build Coastguard Worker        self._spec = spec
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker    def __call__(self, exported_program: ExportedProgram) -> PartitionResult:
66*523fa7a6SAndroid Build Coastguard Worker        return self.partition(exported_program)
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker    @property
69*523fa7a6SAndroid Build Coastguard Worker    def spec(self) -> Mapping[Union[str, int, float, bool], object]:
70*523fa7a6SAndroid Build Coastguard Worker        return self._spec
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    @enforcedmethod
73*523fa7a6SAndroid Build Coastguard Worker    @abstractmethod
74*523fa7a6SAndroid Build Coastguard Worker    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
75*523fa7a6SAndroid Build Coastguard Worker        """
76*523fa7a6SAndroid Build Coastguard Worker        Returns the input exported program with newly created sub-Modules encapsulating
77*523fa7a6SAndroid Build Coastguard Worker        specific portions of the input "tagged" for delegation.
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker        The specific implementation is free to decide how existing computation in the
80*523fa7a6SAndroid Build Coastguard Worker        input exported program should be delegated to one or even more than one specific
81*523fa7a6SAndroid Build Coastguard Worker        backends.
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker        The contract is stringent in that:
84*523fa7a6SAndroid Build Coastguard Worker        * Each node that is intended to be delegated must be tagged
85*523fa7a6SAndroid Build Coastguard Worker        * No change in the original input exported program (ExportedProgram) representation can take
86*523fa7a6SAndroid Build Coastguard Worker        place other than adding sub-Modules for encapsulating existing portions of the
87*523fa7a6SAndroid Build Coastguard Worker        input exported program and the associated metadata for tagging.
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker        Args:
90*523fa7a6SAndroid Build Coastguard Worker            exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation.
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker        Returns:
93*523fa7a6SAndroid Build Coastguard Worker            PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers.
94*523fa7a6SAndroid Build Coastguard Worker        """
95*523fa7a6SAndroid Build Coastguard Worker        pass
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker    def ops_to_not_decompose(
98*523fa7a6SAndroid Build Coastguard Worker        self,
99*523fa7a6SAndroid Build Coastguard Worker        ep: ExportedProgram,
100*523fa7a6SAndroid Build Coastguard Worker    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
101*523fa7a6SAndroid Build Coastguard Worker        """
102*523fa7a6SAndroid Build Coastguard Worker        Returns a list of operator names that should not be decomposed. When these ops are
103*523fa7a6SAndroid Build Coastguard Worker        registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be
104*523fa7a6SAndroid Build Coastguard Worker        guaranteed that the program that the backend receives will not have any of these ops
105*523fa7a6SAndroid Build Coastguard Worker        decomposed.
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker        Returns:
108*523fa7a6SAndroid Build Coastguard Worker            List[torch._ops.OpOverload]: a list of operator names that should not be decomposed.
109*523fa7a6SAndroid Build Coastguard Worker            Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide
110*523fa7a6SAndroid Build Coastguard Worker            which will be called for each node in the graph that users can use as a filter for certain
111*523fa7a6SAndroid Build Coastguard Worker            nodes that should be continued to be decomposed even though the op they correspond to is
112*523fa7a6SAndroid Build Coastguard Worker            in the list returned by ops_to_not_decompose.
113*523fa7a6SAndroid Build Coastguard Worker        """
114*523fa7a6SAndroid Build Coastguard Worker        return ([], None)
115