1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from abc import ABC, abstractmethod 8from typing import Callable, Dict, Iterable, List, Optional, Tuple 9 10import torch 11from executorch.exir.backend.backend_details import ExportedProgram 12from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 13 generate_partitions_from_list_of_nodes, 14) 15from executorch.exir.backend.partitioner import ( 16 DelegationSpec, 17 Partitioner, 18 PartitionResult, 19) 20from torch.fx.passes.infra.partitioner import Partition 21 22 23def format_target_name(target_name: str) -> str: 24 """ 25 We remove the dialect name space from the target name. We generally 26 do not care for the op dialect specific name space ("aten.", "quantized_decomposed.") 27 but rather the op itself. Se remove the dialect-specific name space from the 28 name and return the op name itself 29 """ 30 names = target_name.split(".") 31 if len(names) > 2: 32 names.pop(0) 33 34 return ".".join(names) 35 36 37class PartitionerConfig(ABC): 38 """ 39 Class used to represent a PartitionerConfig. 40 41 PartitionerConfig is used by config-based partitioner to partition identify 42 nodes to be delegated. User overrides the methods: 43 - target_name 44 - check_constraints 45 - get_partition 46 - get_original_aten 47 48 The Config-Based Partitioner then uses these overridden methods to find nodes 49 which match target_name, check_constraints, and if true, returns the partition 50 (list of nodes) which represent the node and its dependencies. get_original_aten 51 is used to halt decomposition to edge_dialect if the node can be delegated by 52 the specified backend. 53 """ 54 55 @classmethod 56 @property 57 @abstractmethod 58 def target_name(cls) -> str: 59 """ 60 Target name for this partitioner config. When the Config-Based Partitioner 61 encounters a node with a matching target name, it uses this config's methods to 62 checks the constraints of this node and get all of its dependencies. 63 the target name is formatted to remove the dialect-specific name space. 64 i.e. linear.default 65 """ 66 pass 67 68 @abstractmethod 69 def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: 70 """ 71 Takes in a node and returns true if the node is partitionable. 72 73 Args: 74 node: Node to be partitioned 75 ep: Exported program of the graph module 76 Returns: 77 True or False whether this node is partitionable 78 """ 79 pass 80 81 @abstractmethod 82 def get_original_aten(self) -> Optional[torch._ops.OpOverload]: 83 """ 84 Returns the original aten dialect op, this is for to_edge_transform_and_lower 85 API, so that this config can be used to stop decomposition of this original 86 aten op 87 """ 88 pass 89 90 @abstractmethod 91 def get_partition( 92 self, node: torch.fx.Node, ep: ExportedProgram 93 ) -> List[torch.fx.Node]: 94 """ 95 Returns the partitioned nodes from get_node_and_deps, but also labels them 96 with the name of the PartitionerConfig class which return this set of nodes. 97 98 Returns an empty list of the node and deps do not satisfy the checked constraints 99 """ 100 pass 101 102 103class ConfigerationBasedPartitioner(Partitioner): 104 def __init__( 105 self, 106 delegation_spec: DelegationSpec, 107 partitioner_configs: Iterable[PartitionerConfig], 108 ): 109 """ 110 Configeration based partitioner. We supply the partitioner with a set of configerations 111 which describe the node type, constraints, and any dependencies required to be partitioned 112 with the node. We use the configerations to partition the graph module. 113 """ 114 super().__init__() 115 # Initialize partitioner configs map {"target_name": PartitionerConfig} 116 self.target_partitioner_configs: Dict[str, PartitionerConfig] = {} 117 for config in partitioner_configs: 118 target_name = config.target_name 119 if target_name in self.target_partitioner_configs: 120 other_config = self.target_partitioner_configs[target_name] 121 raise RuntimeError( 122 f"PartitionerConfig: {config} and {other_config} have the same target_name: {target_name}" 123 ) 124 else: 125 self.target_partitioner_configs[target_name] = config 126 127 self.delegation_spec = delegation_spec 128 129 def ops_to_not_decompose( 130 self, 131 ep: ExportedProgram, 132 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 133 def filter_fn(node: torch.fx.Node) -> bool: 134 """ 135 The partitioner configs we initialize with have check_constraints function, 136 to determine if this op is indeed partitionable. We grab the check_constraint 137 function of this op from the config and use it to filter. 138 """ 139 if node.op != "call_function": 140 return False 141 target_name = format_target_name(node.target.__name__) # pyre-ignore 142 143 if target_name in self.target_partitioner_configs: 144 config = self.target_partitioner_configs[target_name] 145 # only filter_fn if config has original_aten 146 if config.get_original_aten(): 147 return self.target_partitioner_configs[ 148 target_name 149 ].check_constraints(node, ep) 150 151 return False 152 153 # Get list of original aten targets which we do not want to decomp 154 do_not_decomp = [] 155 for node_config in self.target_partitioner_configs.values(): 156 original_aten = node_config.get_original_aten() 157 if original_aten is not None: 158 do_not_decomp.append(original_aten) 159 160 return (do_not_decomp, filter_fn) 161 162 def get_matched_nodes_from_configs( 163 self, ep: ExportedProgram 164 ) -> List[List[torch.fx.Node]]: 165 # gather supported nodes 166 matched_nodes = [] 167 gm = ep.graph_module 168 for node in gm.graph.nodes: 169 if node.op == "call_function": 170 target = format_target_name(node.target.__name__) 171 if target in self.target_partitioner_configs: 172 node_config = self.target_partitioner_configs[target] 173 if node_config.check_constraints(node, ep): 174 matched_nodes.append(node_config.get_partition(node, ep)) 175 176 return matched_nodes 177 178 def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: 179 matched_nodes = self.get_matched_nodes_from_configs(ep) 180 # create partitions 181 partitions = generate_partitions_from_list_of_nodes( 182 ep.graph_module, 183 matched_nodes, 184 ) 185 return partitions 186 187 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 188 partitions = self.generate_partitions(exported_program) 189 190 # tag nodes 191 partition_tags: Dict[str, DelegationSpec] = {} 192 for partition in partitions: 193 for node in partition.nodes: 194 delegation_tag = f"tag{partition.id}" 195 if "delegation_tag" in node.meta: 196 raise RuntimeError( 197 f"Partitioner Erro found node {node} in partition {node.meta['delegation_tag']} and partition {delegation_tag}" 198 ) 199 node.meta["delegation_tag"] = delegation_tag 200 partition_tags[delegation_tag] = self.delegation_spec 201 202 return PartitionResult( 203 tagged_exported_program=exported_program, partition_tags=partition_tags 204 ) 205