1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import typing 8from typing import final 9 10import torch 11from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 12 generate_pattern_op_partitions, 13) 14from executorch.exir.backend.partitioner import ( 15 DelegationSpec, 16 Partitioner, 17 PartitionResult, 18) 19from executorch.exir.backend.test.backend_with_compiler_demo import ( 20 BackendWithCompilerDemo, 21) 22from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( 23 ExecutorBackend, 24) 25from torch.export import ExportedProgram 26from torch.fx.passes.operator_support import any_chain, OperatorSupportBase 27 28 29class AnyOperatorSupport(OperatorSupportBase): 30 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 31 return node.op == "call_function" 32 33 34class AnyDelegateSupport(OperatorSupportBase): 35 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 36 if node.op == "call_method": 37 assert isinstance( 38 node.args[0], torch.fx.Node 39 ), "the first argument is not an fx Node, it's not a valid graph with delgates" 40 lowered_name = typing.cast(torch.fx.Node, node.args[0]).name 41 lowered_module = submodules[lowered_name] 42 return lowered_module.backend_id is BackendWithCompilerDemo.__name__ 43 return False 44 45 46@final 47class ExecutorBackendPartitioner(Partitioner): 48 """ 49 Partitions all add/mul nodes regardless of order 50 """ 51 52 def __init__(self) -> None: 53 self.op_support = any_chain(AnyOperatorSupport(), AnyDelegateSupport()) 54 self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, []) 55 56 def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: 57 partition_tags = {} 58 partition_list = generate_pattern_op_partitions( 59 edge_exported_program.graph_module, op_support=self.op_support 60 ) 61 for partition in partition_list: 62 for node in partition.nodes: 63 delegation_tag = f"tag{partition.id}" 64 node.meta["delegation_tag"] = delegation_tag 65 partition_tags[delegation_tag] = self.delegation_spec 66 67 # Tag the delegate submodules 68 if node.args[0].op == "get_attr": 69 node.args[0].meta["delegation_tag"] = delegation_tag 70 71 return PartitionResult( 72 tagged_exported_program=edge_exported_program, 73 partition_tags=partition_tags, 74 ) 75