xref: /aosp_15_r20/external/executorch/exir/backend/test/test_backends.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import operator
8import unittest
9from typing import Dict, List
10
11import executorch.exir as exir
12import torch
13from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14from executorch.exir.backend.compile_spec_schema import CompileSpec
15from executorch.exir.backend.partitioner import (
16    DelegationSpec,
17    Partitioner,
18    PartitionResult,
19)
20
21# import the backend implementation
22from executorch.exir.backend.test.backend_with_compiler_demo import (
23    BackendWithCompilerDemo,
24)
25from executorch.exir.backend.test.hta_partitioner_demo import (
26    HTAPartitionerMultiplePatternsDemo,
27    HTAPartitionerOnePatternDemo,
28)
29from executorch.exir.backend.test.op_partitioner_demo import (
30    AddAttributePartitionerDemo,
31    AddMulPartitionerDemo,
32)
33from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
34
35from executorch.exir.delegate import executorch_call_delegate
36from executorch.exir.dialects._ops import ops as exir_ops
37from executorch.exir.graph_module import get_control_flow_submodules
38from executorch.exir.lowered_backend_module import get_lowered_submodules
39from executorch.exir.print_program import print_program
40from executorch.exir.schema import (
41    BackendDelegate,
42    BackendDelegateDataReference,
43    DataLocation,
44    DelegateCall,
45    Program,
46)
47
48from executorch.extension.pybindings.portable_lib import (  # @manual
49    _load_for_executorch_from_buffer,
50)
51from executorch.extension.pytree import tree_flatten
52
53from functorch.experimental import control_flow
54from torch.ao.quantization import get_default_qconfig_mapping  # @manual
55from torch.ao.quantization.backend_config.executorch import (
56    get_executorch_backend_config,
57)
58from torch.ao.quantization.quantize_fx import (
59    _convert_to_reference_decomposed_fx,
60    prepare_fx,
61)
62from torch.export import ExportedProgram
63from torch.testing import FileCheck
64
65
66def vary_segments(test_method):
67    """A decorator that calls the test method with `extract_delegate_segments` set to
68    True and False.
69
70    Decorated test methods must expect a boolean parameter named
71    `extract_delegate_segments`, and they should pass that value to to_executorch() like:
72
73        m.to_executorch(
74            config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments)
75        )
76
77    This will cause the delegate data blobs to be extracted from the program and
78    serialized as separate, freeable program segments. Backends should detect no
79    difference at runtime.
80    """
81
82    def wrapper(self):
83        for extract_delegate_segments in [False, True]:
84            # subTest will create a different top-level test entry for each
85            # value, whose full names have a suffix like
86            # "(extract_delegate_segments=True)".
87            with self.subTest(extract_delegate_segments=extract_delegate_segments):
88                test_method(self, extract_delegate_segments=extract_delegate_segments)
89
90    return wrapper
91
92
93class TestBackends(unittest.TestCase):
94    def check_delegate_input(
95        self, delegate: LoweredBackendModule, input_len: int
96    ) -> None:
97        counter = 0
98        for node in delegate.original_module.graph.nodes:
99            if node.op == "placeholder":
100                counter += 1
101        self.assertEqual(counter, input_len)
102
103    def check_backend_delegate(
104        self,
105        program: Program,
106        delegate: BackendDelegate,
107        expected_id: str,
108        expected_processed: bytes,
109    ) -> None:
110        self.assertEqual(delegate.id, expected_id)
111        processed: BackendDelegateDataReference = delegate.processed
112        self.assertEqual(processed.location, DataLocation.INLINE)
113        self.assertLess(processed.index, len(program.backend_delegate_data))
114        self.assertEqual(
115            program.backend_delegate_data[processed.index].data, expected_processed
116        )
117
118    @vary_segments
119    def test_backend_with_compiler(self, extract_delegate_segments: bool):
120        class SinModule(torch.nn.Module):
121            def __init__(self):
122                super().__init__()
123
124            # TODO(chenlai): add a test with a diffrent method name when
125            # it's resolved in compiler side.
126            def forward(self, x):
127                return torch.sin(x)
128
129        sin_module = SinModule()
130        model_inputs = (torch.ones(1),)
131        edgeir_m = exir.capture(
132            sin_module, model_inputs, exir.CaptureConfig()
133        ).to_edge()
134        max_value = model_inputs[0].shape[0]
135        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
136        lowered_sin_module = to_backend(
137            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
138        )
139
140        class CompositeModule(torch.nn.Module):
141            def __init__(self):
142                super().__init__()
143                self.lowered_linear_sin = lowered_sin_module
144
145            def forward(self, x):
146                return self.lowered_linear_sin(x)
147
148        composite_model = CompositeModule()
149        model_inputs = (torch.ones(1),)
150
151        composite_model(*model_inputs)
152
153        exec_prog = (
154            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
155            .to_edge()
156            .to_executorch(
157                config=exir.ExecutorchBackendConfig(
158                    extract_delegate_segments=extract_delegate_segments
159                )
160            )
161        )
162        graph_module = exec_prog.dump_graph_module()
163
164        # Check that there is not an aten.sin node.
165        self.assertTrue(
166            exir_ops.edge.aten.sin
167            not in {node.target for node in graph_module.graph.nodes}
168        )
169
170        # Check that there exists a call_delegate, representing the call to the
171        # delegated function
172        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
173            graph_module.code
174        )
175        lowered_submodules = get_lowered_submodules(graph_module)
176        self.assertEqual(len(lowered_submodules), 1)
177
178        for node in graph_module.graph.nodes:
179            if node.op == "call_function" and node.target == executorch_call_delegate:
180                # Check that first arg is lowered_module_{unique_id}
181                self.assertEqual(node.args[0].target, "lowered_module_0")
182
183        program = exec_prog.program
184
185        # Check the program can be printed
186        print_program(program)
187
188        # Check the backend delegate
189        self.check_backend_delegate(
190            program=program,
191            delegate=program.execution_plan[0].delegates[0],
192            expected_id=BackendWithCompilerDemo.__name__,
193            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
194        )
195
196        # Check the delegate instruction
197        self.assertTrue(
198            isinstance(
199                program.execution_plan[0].chains[0].instructions[0].instr_args,
200                DelegateCall,
201            )
202        )
203        buff = exec_prog.buffer
204
205        executorch_module = _load_for_executorch_from_buffer(buff)
206        model_inputs = torch.ones(1)
207        model_outputs = executorch_module.forward([model_inputs])
208        self.assertEqual(
209            model_inputs,
210            torch.ones(1),
211        )
212        expected_output = 0.8333 * torch.ones(1)
213
214        self.assertTrue(
215            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
216        )
217
218    @vary_segments
219    def test_lowered_add_mul(self, extract_delegate_segments: bool):
220        class AddMulModule(torch.nn.Module):
221            def __init__(self):
222                super().__init__()
223
224            def forward(self, a, x, b):
225                y = torch.mm(a, x)
226                z = torch.add(y, b)
227                return z
228
229        add_mul_module = AddMulModule()
230        model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
231        edge_graph_module = exir.capture(
232            add_mul_module, model_inputs, exir.CaptureConfig()
233        ).to_edge()
234        max_value = model_inputs[0].shape[0]
235        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
236        lowered_add_mul = to_backend(
237            "BackendWithCompilerDemo", edge_graph_module.exported_program, compile_specs
238        )
239
240        class CompositeModule(torch.nn.Module):
241            def __init__(self):
242                super().__init__()
243                self.lowered_add_mul = lowered_add_mul
244
245            def forward(self, a, x, b):
246                return self.lowered_add_mul(a, x, b)
247
248        composite_model = CompositeModule()
249
250        composite_model(*model_inputs)
251
252        exec_prog = (
253            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
254            .to_edge()
255            .to_executorch(
256                config=exir.ExecutorchBackendConfig(
257                    extract_delegate_segments=extract_delegate_segments
258                )
259            )
260        )
261        buff = exec_prog.buffer
262
263        executorch_module = _load_for_executorch_from_buffer(buff)
264
265        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
266        inputs_flattened, _ = tree_flatten(model_inputs)
267        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
268        ref_output = add_mul_module(*model_inputs)
269
270        self.assertTrue(
271            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
272        )
273
274    def run_model_in_unsupported_backend(self, extract_delegate_segments: bool):
275        class SinModule(torch.nn.Module):
276            def __init__(self):
277                super().__init__()
278
279            def forward(self, x):
280                return torch.sin(x)
281
282        sin_module = SinModule()
283        # the backend only  accepts shape <= 4
284        model_inputs = (torch.ones(6),)
285        edgeir_m = exir.capture(
286            sin_module, model_inputs, exir.CaptureConfig()
287        ).to_edge()
288        max_value = model_inputs[0].shape[0]
289        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
290        lowered_sin_module = to_backend(
291            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
292        )
293
294        class CompositeModule(torch.nn.Module):
295            def __init__(self):
296                super().__init__()
297                self.lowered_linear_sin = lowered_sin_module
298
299            def forward(self, x):
300                return self.lowered_linear_sin(x)
301
302        composite_model = CompositeModule()
303        model_inputs = (torch.zeros(6),)
304
305        composite_model(*model_inputs)
306
307        exec_prog = (
308            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
309            .to_edge()
310            .to_executorch(
311                config=exir.ExecutorchBackendConfig(
312                    extract_delegate_segments=extract_delegate_segments
313                ),
314            )
315        )
316
317        buff = exec_prog.buffer
318
319        # This line should raise an exception like
320        # RuntimeError: failed with error 0x12
321        _load_for_executorch_from_buffer(buff)
322
323    @vary_segments
324    def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool):
325        with self.assertRaisesRegex(
326            RuntimeError,
327            "loading method forward failed with error 0x12",
328        ):
329            self.run_model_in_unsupported_backend(
330                extract_delegate_segments=extract_delegate_segments
331            )
332
333    @vary_segments
334    def test_backend_with_compiler_delegate_and_operator(
335        self, extract_delegate_segments: bool
336    ):
337        # Test includes both delegates and operator
338        # import the backend implementation
339        from executorch.exir.backend.test.backend_with_compiler_demo import (
340            BackendWithCompilerDemo,
341        )
342
343        class SinModule(torch.nn.Module):
344            def __init__(self):
345                super().__init__()
346
347            # TODO(chenlai): add a test with a diffrent method name when
348            # it's resolved in compiler side.
349            def forward(self, x):
350                return [torch.sin(x)]
351
352        sin_module = SinModule()
353        model_inputs = (torch.ones(1),)
354        edgeir_m = exir.capture(
355            sin_module, model_inputs, exir.CaptureConfig()
356        ).to_edge()
357        max_value = model_inputs[0].shape[0]
358        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
359        lowered_sin_module = to_backend(
360            "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs
361        )
362
363        class CompositeModule(torch.nn.Module):
364            def __init__(self):
365                super().__init__()
366                self.lowered_linear_sin = lowered_sin_module
367
368            def forward(self, x):
369                a = self.lowered_linear_sin(x)[0]
370                b = self.lowered_linear_sin(x)[0]
371                return torch.add(a, b)
372
373        composite_model = CompositeModule()
374        model_inputs = (torch.ones(1),)
375
376        composite_model(*model_inputs)
377
378        exec_prog = (
379            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
380            .to_edge()
381            .to_executorch(
382                config=exir.ExecutorchBackendConfig(
383                    extract_delegate_segments=extract_delegate_segments
384                ),
385            )
386        )
387        graph_module = exec_prog.dump_graph_module()
388        program = exec_prog.program
389        buff = exec_prog.buffer
390
391        # Check that there is not an aten.sin node.
392        self.assertTrue(
393            exir_ops.edge.aten.sin.default
394            not in {node.target for node in graph_module.graph.nodes}
395        )
396
397        # Check that there exists a call_delegate op, representing the call to the
398        # delegated function
399        FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
400            graph_module.code
401        )
402
403        for node in graph_module.graph.nodes:
404            if node.op == "call_function" and node.target == executorch_call_delegate:
405                # Check that first arg is lowered_module_{unique_id}
406                self.assertEqual(node.args[0].target, "lowered_module_0")
407
408        # Check the backend delegate
409        self.check_backend_delegate(
410            program=program,
411            delegate=program.execution_plan[0].delegates[0],
412            expected_id=BackendWithCompilerDemo.__name__,
413            expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
414        )
415
416        # Check the delegate instruction
417        self.assertTrue(
418            isinstance(
419                program.execution_plan[0].chains[0].instructions[0].instr_args,
420                DelegateCall,
421            )
422        )
423
424        executorch_module = _load_for_executorch_from_buffer(buff)
425        model_inputs = torch.ones(1)
426
427        model_outputs = executorch_module.forward([model_inputs])
428
429        self.assertEqual(
430            model_inputs,
431            torch.ones(1),
432        )
433        expected_output = 1.666667 * torch.ones(1)
434
435        self.assertTrue(
436            torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
437        )
438
439    def test_backend_with_compiler_backend_runtime_exception(self):
440        class SinModule(torch.nn.Module):
441            def __init__(self):
442                super().__init__()
443
444            # TODO(chenlai): add a test with a diffrent method name when
445            # it's resolved in compiler side.
446            def forward(self, x):
447                return torch.sin(x) + torch.cos(x)
448
449        sin_module = SinModule()
450        model_inputs = (torch.ones(1),)
451        edgeir_m = exir.capture(
452            sin_module, model_inputs, exir.CaptureConfig()
453        ).to_edge()
454        error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo"
455
456        with self.assertRaisesRegex(
457            RuntimeError,
458            error_msg,
459        ):
460            _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program, [])
461
462    def test_backend_with_compiler_backend_not_found_exception(self):
463        class SinModule(torch.nn.Module):
464            def __init__(self):
465                super().__init__()
466
467            # TODO(chenlai): add a test with a diffrent method name when
468            # it's resolved in compiler side.
469            def forward(self, x):
470                return torch.sin(x) + torch.cos(x)
471
472        sin_module = SinModule()
473        model_inputs = (torch.ones(1),)
474        edgeir_m = exir.capture(
475            sin_module, model_inputs, exir.CaptureConfig()
476        ).to_edge()
477        error_msg = r"Backend FakeBackendWithCompilerDemo was not found."
478
479        with self.assertRaisesRegex(
480            NotImplementedError,
481            error_msg,
482        ):
483            _ = to_backend("FakeBackendWithCompilerDemo", edgeir_m.exported_program, [])
484
485    @vary_segments
486    def test_backend_with_compiler_delegate_and_operator_with_two_modules(
487        self, extract_delegate_segments: bool
488    ):
489        # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend
490        class LowerableSubModel(torch.nn.Module):
491            def __init__(self):
492                super().__init__()
493
494            def forward(self, x):
495                return torch.sin(x)
496
497        # sin_module is an nn.Module
498        to_be_lowered = LowerableSubModel()
499        example_input = (torch.ones(1),)
500        to_be_lowered_exir_submodule = exir.capture(
501            to_be_lowered, example_input, exir.CaptureConfig()
502        ).to_edge()
503
504        max_value = example_input[0].shape[0]
505        compile_specs = [CompileSpec("max_value", bytes([max_value]))]
506        lowered_module = to_backend(
507            "BackendWithCompilerDemo",
508            to_be_lowered_exir_submodule.exported_program,
509            compile_specs,
510        )
511
512        class NonLowerableSubModel(torch.nn.Module):
513            def __init__(self, bias):
514                super().__init__()
515                self.bias = bias
516
517            def forward(self, a, b):
518                return torch.add(torch.add(a, b), self.bias)
519
520        # the composite modules, including lower part and non-lowerpart
521        class CompositeModel(torch.nn.Module):
522            def __init__(self):
523                super().__init__()
524                self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3)
525                self.lowerable = lowered_module
526
527            def forward(self, x):
528                a = self.lowerable(x)
529                b = self.lowerable(a)
530                ret = self.non_lowerable(a, b)
531                return a, b, ret
532
533        composite_model = CompositeModel()
534
535        # Prepare the model input
536        model_inputs = (torch.ones(1),)
537
538        # Verify the input works with eager module
539        composite_model(*model_inputs)
540
541        exec_prog = (
542            exir.capture(composite_model, model_inputs, exir.CaptureConfig())
543            .to_edge()
544            .to_executorch(
545                config=exir.ExecutorchBackendConfig(
546                    extract_delegate_segments=extract_delegate_segments
547                ),
548            )
549        )
550        flatbuffer = exec_prog.buffer
551
552        executorch_module = _load_for_executorch_from_buffer(flatbuffer)
553        model_outputs = executorch_module.forward([*model_inputs])
554
555        expected_outputs = [
556            0.8333 * torch.ones(1),
557            0.7369 * torch.ones(1),
558            1.8702 * torch.ones(1),
559        ]
560
561        for index, expected_output in enumerate(expected_outputs):
562            self.assertTrue(
563                torch.allclose(
564                    model_outputs[index], expected_output, atol=1e-03, rtol=1e-03
565                )
566            )
567
568    @vary_segments
569    def test_partition_delegate_graph_with_multiple_patterns(
570        self, extract_delegate_segments: bool
571    ):
572        class CompositeModel(torch.nn.Module):
573            def __init__(self, _weight):
574                super().__init__()
575                self.weight = _weight
576                self.lstm = torch.nn.LSTM(
577                    input_size=32,
578                    hidden_size=32,
579                    num_layers=1,
580                )
581                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
582
583            def forward(self, x_raw, h, c):
584                output, (hn, cn) = self.lstm(x_raw, (h, c))
585                k = self.conv(output)
586                x = output
587                y = cn
588                a = torch.sub(x, y)
589                b = torch.sub(x, a)
590                c = torch.sub(x, b)
591                d = torch.add(x, self.weight)
592                e = torch.mul(c, d)
593                return e, hn, k
594
595        # Prepare input and trace it
596        input_x = torch.ones([1, 32])
597        input_h = torch.ones([1, 32])
598        input_c = torch.ones([1, 32])
599        inputs = (input_x, input_h, input_c)
600
601        composite_m = CompositeModel(3)
602        orig_res = composite_m(*inputs)
603
604        traced = exir.capture(composite_m, inputs, exir.CaptureConfig()).to_edge(
605            # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
606            exir.EdgeCompileConfig(_check_ir_validity=False)
607        )
608
609        program_without_delegates = (
610            exir.capture(CompositeModel(3), inputs)
611            .to_edge(
612                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
613                exir.EdgeCompileConfig(_check_ir_validity=False)
614            )
615            .to_executorch(
616                config=exir.ExecutorchBackendConfig(
617                    extract_delegate_segments=extract_delegate_segments
618                ),
619            )
620        )
621        # after this step, part of the graph will be lowered to backend, depending on
622        # HTAPartitionerDemo's rule.
623        program_with_delegates = traced
624        program_with_delegates.exported_program = to_backend(
625            traced.exported_program, HTAPartitionerMultiplePatternsDemo()
626        )
627        program_with_delegates = program_with_delegates.to_executorch(
628            config=exir.ExecutorchBackendConfig(
629                extract_delegate_segments=extract_delegate_segments
630            ),
631        )
632
633        new_res = program_with_delegates.dump_graph_module()(*inputs)
634        for t1, t2 in zip(new_res, orig_res, strict=True):
635            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
636
637        # Check the backend delegate
638        self.check_backend_delegate(
639            program=program_with_delegates.program,
640            delegate=program_with_delegates.program.execution_plan[0].delegates[0],
641            expected_id=QnnBackend.__name__,
642            expected_processed=b"imqnncompiled",
643        )
644
645        # Check add not in the program with delegates
646        self.assertEqual(
647            0,
648            len(
649                [
650                    op
651                    for op in program_with_delegates.program.execution_plan[0].operators
652                    if op.name == "aten::sub"
653                ]
654            ),
655        )
656
657        # Check convolution not in the program with delegates
658        self.assertEqual(
659            0,
660            len(
661                [
662                    op
663                    for op in program_with_delegates.program.execution_plan[0].operators
664                    if op.name == "aten::convolution"
665                ]
666            ),
667        )
668
669        # Check convolution in the program without delegates
670        self.assertEqual(
671            1,
672            len(
673                [
674                    op
675                    for op in program_without_delegates.program.execution_plan[
676                        0
677                    ].operators
678                    if op.name == "aten::convolution"
679                ]
680            ),
681        )
682
683    @vary_segments
684    def test_partition_delegate_graph_with_one_patterns(
685        self, extract_delegate_segments: bool
686    ):
687        class CompositeModel(torch.nn.Module):
688            def __init__(self, _weight):
689                super().__init__()
690                self.weight = _weight
691                self.lstm = torch.nn.LSTM(
692                    input_size=32,
693                    hidden_size=32,
694                    num_layers=1,
695                )
696                self.conv = torch.nn.Conv1d(1, 1, 1, stride=2)
697
698            def forward(self, x_raw, h, c):
699                output, (hn, cn) = self.lstm(x_raw, (h, c))
700                k = self.conv(output)
701                x = output
702                y = cn
703                a = torch.sub(x, y)
704                b = torch.sub(x, a)
705                c = torch.sub(x, b)
706                d = torch.add(x, self.weight)
707                e = torch.mul(c, d)
708                return e, hn, k
709
710        # Prepare input and trace it
711        input_x = torch.ones([1, 32])
712        input_h = torch.ones([1, 32])
713        input_c = torch.ones([1, 32])
714        inputs = (input_x, input_h, input_c)
715
716        composite_m = CompositeModel(3)
717        orig_res = composite_m(*inputs)
718
719        traced = exir.capture(
720            composite_m,
721            inputs,
722            exir.CaptureConfig(),
723        ).to_edge(
724            # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
725            exir.EdgeCompileConfig(_check_ir_validity=False)
726        )
727
728        program_without_delegates = (
729            exir.capture(
730                CompositeModel(3),
731                (input_x, input_h, input_c),
732                exir.CaptureConfig(),
733            )
734            .to_edge(
735                # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical.
736                exir.EdgeCompileConfig(_check_ir_validity=False)
737            )
738            .to_executorch(
739                config=exir.ExecutorchBackendConfig(
740                    extract_delegate_segments=extract_delegate_segments
741                ),
742            )
743        )
744        # after this step, part of the graph will be lowered to backend, depending on
745        # HTAPartitionerDemo's rule.
746        traced_with_delegate = traced
747        traced_with_delegate.exported_program = to_backend(
748            traced.exported_program, HTAPartitionerOnePatternDemo()
749        )
750
751        new_res = traced_with_delegate(*inputs)
752        for t1, t2 in zip(new_res, orig_res, strict=True):
753            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
754
755        program_with_delegates = traced_with_delegate.to_executorch(
756            config=exir.ExecutorchBackendConfig(
757                extract_delegate_segments=extract_delegate_segments
758            ),
759        )
760
761        # TODO(T143084047): Currently not retraceable
762        # Retracing is not needed, but keeping this here to make sure the result
763        # of to_backend is retraceable
764        # graph_module_with_delegate = exir.capture(
765        #     traced_with_delegate,
766        #     (input_x, input_h, input_c),
767        #     exir.CaptureConfig(),
768        # ).to_edge()
769
770        # program_with_delegates = graph_module_with_delegate.to_executorch(
771        #     config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments),
772        # )
773
774        new_res = program_with_delegates.dump_graph_module()(*inputs)
775        for t1, t2 in zip(new_res, orig_res, strict=True):
776            self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03))
777
778        # Check the backend delegate
779        self.check_backend_delegate(
780            program=program_with_delegates.program,
781            delegate=program_with_delegates.program.execution_plan[0].delegates[0],
782            expected_id=QnnBackend.__name__,
783            expected_processed=b"imqnncompiled",
784        )
785
786        # Check add is in the program with delegates
787        self.assertEqual(
788            1,
789            len(
790                [
791                    op
792                    for op in program_with_delegates.program.execution_plan[0].operators
793                    if op.name == "aten::sub"
794                ]
795            ),
796        )
797
798        # Check convolution not in the program with delegates
799        self.assertEqual(
800            0,
801            len(
802                [
803                    op
804                    for op in program_with_delegates.program.execution_plan[0].operators
805                    if op.name == "aten::convolution"
806                ]
807            ),
808        )
809
810        # Check convolution in the program without delegates
811        self.assertEqual(
812            1,
813            len(
814                [
815                    op
816                    for op in program_without_delegates.program.execution_plan[
817                        0
818                    ].operators
819                    if op.name == "aten::convolution"
820                ]
821            ),
822        )
823
824    @vary_segments
825    def test_add_mul_partitioner(self, extract_delegate_segments: bool):
826        class Model(torch.nn.Module):
827            def __init__(self):
828                super().__init__()
829
830            def forward(self, a, x, b):
831                y = torch.mm(a, x)
832                z = y + b
833                a = z - a
834                y = torch.mm(a, x)
835                z = y + b
836                return z
837
838        m = Model()
839        inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
840        orig_res = m(*inputs)
841
842        ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge()
843        executorch_prog = ep
844        executorch_prog.exported_program = to_backend(
845            ep.exported_program, AddMulPartitionerDemo()
846        )
847
848        for node in executorch_prog.exported_program.graph.nodes:
849            if node.op == "call_function" and node.target is executorch_call_delegate:
850                for user in node.users:
851                    self.assertTrue(
852                        user.op == "call_function" and user.target == operator.getitem
853                    )
854                    self.assertTrue(user.meta.get("source_fn_stack", None) is None)
855                    self.assertTrue(user.meta.get("nn_module_stack", None) is None)
856
857        executorch_prog = executorch_prog.to_executorch(
858            config=exir.ExecutorchBackendConfig(
859                extract_delegate_segments=extract_delegate_segments
860            ),
861        )
862
863        new_res = executorch_prog.dump_graph_module()(*inputs)
864        self.assertTrue(torch.allclose(new_res[0], orig_res))
865
866        counter = 0
867        for node in executorch_prog.dump_graph_module().graph.nodes:
868            if node.op == "get_attr":
869                self.assertEqual(node.target, f"lowered_module_{counter}")
870                counter += 1
871        # There should be 2 delegated modules
872        self.assertEqual(counter, 2)
873
874        executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer)
875        # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
876        inputs_flattened, _ = tree_flatten(inputs)
877        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
878        ref_output = m(*inputs)
879
880        self.assertTrue(
881            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03),
882        )
883
884    @vary_segments
885    def test_partitioner_with_attributes(self, extract_delegate_segments: bool):
886        """
887        Check that if we tag the getattr nodes, the attributes will be added to
888        the lowered submodule rather than being passed into the delegate as
889        inputs.
890        """
891
892        class AddOne(torch.nn.Module):
893            def __init__(self):
894                super().__init__()
895                self.one = torch.ones(1, 3)
896
897            def forward(self, x):
898                return x + self.one
899
900        class Model(torch.nn.Module):
901            def __init__(self):
902                super().__init__()
903                self.add_one = AddOne()
904
905            def forward(self, x, y):
906                x = self.add_one(x) * y
907                return self.add_one(x), self.add_one(y)
908
909        inputs = (torch.randn(1, 3), torch.randn(1, 3))
910        orig_res = Model()(*inputs)
911        ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge()
912        executorch_prog = ep
913        executorch_prog.exported_program = to_backend(
914            ep.exported_program, AddAttributePartitionerDemo()
915        )
916
917        for node in executorch_prog.exported_program.graph.nodes:
918            if node.op == "call_function" and node.target is executorch_call_delegate:
919                for user in node.users:
920                    self.assertTrue(
921                        user.op == "call_function" and user.target == operator.getitem
922                    )
923                    self.assertTrue(user.meta.get("source_fn_stack", None) is None)
924                    self.assertTrue(user.meta.get("nn_module_stack", None) is None)
925
926        executorch_prog = executorch_prog.to_executorch(
927            config=exir.ExecutorchBackendConfig(
928                extract_delegate_segments=extract_delegate_segments
929            ),
930        )
931
932        # Check the delegated submodules
933        lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module())
934        self.assertEqual(len(lowered_submodules), 2)
935        # Attributes should be stored in the lowered module
936        self.check_delegate_input(lowered_submodules[0][1], 1)
937        self.check_delegate_input(lowered_submodules[1][1], 2)
938
939        executorch_prog.buffer
940
941        new_res = executorch_prog.dump_graph_module()(*inputs)
942        self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
943        self.assertTrue(torch.allclose(orig_res[1], new_res[1]))
944
945    def test_bad_partitioner(self):
946        """
947        Checks that we throw an error if user provided partitioner modifies the
948        graph module
949        """
950        inputs = (torch.randn(1, 3), torch.randn(1, 3))
951
952        class Model(torch.nn.Module):
953            def __init__(self):
954                super().__init__()
955
956            def forward(self, x, y):
957                x = x + y
958                x = x * y
959                x = x - y
960                x = x / y
961                x = x * y
962                x = x + y
963                return x
964
965        class BadPartitioner(Partitioner):
966            def partition(self, exported_program: ExportedProgram) -> PartitionResult:
967                # Partitioner should not modify the given graph module
968                for node in exported_program.graph.nodes:
969                    if (
970                        node.op == "call_function"
971                        and node.target == exir_ops.edge.aten.add.Tensor
972                    ):
973                        node.target = exir_ops.edge.aten.mul.Tensor
974                return PartitionResult(
975                    tagged_exported_program=exported_program,
976                    partition_tags={
977                        "tag1": DelegationSpec("BackendWithCompilerDemo", [])
978                    },
979                )
980
981        ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge()
982        with self.assertRaises(AssertionError):
983            _ = to_backend(ep.exported_program, BadPartitioner())
984
985    def test_quantized_with_delegate(self) -> None:
986        torch.ops.load_library(
987            "//executorch/kernels/quantized:custom_ops_generated_lib"
988        )
989        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
990        in_size = 2
991        input_size = 3
992        output_size = 4
993        linear = torch.nn.Linear(input_size, output_size).eval()
994        example_inputs = (torch.ones(in_size, input_size),)
995        prepared_linear = prepare_fx(
996            linear,
997            qconfig_mapping,
998            example_inputs,
999            backend_config=get_executorch_backend_config(),
1000        )
1001        converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx(
1002            prepared_linear,
1003        )
1004
1005        # fails to trace here
1006        converted_linear_gm = exir.capture(
1007            converted_linear,
1008            example_inputs,
1009            exir.CaptureConfig(
1010                enable_aot=True,
1011            ),
1012        ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
1013        FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
1014            converted_linear_gm.exported_program.graph_module.code
1015        )
1016
1017    def test_partition_with_control_flow(self) -> None:
1018        def true_fn(x, y):
1019            x = x - y
1020            x = x + y
1021            x = x - y
1022            return x
1023
1024        def false_fn(x, y):
1025            x = x - y
1026            x = torch.mm(x, y)
1027            x = x - y
1028            return x
1029
1030        def f(x, y):
1031            x = x + y
1032            x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y])
1033            x = x - y
1034            return x
1035
1036        inputs = (torch.ones(2, 2), torch.ones(2, 2))
1037        orig_res = f(*inputs)
1038        orig = exir.capture(
1039            f,
1040            inputs,
1041            exir.CaptureConfig(),
1042        ).to_edge()
1043        partitioned = orig
1044        partitioned.exported_program = to_backend(
1045            orig.exported_program, AddMulPartitionerDemo()
1046        )
1047
1048        new_res = partitioned(*inputs)
1049        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1050
1051        toplevel_lowered = get_lowered_submodules(
1052            partitioned.exported_program.graph_module
1053        )
1054        self.assertEqual(len(toplevel_lowered), 1)
1055        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1056            toplevel_lowered[0][1].original_module.graph_module.code
1057        )
1058
1059        # Toplevel module only has the cond submodules
1060        partitioned_submodules = get_control_flow_submodules(
1061            partitioned.exported_program.graph_module
1062        )
1063        self.assertEqual(len(partitioned_submodules), 2)
1064
1065        true_gm = partitioned_submodules[0][1]
1066        true_lowered = get_lowered_submodules(true_gm)
1067        self.assertEqual(len(true_lowered), 1)
1068        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1069            true_lowered[0][1].original_module.graph_module.code
1070        )
1071
1072        false_gm = partitioned_submodules[1][1]
1073        false_lowered = get_lowered_submodules(false_gm)
1074        self.assertEqual(len(true_lowered), 1)
1075        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1076            false_lowered[0][1].original_module.graph_module.code
1077        )
1078
1079    def test_partition_with_map(self) -> None:
1080        def map_fn(x, y):
1081            x = x - y
1082            x = x + y
1083            return x
1084
1085        def f(xs, y):
1086            y = torch.mm(y, y)
1087            return control_flow.map(map_fn, xs, y)
1088
1089        inputs = (torch.ones(2, 2), torch.ones(2, 2))
1090        orig_res = f(*inputs)
1091        orig = exir.capture(
1092            f,
1093            inputs,
1094            exir.CaptureConfig(),
1095        ).to_edge()
1096        partitioned = orig
1097        partitioned.exported_program = to_backend(
1098            orig.exported_program, AddMulPartitionerDemo()
1099        )
1100
1101        toplevel_lowered = get_lowered_submodules(
1102            partitioned.exported_program.graph_module
1103        )
1104        self.assertEqual(len(toplevel_lowered), 1)
1105        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1106            toplevel_lowered[0][1].original_module.graph_module.code
1107        )
1108
1109        # Toplevel module only has the map submodule
1110        partitioned_submodules = get_control_flow_submodules(
1111            partitioned.exported_program.graph_module
1112        )
1113        self.assertEqual(len(partitioned_submodules), 1)
1114
1115        map_fn_gm = partitioned_submodules[0][1]
1116        map_fn_lowered = get_lowered_submodules(map_fn_gm)
1117        self.assertEqual(len(map_fn_lowered), 1)
1118        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1119            map_fn_lowered[0][1].original_module.graph_module.code
1120        )
1121
1122        new_res = partitioned(*inputs)
1123
1124        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1125
1126    def test_partition_with_nested_control_flow(self) -> None:
1127        """
1128        Partitions the add and mul ops, including the ones inside the submodules
1129        """
1130
1131        def true_nested(y):
1132            y = y + y
1133            y = torch.mm(y, y)
1134            return y
1135
1136        def false_nested(y):
1137            return torch.mm(y, y)
1138
1139        def true_fn(x, pred2):
1140            z = control_flow.cond(pred2, true_nested, false_nested, [x])
1141            return x + z
1142
1143        def false_fn(x, _):
1144            return x.cos()
1145
1146        def map_fn(x, pred1, pred2, y):
1147            x = x.cos()
1148            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
1149            x = x + y
1150            return x.sin()
1151
1152        def f(xs, pred1, pred2, y):
1153            y = torch.mm(y, y)
1154            return control_flow.map(map_fn, xs, pred1, pred2, y)
1155
1156        inputs = (
1157            torch.ones(2, 2),
1158            torch.tensor([False]),
1159            torch.Tensor([False]),
1160            torch.ones(2, 2),
1161        )
1162
1163        orig_res = f(*inputs)
1164        orig = exir.capture(
1165            f,
1166            inputs,
1167            exir.CaptureConfig(),
1168        ).to_edge()
1169        partitioned = orig
1170        partitioned.exported_program = to_backend(
1171            orig.exported_program, AddMulPartitionerDemo()
1172        )
1173
1174        new_res = partitioned(*inputs)
1175        self.assertTrue(torch.allclose(orig_res, new_res[0]))
1176
1177        toplevel_lowered = get_lowered_submodules(
1178            partitioned.exported_program.graph_module
1179        )
1180        self.assertEqual(len(toplevel_lowered), 1)
1181        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1182            toplevel_lowered[0][1].original_module.graph_module.code
1183        )
1184
1185        # Toplevel module only has the map submodule
1186        partitioned_submodules = get_control_flow_submodules(
1187            partitioned.exported_program.graph_module
1188        )
1189        self.assertEqual(len(partitioned_submodules), 1)
1190
1191        # Map module has the cond submodules
1192        map_submodules = get_control_flow_submodules(partitioned_submodules[0][1])
1193        self.assertEqual(len(map_submodules), 2)
1194
1195        # True module
1196        true_module = map_submodules[0][1]
1197        true_lowered = get_lowered_submodules(true_module)
1198        self.assertEqual(len(true_lowered), 1)
1199        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run(
1200            true_lowered[0][1].original_module.graph_module.code
1201        )
1202
1203        # False module
1204        false_lowered = get_lowered_submodules(map_submodules[1][1])
1205        self.assertEqual(len(false_lowered), 0)
1206
1207        # True module has the nested cond submodules
1208        true_submodules = get_control_flow_submodules(true_module)
1209        self.assertEqual(len(true_submodules), 2)
1210
1211        # Nested True module
1212        true_true_lowered = get_lowered_submodules(true_submodules[0][1])
1213        self.assertEqual(len(true_true_lowered), 1)
1214        FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check(
1215            "executorch_exir_dialects_edge__ops_aten_mm_default"
1216        ).run(true_true_lowered[0][1].original_module.graph_module.code)
1217
1218        # Nested False module
1219        true_false_lowered = get_lowered_submodules(true_submodules[1][1])
1220        self.assertEqual(len(true_false_lowered), 1)
1221        FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run(
1222            true_false_lowered[0][1].original_module.graph_module.code
1223        )
1224
1225    def test_list_input(self):
1226        def f(x: List[torch.Tensor]):
1227            y = x[0] + x[1]
1228            return y
1229
1230        inputs = ([torch.randn(2, 2), torch.randn(2, 2)],)
1231        edge_prog = exir.capture(f, inputs, exir.CaptureConfig()).to_edge()
1232        lowered_gm = to_backend(
1233            BackendWithCompilerDemo.__name__, edge_prog.exported_program, []
1234        )
1235
1236        class ComposedM(torch.nn.Module):
1237            def __init__(self):
1238                super().__init__()
1239                self.lowered = lowered_gm
1240
1241            def forward(self, x: List[torch.Tensor]):
1242                return self.lowered(x)
1243
1244        gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
1245        gm(*inputs)
1246
1247    def test_dict_input(self):
1248        class M(torch.nn.Module):
1249            def forward(self, x: Dict[str, torch.Tensor]):
1250                y = x["a"] + x["b"]
1251                return y
1252
1253        inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},)
1254        edge_prog = exir.to_edge(torch.export.export(M(), inputs))
1255        lowered_gm = to_backend(
1256            BackendWithCompilerDemo.__name__, edge_prog.exported_program(), []
1257        )
1258
1259        class ComposedM(torch.nn.Module):
1260            def __init__(self):
1261                super().__init__()
1262                self.lowered = lowered_gm
1263
1264            def forward(self, x: List[torch.Tensor]):
1265                return self.lowered(x)
1266
1267        gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
1268        gm(*inputs)
1269