xref: /aosp_15_r20/external/executorch/exir/backend/test/demos/test_xnnpack_qnnpack.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import executorch.exir as exir
10
11import torch
12from executorch.backends.fb.qnnpack.partition.qnnpack_partitioner import (
13    QnnpackPartitioner,
14)
15from executorch.backends.fb.qnnpack.qnnpack_preprocess import QnnpackBackend
16from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
17    XnnpackFloatingPointPartitioner,
18)
19
20# import the xnnpack backend implementation
21from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
22
23from executorch.exir import CaptureConfig
24from executorch.exir.backend.backend_api import to_backend, validation_disabled
25from executorch.exir.passes.spec_prop_pass import SpecPropPass
26
27from executorch.extension.pybindings.portable_lib import (  # @manual
28    _load_for_executorch_from_buffer,
29)
30from executorch.extension.pytree import tree_flatten
31from torch.ao.quantization.backend_config.executorch import (
32    get_executorch_backend_config,
33)
34from torch.ao.quantization.observer import (
35    default_dynamic_quant_observer,
36    default_per_channel_weight_observer,
37)
38from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping
39from torch.ao.quantization.quantize_fx import (
40    _convert_to_reference_decomposed_fx,
41    prepare_fx,
42)
43
44
45class TestXnnQnnBackends(unittest.TestCase):
46    def test_add_xnnpack_and_dqlinear_qnn(self):
47        qconfig_mapping = QConfigMapping().set_object_type(
48            torch.nn.Linear,
49            QConfig(
50                activation=default_dynamic_quant_observer,
51                weight=default_per_channel_weight_observer,
52            ),
53        )
54        in_size = 1
55        in_features = 3
56        out_features = 4
57
58        class LinearAndAdd(torch.nn.Module):
59            def __init__(self):
60                super().__init__()
61                self.linear = torch.nn.Linear(in_features, out_features)
62
63            def forward(self, x, y):
64                return self.linear(x) + y
65
66        linear_and_add_mod = LinearAndAdd()
67
68        example_inputs = (
69            torch.ones(in_size, in_features, dtype=torch.float),
70            torch.ones(in_size, out_features, dtype=torch.float),
71        )
72
73        prepared_mod = prepare_fx(
74            linear_and_add_mod,
75            qconfig_mapping,
76            example_inputs,
77            backend_config=get_executorch_backend_config(),
78        )
79
80        converted_mod: torch.fx.GraphModule = _convert_to_reference_decomposed_fx(
81            prepared_mod
82        )
83
84        # Step 2: EXIR capturing
85        capture_config = CaptureConfig(enable_dynamic_shape=False)
86        captured_mod = exir.capture(
87            converted_mod, example_inputs, config=capture_config
88        ).to_edge(
89            exir.EdgeCompileConfig(
90                _check_ir_validity=False,
91            )
92        )
93
94        # Step 3.1: Lower dynamic quant linear to qnnpack
95        with validation_disabled():
96            module_with_qnnpack_delegate = captured_mod
97            module_with_qnnpack_delegate.exported_program = to_backend(
98                captured_mod.exported_program, QnnpackPartitioner()
99            )
100
101        # Step 3.2: Lower add to xnnpack
102        with validation_disabled():
103            module_with_xnn_and_qnn = module_with_qnnpack_delegate
104            module_with_xnn_and_qnn.exported_program = to_backend(
105                module_with_qnnpack_delegate.exported_program,
106                XnnpackFloatingPointPartitioner(),
107            )
108
109        program_with_delegates = module_with_xnn_and_qnn.to_executorch(
110            exir.ExecutorchBackendConfig(passes=[SpecPropPass()]),
111        )
112        # The first delegate backend is Qnnpack
113        self.assertEqual(
114            program_with_delegates.program.execution_plan[0].delegates[0].id,
115            QnnpackBackend.__name__,
116        )
117        # The second delegate backend is Xnnpack
118        self.assertEqual(
119            program_with_delegates.program.execution_plan[0].delegates[1].id,
120            XnnpackBackend.__name__,
121        )
122
123        executorch_module = _load_for_executorch_from_buffer(
124            program_with_delegates.buffer
125        )
126        inputs_flattened, _ = tree_flatten(example_inputs)
127
128        model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
129        ref_output = captured_mod(*example_inputs)
130
131        # Compare the result from executor and eager mode direclty
132        self.assertTrue(
133            torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
134        )
135