xref: /aosp_15_r20/external/executorch/exir/backend/test/test_backends_nested.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import operator
8import unittest
9from typing import Dict, final, List
10
11import executorch.exir as exir
12
13import torch
14
15from executorch.exir.backend.backend_api import to_backend
16from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
17from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
18    generate_pattern_op_partitions,
19)
20from executorch.exir.backend.compile_spec_schema import CompileSpec
21from executorch.exir.backend.partitioner import (
22    DelegationSpec,
23    Partitioner,
24    PartitionResult,
25)
26
27from executorch.exir.backend.test.op_partitioner_demo import (
28    AddOperatorSupport,
29    MatmulOperatorSupport,
30)
31from executorch.exir.delegate import executorch_call_delegate
32
33from executorch.exir.graph_module import _get_submodule, get_control_flow_submodules
34from executorch.exir.lowered_backend_module import get_lowered_submodules
35from functorch.experimental import control_flow
36from torch.export import ExportedProgram
37from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
38
39
40class M(torch.nn.Module):
41    def __init__(self):
42        super().__init__()
43
44    def forward(self, x, pred1, pred2, y):
45        def true_fn(x, pred2):
46            def true_nested(y):
47                y = y + y
48                y = torch.mm(y, y)
49                return y
50
51            def false_nested(y):
52                return torch.mm(y, y)
53
54            z = control_flow.cond(pred2, true_nested, false_nested, [x])
55            return x + z
56
57        def false_fn(x, _pred2):
58            return torch.mm(x, x)
59
60        x = x.cos()
61        x = x + y
62        y = control_flow.cond(pred1, true_fn, false_fn, [x, pred2])
63        return y.sin()
64
65    def get_example_inputs(self):
66        return (
67            torch.ones(2, 2),
68            torch.tensor([False]),
69            torch.Tensor([False]),
70            torch.ones(2, 2),
71        )
72
73
74@final
75class Backend2Demo(BackendDetails):
76    @staticmethod
77    def preprocess(
78        edge_program: ExportedProgram,
79        compile_specs: List[CompileSpec],
80    ) -> PreprocessResult:
81        processed_bytes = "Backend2::"
82        for node in edge_program.graph.nodes:
83            if node.op == "call_function":
84                processed_bytes += f"{node.target.__name__};"
85        return PreprocessResult(
86            processed_bytes=bytes(processed_bytes, encoding="utf8"),
87        )
88
89
90@final
91class Backend2PartitionerDemo(Partitioner):
92    """
93    Partitions all add/mul nodes regardless of order for Backend2
94    """
95
96    def __init__(self) -> None:
97        self.op_support = any_chain(AddOperatorSupport(), MatmulOperatorSupport())
98        self.delegation_spec = DelegationSpec("Backend2Demo", [])
99        self.partition_tags = {}
100
101    def _partition_graph_module(
102        self, edge_graph_module: torch.fx.GraphModule
103    ) -> Dict[str, DelegationSpec]:
104        partition_tags: Dict[str, DelegationSpec] = {}
105        partition_list = generate_pattern_op_partitions(
106            edge_graph_module, op_support=self.op_support
107        )
108
109        for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
110            submodule_partition_tags = self._partition_graph_module(submodule)
111            partition_tags.update(submodule_partition_tags)
112
113        for partition in partition_list:
114            for node in partition.nodes:
115                delegation_tag = f"backend2_tag{partition.id}"
116                node.meta["delegation_tag"] = delegation_tag
117                partition_tags[delegation_tag] = self.delegation_spec
118        return partition_tags
119
120    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
121        partition_tags = self._partition_graph_module(exported_program.graph_module)
122        return PartitionResult(
123            tagged_exported_program=exported_program, partition_tags=partition_tags
124        )
125
126
127@final
128class Backend1Demo(BackendDetails):
129    @staticmethod
130    def preprocess(
131        edge_program: ExportedProgram,
132        compile_specs: List[CompileSpec],
133    ) -> PreprocessResult:
134        assert isinstance(edge_program, ExportedProgram)
135        partitioned_module = to_backend(edge_program, Backend2PartitionerDemo())
136
137        def process(gm):
138            processed_bytes = ""
139            for node in gm.graph.nodes:
140                if node.op == "call_function":
141                    if node.target is torch.ops.higher_order.cond:
142                        _, true_gm, _ = _get_submodule(gm, node, 1)
143                        _, false_gm, _ = _get_submodule(gm, node, 2)
144                        processed_bytes += f"{node.target.__name__}({process(true_gm)},{process(false_gm)});"
145                    elif node.target is operator.getitem:
146                        continue
147                    elif node.target is executorch_call_delegate:
148                        _, lowered, _ = _get_submodule(gm, node, 0)
149                        processed_bytes += f"call_delegate({lowered.processed_bytes});"
150                    else:
151                        processed_bytes += f"{node.target.__name__};"
152            return processed_bytes
153
154        processed_bytes = f"Backend1::({process(partitioned_module.graph_module)})"
155        return PreprocessResult(
156            processed_bytes=bytes(processed_bytes, encoding="utf8"),
157        )
158
159
160class CondOperatorSupport(OperatorSupportBase):
161    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
162        return node.op == "call_function" and node.target is torch.ops.higher_order.cond
163
164
165@final
166class Backend1PartitionerDemo(Partitioner):
167    """
168    Partitions all add/mul/cond nodes regardless of order. Since we're
169    partitioning the cond ops, we do not need to go into those submodules.
170    """
171
172    def __init__(self) -> None:
173        self.op_support = any_chain(
174            AddOperatorSupport(), MatmulOperatorSupport(), CondOperatorSupport()
175        )
176        self.delegation_spec = DelegationSpec("Backend1Demo", [])
177
178    def _partition_graph_module(
179        self, edge_graph_module: torch.fx.GraphModule
180    ) -> Dict[str, DelegationSpec]:
181        partition_tags: Dict[str, DelegationSpec] = {}
182        partition_list = generate_pattern_op_partitions(
183            edge_graph_module, op_support=self.op_support
184        )
185
186        for _, submodule, node in get_control_flow_submodules(edge_graph_module):
187            # Don't partition the cond submodules because we are lowering the
188            # entire cond node, including it's submodules.
189            if node.target is not control_flow.cond:
190                self._partition_graph_module(submodule)
191
192        for partition in partition_list:
193            for node in partition.nodes:
194                delegation_tag = f"backend1_tag{partition.id}"
195                if (
196                    node.op == "call_function"
197                    and node.target is torch.ops.higher_order.cond
198                ):
199                    # Tag the arguments that take in the submodules to cond
200                    node.args[1].meta["delegation_tag"] = delegation_tag
201                    node.args[2].meta["delegation_tag"] = delegation_tag
202                node.meta["delegation_tag"] = delegation_tag
203                partition_tags[delegation_tag] = self.delegation_spec
204        return partition_tags
205
206    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
207        partition_tags = self._partition_graph_module(exported_program.graph_module)
208        return PartitionResult(
209            tagged_exported_program=exported_program, partition_tags=partition_tags
210        )
211
212
213class TestNestedBackends(unittest.TestCase):
214    def test(self) -> None:
215        """
216        Partitions the cond ops into the delegate
217        """
218
219        m = M()
220        orig_res = m(*m.get_example_inputs())
221        orig = exir.capture(
222            m,
223            m.get_example_inputs(),
224            exir.CaptureConfig(),
225        ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
226
227        partitioned = orig
228        partitioned.exported_program = to_backend(
229            orig.exported_program, Backend1PartitionerDemo()
230        )
231
232        new_res = partitioned(*m.get_example_inputs())[0]
233        self.assertTrue(torch.allclose(orig_res, new_res))
234
235        # The toplevel module should have lowered the cond and add op
236        toplevel_lowered = get_lowered_submodules(
237            partitioned.exported_program.graph_module
238        )
239        self.assertEqual(len(toplevel_lowered), 1)
240        toplevel_lowered = toplevel_lowered[0][1]
241        self.maxDiff = None
242        self.assertEqual(
243            str(toplevel_lowered.processed_bytes),
244            (
245                'b"Backend1::('
246                + "call_delegate(b'Backend2::aten.add.Tensor;');"
247                + "cond("
248                # True function of toplevel cond (nested cond)
249                + "cond(call_delegate(b'Backend2::aten.add.Tensor;aten.mm.default;');,call_delegate(b'Backend2::aten.mm.default;'););"
250                # True function of toplevel cond (delegated add)
251                + "call_delegate(b'Backend2::aten.add.Tensor;');,"
252                # False function of toplevel cond
253                + "call_delegate(b'Backend2::aten.mm.default;'););)\""
254            ),
255        )
256