xref: /aosp_15_r20/external/pytorch/test/quantization/fx/test_quantize_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3from collections import OrderedDict
4import contextlib
5import torch
6import torch.nn.functional as F
7import torch.nn as nn
8import torch.ao.nn.quantized as nnq
9import torch.ao.nn.quantized.reference as nnqr
10import torch.ao.nn.quantized.dynamic as nnqd
11import torch.ao.nn.intrinsic as nni
12import torch.ao.nn.intrinsic.quantized as nniq
13import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
14import torch.multiprocessing as mp
15from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
16
17# graph mode quantization based on fx
18from torch.ao.quantization.quantize_fx import (
19    prepare_fx,
20    convert_fx,
21    convert_to_reference_fx,
22    _convert_to_reference_decomposed_fx,
23    prepare_qat_fx,
24    fuse_fx,
25)
26
27
28from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler
29
30from torch.ao.quantization.fx.match_utils import (
31    _is_match,
32    MatchAllNode,
33)
34
35from torch.ao.quantization import (
36    QuantType,
37)
38from torch.ao.quantization.quant_type import _get_quant_type_to_str
39
40from torch.ao.quantization import (
41    QuantStub,
42    DeQuantStub,
43    QuantWrapper,
44    default_qconfig,
45    default_dynamic_qconfig,
46    default_per_channel_qconfig,
47    default_qat_qconfig,
48    default_reuse_input_qconfig,
49    default_symmetric_qnnpack_qconfig,
50    default_symmetric_qnnpack_qat_qconfig,
51    per_channel_dynamic_qconfig,
52    float16_dynamic_qconfig,
53    float16_static_qconfig,
54    float_qparams_weight_only_qconfig,
55    float_qparams_weight_only_qconfig_4bit,
56    get_default_qconfig,
57    get_default_qat_qconfig,
58    get_default_qconfig_mapping,
59    get_default_qat_qconfig_mapping,
60    fuse_modules,
61    fuse_modules_qat,
62    prepare,
63    prepare_qat,
64    convert,
65    quantize_dynamic,
66    default_placeholder_observer,
67    default_weight_observer,
68    PerChannelMinMaxObserver,
69    FixedQParamsFakeQuantize,
70    FixedQParamsObserver,
71    FusedMovingAvgObsFakeQuantize,
72    FakeQuantize,
73    MovingAverageMinMaxObserver,
74    HistogramObserver,
75    ReuseInputObserver,
76    QConfig,
77    default_embedding_qat_qconfig,
78)
79
80from torch.ao.quantization.backend_config import (
81    get_fbgemm_backend_config,
82    get_qnnpack_backend_config,
83    BackendConfig,
84    BackendPatternConfig,
85    DTypeConfig,
86    DTypeWithConstraints,
87    ObservationType
88)
89from torch.ao.quantization.backend_config.native import (
90    get_test_only_legacy_native_backend_config,
91)
92
93from torch.ao.quantization.qconfig_mapping import (
94    _get_symmetric_qnnpack_qconfig_mapping,
95    _get_symmetric_qnnpack_qat_qconfig_mapping,
96    _GLOBAL_DICT_KEY,
97    _MODULE_NAME_DICT_KEY,
98    _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY,
99    _MODULE_NAME_REGEX_DICT_KEY,
100    _OBJECT_TYPE_DICT_KEY,
101    QConfigMapping,
102)
103
104from torch.ao.quantization.fx.qconfig_mapping_utils import (
105    _get_object_type_qconfig,
106    _get_module_name_qconfig,
107    _get_module_name_regex_qconfig,
108    _maybe_adjust_qconfig_for_module_name_object_type_order,
109)
110
111from torch.ao.quantization.fx.pattern_utils import (
112    _DEFAULT_FUSION_PATTERNS,
113    _DEFAULT_QUANTIZATION_PATTERNS,
114    _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP,
115    _DEFAULT_OUTPUT_OBSERVER_MAP,
116    _register_fusion_pattern,
117    _register_quant_pattern,
118    get_default_output_activation_post_process_map
119)
120
121from torch.ao.quantization.fx.custom_config import (
122    STANDALONE_MODULE_NAME_DICT_KEY,
123    STANDALONE_MODULE_CLASS_DICT_KEY,
124    FLOAT_TO_OBSERVED_DICT_KEY,
125    OBSERVED_TO_QUANTIZED_DICT_KEY,
126    NON_TRACEABLE_MODULE_NAME_DICT_KEY,
127    NON_TRACEABLE_MODULE_CLASS_DICT_KEY,
128    INPUT_QUANTIZED_INDEXES_DICT_KEY,
129    OUTPUT_QUANTIZED_INDEXES_DICT_KEY,
130    PRESERVED_ATTRIBUTES_DICT_KEY,
131    FuseCustomConfig,
132    ConvertCustomConfig,
133    PrepareCustomConfig,
134    StandaloneModuleConfigEntry,
135)
136import torch.ao.quantization.fx.lstm_utils
137
138from torch.ao.quantization.fx.utils import (
139    _reroute_tuple_getitem_pattern,
140    NodeInfo,
141)
142
143from torch.ao.quantization.fake_quantize import (
144    default_fixed_qparams_range_0to1_fake_quant,
145    default_fixed_qparams_range_neg1to1_fake_quant,
146)
147
148from torch.ao.quantization.observer import (
149    default_fixed_qparams_range_0to1_observer,
150    default_fixed_qparams_range_neg1to1_observer,
151    MinMaxObserver,
152    _is_activation_post_process,
153)
154
155# test utils
156from hypothesis import given, settings
157from hypothesis import strategies as st
158from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
159from torch.testing._internal.common_quantization import (
160    LinearReluLinearModel,
161    LinearReluModel,
162    LinearBnLeakyReluModel,
163    LinearTanhModel,
164    ConvBnAddReluModel,
165    QuantizationTestCase,
166    skipIfNoFBGEMM,
167    skipIfNoQNNPACK,
168    skip_if_no_torchvision,
169    train_one_epoch,
170    run_ddp,
171    test_only_eval_fn,
172    test_only_train_fn,
173    ModelForConvTransposeBNFusion,
174    get_supported_device_types,
175    skipIfNoONEDNN,
176)
177
178from torch.testing._internal.common_quantization import (
179    LinearModelWithSubmodule,
180    ResNetBase,
181    RNNDynamicModel,
182    RNNCellDynamicModel,
183)
184
185from torch.testing._internal.common_quantized import (
186    supported_qengines,
187    override_qengines,
188    override_quantized_engine,
189)
190
191from torch.testing._internal.common_utils import (
192    TemporaryFileName,
193    IS_ARM64,
194    skipIfTorchDynamo,
195)
196
197from torch.testing._internal.common_quantization import NodeSpec as ns
198
199from torch.testing import FileCheck
200
201import copy
202import itertools
203import operator
204import unittest
205import io
206from typing import Callable, Optional, List, Tuple
207
208class BinaryOp(torch.nn.Module):
209    def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
210        """ ibinary_op means inplace binary op
211        """
212        super().__init__()
213        self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
214        self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
215        self.is_scalar = is_scalar
216        self.op = ibinary_op if ibinary_op and is_inplace else binary_op
217
218    def forward(self, x, y):
219        x = self.conv1(x)
220        y = 3 if self.is_scalar else self.conv2(y)
221        # x = x + y
222        x = self.op(x, y)
223        # x = y + x
224        x = self.op(y, x)
225        return x
226
227class BinaryOpNonQuantizedInput(torch.nn.Module):
228    def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
229        """ ibinary_op means inplace binary op
230        """
231        super().__init__()
232        self.is_scalar = is_scalar
233        self.op = ibinary_op if ibinary_op and is_inplace else binary_op
234
235    def forward(self, x, y):
236        y = 3 if self.is_scalar else y
237        x = self.op(x, y)
238        return x
239
240class BinaryOpRelu(torch.nn.Module):
241    def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable,
242                 is_scalar):
243        """ ibinary_op means inplace binary op
244        """
245        super().__init__()
246        self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
247        self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
248        self.op = ibinary_op if ibinary_op and is_inplace else binary_op
249        self.relu_callable = relu_callable
250        self.is_scalar = is_scalar
251        if relu_callable is torch.nn.ReLU:
252            self.relu = torch.nn.ReLU()
253        else:
254            self.relu = relu_callable
255
256    def forward(self, x, y):
257        x = self.conv1(x)
258        y = 3 if self.is_scalar else self.conv2(y)
259        x = self.op(x, y)
260        x = self.relu(x)
261        x = self.op(y, x)
262        x = self.relu(x)
263        return x
264
265@torch.fx.wrap
266def _user_func_with_complex_return_type(x):
267    return list(torch.split(x, 1, 1))
268
269class TestFuseFx(QuantizationTestCase):
270    def test_fuse_conv_bn_relu(self):
271        class M(torch.nn.Module):
272            def __init__(self) -> None:
273                super().__init__()
274                self.conv1d = nn.Conv1d(1, 1, 1)
275                self.conv2d = nn.Conv2d(1, 1, 1)
276                self.conv3d = nn.Conv3d(1, 1, 1)
277                self.bn1d = nn.BatchNorm1d(1)
278                self.bn2d = nn.BatchNorm2d(1)
279                self.bn3d = nn.BatchNorm3d(1)
280                self.conv1d2 = nn.Conv1d(1, 1, 1)
281                self.conv2d2 = nn.Conv2d(1, 1, 1)
282                self.conv3d2 = nn.Conv3d(1, 1, 1)
283                self.bn1d2 = nn.BatchNorm1d(1)
284                self.bn2d2 = nn.BatchNorm2d(1)
285                self.bn3d2 = nn.BatchNorm3d(1)
286                self.relu = nn.ReLU()
287
288            def forward(self, x):
289                x = self.conv1d(x)
290                x = self.bn1d(x)
291                x = self.conv2d(x)
292                x = self.bn2d(x)
293                x = self.conv3d(x)
294                x = self.bn3d(x)
295                x = self.conv1d2(x)
296                x = self.bn1d2(x)
297                x = self.relu(x)
298                x = self.conv2d2(x)
299                x = self.bn2d2(x)
300                x = self.relu(x)
301                x = self.conv3d2(x)
302                x = self.bn3d2(x)
303                x = self.relu(x)
304                return x
305
306        # test train mode
307        m = M().train()
308        # currently we don't check if the module are configured with qconfig before fusion
309        # TODO: if we decide to do that in the future, this test needs to
310        # be updated
311        # train mode fuse_fx is called in prepare_qat_fx
312        m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),))
313        expected_nodes = [
314            ns.call_module(nni.ConvBn1d),
315            ns.call_module(nni.ConvBn2d),
316            ns.call_module(nni.ConvBn3d),
317            ns.call_module(nni.ConvBnReLU1d),
318            ns.call_module(nni.ConvBnReLU2d),
319            ns.call_module(nni.ConvBnReLU3d),
320        ]
321        expected_occurrence = {
322            ns.call_module(nn.ReLU): 0
323        }
324        self.checkGraphModuleNodes(
325            m,
326            expected_node_list=expected_nodes,
327            expected_node_occurrence=expected_occurrence)
328
329        # test eval mode
330        m = M().eval()
331        # fuse_fx is a top level api and only supports eval mode
332        m = fuse_fx(m)
333        expected_nodes = [
334            ns.call_module(nn.Conv1d),
335            ns.call_module(nn.Conv2d),
336            ns.call_module(nn.Conv3d),
337            ns.call_module(nni.ConvReLU1d),
338            ns.call_module(nni.ConvReLU2d),
339            ns.call_module(nni.ConvReLU3d),
340        ]
341        # ConvBnRelu1d is not fused
342        expected_occurrence = {
343            ns.call_module(nn.ReLU): 0
344        }
345        self.checkGraphModuleNodes(
346            m,
347            expected_node_list=expected_nodes,
348            expected_node_occurrence=expected_occurrence)
349
350    def test_fuse_linear_bn_eval(self):
351        class M(torch.nn.Module):
352            def __init__(self) -> None:
353                super().__init__()
354                self.linear = nn.Linear(1, 1)
355                self.bn1d = nn.BatchNorm1d(1)
356
357            def forward(self, x):
358                x = self.linear(x)
359                x = self.bn1d(x)
360                return x
361
362        # test eval mode
363        m = M().eval()
364        # fuse_fx is a top level api and only supports eval mode
365        m = fuse_fx(m)
366        expected_nodes = [
367            ns.call_module(nn.Linear),
368        ]
369        expected_occurrence = {
370            ns.call_module(nn.BatchNorm1d): 0,
371        }
372        self.checkGraphModuleNodes(
373            m,
374            expected_node_list=expected_nodes,
375            expected_node_occurrence=expected_occurrence)
376
377    @skipIfNoONEDNN
378    def test_fuse_linear_bn_leaky_relu_onednn(self):
379        # linear - bn - leaky_relu is fused for onednn backend only
380        from torch.ao.quantization.backend_config import get_onednn_backend_config
381        expected_nodes = [
382            ns.call_module(nni.LinearLeakyReLU),
383        ]
384        expected_occurrence = {
385            ns.call_module(nn.BatchNorm1d): 0,
386            ns.call_module(nn.LeakyReLU): 0,
387        }
388
389        for with_bn in [True, False]:
390            # test eval mode
391            m = LinearBnLeakyReluModel(with_bn).eval()
392            # fuse_fx is a top level api and only supports eval mode
393            m = fuse_fx(m,
394                        backend_config=get_onednn_backend_config())
395            self.checkGraphModuleNodes(
396                m,
397                expected_node_list=expected_nodes,
398                expected_node_occurrence=expected_occurrence)
399
400    def test_linear_bn_leaky_relu_not_fused_by_default(self):
401        # Make sure linear - bn - leaky_relu is not fused by default
402        for with_bn in [True, False]:
403            # test eval mode
404            m = LinearBnLeakyReluModel(with_bn).eval()
405            # fuse_fx is a top level api and only supports eval mode
406            m = fuse_fx(m)
407            expected_nodes = [
408                ns.call_module(nn.Linear),
409                ns.call_module(nn.LeakyReLU),
410            ]
411            expected_occurrence = {
412                ns.call_module(nni.LinearLeakyReLU): 0,
413            }
414            self.checkGraphModuleNodes(
415                m,
416                expected_node_list=expected_nodes,
417                expected_node_occurrence=expected_occurrence)
418
419    @skipIfNoONEDNN
420    def test_fuse_linear_tanh_for_onednn_backend(self):
421        # linear - tanh is fused for onednn backend only
422        from torch.ao.quantization.backend_config import get_onednn_backend_config
423        expected_nodes = [
424            ns.call_module(nni.LinearTanh),
425        ]
426        expected_occurrence = {
427            ns.call_module(nn.Linear): 0,
428            ns.call_module(nn.Tanh): 0,
429        }
430
431        # test eval mode
432        m = LinearTanhModel().eval()
433        # fuse_fx is a top level api and only supports eval mode
434        m = fuse_fx(m,
435                    backend_config=get_onednn_backend_config())
436        self.checkGraphModuleNodes(
437            m,
438            expected_node_list=expected_nodes,
439            expected_node_occurrence=expected_occurrence)
440
441    def test_linear_tanh_not_fused_by_default(self):
442        # Make sure linear - tanh is not fused by default
443        # test eval mode
444        m = LinearTanhModel().eval()
445        # fuse_fx is a top level api and only supports eval mode
446        m = fuse_fx(m)
447        expected_nodes = [
448            ns.call_module(nn.Linear),
449            ns.call_module(nn.Tanh),
450        ]
451        expected_occurrence = {
452            ns.call_module(nni.LinearTanh): 0,
453        }
454        self.checkGraphModuleNodes(
455            m,
456            expected_node_list=expected_nodes,
457            expected_node_occurrence=expected_occurrence)
458
459    def test_fuse_conv_bn_add_relu_onednn(self):
460        # conv - bn - add - relu is fused for onednn backend only
461        from torch.ao.quantization.backend_config import get_onednn_backend_config
462        options = itertools.product(
463            [True, False],  # with_bn
464            [True, False],  # with_relu
465            [True, False],  # conv in the left
466            [True, False],  # with_two_conv
467            [True, False],  # use_torch_add
468        )
469        for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
470            expected_nodes = [
471                ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d),
472            ]
473            expected_occurrence = {
474                ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1,
475                ns.call_module(nn.BatchNorm2d): 0,
476            }
477
478            # test eval mode
479            m = ConvBnAddReluModel(
480                with_bn=with_bn,
481                with_relu=with_relu,
482                left_conv=left_conv,
483                two_conv=two_conv,
484                use_torch_add=use_torch_add).eval()
485
486            m = fuse_fx(m,
487                        backend_config=get_onednn_backend_config())
488            self.checkGraphModuleNodes(
489                m,
490                expected_node_list=expected_nodes,
491                expected_node_occurrence=expected_occurrence)
492
493    def test_fuse_conv_bn_add_relu_by_default(self):
494        options = itertools.product(
495            [True, False],  # with_bn
496            [True, False],  # with_relu
497            [True, False],  # conv in the left
498            [True, False],  # with_two_conv
499            [True, False],  # use_torch_add
500        )
501        for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
502            # test eval mode
503            expected_nodes = [
504                ns.call_module(nn.Conv2d),
505            ]
506            expected_occurrence = {
507                ns.call_module(nni.ConvAdd2d): 0,
508            }
509            m = ConvBnAddReluModel(
510                with_bn=with_bn,
511                with_relu=with_relu,
512                left_conv=left_conv,
513                two_conv=two_conv,
514                use_torch_add=use_torch_add).eval()
515            m = fuse_fx(m)
516            self.checkGraphModuleNodes(
517                m,
518                expected_node_list=expected_nodes,
519                expected_node_occurrence=expected_occurrence)
520
521    @skipIfNoONEDNN
522    def test_fuse_conv_bn_add_relu_lowering(self):
523        """ Test fusion and lowering of Conv2d - (bn -) ReLU
524            by FX. For onednn backedn only.
525        """
526        from torch.ao.quantization.backend_config import get_onednn_backend_config
527        qconfig_mapping = get_default_qconfig_mapping('onednn')
528        with override_quantized_engine('onednn'):
529            options = itertools.product(
530                [True, False],  # with_bn
531                [True, False],  # with_relu
532                [True, False],  # conv in the left
533                [True, False],  # two_conv
534                [True, False],  # use_torch_add
535            )
536            for with_bn, with_relu, left_conv, two_conv, use_torch_add in options:
537                node_occurrence = {
538                    ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2,
539                    ns.call_method("dequantize"): 1,
540                    ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1,
541                    ns.call_module(nn.Conv2d): 0,
542                    ns.call_module(nn.ReLU): 0,
543                }
544                node_occurrence_ref = {
545                    ns.call_function(torch.quantize_per_tensor): 3,
546                    ns.call_method("dequantize"): 3,
547                }
548
549                # test eval mode
550                m = ConvBnAddReluModel(
551                    with_bn=with_bn,
552                    with_relu=with_relu,
553                    left_conv=left_conv,
554                    two_conv=two_conv,
555                    use_torch_add=use_torch_add).eval()
556                example_x = m.get_example_inputs()
557                m = prepare_fx(m, qconfig_mapping,
558                               example_inputs=example_x,
559                               backend_config=get_onednn_backend_config())
560                m_copy = copy.deepcopy(m)
561                m = convert_fx(m, backend_config=get_onednn_backend_config())
562                m_ref = convert_to_reference_fx(m_copy)
563                self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
564                self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
565                m(*example_x)
566
567    def test_fuse_convtranspose_bn_eval(self):
568
569        m = ModelForConvTransposeBNFusion().eval()
570        m = fuse_fx(m)
571
572        expected_nodes = [
573            ns.call_module(nn.ConvTranspose1d),
574            ns.call_module(nn.ConvTranspose2d),
575            ns.call_module(nn.ConvTranspose3d),
576        ]
577        expected_occurrence = {
578            ns.call_module(nn.BatchNorm1d): 0,
579            ns.call_module(nn.BatchNorm2d): 0,
580            ns.call_module(nn.BatchNorm3d): 0,
581        }
582        self.checkGraphModuleNodes(
583            m,
584            expected_node_list=expected_nodes,
585            expected_node_occurrence=expected_occurrence)
586
587
588    def test_fuse_module_relu(self):
589        class M(torch.nn.Module):
590            def __init__(self) -> None:
591                super().__init__()
592                self.conv1d = nn.Conv1d(1, 1, 1)
593                self.conv2d = nn.Conv2d(1, 1, 1)
594                self.conv3d = nn.Conv3d(1, 1, 1)
595                self.bn1d = nn.BatchNorm1d(1)
596                self.bn2d = nn.BatchNorm2d(1)
597                self.bn3d = nn.BatchNorm3d(1)
598                self.relu = nn.ReLU()
599
600            def forward(self, x):
601                x = self.conv1d(x)
602                x = self.relu(x)
603                x = self.conv2d(x)
604                x = self.relu(x)
605                x = self.conv3d(x)
606                x = self.relu(x)
607                x = self.bn1d(x)
608                x = self.relu(x)
609                x = self.bn2d(x)
610                x = self.relu(x)
611                x = self.bn3d(x)
612                x = self.relu(x)
613                return x
614
615        m = M().eval()
616        m = fuse_fx(m)
617        expected_nodes = [
618            ns.call_module(nni.ConvReLU1d),
619            ns.call_module(nni.ConvReLU2d),
620            ns.call_module(nni.ConvReLU3d),
621            ns.call_module(nni.BNReLU2d),
622            ns.call_module(nni.BNReLU3d),
623        ]
624        self.checkGraphModuleNodes(m, expected_node_list=expected_nodes)
625
626    @skipIfNoFBGEMM
627    def test_qconfig_fused_module(self):
628        """ TODO: add test for all fused modules
629        """
630        qconfig_dict = {
631            "": None,
632            "object_type": [(nn.Linear, default_qconfig),
633                            (nn.ReLU, default_qconfig),
634                            (F.relu, default_qconfig)]
635        }
636
637        linearRelu_node_list = [
638            ns.call_function(torch.quantize_per_tensor),
639            ns.call_module(nniq.LinearReLU),
640            ns.call_method('dequantize')
641        ]
642
643        linearReluLinear_node_list = [
644            ns.call_function(torch.quantize_per_tensor),
645            ns.call_module(nniq.LinearReLU),
646            ns.call_module(nnq.Linear),
647            ns.call_method('dequantize')
648        ]
649
650        tests = [(LinearReluModel, linearRelu_node_list),
651                 (LinearReluLinearModel, linearReluLinear_node_list)]
652
653        for M, node_list in tests:
654            m = M().eval()
655            example_inputs = (torch.rand(5, 5),)
656            prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
657
658            prepared(*example_inputs)
659            quantized = convert_fx(prepared)
660
661            self.checkGraphModuleNodes(quantized, expected_node_list=node_list)
662
663    def test_problematic_fuse_example(self):
664        class LinearRelu(nn.Sequential):
665            def __init__(self) -> None:
666                super().__init__(
667                    nn.Linear(5, 5),
668                    nn.ReLU(),
669                )
670
671        class M(torch.nn.Module):
672            def __init__(self) -> None:
673                super().__init__()
674                self.lin_relu = LinearRelu()
675                self.linear = nn.Linear(5, 5)
676
677            def forward(self, x):
678                x = self.lin_relu(x)
679                x = self.linear(x)
680                return x
681
682        model = M().eval()
683        # these qconfigs somehow fail equality where default_qconfig does not
684        qconfig_dict = {
685            "": None,
686            "object_type": [
687                (torch.nn.Linear, get_default_qconfig('fbgemm')),
688                (torch.nn.ReLU, get_default_qconfig('fbgemm')),
689            ],
690        }
691        m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
692
693        self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU))
694
695    @unittest.skip("Temporarily skipping the test case, will enable after the simple"
696                   "pattern format is supported")
697    def test_fuse_addtional_fuser_method(self):
698        class MyConvReLU(torch.nn.Module):
699            pass
700
701        def my_conv_relu_fuser(conv, relu):
702            return MyConvReLU()
703
704        class M(torch.nn.Module):
705            def __init__(self) -> None:
706                super().__init__()
707                self.conv = torch.nn.Conv2d(3, 3, 3)
708                self.relu = torch.nn.ReLU()
709
710            def forward(self, x):
711                return self.relu(self.conv(x))
712
713        m = M().eval()
714        m = fuse_fx(m, fuse_custom_config={
715            "additional_fuser_method_mapping": {
716                (torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser
717            }
718        })
719        self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU))
720
721    def test_fuse_custom_pattern(self):
722        class M(torch.nn.Module):
723            def __init__(self, use_torch_add=True):
724                super().__init__()
725                self.conv = torch.nn.Conv2d(3, 3, 3)
726                self.bn = torch.nn.BatchNorm2d(3)
727                self.relu = torch.nn.ReLU()
728                self.maxpool = torch.nn.MaxPool2d(3)
729                if use_torch_add:
730                    self.add = torch.add
731                else:
732                    self.add = operator.add
733
734            def forward(self, x):
735                y = x
736                y = self.maxpool(x)
737                x = self.conv(x)
738                x = self.bn(x)
739                x = self.add(y, x)
740                x = self.relu(x)
741                return x
742
743        for use_torch_add in [True, False]:
744            m = M(use_torch_add).eval()
745
746            def fuse_conv_bn_relu(is_qat, relu, add_pattern):
747                _, _, bn_pattern = add_pattern
748                bn, conv = bn_pattern
749                return conv
750
751            conv_bn_res_relu_config1 = BackendPatternConfig() \
752                ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
753                .set_fuser_method(fuse_conv_bn_relu)
754            conv_bn_res_relu_config2 = BackendPatternConfig() \
755                ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \
756                .set_fuser_method(fuse_conv_bn_relu)
757            backend_config = BackendConfig() \
758                .set_backend_pattern_config(conv_bn_res_relu_config1) \
759                .set_backend_pattern_config(conv_bn_res_relu_config2)
760            m = fuse_fx(m, backend_config=backend_config)
761            self.assertEqual(type(m.conv), torch.nn.Conv2d)
762            # check bn and relu are gone since we replaced the whole pattern to conv
763            self.assertFalse(hasattr(m, "bn"))
764            self.assertFalse(hasattr(m, "relu"))
765
766    def test_fusion_pattern_with_multiple_inputs(self):
767        """ This test tests two keys in backend_config: root_node_getter and
768        extra_inputs_getter,
769        root_node_getter is used to identify a "root" module in the node pattern,
770        the node that we'll keep after fusion.
771        extra_inputs_getter will return a list of node that needs to be added to the
772        fused node as extra inputs.
773        """
774        class M(torch.nn.Module):
775            def __init__(self) -> None:
776                super().__init__()
777                self.conv = torch.nn.Conv2d(3, 3, 3)
778                self.bn = torch.nn.BatchNorm2d(3)
779                self.relu = torch.nn.ReLU()
780                self.maxpool = torch.nn.MaxPool2d(3)
781
782            def forward(self, x):
783                y = x
784                y = self.maxpool(x)
785                x = self.conv(x)
786                x = self.bn(x)
787                x = torch.add(x, y)
788                x = self.relu(x)
789                return x
790
791        m = M().eval()
792
793        def fuse_conv_bn_relu(is_qat, relu, add_pattern):
794            _, bn_pattern, _ = add_pattern
795            bn, conv = bn_pattern
796            return conv
797
798        def conv_bn_res_relu_root_node_getter(pattern):
799            relu, add_pattern = pattern
800            _, bn_pattern, _ = add_pattern
801            bn, conv = bn_pattern
802            return conv
803
804        def conv_bn_res_relu_extra_inputs_getter(pattern):
805            """ get inputs pattern for extra inputs, inputs for root node
806            are assumed to be copied over from root node to the fused node
807            """
808            relu, add_pattern = pattern
809            _, bn_pattern, extra_input = add_pattern
810            bn, conv = bn_pattern
811            return [extra_input]
812
813        conv_bn_res_relu_config = BackendPatternConfig() \
814            ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
815            .set_fuser_method(fuse_conv_bn_relu) \
816            ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
817            ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
818        backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config)
819        m = fuse_fx(m, backend_config=backend_config)
820        self.assertEqual(type(m.conv), torch.nn.Conv2d)
821        # check bn and relu are gone since we replaced the whole pattern to conv
822        self.assertFalse(hasattr(m, "bn"))
823        self.assertFalse(hasattr(m, "relu"))
824
825        # check conv module has two inputs
826        named_modules = dict(m.named_modules())
827        for node in m.graph.nodes:
828            if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
829                self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments"
830
831    def test_fusion_pattern_with_matchallnode(self):
832        """This test tests that the node matched by MatchAllNode will be regared as an input
833        instead of a module to be fused. For instance, we have two patterns:
834            (nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d))
835            (nn.ReLU, nn.Conv2d)
836        And we wanna fuse the following model
837            Conv2d -> ReLU +
838            Conv2d ------ Add -> ReLU
839        ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be
840        fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d.
841        """
842
843        class M(torch.nn.Module):
844            def __init__(self) -> None:
845                super().__init__()
846                self.conv1 = torch.nn.Conv2d(3, 3, 3)
847                self.relu1 = torch.nn.ReLU()
848                self.conv2 = torch.nn.Conv2d(3, 3, 3)
849                self.relu2 = torch.nn.ReLU()
850
851            def forward(self, x):
852                y = self.conv1(x)
853                y = self.relu1(y)
854
855                x = self.conv2(x)
856                x = torch.add(x, y)
857                x = self.relu2(x)
858                return x
859
860        m = M().eval()
861
862        def fuse_conv_relu(is_qat, conv, relu):
863            return conv
864
865        def fuse_conv_res_relu(is_qat, relu, add_pattern):
866            _, conv, _ = add_pattern
867            return conv
868
869        def conv_res_relu_root_node_getter(pattern):
870            relu, (_, conv, _) = pattern
871            return conv
872
873        def conv_res_relu_extra_inputs_getter(pattern):
874            relu, (_, _, extra_input) = pattern
875            return [extra_input]
876
877        conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \
878            .set_fuser_method(fuse_conv_relu)
879        conv_res_relu_config = BackendPatternConfig() \
880            ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \
881            .set_fuser_method(fuse_conv_res_relu) \
882            ._set_root_node_getter(conv_res_relu_root_node_getter) \
883            ._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter)
884        backend_config = BackendConfig() \
885            .set_backend_pattern_config(conv_relu_config) \
886            .set_backend_pattern_config(conv_res_relu_config)
887        m = fuse_fx(m, backend_config=backend_config)
888        self.assertEqual(type(m.conv1), torch.nn.Conv2d)
889        self.assertEqual(type(m.conv2), torch.nn.Conv2d)
890        # check relu are gone since we replaced both patterns to conv
891        self.assertFalse(hasattr(m, "relu1"))
892        self.assertFalse(hasattr(m, "relu2"))
893
894
895@skipIfNoFBGEMM
896class TestQuantizeFx(QuantizationTestCase):
897    def test_pattern_match(self):
898        """ test MatchAllNode with
899            conv - bn - add - relu pattern
900        """
901        class M(torch.nn.Module):
902            def __init__(self) -> None:
903                super().__init__()
904                self.conv = nn.Conv2d(1, 1, 1)
905                self.bn = nn.BatchNorm2d(1)
906                self.relu = nn.ReLU()
907
908            def forward(self, x, y):
909                x = self.conv(x)
910                x = self.bn(x)
911                x = x + y
912                x = self.relu(x)
913                return x
914
915        pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
916        m = torch.fx.symbolic_trace(M())
917        modules = dict(m.named_modules())
918        for n in m.graph.nodes:
919            if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
920                self.assertTrue(_is_match(modules, n, pattern))
921
922    def test_pattern_match_constant(self):
923        class M(torch.nn.Module):
924            def forward(self, x):
925                x, _ = torch.ops.aten.max_pool2d_with_indices.default(x)
926                return x
927
928        pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0)
929        m = torch.fx.symbolic_trace(M())
930        # eliminate the code that get the second output of maxpool, so that the pattern
931        # can be matched
932        m.graph.eliminate_dead_code()
933        modules = dict(m.named_modules())
934        for n in m.graph.nodes:
935            if n.op == "call_function" and n.target == operator.getitem:
936                self.assertTrue(_is_match(modules, n, pattern))
937
938    def test_fused_module_qat_swap(self):
939        class Tmp(torch.nn.Module):
940            def __init__(self) -> None:
941                super().__init__()
942                self.tmp = torch.nn.Linear(5, 5)
943                self.relu = torch.nn.ReLU()
944
945            def forward(self, x):
946                x = self.tmp(x)
947                return self.relu(x)
948
949
950        class M(torch.nn.Module):
951            def __init__(self) -> None:
952                super().__init__()
953                self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5))
954                self.mods2 = torch.nn.Linear(5, 5)
955
956            def forward(self, x):
957                a = self.mods1(x)
958                x = torch.add(x, 5)
959                x = self.mods2(x)
960                x = torch.add(x, 5)
961                return a, x
962
963
964        model = M().train()
965        qconfig_dict = {
966            "": None,
967            "object_type": [
968                (torch.nn.Linear, default_qat_qconfig),
969                (torch.nn.ReLU, default_qat_qconfig),
970            ],
971        }
972        prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),))
973        self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU))
974
975    def _get_conv_linear_test_cases(self, is_reference):
976        """ Returns a list of test cases, with format:
977        is_dynamic, ModuleClass, module_constructor_inputs,
978        inputs, quantized_node, weight_prepack_op
979        """
980        class FunctionalConv1d(torch.nn.Module):
981            def __init__(self, weight):
982                super().__init__()
983                self.weight = torch.nn.Parameter(weight)
984                self.stride = 1
985                self.padding = 0
986                self.dilation = 1
987                self.groups = 1
988
989            def forward(self, x):
990                return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
991
992
993        class Conv1d(torch.nn.Module):
994            def __init__(self, *args):
995                super().__init__()
996                self.conv = torch.nn.Conv1d(*args)
997
998            def forward(self, x):
999                return self.conv(x)
1000
1001        conv1d_input = torch.rand(1, 3, 224)
1002        conv1d_weight = torch.rand(3, 3, 3)
1003        conv1d_module_args = (3, 3, 3)
1004
1005        class FunctionalConv2d(torch.nn.Module):
1006            def __init__(self, weight):
1007                super().__init__()
1008                self.weight = torch.nn.Parameter(weight)
1009                self.stride = (1, 1)
1010                self.padding = (0, 0)
1011                self.dilation = (1, 1)
1012                self.groups = 1
1013
1014            def forward(self, x):
1015                return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
1016
1017        class Conv2d(torch.nn.Module):
1018            def __init__(self, *args):
1019                super().__init__()
1020                self.conv = torch.nn.Conv2d(*args)
1021
1022            def forward(self, x):
1023                return self.conv(x)
1024
1025        conv2d_input = torch.rand(1, 3, 224, 224)
1026        conv2d_weight = torch.rand(3, 3, 3, 3)
1027        conv2d_module_args = (3, 3, 3)
1028
1029        class FunctionalConv3d(torch.nn.Module):
1030            def __init__(self, weight):
1031                super().__init__()
1032                self.weight = torch.nn.Parameter(weight)
1033                self.stride = (1, 1, 1)
1034                self.padding = (0, 0, 0)
1035                self.dilation = (1, 1, 1)
1036                self.groups = 1
1037
1038            def forward(self, x):
1039                return F.conv3d(
1040                    x,
1041                    self.weight,
1042                    None,
1043                    self.stride,
1044                    self.padding,
1045                    self.dilation,
1046                    self.groups,
1047                )
1048
1049        class Conv3d(torch.nn.Module):
1050            def __init__(self, *args):
1051                super().__init__()
1052                self.conv = torch.nn.Conv3d(*args)
1053
1054            def forward(self, x):
1055                return self.conv(x)
1056
1057        conv3d_input = torch.rand(1, 3, 32, 224, 224)
1058        conv3d_weight = torch.rand(3, 3, 3, 3, 3)
1059        conv3d_module_args = (3, 3, 3)
1060
1061        class Linear(torch.nn.Module):
1062            def __init__(self, weight):
1063                super().__init__()
1064                self.weight = torch.nn.Parameter(weight)
1065
1066            def forward(self, x):
1067                return F.linear(x, self.weight)
1068
1069        linear_input = torch.rand(8, 5)
1070        linear_weight = torch.rand(10, 5)
1071
1072        class LinearModule(torch.nn.Module):
1073            def __init__(self) -> None:
1074                super().__init__()
1075                self.linear = torch.nn.Linear(5, 10)
1076
1077            def forward(self, x):
1078                return self.linear(x)
1079
1080        linear_module_input = torch.rand(8, 5)
1081
1082        # is_dynamic, ModuleClass, module_constructor_inputs,
1083        # inputs, quantized_node, weight_prepack_node
1084        tests = [
1085            (
1086                False,
1087                FunctionalConv1d,
1088                (conv1d_weight,),
1089                (conv1d_input,),
1090                ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) ,
1091                ns.call_function(torch.ops.quantized.conv1d_prepack),
1092            ),
1093            (
1094                False,
1095                FunctionalConv2d,
1096                (conv2d_weight,),
1097                (conv2d_input,),
1098                ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d),
1099                ns.call_function(torch.ops.quantized.conv2d_prepack),
1100            ),
1101            (
1102                False,
1103                FunctionalConv3d,
1104                (conv3d_weight,),
1105                (conv3d_input,),
1106                ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d),
1107                ns.call_function(torch.ops.quantized.conv3d_prepack),
1108            ),
1109            (
1110                False,
1111                Conv1d,
1112                conv1d_module_args,
1113                (conv1d_input,),
1114                ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d),
1115                None
1116            ),
1117            (
1118                False,
1119                Conv2d,
1120                conv2d_module_args,
1121                (conv2d_input,),
1122                ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d),
1123                None
1124            ),
1125            (
1126                False,
1127                Conv3d,
1128                conv3d_module_args,
1129                (conv3d_input,),
1130                ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d),
1131                None
1132            ),
1133            (
1134                True,
1135                Linear,
1136                (linear_weight,),
1137                (linear_input,),
1138                None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic),
1139                ns.call_function(torch.ops.quantized.linear_prepack),
1140            ),
1141            (
1142                False,
1143                Linear,
1144                (linear_weight,),
1145                (linear_input,),
1146                ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear),
1147                ns.call_function(torch.ops.quantized.linear_prepack),
1148            ),
1149            (
1150                True,
1151                LinearModule,
1152                (),
1153                (linear_module_input,),
1154                ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear),
1155                None,
1156            ),
1157            (
1158                False,
1159                LinearModule,
1160                (),
1161                (linear_module_input,),
1162                ns.call_module(nnqr.Linear if is_reference else nnq.Linear),
1163                None,
1164            ),
1165        ]
1166        return tests
1167
1168    @skipIfNoFBGEMM
1169    def test_conv_linear_not_reference(self):
1170        """ Test quantizing conv and linear
1171        """
1172        tests = self._get_conv_linear_test_cases(is_reference=False)
1173        for (is_dynamic, ModuleClass, module_constructor_inputs,
1174             inputs, quantized_node, weight_prepack_node) in tests:
1175            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1176            node_occurrence = {}
1177            if weight_prepack_node:
1178                node_occurrence[weight_prepack_node] = 0
1179            self.checkGraphModeFxOp(
1180                ModuleClass(*module_constructor_inputs),
1181                inputs, quant_type,
1182                expected_node=quantized_node,
1183                expected_node_occurrence=node_occurrence,
1184                is_reference=False)
1185
1186    @skipIfNoFBGEMM
1187    def test_conv_linear_reference(self):
1188        """ Test quantizing functional conv and linear with reference option
1189        """
1190        tests = self._get_conv_linear_test_cases(is_reference=True)
1191
1192        def _get_keys(prefix, is_dynamic):
1193            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
1194            if not is_dynamic:
1195                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
1196            return all_keys
1197
1198        for (is_dynamic, ModuleClass, module_constructor_inputs,
1199             inputs, quantized_node, weight_prepack_node) in tests:
1200            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1201            node_occurrence = {}
1202            if weight_prepack_node:
1203                node_occurrence[weight_prepack_node] = 0
1204            result_dict = self.checkGraphModeFxOp(
1205                ModuleClass(*module_constructor_inputs),
1206                inputs, quant_type,
1207                expected_node=quantized_node,
1208                expected_node_occurrence=node_occurrence,
1209                is_reference=True)
1210            qr = result_dict["quantized_reference"]
1211
1212            def checkWeightQParams(model):
1213                for module_name in ("linear", "conv"):
1214                    if hasattr(model, module_name):
1215                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
1216                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
1217                        self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
1218                        self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
1219
1220            def checkSerDeser(model, is_dynamic):
1221                for module_name in ("linear", "conv"):
1222                    if hasattr(model, module_name):
1223                        # make sure seralization works
1224                        state_dict = copy.deepcopy(model.state_dict())
1225                        all_keys = _get_keys(module_name, is_dynamic)
1226                        for key in all_keys:
1227                            self.assertTrue(key in state_dict)
1228                        # check load_state_dict restores states
1229                        module = getattr(model, module_name)
1230                        prev_scale = module.weight_scale
1231                        module.weight_scale = None
1232                        model.load_state_dict(state_dict)
1233                        module = getattr(model, module_name)
1234                        self.assertTrue(torch.equal(prev_scale, module.weight_scale))
1235
1236
1237            checkWeightQParams(qr)
1238            qr = copy.deepcopy(qr)
1239            # make sure the qparams are preserved after copy
1240            checkWeightQParams(qr)
1241
1242            checkSerDeser(qr, is_dynamic)
1243
1244    def _get_conv_transpose_test_cases(self, use_relu, is_reference):
1245        """ Returns a list of test cases, with format:
1246        is_dynamic, ModuleClass, module_constructor_inputs,
1247        inputs, quantized_node, weight_prepack_op
1248        """
1249        class FunctionalConvTranspose1d(torch.nn.Module):
1250            def __init__(self, weight):
1251                super().__init__()
1252                self.weight = torch.nn.Parameter(weight)
1253                self.stride = 1
1254                self.padding = 0
1255                self.output_padding = 0
1256                self.dilation = 1
1257                self.groups = 1
1258
1259            def forward(self, x):
1260                y = F.conv_transpose1d(
1261                    x,
1262                    self.weight,
1263                    None,
1264                    self.stride,
1265                    self.padding,
1266                    self.output_padding,
1267                    self.groups,
1268                    self.dilation
1269                )
1270                if use_relu:
1271                    y = F.relu(y)
1272                return y
1273
1274        class ConvTranspose1d(torch.nn.Module):
1275            def __init__(self, *args):
1276                super().__init__()
1277                self.deconv = torch.nn.ConvTranspose1d(*args)
1278                self.relu = torch.nn.ReLU()
1279
1280            def forward(self, x):
1281                y = self.deconv(x)
1282                if use_relu:
1283                    y = self.relu(y)
1284                return y
1285
1286        conv_transpose1d_input = torch.rand(1, 3, 224)
1287        conv_transpose1d_weight = torch.rand(3, 3, 3)
1288        conv_transpose1d_module_args = (3, 3, 3)
1289
1290        class FunctionalConvTranspose2d(torch.nn.Module):
1291            def __init__(self, weight):
1292                super().__init__()
1293                self.weight = torch.nn.Parameter(weight)
1294                self.stride = (1, 1)
1295                self.padding = (0, 0)
1296                self.output_padding = (0, 0)
1297                self.dilation = (1, 1)
1298                self.groups = 1
1299
1300            def forward(self, x):
1301                y = F.conv_transpose2d(
1302                    x,
1303                    self.weight,
1304                    None,
1305                    self.stride,
1306                    self.padding,
1307                    self.output_padding,
1308                    self.groups,
1309                    self.dilation
1310                )
1311                if use_relu:
1312                    y = F.relu(y)
1313                return y
1314
1315        class ConvTranspose2d(torch.nn.Module):
1316            def __init__(self, *args):
1317                super().__init__()
1318                self.deconv = torch.nn.ConvTranspose2d(*args)
1319                self.relu = torch.nn.ReLU()
1320
1321            def forward(self, x):
1322                y = self.deconv(x)
1323                if use_relu:
1324                    y = self.relu(y)
1325                return y
1326
1327        conv_transpose2d_input = torch.rand(1, 3, 224, 224)
1328        conv_transpose2d_weight = torch.rand(3, 3, 3, 3)
1329        conv_transpose2d_module_args = (3, 3, 3)
1330
1331        class FunctionalConvTranspose3d(torch.nn.Module):
1332            def __init__(self, weight):
1333                super().__init__()
1334                self.weight = torch.nn.Parameter(weight)
1335                self.stride = (1, 1, 1)
1336                self.padding = (0, 0, 0)
1337                self.output_padding = (0, 0, 0)
1338                self.dilation = (1, 1, 1)
1339                self.groups = 1
1340
1341            def forward(self, x):
1342                y = F.conv_transpose3d(
1343                    x,
1344                    self.weight,
1345                    None,
1346                    self.stride,
1347                    self.padding,
1348                    self.output_padding,
1349                    self.groups,
1350                    self.dilation
1351                )
1352                if use_relu:
1353                    y = F.relu(y)
1354                return y
1355
1356        class ConvTranspose3d(torch.nn.Module):
1357            def __init__(self, *args):
1358                super().__init__()
1359                self.deconv = torch.nn.ConvTranspose3d(*args)
1360                self.relu = torch.nn.ReLU()
1361
1362            def forward(self, x):
1363                y = self.deconv(x)
1364                if use_relu:
1365                    y = self.relu(y)
1366                return y
1367
1368        conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224)
1369        conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3)
1370        conv_transpose3d_module_args = (3, 3, 3)
1371
1372        # is_dynamic, ModuleClass, module_constructor_inputs,
1373        # inputs, quantized_node, weight_prepack_node
1374        tests = [
1375            (
1376                False,
1377                FunctionalConvTranspose1d,
1378                (conv_transpose1d_weight,),
1379                (conv_transpose1d_input,),
1380                ns.call_function(
1381                    torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d
1382                ),
1383                ns.call_function(torch.ops.quantized.conv_transpose1d_prepack),
1384            ),
1385            (
1386                False,
1387                FunctionalConvTranspose2d,
1388                (conv_transpose2d_weight,),
1389                (conv_transpose2d_input,),
1390                ns.call_function(
1391                    torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d
1392                ),
1393                ns.call_function(torch.ops.quantized.conv_transpose2d_prepack),
1394            ),
1395            (
1396                False,
1397                FunctionalConvTranspose3d,
1398                (conv_transpose3d_weight,),
1399                (conv_transpose3d_input,),
1400                ns.call_function(
1401                    torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d),
1402                ns.call_function(torch.ops.quantized.conv_transpose3d_prepack),
1403            ),
1404            (
1405                False,
1406                ConvTranspose1d,
1407                conv_transpose1d_module_args,
1408                (conv_transpose1d_input,),
1409                ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d),
1410                None
1411            ),
1412            (
1413                False,
1414                ConvTranspose2d,
1415                conv_transpose2d_module_args,
1416                (conv_transpose2d_input,),
1417                ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d),
1418                None
1419            ),
1420            (
1421                False,
1422                ConvTranspose3d,
1423                conv_transpose3d_module_args,
1424                (conv_transpose3d_input,),
1425                ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d),
1426                None
1427            ),
1428        ]
1429        return tests
1430
1431    @skipIfNoFBGEMM
1432    def test_conv_transpose_not_reference(self):
1433        """ Test quantizing transposed conv
1434        """
1435        tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False)
1436        for (is_dynamic, ModuleClass, module_constructor_inputs,
1437             inputs, quantized_node, weight_prepack_node) in tests:
1438            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1439            node_occurrence = {}
1440            if weight_prepack_node:
1441                node_occurrence[weight_prepack_node] = 0
1442            self.checkGraphModeFxOp(
1443                ModuleClass(*module_constructor_inputs),
1444                inputs, quant_type,
1445                expected_node=quantized_node,
1446                expected_node_occurrence=node_occurrence,
1447                is_reference=False)
1448
1449    @skipIfNoFBGEMM
1450    def test_conv_transpose_reference(self):
1451        """ Test quantizing transposed conv with reference option
1452        """
1453        tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True)
1454
1455        def _get_keys(prefix, is_dynamic):
1456            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
1457            if not is_dynamic:
1458                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
1459            return all_keys
1460
1461        for (is_dynamic, ModuleClass, module_constructor_inputs,
1462             inputs, quantized_node, weight_prepack_node) in tests:
1463            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1464            node_occurrence = {}
1465            if weight_prepack_node:
1466                node_occurrence[weight_prepack_node] = 0
1467            result_dict = self.checkGraphModeFxOp(
1468                ModuleClass(*module_constructor_inputs),
1469                inputs, quant_type,
1470                expected_node=quantized_node,
1471                expected_node_occurrence=node_occurrence,
1472                is_reference=True)
1473            qr = result_dict["quantized_reference"]
1474
1475            def checkWeightQParams(model):
1476                module_name = "deconv"
1477                if hasattr(model, module_name):
1478                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
1479                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
1480                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
1481                    self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
1482
1483            def checkSerDeser(model, is_dynamic):
1484                module_name = "deconv"
1485                if hasattr(model, module_name):
1486                    # make sure seralization works
1487                    state_dict = copy.deepcopy(model.state_dict())
1488                    all_keys = _get_keys(module_name, is_dynamic)
1489                    for key in all_keys:
1490                        self.assertTrue(key in state_dict)
1491                    # check load_state_dict restores states
1492                    module = getattr(model, module_name)
1493                    prev_scale = module.weight_scale
1494                    module.weight_scale = None
1495                    model.load_state_dict(state_dict)
1496                    module = getattr(model, module_name)
1497                    self.assertTrue(torch.equal(prev_scale, module.weight_scale))
1498
1499
1500            checkWeightQParams(qr)
1501            qr = copy.deepcopy(qr)
1502            # make sure the qparams are preserved after copy
1503            checkWeightQParams(qr)
1504
1505            checkSerDeser(qr, is_dynamic)
1506
1507    def test_conv_transpose_relu_not_reference(self):
1508        """ Test quantizing transposed conv + relu
1509            Fusion with relu is not supported.
1510        """
1511        tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False)
1512        for (is_dynamic, ModuleClass, module_constructor_inputs,
1513             inputs, quantized_node, weight_prepack_node) in tests:
1514            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1515            node_occurrence = {}
1516            if weight_prepack_node:
1517                node_occurrence[weight_prepack_node] = 0
1518            if quantized_node.op == 'call_module':
1519                node_occurrence[ns.call_module(nn.ReLU)] = 1
1520            else:
1521                node_occurrence[ns.call_function(F.relu)] = 1
1522            self.checkGraphModeFxOp(
1523                ModuleClass(*module_constructor_inputs),
1524                inputs, quant_type,
1525                expected_node=quantized_node,
1526                expected_node_occurrence=node_occurrence,
1527                is_reference=False)
1528
1529    @skipIfNoFBGEMM
1530    def test_conv_transpose_relu_reference(self):
1531        """ Test quantizing transposed conv with reference option
1532            Fusion with relu is not supported.
1533        """
1534        tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True)
1535
1536        def _get_keys(prefix, is_dynamic):
1537            all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]]
1538            if not is_dynamic:
1539                all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]])
1540            return all_keys
1541
1542        for (is_dynamic, ModuleClass, module_constructor_inputs,
1543             inputs, quantized_node, weight_prepack_node) in tests:
1544            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
1545            node_occurrence = {}
1546            if weight_prepack_node:
1547                node_occurrence[weight_prepack_node] = 0
1548            if quantized_node.op == 'call_module':
1549                node_occurrence[ns.call_module(nn.ReLU)] = 1
1550            else:
1551                node_occurrence[ns.call_function(F.relu)] = 1
1552            result_dict = self.checkGraphModeFxOp(
1553                ModuleClass(*module_constructor_inputs),
1554                inputs, quant_type,
1555                expected_node=quantized_node,
1556                expected_node_occurrence=node_occurrence,
1557                is_reference=True)
1558            qr = result_dict["quantized_reference"]
1559
1560            def checkWeightQParams(model):
1561                module_name = "deconv"
1562                if hasattr(model, module_name):
1563                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme"))
1564                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale"))
1565                    self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point"))
1566                    self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name())
1567
1568            def checkSerDeser(model, is_dynamic):
1569                module_name = "deconv"
1570                if hasattr(model, module_name):
1571                    # make sure seralization works
1572                    state_dict = copy.deepcopy(model.state_dict())
1573                    all_keys = _get_keys(module_name, is_dynamic)
1574                    for key in all_keys:
1575                        self.assertTrue(key in state_dict)
1576                    # check load_state_dict restores states
1577                    module = getattr(model, module_name)
1578                    prev_scale = module.weight_scale
1579                    module.weight_scale = None
1580                    model.load_state_dict(state_dict)
1581                    module = getattr(model, module_name)
1582                    self.assertTrue(torch.equal(prev_scale, module.weight_scale))
1583
1584
1585            checkWeightQParams(qr)
1586            qr = copy.deepcopy(qr)
1587            # make sure the qparams are preserved after copy
1588            checkWeightQParams(qr)
1589
1590            checkSerDeser(qr, is_dynamic)
1591
1592    @skipIfNoFBGEMM
1593    def test_dynamic_quant_weight_observer(self):
1594        ''' Test that weight observer is run in convert step
1595        '''
1596
1597        class M(torch.nn.Module):
1598            def __init__(self, weight):
1599                super().__init__()
1600                self.weight = torch.nn.Parameter(weight)
1601
1602            def forward(self, x):
1603                return F.linear(x, self.weight)
1604
1605        m = M(torch.rand(1, 1)).eval()
1606        qconfig = default_dynamic_qconfig
1607        qconfig_dict = {'': qconfig}
1608        example_inputs = (torch.rand(1, 1),)
1609        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1610        quantized = convert_to_reference_fx(prepared)
1611        qparams = (quantized._scale_0, quantized._zero_point_0)
1612        weight_obs = qconfig.weight()
1613        weight_obs(quantized.weight)
1614        # Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1])
1615        ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item())
1616        self.assertEqual(qparams, ref_qparams)
1617
1618    def test_conv_bn_relu(self):
1619        """ Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU"
1620        """
1621        convs = {
1622            1: nn.Conv1d,
1623            2: nn.Conv2d,
1624            3: nn.Conv3d,
1625        }
1626        bns = {
1627            1: nn.BatchNorm1d,
1628            2: nn.BatchNorm2d,
1629            3: nn.BatchNorm3d,
1630        }
1631        quantized_convs = {
1632            1: nnq.Conv1d,
1633            2: nnq.Conv2d,
1634            3: nnq.Conv3d,
1635        }
1636        quantized_conv_relus = {
1637            1: nniq.ConvReLU1d,
1638            2: nniq.ConvReLU2d,
1639            3: nniq.ConvReLU3d,
1640        }
1641
1642        class M(torch.nn.Module):
1643            def __init__(self, dim, has_relu):
1644                super().__init__()
1645                self.conv = convs[dim](3, 3, 3)
1646                self.bn = bns[dim](3)
1647                self.relu = nn.ReLU() if has_relu else nn.Identity()
1648                self.has_relu = has_relu
1649                self.quant = QuantStub()
1650                self.dequant = DeQuantStub()
1651
1652            def forward(self, x):
1653                x = self.quant(x)
1654                x = self.conv(x)
1655                x = self.bn(x)
1656                if self.has_relu:
1657                    x = self.relu(x)
1658                x = self.dequant(x)
1659                return x
1660
1661        options = itertools.product([1, 2, 3], [True, False], self.static_quant_types)
1662        for dim, has_relu, quant_type in options:
1663            expected_node = ns.call_module(
1664                quantized_conv_relus[dim] if has_relu
1665                else quantized_convs[dim])
1666            m = M(dim, has_relu)
1667            m_eager = copy.deepcopy(m)
1668            result_dict = self.checkGraphModeFxOp(
1669                m,
1670                self.img_data_dict[dim],
1671                quant_type,
1672                expected_node=expected_node,
1673            )
1674            result = result_dict["quantized_output"]
1675
1676            # check numerics
1677            qengine = torch.backends.quantized.engine
1678            if quant_type == QuantType.STATIC:
1679                m_eager.eval()
1680                qconfig = get_default_qconfig(qengine)
1681                prepare_fn = prepare
1682                is_qat = False
1683            else:
1684                m_eager.train()
1685                qconfig = get_default_qat_qconfig(qengine)
1686                prepare_fn = prepare_qat
1687                is_qat = True
1688
1689            fuse_list = ["conv", "bn"]
1690            if has_relu:
1691                fuse_list.append("relu")
1692            if is_qat:
1693                fuse_modules_qat(m_eager, fuse_list, inplace=True)
1694            else:
1695                fuse_modules(m_eager, fuse_list, inplace=True)
1696            m_eager.qconfig = qconfig
1697            m_eager = prepare_fn(m_eager)
1698            prepared_fx = result_dict["prepared"]
1699
1700            m_eager(*self.img_data_dict[dim][0])
1701            m_eager = convert(m_eager)
1702            result_eager = m_eager(*self.img_data_dict[dim][0])
1703            self.assertEqual(result, result_eager)
1704
1705    def test_linear_bn(self):
1706        class M(torch.nn.Module):
1707            def __init__(self) -> None:
1708                super().__init__()
1709                self.linear = nn.Linear(4, 4)
1710                self.bn = nn.BatchNorm1d(4)
1711                self.quant = QuantStub()
1712                self.dequant = DeQuantStub()
1713
1714            def forward(self, x):
1715                x = self.quant(x)
1716                x = self.linear(x)
1717                x = self.bn(x)
1718                x = self.dequant(x)
1719                return x
1720
1721        data = (torch.randn(4, 4),)
1722        for quant_type in self.static_quant_types:
1723            expected_node = ns.call_module(nnq.Linear)
1724            m = M()
1725            m_eager = copy.deepcopy(m)
1726            result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node)
1727            result = result_dict["quantized_output"]
1728
1729            # check numerics vs eager mode
1730            fuse_list = ["linear", "bn"]
1731            qengine = torch.backends.quantized.engine
1732            if quant_type == QuantType.STATIC:
1733                m_eager.eval()
1734                qconfig = get_default_qconfig(qengine)
1735                prepare_fn = prepare
1736                fuse_modules(m_eager, fuse_list, inplace=True)
1737            else:
1738                m_eager.train()
1739                qconfig = get_default_qat_qconfig(qengine)
1740                prepare_fn = prepare_qat
1741                fuse_modules_qat(m_eager, fuse_list, inplace=True)
1742            m_eager.qconfig = qconfig
1743            m_eager = prepare_fn(m_eager)
1744            m_eager(*data)
1745            m_eager = convert(m_eager)
1746            result_eager = m_eager(*data)
1747            self.assertEqual(result, result_eager)
1748
1749    @skipIfNoFBGEMM
1750    def test_dynamic_quant_fp16(self):
1751        with override_quantized_engine('fbgemm'):
1752            class Linear(torch.nn.Module):
1753                def __init__(self, weight):
1754                    super().__init__()
1755                    self.weight = torch.nn.Parameter(weight)
1756
1757                def forward(self, x):
1758                    return F.linear(x, self.weight)
1759
1760            linear_input = torch.rand(8, 5)
1761            linear_weight = torch.rand(10, 5)
1762
1763            class LinearModule(torch.nn.Module):
1764                def __init__(self) -> None:
1765                    super().__init__()
1766                    self.linear = torch.nn.Linear(5, 10)
1767
1768                def forward(self, x):
1769                    return self.linear(x)
1770
1771            linear_module_input = torch.rand(8, 5)
1772
1773            tests = [
1774                (Linear, (linear_weight,), (linear_input,),
1775                 ns.call_function(torch.ops.quantized.linear_dynamic_fp16),
1776                 ns.call_function(torch.ops.quantized.linear_prepack_fp16)),
1777                (LinearModule, (), (linear_module_input,),
1778                 ns.call_module(nnqd.Linear),
1779                 None),
1780            ]
1781            for (ModuleClass, module_constructor_inputs,
1782                 inputs, quantized_node, weight_prepack_node) in tests:
1783                for is_reference in [True, False]:
1784                    node_occurrence = {}
1785                    if weight_prepack_node:
1786                        node_occurrence[weight_prepack_node] = 0
1787                    m = ModuleClass(*module_constructor_inputs).eval()
1788                    qconfig_dict = {"": float16_dynamic_qconfig}
1789                    m = prepare_fx(m, qconfig_dict, example_inputs=inputs)
1790                    convert_fn = convert_to_reference_fx if is_reference else convert_fx
1791                    m = convert_fn(m)
1792                    self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
1793
1794
1795
1796    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1797    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1798    @override_qengines
1799    def test_qat_prepare_device_affinity(self):
1800        """
1801        Tests that FX QAT prepare pass respects device affinity
1802        """
1803        class Model(nn.Module):
1804
1805            def __init__(self) -> None:
1806                super().__init__()
1807                self.conv = nn.Conv2d(1, 1, 1)
1808                self.bn = nn.BatchNorm2d(1)
1809                self.relu = nn.ReLU()
1810
1811            def forward(self, x):
1812                x = self.conv(x)
1813                x = self.bn(x)
1814                x = self.relu(x)
1815                return x
1816
1817        model = Model()
1818        qengine = torch.backends.quantized.engine
1819        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
1820        device = torch.device('cuda:0')
1821        model.to(device)
1822
1823        example_inputs = (torch.randn(4, 1, 4, 4, device=device),)
1824        # QAT prepare
1825        model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
1826
1827        # ensure that running an input on CUDA works without any needed changes
1828        model(*example_inputs)
1829
1830        # ensure all buffers and parameters are on the device we expect
1831        model_devices = {p.device for p in model.parameters()} | \
1832            {p.device for p in model.buffers()}
1833        self.assertEqual(len(model_devices), 1)
1834        model_device = next(iter(model_devices))
1835        self.assertEqual(model_device, device)
1836
1837    @skipIfNoFBGEMM
1838    def test_dict_output(self):
1839        """ Make sure quantization runs for models with dictionary output
1840        """
1841        class M(torch.nn.Module):
1842            def __init__(self) -> None:
1843                super().__init__()
1844                self.conv = torch.nn.Conv2d(1, 1, 1)
1845
1846            def forward(self, x):
1847                return {"output": self.conv(x["input"])}
1848
1849        example_inputs = ({"input": torch.randn(1, 1, 1, 1)},)
1850        m = M().eval()
1851        qconfig_dict = {"": default_qconfig}
1852        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1853        m(*example_inputs)
1854        m = convert_fx(m)
1855        m(*example_inputs)
1856
1857    @override_qengines
1858    def test_attention(self):
1859        """ Make sure quantization runs for a corner case in attention module
1860        """
1861        class M(torch.nn.Module):
1862            def __init__(self) -> None:
1863                super().__init__()
1864                self.conv = torch.nn.Conv2d(1, 1, 1)
1865
1866            def forward(self, x):
1867                x = self.conv(x)
1868                q, k, v = x.chunk(3, dim=0)
1869                q = q.contiguous().view(-1, 1).transpose(0, 1)
1870                k = k.contiguous().view(-1, 1).transpose(0, 1)
1871                v = v.contiguous().view(-1, 1).transpose(0, 1)
1872                torch._assert(
1873                    k.size(1) == 1, "key size should be equal to 1"
1874                )
1875                r = torch.mm(k, v)
1876                return q * k + r
1877
1878        example_inputs = (torch.randn(3, 1, 1, 1),)
1879        m = M().eval()
1880        qconfig_dict = {
1881            "": None,
1882            "object_type": [
1883                (nn.Conv2d, default_qconfig),
1884            ]
1885        }
1886        # make sure it runs
1887        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
1888        m(*example_inputs)
1889        m = convert_fx(m)
1890        m(*example_inputs)
1891
1892    def _test_standalone_module(
1893            self,
1894            interface_config,
1895            prepare_count_check,
1896            standalone_prepare_count_check,
1897            convert_count_check,
1898            standalone_convert_count_check):
1899        """ Test standalone module with different quantized input/quantized output
1900        configurations
1901        """
1902        class StandaloneModule(torch.nn.Module):
1903            def __init__(self) -> None:
1904                super().__init__()
1905                self.conv = torch.nn.Conv2d(1, 1, 1)
1906
1907            def forward(self, x):
1908                return self.conv(x)
1909
1910        class M(torch.nn.Module):
1911            def __init__(self) -> None:
1912                super().__init__()
1913                self.conv = torch.nn.Conv2d(1, 1, 1)
1914                self.standalone = StandaloneModule()
1915
1916            def forward(self, x):
1917                x = self.conv(x)
1918                x = self.standalone(x)
1919                return x
1920
1921        class RefM(torch.nn.Module):
1922            def __init__(self) -> None:
1923                super().__init__()
1924                self.conv1 = torch.nn.Conv2d(1, 1, 1)
1925                self.conv2 = torch.nn.Conv2d(1, 1, 1)
1926
1927            def forward(self, x):
1928                x = self.conv1(x)
1929                x = self.conv2(x)
1930                return x
1931
1932        example_inputs = (torch.randn(1, 1, 1, 1),)
1933        # instantiate M and RefM and align the parameters
1934        original_m = M().eval()
1935        original_ref_m = RefM().eval()
1936        original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
1937        original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
1938        original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
1939        original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
1940
1941        for is_name in [True, False]:
1942            sm_example_inputs = example_inputs
1943            if is_name:
1944                prepare_config = {
1945                    "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)]
1946                }
1947            else:
1948                prepare_config = {
1949                    "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)]
1950                }
1951
1952            original_m_copy = copy.deepcopy(original_m)
1953            original_ref_m_copy = copy.deepcopy(original_ref_m)
1954
1955            qconfig_dict = {"": default_qconfig}
1956            # check prepared model
1957            m = prepare_fx(
1958                original_m_copy,
1959                qconfig_dict,
1960                example_inputs=example_inputs,
1961                prepare_custom_config=prepare_config)
1962            # calibration
1963            m(*example_inputs)
1964            self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
1965            self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
1966
1967            # check converted/quantized model
1968            m = convert_fx(m)
1969            self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
1970            self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
1971            res = m(*example_inputs)
1972
1973            # quantize the reference model
1974            ref_m = prepare_fx(
1975                original_ref_m_copy,
1976                qconfig_dict,
1977                example_inputs=example_inputs,
1978            )
1979            ref_m(*example_inputs)
1980            ref_m = convert_fx(ref_m)
1981            ref_res = ref_m(*example_inputs)
1982            self.assertEqual(res, ref_res)
1983
1984    def test_standalone_module_float_interface(self):
1985        float_interface_config = {
1986            "input_quantized_idxs": [],  # float input
1987            "output_quantized_idxs": [],  # float output
1988        }
1989        interface_config = float_interface_config
1990        # input and output of first conv, observer for standalone module
1991        # will be inserted in the standalone module itself
1992        prepare_count_check = {
1993            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
1994        }
1995        # for input and output of conv in the standalone module
1996        standalone_prepare_count_check = {
1997            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
1998        }
1999        convert_count_check = {
2000            ns.call_function(torch.quantize_per_tensor) : 1,
2001            ns.call_module(nnq.Conv2d) : 1,
2002            ns.call_method("dequantize") : 1,
2003        }
2004        standalone_convert_count_check = {
2005            # standalone module will take float as input and output
2006            # so we'll see quantize and dequantize in the modoule
2007            ns.call_function(torch.quantize_per_tensor) : 1,
2008            ns.call_module(nnq.Conv2d): 1,
2009            ns.call_method("dequantize") : 1,
2010        }
2011        self._test_standalone_module(
2012            interface_config,
2013            prepare_count_check,
2014            standalone_prepare_count_check,
2015            convert_count_check,
2016            standalone_convert_count_check)
2017
2018    def test_standalone_module_quantized_interface(self):
2019        quantized_interface_config = {
2020            "input_quantized_idxs": [0],  # quantized input
2021            "output_quantized_idxs": [0],  # quantized output
2022        }
2023        interface_config = quantized_interface_config
2024        # observer for input and output of first conv
2025        prepare_count_check = {
2026            ns.call_module(torch.ao.quantization.MinMaxObserver): 2
2027        }
2028        # for output of conv in the standalone module
2029        standalone_prepare_count_check = {
2030            ns.call_module(torch.ao.quantization.MinMaxObserver): 1
2031        }
2032        convert_count_check = {
2033            # quantizing input for conv
2034            ns.call_function(torch.quantize_per_tensor) : 1,
2035            ns.call_module(nnq.Conv2d) : 1,
2036            # dequantizing output of standalone module
2037            ns.call_method("dequantize") : 1,
2038        }
2039        standalone_convert_count_check = {
2040            # quantization of input happens in parent module
2041            # quantization of output happens in the quantized conv module
2042            ns.call_function(torch.quantize_per_tensor) : 0,
2043            ns.call_module(nnq.Conv2d): 1,
2044            # dequantization for output happens in parent module
2045            ns.call_method("dequantize") : 0,
2046        }
2047        self._test_standalone_module(
2048            interface_config,
2049            prepare_count_check,
2050            standalone_prepare_count_check,
2051            convert_count_check,
2052            standalone_convert_count_check)
2053
2054    @skipIfNoFBGEMM
2055    def test_qconfig_none(self):
2056        class M(torch.nn.Module):
2057            def __init__(self) -> None:
2058                super().__init__()
2059                self.conv1 = nn.Conv2d(1, 1, 1)
2060                self.conv2 = nn.Conv2d(1, 1, 1)
2061
2062            def forward(self, x):
2063                x = self.conv1(x)
2064                x = self.conv2(x)
2065                return x
2066
2067        m = M().eval()
2068        qconfig_dict = {"": default_qconfig,
2069                        "module_name": [("conv2", None)]}
2070        example_inputs = (torch.randn(1, 1, 1, 1),)
2071        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2072        m(*example_inputs)
2073        m = convert_fx(m)
2074        m(*example_inputs)
2075        # first conv is quantized, second conv is not quantized
2076        node_list = [
2077            ns.call_function(torch.quantize_per_tensor),
2078            ns.call_module(nnq.Conv2d),
2079            ns.call_method("dequantize"),
2080            ns.call_module(nn.Conv2d),
2081        ]
2082        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2083
2084    def test_qconfig_module_type(self):
2085        class M(torch.nn.Module):
2086            def __init__(self) -> None:
2087                super().__init__()
2088                self.conv = nn.Conv2d(1, 1, 1)
2089                self.linear = nn.Linear(9, 3)
2090
2091            def forward(self, x):
2092                x = self.conv(x)
2093                x = x.reshape((1, -1))
2094                x = self.linear(x)
2095                return x
2096
2097        m = M().eval()
2098        qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
2099        example_inputs = (torch.randn(1, 1, 3, 3),)
2100        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2101        m(*example_inputs)
2102        m = convert_fx(m)
2103        m(*example_inputs)
2104        # conv is quantized, linear is not quantized
2105        node_list = [
2106            ns.call_function(torch.quantize_per_tensor),
2107            ns.call_module(nnq.Conv2d),
2108            ns.call_method("dequantize"),
2109            ns.call_module(nn.Linear),
2110        ]
2111        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2112
2113    def test_qconfig_qat_module_type(self):
2114        class LinearRelu(nn.Sequential):
2115            def __init__(self) -> None:
2116                super().__init__(
2117                    nn.Linear(5, 5),
2118                    nn.ReLU(),
2119                )
2120
2121        class M(torch.nn.Module):
2122            def __init__(self) -> None:
2123                super().__init__()
2124                self.lin_relu = LinearRelu()
2125                self.linear = nn.Linear(5, 5)
2126
2127            def forward(self, x):
2128                x = self.lin_relu(x)
2129                x = self.linear(x)
2130                return x
2131
2132        model = M().train()
2133
2134        qconfig_dict = {
2135            "": None,
2136            "object_type": [
2137                (torch.nn.Linear, default_qat_qconfig),
2138                (torch.nn.ReLU, default_qat_qconfig),
2139            ],
2140        }
2141        example_inputs = (torch.rand(5, 5),)
2142        m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
2143        m(*example_inputs)
2144        m = convert_fx(m)
2145        m(*example_inputs)
2146        node_list = [
2147            ns.call_function(torch.quantize_per_tensor),
2148            ns.call_module(nniq.LinearReLU),
2149            ns.call_module(nnq.Linear),
2150            ns.call_method("dequantize"),
2151        ]
2152        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2153
2154    def test_qconfig_function(self):
2155        class M(torch.nn.Module):
2156            def forward(self, x, y):
2157                return x + y
2158
2159        m = M().eval()
2160        qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
2161        data = torch.randn(1, 1, 1, 1)
2162        example_inputs = (data, data)
2163        m = prepare_fx(m, qconfig_dict, example_inputs)
2164        m(*example_inputs)
2165        m = convert_fx(m)
2166        m(*example_inputs)
2167        # first conv is quantized, second conv is not quantized
2168        node_list = [
2169            ns.call_function(torch.quantize_per_tensor),
2170            ns.call_function(torch.ops.quantized.add),
2171            ns.call_method("dequantize"),
2172        ]
2173        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2174
2175    def test_qconfig_module_name_regex(self):
2176        class M(torch.nn.Module):
2177            def __init__(self) -> None:
2178                super().__init__()
2179                self.conv1 = nn.Conv2d(1, 1, 1)
2180                self.conv2 = nn.Conv2d(1, 1, 1)
2181
2182            def forward(self, x):
2183                x = self.conv1(x)
2184                x = self.conv2(x)
2185                return x
2186
2187        m = M().eval()
2188        qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
2189        example_inputs = (torch.randn(1, 1, 1, 1),)
2190        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2191        m(*example_inputs)
2192        m = convert_fx(m)
2193        m(*example_inputs)
2194        # first conv is quantized, second conv is not quantized
2195        node_list = [
2196            ns.call_function(torch.quantize_per_tensor),
2197            ns.call_module(nnq.Conv2d),
2198            ns.call_module(nnq.Conv2d),
2199            ns.call_method("dequantize"),
2200        ]
2201        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2202
2203    def test_qconfig_precedence(self):
2204        for device in get_supported_device_types():
2205            class M(torch.nn.Module):
2206                def __init__(self) -> None:
2207                    super().__init__()
2208                    self.linear = nn.Linear(1, 1)
2209                    self.conv = nn.Conv2d(1, 1, 1)
2210                    self.module_conv1 = nn.Conv2d(1, 1, 1)
2211                    self.module_conv2 = nn.Conv2d(1, 1, 1)
2212
2213                def forward(self, x):
2214                    # global
2215                    x = self.linear(x)
2216                    # global + object_type --> object_type
2217                    x = self.conv(x)
2218                    # global + object_type + module_name_regex --> module_name_regex
2219                    x = self.module_conv1(x)
2220                    # global + object_type + module_name_regex + module_name --> module_name
2221                    x = self.module_conv2(x)
2222                    return x
2223
2224            m = M().to(device).eval()
2225
2226            global_qconfig = default_qconfig
2227            object_type_qconfig = default_dynamic_qconfig
2228            module_name_regex_qconfig = float16_dynamic_qconfig
2229            module_name_qconfig = default_qat_qconfig
2230            qconfig_dict = {
2231                "": global_qconfig,
2232                "object_type": [(nn.Conv2d, object_type_qconfig)],
2233                "module_name_regex": [("module_conv*", module_name_regex_qconfig)],
2234                "module_name": [("module_conv2", module_name_qconfig)]}
2235            m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),))
2236            self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func)
2237            self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func)
2238            self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func)
2239            self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func)
2240            self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func)
2241            self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func)
2242            self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func)
2243            self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func)
2244
2245    def test_qconfig_module_name_object_type_order(self):
2246        class M1(torch.nn.Module):
2247            def __init__(self) -> None:
2248                super().__init__()
2249                self.fc1 = nn.Linear(1, 1)
2250                self.fc2 = nn.Linear(1, 1)
2251
2252            def forward(self, x):
2253                x = self.fc1(x)
2254                x = self.fc2(x)
2255                x = torch.add(x, x)
2256                x = torch.add(x, x)
2257                return x
2258
2259        class M2(torch.nn.Module):
2260            def __init__(self) -> None:
2261                super().__init__()
2262                self.fc1 = nn.Linear(1, 1)
2263                self.fc2 = nn.Linear(1, 1)
2264                self.m1 = M1()
2265
2266            def forward(self, x):
2267                x = self.fc1(x)
2268                x = self.fc2(x)
2269                x = torch.add(x, x)
2270                x = torch.add(x, x)
2271                x = self.m1(x)
2272                return x
2273
2274        class M3(torch.nn.Module):
2275            def __init__(self) -> None:
2276                super().__init__()
2277                self.fc1 = nn.Linear(1, 1)
2278                self.fc2 = nn.Linear(1, 1)
2279                self.m2 = M2()
2280
2281            def forward(self, x):
2282                x = self.fc1(x)
2283                x = self.fc2(x)
2284                x = torch.add(x, x)
2285                x = torch.add(x, x)
2286                x = self.m2(x)
2287                return x
2288
2289        m = M3().eval()
2290        qconfig_dict = {
2291            "module_name_object_type_order": [
2292                # test various FQNs: global, single child, multiple children
2293                ("", nn.Linear, 0, torch.ao.quantization.default_qconfig),
2294                ("", torch.add, 0, torch.ao.quantization.default_qconfig),
2295                ("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig),
2296                ("m2", torch.add, 1, torch.ao.quantization.default_qconfig),
2297                ("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig),
2298                ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig),
2299            ],
2300        }
2301        example_inputs = (torch.randn(1, 1, 1, 1),)
2302        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2303        m(*example_inputs)
2304        m = convert_fx(m)
2305        m(*example_inputs)
2306
2307        node_list = [
2308            # m3
2309            ns.call_function(torch.quantize_per_tensor),
2310            ns.call_module(nnq.Linear),
2311            ns.call_method("dequantize"),
2312            ns.call_module(nn.Linear),
2313            ns.call_function(torch.quantize_per_tensor),
2314            ns.call_function(torch.ops.quantized.add),
2315            ns.call_method("dequantize"),
2316            ns.call_function(torch.add),
2317            # m2
2318            ns.call_module(nn.Linear),
2319            ns.call_function(torch.quantize_per_tensor),
2320            ns.call_module(nnq.Linear),
2321            ns.call_method("dequantize"),
2322            ns.call_function(torch.add),
2323            ns.call_function(torch.quantize_per_tensor),
2324            ns.call_function(torch.ops.quantized.add),
2325            # m1
2326            ns.call_module(nnq.Linear),
2327            ns.call_method("dequantize"),
2328            ns.call_module(nn.Linear),
2329            ns.call_function(torch.quantize_per_tensor),
2330            ns.call_function(torch.ops.quantized.add),
2331            ns.call_method("dequantize"),
2332            ns.call_function(torch.add),
2333        ]
2334        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2335
2336        # test that function order overrides global qconfig
2337        class M4(torch.nn.Module):
2338            def __init__(self) -> None:
2339                super().__init__()
2340                self.fc1 = nn.Linear(1, 1)
2341                self.fc2 = nn.Linear(1, 1)
2342
2343            def forward(self, x):
2344                x = self.fc1(x)
2345                x = self.fc2(x)
2346                x = torch.add(x, x)
2347                x = torch.add(x, x)
2348                return x
2349
2350        m = M4().eval()
2351        qconfig_dict = {
2352            "": torch.ao.quantization.default_qconfig,
2353            "module_name_object_type_order": [
2354                ("", nn.Linear, 1, None),
2355                ("", torch.add, 1, None),
2356            ],
2357        }
2358        example_inputs = (torch.randn(1, 1, 1, 1),)
2359        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2360        m(*example_inputs)
2361        m = convert_fx(m)
2362        m(*example_inputs)
2363
2364        node_list = [
2365            ns.call_function(torch.quantize_per_tensor),
2366            ns.call_module(nnq.Linear),
2367            ns.call_method("dequantize"),
2368            ns.call_module(nn.Linear),
2369            ns.call_function(torch.quantize_per_tensor),
2370            ns.call_function(torch.ops.quantized.add),
2371            ns.call_method("dequantize"),
2372            ns.call_function(torch.add),
2373        ]
2374        self.checkGraphModuleNodes(m, expected_node_list=node_list)
2375
2376
2377    @override_qengines
2378    def test_qconfig_dict_with_fused_modules(self):
2379        class LinearReLUModel(torch.nn.Module):
2380            def __init__(self, relu):
2381                super().__init__()
2382                self.linear = torch.nn.Linear(3, 3)
2383                self.relu = relu
2384
2385            def forward(self, x):
2386                x = self.linear(x)
2387                x = self.relu(x)
2388                return x
2389
2390        class ConvReLUModel(torch.nn.Module):
2391            def __init__(self, relu):
2392                super().__init__()
2393                self.conv = torch.nn.Conv1d(3, 3, 3)
2394                self.relu = relu
2395
2396            def forward(self, x):
2397                x = self.conv(x)
2398                x = self.relu(x)
2399                return x
2400
2401        class ConvBnReLUModel(torch.nn.Module):
2402            def __init__(self, relu):
2403                super().__init__()
2404                self.conv = torch.nn.Conv1d(3, 3, 3)
2405                self.bn = torch.nn.BatchNorm1d(3)
2406                self.relu = relu
2407
2408            def forward(self, x):
2409                x = self.conv(x)
2410                x = self.bn(x)
2411                x = self.relu(x)
2412                return x
2413
2414        for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]:
2415            for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]:
2416                m = model(relu).eval()
2417                qengine = torch.backends.quantized.engine
2418                qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine)
2419                # should not crash as in https://github.com/pytorch/pytorch/issues/75825
2420                prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
2421
2422    # TODO: move QConfigMapping tests to test/quantization/core
2423    def test_qconfig_mapping_set_global(self):
2424        qconfig = get_default_qconfig()
2425        qconfig_mapping = QConfigMapping()
2426        self.assertEqual(qconfig_mapping.global_qconfig, None)
2427        qconfig_mapping.set_global(qconfig)
2428        self.assertEqual(qconfig_mapping.global_qconfig, qconfig)
2429
2430    def test_qconfig_mapping_set_object_type(self):
2431        qconfig1 = get_default_qconfig()
2432        qconfig2 = get_default_qconfig()
2433        qconfig3 = get_default_qconfig()
2434        self.assertNotEqual(qconfig1, qconfig2)
2435        self.assertNotEqual(qconfig1, qconfig3)
2436        qconfig_mapping = QConfigMapping()
2437        self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0)
2438        # Insert some entries
2439        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1)
2440        qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2)
2441        self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2)
2442        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1)
2443        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
2444        # Override existing key
2445        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3)
2446        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3)
2447        self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2)
2448        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3)
2449        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2)
2450        self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None)
2451
2452    def test_qconfig_mapping_set_module_name_regex(self):
2453        qconfig1 = get_default_qconfig()
2454        qconfig2 = get_default_qconfig()
2455        qconfig3 = get_default_qconfig()
2456        self.assertNotEqual(qconfig1, qconfig2)
2457        self.assertNotEqual(qconfig1, qconfig3)
2458        qconfig_mapping = QConfigMapping()
2459        self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0)
2460        # Insert some entries
2461        qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1)
2462        qconfig_mapping.set_module_name_regex("foo.*", qconfig2)
2463        self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2)
2464        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1)
2465        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
2466        # Override existing key
2467        qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3)
2468        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3)
2469        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2)
2470        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3)
2471        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3)
2472        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2)
2473        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2)
2474        self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None)
2475
2476    def test_qconfig_mapping_set_module_name(self):
2477        qconfig1 = get_default_qconfig()
2478        qconfig2 = get_default_qconfig()
2479        qconfig3 = get_default_qconfig()
2480        self.assertNotEqual(qconfig1, qconfig2)
2481        self.assertNotEqual(qconfig1, qconfig3)
2482        qconfig_mapping = QConfigMapping()
2483        self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0)
2484        # Insert some entries
2485        qconfig_mapping.set_module_name("mod1", qconfig1)
2486        qconfig_mapping.set_module_name("mod2", qconfig2)
2487        self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2)
2488        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1)
2489        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
2490        # Override existing key
2491        qconfig_mapping.set_module_name("mod1", qconfig3)
2492        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3)
2493        self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2)
2494        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3)
2495        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2)
2496        self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None)
2497
2498    def test_qconfig_mapping_set_module_name_object_type_order(self):
2499        qconfig1 = get_default_qconfig()
2500        qconfig2 = get_default_qconfig()
2501        qconfig3 = get_default_qconfig()
2502        self.assertNotEqual(qconfig1, qconfig2)
2503        self.assertNotEqual(qconfig1, qconfig3)
2504        qconfig_mapping = QConfigMapping()
2505        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0)
2506        # Insert some entries
2507        qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1)
2508        qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2)
2509        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
2510        key1 = ("mod1", torch.nn.Linear, 0)
2511        key2 = ("mod2", torch.nn.ReLU, 1)
2512        self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
2513        self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
2514        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1)
2515        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
2516        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2517                         qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1)
2518        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2519                         qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
2520        # Override existing key
2521        qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3)
2522        self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2)
2523        self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1)
2524        self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2)
2525        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3)
2526        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2)
2527        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2528                         qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3)
2529        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2530                         qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2)
2531        # No match
2532        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2533                         qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None)
2534        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2535                         qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None)
2536        self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order(
2537                         qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None)
2538
2539    def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2):
2540        """
2541        Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods.
2542        """
2543        return {
2544            _GLOBAL_DICT_KEY: global_qconfig,
2545            _OBJECT_TYPE_DICT_KEY: [
2546                (torch.nn.Linear, qconfig1),
2547                (torch.nn.ReLU, qconfig2),
2548            ],
2549            _MODULE_NAME_REGEX_DICT_KEY: [
2550                ("foo.*bar", qconfig1),
2551                ("foo.*", qconfig2),
2552            ],
2553            _MODULE_NAME_DICT_KEY: [
2554                ("bazbaz", qconfig1),
2555                ("borbor", qconfig2),
2556            ],
2557            _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [
2558                ("bazbaz", torch.nn.Linear, 0, qconfig1),
2559                ("foofoo", torch.nn.ReLU, 1, qconfig2),
2560            ],
2561        }
2562
2563        with self.assertRaises(ValueError) as context:
2564            m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))  # noqa: F821
2565        self.assertTrue(
2566            'Expected qconfig_dict to have the following keys:' in str(context.exception)
2567        )
2568        self.assertTrue('But found \'object_typo\' instead.' in str(context.exception))
2569
2570    def test_qconfig_mapping_from_dict(self):
2571        global_qconfig = QConfig(123, "global")
2572        qconfig1 = QConfig(1, "one")
2573        qconfig2 = QConfig(2, "two")
2574        qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
2575        qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)]
2576        qconfig_mapping = QConfigMapping.from_dict(qconfig_dict)
2577        self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig)
2578        self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({
2579            torch.nn.Linear: qconfig1,
2580            torch.nn.ReLU: qconfig2,
2581        }))
2582        self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({
2583            "foo.*bar": qconfig1,
2584            "foo.*": qconfig2,
2585        }))
2586        self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({
2587            "bazbaz": qconfig1,
2588            "borbor": qconfig2,
2589        }))
2590        self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({
2591            ("bazbaz", torch.nn.Linear, 0): qconfig1,
2592            ("foofoo", torch.nn.ReLU, 1): qconfig2,
2593        }))
2594
2595    def test_qconfig_mapping_to_dict(self):
2596        global_qconfig = QConfig(123, "global")
2597        qconfig1 = QConfig(1, "one")
2598        qconfig2 = QConfig(2, "two")
2599        qconfig_mapping = QConfigMapping().set_global(global_qconfig) \
2600            .set_object_type(torch.nn.Linear, qconfig1) \
2601            .set_object_type(torch.nn.ReLU, qconfig2) \
2602            .set_module_name_regex("foo.*bar", qconfig1) \
2603            .set_module_name_regex("foo.*", qconfig2) \
2604            .set_module_name("bazbaz", qconfig1) \
2605            .set_module_name("borbor", qconfig2) \
2606            .set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \
2607            .set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2)
2608        qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2)
2609        self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict)
2610
2611    def test_qconfig_mapping_repr(self):
2612        self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str))
2613
2614    def test_default_qconfig_mapping_override_global(self):
2615        class M(torch.nn.Module):
2616            def __init__(self) -> None:
2617                super().__init__()
2618                self.conv = torch.nn.Conv2d(1, 1, 1)
2619
2620            def forward(self, x):
2621                return self.conv(x)
2622
2623        m = M().eval()
2624        my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer)
2625        qconfig_mapping = get_default_qconfig_mapping()
2626        # Override global qconfig
2627        old_global_qconfig = qconfig_mapping.global_qconfig
2628        qconfig_mapping.set_global(my_qconfig)
2629        # Verify the correct qconfig was used
2630        example_inputs = (torch.randn(1, 1, 1, 1),)
2631        m = prepare_fx(m, qconfig_mapping, example_inputs)
2632        self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver))
2633        self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver))
2634        self.assertTrue(hasattr(m, "activation_post_process_0"))
2635        self.assertTrue(hasattr(m, "activation_post_process_1"))
2636        self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver))
2637        self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver))
2638
2639    # Dummy classes for PrepareCustomConfig testing
2640
2641    class _DummyStandaloneModule:
2642        pass
2643
2644    class _DummyFloatModule:
2645        pass
2646
2647    class _DummyObservedModule:
2648        pass
2649
2650    class _DummyQuantizedModule:
2651        pass
2652
2653    class _DummyNonTraceableModule1:
2654        pass
2655
2656    class _DummyNonTraceableModule2:
2657        pass
2658
2659    def test_prepare_custom_config_set_standalone_module_name(self):
2660        qconfig_mapping = QConfigMapping()
2661        example_inputs = (torch.randn(3),)
2662        child_prepare_custom_config = PrepareCustomConfig()
2663        backend_config = BackendConfig("my_backend")
2664        config_entry = StandaloneModuleConfigEntry(
2665            qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
2666        prepare_custom_config = PrepareCustomConfig()
2667        self.assertEqual(len(prepare_custom_config.standalone_module_names), 0)
2668        prepare_custom_config.set_standalone_module_name(
2669            "module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
2670        self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"])
2671        self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry)
2672
2673    def test_prepare_custom_config_set_standalone_module_class(self):
2674        qconfig_mapping = QConfigMapping()
2675        example_inputs = (torch.randn(3),)
2676        child_prepare_custom_config = PrepareCustomConfig()
2677        backend_config = BackendConfig("my_backend")
2678        config_entry = StandaloneModuleConfigEntry(
2679            qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
2680        prepare_custom_config = PrepareCustomConfig()
2681        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0)
2682        prepare_custom_config.set_standalone_module_class(
2683            self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config)
2684        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
2685        self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes)
2686        self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry)
2687
2688    def test_prepare_custom_config_set_float_to_observed_mapping(self):
2689        prepare_custom_config = PrepareCustomConfig()
2690        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0)
2691        prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC)
2692        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
2693        self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
2694        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
2695        self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
2696        self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
2697                         self._DummyObservedModule)
2698
2699    def test_prepare_custom_config_set_non_traceable_module_names(self):
2700        prepare_custom_config = PrepareCustomConfig()
2701        self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0)
2702        prepare_custom_config.set_non_traceable_module_names(["module1", "module2"])
2703        self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"])
2704
2705    def test_prepare_custom_config_set_non_traceable_module_classes(self):
2706        prepare_custom_config = PrepareCustomConfig()
2707        self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0)
2708        prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
2709        self.assertEqual(prepare_custom_config.non_traceable_module_classes,
2710                         [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
2711
2712    def test_prepare_custom_config_set_input_quantized_indexes(self):
2713        prepare_custom_config = PrepareCustomConfig()
2714        self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0)
2715        prepare_custom_config.set_input_quantized_indexes([0, 1])
2716        self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])
2717
2718    def test_prepare_custom_config_set_output_quantized_indexes(self):
2719        prepare_custom_config = PrepareCustomConfig()
2720        self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0)
2721        prepare_custom_config.set_output_quantized_indexes([0, 1])
2722        self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])
2723
2724    def test_prepare_custom_config_set_preserved_attributes(self):
2725        prepare_custom_config = PrepareCustomConfig()
2726        self.assertEqual(len(prepare_custom_config.preserved_attributes), 0)
2727        prepare_custom_config.set_preserved_attributes(["attr1", "attr2"])
2728        self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])
2729
2730    def _get_dummy_prepare_custom_config_dict(self):
2731        """
2732        Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods.
2733        """
2734        return {
2735            STANDALONE_MODULE_NAME_DICT_KEY: [(
2736                "module1",
2737                QConfigMapping(),
2738                (torch.randn(3),),
2739                PrepareCustomConfig(),
2740                BackendConfig("my_backend"),
2741            )],
2742            STANDALONE_MODULE_CLASS_DICT_KEY: [(
2743                self._DummyStandaloneModule,
2744                QConfigMapping(),
2745                (torch.randn(10),),
2746                PrepareCustomConfig(),
2747                BackendConfig("my_backend"),
2748            )],
2749            FLOAT_TO_OBSERVED_DICT_KEY: {
2750                "static": {
2751                    self._DummyFloatModule: self._DummyObservedModule
2752                },
2753            },
2754            NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"],
2755            NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2],
2756            INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
2757            OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1],
2758            PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
2759        }
2760
2761    def test_prepare_custom_config_from_dict(self):
2762        prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
2763        (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
2764        (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
2765        sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1)
2766        sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2)
2767        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict)
2768
2769        # Standalone modules
2770        self.assertEqual(len(prepare_custom_config.standalone_module_names), 1)
2771        self.assertTrue(sm_name in prepare_custom_config.standalone_module_names)
2772        self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1)
2773        self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1)
2774        self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes)
2775        self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2)
2776
2777        # Float to observed mapping
2778        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1)
2779        self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC])
2780        self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1)
2781        self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC])
2782        self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule],
2783                         self._DummyObservedModule)
2784
2785        # Other
2786        self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"])
2787        self.assertEqual(prepare_custom_config.non_traceable_module_classes,
2788                         [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2])
2789        self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1])
2790        self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1])
2791        self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"])
2792
2793    def test_prepare_custom_config_to_dict(self):
2794        prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict()
2795        (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0]
2796        (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0]
2797        prepare_custom_config = PrepareCustomConfig() \
2798            .set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \
2799            .set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \
2800            .set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \
2801            .set_non_traceable_module_names(["module2", "module3"]) \
2802            .set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \
2803            .set_input_quantized_indexes([0, 1]) \
2804            .set_output_quantized_indexes([0, 1]) \
2805            .set_preserved_attributes(["attr1", "attr2"])
2806        # PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts
2807        prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1)
2808        prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2)
2809        self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict)
2810
2811    def test_convert_custom_config_set_observed_to_quantized_mapping(self):
2812        convert_custom_config = ConvertCustomConfig()
2813        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0)
2814        convert_custom_config.set_observed_to_quantized_mapping(
2815            self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC)
2816        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
2817        self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
2818        self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
2819        self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
2820                         self._DummyQuantizedModule)
2821
2822    def test_convert_custom_config_set_preserved_attributes(self):
2823        convert_custom_config = ConvertCustomConfig()
2824        self.assertEqual(len(convert_custom_config.preserved_attributes), 0)
2825        convert_custom_config.set_preserved_attributes(["attr1", "attr2"])
2826        self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])
2827
2828    def _get_dummy_convert_custom_config_dict(self):
2829        """
2830        Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods.
2831        """
2832        return {
2833            OBSERVED_TO_QUANTIZED_DICT_KEY: {
2834                "static": {
2835                    self._DummyObservedModule: self._DummyQuantizedModule
2836                },
2837            },
2838            PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]
2839        }
2840
2841    def test_convert_custom_config_from_dict(self):
2842        convert_custom_config_dict = self._get_dummy_convert_custom_config_dict()
2843        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict)
2844        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1)
2845        self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC])
2846        self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1)
2847        self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC])
2848        self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule],
2849                         self._DummyQuantizedModule)
2850        self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"])
2851
2852    def test_convert_custom_config_to_dict(self):
2853        convert_custom_config = ConvertCustomConfig() \
2854            .set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \
2855            .set_preserved_attributes(["attr1", "attr2"])
2856        self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict())
2857
2858    def test_fuse_custom_config_set_preserved_attributes(self):
2859        fuse_custom_config = FuseCustomConfig()
2860        self.assertEqual(len(fuse_custom_config.preserved_attributes), 0)
2861        fuse_custom_config.set_preserved_attributes(["attr1", "attr2"])
2862        self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])
2863
2864    def test_fuse_custom_config_from_dict(self):
2865        fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
2866        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict)
2867        self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"])
2868
2869    def test_fuse_custom_config_to_dict(self):
2870        fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]}
2871        fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
2872        self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict)
2873
2874    def test_remove_qconfig(self):
2875        class M(torch.nn.Module):
2876            def __init__(self) -> None:
2877                super().__init__()
2878                self.avg_pool = torch.nn.AvgPool2d(1)
2879
2880            def forward(self, x):
2881                return self.avg_pool(x)
2882
2883        m = M().eval()
2884        qconfig_dict = {'': default_qconfig}
2885        example_inputs = (torch.randn(1, 1, 1, 1),)
2886        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2887        m(*example_inputs)
2888        m = convert_fx(m)
2889        m(*example_inputs)
2890        for name, module in m.named_modules():
2891            self.assertFalse(hasattr(module, 'qconfig'),
2892                             'qconfig is not removed for ' + name)
2893
2894    def test_return_none(self):
2895        class M(torch.nn.Module):
2896            def forward(self, x):
2897                pass
2898
2899        m = M().eval()
2900        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
2901        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),))
2902        m = convert_fx(m)
2903
2904    def test_default_quant_after_none_qconfig(self):
2905        """ Make sure default quant is inserted properly"""
2906        class M(torch.nn.Module):
2907            def __init__(self) -> None:
2908                super().__init__()
2909                self.conv1 = torch.nn.Conv2d(1, 1, 1)
2910                self.conv2 = torch.nn.Conv2d(1, 1, 1)
2911
2912            def forward(self, x):
2913                x = self.conv1(x)
2914                x = x.transpose(1, 2)
2915                x = self.conv2(x)
2916
2917        m = M().eval()
2918        qconfig_dict = {
2919            "": default_qconfig,
2920            "module_name": [
2921                ("conv1", None)
2922            ]
2923        }
2924        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
2925        m = convert_fx(m)
2926
2927    def test_qconfig_for_call_method(self):
2928        class Sub(torch.nn.Module):
2929            def __init__(self) -> None:
2930                super().__init__()
2931                self.conv = torch.nn.Conv2d(1, 1, 1)
2932
2933            def forward(self, x):
2934                x = x.transpose(2, 3)
2935                x = self.conv(x)
2936                return x.transpose(2, 3)
2937
2938        class M(torch.nn.Module):
2939            def __init__(self) -> None:
2940                super().__init__()
2941                self.sub = Sub()
2942                self.conv1 = torch.nn.Conv2d(1, 1, 1)
2943                self.conv2 = torch.nn.Conv2d(1, 1, 1)
2944
2945            def forward(self, x):
2946                x = self.conv1(x)
2947                x = self.sub(x)
2948                x = self.conv2(x)
2949                return x.transpose(2, 3)
2950
2951        qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]}
2952        # since sub is configured to have qconfig None, we should dequantize the output
2953        # of self.conv1 and quantize the input of self.conv2
2954        # dequantize after conv2 should happen after transpose since
2955        # it is configured with default_qconfig
2956        # nodes in Sub module instance is not quantized
2957        node_list1 = [
2958            ns.call_function(torch.quantize_per_tensor),
2959            ns.call_module(nnq.Conv2d),
2960            ns.call_method("dequantize"),
2961            ns.call_method("transpose"),
2962            ns.call_module(nn.Conv2d),
2963            ns.call_method("transpose"),
2964            ns.call_function(torch.quantize_per_tensor),
2965            ns.call_module(nnq.Conv2d),
2966            ns.call_method("transpose"),
2967            ns.call_method("dequantize")
2968        ]
2969
2970        qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]}
2971        # Only nodes in Sub module instance are quantized
2972        # the first transpose is not quantized because the input is not quantized
2973        node_list2 = [
2974            ns.call_module(nn.Conv2d),
2975            ns.call_function(torch.quantize_per_tensor),
2976            ns.call_method("transpose"),
2977            ns.call_module(nnq.Conv2d),
2978            ns.call_method("transpose"),
2979            ns.call_method("dequantize"),
2980            ns.call_module(nn.Conv2d),
2981            ns.call_method("transpose"),
2982        ]
2983
2984        for qconfig_dict, node_list in [
2985                (qconfig_dict1, node_list1),
2986                (qconfig_dict2, node_list2)
2987        ]:
2988            example_inputs = (torch.randn(2, 1, 3, 3),)
2989            m = M().eval()
2990            m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
2991            m(torch.randn(2, 1, 3, 3))
2992            m = convert_fx(m)
2993            self.checkGraphModuleNodes(m, expected_node_list=node_list)
2994            # make sure it runs
2995            m(*example_inputs)
2996
2997    def test_qconfig_for_call_func(self):
2998        class Linear(torch.nn.Module):
2999            def __init__(self) -> None:
3000                super().__init__()
3001                self.w = torch.ones(5, 5)
3002                self.b = torch.zeros(5)
3003
3004            def forward(self, x):
3005                return torch.nn.functional.linear(x, self.w, self.b)
3006
3007        class M(torch.nn.Module):
3008            def __init__(self) -> None:
3009                super().__init__()
3010                self.mods1 = torch.nn.Sequential(
3011                    Linear(),
3012                    Linear()
3013                )
3014                self.mods2 = Linear()
3015
3016            def forward(self, x):
3017                x = self.mods1(x)
3018                x = self.mods2(x)
3019                return x
3020
3021        model = M().eval()
3022        example_inputs = (torch.rand(5, 5),)
3023        qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]}
3024        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
3025        m(*example_inputs)
3026
3027        m = convert_fx(m)
3028        node_list = [
3029            ns.call_function(torch.quantize_per_tensor),
3030            ns.call_function(torch.ops.quantized.linear),
3031            ns.call_function(torch.ops.quantized.linear),
3032            ns.call_method('dequantize'),
3033            ns.call_function(torch.nn.functional.linear)
3034        ]
3035        self.checkGraphModuleNodes(m, expected_node_list=node_list)
3036        m(torch.rand(5, 5))
3037
3038    def test_preserve_attributes(self):
3039        class M(torch.nn.Module):
3040            def __init__(self) -> None:
3041                super().__init__()
3042                self.conv = torch.nn.Conv2d(1, 1, 1)
3043
3044            def forward(self, x):
3045                return self.conv(x)
3046
3047        m = M()
3048        m.eval()
3049        m.preserved_attr = 3
3050        prepare_custom_config_dict = {
3051            "preserved_attributes": ["preserved_attr"]
3052        }
3053        example_inputs = (torch.randn(1, 1, 1, 1),)
3054        m = prepare_fx(
3055            m,
3056            {"": default_qconfig},
3057            example_inputs=example_inputs,
3058            prepare_custom_config=prepare_custom_config_dict)
3059
3060        def assertAttrPreserved(m):
3061            self.assertTrue(hasattr(m, "preserved_attr"))
3062            self.assertEqual(m.preserved_attr, 3)
3063
3064        assertAttrPreserved(m)
3065        convert_custom_config_dict = {
3066            "preserved_attributes": ["preserved_attr"]
3067        }
3068        m = convert_fx(m, convert_custom_config=convert_custom_config_dict)
3069        assertAttrPreserved(m)
3070
3071    @skipIfNoFBGEMM
3072    def test_qat_and_script(self):
3073        model = LinearModelWithSubmodule().train()
3074        qengine = torch.backends.quantized.engine
3075        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)}
3076        x = torch.randn(5, 5)
3077        example_inputs = (x,)
3078        model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
3079
3080        # ensure scripting works
3081        scripted = torch.jit.script(model)
3082        # run one round to make sure model runs
3083        scripted(x)
3084        FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \
3085                   .run(scripted.graph)
3086
3087        # disable fake_quant and observer
3088        for epoch in range(3):
3089            if epoch == 1:
3090                scripted.apply(torch.ao.quantization.disable_observer)
3091            if epoch == 2:
3092                scripted.apply(torch.ao.quantization.disable_fake_quant)
3093
3094        # ensure the fake_quant and observer have been disabled.
3095        matches = ['.fake_quant_enabled', '.observer_enabled']
3096        for key, v in scripted.state_dict().items():
3097            if any(x in key for x in matches):
3098                self.assertEqual(v, torch.tensor([0], dtype=torch.int64))
3099
3100        # enable them back
3101        scripted.apply(torch.ao.quantization.enable_fake_quant)
3102        scripted.apply(torch.ao.quantization.enable_observer)
3103        for key, v in scripted.state_dict().items():
3104            if any(x in key for x in matches):
3105                self.assertEqual(v, torch.tensor([1], dtype=torch.int64))
3106
3107    @skipIfNoFBGEMM
3108    def test_save_observer_state_dict(self):
3109        orig = LinearModelWithSubmodule().eval()
3110        model = orig
3111        qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
3112        x = torch.randn(5, 5)
3113        model = prepare_fx(model, qconfig_dict, example_inputs=(x,))
3114
3115        # run it through input
3116        model(x)
3117        # save state_dict of model
3118        obs_dict = torch.ao.quantization.get_observer_state_dict(model)
3119
3120        quant = convert_fx(model)
3121
3122        b = io.BytesIO()
3123        torch.save(obs_dict, b)
3124
3125        # Load the stats into new model
3126        for weights_only in [True, False]:
3127            b.seek(0)
3128            model_2 = orig
3129            model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,))
3130
3131            loaded_dict = torch.load(b, weights_only=weights_only)
3132            torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict)
3133
3134            quant_2 = convert_fx(model_2)
3135
3136            # Verify that loaded state dict produces same results.
3137            self.assertEqual(quant(x), quant_2(x))
3138
3139    @skipIfNoFBGEMM
3140    def test_custom_module_class(self):
3141        class CustomModule(torch.nn.Module):
3142            def __init__(self) -> None:
3143                super().__init__()
3144                self.linear = torch.nn.Linear(3, 3)
3145
3146            def forward(self, x):
3147                return self.linear(x)
3148
3149        class ObservedCustomModule(torch.nn.Module):
3150            def __init__(self, linear):
3151                super().__init__()
3152                self.linear = linear
3153
3154            def forward(self, x):
3155                return self.linear(x)
3156
3157            @classmethod
3158            def from_float(cls, float_module):
3159                assert hasattr(float_module, 'qconfig')
3160                observed = cls(float_module.linear)
3161                observed.qconfig = float_module.qconfig
3162                return observed
3163
3164        class StaticQuantCustomModule(torch.nn.Module):
3165            def __init__(self, linear):
3166                super().__init__()
3167                self.linear = linear
3168
3169            def forward(self, x):
3170                return self.linear(x)
3171
3172            @classmethod
3173            def from_observed(cls, observed_module):
3174                assert hasattr(observed_module, 'qconfig')
3175                assert hasattr(observed_module, 'activation_post_process')
3176                observed_module.linear.activation_post_process = \
3177                    observed_module.activation_post_process
3178                quantized = cls(nnq.Linear.from_float(observed_module.linear))
3179                return quantized
3180
3181        class DynamicQuantCustomModule(torch.nn.Module):
3182            def __init__(self, linear):
3183                super().__init__()
3184                self.linear = linear
3185
3186            def forward(self, x):
3187                return self.linear(x)
3188
3189            @classmethod
3190            def from_observed(cls, observed_module):
3191                assert hasattr(observed_module, 'qconfig')
3192                observed_module.linear.qconfig = observed_module.qconfig
3193                quantized = cls(nnqd.Linear.from_float(observed_module.linear))
3194                return quantized
3195
3196        class M(torch.nn.Module):
3197            def __init__(self) -> None:
3198                super().__init__()
3199                self.linear = torch.nn.Linear(3, 3)
3200                self.custom = CustomModule()
3201
3202            def forward(self, x):
3203                x = self.linear(x)
3204                x = self.custom(x)
3205                return x
3206
3207        class RefM(torch.nn.Module):
3208            def __init__(self) -> None:
3209                super().__init__()
3210                self.linear1 = torch.nn.Linear(3, 3)
3211                self.linear2 = torch.nn.Linear(3, 3)
3212
3213            def forward(self, x):
3214                x = self.linear1(x)
3215                x = self.linear2(x)
3216                return x
3217
3218        # instantiate M and RefM and align the parameters
3219        original_m = M().eval()
3220        original_ref_m = RefM().eval()
3221        original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach())
3222        original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach())
3223        original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach())
3224        original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach())
3225
3226        a16_qconfig = QConfig(
3227            activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536),
3228            weight=default_weight_observer,
3229        )
3230        test_configs = {
3231            "static": (default_qconfig, StaticQuantCustomModule, 3),
3232            "static_a16": (a16_qconfig, StaticQuantCustomModule, 3),
3233            "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0)
3234        }
3235
3236        for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]:
3237            key = _get_quant_type_to_str(quant_type)
3238            qconfig, quantized_module_class, num_observers = test_configs[key]
3239            qconfig_dict = {"": qconfig}
3240            if key == "static":
3241                prepare_custom_config_dict = {
3242                    "float_to_observed_custom_module_class": {
3243                        "static": {
3244                            CustomModule: ObservedCustomModule
3245                        }
3246                    }
3247                }
3248                convert_custom_config_dict = {
3249                    "observed_to_quantized_custom_module_class": {
3250                        "static": {
3251                            ObservedCustomModule: quantized_module_class
3252                        }
3253                    }
3254                }
3255            else:
3256                prepare_custom_config_dict = {
3257                    "non_traceable_module_class": [
3258                        CustomModule
3259                    ]
3260                }
3261                convert_custom_config_dict = {
3262                    "observed_to_quantized_custom_module_class": {
3263                        "dynamic": {
3264                            CustomModule: quantized_module_class
3265                        }
3266                    }
3267                }
3268
3269            example_inputs = (torch.randn(3, 3),)
3270            # check prepared model
3271            m = prepare_fx(
3272                copy.deepcopy(original_m),
3273                qconfig_dict,
3274                example_inputs=example_inputs,
3275                prepare_custom_config=prepare_custom_config_dict)
3276            # calibration
3277            m(*example_inputs)
3278            # all activation observers are inserted in the top level module
3279            count_check = {
3280                ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers
3281            }
3282            self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
3283
3284            # check converted/quantized model
3285            m = convert_fx(
3286                m,
3287                convert_custom_config=convert_custom_config_dict)
3288            if quant_type == QuantType.STATIC:
3289                count_check = {
3290                    ns.call_function(torch.quantize_per_tensor) : 1,
3291                    ns.call_module(nnq.Linear) : 1,
3292                    ns.call_method('dequantize') : 1,
3293                }
3294                self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
3295            self.assertEqual(type(m.custom), quantized_module_class)
3296            res = m(*example_inputs)
3297
3298            # quantize the reference model
3299            ref_m = prepare_fx(
3300                copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs)
3301            ref_m(*example_inputs)
3302            ref_m = convert_fx(ref_m)
3303            ref_res = ref_m(*example_inputs)
3304            self.assertEqual(res, ref_res)
3305
3306    @skipIfNoFBGEMM
3307    def test_custom_module_class_input_has_multiple_users(self):
3308        """ Tests that the flow still works when the input of custom module
3309        has multiple users
3310        """
3311        class CustomModule(torch.nn.Module):
3312            def __init__(self) -> None:
3313                super().__init__()
3314                self.linear = torch.nn.Linear(3, 3)
3315
3316            def forward(self, x):
3317                return self.linear(x)
3318
3319        class ObservedCustomModule(torch.nn.Module):
3320            def __init__(self, linear):
3321                super().__init__()
3322                self.linear = linear
3323
3324            def forward(self, x):
3325                return self.linear(x)
3326
3327            @classmethod
3328            def from_float(cls, float_module):
3329                assert hasattr(float_module, 'qconfig')
3330                observed = cls(float_module.linear)
3331                observed.qconfig = float_module.qconfig
3332                return observed
3333
3334        class StaticQuantCustomModule(torch.nn.Module):
3335            def __init__(self, linear):
3336                super().__init__()
3337                self.linear = linear
3338
3339            def forward(self, x):
3340                return self.linear(x)
3341
3342            @classmethod
3343            def from_observed(cls, observed_module):
3344                assert hasattr(observed_module, 'qconfig')
3345                assert hasattr(observed_module, 'activation_post_process')
3346                observed_module.linear.activation_post_process = \
3347                    observed_module.activation_post_process
3348                quantized = cls(nnq.Linear.from_float(observed_module.linear))
3349                return quantized
3350
3351        class M(torch.nn.Module):
3352            def __init__(self) -> None:
3353                super().__init__()
3354                self.linear = torch.nn.Linear(3, 3)
3355                self.custom = CustomModule()
3356
3357            def forward(self, x0):
3358                x1 = self.custom(x0)
3359                x2 = self.linear(x0)
3360                return x1 + x2
3361
3362        prepare_custom_config_dict = {
3363            "float_to_observed_custom_module_class": {
3364                "static": {
3365                    CustomModule: ObservedCustomModule
3366                }
3367            }
3368        }
3369        convert_custom_config_dict = {
3370            "observed_to_quantized_custom_module_class": {
3371                "static": {
3372                    ObservedCustomModule: StaticQuantCustomModule
3373                }
3374            }
3375        }
3376        m = M().eval()
3377        example_inputs = (torch.randn(3, 3),)
3378        m = prepare_fx(
3379            m,
3380            {"": default_qconfig},
3381            example_inputs=example_inputs,
3382            prepare_custom_config=prepare_custom_config_dict)
3383        # make sure it works
3384        m = convert_fx(
3385            m,
3386            convert_custom_config=convert_custom_config_dict)
3387        # make sure it runs
3388        m(*example_inputs)
3389
3390    @skipIfNoFBGEMM
3391    def test_custom_module_class_input_has_duplicate_nodes(self):
3392        """ Tests that the flow still works when the graph has
3393        multiple nodes with the same custom module target.
3394        """
3395        class CustomModule(torch.nn.Module):
3396            def __init__(self) -> None:
3397                super().__init__()
3398                self.linear = torch.nn.Linear(3, 3)
3399
3400            def forward(self, x):
3401                return self.linear(x)
3402
3403        class ObservedCustomModule(torch.nn.Module):
3404            def __init__(self, linear):
3405                super().__init__()
3406                self.linear = linear
3407
3408            def forward(self, x):
3409                return self.linear(x)
3410
3411            @classmethod
3412            def from_float(cls, float_module):
3413                assert hasattr(float_module, 'qconfig')
3414                observed = cls(float_module.linear)
3415                observed.qconfig = float_module.qconfig
3416                return observed
3417
3418        class StaticQuantCustomModule(torch.nn.Module):
3419            def __init__(self, linear):
3420                super().__init__()
3421                self.linear = linear
3422
3423            def forward(self, x):
3424                return self.linear(x)
3425
3426            @classmethod
3427            def from_observed(cls, observed_module):
3428                assert hasattr(observed_module, 'qconfig')
3429                assert hasattr(observed_module, 'activation_post_process')
3430                observed_module.linear.activation_post_process = \
3431                    observed_module.activation_post_process
3432                quantized = cls(nnq.Linear.from_float(observed_module.linear))
3433                return quantized
3434
3435        class M(torch.nn.Module):
3436            def __init__(self) -> None:
3437                super().__init__()
3438                self.custom = CustomModule()
3439
3440            def forward(self, x0):
3441                x1 = self.custom(x0)
3442                x2 = self.custom(x0)
3443                return x1 + x2
3444
3445        prepare_custom_config_dict = {
3446            "float_to_observed_custom_module_class": {
3447                "static": {
3448                    CustomModule: ObservedCustomModule
3449                }
3450            }
3451        }
3452        convert_custom_config_dict = {
3453            "observed_to_quantized_custom_module_class": {
3454                "static": {
3455                    ObservedCustomModule: StaticQuantCustomModule
3456                }
3457            }
3458        }
3459        m = M().eval()
3460        example_inputs = (torch.randn(3, 3),)
3461        m = prepare_fx(
3462            m,
3463            {"": default_qconfig},
3464            example_inputs=example_inputs,
3465            prepare_custom_config=prepare_custom_config_dict)
3466        # make sure it works
3467        m = convert_fx(
3468            m,
3469            convert_custom_config=convert_custom_config_dict)
3470        # make sure it runs
3471        m(*example_inputs)
3472
3473    @skipIfNoFBGEMM
3474    def test_non_traceable_module(self):
3475        class NonTraceable(torch.nn.Module):
3476            def forward(self, x):
3477                for k in x.keys():
3478                    print(x[k])
3479                return x
3480
3481        class NonTraceable2(torch.nn.Module):
3482            def forward(self, x):
3483                # data dependent control flow is not traceable
3484                for i in x:
3485                    print(i)
3486                return x
3487
3488        class M(torch.nn.Module):
3489            def __init__(self) -> None:
3490                super().__init__()
3491                self.m1 = NonTraceable()
3492                self.m2 = NonTraceable2()
3493
3494            def forward(self, x):
3495                x = self.m1(x)
3496                x = self.m2(x)
3497                return x
3498
3499        m = M().eval()
3500        qconfig_dict = {"": default_qconfig}
3501        prepare_custom_config_dict = {
3502            "non_traceable_module_name": [
3503                "m1"
3504            ],
3505            "non_traceable_module_class": [
3506                NonTraceable2
3507            ]
3508        }
3509        m = prepare_fx(
3510            m, qconfig_dict,
3511            example_inputs=({"key": torch.randn(1)},),
3512            prepare_custom_config=prepare_custom_config_dict)
3513
3514        node_occurrence = {
3515            ns.call_module(NonTraceable) : 1,
3516            ns.call_module(NonTraceable2) : 1,
3517        }
3518        # make sure these modules are not traced
3519        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
3520
3521    def test_prepared_model_deepcopy(self):
3522        """Ensures that copy.deepcopy works correctly on a prepared model.
3523        """
3524        class M(torch.nn.Module):
3525            def __init__(self) -> None:
3526                super().__init__()
3527                self.conv = torch.nn.Conv2d(1, 1, 1)
3528                self._foobar = 'foobar'
3529                self.foobar2 = 'foobar2'
3530
3531            def forward(self, x):
3532                x = self.conv(x)
3533                return x
3534
3535        m = M()
3536        m.eval()
3537        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
3538        example_inputs = (torch.randn(4, 1, 4, 4),)
3539        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
3540        # calibrate
3541        prepared(*example_inputs)
3542        # copy
3543        prepared_copy = copy.deepcopy(prepared)
3544        # quantize, should run with no errors
3545        quantized = convert_fx(prepared_copy)
3546
3547    def test_quantized_model_type(self):
3548        """ Test state_dict and deepcopy works properly in the quantized model
3549        """
3550        class M(torch.nn.Module):
3551            def __init__(self) -> None:
3552                super().__init__()
3553                self.linear = torch.nn.Linear(5, 5)
3554
3555            def forward(self, x):
3556                return self.linear(x)
3557
3558        example_inputs = (torch.rand(8, 5),)
3559        m = M().eval()
3560        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
3561        m = convert_fx(m)
3562        # test deepcopy
3563        m_copy = copy.deepcopy(m)
3564        self.assertEqual(m_copy(*example_inputs), m(*example_inputs))
3565
3566        # test state_dict
3567        state_dict = m.state_dict()
3568        m_new = M().eval()
3569        m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs)
3570        m_new = convert_fx(m_new)
3571        m_new.load_state_dict(state_dict)
3572        self.assertEqual(m_new(*example_inputs), m(*example_inputs))
3573
3574    def test_dequantize(self):
3575        r""" Test to make sure dequantize node are placed before
3576        non-quantizable node
3577        """
3578        class M(torch.nn.Module):
3579            def __init__(self) -> None:
3580                super().__init__()
3581                self.conv = torch.nn.Conv2d(1, 1, 1)
3582                self.act = torch.nn.GELU()
3583
3584            def forward(self, x):
3585                x = self.conv(x)
3586                return self.act(x)
3587
3588        data = torch.rand(5, 1, 3, 3, dtype=torch.float)
3589        for quant_type in self.static_quant_types:
3590            node_list = [
3591                ns.call_module(nnq.Conv2d),
3592                ns.call_method("dequantize"),
3593                ns.call_module(nn.GELU),
3594            ]
3595            self.checkGraphModeFxOp(
3596                M().eval(), (data,), quant_type, expected_node_list=node_list)
3597
3598    def test_sequential(self):
3599        class M(torch.nn.Module):
3600            def __init__(self) -> None:
3601                super().__init__()
3602                self.convs = torch.nn.Sequential(
3603                    torch.nn.Conv2d(1, 1, 1),
3604                    torch.nn.Conv2d(1, 1, 1)
3605                )
3606
3607            def forward(self, x):
3608                x = self.convs(x)
3609                return x
3610
3611        data = torch.rand(5, 1, 3, 3, dtype=torch.float)
3612        for quant_type in self.static_quant_types:
3613            node_list = [
3614                ns.call_module(nnq.Conv2d),
3615                ns.call_module(nnq.Conv2d),
3616            ]
3617            self.checkGraphModeFxOp(
3618                M().eval(), (data,), quant_type, expected_node_list=node_list)
3619
3620    def _test_quantized_inputs_outputs(
3621            self, prepare_custom_config_dict, prepare_count_check,
3622            convert_count_check):
3623        """
3624        Test the option to have inputs and outputs of the graph quantized
3625        """
3626        class M(torch.nn.Module):
3627            def __init__(self) -> None:
3628                super().__init__()
3629                self.conv1 = torch.nn.Conv2d(1, 1, 1)
3630                self.conv2 = torch.nn.Conv2d(1, 1, 1)
3631
3632            def forward(self, x):
3633                x = self.conv1(x)
3634                x = self.conv2(x)
3635                return x
3636
3637        # quantized input, quantized output
3638        m = M()
3639        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
3640        example_inputs = (torch.randn(1, 1, 4, 4),)
3641        m.eval()
3642        mp = torch.ao.quantization.quantize_fx.prepare_fx(
3643            m, qconfig_dict,
3644            example_inputs=example_inputs,
3645            prepare_custom_config=prepare_custom_config_dict)
3646        self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
3647        mp(*example_inputs)
3648        mq = torch.ao.quantization.quantize_fx.convert_fx(mp)
3649        self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
3650
3651    def test_quantized_input_quantized_output(self):
3652        prepare_custom_config_dict = {
3653            'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
3654        prepare_count_check = {
3655            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
3656        }
3657        convert_count_check = {
3658            ns.call_function(torch.quantize_per_tensor): 0,
3659            ns.call_method('dequantize'): 0,
3660        }
3661        self._test_quantized_inputs_outputs(
3662            prepare_custom_config_dict, prepare_count_check, convert_count_check)
3663
3664    def test_fp32_input_quantized_output(self):
3665        prepare_custom_config_dict = {
3666            'output_quantized_idxs': [0]}
3667        prepare_count_check = {
3668            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
3669        }
3670        convert_count_check = {
3671            ns.call_function(torch.quantize_per_tensor): 1,
3672            ns.call_method('dequantize'): 0,
3673        }
3674        self._test_quantized_inputs_outputs(
3675            prepare_custom_config_dict, prepare_count_check, convert_count_check)
3676
3677    def test_quantized_input_fp32_output(self):
3678        prepare_custom_config_dict = {
3679            'input_quantized_idxs': [0]}
3680        prepare_count_check = {
3681            ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
3682        }
3683        convert_count_check = {
3684            ns.call_function(torch.quantize_per_tensor): 0,
3685            ns.call_method('dequantize'): 1,
3686        }
3687        self._test_quantized_inputs_outputs(
3688            prepare_custom_config_dict, prepare_count_check, convert_count_check)
3689
3690    def test_fp32_input_fp32_output(self):
3691        prepare_custom_config_dict = {}
3692        prepare_count_check = {
3693            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
3694        }
3695        convert_count_check = {
3696            ns.call_function(torch.quantize_per_tensor): 1,
3697            ns.call_method('dequantize'): 1,
3698        }
3699        self._test_quantized_inputs_outputs(
3700            prepare_custom_config_dict, prepare_count_check, convert_count_check)
3701
3702    @skipIfNoFBGEMM
3703    def test_convtranspose_per_channel_fails_early(self):
3704        r"""
3705        Verifies that attempting to quantize a ConvTranspose module with per-Channel
3706        weight observers fails in the prepare step, as opposed to the convert step.
3707        """
3708        m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
3709        m.eval()
3710        qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')}
3711        with self.assertRaises(AssertionError) as context:
3712            mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
3713        self.assertTrue(
3714            str(context.exception) ==
3715            'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
3716
3717    @skipIfNoFBGEMM
3718    def test_qparams_buffers(self):
3719        class Linear(torch.nn.Module):
3720            def __init__(self) -> None:
3721                super().__init__()
3722                self.w = torch.ones(5, 5)
3723                self.b = torch.zeros(5)
3724
3725            def forward(self, x):
3726                return torch.nn.functional.linear(x, self.w, self.b)
3727
3728        class M(torch.nn.Module):
3729            def __init__(self) -> None:
3730                super().__init__()
3731                self.mods1 = torch.nn.Sequential(
3732                    Linear(),
3733                    Linear()
3734                )
3735                self.mods2 = Linear()
3736
3737            def forward(self, x):
3738                x = self.mods1(x)
3739                x = self.mods2(x)
3740                return x
3741
3742        model = M().eval()
3743        qconfig_dict = {"": default_qconfig}
3744        example_inputs = (torch.rand(5, 5),)
3745        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
3746        m(*example_inputs)
3747        m = convert_fx(m)
3748        keys = m.state_dict().keys()
3749        quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0
3750        for k in keys:
3751            if 'input_scale' in k:
3752                quant_scale_count = quant_scale_count + 1
3753            elif 'input_zero_point' in k:
3754                quant_zero_point = quant_zero_point + 1
3755            elif 'scale' in k:
3756                scale_count = scale_count + 1
3757            elif 'zero_point' in k:
3758                zero_point_count = zero_point_count + 1
3759
3760        # Expect each quantized linear op to have a scale and zero point
3761        self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
3762        self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
3763        m(*example_inputs)
3764        # ensure it is scriptable
3765        scripted = torch.jit.script(m)
3766        scripted_keys = scripted.state_dict().keys()
3767        scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"]
3768        non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key]
3769        self.assertTrue(
3770            set(scripted_keys) == set(non_packed_weight_keys),
3771            "Expected the scripted model to preserve the state_dict for non-packed weight attributes")
3772        # TODO: probably don't want to hardcode the attribute names, since they are generated
3773        for attr_name in [
3774                "mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
3775                "mods1_0_scale_1", "mods1_0_zero_point_1",
3776                "mods1_1_scale_1", "mods1_1_zero_point_1",
3777                "mods2_scale_1", "mods2_zero_point_1"]:
3778            self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
3779
3780    @skipIfNoFBGEMM
3781    def test_packed_weight_fused_op(self):
3782        class Linear(torch.nn.Module):
3783            def __init__(self) -> None:
3784                super().__init__()
3785                self.w = torch.ones(5, 5)
3786                self.b = torch.zeros(5)
3787
3788            def forward(self, x):
3789                return F.linear(x, self.w, self.b)
3790
3791        class M(torch.nn.Module):
3792            def __init__(self) -> None:
3793                super().__init__()
3794                self.mods1 = torch.nn.Sequential(
3795                    Linear(),
3796                    Linear()
3797                )
3798                self.mods2 = Linear()
3799                self.relu = F.relu
3800
3801            def forward(self, x):
3802                x = self.mods1(x)
3803                x = self.mods2(x)
3804                x = self.relu(x)
3805                return x
3806
3807        model = M().eval()
3808        example_inputs = (torch.rand(5, 5),)
3809        qconfig_dict = {"": default_qconfig}
3810        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
3811        m(*example_inputs)
3812        m = convert_fx(m)
3813        assert hasattr(m, "mods1_0_packed_weight_0")
3814        assert hasattr(m, "mods1_1_packed_weight_0")
3815        assert hasattr(m, "mods2_packed_weight_0")
3816
3817    @skipIfNoFBGEMM
3818    def test_mul_add_fp16_config(self):
3819        with override_quantized_engine('fbgemm'):
3820            class Linear(torch.nn.Module):
3821                def __init__(self) -> None:
3822                    super().__init__()
3823                    self.w = torch.ones(5, 5)
3824                    self.b = torch.zeros(5)
3825
3826                def forward(self, x):
3827                    return torch.nn.functional.linear(x, self.w, self.b)
3828
3829            class M(torch.nn.Module):
3830                def __init__(self) -> None:
3831                    super().__init__()
3832                    self.mods1 = torch.nn.Sequential(
3833                        Linear(),
3834                        Linear()
3835                    )
3836                    self.mods2 = Linear()
3837
3838                def forward(self, x):
3839                    x = x * 5
3840                    x = x + 5
3841                    x = self.mods1(x)
3842                    x = self.mods2(x)
3843                    return x
3844            model = M().eval()
3845            qconfig_dict = {"": float16_dynamic_qconfig}
3846            example_inputs = (torch.rand(5, 5),)
3847            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
3848            m = convert_fx(m)
3849            # make sure it runs
3850            m(*example_inputs)
3851
3852    def test_getattr_with_nontensor_result(self):
3853        """
3854        Verifies that binary ops get quantized correctly if some
3855        of the args are nodes but not Tensors, such as an `x.ndim`
3856        pattern.
3857        """
3858        class M1(torch.nn.Module):
3859            def forward(self, x):
3860                dims = x.ndim
3861                dims_sub = dims - 1
3862                dims_sub2 = dims_sub - 1
3863                x = torch.add(x, dims_sub2)
3864                return x
3865
3866        class M2(torch.nn.Module):
3867            def forward(self, x):
3868                dims = x.ndim
3869                dims_sub = dims - 2
3870                mul = [1] * dims_sub
3871                dims_list = [-1, x.size(1)] + mul
3872                x = x.view(dims_list)
3873                return x
3874
3875        class M3(torch.nn.Module):
3876            def forward(self, x):
3877                shape = x.shape
3878                x = x.view(shape)
3879                return x
3880
3881        for cls in (M1, M2, M3):
3882            m = cls().eval()
3883            example_inputs = (torch.rand(4, 4, 4, 4),)
3884            m(*example_inputs)
3885            qconfig_dict = {'': torch.ao.quantization.default_qconfig}
3886            mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
3887            mp(torch.rand(4, 4, 4, 4))
3888            mc = convert_fx(mp)
3889
3890    class _NonReferenceTestModel(nn.Module):
3891        def __init__(self, func, lin_in, lin_out):
3892            super().__init__()
3893            self.conv1 = nn.Conv2d(3, 6, 5)
3894            self.pool = nn.MaxPool2d(2, 2)
3895            self.lin = nn.Linear(lin_in, lin_out)
3896            self.func = func
3897
3898        def forward(self, x, y, z):
3899            x = self.pool(F.relu(self.conv1(x)))
3900            x = torch.flatten(x, 1)
3901            x = self.func(x, y, z)
3902            x = self.lin(x)
3903            return x
3904
3905    # This function looks at the node specified by the NodeInfo in the key of
3906    # node_info_to_non_tensor_args and checks that the args at specified indices
3907    # are not observed (since they are non tensors). If the args at those indices
3908    # are a tuple/list (which do not show up as nodes) the function checks the
3909    # individual elements of the tuple/list recursively.
3910    def _check_not_observed(self, model, node_info_to_non_tensor_args):
3911
3912        # this is a helper function (for easier recursion) that checks whether
3913        # arg_node is observed
3914        def _check_node_not_observed(model, arg_node, node):
3915            if isinstance(arg_node, (tuple, list)):
3916                for new_node in arg_node:
3917                    _check_node_not_observed(model, new_node, node)
3918            elif arg_node.op == "call_module":
3919                self.assertTrue(
3920                    not _is_activation_post_process(getattr(model, arg_node.target)),
3921                    f"Arg: {arg_node} of node: {node} is observed but is not a float tensor",
3922                )
3923
3924        for node in model.graph.nodes:
3925            indices = node_info_to_non_tensor_args.get(
3926                NodeInfo(node.op, node.target), []
3927            )
3928            for index in indices:
3929                if index < len(node.args):
3930                    arg_node = node.args[index]
3931                    _check_node_not_observed(model, arg_node, node)
3932
3933    # This test checks that the model gets prepared correct, doesn't have observers
3934    # on specific ops (see _check_not_observed) and that the prepared model runs
3935    def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args):
3936        model.eval()
3937        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
3938        prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args))
3939        self._check_not_observed(prepared_model, node_info_to_non_tensor_args)
3940        prepared_model(*args)
3941
3942    def test_masked_fill_nontensor_args_not_observed(self):
3943        def func(x, y, z):
3944            return x.masked_fill(y, z)
3945
3946        model = self._NonReferenceTestModel(func, 1176, 1)
3947        args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1]
3948        node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]}
3949        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3950
3951    def test_permute_nontensor_args_not_observed(self):
3952        def func(x, y, z):
3953            return x.permute(y, z)
3954
3955        model = self._NonReferenceTestModel(func, 1176, 1)
3956        args = [torch.randn(5, 3, 32, 32), 0, 1]
3957        node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]}
3958        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3959
3960    def test_repeat_nontensor_args_not_observed(self):
3961        def func(x, y, z):
3962            return x.repeat(y, z)
3963
3964        model = self._NonReferenceTestModel(func, 1176, 1)
3965        args = [torch.randn(5, 3, 32, 32), 2, 1]
3966        node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]}
3967        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3968
3969    def test_reshape_nontensor_args_not_observed(self):
3970        def func(x, y, z):
3971            return x.reshape(-1, y)
3972
3973        model = self._NonReferenceTestModel(func, 5, 1)
3974        args = [torch.randn(5, 3, 32, 32), 5, None]
3975        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]}
3976        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3977
3978    def test_size_nontensor_args_not_observed(self):
3979        def func(x, y, z):
3980            return x.reshape((-1, x.size(y)))
3981
3982        model = self._NonReferenceTestModel(func, 5, 1)
3983        args = [torch.randn(5, 3, 32, 32), 0, None]
3984        node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]}
3985        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3986
3987    def test_transpose_nontensor_args_not_observed(self):
3988        def func(x, y, z):
3989            return x.transpose(y, z)
3990
3991        model = self._NonReferenceTestModel(func, 5, 1)
3992        args = [torch.randn(5, 3, 32, 32), 0, 1]
3993        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
3994        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
3995
3996    def test_torch_transpose_nontensor_args_not_observed(self):
3997        # TODO: make torch.transpose traceable by fx when using
3998        # variable nontensor arguments
3999        # func = lambda x, y, z: torch.transpose(x, y, z) # error
4000        def func(x, y, z):
4001            return torch.transpose(x, 0, 1)
4002
4003        model = self._NonReferenceTestModel(func, 5, 1)
4004        node_info_to_non_tensor_args = {
4005            NodeInfo("call_method", torch.transpose): [1, 2]
4006        }
4007        args = [torch.randn(5, 3, 32, 32), 0, 1]
4008        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4009
4010    def test_unsqueeze_nontensor_args_not_observed(self):
4011        def func(x, y, z):
4012            return x.unsqueeze(y)
4013
4014        model = self._NonReferenceTestModel(func, 1176, 1)
4015        args = [torch.randn(5, 3, 32, 32), 1, None]
4016        node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]}
4017        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4018
4019    def test_unsqueeze__nontensor_args_not_observed(self):
4020        def func(x, y, z):
4021            return x.unsqueeze_(y)
4022
4023        model = self._NonReferenceTestModel(func, 1176, 1)
4024        args = [torch.randn(5, 3, 32, 32), 1, None]
4025        node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]}
4026        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4027
4028    def test_torch_unsqueeze_nontensor_args_not_observed(self):
4029        # TODO: make torch.unsqueeze scriptable by fx when using
4030        # variable nontensor arguments
4031        # func = lambda x, y, z: torch.unsqueeze(x, y) # error
4032        def func(x, y, z):
4033            return torch.unsqueeze(x, 1)
4034
4035        model = self._NonReferenceTestModel(func, 1176, 1)
4036        args = [torch.randn(5, 3, 32, 32), 1, None]
4037        node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]}
4038        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4039
4040    def test_view_nontensor_args_not_observed(self):
4041        def func(x, y, z):
4042            return x.view(-1, y)
4043
4044        model = self._NonReferenceTestModel(func, 5, 1)
4045        args = [torch.randn(5, 3, 32, 32), 5, None]
4046        node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]}
4047        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4048
4049    def test_propagate_dtypes_for_known_nodes_list_args(self):
4050        def func(x, y, z):
4051            return x.reshape(y)
4052
4053        model = self._NonReferenceTestModel(func, 5, 1)
4054        args = [torch.randn(5, 3, 32, 32), [-1, 5], None]
4055        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
4056        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4057
4058    def test_propagate_dtypes_for_known_nodes_split_list_args(self):
4059        def func(x, y, z):
4060            return x.reshape([y, z])
4061
4062        model = self._NonReferenceTestModel(func, 5, 1)
4063        args = [torch.randn(5, 3, 32, 32), -1, 5]
4064        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
4065        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4066
4067    def test_propagate_dtypes_for_known_nodes_tuple_args(self):
4068        def func(x, y, z):
4069            return x.reshape(y)
4070
4071        model = self._NonReferenceTestModel(func, 5, 1)
4072        args = [torch.randn(5, 3, 32, 32), (-1, 5), None]
4073        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
4074        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4075
4076    def test_propagate_dtypes_for_known_nodes_split_tuple_args(self):
4077        def func(x, y, z):
4078            return x.reshape((y, z))
4079
4080        model = self._NonReferenceTestModel(func, 5, 1)
4081        args = [torch.randn(5, 3, 32, 32), -1, 5]
4082        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
4083        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4084
4085    def test_propagate_dtypes_for_known_nodes_dict_args(self):
4086        def func(x, y, z):
4087            return x.transpose(y["first"], y["second"])
4088
4089        model = self._NonReferenceTestModel(func, 5, 1)
4090        args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None]
4091        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]}
4092        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4093
4094    def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self):
4095        class reshape_module(nn.Module):
4096            def forward(self, x, y, z):
4097                return x.reshape(y["shape"])
4098
4099        model = self._NonReferenceTestModel(reshape_module(), 5, 1)
4100        args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None]
4101        node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]}
4102        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4103
4104    def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self):
4105        def func(x, y, z):
4106            return x.reshape((y["first"], y["second"]))
4107
4108        model = self._NonReferenceTestModel(func, 5, 1)
4109        args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None]
4110        node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]}
4111        self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args)
4112
4113    def test_assert_on_size_after_quant_layer(self):
4114        """
4115        Verifies that calculating a size of a quantized tensor works
4116        correctly in quantization passes.
4117        """
4118        class M(torch.nn.Module):
4119            def __init__(self) -> None:
4120                super().__init__()
4121                self.conv1 = nn.Conv2d(1, 1, 1)
4122
4123            def forward(self, x):
4124                x = self.conv1(x)
4125                torch._assert(x.size(1) == 1, 'foobar')
4126                return x
4127
4128        m = M().eval()
4129        example_inputs = (torch.rand(4, 1, 4, 4),)
4130        m(*example_inputs)
4131        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
4132        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
4133        mp(*example_inputs)
4134        mc = convert_fx(mp)
4135        mc(*example_inputs)
4136
4137    def test_fp32_sum(self):
4138        """
4139        Verifies that fp32 sum works correctly if it's before or after
4140        quantized layers.
4141        """
4142        class M1(torch.nn.Module):
4143            def __init__(self) -> None:
4144                super().__init__()
4145                self.conv1 = nn.Conv2d(1, 1, 1)
4146
4147            def forward(self, x):
4148                x = self.conv1(x)
4149                x = torch.stack([x])
4150                x = torch.sum(x)
4151                return x
4152
4153        class M2(torch.nn.Module):
4154            def __init__(self) -> None:
4155                super().__init__()
4156                self.conv1 = nn.Conv2d(1, 1, 1)
4157                self.conv2 = nn.Conv2d(1, 1, 1)
4158
4159            def forward(self, x):
4160                x = self.conv1(x)
4161                x1 = torch.stack([x])
4162                x1 = torch.sum(x1, dim=0)
4163                x2 = self.conv2(x1)
4164                return x2
4165
4166        for cls in (M1, M2):
4167            m = cls().eval()
4168            example_inputs = (torch.rand(4, 1, 4, 4),)
4169            m(*example_inputs)
4170            qconfig_dict = {'': torch.ao.quantization.default_qconfig}
4171            mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
4172            mp(*example_inputs)
4173            mc = convert_fx(mp)
4174            mc(*example_inputs)
4175
4176    def test_fusion_pattern_unquantized(self):
4177        """
4178        Ensure that leaving a possible fusion pattern of multiple nodes
4179        unquantized runs through the APIs without errors.
4180        """
4181        class Child(torch.nn.Module):
4182            def __init__(self) -> None:
4183                super().__init__()
4184                self.relu = nn.ReLU()
4185
4186            def forward(self, x):
4187                x = torch.add(x, 1.0)
4188                x = torch.nn.functional.relu(x)
4189                return x
4190
4191        class Parent(torch.nn.Module):
4192            def __init__(self) -> None:
4193                super().__init__()
4194                self.child = Child()
4195                self.conv = nn.Conv2d(1, 1, 1)
4196
4197            def forward(self, x):
4198                x = self.child(x)
4199                x = self.conv(x)
4200                return x
4201
4202        m = Parent().eval()
4203        qconfig_dict = {
4204            '': torch.ao.quantization.default_qconfig,
4205            'module_name': [
4206                ('child', None),
4207            ],
4208        }
4209        example_inputs = (torch.rand(1, 1, 1, 1),)
4210        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
4211        mp(*example_inputs)
4212        mc = convert_fx(mp)
4213
4214    def test_state_dict(self):
4215        """ Make sure packed params appear in state_dict
4216        """
4217
4218        # test linear packed weight
4219        class M1(torch.nn.Module):
4220            def __init__(self) -> None:
4221                super().__init__()
4222                self.w = torch.rand(4, 30)
4223                self.b = torch.rand(4)
4224
4225            def forward(self, x):
4226                return F.linear(x, self.w, self.b)
4227
4228        m = M1().eval()
4229        qconfig_dict = {"": default_qconfig}
4230        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),))
4231        m = convert_fx(m)
4232        state_dict = m.state_dict()
4233        self.assertTrue("_packed_weight_0" in state_dict)
4234
4235        # test conv packed weight
4236        class M2(torch.nn.Module):
4237            def __init__(self) -> None:
4238                super().__init__()
4239                self.w = torch.rand(3, 3, 3, 3)
4240                self.b = torch.rand(3)
4241                self.stride = (1, 1)
4242                self.padding = (0, 0)
4243                self.dilation = (1, 1)
4244                self.groups = 1
4245
4246            def forward(self, x):
4247                return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
4248
4249        m = M2().eval()
4250        qconfig_dict = {"": default_qconfig}
4251        m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),))
4252        m = convert_fx(m)
4253        state_dict = m.state_dict()
4254        self.assertTrue("_packed_weight_0" in state_dict)
4255
4256        # test load
4257        ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"])
4258        data = torch.rand(1, 3, 5, 5)
4259        ref_res = m(data)
4260        m = M2().eval()
4261        m = prepare_fx(m, qconfig_dict, (data,))
4262        m = convert_fx(m)
4263        res = m(data)
4264        weight, bias = m._packed_weight_0.unpack()
4265        # check that random model weight/bias does not match ref weight/bias
4266        self.assertNotEqual(weight, ref_weight)
4267        self.assertNotEqual(bias, ref_bias)
4268        self.assertNotEqual(res, ref_res)
4269        m.load_state_dict(state_dict)
4270
4271        def checkModel(m, data, ref_weight, ref_bias, ref_res):
4272            res = m(data)
4273            weight, bias = m._packed_weight_0.unpack()
4274            # check that weight/bias matches after load the state_dict
4275            self.assertEqual(weight, ref_weight)
4276            self.assertEqual(bias, ref_bias)
4277            self.assertEqual(res, ref_res)
4278
4279        checkModel(m, data, ref_weight, ref_bias, ref_res)
4280
4281        # Test save to disk and load back
4282        m = M2().eval()
4283        m = prepare_fx(m, qconfig_dict, example_inputs=(data,))
4284        m = convert_fx(m)
4285        m.load_state_dict(state_dict)
4286        with TemporaryFileName() as fname:
4287            torch.save(m.state_dict(), fname)
4288            # weights_only=False as this is loading a ScriptModule
4289            m.load_state_dict(torch.load(fname, weights_only=False))
4290
4291        checkModel(m, data, ref_weight, ref_bias, ref_res)
4292
4293    @skipIfNoFBGEMM
4294    def test_preserve_qconfig(self):
4295        """
4296        Test to make sure the temporary config option to preserve qconfig attributes
4297        in the model works
4298        """
4299        with override_quantized_engine('fbgemm'):
4300            class Linear(torch.nn.Module):
4301                def __init__(self) -> None:
4302                    super().__init__()
4303                    self.w = torch.ones(5, 5)
4304                    self.b = torch.zeros(5)
4305
4306                def forward(self, x):
4307                    return torch.nn.functional.linear(x, self.w, self.b)
4308
4309            class M(torch.nn.Module):
4310                def __init__(self) -> None:
4311                    super().__init__()
4312                    self.mods1 = torch.nn.Sequential(
4313                        Linear(),
4314                        Linear()
4315                    )
4316                    self.mods2 = torch.nn.Sigmoid()
4317
4318                def forward(self, x):
4319                    x = self.mods1(x)
4320                    x = self.mods2(x)
4321                    return x
4322
4323            model = M().eval()
4324            qconfig_dict = {
4325                "object_type": [
4326                    (torch.nn.functional.linear, float16_dynamic_qconfig),
4327                ],
4328            }
4329            example_inputs = (torch.rand(5, 5),)
4330            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
4331            m(*example_inputs)
4332            m = convert_fx(m, _remove_qconfig=False)
4333
4334            self.assertTrue(hasattr(m.mods2, 'qconfig'))
4335
4336    def test_not_used(self):
4337        """ Test quantizing a not used value"""
4338
4339        class M(torch.nn.Module):
4340            def forward(self, x):
4341                x = x + x
4342                x.sigmoid_()
4343                return x
4344
4345        m = M().eval()
4346        qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig)
4347        # make sure quantization runs
4348        m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),))
4349        m = convert_fx(m)
4350
4351    def test_qparams_fqn(self):
4352        """ Test that the FQN of input_scale/zero_point is set
4353        to that of first linear use. """
4354        class Linear(torch.nn.Module):
4355            def __init__(self) -> None:
4356                super().__init__()
4357                self.w = torch.ones(5, 5)
4358                self.b = torch.zeros(5)
4359
4360            def forward(self, x):
4361                return torch.nn.functional.linear(x, self.w, self.b)
4362
4363        class M(torch.nn.Module):
4364            def __init__(self) -> None:
4365                super().__init__()
4366                self.mods1 = torch.nn.Sequential(
4367                    Linear(),
4368                    Linear()
4369                )
4370
4371            def forward(self, x):
4372                x = torch.cat((x,), 1)
4373                tmp = x.size()
4374                x = self.mods1(x)
4375                y = x * tmp[0]
4376                return y
4377
4378        model = M().eval()
4379        qconfig_dict = {
4380            "": None,
4381            "object_type": [
4382                (torch.nn.functional.linear, default_qconfig),
4383                (torch.nn.functional.relu, default_qconfig),
4384            ],
4385        }
4386        example_inputs = (torch.rand(5, 5),)
4387        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
4388        m(*example_inputs)
4389        m = convert_fx(m)
4390        keys = m.state_dict().keys()
4391        m(torch.randn(5, 5))
4392        # TODO: probably don't want to hardcode the attribute names, since they are generated
4393        for attr_name in [
4394                "mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
4395                "mods1_0_scale_0", "mods1_0_zero_point_0",
4396                "mods1_1_scale_0", "mods1_1_zero_point_0"]:
4397            self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
4398
4399    def test_no_obs_between_unmatched_node_and_copy_node(self):
4400        """
4401        Verifies that an observer is not inserted between an unmatched
4402        node and a node matched to CopyNodeQuantizeHandler.  This is done
4403        because observers require activations to be Tensors, and there is
4404        no guarantee that an output of an unmatched node is a Tensor.
4405        """
4406
4407        class M(nn.Module):
4408            def __init__(self) -> None:
4409                super().__init__()
4410                self.relu = nn.ReLU()
4411
4412            def forward(self, x):
4413                x = _user_func_with_complex_return_type(x)
4414                x1 = x[0] + 1
4415                return x1, x[1]
4416
4417        m = M().eval()
4418
4419        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
4420        example_inputs = (torch.randn(4, 4, 4, 4),)
4421        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
4422        # if an observer is inserted after _user_func_with_complex_return_type,
4423        # the following call will fail
4424        mp(*example_inputs)
4425        mc = convert_fx(mp)
4426        mc(*example_inputs)
4427
4428    def test_fold_quant_dequant(self):
4429        """ Test that the sequence of quant-dequant nodes in the
4430            graph, get folded and we erase the extra dequant nodes.
4431        """
4432        class M(torch.nn.Module):
4433            def __init__(self) -> None:
4434                super().__init__()
4435                self.w = torch.ones(5, 5)
4436                self.b = torch.zeros(5)
4437
4438            def forward(self, x):
4439                x = torch.cat((x,), 1)
4440                tmp = x.size()
4441                x = torch.nn.functional.linear(x, self.w, self.b)
4442                y = x * tmp[0]
4443                return y
4444
4445        model = M().eval()
4446        qconfig_dict = {
4447            "": None,
4448            "object_type": [
4449                (torch.nn.functional.linear, default_qconfig),
4450            ],
4451        }
4452        example_inputs = (torch.rand(5, 5),)
4453        m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
4454        m(*example_inputs)
4455        m = convert_fx(m)
4456        keys = m.state_dict().keys()
4457        m(*example_inputs)
4458        dequant = 0
4459        quant = 0
4460        for n in m.graph.nodes:
4461            if n.op == "call_method" and n.target == "dequantize":
4462                dequant = dequant + 1
4463            if n.op == "call_function" and n.target == torch.quantize_per_tensor:
4464                quant = quant + 1
4465        self.assertEqual(dequant, 1)
4466        self.assertEqual(quant, 1)
4467
4468    def test_quant_output_always_observed(self):
4469        """
4470        If the output is hardcoded to be quantized, ensure that
4471        there is always an observer, even if the last non-output node is not
4472        quantizeable.
4473        """
4474        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
4475        prepare_custom_config_dict = {'output_quantized_idxs': [0]}
4476        example_inputs = (torch.randn(4, 1, 4, 4),)
4477
4478        # non-quantizeable node, quantized output
4479        class M1(torch.nn.Module):
4480            def __init__(self) -> None:
4481                super().__init__()
4482                self.identity = torch.nn.Identity()
4483
4484            def forward(self, x):
4485                x = self.identity(x)
4486                return x
4487
4488        m1 = M1()
4489        self.checkGraphModeFxOp(
4490            m1, example_inputs, QuantType.QAT,
4491            prepare_expected_node_occurrence={
4492                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
4493            },
4494            expected_node_occurrence={
4495                ns.call_function(torch.quantize_per_tensor): 1,
4496            },
4497            prepare_custom_config=prepare_custom_config_dict)
4498
4499        # quantizeable node, quantized output
4500        class M2(torch.nn.Module):
4501            def __init__(self) -> None:
4502                super().__init__()
4503                self.conv = torch.nn.Conv2d(1, 1, 1)
4504
4505            def forward(self, x):
4506                x = self.conv(x)
4507                return x
4508
4509        m2 = M2()
4510        self.checkGraphModeFxOp(
4511            m2, example_inputs, QuantType.QAT,
4512            prepare_expected_node_occurrence={
4513                # one for weights, one for activations
4514                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
4515            },
4516            expected_node_occurrence={
4517                ns.call_function(torch.quantize_per_tensor): 1,
4518            },
4519            prepare_custom_config=prepare_custom_config_dict)
4520
4521        # quantizeable node, quantized dictionary output
4522        class M3(torch.nn.Module):
4523            def __init__(self) -> None:
4524                super().__init__()
4525                self.conv = torch.nn.Conv2d(1, 1, 1)
4526
4527            def forward(self, x):
4528                x = self.conv(x)
4529                return {"output": x}
4530
4531        m3 = M3()
4532        self.checkGraphModeFxOp(
4533            m3, example_inputs, QuantType.QAT,
4534            prepare_expected_node_occurrence={
4535                # one for weights, one for activations
4536                ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2,
4537            },
4538            expected_node_occurrence={
4539                ns.call_function(torch.quantize_per_tensor): 1,
4540            },
4541            prepare_custom_config=prepare_custom_config_dict)
4542
4543    def test_deepcopy_preserve_attributes(self):
4544        class M(torch.nn.Module):
4545            def __init__(self) -> None:
4546                super().__init__()
4547                self.attr = 3
4548
4549            def forward(self, x):
4550                return x
4551
4552        m = M().eval()
4553        m = prepare_fx(
4554            m,
4555            {"": default_qconfig},
4556            example_inputs=(torch.randn(1),),
4557            prepare_custom_config={"preserved_attributes": ["attr"]})
4558        # preserved attributes are also stored in meta so that it doesn't get lost
4559        # during deepcopy
4560        self.assertTrue(hasattr(m, "attr"))
4561        self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
4562        m2 = copy.deepcopy(m)
4563        self.assertTrue(hasattr(m2, "attr"))
4564        self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
4565        m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]})
4566        self.assertTrue(hasattr(m, "attr"))
4567        self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
4568        m2 = copy.deepcopy(m)
4569        self.assertTrue(hasattr(m2, "attr"))
4570        self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY])
4571
4572    def test_output_lists_and_dicts(self):
4573        """Verify that specifying complicated output types does not crash.
4574        """
4575        class M(torch.nn.Module):
4576            def __init__(self) -> None:
4577                super().__init__()
4578                self.conv = nn.Conv2d(1, 1, 1)
4579
4580            def forward(self, x):
4581                x = self.conv(x)
4582                return {'foo': [x]}, [{'foo': [[x]]}]
4583
4584        m = M().eval()
4585        qconfig_dict = {'': torch.ao.quantization.default_qconfig}
4586        mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),))
4587        mc = convert_fx(mp)
4588
4589    def test_shape_followed_by_quantized_op(self):
4590        """ Make sure that shape does not dequantize
4591        the Tensor before the next operator
4592        """
4593        class M(torch.nn.Module):
4594            def __init__(self) -> None:
4595                super().__init__()
4596                self.conv1 = torch.nn.Conv2d(2, 2, 2)
4597                self.conv2 = torch.nn.Conv2d(2, 2, 2)
4598
4599            def forward(self, x):
4600                x = self.conv1(x)
4601                s = x.shape
4602                torch._assert(s == x.shape, "")
4603                x = self.conv2(x)
4604                return x
4605
4606        # make sure quantization runs
4607        m = M().eval()
4608        example_inputs = (torch.randn(2, 2, 4, 4),)
4609        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
4610        m = convert_fx(m)
4611        m(*example_inputs)
4612        node_occurrence = {
4613            ns.call_function(torch.quantize_per_tensor): 1,
4614            ns.call_method("dequantize"): 1
4615        }
4616        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
4617
4618    def test_trace_quantize_per_tensor(self):
4619        class M(torch.nn.Module):
4620            def __init__(self) -> None:
4621                super().__init__()
4622                self.conv = torch.nn.Conv2d(1, 1, 1)
4623
4624            def forward(self, x):
4625                x = self.conv(x)
4626                return x
4627
4628        m = M().eval()
4629        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),))
4630        m = convert_fx(m)
4631        # Make sure this runs without error
4632        m = torch.fx.Transformer(m).transform()
4633
4634    def test_copy_node_has_shared_actpp_instance(self):
4635        """ Test the output of CopyNode to have the same
4636        observer/fake_quant instance as the input
4637        """
4638
4639        class M(torch.nn.Module):
4640            def __init__(self) -> None:
4641                super().__init__()
4642                self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3)
4643
4644            def forward(self, x):
4645                x = self.avgpool2d(x)
4646                return x
4647
4648        for quant_type in self.static_quant_types:
4649            m = M()
4650            # Checks that we have an observer for both input and output
4651            occurrence_map = {
4652                QuantType.STATIC: {
4653                    ns.call_module(torch.ao.quantization.MinMaxObserver): 2
4654                },
4655                QuantType.QAT: {
4656                    ns.call_module(torch.ao.quantization.FakeQuantize): 2
4657                }
4658            }
4659            if quant_type == QuantType.QAT:
4660                m.train()
4661                prepare = prepare_qat_fx
4662                qconfig = default_qat_qconfig
4663                actpp_module_class = torch.ao.quantization.FakeQuantize
4664            else:
4665                m.eval()
4666                prepare = prepare_fx
4667                qconfig = default_qconfig
4668                actpp_module_class = torch.ao.quantization.MinMaxObserver
4669
4670            example_inputs = (torch.randn(1, 3, 3, 3),)
4671            m = prepare(m, {"": qconfig}, example_inputs=example_inputs)
4672            # check that there is a duplicated observer instance
4673            actpp_module_count = 0
4674            for name, module in m.named_modules(remove_duplicate=False):
4675                if isinstance(module, actpp_module_class):
4676                    actpp_module_count += 1
4677            self.assertEqual(actpp_module_count, 2)
4678
4679            actpp_module_count = 0
4680            for name, module in m.named_modules():
4681                if isinstance(module, actpp_module_class):
4682                    actpp_module_count += 1
4683            self.assertEqual(actpp_module_count, 1)
4684
4685            m_copy = copy.deepcopy(m)
4686            m = convert_fx(m)
4687            m_reference = convert_to_reference_fx(m_copy)
4688
4689            # checks for non-reference quantized model
4690            node_occurrence = {
4691                ns.call_function(torch.quantize_per_tensor): 1,
4692                ns.call_method("dequantize"): 1
4693            }
4694            node_list = [
4695                ns.call_function(torch.quantize_per_tensor),
4696                ns.call_module(torch.nn.AvgPool2d),
4697                ns.call_method("dequantize"),
4698            ]
4699            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list)
4700
4701            # checks for reference quantized model, for copy nodes we'll have
4702            # dequant - copy_node - quant patterns which will be fused later
4703            # in the backend lowering step
4704            node_occurrence = {
4705                ns.call_function(torch.quantize_per_tensor): 2,
4706                ns.call_method("dequantize"): 2
4707            }
4708            node_list = [
4709                ns.call_function(torch.quantize_per_tensor),
4710                ns.call_method("dequantize"),
4711                ns.call_module(torch.nn.AvgPool2d),
4712                ns.call_function(torch.quantize_per_tensor),
4713                ns.call_method("dequantize"),
4714            ]
4715            self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list)
4716
4717    def test_linear_qint8_activation(self):
4718        """Test support for qint8 activation in reference pattern
4719        """
4720        class M(torch.nn.Module):
4721            def __init__(self) -> None:
4722                super().__init__()
4723                self.conv = torch.nn.Conv2d(1, 2, 2, 2)
4724                self.linear = torch.nn.Linear(8, 5)
4725
4726            def forward(self, x):
4727                x = self.conv(x)
4728                x = torch.flatten(x, 1)
4729                x = self.linear(x)
4730                return x
4731
4732        m = M().eval()
4733        example_inputs = (torch.rand(2, 1, 5, 5),)
4734        m = prepare_fx(
4735            m,
4736            {"": torch.ao.quantization.QConfig(
4737                activation=torch.ao.quantization.HistogramObserver.with_args(
4738                    qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
4739                ), weight=torch.ao.quantization.default_per_channel_weight_observer)},
4740            example_inputs=example_inputs)
4741        m = convert_to_reference_fx(m)
4742        m(*example_inputs)
4743
4744    def test_preserve_tuple(self):
4745        """ Test tuple input type is preserved
4746        """
4747
4748        class LSTM(nn.Module):
4749            def __init__(self) -> None:
4750                super().__init__()
4751                self.lstm = nn.LSTM(50, 50, 1)
4752
4753            def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]):
4754                h = state[0]
4755                c = state[1]
4756                return self.lstm(inputs, (h, c))
4757
4758        m = LSTM().eval()
4759        example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50))
4760        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
4761        # make sure the arg[1] of lstm module is a tuple
4762        for n in m.graph.nodes:
4763            if n.target == "lstm":
4764                self.assertEqual(type(n.args[1]), tuple)
4765
4766    def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence):
4767        """
4768        Helper method to validate the graph of a model with static LSTM.
4769        """
4770        qconfig_mapping = get_default_qconfig_mapping()
4771        prepare_custom_config = PrepareCustomConfig() \
4772            .set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM)
4773        convert_custom_config = ConvertCustomConfig() \
4774            .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM)
4775        example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
4776
4777        model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
4778        self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence)
4779        model(*example_inputs)
4780
4781        model = convert_fx(model, convert_custom_config=convert_custom_config)
4782        self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence)
4783        model(*example_inputs)
4784
4785    def test_static_lstm(self):
4786        """
4787        Test statically quantized custom module LSTM followed by ops that consume individual
4788        tensors of the output tuple.
4789        """
4790        class MyModel(nn.Module):
4791            def __init__(self) -> None:
4792                super().__init__()
4793                self.lstm = nn.LSTM(50, 50, 1)
4794                self.linear1 = nn.Linear(50, 10)
4795                self.linear2 = nn.Linear(50, 10)
4796                self.linear3 = nn.Linear(50, 10)
4797
4798            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
4799                (out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0))
4800                out = self.linear1(out)
4801                h0_out = self.linear2(h0_out)
4802                c0_out = self.linear3(c0_out)
4803                return (out, (h0_out, c0_out))
4804
4805        m = MyModel()
4806        prepare_node_occurrence = {
4807            ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
4808        }
4809        convert_node_occurrence = {
4810            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
4811            ns.call_function(torch.quantize_per_tensor): 3,
4812            # lstm[0].dequantize()
4813            # lstm[1][0].dequantize()
4814            # lstm[1][1].dequantize()
4815            ns.call_method("dequantize"): 3,
4816            # lstm[0], lstm[1], lstm[1][0], lstm[1][1]
4817            ns.call_function(operator.getitem): 4,
4818            # No tuples are consumed
4819            ns.call_function(tuple): 0,
4820        }
4821        self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence)
4822
4823    def test_static_lstm_consume_tuple(self):
4824        """
4825        Test statically quantized custom module LSTM followed by a module that consumes the
4826        output tuple, either as a whole or part of it.
4827        """
4828        class ModuleAfterLSTM(nn.Module):
4829            def __init__(self) -> None:
4830                super().__init__()
4831                self.identity = torch.nn.Identity()
4832
4833            def forward(self, x):
4834                return self.identity(x)
4835
4836        class ConsumeWholeTuple(nn.Module):
4837            def __init__(self) -> None:
4838                super().__init__()
4839                self.lstm = nn.LSTM(50, 50, 1)
4840                self.module_after_lstm = ModuleAfterLSTM()
4841
4842            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
4843                x = self.lstm(inputs, (h0, c0))
4844                x = self.module_after_lstm(x)  # consume tuple (output, (hidden0, hidden1))
4845                return x
4846
4847        class ConsumeHiddenTuple(ConsumeWholeTuple):
4848            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
4849                x = self.lstm(inputs, (h0, c0))
4850                x = self.module_after_lstm(x[1])  # consume tuple (hidden0, hidden1)
4851                return x
4852
4853        # Test consuming the whole tuple (output, (hidden0, hidden1))
4854        m1 = ConsumeWholeTuple()
4855        prepare_node_occurrence = {
4856            ns.call_module(torch.ao.nn.quantizable.LSTM): 1,
4857        }
4858        convert_node_occurrence1 = {
4859            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
4860            ns.call_function(torch.quantize_per_tensor): 3,
4861            # lstm[0].dequantize()
4862            # lstm[1][0].dequantize()
4863            # lstm[1][1].dequantize()
4864            ns.call_method("dequantize"): 3,
4865            # lstm[0], lstm[1], lstm[1][0], lstm[1][1]
4866            ns.call_function(operator.getitem): 4,
4867            # tuple(output_dq, tuple(hidden0_dq, hidden1_dq))
4868            ns.call_function(tuple): 2,
4869        }
4870        self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1)
4871
4872        # Test consuming just the hidden tuple (hidden0, hidden1)
4873        m2 = ConsumeHiddenTuple()
4874        convert_node_occurrence2 = {
4875            ns.call_module(torch.ao.nn.quantized.LSTM): 1,
4876            ns.call_function(torch.quantize_per_tensor): 3,
4877            # lstm[1][0].dequantize()
4878            # lstm[1][1].dequantize()
4879            ns.call_method("dequantize"): 2,
4880            # lstm[1], lstm[1][0], lstm[1][1]
4881            ns.call_function(operator.getitem): 3,
4882            # tuple(hidden0_dq, hidden1_dq)
4883            ns.call_function(tuple): 1,
4884        }
4885        self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2)
4886
4887    def test_static_lstm_with_custom_fixed_qparams(self):
4888        """
4889        Test statically quantized LSTM with custom fixed qparams assigned to each of the
4890        inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM`
4891        and use the child class in the custom module mapping.
4892        """
4893        class MyModel(torch.nn.Module):
4894            def __init__(self) -> None:
4895                super().__init__()
4896                self.my_lstm = torch.nn.LSTM(50, 50, 1)
4897
4898            def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor):
4899                x = self.my_lstm(inputs, (h0, c0))
4900                return x
4901
4902        # Construct a BackendConfig that supports qint32 for certain ops
4903        # TODO: build a BackendConfig from scratch instead of modifying an existing one
4904        qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32)
4905        my_backend_config = get_qnnpack_backend_config()
4906        for config in my_backend_config.configs:
4907            if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]:
4908                config.add_dtype_config(qint32_dtype_config)
4909
4910        class UserObservedLSTM(torch.ao.nn.quantizable.LSTM):
4911            """
4912            Example of user provided LSTM implementation that assigns fixed qparams
4913            to the inner ops.
4914            """
4915            @classmethod
4916            def from_float(cls, float_lstm):
4917                assert isinstance(float_lstm, cls._FLOAT_MODULE)
4918                # uint16, [-16, 16)
4919                linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32)
4920                # uint16, [0, 1)
4921                sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32)
4922                # uint16, [-1, 1)
4923                tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32)
4924                # int16, [-16, 16)
4925                cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32)
4926                # uint8, [-1, 1)
4927                hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8)
4928                example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50)))
4929                return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts(
4930                    float_lstm=float_lstm,
4931                    example_inputs=example_inputs,
4932                    backend_config=my_backend_config,
4933                    linear_output_obs_ctr=linear_output_obs_ctr,
4934                    sigmoid_obs_ctr=sigmoid_obs_ctr,
4935                    tanh_obs_ctr=tanh_obs_ctr,
4936                    cell_state_obs_ctr=cell_state_obs_ctr,
4937                    hidden_state_obs_ctr=hidden_state_obs_ctr,
4938                )
4939
4940        class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM):
4941            """
4942            Example of user provided LSTM implementation that produces a reference
4943            quantized module from a `UserObservedLSTM`.
4944            """
4945            @classmethod
4946            def from_observed(cls, observed_lstm):
4947                assert isinstance(observed_lstm, cls._FLOAT_MODULE)
4948                return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module(
4949                    observed_lstm=observed_lstm,
4950                    backend_config=my_backend_config,
4951                )
4952
4953        # FX graph mode quantization
4954        m = MyModel()
4955        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
4956        example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50))
4957        prepare_custom_config = PrepareCustomConfig() \
4958            .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM)
4959        convert_custom_config = ConvertCustomConfig() \
4960            .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM)
4961        prepared = prepare_fx(
4962            m,
4963            qconfig_mapping,
4964            example_inputs,
4965            prepare_custom_config,
4966            backend_config=my_backend_config,
4967        )
4968        prepared(*example_inputs)
4969        converted = convert_fx(
4970            prepared,
4971            convert_custom_config,
4972            backend_config=my_backend_config,
4973        )
4974        converted(*example_inputs)
4975
4976        # Find the patterns [dq - op - q_to_specific_dtype] in the graph and
4977        # verify that qparams and dtypes are set correctly in the quantize ops
4978        node_name_to_expected_quantize_args = {
4979            "igates": (None, None, torch.quint8),
4980            "hgates": (None, None, torch.quint8),
4981            "add": (2 ** -11, 2 ** 15, torch.qint32),  # gates.add
4982            "input_gate": (2 ** -16, 0, torch.qint32),
4983            "forget_gate": (2 ** -16, 0, torch.qint32),
4984            "cell_gate": (2 ** -15, 2 ** 15, torch.qint32),
4985            "output_gate": (2 ** -16, 0, torch.qint32),
4986            "mul": (2 ** -11, 0, torch.qint32),  # fgate_cx.mul
4987            "mul_1": (2 ** -11, 0, torch.qint32),  # igate_cgate.mul
4988            "add_1": (2 ** -11, 0, torch.qint32),  # fgate_cx_igate_cgate.add
4989            "mul_2": (2 ** -7, 2 ** 7, torch.quint8),  # ogate_cy.mul
4990        }
4991        cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell
4992        matched_names = set()
4993        for node in cell.graph.nodes:
4994            if node.name not in node_name_to_expected_quantize_args:
4995                continue
4996            matched_names.add(node.name)
4997            # Match preceding dequantize
4998            self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
4999            # Match following quantize with the specific qparams and dtypes
5000            expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name]
5001            for user in node.users.keys():
5002                self.assertEqual(user.target, torch.quantize_per_tensor)
5003                if expected_scale is not None:
5004                    self.assertEqual(getattr(cell, user.args[1].target), expected_scale)
5005                if expected_zp is not None:
5006                    self.assertEqual(getattr(cell, user.args[2].target), expected_zp)
5007                self.assertEqual(user.args[-1], expected_dtype)
5008        # Ensure all patterns were matched
5009        self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys()))
5010
5011    def test_reroute_tuple_getitem_patterns(self):
5012        """
5013        The following graph should redirect the output to `b`. After the transformation,
5014        all other nodes, including the inputs `a` and `c`, are no longer needed.
5015
5016             a   b     c
5017             |   \\   /
5018             \\   tuple
5019              \\   /
5020               tuple
5021               /  \\
5022              /    \\
5023             |      \\
5024             |       \\
5025             |        \\
5026        getitem0    getitem1
5027             |      /     \\
5028             | getitem0  getitem1
5029             |     \\     /
5030             \\      tuple
5031              \\      /
5032               \\    /
5033                tuple
5034                  |
5035               getitem1
5036                  |
5037               getitem0
5038                  |
5039                output
5040        """
5041        # Construct graph manually because symbolic_trace does not insert tuple and getitem nodes
5042        graph = torch.fx.Graph()
5043        a = graph.create_node("placeholder", "a")
5044        b = graph.create_node("placeholder", "b")
5045        c = graph.create_node("placeholder", "c")
5046        bc = graph.call_function(tuple, args=([b, c],))
5047        abc = graph.call_function(tuple, args=([a, bc],))
5048
5049        # Break down tuple and reconstruct it again
5050        a2 = graph.call_function(operator.getitem, args=(abc, 0))
5051        bc2 = graph.call_function(operator.getitem, args=(abc, 1))
5052        b2 = graph.call_function(operator.getitem, args=(bc2, 0))
5053        c2 = graph.call_function(operator.getitem, args=(bc2, 1))
5054        bc3 = graph.call_function(tuple, args=([b2, c2],))
5055        abc2 = graph.call_function(tuple, args=([a2, bc3],))
5056
5057        # Output tuple[1][0]
5058        bc4 = graph.call_function(operator.getitem, args=(abc2, 1))
5059        b3 = graph.call_function(operator.getitem, args=(bc4, 0))
5060        output = graph.output(b3)
5061
5062        # Do reroute
5063        _reroute_tuple_getitem_pattern(graph)
5064
5065        # Assert that output reroutes to `b` directly, and all other nodes can be removed
5066        output_ancestors = []
5067        def gather_ancestors(current_node):  # noqa: E306
5068            for arg in current_node.args:
5069                output_ancestors.append(arg)
5070                gather_ancestors(arg)
5071        gather_ancestors(output)
5072        self.assertEqual(output_ancestors, [b])
5073        self.assertEqual(output.args[0], b)
5074
5075    def test_relu_lowering(self):
5076        class M(torch.nn.Module):
5077            def forward(self, x):
5078                return torch.nn.functional.relu(x)
5079
5080        m = M().eval()
5081        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),))
5082        m_copy = copy.deepcopy(m)
5083        m = convert_fx(m)
5084        m_ref = convert_to_reference_fx(m_copy)
5085        node_occurrence = {
5086            ns.call_function(torch.quantize_per_tensor): 1,
5087            ns.call_method("dequantize"): 1
5088        }
5089        node_occurrence_ref = {
5090            ns.call_function(torch.quantize_per_tensor): 2,
5091            ns.call_method("dequantize"): 2
5092        }
5093
5094        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
5095        self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
5096
5097    @skipIfNoFBGEMM
5098    def test_dynamic_with_fusion(self):
5099        """
5100        Tests that dynamic quantization APIs work with Linear + Relu fusion
5101        """
5102        with override_quantized_engine('fbgemm'):
5103            class LinearRelu(torch.nn.Module):
5104                def __init__(self) -> None:
5105                    super().__init__()
5106                    self.linear = torch.nn.Linear(5, 5)
5107                    self.relu = torch.nn.ReLU()
5108
5109                def forward(self, x):
5110                    x = self.linear(x)
5111                    return self.relu(x)
5112
5113            class Linear(torch.nn.Module):
5114                def __init__(self) -> None:
5115                    super().__init__()
5116                    self.w = torch.ones(5, 5)
5117                    self.b = torch.zeros(5)
5118
5119                def forward(self, x):
5120                    return torch.nn.functional.linear(x, self.w, self.b)
5121
5122            class M(torch.nn.Module):
5123                def __init__(self) -> None:
5124                    super().__init__()
5125                    self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu())
5126                    self.mods2 = Linear()
5127                    self.relu = F.relu
5128
5129                def forward(self, x):
5130                    x = self.mods1(x)
5131                    x = self.mods2(x)
5132                    x = self.relu(x)
5133                    return x
5134
5135            dynamic_quantized_ops = {
5136                float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16,
5137                default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic
5138            }
5139            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
5140                model = M().eval()
5141                qconfig_dict = {
5142                    "": qconfig
5143                }
5144                example_inputs = (torch.rand(5, 5),)
5145                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
5146                m = convert_fx(m)
5147                m(*example_inputs)
5148                node_list = [
5149                    ns.call_module(nniqd.LinearReLU),
5150                    ns.call_module(nniqd.LinearReLU),
5151                    ns.call_function(dynamic_quantized_ops[qconfig]),
5152                ]
5153                self.checkGraphModuleNodes(m, expected_node_list=node_list)
5154
5155    @skipIfNoFBGEMM
5156    def test_dynamic_with_fusion_multiple_uses(self):
5157        """
5158        Tests that dynamic quantization APIs work with Linear + Relu fusion
5159        """
5160        with override_quantized_engine('fbgemm'):
5161            class LinearRelu(torch.nn.Module):
5162                def __init__(self) -> None:
5163                    super().__init__()
5164                    self.linear = torch.nn.Linear(5, 5)
5165                    self.relu = torch.nn.ReLU()
5166
5167                def forward(self, x):
5168                    x = self.linear(x)
5169                    return self.relu(x)
5170
5171            class M(torch.nn.Module):
5172                def __init__(self) -> None:
5173                    super().__init__()
5174                    self.linear_relu = LinearRelu()
5175
5176                def forward(self, x):
5177                    x = self.linear_relu(x)
5178                    x = self.linear_relu(x)
5179                    return x
5180
5181            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
5182                model = M().eval()
5183                qconfig_dict = {
5184                    "": qconfig
5185                }
5186                example_inputs = (torch.randn(5, 5),)
5187                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
5188                m = convert_fx(m)
5189                m(*example_inputs)
5190                node_list = [
5191                    ns.call_module(nniqd.LinearReLU),
5192                    ns.call_module(nniqd.LinearReLU),
5193                ]
5194                self.checkGraphModuleNodes(m, expected_node_list=node_list)
5195
5196    @skipIfNoFBGEMM
5197    def test_dynamic_linear_input_multiple_use(self):
5198        """
5199        Tests input for dynamic linear being used by multiple ops
5200        """
5201        with override_quantized_engine('fbgemm'):
5202            class LinearRelu(torch.nn.Module):
5203                def __init__(self) -> None:
5204                    super().__init__()
5205                    self.linear = torch.nn.Linear(5, 5)
5206                    self.relu = torch.nn.ReLU()
5207
5208                def forward(self, x):
5209                    x = self.linear(x)
5210                    return self.relu(x)
5211
5212            class M(torch.nn.Module):
5213                def __init__(self) -> None:
5214                    super().__init__()
5215                    self.mod1 = LinearRelu()
5216                    self.mod2 = LinearRelu()
5217
5218                def forward(self, x):
5219                    y1 = self.mod1(x)
5220                    y2 = self.mod2(x)
5221                    return y1 + y2
5222
5223            for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
5224                model = M().eval()
5225                qconfig_dict = {
5226                    "": qconfig
5227                }
5228                example_inputs = (torch.rand(5, 5, 5),)
5229                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
5230                m = convert_fx(m)
5231                m(*example_inputs)
5232                node_list = [
5233                    ns.call_module(nniqd.LinearReLU),
5234                    ns.call_module(nniqd.LinearReLU),
5235                ]
5236                self.checkGraphModuleNodes(m, expected_node_list=node_list)
5237
5238    def test_ref_linear_module(self):
5239        """ Make sure the numerics for models with ref linear module
5240        matches models with fbgemm/qnnpack module
5241        """
5242        class M1(torch.nn.Module):
5243            def __init__(self) -> None:
5244                super().__init__()
5245                self.linear = torch.nn.Linear(10, 5)
5246
5247            def forward(self, x):
5248                return self.linear(x)
5249
5250        class M2(torch.nn.Module):
5251            def __init__(self) -> None:
5252                super().__init__()
5253                self.linear = torch.nn.Linear(10, 5)
5254                self.relu = torch.nn.ReLU()
5255
5256            def forward(self, x):
5257                return self.relu(self.linear(x))
5258
5259        for M in [M1, M2]:
5260            m = M().eval()
5261            example_inputs = (torch.randn(5, 10),)
5262            m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
5263            m_copy = copy.deepcopy(m)
5264            m = convert_fx(m)
5265            m_ref = convert_to_reference_fx(m_copy)
5266            result = m(*example_inputs)
5267            result_ref = m_ref(*example_inputs)
5268            self.assertTrue(torch.equal(result, result_ref))
5269
5270    def test_ref_conv_module(self):
5271        """ Make sure the numerics for models with ref conv module
5272        matches models with fbgemm/qnnpack module
5273        """
5274        convs = {
5275            1: nn.Conv1d,
5276            2: nn.Conv2d,
5277            3: nn.Conv3d,
5278        }
5279
5280        class M1(torch.nn.Module):
5281            def __init__(self, dim):
5282                super().__init__()
5283                self.conv = convs[dim](3, 3, 3)
5284
5285            def forward(self, x):
5286                return self.conv(x)
5287
5288        class M2(torch.nn.Module):
5289            def __init__(self, dim):
5290                super().__init__()
5291                self.conv = convs[dim](3, 3, 3)
5292                self.relu = torch.nn.ReLU()
5293
5294            def forward(self, x):
5295                return self.relu(self.conv(x))
5296
5297        for dim, M in itertools.product([1, 2, 3], [M1, M2]):
5298            m = M(dim).eval()
5299            data = self.img_data_dict[dim][0][0]
5300            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
5301            m_copy = copy.deepcopy(m)
5302            m = convert_fx(m)
5303            m_ref = convert_to_reference_fx(m_copy)
5304            result = m(data)
5305            result_ref = m_ref(data)
5306            self.assertTrue(torch.equal(result, result_ref))
5307
5308    def test_sub_scalar(self):
5309        class M(torch.nn.Module):
5310            def forward(self, x):
5311                x = x + 1
5312                x = x - 1
5313                x = x + 3
5314                x = x - 4
5315                return x
5316
5317        m = M().eval()
5318        m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),))
5319        m = convert_fx(m)
5320        occurrence = {
5321            ns.call_function(torch.quantize_per_tensor): 2,
5322            ns.call_method("dequantize"): 2
5323        }
5324        self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence)
5325
5326    def test_observer_fqn(self):
5327        """
5328        Test to make sure the observer FQN is based on the quantizable op/module that it is observing
5329        and uses the modules FQN to determine the observer name.
5330        """
5331        class Linear(torch.nn.Module):
5332            def __init__(self) -> None:
5333                super().__init__()
5334                self.w = torch.ones(5, 5)
5335                self.b = torch.zeros(5)
5336
5337
5338            def forward(self, x):
5339                return torch.nn.functional.linear(x, self.w, self.b)
5340
5341
5342        class M(torch.nn.Module):
5343            def __init__(self) -> None:
5344                super().__init__()
5345                self.mods1 = torch.nn.Sequential(
5346                    Linear(),
5347                    Linear()
5348                )
5349                self.mods2 = Linear()
5350                self.mods3 = torch.nn.Linear(5, 5)
5351
5352            def forward(self, x):
5353                x = self.mods1(x)
5354                x = torch.add(x, 4)
5355                x = self.mods2(x)
5356                y = torch.add(x, 2)
5357                z = torch.mul(x, 5)
5358                a = self.mods3(y)
5359                return a, z
5360
5361        model = M().eval()
5362
5363        prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5)))
5364        name_list = []
5365        for name, mod in prepared.named_modules():
5366            if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver):
5367                name_list.append(name)
5368        expected_name_list = ['activation_post_process_0',
5369                              'activation_post_process_1',
5370                              'activation_post_process_2',
5371                              'activation_post_process_3',
5372                              'activation_post_process_4',
5373                              'activation_post_process_6',
5374                              'activation_post_process_7',
5375                              'activation_post_process_10']
5376        assert name_list == expected_name_list
5377
5378    def test_conv_lowering(self):
5379        convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d}
5380        qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d}
5381
5382        class M(torch.nn.Module):
5383            def __init__(self, dim):
5384                super().__init__()
5385                self.conv = convs[dim](3, 3, 3)
5386
5387            def forward(self, x):
5388                return self.conv(x)
5389
5390        for dim in range(1, len(convs) + 1):
5391            m = M(dim).eval()
5392            data = self.img_data_dict[dim][0][0]
5393            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,))
5394            m_ref = copy.deepcopy(m)
5395            m_ref = convert_to_reference_fx(m_ref)
5396            m = convert_fx(m)
5397            out_ref = m_ref(data)
5398            out = m(data)
5399            # check that reference pattern for quantized conv module is fused
5400            expected_node_occurrence = {
5401                ns.call_function(torch.quantize_per_tensor): 1,
5402                ns.call_module(qconvs[dim]): 1,
5403                ns.call_method("dequantize"): 1
5404            }
5405            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence)
5406            # checking result match
5407            self.assertTrue(torch.equal(out_ref, out))
5408
5409    def test_convert_qconfig_mapping(self):
5410        class Linear(torch.nn.Module):
5411            def __init__(self) -> None:
5412                super().__init__()
5413                self.w = torch.ones(5, 5)
5414                self.b = torch.zeros(5)
5415
5416            def forward(self, x):
5417                return torch.nn.functional.linear(x, self.w, self.b)
5418
5419
5420        class M(torch.nn.Module):
5421            def __init__(self) -> None:
5422                super().__init__()
5423                self.mods1 = torch.nn.Sequential(
5424                    Linear(),
5425                    Linear()
5426                )
5427                self.mods3 = torch.nn.Linear(5, 5)
5428
5429            def forward(self, x):
5430                x = self.mods1(x)
5431                x = torch.add(x, 4)
5432                z = torch.mul(x, 5)
5433                x = self.mods3(z)
5434                return x
5435
5436        model = M().train()
5437
5438        for check in ["module_name", "object_type"]:
5439            qconfig_dict = {"": None,
5440                            "object_type": [
5441                                (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
5442                                (torch.add, get_default_qat_qconfig("fbgemm")),
5443                                (nn.Linear, get_default_qat_qconfig("fbgemm")),
5444                            ],
5445                            }
5446            example_inputs = (torch.rand(5, 5),)
5447            prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
5448            prepared(*example_inputs)
5449            if check == "module_name":
5450                convert_qconfig_dict = {"": None,
5451                                        "object_type": [
5452                                            (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
5453                                            (torch.add, get_default_qat_qconfig("fbgemm")),
5454                                            (nn.Linear, get_default_qat_qconfig("fbgemm")),
5455                                        ],
5456                                        "module_name": [("mods1.0", None)]}
5457
5458                node_occurrence = {
5459                    ns.call_function(torch.quantize_per_tensor): 2,
5460                    ns.call_function(torch.nn.functional.linear): 1,
5461                    ns.call_function(torch.ops.quantized.linear): 1,
5462                    ns.call_function(torch.ops.quantized.add): 1,
5463                    ns.call_method("dequantize"): 2
5464                }
5465                order_check = [
5466                    ns.call_function(torch.nn.functional.linear),
5467                    ns.call_function(torch.quantize_per_tensor),
5468                    ns.call_function(torch.ops.quantized.linear),
5469                    ns.call_function(torch.ops.quantized.add),
5470                    ns.call_method("dequantize"),
5471                    ns.call_function(torch.quantize_per_tensor),
5472                    ns.call_module(nnq.Linear),
5473                    ns.call_method("dequantize"),
5474                ]
5475            elif check == "object_type":
5476                convert_qconfig_dict = {"": None,
5477                                        "object_type": [
5478                                            (nn.functional.linear, get_default_qat_qconfig("fbgemm")),
5479                                            (torch.add, get_default_qat_qconfig("fbgemm")),
5480                                            (nn.Linear, None),
5481                                        ]}
5482
5483                node_occurrence = {
5484                    ns.call_function(torch.quantize_per_tensor): 1,
5485                    ns.call_function(torch.ops.quantized.linear): 2,
5486                    ns.call_function(torch.ops.quantized.add): 1,
5487                    ns.call_function(torch.mul): 1,
5488                    ns.call_method("dequantize"): 1
5489                }
5490                order_check = [
5491                    ns.call_function(torch.quantize_per_tensor),
5492                    ns.call_function(torch.ops.quantized.linear),
5493                    ns.call_function(torch.ops.quantized.linear),
5494                    ns.call_function(torch.ops.quantized.add),
5495                    ns.call_method("dequantize"),
5496                    ns.call_function(torch.mul),
5497                    ns.call_module(nn.Linear),
5498                ]
5499
5500            converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict)
5501            converted(torch.rand(5, 5))
5502            self.checkGraphModuleNodes(
5503                converted,
5504                expected_node_occurrence=node_occurrence,
5505                expected_node_list=order_check)
5506
5507    def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2):
5508        self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr)
5509
5510    def test_register_patterns(self):
5511        def cleanUp():
5512            del _DEFAULT_FUSION_PATTERNS["dummy_fusion"]
5513            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"]
5514            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"]
5515            del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"]
5516            del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"]
5517            del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"]
5518            del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"]
5519            del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"]
5520        self.addCleanup(cleanUp)
5521
5522        @_register_fusion_pattern("dummy_fusion")
5523        class DummyFusion:
5524            pass
5525
5526        @_register_quant_pattern("dummy_quant")
5527        class DummyQuant:
5528            pass
5529
5530        @_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer)
5531        class DummyQuant2:
5532            pass
5533
5534        @_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer)
5535        class DummyQuant3:
5536            pass
5537
5538        self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion)
5539        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant)
5540        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2)
5541        self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3)
5542        self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer)
5543        self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer)
5544        self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"],
5545                                                  default_fixed_qparams_range_0to1_fake_quant)
5546        self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"],
5547                                                  default_fixed_qparams_range_neg1to1_fake_quant)
5548        output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True)
5549        output_observer_map = get_default_output_activation_post_process_map(is_training=False)
5550        self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer)
5551        self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"),
5552                                                  default_fixed_qparams_range_neg1to1_fake_quant)
5553
5554
5555
5556    def test_reuse_input_qconfig(self):
5557        class M1(torch.nn.Module):
5558            def __init__(self) -> None:
5559                super().__init__()
5560                self.conv = torch.nn.Conv2d(3, 3, 3)
5561
5562            def forward(self, x):
5563                x = self.conv(x)
5564                x = x.reshape()
5565                return x
5566
5567        class M2(torch.nn.Module):
5568            def forward(self, x):
5569                x = x.reshape()
5570                return x
5571
5572        options = itertools.product([M1, M2], [True, False])
5573        for M, is_qat in options:
5574            m = M1().eval()
5575            example_inputs = (torch.randn(1, 3, 3, 3),)
5576            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
5577            m = convert_fx(m)
5578            node_list = [
5579                ns.call_function(torch.quantize_per_tensor),
5580                ns.call_module(nnq.Conv2d),
5581                ns.call_method("reshape"),
5582                ns.call_method("dequantize"),
5583            ]
5584            self.checkGraphModuleNodes(
5585                m,
5586                expected_node_list=node_list)
5587
5588            m = M2().eval()
5589            m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs)
5590            m = convert_fx(m)
5591            node_occurrence = {
5592                ns.call_function(torch.quantize_per_tensor): 0,
5593                ns.call_method("dequnatize"): 0,
5594            }
5595            node_list = [
5596                ns.call_method("reshape"),
5597            ]
5598            self.checkGraphModuleNodes(
5599                m,
5600                expected_node_occurrence=node_occurrence,
5601                expected_node_list=node_list)
5602
5603    def test_stack_trace_preserved_linear(self):
5604        class M(nn.Module):
5605            def __init__(self) -> None:
5606                super().__init__()
5607                self.linear = nn.Linear(1, 1)
5608
5609            def forward(self, x):
5610                x = self.linear(x)
5611                return x
5612
5613        m = M().eval()
5614        mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),))
5615
5616        found_stack_trace = False
5617        for n in mp.graph.nodes:
5618            if n.op == 'call_module' and n.target == 'linear':
5619                found_stack_trace = n.stack_trace is not None
5620                break
5621        self.assertTrue(found_stack_trace)
5622
5623        # test reference model
5624        mq = convert_to_reference_fx(copy.deepcopy(mp))
5625        found_stack_trace = False
5626        for n in mq.graph.nodes:
5627            if n.op == 'call_module' and n.target == 'linear':
5628                found_stack_trace = n.stack_trace is not None
5629                break
5630        self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True")
5631
5632        # test quantized model
5633        mq = convert_fx(mp)
5634        found_stack_trace = False
5635        for n in mq.graph.nodes:
5636            if n.op == 'call_module' and n.target == 'linear':
5637                found_stack_trace = n.stack_trace is not None
5638                break
5639        self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False")
5640
5641    def test_qat_skip_untraced(self):
5642        class UnTraceableModuleClass(nn.Module):
5643            def __init__(self) -> None:
5644                super().__init__()
5645                self.linear = nn.Linear(2, 2)
5646
5647            def forward(self, x):
5648                return self.linear(x)
5649
5650        class UnTraceableModuleName(nn.Module):
5651            def __init__(self) -> None:
5652                super().__init__()
5653                self.linear = nn.Linear(2, 2)
5654
5655            def forward(self, x):
5656                return self.linear(x)
5657
5658        class M(nn.Module):
5659            def __init__(self) -> None:
5660                super().__init__()
5661                self.untraceable_module_class = UnTraceableModuleClass()
5662                self.untraceable_module_name = UnTraceableModuleClass()
5663
5664            def forward(self, x):
5665                x = self.untraceable_module_class(x)
5666                x = self.untraceable_module_name(x)
5667                return x
5668
5669        mod = M()
5670
5671        qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()}
5672        prepare_custom_config_dict = {
5673            "non_traceable_module_class": [UnTraceableModuleClass],
5674            "non_traceable_module_name": ["untraceable_module_name"],
5675        }
5676        example_inputs = (torch.randn(2, 2),)
5677        mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
5678            mod.train(), qconfig_dict, example_inputs=example_inputs,
5679            prepare_custom_config=prepare_custom_config_dict
5680        )
5681        mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx(
5682            mod.train(), qconfig_dict, example_inputs=example_inputs,
5683            prepare_custom_config=prepare_custom_config_dict
5684        )
5685        self.assertTrue(
5686            isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear)
5687        )
5688        self.assertTrue(
5689            isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear)
5690        )
5691        self.assertTrue(
5692            type(mod_prep.untraceable_module_class.linear)
5693            is not torch.ao.nn.qat.modules.linear.Linear,
5694            "prepare_qat_fx shold not convert anything inside untraced module classes",
5695        )
5696        self.assertTrue(
5697            type(mod_prep.untraceable_module_name.linear)
5698            is not torch.ao.nn.qat.modules.linear.Linear,
5699            "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names",
5700        )
5701
5702    def test_qconfig_dict_setup(self):
5703        class M(torch.nn.Module):
5704            def __init__(self) -> None:
5705                super().__init__()
5706                self.Conv1d = torch.nn.Conv1d(1, 1, 1)
5707                self.Conv2d = torch.nn.Conv2d(1, 1, 1)
5708                self.Conv3d = torch.nn.Conv3d(1, 1, 1)
5709                self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1)
5710                self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1)
5711                self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1)
5712                self.Linear = torch.nn.Linear(1, 1, 1)
5713
5714            def forward(self, x):
5715                x = self.Conv1d(x)
5716                x = self.Conv2d(x)
5717                x = self.Conv3d(x)
5718                x = self.ConvTranspose1d(x)
5719                x = self.ConvTranspose2d(x)
5720                x = self.ConvTranspose3d(x)
5721                x = self.Linear(x)
5722                x = torch.nn.functional.conv1d(x, torch.rand(2, 2))
5723                x = torch.nn.functional.conv2d(x, torch.rand(2, 2))
5724                x = torch.nn.functional.conv3d(x, torch.rand(2, 2))
5725                x = torch.nn.functional.linear(x, torch.rand(2, 2))
5726                return x
5727
5728        backends = ["qnnpack", "fbgemm"]
5729        for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]:
5730            for backend in backends:
5731                m = M().eval()
5732                qconfig_dict = func(backend)
5733                m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1)))
5734                for name, mod in m.named_modules():
5735                    if _is_activation_post_process(mod) and mod.dtype == torch.quint8:
5736                        if backend == "fbgemm":
5737                            lower_bnd = 0
5738                            upper_bnd = 127
5739                        else:
5740                            lower_bnd = 0
5741                            upper_bnd = 255
5742                        if issubclass(type(mod), FakeQuantize):
5743                            self.assertEqual(mod.activation_post_process.quant_min, lower_bnd)
5744                            self.assertEqual(mod.activation_post_process.quant_max, upper_bnd)
5745                        else:
5746                            self.assertEqual(mod.quant_min, lower_bnd)
5747                            self.assertEqual(mod.quant_max, upper_bnd)
5748
5749    def test_prepare_mode(self):
5750        class LinearModel(torch.nn.Module):
5751            def __init__(self) -> None:
5752                super().__init__()
5753                self.linear = torch.nn.Linear(5, 10)
5754
5755            def forward(self, x):
5756                return self.linear(x)
5757
5758        def _test(prepare_fn, qconfig_dict):
5759            m = LinearModel()
5760            m1 = copy.deepcopy(m)
5761            m1.train()
5762            example_inputs = (torch.randn(1, 5),)
5763            prepare_fn(m1, qconfig_dict, example_inputs=example_inputs)
5764            m2 = copy.deepcopy(m)
5765            m2.eval()
5766            prepare_fn(m2, qconfig_dict, example_inputs=example_inputs)
5767
5768        # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes
5769        _test(prepare_fx, get_default_qconfig_mapping())
5770        _test(prepare_qat_fx, get_default_qat_qconfig_mapping())
5771
5772    def _validate_qconfig_against_backend_config_constraints(
5773            self,
5774            model: torch.nn.Module,
5775            qconfig: QConfig,
5776            backend_config: BackendConfig,
5777            satisfies_constraints: bool,
5778            qconfig_name: Optional[str] = None):
5779        """
5780        Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`.
5781        """
5782        qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
5783        example_inputs = (torch.rand((1, 30), dtype=torch.float),)
5784        model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
5785        model(*example_inputs)
5786        model = convert_fx(model, backend_config=backend_config)
5787        if satisfies_constraints:
5788            expected_node_occurrence = {
5789                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
5790                ns.call_module(torch.nn.Linear) : 0,
5791            }
5792        else:
5793            expected_node_occurrence = {
5794                ns.call_module(torch.ao.nn.quantized.Linear) : 0,
5795                ns.call_module(torch.nn.Linear) : 1,
5796            }
5797        try:
5798            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
5799        except AssertionError as e:
5800            if qconfig_name is not None:
5801                print(f"ERROR: Validation for QConfig '{qconfig_name}' failed")
5802            raise e
5803
5804    def test_backend_config_quantization_range(self):
5805        """
5806        Check that quantization ranges specified through the BackendConfig are reflected in
5807        the observers inserted into the model.
5808        """
5809        class MyModel(torch.nn.Module):
5810            def __init__(self) -> None:
5811                super().__init__()
5812                self.linear = torch.nn.Linear(30, 4).float()
5813
5814            def forward(self, x):
5815                return self.linear(x)
5816
5817        dtype_config = DTypeConfig(
5818            input_dtype=DTypeWithConstraints(
5819                dtype=torch.quint8,
5820                quant_min_lower_bound=0,
5821                quant_max_upper_bound=31,
5822            ),
5823            output_dtype=DTypeWithConstraints(
5824                dtype=torch.quint8,
5825                quant_min_lower_bound=0,
5826                quant_max_upper_bound=31,
5827            ),
5828            weight_dtype=DTypeWithConstraints(
5829                dtype=torch.qint8,
5830                quant_min_lower_bound=-64,
5831                quant_max_upper_bound=63,
5832            ),
5833            bias_dtype=torch.float,
5834        )
5835        backend_config = BackendConfig() \
5836            .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
5837                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E128
5838                .add_dtype_config(dtype_config)
5839                .set_root_module(torch.nn.Linear)
5840                .set_reference_quantized_module(nnqr.Linear))
5841
5842        def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
5843            self._validate_qconfig_against_backend_config_constraints(
5844                MyModel(), qconfig, backend_config, satisfies_constraints)
5845
5846        # Case 1: QConfig ranges fit within backend ranges, OK
5847        qconfig1 = QConfig(
5848            activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8),
5849            weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
5850        validate_qconfig(qconfig1, satisfies_constraints=True)
5851
5852        # Case 2: QConfig activation range falls outside backend range, should fail
5853        qconfig2 = QConfig(
5854            activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8),
5855            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
5856        validate_qconfig(qconfig2, satisfies_constraints=False)
5857
5858        # Case 3: QConfig weight range falls outside backend range, should fail
5859        qconfig3 = QConfig(
5860            activation=MinMaxObserver.with_args(dtype=torch.quint8),
5861            weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
5862        validate_qconfig(qconfig3, satisfies_constraints=False)
5863
5864        # Case 4: QConfig doesn't specify range, should fail
5865        qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver)
5866        validate_qconfig(qconfig4, satisfies_constraints=False)
5867
5868    def test_backend_config_scale_min(self):
5869        """
5870        Test QConfig eps validation against the BackendConfig's min scale value.
5871        """
5872        class MyModel(torch.nn.Module):
5873            def __init__(self) -> None:
5874                super().__init__()
5875                self.linear = torch.nn.Linear(30, 4).float()
5876
5877            def forward(self, x):
5878                return self.linear(x)
5879
5880        dtype_config = DTypeConfig(
5881            input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
5882            output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12),
5883            weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12),
5884            bias_dtype=torch.float,
5885        )
5886
5887        backend_config = BackendConfig() \
5888            .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear)
5889                .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT)  # noqa: E128
5890                .add_dtype_config(dtype_config)
5891                .set_root_module(torch.nn.Linear)
5892                .set_reference_quantized_module(nnqr.Linear))
5893
5894        def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool):
5895            self._validate_qconfig_against_backend_config_constraints(
5896                MyModel(), qconfig, backend_config, satisfies_constraints)
5897
5898        # Case 1: QConfig min scale value == backend min scale value, OK
5899        qconfig1 = QConfig(
5900            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12),
5901            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12))
5902        validate_qconfig(qconfig1, satisfies_constraints=True)
5903
5904        # Case 2: QConfig min scale value > backend min scale value, OK
5905        qconfig2 = QConfig(
5906            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10),
5907            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10))
5908        validate_qconfig(qconfig2, satisfies_constraints=True)
5909
5910        # Case 3: QConfig activation min scale value < backend min scale value, should fail
5911        qconfig3 = QConfig(
5912            activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14),
5913            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
5914        validate_qconfig(qconfig3, satisfies_constraints=False)
5915
5916        # Case 3: QConfig weight min scale value < backend min scale value, should fail
5917        qconfig4 = QConfig(
5918            activation=MinMaxObserver.with_args(dtype=torch.quint8),
5919            weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14))
5920        validate_qconfig(qconfig4, satisfies_constraints=False)
5921
5922        # Case 5: QConfig doesn't specify eps, should fail
5923        qconfig5 = QConfig(
5924            activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0),
5925            weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0))
5926        validate_qconfig(qconfig5, satisfies_constraints=False)
5927
5928    def test_qnnpack_backend_config(self):
5929        """
5930        Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig.
5931        """
5932        class MyModel(torch.nn.Module):
5933            def __init__(self) -> None:
5934                super().__init__()
5935                self.linear = torch.nn.Linear(30, 4).float()
5936
5937            def forward(self, x):
5938                return self.linear(x)
5939
5940        all_qconfigs: List[Tuple[QConfig, str]] = [
5941            (get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"),
5942            (get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"),
5943            (get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"),
5944            (default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"),
5945            (default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"),
5946            # TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862
5947            # (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"),
5948            # (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"),
5949        ]
5950        backend_config = get_qnnpack_backend_config()
5951        for qconfig, qconfig_name in all_qconfigs:
5952            self._validate_qconfig_against_backend_config_constraints(
5953                MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name)
5954
5955    def test_symmetric_qnnpack_qconfig_mapping(self):
5956        """
5957        Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping`
5958        works with the QNNPACK BackendConfig.
5959        """
5960        if "qnnpack" not in supported_qengines:
5961            return
5962
5963        class MyModel(torch.nn.Module):
5964            def __init__(self) -> None:
5965                super().__init__()
5966                self.linear = torch.nn.Linear(30, 4).float()
5967
5968            def forward(self, x):
5969                return self.linear(x)
5970
5971        with override_quantized_engine("qnnpack"):
5972            qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping()
5973            example_inputs = (torch.rand((1, 30), dtype=torch.float),)
5974            backend_config = get_qnnpack_backend_config()
5975            model = MyModel()
5976            model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
5977            model(*example_inputs)
5978            model = convert_fx(model, backend_config=backend_config)
5979            expected_node_occurrence = {
5980                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
5981                ns.call_module(torch.nn.Linear) : 0,
5982            }
5983            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
5984            model(*example_inputs)
5985
5986    def test_symmetric_qnnpack_qat_qconfig_mapping(self):
5987        """
5988        Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping`
5989        works with the QNNPACK BackendConfig.
5990        """
5991        if "qnnpack" not in supported_qengines:
5992            return
5993
5994        class MyModel(torch.nn.Module):
5995            def __init__(self) -> None:
5996                super().__init__()
5997                self.linear = torch.nn.Linear(30, 4).float()
5998
5999            def forward(self, x):
6000                return self.linear(x)
6001
6002        with override_quantized_engine("qnnpack"):
6003            qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping()
6004            example_inputs = (torch.rand((1, 30), dtype=torch.float),)
6005            backend_config = get_qnnpack_backend_config()
6006            model = MyModel()
6007            model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config)
6008            model(*example_inputs)
6009            model = convert_fx(model, backend_config=backend_config)
6010            expected_node_occurrence = {
6011                ns.call_module(torch.ao.nn.quantized.Linear) : 1,
6012                ns.call_module(torch.nn.Linear) : 0,
6013            }
6014            self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence)
6015            model(*example_inputs)
6016
6017
6018    def test_get_executorch_backend_config(self):
6019        from torch.ao.quantization.backend_config import get_executorch_backend_config
6020        # make sure this runs
6021        executorch_backend_config = get_executorch_backend_config()
6022
6023    def test_backend_config_check_for_weight_and_bias(self):
6024        """ Test to make sure the backend_config check for weight and bias
6025        runs when the qconfig is None for the ops with weight and bias
6026        previously the error was not hit because we first check input, and
6027        the check for weight and bias are skipped.
6028        """
6029
6030        class M(torch.nn.Module):
6031            def __init__(self) -> None:
6032                super().__init__()
6033                self.weight = torch.tensor((5, 5))
6034                self.bias = torch.tensor((5,))
6035
6036            def forward(self, x):
6037                return torch.addmm(self.bias, x, self.weight)
6038
6039        m = M().eval()
6040        qconfig_mapping = QConfigMapping()
6041        observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
6042        weighted_op_quint8_dtype_config = DTypeConfig(
6043            input_dtype=torch.quint8,
6044            output_dtype=torch.quint8,
6045            weight_dtype=torch.qint8,
6046            bias_dtype=torch.float,
6047        )
6048        dtype_configs = [weighted_op_quint8_dtype_config]
6049        backend_pattern_config = BackendPatternConfig(torch.addmm) \
6050            .set_observation_type(observation_type) \
6051            .set_dtype_configs(dtype_configs) \
6052            ._set_input_type_to_index({"weight": 2, "bias": 0})
6053        backend_config = BackendConfig() \
6054            .set_backend_pattern_config(backend_pattern_config)
6055        example_inputs = (torch.rand(1, 5),)
6056        # make sure this runs
6057        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
6058
6059    def test_get_default_qconfig_valid_backend(self):
6060        """ Checks that AssertionError is raised when non expected backend input is specified
6061        """
6062        invalid_backends = ["imaginary_backend", 3]
6063        for invalid_backend in invalid_backends:
6064            with self.assertRaisesRegex(AssertionError, "not supported"):
6065                qconfig = get_default_qconfig(invalid_backend)
6066            with self.assertRaisesRegex(AssertionError, "not supported"):
6067                qconfig = get_default_qat_qconfig(invalid_backend)
6068            with self.assertRaisesRegex(AssertionError, "not supported"):
6069                qconfig_mapping = get_default_qconfig_mapping(invalid_backend)
6070            with self.assertRaisesRegex(AssertionError, "not supported"):
6071                qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend)
6072
6073    def test__convert_to_reference_decomposed_fx(self):
6074        class M(torch.nn.Module):
6075            def __init__(self) -> None:
6076                super().__init__()
6077                self.linear = torch.nn.Linear(5, 10)
6078
6079            def forward(self, x):
6080                return self.linear(x)
6081
6082        m = M().eval()
6083        qconfig_mapping = get_default_qconfig_mapping("fbgemm")
6084        example_inputs = (torch.randn(1, 5),)
6085        m = prepare_fx(m, qconfig_mapping, example_inputs)
6086        m_ref = copy.deepcopy(m)
6087        m_ref = convert_to_reference_fx(m_ref)
6088        m = _convert_to_reference_decomposed_fx(m)
6089        expected_occurrence = {
6090            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
6091            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
6092        }
6093        self.checkGraphModuleNodes(
6094            m,
6095            expected_node_occurrence=expected_occurrence)
6096        # make sure it runs
6097        res_ref = m_ref(*example_inputs)
6098        res = m(*example_inputs)
6099        self.assertEqual(res, res_ref)
6100
6101    @skipIfNoQNNPACK
6102    def test__convert_to_reference_decomposed_fx_dynamic_quant(self):
6103        class M(torch.nn.Module):
6104            def __init__(self) -> None:
6105                super().__init__()
6106                self.linear = torch.nn.Linear(5, 10)
6107
6108            def forward(self, x):
6109                return self.linear(x)
6110
6111        # to avoid reduce_range
6112        with override_quantized_engine("qnnpack"):
6113            m = M().eval()
6114            qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
6115                .set_object_type(torch.nn.Linear, default_dynamic_qconfig)
6116            example_inputs = (torch.randn(1, 5),)
6117            m = prepare_fx(m, qconfig_mapping, example_inputs)
6118            m(*example_inputs)
6119            m_ref = copy.deepcopy(m)
6120            m_ref = convert_to_reference_fx(m_ref)
6121            m = _convert_to_reference_decomposed_fx(m)
6122            expected_occurrence = {
6123                ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1,
6124                ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1,
6125                ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1,
6126            }
6127            self.checkGraphModuleNodes(
6128                m,
6129                expected_node_occurrence=expected_occurrence)
6130            # make sure it runs
6131            res_ref = m_ref(*example_inputs)
6132            res = m(*example_inputs)
6133            self.assertEqual(res, res_ref)
6134
6135    def test__convert_to_reference_decomposed_fx_per_channel_quant(self):
6136        class M(torch.nn.Module):
6137            def forward(self, x, weight, bias):
6138                return F.linear(x, weight, bias)
6139
6140        m = M().eval()
6141        qconfig_mapping = get_default_qconfig_mapping("fbgemm") \
6142            .set_object_type(F.linear, default_per_channel_qconfig)
6143        example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,))
6144        m = prepare_fx(m, qconfig_mapping, example_inputs)
6145        m(*example_inputs)
6146        m_ref = copy.deepcopy(m)
6147        m_ref = convert_to_reference_fx(m_ref)
6148        m = _convert_to_reference_decomposed_fx(m)
6149        expected_occurrence = {
6150            # for input and output activations
6151            ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
6152            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
6153            # for weight
6154            ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
6155            ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
6156        }
6157        self.checkGraphModuleNodes(
6158            m,
6159            expected_node_occurrence=expected_occurrence)
6160        # make sure it runs
6161        res_ref = m_ref(*example_inputs)
6162        res = m(*example_inputs)
6163        self.assertEqual(res, res_ref)
6164
6165    def test_change_backend_config_for_fixed_qparam_ops(self):
6166        """ Making sure we can skip validation of qconfigs for fixedqparam ops based
6167        on BackendConfig
6168        """
6169        class M(nn.Module):
6170            def __init__(self) -> None:
6171                super().__init__()
6172                self.tanh = torch.nn.Tanh()
6173
6174            def forward(self, x: torch.Tensor):
6175                x = self.tanh(x)
6176                return x
6177
6178        model = M().eval()
6179        # we set a global default_qconfig, which will be ignored since the backend
6180        # we defined doesn't support anything
6181        # this is to make sure we don't validate the qconfig when BackendConfig does not
6182        # have fixed qparam op related configurations
6183        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
6184        backend_config = BackendConfig()
6185        # make sure this runs
6186        model = prepare_fx(
6187            model,
6188            qconfig_mapping=qconfig_mapping,
6189            example_inputs=(torch.randn(1, 2, 3, 4),),
6190            backend_config=backend_config
6191        )
6192
6193    def test_channel_shuffle_lowering(self):
6194        # Three versions of channel shuffle
6195        class M1(torch.nn.Module):
6196            def __init__(self) -> None:
6197                super().__init__()
6198                self.op = torch.nn.ChannelShuffle(2)
6199
6200            def forward(self, x):
6201                return self.op(x + x) + x
6202
6203        class M2(torch.nn.Module):
6204            def forward(self, x):
6205                return torch.channel_shuffle(x + x, 2) + x
6206
6207        class M3(torch.nn.Module):
6208            def forward(self, x):
6209                return torch.nn.functional.channel_shuffle(x + x, 2) + x
6210
6211        x = torch.randn(4, 4, 4, 4)
6212        # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle
6213        model_node_pairs = [
6214            (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)),
6215            (M2().eval(), ns.call_function(torch.channel_shuffle)),
6216            (M3().eval(), ns.call_function(torch.channel_shuffle))
6217        ]
6218        for m, node in model_node_pairs:
6219            m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,))
6220            m_copy = copy.deepcopy(m)
6221            m = convert_fx(m)
6222            m_ref = convert_to_reference_fx(m_copy)
6223            node_occurrence = {
6224                node: 1,
6225                ns.call_function(torch.quantize_per_tensor): 1,
6226                ns.call_method("dequantize"): 1
6227            }
6228            node_occurrence_ref = {
6229                node: 1,
6230                ns.call_function(torch.quantize_per_tensor): 4,
6231                ns.call_method("dequantize"): 4
6232            }
6233            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
6234            self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
6235
6236    def test_match_pattern_with_multiple_args(self):
6237        """ Test that we can match a pattern that has multiple arguments
6238        Pattern:
6239                           shape \
6240        transpose (observed) -> reshape -> output (observed) ->
6241
6242        where `reshape` has two arguments
6243        """
6244
6245        def _get_pattern_configs():
6246            backend_pattern_configs = []
6247            observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
6248            weighted_op_quint8_dtype_config = DTypeConfig(
6249                input_dtype=torch.quint8,
6250                output_dtype=torch.quint8,
6251                weight_dtype=torch.qint8,
6252                bias_dtype=torch.float,
6253            )
6254            dtype_configs = [weighted_op_quint8_dtype_config]
6255
6256            def root_node_getter(node_pattern):
6257                reshape, transpose, shape = node_pattern
6258                return transpose
6259
6260            backend_pattern_configs.append(
6261                BackendPatternConfig()
6262                ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode))  # noqa: E131
6263                .set_observation_type(observation_type)
6264                .set_dtype_configs(dtype_configs)
6265                ._set_root_node_getter(root_node_getter)
6266            )
6267            return backend_pattern_configs
6268
6269        backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs())
6270
6271        class M(torch.nn.Module):
6272            def forward(self, x):
6273                x = torch.transpose(x, 0, 1)
6274                x = torch.reshape(x, (-1,))
6275                return x
6276
6277        m = M().eval()
6278        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
6279        example_inputs = (torch.randn(1, 3, 3, 3),)
6280        m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
6281        node_occurrence = {
6282            # one for input of the pattern and one for output of the pattern
6283            ns.call_module(MinMaxObserver): 2
6284        }
6285        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
6286
6287    def _test_linear_activation_fusion_lowering_helper(
6288            self, module, example_inputs, qconfig_mapping,
6289            backend_config, fused_module, root_module, activation_module):
6290        node_occurrence = {
6291            ns.call_function(torch.quantize_per_tensor): 1,
6292            ns.call_method("dequantize"): 1,
6293            ns.call_module(fused_module): 1,
6294            ns.call_module(root_module): 0,
6295            ns.call_module(activation_module): 0,
6296        }
6297        node_occurrence_ref = {
6298            ns.call_function(torch.quantize_per_tensor): 2,
6299            ns.call_method("dequantize"): 2,
6300        }
6301        m = module.eval()
6302        m = prepare_fx(m, qconfig_mapping,
6303                       example_inputs=example_inputs,
6304                       backend_config=backend_config)
6305        m_copy = copy.deepcopy(m)
6306        m = convert_fx(m, backend_config=backend_config)
6307        m_ref = convert_to_reference_fx(m_copy)
6308
6309        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
6310        self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref)
6311        m(*example_inputs)
6312
6313    @skipIfNoONEDNN
6314    def test_linear_leaky_relu_lowering(self):
6315        """ Test fusion and lowering of Linear - (bn -) LeakyReLU
6316            by FX. For onednn backedn only.
6317        """
6318        from torch.ao.quantization.backend_config import get_onednn_backend_config
6319        qconfig_mapping = get_default_qconfig_mapping('onednn')
6320        with override_quantized_engine('onednn'):
6321            for with_bn in [True, False]:
6322                m = LinearBnLeakyReluModel(with_bn)
6323                self._test_linear_activation_fusion_lowering_helper(
6324                    m,
6325                    m.get_example_inputs(),
6326                    qconfig_mapping,
6327                    get_onednn_backend_config(),
6328                    nniq.LinearLeakyReLU,
6329                    nn.Linear,
6330                    nn.LeakyReLU)
6331
6332    @skipIfNoONEDNN
6333    def test_linear_tanh_lowering(self):
6334        """ Test fusion and lowering of Linear - Tanh
6335            by FX. For onednn backedn only.
6336        """
6337        from torch.ao.quantization.backend_config import get_onednn_backend_config
6338        qconfig_mapping = get_default_qconfig_mapping('onednn')
6339        # TODO Currently it's required that separate ops in a fused op/module have the same qconfig.
6340        #      Need to be able to support fusion of ops with different qconfigs
6341        # Since tanh must have 'fixed_qparams_qconfig' while linear should use
6342        # the global qconfig, we need to set qconfigs for them manually here for
6343        # fusion and cannot put such configs in onednn's default qconfig_mapping.
6344        # Known issue:
6345        # Cannot fuse linear - tanh and quantize standalone tanh at the same time.
6346        qconfig = get_default_qconfig('onednn')
6347        qconfig_mapping.set_object_type(torch.nn.Linear, qconfig)
6348        qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig)
6349        with override_quantized_engine('onednn'):
6350            m = LinearTanhModel()
6351            self._test_linear_activation_fusion_lowering_helper(
6352                m,
6353                m.get_example_inputs(),
6354                qconfig_mapping,
6355                get_onednn_backend_config(),
6356                nniq.LinearTanh,
6357                nn.Linear,
6358                nn.Tanh)
6359
6360    @override_qengines
6361    def test_linear_size_view(self):
6362        class M(torch.nn.Module):
6363            def __init__(self, use_relu=False):
6364                super().__init__()
6365                self.linear = torch.nn.Linear(16, 32)
6366                self.relu = torch.nn.ReLU()
6367                self.use_relu = use_relu
6368
6369            def forward(self, x):
6370                x = self.linear(x)
6371                if self.use_relu:
6372                    x = self.relu(x)
6373                return x.view(x.size(0), 1, 4, 8)
6374
6375        for use_relu in [False, True]:
6376            model_fp32 = M(use_relu).eval()
6377            qengine = torch.backends.quantized.engine
6378            qconfig_mapping = get_default_qconfig_mapping(qengine)
6379            x = torch.randn((5, 16))
6380            model_fp32(x)
6381            prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
6382            prepared_model(x)
6383            quantized_model = convert_fx(prepared_model)
6384            node_occurrence = {
6385                ns.call_module(nnq.Linear): 0 if use_relu else 1,
6386                ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
6387                ns.call_function(torch.quantize_per_tensor): 1,
6388                ns.call_method("dequantize"): 1
6389            }
6390            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
6391
6392    @override_qengines
6393    def test_linear_shape_view(self):
6394        class M(torch.nn.Module):
6395            def __init__(self, use_relu=False):
6396                super().__init__()
6397                self.linear = torch.nn.Linear(16, 32)
6398                self.relu = torch.nn.ReLU()
6399                self.use_relu = use_relu
6400
6401            def forward(self, x):
6402                x = self.linear(x)
6403                if self.use_relu:
6404                    x = self.relu(x)
6405                return x.view(x.shape[0], 1, 4, 8)
6406
6407        for use_relu in [False, True]:
6408            model_fp32 = M(use_relu).eval()
6409            qengine = torch.backends.quantized.engine
6410            qconfig_mapping = get_default_qconfig_mapping(qengine)
6411            x = torch.randn((5, 16))
6412            model_fp32(x)
6413            prepared_model = prepare_fx(model_fp32, qconfig_mapping, x)
6414            prepared_model(x)
6415            quantized_model = convert_fx(prepared_model)
6416            node_occurrence = {
6417                ns.call_module(nnq.Linear): 0 if use_relu else 1,
6418                ns.call_module(nniq.LinearReLU): 1 if use_relu else 0,
6419                ns.call_function(torch.quantize_per_tensor): 1,
6420                ns.call_method("dequantize"): 1
6421            }
6422            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
6423
6424    def test_mixed_dtypes(self):
6425        """
6426        Test that multiple dtypes can be used in the same model for different layers,
6427        and the dtypes will be converted correctly between the layers.
6428        """
6429        class MyModule(torch.nn.Module):
6430            def __init__(self) -> None:
6431                super().__init__()
6432                self.linear1 = torch.nn.Linear(5, 5)
6433                self.linear2 = torch.nn.Linear(5, 5)
6434                self.sigmoid = torch.nn.Sigmoid()
6435                self.tanh = torch.nn.Tanh()
6436                self.float_functional = torch.ao.nn.quantized.FloatFunctional()
6437
6438            def forward(self, x: torch.Tensor):
6439                x = self.linear1(x)  # qint32
6440                x = self.linear2(x)  # quint8
6441                linear2 = x
6442                x = self.sigmoid(x)  # back to qint32
6443                x = self.tanh(x)  # back to quint8
6444                x = self.float_functional.add(linear2, x)  # adding two quint8's together
6445                return x
6446
6447        def make_qconfig(scale, zp, dtype):
6448            return QConfig(
6449                activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype),
6450                weight=torch.ao.quantization.default_weight_observer)
6451
6452        # Set up a QConfigMapping that specifies different qparams and dtypes for different layers
6453        qconfig_mapping = QConfigMapping() \
6454            .set_global(get_default_qconfig("qnnpack")) \
6455            .set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \
6456            .set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \
6457            .set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \
6458            .set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8))
6459
6460        # Set up BackendConfig that supports the dtypes configured in the above QConfigMapping
6461        weighted_op_qint32_dtype_config = DTypeConfig(
6462            input_dtype=torch.qint32,
6463            output_dtype=torch.qint32,
6464            weight_dtype=torch.qint8,
6465            bias_dtype=torch.float,
6466        )
6467        fixed_qparams_op_quint8_dtype_config = DTypeConfig(
6468            input_dtype=torch.quint8,
6469            output_dtype=torch.quint8,
6470        )
6471        fixed_qparams_op_qint32_dtype_config = DTypeConfig(
6472            input_dtype=torch.qint32,
6473            output_dtype=torch.qint32,
6474        )
6475        backend_config = get_qnnpack_backend_config()
6476        for config in backend_config.configs:
6477            if config.pattern == torch.nn.Linear:
6478                config.add_dtype_config(weighted_op_qint32_dtype_config)
6479            elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]:
6480                config.add_dtype_config(fixed_qparams_op_quint8_dtype_config)
6481                config.add_dtype_config(fixed_qparams_op_qint32_dtype_config)
6482
6483        # Produce the reference quantized model
6484        m = MyModule()
6485        example_inputs = (torch.rand(5, 5),)
6486        prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config)
6487        prepared(*example_inputs)  # calibrate
6488        converted = convert_to_reference_fx(prepared, backend_config=backend_config)
6489        converted(*example_inputs)
6490
6491        # Verify that the reference model is correct
6492        #
6493        # Reference model until add should be:
6494        # fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq ->
6495        # q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) ->
6496        # q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq ->
6497        # q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq)
6498        #
6499        # Complete reference model with add should be:
6500        # [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output
6501
6502        target_to_expected_dtypes = {
6503            "linear1": torch.qint32,
6504            "linear2": torch.quint8,
6505            "sigmoid": torch.qint32,
6506            "tanh": torch.quint8,
6507            torch.add: torch.quint8,
6508        }
6509        # Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph
6510        linear2_node = tanh_node = None
6511        for node in converted.graph.nodes:
6512            if node.target not in target_to_expected_dtypes:
6513                continue
6514
6515            # Match preceding dequantize
6516            self.assertTrue(len(node.args) == 1 or len(node.args) == 2)
6517            self.assertTrue(all(arg.target == "dequantize" for arg in node.args))
6518
6519            # Match following quantize with the specific dtypes
6520            self.assertEqual(len(node.users), 1)
6521            user = next(iter(node.users.keys()))
6522            self.assertEqual(user.target, torch.quantize_per_tensor)
6523            self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target])
6524
6525            # Match [dq - torch.add(linear2_dq, tanh_dq) - q]
6526            if node.target == "linear2":
6527                linear2_node = node
6528            elif node.target == "tanh":
6529                tanh_node = node
6530            elif node.target == torch.add:
6531                linear2_dq, tanh_dq = node.args
6532                self.assertEqual(tanh_dq.args[0].args[0], tanh_node)
6533                self.assertEqual(linear2_dq.args[0].args[0], linear2_node)
6534
6535    def test_lowering_functional_conv_with_kwargs(self):
6536        dim_to_op = {
6537            1: F.conv1d,
6538            2: F.conv2d,
6539            3: F.conv3d,
6540        }
6541        dim_to_qop = {
6542            1: torch.ops.quantized.conv1d,
6543            2: torch.ops.quantized.conv2d,
6544            3: torch.ops.quantized.conv3d,
6545        }
6546
6547        class Mod(nn.Module):
6548            def __init__(self, in_channels, out_channels, kernel_size, dimension):
6549                super().__init__()
6550                self.dim = dimension
6551                self.op = dim_to_op[dimension]
6552                kernel_sizes = [kernel_size] * self.dim
6553                self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes))
6554
6555            def forward(self, input):
6556                return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
6557                               padding=[0] * self.dim, dilation=[1] * self.dim, groups=1)
6558
6559        for dimension in [1, 2, 3]:
6560            model = Mod(3, 16, 3, dimension)
6561            model.eval()
6562            qconfig_mapping = get_default_qconfig_mapping()
6563            input_shape = (1, 3, *([8] * dimension))
6564            example_inputs = torch.randn(input_shape)
6565            prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
6566            prepared_model(example_inputs)
6567            quantized_model = convert_fx(prepared_model)
6568            # This should pass
6569            quantized_model(example_inputs)
6570            # Ensure the quantized model has the expected op
6571            node_occurrence = {
6572                ns.call_function(dim_to_qop[dimension]): 1,
6573            }
6574            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
6575
6576    def test_lowering_functional_conv_transpose_with_kwargs(self):
6577        dim_to_op = {
6578            1: F.conv_transpose1d,
6579            2: F.conv_transpose2d,
6580            3: F.conv_transpose3d,
6581        }
6582        dim_to_qop = {
6583            1: torch.ops.quantized.conv_transpose1d,
6584            2: torch.ops.quantized.conv_transpose2d,
6585            3: torch.ops.quantized.conv_transpose3d,
6586        }
6587
6588        class Mod(nn.Module):
6589            def __init__(self, in_channels, out_channels, kernel_size, dimension):
6590                super().__init__()
6591                self.dim = dimension
6592                self.op = dim_to_op[dimension]
6593                kernel_sizes = [kernel_size] * self.dim
6594                self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes))
6595
6596            def forward(self, input):
6597                return self.op(input, self.weight, bias=None, stride=[1] * self.dim,
6598                               padding=[0] * self.dim, output_padding=[0] * self.dim,
6599                               dilation=[1] * self.dim, groups=1)
6600
6601        for dimension in [1, 2, 3]:
6602            model = Mod(3, 16, 3, dimension)
6603            model.eval()
6604            qconfig_mapping = get_default_qconfig_mapping()
6605            input_shape = (1, 3, *([8] * dimension))
6606            example_inputs = torch.randn(input_shape)
6607            prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
6608            prepared_model(example_inputs)
6609            quantized_model = convert_fx(prepared_model)
6610            # This should pass
6611            quantized_model(example_inputs)
6612            # Ensure the quantized model has the expected op
6613            node_occurrence = {
6614                ns.call_function(dim_to_qop[dimension]): 1,
6615            }
6616            self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
6617
6618    def test_lowering_functional_linear_with_kwargs(self):
6619        class Mod(nn.Module):
6620            def __init__(self, in_channels, out_channels):
6621                super().__init__()
6622                self.weight = nn.Parameter(torch.randn(out_channels, in_channels))
6623
6624            def forward(self, input):
6625                return F.linear(input, self.weight, bias=None)
6626
6627        model = Mod(8, 4)
6628        model.eval()
6629        qconfig_mapping = get_default_qconfig_mapping()
6630        example_inputs = torch.randn(1, 8)
6631        prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)
6632        prepared_model(example_inputs)
6633        quantized_model = convert_fx(prepared_model)
6634        # This should pass
6635        quantized_model(example_inputs)
6636        # Ensure the quantized model has the expected op
6637        node_occurrence = {
6638            ns.call_function(torch.ops.quantized.linear): 1,
6639        }
6640        self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence)
6641
6642@skipIfNoFBGEMM
6643class TestQuantizeFxOps(QuantizationTestCase):
6644    def setUp(self):
6645        super().setUp()
6646        self.custom_qconfig = torch.ao.quantization.QConfig(
6647            activation=torch.ao.quantization.observer.HistogramObserver.with_args(
6648                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
6649            ),
6650            weight=torch.ao.quantization.default_per_channel_weight_observer
6651        )
6652        self.common_quant_patterns = {
6653            torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler,
6654            torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler,
6655            torch.nn.ELU: DefaultNodeQuantizeHandler,
6656            torch.nn.LeakyReLU: DefaultNodeQuantizeHandler,
6657            torch.nn.Hardswish: DefaultNodeQuantizeHandler,
6658            torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler,
6659            torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler,
6660            torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler,
6661            torch.nn.LayerNorm: DefaultNodeQuantizeHandler,
6662            torch.nn.SiLU: DefaultNodeQuantizeHandler,
6663            torch.nn.Mish: DefaultNodeQuantizeHandler,
6664            torch.nn.GELU: DefaultNodeQuantizeHandler,
6665            torch.nn.Softmax: DefaultNodeQuantizeHandler,
6666            torch.nn.functional.elu: DefaultNodeQuantizeHandler,
6667            torch.nn.functional.hardswish: DefaultNodeQuantizeHandler,
6668            torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler,
6669            torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler,
6670            torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler,
6671            torch.nn.functional.silu: DefaultNodeQuantizeHandler,
6672            torch.nn.functional.mish: DefaultNodeQuantizeHandler,
6673            torch.nn.functional.gelu: DefaultNodeQuantizeHandler,
6674            torch.nn.functional.softmax: DefaultNodeQuantizeHandler,
6675            torch.sum: DefaultNodeQuantizeHandler
6676        }
6677
6678    """Unit tests for individual ops
6679    """
6680    @skipIfNoFBGEMM
6681    def test_linear_module(self):
6682        with override_quantized_engine('fbgemm'):
6683            class LinearModel(torch.nn.Module):
6684                def __init__(self) -> None:
6685                    super().__init__()
6686                    self.linear = torch.nn.Linear(30, 4).float()
6687
6688                def forward(self, x):
6689                    return self.linear(x)
6690
6691            class LinearReLUModel(torch.nn.Module):
6692                def __init__(self, f_relu=False):
6693                    super().__init__()
6694                    self.linear = torch.nn.Linear(30, 4).float()
6695                    if f_relu:
6696                        self.relu = F.relu
6697                    else:
6698                        self.relu = torch.nn.ReLU()
6699
6700                def forward(self, x):
6701                    x = self.linear(x)
6702                    x = self.relu(x)
6703                    return x
6704
6705            class LinearBnModel(torch.nn.Module):
6706                def __init__(self) -> None:
6707                    super().__init__()
6708                    self.linear = torch.nn.Linear(4, 4).float()
6709                    self.bn = torch.nn.BatchNorm1d(4)
6710
6711                def forward(self, x):
6712                    x = self.linear(x)
6713                    x = self.bn(x)
6714                    return x
6715
6716            # Test linear
6717            data = (torch.rand((1, 30), dtype=torch.float),)
6718            for quant_type in self.all_quant_types:
6719                model = LinearModel()
6720                quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear
6721                quantized_node = ns.call_module(quantized_module)
6722                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
6723                if quant_type in self.static_quant_types:
6724                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
6725
6726            # TODO: enable test for dynamic quant
6727            # Test linear-relu
6728            for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
6729                model = LinearReLUModel(f_relu)
6730                quantized_node = ns.call_module(nniq.LinearReLU)
6731                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
6732                self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
6733
6734            # Test linear-bn
6735            data = (torch.rand((4, 4), dtype=torch.float),)
6736            for quant_type in self.static_quant_types:
6737                model = LinearBnModel()
6738                quantized_node = ns.call_module(nnq.Linear)
6739                result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
6740                self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
6741
6742    @skipIfNoFBGEMM
6743    def test_functional_linear(self):
6744        with override_quantized_engine('fbgemm'):
6745            class FuncLinear(torch.nn.Module):
6746                def __init__(self, use_bias, has_relu, f_relu):
6747                    super().__init__()
6748                    self.w = torch.randn(4, 30)
6749                    self.b = torch.randn(4)
6750                    self.use_bias = use_bias
6751                    if has_relu:
6752                        if f_relu:
6753                            self.relu_or_id = F.relu
6754                        else:
6755                            self.relu_or_id = torch.nn.ReLU()
6756                    else:
6757                        self.relu_or_id = torch.nn.Identity()
6758
6759                def forward(self, x):
6760                    if self.use_bias:
6761                        x = F.linear(x, self.w, self.b)
6762                    else:
6763                        x = F.linear(x, self.w)
6764                    x = self.relu_or_id(x)
6765                    return x
6766
6767            data = (torch.rand((1, 30), dtype=torch.float),)
6768            quant_type_to_qlinear_fun = {
6769                QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
6770                QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
6771                QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
6772            }
6773            quant_type_to_qlinear_relu_fun = {
6774                # we don't have linear_relu_dynamic
6775                QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic),
6776                QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
6777                QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
6778            }
6779
6780            options = itertools.product(
6781                self.all_quant_types,
6782                (True, False),  # use_bias
6783                (True, False),  # has_relu
6784                (True, False),  # functional relu
6785            )
6786            for quant_type, use_bias, has_relu, f_relu in options:
6787                # when has_relu is False, we are using an nn.Identity and
6788                # we will insert observer/fake_quant for the output of nn.Identity since
6789                # it is a copy node, that's why we have extra observer/fake_quant
6790                # when has_relu is False
6791                quant_type_to_prepare_expected_node_occurrence = {
6792                    QuantType.DYNAMIC: {
6793                        ns.call_module(torch.ao.quantization.PlaceholderObserver): 1,
6794                        ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
6795                    },
6796                    # There should be 3 observers: after input, weight and activation.
6797                    # one more observer for torch.nn.Identity when there is no relu
6798                    QuantType.STATIC: {
6799                        ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
6800                        ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
6801                    },
6802                    # There should be 3 observers: after input, weight and activation.
6803                    QuantType.QAT: {
6804                        ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
6805                    },
6806                }
6807                model = FuncLinear(use_bias, has_relu, f_relu)
6808                if has_relu:
6809                    qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type]
6810                else:
6811                    qlinear_fun = quant_type_to_qlinear_fun[quant_type]
6812
6813                if quant_type != QuantType.DYNAMIC:
6814                    num_dequantize = 1
6815                else:
6816                    # we will have an extra quantize_per_tensor_dynamic + dequantize for
6817                    # nn.Identity right now, but it will be fixed after we use
6818                    # backend_config to configure the default pt backend
6819                    num_dequantize = int(not has_relu)
6820
6821                convert_node_occurrence = {
6822                    ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
6823                    qlinear_fun: 1,
6824                    ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0,
6825                }
6826                prepare_expected_node_occurrence = \
6827                    quant_type_to_prepare_expected_node_occurrence[quant_type]
6828                result_dict = self.checkGraphModeFxOp(
6829                    model, data, quant_type, qlinear_fun,
6830                    prepare_expected_node_occurrence=prepare_expected_node_occurrence,
6831                    expected_node_occurrence=convert_node_occurrence)
6832                if quant_type != QuantType.DYNAMIC:
6833                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
6834                    # Ensure packed weights in lowered models are folded
6835                    self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())
6836
6837    @skipIfNoFBGEMM
6838    def test_linear_dynamic_fp16(self):
6839        with override_quantized_engine('fbgemm'):
6840            class FuncLinear(torch.nn.Module):
6841                def __init__(self, use_bias, has_relu, f_relu):
6842                    super().__init__()
6843                    self.w = torch.randn(4, 30)
6844                    self.b = torch.randn(4)
6845                    self.use_bias = use_bias
6846                    if has_relu:
6847                        if f_relu:
6848                            self.relu = F.relu
6849                        else:
6850                            self.relu = torch.nn.ReLU()
6851                    else:
6852                        self.relu = torch.nn.Identity()
6853
6854                def forward(self, x):
6855                    if self.use_bias:
6856                        x = F.linear(x, self.w, self.b)
6857                    else:
6858                        x = F.linear(x, self.w)
6859                    x = self.relu(x)
6860                    return x
6861
6862            data = (torch.rand((1, 30), dtype=torch.float),)
6863            options = itertools.product(
6864                (True, False),  # use_bias
6865                (True, False),  # has_relu
6866                (True, False),  # functional relu
6867                (True, False),  # is_reference
6868            )
6869            for use_bias, has_relu, f_relu, is_reference in options:
6870                model = FuncLinear(use_bias, has_relu, f_relu)
6871                if is_reference:
6872                    qlinear_fun = ns.call_function(torch.nn.functional.linear)
6873                else:
6874                    if has_relu:
6875                        qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16)
6876                    else:
6877                        qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16)
6878                prepare_node_occurrence = {
6879                    # activation and weight
6880                    ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
6881                }
6882                convert_node_occurrence = {
6883                    qlinear_fun: 1,
6884                    # weight
6885                    ns.call_method("to"): 1 if is_reference else 0
6886                }
6887                self.checkGraphModeFxOp(
6888                    model, data, QuantType.DYNAMIC, qlinear_fun,
6889                    is_reference=is_reference,
6890                    custom_qconfig_dict={"": float16_dynamic_qconfig},
6891                    prepare_expected_node_occurrence=prepare_node_occurrence,
6892                    expected_node_occurrence=convert_node_occurrence)
6893
6894    def test_linear_static_fp16(self):
6895        class FuncLinear(torch.nn.Module):
6896            def __init__(self, use_bias, has_relu, f_relu):
6897                super().__init__()
6898                self.w = torch.randn(4, 30)
6899                self.b = torch.randn(4)
6900                self.use_bias = use_bias
6901                if has_relu:
6902                    if f_relu:
6903                        self.relu = F.relu
6904                    else:
6905                        self.relu = torch.nn.ReLU()
6906                else:
6907                    self.relu = torch.nn.Identity()
6908
6909            def forward(self, x):
6910                if self.use_bias:
6911                    x = F.linear(x, self.w, self.b)
6912                else:
6913                    x = F.linear(x, self.w)
6914                x = self.relu(x)
6915                return x
6916
6917        data = (torch.rand((1, 30), dtype=torch.float),)
6918        options = itertools.product(
6919            (True, False),  # use_bias
6920            (True, False),  # has_relu
6921            (True, False),  # functional relu
6922            (True, False),  # is_reference
6923        )
6924        backend_config = get_test_only_legacy_native_backend_config()
6925        for use_bias, has_relu, f_relu, is_reference in options:
6926            model = FuncLinear(use_bias, has_relu, f_relu)
6927            linear_fun = ns.call_function(torch.nn.functional.linear)
6928            # when has_relu is False, we are using an nn.Identity and
6929            # we will insert observer/fake_quant for the output of nn.Identity since
6930            # it is a copy node, that's why we have extra observer/fake_quant
6931            # when has_relu is False
6932            prepare_node_occurrence = {
6933                # activation, weight, bias and output
6934                ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu),
6935            }
6936            # We have extra to and dequantize when is_reference is True
6937            # and has_relu is False since when has_relu is False, we
6938            # have an nn.Identity in the model, which is a CopyNode
6939            # and we would add extra quant - dequant for CopyNode in
6940            # reference patterns
6941            convert_node_occurrence = {
6942                # we don't support static fp16 ops, so the linear function
6943                # is unfused
6944                linear_fun: 1,
6945                # activation, weight, bias and output
6946                ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference),
6947                ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference)
6948            }
6949            self.checkGraphModeFxOp(
6950                model, data, QuantType.DYNAMIC, linear_fun,
6951                is_reference=is_reference,
6952                custom_qconfig_dict={"": float16_static_qconfig},
6953                prepare_expected_node_occurrence=prepare_node_occurrence,
6954                expected_node_occurrence=convert_node_occurrence,
6955                backend_config=backend_config)
6956
6957    @skipIfNoFBGEMM
6958    def test_conv_module(self):
6959        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
6960
6961        class ConvWrapper(torch.nn.Module):
6962            def __init__(self, dim):
6963                super().__init__()
6964                self.conv = conv_module[dim](3, 3, 3).float()
6965
6966            def forward(self, x):
6967                return self.conv(x)
6968
6969        options = itertools.product([1, 2, 3], self.static_quant_types)
6970        quantized_nodes = {
6971            # dim
6972            1: ns.call_module(nnq.Conv1d),
6973            2: ns.call_module(nnq.Conv2d),
6974            3: ns.call_module(nnq.Conv3d),
6975        }
6976        for dim, quant_type in options:
6977            self.checkGraphModeFxOp(
6978                ConvWrapper(dim), self.img_data_dict[dim], quant_type,
6979                quantized_nodes[dim])
6980
6981    @skipIfNoFBGEMM
6982    def test_functional_conv(self):
6983        with override_quantized_engine('fbgemm'):
6984            """ Test for function conv and functional conv + relu
6985            """
6986            convs = {
6987                1: torch.nn.functional.conv1d,
6988                2: torch.nn.functional.conv2d,
6989                3: torch.nn.functional.conv3d,
6990            }
6991
6992            class FuncConv(torch.nn.Module):
6993                def __init__(self, dim, use_bias, has_relu, f_relu):
6994                    super().__init__()
6995                    self.dim = dim
6996                    self.w = torch.randn(tuple([3] * (dim + 2)))
6997                    self.b = torch.randn(3) if use_bias else None
6998                    self.stride = tuple([1] * dim)
6999                    self.padding = tuple([0] * dim)
7000                    self.dilation = tuple([1] * dim)
7001                    self.groups = 1
7002                    self.use_bias = use_bias
7003                    if has_relu:
7004                        if f_relu:
7005                            self.relu = F.relu
7006                        else:
7007                            self.relu = torch.nn.ReLU()
7008                    else:
7009                        self.relu = torch.nn.Identity()
7010
7011                def forward(self, x):
7012                    x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
7013                    x = self.relu(x)
7014                    return x
7015
7016            quant_type_to_qconv_fun = {
7017                QuantType.STATIC: {
7018                    1: ns.call_function(torch.ops.quantized.conv1d),
7019                    2: ns.call_function(torch.ops.quantized.conv2d),
7020                    3: ns.call_function(torch.ops.quantized.conv3d)
7021                },
7022                QuantType.QAT: {
7023                    1: ns.call_function(torch.ops.quantized.conv1d),
7024                    2: ns.call_function(torch.ops.quantized.conv2d),
7025                    3: ns.call_function(torch.ops.quantized.conv3d)
7026                },
7027            }
7028            quant_type_to_qconv_relu_fun = {
7029                QuantType.STATIC: {
7030                    1: ns.call_function(torch.ops.quantized.conv1d_relu),
7031                    2: ns.call_function(torch.ops.quantized.conv2d_relu),
7032                    3: ns.call_function(torch.ops.quantized.conv3d_relu)
7033                },
7034                QuantType.QAT: {
7035                    1: ns.call_function(torch.ops.quantized.conv1d_relu),
7036                    2: ns.call_function(torch.ops.quantized.conv2d_relu),
7037                    3: ns.call_function(torch.ops.quantized.conv3d_relu)
7038                },
7039            }
7040
7041            options = itertools.product(
7042                [1, 2, 3],  # dims
7043                self.static_quant_types,
7044                (True, False),  # use_bias
7045                (True, False),  # has_relu
7046                (True, False),  # functional relu
7047            )
7048            for dim, quant_type, use_bias, has_relu, f_relu in options:
7049                # when has_relu is False, we are using an nn.Identity and
7050                # we will insert observer/fake_quant for the output of nn.Identity since
7051                # it is a copy node, that's why we have extra observer/fake_quant
7052                # when has_relu is False
7053                quant_type_to_prepare_expected_node_occurrence = {
7054                    QuantType.DYNAMIC: {},
7055                    # There should be 3 observers: after input, weight and activation.
7056                    QuantType.STATIC: {
7057                        ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3,
7058                        ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1,
7059                    },
7060                    # There should be 3 observers: after input, weight and activation.
7061                    QuantType.QAT: {
7062                        ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4,
7063                    },
7064                }
7065                data_dims = [2, 3] + [4] * dim
7066                data = (torch.randn(tuple(data_dims), dtype=torch.float),)
7067                model = FuncConv(dim, use_bias, has_relu, f_relu)
7068                if has_relu:
7069                    qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim]
7070                else:
7071                    qconv_fun = quant_type_to_qconv_fun[quant_type][dim]
7072
7073                convert_node_occurrence = {
7074                    ns.call_function(torch.quantize_per_tensor): 1,
7075                    qconv_fun: 1,
7076                    ns.call_method("dequantize"): 1
7077                }
7078                prepare_expected_node_occurrence = \
7079                    quant_type_to_prepare_expected_node_occurrence[quant_type]
7080                result_dict = self.checkGraphModeFxOp(
7081                    model, data, quant_type, qconv_fun,
7082                    prepare_expected_node_occurrence=prepare_expected_node_occurrence,
7083                    expected_node_occurrence=convert_node_occurrence)
7084                if quant_type != QuantType.DYNAMIC:
7085                    self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
7086                    # Ensure packed weights in lowered models are folded
7087                    self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys())
7088
7089    @skipIfNoFBGEMM
7090    def test_quantized_conv_relu(self):
7091        """tests for conv1d_relu/conv2d_relu/conv3d_relu"""
7092        conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
7093
7094        class ConvNdRelu(torch.nn.Module):
7095            def __init__(self, dim, inplace):
7096                super().__init__()
7097                self.conv = conv_module[dim](3, 3, 3).float()
7098                self.relu = torch.nn.ReLU(inplace)
7099
7100            def forward(self, x):
7101                return self.relu(self.conv(x))
7102
7103        class ConvNdFunctionalRelu(torch.nn.Module):
7104            def __init__(self, dim):
7105                super().__init__()
7106                self.conv = conv_module[dim](3, 3, 3).float()
7107
7108            def forward(self, x):
7109                return F.relu(self.conv(x))
7110
7111        class ConvNdInplaceFunctionalRelu(torch.nn.Module):
7112            def __init__(self, dim):
7113                super().__init__()
7114                self.conv = conv_module[dim](3, 3, 3).float()
7115
7116            def forward(self, x):
7117                return F.relu(self.conv(x), True)
7118
7119        options = itertools.product([1, 2, 3], self.static_quant_types)
7120        quantized_nodes = {
7121            # dim
7122            1: ns.call_module(nniq.ConvReLU1d),
7123            2: ns.call_module(nniq.ConvReLU2d),
7124            3: ns.call_module(nniq.ConvReLU3d),
7125        }
7126        for dim, quant_type in options:
7127            for m in [ConvNdRelu(dim, True),
7128                      ConvNdRelu(dim, False),
7129                      ConvNdFunctionalRelu(dim),
7130                      ConvNdInplaceFunctionalRelu(dim)]:
7131                self.checkGraphModeFxOp(
7132                    m, self.img_data_dict[dim], quant_type,
7133                    quantized_nodes[dim])
7134
7135
7136    def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op):
7137        data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
7138                torch.randn(1, 1, 1, 1, dtype=torch.float))
7139        options = itertools.product([True, False], [True, False], [True, False])
7140        quant_type = QuantType.STATIC
7141        # testing for default int8 static quant
7142        for is_inplace, is_scalar, is_reference in options:
7143            if is_reference:
7144                node_list = [
7145                    ns.call_method("dequantize"),
7146                    ns.call_function(binary_op),
7147                    ns.call_function(torch.quantize_per_tensor)
7148                ]
7149                quantized_node = None
7150            else:
7151                node_list = None
7152                quantized_node = ns.call_function(quantized_op)
7153
7154            self.checkGraphModeFxOp(
7155                BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
7156                quantized_node, expected_node_list=node_list, is_reference=is_reference)
7157            # This tests the binary op should be quantized even when it is not feed with a
7158            # quantized input
7159            self.checkGraphModeFxOp(
7160                BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar),
7161                data, quant_type, quantized_node,
7162                expected_node_list=node_list, is_reference=is_reference)
7163
7164
7165    def _test_binary_op_float16_impl(self, binary_op, ibinary_op):
7166        data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
7167                torch.randn(1, 1, 1, 1, dtype=torch.float))
7168        quant_type = QuantType.STATIC
7169        # testing for fp16 static quant
7170        # we are producing fp16 patterns
7171        options = itertools.product([True, False], [True, False])
7172        custom_qconfig_dict = {
7173            "object_type": [(binary_op, float16_static_qconfig)]
7174        }
7175        backend_config = get_test_only_legacy_native_backend_config()
7176        for is_inplace, is_scalar in options:
7177            node_occurrence = {
7178                # output_conv1, output_add1, output_add2 for scalar
7179                # output_conv1, output_conv2, output_add1, output_add2 for non-scalar
7180                ns.call_method("to"): 3 if is_scalar else 4
7181            }
7182            self.checkGraphModeFxOp(
7183                BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
7184                expected_node_occurrence=node_occurrence,
7185                custom_qconfig_dict=custom_qconfig_dict,
7186                backend_config=backend_config)
7187
7188            node_occurrence = {
7189                # input_add, output_add for scalar
7190                # input_add1, input_add2, output_add for non-scalar
7191                ns.call_method("to"): 2 if is_scalar else 3
7192            }
7193            self.checkGraphModeFxOp(
7194                BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type,
7195                expected_node_occurrence=node_occurrence,
7196                custom_qconfig_dict=custom_qconfig_dict,
7197                backend_config=backend_config)
7198
7199    def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op):
7200        data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
7201                torch.rand((1, 1, 1, 1), dtype=torch.float))
7202        quant_type = QuantType.STATIC
7203        quantized_node = ns.call_function(quantized_op)
7204        options = itertools.product(
7205            [True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
7206        for is_inplace_op, relu_callable, is_scalar in options:
7207            model = BinaryOpRelu(
7208                binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar)
7209            self.checkGraphModeFxOp(
7210                model, data, quant_type, quantized_node)
7211
7212    def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op):
7213        data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
7214                torch.rand((1, 1, 1, 1), dtype=torch.float))
7215        quant_type = QuantType.STATIC
7216        options = itertools.product(
7217            [True, False], [nn.ReLU, F.relu, torch.relu], [True, False])
7218        custom_qconfig_dict = {
7219            "": float16_static_qconfig,
7220            "object_type": [(torch.nn.Conv2d, None)]
7221        }
7222        backend_config = get_test_only_legacy_native_backend_config()
7223        for is_inplace_op, is_functional_relu, is_scalar in options:
7224            node_occurrence = {
7225                ns.call_method("to"): 3 if is_scalar else 4
7226            }
7227            model = BinaryOpRelu(
7228                binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar)
7229            self.checkGraphModeFxOp(
7230                model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict,
7231                expected_node_occurrence=node_occurrence,
7232                backend_config=backend_config)
7233
7234
7235    @skipIfNoFBGEMM
7236    def test_add(self):
7237        self._test_binary_op_int8_impl(
7238            operator.add, operator.iadd, torch.ops.quantized.add)
7239        self._test_binary_op_float16_impl(
7240            operator.add, operator.iadd)
7241
7242    @unittest.skip("This is no longer needed right now, can enable later with new api")
7243    def test_sub(self):
7244        self._test_binary_op_float16_impl(operator.sub, operator.isub)
7245        self._test_binary_op_float16_impl(torch.sub, None)
7246
7247    @unittest.skip("This is no longer needed right now, can enable later with new api")
7248    def test_div(self):
7249        self._test_binary_op_float16_impl(operator.truediv, operator.itruediv)
7250        self._test_binary_op_float16_impl(torch.div, None)
7251
7252    @skipIfNoFBGEMM
7253    def test_mul(self):
7254        self._test_binary_op_int8_impl(
7255            operator.mul, operator.imul, torch.ops.quantized.mul)
7256        self._test_binary_op_float16_impl(operator.mul, operator.imul)
7257
7258    @unittest.skip("This is no longer needed right now, can enable later with new api")
7259    def test_sum(self):
7260        class Sum(torch.nn.Module):
7261            def forward(self, x):
7262                x = torch.sum(x, [1], keepdim=True)
7263                x = torch.sum(x, [1])
7264                return x
7265
7266        data = torch.randn(1, 2, 3, 4, dtype=torch.float)
7267        quant_type = QuantType.STATIC
7268        # testing for fp16 static quant
7269        # we are producing fp16 patterns
7270        custom_qconfig_dict = {
7271            "object_type": [(torch.sum, float16_static_qconfig)]
7272        }
7273        node_occurrence = {
7274            # input_sum1, output_sum1, output_sum2
7275            ns.call_method("to"): 3
7276        }
7277        self.checkGraphModeFxOp(
7278            Sum(), data, quant_type,
7279            expected_node_occurrence=node_occurrence,
7280            custom_qconfig_dict=custom_qconfig_dict)
7281
7282    @unittest.skip("This is no longer needed right now, can enable later with new api")
7283    def test_bmm(self):
7284        class BMMMethod(torch.nn.Module):
7285            def forward(self, x, y):
7286                return x.bmm(y)
7287
7288        data = (torch.randn(1, 1, 1, dtype=torch.float),
7289                torch.randn(1, 1, 1, dtype=torch.float))
7290        quant_type = QuantType.STATIC
7291        # testing for fp16 static quant
7292        # we are producing fp16 patterns
7293        custom_qconfig_dict = {
7294            "object_type": [(torch.bmm, float16_static_qconfig),
7295                            ("bmm", float16_static_qconfig)]
7296        }
7297        node_occurrence = {
7298            # input_bmm1, input_bmm2, output_bmm
7299            ns.call_method("to"): 3
7300        }
7301        self.checkGraphModeFxOp(
7302            BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type,
7303            expected_node_occurrence=node_occurrence,
7304            custom_qconfig_dict=custom_qconfig_dict)
7305
7306        # TODO: support call_method("bmm")
7307        # we can transform call_method("bmm") to call_function(torch.bmm)
7308        # self.checkGraphModeFxOp(
7309        #     BMMMethod(), data, quant_type,
7310        #     expected_node_occurrence=node_occurrence,
7311        #     custom_qconfig_dict=custom_qconfig_dict,
7312        #     print_debug_info=True)
7313
7314    @skipIfNoFBGEMM
7315    def test_add_relu(self):
7316        self._test_binary_op_relu_int8_impl(
7317            operator.add, operator.iadd, torch.ops.quantized.add_relu)
7318        self._test_binary_op_relu_float16_impl(
7319            operator.add, operator.iadd)
7320
7321    @skipIfNoFBGEMM
7322    def test_add_relu_multiple_uses_of_relu(self):
7323        class Sub(torch.nn.Module):
7324            def __init__(self) -> None:
7325                super().__init__()
7326                self.relu = torch.nn.ReLU(inplace=True)
7327
7328        class M(torch.nn.Module):
7329            def __init__(self) -> None:
7330                super().__init__()
7331                self.sub = Sub()
7332
7333            def forward(self, x, y):
7334                x = x + y
7335                x = self.sub.relu(x)
7336                x = x + y
7337                x = self.sub.relu(x)
7338                return x
7339
7340        m = M().eval()
7341        example_inputs = (torch.randn(3), torch.randn(3))
7342        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
7343        m = convert_fx(m)
7344        node_occurrence = {
7345            ns.call_function(torch.quantize_per_tensor): 2,
7346            ns.call_function(torch.ops.quantized.add_relu): 2,
7347            ns.call_method("dequantize"): 1,
7348        }
7349        self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
7350        # check the model is scriptable
7351        m = torch.jit.script(m)
7352        # check the model is runnable
7353        m(*example_inputs)
7354
7355    @skipIfNoFBGEMM
7356    def test_mul_relu(self):
7357        self._test_binary_op_relu_int8_impl(
7358            operator.mul, operator.imul, torch.ops.quantized.mul_relu)
7359        self._test_binary_op_relu_float16_impl(
7360            operator.mul, operator.imul)
7361
7362    # TODO(future PR): make more generic
7363    def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence):
7364        qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}
7365        mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
7366        self.checkGraphModuleNodes(
7367            mp, expected_node_occurrence=expected_node_occurrence)
7368
7369    @skipIfNoFBGEMM
7370    def test_quantized_add_qat(self):
7371        class M(torch.nn.Module):
7372            def __init__(self) -> None:
7373                super().__init__()
7374                self.conv1 = torch.nn.Conv2d(1, 1, 1)
7375                self.conv2 = torch.nn.Conv2d(1, 1, 1)
7376
7377            def forward(self, x):
7378                x = torch.add(x, 1.0)
7379                x = self.conv1(x)
7380                x = torch.add(x, 1.0)
7381                x = torch.relu(x)
7382                x = self.conv2(x)
7383                return x
7384
7385        m = M()
7386        example_inputs = (torch.randn(1, 1, 1, 1),)
7387        expected_node_occurrence = {
7388            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
7389        }
7390        self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)
7391
7392    @skipIfNoFBGEMM
7393    def test_quantized_mul_qat(self):
7394        class M(torch.nn.Module):
7395            def __init__(self) -> None:
7396                super().__init__()
7397                self.conv1 = torch.nn.Conv2d(1, 1, 1)
7398                self.conv2 = torch.nn.Conv2d(1, 1, 1)
7399
7400            def forward(self, x):
7401                x = torch.mul(x, 1.0)
7402                x = self.conv1(x)
7403                x = torch.mul(x, 1.0)
7404                x = torch.relu(x)
7405                x = self.conv2(x)
7406                return x
7407
7408        m = M()
7409        example_inputs = (torch.randn(1, 1, 1, 1),)
7410        expected_node_occurrence = {
7411            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5,
7412        }
7413        self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence)
7414
7415    def test_int8_input_no_unnecessary_fq(self):
7416        """
7417        If the inputs to the graph are quantized and the only node
7418        does not need an activation observer, verifies that the
7419        activation observer is not inserted.
7420        """
7421        class M(nn.Module):
7422            def __init__(self, scalar):
7423                super().__init__()
7424                self.scalar = scalar
7425                self.add_func = torch.ao.nn.quantized.FloatFunctional()
7426
7427            def forward(self, x):
7428                return self.add_func.add_scalar(x, self.scalar)
7429
7430        m = M(0.5)
7431        mp = torch.ao.quantization.quantize_fx.prepare_qat_fx(
7432            m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')},
7433            example_inputs=(torch.randn(1),),
7434            prepare_custom_config={"input_quantized_idxs": [0]})
7435        expected_node_occurrence = {
7436            ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1,
7437        }
7438        self.checkGraphModuleNodes(
7439            mp, expected_node_occurrence=expected_node_occurrence)
7440
7441    @skipIfNoFBGEMM
7442    def test_cat(self):
7443        """ quantization of the output of cat will depend on the
7444        input of cat. we only quantize the output of cat when its inputs are quantized.
7445        """
7446        class M(torch.nn.Module):
7447            def __init__(self) -> None:
7448                super().__init__()
7449                self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
7450                self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
7451
7452            def forward(self, x, y):
7453                x = self.conv1(x)
7454                y = self.conv2(y)
7455                return torch.cat([x, y], 1)
7456
7457        example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float),
7458                          torch.randn(1, 2, 5, 5, dtype=torch.float))
7459        quantized_node = ns.call_function(torch.cat)
7460        options = itertools.product(self.static_quant_types, [True, False])
7461        for quant_type, is_reference in options:
7462            if is_reference:
7463                converted_node_list = [
7464                    ns.call_method("dequantize"),
7465                    ns.call_function(torch.cat),
7466                    ns.call_function(torch.quantize_per_tensor)
7467                ]
7468                converted_node_occurrence = {
7469                    # inputs and outputs of the two conv, and output of cat
7470                    ns.call_method("dequantize"): 5,
7471                    ns.call_function(torch.cat): 1,
7472                    # inputs and outputs of the two conv, and output of cat
7473                    ns.call_function(torch.quantize_per_tensor): 5,
7474                }
7475            else:
7476                converted_node_list = None
7477                converted_node_occurrence = {
7478                    # output of cat
7479                    ns.call_method("dequantize"): 1,
7480                    ns.call_function(torch.cat): 1,
7481                    # for two inputs
7482                    ns.call_function(torch.quantize_per_tensor): 2,
7483                }
7484
7485            self.checkGraphModeFxOp(
7486                M(),
7487                example_inputs,
7488                quant_type,
7489                quantized_node,
7490                expected_node_list=converted_node_list,
7491                expected_node_occurrence=converted_node_occurrence,
7492                is_reference=is_reference)
7493
7494        # check cat is using the same observer for input and output
7495        m = M().eval()
7496        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
7497        # two inputs and one output of torch.cat are using same observer, so we have
7498        # 2 observers that's replicated
7499        all_observers = len(dict(m.named_modules(remove_duplicate=False)))
7500        distinct_observers = len(dict(m.named_modules()))
7501        self.assertEqual(all_observers, distinct_observers + 2)
7502        # make sure the converted model runs
7503        m = convert_fx(m)
7504        m(*example_inputs)
7505
7506    @skipIfNoFBGEMM
7507    def test_qbatch_norm(self):
7508        bn_module = {
7509            # TODO: quantized batchnorm 1d module is missing
7510            # 1 : torch.nn.BatchNorm1d,
7511            2 : torch.nn.BatchNorm2d,
7512            3 : torch.nn.BatchNorm3d,
7513        }
7514
7515        class M(torch.nn.Module):
7516            def __init__(self, dim):
7517                super().__init__()
7518                self.bn = bn_module[dim](3).to(torch.float)
7519
7520            def forward(self, x):
7521                return self.bn(x)
7522
7523        options = itertools.product(self.static_quant_types, [2, 3], [True, False])
7524        quantized_nodes = {
7525            False: {
7526                # 1: ns.call_module(nnq.BatchNorm1d),
7527                2: ns.call_module(nnq.BatchNorm2d),
7528                3: ns.call_module(nnq.BatchNorm3d),
7529            },
7530            True: {
7531                # 1: ns.call_module(nn.BatchNorm1d),
7532                2: ns.call_module(nn.BatchNorm2d),
7533                3: ns.call_module(nn.BatchNorm3d),
7534            }
7535        }
7536        for quant_type, dim, is_reference in options:
7537            self.checkGraphModeFxOp(
7538                M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference)
7539
7540    @skipIfNoFBGEMM
7541    def test_qbatch_norm_relu(self):
7542        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
7543
7544        class BNRelu(torch.nn.Module):
7545            def __init__(self, dim, inplace):
7546                super().__init__()
7547                self.bn = bn_module[dim](3).to(torch.float)
7548                self.relu = torch.nn.ReLU(inplace=inplace)
7549
7550            def forward(self, x):
7551                return self.relu(self.bn(x))
7552
7553        class BNFuncRelu(torch.nn.Module):
7554            def __init__(self, dim):
7555                super().__init__()
7556                self.bn = bn_module[dim](3).to(torch.float)
7557
7558            def forward(self, x):
7559                return F.relu(self.bn(x), False)
7560
7561        class BNFuncInplaceRelu(torch.nn.Module):
7562            def __init__(self, dim):
7563                super().__init__()
7564                self.bn = bn_module[dim](3).to(torch.float)
7565
7566            def forward(self, x):
7567                return F.relu(self.bn(x), True)
7568
7569        options = itertools.product(self.static_quant_types, [2, 3], [True, False])
7570        quantized_nodes = {
7571            True: {
7572                2: ns.call_module(nni.BNReLU2d),
7573                3: ns.call_module(nni.BNReLU3d),
7574            },
7575            False: {
7576                2: ns.call_module(nniq.BNReLU2d),
7577                3: ns.call_module(nniq.BNReLU3d),
7578            }
7579        }
7580        for quant_type, dim, is_reference in options:
7581            for instance in [BNRelu(dim, True), BNRelu(dim, False),
7582                             BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
7583                self.checkGraphModeFxOp(
7584                    instance, self.img_data_dict[dim], quant_type,
7585                    quantized_nodes[is_reference][dim], is_reference=is_reference)
7586
7587    def _test_activation_impl(
7588            self, float_module, float_op, quantized_module, quantized_op):
7589        ''' Test for activation op(with inplace options), float_op can be
7590        torch op or functional op
7591        '''
7592        class M(torch.nn.Module):
7593            def __init__(self, is_module, inplace):
7594                super().__init__()
7595                self.is_module = is_module
7596                self.inplace = inplace
7597                if self.is_module:
7598                    self.op = float_module(self.inplace)
7599                else:
7600                    self.op = float_op
7601
7602            def forward(self, input):
7603                if self.is_module:
7604                    return self.op(input)
7605                else:
7606                    return self.op(input, self.inplace)
7607
7608        options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False])
7609        quantized_nodes = {
7610            # is_module
7611            True: {
7612                # is_reference
7613                True: ns.call_module(float_module),
7614                False: ns.call_module(quantized_module),
7615            },
7616            False: {
7617                True: ns.call_function(float_op),
7618                False: ns.call_function(quantized_op),
7619            }
7620        }
7621
7622        for is_module, is_inplace, quant_type, is_reference in options:
7623            self.checkGraphModeFxOp(
7624                M(is_module, is_inplace), self.img_data_2d,
7625                quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference)
7626
7627    def test_hardswish(self):
7628        self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish)
7629
7630    def test_elu(self):
7631        self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)
7632
7633    def test_leaky_relu(self):
7634        self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)
7635
7636    def test_prelu(self):
7637        class M(torch.nn.Module):
7638            def __init__(self, num_param: int):
7639                super().__init__()
7640                self.op = torch.nn.PReLU(num_parameters=num_param)
7641
7642            def forward(self, input):
7643                return self.op(input)
7644
7645        X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]]
7646        options = itertools.product([1, 4], self.static_quant_types, [True, False])
7647        quantized_nodes = {
7648            # is_reference
7649            True: ns.call_module(torch.nn.PReLU),
7650            False: ns.call_module(torch.ao.nn.quantized.PReLU),
7651        }
7652
7653        for num_parameter, quant_type, is_reference in options:
7654            self.checkGraphModeFxOp(
7655                M(num_parameter), X, quant_type, quantized_nodes[is_reference],
7656                is_reference=is_reference)
7657
7658    def _test_norm_impl(
7659            self, float_module, float_op, op_args, data, quantized_module, quantized_op,
7660            skip_op_arg_for_functional=False):
7661        ''' Test for normalization op, float_op can be torch op or functional op,
7662        op_args is a list of positional argument for the module/op
7663        '''
7664        class M(torch.nn.Module):
7665            def __init__(self, is_module):
7666                super().__init__()
7667                self.is_module = is_module
7668                if self.is_module:
7669                    self.op = float_module(*op_args)
7670                else:
7671                    self.op = float_op
7672
7673            def forward(self, input):
7674                if self.is_module:
7675                    return self.op(input)
7676                else:
7677                    args = [input]
7678                    if not skip_op_arg_for_functional:
7679                        args += op_args
7680                    return self.op(*args)
7681
7682        options = itertools.product([True, False], self.static_quant_types)
7683        quantized_nodes = {
7684            # is_module
7685            True: ns.call_module(quantized_module),
7686            False: ns.call_function(quantized_op),
7687        }
7688
7689        for is_module, quant_type in options:
7690            self.checkGraphModeFxOp(
7691                M(is_module), data, quant_type, quantized_nodes[is_module])
7692
7693    def _test_norm_float16_impl(
7694            self, float_module, float_op, op_args, data,
7695            skip_op_arg_for_functional=False):
7696        ''' Test for normalization op, float_op can be torch op or functional op,
7697        op_args is a list of positional argument for the module/op
7698        '''
7699        class M(torch.nn.Module):
7700            def __init__(self, is_module):
7701                super().__init__()
7702                self.is_module = is_module
7703                if self.is_module:
7704                    self.op = float_module(*op_args)
7705                else:
7706                    self.op = float_op
7707
7708            def forward(self, input):
7709                if self.is_module:
7710                    return self.op(input)
7711                else:
7712                    args = [input]
7713                    if not skip_op_arg_for_functional:
7714                        args += op_args
7715                    return self.op(*args)
7716
7717        options = itertools.product([True, False], self.static_quant_types)
7718        qconfig_dict = {
7719            "object_type": [
7720                (float_module, float16_static_qconfig),
7721                (float_op, float16_static_qconfig)
7722            ]
7723        }
7724        node_occurrence = {
7725            ns.call_method("to"): 2
7726        }
7727        for is_module, quant_type in options:
7728            self.checkGraphModeFxOp(
7729                M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence)
7730
7731    def test_layer_norm(self):
7732        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
7733        self._test_norm_impl(
7734            nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm)
7735
7736    def test_instance_norm(self):
7737        data_1d = (torch.rand((1, 4, 5), dtype=torch.float),)
7738        data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),)
7739        data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),)
7740        data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d}
7741        instance_norm_modules = {1 : nn.InstanceNorm1d,
7742                                 2 : nn.InstanceNorm2d,
7743                                 3 : nn.InstanceNorm3d}
7744        quantized_instance_norm_modules = {
7745            1 : nnq.InstanceNorm1d,
7746            2 : nnq.InstanceNorm2d,
7747            3 : nnq.InstanceNorm3d
7748        }
7749        for dim in [1, 2, 3]:
7750            data = data_dict[dim]
7751            module = instance_norm_modules[dim]
7752            quantized_module = quantized_instance_norm_modules[dim]
7753            self._test_norm_impl(
7754                module, F.instance_norm, [4], data,
7755                quantized_module, torch.ops.quantized.instance_norm,
7756                skip_op_arg_for_functional=True)
7757
7758    def test_norm_weight_bias(self):
7759        class Linear(torch.nn.Module):
7760            def __init__(self) -> None:
7761                super().__init__()
7762                self.w = torch.ones(5, 5)
7763                self.b = torch.zeros(5)
7764
7765            def forward(self, x):
7766                return torch.nn.functional.linear(x, self.w, self.b)
7767
7768        class M(torch.nn.Module):
7769            def __init__(self) -> None:
7770                super().__init__()
7771                self.mods1 = Linear()
7772                self.scale = torch.randn(5, 5)
7773                self.bias = torch.randn(5, 5)
7774
7775            def forward(self, x):
7776                x1 = self.mods1(x)
7777                y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias)
7778                return y
7779
7780        model = M()
7781        expected_occurrence = {
7782            ns.call_function(torch.quantize_per_tensor): 1,
7783            ns.call_function(torch.ops.quantized.linear): 1,
7784            ns.call_function(torch.ops.quantized.layer_norm): 1,
7785            ns.call_method("dequantize"): 1,
7786        }
7787
7788        self.checkGraphModeFxOp(
7789            model,
7790            (torch.rand(5, 5),),
7791            QuantType.STATIC,
7792            expected_node_occurrence=expected_occurrence,
7793            custom_qconfig_dict=get_default_qconfig_mapping().to_dict()
7794        )
7795
7796    def _test_default_node_quant_handler_ops(
7797            self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None
7798    ):
7799        class M(torch.nn.Module):
7800            def __init__(self, mod, func):
7801                super().__init__()
7802                self.module = mod()
7803                self.functional = func
7804
7805            def forward(self, x):
7806                x = self.module(x)
7807                x = self.functional(x)
7808                return x
7809
7810        if node_list is None:
7811            node_list = []
7812        if additional_quant_pattern_dict is None:
7813            additional_quant_pattern_dict = {}
7814
7815        data = torch.randn((2, 2, 2, 2))
7816        quant_type = QuantType.STATIC
7817        prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict}
7818        qconfig_dict = {"": qconfig}
7819
7820        m = M(module, functional).eval()
7821        m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict)
7822        m_prep(data)
7823        convert_fn = convert_to_reference_fx if is_reference else convert_fx
7824        m_quant = convert_fn(m_prep, is_reference=is_reference)
7825        m_quant(data)
7826
7827        self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
7828
7829    @unittest.skip("TODO: reenable with backend_config api")
7830    def test_gelu_normal(self):
7831        module = torch.nn.GELU
7832        functional = torch.nn.functional.gelu
7833        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
7834        is_reference = False
7835        node_list = [
7836            ns.call_module(module),
7837            ns.call_function(functional),
7838        ]
7839        self._test_default_node_quant_handler_ops(
7840            module, functional, qconfig, is_reference, node_list)
7841
7842    @unittest.skip("TODO: reenable with backend_config api")
7843    def test_softmax_normal(self):
7844        module = torch.nn.Softmax
7845        functional = torch.nn.functional.softmax
7846        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
7847        is_reference = False
7848        node_list = [
7849            ns.call_module(torch.ao.nn.quantized.Softmax),
7850            ns.call_function(functional),
7851        ]
7852        self._test_default_node_quant_handler_ops(
7853            module, functional, qconfig, is_reference, node_list)
7854
7855    @unittest.skip("This is no longer needed right now, can enable later with new api")
7856    def test_gelu_reference(self):
7857        module = torch.nn.GELU
7858        functional = torch.nn.functional.gelu
7859        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
7860        is_reference = True
7861        node_list = [
7862            ns.call_function(torch.quantize_per_tensor),
7863            ns.call_method("dequantize"),
7864            ns.call_module(module),
7865            ns.call_function(torch.quantize_per_tensor),
7866            ns.call_method('dequantize'),
7867            ns.call_function(functional),
7868            ns.call_function(torch.quantize_per_tensor),
7869            ns.call_method('dequantize')
7870        ]
7871        # TODO: change these to use backend_config
7872        additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler,
7873                               torch.nn.functional.gelu: DefaultNodeQuantizeHandler}
7874        self._test_default_node_quant_handler_ops(
7875            module, functional, qconfig, is_reference, node_list, additional_patterns)
7876
7877        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
7878                                                  additional_quant_pattern_dict=self.common_quant_patterns)
7879
7880    @unittest.skip("This is no longer needed right now, can enable later with new api")
7881    def test_softmax_reference(self):
7882        module = torch.nn.Softmax
7883        functional = torch.nn.functional.softmax
7884        qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
7885        is_reference = True
7886        node_list = [
7887            ns.call_function(torch.quantize_per_tensor),
7888            ns.call_method("dequantize"),
7889            ns.call_module(module),
7890            ns.call_function(torch.quantize_per_tensor),
7891            ns.call_method('dequantize'),
7892            ns.call_function(functional),
7893            ns.call_function(torch.quantize_per_tensor),
7894            ns.call_method('dequantize')
7895        ]
7896        additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler,
7897                               torch.nn.functional.softmax: DefaultNodeQuantizeHandler}
7898        self._test_default_node_quant_handler_ops(
7899            module, functional, qconfig, is_reference, node_list, additional_patterns)
7900
7901        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
7902                                                  additional_quant_pattern_dict=self.common_quant_patterns)
7903
7904    @unittest.skip("This is no longer needed right now, can enable later with new api")
7905    def test_silu_reference(self):
7906        module = torch.nn.SiLU
7907        functional = torch.nn.functional.silu
7908        qconfig = float16_static_qconfig
7909        is_reference = True
7910        node_list = [
7911            ns.call_method("to"),
7912            ns.call_method("dequantize"),
7913            ns.call_module(module),
7914            ns.call_method("to"),
7915            ns.call_method('dequantize'),
7916            ns.call_function(functional),
7917            ns.call_method("to"),
7918            ns.call_method('dequantize')
7919        ]
7920        self._test_default_node_quant_handler_ops(
7921            module, functional, qconfig, is_reference, node_list)
7922
7923        node_list = [
7924            ns.call_function(torch.quantize_per_tensor),
7925            ns.call_method("dequantize"),
7926            ns.call_module(module),
7927            ns.call_function(torch.quantize_per_tensor),
7928            ns.call_method("dequantize"),
7929            ns.call_function(functional),
7930            ns.call_function(torch.quantize_per_tensor),
7931            ns.call_method("dequantize")
7932        ]
7933        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
7934                                                  additional_quant_pattern_dict=self.common_quant_patterns)
7935
7936    @unittest.skip("This is no longer needed right now, can enable later with new api")
7937    def test_mish_reference(self):
7938        module = torch.nn.Mish
7939        functional = torch.nn.functional.mish
7940        qconfig = float16_static_qconfig
7941        is_reference = True
7942        node_list = [
7943            ns.call_method("to"),
7944            ns.call_method("dequantize"),
7945            ns.call_module(module),
7946            ns.call_method("to"),
7947            ns.call_method('dequantize'),
7948            ns.call_function(functional),
7949            ns.call_method("to"),
7950            ns.call_method('dequantize')
7951        ]
7952        self._test_default_node_quant_handler_ops(
7953            module, functional, qconfig, is_reference, node_list)
7954
7955        node_list = [
7956            ns.call_function(torch.quantize_per_tensor),
7957            ns.call_method("dequantize"),
7958            ns.call_module(module),
7959            ns.call_function(torch.quantize_per_tensor),
7960            ns.call_method("dequantize"),
7961            ns.call_function(functional),
7962            ns.call_function(torch.quantize_per_tensor),
7963            ns.call_method("dequantize")
7964        ]
7965        self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list,
7966                                                  additional_quant_pattern_dict=self.common_quant_patterns)
7967
7968    def test_bmm_int_reference(self):
7969        """ int8 is not supported for bmm so we won't produce reference
7970            pattern for it
7971        """
7972        class M(torch.nn.Module):
7973            def __init__(self) -> None:
7974                super().__init__()
7975                self.bmm = torch.bmm
7976
7977            def forward(self, x, y):
7978                out = self.bmm(x, y)
7979                return out
7980
7981        data_x = torch.randn((2, 2, 2,))
7982        data_y = torch.randn((2, 2, 2,))
7983        example_inputs = (data_x, data_y)
7984        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")}
7985        is_reference = True
7986        node_list = [
7987            ns.call_function(torch.bmm),
7988        ]
7989
7990        m = M().eval()
7991        m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
7992        m_prep(*example_inputs)
7993        convert_fn = convert_to_reference_fx if is_reference else convert_fx
7994        m_quant = convert_fn(m_prep)
7995        m_quant(*example_inputs)
7996
7997        self.checkGraphModuleNodes(m_quant, expected_node_list=node_list)
7998
7999    @skipIfNoFBGEMM
8000    def test_clamp(self):
8001        class M(torch.nn.Module):
8002            def __init__(self) -> None:
8003                super().__init__()
8004                self.conv = torch.nn.Conv2d(2, 2, 2).float()
8005                self.relu6 = torch.nn.ReLU6()
8006                self.relu6_ = torch.nn.ReLU6(True)
8007                self.hardtanh = torch.nn.Hardtanh()
8008                self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
8009
8010            def forward(self, x):
8011                x = self.conv(x)
8012                x = self.relu6(x)
8013                self.relu6_(x)
8014                x = F.relu6(x)
8015                x = torch.clamp(x, -3, 3)
8016                x = x.clamp(-2.5, 2.5)
8017                # x = x.clamp_(-2, 2)  # Enable when quantized `clamp_` is ready
8018                x = self.hardtanh(x)
8019                self.hardtanh_(x)
8020                x = F.hardtanh(x)
8021                return x
8022
8023        data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
8024        # list of node that should occur in order
8025        node_list = [
8026            ns.call_function(torch.quantize_per_tensor),
8027            ns.call_module(nnq.Conv2d),
8028            ns.call_method('dequantize')
8029        ]
8030        for quant_type in self.static_quant_types:
8031            self.checkGraphModeFxOp(
8032                M(), data, quant_type, expected_node_list=node_list)
8033
8034    def test_fixed_qparams_ops_fp16(self):
8035        class M(torch.nn.Module):
8036            def __init__(self) -> None:
8037                super().__init__()
8038                self.sigmoid = torch.nn.Sigmoid()
8039                self.tanh = torch.nn.Tanh()
8040
8041            def forward(self, x):
8042                x = self.sigmoid(x)
8043                x = torch.sigmoid(x)
8044                x = x.sigmoid()
8045                x = self.tanh(x)
8046                x = torch.tanh(x)
8047                x = x.tanh()
8048                return x
8049
8050        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
8051        quant_type = QuantType.STATIC
8052        # TODO: use get_default_qconfig_mapping once it handles fp16
8053        qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig)
8054        backend_config = get_test_only_legacy_native_backend_config()
8055        node_occurrence = {
8056            ns.call_method("to"): 7
8057        }
8058        self.checkGraphModeFxOp(
8059            M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
8060            expected_node_occurrence=node_occurrence,
8061            backend_config=backend_config)
8062
8063    def test_fixed_qparams_ops_qint8(self):
8064        class M(torch.nn.Module):
8065            def __init__(self) -> None:
8066                super().__init__()
8067                self.sigmoid = torch.nn.Sigmoid()
8068                self.tanh = torch.nn.Tanh()
8069
8070            def forward(self, x):
8071                x = self.sigmoid(x)
8072                x = torch.sigmoid(x)
8073                x = x.sigmoid()
8074                x = self.tanh(x)
8075                x = torch.tanh(x)
8076                x = x.tanh()
8077                return x
8078
8079        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
8080        quant_type = QuantType.STATIC
8081        qconfig = torch.ao.quantization.QConfig(
8082            activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8),
8083            weight=default_weight_observer)
8084        qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig)
8085        node_occurrence = {
8086            ns.call_function(torch.quantize_per_tensor): 7,
8087            ns.call_method("dequantize"): 7
8088        }
8089        self.checkGraphModeFxOp(
8090            M(), data, quant_type, custom_qconfig_dict=qconfig_mapping,
8091            expected_node_occurrence=node_occurrence, is_reference=True)
8092
8093    def test_fixed_qparams_ops_wrong_qconfig(self):
8094        """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized.
8095        """
8096        class M(torch.nn.Module):
8097            def __init__(self) -> None:
8098                super().__init__()
8099                self.sigmoid = torch.nn.Sigmoid()
8100                self.tanh = torch.nn.Tanh()
8101
8102            def forward(self, x):
8103                x = self.sigmoid(x)
8104                x = torch.sigmoid(x)
8105                x = x.sigmoid()
8106                x = self.tanh(x)
8107                x = torch.tanh(x)
8108                x = x.tanh()
8109                return x
8110
8111        data = (torch.randn((2, 2, 2, 2), dtype=torch.float),)
8112        qconfig_mapping = QConfigMapping().set_global(default_qconfig)
8113        m = M().eval()
8114        node_occurrence = {
8115            ns.call_function(torch.quantize_per_tensor): 0,
8116            ns.call_method("dequantize"): 0,
8117        }
8118        self.checkGraphModeFxOp(
8119            m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping,
8120            expected_node_occurrence=node_occurrence, is_reference=True)
8121        self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid))
8122        self.assertTrue(isinstance(m.tanh, torch.nn.Tanh))
8123
8124    @skipIfNoFBGEMM
8125    def test_general_shape_ops(self):
8126        """ A test that checks dequantize will be swapped for
8127        all supported general shape ops like aten::flatten
8128        without actually checking for execution of these ops
8129        """
8130        class M(torch.nn.Module):
8131            def __init__(self) -> None:
8132                super().__init__()
8133                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
8134                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
8135                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
8136                self.dropout = torch.nn.Dropout()
8137                self.conv1 = torch.nn.Conv2d(3, 3, 3)
8138                self.conv2 = torch.nn.Conv2d(3, 3, 3)
8139                self.relu = torch.nn.ReLU()
8140
8141            def forward(self, x):
8142                x = self.conv1(x)
8143                # add_scalar
8144                x = x + 3
8145                # mul_scalar
8146                x = x * 3
8147                # add_scalar_out
8148                x += 3
8149                # mul_scalar_out
8150                x *= 3
8151                # add_scalar_relu
8152                x = x + 3
8153                x = F.relu(x)
8154                # add_scalar_relu_out
8155                x += 3
8156                x = F.relu(x)
8157                # mul_scalar_relu
8158                x = x * 3
8159                x = F.relu(x)
8160                # mul_scalar_relu_out
8161                x *= 3
8162                x = F.relu(x)
8163                x = self.maxpool1d(x)
8164                x = self.maxpool2d(x)
8165                x = self.maxpool3d(x)
8166                x = torch.flatten(x)
8167                x = x.reshape([-1])
8168                x = x.resize_(1, 1, x)
8169                x = x.view(-1)
8170                # prim::ListConstruct
8171                xs = [x, x]
8172                # prim::ListUnpack
8173                x, y = xs
8174                # prim::TupleConstruct
8175                xs = (x, x)
8176                # prim::TupleUnpack
8177                x, y = xs
8178                x = x.transpose(1, 2)
8179                x = x.contiguous()
8180                # chunk is not supported since observer only supports
8181                # observing single Tensor currently
8182                x, y = torch.chunk(x, 2)
8183                x = F.dropout(x)
8184                x = self.dropout(x)
8185                x = x.permute(0, 2, 3, 1)
8186                x = x.repeat_interleave(3, 1)
8187                x = torch.repeat_interleave(x, 3, 1)
8188                x = self.relu(x)
8189                x = F.relu(x)
8190                x = F.relu(x, inplace=True)
8191                x = x.relu()
8192                x.relu_()
8193                x = x.squeeze(0)
8194                x.squeeze_(0)
8195                x = torch.squeeze(x, 0)
8196                x = x.unsqueeze(0)
8197                x.unsqueeze_(0)
8198                x = torch.unsqueeze(x, 0)
8199                x = x.detach()
8200                x.detach_()
8201                x = x.repeat(4, 2)
8202                y = []
8203                y.append(x)
8204                z = torch.stack(y, 0)
8205                z = [z, z]
8206                x, _ = z
8207                x = self.conv2(x)
8208                return x
8209
8210        example_inputs = (torch.rand(1, 3, 10, 10),)
8211        # This model is not executable since we just put all ops
8212        # in the same forward
8213        m = M().eval()
8214        qconfig_dict = {'': default_qconfig}
8215        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
8216        # not runnable
8217        quantized = convert_fx(prepared)
8218
8219        # This checks that the dequantize from the output of first conv
8220        # is being propagated to the end, so that we don't insert extra
8221        # observers and also successfully fused two quantized::conv2d
8222        # patterns
8223        # one quantize_per_tensor for input
8224        # check exact counts of quantize and dequantize
8225        count_check = {
8226            # input of conv and two outputs of getitem
8227            ns.call_function(torch.quantize_per_tensor) : 2,
8228            # output of the model and two outputs of getitem
8229            ns.call_method('dequantize') : 2
8230        }
8231        order_check = [
8232            ns.call_function(torch.quantize_per_tensor),
8233            ns.call_module(nnq.Conv2d),
8234            ns.call_module(nnq.Conv2d),
8235            ns.call_method('dequantize'),
8236        ]
8237        self.checkGraphModuleNodes(
8238            quantized,
8239            expected_node_occurrence=count_check,
8240            expected_node_list=order_check)
8241
8242
8243        # Checking the is_reference output
8244        m = M().eval()
8245        qconfig_dict = {'': default_qconfig}
8246        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
8247        # not runnable
8248        quantized = convert_to_reference_fx(prepared)
8249
8250
8251    @skipIfNoFBGEMM
8252    def test_ave_pool_with_custom_cfg(self):
8253        """ A test that checks correct patterns are produced for
8254        avg_pool2d with customized config
8255        """
8256        class M(torch.nn.Module):
8257            def __init__(self) -> None:
8258                super().__init__()
8259                self.avg_pool2d = torch.nn.AvgPool2d(3)
8260
8261
8262            def forward(self, x):
8263                x = self.avg_pool2d(x)
8264                return x
8265
8266        # This model is not executable since we just put all ops
8267        # in the same forward
8268        m = M().eval()
8269        # nothing to fuse so skipping the fuse step
8270        qconfig_dict = {'': default_qconfig}
8271        example_inputs = (torch.randn(1, 3, 3, 3),)
8272        prepared = prepare_fx(
8273            m, qconfig_dict, example_inputs=example_inputs,
8274            prepare_custom_config={"input_quantized_idxs": [0]})
8275
8276        # not runnable
8277        quantized = convert_fx(prepared)
8278
8279        # This checks that the dequantize from the output of first conv
8280        # is being propagated to the end, so that we don't insert extra
8281        # observers
8282        # check exact counts of quantize and dequantize
8283        count_check = {
8284            ns.call_method('dequantize') : 1
8285        }
8286        order_check = [
8287            ns.call_module(nn.AvgPool2d),
8288            ns.call_method('dequantize'),
8289        ]
8290        self.checkGraphModuleNodes(
8291            quantized,
8292            expected_node_occurrence=count_check,
8293            expected_node_list=order_check)
8294
8295    @skipIfNoFBGEMM
8296    def test_general_value_ops(self):
8297        """ A test that checks correct patterns are produced for
8298        all supported general value ops like aten::avg_pool2d \
8299        without actually checking for execution of these ops
8300        """
8301        class M(torch.nn.Module):
8302            def __init__(self) -> None:
8303                super().__init__()
8304                self.conv = torch.nn.Conv2d(3, 3, 3)
8305                self.avg_pool1d = torch.nn.AvgPool1d(3)
8306                self.avg_pool2d = torch.nn.AvgPool2d(3)
8307                self.avg_pool3d = torch.nn.AvgPool3d(3)
8308                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1)
8309                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
8310                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
8311
8312            def forward(self, x):
8313                x = self.conv(x)
8314                x = self.avg_pool1d(x)
8315                x = self.avg_pool2d(x)
8316                x = self.avg_pool3d(x)
8317                x = self.adaptive_avg_pool1d(x)
8318                x = self.adaptive_avg_pool2d(x)
8319                x = self.adaptive_avg_pool3d(x)
8320                x = F.avg_pool1d(x, 3)
8321                x = F.avg_pool2d(x, 3)
8322                x = F.avg_pool3d(x, 3)
8323                x = F.adaptive_avg_pool1d(x, (1))
8324                x = F.adaptive_avg_pool2d(x, (1, 1))
8325                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
8326                x = torch.mean(x)
8327                x = torch.mean(x, [2, 3], False)
8328                x = x.mean()
8329                x = x.mean([2, 3], True)
8330                x = F.interpolate(x, 4, mode='nearest')
8331                x = F.interpolate(x, 4, mode='linear')
8332                x = self.conv(x)
8333                return x
8334
8335        # This model is not executable since we just put all ops
8336        # in the same forward
8337        m = M().eval()
8338        # nothing to fuse so skipping the fuse step
8339        qconfig_dict = {'': default_qconfig}
8340        example_inputs = (torch.randn(1, 3, 3, 3),)
8341        prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
8342        # not runnable
8343        quantized = convert_fx(prepared)
8344
8345        # This checks that the dequantize from the output of first conv
8346        # is being propagated to the end, so that we don't insert extra
8347        # observers
8348        # check exact counts of quantize and dequantize
8349        count_check = {
8350            ns.call_function(torch.quantize_per_tensor) : 1,
8351            ns.call_method('dequantize') : 1
8352        }
8353        order_check = [
8354            ns.call_function(torch.quantize_per_tensor),
8355            ns.call_module(nnq.Conv2d),
8356            ns.call_module(nnq.Conv2d),
8357            ns.call_method('dequantize'),
8358        ]
8359        self.checkGraphModuleNodes(
8360            quantized,
8361            expected_node_occurrence=count_check,
8362            expected_node_list=order_check)
8363
8364    def test_copy_node_fp32_input(self):
8365        """ CopyNode works for both fp32 and int8 inputs, this is a test to make
8366        sure that a CopyNode can be successfully quantized in both cases
8367        """
8368        class M(torch.nn.Module):
8369            def forward(self, x):
8370                x = x.relu()
8371                return x
8372
8373        m = M().eval()
8374        m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),))
8375        m = convert_fx(m)
8376        # make sure it runs
8377        m(torch.rand(1))
8378
8379    def test_getitem(self):
8380        """ Make sure we only insert observer for getitem if the following node is matched
8381        or needs to be quantized
8382        """
8383        class M(torch.nn.Module):
8384            def forward(self, xs):
8385                x = xs[0]
8386                return x
8387
8388        m = M().eval()
8389        example_inputs = (torch.rand(1, 2),)
8390        qconfig_mapping = get_default_qconfig_mapping()
8391        m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs)
8392        self.checkGraphModuleNodes(m, expected_node_occurrence={
8393            ns.call_module(torch.ao.quantization.MinMaxObserver): 0
8394        })
8395        m = convert_fx(m)
8396        m(*example_inputs)
8397
8398        class M2(torch.nn.Module):
8399            def forward(self, xs):
8400                x = xs[0]
8401                x = torch.sigmoid(x)
8402                return x
8403
8404        m2 = M2().eval()
8405        example_inputs = ([torch.rand(1, 2)],)
8406        qconfig_mapping = get_default_qconfig_mapping()
8407        m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs)
8408        self.checkGraphModuleNodes(m2, expected_node_occurrence={
8409            ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
8410        })
8411        m2 = convert_fx(m2)
8412        self.checkGraphModuleNodes(m2, expected_node_list=[
8413            ns.call_function(torch.quantize_per_tensor),
8414            ns.call_method("dequantize")
8415        ])
8416        m2(*example_inputs)
8417
8418        # testing prepare recognizes non-Tensor input for getitem
8419        class M3(torch.nn.Module):
8420            def forward(self, x):
8421                s = x.shape
8422                n, c = s[:2]
8423                x = torch.sigmoid(x)
8424                return x
8425
8426        m3 = M3().eval()
8427        example_inputs = (torch.rand(1, 2, 3, 4),)
8428        qconfig_mapping = get_default_qconfig_mapping()
8429        m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs)
8430        self.checkGraphModuleNodes(m3, expected_node_occurrence={
8431            ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2
8432        })
8433        m3 = convert_fx(m3)
8434        self.checkGraphModuleNodes(m3, expected_node_list=[
8435            ns.call_function(torch.quantize_per_tensor),
8436            ns.call_method("dequantize")
8437        ])
8438        m3(*example_inputs)
8439
8440
8441    @skipIfNoFBGEMM
8442    def test_fixed_qparams_ops(self):
8443        class M(torch.nn.Module):
8444            def __init__(self) -> None:
8445                super().__init__()
8446                self.conv = torch.nn.Conv2d(3, 3, 3)
8447                self.sigmoid = torch.nn.Sigmoid()
8448                self.hardsigmoid = torch.nn.Hardsigmoid()
8449                self.tanh = torch.nn.Tanh()
8450                self.softmax = torch.nn.Softmax(dim=0)
8451
8452            def forward(self, x):
8453                x = self.conv(x)
8454                # F.sigmoid is deprecated
8455                x = self.sigmoid(x)
8456                x = torch.sigmoid(x)
8457                x = x.sigmoid()
8458                x = self.hardsigmoid(x)
8459                x = F.hardsigmoid(x)
8460                x = F.hardsigmoid(x, inplace=True)
8461                x = self.tanh(x)
8462                # F.tanh is deprecated
8463                x = torch.tanh(x)
8464                x = x.tanh()
8465                # TODO(future PR): handle F.softmax
8466                x = self.softmax(x)
8467                return x
8468
8469        for eval_mode in [True, False]:
8470            # This model is not executable since we just put all ops
8471            # in the same forward
8472            m = M()
8473            if eval_mode:
8474                m.eval()
8475                qconfig_mapping = get_default_qconfig_mapping()
8476                prepare = prepare_fx
8477                fq_count = 10
8478            else:
8479                m.train()
8480                qconfig_mapping = get_default_qat_qconfig_mapping()
8481                prepare = prepare_qat_fx
8482                fq_count = 10
8483            # nothing to fuse so skipping the fuse step
8484            m_copy = copy.deepcopy(m)
8485            example_inputs = (torch.rand(3, 3, 3, 3),)
8486            prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs)
8487            prepared_copy = copy.deepcopy(prepared)
8488            # check that prepare does not change model result
8489            if eval_mode:
8490                self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs))
8491            # check the correct number of activation_post_process is inserted
8492            expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize
8493            count_check = {
8494                ns.call_module(expected_activation_post_process) : fq_count,
8495            }
8496            self.checkGraphModuleNodes(
8497                prepared,
8498                expected_node_occurrence=count_check)
8499            # not runnable
8500            quantized = convert_fx(prepared)
8501            quantized_reference = convert_to_reference_fx(prepared_copy)
8502
8503            # This checks that the dequantize from the output of first conv
8504            # is being propagated to the end, so that we don't insert extra
8505            # observers
8506            # check exact counts of quantize and dequantize
8507            count_check = {
8508                ns.call_function(torch.quantize_per_tensor) : 1,
8509                ns.call_method('dequantize') : 1
8510            }
8511            order_check = [
8512                ns.call_function(torch.quantize_per_tensor),
8513                ns.call_module(nnq.Conv2d),
8514                ns.call_module(nn.Sigmoid),
8515                ns.call_module(nnq.Softmax),
8516                ns.call_method('dequantize'),
8517            ]
8518            self.checkGraphModuleNodes(
8519                quantized,
8520                expected_node_occurrence=count_check,
8521                expected_node_list=order_check)
8522
8523            reference_count_check = {
8524                ns.call_function(torch.quantize_per_tensor) : 12,
8525                ns.call_method('dequantize') : 12
8526            }
8527            reference_order_check = [
8528                ns.call_function(torch.quantize_per_tensor),
8529                ns.call_method('dequantize'),
8530                ns.call_module(nnqr.Conv2d),
8531                ns.call_function(torch.quantize_per_tensor),
8532                ns.call_method('dequantize'),
8533                ns.call_module(nn.Sigmoid),
8534                ns.call_function(torch.quantize_per_tensor),
8535                ns.call_method('dequantize'),
8536                ns.call_module(nn.Softmax),
8537                ns.call_function(torch.quantize_per_tensor),
8538                ns.call_method('dequantize'),
8539            ]
8540            self.checkGraphModuleNodes(
8541                quantized_reference,
8542                expected_node_occurrence=reference_count_check,
8543                expected_node_list=reference_order_check)
8544
8545            # Verify that softmax scale and zero_point are correct
8546            self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8)
8547            self.assertTrue(quantized.softmax.zero_point == 0)
8548
8549    def test_float_functional(self):
8550        class TorchAdd(nn.Module):
8551            """Wrapper around torch.add so that all ops can be found at build"""
8552            def __init__(self) -> None:
8553                super().__init__()
8554                self.add_func = nnq.FloatFunctional()
8555
8556            def forward(self, x, y):
8557                return self.add_func.add(x, y)
8558
8559        class M(torch.nn.Module):
8560            def __init__(self) -> None:
8561                super().__init__()
8562                self.ff1 = TorchAdd()
8563                self.ff2 = nnq.FloatFunctional()
8564                self.ff3 = nnq.FloatFunctional()
8565                self.ff4 = nnq.FloatFunctional()
8566                self.ff5 = nnq.FloatFunctional()
8567                self.ff6 = nnq.FloatFunctional()
8568
8569            def forward(self, x):
8570                x = self.ff1(x, x)
8571                x = self.ff2.add_scalar(x, 3)
8572                x = self.ff3.mul(x, x)
8573                x = self.ff4.mul_scalar(x, 3)
8574                x = self.ff5.add_relu(x, x)
8575                x = self.ff6.cat([x])
8576                return x
8577
8578        example_inputs = (torch.rand(3, 3),)
8579        # Note: QAT test succeeded by chance, to make it actually work
8580        # we need to fix eager mode FloatFunctional by removing
8581        # activation_post_process in add_scalar and mul_scalar
8582        for quant_type in self.static_quant_types:
8583            m = M()
8584            ref_m = torch.ao.quantization.QuantWrapper(M())
8585            is_qat = quant_type == QuantType.QAT
8586            if is_qat:
8587                m.train()
8588                ref_m.train()
8589                qconfig = default_qat_qconfig
8590                expected_act_post_process = torch.ao.quantization.FakeQuantize
8591            else:
8592                m.eval()
8593                ref_m.eval()
8594                qconfig = default_qconfig
8595                expected_act_post_process = torch.ao.quantization.MinMaxObserver
8596
8597            prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx
8598            qconfig_dict = {"": qconfig}
8599            m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs)
8600            node_occurrence = {
8601                ns.call_module(expected_act_post_process): 7,
8602                ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0
8603            }
8604            self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
8605            m(*example_inputs)
8606            node_list = [
8607                ns.call_function(torch.quantize_per_tensor),
8608                ns.call_function(torch.ops.quantized.add),
8609                ns.call_function(torch.ops.quantized.add),
8610                ns.call_function(torch.ops.quantized.mul),
8611                ns.call_function(torch.ops.quantized.mul),
8612                ns.call_function(torch.ops.quantized.add_relu),
8613                ns.call_function(torch.cat),
8614                ns.call_method('dequantize')
8615            ]
8616            m = convert_fx(m)
8617            self.checkGraphModuleNodes(m, expected_node_list=node_list)
8618
8619            # make sure numerics match with eager mode
8620            ref_m.qconfig = qconfig
8621            prepare_function = prepare_qat if is_qat else prepare
8622            ref_m = prepare_function(ref_m)
8623            ref_m(*example_inputs)
8624            ref_m = convert(ref_m)
8625            # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar
8626            # self.assertEqual(m(data), ref_m(data))
8627
8628    def test_embedding(self):
8629        class M(torch.nn.Module):
8630            def __init__(self) -> None:
8631                super().__init__()
8632                self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
8633
8634            def forward(self, indices):
8635                return self.emb(indices)
8636
8637        for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]:
8638            model = M().eval()
8639            indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
8640            example_inputs = (indices,)
8641            quantized_node = ns.call_module(nnq.Embedding)
8642
8643            # check dynamic quant
8644            self.checkGraphModeFxOp(
8645                model,
8646                example_inputs,
8647                QuantType.DYNAMIC,
8648                quantized_node,
8649                custom_qconfig_dict={"": qconfig_type}
8650            )
8651            model = M().eval()
8652
8653            configs = [
8654                (qconfig_type, ns.call_module(nnq.Embedding)),
8655                (None, ns.call_module(nn.Embedding)),
8656                (default_qconfig, ns.call_module(nn.Embedding)),
8657            ]
8658
8659            # check static quantization
8660            for qconfig, node in configs:
8661                qconfig_dict = {"": qconfig}
8662                m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
8663                self.checkGraphModuleNodes(m, expected_node_occurrence={
8664                    ns.call_module(torch.ao.quantization.MinMaxObserver): 0
8665                })
8666                m = convert_fx(m)
8667                self.checkGraphModuleNodes(m, expected_node=node)
8668                # make sure it runs
8669                m(*example_inputs)
8670
8671    def test_embedding_bag(self):
8672        class M(torch.nn.Module):
8673            def __init__(self) -> None:
8674                super().__init__()
8675                self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True)
8676
8677            def forward(self, indices, offsets):
8678                return self.emb(indices, offsets)
8679
8680        indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
8681        offsets = torch.tensor([0, 19, 20, 28, 28, 32])
8682        quantized_node = ns.call_module(nnq.EmbeddingBag)
8683        example_inputs = (indices, offsets)
8684
8685        for dtype in [torch.quint8, torch.quint4x2]:
8686            model = M().eval()
8687            float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
8688                                                                        qscheme=torch.per_channel_affine_float_qparams,
8689                                                                        ch_axis=0)
8690            float_qparams_qconfig = QConfig(activation=default_placeholder_observer,
8691                                            weight=float_qparams_observer)
8692            self.checkGraphModeFxOp(
8693                model,
8694                example_inputs,
8695                QuantType.DYNAMIC,
8696                quantized_node,
8697                custom_qconfig_dict={"": float_qparams_qconfig}
8698            )
8699
8700        # check it works in None and static qconfig
8701        for qconfig in [None, default_qconfig]:
8702            qconfig_dict = {"": default_qconfig}
8703            m = M().eval()
8704            m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
8705            self.checkGraphModuleNodes(m, expected_node_occurrence={
8706                ns.call_module(torch.ao.quantization.MinMaxObserver): 0
8707            })
8708            m = convert_fx(m)
8709            self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
8710            # make sure it runs
8711            m(*example_inputs)
8712
8713    def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input):
8714        options = itertools.product(qconfigs, module_type_strs)
8715        for qconfig, module_type_str in options:
8716            model_eager = M(module_type_str).eval()
8717            model_graph = copy.deepcopy(model_eager)
8718            if torch.backends.quantized.engine == 'qnnpack' and \
8719               qconfig is float16_dynamic_qconfig:
8720                continue
8721                # fp16 dynamic quant is not supported for qnnpack
8722
8723            eager_qconfig_dict = dict.fromkeys(module_types, qconfig)
8724            model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)
8725
8726            graph_qconfig_dict = {
8727                "object_type": [
8728                    (x, qconfig) for x in module_types
8729                ]
8730            }
8731            model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,))
8732            model_graph = convert_fx(model_graph)
8733            self.assertEqual(model_eager(sample_input), model_graph(sample_input))
8734            self.checkScriptable(model_graph, [[sample_input]], True)
8735
8736    @override_qengines
8737    def test_rnn_cell(self):
8738        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
8739            return
8740        qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
8741        module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']
8742        module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell]
8743        sample_input = torch.tensor([[100, -155],
8744                                     [-155, 100],
8745                                     [100, -155]], dtype=torch.float)
8746        self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input)
8747
8748    @override_qengines
8749    def test_rnn(self):
8750        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
8751            return
8752        qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
8753        module_type_strs = ['LSTM', 'GRU']
8754        module_types = [torch.nn.LSTM, torch.nn.GRU]
8755        niter = 10
8756        sample_input = torch.tensor([[100, -155],
8757                                     [-155, 100],
8758                                     [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
8759        self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)
8760
8761    def _test_conv_transpose_impl(
8762            self, float_cls: Callable, q_cls: Callable, data: torch.Tensor):
8763        with override_quantized_engine('qnnpack'):
8764            # Create fp32 versions of FX and Eager models
8765            m1 = torch.nn.Sequential(float_cls(1, 1, 1))
8766            m2 = torch.nn.Sequential(float_cls(1, 1, 1))
8767            m2.load_state_dict(m1.state_dict())
8768            m2 = torch.ao.quantization.QuantWrapper(m2)
8769            # FX graph
8770            result_dict = self.checkGraphModeFxOp(
8771                m1, (data,), QuantType.STATIC,
8772                expected_node_occurrence={
8773                    ns.call_module(q_cls): 1,
8774                })
8775            q_result1 = result_dict["quantized_output"]
8776            # Eager
8777            m2.qconfig = get_default_qconfig(torch.backends.quantized.engine)
8778            m2.eval()
8779            m2p = torch.ao.quantization.prepare(m2)
8780            m2p(data)
8781            m2q = torch.ao.quantization.convert(m2p)
8782            q_result2 = m2q(data)
8783            # verify results match
8784            self.assertEqual(q_result1, q_result2)
8785
8786    @unittest.skipUnless('qnnpack' in supported_qengines,
8787                         "This Pytorch Build has not been built with or does not support QNNPACK")
8788    def test_conv_transpose_1d(self):
8789        self._test_conv_transpose_impl(
8790            torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4))
8791
8792    @unittest.skipUnless('qnnpack' in supported_qengines,
8793                         "This Pytorch Build has not been built with or does not support QNNPACK")
8794    def test_conv_transpose_2d(self):
8795        self._test_conv_transpose_impl(
8796            torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4))
8797
8798    def test_reshape_fp16(self):
8799        class M(torch.nn.Module):
8800            def __init__(self, w, b):
8801                super().__init__()
8802                self.w = w
8803                self.b = b
8804
8805            def forward(self, x):
8806                x = torch.nn.functional.linear(x, self.w)
8807                x = x.reshape(-1, 4)
8808                x = torch.nn.functional.linear(x, self.w)
8809                return x
8810
8811        w = torch.randn(4, 4)
8812        b = torch.randn(4)
8813        m = M(w, b).eval()
8814        qconfig_dict = {
8815            # reshape will be quantized to fp16 as requested by this qconfig
8816            "": float16_static_qconfig,
8817            "object_type": [
8818                (torch.nn.functional.linear, default_qconfig)
8819            ]
8820        }
8821        backend_config = get_test_only_legacy_native_backend_config()
8822        example_inputs = (torch.randn(1, 4),)
8823        m = prepare_fx(
8824            m, qconfig_dict, example_inputs=example_inputs,
8825            backend_config=backend_config)
8826        expected_occurrence = {
8827            # input and weight of first and second linear, output of first and second linear
8828            ns.call_module(torch.ao.quantization.MinMaxObserver): 6,
8829            # we insert placeholder observer for both input and output of reshape
8830            ns.call_module(torch.ao.quantization.PlaceholderObserver): 2
8831        }
8832        self.checkGraphModuleNodes(
8833            m,
8834            expected_node_occurrence=expected_occurrence
8835        )
8836        m = convert_fx(m, backend_config=backend_config)
8837        expected_occurrence = {
8838            ns.call_function(torch.quantize_per_tensor): 2,
8839            # dequantize after first linear, before reshape and before output
8840            ns.call_method("dequantize"): 3,
8841            # before reshape, to(fp16)
8842            ns.call_method("to"): 1,
8843            ns.call_function(torch.ops.quantized.linear): 2
8844        }
8845        self.checkGraphModuleNodes(
8846            m,
8847            expected_node_occurrence=expected_occurrence
8848        )
8849        # make sure it runs
8850        m(torch.randn(2, 4))
8851
8852    def test_multiple_qconfigs_for_single_value(self):
8853        """ Test multiple qconfigs for a single value"""
8854        class M(torch.nn.Module):
8855            def __init__(self, w, b):
8856                super().__init__()
8857                self.w = w
8858                self.b = b
8859
8860            def forward(self, x):
8861                x = torch.nn.functional.linear(x, self.w)
8862                x = torch.sigmoid(x)
8863                return x
8864
8865        w = torch.randn(4, 4)
8866        b = torch.randn(4)
8867        m = M(w, b).eval()
8868        # TODO: use get_default_qconfig_mapping once it handles fp16
8869        qconfig_mapping = QConfigMapping() \
8870            .set_global(float16_static_qconfig) \
8871            .set_object_type(torch.nn.functional.linear, default_qconfig)
8872        example_inputs = (torch.randn(1, 4),)
8873        backend_config = get_test_only_legacy_native_backend_config()
8874        m = prepare_fx(
8875            m, qconfig_mapping, example_inputs=example_inputs,
8876            backend_config=backend_config)
8877        expected_occurrence = {
8878            # input and weight of linear, output of linear
8879            ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
8880            # input and output of sigmoid
8881            ns.call_module(torch.ao.quantization.PlaceholderObserver): 2,
8882        }
8883        self.checkGraphModuleNodes(
8884            m,
8885            expected_node_occurrence=expected_occurrence
8886        )
8887        # make sure it runs
8888        m = convert_fx(m)
8889        expected_occurrence = {
8890            ns.call_function(torch.quantize_per_tensor): 1,
8891            ns.call_method("dequantize"): 3,
8892            ns.call_method("to"): 2
8893        }
8894        self.checkGraphModuleNodes(
8895            m,
8896            expected_node_occurrence=expected_occurrence
8897        )
8898
8899    def test_boolean_tensor(self):
8900        """ Make sure we don't insert observer for boolean Tensors """
8901        class M(torch.nn.Module):
8902            def forward(self, x, mask):
8903                mask = mask.unsqueeze(0)
8904                mask = mask.unsqueeze(1)
8905                x = x.masked_fill(mask, 1)
8906                return x
8907
8908        m = M().eval()
8909        example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool())
8910        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
8911        expected_occurrence = {
8912            ns.call_module(torch.ao.quantization.MinMaxObserver): 0
8913        }
8914        self.checkGraphModuleNodes(
8915            m,
8916            expected_node_occurrence=expected_occurrence)
8917        m = convert_fx(m)
8918        m(*example_inputs)
8919
8920    def test_chunk(self):
8921        class M(torch.nn.Module):
8922            def forward(self, x):
8923                x, y = torch.chunk(x, 2)
8924                x = x + y
8925                return x
8926        m = M().eval()
8927        example_inputs = (torch.rand(2, 2, 2, 2),)
8928        m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs)
8929        m(*example_inputs)
8930        m = convert_fx(m)
8931        m(*example_inputs)
8932        # make sure everything runs
8933
8934    def test_ref_pattern_multi_use(self):
8935        class M(torch.nn.Module):
8936            def __init__(self) -> None:
8937                super().__init__()
8938                self.linear = torch.nn.Linear(5, 5)
8939                self.linear1 = torch.nn.Linear(5, 5)
8940
8941            def forward(self, x):
8942                y = self.linear(x)
8943                z = self.linear1(x)
8944                a = torch.mul(z, 5)
8945                b = torch.add(z, 5)
8946                return (y, a, b)
8947
8948        m = M().eval()
8949        qconfig_dict = {
8950            "": None,
8951            "object_type": [
8952                (torch.nn.Linear, get_default_qconfig("fbgemm")),
8953                (torch.nn.ReLU, get_default_qconfig("fbgemm")),
8954            ],
8955        }
8956        example_inputs = (torch.randn(1, 5),)
8957        m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
8958        m = convert_fx(m)
8959        expected_occurrence = {
8960            ns.call_function(torch.quantize_per_tensor): 1,
8961            ns.call_module(nnq.Linear): 2,
8962            ns.call_method("dequantize"): 2,
8963            ns.call_function(torch.add): 1,
8964            ns.call_function(torch.mul): 1,
8965        }
8966        self.checkGraphModuleNodes(
8967            m,
8968            expected_node_occurrence=expected_occurrence)
8969
8970    def test_qmatmul(self):
8971        class M(torch.nn.Module):
8972            def forward(self, x, y):
8973                z = torch.matmul(x, y)
8974                return z
8975
8976        m = M().eval()
8977        example_inputs = (torch.randn(2, 2), torch.randn(2, 2))
8978        qconfig_dict = get_default_qconfig_mapping("fbgemm")
8979        mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
8980        mp(*example_inputs)
8981        mq = convert_fx(mp)
8982        expected_occurrence = {
8983            ns.call_function(torch.matmul): 0,
8984            ns.call_function(torch.ops.quantized.matmul): 1,
8985        }
8986        self.checkGraphModuleNodes(
8987            mq,
8988            expected_node_occurrence=expected_occurrence)
8989        # verify no crash
8990        res = mq(*example_inputs)
8991
8992    def test_pixel_shuffle(self):
8993        class MyBias(nn.Module):
8994            def __init__(self) -> None:
8995                super().__init__()
8996                self.bias = nn.Parameter(torch.randn(8))
8997
8998        class MyModel(nn.Module):
8999            def __init__(self) -> None:
9000                super().__init__()
9001                self.conv = nn.Conv2d(8, 8, 1, bias=False)
9002                self.bias = MyBias()
9003
9004            def forward(self, x):
9005                x = self.conv(x)
9006                x = nn.functional.pixel_shuffle(x, 2)
9007                x = x.view(-1, 8, 2, 2)
9008                bias = self.bias.bias
9009                return x + bias
9010
9011        backend_config = get_qnnpack_backend_config()
9012        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
9013        model = MyModel()
9014        m = prepare_fx(
9015            model,
9016            qconfig_mapping=qconfig_mapping,
9017            example_inputs=(torch.randn(1, 8, 3, 3),),
9018            backend_config=backend_config
9019        )
9020        m = convert_fx(m)
9021        expected_occurrence = {
9022            ns.call_function(torch.quantize_per_tensor): 2,
9023            ns.call_method("dequantize"): 1,
9024        }
9025        self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
9026
9027    def test_pixel_shuffle_module(self) -> None:
9028        class MyBias(nn.Module):
9029            def __init__(self) -> None:
9030                super().__init__()
9031                self.bias = nn.Parameter(torch.randn(8))
9032
9033        class MyModel(nn.Module):
9034            def __init__(self) -> None:
9035                super().__init__()
9036                self.conv = nn.Conv2d(8, 8, 1, bias=False)
9037                self.ps = nn.PixelShuffle(upscale_factor=2)
9038                self.bias = MyBias()
9039
9040            def forward(self, x):
9041                x = self.conv(x)
9042                x = self.ps(x)
9043                x = x.view(-1, 8, 2, 2)
9044                bias = self.bias.bias
9045                return x + bias
9046
9047        backend_config = get_qnnpack_backend_config()
9048        qconfig_mapping = get_default_qconfig_mapping("qnnpack")
9049        model = MyModel()
9050        m = prepare_fx(
9051            model,
9052            qconfig_mapping=qconfig_mapping,
9053            example_inputs=(torch.randn(1, 8, 3, 3),),
9054            backend_config=backend_config
9055        )
9056        m = convert_fx(m)
9057        expected_occurrence = {
9058            ns.call_function(torch.quantize_per_tensor): 2,
9059            ns.call_method("dequantize"): 1,
9060            ns.call_module(nn.PixelShuffle): 1,
9061        }
9062        self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
9063
9064    def test_pixel_unshuffle(self):
9065        class MyBias(nn.Module):
9066            def __init__(self) -> None:
9067                super().__init__()
9068                self.bias = nn.Parameter(torch.randn(64))
9069
9070        class MyModel(nn.Module):
9071            def __init__(self) -> None:
9072                super().__init__()
9073                self.conv = nn.Conv2d(8, 8, 1, bias=False)
9074                self.bias = MyBias()
9075
9076            def forward(self, x):
9077                x = self.conv(x)
9078                x = nn.functional.pixel_unshuffle(x, 2)
9079                bias = self.bias.bias
9080                return x + bias
9081
9082        for backend in ["fbgemm", "qnnpack"]:
9083            if backend == "fbgemm":
9084                backend_config = get_fbgemm_backend_config()
9085            else:
9086                backend_config = get_qnnpack_backend_config()
9087            qconfig_mapping = get_default_qconfig_mapping(backend)
9088            model = MyModel()
9089            m = prepare_fx(
9090                model,
9091                qconfig_mapping=qconfig_mapping,
9092                example_inputs=(torch.randn(1, 8, 6, 6),),
9093                backend_config=backend_config
9094            )
9095            m = convert_fx(m)
9096            expected_occurrence = {
9097                ns.call_function(torch.quantize_per_tensor): 2,
9098                ns.call_method("dequantize"): 1,
9099            }
9100            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
9101
9102    def test_pixel_unshuffle_module(self) -> None:
9103        class MyBias(nn.Module):
9104            def __init__(self) -> None:
9105                super().__init__()
9106                self.bias = nn.Parameter(torch.randn(64))
9107
9108        class MyModel(nn.Module):
9109            def __init__(self) -> None:
9110                super().__init__()
9111                self.conv = nn.Conv2d(8, 8, 1, bias=False)
9112                self.unshuffle = nn.PixelUnshuffle(downscale_factor=2)
9113                self.bias = MyBias()
9114
9115            def forward(self, x):
9116                x = self.conv(x)
9117                x = self.unshuffle(x)
9118                bias = self.bias.bias
9119                return x + bias
9120
9121        for backend in ["fbgemm", "qnnpack"]:
9122            if backend == "fbgemm":
9123                backend_config = get_fbgemm_backend_config()
9124            else:
9125                backend_config = get_qnnpack_backend_config()
9126            qconfig_mapping = get_default_qconfig_mapping(backend)
9127            model = MyModel()
9128            m = prepare_fx(
9129                model,
9130                qconfig_mapping=qconfig_mapping,
9131                example_inputs=(torch.randn(1, 8, 6, 6),),
9132                backend_config=backend_config
9133            )
9134            m = convert_fx(m)
9135            expected_occurrence = {
9136                ns.call_function(torch.quantize_per_tensor): 2,
9137                ns.call_method("dequantize"): 1,
9138                ns.call_module(nn.PixelUnshuffle): 1,
9139            }
9140            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
9141
9142
9143
9144    def test_narrow(self):
9145        class MyBias(nn.Module):
9146            def __init__(self) -> None:
9147                super().__init__()
9148                self.bias = nn.Parameter(torch.randn(4))
9149
9150        class MyModel(nn.Module):
9151            def __init__(self) -> None:
9152                super().__init__()
9153                self.conv = nn.Conv2d(8, 8, 1, bias=False)
9154                self.bias = MyBias()
9155
9156            def forward(self, x):
9157                x = self.conv(x)
9158                x = torch.narrow(x, 1, 0, 4)
9159                bias = self.bias.bias
9160                return x + bias
9161
9162        for backend in ["fbgemm", "qnnpack"]:
9163            if backend == "fbgemm":
9164                backend_config = get_fbgemm_backend_config()
9165            else:
9166                backend_config = get_qnnpack_backend_config()
9167            qconfig_mapping = get_default_qconfig_mapping(backend)
9168            model = MyModel()
9169            m = prepare_fx(
9170                model,
9171                qconfig_mapping=qconfig_mapping,
9172                example_inputs=(torch.randn(1, 8, 3, 3),),
9173                backend_config=backend_config
9174            )
9175            m = convert_fx(m)
9176            expected_occurrence = {
9177                ns.call_function(torch.quantize_per_tensor): 2,
9178                ns.call_method("dequantize"): 1,
9179            }
9180            self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence)
9181
9182class TestQuantizeFxModels(QuantizationTestCase):
9183    @skipIfNoFBGEMM
9184    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
9185    def test_static_gpu_convert_basic(self):
9186
9187        class Net(nn.Module):
9188            def __init__(self) -> None:
9189                super().__init__()
9190                self.relu1 = nn.ReLU()
9191                self.conv1 = nn.Conv2d(1, 6, 5)
9192                self.linear1 = nn.Linear(120, 1)
9193
9194            def forward(self, x):
9195                x = self.relu1(self.conv1(x))
9196                y = self.linear1(x.view(-1))
9197                return y
9198
9199        input = torch.randn((5, 1, 6, 6)).to('cuda')
9200        example_inputs = (input,)
9201        model = Net().to('cuda').eval()
9202        qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
9203        model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
9204        model_prepared(*example_inputs)
9205        model_quantized = convert_to_reference_fx(model_prepared)
9206        out = model_quantized(*example_inputs)
9207        self.assertEqual(out.device.type, 'cuda')
9208
9209    @skipIfNoFBGEMM
9210    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
9211    def test_switch_device_prepare_convert(self):
9212
9213        class Net(nn.Module):
9214            def __init__(self) -> None:
9215                super().__init__()
9216                self.relu1 = nn.ReLU()
9217                self.conv1 = nn.Conv2d(1, 6, 5)
9218                self.linear1 = nn.Linear(120, 1)
9219
9220            def forward(self, x):
9221                x = self.relu1(self.conv1(x))
9222                y = self.linear1(x.view(-1))
9223                return y
9224
9225        for device in ['cuda', 'cpu']:
9226            device_after = 'cuda' if device == 'cpu' else 'cpu'
9227            input = torch.randn((5, 1, 6, 6)).to(device)
9228            model = Net().to(device).eval()
9229            qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
9230            model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,))
9231            model_prepared(input)
9232            model_prepared.to(device_after)
9233            model_quantized = convert_to_reference_fx(model_prepared)
9234            out = model_quantized(input.to(device_after))
9235            self.assertEqual(out.device.type, device_after)
9236
9237    @skipIfNoFBGEMM
9238    @unittest.skipIf(not TEST_CUDA, "gpu is not available.")
9239    def test_prepare_serialize_switch_device_convert(self):
9240        class Net(nn.Module):
9241            def __init__(self) -> None:
9242                super().__init__()
9243                self.conv1 = nn.Conv2d(1, 6, 5)
9244                self.linear1 = nn.Linear(120, 1)
9245
9246            def forward(self, x):
9247                x = self.conv1(x)
9248                y = self.linear1(x.view(-1))
9249                return y
9250
9251        for device in ['cuda', 'cpu']:
9252            for device_after in ['cuda', 'cpu']:
9253                input = torch.randn((5, 1, 6, 6)).to(device)
9254                model = Net().to(device).eval()
9255                qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')}
9256                model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,))
9257                model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,))
9258                model_prepared_first(input)
9259                state_dict = model_prepared_first.state_dict()
9260                del model_prepared_first
9261                model_prepared_second.load_state_dict(state_dict)
9262                model_prepared_second.to(device_after)
9263                model_quantized = convert_to_reference_fx(model_prepared_second)
9264                out = model_quantized(input.to(device_after))
9265                self.assertEqual(out.device.type, device_after)
9266
9267    @skipIfTorchDynamo("too slow")
9268    @skip_if_no_torchvision
9269    def test_model_dropout(self):
9270        from torchvision import models
9271        m = models.mobilenet_v3_small()
9272        qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm')
9273        example_inputs = (torch.randn(1, 3, 224, 224),)
9274        mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs)
9275        mp(*example_inputs)
9276        with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext():
9277            mq = convert_fx(mp)
9278        mq(*example_inputs)
9279
9280    def _test_model_impl(
9281            self, mode, name, model, eager_quantizable_model,
9282            check_with_eager=True,
9283            diff_of_quant=None,
9284            diff_from_eager=None):
9285        if diff_of_quant is None or diff_from_eager is None:
9286            diff_of_quant = {}
9287            diff_from_eager = {}
9288
9289        if mode not in diff_of_quant or mode not in diff_from_eager:
9290            diff_of_quant[mode] = {}
9291            diff_from_eager[mode] = {}
9292
9293        input_tensor = torch.rand(1, 3, 224, 224)
9294        input_tensor_inception = torch.rand(1, 3, 299, 299)
9295        output_value = torch.randint(0, 1, (1,))
9296
9297        # print('quantizing:', name, ' mode:', mode)
9298        if name == 'inception_v3':
9299            input_value = input_tensor_inception
9300        else:
9301            input_value = input_tensor
9302
9303        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
9304        qconfig_dict = {'': qconfig}
9305        script = torch.jit.script(model)
9306
9307        # make sure graph module and script module are both runanble
9308        original_out = model(input_value)
9309        is_not_tuple_out = not isinstance(original_out, tuple)
9310        script_out = script(input_value)
9311
9312        # set to train just before quantization
9313        prepare_fx_fn = prepare_fx
9314        if mode != 'static':
9315            model.train()
9316            prepare_fx_fn = prepare_qat_fx
9317
9318        prepared = prepare_fx_fn(model, qconfig_dict)
9319
9320        if mode == 'ddp':
9321            mp.spawn(run_ddp,
9322                     args=(world_size, prepared),  # noqa: F821
9323                     nprocs=world_size,  # noqa: F821
9324                     join=True)
9325        elif mode == 'qat':
9326            assert prepared.training, 'prepared must be in training mode for qat'
9327            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
9328            criterion = nn.CrossEntropyLoss()
9329            train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
9330        else:
9331            for i in range(10):
9332                prepared(input_value)
9333
9334        # print('after observation root:', prepared.root)
9335
9336        qgraph = convert_fx(prepared)
9337        # print('after quantization root:', qgraph.root)
9338        # print('after quantization code:', qgraph.src)
9339        qgraph.eval()
9340        qgraph_script = torch.jit.script(qgraph)
9341        # print('quantized and scripted:', qgraph_script.graph)
9342
9343        qgraph_out = qgraph(input_value)
9344        qgraph_script = qgraph_script(input_value)
9345
9346        if is_not_tuple_out:
9347            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
9348            assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph'
9349        else:
9350            print('tuple output')
9351
9352        if eager_quantizable_model is not None:
9353            # comparing to eager mode quantization
9354            qeager = eager_quantizable_model
9355            ref_out = qeager(input_value)
9356            qeager.qconfig = qconfig
9357            if mode == 'static':
9358                qeager.fuse_model()
9359                prepare(qeager, inplace=True)
9360            else:
9361                qeager.train()
9362                qeager.fuse_model()
9363                prepare_qat(qeager, inplace=True)
9364
9365            # calibration
9366            if mode == 'ddp':
9367                mp.spawn(run_ddp,
9368                         args=(world_size, qeager),  # noqa: F821
9369                         nprocs=world_size,  # noqa: F821
9370                         join=True)
9371            elif mode == 'qat':
9372                assert qeager.training, 'qeager should be in training mode for qat'
9373                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
9374                train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
9375            else:
9376                for i in range(10):
9377                    qeager(input_value)
9378
9379            # print('ref after observation:', qeager)
9380
9381            convert(qeager, inplace=True)
9382            qeager.eval()
9383
9384            # print('ref after quantization:', qeager)
9385            qeager_out = qeager(input_value)
9386            qeager_script = torch.jit.script(qeager)
9387            qscript_out = qeager_script(input_value)
9388            if is_not_tuple_out:
9389                diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max()
9390                if check_with_eager:
9391                    self.assertEqual(diff_from_eager[mode][name], 0,
9392                                     'Result of graph mode quantization and ' +
9393                                     'eager mode quantization on model: ' + name +
9394                                     ' should match. Mode: ' + mode +
9395                                     ' diff:' + str(diff_from_eager[mode][name]))
9396
9397    def _test_building_block(self, quant_type, BB):
9398        eager = BB().float()
9399        graph = copy.deepcopy(eager)
9400
9401        if quant_type == QuantType.STATIC:
9402            qconfig = default_qconfig
9403            eager_prepare = prepare
9404            graph_prepare = prepare_fx
9405            eager.eval()
9406            graph.eval()
9407            calibrate_or_train = test_only_eval_fn
9408            data = self.img_data_2d
9409            is_qat = False
9410        else:
9411            assert quant_type == QuantType.QAT
9412            qconfig = default_qat_qconfig
9413            eager_prepare = prepare_qat
9414            graph_prepare = prepare_qat_fx
9415            eager.train()
9416            graph.train()
9417            calibrate_or_train = test_only_train_fn
9418            data = self.img_data_2d_train
9419            is_qat = True
9420
9421        if hasattr(eager, "fuse_model"):
9422            eager.fuse_model()
9423        eager = QuantWrapper(eager)
9424        eager.qconfig = qconfig
9425        eager = eager_prepare(eager)
9426
9427        qconfig_dict = {"": qconfig}
9428        graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],))
9429
9430        eager_out = eager(data[0][0])
9431        graph_out = graph(data[0][0])
9432        # Eager Mode and FX Graph Mode QAT now differ in numerics both
9433        # in Post Training and QAT because FX Graph Mode uses same fake_quant instances
9434        # for input and output of CopyNode
9435        # self.assertEqual(eager_out, graph_out)
9436
9437        calibrate_or_train(eager, data)
9438        calibrate_or_train(graph, data)
9439
9440        eager = convert(eager)
9441        graph = convert_fx(graph)
9442
9443        eager_out = eager(data[0][0])
9444        graph_out = graph(data[0][0])
9445
9446    @override_qengines
9447    def test_resnet_base(self):
9448        models = [ResNetBase]
9449        options = itertools.product(self.static_quant_types, models)
9450        for quant_type, M in options:
9451            self._test_building_block(quant_type, M)
9452
9453    @skip_if_no_torchvision
9454    @skipIfNoFBGEMM
9455    @unittest.skip("skip for now since tbb failed")
9456    def test_torchvision(self):
9457        from torchvision import models
9458        from torchvision.models import quantization as quantized_models
9459        from torchvision.models.quantization.utils import _replace_relu
9460
9461        def get_available_classification_models(models):
9462            return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
9463
9464        model_list = get_available_classification_models(models)
9465        quantized_model_list = get_available_classification_models(quantized_models)
9466
9467        quantized_model_list = set(quantized_model_list)
9468        # test eager and graph consistency
9469        model_list = quantized_model_list
9470        # mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat
9471        # we might observe the output of AdaptiveAveragePool in the future
9472        # and re-enable the test
9473        fx_eager_not_matching = [
9474            ("mobilenet_v2", "qat"),
9475            ("inception_v3", "qat"),
9476            ("googlenet", "qat")
9477        ]  # because relu6 is replaced as relu in mobilenetv2
9478
9479        diff_of_quant = {}
9480        diff_from_eager = {}
9481        modes = ['static', 'qat']
9482        options = itertools.product(modes, model_list)
9483        for mode, name in options:
9484            pretrained = name in quantized_model_list  # load pretrained model to compare with quantized model
9485            kwargs = {}
9486            # turn off transform input for inception_v3 since
9487            # it's not quantized in eager mode and in fx graph
9488            # mode we can't skip quantizing a method right now
9489            # (might be supported in the future)
9490            if name in ["inception_v3", "googlenet"]:
9491                kwargs["transform_input"] = False
9492            eager_quantizable_model = None
9493            if name in quantized_model_list:
9494                eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float()
9495            # compare with eager mode quantized model when it is available
9496            pretrained = eager_quantizable_model is not None
9497            model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float()
9498            if name == "mobilenet_v2":
9499                _replace_relu(model)
9500            # disable aux logits
9501            if hasattr(model, "aux_logits"):
9502                model.aux_logits = False
9503                model.AuxLogits = None
9504                if eager_quantizable_model:
9505                    eager_quantizable_model.aux_logits = False
9506                    eager_quantizable_model.AuxLogits = None
9507
9508            check_with_eager = (name, mode) not in fx_eager_not_matching
9509            self._test_model_impl(
9510                mode, name, model, eager_quantizable_model,
9511                check_with_eager,
9512                diff_of_quant, diff_from_eager)
9513
9514        def print_diffs(diffs):
9515            for mode, diffs_for_mode in diffs.items():
9516                print('mode:', mode)
9517                for name, diff in diffs_for_mode.items():
9518                    print(name, ':', diff)
9519
9520        # print('differences between float and quantized')
9521        # print_diffs(diff_of_quant)
9522        # print('----------------------')
9523        # print('differences between graph mode and eager mode')
9524        # print_diffs(diff_from_eager)
9525        # print('----------------------')
9526
9527    @skip_if_no_torchvision
9528    @skipIfNoFBGEMM
9529    @unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979")
9530    def test_resnet18_ddp(self):
9531        from torchvision import models
9532        from torchvision.models import quantization as quantized_models
9533        eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float()  # noqa: F821
9534        model = models.__dict__[name](pretrained=False).eval().float()  # noqa: F821
9535        self._test_model_impl(
9536            'ddp', 'resnet18', model, eager_quantizable_model)
9537
9538    @override_qengines
9539    def test_qat_embeddingbag_linear(self):
9540        for device in get_supported_device_types():
9541            class EmbeddingBagLinear(torch.nn.Module):
9542                def __init__(self) -> None:
9543                    super().__init__()
9544                    self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
9545                    self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)
9546
9547                def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
9548                            per_sample_weights: Optional[torch.Tensor] = None):
9549                    x = self.emb(input, offsets, per_sample_weights)
9550                    x = self.linear(x)
9551                    return x
9552
9553            qengine = torch.backends.quantized.engine
9554            qconfig_dict = QConfigMapping() \
9555                .set_global(get_default_qat_qconfig(qengine)) \
9556                .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig)
9557
9558            train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
9559            eval_output = [[torch.randint(0, 10, (12, 1))]]
9560
9561            model = EmbeddingBagLinear().train()
9562            prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
9563            test_only_train_fn(prepared_fx_model, train_indices)
9564            quant_model = convert_fx(prepared_fx_model,
9565                                     qconfig_mapping=qconfig_dict)
9566
9567            def checkQuantized(model):
9568                # Make sure EmbeddingBag is now a quantized EmbeddingBag.
9569                self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag)
9570                # Also test that Linear has been quantized.
9571                self.assertTrue(type(model.linear), nnq.Linear)
9572
9573                test_only_eval_fn(model, eval_output)
9574                self.checkScriptable(model, eval_output)
9575                self.checkNoQconfig(model)
9576            checkQuantized(quant_model)
9577
9578
9579    @override_qengines
9580    def test_qat_embedding_linear(self):
9581        for device in get_supported_device_types():
9582            class EmbeddingLinear(torch.nn.Module):
9583                def __init__(self) -> None:
9584                    super().__init__()
9585                    self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
9586                    self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float)
9587
9588                def forward(self, input: torch.Tensor):
9589                    x = torch.sum(self.emb(input), dim=1)
9590                    x = self.linear(x)
9591                    return x
9592
9593            qengine = torch.backends.quantized.engine
9594            qconfig_dict = {"": get_default_qat_qconfig(qengine),
9595                            "object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]}
9596
9597
9598            train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)]
9599            eval_output = [[torch.randint(0, 10, (12, 1))]]
9600
9601            model = EmbeddingLinear().train()
9602            prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],))
9603            test_only_train_fn(prepared_fx_model, train_indices)
9604            quant_model = convert_fx(prepared_fx_model,
9605                                     qconfig_mapping=qconfig_dict)
9606
9607            def checkQuantized(model):
9608                # Make sure EmbeddingBag is now a quantized EmbeddingBag.
9609                self.assertTrue(type(model.emb), nn.quantized.Embedding)
9610                # Also test that Linear has been quantized.
9611                self.assertTrue(type(model.linear), nnq.Linear)
9612
9613                test_only_eval_fn(model, eval_output)
9614                self.checkScriptable(model, eval_output)
9615                self.checkNoQconfig(model)
9616            checkQuantized(quant_model)
9617
9618    @given(
9619        device=st.sampled_from(
9620            ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
9621        )
9622    )
9623    @settings(deadline=None)
9624    @override_qengines
9625    def test_qat_functional_linear(self, device):
9626        if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'):
9627            return
9628
9629        class Linear(torch.nn.Module):
9630            def __init__(self) -> None:
9631                super().__init__()
9632                self.w = torch.ones(5, 5)
9633                self.b = torch.zeros(5)
9634
9635            def forward(self, x):
9636                return torch.nn.functional.linear(x, self.w, self.b)
9637
9638        class M(torch.nn.Module):
9639            def __init__(self) -> None:
9640                super().__init__()
9641                self.mods1 = torch.nn.Sequential(Linear(), Linear())
9642                self.mods2 = Linear()
9643
9644            def forward(self, x):
9645                x = self.mods1(x)
9646                x = self.mods2(x)
9647                return x
9648
9649        model = M().train()
9650        ref_fake_quant = FakeQuantize.with_args(
9651            observer=MovingAverageMinMaxObserver,
9652            quant_min=0,
9653            quant_max=255,
9654            dtype=torch.quint8,
9655            reduce_range=False,
9656        )
9657        ref_weight_fake_quant = FakeQuantize.with_args(
9658            observer=MovingAverageMinMaxObserver,
9659            quant_min=-128,
9660            quant_max=127,
9661            dtype=torch.qint8,
9662            reduce_range=False,
9663        )
9664        ref_qat_qconfig = QConfig(
9665            activation=ref_fake_quant, weight=ref_weight_fake_quant
9666        )
9667        qconfig_dict = {"": ref_qat_qconfig}
9668        example_inputs = (torch.randn(1, 5),)
9669        prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs)
9670
9671        custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
9672            observer=MovingAverageMinMaxObserver,
9673            quant_min=0,
9674            quant_max=255,
9675            dtype=torch.quint8,
9676            reduce_range=False,
9677        )
9678        custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(
9679            observer=MovingAverageMinMaxObserver,
9680            quant_min=-128,
9681            quant_max=127,
9682            dtype=torch.qint8,
9683            reduce_range=False,
9684        )
9685        custom_qconfig = QConfig(
9686            activation=custom_fake_quant, weight=custom_weight_fake_quant
9687        )
9688        custom_qconfig_dict = {"": custom_qconfig}
9689        prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs)
9690
9691        prepared.to(device)
9692        prepared_ref.to(device)
9693
9694        prepared.apply(torch.ao.quantization.disable_fake_quant)
9695        prepared.apply(torch.ao.quantization.disable_observer)
9696        prepared_ref.apply(torch.ao.quantization.disable_fake_quant)
9697        prepared_ref.apply(torch.ao.quantization.disable_observer)
9698
9699        inp = torch.randn(5, 5, device=device, requires_grad=True)
9700        for i in range(10):
9701            if i == 2:
9702                prepared.apply(torch.ao.quantization.enable_observer)
9703                prepared_ref.apply(torch.ao.quantization.enable_observer)
9704            if i == 4:
9705                prepared.apply(torch.ao.quantization.enable_fake_quant)
9706                prepared_ref.apply(torch.ao.quantization.enable_fake_quant)
9707
9708            inp = torch.randn(5, 5, device=device, requires_grad=True)
9709            out_ref = prepared_ref(inp)
9710            out = prepared(inp)
9711            torch.testing.assert_close(out, out_ref)
9712
9713            # try backward pass
9714            labels = torch.randn(5, 5, device=device)
9715            loss = (out - labels).sum()
9716            grad = torch.autograd.grad(loss, [inp])
9717            loss_ref = (out_ref - labels).sum()
9718            grad_ref = torch.autograd.grad(loss_ref, [inp])
9719            torch.testing.assert_close(grad[0], grad_ref[0])
9720
9721        if 'fbgemm' in torch.backends.quantized.supported_engines:
9722            # During the lowering step in convert, fold_weight calls quantized::linear_prepack
9723            # which doesn't support QuantizedCuda backend
9724            prepared.cpu()
9725            prepared_ref.cpu()
9726            converted = convert_fx(prepared)
9727            converted_ref = convert_fx(prepared_ref)
9728            inp = torch.rand(5, 5)
9729            out = converted(inp)
9730            out_ref = converted_ref(inp)
9731
9732            torch.testing.assert_close(out, out_ref)
9733if __name__ == '__main__':
9734    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
9735                       "\tpython test/test_quantization.py TESTNAME\n\n"
9736                       "instead.")
9737