xref: /aosp_15_r20/external/executorch/exir/tests/test_quant_fusion_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10
11import torch
12from executorch import exir
13from executorch.exir import EdgeCompileConfig, to_edge
14from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
15from executorch.exir.tests.common import register_additional_test_aten_ops
16from torch.ao.quantization import (  # @manual
17    float_qparams_weight_only_qconfig,
18    get_default_qconfig_mapping,
19)
20from torch.ao.quantization.backend_config.executorch import (
21    get_executorch_backend_config,
22)
23
24from torch.ao.quantization.qconfig_mapping import QConfigMapping
25from torch.ao.quantization.quantize_fx import (
26    _convert_to_reference_decomposed_fx,
27    prepare_fx,
28)
29from torch.export import export
30from torch.nn import functional as F
31
32from torch.testing import FileCheck
33
34
35class TestQuantFusionPass(unittest.TestCase):
36    @classmethod
37    def setUpClass(cls) -> None:
38        register_additional_test_aten_ops()
39
40    def test_add(self) -> None:
41        class M(torch.nn.Module):
42            def forward(self, x, y):
43                # edge case, doesn't work yet, but we can add a fusion
44                # pattern to enable it if needed
45                # return x + x
46                return x + y
47
48        example_inputs = (torch.randn(1, 5), torch.randn(1, 5))
49        m = M().eval()
50        # TODO: define qconfig_mapping specifically for executorch
51        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
52        m = prepare_fx(
53            m,
54            qconfig_mapping,
55            example_inputs,
56            backend_config=get_executorch_backend_config(),
57        )
58        m = _convert_to_reference_decomposed_fx(m)
59        config = EdgeCompileConfig(_check_ir_validity=False)
60        m = to_edge(export(m, example_inputs), compile_config=config)
61        # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
62        m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
63        # check that we are using functional variant of q/dq/add
64        FileCheck().check(
65            "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
66        ).check(
67            "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default"
68        ).check(
69            "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
70        ).run(
71            m.exported_program().graph_module.code
72        )
73        m = m.to_executorch()
74        # check that we are using out variant of q/dq/add
75        FileCheck().check("torch.ops.quantized_decomposed.add.out").run(
76            m.exported_program().graph_module.code
77        )
78
79    def test_reshape(self) -> None:
80        class M(torch.nn.Module):
81            def forward(self, x, y):
82                x = x + y
83                x = x.reshape(1, x.numel())
84                return x
85
86        example_inputs = (torch.randn(3, 5), torch.randn(3, 5))
87        m = M().eval()
88        # TODO: define qconfig_mapping specifically for executorch
89        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
90        m = prepare_fx(
91            m,
92            qconfig_mapping,
93            example_inputs,
94            backend_config=get_executorch_backend_config(),
95        )
96        m(*example_inputs)
97        m = _convert_to_reference_decomposed_fx(m)
98        config = EdgeCompileConfig(_check_ir_validity=False)
99        m = to_edge(export(m, example_inputs), compile_config=config)
100        # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
101        m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
102        # check that we are using functional variant of q/dq/add/reshape
103        # make sure we only have two quant and one dequant since the q/dq around reshape
104        # should be fused
105        FileCheck().check_count(
106            "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
107            2,
108            exactly=True,
109        ).check(
110            "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default"
111        ).check(
112            "executorch_exir_dialects_edge__ops_aten_view_copy_default"
113        ).check_count(
114            "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
115            1,
116            exactly=True,
117        ).run(
118            m.exported_program().graph_module.code
119        )
120
121        m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False))
122        # check that we are using out variant of q/dq/add
123        FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
124            "torch.ops.aten.view_copy.out"
125        ).run(m.exported_program().graph_module.code)
126
127    def test_slice(self) -> None:
128        """We don't proactively quantize slice today, but we'll fuse the dq-slice-q
129
130        pattern into a int8 slice operator, we can revist this later to
131        see if proactively quantize slice is needed or not
132        """
133
134        class M(torch.nn.Module):
135            def forward(self, x, y):
136                x = x + y
137                x = x[1:]
138                y = y[1:]
139                x = x + y
140                return x
141
142        example_inputs = (torch.randn(3, 5), torch.randn(3, 5))
143        m = M().eval()
144        # TODO: define qconfig_mapping specifically for executorch
145        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
146        m = prepare_fx(
147            m,
148            qconfig_mapping,
149            example_inputs,
150            backend_config=get_executorch_backend_config(),
151        )
152        m = _convert_to_reference_decomposed_fx(m)
153        config = EdgeCompileConfig(_check_ir_validity=False)
154        m = to_edge(export(m, example_inputs), compile_config=config)
155        # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
156        m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
157        # check that we are using functional variant of q/dq/add/slice
158        # make sure we only have one quant and one dequant since the q/dq around slice
159        # should be fused
160        FileCheck().check_count(
161            "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
162            2,
163            exactly=True,
164        ).check("executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor").check(
165            "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
166        ).check(
167            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
168        ).check(
169            "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default"
170        ).check(
171            "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
172        ).run(
173            m.exported_program().graph_module.code
174        )
175
176        m = m.to_executorch()
177        # check that we are using out variant of add and slice_copy
178        FileCheck().check("torch.ops.quantized_decomposed.add.out").check(
179            "torch.ops.aten.slice_copy.Tensor_out"
180        ).run(m.exported_program().graph_module.code)
181
182    def test_cat(self) -> None:
183        class M(torch.nn.Module):
184            def forward(self, x, y):
185                x = torch.cat([x, x], dim=0)
186                return x
187
188        example_inputs = (torch.randn(3, 5), torch.randn(3, 5))
189        m = M().eval()
190        # TODO: define qconfig_mapping specifically for executorch
191        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
192        m = prepare_fx(
193            m,
194            qconfig_mapping,
195            example_inputs,
196            backend_config=get_executorch_backend_config(),
197        )
198        m(*example_inputs)
199        m = _convert_to_reference_decomposed_fx(m)
200        config = EdgeCompileConfig(_check_ir_validity=False)
201        m = to_edge(export(m, example_inputs), compile_config=config)
202        # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
203        m = m.transform([QuantFusionPass()])
204        # check that we are using functional variant of q/dq/cat
205        FileCheck().check_count(
206            "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default",
207            1,
208            exactly=True,
209        ).check("executorch_exir_dialects_edge__ops_aten_cat_default").check_count(
210            "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default",
211            1,
212            exactly=True,
213        ).run(
214            m.exported_program().graph_module.code
215        )
216
217        m = m.to_executorch()
218        # Note: quantized add is not fused since the qparams are the same and current subgraph_rewriter
219        # doesn't work for the case when single graph node map to two different pattern node
220        # one work around would be to add new patterns for the case when qparams are the same
221        # for quantized add pattern, but this may not be needed in real use case, we can
222        # add this workaround if needed in another diff
223        FileCheck().check_count(
224            "torch.ops.quantized_decomposed.quantize_per_tensor.out", 1, exactly=True
225        ).check("torch.ops.aten.cat.out").check_count(
226            "torch.ops.quantized_decomposed.dequantize_per_tensor.out", 1, exactly=True
227        ).run(
228            m.exported_program().graph_module.code
229        )
230
231    def test_embedding_byte(self) -> None:
232        class M(torch.nn.Module):
233            def __init__(self):
234                super().__init__()
235                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
236
237            def forward(self, indices):
238                return self.emb(indices)
239
240        for qconfig in [float_qparams_weight_only_qconfig]:
241            m = M().eval()
242            indices = torch.tensor(
243                [
244                    9,
245                    6,
246                    5,
247                    7,
248                    8,
249                    8,
250                    9,
251                    2,
252                    8,
253                    6,
254                    6,
255                    9,
256                    1,
257                    6,
258                    8,
259                    8,
260                    3,
261                    2,
262                    3,
263                    6,
264                    3,
265                    6,
266                    5,
267                    7,
268                    0,
269                    8,
270                    4,
271                    6,
272                    5,
273                    8,
274                    2,
275                    3,
276                ]
277            )
278            example_inputs = (indices,)
279            # TODO: define qconfig_mapping specifically for executorch
280            qconfig_mapping = get_default_qconfig_mapping("qnnpack")
281            qconfig_mapping = qconfig_mapping.set_object_type(
282                torch.nn.Embedding, qconfig
283            )
284            m = prepare_fx(
285                m,
286                qconfig_mapping,
287                example_inputs,
288                backend_config=get_executorch_backend_config(),
289            )
290            m(*example_inputs)
291            m = _convert_to_reference_decomposed_fx(m)
292            compile_config = EdgeCompileConfig(
293                _check_ir_validity=False,
294                _use_edge_ops=True,
295            )
296            m = to_edge(export(m, example_inputs), compile_config=compile_config)
297            # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
298            m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
299            # check that we are using functional variant of q/dq/cat
300            FileCheck().check(
301                "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
302            ).check(
303                "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
304            ).run(
305                m.exported_program().graph_module.code
306            )
307
308            # TODO: enable after the out variants of quantize_per_channel is supported
309            # m = m.to_executorch()
310            # FileCheck().check(
311            #     "executorch_exir_dialects_edge__ops_quantized_decomposed.quantize_per_channel.out",
312            # ).check("executorch_exir_dialects_edge__ops_quantized_decomposed.embedding_byte.out"
313            # ).run(
314            #     m.dump_graph_module().code
315            # )
316
317    def test_embedding_byte_functional(self) -> None:
318        class M(torch.nn.Module):
319            def __init__(self):
320                super().__init__()
321                self.weight = torch.rand((3, 2))
322
323            def forward(self, indices):
324                return F.embedding(indices, self.weight)
325
326        for qconfig in [float_qparams_weight_only_qconfig]:
327            m = M().eval()
328            indices = torch.tensor(
329                [
330                    0,
331                ]
332            )
333            example_inputs = (indices,)
334
335            qconfig_mapping = QConfigMapping().set_object_type(
336                F.embedding,
337                qconfig,
338            )
339
340            m = prepare_fx(
341                m,
342                qconfig_mapping,
343                example_inputs,
344                backend_config=get_executorch_backend_config(),
345            )
346            m(*example_inputs)
347            m = _convert_to_reference_decomposed_fx(m)
348            compile_config = EdgeCompileConfig(
349                _check_ir_validity=False,
350                _use_edge_ops=True,
351            )
352            m = to_edge(export(m, example_inputs), compile_config=compile_config)
353            # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
354            m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
355            # check that we are using functional variant of q/dq/cat
356            FileCheck().check(
357                "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default",
358            ).check(
359                "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
360            ).run(
361                m.exported_program().graph_module.code
362            )
363
364            # TODO: enable after the out variants of quantize_per_channel is supported
365            # m = m.to_executorch()
366            # FileCheck().check(
367            #     "executorch_exir_dialects_edge__ops_quantized_decomposed.quantize_per_channel.out",
368            # ).check("executorch_exir_dialects_edge__ops_quantized_decomposed.embedding_byte.out"
369            # ).run(
370            #     m.dump_graph_module().code
371            # )
372