xref: /aosp_15_r20/external/executorch/exir/backend/test/demos/rpc/executor_backend_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import typing
8from typing import final
9
10import torch
11from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
12    generate_pattern_op_partitions,
13)
14from executorch.exir.backend.partitioner import (
15    DelegationSpec,
16    Partitioner,
17    PartitionResult,
18)
19from executorch.exir.backend.test.backend_with_compiler_demo import (
20    BackendWithCompilerDemo,
21)
22from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
23    ExecutorBackend,
24)
25from torch.export import ExportedProgram
26from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
27
28
29class AnyOperatorSupport(OperatorSupportBase):
30    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
31        return node.op == "call_function"
32
33
34class AnyDelegateSupport(OperatorSupportBase):
35    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
36        if node.op == "call_method":
37            assert isinstance(
38                node.args[0], torch.fx.Node
39            ), "the first argument is not an fx Node, it's not a valid graph with delgates"
40            lowered_name = typing.cast(torch.fx.Node, node.args[0]).name
41            lowered_module = submodules[lowered_name]
42            return lowered_module.backend_id is BackendWithCompilerDemo.__name__
43        return False
44
45
46@final
47class ExecutorBackendPartitioner(Partitioner):
48    """
49    Partitions all add/mul nodes regardless of order
50    """
51
52    def __init__(self) -> None:
53        self.op_support = any_chain(AnyOperatorSupport(), AnyDelegateSupport())
54        self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, [])
55
56    def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
57        partition_tags = {}
58        partition_list = generate_pattern_op_partitions(
59            edge_exported_program.graph_module, op_support=self.op_support
60        )
61        for partition in partition_list:
62            for node in partition.nodes:
63                delegation_tag = f"tag{partition.id}"
64                node.meta["delegation_tag"] = delegation_tag
65                partition_tags[delegation_tag] = self.delegation_spec
66
67                # Tag the delegate submodules
68                if node.args[0].op == "get_attr":
69                    node.args[0].meta["delegation_tag"] = delegation_tag
70
71        return PartitionResult(
72            tagged_exported_program=edge_exported_program,
73            partition_tags=partition_tags,
74        )
75