xref: /aosp_15_r20/external/executorch/exir/tests/test_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport unittest
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import EdgeCompileConfig, to_edge
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import (
15*523fa7a6SAndroid Build Coastguard Worker    create_submodule_from_nodes,
16*523fa7a6SAndroid Build Coastguard Worker    LoweredBackendModule,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import (
19*523fa7a6SAndroid Build Coastguard Worker    BackendDelegate,
20*523fa7a6SAndroid Build Coastguard Worker    BackendDelegateDataReference,
21*523fa7a6SAndroid Build Coastguard Worker    DataLocation,
22*523fa7a6SAndroid Build Coastguard Worker    DelegateCall,
23*523fa7a6SAndroid Build Coastguard Worker)
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops
25*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import export
26*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Workerclass WrapperModule(torch.nn.Module):
30*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, fn):
31*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
32*523fa7a6SAndroid Build Coastguard Worker        self.fn = fn
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    def forward(self, *args, **kwargs):
35*523fa7a6SAndroid Build Coastguard Worker        return self.fn(*args, **kwargs)
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Workerclass TestDelegate(unittest.TestCase):
39*523fa7a6SAndroid Build Coastguard Worker    @classmethod
40*523fa7a6SAndroid Build Coastguard Worker    def setUpClass(cls) -> None:
41*523fa7a6SAndroid Build Coastguard Worker        register_additional_test_aten_ops()
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker    def test_call_delegate(self) -> None:
44*523fa7a6SAndroid Build Coastguard Worker        def g(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
45*523fa7a6SAndroid Build Coastguard Worker            return x + y
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.ones(1, 3), torch.ones(1, 3))
48*523fa7a6SAndroid Build Coastguard Worker        edge_ir_m = to_edge(export(WrapperModule(g), inputs))
49*523fa7a6SAndroid Build Coastguard Worker        lowered_module: LoweredBackendModule = LoweredBackendModule(
50*523fa7a6SAndroid Build Coastguard Worker            edge_ir_m.exported_program(), "BackendWithCompilerDemo", b"moo", []
51*523fa7a6SAndroid Build Coastguard Worker        )
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
54*523fa7a6SAndroid Build Coastguard Worker            return torch.ops.higher_order.executorch_call_delegate(lowered_module, x, y)
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker        orig_res = f(*inputs)
57*523fa7a6SAndroid Build Coastguard Worker        gm = export(
58*523fa7a6SAndroid Build Coastguard Worker            WrapperModule(f),
59*523fa7a6SAndroid Build Coastguard Worker            inputs,
60*523fa7a6SAndroid Build Coastguard Worker        )
61*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("lowered_module_0").check(
62*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.higher_order.executorch_call_delegate"
63*523fa7a6SAndroid Build Coastguard Worker        ).run(gm.graph_module.code)
64*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(orig_res, gm.module()(*inputs)))
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Worker    def test_to_backend(self) -> None:
67*523fa7a6SAndroid Build Coastguard Worker        """Check if we have patched a lowered module correctly (for delegation)"""
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker        m = models.CompositeDelegateModule()
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker        exec_prog = to_edge(
72*523fa7a6SAndroid Build Coastguard Worker            export(m, m.get_random_inputs()),
73*523fa7a6SAndroid Build Coastguard Worker            compile_config=EdgeCompileConfig(_check_ir_validity=False),
74*523fa7a6SAndroid Build Coastguard Worker        ).to_executorch()  # TODO(larryliu): fix split_copy.Tensor
75*523fa7a6SAndroid Build Coastguard Worker        graph_module = exec_prog.exported_program().graph_module
76*523fa7a6SAndroid Build Coastguard Worker        program = exec_prog._emitter_output.program
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker        # Check that there exists a call_delegate, representing the call to the
79*523fa7a6SAndroid Build Coastguard Worker        # delegated function
80*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("lowered_module_0").check(
81*523fa7a6SAndroid Build Coastguard Worker            "torch.ops.higher_order.executorch_call_delegate"
82*523fa7a6SAndroid Build Coastguard Worker        ).run(graph_module.code)
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker        # Check that there does not exist an add node (from the non-delegated
85*523fa7a6SAndroid Build Coastguard Worker        # BasicModuleAdd.forward function)
86*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
87*523fa7a6SAndroid Build Coastguard Worker            exir_ops.edge.aten.add.default
88*523fa7a6SAndroid Build Coastguard Worker            not in {node.target for node in graph_module.graph.nodes}
89*523fa7a6SAndroid Build Coastguard Worker        )
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes:
92*523fa7a6SAndroid Build Coastguard Worker            if (
93*523fa7a6SAndroid Build Coastguard Worker                node.op == "call_function"
94*523fa7a6SAndroid Build Coastguard Worker                and node.target == torch.ops.higher_order.executorch_call_delegate
95*523fa7a6SAndroid Build Coastguard Worker            ):
96*523fa7a6SAndroid Build Coastguard Worker                # Check that the first argument is the lowered backend module
97*523fa7a6SAndroid Build Coastguard Worker                # (which we got from a getattr)
98*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(node.args[0].op, "get_attr")
99*523fa7a6SAndroid Build Coastguard Worker                get_attr_backend = getattr(graph_module, node.args[0].target)
100*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(
101*523fa7a6SAndroid Build Coastguard Worker                    get_attr_backend._backend_id, m.lowered_module._backend_id
102*523fa7a6SAndroid Build Coastguard Worker                )
103*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(
104*523fa7a6SAndroid Build Coastguard Worker                    get_attr_backend._processed_bytes, m.lowered_module._processed_bytes
105*523fa7a6SAndroid Build Coastguard Worker                )
106*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(
107*523fa7a6SAndroid Build Coastguard Worker                    get_attr_backend._compile_specs, m.lowered_module._compile_specs
108*523fa7a6SAndroid Build Coastguard Worker                )
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker        # Check the BackendDelegate object itself
111*523fa7a6SAndroid Build Coastguard Worker        delegate: BackendDelegate = program.execution_plan[0].delegates[0]
112*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(delegate.id, "backend_demo")
113*523fa7a6SAndroid Build Coastguard Worker        processed: BackendDelegateDataReference = delegate.processed
114*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(processed.location, DataLocation.INLINE)
115*523fa7a6SAndroid Build Coastguard Worker        self.assertLess(processed.index, len(program.backend_delegate_data))
116*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
117*523fa7a6SAndroid Build Coastguard Worker            program.backend_delegate_data[processed.index].data, b"basic_module_add"
118*523fa7a6SAndroid Build Coastguard Worker        )
119*523fa7a6SAndroid Build Coastguard Worker
120*523fa7a6SAndroid Build Coastguard Worker        # Check the delegate instruction
121*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
122*523fa7a6SAndroid Build Coastguard Worker            isinstance(
123*523fa7a6SAndroid Build Coastguard Worker                program.execution_plan[0].chains[0].instructions[0].instr_args,
124*523fa7a6SAndroid Build Coastguard Worker                DelegateCall,
125*523fa7a6SAndroid Build Coastguard Worker            )
126*523fa7a6SAndroid Build Coastguard Worker        )
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker    def test_cannot_assign_attr(self) -> None:
129*523fa7a6SAndroid Build Coastguard Worker        deleg = LoweredBackendModule(None, "", b"", [])  # pyre-ignore
130*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaises(AttributeError):
131*523fa7a6SAndroid Build Coastguard Worker            deleg.backend_id = "123"  # pyre-ignore
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker    def test_create_submodule_single_return(self) -> None:
134*523fa7a6SAndroid Build Coastguard Worker        """
135*523fa7a6SAndroid Build Coastguard Worker        Original graph:
136*523fa7a6SAndroid Build Coastguard Worker            add_tensor = add(x, y)
137*523fa7a6SAndroid Build Coastguard Worker            mul_tensor = mul(add_tensor, y)
138*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(mul_tensor, y)
139*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, y)
140*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker        Partitioned graph:
143*523fa7a6SAndroid Build Coastguard Worker            add_tensor = add(x, y)
144*523fa7a6SAndroid Build Coastguard Worker            mul_tensor = mul(add_tensor, y)
145*523fa7a6SAndroid Build Coastguard Worker            return [mul_tensor]  # Output is pytree.flatten-ed
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker        Final graph:
148*523fa7a6SAndroid Build Coastguard Worker            partitioned_res = partitioned_graph(x, y)
149*523fa7a6SAndroid Build Coastguard Worker            getitem_0 = partitioned_res[0]
150*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(getitem_0, y)
151*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, y)
152*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
153*523fa7a6SAndroid Build Coastguard Worker        """
154*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.randn(1, 3), torch.randn(1, 3))
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
157*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
158*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
161*523fa7a6SAndroid Build Coastguard Worker                x = x + y
162*523fa7a6SAndroid Build Coastguard Worker                x = x * y
163*523fa7a6SAndroid Build Coastguard Worker                x = x - y
164*523fa7a6SAndroid Build Coastguard Worker                x = x / y
165*523fa7a6SAndroid Build Coastguard Worker                return x
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker        orig_res = Model()(*inputs)
168*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(export(Model(), inputs))
169*523fa7a6SAndroid Build Coastguard Worker        gm = prog.exported_program().graph_module
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker        node_list = []
172*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
173*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target in {
174*523fa7a6SAndroid Build Coastguard Worker                exir_ops.edge.aten.add.Tensor,
175*523fa7a6SAndroid Build Coastguard Worker                exir_ops.edge.aten.mul.Tensor,
176*523fa7a6SAndroid Build Coastguard Worker            }:
177*523fa7a6SAndroid Build Coastguard Worker                node_list.append(node)
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
180*523fa7a6SAndroid Build Coastguard Worker        sub_gm.recompile()
181*523fa7a6SAndroid Build Coastguard Worker        gm.recompile()
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Worker        for node in sub_gm.graph.nodes:
184*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output":
185*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args), 1)
186*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(isinstance(node.args[0], list))
187*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args[0]), 1)
188*523fa7a6SAndroid Build Coastguard Worker
189*523fa7a6SAndroid Build Coastguard Worker        new_res = prog.exported_program().module()(*inputs)
190*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(new_res, orig_res))
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker    def test_create_submodule_multiple_return(self) -> None:
193*523fa7a6SAndroid Build Coastguard Worker        """
194*523fa7a6SAndroid Build Coastguard Worker        Original graph:
195*523fa7a6SAndroid Build Coastguard Worker            add_tensor = add(x, y)
196*523fa7a6SAndroid Build Coastguard Worker            mul_tensor = mul(add_tensor, y)
197*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(add_tensor, mul_tensor)
198*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, mul_tensor)
199*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker        Partitioned graph:
202*523fa7a6SAndroid Build Coastguard Worker            add_tensor = add(x, y)
203*523fa7a6SAndroid Build Coastguard Worker            mul_tensor = mul(add_tensor, y)
204*523fa7a6SAndroid Build Coastguard Worker            return [add_tensor, mul_tensor]
205*523fa7a6SAndroid Build Coastguard Worker
206*523fa7a6SAndroid Build Coastguard Worker        Final graph:
207*523fa7a6SAndroid Build Coastguard Worker            partitioned_res = partitioned_graph(x, y)
208*523fa7a6SAndroid Build Coastguard Worker            getitem_0 = partitioned_res[0]
209*523fa7a6SAndroid Build Coastguard Worker            getitem_1 = partitioned_res[1]
210*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(getitem_0, getitem_1)
211*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, getitem_1)
212*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
213*523fa7a6SAndroid Build Coastguard Worker        """
214*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.randn(1, 3), torch.randn(1, 3))
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
217*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
218*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
221*523fa7a6SAndroid Build Coastguard Worker                x = x + y
222*523fa7a6SAndroid Build Coastguard Worker                y = x * y
223*523fa7a6SAndroid Build Coastguard Worker                x = x - y
224*523fa7a6SAndroid Build Coastguard Worker                x = x / y
225*523fa7a6SAndroid Build Coastguard Worker                return x
226*523fa7a6SAndroid Build Coastguard Worker
227*523fa7a6SAndroid Build Coastguard Worker        orig_res = Model()(*inputs)
228*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(export(Model(), inputs))
229*523fa7a6SAndroid Build Coastguard Worker        gm = prog.exported_program().graph_module
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker        node_list = []
232*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
233*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target in {
234*523fa7a6SAndroid Build Coastguard Worker                exir_ops.edge.aten.add.Tensor,
235*523fa7a6SAndroid Build Coastguard Worker                exir_ops.edge.aten.mul.Tensor,
236*523fa7a6SAndroid Build Coastguard Worker            }:
237*523fa7a6SAndroid Build Coastguard Worker                node_list.append(node)
238*523fa7a6SAndroid Build Coastguard Worker
239*523fa7a6SAndroid Build Coastguard Worker        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
240*523fa7a6SAndroid Build Coastguard Worker        sub_gm.recompile()
241*523fa7a6SAndroid Build Coastguard Worker        gm.recompile()
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker        for node in sub_gm.graph.nodes:
244*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output":
245*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args), 1)
246*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(isinstance(node.args[0], list))
247*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args[0]), 2)
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker        new_res = prog.exported_program().module()(*inputs)
250*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(new_res, orig_res))
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker    def test_create_submodule_list_return(self) -> None:
253*523fa7a6SAndroid Build Coastguard Worker        """
254*523fa7a6SAndroid Build Coastguard Worker        Original graph:
255*523fa7a6SAndroid Build Coastguard Worker            split_tensor = split(x, 5)
256*523fa7a6SAndroid Build Coastguard Worker            getitem_0 = split_tensor[0]
257*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(getitem_0, y)
258*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, y)
259*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker        Partitioned graph:
262*523fa7a6SAndroid Build Coastguard Worker            split_tensor = split(x, 5)
263*523fa7a6SAndroid Build Coastguard Worker            getitem_0 = split_tensor[0]
264*523fa7a6SAndroid Build Coastguard Worker            getitem_1 = split_tensor[1]
265*523fa7a6SAndroid Build Coastguard Worker            return [getitem_0, getitem_1]  # List output is "opened"
266*523fa7a6SAndroid Build Coastguard Worker
267*523fa7a6SAndroid Build Coastguard Worker        Final graph:
268*523fa7a6SAndroid Build Coastguard Worker            partitioned_res = partitioned_graph(x, y)
269*523fa7a6SAndroid Build Coastguard Worker            getitem_0 = partitioned_res[0]
270*523fa7a6SAndroid Build Coastguard Worker            sub_tensor = sub(getitem_0, y)
271*523fa7a6SAndroid Build Coastguard Worker            div_tensor = div(sub_tensor, y)
272*523fa7a6SAndroid Build Coastguard Worker            return [div_tensor]
273*523fa7a6SAndroid Build Coastguard Worker        """
274*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.randn(10), torch.randn(5))
275*523fa7a6SAndroid Build Coastguard Worker
276*523fa7a6SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
277*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
278*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
281*523fa7a6SAndroid Build Coastguard Worker                x = torch.split(x, 5)
282*523fa7a6SAndroid Build Coastguard Worker                x = x[0] - y
283*523fa7a6SAndroid Build Coastguard Worker                x = x / y
284*523fa7a6SAndroid Build Coastguard Worker                return x
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Worker        orig_res = Model()(*inputs)
287*523fa7a6SAndroid Build Coastguard Worker        prog = to_edge(export(Model(), inputs))
288*523fa7a6SAndroid Build Coastguard Worker        gm = prog.exported_program().graph_module
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker        node_list = []
291*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
292*523fa7a6SAndroid Build Coastguard Worker            # TODO(ssjia): split.Tensor now gets decomposed to split_with_sizes. Due to how executorch uses a pinned Pytorch
293*523fa7a6SAndroid Build Coastguard Worker            # nightly, the CI may not catch the changes to Pytorch's core decomposition table. As a temporary workaround,
294*523fa7a6SAndroid Build Coastguard Worker            # make the test backwards compatible with the old decomposition table. Remove the or statement once Pytorch nightly
295*523fa7a6SAndroid Build Coastguard Worker            # has been updated.
296*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function" and (
297*523fa7a6SAndroid Build Coastguard Worker                node.target == exir_ops.edge.aten.split_with_sizes_copy.default
298*523fa7a6SAndroid Build Coastguard Worker                or node.target == exir_ops.edge.aten.split_copy.Tensor
299*523fa7a6SAndroid Build Coastguard Worker            ):
300*523fa7a6SAndroid Build Coastguard Worker                node_list.append(node)
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Worker        sub_gm, node = create_submodule_from_nodes(gm, node_list, "tag")
303*523fa7a6SAndroid Build Coastguard Worker
304*523fa7a6SAndroid Build Coastguard Worker        for node in sub_gm.graph.nodes:
305*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output":
306*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args), 1)
307*523fa7a6SAndroid Build Coastguard Worker                self.assertTrue(isinstance(node.args[0], list))
308*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.args[0]), 2)
309*523fa7a6SAndroid Build Coastguard Worker
310*523fa7a6SAndroid Build Coastguard Worker        new_res = prog.exported_program().module()(*inputs)
311*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(new_res, orig_res))
312