1# mypy: ignore-errors 2 3r"""Importing this file includes common utility methods and base clases for 4checking quantization api and properties of resulting modules. 5""" 6 7import torch 8import torch.nn as nn 9import torch.nn.functional as F 10import torch.ao.nn.intrinsic.quantized.dynamic as nniqd 11import torch.ao.nn.quantized as nnq 12import torch.ao.nn.quantized.dynamic as nnqd 13from torch.ao.nn.intrinsic import _FusedModule 14import torch.distributed as dist 15from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM 16 17from torch._export import capture_pre_autograd_graph 18from torch.ao.quantization import ( 19 QuantType, 20 default_dynamic_qat_qconfig, 21 default_embedding_qat_qconfig, 22 default_symmetric_qnnpack_qat_qconfig, 23) 24from torch.ao.quantization.quantize_pt2e import ( 25 _convert_to_reference_decomposed_fx, 26 convert_pt2e, 27 prepare_pt2e, 28 prepare_qat_pt2e, 29) 30from torch.ao.quantization.backend_config import ( 31 get_executorch_backend_config, 32) 33from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 34 XNNPACKQuantizer, 35 get_symmetric_quantization_config, 36) 37from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \ 38 default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ 39 propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ 40 get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \ 41 QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping 42from torch.ao.quantization.quantization_mappings import ( 43 get_default_dynamic_quant_module_mappings, 44 get_default_qconfig_propagation_list, 45 get_default_qat_module_mappings, 46) 47from torch.testing._internal.common_quantized import ( 48 override_quantized_engine, 49) 50from torch.jit.mobile import _load_for_lite_interpreter 51 52try: 53 # graph mode quantization based on fx 54 from torch.ao.quantization.quantize_fx import ( 55 prepare_fx, 56 prepare_qat_fx, 57 convert_fx, 58 convert_to_reference_fx, 59 ) 60 from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph 61 from torch.fx.graph import Node 62 from torch.fx import GraphModule 63 HAS_FX = True 64except ImportError: 65 HAS_FX = False 66 67import copy 68import io 69import functools 70import time 71import os 72 73import unittest 74import numpy as np 75from torch.testing import FileCheck 76from typing import Callable, Tuple, Dict, Any, Union, Type, Optional 77import torch._dynamo as torchdynamo 78import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq 79from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer 80import contextlib 81 82class NodeSpec: 83 ''' Used for checking GraphModule Node 84 ''' 85 def __init__(self, op, target): 86 ''' 87 op: call_function | call_module 88 target: 89 for call_function, target would be a function 90 for call_module, target would be the type of PyTorch module 91 ''' 92 self.op = op 93 self.target = target 94 95 @classmethod 96 def call_function(cls, target): 97 return NodeSpec('call_function', target) 98 99 @classmethod 100 def call_method(cls, target): 101 return NodeSpec('call_method', target) 102 103 @classmethod 104 def call_module(cls, target): 105 return NodeSpec('call_module', target) 106 107 def __hash__(self): 108 return hash((self.op, self.target)) 109 110 def __eq__(self, other): 111 if not isinstance(other, NodeSpec): 112 return NotImplemented 113 114 return self.op == other.op and self.target == other.target 115 116 def __repr__(self): 117 return repr(self.op) + " " + repr(self.target) 118 119def get_supported_device_types(): 120 return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu'] 121 122def test_only_eval_fn(model, calib_data): 123 r""" 124 Default evaluation function takes a torch.utils.data.Dataset or a list of 125 input Tensors and run the model on the dataset 126 """ 127 for inp in calib_data: 128 output = model(*inp) 129 130_default_loss_fn = torch.nn.CrossEntropyLoss() 131def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): 132 r""" 133 Default train function takes a torch.utils.data.Dataset and train the model 134 on the dataset 135 """ 136 optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 137 train_loss, correct, total = 0, 0, 0 138 for i in range(10): 139 model.train() 140 141 for data, target in train_data: 142 optimizer.zero_grad() 143 output = model(data) 144 loss = loss_fn(output, target) 145 loss.backward() 146 optimizer.step() 147 train_loss += loss.item() 148 _, predicted = torch.max(output, 1) 149 total += target.size(0) 150 correct += (predicted == target).sum().item() 151 return train_loss, correct, total 152 153class AverageMeter: 154 """Computes and stores the average and current value""" 155 def __init__(self, name, fmt=':f'): 156 self.name = name 157 self.fmt = fmt 158 self.reset() 159 160 def reset(self): 161 self.val = 0 162 self.avg = 0 163 self.sum = 0 164 self.count = 0 165 166 def update(self, val, n=1): 167 self.val = val 168 self.sum += val * n 169 self.count += n 170 self.avg = self.sum / self.count 171 172 def __str__(self): 173 fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 174 return fmtstr.format(**self.__dict__) 175 176 177def accuracy(output, target, topk=(1,)): 178 """Computes the accuracy over the k top predictions for the specified values of k""" 179 with torch.no_grad(): 180 maxk = max(topk) 181 batch_size = target.size(0) 182 183 _, pred = output.topk(maxk, 1, True, True) 184 pred = pred.t() 185 correct = pred.eq(target.view(1, -1).expand_as(pred)) 186 187 res = [] 188 for k in topk: 189 correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 190 res.append(correct_k.mul_(100.0 / batch_size)) 191 return res 192 193def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): 194 model.train() 195 cnt = 0 196 for image, target in data_loader: 197 start_time = time.time() 198 print('.', end='') 199 cnt += 1 200 image, target = image.to(device), target.to(device) 201 output = model(image) 202 loss = criterion(output, target) 203 optimizer.zero_grad() 204 loss.backward() 205 optimizer.step() 206 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 207 if cnt >= ntrain_batches: 208 return 209 return 210 211def ddp_setup(rank, world_size): 212 os.environ['MASTER_ADDR'] = 'localhost' 213 os.environ['MASTER_PORT'] = '12355' 214 215 # initialize the process group 216 dist.init_process_group("gloo", rank=rank, world_size=world_size) 217 218def ddp_cleanup(): 219 dist.destroy_process_group() 220 221def run_ddp(rank, world_size, prepared): 222 ddp_setup(rank, world_size) 223 prepared.cuda() 224 prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank]) 225 prepared.to(rank) 226 model_with_ddp = prepared 227 optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001) 228 train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) # noqa: F821 229 ddp_cleanup() 230 231 232def convert_dynamic(module): 233 convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) 234 235def prepare_dynamic(model, qconfig_dict=None): 236 propagate_qconfig_(model, qconfig_dict) 237 238def _make_conv_test_input( 239 batch_size, in_channels_per_group, input_feature_map_size, 240 out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, 241 W_zero_point, use_bias, use_channelwise, 242): 243 in_channels = in_channels_per_group * groups 244 out_channels = out_channels_per_group * groups 245 246 (X_value_min, X_value_max) = (0, 4) 247 X_init = torch.randint( 248 X_value_min, X_value_max, 249 (batch_size, in_channels,) + input_feature_map_size) 250 X = X_scale * (X_init - X_zero_point).float() 251 X_q = torch.quantize_per_tensor( 252 X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) 253 254 W_scale = W_scale * out_channels 255 W_zero_point = W_zero_point * out_channels 256 # Resize W_scale and W_zero_points arrays equal to out_channels 257 W_scale = W_scale[:out_channels] 258 W_zero_point = W_zero_point[:out_channels] 259 # For testing, we use small values for weights and for activations so that 260 # no overflow occurs in vpmaddubsw instruction. If the overflow occurs in 261 # qconv implementation and if there is no overflow. 262 # In reference we can't exactly match the results with reference. 263 # Please see the comment in qconv implementation file 264 # aten/src/ATen/native/quantized/cpu/qconv.cpp for more details. 265 (W_value_min, W_value_max) = (-5, 5) 266 # The operator expects them in the format 267 # (out_channels, in_channels/groups,) + kernel_size 268 W_init = torch.randint( 269 W_value_min, W_value_max, 270 (out_channels, in_channels_per_group,) + kernel_size) 271 b_init = torch.randint(0, 10, (out_channels,)) 272 273 if use_channelwise: 274 W_shape = (-1, 1) + (1,) * len(kernel_size) 275 W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) 276 W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) 277 W = W_scales_tensor.reshape(*W_shape) * ( 278 W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() 279 b = X_scale * W_scales_tensor * b_init.float() 280 W_q = torch.quantize_per_channel( 281 W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0, 282 dtype=torch.qint8) 283 else: 284 W = W_scale[0] * (W_init - W_zero_point[0]).float() 285 b = X_scale * W_scale[0] * b_init.float() 286 W_q = torch.quantize_per_tensor( 287 W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) 288 289 return (X, X_q, W, W_q, b if use_bias else None) 290 291def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): 292 (X_value_min, X_value_max) = (0, 4) 293 X_init = torch.randint( 294 X_value_min, 295 X_value_max, 296 sizes # Infer the size of tensor to do the add 297 ) 298 X = scale * (X_init - zero_point).float() 299 X_q = torch.quantize_per_tensor( 300 X, scale=scale, zero_point=zero_point, dtype=torch.quint8) 301 return X, X_q 302 303def skipIfNoFBGEMM(fn): 304 reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.' 305 if isinstance(fn, type): 306 if 'fbgemm' not in torch.backends.quantized.supported_engines: 307 fn.__unittest_skip__ = True 308 fn.__unittest_skip_why__ = reason 309 return fn 310 311 @functools.wraps(fn) 312 def wrapper(*args, **kwargs): 313 if 'fbgemm' not in torch.backends.quantized.supported_engines: 314 raise unittest.SkipTest(reason) 315 else: 316 fn(*args, **kwargs) 317 return wrapper 318 319def skipIfNoQNNPACK(fn): 320 reason = 'Quantized operations require QNNPACK.' 321 if isinstance(fn, type): 322 if 'qnnpack' not in torch.backends.quantized.supported_engines: 323 fn.__unittest_skip__ = True 324 fn.__unittest_skip_why__ = reason 325 return fn 326 327 @functools.wraps(fn) 328 def wrapper(*args, **kwargs): 329 if 'qnnpack' not in torch.backends.quantized.supported_engines: 330 raise unittest.SkipTest(reason) 331 else: 332 fn(*args, **kwargs) 333 return wrapper 334 335def withQNNPACKBackend(fn): 336 # TODO(future PR): consider combining with skipIfNoQNNPACK, 337 # will require testing of existing callsites 338 reason = 'Quantized operations require QNNPACK.' 339 if isinstance(fn, type): 340 if 'qnnpack' not in torch.backends.quantized.supported_engines: 341 fn.__unittest_skip__ = True 342 fn.__unittest_skip_why__ = reason 343 return fn 344 345 @functools.wraps(fn) 346 def wrapper(*args, **kwargs): 347 if 'qnnpack' not in torch.backends.quantized.supported_engines: 348 raise unittest.SkipTest(reason) 349 with override_quantized_engine('qnnpack'): 350 fn(*args, **kwargs) 351 352 return wrapper 353 354def skipIfNoONEDNN(fn): 355 reason = 'Quantized operations require ONEDNN.' 356 if isinstance(fn, type): 357 if 'onednn' not in torch.backends.quantized.supported_engines: 358 fn.__unittest_skip__ = True 359 fn.__unittest_skip_why__ = reason 360 return fn 361 362 @functools.wraps(fn) 363 def wrapper(*args, **kwargs): 364 if 'onednn' not in torch.backends.quantized.supported_engines: 365 raise unittest.SkipTest(reason) 366 else: 367 fn(*args, **kwargs) 368 return wrapper 369 370def skipIfNoONEDNNBF16(fn): 371 reason = 'Quantized operations require BF16 support.' 372 if isinstance(fn, type): 373 if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): 374 fn.__unittest_skip__ = True 375 fn.__unittest_skip_why__ = reason 376 return fn 377 378 @functools.wraps(fn) 379 def wrapper(*args, **kwargs): 380 if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): 381 raise unittest.SkipTest(reason) 382 else: 383 fn(*args, **kwargs) 384 return wrapper 385 386def skipIfNoX86(fn): 387 reason = 'Quantized operations require X86.' 388 if isinstance(fn, type): 389 if 'x86' not in torch.backends.quantized.supported_engines: 390 fn.__unittest_skip__ = True 391 fn.__unittest_skip_why__ = reason 392 return fn 393 394 @functools.wraps(fn) 395 def wrapper(*args, **kwargs): 396 if 'x86' not in torch.backends.quantized.supported_engines: 397 raise unittest.SkipTest(reason) 398 else: 399 fn(*args, **kwargs) 400 return wrapper 401 402def skipIfNoDynamoSupport(fn): 403 reason = "dynamo doesn't support." 404 if isinstance(fn, type): 405 if not torchdynamo.is_dynamo_supported(): 406 fn.__unittest_skip__ = True 407 fn.__unittest_skip_why__ = reason 408 return fn 409 410 @functools.wraps(fn) 411 def wrapper(*args, **kwargs): 412 if not torchdynamo.is_dynamo_supported(): 413 raise unittest.SkipTest(reason) 414 else: 415 fn(*args, **kwargs) 416 return wrapper 417 418def skipIfNoInductorSupport(fn): 419 reason = "inductor doesn't support." 420 if isinstance(fn, type): 421 if not torchdynamo.is_inductor_supported(): 422 fn.__unittest_skip__ = True 423 fn.__unittest_skip_why__ = reason 424 return fn 425 426 @functools.wraps(fn) 427 def wrapper(*args, **kwargs): 428 if not torchdynamo.is_inductor_supported(): 429 raise unittest.SkipTest(reason) 430 else: 431 fn(*args, **kwargs) 432 return wrapper 433 434try: 435 import torchvision # noqa: F401 436 HAS_TORCHVISION = True 437except ImportError: 438 HAS_TORCHVISION = False 439skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 440 441def get_script_module(model, tracing, data): 442 return torch.jit.trace(model, data) if tracing else torch.jit.script(model) 443 444def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): 445 """ 446 Convert lengths to offsets for embedding_bag 447 """ 448 tt = np.zeros((t.shape[0] + 1,), dtype=offset_type) 449 tt[1:] = t 450 tt = torch.from_numpy(np.cumsum(tt, dtype=offset_type)) 451 if use_begin_offset: 452 return tt[:-1] 453 return tt[1:] 454 455 456def _group_quantize_tensor(w, n_bit=4, q_group_size=16): 457 assert w.dim() == 2 458 w = w.transpose(0, 1).contiguous() 459 assert q_group_size > 1 460 assert w.shape[-1] % q_group_size == 0 461 462 to_quant = w.reshape(-1, q_group_size) 463 assert torch.isnan(to_quant).sum() == 0 464 465 max_val = to_quant.amax(dim=1, keepdim=True) 466 min_val = to_quant.amin(dim=1, keepdim=True) 467 max_int = 2 ** n_bit - 1 468 min_int = 0 469 scales = (max_val - min_val).clamp(min=1e-6) / max_int 470 assert torch.isnan(scales).sum() == 0 471 472 zeros = min_val + scales * (2 ** (n_bit - 1)) 473 assert torch.isnan(zeros).sum() == 0 474 475 out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) 476 assert torch.isnan(out).sum() == 0 477 478 out = out.to(dtype=torch.int32).reshape(w.shape) 479 out_uint8 = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) 480 481 # Scales and zeros for the same q-group should be contiguous, so we can 482 # load as a 32-bit word 483 scales = scales.view(w.shape[0], -1) 484 zeros = zeros.view(w.shape[0], -1) 485 scales_and_zeros = ( 486 torch.cat( 487 [ 488 scales.reshape(scales.size(0), scales.size(1), 1), 489 zeros.reshape(zeros.size(0), zeros.size(1), 1), 490 ], 491 2, 492 ).transpose(0, 1).contiguous() 493 ) 494 495 return out_uint8, scales_and_zeros 496 497 498def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 499 # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py 500 # default setup for affine quantization of activations 501 x_dtype = x.dtype 502 x = x.float() 503 eps = torch.finfo(torch.float32).eps 504 505 # get min and max 506 min_val, max_val = torch.aminmax(x, dim=1) 507 508 # calculate scales and zero_points based on min and max 509 # reference: https://fburl.com/code/srbiybme 510 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 511 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 512 device = min_val_neg.device 513 514 # reference: https://fburl.com/code/4wll53rk 515 max_val_pos = torch.max(-min_val_neg, max_val_pos) 516 scales = max_val_pos / (float(quant_max - quant_min) / 2) 517 # ensure scales is the same dtype as the original tensor 518 scales = torch.clamp(scales, min=eps).to(x.dtype) 519 zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 520 521 # quantize based on qmin/qmax/scales/zp 522 x_div = x / scales.unsqueeze(-1) 523 x_round = torch.round(x_div) 524 x_zp = x_round + zero_points.unsqueeze(-1) 525 quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 526 527 return quant, scales.to(x_dtype), zero_points 528 529 530 531# QuantizationTestCase used as a base class for testing quantization on modules 532class QuantizationTestCase(TestCase): 533 def setUp(self): 534 super().setUp() 535 self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] 536 self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)] 537 self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] 538 for _ in range(2)] 539 self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)] 540 for _ in range(2)] 541 self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] 542 for _ in range(2)] 543 self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float), 544 torch.randint(0, 1, (1,), dtype=torch.long)] 545 for _ in range(2)] 546 self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float), 547 torch.randint(0, 1, (1,), dtype=torch.long)] 548 for _ in range(2)] 549 self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float), 550 torch.randint(0, 1, (1,), dtype=torch.long)] 551 for _ in range(2)] 552 553 self.img_data_dict = {1 : self.img_data_1d, 554 2 : self.img_data_2d, 555 3 : self.img_data_3d} 556 557 # Quant types that produce statically quantized ops 558 self.static_quant_types = [QuantType.STATIC, QuantType.QAT] 559 # All quant types for (fx based) graph mode quantization 560 self.all_quant_types = [QuantType.DYNAMIC, QuantType.STATIC, QuantType.QAT] 561 562 def checkNoPrepModules(self, module): 563 r"""Checks the module does not contain child 564 modules for quantization preparation, e.g. 565 quant, dequant and observer 566 """ 567 self.assertFalse(hasattr(module, 'quant')) 568 self.assertFalse(hasattr(module, 'dequant')) 569 570 def checkNoQconfig(self, module): 571 r"""Checks the module does not contain qconfig 572 """ 573 self.assertFalse(hasattr(module, 'qconfig')) 574 575 for child in module.children(): 576 self.checkNoQconfig(child) 577 578 def checkHasPrepModules(self, module): 579 r"""Checks the module contains child 580 modules for quantization preparation, e.g. 581 quant, dequant and observer 582 """ 583 self.assertTrue(hasattr(module, 'module')) 584 self.assertTrue(hasattr(module, 'quant')) 585 self.assertTrue(hasattr(module, 'dequant')) 586 587 def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None): 588 r"""Checks the module or module's leaf descendants 589 have observers in preparation for quantization 590 """ 591 if propagate_qconfig_list is None: 592 propagate_qconfig_list = get_default_qconfig_propagation_list() 593 if prepare_custom_config_dict is None: 594 prepare_custom_config_dict = {} 595 float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) 596 597 # check if a module is a leaf module, ignoring activation_post_process attribute 598 def is_leaf_module(module): 599 submodule_name_count = 0 600 for name, _ in module.named_children(): 601 if name != 'activation_post_process': 602 submodule_name_count += 1 603 return submodule_name_count == 0 604 605 if hasattr(module, 'qconfig') and module.qconfig is not None and \ 606 ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) 607 and type(module) in propagate_qconfig_list) or 608 type(module) in float_to_observed_module_class_mapping.keys()) and \ 609 not isinstance(module, torch.ao.quantization.DeQuantStub): 610 self.assertTrue(hasattr(module, 'activation_post_process'), 611 'module: ' + str(type(module)) + ' do not have observer') 612 # we don't need to check observers for child modules of the 613 # qat modules 614 if type(module) not in get_default_qat_module_mappings().values() and \ 615 type(module) not in float_to_observed_module_class_mapping.values() and \ 616 not isinstance(module, _FusedModule): 617 for child in module.children(): 618 if type(child) in [nn.Dropout]: 619 continue 620 self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) 621 622 def checkQuantDequant(self, mod): 623 r"""Checks that mod has nn.Quantize and 624 nn.DeQuantize submodules inserted 625 """ 626 self.assertEqual(type(mod.quant), nnq.Quantize) 627 self.assertEqual(type(mod.dequant), nnq.DeQuantize) 628 629 def checkWrappedQuantizedLinear(self, mod): 630 r"""Checks that mod has been swapped for an nnq.Linear 631 module, the bias is qint32, and that the module 632 has Quantize and DeQuantize submodules 633 """ 634 self.assertEqual(type(mod.module), nnq.Linear) 635 self.checkQuantDequant(mod) 636 637 def checkQuantizedLinear(self, mod): 638 self.assertEqual(type(mod), nnq.Linear) 639 640 def checkDynamicQuantizedLinear(self, mod, dtype): 641 r"""Checks that mod has been swapped for an nnqd.Linear 642 module, the bias is float. 643 """ 644 self.assertEqual(type(mod), nnqd.Linear) 645 self.assertEqual(mod._packed_params.dtype, dtype) 646 647 def checkDynamicQuantizedLinearRelu(self, mod, dtype): 648 r"""Checks that mod has been swapped for an nnqd.Linear 649 module, the bias is float. 650 """ 651 self.assertEqual(type(mod), nniqd.LinearReLU) 652 self.assertEqual(mod._packed_params.dtype, dtype) 653 654 def check_eager_serialization(self, ref_model, loaded_model, x): 655 # Check state dict serialization and torch.save APIs 656 model_dict = ref_model.state_dict() 657 b = io.BytesIO() 658 torch.save(model_dict, b) 659 b.seek(0) 660 # weights_only=False as we sometimes get a ScriptObect here (weird) 661 loaded_dict = torch.load(b, weights_only=False) 662 loaded_model.load_state_dict(loaded_dict) 663 ref_out = ref_model(*x) 664 load_out = loaded_model(*x) 665 666 def check_outputs(ref_out, load_out): 667 self.assertEqual(ref_out[0], load_out[0]) 668 if isinstance(ref_out[1], tuple): 669 self.assertEqual(ref_out[1][0], load_out[1][0]) 670 self.assertEqual(ref_out[1][1], load_out[1][1]) 671 else: 672 self.assertEqual(ref_out[1], load_out[1]) 673 674 check_outputs(ref_out, load_out) 675 b = io.BytesIO() 676 torch.save(ref_model, b) 677 b.seek(0) 678 # weights_only=False as this is legacy code that saves the model 679 loaded = torch.load(b, weights_only=False) 680 load_out = loaded(*x) 681 check_outputs(ref_out, load_out) 682 683 def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): 684 weight = ref_model.get_weight() 685 bias = ref_model.get_bias() 686 self.assertEqual(weight_keys ^ weight.keys(), set()) 687 self.assertEqual(bias_keys ^ bias.keys(), set()) 688 689 def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): 690 r"""Checks that mod has been swapped for an nnqd.LSTM type 691 module, the bias is float. 692 """ 693 wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} 694 self.assertEqual(type(mod), reference_module_type) 695 for packed_params in mod._all_weight_values: 696 self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) 697 698 def checkLinear(self, mod): 699 self.assertEqual(type(mod), torch.nn.Linear) 700 701 def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): 702 r"""Checks that mod has been swapped for an nnqd.Linear 703 module, the bias is float. 704 """ 705 wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} 706 self.assertEqual(type(mod), reference_module_type) 707 if hasattr(mod, '_all_weight_values'): 708 for packed_params in mod._all_weight_values: 709 self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) 710 711 def checkScriptable(self, orig_mod, calib_data, check_save_load=False): 712 scripted = torch.jit.script(orig_mod) 713 self._checkScriptable(orig_mod, scripted, calib_data, check_save_load) 714 715 # Use first calib_data entry as trace input 716 traced = torch.jit.trace(orig_mod, calib_data[0]) 717 self._checkScriptable(orig_mod, traced, calib_data, check_save_load) 718 719 # Call this twice: once for a scripted module and once for a traced module 720 def _checkScriptable(self, orig_mod, script_mod, calib_data, check_save_load): 721 self._checkModuleCorrectnessAgainstOrig(orig_mod, script_mod, calib_data) 722 723 # Test save/load 724 buffer = io.BytesIO() 725 torch.jit.save(script_mod, buffer) 726 727 buffer.seek(0) 728 loaded_mod = torch.jit.load(buffer) 729 # Pending __get_state_ and __set_state__ support 730 # See tracking task https://github.com/pytorch/pytorch/issues/23984 731 if check_save_load: 732 self._checkModuleCorrectnessAgainstOrig(orig_mod, loaded_mod, calib_data) 733 734 def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): 735 for inp in calib_data: 736 ref_output = orig_mod(*inp) 737 scripted_output = test_mod(*inp) 738 self.assertEqual(scripted_output, ref_output) 739 740 741 def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False, 742 check=True, eval_mode=True, dynamic=False, qconfig=None): 743 if debug: 744 print('Testing:', str(module)) 745 qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} 746 747 if eval_mode: 748 module = module.eval() 749 if dynamic: 750 qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig} 751 model = get_script_module(module, tracing, inputs[0]).eval() 752 if debug: 753 print('input graph:', model.graph) 754 models = {} 755 outputs = {} 756 for debug in [True, False]: 757 if dynamic: 758 models[debug] = quantize_dynamic_jit(model, qconfig_dict, debug=debug) 759 # make sure it runs 760 outputs[debug] = models[debug](inputs) 761 else: 762 # module under test can contain in-place ops, and we depend on 763 # input data staying constant for comparisons 764 inputs_copy = copy.deepcopy(inputs) 765 models[debug] = quantize_jit( 766 model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, 767 debug=debug) 768 # make sure it runs 769 outputs[debug] = models[debug](*inputs[0]) 770 771 if debug: 772 print('debug graph:', models[True].graph) 773 print('non debug graph:', models[False].graph) 774 775 if check: 776 # debug and non-debug option should have the same numerics 777 self.assertEqual(outputs[True], outputs[False]) 778 779 # non debug graph should produce quantized op 780 FileCheck().check(quantized_op) \ 781 .run(models[False].graph) 782 783 return models[False] 784 785 def checkGraphModuleNodes( 786 self, graph_module, 787 expected_node=None, 788 expected_node_occurrence=None, 789 expected_node_list=None): 790 """ Check if GraphModule contains the target node 791 Args: 792 graph_module: the GraphModule instance we want to check 793 expected_node, expected_node_occurrence, expected_node_list: 794 see docs for checkGraphModeFxOp 795 """ 796 nodes_in_graph = {} 797 node_list = [] 798 modules = dict(graph_module.named_modules(remove_duplicate=False)) 799 for node in graph_module.graph.nodes: 800 n = None 801 if node.op == 'call_function' or node.op == 'call_method': 802 n = NodeSpec(node.op, node.target) 803 elif node.op == 'call_module': 804 n = NodeSpec(node.op, type(modules[node.target])) 805 806 if n is not None: 807 node_list.append(n) 808 if n in nodes_in_graph: 809 nodes_in_graph[n] += 1 810 else: 811 nodes_in_graph[n] = 1 812 813 if expected_node is not None: 814 self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) + 815 ' not found in the graph module') 816 817 if expected_node_occurrence is not None: 818 for expected_node, occurrence in expected_node_occurrence.items(): 819 if occurrence != 0: 820 self.assertTrue( 821 expected_node in nodes_in_graph, 822 'Check failed for node:' + str(expected_node) + 823 ' not found') 824 self.assertTrue( 825 nodes_in_graph[expected_node] == occurrence, 826 'Check failed for node:' + str(expected_node) + 827 ' Expected occurrence:' + str(occurrence) + 828 ' Found occurrence:' + str(nodes_in_graph[expected_node])) 829 else: 830 self.assertTrue( 831 expected_node not in nodes_in_graph, 832 'Check failed for node:' + str(expected_node) + 833 ' expected no occurrence but found') 834 835 if expected_node_list is not None: 836 cur_index = 0 837 for n in node_list: 838 if cur_index == len(expected_node_list): 839 return 840 if n == expected_node_list[cur_index]: 841 cur_index += 1 842 self.assertTrue( 843 cur_index == len(expected_node_list), 844 "Check failed for graph:" + 845 self.printGraphModule(graph_module, print_str=False) + 846 "Expected ordered list:" + 847 str(expected_node_list)) 848 849 def printGraphModule(self, graph_module, print_str=True): 850 modules = dict(graph_module.named_modules(remove_duplicate=False)) 851 node_infos = [] 852 for n in graph_module.graph.nodes: 853 node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) 854 if n.op == 'call_module': 855 node_info += ' module type: ' + repr(type(modules[n.target])) 856 node_infos.append(node_info) 857 str_to_print = '\n'.join(node_infos) 858 if print_str: 859 print(str_to_print) 860 return str_to_print 861 862 if HAS_FX: 863 864 def assert_types_for_matched_subgraph_pairs( 865 self, 866 matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]], 867 expected_types: Dict[str, Tuple[Tuple[Callable, Callable], Tuple[Callable, Callable]]], 868 gm_a: GraphModule, 869 gm_b: GraphModule, 870 ) -> None: 871 """ 872 Verifies that the types specified in expected_types match 873 the underlying objects pointed to by the nodes in matched_subgraph_pairs. 874 875 An example successful test case: 876 877 matched_subgraph_pairs = {'x0': (graph_a_conv_0_node, graph_b_conv_0_node)} 878 expected_types = {'x0': (nn.Conv2d, nnq.Conv2d)} 879 880 The function tests for key equivalence, and verifies types with 881 instance checks. 882 """ 883 884 def _get_underlying_op_type( 885 node: Node, gm: GraphModule 886 ) -> Union[Callable, str]: 887 if node.op == 'call_module': 888 mod = getattr(gm, node.target) 889 return type(mod) 890 else: 891 assert node.op in ('call_function', 'call_method') 892 return node.target 893 894 self.assertTrue( 895 len(matched_subgraph_pairs) == len(expected_types), 896 f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}' 897 ) 898 for k, v in expected_types.items(): 899 expected_types_a, expected_types_b = v 900 exp_type_start_a, exp_type_end_a = expected_types_a 901 exp_type_start_b, exp_type_end_b = expected_types_b 902 subgraph_a, subgraph_b = matched_subgraph_pairs[k] 903 904 act_type_start_a = _get_underlying_op_type(subgraph_a.start_node, gm_a) 905 act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) 906 act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) 907 act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) 908 types_match = (exp_type_start_a is act_type_start_a) and \ 909 (exp_type_end_a is act_type_end_a) and \ 910 (exp_type_start_b is act_type_start_b) and \ 911 (exp_type_end_b is act_type_end_b) 912 self.assertTrue( 913 types_match, 914 f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, ' 915 f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}' 916 ) 917 918 def assert_ns_compare_dict_valid( 919 self, 920 act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]], 921 ) -> None: 922 """ 923 Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid: 924 1. for each layer, results are recorded for two models 925 2. number of seen tensors match 926 3. shapes of each pair of seen tensors match 927 """ 928 for layer_name, result_type_to_data in act_compare_dict.items(): 929 for result_type, layer_data in result_type_to_data.items(): 930 self.assertTrue( 931 len(layer_data) == 2, 932 f"Layer {layer_name} does not have exactly two model results.") 933 model_name_0, model_name_1 = layer_data.keys() 934 for res_idx in range(len(layer_data[model_name_0])): 935 layer_data_0 = layer_data[model_name_0][res_idx] 936 layer_data_1 = layer_data[model_name_1][res_idx] 937 self.assertTrue( 938 layer_data_0['type'] == layer_data_0['type'], 939 f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.") 940 941 self.assertTrue( 942 len(layer_data_0['values']) == 943 len(layer_data_1['values']), 944 f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.") 945 946 # F.conv1d weight has rank 3, and toq.conv1d unpacked weight 947 # has rank 4. For now, skip the length check for conv1d only. 948 is_weight_functional_conv1d = ( 949 result_type == NSSingleResultValuesType.WEIGHT.value and 950 ( 951 'conv1d' in layer_data_0['prev_node_target_type'] or 952 'conv1d' in layer_data_1['prev_node_target_type'] 953 ) 954 ) 955 if not is_weight_functional_conv1d: 956 for idx in range(len(layer_data_0['values'])): 957 values_0 = layer_data_0['values'][idx] 958 values_1 = layer_data_1['values'][idx] 959 if isinstance(values_0, torch.Tensor): 960 self.assertTrue( 961 values_0.shape == values_1.shape, 962 f"Layer {layer_name}, {model_name_0} and {model_name_1} " + 963 f"have a shape mismatch at idx {idx}.") 964 elif isinstance(values_0, list): 965 values_0 = values_0[0] 966 values_1 = values_1[0] 967 self.assertTrue( 968 values_0.shape == values_1.shape, 969 f"Layer {layer_name}, {model_name_0} and {model_name_1} " + 970 f"have a shape mismatch at idx {idx}.") 971 else: 972 assert isinstance(values_0, tuple), \ 973 f"unhandled type {type(values_0)}" 974 assert len(values_0) == 2 975 assert len(values_0[1]) == 2 976 assert values_0[0].shape == values_1[0].shape 977 assert values_0[1][0].shape == values_1[1][0].shape 978 assert values_0[1][1].shape == values_1[1][1].shape 979 980 # verify that ref_node_name is valid 981 ref_node_name_0 = layer_data_0['ref_node_name'] 982 ref_node_name_1 = layer_data_1['ref_node_name'] 983 prev_node_name_0 = layer_data_0['prev_node_name'] 984 prev_node_name_1 = layer_data_1['prev_node_name'] 985 if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value: 986 self.assertTrue(ref_node_name_0 == prev_node_name_0) 987 self.assertTrue(ref_node_name_1 == prev_node_name_1) 988 elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value: 989 self.assertTrue(ref_node_name_0 != prev_node_name_0) 990 self.assertTrue(ref_node_name_1 != prev_node_name_1) 991 992 def checkGraphModeFxOp( 993 self, 994 model, 995 inputs, 996 quant_type, 997 expected_node=None, 998 expected_node_occurrence=None, 999 expected_node_list=None, 1000 is_reference=False, 1001 print_debug_info=False, 1002 custom_qconfig_dict=None, 1003 prepare_expected_node=None, 1004 prepare_expected_node_occurrence=None, 1005 prepare_expected_node_list=None, 1006 prepare_custom_config=None, 1007 backend_config=None): 1008 """ Quantizes model with graph mode quantization on fx and check if the 1009 quantized model contains the quantized_node 1010 1011 Args: 1012 model: floating point torch.nn.Module 1013 inputs: one positional sample input arguments for model 1014 expected_node: NodeSpec 1015 e.g. NodeSpec.call_function(torch.quantize_per_tensor) 1016 expected_node_occurrence: a dict from NodeSpec to 1017 expected number of occurrences (int) 1018 e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, 1019 NodeSpec.call_method('dequantize'): 1} 1020 expected_node_list: a list of NodeSpec, used to check the order 1021 of the occurrence of Node 1022 e.g. [NodeSpec.call_function(torch.quantize_per_tensor), 1023 NodeSpec.call_module(nnq.Conv2d), 1024 NodeSpec.call_function(F.hardtanh_), 1025 NodeSpec.call_method('dequantize')] 1026 is_reference: if True, enables reference mode 1027 print_debug_info: if True, prints debug info 1028 custom_qconfig_dict: overrides default qconfig_dict 1029 prepare_expected_node: same as expected_node, but for prepare 1030 prepare_expected_node_occurrence: same as 1031 expected_node_occurrence, but for prepare 1032 prepare_expected_node_list: same as expected_node_list, but 1033 for prepare 1034 1035 Returns: 1036 A dictionary with the following structure: 1037 { 1038 "prepared": ..., # the prepared model 1039 "quantized": ..., # the quantized non-reference model 1040 "quantized_reference": ..., # the quantized reference model 1041 "result": ..., # the result for either quantized or 1042 # quantized_reference model depending on the 1043 # is_reference argument 1044 } 1045 """ 1046 # TODO: make img_data a single example instead of a list 1047 if type(inputs) == list: 1048 inputs = inputs[0] 1049 1050 if quant_type == QuantType.QAT: 1051 qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine) 1052 model.train() 1053 elif quant_type == QuantType.STATIC: 1054 qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine) 1055 model.eval() 1056 else: 1057 qconfig = default_dynamic_qconfig 1058 qconfig_mapping = QConfigMapping().set_global(qconfig) 1059 model.eval() 1060 1061 if quant_type == QuantType.QAT: 1062 prepare = prepare_qat_fx 1063 else: 1064 prepare = prepare_fx 1065 1066 # overwrite qconfig_dict with custom_qconfig_dict 1067 if custom_qconfig_dict is not None: 1068 assert type(custom_qconfig_dict) in (QConfigMapping, dict), \ 1069 'custom_qconfig_dict should be a QConfigMapping or a dict' 1070 if isinstance(custom_qconfig_dict, QConfigMapping): 1071 qconfig_mapping = custom_qconfig_dict 1072 else: 1073 qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) 1074 prepared = prepare( 1075 model, qconfig_mapping, 1076 example_inputs=inputs, 1077 prepare_custom_config=prepare_custom_config, 1078 backend_config=backend_config) 1079 if not quant_type == QuantType.DYNAMIC: 1080 prepared(*inputs) 1081 1082 if print_debug_info: 1083 print() 1084 print('quant type:\n', quant_type) 1085 print('original model:\n', model) 1086 print() 1087 print('prepared model:\n', prepared) 1088 1089 self.checkGraphModuleNodes( 1090 prepared, prepare_expected_node, 1091 prepare_expected_node_occurrence, prepare_expected_node_list) 1092 1093 prepared_copy = copy.deepcopy(prepared) 1094 qgraph = convert_fx(copy.deepcopy(prepared)) 1095 qgraph_reference = convert_to_reference_fx(copy.deepcopy(prepared)) 1096 result = qgraph(*inputs) 1097 result_reference = qgraph_reference(*inputs) 1098 qgraph_copy = copy.deepcopy(qgraph) 1099 qgraph_reference_copy = copy.deepcopy(qgraph_reference) 1100 1101 qgraph_to_check = qgraph_reference if is_reference else qgraph 1102 if print_debug_info: 1103 print() 1104 print('quantized model:\n', qgraph_to_check) 1105 self.printGraphModule(qgraph_to_check) 1106 print() 1107 self.checkGraphModuleNodes( 1108 qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) 1109 return {"prepared": prepared_copy, 1110 "quantized": qgraph_copy, 1111 "quantized_reference": qgraph_reference_copy, 1112 "quantized_output": result, 1113 "quantized_reference_output": result_reference} 1114 1115 1116 def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, 1117 set_qconfig, is_emb_bag, dtype=torch.quint8): 1118 # Test serialization of dynamic EmbeddingBag module using state_dict 1119 if is_emb_bag: 1120 inputs = [indices, offsets] 1121 else: 1122 inputs = [indices] 1123 emb_dict = qemb.state_dict() 1124 b = io.BytesIO() 1125 torch.save(emb_dict, b) 1126 b.seek(0) 1127 loaded_dict = torch.load(b) 1128 embedding_unpack = torch.ops.quantized.embedding_bag_unpack 1129 # Check unpacked weight values explicitly 1130 for key in emb_dict: 1131 if isinstance(emb_dict[key], torch._C.ScriptObject): 1132 assert isinstance(loaded_dict[key], torch._C.ScriptObject) 1133 emb_weight = embedding_unpack(emb_dict[key]) 1134 loaded_weight = embedding_unpack(loaded_dict[key]) 1135 self.assertEqual(emb_weight, loaded_weight) 1136 1137 # Check state dict serialization and torch.save APIs 1138 if is_emb_bag: 1139 loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, 1140 include_last_offset=True, mode='sum', dtype=dtype) 1141 else: 1142 loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) 1143 self.check_eager_serialization(qemb, loaded_qemb, inputs) 1144 1145 loaded_qemb.load_state_dict(loaded_dict) 1146 self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight), 1147 embedding_unpack(loaded_qemb._packed_params._packed_weight)) 1148 1149 1150 # Test JIT serialization 1151 self.checkScriptable(qemb, [inputs], check_save_load=True) 1152 1153 # Test from_float call 1154 if is_emb_bag: 1155 float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, 1156 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 1157 else: 1158 float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 1159 1160 if set_qconfig: 1161 float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, 1162 qscheme=torch.per_channel_affine_float_qparams, 1163 ch_axis=0) 1164 float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer, 1165 weight=float_qparams_observer) 1166 1167 prepare_dynamic(float_embedding) 1168 1169 float_embedding(*inputs) 1170 if is_emb_bag: 1171 q_embeddingbag = nnq.EmbeddingBag.from_float(float_embedding) 1172 expected_name = "QuantizedEmbeddingBag" 1173 else: 1174 q_embeddingbag = nnq.Embedding.from_float(float_embedding) 1175 expected_name = "QuantizedEmbedding" 1176 1177 q_embeddingbag(*inputs) 1178 1179 self.assertTrue(expected_name in str(q_embeddingbag)) 1180 1181class QuantizationLiteTestCase(QuantizationTestCase): 1182 def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs): 1183 # Creates quantized model for testing mobile script modules 1184 qengine = "qnnpack" 1185 with override_quantized_engine(qengine): 1186 qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1187 model = model_class(**kwargs) 1188 model = quantize(model, test_only_eval_fn, [self.calib_data]) 1189 1190 return model 1191 1192 def _compare_script_and_mobile(self, 1193 model: torch.nn.Module, 1194 input: torch.Tensor): 1195 # Compares the numerical outputs for script and lite modules 1196 qengine = "qnnpack" 1197 with override_quantized_engine(qengine): 1198 script_module = torch.jit.script(model) 1199 script_module_result = script_module(input) 1200 1201 max_retry = 5 1202 for retry in range(1, max_retry + 1): 1203 # retries `max_retry` times; breaks iff succeeds else throws exception 1204 try: 1205 buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) 1206 buffer.seek(0) 1207 mobile_module = _load_for_lite_interpreter(buffer) 1208 1209 mobile_module_result = mobile_module(input) 1210 1211 torch.testing.assert_close(script_module_result, mobile_module_result) 1212 mobile_module_forward_result = mobile_module.forward(input) 1213 torch.testing.assert_close(script_module_result, mobile_module_forward_result) 1214 1215 mobile_module_run_method_result = mobile_module.run_method("forward", input) 1216 torch.testing.assert_close(script_module_result, mobile_module_run_method_result) 1217 except AssertionError as e: 1218 if retry == max_retry: 1219 raise e 1220 else: 1221 continue 1222 break 1223 1224 1225class PT2EQuantizationTestCase(QuantizationTestCase): 1226 """ 1227 Base QuantizationTestCase for PT2 with some helper methods. 1228 """ 1229 _MAP_TO_FX_TRACED_OPS = { 1230 torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, 1231 torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1232 torch.ops.quantized_decomposed.quantize_per_channel: torch.ops.quantized_decomposed.quantize_per_channel.default, 1233 torch.ops.quantized_decomposed.dequantize_per_channel: torch.ops.quantized_decomposed.dequantize_per_channel.default, 1234 torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 1235 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 1236 } 1237 1238 def _test_quantizer( 1239 self, 1240 model, 1241 example_inputs, 1242 quantizer, 1243 expected_node_occurrence, 1244 expected_node_list=None, 1245 check_against_fx_quant=False, 1246 fx_qconfig_mapping=None, 1247 export_with_dynamic_shape=False, 1248 is_qat=False, 1249 is_debug_mode=False, 1250 capture_pre_autograd_graph_node_occurrence=None, 1251 ): 1252 # resetting dynamo cache 1253 torch._dynamo.reset() 1254 m_eager = model.eval() 1255 1256 # program capture 1257 m = copy.deepcopy(m_eager) 1258 dynamic_shapes = tuple( 1259 {0: torch.export.Dim("dim")} if i == 0 else None 1260 for i in range(len(example_inputs)) 1261 ) 1262 m = capture_pre_autograd_graph( 1263 m, 1264 example_inputs, 1265 dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, 1266 ) 1267 1268 if is_qat: 1269 m = prepare_qat_pt2e(m, quantizer) 1270 else: 1271 m = prepare_pt2e(m, quantizer) 1272 # Calibrate 1273 m(*example_inputs) 1274 m = convert_pt2e(m) 1275 if is_debug_mode: 1276 print("quantized model", m) 1277 1278 pt2_quant_output = m(*example_inputs) 1279 ns = NodeSpec 1280 node_occurrence = { 1281 ns.call_function(k): v for k, v in expected_node_occurrence.items() 1282 } 1283 if expected_node_list is None: 1284 expected_node_list = [] 1285 node_list = [ns.call_function(n) for n in expected_node_list] 1286 self.checkGraphModuleNodes( 1287 m, expected_node_occurrence=node_occurrence, expected_node_list=node_list 1288 ) 1289 if check_against_fx_quant: 1290 qconfig_mapping = fx_qconfig_mapping 1291 backend_config = get_executorch_backend_config() 1292 m_copy = copy.deepcopy(m_eager) 1293 m_fx = prepare_fx( 1294 m_copy, qconfig_mapping, example_inputs, backend_config=backend_config 1295 ) 1296 m_fx(*example_inputs) 1297 m_fx = _convert_to_reference_decomposed_fx( 1298 m_fx, backend_config=backend_config 1299 ) 1300 m_fx = capture_pre_autograd_graph( 1301 m_fx, 1302 example_inputs, 1303 dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, 1304 ) 1305 node_occurrence = {} 1306 for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): 1307 if k in expected_node_occurrence: 1308 node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] 1309 if capture_pre_autograd_graph_node_occurrence is not None: 1310 node_occurrence = { 1311 ns.call_function(k): v for k, v in capture_pre_autograd_graph_node_occurrence.items() 1312 } 1313 self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) 1314 fx_quant_output = m_fx(*example_inputs) 1315 self.assertEqual(fx_quant_output, pt2_quant_output) 1316 return m 1317 1318 def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): 1319 # resetting dynamo cache 1320 torch._dynamo.reset() 1321 1322 m = capture_pre_autograd_graph( 1323 m, 1324 example_inputs, 1325 ) 1326 if is_qat: 1327 m = prepare_qat_pt2e(m, quantizer) 1328 else: 1329 m = prepare_pt2e(m, quantizer) 1330 m(*example_inputs) 1331 m = convert_pt2e(m) 1332 return m 1333 1334 def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: 1335 class M(torch.nn.Module): 1336 def __init__(self) -> None: 1337 super().__init__() 1338 self.linear = torch.nn.Linear(2, 2) 1339 1340 def forward(self, x): 1341 return self.linear(x) 1342 1343 quantizer = XNNPACKQuantizer() 1344 operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel) 1345 quantizer.set_global(operator_config) 1346 example_inputs = (torch.randn(2, 2),) 1347 m = M().eval() 1348 return self._quantize(m, quantizer, example_inputs) 1349 1350# Below are a series of toy models to use in testing quantization 1351 1352class SingleLayerLinearModel(torch.nn.Module): 1353 def __init__(self) -> None: 1354 super().__init__() 1355 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1356 1357 def forward(self, x): 1358 x = self.fc1(x) 1359 return x 1360 1361 def get_example_inputs(self) -> Tuple[Any, ...]: 1362 return (torch.rand(1, 5),) 1363 1364class AnnotatedSingleLayerLinearModel(torch.nn.Module): 1365 def __init__(self, qengine='fbgemm'): 1366 super().__init__() 1367 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1368 self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) 1369 1370 def forward(self, x): 1371 x = self.fc1(x) 1372 return x 1373 1374 def get_example_inputs(self) -> Tuple[Any, ...]: 1375 return (torch.rand(1, 5),) 1376 1377class SingleLayerLinearDynamicModel(torch.nn.Module): 1378 def __init__(self, qengine='fbgemm'): 1379 super().__init__() 1380 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1381 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1382 1383 def forward(self, x): 1384 x = self.fc1(x) 1385 return x 1386 1387 def get_example_inputs(self) -> Tuple[Any, ...]: 1388 return (torch.rand(1, 5),) 1389 1390class LinearAddModel(nn.Module): 1391 def __init__(self) -> None: 1392 super().__init__() 1393 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1394 self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) 1395 1396 def forward(self, x): 1397 x = self.fc1(x) 1398 x = torch.add(x, 5) 1399 x = self.fc2(x) 1400 return x 1401 1402 def get_example_inputs(self) -> Tuple[Any, ...]: 1403 return (torch.rand(1, 5),) 1404 1405class RNNDynamicModel(torch.nn.Module): 1406 def __init__(self, mod_type): 1407 super().__init__() 1408 self.qconfig = default_dynamic_qconfig 1409 if mod_type == 'GRU': 1410 self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) 1411 if mod_type == 'LSTM': 1412 self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) 1413 1414 def forward(self, x): 1415 x = self.mod(x) 1416 return x 1417 1418class RNNCellDynamicModel(torch.nn.Module): 1419 def __init__(self, mod_type): 1420 super().__init__() 1421 self.qconfig = default_dynamic_qconfig 1422 if mod_type == 'GRUCell': 1423 self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) 1424 if mod_type == 'LSTMCell': 1425 self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) 1426 if mod_type == 'RNNReLU': 1427 self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float) 1428 if mod_type == 'RNNTanh': 1429 self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float) 1430 1431 def forward(self, x): 1432 x = self.mod(x) 1433 return x 1434 1435class LSTMwithHiddenDynamicModel(torch.nn.Module): 1436 def __init__(self, qengine='fbgemm'): 1437 super().__init__() 1438 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1439 self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) 1440 1441 def forward(self, x, hid): 1442 x, hid = self.lstm(x, hid) 1443 return x, hid 1444 1445class ConvModel(torch.nn.Module): 1446 def __init__(self) -> None: 1447 super().__init__() 1448 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1449 1450 def forward(self, x): 1451 x = self.conv(x) 1452 return x 1453 1454 def get_example_inputs(self) -> Tuple[Any, ...]: 1455 return (torch.rand(1, 3, 5, 5),) 1456 1457class ConvTransposeModel(torch.nn.Module): 1458 def __init__(self) -> None: 1459 super().__init__() 1460 self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) 1461 1462 def forward(self, x): 1463 x = self.conv(x) 1464 return x 1465 1466 def get_example_inputs(self) -> Tuple[Any, ...]: 1467 return (torch.rand(1, 3, 5, 5),) 1468 1469class AnnotatedConvModel(torch.nn.Module): 1470 def __init__(self, qengine): 1471 super().__init__() 1472 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1473 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1474 self.quant = QuantStub() 1475 self.dequant = DeQuantStub() 1476 1477 def forward(self, x): 1478 x = self.quant(x) 1479 x = self.conv(x) 1480 x = self.dequant(x) 1481 return x 1482 1483 def get_example_inputs(self) -> Tuple[Any, ...]: 1484 return (torch.rand(1, 3, 5, 5),) 1485 1486class AnnotatedConvTransposeModel(torch.nn.Module): 1487 def __init__(self, qengine): 1488 super().__init__() 1489 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1490 self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float) 1491 self.quant = QuantStub() 1492 self.dequant = DeQuantStub() 1493 1494 def forward(self, x): 1495 x = self.quant(x) 1496 x = self.conv(x) 1497 x = self.dequant(x) 1498 return x 1499 1500 def get_example_inputs(self) -> Tuple[Any, ...]: 1501 return (torch.rand(1, 3, 5, 5),) 1502 1503class ConvBnModel(torch.nn.Module): 1504 def __init__(self) -> None: 1505 super().__init__() 1506 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1507 self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) 1508 1509 def forward(self, x): 1510 x = self.conv(x) 1511 x = self.bn(x) 1512 return x 1513 1514 def get_example_inputs(self) -> Tuple[Any, ...]: 1515 return (torch.rand(1, 3, 5, 5),) 1516 1517class AnnotatedConvBnModel(torch.nn.Module): 1518 def __init__(self) -> None: 1519 super().__init__() 1520 self.qconfig = default_qconfig 1521 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1522 self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) 1523 self.quant = QuantStub() 1524 self.dequant = DeQuantStub() 1525 1526 def forward(self, x): 1527 x = self.quant(x) 1528 x = self.conv(x) 1529 x = self.bn(x) 1530 x = self.dequant(x) 1531 return x 1532 1533 def get_example_inputs(self) -> Tuple[Any, ...]: 1534 return (torch.rand(1, 3, 5, 5),) 1535 1536class ConvBnReLUModel(torch.nn.Module): 1537 def __init__(self) -> None: 1538 super().__init__() 1539 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1540 self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) 1541 self.relu = nn.ReLU(inplace=True) 1542 1543 def forward(self, x): 1544 x = self.conv(x) 1545 x = self.bn(x) 1546 x = self.relu(x) 1547 return x 1548 1549 def get_example_inputs(self) -> Tuple[Any, ...]: 1550 return (torch.rand(1, 3, 5, 5),) 1551 1552class AnnotatedConvBnReLUModel(torch.nn.Module): 1553 def __init__(self, qengine='fbgemm'): 1554 super().__init__() 1555 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 1556 self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1557 self.bn = torch.nn.BatchNorm2d(5).to(dtype=torch.float) 1558 self.relu = nn.ReLU(inplace=True) 1559 self.quant = QuantStub() 1560 self.dequant = DeQuantStub() 1561 1562 def forward(self, x): 1563 x = self.quant(x) 1564 x = self.conv(x) 1565 x = self.bn(x) 1566 x = self.relu(x) 1567 x = self.dequant(x) 1568 return x 1569 1570 def fuse_model(self): 1571 # TODO: remove this check and define two fuse_modules function on this module 1572 if self.training: 1573 torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True) 1574 else: 1575 torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) 1576 1577 def get_example_inputs(self) -> Tuple[Any, ...]: 1578 return (torch.rand(1, 3, 5, 5),) 1579 1580class TwoLayerConvModel(torch.nn.Module): 1581 def __init__(self) -> None: 1582 super().__init__() 1583 self.conv1 = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) 1584 self.conv2 = torch.nn.Conv2d(5, 5, 1, bias=False).to(dtype=torch.float) 1585 1586 def forward(self, x): 1587 x = self.conv1(x) 1588 x = self.conv2(x) 1589 return x 1590 1591 def get_example_inputs(self) -> Tuple[Any, ...]: 1592 return (torch.rand(1, 3, 5, 5),) 1593 1594class TwoLayerLinearModel(torch.nn.Module): 1595 def __init__(self) -> None: 1596 super().__init__() 1597 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1598 self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) 1599 1600 def forward(self, x): 1601 x = self.fc1(x) 1602 x = self.fc2(x) 1603 return x 1604 1605 def get_example_inputs(self) -> Tuple[Any, ...]: 1606 return (torch.rand(1, 5),) 1607 1608class LinearModelWithSubmodule(nn.Module): 1609 def __init__(self) -> None: 1610 super().__init__() 1611 self.subm = TwoLayerLinearModel() 1612 self.fc = nn.Linear(5, 5) 1613 1614 def forward(self, x): 1615 x = self.subm(x) 1616 x = self.fc(x) 1617 return x 1618 1619 def get_example_inputs(self) -> Tuple[Any, ...]: 1620 return self.subm.get_example_inputs() 1621 1622class AnnotatedTwoLayerLinearModel(torch.nn.Module): 1623 def __init__(self) -> None: 1624 super().__init__() 1625 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1626 self.fc2 = QuantWrapper(torch.nn.Linear(8, 5).to(dtype=torch.float)) 1627 self.fc2.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 1628 1629 def forward(self, x): 1630 x = self.fc1(x) 1631 x = self.fc2(x) 1632 return x 1633 1634 def get_example_inputs(self) -> Tuple[Any, ...]: 1635 return (torch.rand(1, 5),) 1636 1637class ActivationsTestModel(torch.nn.Module): 1638 def __init__(self) -> None: 1639 super().__init__() 1640 self.qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 1641 self.quant = torch.ao.quantization.QuantStub() 1642 self.hardswish = torch.nn.Hardswish().to(dtype=torch.float) 1643 self.elu = torch.nn.ELU().to(dtype=torch.float) 1644 self.dequant = torch.ao.quantization.DeQuantStub() 1645 1646 def forward(self, x): 1647 x = self.quant(x) 1648 x = self.hardswish(x) 1649 x = self.elu(x) 1650 x = self.dequant(x) 1651 return x 1652 1653class LinearReluModel(torch.nn.Module): 1654 def __init__(self) -> None: 1655 super().__init__() 1656 self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) 1657 self.relu = torch.nn.ReLU() 1658 1659 def forward(self, x): 1660 x = self.relu(self.fc(x)) 1661 return x 1662 1663 def get_example_inputs(self) -> Tuple[Any, ...]: 1664 return (torch.rand(1, 5),) 1665 1666 1667class LinearReluLinearModel(torch.nn.Module): 1668 def __init__(self) -> None: 1669 super().__init__() 1670 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1671 self.relu = torch.nn.ReLU() 1672 self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) 1673 1674 def forward(self, x): 1675 x = self.fc1(x) 1676 x = self.relu(x) 1677 x = self.fc2(x) 1678 return x 1679 1680 def get_example_inputs(self) -> Tuple[Any, ...]: 1681 return (torch.rand(1, 5),) 1682 1683class LinearReluAddModel(torch.nn.Module): 1684 def __init__(self) -> None: 1685 super().__init__() 1686 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1687 self.relu = torch.nn.ReLU() 1688 self.fc2 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1689 1690 def forward(self, x): 1691 x = self.fc1(x) 1692 x = self.relu(x) 1693 x = torch.add(x, 5) 1694 x = self.fc2(x) 1695 self.relu = torch.nn.ReLU() 1696 return x 1697 1698 def get_example_inputs(self) -> Tuple[Any, ...]: 1699 return (torch.rand(1, 5),) 1700 1701class LinearBnLeakyReluModel(torch.nn.Module): 1702 def __init__(self, with_bn=True): 1703 super().__init__() 1704 self.linear = nn.Linear(5, 5) 1705 self.bn1d = nn.BatchNorm1d(5) 1706 self.leaky_relu = nn.LeakyReLU(0.01) 1707 self.with_bn = with_bn 1708 1709 def forward(self, x): 1710 x = self.linear(x) 1711 if self.with_bn: 1712 x = self.bn1d(x) 1713 x = self.leaky_relu(x) 1714 return x 1715 1716 def get_example_inputs(self) -> Tuple[Any, ...]: 1717 return (torch.rand(1, 5),) 1718 1719class LinearTanhModel(torch.nn.Module): 1720 def __init__(self) -> None: 1721 super().__init__() 1722 self.linear = nn.Linear(5, 5) 1723 self.tanh = nn.Tanh() 1724 1725 def forward(self, x): 1726 x = self.linear(x) 1727 x = self.tanh(x) 1728 return x 1729 1730 def get_example_inputs(self) -> Tuple[Any, ...]: 1731 return (torch.rand(1, 5),) 1732 1733class ConvBnAddReluModel(torch.nn.Module): 1734 def __init__(self, 1735 with_bn=True, 1736 with_relu=True, 1737 left_conv=True, 1738 two_conv=True, 1739 use_torch_add=True): 1740 super().__init__() 1741 self.conv = nn.Conv2d(5, 5, (2, 2)) 1742 self.conv2 = nn.Conv2d(5, 5, (2, 2)) 1743 self.bn = nn.BatchNorm2d(5) 1744 self.relu = nn.ReLU() 1745 self.with_bn = with_bn 1746 self.with_relu = with_relu 1747 self.two_conv = two_conv 1748 self.left_conv = left_conv 1749 self.use_torch_add = use_torch_add 1750 1751 def forward(self, x1, x2): 1752 if self.two_conv: 1753 if self.use_torch_add: 1754 if self.with_bn: 1755 x = torch.add(self.bn(self.conv(x1)), self.conv2(x1)) 1756 else: 1757 x = torch.add(self.conv(x1), self.conv2(x1)) 1758 else: 1759 if self.with_bn: 1760 x = self.bn(self.conv(x1)) + self.conv2(x1) 1761 else: 1762 x = self.conv(x1) + self.conv2(x1) 1763 else: 1764 if self.use_torch_add: 1765 if self.left_conv: 1766 if self.with_bn: 1767 x = torch.add(self.bn(self.conv(x1)), x2) 1768 else: 1769 x = torch.add(self.conv(x1), x2) 1770 else: 1771 if self.with_bn: 1772 x = torch.add(x2, self.bn(self.conv(x1))) 1773 else: 1774 x = torch.add(x2, self.conv(x1)) 1775 else: 1776 if self.left_conv: 1777 if self.with_bn: 1778 x = self.bn(self.conv(x1)) + x2 1779 else: 1780 x = self.conv(x1) + x2 1781 else: 1782 if self.with_bn: 1783 x = x2 + self.bn(self.conv(x1)) 1784 else: 1785 x = x2 + self.conv(x1) 1786 if self.with_relu: 1787 x = self.relu(x) 1788 return x 1789 1790 def get_example_inputs(self) -> Tuple[Any, ...]: 1791 return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) 1792 1793# TODO: self.fc should be self.conv 1794class ConvReluModel(torch.nn.Module): 1795 def __init__(self) -> None: 1796 super().__init__() 1797 self.fc = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) 1798 self.relu = torch.nn.ReLU() 1799 1800 def forward(self, x): 1801 x = self.relu(self.fc(x)) 1802 return x 1803 1804 def get_example_inputs(self) -> Tuple[Any, ...]: 1805 return (torch.rand(1, 3, 5, 5),) 1806 1807# TODO: self.fc should be self.conv 1808class ConvReluConvModel(torch.nn.Module): 1809 def __init__(self) -> None: 1810 super().__init__() 1811 self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) 1812 self.relu = torch.nn.ReLU() 1813 self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) 1814 1815 def forward(self, x): 1816 x = self.fc1(x) 1817 x = self.relu(x) 1818 x = self.fc2(x) 1819 return x 1820 1821 def get_example_inputs(self) -> Tuple[Any, ...]: 1822 return (torch.rand(1, 3, 5, 5),) 1823 1824# TODO: self.fc should be self.conv 1825class ConvReluAddModel(torch.nn.Module): 1826 def __init__(self) -> None: 1827 super().__init__() 1828 self.fc1 = torch.nn.Conv2d(3, 5, 3).to(dtype=torch.float) 1829 self.relu = torch.nn.ReLU() 1830 self.fc2 = torch.nn.Conv2d(5, 5, 1).to(dtype=torch.float) 1831 1832 def forward(self, x): 1833 x = self.fc1(x) 1834 x = self.relu(x) 1835 x = torch.add(x, 5) 1836 x = self.fc2(x) 1837 self.relu = torch.nn.ReLU() 1838 return x 1839 1840 def get_example_inputs(self) -> Tuple[Any, ...]: 1841 return (torch.rand(1, 3, 5, 5),) 1842 1843class NormalizationTestModel(torch.nn.Module): 1844 def __init__(self) -> None: 1845 super().__init__() 1846 self.quant = torch.ao.quantization.QuantStub() 1847 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1848 self.layer_norm = torch.nn.LayerNorm(8) 1849 self.group_norm = torch.nn.GroupNorm(2, 8) 1850 self.instance_norm1d = torch.nn.InstanceNorm1d(8) 1851 self.instance_norm2d = torch.nn.InstanceNorm2d(8) 1852 self.instance_norm3d = torch.nn.InstanceNorm3d(8) 1853 1854 def forward(self, x): 1855 x = self.quant(x) 1856 x = self.fc1(x) 1857 x = self.layer_norm(x) 1858 x = self.group_norm(x.unsqueeze(-1).repeat(1, 1, 3)) 1859 x = self.instance_norm1d(x) 1860 x = self.instance_norm2d(x.unsqueeze(-1)) 1861 x = self.instance_norm3d(x.unsqueeze(-1)) 1862 return x 1863 1864class NestedModel(torch.nn.Module): 1865 def __init__(self) -> None: 1866 super().__init__() 1867 self.sub1 = LinearReluModel() 1868 self.sub2 = TwoLayerLinearModel() 1869 self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1870 1871 def forward(self, x): 1872 x = self.sub1(x) 1873 x = self.sub2(x) 1874 x = self.fc3(x) 1875 return x 1876 1877class AnnotatedNestedModel(torch.nn.Module): 1878 def __init__(self, qengine): 1879 super().__init__() 1880 self.sub1 = LinearReluModel() 1881 self.sub2 = TwoLayerLinearModel() 1882 self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) 1883 self.fc3.qconfig = default_qconfig 1884 self.sub2.fc1 = QuantWrapper(self.sub2.fc1) 1885 if qengine == 'fbgemm': 1886 self.sub2.fc1.qconfig = default_per_channel_qconfig 1887 else: 1888 self.sub2.fc1.qconfig = default_qconfig 1889 1890 def forward(self, x): 1891 x = self.sub1(x) 1892 x = self.sub2(x) 1893 x = self.fc3(x) 1894 return x 1895 1896class AnnotatedSubNestedModel(torch.nn.Module): 1897 def __init__(self) -> None: 1898 super().__init__() 1899 self.sub1 = LinearReluModel() 1900 self.sub2 = QuantWrapper(TwoLayerLinearModel()) 1901 self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) 1902 self.fc3.qconfig = default_qconfig 1903 self.sub2.qconfig = default_qconfig 1904 1905 def forward(self, x): 1906 x = self.sub1(x) 1907 x = self.sub2(x) 1908 x = self.fc3(x) 1909 return x 1910 1911class AnnotatedCustomConfigNestedModel(torch.nn.Module): 1912 def __init__(self) -> None: 1913 super().__init__() 1914 self.sub1 = LinearReluModel() 1915 self.sub2 = TwoLayerLinearModel() 1916 self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) 1917 self.fc3.qconfig = default_qconfig 1918 self.sub2.qconfig = default_qconfig 1919 1920 custom_options = { 1921 'dtype': torch.quint8, 1922 'qscheme': torch.per_tensor_affine 1923 } 1924 custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options), 1925 weight=default_weight_observer) 1926 self.sub2.fc1.qconfig = custom_qconfig 1927 1928 self.sub2.fc1 = QuantWrapper(self.sub2.fc1) 1929 self.sub2.fc2 = QuantWrapper(self.sub2.fc2) 1930 1931 def forward(self, x): 1932 x = self.sub1(x) 1933 x = self.sub2(x) 1934 x = self.fc3(x) 1935 return x 1936 1937class QuantSubModel(torch.nn.Module): 1938 def __init__(self) -> None: 1939 super().__init__() 1940 self.sub1 = LinearReluModel() 1941 self.sub2 = QuantWrapper(TwoLayerLinearModel()) 1942 self.sub2.qconfig = default_qconfig 1943 self.fc3 = torch.nn.Linear(5, 5).to(dtype=torch.float) 1944 self.fc3.qconfig = default_qconfig 1945 1946 def forward(self, x): 1947 x = self.sub1(x) 1948 x = self.sub2(x) 1949 x = self.fc3(x) 1950 return x 1951 1952class InnerModule(torch.nn.Module): 1953 def __init__(self) -> None: 1954 super().__init__() 1955 self.fc1 = torch.nn.Linear(5, 8).to(dtype=torch.float) 1956 self.relu1 = torch.nn.ReLU() 1957 self.fc2 = torch.nn.Linear(8, 5).to(dtype=torch.float) 1958 self.relu2 = torch.nn.ReLU() 1959 1960 def forward(self, x): 1961 return self.relu2(self.fc2(self.relu1(self.fc1(x)))) 1962 1963 def fuse_modules(self): 1964 fusable_layers = [] 1965 named_children = list(self.named_children()) 1966 for idx, (current_name, layer) in enumerate(named_children): 1967 if isinstance(layer, torch.nn.Linear): 1968 if idx >= len(named_children) - 1: 1969 break 1970 if isinstance(named_children[idx + 1][1], torch.nn.ReLU): 1971 fusable_layers.append([current_name, 1972 named_children[idx + 1][0]]) 1973 # TODO: remove this check and define two fuse_modules function on this module 1974 if self.training: 1975 torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) 1976 else: 1977 torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) 1978 1979class FunctionalLinear(torch.nn.Module): 1980 def __init__(self) -> None: 1981 super().__init__() 1982 self.weight = torch.rand((5, 5)) 1983 self.bias = torch.zeros(5) 1984 1985 def forward(self, x): 1986 return F.linear(x, self.weight, self.bias) 1987 1988 def get_example_inputs(self) -> Tuple[Any, ...]: 1989 return (torch.rand(1, 5),) 1990 1991class SingleLayerFunctionalLinearModel(torch.nn.Module): 1992 def __init__(self) -> None: 1993 super().__init__() 1994 self.linear1 = FunctionalLinear() 1995 1996 def forward(self, x): 1997 x = self.linear1(x) 1998 return x 1999 2000 def get_example_inputs(self) -> Tuple[Any, ...]: 2001 return self.linear1.get_example_inputs() 2002 2003class TwoLayerFunctionalLinearModel(torch.nn.Module): 2004 def __init__(self) -> None: 2005 super().__init__() 2006 self.linear1 = FunctionalLinear() 2007 self.linear2 = FunctionalLinear() 2008 2009 def forward(self, x): 2010 x = self.linear1(x) 2011 x = self.linear2(x) 2012 return x 2013 2014 def get_example_inputs(self) -> Tuple[Any, ...]: 2015 return self.linear1.get_example_inputs() 2016 2017class FunctionalLinearAddModel(torch.nn.Module): 2018 def __init__(self) -> None: 2019 super().__init__() 2020 self.linear1 = FunctionalLinear() 2021 self.linear2 = FunctionalLinear() 2022 2023 def forward(self, x): 2024 x = self.linear1(x) 2025 x = torch.add(x, 5) 2026 x = self.linear2(x) 2027 return x 2028 2029 def get_example_inputs(self) -> Tuple[Any, ...]: 2030 return self.linear1.get_example_inputs() 2031 2032class FunctionalLinearReluModel(nn.Module): 2033 def __init__(self) -> None: 2034 super().__init__() 2035 self.linear = FunctionalLinear() 2036 2037 def forward(self, x): 2038 x = self.linear(x) 2039 x = F.relu(x) 2040 return x 2041 2042 def get_example_inputs(self) -> Tuple[Any, ...]: 2043 return self.linear.get_example_inputs() 2044 2045class FunctionalLinearReluLinearModel(nn.Module): 2046 def __init__(self) -> None: 2047 super().__init__() 2048 self.linear1 = FunctionalLinear() 2049 self.relu = nn.ReLU() 2050 self.linear2 = FunctionalLinear() 2051 2052 def forward(self, x): 2053 x = self.linear1(x) 2054 x = self.relu(x) 2055 x = self.linear2(x) 2056 return x 2057 2058 def get_example_inputs(self) -> Tuple[Any, ...]: 2059 return self.linear1.get_example_inputs() 2060 2061class FunctionalConv2d(torch.nn.Module): 2062 def __init__(self) -> None: 2063 super().__init__() 2064 self.weight = torch.rand(3, 3, 3, 3) 2065 self.bias = torch.rand(3) 2066 self.stride = (1, 1) 2067 self.padding = (0, 0) 2068 self.dilation = (1, 1) 2069 self.groups = 1 2070 2071 def forward(self, x): 2072 return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 2073 2074 def get_example_inputs(self) -> Tuple[Any, ...]: 2075 return (torch.rand(1, 3, 5, 5),) 2076 2077class SingleLayerFunctionalConvModel(torch.nn.Module): 2078 def __init__(self) -> None: 2079 super().__init__() 2080 self.conv1 = FunctionalConv2d() 2081 2082 def forward(self, x): 2083 x = self.conv1(x) 2084 return x 2085 2086 def get_example_inputs(self) -> Tuple[Any, ...]: 2087 return self.conv1.get_example_inputs() 2088 2089class TwoLayerFunctionalConvModel(torch.nn.Module): 2090 def __init__(self) -> None: 2091 super().__init__() 2092 self.conv1 = FunctionalConv2d() 2093 self.conv2 = FunctionalConv2d() 2094 2095 def forward(self, x): 2096 x = self.conv1(x) 2097 x = self.conv2(x) 2098 return x 2099 2100 def get_example_inputs(self) -> Tuple[Any, ...]: 2101 return self.conv1.get_example_inputs() 2102 2103class FunctionalConvReluModel(nn.Module): 2104 def __init__(self) -> None: 2105 super().__init__() 2106 self.conv = FunctionalConv2d() 2107 2108 def forward(self, x): 2109 x = self.conv(x) 2110 x = F.relu(x) 2111 return x 2112 2113 def get_example_inputs(self) -> Tuple[Any, ...]: 2114 return self.conv.get_example_inputs() 2115 2116class FunctionalConvReluConvModel(nn.Module): 2117 def __init__(self) -> None: 2118 super().__init__() 2119 self.conv1 = FunctionalConv2d() 2120 self.relu = nn.ReLU() 2121 self.conv2 = FunctionalConv2d() 2122 2123 def forward(self, x): 2124 x = self.conv1(x) 2125 x = self.relu(x) 2126 x = self.conv2(x) 2127 return x 2128 2129 def get_example_inputs(self) -> Tuple[Any, ...]: 2130 return self.conv1.get_example_inputs() 2131 2132class SkipQuantModel(torch.nn.Module): 2133 r"""We can skip quantization by explicitly 2134 setting qconfig of a submodule to None 2135 """ 2136 def __init__(self) -> None: 2137 super().__init__() 2138 self.sub = InnerModule() 2139 self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) 2140 2141 def forward(self, x): 2142 return self.fc(self.sub(x)) 2143 2144 def fuse_modules(self): 2145 self.sub.fuse_modules() 2146 2147class AnnotatedSkipQuantModel(torch.nn.Module): 2148 r"""We can skip quantization by explicitly 2149 setting qconfig of a submodule to None 2150 """ 2151 def __init__(self, qengine): 2152 super().__init__() 2153 self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) 2154 self.sub = QuantWrapper(InnerModule()) 2155 self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) 2156 # don't quantize this fc 2157 self.fc.qconfig = None 2158 2159 def forward(self, x): 2160 return self.fc(self.sub(x)) 2161 2162 def fuse_modules(self): 2163 self.sub.module.fuse_modules() 2164 2165class QuantStubModel(torch.nn.Module): 2166 r"""A Module with manually inserted `QuantStub` and `DeQuantStub` 2167 """ 2168 def __init__(self) -> None: 2169 super().__init__() 2170 self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 2171 self.quant = QuantStub() 2172 self.dequant = DeQuantStub() 2173 self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) 2174 2175 def forward(self, x): 2176 x = self.quant(x) 2177 x = self.fc(x) 2178 return self.dequant(x) 2179 2180class ManualLinearQATModel(torch.nn.Module): 2181 r"""A Module with manually inserted `QuantStub` and `DeQuantStub` 2182 """ 2183 def __init__(self, qengine): 2184 super().__init__() 2185 self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 2186 self.quant = QuantStub() 2187 self.dequant = DeQuantStub() 2188 self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) 2189 self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) 2190 2191 def forward(self, x): 2192 x = self.quant(x) 2193 x = self.fc1(x) 2194 x = self.fc2(x) 2195 return self.dequant(x) 2196 2197class ManualDropoutQATModel(torch.nn.Module): 2198 r"""A Module with manually inserted `QuantStub` and `DeQuantStub` 2199 """ 2200 def __init__(self, qengine): 2201 super().__init__() 2202 self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) 2203 self.quant = QuantStub() 2204 self.dequant = DeQuantStub() 2205 self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) 2206 self.dropout = torch.nn.Dropout(0.5) 2207 2208 def forward(self, x): 2209 x = self.quant(x) 2210 x = self.fc1(x) 2211 x = self.dropout(x) 2212 return self.dequant(x) 2213 2214class ManualLinearDynamicQATModel(torch.nn.Module): 2215 r"""A Module that uses a dynamic QAT by default. 2216 """ 2217 def __init__(self, qconfig=None): 2218 super().__init__() 2219 self.qconfig = qconfig or default_dynamic_qat_qconfig 2220 self.fc1 = torch.nn.Linear(5, 1).to(dtype=torch.float) 2221 self.fc2 = torch.nn.Linear(1, 10).to(dtype=torch.float) 2222 2223 def forward(self, x): 2224 x = self.fc1(x) 2225 x = self.fc2(x) 2226 return x 2227 2228class ManualConvLinearQATModel(torch.nn.Module): 2229 r"""A module with manually inserted `QuantStub` and `DeQuantStub` 2230 and contains both linear and conv modules 2231 """ 2232 def __init__(self, qconfig=None): 2233 super().__init__() 2234 self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack") 2235 self.quant = QuantStub() 2236 self.dequant = DeQuantStub() 2237 self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) 2238 self.fc1 = torch.nn.Linear(64, 10).to(dtype=torch.float) 2239 self.fc2 = torch.nn.Linear(10, 10).to(dtype=torch.float) 2240 2241 def forward(self, x): 2242 x = self.quant(x) 2243 x = self.conv(x) 2244 x = x.view(-1, 64).contiguous() 2245 x = self.fc1(x) 2246 x = self.fc2(x) 2247 return self.dequant(x) 2248 2249class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): 2250 r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. 2251 Supported only with qnnpack. 2252 """ 2253 def __init__(self) -> None: 2254 super().__init__(default_symmetric_qnnpack_qat_qconfig) 2255 2256class ManualEmbeddingBagLinear(nn.Module): 2257 def __init__(self) -> None: 2258 super().__init__() 2259 self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') 2260 self.emb.qconfig = default_embedding_qat_qconfig 2261 self.quant = QuantStub() 2262 self.dequant = DeQuantStub() 2263 self.linear = nn.Linear(12, 1).to(dtype=torch.float) 2264 self.qconfig = get_default_qat_qconfig("qnnpack") 2265 2266 def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, 2267 per_sample_weights: Optional[torch.Tensor] = None): 2268 x = self.emb(input, offsets, per_sample_weights) 2269 x = self.quant(x) 2270 x = self.linear(x) 2271 return self.dequant(x) 2272 2273class DeFusedEmbeddingBagLinear(nn.Module): 2274 r"""A module to simulate QAT embedding bag with a linear layer, 2275 this module uses a separate embedding and bagging op, similar 2276 to that which is described in the EmbeddingBag documentation. 2277 2278 https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html 2279 """ 2280 def __init__(self) -> None: 2281 super().__init__() 2282 self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) 2283 self.emb.qconfig = default_embedding_qat_qconfig 2284 self.bagging_op = torch.sum 2285 self.quant = QuantStub() 2286 self.dequant = DeQuantStub() 2287 self.linear = nn.Linear(12, 1).to(dtype=torch.float) 2288 self.qconfig = get_default_qat_qconfig("qnnpack") 2289 2290 def forward(self, input: torch.Tensor) -> torch.Tensor: 2291 x = self.bagging_op(self.emb(input), dim=1) 2292 x = self.quant(x) 2293 x = self.linear(x) 2294 return self.dequant(x) 2295 2296class SubModelForFusion(nn.Module): 2297 def __init__(self) -> None: 2298 super().__init__() 2299 self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) 2300 self.bn = nn.BatchNorm2d(2).to(dtype=torch.float) 2301 2302 def forward(self, x): 2303 x = self.conv(x) 2304 x = self.bn(x) 2305 return x 2306 2307 2308class SubModelWithoutFusion(nn.Module): 2309 def __init__(self) -> None: 2310 super().__init__() 2311 self.conv = nn.Conv2d(2, 2, 1, bias=None).to(dtype=torch.float) 2312 self.relu = nn.ReLU(inplace=False).to(dtype=torch.float) 2313 2314 def forward(self, x): 2315 return self.relu(self.conv(x)) 2316 2317class ModelForFusion(nn.Module): 2318 def __init__(self, qconfig): 2319 super().__init__() 2320 self.conv1 = nn.Conv2d(3, 2, 1, bias=None).to(dtype=torch.float) 2321 self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) 2322 self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) 2323 self.sub1 = SubModelForFusion() 2324 self.sub2 = SubModelWithoutFusion() 2325 self.fc = nn.Linear(36, 10).to(dtype=torch.float) 2326 self.quant = QuantStub() 2327 self.dequant = DeQuantStub() 2328 self.qconfig = qconfig 2329 self.conv2 = nn.Conv3d(3, 2, (1, 1, 1), bias=None).to(dtype=torch.float) 2330 self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float) 2331 self.bn2 = nn.BatchNorm3d(2).to(dtype=torch.float) 2332 self.relu3 = nn.ReLU(inplace=True).to(dtype=torch.float) 2333 self.conv3 = nn.Conv1d(3, 3, 2).to(dtype=torch.float) 2334 self.bn3 = nn.BatchNorm1d(3).to(dtype=torch.float) 2335 self.relu4 = nn.ReLU(inplace=True).to(dtype=torch.float) 2336 # don't quantize sub2 2337 self.sub2.qconfig = None 2338 self.fc.qconfig = None 2339 2340 def forward(self, x): 2341 x = x.squeeze(2) 2342 x = self.quant(x) 2343 x = self.conv3(x) 2344 x = self.bn3(x) 2345 x = self.relu4(x) 2346 x = x.unsqueeze(2) 2347 y = x.unsqueeze(2) 2348 x = self.conv1(x) 2349 x = self.bn1(x) 2350 x = self.relu1(x) 2351 x = self.sub1(x) 2352 x = self.dequant(x) 2353 x = self.sub2(x) 2354 x = x.reshape(-1, 36).contiguous() 2355 x = self.fc(x) 2356 y = self.conv2(y) 2357 y = self.relu2(y) 2358 y = self.bn2(y) 2359 y = self.relu3(y) 2360 y = self.dequant(y) 2361 return x 2362 2363class ConvBNReLU(nn.Sequential): 2364 def __init__(self) -> None: 2365 super().__init__( 2366 nn.Conv2d(3, 3, 1, 1, bias=False), 2367 nn.BatchNorm2d(3), 2368 nn.ReLU(inplace=False) 2369 ) 2370 2371class ModelWithSequentialFusion(nn.Module): 2372 def __init__(self) -> None: 2373 super().__init__() 2374 self.conv1 = nn.Conv2d(3, 3, 1) 2375 self.relu1 = nn.ReLU(inplace=False) 2376 layers = [] 2377 for i in range(3): 2378 layers.append(ConvBNReLU()) 2379 self.features = nn.Sequential(*layers) 2380 head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] 2381 self.classifier = nn.Sequential(*head) 2382 self.seq = nn.Sequential() 2383 self.quant = QuantStub() 2384 self.dequant = DeQuantStub() 2385 2386 def forward(self, x): 2387 x = self.quant(x) 2388 x = self.conv1(x) 2389 x = self.relu1(x) 2390 x = self.features(x) 2391 x = torch.reshape(x, (-1, 3 * 10 * 10)) 2392 x = self.classifier(x) 2393 x = self.seq(x) 2394 x = self.dequant(x) 2395 return x 2396 2397class ModelForFusionWithBias(nn.Module): 2398 def __init__(self) -> None: 2399 super().__init__() 2400 self.conv1 = nn.Conv2d(3, 2, 5, bias=True).to(dtype=torch.float) 2401 self.bn1 = nn.BatchNorm2d(2).to(dtype=torch.float) 2402 self.relu1 = nn.ReLU(inplace=True).to(dtype=torch.float) 2403 self.conv2 = nn.Conv2d(2, 2, 1, bias=True).to(dtype=torch.float) 2404 self.bn2 = nn.BatchNorm2d(2).to(dtype=torch.float) 2405 self.quant = QuantStub() 2406 self.dequant = DeQuantStub() 2407 2408 def forward(self, x): 2409 x = self.quant(x) 2410 x = self.conv1(x) 2411 x = self.bn1(x) 2412 x = self.relu1(x) 2413 x = self.conv2(x) 2414 x = self.bn2(x) 2415 x = self.dequant(x) 2416 return x 2417 2418class ModelForLinearBNFusion(nn.Module): 2419 def __init__(self) -> None: 2420 super().__init__() 2421 self.fc = nn.Linear(20, 10) 2422 self.bn = nn.BatchNorm1d(10) 2423 nn.init.uniform_(self.bn.weight) 2424 nn.init.uniform_(self.bn.bias) 2425 2426 def forward(self, x): 2427 return self.bn(self.fc(x)) 2428 2429class DummyObserver(torch.nn.Module): 2430 def calculate_qparams(self): 2431 return 1.0, 0 2432 2433 def forward(self, x): 2434 return x 2435 2436 2437class ModelForConvTransposeBNFusion(nn.Module): 2438 def __init__(self) -> None: 2439 super().__init__() 2440 self.conv1 = nn.ConvTranspose1d(3, 3, 1) 2441 self.bn1 = nn.BatchNorm1d(3) 2442 self.conv2 = nn.ConvTranspose2d(3, 3, 1) 2443 self.bn2 = nn.BatchNorm2d(3) 2444 self.conv3 = nn.ConvTranspose3d(3, 3, 1) 2445 self.bn3 = nn.BatchNorm3d(3) 2446 2447 def forward(self, x): 2448 x = self.conv1(x) 2449 x = self.bn1(x) 2450 x = x.unsqueeze(2) 2451 x = self.conv2(x) 2452 x = self.bn2(x) 2453 x = x.unsqueeze(2) 2454 x = self.conv3(x) 2455 x = self.bn3(x) 2456 return x 2457 2458 2459class ModelWithFunctionals(torch.nn.Module): 2460 def __init__(self) -> None: 2461 super().__init__() 2462 self.mycat = nnq.FloatFunctional() 2463 self.myadd = nnq.FloatFunctional() 2464 self.myadd_relu = nnq.FloatFunctional() 2465 self.mymatmul = nnq.FloatFunctional() 2466 # Tracing doesnt work yet for c10 ops with scalar inputs 2467 # https://github.com/pytorch/pytorch/issues/27097 2468 # self.my_scalar_add = nnq.FloatFunctional() 2469 # self.my_scalar_mul = nnq.FloatFunctional() 2470 2471 def forward(self, x): 2472 y = self.mycat.cat([x, x, x]) 2473 z = self.myadd.add(y, y) 2474 w = self.myadd_relu.add_relu(z, z) 2475 u = self.mymatmul.matmul(w, w.T) 2476 # Tracing doesnt work yet for c10 ops with scalar inputs 2477 # https://github.com/pytorch/pytorch/issues/27097 2478 # w = self.my_scalar_add.add_scalar(w, -0.5) 2479 # w = self.my_scalar_mul.mul_scalar(w, 0.5) 2480 return u 2481 2482 2483class ResNetBase(torch.nn.Module): 2484 def __init__(self) -> None: 2485 super().__init__() 2486 norm_layer = nn.BatchNorm2d 2487 inplanes = 3 2488 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) 2489 self.bn1 = norm_layer(inplanes) 2490 self.relu1 = nn.ReLU() 2491 self.relu2 = nn.ReLU() 2492 self.downsample = torch.nn.Identity() 2493 self.myop = nn.quantized.FloatFunctional() 2494 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 2495 self.fc = torch.nn.Linear(inplanes, 1) 2496 2497 def forward(self, x): 2498 out = self.conv1(x) 2499 out = self.bn1(out) 2500 out = self.relu1(out) 2501 identity = self.downsample(x) 2502 out = self.myop.add(out, identity) 2503 out = self.relu2(out) 2504 out = self.avgpool(out) 2505 out = torch.flatten(out, 1) 2506 out = self.fc(out) 2507 return out 2508 2509 def fuse_model(self): 2510 # TODO: remove this check and define two fuse_model function on this module 2511 if self.training: 2512 torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True) 2513 else: 2514 torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) 2515 2516class ModelMultipleOps(torch.nn.Module): 2517 def __init__(self) -> None: 2518 super().__init__() 2519 norm_layer = nn.BatchNorm2d 2520 inplanes = 3 2521 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) 2522 self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) 2523 self.bn1 = norm_layer(inplanes) 2524 self.relu1 = nn.ReLU() 2525 self.relu2 = nn.ReLU() 2526 self.downsample = torch.nn.Identity() 2527 self.skip_add = nn.quantized.FloatFunctional() 2528 self.cat = nn.quantized.FloatFunctional() 2529 self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) 2530 self.fc = nn.Linear(12, 6) 2531 2532 def forward(self, x): 2533 out = self.conv1(x) 2534 out = self.bn1(out) 2535 out = self.relu1(out) 2536 identity = self.downsample(x) 2537 out = self.skip_add.add(out, identity) 2538 out = self.relu2(out) 2539 out = self.avgpool(out) 2540 out = self.conv2(out) 2541 out = torch.nn.functional.max_pool2d(out, 2, 2) 2542 out = self.cat.cat([out, out]) 2543 out = out.reshape(-1, 3 * 2 * 2) 2544 out = self.fc(out) 2545 return out 2546 2547# Model to ensure consistency of fake quant with true quant 2548# Average pooling and mean operations are not modelled 2549# accurately with fake-quant so this model does not 2550# contain those operations 2551class ModelMultipleOpsNoAvgPool(torch.nn.Module): 2552 def __init__(self) -> None: 2553 super().__init__() 2554 norm_layer = nn.BatchNorm2d 2555 inplanes = 3 2556 self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) 2557 self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) 2558 self.bn1 = norm_layer(inplanes) 2559 self.relu1 = nn.ReLU() 2560 self.relu2 = nn.ReLU() 2561 self.skip_add = nn.quantized.FloatFunctional() 2562 self.cat = nn.quantized.FloatFunctional() 2563 self.maxpool = nn.MaxPool2d((4, 4)) 2564 self.fc = nn.Linear(12, 6) 2565 2566 def forward(self, x): 2567 out = self.conv1(x) 2568 out = self.bn1(out) 2569 out = self.relu1(out) 2570 skip = self.conv2(x) 2571 out = self.skip_add.add(out, skip) 2572 out = self.relu2(out) 2573 out = self.maxpool(out) 2574 out = self.conv2(out) 2575 out = torch.nn.functional.max_pool2d(out, 2, 2) 2576 out = self.cat.cat([out, out]) 2577 out = out.reshape(-1, 3 * 2 * 2) 2578 out = self.fc(out) 2579 return out 2580 2581class EmbeddingBagModule(torch.nn.Module): 2582 def __init__(self) -> None: 2583 super().__init__() 2584 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, 2585 include_last_offset=True, scale_grad_by_freq=False, mode='sum') 2586 2587 def forward(self, indices, offsets, per_sample_weights): 2588 return self.emb(indices, offsets, per_sample_weights) 2589 2590class EmbeddingModule(torch.nn.Module): 2591 def __init__(self) -> None: 2592 super().__init__() 2593 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 2594 2595 def forward(self, indices): 2596 return self.emb(indices) 2597 2598class EmbeddingWithStaticLinear(torch.nn.Module): 2599 def __init__(self) -> None: 2600 super().__init__() 2601 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12) 2602 self.fc = torch.nn.Linear(4, 2) 2603 self.emb.qconfig = float_qparams_weight_only_qconfig 2604 self.qconfig = default_qconfig 2605 self.quant = QuantStub() 2606 self.dequant = DeQuantStub() 2607 2608 def forward(self, indices, offsets, linear_in): 2609 emb = self.emb(indices, offsets) 2610 q_x = self.quant(linear_in) 2611 fc = self.fc(q_x) 2612 fc = self.dequant(fc) 2613 features = torch.cat([fc] + [emb], dim=1) 2614 return features 2615 2616class DenseTopMLP(nn.Module): 2617 2618 def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None: 2619 super().__init__() 2620 2621 self.dense_mlp = nn.Sequential( 2622 nn.Linear(dense_dim, dense_out), 2623 ) 2624 self.top_mlp = nn.Sequential( 2625 nn.Linear(dense_out + embedding_dim, top_out_in), 2626 nn.Linear(top_out_in, top_out_out), 2627 ) 2628 2629 def forward( 2630 self, 2631 sparse_feature: torch.Tensor, 2632 dense: torch.Tensor, 2633 ) -> torch.Tensor: 2634 dense_feature = self.dense_mlp(dense) 2635 features = torch.cat([dense_feature] + [sparse_feature], dim=1) 2636 2637 out = self.top_mlp(features) 2638 return out 2639 2640# thin wrapper around embedding bag, because tracing inside nn.Embedding 2641# bag is not supported at the moment and this is top level 2642class EmbBagWrapper(nn.Module): 2643 def __init__(self, num_embeddings, embedding_dim): 2644 super().__init__() 2645 self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum') 2646 2647 def forward(self, indices, offsets): 2648 return self.emb_bag(indices, offsets) 2649 2650class SparseNNModel(nn.Module): 2651 _NUM_EMBEDDINGS = 10 2652 _EMBEDDING_DIM = 5 2653 _DENSE_DIM = 4 2654 _DENSE_OUTPUT = 2 2655 _TOP_OUT_IN = 2 2656 _TOP_OUT_OUT = 2 2657 _TOP_MLP_DIM = 1 2658 2659 def __init__(self) -> None: 2660 super().__init__() 2661 2662 self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) 2663 self.dense_top = DenseTopMLP( 2664 self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN, 2665 self._TOP_OUT_OUT) 2666 2667 def forward( 2668 self, 2669 sparse_indices: torch.Tensor, 2670 sparse_offsets: torch.Tensor, 2671 dense: torch.Tensor, 2672 ) -> torch.Tensor: 2673 2674 sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) 2675 out = self.dense_top(sparse_feature, dense) 2676 2677 return out 2678 2679class TestHelperModules: 2680 class Conv2dPropAnnotaton(torch.nn.Module): 2681 def __init__(self) -> None: 2682 super().__init__() 2683 self.conv = torch.nn.Conv2d(3, 3, 3) 2684 self.linear = torch.nn.Linear(3, 3) 2685 2686 def forward(self, x): 2687 x = self.conv(x) 2688 x = x.view(-1, 3) 2689 x = torch.nn.functional.hardtanh(x, -0.5, 0.5) 2690 x = self.linear(x) 2691 return x 2692 2693 class Conv2dWithObsSharingOps(torch.nn.Module): 2694 def __init__(self) -> None: 2695 super().__init__() 2696 self.conv = torch.nn.Conv2d(3, 3, 3) 2697 self.hardtanh = torch.nn.Hardtanh() 2698 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 2699 2700 def forward(self, x): 2701 x = self.conv(x) 2702 x = self.adaptive_avg_pool2d(x) 2703 x = self.hardtanh(x) 2704 x = torch.mean(x) 2705 return x 2706 2707 class Conv2dWithTwoLinearPermute(torch.nn.Module): 2708 def __init__(self) -> None: 2709 super().__init__() 2710 self.conv = torch.nn.Conv2d(3, 16, 3) 2711 self.linear1 = torch.nn.Linear(16, 8, bias=False) 2712 self.linear2 = torch.nn.Linear(8, 8) 2713 2714 def forward(self, x): 2715 conv_out = self.conv(x) 2716 permute_out = torch.permute(conv_out, (0, 2, 3, 1)) 2717 return self.linear2(self.linear1(permute_out)) 2718 2719 class Conv2dWithTwoLinear(torch.nn.Module): 2720 def __init__(self) -> None: 2721 super().__init__() 2722 self.conv = torch.nn.Conv2d(3, 16, 3) 2723 self.linear1 = torch.nn.Linear(64, 8, bias=False) 2724 self.linear2 = torch.nn.Linear(8, 8) 2725 2726 def forward(self, x): 2727 conv_out = self.conv(x) 2728 reshape_out = torch.reshape(conv_out, (2, 64)) 2729 return self.linear2(self.linear1(reshape_out)) 2730 2731 class ConvLinearWPermute(torch.nn.Module): 2732 def __init__(self) -> None: 2733 super().__init__() 2734 self.conv = torch.nn.Conv2d(3, 8, 3) 2735 self.linear1 = torch.nn.Linear(8, 8) 2736 2737 def forward(self, x): 2738 conv_out = self.conv(x) 2739 permute_out = torch.permute(conv_out, (0, 2, 3, 1)) 2740 return self.linear1(permute_out) 2741 2742 class TwoLinearModule(torch.nn.Module): 2743 def __init__(self) -> None: 2744 super().__init__() 2745 self.linear1 = torch.nn.Linear(8, 16, bias=False) 2746 self.linear2 = torch.nn.Linear(16, 8) 2747 2748 def forward(self, x): 2749 return self.linear2(self.linear1(x)) 2750 2751 class ConvMaxPool2d(torch.nn.Module): 2752 def __init__(self) -> None: 2753 super().__init__() 2754 self.conv = torch.nn.Conv2d(2, 2, 1) 2755 self.pool = torch.nn.MaxPool2d(1, 1) 2756 2757 def forward(self, x): 2758 x = self.conv(x) 2759 x = self.pool(x) 2760 return x 2761 2762 class ConvWithAdaptiveAvgPool2d(torch.nn.Module): 2763 def __init__(self) -> None: 2764 super().__init__() 2765 self.conv = torch.nn.Conv2d(3, 3, 3) 2766 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 2767 2768 def forward(self, x): 2769 x = self.conv(x) 2770 x = self.adaptive_avg_pool2d(x) 2771 return x 2772 2773 class ConvWithBNRelu(torch.nn.Module): 2774 def __init__(self, relu, dim=2, bn=True, bias=True): 2775 super().__init__() 2776 convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d} 2777 bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} 2778 self.conv = convs[dim](3, 3, 3, bias=bias) 2779 2780 if bn: 2781 self.bn = bns[dim](3) 2782 else: 2783 self.bn = torch.nn.Identity() 2784 if relu: 2785 self.relu = torch.nn.ReLU() 2786 else: 2787 self.relu = torch.nn.Identity() 2788 2789 def forward(self, x): 2790 x = self.conv(x) 2791 x = self.bn(x) 2792 return self.relu(x) 2793 2794 class ConvTWithBNRelu(torch.nn.Module): 2795 def __init__(self, relu, dim=2, bn=True, bias=True): 2796 super().__init__() 2797 convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d} 2798 bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d} 2799 self.convt = convts[dim](3, 3, 3, bias=bias) 2800 2801 if bn: 2802 self.bn = bns[dim](3) 2803 else: 2804 self.bn = torch.nn.Identity() 2805 if relu: 2806 self.relu = torch.nn.ReLU() 2807 else: 2808 self.relu = torch.nn.Identity() 2809 2810 def forward(self, x): 2811 x = self.convt(x) 2812 x = self.bn(x) 2813 return self.relu(x) 2814 2815 class Conv2dThenConv1d(torch.nn.Module): 2816 def __init__(self) -> None: 2817 super().__init__() 2818 self.conv1d = torch.nn.Conv1d(3, 3, 3) 2819 self.conv2d = torch.nn.Conv2d(3, 3, 3) 2820 2821 def forward(self, x): 2822 x = self.conv2d(x) 2823 x = x.squeeze(0) 2824 x = self.conv1d(x) 2825 return x 2826 2827 def example_inputs(self): 2828 return (torch.randn(1, 3, 5, 5),) 2829 2830 class Conv2dWithCat(torch.nn.Module): 2831 def __init__(self) -> None: 2832 super().__init__() 2833 self.conv1 = torch.nn.Conv2d(3, 3, 3) 2834 self.conv2 = torch.nn.Conv2d(3, 3, 3) 2835 2836 def forward(self, x, y): 2837 x = self.conv1(x) 2838 y = self.conv2(y) 2839 z = torch.cat([x, y], dim=1) 2840 return z 2841 2842 class Conv2dWithTwoCat(torch.nn.Module): 2843 def __init__(self) -> None: 2844 super().__init__() 2845 self.conv1 = torch.nn.Conv2d(3, 3, 3) 2846 self.conv2 = torch.nn.Conv2d(3, 3, 3) 2847 2848 def forward(self, x1, x2, x3, x4): 2849 x1 = self.conv1(x1) 2850 x2 = self.conv2(x2) 2851 y = torch.cat([x1, x2], dim=1) 2852 z = x3 + x4 2853 w = torch.cat([z, y]) 2854 return w 2855 2856 class ThreeAdd(torch.nn.Module): 2857 def forward(self, x1, x2, x3, x4): 2858 y = x1 + x2 2859 z = x3 + x4 2860 w = y + z 2861 return w 2862 2863 class EmbeddingModule(torch.nn.Module): 2864 def __init__(self) -> None: 2865 super().__init__() 2866 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 2867 2868 def forward(self, indices): 2869 return self.emb(indices) 2870 2871 class EmbeddingConvLinearModule(torch.nn.Module): 2872 def __init__(self) -> None: 2873 super().__init__() 2874 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=8) 2875 self.conv = torch.nn.Conv2d(8, 16, (1, 3)) 2876 self.linear = torch.nn.Linear(16, 8) 2877 2878 def forward(self, indices): 2879 embeddings = self.emb(indices) 2880 embeddings = torch.unsqueeze(embeddings, dim=0) 2881 embeddings = torch.permute(embeddings, (0, 3, 1, 2)) 2882 conv_out = self.conv(embeddings) 2883 conv_out = torch.permute(conv_out, (0, 2, 3, 1)) 2884 conv_out = torch.squeeze(conv_out, dim=0) 2885 return self.linear(conv_out) 2886 2887 class AddInplaceAdd(torch.nn.Module): 2888 def forward(self, x, y): 2889 x = x + y 2890 x += y 2891 return x 2892 2893 class MulInplaceMul(torch.nn.Module): 2894 def forward(self, x, y): 2895 x = x * y 2896 x *= y 2897 return x 2898 2899 class AddMulScalar(torch.nn.Module): 2900 def forward(self, x): 2901 x = x + 3 2902 x = x * 3 2903 x += 3 2904 x *= 3 2905 return x 2906 2907 class ConvBnReLU2dAndLinearReLU(torch.nn.Module): 2908 def __init__(self) -> None: 2909 super().__init__() 2910 self.conv_bn_relu = TestHelperModules.ConvWithBNRelu(relu=True) 2911 self.linear = torch.nn.Linear(3, 8, bias=False) 2912 self.relu = torch.nn.ReLU() 2913 2914 def forward(self, x): 2915 x = self.conv_bn_relu(x) 2916 permute_out = torch.permute(x, (0, 2, 3, 1)) 2917 linear_out = self.linear(permute_out) 2918 return linear_out 2919 2920 class GroupwiseConv2d(torch.nn.Module): 2921 def __init__(self) -> None: 2922 super().__init__() 2923 self.conv = torch.nn.Conv2d(4, 4, 3, groups=2) 2924 2925 def forward(self, x): 2926 return self.conv(x) 2927 2928 def example_inputs(self): 2929 return (torch.randn(2, 4, 10, 10),) 2930 2931 class LinearReluModel(torch.nn.Module): 2932 def __init__(self) -> None: 2933 super().__init__() 2934 self.fc = torch.nn.Linear(5, 5).to(dtype=torch.float) 2935 self.relu = torch.nn.ReLU() 2936 2937 def forward(self, x): 2938 x = self.relu(self.fc(x)) 2939 return x 2940 2941def _generate_qdq_quantized_model( 2942 mod, inputs, is_qat=False, is_dynamic=False, quantizer=None 2943): 2944 2945 def get_default_quantizer(is_qat, is_dynamic): 2946 quantizer = X86InductorQuantizer() 2947 quantizer.set_global( 2948 xiq.get_default_x86_inductor_quantization_config( 2949 is_qat=is_qat, is_dynamic=is_dynamic 2950 ) 2951 ) 2952 return quantizer 2953 2954 maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() 2955 with maybe_no_grad: 2956 export_model = capture_pre_autograd_graph( 2957 mod, 2958 inputs, 2959 ) 2960 quantizer = ( 2961 quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic) 2962 ) 2963 prepare_model = ( 2964 prepare_qat_pt2e(export_model, quantizer) 2965 if is_qat 2966 else prepare_pt2e(export_model, quantizer) 2967 ) 2968 prepare_model(*inputs) 2969 torch.ao.quantization.move_exported_model_to_eval(prepare_model) 2970 convert_model = convert_pt2e(prepare_model) 2971 return convert_model 2972