1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom abc import ABC, abstractmethod 8*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 9*523fa7a6SAndroid Build Coastguard Workerfrom types import MappingProxyType 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import enforcedmethod 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 16*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerclass DelegationSpec(NamedTuple): 20*523fa7a6SAndroid Build Coastguard Worker backend_id: str 21*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec] 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker@dataclass 25*523fa7a6SAndroid Build Coastguard Workerclass PartitionResult: 26*523fa7a6SAndroid Build Coastguard Worker """ 27*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata 28*523fa7a6SAndroid Build Coastguard Worker partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used 29*523fa7a6SAndroid Build Coastguard Worker in the node.meta. 30*523fa7a6SAndroid Build Coastguard Worker """ 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program: ExportedProgram 33*523fa7a6SAndroid Build Coastguard Worker partition_tags: Dict[str, DelegationSpec] 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Workerclass Partitioner(ABC): 37*523fa7a6SAndroid Build Coastguard Worker """ 38*523fa7a6SAndroid Build Coastguard Worker Defines a callable interface for partitioning an exported program for 39*523fa7a6SAndroid Build Coastguard Worker backend delegation. 40*523fa7a6SAndroid Build Coastguard Worker A partitioner implementation would receive an exported program, determine what portions of 41*523fa7a6SAndroid Build Coastguard Worker the it can be delegated to certain backend (though a partitioner can target multiple 42*523fa7a6SAndroid Build Coastguard Worker backends as well), and return the PartitionResult including: 43*523fa7a6SAndroid Build Coastguard Worker - the same input module with specific nodes in the input graph tagged for delegation 44*523fa7a6SAndroid Build Coastguard Worker - the "partition_tags" to indicate how the tag is mapped to Delegation Spec. 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker The nodes that intend to be delegated must be tagged (by setting 47*523fa7a6SAndroid Build Coastguard Worker node.meta["delegation_tag"]) and this tag must be provided in the 48*523fa7a6SAndroid Build Coastguard Worker `partition_tags` dictionary mapping to an instance of 49*523fa7a6SAndroid Build Coastguard Worker DelegationSpec(backend_id, method_compilation_spec). Each tag must represent 50*523fa7a6SAndroid Build Coastguard Worker a distinct submodule that we intend on lowering and should be fully contained. 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker For details on method_compilation_spec see the to_backend API, as these objects follow 53*523fa7a6SAndroid Build Coastguard Worker the same format. 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker Args: 56*523fa7a6SAndroid Build Coastguard Worker exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation. 57*523fa7a6SAndroid Build Coastguard Worker """ 58*523fa7a6SAndroid Build Coastguard Worker 59*523fa7a6SAndroid Build Coastguard Worker def __init__( 60*523fa7a6SAndroid Build Coastguard Worker self, 61*523fa7a6SAndroid Build Coastguard Worker spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}), 62*523fa7a6SAndroid Build Coastguard Worker ): 63*523fa7a6SAndroid Build Coastguard Worker self._spec = spec 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker def __call__(self, exported_program: ExportedProgram) -> PartitionResult: 66*523fa7a6SAndroid Build Coastguard Worker return self.partition(exported_program) 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker @property 69*523fa7a6SAndroid Build Coastguard Worker def spec(self) -> Mapping[Union[str, int, float, bool], object]: 70*523fa7a6SAndroid Build Coastguard Worker return self._spec 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker @enforcedmethod 73*523fa7a6SAndroid Build Coastguard Worker @abstractmethod 74*523fa7a6SAndroid Build Coastguard Worker def partition(self, exported_program: ExportedProgram) -> PartitionResult: 75*523fa7a6SAndroid Build Coastguard Worker """ 76*523fa7a6SAndroid Build Coastguard Worker Returns the input exported program with newly created sub-Modules encapsulating 77*523fa7a6SAndroid Build Coastguard Worker specific portions of the input "tagged" for delegation. 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker The specific implementation is free to decide how existing computation in the 80*523fa7a6SAndroid Build Coastguard Worker input exported program should be delegated to one or even more than one specific 81*523fa7a6SAndroid Build Coastguard Worker backends. 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker The contract is stringent in that: 84*523fa7a6SAndroid Build Coastguard Worker * Each node that is intended to be delegated must be tagged 85*523fa7a6SAndroid Build Coastguard Worker * No change in the original input exported program (ExportedProgram) representation can take 86*523fa7a6SAndroid Build Coastguard Worker place other than adding sub-Modules for encapsulating existing portions of the 87*523fa7a6SAndroid Build Coastguard Worker input exported program and the associated metadata for tagging. 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker Args: 90*523fa7a6SAndroid Build Coastguard Worker exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation. 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker Returns: 93*523fa7a6SAndroid Build Coastguard Worker PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers. 94*523fa7a6SAndroid Build Coastguard Worker """ 95*523fa7a6SAndroid Build Coastguard Worker pass 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker def ops_to_not_decompose( 98*523fa7a6SAndroid Build Coastguard Worker self, 99*523fa7a6SAndroid Build Coastguard Worker ep: ExportedProgram, 100*523fa7a6SAndroid Build Coastguard Worker ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 101*523fa7a6SAndroid Build Coastguard Worker """ 102*523fa7a6SAndroid Build Coastguard Worker Returns a list of operator names that should not be decomposed. When these ops are 103*523fa7a6SAndroid Build Coastguard Worker registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be 104*523fa7a6SAndroid Build Coastguard Worker guaranteed that the program that the backend receives will not have any of these ops 105*523fa7a6SAndroid Build Coastguard Worker decomposed. 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker Returns: 108*523fa7a6SAndroid Build Coastguard Worker List[torch._ops.OpOverload]: a list of operator names that should not be decomposed. 109*523fa7a6SAndroid Build Coastguard Worker Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide 110*523fa7a6SAndroid Build Coastguard Worker which will be called for each node in the graph that users can use as a filter for certain 111*523fa7a6SAndroid Build Coastguard Worker nodes that should be continued to be decomposed even though the op they correspond to is 112*523fa7a6SAndroid Build Coastguard Worker in the list returned by ops_to_not_decompose. 113*523fa7a6SAndroid Build Coastguard Worker """ 114*523fa7a6SAndroid Build Coastguard Worker return ([], None) 115