1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import itertools 8from typing import Callable, Dict, final, List, Optional, Tuple 9 10import torch 11from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 12 generate_pattern_op_partitions, 13) 14 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from executorch.exir.backend.partitioner import ( 17 DelegationSpec, 18 Partitioner, 19 PartitionResult, 20) 21from executorch.exir.backend.test.backend_with_compiler_demo import ( 22 BackendWithCompilerDemo, 23) 24from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( 25 ExecutorBackend, 26) 27from executorch.exir.dialects._ops import ops as exir_ops 28from executorch.exir.graph_module import get_control_flow_submodules 29from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 30from torch.export import ExportedProgram 31from torch.fx.passes.infra.partitioner import Partition 32from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 33 34 35class AllOperatorSupport(OperatorSupportBase): 36 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 37 return node.op == "call_function" 38 39 40class AddOperatorSupport(OperatorSupportBase): 41 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 42 return node.op == "call_function" and node.target in [ 43 exir_ops.edge.aten.add.Tensor, 44 ] 45 46 47class MatmulOperatorSupport(OperatorSupportBase): 48 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 49 return node.op == "call_function" and node.target in [ 50 exir_ops.edge.aten.mm.default, 51 ] 52 53 54@final 55class AddMulPartitionerDemo(Partitioner): 56 """ 57 Partitions all add/mul nodes regardless of order 58 """ 59 60 def __init__(self) -> None: 61 self.op_support = any_chain(AddOperatorSupport(), MatmulOperatorSupport()) 62 self.delegation_spec = DelegationSpec( 63 BackendWithCompilerDemo.__name__, 64 [CompileSpec("max_value", bytes([4]))], 65 ) 66 67 def _partition_graph_module( 68 self, 69 graph_module: torch.fx.GraphModule, 70 ) -> Dict[str, DelegationSpec]: 71 partition_tags: Dict[str, DelegationSpec] = {} 72 partition_list = generate_pattern_op_partitions( 73 graph_module, op_support=self.op_support 74 ) 75 for partition in partition_list: 76 for node in partition.nodes: 77 delegation_tag = f"tag{partition.id}" 78 node.meta["delegation_tag"] = delegation_tag 79 partition_tags[delegation_tag] = self.delegation_spec 80 81 for _, submodule, _ in get_control_flow_submodules(graph_module): 82 ret_partition_tags = self._partition_graph_module(submodule) 83 partition_tags.update(ret_partition_tags) 84 85 return partition_tags 86 87 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 88 partition_tags = self._partition_graph_module(exported_program.graph_module) 89 return PartitionResult( 90 tagged_exported_program=exported_program, partition_tags=partition_tags 91 ) 92 93 94@final 95class AddAttributePartitionerDemo(Partitioner): 96 """ 97 Partitions all add and get_attr nodes 98 """ 99 100 def __init__(self) -> None: 101 self.op_support = AddOperatorSupport() 102 103 self.delegation_spec = DelegationSpec(BackendWithCompilerDemo.__name__, []) 104 105 def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: 106 partition_tags = {} 107 partition_list = generate_pattern_op_partitions( 108 edge_exported_program.graph_module, op_support=self.op_support 109 ) 110 for partition in partition_list: 111 for node in partition.nodes: 112 delegation_tag = f"tag{partition.id}" 113 partition_tags[delegation_tag] = self.delegation_spec 114 115 # Tag the add nodes 116 node.meta["delegation_tag"] = delegation_tag 117 118 for arg_node in node.args: 119 if not isinstance(arg_node, torch.fx.Node): 120 continue 121 122 is_get_attr = arg_node.op == "get_attr" 123 is_param_buffer = arg_node.op == "placeholder" and ( 124 is_param(edge_exported_program, arg_node) 125 or is_buffer(edge_exported_program, arg_node) 126 or is_lifted_tensor_constant(edge_exported_program, arg_node) 127 ) 128 if is_get_attr or is_param_buffer: 129 arg_node.meta["delegation_tag"] = delegation_tag 130 # Add to the list of partitioned nodes. 131 132 return PartitionResult( 133 tagged_exported_program=edge_exported_program, partition_tags=partition_tags 134 ) 135 136 137@final 138class AllNodesPartitionerDemo(Partitioner): 139 """ 140 Partitions all nodes 141 """ 142 143 def __init__(self) -> None: 144 self.op_support = AllOperatorSupport() 145 self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, []) 146 147 def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: 148 partition_tags = {} 149 partition_list = generate_pattern_op_partitions( 150 edge_exported_program.graph_module, op_support=self.op_support 151 ) 152 for partition in partition_list: 153 for node in partition.nodes: 154 delegation_tag = f"tag{partition.id}" 155 partition_tags[delegation_tag] = self.delegation_spec 156 157 # Tag the add nodes 158 node.meta["delegation_tag"] = delegation_tag 159 160 for arg_node in node.args: 161 if not isinstance(arg_node, torch.fx.Node): 162 continue 163 164 is_get_attr = arg_node.op == "get_attr" 165 is_param_buffer = arg_node.op == "placeholder" and ( 166 is_param(edge_exported_program, arg_node) 167 or is_buffer(edge_exported_program, arg_node) 168 or is_lifted_tensor_constant(edge_exported_program, arg_node) 169 ) 170 if is_get_attr or is_param_buffer: 171 arg_node.meta["delegation_tag"] = delegation_tag 172 # Add to the list of partitioned nodes. 173 174 return PartitionResult( 175 tagged_exported_program=edge_exported_program, partition_tags=partition_tags 176 ) 177 178 179ops_not_to_decompose = [ 180 torch.ops.aten.linear.default, 181 torch.ops.aten.scaled_dot_product_attention.default, 182 torch.ops.aten.upsample_nearest2d.vec, 183] 184 185edge_ops_non_decomposed = [ 186 exir_ops.edge.aten.linear.default, 187 exir_ops.edge.aten.scaled_dot_product_attention.default, 188 exir_ops.edge.aten.upsample_nearest2d.vec, 189] 190 191 192class OpsToNotDecomposeOperatorSupport(OperatorSupportBase): 193 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 194 return node.op == "call_function" and node.target in edge_ops_non_decomposed 195 196 197@final 198class NonDecompTestPartitioner(Partitioner): 199 """ 200 Non Decomp Test Partitioner, preserves aten ops from decomposition for delegate 201 consumption. Ensures that non_decomposed_edge_ops are all within their own delegate 202 """ 203 204 def __init__(self) -> None: 205 self.supported_non_decomposed_edge_ops = edge_ops_non_decomposed 206 self.op_support = any_chain(OpsToNotDecomposeOperatorSupport()) 207 self.delegation_spec = DelegationSpec( 208 BackendWithCompilerDemo.__name__, 209 [CompileSpec("max_value", bytes([4]))], 210 ) 211 212 def ops_to_not_decompose( 213 self, ep: ExportedProgram 214 ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: 215 def filter_ops(node: torch.fx.Node) -> bool: 216 if node.op == "call_function" and node.target in ops_not_to_decompose: 217 if len(node.args) == 3: 218 # This means that linear has a bias which is the only linear we support in this 219 # demo partitioner. 220 return True 221 else: 222 return False 223 224 return True 225 226 return (ops_not_to_decompose, filter_ops) 227 228 def _generate_single_node_partition( 229 self, gm: torch.fx.GraphModule 230 ) -> List[Partition]: 231 partitions = [] 232 partition_id = itertools.count() 233 nodes_seen = set() 234 for node in gm.graph.nodes: 235 if ( 236 node.op == "call_function" 237 and node.target in self.supported_non_decomposed_edge_ops 238 and node not in nodes_seen 239 ): 240 partitions.append(Partition(nodes=[node], id=next(partition_id))) 241 nodes_seen.add(node) 242 243 return partitions 244 245 def _partition_graph_module( 246 self, 247 graph_module: torch.fx.GraphModule, 248 ) -> Dict[str, DelegationSpec]: 249 partition_tags: Dict[str, DelegationSpec] = {} 250 partition_list = self._generate_single_node_partition(graph_module) 251 for partition in partition_list: 252 for node in partition.nodes: 253 delegation_tag = f"tag{partition.id}" 254 node.meta["delegation_tag"] = delegation_tag 255 partition_tags[delegation_tag] = self.delegation_spec 256 257 for _, submodule, _ in get_control_flow_submodules(graph_module): 258 ret_partition_tags = self._partition_graph_module(submodule) 259 partition_tags.update(ret_partition_tags) 260 return partition_tags 261 262 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 263 partition_tags = self._partition_graph_module(exported_program.graph_module) 264 return PartitionResult( 265 tagged_exported_program=exported_program, partition_tags=partition_tags 266 ) 267