1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from abc import ABC, abstractmethod 8from dataclasses import dataclass 9from types import MappingProxyType 10from typing import Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, Union 11 12import torch 13 14from executorch.exir.backend.backend_details import enforcedmethod 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from torch.export import ExportedProgram 17 18 19class DelegationSpec(NamedTuple): 20 backend_id: str 21 compile_specs: List[CompileSpec] 22 23 24@dataclass 25class PartitionResult: 26 """ 27 tagged_exported_program: the graph with nodes that intend to be delegated containing a "DelegationSpec" metadata 28 partition_tags: A dictionary that will be used to keep track of the tags and it's corresponding DelegationSpec. The tag is defined by users and used 29 in the node.meta. 30 """ 31 32 tagged_exported_program: ExportedProgram 33 partition_tags: Dict[str, DelegationSpec] 34 35 36class Partitioner(ABC): 37 """ 38 Defines a callable interface for partitioning an exported program for 39 backend delegation. 40 A partitioner implementation would receive an exported program, determine what portions of 41 the it can be delegated to certain backend (though a partitioner can target multiple 42 backends as well), and return the PartitionResult including: 43 - the same input module with specific nodes in the input graph tagged for delegation 44 - the "partition_tags" to indicate how the tag is mapped to Delegation Spec. 45 46 The nodes that intend to be delegated must be tagged (by setting 47 node.meta["delegation_tag"]) and this tag must be provided in the 48 `partition_tags` dictionary mapping to an instance of 49 DelegationSpec(backend_id, method_compilation_spec). Each tag must represent 50 a distinct submodule that we intend on lowering and should be fully contained. 51 52 For details on method_compilation_spec see the to_backend API, as these objects follow 53 the same format. 54 55 Args: 56 exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation. 57 """ 58 59 def __init__( 60 self, 61 spec: Mapping[Union[str, int, float, bool], object] = MappingProxyType({}), 62 ): 63 self._spec = spec 64 65 def __call__(self, exported_program: ExportedProgram) -> PartitionResult: 66 return self.partition(exported_program) 67 68 @property 69 def spec(self) -> Mapping[Union[str, int, float, bool], object]: 70 return self._spec 71 72 @enforcedmethod 73 @abstractmethod 74 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 75 """ 76 Returns the input exported program with newly created sub-Modules encapsulating 77 specific portions of the input "tagged" for delegation. 78 79 The specific implementation is free to decide how existing computation in the 80 input exported program should be delegated to one or even more than one specific 81 backends. 82 83 The contract is stringent in that: 84 * Each node that is intended to be delegated must be tagged 85 * No change in the original input exported program (ExportedProgram) representation can take 86 place other than adding sub-Modules for encapsulating existing portions of the 87 input exported program and the associated metadata for tagging. 88 89 Args: 90 exported_program: An ExportedProgram in Edge dialect to be partitioned for backend delegation. 91 92 Returns: 93 PartitionResult: includes the tagged graph and the delegation spec to indicate what backend_id and compile_spec is used for each node and the tag created by the backend developers. 94 """ 95 pass 96 97 def ops_to_not_decompose( 98 self, 99 ep: ExportedProgram, 100 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 101 """ 102 Returns a list of operator names that should not be decomposed. When these ops are 103 registered and the `to_backend` is invoked through to_edge_transform_and_lower it will be 104 guaranteed that the program that the backend receives will not have any of these ops 105 decomposed. 106 107 Returns: 108 List[torch._ops.OpOverload]: a list of operator names that should not be decomposed. 109 Optional[Callable[[torch.fx.Node], bool]]]: an optional callable, acting as a filter, that users can provide 110 which will be called for each node in the graph that users can use as a filter for certain 111 nodes that should be continued to be decomposed even though the op they correspond to is 112 in the list returned by ops_to_not_decompose. 113 """ 114 return ([], None) 115