xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3r"""Importing this file includes common utility methods and base clases for
4checking quantization api and properties of resulting modules.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
11import torch.ao.nn.quantized as nnq
12import torch.ao.nn.quantized.dynamic as nnqd
13from torch.ao.nn.intrinsic import _FusedModule
14import torch.distributed as dist
15from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
16
17from torch._export import capture_pre_autograd_graph
18from torch.ao.quantization import (
19    QuantType,
20    default_dynamic_qat_qconfig,
21    default_embedding_qat_qconfig,
22    default_symmetric_qnnpack_qat_qconfig,
23)
24from torch.ao.quantization.quantize_pt2e import (
25    _convert_to_reference_decomposed_fx,
26    convert_pt2e,
27    prepare_pt2e,
28    prepare_qat_pt2e,
29)
30from torch.ao.quantization.backend_config import (
31    get_executorch_backend_config,
32)
33from torch.ao.quantization.quantizer.xnnpack_quantizer import (
34    XNNPACKQuantizer,
35    get_symmetric_quantization_config,
36)
37from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \
38    default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \
39    propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \
40    get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \
41    QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping
42from torch.ao.quantization.quantization_mappings import (
43    get_default_dynamic_quant_module_mappings,
44    get_default_qconfig_propagation_list,
45    get_default_qat_module_mappings,
46)
47from torch.testing._internal.common_quantized import (
48    override_quantized_engine,
49)
50from torch.jit.mobile import _load_for_lite_interpreter
51
52try:
53    # graph mode quantization based on fx
54    from torch.ao.quantization.quantize_fx import (
55        prepare_fx,
56        prepare_qat_fx,
57        convert_fx,
58        convert_to_reference_fx,
59    )
60    from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph
61    from torch.fx.graph import Node
62    from torch.fx import GraphModule
63    HAS_FX = True
64except ImportError:
65    HAS_FX = False
66
67import copy
68import io
69import functools
70import time
71import os
72
73import unittest
74import numpy as np
75from torch.testing import FileCheck
76from typing import Callable, Tuple, Dict, Any, Union, Type, Optional
77import torch._dynamo as torchdynamo
78import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
79from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
80import contextlib
81
82class NodeSpec:
83    ''' Used for checking GraphModule Node
84    '''
85    def __init__(self, op, target):
86        '''
87        op: call_function | call_module
88        target:
89          for call_function, target would be a function
90          for call_module, target would be the type of PyTorch module
91        '''
92        self.op = op
93        self.target = target
94
95    @classmethod
96    def call_function(cls, target):
97        return NodeSpec('call_function', target)
98
99    @classmethod
100    def call_method(cls, target):
101        return NodeSpec('call_method', target)
102
103    @classmethod
104    def call_module(cls, target):
105        return NodeSpec('call_module', target)
106
107    def __hash__(self):
108        return hash((self.op, self.target))
109
110    def __eq__(self, other):
111        if not isinstance(other, NodeSpec):
112            return NotImplemented
113
114        return self.op == other.op and self.target == other.target
115
116    def __repr__(self):
117        return repr(self.op) + " " + repr(self.target)
118
119def get_supported_device_types():
120    return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']
121
122def test_only_eval_fn(model, calib_data):
123    r"""
124    Default evaluation function takes a torch.utils.data.Dataset or a list of
125    input Tensors and run the model on the dataset
126    """
127    for inp in calib_data:
128        output = model(*inp)
129
130_default_loss_fn = torch.nn.CrossEntropyLoss()
131def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn):
132    r"""
133    Default train function takes a torch.utils.data.Dataset and train the model
134    on the dataset
135    """
136    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
137    train_loss, correct, total = 0, 0, 0
138    for i in range(10):
139        model.train()
140
141        for data, target in train_data:
142            optimizer.zero_grad()
143            output = model(data)
144            loss = loss_fn(output, target)
145            loss.backward()
146            optimizer.step()
147            train_loss += loss.item()
148            _, predicted = torch.max(output, 1)
149            total += target.size(0)
150            correct += (predicted == target).sum().item()
151    return train_loss, correct, total
152
153class AverageMeter:
154    """Computes and stores the average and current value"""
155    def __init__(self, name, fmt=':f'):
156        self.name = name
157        self.fmt = fmt
158        self.reset()
159
160    def reset(self):
161        self.val = 0
162        self.avg = 0
163        self.sum = 0
164        self.count = 0
165
166    def update(self, val, n=1):
167        self.val = val
168        self.sum += val * n
169        self.count += n
170        self.avg = self.sum / self.count
171
172    def __str__(self):
173        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
174        return fmtstr.format(**self.__dict__)
175
176
177def accuracy(output, target, topk=(1,)):
178    """Computes the accuracy over the k top predictions for the specified values of k"""
179    with torch.no_grad():
180        maxk = max(topk)
181        batch_size = target.size(0)
182
183        _, pred = output.topk(maxk, 1, True, True)
184        pred = pred.t()
185        correct = pred.eq(target.view(1, -1).expand_as(pred))
186
187        res = []
188        for k in topk:
189            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
190            res.append(correct_k.mul_(100.0 / batch_size))
191        return res
192
193def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches):
194    model.train()
195    cnt = 0
196    for image, target in data_loader:
197        start_time = time.time()
198        print('.', end='')
199        cnt += 1
200        image, target = image.to(device), target.to(device)
201        output = model(image)
202        loss = criterion(output, target)
203        optimizer.zero_grad()
204        loss.backward()
205        optimizer.step()
206        acc1, acc5 = accuracy(output, target, topk=(1, 5))
207        if cnt >= ntrain_batches:
208            return
209    return
210
211def ddp_setup(rank, world_size):
212    os.environ['MASTER_ADDR'] = 'localhost'
213    os.environ['MASTER_PORT'] = '12355'
214
215    # initialize the process group
216    dist.init_process_group("gloo", rank=rank, world_size=world_size)
217
218def ddp_cleanup():
219    dist.destroy_process_group()
220
221def run_ddp(rank, world_size, prepared):
222    ddp_setup(rank, world_size)
223    prepared.cuda()
224    prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank])
225    prepared.to(rank)
226    model_with_ddp = prepared
227    optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001)
228    train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1)  # noqa: F821
229    ddp_cleanup()
230
231
232def convert_dynamic(module):
233    convert(module, get_default_dynamic_quant_module_mappings(), inplace=True)
234
235def prepare_dynamic(model, qconfig_dict=None):
236    propagate_qconfig_(model, qconfig_dict)
237
238def _make_conv_test_input(
239    batch_size, in_channels_per_group, input_feature_map_size,
240    out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale,
241    W_zero_point, use_bias, use_channelwise,
242):
243    in_channels = in_channels_per_group * groups
244    out_channels = out_channels_per_group * groups
245
246    (X_value_min, X_value_max) = (0, 4)
247    X_init = torch.randint(
248        X_value_min, X_value_max,
249        (batch_size, in_channels,) + input_feature_map_size)
250    X = X_scale * (X_init - X_zero_point).float()
251    X_q = torch.quantize_per_tensor(
252        X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8)
253
254    W_scale = W_scale * out_channels
255    W_zero_point = W_zero_point * out_channels
256    # Resize W_scale and W_zero_points arrays equal to out_channels
257    W_scale = W_scale[:out_channels]
258    W_zero_point = W_zero_point[:out_channels]
259    # For testing, we use small values for weights and for activations so that
260    # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in
261    # qconv implementation and if there is no overflow.
262    # In reference we can't exactly match the results with reference.
263    # Please see the comment in qconv implementation file
264    #   aten/src/ATen/native/quantized/cpu/qconv.cpp for more details.
265    (W_value_min, W_value_max) = (-5, 5)
266    # The operator expects them in the format
267    # (out_channels, in_channels/groups,) + kernel_size
268    W_init = torch.randint(
269        W_value_min, W_value_max,
270        (out_channels, in_channels_per_group,) + kernel_size)
271    b_init = torch.randint(0, 10, (out_channels,))
272
273    if use_channelwise:
274        W_shape = (-1, 1) + (1,) * len(kernel_size)
275        W_scales_tensor = torch.tensor(W_scale, dtype=torch.float)
276        W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float)
277        W = W_scales_tensor.reshape(*W_shape) * (
278            W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float()
279        b = X_scale * W_scales_tensor * b_init.float()
280        W_q = torch.quantize_per_channel(
281            W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0,
282            dtype=torch.qint8)
283    else:
284        W = W_scale[0] * (W_init - W_zero_point[0]).float()
285        b = X_scale * W_scale[0] * b_init.float()
286        W_q = torch.quantize_per_tensor(
287            W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)
288
289    return (X, X_q, W, W_q, b if use_bias else None)
290
291def _make_conv_add_extra_input_tensor(scale, zero_point, sizes):
292    (X_value_min, X_value_max) = (0, 4)
293    X_init = torch.randint(
294        X_value_min,
295        X_value_max,
296        sizes  # Infer the size of tensor to do the add
297    )
298    X = scale * (X_init - zero_point).float()
299    X_q = torch.quantize_per_tensor(
300        X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
301    return X, X_q
302
303def skipIfNoFBGEMM(fn):
304    reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.'
305    if isinstance(fn, type):
306        if 'fbgemm' not in torch.backends.quantized.supported_engines:
307            fn.__unittest_skip__ = True
308            fn.__unittest_skip_why__ = reason
309        return fn
310
311    @functools.wraps(fn)
312    def wrapper(*args, **kwargs):
313        if 'fbgemm' not in torch.backends.quantized.supported_engines:
314            raise unittest.SkipTest(reason)
315        else:
316            fn(*args, **kwargs)
317    return wrapper
318
319def skipIfNoQNNPACK(fn):
320    reason = 'Quantized operations require QNNPACK.'
321    if isinstance(fn, type):
322        if 'qnnpack' not in torch.backends.quantized.supported_engines:
323            fn.__unittest_skip__ = True
324            fn.__unittest_skip_why__ = reason
325        return fn
326
327    @functools.wraps(fn)
328    def wrapper(*args, **kwargs):
329        if 'qnnpack' not in torch.backends.quantized.supported_engines:
330            raise unittest.SkipTest(reason)
331        else:
332            fn(*args, **kwargs)
333    return wrapper
334
335def withQNNPACKBackend(fn):
336    # TODO(future PR): consider combining with skipIfNoQNNPACK,
337    # will require testing of existing callsites
338    reason = 'Quantized operations require QNNPACK.'
339    if isinstance(fn, type):
340        if 'qnnpack' not in torch.backends.quantized.supported_engines:
341            fn.__unittest_skip__ = True
342            fn.__unittest_skip_why__ = reason
343        return fn
344
345    @functools.wraps(fn)
346    def wrapper(*args, **kwargs):
347        if 'qnnpack' not in torch.backends.quantized.supported_engines:
348            raise unittest.SkipTest(reason)
349        with override_quantized_engine('qnnpack'):
350            fn(*args, **kwargs)
351
352    return wrapper
353
354def skipIfNoONEDNN(fn):
355    reason = 'Quantized operations require ONEDNN.'
356    if isinstance(fn, type):
357        if 'onednn' not in torch.backends.quantized.supported_engines:
358            fn.__unittest_skip__ = True
359            fn.__unittest_skip_why__ = reason
360        return fn
361
362    @functools.wraps(fn)
363    def wrapper(*args, **kwargs):
364        if 'onednn' not in torch.backends.quantized.supported_engines:
365            raise unittest.SkipTest(reason)
366        else:
367            fn(*args, **kwargs)
368    return wrapper
369
370def skipIfNoONEDNNBF16(fn):
371    reason = 'Quantized operations require BF16 support.'
372    if isinstance(fn, type):
373        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
374            fn.__unittest_skip__ = True
375            fn.__unittest_skip_why__ = reason
376        return fn
377
378    @functools.wraps(fn)
379    def wrapper(*args, **kwargs):
380        if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
381            raise unittest.SkipTest(reason)
382        else:
383            fn(*args, **kwargs)
384    return wrapper
385
386def skipIfNoX86(fn):
387    reason = 'Quantized operations require X86.'
388    if isinstance(fn, type):
389        if 'x86' not in torch.backends.quantized.supported_engines:
390            fn.__unittest_skip__ = True
391            fn.__unittest_skip_why__ = reason
392        return fn
393
394    @functools.wraps(fn)
395    def wrapper(*args, **kwargs):
396        if 'x86' not in torch.backends.quantized.supported_engines:
397            raise unittest.SkipTest(reason)
398        else:
399            fn(*args, **kwargs)
400    return wrapper
401
402def skipIfNoDynamoSupport(fn):
403    reason = "dynamo doesn't support."
404    if isinstance(fn, type):
405        if not torchdynamo.is_dynamo_supported():
406            fn.__unittest_skip__ = True
407            fn.__unittest_skip_why__ = reason
408        return fn
409
410    @functools.wraps(fn)
411    def wrapper(*args, **kwargs):
412        if not torchdynamo.is_dynamo_supported():
413            raise unittest.SkipTest(reason)
414        else:
415            fn(*args, **kwargs)
416    return wrapper
417
418def skipIfNoInductorSupport(fn):
419    reason = "inductor doesn't support."
420    if isinstance(fn, type):
421        if not torchdynamo.is_inductor_supported():
422            fn.__unittest_skip__ = True
423            fn.__unittest_skip_why__ = reason
424        return fn
425
426    @functools.wraps(fn)
427    def wrapper(*args, **kwargs):
428        if not torchdynamo.is_inductor_supported():
429            raise unittest.SkipTest(reason)
430        else:
431            fn(*args, **kwargs)
432    return wrapper
433
434try:
435    import torchvision  # noqa: F401
436    HAS_TORCHVISION = True
437except ImportError:
438    HAS_TORCHVISION = False
439skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
440
441def get_script_module(model, tracing, data):
442    return torch.jit.trace(model, data) if tracing else torch.jit.script(model)
443
444def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True):
445    """
446    Convert lengths to offsets for embedding_bag
447    """
448    tt = np.zeros((t.shape[0] + 1,), dtype=offset_type)
449    tt[1:] = t
450    tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type))
451    if use_begin_offset:
452        return tt[:-1]
453    return tt[1:]
454
455
456def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
457    assert w.dim() == 2
458    w = w.transpose(0, 1).contiguous()
459    assert q_group_size > 1
460    assert w.shape[-1] % q_group_size == 0
461
462    to_quant = w.reshape(-1, q_group_size)
463    assert torch.isnan(to_quant).sum() == 0
464
465    max_val = to_quant.amax(dim=1, keepdim=True)
466    min_val = to_quant.amin(dim=1, keepdim=True)
467    max_int = 2 ** n_bit - 1
468    min_int = 0
469    scales = (max_val - min_val).clamp(min=1e-6) / max_int
470    assert torch.isnan(scales).sum() == 0
471
472    zeros = min_val + scales * (2 ** (n_bit - 1))
473    assert torch.isnan(zeros).sum() == 0
474
475    out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
476    assert torch.isnan(out).sum() == 0
477
478    out = out.to(dtype=torch.int32).reshape(w.shape)
479    out_uint8 = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8)
480
481    # Scales and zeros for the same q-group should be contiguous, so we can
482    # load as a 32-bit word
483    scales = scales.view(w.shape[0], -1)
484    zeros = zeros.view(w.shape[0], -1)
485    scales_and_zeros = (
486        torch.cat(
487            [
488                scales.reshape(scales.size(0), scales.size(1), 1),
489                zeros.reshape(zeros.size(0), zeros.size(1), 1),
490            ],
491            2,
492        ).transpose(0, 1).contiguous()
493    )
494
495    return out_uint8, scales_and_zeros
496
497
498def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
499    # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
500    # default setup for affine quantization of activations
501    x_dtype = x.dtype
502    x = x.float()
503    eps = torch.finfo(torch.float32).eps
504
505    # get min and max
506    min_val, max_val = torch.aminmax(x, dim=1)
507
508    # calculate scales and zero_points based on min and max
509    # reference: https://fburl.com/code/srbiybme
510    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
511    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
512    device = min_val_neg.device
513
514    # reference: https://fburl.com/code/4wll53rk
515    max_val_pos = torch.max(-min_val_neg, max_val_pos)
516    scales = max_val_pos / (float(quant_max - quant_min) / 2)
517    # ensure scales is the same dtype as the original tensor
518    scales = torch.clamp(scales, min=eps).to(x.dtype)
519    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
520
521    # quantize based on qmin/qmax/scales/zp
522    x_div = x / scales.unsqueeze(-1)
523    x_round = torch.round(x_div)
524    x_zp = x_round + zero_points.unsqueeze(-1)
525    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
526
527    return quant, scales.to(x_dtype), zero_points
528
529
530
531# QuantizationTestCase used as a base class for testing quantization on modules
532class QuantizationTestCase(TestCase):
533    def setUp(self):
534        super().setUp()
535        self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)]
536        self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)]
537        self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)]
538                            for _ in range(2)]
539        self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)]
540                            for _ in range(2)]
541        self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)]
542                            for _ in range(2)]
543        self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float),
544                                   torch.randint(0, 1, (1,), dtype=torch.long)]
545                                  for _ in range(2)]
546        self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float),
547                                   torch.randint(0, 1, (1,), dtype=torch.long)]
548                                  for _ in range(2)]
549        self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float),
550                                   torch.randint(0, 1, (1,), dtype=torch.long)]
551                                  for _ in range(2)]
552
553        self.img_data_dict = {1 : self.img_data_1d,
554                              2 : self.img_data_2d,
555                              3 : self.img_data_3d}
556
557        # Quant types that produce statically quantized ops
558        self.static_quant_types = [QuantType.STATIC, QuantType.QAT]
559        # All quant types for (fx based) graph mode quantization
560        self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT]
561
562    def checkNoPrepModules(self, module):
563        r"""Checks the module does not contain child
564            modules for quantization preparation, e.g.
565            quant, dequant and observer
566        """
567        self.assertFalse(hasattr(module, 'quant'))
568        self.assertFalse(hasattr(module, 'dequant'))
569
570    def checkNoQconfig(self, module):
571        r"""Checks the module does not contain qconfig
572        """
573        self.assertFalse(hasattr(module, 'qconfig'))
574
575        for child in module.children():
576            self.checkNoQconfig(child)
577
578    def checkHasPrepModules(self, module):
579        r"""Checks the module contains child
580            modules for quantization preparation, e.g.
581            quant, dequant and observer
582        """
583        self.assertTrue(hasattr(module, 'module'))
584        self.assertTrue(hasattr(module, 'quant'))
585        self.assertTrue(hasattr(module, 'dequant'))
586
587    def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None):
588        r"""Checks the module or module's leaf descendants
589            have observers in preparation for quantization
590        """
591        if propagate_qconfig_list is None:
592            propagate_qconfig_list = get_default_qconfig_propagation_list()
593        if prepare_custom_config_dict is None:
594            prepare_custom_config_dict = {}
595        float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
596
597        # check if a module is a leaf module, ignoring activation_post_process attribute
598        def is_leaf_module(module):
599            submodule_name_count = 0
600            for name, _ in module.named_children():
601                if name != 'activation_post_process':
602                    submodule_name_count += 1
603            return submodule_name_count == 0
604
605        if hasattr(module, 'qconfig') and module.qconfig is not None and \
606           ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential)
607            and type(module) in propagate_qconfig_list) or
608           type(module) in float_to_observed_module_class_mapping.keys()) and \
609           not isinstance(module, torch.ao.quantization.DeQuantStub):
610            self.assertTrue(hasattr(module, 'activation_post_process'),
611                            'module: ' + str(type(module)) + ' do not have observer')
612        # we don't need to check observers for child modules of the
613        # qat modules
614        if type(module) not in get_default_qat_module_mappings().values() and \
615           type(module) not in float_to_observed_module_class_mapping.values() and \
616           not isinstance(module, _FusedModule):
617            for child in module.children():
618                if type(child) in [nn.Dropout]:
619                    continue
620                self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict)
621
622    def checkQuantDequant(self, mod):
623        r"""Checks that mod has nn.Quantize and
624            nn.DeQuantize submodules inserted
625        """
626        self.assertEqual(type(mod.quant), nnq.Quantize)
627        self.assertEqual(type(mod.dequant), nnq.DeQuantize)
628
629    def checkWrappedQuantizedLinear(self, mod):
630        r"""Checks that mod has been swapped for an nnq.Linear
631            module, the bias is qint32, and that the module
632            has Quantize and DeQuantize submodules
633        """
634        self.assertEqual(type(mod.module), nnq.Linear)
635        self.checkQuantDequant(mod)
636
637    def checkQuantizedLinear(self, mod):
638        self.assertEqual(type(mod), nnq.Linear)
639
640    def checkDynamicQuantizedLinear(self, mod, dtype):
641        r"""Checks that mod has been swapped for an nnqd.Linear
642            module, the bias is float.
643        """
644        self.assertEqual(type(mod), nnqd.Linear)
645        self.assertEqual(mod._packed_params.dtype, dtype)
646
647    def checkDynamicQuantizedLinearRelu(self, mod, dtype):
648        r"""Checks that mod has been swapped for an nnqd.Linear
649            module, the bias is float.
650        """
651        self.assertEqual(type(mod), nniqd.LinearReLU)
652        self.assertEqual(mod._packed_params.dtype, dtype)
653
654    def check_eager_serialization(self, ref_model, loaded_model, x):
655        # Check state dict serialization and torch.save APIs
656        model_dict = ref_model.state_dict()
657        b = io.BytesIO()
658        torch.save(model_dict, b)
659        b.seek(0)
660        # weights_only=False as we sometimes get a ScriptObect here (weird)
661        loaded_dict = torch.load(b, weights_only=False)
662        loaded_model.load_state_dict(loaded_dict)
663        ref_out = ref_model(*x)
664        load_out = loaded_model(*x)
665
666        def check_outputs(ref_out, load_out):
667            self.assertEqual(ref_out[0], load_out[0])
668            if isinstance(ref_out[1], tuple):
669                self.assertEqual(ref_out[1][0], load_out[1][0])
670                self.assertEqual(ref_out[1][1], load_out[1][1])
671            else:
672                self.assertEqual(ref_out[1], load_out[1])
673
674        check_outputs(ref_out, load_out)
675        b = io.BytesIO()
676        torch.save(ref_model, b)
677        b.seek(0)
678        # weights_only=False as this is legacy code that saves the model
679        loaded = torch.load(b, weights_only=False)
680        load_out = loaded(*x)
681        check_outputs(ref_out, load_out)
682
683    def check_weight_bias_api(self, ref_model, weight_keys, bias_keys):
684        weight = ref_model.get_weight()
685        bias = ref_model.get_bias()
686        self.assertEqual(weight_keys ^ weight.keys(), set())
687        self.assertEqual(bias_keys ^ bias.keys(), set())
688
689    def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype):
690        r"""Checks that mod has been swapped for an nnqd.LSTM type
691            module, the bias is float.
692        """
693        wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
694        self.assertEqual(type(mod), reference_module_type)
695        for packed_params in mod._all_weight_values:
696            self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
697
698    def checkLinear(self, mod):
699        self.assertEqual(type(mod), torch.nn.Linear)
700
701    def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype):
702        r"""Checks that mod has been swapped for an nnqd.Linear
703            module, the bias is float.
704        """
705        wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'}
706        self.assertEqual(type(mod), reference_module_type)
707        if hasattr(mod, '_all_weight_values'):
708            for packed_params in mod._all_weight_values:
709                self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype])
710
711    def checkScriptable(self, orig_mod, calib_data, check_save_load=False):
712        scripted = torch.jit.script(orig_mod)
713        self._checkScriptable(orig_mod, scripted, calib_data, check_save_load)
714
715        # Use first calib_data entry as trace input
716        traced = torch.jit.trace(orig_mod, calib_data[0])
717        self._checkScriptable(orig_mod, traced, calib_data, check_save_load)
718
719    # Call this twice: once for a scripted module and once for a traced module
720    def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load):
721        self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data)
722
723        # Test save/load
724        buffer = io.BytesIO()
725        torch.jit.save(script_mod, buffer)
726
727        buffer.seek(0)
728        loaded_mod = torch.jit.load(buffer)
729        # Pending __get_state_ and __set_state__ support
730        # See tracking task https://github.com/pytorch/pytorch/issues/23984
731        if check_save_load:
732            self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data)
733
734    def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data):
735        for inp in calib_data:
736            ref_output = orig_mod(*inp)
737            scripted_output = test_mod(*inp)
738            self.assertEqual(scripted_output, ref_output)
739
740
741    def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False,
742                         check=True, eval_mode=True, dynamic=False, qconfig=None):
743        if debug:
744            print('Testing:', str(module))
745        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
746
747        if eval_mode:
748            module = module.eval()
749        if dynamic:
750            qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
751        model = get_script_module(module, tracing, inputs[0]).eval()
752        if debug:
753            print('input graph:', model.graph)
754        models = {}
755        outputs = {}
756        for debug in [True, False]:
757            if dynamic:
758                models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug)
759                # make sure it runs
760                outputs[debug] = models[debug](inputs)
761            else:
762                # module under test can contain in-place ops, and we depend on
763                # input data staying constant for comparisons
764                inputs_copy = copy.deepcopy(inputs)
765                models[debug] = quantize_jit(
766                    model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False,
767                    debug=debug)
768                # make sure it runs
769                outputs[debug] = models[debug](*inputs[0])
770
771        if debug:
772            print('debug graph:', models[True].graph)
773            print('non debug graph:', models[False].graph)
774
775        if check:
776            # debug and non-debug option should have the same numerics
777            self.assertEqual(outputs[True], outputs[False])
778
779            # non debug graph should produce quantized op
780            FileCheck().check(quantized_op) \
781                       .run(models[False].graph)
782
783        return models[False]
784
785    def checkGraphModuleNodes(
786            self, graph_module,
787            expected_node=None,
788            expected_node_occurrence=None,
789            expected_node_list=None):
790        """ Check if GraphModule contains the target node
791        Args:
792            graph_module: the GraphModule instance we want to check
793            expected_node, expected_node_occurrence, expected_node_list:
794               see docs for checkGraphModeFxOp
795        """
796        nodes_in_graph = {}
797        node_list = []
798        modules = dict(graph_module.named_modules(remove_duplicate=False))
799        for node in graph_module.graph.nodes:
800            n = None
801            if node.op == 'call_function' or node.op == 'call_method':
802                n = NodeSpec(node.op, node.target)
803            elif node.op == 'call_module':
804                n = NodeSpec(node.op, type(modules[node.target]))
805
806            if n is not None:
807                node_list.append(n)
808                if n in nodes_in_graph:
809                    nodes_in_graph[n] += 1
810                else:
811                    nodes_in_graph[n] = 1
812
813        if expected_node is not None:
814            self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) +
815                            ' not found in the graph module')
816
817        if expected_node_occurrence is not None:
818            for expected_node, occurrence in expected_node_occurrence.items():
819                if occurrence != 0:
820                    self.assertTrue(
821                        expected_node in nodes_in_graph,
822                        'Check failed for node:' + str(expected_node) +
823                        ' not found')
824                    self.assertTrue(
825                        nodes_in_graph[expected_node] == occurrence,
826                        'Check failed for node:' + str(expected_node) +
827                        ' Expected occurrence:' + str(occurrence) +
828                        ' Found occurrence:' + str(nodes_in_graph[expected_node]))
829                else:
830                    self.assertTrue(
831                        expected_node not in nodes_in_graph,
832                        'Check failed for node:' + str(expected_node) +
833                        ' expected no occurrence but found')
834
835        if expected_node_list is not None:
836            cur_index = 0
837            for n in node_list:
838                if cur_index == len(expected_node_list):
839                    return
840                if n == expected_node_list[cur_index]:
841                    cur_index += 1
842            self.assertTrue(
843                cur_index == len(expected_node_list),
844                "Check failed for graph:" +
845                self.printGraphModule(graph_module, print_str=False) +
846                "Expected ordered list:" +
847                str(expected_node_list))
848
849    def printGraphModule(self, graph_module, print_str=True):
850        modules = dict(graph_module.named_modules(remove_duplicate=False))
851        node_infos = []
852        for n in graph_module.graph.nodes:
853            node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs]))
854            if n.op == 'call_module':
855                node_info += ' module type: ' + repr(type(modules[n.target]))
856            node_infos.append(node_info)
857        str_to_print = '\n'.join(node_infos)
858        if print_str:
859            print(str_to_print)
860        return str_to_print
861
862    if HAS_FX:
863
864        def assert_types_for_matched_subgraph_pairs(
865            self,
866            matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
867            expected_types: Dict[str, Tuple[Tuple[Callable, Callable], Tuple[Callable, Callable]]],
868            gm_a: GraphModule,
869            gm_b: GraphModule,
870        ) -> None:
871            """
872            Verifies that the types specified in expected_types match
873            the underlying objects pointed to by the nodes in matched_subgraph_pairs.
874
875            An example successful test case:
876
877              matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)}
878              expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)}
879
880            The function tests for key equivalence, and verifies types with
881            instance checks.
882            """
883
884            def _get_underlying_op_type(
885                node: Node, gm: GraphModule
886            ) -> Union[Callable, str]:
887                if node.op == 'call_module':
888                    mod = getattr(gm, node.target)
889                    return type(mod)
890                else:
891                    assert node.op in ('call_function', 'call_method')
892                    return node.target
893
894            self.assertTrue(
895                len(matched_subgraph_pairs) == len(expected_types),
896                f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}'
897            )
898            for k, v in expected_types.items():
899                expected_types_a, expected_types_b = v
900                exp_type_start_a, exp_type_end_a = expected_types_a
901                exp_type_start_b, exp_type_end_b = expected_types_b
902                subgraph_a, subgraph_b = matched_subgraph_pairs[k]
903
904                act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a)
905                act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b)
906                act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a)
907                act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b)
908                types_match = (exp_type_start_a is act_type_start_a) and \
909                    (exp_type_end_a is act_type_end_a) and \
910                    (exp_type_start_b is act_type_start_b) and \
911                    (exp_type_end_b is act_type_end_b)
912                self.assertTrue(
913                    types_match,
914                    f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, '
915                    f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}'
916                )
917
918        def assert_ns_compare_dict_valid(
919            self,
920            act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]],
921        ) -> None:
922            """
923            Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
924            1. for each layer, results are recorded for two models
925            2. number of seen tensors match
926            3. shapes of each pair of seen tensors match
927            """
928            for layer_name, result_type_to_data in act_compare_dict.items():
929                for result_type, layer_data in result_type_to_data.items():
930                    self.assertTrue(
931                        len(layer_data) == 2,
932                        f"Layer {layer_name} does not have exactly two model results.")
933                    model_name_0, model_name_1 = layer_data.keys()
934                    for res_idx in range(len(layer_data[model_name_0])):
935                        layer_data_0 = layer_data[model_name_0][res_idx]
936                        layer_data_1 = layer_data[model_name_1][res_idx]
937                        self.assertTrue(
938                            layer_data_0['type'] == layer_data_0['type'],
939                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.")
940
941                        self.assertTrue(
942                            len(layer_data_0['values']) ==
943                            len(layer_data_1['values']),
944                            f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.")
945
946                        # F.conv1d weight has rank 3, and toq.conv1d unpacked weight
947                        # has rank 4. For now, skip the length check for conv1d only.
948                        is_weight_functional_conv1d = (
949                            result_type == NSSingleResultValuesType.WEIGHT.value and
950                            (
951                                'conv1d' in layer_data_0['prev_node_target_type'] or
952                                'conv1d' in layer_data_1['prev_node_target_type']
953                            )
954                        )
955                        if not is_weight_functional_conv1d:
956                            for idx in range(len(layer_data_0['values'])):
957                                values_0 = layer_data_0['values'][idx]
958                                values_1 = layer_data_1['values'][idx]
959                                if isinstance(values_0, torch.Tensor):
960                                    self.assertTrue(
961                                        values_0.shape == values_1.shape,
962                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
963                                        f"have a shape mismatch at idx {idx}.")
964                                elif isinstance(values_0, list):
965                                    values_0 = values_0[0]
966                                    values_1 = values_1[0]
967                                    self.assertTrue(
968                                        values_0.shape == values_1.shape,
969                                        f"Layer {layer_name}, {model_name_0} and {model_name_1} " +
970                                        f"have a shape mismatch at idx {idx}.")
971                                else:
972                                    assert isinstance(values_0, tuple), \
973                                        f"unhandled type {type(values_0)}"
974                                    assert len(values_0) == 2
975                                    assert len(values_0[1]) == 2
976                                    assert values_0[0].shape == values_1[0].shape
977                                    assert values_0[1][0].shape == values_1[1][0].shape
978                                    assert values_0[1][1].shape == values_1[1][1].shape
979
980                        # verify that ref_node_name is valid
981                        ref_node_name_0 = layer_data_0['ref_node_name']
982                        ref_node_name_1 = layer_data_1['ref_node_name']
983                        prev_node_name_0 = layer_data_0['prev_node_name']
984                        prev_node_name_1 = layer_data_1['prev_node_name']
985                        if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value:
986                            self.assertTrue(ref_node_name_0 == prev_node_name_0)
987                            self.assertTrue(ref_node_name_1 == prev_node_name_1)
988                        elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value:
989                            self.assertTrue(ref_node_name_0 != prev_node_name_0)
990                            self.assertTrue(ref_node_name_1 != prev_node_name_1)
991
992        def checkGraphModeFxOp(
993                self,
994                model,
995                inputs,
996                quant_type,
997                expected_node=None,
998                expected_node_occurrence=None,
999                expected_node_list=None,
1000                is_reference=False,
1001                print_debug_info=False,
1002                custom_qconfig_dict=None,
1003                prepare_expected_node=None,
1004                prepare_expected_node_occurrence=None,
1005                prepare_expected_node_list=None,
1006                prepare_custom_config=None,
1007                backend_config=None):
1008            """ Quantizes model with graph mode quantization on fx and check if the
1009                quantized model contains the quantized_node
1010
1011                Args:
1012                    model: floating point torch.nn.Module
1013                    inputs: one positional sample input arguments for model
1014                    expected_node: NodeSpec
1015                        e.g. NodeSpec.call_function(torch.quantize_per_tensor)
1016                    expected_node_occurrence: a dict from NodeSpec to
1017                        expected number of occurrences (int)
1018                        e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
1019                                NodeSpec.call_method('dequantize'): 1}
1020                    expected_node_list: a list of NodeSpec, used to check the order
1021                        of the occurrence of Node
1022                        e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
1023                                NodeSpec.call_module(nnq.Conv2d),
1024                                NodeSpec.call_function(F.hardtanh_),
1025                                NodeSpec.call_method('dequantize')]
1026                    is_reference: if True, enables reference mode
1027                    print_debug_info: if True, prints debug info
1028                    custom_qconfig_dict: overrides default qconfig_dict
1029                    prepare_expected_node: same as expected_node, but for prepare
1030                    prepare_expected_node_occurrence: same as
1031                        expected_node_occurrence, but for prepare
1032                    prepare_expected_node_list: same as expected_node_list, but
1033                        for prepare
1034
1035                Returns:
1036                    A dictionary with the following structure:
1037                   {
1038                       "prepared": ...,  # the prepared model
1039                       "quantized": ...,  # the quantized non-reference model
1040                       "quantized_reference": ...,  # the quantized reference model
1041                       "result": ...,  # the result for either quantized or
1042                                       # quantized_reference model depending on the
1043                                       # is_reference argument
1044                   }
1045            """
1046            # TODO: make img_data a single example instead of a list
1047            if type(inputs) == list:
1048                inputs = inputs[0]
1049
1050            if quant_type == QuantType.QAT:
1051                qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine)
1052                model.train()
1053            elif quant_type == QuantType.STATIC:
1054                qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine)
1055                model.eval()
1056            else:
1057                qconfig = default_dynamic_qconfig
1058                qconfig_mapping = QConfigMapping().set_global(qconfig)
1059                model.eval()
1060
1061            if quant_type == QuantType.QAT:
1062                prepare = prepare_qat_fx
1063            else:
1064                prepare = prepare_fx
1065
1066            # overwrite qconfig_dict with custom_qconfig_dict
1067            if custom_qconfig_dict is not None:
1068                assert type(custom_qconfig_dict) in (QConfigMapping, dict), \
1069                    'custom_qconfig_dict should be a QConfigMapping or a dict'
1070                if isinstance(custom_qconfig_dict, QConfigMapping):
1071                    qconfig_mapping = custom_qconfig_dict
1072                else:
1073                    qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict)
1074            prepared = prepare(
1075                model, qconfig_mapping,
1076                example_inputs=inputs,
1077                prepare_custom_config=prepare_custom_config,
1078                backend_config=backend_config)
1079            if not quant_type == QuantType.DYNAMIC:
1080                prepared(*inputs)
1081
1082            if print_debug_info:
1083                print()
1084                print('quant type:\n', quant_type)
1085                print('original model:\n', model)
1086                print()
1087                print('prepared model:\n', prepared)
1088
1089            self.checkGraphModuleNodes(
1090                prepared, prepare_expected_node,
1091                prepare_expected_node_occurrence, prepare_expected_node_list)
1092
1093            prepared_copy = copy.deepcopy(prepared)
1094            qgraph = convert_fx(copy.deepcopy(prepared))
1095            qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared))
1096            result = qgraph(*inputs)
1097            result_reference = qgraph_reference(*inputs)
1098            qgraph_copy = copy.deepcopy(qgraph)
1099            qgraph_reference_copy = copy.deepcopy(qgraph_reference)
1100
1101            qgraph_to_check = qgraph_reference if is_reference else qgraph
1102            if print_debug_info:
1103                print()
1104                print('quantized model:\n', qgraph_to_check)
1105                self.printGraphModule(qgraph_to_check)
1106                print()
1107            self.checkGraphModuleNodes(
1108                qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
1109            return {"prepared": prepared_copy,
1110                    "quantized": qgraph_copy,
1111                    "quantized_reference": qgraph_reference_copy,
1112                    "quantized_output": result,
1113                    "quantized_reference_output": result_reference}
1114
1115
1116    def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets,
1117                                    set_qconfig, is_emb_bag, dtype=torch.quint8):
1118        # Test serialization of dynamic EmbeddingBag module using state_dict
1119        if is_emb_bag:
1120            inputs = [indices, offsets]
1121        else:
1122            inputs = [indices]
1123        emb_dict = qemb.state_dict()
1124        b = io.BytesIO()
1125        torch.save(emb_dict, b)
1126        b.seek(0)
1127        loaded_dict = torch.load(b)
1128        embedding_unpack = torch.ops.quantized.embedding_bag_unpack
1129        # Check unpacked weight values explicitly
1130        for key in emb_dict:
1131            if isinstance(emb_dict[key], torch._C.ScriptObject):
1132                assert isinstance(loaded_dict[key], torch._C.ScriptObject)
1133                emb_weight = embedding_unpack(emb_dict[key])
1134                loaded_weight = embedding_unpack(loaded_dict[key])
1135                self.assertEqual(emb_weight, loaded_weight)
1136
1137        # Check state dict serialization and torch.save APIs
1138        if is_emb_bag:
1139            loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
1140                                           include_last_offset=True, mode='sum', dtype=dtype)
1141        else:
1142            loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
1143        self.check_eager_serialization(qemb, loaded_qemb, inputs)
1144
1145        loaded_qemb.load_state_dict(loaded_dict)
1146        self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight),
1147                         embedding_unpack(loaded_qemb._packed_params._packed_weight))
1148
1149
1150        # Test JIT serialization
1151        self.checkScriptable(qemb, [inputs], check_save_load=True)
1152
1153        # Test from_float call
1154        if is_emb_bag:
1155            float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim,
1156                                                    include_last_offset=True, scale_grad_by_freq=False, mode='sum')
1157        else:
1158            float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
1159
1160        if set_qconfig:
1161            float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
1162                                                                        qscheme=torch.per_channel_affine_float_qparams,
1163                                                                        ch_axis=0)
1164            float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer,
1165                                              weight=float_qparams_observer)
1166
1167        prepare_dynamic(float_embedding)
1168
1169        float_embedding(*inputs)
1170        if is_emb_bag:
1171            q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding)
1172            expected_name = "QuantizedEmbeddingBag"
1173        else:
1174            q_embeddingbag = nnq.Embedding.from_float(float_embedding)
1175            expected_name = "QuantizedEmbedding"
1176
1177        q_embeddingbag(*inputs)
1178
1179        self.assertTrue(expected_name in str(q_embeddingbag))
1180
1181class QuantizationLiteTestCase(QuantizationTestCase):
1182    def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs):
1183        # Creates quantized model for testing mobile script modules
1184        qengine = "qnnpack"
1185        with override_quantized_engine(qengine):
1186            qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1187            model = model_class(**kwargs)
1188            model = quantize(model, test_only_eval_fn, [self.calib_data])
1189
1190        return model
1191
1192    def _compare_script_and_mobile(self,
1193                                   model: torch.nn.Module,
1194                                   input: torch.Tensor):
1195        # Compares the numerical outputs for script and lite modules
1196        qengine = "qnnpack"
1197        with override_quantized_engine(qengine):
1198            script_module = torch.jit.script(model)
1199            script_module_result = script_module(input)
1200
1201            max_retry = 5
1202            for retry in range(1, max_retry + 1):
1203                # retries `max_retry` times; breaks iff succeeds else throws exception
1204                try:
1205                    buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
1206                    buffer.seek(0)
1207                    mobile_module = _load_for_lite_interpreter(buffer)
1208
1209                    mobile_module_result = mobile_module(input)
1210
1211                    torch.testing.assert_close(script_module_result, mobile_module_result)
1212                    mobile_module_forward_result = mobile_module.forward(input)
1213                    torch.testing.assert_close(script_module_result, mobile_module_forward_result)
1214
1215                    mobile_module_run_method_result = mobile_module.run_method("forward", input)
1216                    torch.testing.assert_close(script_module_result, mobile_module_run_method_result)
1217                except AssertionError as e:
1218                    if retry == max_retry:
1219                        raise e
1220                    else:
1221                        continue
1222                break
1223
1224
1225class PT2EQuantizationTestCase(QuantizationTestCase):
1226    """
1227    Base QuantizationTestCase for PT2 with some helper methods.
1228    """
1229    _MAP_TO_FX_TRACED_OPS = {
1230        torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default,
1231        torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1232        torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default,
1233        torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default,
1234        torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
1235        torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
1236    }
1237
1238    def _test_quantizer(
1239        self,
1240        model,
1241        example_inputs,
1242        quantizer,
1243        expected_node_occurrence,
1244        expected_node_list=None,
1245        check_against_fx_quant=False,
1246        fx_qconfig_mapping=None,
1247        export_with_dynamic_shape=False,
1248        is_qat=False,
1249        is_debug_mode=False,
1250        capture_pre_autograd_graph_node_occurrence=None,
1251    ):
1252        # resetting dynamo cache
1253        torch._dynamo.reset()
1254        m_eager = model.eval()
1255
1256        # program capture
1257        m = copy.deepcopy(m_eager)
1258        dynamic_shapes = tuple(
1259            {0: torch.export.Dim("dim")} if i == 0 else None
1260            for i in range(len(example_inputs))
1261        )
1262        m = capture_pre_autograd_graph(
1263            m,
1264            example_inputs,
1265            dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
1266        )
1267
1268        if is_qat:
1269            m = prepare_qat_pt2e(m, quantizer)
1270        else:
1271            m = prepare_pt2e(m, quantizer)
1272        # Calibrate
1273        m(*example_inputs)
1274        m = convert_pt2e(m)
1275        if is_debug_mode:
1276            print("quantized model", m)
1277
1278        pt2_quant_output = m(*example_inputs)
1279        ns = NodeSpec
1280        node_occurrence = {
1281            ns.call_function(k): v for k, v in expected_node_occurrence.items()
1282        }
1283        if expected_node_list is None:
1284            expected_node_list = []
1285        node_list = [ns.call_function(n) for n in expected_node_list]
1286        self.checkGraphModuleNodes(
1287            m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
1288        )
1289        if check_against_fx_quant:
1290            qconfig_mapping = fx_qconfig_mapping
1291            backend_config = get_executorch_backend_config()
1292            m_copy = copy.deepcopy(m_eager)
1293            m_fx = prepare_fx(
1294                m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
1295            )
1296            m_fx(*example_inputs)
1297            m_fx = _convert_to_reference_decomposed_fx(
1298                m_fx, backend_config=backend_config
1299            )
1300            m_fx = capture_pre_autograd_graph(
1301                m_fx,
1302                example_inputs,
1303                dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
1304            )
1305            node_occurrence = {}
1306            for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items():
1307                if k in expected_node_occurrence:
1308                    node_occurrence[ns.call_function(v)] = expected_node_occurrence[k]
1309            if capture_pre_autograd_graph_node_occurrence is not None:
1310                node_occurrence = {
1311                    ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items()
1312                }
1313            self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence)
1314            fx_quant_output = m_fx(*example_inputs)
1315            self.assertEqual(fx_quant_output, pt2_quant_output)
1316        return m
1317
1318    def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False):
1319        # resetting dynamo cache
1320        torch._dynamo.reset()
1321
1322        m = capture_pre_autograd_graph(
1323            m,
1324            example_inputs,
1325        )
1326        if is_qat:
1327            m = prepare_qat_pt2e(m, quantizer)
1328        else:
1329            m = prepare_pt2e(m, quantizer)
1330        m(*example_inputs)
1331        m = convert_pt2e(m)
1332        return m
1333
1334    def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
1335        class M(torch.nn.Module):
1336            def __init__(self) -> None:
1337                super().__init__()
1338                self.linear = torch.nn.Linear(2, 2)
1339
1340            def forward(self, x):
1341                return self.linear(x)
1342
1343        quantizer = XNNPACKQuantizer()
1344        operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel)
1345        quantizer.set_global(operator_config)
1346        example_inputs = (torch.randn(2, 2),)
1347        m = M().eval()
1348        return self._quantize(m, quantizer, example_inputs)
1349
1350# Below are a series of toy models to use in testing quantization
1351
1352class SingleLayerLinearModel(torch.nn.Module):
1353    def __init__(self) -> None:
1354        super().__init__()
1355        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1356
1357    def forward(self, x):
1358        x = self.fc1(x)
1359        return x
1360
1361    def get_example_inputs(self) -> Tuple[Any, ...]:
1362        return (torch.rand(1, 5),)
1363
1364class AnnotatedSingleLayerLinearModel(torch.nn.Module):
1365    def __init__(self, qengine='fbgemm'):
1366        super().__init__()
1367        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1368        self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
1369
1370    def forward(self, x):
1371        x = self.fc1(x)
1372        return x
1373
1374    def get_example_inputs(self) -> Tuple[Any, ...]:
1375        return (torch.rand(1, 5),)
1376
1377class SingleLayerLinearDynamicModel(torch.nn.Module):
1378    def __init__(self, qengine='fbgemm'):
1379        super().__init__()
1380        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1381        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1382
1383    def forward(self, x):
1384        x = self.fc1(x)
1385        return x
1386
1387    def get_example_inputs(self) -> Tuple[Any, ...]:
1388        return (torch.rand(1, 5),)
1389
1390class LinearAddModel(nn.Module):
1391    def __init__(self) -> None:
1392        super().__init__()
1393        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1394        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
1395
1396    def forward(self, x):
1397        x = self.fc1(x)
1398        x = torch.add(x, 5)
1399        x = self.fc2(x)
1400        return x
1401
1402    def get_example_inputs(self) -> Tuple[Any, ...]:
1403        return (torch.rand(1, 5),)
1404
1405class RNNDynamicModel(torch.nn.Module):
1406    def __init__(self, mod_type):
1407        super().__init__()
1408        self.qconfig = default_dynamic_qconfig
1409        if mod_type == 'GRU':
1410            self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float)
1411        if mod_type == 'LSTM':
1412            self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float)
1413
1414    def forward(self, x):
1415        x = self.mod(x)
1416        return x
1417
1418class RNNCellDynamicModel(torch.nn.Module):
1419    def __init__(self, mod_type):
1420        super().__init__()
1421        self.qconfig = default_dynamic_qconfig
1422        if mod_type == 'GRUCell':
1423            self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float)
1424        if mod_type == 'LSTMCell':
1425            self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float)
1426        if mod_type == 'RNNReLU':
1427            self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float)
1428        if mod_type == 'RNNTanh':
1429            self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float)
1430
1431    def forward(self, x):
1432        x = self.mod(x)
1433        return x
1434
1435class LSTMwithHiddenDynamicModel(torch.nn.Module):
1436    def __init__(self, qengine='fbgemm'):
1437        super().__init__()
1438        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1439        self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
1440
1441    def forward(self, x, hid):
1442        x, hid = self.lstm(x, hid)
1443        return x, hid
1444
1445class ConvModel(torch.nn.Module):
1446    def __init__(self) -> None:
1447        super().__init__()
1448        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1449
1450    def forward(self, x):
1451        x = self.conv(x)
1452        return x
1453
1454    def get_example_inputs(self) -> Tuple[Any, ...]:
1455        return (torch.rand(1, 3, 5, 5),)
1456
1457class ConvTransposeModel(torch.nn.Module):
1458    def __init__(self) -> None:
1459        super().__init__()
1460        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
1461
1462    def forward(self, x):
1463        x = self.conv(x)
1464        return x
1465
1466    def get_example_inputs(self) -> Tuple[Any, ...]:
1467        return (torch.rand(1, 3, 5, 5),)
1468
1469class AnnotatedConvModel(torch.nn.Module):
1470    def __init__(self, qengine):
1471        super().__init__()
1472        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1473        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1474        self.quant = QuantStub()
1475        self.dequant = DeQuantStub()
1476
1477    def forward(self, x):
1478        x = self.quant(x)
1479        x = self.conv(x)
1480        x = self.dequant(x)
1481        return x
1482
1483    def get_example_inputs(self) -> Tuple[Any, ...]:
1484        return (torch.rand(1, 3, 5, 5),)
1485
1486class AnnotatedConvTransposeModel(torch.nn.Module):
1487    def __init__(self, qengine):
1488        super().__init__()
1489        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1490        self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
1491        self.quant = QuantStub()
1492        self.dequant = DeQuantStub()
1493
1494    def forward(self, x):
1495        x = self.quant(x)
1496        x = self.conv(x)
1497        x = self.dequant(x)
1498        return x
1499
1500    def get_example_inputs(self) -> Tuple[Any, ...]:
1501        return (torch.rand(1, 3, 5, 5),)
1502
1503class ConvBnModel(torch.nn.Module):
1504    def __init__(self) -> None:
1505        super().__init__()
1506        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1507        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
1508
1509    def forward(self, x):
1510        x = self.conv(x)
1511        x = self.bn(x)
1512        return x
1513
1514    def get_example_inputs(self) -> Tuple[Any, ...]:
1515        return (torch.rand(1, 3, 5, 5),)
1516
1517class AnnotatedConvBnModel(torch.nn.Module):
1518    def __init__(self) -> None:
1519        super().__init__()
1520        self.qconfig = default_qconfig
1521        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1522        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
1523        self.quant = QuantStub()
1524        self.dequant = DeQuantStub()
1525
1526    def forward(self, x):
1527        x = self.quant(x)
1528        x = self.conv(x)
1529        x = self.bn(x)
1530        x = self.dequant(x)
1531        return x
1532
1533    def get_example_inputs(self) -> Tuple[Any, ...]:
1534        return (torch.rand(1, 3, 5, 5),)
1535
1536class ConvBnReLUModel(torch.nn.Module):
1537    def __init__(self) -> None:
1538        super().__init__()
1539        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1540        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
1541        self.relu = nn.ReLU(inplace=True)
1542
1543    def forward(self, x):
1544        x = self.conv(x)
1545        x = self.bn(x)
1546        x = self.relu(x)
1547        return x
1548
1549    def get_example_inputs(self) -> Tuple[Any, ...]:
1550        return (torch.rand(1, 3, 5, 5),)
1551
1552class AnnotatedConvBnReLUModel(torch.nn.Module):
1553    def __init__(self, qengine='fbgemm'):
1554        super().__init__()
1555        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
1556        self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1557        self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float)
1558        self.relu = nn.ReLU(inplace=True)
1559        self.quant = QuantStub()
1560        self.dequant = DeQuantStub()
1561
1562    def forward(self, x):
1563        x = self.quant(x)
1564        x = self.conv(x)
1565        x = self.bn(x)
1566        x = self.relu(x)
1567        x = self.dequant(x)
1568        return x
1569
1570    def fuse_model(self):
1571        # TODO: remove this check and define two fuse_modules function on this module
1572        if self.training:
1573            torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True)
1574        else:
1575            torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True)
1576
1577    def get_example_inputs(self) -> Tuple[Any, ...]:
1578        return (torch.rand(1, 3, 5, 5),)
1579
1580class TwoLayerConvModel(torch.nn.Module):
1581    def __init__(self) -> None:
1582        super().__init__()
1583        self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float)
1584        self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float)
1585
1586    def forward(self, x):
1587        x = self.conv1(x)
1588        x = self.conv2(x)
1589        return x
1590
1591    def get_example_inputs(self) -> Tuple[Any, ...]:
1592        return (torch.rand(1, 3, 5, 5),)
1593
1594class TwoLayerLinearModel(torch.nn.Module):
1595    def __init__(self) -> None:
1596        super().__init__()
1597        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1598        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
1599
1600    def forward(self, x):
1601        x = self.fc1(x)
1602        x = self.fc2(x)
1603        return x
1604
1605    def get_example_inputs(self) -> Tuple[Any, ...]:
1606        return (torch.rand(1, 5),)
1607
1608class LinearModelWithSubmodule(nn.Module):
1609    def __init__(self) -> None:
1610        super().__init__()
1611        self.subm = TwoLayerLinearModel()
1612        self.fc = nn.Linear(5, 5)
1613
1614    def forward(self, x):
1615        x = self.subm(x)
1616        x = self.fc(x)
1617        return x
1618
1619    def get_example_inputs(self) -> Tuple[Any, ...]:
1620        return self.subm.get_example_inputs()
1621
1622class AnnotatedTwoLayerLinearModel(torch.nn.Module):
1623    def __init__(self) -> None:
1624        super().__init__()
1625        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1626        self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float))
1627        self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
1628
1629    def forward(self, x):
1630        x = self.fc1(x)
1631        x = self.fc2(x)
1632        return x
1633
1634    def get_example_inputs(self) -> Tuple[Any, ...]:
1635        return (torch.rand(1, 5),)
1636
1637class ActivationsTestModel(torch.nn.Module):
1638    def __init__(self) -> None:
1639        super().__init__()
1640        self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
1641        self.quant = torch.ao.quantization.QuantStub()
1642        self.hardswish = torch.nn.Hardswish().to(dtype=torch.float)
1643        self.elu = torch.nn.ELU().to(dtype=torch.float)
1644        self.dequant = torch.ao.quantization.DeQuantStub()
1645
1646    def forward(self, x):
1647        x = self.quant(x)
1648        x = self.hardswish(x)
1649        x = self.elu(x)
1650        x = self.dequant(x)
1651        return x
1652
1653class LinearReluModel(torch.nn.Module):
1654    def __init__(self) -> None:
1655        super().__init__()
1656        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
1657        self.relu = torch.nn.ReLU()
1658
1659    def forward(self, x):
1660        x = self.relu(self.fc(x))
1661        return x
1662
1663    def get_example_inputs(self) -> Tuple[Any, ...]:
1664        return (torch.rand(1, 5),)
1665
1666
1667class LinearReluLinearModel(torch.nn.Module):
1668    def __init__(self) -> None:
1669        super().__init__()
1670        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1671        self.relu = torch.nn.ReLU()
1672        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
1673
1674    def forward(self, x):
1675        x = self.fc1(x)
1676        x = self.relu(x)
1677        x = self.fc2(x)
1678        return x
1679
1680    def get_example_inputs(self) -> Tuple[Any, ...]:
1681        return (torch.rand(1, 5),)
1682
1683class LinearReluAddModel(torch.nn.Module):
1684    def __init__(self) -> None:
1685        super().__init__()
1686        self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1687        self.relu = torch.nn.ReLU()
1688        self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1689
1690    def forward(self, x):
1691        x = self.fc1(x)
1692        x = self.relu(x)
1693        x = torch.add(x, 5)
1694        x = self.fc2(x)
1695        self.relu = torch.nn.ReLU()
1696        return x
1697
1698    def get_example_inputs(self) -> Tuple[Any, ...]:
1699        return (torch.rand(1, 5),)
1700
1701class LinearBnLeakyReluModel(torch.nn.Module):
1702    def __init__(self, with_bn=True):
1703        super().__init__()
1704        self.linear = nn.Linear(5, 5)
1705        self.bn1d = nn.BatchNorm1d(5)
1706        self.leaky_relu = nn.LeakyReLU(0.01)
1707        self.with_bn = with_bn
1708
1709    def forward(self, x):
1710        x = self.linear(x)
1711        if self.with_bn:
1712            x = self.bn1d(x)
1713        x = self.leaky_relu(x)
1714        return x
1715
1716    def get_example_inputs(self) -> Tuple[Any, ...]:
1717        return (torch.rand(1, 5),)
1718
1719class LinearTanhModel(torch.nn.Module):
1720    def __init__(self) -> None:
1721        super().__init__()
1722        self.linear = nn.Linear(5, 5)
1723        self.tanh = nn.Tanh()
1724
1725    def forward(self, x):
1726        x = self.linear(x)
1727        x = self.tanh(x)
1728        return x
1729
1730    def get_example_inputs(self) -> Tuple[Any, ...]:
1731        return (torch.rand(1, 5),)
1732
1733class ConvBnAddReluModel(torch.nn.Module):
1734    def __init__(self,
1735                 with_bn=True,
1736                 with_relu=True,
1737                 left_conv=True,
1738                 two_conv=True,
1739                 use_torch_add=True):
1740        super().__init__()
1741        self.conv = nn.Conv2d(5, 5, (2, 2))
1742        self.conv2 = nn.Conv2d(5, 5, (2, 2))
1743        self.bn = nn.BatchNorm2d(5)
1744        self.relu = nn.ReLU()
1745        self.with_bn = with_bn
1746        self.with_relu = with_relu
1747        self.two_conv = two_conv
1748        self.left_conv = left_conv
1749        self.use_torch_add = use_torch_add
1750
1751    def forward(self, x1, x2):
1752        if self.two_conv:
1753            if self.use_torch_add:
1754                if self.with_bn:
1755                    x = torch.add(self.bn(self.conv(x1)), self.conv2(x1))
1756                else:
1757                    x = torch.add(self.conv(x1), self.conv2(x1))
1758            else:
1759                if self.with_bn:
1760                    x = self.bn(self.conv(x1)) + self.conv2(x1)
1761                else:
1762                    x = self.conv(x1) + self.conv2(x1)
1763        else:
1764            if self.use_torch_add:
1765                if self.left_conv:
1766                    if self.with_bn:
1767                        x = torch.add(self.bn(self.conv(x1)), x2)
1768                    else:
1769                        x = torch.add(self.conv(x1), x2)
1770                else:
1771                    if self.with_bn:
1772                        x = torch.add(x2, self.bn(self.conv(x1)))
1773                    else:
1774                        x = torch.add(x2, self.conv(x1))
1775            else:
1776                if self.left_conv:
1777                    if self.with_bn:
1778                        x = self.bn(self.conv(x1)) + x2
1779                    else:
1780                        x = self.conv(x1) + x2
1781                else:
1782                    if self.with_bn:
1783                        x = x2 + self.bn(self.conv(x1))
1784                    else:
1785                        x = x2 + self.conv(x1)
1786        if self.with_relu:
1787            x = self.relu(x)
1788        return x
1789
1790    def get_example_inputs(self) -> Tuple[Any, ...]:
1791        return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2))
1792
1793# TODO: self.fc should be self.conv
1794class ConvReluModel(torch.nn.Module):
1795    def __init__(self) -> None:
1796        super().__init__()
1797        self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
1798        self.relu = torch.nn.ReLU()
1799
1800    def forward(self, x):
1801        x = self.relu(self.fc(x))
1802        return x
1803
1804    def get_example_inputs(self) -> Tuple[Any, ...]:
1805        return (torch.rand(1, 3, 5, 5),)
1806
1807# TODO: self.fc should be self.conv
1808class ConvReluConvModel(torch.nn.Module):
1809    def __init__(self) -> None:
1810        super().__init__()
1811        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
1812        self.relu = torch.nn.ReLU()
1813        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
1814
1815    def forward(self, x):
1816        x = self.fc1(x)
1817        x = self.relu(x)
1818        x = self.fc2(x)
1819        return x
1820
1821    def get_example_inputs(self) -> Tuple[Any, ...]:
1822        return (torch.rand(1, 3, 5, 5),)
1823
1824# TODO: self.fc should be self.conv
1825class ConvReluAddModel(torch.nn.Module):
1826    def __init__(self) -> None:
1827        super().__init__()
1828        self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float)
1829        self.relu = torch.nn.ReLU()
1830        self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float)
1831
1832    def forward(self, x):
1833        x = self.fc1(x)
1834        x = self.relu(x)
1835        x = torch.add(x, 5)
1836        x = self.fc2(x)
1837        self.relu = torch.nn.ReLU()
1838        return x
1839
1840    def get_example_inputs(self) -> Tuple[Any, ...]:
1841        return (torch.rand(1, 3, 5, 5),)
1842
1843class NormalizationTestModel(torch.nn.Module):
1844    def __init__(self) -> None:
1845        super().__init__()
1846        self.quant = torch.ao.quantization.QuantStub()
1847        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1848        self.layer_norm = torch.nn.LayerNorm(8)
1849        self.group_norm = torch.nn.GroupNorm(2, 8)
1850        self.instance_norm1d = torch.nn.InstanceNorm1d(8)
1851        self.instance_norm2d = torch.nn.InstanceNorm2d(8)
1852        self.instance_norm3d = torch.nn.InstanceNorm3d(8)
1853
1854    def forward(self, x):
1855        x = self.quant(x)
1856        x = self.fc1(x)
1857        x = self.layer_norm(x)
1858        x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3))
1859        x = self.instance_norm1d(x)
1860        x = self.instance_norm2d(x.unsqueeze(-1))
1861        x = self.instance_norm3d(x.unsqueeze(-1))
1862        return x
1863
1864class NestedModel(torch.nn.Module):
1865    def __init__(self) -> None:
1866        super().__init__()
1867        self.sub1 = LinearReluModel()
1868        self.sub2 = TwoLayerLinearModel()
1869        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1870
1871    def forward(self, x):
1872        x = self.sub1(x)
1873        x = self.sub2(x)
1874        x = self.fc3(x)
1875        return x
1876
1877class AnnotatedNestedModel(torch.nn.Module):
1878    def __init__(self, qengine):
1879        super().__init__()
1880        self.sub1 = LinearReluModel()
1881        self.sub2 = TwoLayerLinearModel()
1882        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
1883        self.fc3.qconfig = default_qconfig
1884        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
1885        if qengine == 'fbgemm':
1886            self.sub2.fc1.qconfig = default_per_channel_qconfig
1887        else:
1888            self.sub2.fc1.qconfig = default_qconfig
1889
1890    def forward(self, x):
1891        x = self.sub1(x)
1892        x = self.sub2(x)
1893        x = self.fc3(x)
1894        return x
1895
1896class AnnotatedSubNestedModel(torch.nn.Module):
1897    def __init__(self) -> None:
1898        super().__init__()
1899        self.sub1 = LinearReluModel()
1900        self.sub2 = QuantWrapper(TwoLayerLinearModel())
1901        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
1902        self.fc3.qconfig = default_qconfig
1903        self.sub2.qconfig = default_qconfig
1904
1905    def forward(self, x):
1906        x = self.sub1(x)
1907        x = self.sub2(x)
1908        x = self.fc3(x)
1909        return x
1910
1911class AnnotatedCustomConfigNestedModel(torch.nn.Module):
1912    def __init__(self) -> None:
1913        super().__init__()
1914        self.sub1 = LinearReluModel()
1915        self.sub2 = TwoLayerLinearModel()
1916        self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float))
1917        self.fc3.qconfig = default_qconfig
1918        self.sub2.qconfig = default_qconfig
1919
1920        custom_options = {
1921            'dtype': torch.quint8,
1922            'qscheme': torch.per_tensor_affine
1923        }
1924        custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options),
1925                                 weight=default_weight_observer)
1926        self.sub2.fc1.qconfig = custom_qconfig
1927
1928        self.sub2.fc1 = QuantWrapper(self.sub2.fc1)
1929        self.sub2.fc2 = QuantWrapper(self.sub2.fc2)
1930
1931    def forward(self, x):
1932        x = self.sub1(x)
1933        x = self.sub2(x)
1934        x = self.fc3(x)
1935        return x
1936
1937class QuantSubModel(torch.nn.Module):
1938    def __init__(self) -> None:
1939        super().__init__()
1940        self.sub1 = LinearReluModel()
1941        self.sub2 = QuantWrapper(TwoLayerLinearModel())
1942        self.sub2.qconfig = default_qconfig
1943        self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float)
1944        self.fc3.qconfig = default_qconfig
1945
1946    def forward(self, x):
1947        x = self.sub1(x)
1948        x = self.sub2(x)
1949        x = self.fc3(x)
1950        return x
1951
1952class InnerModule(torch.nn.Module):
1953    def __init__(self) -> None:
1954        super().__init__()
1955        self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float)
1956        self.relu1 = torch.nn.ReLU()
1957        self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float)
1958        self.relu2 = torch.nn.ReLU()
1959
1960    def forward(self, x):
1961        return self.relu2(self.fc2(self.relu1(self.fc1(x))))
1962
1963    def fuse_modules(self):
1964        fusable_layers = []
1965        named_children = list(self.named_children())
1966        for idx, (current_name, layer) in enumerate(named_children):
1967            if isinstance(layer, torch.nn.Linear):
1968                if idx >= len(named_children) - 1:
1969                    break
1970                if isinstance(named_children[idx + 1][1], torch.nn.ReLU):
1971                    fusable_layers.append([current_name,
1972                                           named_children[idx + 1][0]])
1973        # TODO: remove this check and define two fuse_modules function on this module
1974        if self.training:
1975            torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True)
1976        else:
1977            torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True)
1978
1979class FunctionalLinear(torch.nn.Module):
1980    def __init__(self) -> None:
1981        super().__init__()
1982        self.weight = torch.rand((5, 5))
1983        self.bias = torch.zeros(5)
1984
1985    def forward(self, x):
1986        return F.linear(x, self.weight, self.bias)
1987
1988    def get_example_inputs(self) -> Tuple[Any, ...]:
1989        return (torch.rand(1, 5),)
1990
1991class SingleLayerFunctionalLinearModel(torch.nn.Module):
1992    def __init__(self) -> None:
1993        super().__init__()
1994        self.linear1 = FunctionalLinear()
1995
1996    def forward(self, x):
1997        x = self.linear1(x)
1998        return x
1999
2000    def get_example_inputs(self) -> Tuple[Any, ...]:
2001        return self.linear1.get_example_inputs()
2002
2003class TwoLayerFunctionalLinearModel(torch.nn.Module):
2004    def __init__(self) -> None:
2005        super().__init__()
2006        self.linear1 = FunctionalLinear()
2007        self.linear2 = FunctionalLinear()
2008
2009    def forward(self, x):
2010        x = self.linear1(x)
2011        x = self.linear2(x)
2012        return x
2013
2014    def get_example_inputs(self) -> Tuple[Any, ...]:
2015        return self.linear1.get_example_inputs()
2016
2017class FunctionalLinearAddModel(torch.nn.Module):
2018    def __init__(self) -> None:
2019        super().__init__()
2020        self.linear1 = FunctionalLinear()
2021        self.linear2 = FunctionalLinear()
2022
2023    def forward(self, x):
2024        x = self.linear1(x)
2025        x = torch.add(x, 5)
2026        x = self.linear2(x)
2027        return x
2028
2029    def get_example_inputs(self) -> Tuple[Any, ...]:
2030        return self.linear1.get_example_inputs()
2031
2032class FunctionalLinearReluModel(nn.Module):
2033    def __init__(self) -> None:
2034        super().__init__()
2035        self.linear = FunctionalLinear()
2036
2037    def forward(self, x):
2038        x = self.linear(x)
2039        x = F.relu(x)
2040        return x
2041
2042    def get_example_inputs(self) -> Tuple[Any, ...]:
2043        return self.linear.get_example_inputs()
2044
2045class FunctionalLinearReluLinearModel(nn.Module):
2046    def __init__(self) -> None:
2047        super().__init__()
2048        self.linear1 = FunctionalLinear()
2049        self.relu = nn.ReLU()
2050        self.linear2 = FunctionalLinear()
2051
2052    def forward(self, x):
2053        x = self.linear1(x)
2054        x = self.relu(x)
2055        x = self.linear2(x)
2056        return x
2057
2058    def get_example_inputs(self) -> Tuple[Any, ...]:
2059        return self.linear1.get_example_inputs()
2060
2061class FunctionalConv2d(torch.nn.Module):
2062    def __init__(self) -> None:
2063        super().__init__()
2064        self.weight = torch.rand(3, 3, 3, 3)
2065        self.bias = torch.rand(3)
2066        self.stride = (1, 1)
2067        self.padding = (0, 0)
2068        self.dilation = (1, 1)
2069        self.groups = 1
2070
2071    def forward(self, x):
2072        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
2073
2074    def get_example_inputs(self) -> Tuple[Any, ...]:
2075        return (torch.rand(1, 3, 5, 5),)
2076
2077class SingleLayerFunctionalConvModel(torch.nn.Module):
2078    def __init__(self) -> None:
2079        super().__init__()
2080        self.conv1 = FunctionalConv2d()
2081
2082    def forward(self, x):
2083        x = self.conv1(x)
2084        return x
2085
2086    def get_example_inputs(self) -> Tuple[Any, ...]:
2087        return self.conv1.get_example_inputs()
2088
2089class TwoLayerFunctionalConvModel(torch.nn.Module):
2090    def __init__(self) -> None:
2091        super().__init__()
2092        self.conv1 = FunctionalConv2d()
2093        self.conv2 = FunctionalConv2d()
2094
2095    def forward(self, x):
2096        x = self.conv1(x)
2097        x = self.conv2(x)
2098        return x
2099
2100    def get_example_inputs(self) -> Tuple[Any, ...]:
2101        return self.conv1.get_example_inputs()
2102
2103class FunctionalConvReluModel(nn.Module):
2104    def __init__(self) -> None:
2105        super().__init__()
2106        self.conv = FunctionalConv2d()
2107
2108    def forward(self, x):
2109        x = self.conv(x)
2110        x = F.relu(x)
2111        return x
2112
2113    def get_example_inputs(self) -> Tuple[Any, ...]:
2114        return self.conv.get_example_inputs()
2115
2116class FunctionalConvReluConvModel(nn.Module):
2117    def __init__(self) -> None:
2118        super().__init__()
2119        self.conv1 = FunctionalConv2d()
2120        self.relu = nn.ReLU()
2121        self.conv2 = FunctionalConv2d()
2122
2123    def forward(self, x):
2124        x = self.conv1(x)
2125        x = self.relu(x)
2126        x = self.conv2(x)
2127        return x
2128
2129    def get_example_inputs(self) -> Tuple[Any, ...]:
2130        return self.conv1.get_example_inputs()
2131
2132class SkipQuantModel(torch.nn.Module):
2133    r"""We can skip quantization by explicitly
2134    setting qconfig of a submodule to None
2135    """
2136    def __init__(self) -> None:
2137        super().__init__()
2138        self.sub = InnerModule()
2139        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
2140
2141    def forward(self, x):
2142        return self.fc(self.sub(x))
2143
2144    def fuse_modules(self):
2145        self.sub.fuse_modules()
2146
2147class AnnotatedSkipQuantModel(torch.nn.Module):
2148    r"""We can skip quantization by explicitly
2149    setting qconfig of a submodule to None
2150    """
2151    def __init__(self, qengine):
2152        super().__init__()
2153        self.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
2154        self.sub = QuantWrapper(InnerModule())
2155        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
2156        # don't quantize this fc
2157        self.fc.qconfig = None
2158
2159    def forward(self, x):
2160        return self.fc(self.sub(x))
2161
2162    def fuse_modules(self):
2163        self.sub.module.fuse_modules()
2164
2165class QuantStubModel(torch.nn.Module):
2166    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
2167    """
2168    def __init__(self) -> None:
2169        super().__init__()
2170        self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
2171        self.quant = QuantStub()
2172        self.dequant = DeQuantStub()
2173        self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
2174
2175    def forward(self, x):
2176        x = self.quant(x)
2177        x = self.fc(x)
2178        return self.dequant(x)
2179
2180class ManualLinearQATModel(torch.nn.Module):
2181    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
2182    """
2183    def __init__(self, qengine):
2184        super().__init__()
2185        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
2186        self.quant = QuantStub()
2187        self.dequant = DeQuantStub()
2188        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
2189        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
2190
2191    def forward(self, x):
2192        x = self.quant(x)
2193        x = self.fc1(x)
2194        x = self.fc2(x)
2195        return self.dequant(x)
2196
2197class ManualDropoutQATModel(torch.nn.Module):
2198    r"""A Module with manually inserted `QuantStub` and `DeQuantStub`
2199    """
2200    def __init__(self, qengine):
2201        super().__init__()
2202        self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
2203        self.quant = QuantStub()
2204        self.dequant = DeQuantStub()
2205        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
2206        self.dropout = torch.nn.Dropout(0.5)
2207
2208    def forward(self, x):
2209        x = self.quant(x)
2210        x = self.fc1(x)
2211        x = self.dropout(x)
2212        return self.dequant(x)
2213
2214class ManualLinearDynamicQATModel(torch.nn.Module):
2215    r"""A Module that uses a dynamic QAT by default.
2216    """
2217    def __init__(self, qconfig=None):
2218        super().__init__()
2219        self.qconfig = qconfig or default_dynamic_qat_qconfig
2220        self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float)
2221        self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float)
2222
2223    def forward(self, x):
2224        x = self.fc1(x)
2225        x = self.fc2(x)
2226        return x
2227
2228class ManualConvLinearQATModel(torch.nn.Module):
2229    r"""A module with manually inserted `QuantStub` and `DeQuantStub`
2230    and contains both linear and conv modules
2231    """
2232    def __init__(self, qconfig=None):
2233        super().__init__()
2234        self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack")
2235        self.quant = QuantStub()
2236        self.dequant = DeQuantStub()
2237        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float)
2238        self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float)
2239        self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float)
2240
2241    def forward(self, x):
2242        x = self.quant(x)
2243        x = self.conv(x)
2244        x = x.view(-1, 64).contiguous()
2245        x = self.fc1(x)
2246        x = self.fc2(x)
2247        return self.dequant(x)
2248
2249class ManualConvLinearSymmQATModel(ManualConvLinearQATModel):
2250    r"""Same as ManualConvLinearQATModule but with Symmetric Quantization.
2251    Supported only with qnnpack.
2252    """
2253    def __init__(self) -> None:
2254        super().__init__(default_symmetric_qnnpack_qat_qconfig)
2255
2256class ManualEmbeddingBagLinear(nn.Module):
2257    def __init__(self) -> None:
2258        super().__init__()
2259        self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum')
2260        self.emb.qconfig = default_embedding_qat_qconfig
2261        self.quant = QuantStub()
2262        self.dequant = DeQuantStub()
2263        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
2264        self.qconfig = get_default_qat_qconfig("qnnpack")
2265
2266    def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None,
2267                per_sample_weights: Optional[torch.Tensor] = None):
2268        x = self.emb(input, offsets, per_sample_weights)
2269        x = self.quant(x)
2270        x = self.linear(x)
2271        return self.dequant(x)
2272
2273class DeFusedEmbeddingBagLinear(nn.Module):
2274    r"""A module to simulate QAT embedding bag with a linear layer,
2275    this module uses a separate embedding and bagging op, similar
2276    to that which is described in the EmbeddingBag documentation.
2277
2278    https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html
2279    """
2280    def __init__(self) -> None:
2281        super().__init__()
2282        self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12)
2283        self.emb.qconfig = default_embedding_qat_qconfig
2284        self.bagging_op = torch.sum
2285        self.quant = QuantStub()
2286        self.dequant = DeQuantStub()
2287        self.linear = nn.Linear(12, 1).to(dtype=torch.float)
2288        self.qconfig = get_default_qat_qconfig("qnnpack")
2289
2290    def forward(self, input: torch.Tensor) -> torch.Tensor:
2291        x = self.bagging_op(self.emb(input), dim=1)
2292        x = self.quant(x)
2293        x = self.linear(x)
2294        return self.dequant(x)
2295
2296class SubModelForFusion(nn.Module):
2297    def __init__(self) -> None:
2298        super().__init__()
2299        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
2300        self.bn = nn.BatchNorm2d(2).to(dtype=torch.float)
2301
2302    def forward(self, x):
2303        x = self.conv(x)
2304        x = self.bn(x)
2305        return x
2306
2307
2308class SubModelWithoutFusion(nn.Module):
2309    def __init__(self) -> None:
2310        super().__init__()
2311        self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float)
2312        self.relu = nn.ReLU(inplace=False).to(dtype=torch.float)
2313
2314    def forward(self, x):
2315        return self.relu(self.conv(x))
2316
2317class ModelForFusion(nn.Module):
2318    def __init__(self, qconfig):
2319        super().__init__()
2320        self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float)
2321        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
2322        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
2323        self.sub1 = SubModelForFusion()
2324        self.sub2 = SubModelWithoutFusion()
2325        self.fc = nn.Linear(36, 10).to(dtype=torch.float)
2326        self.quant = QuantStub()
2327        self.dequant = DeQuantStub()
2328        self.qconfig = qconfig
2329        self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float)
2330        self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
2331        self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float)
2332        self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float)
2333        self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float)
2334        self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float)
2335        self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float)
2336        # don't quantize sub2
2337        self.sub2.qconfig = None
2338        self.fc.qconfig = None
2339
2340    def forward(self, x):
2341        x = x.squeeze(2)
2342        x = self.quant(x)
2343        x = self.conv3(x)
2344        x = self.bn3(x)
2345        x = self.relu4(x)
2346        x = x.unsqueeze(2)
2347        y = x.unsqueeze(2)
2348        x = self.conv1(x)
2349        x = self.bn1(x)
2350        x = self.relu1(x)
2351        x = self.sub1(x)
2352        x = self.dequant(x)
2353        x = self.sub2(x)
2354        x = x.reshape(-1, 36).contiguous()
2355        x = self.fc(x)
2356        y = self.conv2(y)
2357        y = self.relu2(y)
2358        y = self.bn2(y)
2359        y = self.relu3(y)
2360        y = self.dequant(y)
2361        return x
2362
2363class ConvBNReLU(nn.Sequential):
2364    def __init__(self) -> None:
2365        super().__init__(
2366            nn.Conv2d(3, 3, 1, 1, bias=False),
2367            nn.BatchNorm2d(3),
2368            nn.ReLU(inplace=False)
2369        )
2370
2371class ModelWithSequentialFusion(nn.Module):
2372    def __init__(self) -> None:
2373        super().__init__()
2374        self.conv1 = nn.Conv2d(3, 3, 1)
2375        self.relu1 = nn.ReLU(inplace=False)
2376        layers = []
2377        for i in range(3):
2378            layers.append(ConvBNReLU())
2379        self.features = nn.Sequential(*layers)
2380        head = [nn.Linear(300, 10), nn.ReLU(inplace=False)]
2381        self.classifier = nn.Sequential(*head)
2382        self.seq = nn.Sequential()
2383        self.quant = QuantStub()
2384        self.dequant = DeQuantStub()
2385
2386    def forward(self, x):
2387        x = self.quant(x)
2388        x = self.conv1(x)
2389        x = self.relu1(x)
2390        x = self.features(x)
2391        x = torch.reshape(x, (-1, 3 * 10 * 10))
2392        x = self.classifier(x)
2393        x = self.seq(x)
2394        x = self.dequant(x)
2395        return x
2396
2397class ModelForFusionWithBias(nn.Module):
2398    def __init__(self) -> None:
2399        super().__init__()
2400        self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float)
2401        self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float)
2402        self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float)
2403        self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float)
2404        self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float)
2405        self.quant = QuantStub()
2406        self.dequant = DeQuantStub()
2407
2408    def forward(self, x):
2409        x = self.quant(x)
2410        x = self.conv1(x)
2411        x = self.bn1(x)
2412        x = self.relu1(x)
2413        x = self.conv2(x)
2414        x = self.bn2(x)
2415        x = self.dequant(x)
2416        return x
2417
2418class ModelForLinearBNFusion(nn.Module):
2419    def __init__(self) -> None:
2420        super().__init__()
2421        self.fc = nn.Linear(20, 10)
2422        self.bn = nn.BatchNorm1d(10)
2423        nn.init.uniform_(self.bn.weight)
2424        nn.init.uniform_(self.bn.bias)
2425
2426    def forward(self, x):
2427        return self.bn(self.fc(x))
2428
2429class DummyObserver(torch.nn.Module):
2430    def calculate_qparams(self):
2431        return 1.0, 0
2432
2433    def forward(self, x):
2434        return x
2435
2436
2437class ModelForConvTransposeBNFusion(nn.Module):
2438    def __init__(self) -> None:
2439        super().__init__()
2440        self.conv1 = nn.ConvTranspose1d(3, 3, 1)
2441        self.bn1 = nn.BatchNorm1d(3)
2442        self.conv2 = nn.ConvTranspose2d(3, 3, 1)
2443        self.bn2 = nn.BatchNorm2d(3)
2444        self.conv3 = nn.ConvTranspose3d(3, 3, 1)
2445        self.bn3 = nn.BatchNorm3d(3)
2446
2447    def forward(self, x):
2448        x = self.conv1(x)
2449        x = self.bn1(x)
2450        x = x.unsqueeze(2)
2451        x = self.conv2(x)
2452        x = self.bn2(x)
2453        x = x.unsqueeze(2)
2454        x = self.conv3(x)
2455        x = self.bn3(x)
2456        return x
2457
2458
2459class ModelWithFunctionals(torch.nn.Module):
2460    def __init__(self) -> None:
2461        super().__init__()
2462        self.mycat = nnq.FloatFunctional()
2463        self.myadd = nnq.FloatFunctional()
2464        self.myadd_relu = nnq.FloatFunctional()
2465        self.mymatmul = nnq.FloatFunctional()
2466        # Tracing doesnt work yet for c10 ops with scalar inputs
2467        # https://github.com/pytorch/pytorch/issues/27097
2468        # self.my_scalar_add = nnq.FloatFunctional()
2469        # self.my_scalar_mul = nnq.FloatFunctional()
2470
2471    def forward(self, x):
2472        y = self.mycat.cat([x, x, x])
2473        z = self.myadd.add(y, y)
2474        w = self.myadd_relu.add_relu(z, z)
2475        u = self.mymatmul.matmul(w, w.T)
2476        # Tracing doesnt work yet for c10 ops with scalar inputs
2477        # https://github.com/pytorch/pytorch/issues/27097
2478        # w = self.my_scalar_add.add_scalar(w, -0.5)
2479        # w = self.my_scalar_mul.mul_scalar(w, 0.5)
2480        return u
2481
2482
2483class ResNetBase(torch.nn.Module):
2484    def __init__(self) -> None:
2485        super().__init__()
2486        norm_layer = nn.BatchNorm2d
2487        inplanes = 3
2488        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
2489        self.bn1 = norm_layer(inplanes)
2490        self.relu1 = nn.ReLU()
2491        self.relu2 = nn.ReLU()
2492        self.downsample = torch.nn.Identity()
2493        self.myop = nn.quantized.FloatFunctional()
2494        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
2495        self.fc = torch.nn.Linear(inplanes, 1)
2496
2497    def forward(self, x):
2498        out = self.conv1(x)
2499        out = self.bn1(out)
2500        out = self.relu1(out)
2501        identity = self.downsample(x)
2502        out = self.myop.add(out, identity)
2503        out = self.relu2(out)
2504        out = self.avgpool(out)
2505        out = torch.flatten(out, 1)
2506        out = self.fc(out)
2507        return out
2508
2509    def fuse_model(self):
2510        # TODO: remove this check and define two fuse_model function on this module
2511        if self.training:
2512            torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True)
2513        else:
2514            torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True)
2515
2516class ModelMultipleOps(torch.nn.Module):
2517    def __init__(self) -> None:
2518        super().__init__()
2519        norm_layer = nn.BatchNorm2d
2520        inplanes = 3
2521        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
2522        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
2523        self.bn1 = norm_layer(inplanes)
2524        self.relu1 = nn.ReLU()
2525        self.relu2 = nn.ReLU()
2526        self.downsample = torch.nn.Identity()
2527        self.skip_add = nn.quantized.FloatFunctional()
2528        self.cat = nn.quantized.FloatFunctional()
2529        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
2530        self.fc = nn.Linear(12, 6)
2531
2532    def forward(self, x):
2533        out = self.conv1(x)
2534        out = self.bn1(out)
2535        out = self.relu1(out)
2536        identity = self.downsample(x)
2537        out = self.skip_add.add(out, identity)
2538        out = self.relu2(out)
2539        out = self.avgpool(out)
2540        out = self.conv2(out)
2541        out = torch.nn.functional.max_pool2d(out, 2, 2)
2542        out = self.cat.cat([out, out])
2543        out = out.reshape(-1, 3 * 2 * 2)
2544        out = self.fc(out)
2545        return out
2546
2547# Model to ensure consistency of fake quant with true quant
2548# Average pooling and mean operations are not modelled
2549# accurately with fake-quant so this model does not
2550# contain those operations
2551class ModelMultipleOpsNoAvgPool(torch.nn.Module):
2552    def __init__(self) -> None:
2553        super().__init__()
2554        norm_layer = nn.BatchNorm2d
2555        inplanes = 3
2556        self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
2557        self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False)
2558        self.bn1 = norm_layer(inplanes)
2559        self.relu1 = nn.ReLU()
2560        self.relu2 = nn.ReLU()
2561        self.skip_add = nn.quantized.FloatFunctional()
2562        self.cat = nn.quantized.FloatFunctional()
2563        self.maxpool = nn.MaxPool2d((4, 4))
2564        self.fc = nn.Linear(12, 6)
2565
2566    def forward(self, x):
2567        out = self.conv1(x)
2568        out = self.bn1(out)
2569        out = self.relu1(out)
2570        skip = self.conv2(x)
2571        out = self.skip_add.add(out, skip)
2572        out = self.relu2(out)
2573        out = self.maxpool(out)
2574        out = self.conv2(out)
2575        out = torch.nn.functional.max_pool2d(out, 2, 2)
2576        out = self.cat.cat([out, out])
2577        out = out.reshape(-1, 3 * 2 * 2)
2578        out = self.fc(out)
2579        return out
2580
2581class EmbeddingBagModule(torch.nn.Module):
2582    def __init__(self) -> None:
2583        super().__init__()
2584        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
2585                                         include_last_offset=True, scale_grad_by_freq=False, mode='sum')
2586
2587    def forward(self, indices, offsets, per_sample_weights):
2588        return self.emb(indices, offsets, per_sample_weights)
2589
2590class EmbeddingModule(torch.nn.Module):
2591    def __init__(self) -> None:
2592        super().__init__()
2593        self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
2594
2595    def forward(self, indices):
2596        return self.emb(indices)
2597
2598class EmbeddingWithStaticLinear(torch.nn.Module):
2599    def __init__(self) -> None:
2600        super().__init__()
2601        self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12)
2602        self.fc = torch.nn.Linear(4, 2)
2603        self.emb.qconfig = float_qparams_weight_only_qconfig
2604        self.qconfig = default_qconfig
2605        self.quant = QuantStub()
2606        self.dequant = DeQuantStub()
2607
2608    def forward(self, indices, offsets, linear_in):
2609        emb = self.emb(indices, offsets)
2610        q_x = self.quant(linear_in)
2611        fc = self.fc(q_x)
2612        fc = self.dequant(fc)
2613        features = torch.cat([fc] + [emb], dim=1)
2614        return features
2615
2616class DenseTopMLP(nn.Module):
2617
2618    def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None:
2619        super().__init__()
2620
2621        self.dense_mlp = nn.Sequential(
2622            nn.Linear(dense_dim, dense_out),
2623        )
2624        self.top_mlp = nn.Sequential(
2625            nn.Linear(dense_out + embedding_dim, top_out_in),
2626            nn.Linear(top_out_in, top_out_out),
2627        )
2628
2629    def forward(
2630        self,
2631        sparse_feature: torch.Tensor,
2632        dense: torch.Tensor,
2633    ) -> torch.Tensor:
2634        dense_feature = self.dense_mlp(dense)
2635        features = torch.cat([dense_feature] + [sparse_feature], dim=1)
2636
2637        out = self.top_mlp(features)
2638        return out
2639
2640# thin wrapper around embedding bag, because tracing inside nn.Embedding
2641# bag is not supported at the moment and this is top level
2642class EmbBagWrapper(nn.Module):
2643    def __init__(self, num_embeddings, embedding_dim):
2644        super().__init__()
2645        self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum')
2646
2647    def forward(self, indices, offsets):
2648        return self.emb_bag(indices, offsets)
2649
2650class SparseNNModel(nn.Module):
2651    _NUM_EMBEDDINGS = 10
2652    _EMBEDDING_DIM = 5
2653    _DENSE_DIM = 4
2654    _DENSE_OUTPUT = 2
2655    _TOP_OUT_IN = 2
2656    _TOP_OUT_OUT = 2
2657    _TOP_MLP_DIM = 1
2658
2659    def __init__(self) -> None:
2660        super().__init__()
2661
2662        self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM)
2663        self.dense_top = DenseTopMLP(
2664            self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN,
2665            self._TOP_OUT_OUT)
2666
2667    def forward(
2668        self,
2669        sparse_indices: torch.Tensor,
2670        sparse_offsets: torch.Tensor,
2671        dense: torch.Tensor,
2672    ) -> torch.Tensor:
2673
2674        sparse_feature = self.model_sparse(sparse_indices, sparse_offsets)
2675        out = self.dense_top(sparse_feature, dense)
2676
2677        return out
2678
2679class TestHelperModules:
2680    class Conv2dPropAnnotaton(torch.nn.Module):
2681        def __init__(self) -> None:
2682            super().__init__()
2683            self.conv = torch.nn.Conv2d(3, 3, 3)
2684            self.linear = torch.nn.Linear(3, 3)
2685
2686        def forward(self, x):
2687            x = self.conv(x)
2688            x = x.view(-1, 3)
2689            x = torch.nn.functional.hardtanh(x, -0.5, 0.5)
2690            x = self.linear(x)
2691            return x
2692
2693    class Conv2dWithObsSharingOps(torch.nn.Module):
2694        def __init__(self) -> None:
2695            super().__init__()
2696            self.conv = torch.nn.Conv2d(3, 3, 3)
2697            self.hardtanh = torch.nn.Hardtanh()
2698            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
2699
2700        def forward(self, x):
2701            x = self.conv(x)
2702            x = self.adaptive_avg_pool2d(x)
2703            x = self.hardtanh(x)
2704            x = torch.mean(x)
2705            return x
2706
2707    class Conv2dWithTwoLinearPermute(torch.nn.Module):
2708        def __init__(self) -> None:
2709            super().__init__()
2710            self.conv = torch.nn.Conv2d(3, 16, 3)
2711            self.linear1 = torch.nn.Linear(16, 8, bias=False)
2712            self.linear2 = torch.nn.Linear(8, 8)
2713
2714        def forward(self, x):
2715            conv_out = self.conv(x)
2716            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
2717            return self.linear2(self.linear1(permute_out))
2718
2719    class Conv2dWithTwoLinear(torch.nn.Module):
2720        def __init__(self) -> None:
2721            super().__init__()
2722            self.conv = torch.nn.Conv2d(3, 16, 3)
2723            self.linear1 = torch.nn.Linear(64, 8, bias=False)
2724            self.linear2 = torch.nn.Linear(8, 8)
2725
2726        def forward(self, x):
2727            conv_out = self.conv(x)
2728            reshape_out = torch.reshape(conv_out, (2, 64))
2729            return self.linear2(self.linear1(reshape_out))
2730
2731    class ConvLinearWPermute(torch.nn.Module):
2732        def __init__(self) -> None:
2733            super().__init__()
2734            self.conv = torch.nn.Conv2d(3, 8, 3)
2735            self.linear1 = torch.nn.Linear(8, 8)
2736
2737        def forward(self, x):
2738            conv_out = self.conv(x)
2739            permute_out = torch.permute(conv_out, (0, 2, 3, 1))
2740            return self.linear1(permute_out)
2741
2742    class TwoLinearModule(torch.nn.Module):
2743        def __init__(self) -> None:
2744            super().__init__()
2745            self.linear1 = torch.nn.Linear(8, 16, bias=False)
2746            self.linear2 = torch.nn.Linear(16, 8)
2747
2748        def forward(self, x):
2749            return self.linear2(self.linear1(x))
2750
2751    class ConvMaxPool2d(torch.nn.Module):
2752        def __init__(self) -> None:
2753            super().__init__()
2754            self.conv = torch.nn.Conv2d(2, 2, 1)
2755            self.pool = torch.nn.MaxPool2d(1, 1)
2756
2757        def forward(self, x):
2758            x = self.conv(x)
2759            x = self.pool(x)
2760            return x
2761
2762    class ConvWithAdaptiveAvgPool2d(torch.nn.Module):
2763        def __init__(self) -> None:
2764            super().__init__()
2765            self.conv = torch.nn.Conv2d(3, 3, 3)
2766            self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
2767
2768        def forward(self, x):
2769            x = self.conv(x)
2770            x = self.adaptive_avg_pool2d(x)
2771            return x
2772
2773    class ConvWithBNRelu(torch.nn.Module):
2774        def __init__(self, relu, dim=2, bn=True, bias=True):
2775            super().__init__()
2776            convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d}
2777            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
2778            self.conv = convs[dim](3, 3, 3, bias=bias)
2779
2780            if bn:
2781                self.bn = bns[dim](3)
2782            else:
2783                self.bn = torch.nn.Identity()
2784            if relu:
2785                self.relu = torch.nn.ReLU()
2786            else:
2787                self.relu = torch.nn.Identity()
2788
2789        def forward(self, x):
2790            x = self.conv(x)
2791            x = self.bn(x)
2792            return self.relu(x)
2793
2794    class ConvTWithBNRelu(torch.nn.Module):
2795        def __init__(self, relu, dim=2, bn=True, bias=True):
2796            super().__init__()
2797            convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
2798            bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
2799            self.convt = convts[dim](3, 3, 3, bias=bias)
2800
2801            if bn:
2802                self.bn = bns[dim](3)
2803            else:
2804                self.bn = torch.nn.Identity()
2805            if relu:
2806                self.relu = torch.nn.ReLU()
2807            else:
2808                self.relu = torch.nn.Identity()
2809
2810        def forward(self, x):
2811            x = self.convt(x)
2812            x = self.bn(x)
2813            return self.relu(x)
2814
2815    class Conv2dThenConv1d(torch.nn.Module):
2816        def __init__(self) -> None:
2817            super().__init__()
2818            self.conv1d = torch.nn.Conv1d(3, 3, 3)
2819            self.conv2d = torch.nn.Conv2d(3, 3, 3)
2820
2821        def forward(self, x):
2822            x = self.conv2d(x)
2823            x = x.squeeze(0)
2824            x = self.conv1d(x)
2825            return x
2826
2827        def example_inputs(self):
2828            return (torch.randn(1, 3, 5, 5),)
2829
2830    class Conv2dWithCat(torch.nn.Module):
2831        def __init__(self) -> None:
2832            super().__init__()
2833            self.conv1 = torch.nn.Conv2d(3, 3, 3)
2834            self.conv2 = torch.nn.Conv2d(3, 3, 3)
2835
2836        def forward(self, x, y):
2837            x = self.conv1(x)
2838            y = self.conv2(y)
2839            z = torch.cat([x, y], dim=1)
2840            return z
2841
2842    class Conv2dWithTwoCat(torch.nn.Module):
2843        def __init__(self) -> None:
2844            super().__init__()
2845            self.conv1 = torch.nn.Conv2d(3, 3, 3)
2846            self.conv2 = torch.nn.Conv2d(3, 3, 3)
2847
2848        def forward(self, x1, x2, x3, x4):
2849            x1 = self.conv1(x1)
2850            x2 = self.conv2(x2)
2851            y = torch.cat([x1, x2], dim=1)
2852            z = x3 + x4
2853            w = torch.cat([z, y])
2854            return w
2855
2856    class ThreeAdd(torch.nn.Module):
2857        def forward(self, x1, x2, x3, x4):
2858            y = x1 + x2
2859            z = x3 + x4
2860            w = y + z
2861            return w
2862
2863    class EmbeddingModule(torch.nn.Module):
2864        def __init__(self) -> None:
2865            super().__init__()
2866            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
2867
2868        def forward(self, indices):
2869            return self.emb(indices)
2870
2871    class EmbeddingConvLinearModule(torch.nn.Module):
2872        def __init__(self) -> None:
2873            super().__init__()
2874            self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8)
2875            self.conv = torch.nn.Conv2d(8, 16, (1, 3))
2876            self.linear = torch.nn.Linear(16, 8)
2877
2878        def forward(self, indices):
2879            embeddings = self.emb(indices)
2880            embeddings = torch.unsqueeze(embeddings, dim=0)
2881            embeddings = torch.permute(embeddings, (0, 3, 1, 2))
2882            conv_out = self.conv(embeddings)
2883            conv_out = torch.permute(conv_out, (0, 2, 3, 1))
2884            conv_out = torch.squeeze(conv_out, dim=0)
2885            return self.linear(conv_out)
2886
2887    class AddInplaceAdd(torch.nn.Module):
2888        def forward(self, x, y):
2889            x = x + y
2890            x += y
2891            return x
2892
2893    class MulInplaceMul(torch.nn.Module):
2894        def forward(self, x, y):
2895            x = x * y
2896            x *= y
2897            return x
2898
2899    class AddMulScalar(torch.nn.Module):
2900        def forward(self, x):
2901            x = x + 3
2902            x = x * 3
2903            x += 3
2904            x *= 3
2905            return x
2906
2907    class ConvBnReLU2dAndLinearReLU(torch.nn.Module):
2908        def __init__(self) -> None:
2909            super().__init__()
2910            self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True)
2911            self.linear = torch.nn.Linear(3, 8, bias=False)
2912            self.relu = torch.nn.ReLU()
2913
2914        def forward(self, x):
2915            x = self.conv_bn_relu(x)
2916            permute_out = torch.permute(x, (0, 2, 3, 1))
2917            linear_out = self.linear(permute_out)
2918            return linear_out
2919
2920    class GroupwiseConv2d(torch.nn.Module):
2921        def __init__(self) -> None:
2922            super().__init__()
2923            self.conv = torch.nn.Conv2d(4, 4, 3, groups=2)
2924
2925        def forward(self, x):
2926            return self.conv(x)
2927
2928        def example_inputs(self):
2929            return (torch.randn(2, 4, 10, 10),)
2930
2931    class LinearReluModel(torch.nn.Module):
2932        def __init__(self) -> None:
2933            super().__init__()
2934            self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float)
2935            self.relu = torch.nn.ReLU()
2936
2937        def forward(self, x):
2938            x = self.relu(self.fc(x))
2939            return x
2940
2941def _generate_qdq_quantized_model(
2942    mod, inputs, is_qat=False, is_dynamic=False, quantizer=None
2943):
2944
2945    def get_default_quantizer(is_qat, is_dynamic):
2946        quantizer = X86InductorQuantizer()
2947        quantizer.set_global(
2948            xiq.get_default_x86_inductor_quantization_config(
2949                is_qat=is_qat, is_dynamic=is_dynamic
2950            )
2951        )
2952        return quantizer
2953
2954    maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
2955    with maybe_no_grad:
2956        export_model = capture_pre_autograd_graph(
2957            mod,
2958            inputs,
2959        )
2960        quantizer = (
2961            quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic)
2962        )
2963        prepare_model = (
2964            prepare_qat_pt2e(export_model, quantizer)
2965            if is_qat
2966            else prepare_pt2e(export_model, quantizer)
2967        )
2968        prepare_model(*inputs)
2969        torch.ao.quantization.move_exported_model_to_eval(prepare_model)
2970        convert_model = convert_pt2e(prepare_model)
2971        return convert_model
2972