xref: /aosp_15_r20/external/executorch/exir/tests/test_quantize_io_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import unittest
9
10import torch
11from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
12from executorch.exir.passes.quantize_io_pass import (
13    get_config_method_name,
14    QuantizeInputs,
15    QuantizeOutputs,
16)
17from executorch.exir.tensor import get_scalar_type
18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
19
20from torch.ao.quantization.quantizer.xnnpack_quantizer import (
21    get_symmetric_quantization_config,
22    XNNPACKQuantizer,
23)
24from torch.testing import FileCheck
25
26op_str = {
27    "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
28    "dq": "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
29}
30
31
32class TestQuantIOPass(unittest.TestCase):
33    class Add(torch.nn.Module):
34        def forward(self, x, y):
35            return x + y
36
37    def _quantize(self, mod, example_inputs):
38        quantizer = XNNPACKQuantizer()
39        operator_config = get_symmetric_quantization_config()
40        quantizer.set_global(operator_config)
41        m = torch.export.export_for_training(
42            mod, copy.deepcopy(example_inputs)
43        ).module()
44        m = prepare_pt2e(m, quantizer)
45        _ = m(*example_inputs)
46        m = convert_pt2e(m)
47        exported_program = torch.export.export_for_training(m, example_inputs)
48        return exported_program
49
50    def _check_count(self, op, count, epm):
51        code = epm.exported_program().graph_module.code
52        FileCheck().check_count(op, count, exactly=True).run(code)
53
54    def _get_edge_prog_manager(self, mod, example_inputs):
55        exported_program = self._quantize(mod, example_inputs)
56        edge_program_manager = to_edge_transform_and_lower(
57            exported_program,
58            transform_passes=[],
59            partitioner=None,
60            compile_config=EdgeCompileConfig(_check_ir_validity=False),
61        )
62
63        self._check_count(op_str["dq"], 3, edge_program_manager)
64        self._check_count(op_str["q"], 3, edge_program_manager)
65        return edge_program_manager
66
67    def test_add_drop_q_inputs(self) -> None:
68        example_inputs = (torch.randn(1, 5), torch.randn(1, 5))
69        mod = self.Add().eval()
70        edge_program_manager = self._get_edge_prog_manager(mod, example_inputs)
71        reference_outputs = edge_program_manager.exported_program().module()(
72            *example_inputs
73        )
74
75        edge_program_manager_qin = edge_program_manager.transform(
76            [
77                QuantizeInputs(
78                    edge_program_manager=edge_program_manager,
79                    quantized_inputs_idx=[0, 1],
80                    method_name="forward",
81                )
82            ]
83        )
84        self._check_count(op_str["dq"], 3, edge_program_manager)
85        self._check_count(op_str["q"], 1, edge_program_manager)
86
87        quantized_example_inputs = []
88        for i in range(len(example_inputs)):
89            d = edge_program_manager_qin._config_methods
90            scale = d[get_config_method_name("forward", "input", i, "scale")]
91            zp = d[get_config_method_name("forward", "input", i, "zp")]
92            quant_min = d[get_config_method_name("forward", "input", i, "quant_min")]
93            quant_max = d[get_config_method_name("forward", "input", i, "quant_max")]
94            dtype = get_scalar_type(
95                d[get_config_method_name("forward", "input", i, "dtype")]
96            )
97
98            quantized_example_inputs.append(
99                torch.ops.quantized_decomposed.quantize_per_tensor.default(
100                    example_inputs[i], scale, zp, quant_min, quant_max, dtype
101                ),
102            )
103        quantized_example_inputs = tuple(quantized_example_inputs)
104        output = edge_program_manager_qin.exported_program().module()(
105            *quantized_example_inputs
106        )
107        torch.testing.assert_close(
108            reference_outputs[0],
109            output[0],
110        )
111
112    def test_add_drop_dq_output(self) -> None:
113        example_inputs = (torch.randn(1, 5), torch.randn(1, 5))
114        mod = self.Add().eval()
115        edge_program_manager = self._get_edge_prog_manager(mod, example_inputs)
116        reference_outputs = edge_program_manager.exported_program().module()(
117            *example_inputs
118        )
119
120        edge_program_manager_dqout = edge_program_manager.transform(
121            [
122                QuantizeOutputs(
123                    edge_program_manager=edge_program_manager,
124                    quantized_outputs_idx_list=[0],
125                    method_name="forward",
126                )
127            ]
128        )
129        self._check_count(op_str["dq"], 2, edge_program_manager)
130        self._check_count(op_str["q"], 3, edge_program_manager)
131
132        quantized_outputs = edge_program_manager_dqout.exported_program().module()(
133            *example_inputs
134        )
135
136        dequantized_outputs = []
137        for i in range(len(quantized_outputs)):
138            d = edge_program_manager_dqout._config_methods
139            scale = d[get_config_method_name("forward", "output", i, "scale")]
140            zp = d[get_config_method_name("forward", "output", i, "zp")]
141            q_min = d[get_config_method_name("forward", "output", i, "quant_min")]
142            q_max = d[get_config_method_name("forward", "output", i, "quant_max")]
143            dtype = get_scalar_type(
144                d[get_config_method_name("forward", "output", i, "dtype")]
145            )
146            dequantized_outputs.append(
147                torch.ops.quantized_decomposed.dequantize_per_tensor.default(
148                    quantized_outputs[i], scale, zp, q_min, q_max, dtype
149                )
150            )
151        dequantized_outputs = tuple(dequantized_outputs)
152
153        torch.testing.assert_close(
154            reference_outputs[0],
155            dequantized_outputs[0],
156        )
157