xref: /aosp_15_r20/external/executorch/backends/example/test_example_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import unittest
9
10import torch
11from executorch import exir
12from executorch.backends.example.example_partitioner import ExamplePartitioner
13from executorch.backends.example.example_quantizer import ExampleQuantizer
14from executorch.exir import to_edge
15
16from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import (
17    DuplicateDequantNodePass,
18)
19from executorch.exir.delegate import executorch_call_delegate
20
21from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
22from torch.export import export
23
24from torchvision.models.quantization import mobilenet_v2
25
26
27class TestExampleDelegate(unittest.TestCase):
28    def test_delegate_linear(self):
29        class Conv2dModule(torch.nn.Module):
30            def __init__(self):
31                super().__init__()
32                self.conv2d = torch.nn.Conv2d(16, 33, 3)
33
34            def forward(self, arg):
35                return self.conv2d(arg)
36
37            @staticmethod
38            def get_example_inputs():
39                return (torch.randn(20, 16, 50, 100),)
40
41        model = Conv2dModule()
42        example_inputs = Conv2dModule.get_example_inputs()
43        EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
44            _check_ir_validity=False,
45            _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
46        )
47
48        m = model.eval()
49        m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
50        # print("original model:", m)
51        quantizer = ExampleQuantizer()
52        # quantizer = XNNPACKQuantizer()
53        # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
54        # operator_config = get_symmetric_quantization_config(is_per_channel=False)
55        # quantizer.set_global(operator_config)
56        m = prepare_pt2e(m, quantizer)
57        # calibration
58        m(*example_inputs)
59        m = convert_pt2e(m)
60
61        quantized_gm = m
62        exported_program = to_edge(
63            export(quantized_gm, copy.deepcopy(example_inputs)),
64            compile_config=EDGE_COMPILE_CONFIG,
65        )
66
67        lowered_export_program = exported_program.to_backend(
68            ExamplePartitioner(),
69        )
70
71        print("After lowering to qnn backend: ")
72        lowered_export_program.exported_program().graph.print_tabular()
73
74    def test_delegate_mobilenet_v2(self):
75        model = mobilenet_v2(num_classes=3)
76        model.eval()
77        example_inputs = (torch.rand(1, 3, 320, 240),)
78
79        EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
80            _check_ir_validity=False,
81            _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
82        )
83
84        m = model.eval()
85        m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
86        quantizer = ExampleQuantizer()
87
88        m = prepare_pt2e(m, quantizer)
89        # calibration
90        m(*example_inputs)
91        m = convert_pt2e(m)
92
93        quantized_gm = m
94        exported_program = to_edge(
95            export(quantized_gm, copy.deepcopy(example_inputs)),
96            compile_config=EDGE_COMPILE_CONFIG,
97        )
98
99        lowered_export_program = exported_program.transform(
100            [DuplicateDequantNodePass()]
101        ).to_backend(
102            ExamplePartitioner(),
103        )
104
105        lowered_export_program.exported_program().graph.print_tabular()
106
107        call_deleage_node = [
108            node
109            for node in lowered_export_program.exported_program().graph.nodes
110            if node.target == executorch_call_delegate
111        ]
112        self.assertEqual(len(call_deleage_node), 1)
113