1# Owner(s): ["oncall: quantization"] 2 3from collections import OrderedDict 4import contextlib 5import torch 6import torch.nn.functional as F 7import torch.nn as nn 8import torch.ao.nn.quantized as nnq 9import torch.ao.nn.quantized.reference as nnqr 10import torch.ao.nn.quantized.dynamic as nnqd 11import torch.ao.nn.intrinsic as nni 12import torch.ao.nn.intrinsic.quantized as nniq 13import torch.ao.nn.intrinsic.quantized.dynamic as nniqd 14import torch.multiprocessing as mp 15from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY 16 17# graph mode quantization based on fx 18from torch.ao.quantization.quantize_fx import ( 19 prepare_fx, 20 convert_fx, 21 convert_to_reference_fx, 22 _convert_to_reference_decomposed_fx, 23 prepare_qat_fx, 24 fuse_fx, 25) 26 27 28from torch.ao.quantization.fx.quantize_handler import DefaultNodeQuantizeHandler 29 30from torch.ao.quantization.fx.match_utils import ( 31 _is_match, 32 MatchAllNode, 33) 34 35from torch.ao.quantization import ( 36 QuantType, 37) 38from torch.ao.quantization.quant_type import _get_quant_type_to_str 39 40from torch.ao.quantization import ( 41 QuantStub, 42 DeQuantStub, 43 QuantWrapper, 44 default_qconfig, 45 default_dynamic_qconfig, 46 default_per_channel_qconfig, 47 default_qat_qconfig, 48 default_reuse_input_qconfig, 49 default_symmetric_qnnpack_qconfig, 50 default_symmetric_qnnpack_qat_qconfig, 51 per_channel_dynamic_qconfig, 52 float16_dynamic_qconfig, 53 float16_static_qconfig, 54 float_qparams_weight_only_qconfig, 55 float_qparams_weight_only_qconfig_4bit, 56 get_default_qconfig, 57 get_default_qat_qconfig, 58 get_default_qconfig_mapping, 59 get_default_qat_qconfig_mapping, 60 fuse_modules, 61 fuse_modules_qat, 62 prepare, 63 prepare_qat, 64 convert, 65 quantize_dynamic, 66 default_placeholder_observer, 67 default_weight_observer, 68 PerChannelMinMaxObserver, 69 FixedQParamsFakeQuantize, 70 FixedQParamsObserver, 71 FusedMovingAvgObsFakeQuantize, 72 FakeQuantize, 73 MovingAverageMinMaxObserver, 74 HistogramObserver, 75 ReuseInputObserver, 76 QConfig, 77 default_embedding_qat_qconfig, 78) 79 80from torch.ao.quantization.backend_config import ( 81 get_fbgemm_backend_config, 82 get_qnnpack_backend_config, 83 BackendConfig, 84 BackendPatternConfig, 85 DTypeConfig, 86 DTypeWithConstraints, 87 ObservationType 88) 89from torch.ao.quantization.backend_config.native import ( 90 get_test_only_legacy_native_backend_config, 91) 92 93from torch.ao.quantization.qconfig_mapping import ( 94 _get_symmetric_qnnpack_qconfig_mapping, 95 _get_symmetric_qnnpack_qat_qconfig_mapping, 96 _GLOBAL_DICT_KEY, 97 _MODULE_NAME_DICT_KEY, 98 _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, 99 _MODULE_NAME_REGEX_DICT_KEY, 100 _OBJECT_TYPE_DICT_KEY, 101 QConfigMapping, 102) 103 104from torch.ao.quantization.fx.qconfig_mapping_utils import ( 105 _get_object_type_qconfig, 106 _get_module_name_qconfig, 107 _get_module_name_regex_qconfig, 108 _maybe_adjust_qconfig_for_module_name_object_type_order, 109) 110 111from torch.ao.quantization.fx.pattern_utils import ( 112 _DEFAULT_FUSION_PATTERNS, 113 _DEFAULT_QUANTIZATION_PATTERNS, 114 _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP, 115 _DEFAULT_OUTPUT_OBSERVER_MAP, 116 _register_fusion_pattern, 117 _register_quant_pattern, 118 get_default_output_activation_post_process_map 119) 120 121from torch.ao.quantization.fx.custom_config import ( 122 STANDALONE_MODULE_NAME_DICT_KEY, 123 STANDALONE_MODULE_CLASS_DICT_KEY, 124 FLOAT_TO_OBSERVED_DICT_KEY, 125 OBSERVED_TO_QUANTIZED_DICT_KEY, 126 NON_TRACEABLE_MODULE_NAME_DICT_KEY, 127 NON_TRACEABLE_MODULE_CLASS_DICT_KEY, 128 INPUT_QUANTIZED_INDEXES_DICT_KEY, 129 OUTPUT_QUANTIZED_INDEXES_DICT_KEY, 130 PRESERVED_ATTRIBUTES_DICT_KEY, 131 FuseCustomConfig, 132 ConvertCustomConfig, 133 PrepareCustomConfig, 134 StandaloneModuleConfigEntry, 135) 136import torch.ao.quantization.fx.lstm_utils 137 138from torch.ao.quantization.fx.utils import ( 139 _reroute_tuple_getitem_pattern, 140 NodeInfo, 141) 142 143from torch.ao.quantization.fake_quantize import ( 144 default_fixed_qparams_range_0to1_fake_quant, 145 default_fixed_qparams_range_neg1to1_fake_quant, 146) 147 148from torch.ao.quantization.observer import ( 149 default_fixed_qparams_range_0to1_observer, 150 default_fixed_qparams_range_neg1to1_observer, 151 MinMaxObserver, 152 _is_activation_post_process, 153) 154 155# test utils 156from hypothesis import given, settings 157from hypothesis import strategies as st 158from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA 159from torch.testing._internal.common_quantization import ( 160 LinearReluLinearModel, 161 LinearReluModel, 162 LinearBnLeakyReluModel, 163 LinearTanhModel, 164 ConvBnAddReluModel, 165 QuantizationTestCase, 166 skipIfNoFBGEMM, 167 skipIfNoQNNPACK, 168 skip_if_no_torchvision, 169 train_one_epoch, 170 run_ddp, 171 test_only_eval_fn, 172 test_only_train_fn, 173 ModelForConvTransposeBNFusion, 174 get_supported_device_types, 175 skipIfNoONEDNN, 176) 177 178from torch.testing._internal.common_quantization import ( 179 LinearModelWithSubmodule, 180 ResNetBase, 181 RNNDynamicModel, 182 RNNCellDynamicModel, 183) 184 185from torch.testing._internal.common_quantized import ( 186 supported_qengines, 187 override_qengines, 188 override_quantized_engine, 189) 190 191from torch.testing._internal.common_utils import ( 192 TemporaryFileName, 193 IS_ARM64, 194 skipIfTorchDynamo, 195) 196 197from torch.testing._internal.common_quantization import NodeSpec as ns 198 199from torch.testing import FileCheck 200 201import copy 202import itertools 203import operator 204import unittest 205import io 206from typing import Callable, Optional, List, Tuple 207 208class BinaryOp(torch.nn.Module): 209 def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): 210 """ ibinary_op means inplace binary op 211 """ 212 super().__init__() 213 self.conv1 = torch.nn.Conv2d(1, 1, 1).float() 214 self.conv2 = torch.nn.Conv2d(1, 1, 1).float() 215 self.is_scalar = is_scalar 216 self.op = ibinary_op if ibinary_op and is_inplace else binary_op 217 218 def forward(self, x, y): 219 x = self.conv1(x) 220 y = 3 if self.is_scalar else self.conv2(y) 221 # x = x + y 222 x = self.op(x, y) 223 # x = y + x 224 x = self.op(y, x) 225 return x 226 227class BinaryOpNonQuantizedInput(torch.nn.Module): 228 def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): 229 """ ibinary_op means inplace binary op 230 """ 231 super().__init__() 232 self.is_scalar = is_scalar 233 self.op = ibinary_op if ibinary_op and is_inplace else binary_op 234 235 def forward(self, x, y): 236 y = 3 if self.is_scalar else y 237 x = self.op(x, y) 238 return x 239 240class BinaryOpRelu(torch.nn.Module): 241 def __init__(self, binary_op, ibinary_op, is_inplace, relu_callable, 242 is_scalar): 243 """ ibinary_op means inplace binary op 244 """ 245 super().__init__() 246 self.conv1 = torch.nn.Conv2d(1, 1, 1).float() 247 self.conv2 = torch.nn.Conv2d(1, 1, 1).float() 248 self.op = ibinary_op if ibinary_op and is_inplace else binary_op 249 self.relu_callable = relu_callable 250 self.is_scalar = is_scalar 251 if relu_callable is torch.nn.ReLU: 252 self.relu = torch.nn.ReLU() 253 else: 254 self.relu = relu_callable 255 256 def forward(self, x, y): 257 x = self.conv1(x) 258 y = 3 if self.is_scalar else self.conv2(y) 259 x = self.op(x, y) 260 x = self.relu(x) 261 x = self.op(y, x) 262 x = self.relu(x) 263 return x 264 265@torch.fx.wrap 266def _user_func_with_complex_return_type(x): 267 return list(torch.split(x, 1, 1)) 268 269class TestFuseFx(QuantizationTestCase): 270 def test_fuse_conv_bn_relu(self): 271 class M(torch.nn.Module): 272 def __init__(self) -> None: 273 super().__init__() 274 self.conv1d = nn.Conv1d(1, 1, 1) 275 self.conv2d = nn.Conv2d(1, 1, 1) 276 self.conv3d = nn.Conv3d(1, 1, 1) 277 self.bn1d = nn.BatchNorm1d(1) 278 self.bn2d = nn.BatchNorm2d(1) 279 self.bn3d = nn.BatchNorm3d(1) 280 self.conv1d2 = nn.Conv1d(1, 1, 1) 281 self.conv2d2 = nn.Conv2d(1, 1, 1) 282 self.conv3d2 = nn.Conv3d(1, 1, 1) 283 self.bn1d2 = nn.BatchNorm1d(1) 284 self.bn2d2 = nn.BatchNorm2d(1) 285 self.bn3d2 = nn.BatchNorm3d(1) 286 self.relu = nn.ReLU() 287 288 def forward(self, x): 289 x = self.conv1d(x) 290 x = self.bn1d(x) 291 x = self.conv2d(x) 292 x = self.bn2d(x) 293 x = self.conv3d(x) 294 x = self.bn3d(x) 295 x = self.conv1d2(x) 296 x = self.bn1d2(x) 297 x = self.relu(x) 298 x = self.conv2d2(x) 299 x = self.bn2d2(x) 300 x = self.relu(x) 301 x = self.conv3d2(x) 302 x = self.bn3d2(x) 303 x = self.relu(x) 304 return x 305 306 # test train mode 307 m = M().train() 308 # currently we don't check if the module are configured with qconfig before fusion 309 # TODO: if we decide to do that in the future, this test needs to 310 # be updated 311 # train mode fuse_fx is called in prepare_qat_fx 312 m = prepare_qat_fx(m, {}, example_inputs=(torch.randn(1, 1, 1, 1),)) 313 expected_nodes = [ 314 ns.call_module(nni.ConvBn1d), 315 ns.call_module(nni.ConvBn2d), 316 ns.call_module(nni.ConvBn3d), 317 ns.call_module(nni.ConvBnReLU1d), 318 ns.call_module(nni.ConvBnReLU2d), 319 ns.call_module(nni.ConvBnReLU3d), 320 ] 321 expected_occurrence = { 322 ns.call_module(nn.ReLU): 0 323 } 324 self.checkGraphModuleNodes( 325 m, 326 expected_node_list=expected_nodes, 327 expected_node_occurrence=expected_occurrence) 328 329 # test eval mode 330 m = M().eval() 331 # fuse_fx is a top level api and only supports eval mode 332 m = fuse_fx(m) 333 expected_nodes = [ 334 ns.call_module(nn.Conv1d), 335 ns.call_module(nn.Conv2d), 336 ns.call_module(nn.Conv3d), 337 ns.call_module(nni.ConvReLU1d), 338 ns.call_module(nni.ConvReLU2d), 339 ns.call_module(nni.ConvReLU3d), 340 ] 341 # ConvBnRelu1d is not fused 342 expected_occurrence = { 343 ns.call_module(nn.ReLU): 0 344 } 345 self.checkGraphModuleNodes( 346 m, 347 expected_node_list=expected_nodes, 348 expected_node_occurrence=expected_occurrence) 349 350 def test_fuse_linear_bn_eval(self): 351 class M(torch.nn.Module): 352 def __init__(self) -> None: 353 super().__init__() 354 self.linear = nn.Linear(1, 1) 355 self.bn1d = nn.BatchNorm1d(1) 356 357 def forward(self, x): 358 x = self.linear(x) 359 x = self.bn1d(x) 360 return x 361 362 # test eval mode 363 m = M().eval() 364 # fuse_fx is a top level api and only supports eval mode 365 m = fuse_fx(m) 366 expected_nodes = [ 367 ns.call_module(nn.Linear), 368 ] 369 expected_occurrence = { 370 ns.call_module(nn.BatchNorm1d): 0, 371 } 372 self.checkGraphModuleNodes( 373 m, 374 expected_node_list=expected_nodes, 375 expected_node_occurrence=expected_occurrence) 376 377 @skipIfNoONEDNN 378 def test_fuse_linear_bn_leaky_relu_onednn(self): 379 # linear - bn - leaky_relu is fused for onednn backend only 380 from torch.ao.quantization.backend_config import get_onednn_backend_config 381 expected_nodes = [ 382 ns.call_module(nni.LinearLeakyReLU), 383 ] 384 expected_occurrence = { 385 ns.call_module(nn.BatchNorm1d): 0, 386 ns.call_module(nn.LeakyReLU): 0, 387 } 388 389 for with_bn in [True, False]: 390 # test eval mode 391 m = LinearBnLeakyReluModel(with_bn).eval() 392 # fuse_fx is a top level api and only supports eval mode 393 m = fuse_fx(m, 394 backend_config=get_onednn_backend_config()) 395 self.checkGraphModuleNodes( 396 m, 397 expected_node_list=expected_nodes, 398 expected_node_occurrence=expected_occurrence) 399 400 def test_linear_bn_leaky_relu_not_fused_by_default(self): 401 # Make sure linear - bn - leaky_relu is not fused by default 402 for with_bn in [True, False]: 403 # test eval mode 404 m = LinearBnLeakyReluModel(with_bn).eval() 405 # fuse_fx is a top level api and only supports eval mode 406 m = fuse_fx(m) 407 expected_nodes = [ 408 ns.call_module(nn.Linear), 409 ns.call_module(nn.LeakyReLU), 410 ] 411 expected_occurrence = { 412 ns.call_module(nni.LinearLeakyReLU): 0, 413 } 414 self.checkGraphModuleNodes( 415 m, 416 expected_node_list=expected_nodes, 417 expected_node_occurrence=expected_occurrence) 418 419 @skipIfNoONEDNN 420 def test_fuse_linear_tanh_for_onednn_backend(self): 421 # linear - tanh is fused for onednn backend only 422 from torch.ao.quantization.backend_config import get_onednn_backend_config 423 expected_nodes = [ 424 ns.call_module(nni.LinearTanh), 425 ] 426 expected_occurrence = { 427 ns.call_module(nn.Linear): 0, 428 ns.call_module(nn.Tanh): 0, 429 } 430 431 # test eval mode 432 m = LinearTanhModel().eval() 433 # fuse_fx is a top level api and only supports eval mode 434 m = fuse_fx(m, 435 backend_config=get_onednn_backend_config()) 436 self.checkGraphModuleNodes( 437 m, 438 expected_node_list=expected_nodes, 439 expected_node_occurrence=expected_occurrence) 440 441 def test_linear_tanh_not_fused_by_default(self): 442 # Make sure linear - tanh is not fused by default 443 # test eval mode 444 m = LinearTanhModel().eval() 445 # fuse_fx is a top level api and only supports eval mode 446 m = fuse_fx(m) 447 expected_nodes = [ 448 ns.call_module(nn.Linear), 449 ns.call_module(nn.Tanh), 450 ] 451 expected_occurrence = { 452 ns.call_module(nni.LinearTanh): 0, 453 } 454 self.checkGraphModuleNodes( 455 m, 456 expected_node_list=expected_nodes, 457 expected_node_occurrence=expected_occurrence) 458 459 def test_fuse_conv_bn_add_relu_onednn(self): 460 # conv - bn - add - relu is fused for onednn backend only 461 from torch.ao.quantization.backend_config import get_onednn_backend_config 462 options = itertools.product( 463 [True, False], # with_bn 464 [True, False], # with_relu 465 [True, False], # conv in the left 466 [True, False], # with_two_conv 467 [True, False], # use_torch_add 468 ) 469 for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: 470 expected_nodes = [ 471 ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d), 472 ] 473 expected_occurrence = { 474 ns.call_module(nni.ConvAddReLU2d if with_relu else nni.ConvAdd2d): 1, 475 ns.call_module(nn.BatchNorm2d): 0, 476 } 477 478 # test eval mode 479 m = ConvBnAddReluModel( 480 with_bn=with_bn, 481 with_relu=with_relu, 482 left_conv=left_conv, 483 two_conv=two_conv, 484 use_torch_add=use_torch_add).eval() 485 486 m = fuse_fx(m, 487 backend_config=get_onednn_backend_config()) 488 self.checkGraphModuleNodes( 489 m, 490 expected_node_list=expected_nodes, 491 expected_node_occurrence=expected_occurrence) 492 493 def test_fuse_conv_bn_add_relu_by_default(self): 494 options = itertools.product( 495 [True, False], # with_bn 496 [True, False], # with_relu 497 [True, False], # conv in the left 498 [True, False], # with_two_conv 499 [True, False], # use_torch_add 500 ) 501 for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: 502 # test eval mode 503 expected_nodes = [ 504 ns.call_module(nn.Conv2d), 505 ] 506 expected_occurrence = { 507 ns.call_module(nni.ConvAdd2d): 0, 508 } 509 m = ConvBnAddReluModel( 510 with_bn=with_bn, 511 with_relu=with_relu, 512 left_conv=left_conv, 513 two_conv=two_conv, 514 use_torch_add=use_torch_add).eval() 515 m = fuse_fx(m) 516 self.checkGraphModuleNodes( 517 m, 518 expected_node_list=expected_nodes, 519 expected_node_occurrence=expected_occurrence) 520 521 @skipIfNoONEDNN 522 def test_fuse_conv_bn_add_relu_lowering(self): 523 """ Test fusion and lowering of Conv2d - (bn -) ReLU 524 by FX. For onednn backedn only. 525 """ 526 from torch.ao.quantization.backend_config import get_onednn_backend_config 527 qconfig_mapping = get_default_qconfig_mapping('onednn') 528 with override_quantized_engine('onednn'): 529 options = itertools.product( 530 [True, False], # with_bn 531 [True, False], # with_relu 532 [True, False], # conv in the left 533 [True, False], # two_conv 534 [True, False], # use_torch_add 535 ) 536 for with_bn, with_relu, left_conv, two_conv, use_torch_add in options: 537 node_occurrence = { 538 ns.call_function(torch.quantize_per_tensor): 1 if two_conv else 2, 539 ns.call_method("dequantize"): 1, 540 ns.call_module(nniq.ConvAddReLU2d if with_relu else nniq.ConvAdd2d): 1, 541 ns.call_module(nn.Conv2d): 0, 542 ns.call_module(nn.ReLU): 0, 543 } 544 node_occurrence_ref = { 545 ns.call_function(torch.quantize_per_tensor): 3, 546 ns.call_method("dequantize"): 3, 547 } 548 549 # test eval mode 550 m = ConvBnAddReluModel( 551 with_bn=with_bn, 552 with_relu=with_relu, 553 left_conv=left_conv, 554 two_conv=two_conv, 555 use_torch_add=use_torch_add).eval() 556 example_x = m.get_example_inputs() 557 m = prepare_fx(m, qconfig_mapping, 558 example_inputs=example_x, 559 backend_config=get_onednn_backend_config()) 560 m_copy = copy.deepcopy(m) 561 m = convert_fx(m, backend_config=get_onednn_backend_config()) 562 m_ref = convert_to_reference_fx(m_copy) 563 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 564 self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) 565 m(*example_x) 566 567 def test_fuse_convtranspose_bn_eval(self): 568 569 m = ModelForConvTransposeBNFusion().eval() 570 m = fuse_fx(m) 571 572 expected_nodes = [ 573 ns.call_module(nn.ConvTranspose1d), 574 ns.call_module(nn.ConvTranspose2d), 575 ns.call_module(nn.ConvTranspose3d), 576 ] 577 expected_occurrence = { 578 ns.call_module(nn.BatchNorm1d): 0, 579 ns.call_module(nn.BatchNorm2d): 0, 580 ns.call_module(nn.BatchNorm3d): 0, 581 } 582 self.checkGraphModuleNodes( 583 m, 584 expected_node_list=expected_nodes, 585 expected_node_occurrence=expected_occurrence) 586 587 588 def test_fuse_module_relu(self): 589 class M(torch.nn.Module): 590 def __init__(self) -> None: 591 super().__init__() 592 self.conv1d = nn.Conv1d(1, 1, 1) 593 self.conv2d = nn.Conv2d(1, 1, 1) 594 self.conv3d = nn.Conv3d(1, 1, 1) 595 self.bn1d = nn.BatchNorm1d(1) 596 self.bn2d = nn.BatchNorm2d(1) 597 self.bn3d = nn.BatchNorm3d(1) 598 self.relu = nn.ReLU() 599 600 def forward(self, x): 601 x = self.conv1d(x) 602 x = self.relu(x) 603 x = self.conv2d(x) 604 x = self.relu(x) 605 x = self.conv3d(x) 606 x = self.relu(x) 607 x = self.bn1d(x) 608 x = self.relu(x) 609 x = self.bn2d(x) 610 x = self.relu(x) 611 x = self.bn3d(x) 612 x = self.relu(x) 613 return x 614 615 m = M().eval() 616 m = fuse_fx(m) 617 expected_nodes = [ 618 ns.call_module(nni.ConvReLU1d), 619 ns.call_module(nni.ConvReLU2d), 620 ns.call_module(nni.ConvReLU3d), 621 ns.call_module(nni.BNReLU2d), 622 ns.call_module(nni.BNReLU3d), 623 ] 624 self.checkGraphModuleNodes(m, expected_node_list=expected_nodes) 625 626 @skipIfNoFBGEMM 627 def test_qconfig_fused_module(self): 628 """ TODO: add test for all fused modules 629 """ 630 qconfig_dict = { 631 "": None, 632 "object_type": [(nn.Linear, default_qconfig), 633 (nn.ReLU, default_qconfig), 634 (F.relu, default_qconfig)] 635 } 636 637 linearRelu_node_list = [ 638 ns.call_function(torch.quantize_per_tensor), 639 ns.call_module(nniq.LinearReLU), 640 ns.call_method('dequantize') 641 ] 642 643 linearReluLinear_node_list = [ 644 ns.call_function(torch.quantize_per_tensor), 645 ns.call_module(nniq.LinearReLU), 646 ns.call_module(nnq.Linear), 647 ns.call_method('dequantize') 648 ] 649 650 tests = [(LinearReluModel, linearRelu_node_list), 651 (LinearReluLinearModel, linearReluLinear_node_list)] 652 653 for M, node_list in tests: 654 m = M().eval() 655 example_inputs = (torch.rand(5, 5),) 656 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 657 658 prepared(*example_inputs) 659 quantized = convert_fx(prepared) 660 661 self.checkGraphModuleNodes(quantized, expected_node_list=node_list) 662 663 def test_problematic_fuse_example(self): 664 class LinearRelu(nn.Sequential): 665 def __init__(self) -> None: 666 super().__init__( 667 nn.Linear(5, 5), 668 nn.ReLU(), 669 ) 670 671 class M(torch.nn.Module): 672 def __init__(self) -> None: 673 super().__init__() 674 self.lin_relu = LinearRelu() 675 self.linear = nn.Linear(5, 5) 676 677 def forward(self, x): 678 x = self.lin_relu(x) 679 x = self.linear(x) 680 return x 681 682 model = M().eval() 683 # these qconfigs somehow fail equality where default_qconfig does not 684 qconfig_dict = { 685 "": None, 686 "object_type": [ 687 (torch.nn.Linear, get_default_qconfig('fbgemm')), 688 (torch.nn.ReLU, get_default_qconfig('fbgemm')), 689 ], 690 } 691 m = prepare_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) 692 693 self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.ao.nn.intrinsic.modules.fused.LinearReLU)) 694 695 @unittest.skip("Temporarily skipping the test case, will enable after the simple" 696 "pattern format is supported") 697 def test_fuse_addtional_fuser_method(self): 698 class MyConvReLU(torch.nn.Module): 699 pass 700 701 def my_conv_relu_fuser(conv, relu): 702 return MyConvReLU() 703 704 class M(torch.nn.Module): 705 def __init__(self) -> None: 706 super().__init__() 707 self.conv = torch.nn.Conv2d(3, 3, 3) 708 self.relu = torch.nn.ReLU() 709 710 def forward(self, x): 711 return self.relu(self.conv(x)) 712 713 m = M().eval() 714 m = fuse_fx(m, fuse_custom_config={ 715 "additional_fuser_method_mapping": { 716 (torch.nn.Conv2d, torch.nn.ReLU): my_conv_relu_fuser 717 } 718 }) 719 self.checkGraphModuleNodes(m, expected_node=ns.call_module(MyConvReLU)) 720 721 def test_fuse_custom_pattern(self): 722 class M(torch.nn.Module): 723 def __init__(self, use_torch_add=True): 724 super().__init__() 725 self.conv = torch.nn.Conv2d(3, 3, 3) 726 self.bn = torch.nn.BatchNorm2d(3) 727 self.relu = torch.nn.ReLU() 728 self.maxpool = torch.nn.MaxPool2d(3) 729 if use_torch_add: 730 self.add = torch.add 731 else: 732 self.add = operator.add 733 734 def forward(self, x): 735 y = x 736 y = self.maxpool(x) 737 x = self.conv(x) 738 x = self.bn(x) 739 x = self.add(y, x) 740 x = self.relu(x) 741 return x 742 743 for use_torch_add in [True, False]: 744 m = M(use_torch_add).eval() 745 746 def fuse_conv_bn_relu(is_qat, relu, add_pattern): 747 _, _, bn_pattern = add_pattern 748 bn, conv = bn_pattern 749 return conv 750 751 conv_bn_res_relu_config1 = BackendPatternConfig() \ 752 ._set_pattern_complex_format((nn.ReLU, (torch.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ 753 .set_fuser_method(fuse_conv_bn_relu) 754 conv_bn_res_relu_config2 = BackendPatternConfig() \ 755 ._set_pattern_complex_format((nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) \ 756 .set_fuser_method(fuse_conv_bn_relu) 757 backend_config = BackendConfig() \ 758 .set_backend_pattern_config(conv_bn_res_relu_config1) \ 759 .set_backend_pattern_config(conv_bn_res_relu_config2) 760 m = fuse_fx(m, backend_config=backend_config) 761 self.assertEqual(type(m.conv), torch.nn.Conv2d) 762 # check bn and relu are gone since we replaced the whole pattern to conv 763 self.assertFalse(hasattr(m, "bn")) 764 self.assertFalse(hasattr(m, "relu")) 765 766 def test_fusion_pattern_with_multiple_inputs(self): 767 """ This test tests two keys in backend_config: root_node_getter and 768 extra_inputs_getter, 769 root_node_getter is used to identify a "root" module in the node pattern, 770 the node that we'll keep after fusion. 771 extra_inputs_getter will return a list of node that needs to be added to the 772 fused node as extra inputs. 773 """ 774 class M(torch.nn.Module): 775 def __init__(self) -> None: 776 super().__init__() 777 self.conv = torch.nn.Conv2d(3, 3, 3) 778 self.bn = torch.nn.BatchNorm2d(3) 779 self.relu = torch.nn.ReLU() 780 self.maxpool = torch.nn.MaxPool2d(3) 781 782 def forward(self, x): 783 y = x 784 y = self.maxpool(x) 785 x = self.conv(x) 786 x = self.bn(x) 787 x = torch.add(x, y) 788 x = self.relu(x) 789 return x 790 791 m = M().eval() 792 793 def fuse_conv_bn_relu(is_qat, relu, add_pattern): 794 _, bn_pattern, _ = add_pattern 795 bn, conv = bn_pattern 796 return conv 797 798 def conv_bn_res_relu_root_node_getter(pattern): 799 relu, add_pattern = pattern 800 _, bn_pattern, _ = add_pattern 801 bn, conv = bn_pattern 802 return conv 803 804 def conv_bn_res_relu_extra_inputs_getter(pattern): 805 """ get inputs pattern for extra inputs, inputs for root node 806 are assumed to be copied over from root node to the fused node 807 """ 808 relu, add_pattern = pattern 809 _, bn_pattern, extra_input = add_pattern 810 bn, conv = bn_pattern 811 return [extra_input] 812 813 conv_bn_res_relu_config = BackendPatternConfig() \ 814 ._set_pattern_complex_format((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \ 815 .set_fuser_method(fuse_conv_bn_relu) \ 816 ._set_root_node_getter(conv_bn_res_relu_root_node_getter) \ 817 ._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) 818 backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config) 819 m = fuse_fx(m, backend_config=backend_config) 820 self.assertEqual(type(m.conv), torch.nn.Conv2d) 821 # check bn and relu are gone since we replaced the whole pattern to conv 822 self.assertFalse(hasattr(m, "bn")) 823 self.assertFalse(hasattr(m, "relu")) 824 825 # check conv module has two inputs 826 named_modules = dict(m.named_modules()) 827 for node in m.graph.nodes: 828 if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d: 829 self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments" 830 831 def test_fusion_pattern_with_matchallnode(self): 832 """This test tests that the node matched by MatchAllNode will be regared as an input 833 instead of a module to be fused. For instance, we have two patterns: 834 (nn.ReLU, (torch.add, MatchAllNode, nn.Conv2d)) 835 (nn.ReLU, nn.Conv2d) 836 And we wanna fuse the following model 837 Conv2d -> ReLU + 838 Conv2d ------ Add -> ReLU 839 ReLU in the first row is matched as MatchAllNode in the residual pattern. But it won't be 840 fused as part of that pattnern. It needs to be properly fused with the upstream Conv2d. 841 """ 842 843 class M(torch.nn.Module): 844 def __init__(self) -> None: 845 super().__init__() 846 self.conv1 = torch.nn.Conv2d(3, 3, 3) 847 self.relu1 = torch.nn.ReLU() 848 self.conv2 = torch.nn.Conv2d(3, 3, 3) 849 self.relu2 = torch.nn.ReLU() 850 851 def forward(self, x): 852 y = self.conv1(x) 853 y = self.relu1(y) 854 855 x = self.conv2(x) 856 x = torch.add(x, y) 857 x = self.relu2(x) 858 return x 859 860 m = M().eval() 861 862 def fuse_conv_relu(is_qat, conv, relu): 863 return conv 864 865 def fuse_conv_res_relu(is_qat, relu, add_pattern): 866 _, conv, _ = add_pattern 867 return conv 868 869 def conv_res_relu_root_node_getter(pattern): 870 relu, (_, conv, _) = pattern 871 return conv 872 873 def conv_res_relu_extra_inputs_getter(pattern): 874 relu, (_, _, extra_input) = pattern 875 return [extra_input] 876 877 conv_relu_config = BackendPatternConfig((nn.Conv2d, nn.ReLU)) \ 878 .set_fuser_method(fuse_conv_relu) 879 conv_res_relu_config = BackendPatternConfig() \ 880 ._set_pattern_complex_format((nn.ReLU, (torch.add, nn.Conv2d, MatchAllNode))) \ 881 .set_fuser_method(fuse_conv_res_relu) \ 882 ._set_root_node_getter(conv_res_relu_root_node_getter) \ 883 ._set_extra_inputs_getter(conv_res_relu_extra_inputs_getter) 884 backend_config = BackendConfig() \ 885 .set_backend_pattern_config(conv_relu_config) \ 886 .set_backend_pattern_config(conv_res_relu_config) 887 m = fuse_fx(m, backend_config=backend_config) 888 self.assertEqual(type(m.conv1), torch.nn.Conv2d) 889 self.assertEqual(type(m.conv2), torch.nn.Conv2d) 890 # check relu are gone since we replaced both patterns to conv 891 self.assertFalse(hasattr(m, "relu1")) 892 self.assertFalse(hasattr(m, "relu2")) 893 894 895@skipIfNoFBGEMM 896class TestQuantizeFx(QuantizationTestCase): 897 def test_pattern_match(self): 898 """ test MatchAllNode with 899 conv - bn - add - relu pattern 900 """ 901 class M(torch.nn.Module): 902 def __init__(self) -> None: 903 super().__init__() 904 self.conv = nn.Conv2d(1, 1, 1) 905 self.bn = nn.BatchNorm2d(1) 906 self.relu = nn.ReLU() 907 908 def forward(self, x, y): 909 x = self.conv(x) 910 x = self.bn(x) 911 x = x + y 912 x = self.relu(x) 913 return x 914 915 pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) 916 m = torch.fx.symbolic_trace(M()) 917 modules = dict(m.named_modules()) 918 for n in m.graph.nodes: 919 if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU: 920 self.assertTrue(_is_match(modules, n, pattern)) 921 922 def test_pattern_match_constant(self): 923 class M(torch.nn.Module): 924 def forward(self, x): 925 x, _ = torch.ops.aten.max_pool2d_with_indices.default(x) 926 return x 927 928 pattern = (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) 929 m = torch.fx.symbolic_trace(M()) 930 # eliminate the code that get the second output of maxpool, so that the pattern 931 # can be matched 932 m.graph.eliminate_dead_code() 933 modules = dict(m.named_modules()) 934 for n in m.graph.nodes: 935 if n.op == "call_function" and n.target == operator.getitem: 936 self.assertTrue(_is_match(modules, n, pattern)) 937 938 def test_fused_module_qat_swap(self): 939 class Tmp(torch.nn.Module): 940 def __init__(self) -> None: 941 super().__init__() 942 self.tmp = torch.nn.Linear(5, 5) 943 self.relu = torch.nn.ReLU() 944 945 def forward(self, x): 946 x = self.tmp(x) 947 return self.relu(x) 948 949 950 class M(torch.nn.Module): 951 def __init__(self) -> None: 952 super().__init__() 953 self.mods1 = torch.nn.Sequential(Tmp(), torch.nn.Linear(5, 5)) 954 self.mods2 = torch.nn.Linear(5, 5) 955 956 def forward(self, x): 957 a = self.mods1(x) 958 x = torch.add(x, 5) 959 x = self.mods2(x) 960 x = torch.add(x, 5) 961 return a, x 962 963 964 model = M().train() 965 qconfig_dict = { 966 "": None, 967 "object_type": [ 968 (torch.nn.Linear, default_qat_qconfig), 969 (torch.nn.ReLU, default_qat_qconfig), 970 ], 971 } 972 prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=(torch.randn(1, 5),)) 973 self.assertTrue(isinstance(getattr(prepared.mods1, "0").tmp, torch.ao.nn.intrinsic.qat.LinearReLU)) 974 975 def _get_conv_linear_test_cases(self, is_reference): 976 """ Returns a list of test cases, with format: 977 is_dynamic, ModuleClass, module_constructor_inputs, 978 inputs, quantized_node, weight_prepack_op 979 """ 980 class FunctionalConv1d(torch.nn.Module): 981 def __init__(self, weight): 982 super().__init__() 983 self.weight = torch.nn.Parameter(weight) 984 self.stride = 1 985 self.padding = 0 986 self.dilation = 1 987 self.groups = 1 988 989 def forward(self, x): 990 return F.conv1d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) 991 992 993 class Conv1d(torch.nn.Module): 994 def __init__(self, *args): 995 super().__init__() 996 self.conv = torch.nn.Conv1d(*args) 997 998 def forward(self, x): 999 return self.conv(x) 1000 1001 conv1d_input = torch.rand(1, 3, 224) 1002 conv1d_weight = torch.rand(3, 3, 3) 1003 conv1d_module_args = (3, 3, 3) 1004 1005 class FunctionalConv2d(torch.nn.Module): 1006 def __init__(self, weight): 1007 super().__init__() 1008 self.weight = torch.nn.Parameter(weight) 1009 self.stride = (1, 1) 1010 self.padding = (0, 0) 1011 self.dilation = (1, 1) 1012 self.groups = 1 1013 1014 def forward(self, x): 1015 return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups) 1016 1017 class Conv2d(torch.nn.Module): 1018 def __init__(self, *args): 1019 super().__init__() 1020 self.conv = torch.nn.Conv2d(*args) 1021 1022 def forward(self, x): 1023 return self.conv(x) 1024 1025 conv2d_input = torch.rand(1, 3, 224, 224) 1026 conv2d_weight = torch.rand(3, 3, 3, 3) 1027 conv2d_module_args = (3, 3, 3) 1028 1029 class FunctionalConv3d(torch.nn.Module): 1030 def __init__(self, weight): 1031 super().__init__() 1032 self.weight = torch.nn.Parameter(weight) 1033 self.stride = (1, 1, 1) 1034 self.padding = (0, 0, 0) 1035 self.dilation = (1, 1, 1) 1036 self.groups = 1 1037 1038 def forward(self, x): 1039 return F.conv3d( 1040 x, 1041 self.weight, 1042 None, 1043 self.stride, 1044 self.padding, 1045 self.dilation, 1046 self.groups, 1047 ) 1048 1049 class Conv3d(torch.nn.Module): 1050 def __init__(self, *args): 1051 super().__init__() 1052 self.conv = torch.nn.Conv3d(*args) 1053 1054 def forward(self, x): 1055 return self.conv(x) 1056 1057 conv3d_input = torch.rand(1, 3, 32, 224, 224) 1058 conv3d_weight = torch.rand(3, 3, 3, 3, 3) 1059 conv3d_module_args = (3, 3, 3) 1060 1061 class Linear(torch.nn.Module): 1062 def __init__(self, weight): 1063 super().__init__() 1064 self.weight = torch.nn.Parameter(weight) 1065 1066 def forward(self, x): 1067 return F.linear(x, self.weight) 1068 1069 linear_input = torch.rand(8, 5) 1070 linear_weight = torch.rand(10, 5) 1071 1072 class LinearModule(torch.nn.Module): 1073 def __init__(self) -> None: 1074 super().__init__() 1075 self.linear = torch.nn.Linear(5, 10) 1076 1077 def forward(self, x): 1078 return self.linear(x) 1079 1080 linear_module_input = torch.rand(8, 5) 1081 1082 # is_dynamic, ModuleClass, module_constructor_inputs, 1083 # inputs, quantized_node, weight_prepack_node 1084 tests = [ 1085 ( 1086 False, 1087 FunctionalConv1d, 1088 (conv1d_weight,), 1089 (conv1d_input,), 1090 ns.call_function(torch.nn.functional.conv1d if is_reference else torch.ops.quantized.conv1d) , 1091 ns.call_function(torch.ops.quantized.conv1d_prepack), 1092 ), 1093 ( 1094 False, 1095 FunctionalConv2d, 1096 (conv2d_weight,), 1097 (conv2d_input,), 1098 ns.call_function(torch.nn.functional.conv2d if is_reference else torch.ops.quantized.conv2d), 1099 ns.call_function(torch.ops.quantized.conv2d_prepack), 1100 ), 1101 ( 1102 False, 1103 FunctionalConv3d, 1104 (conv3d_weight,), 1105 (conv3d_input,), 1106 ns.call_function(torch.nn.functional.conv3d if is_reference else torch.ops.quantized.conv3d), 1107 ns.call_function(torch.ops.quantized.conv3d_prepack), 1108 ), 1109 ( 1110 False, 1111 Conv1d, 1112 conv1d_module_args, 1113 (conv1d_input,), 1114 ns.call_module(nnqr.Conv1d if is_reference else nnq.Conv1d), 1115 None 1116 ), 1117 ( 1118 False, 1119 Conv2d, 1120 conv2d_module_args, 1121 (conv2d_input,), 1122 ns.call_module(nnqr.Conv2d if is_reference else nnq.Conv2d), 1123 None 1124 ), 1125 ( 1126 False, 1127 Conv3d, 1128 conv3d_module_args, 1129 (conv3d_input,), 1130 ns.call_module(nnqr.Conv3d if is_reference else nnq.Conv3d), 1131 None 1132 ), 1133 ( 1134 True, 1135 Linear, 1136 (linear_weight,), 1137 (linear_input,), 1138 None if is_reference else ns.call_function(torch.ops.quantized.linear_dynamic), 1139 ns.call_function(torch.ops.quantized.linear_prepack), 1140 ), 1141 ( 1142 False, 1143 Linear, 1144 (linear_weight,), 1145 (linear_input,), 1146 ns.call_function(torch.nn.functional.linear if is_reference else torch.ops.quantized.linear), 1147 ns.call_function(torch.ops.quantized.linear_prepack), 1148 ), 1149 ( 1150 True, 1151 LinearModule, 1152 (), 1153 (linear_module_input,), 1154 ns.call_module(nnqr.Linear) if is_reference else ns.call_module(nnqd.Linear), 1155 None, 1156 ), 1157 ( 1158 False, 1159 LinearModule, 1160 (), 1161 (linear_module_input,), 1162 ns.call_module(nnqr.Linear if is_reference else nnq.Linear), 1163 None, 1164 ), 1165 ] 1166 return tests 1167 1168 @skipIfNoFBGEMM 1169 def test_conv_linear_not_reference(self): 1170 """ Test quantizing conv and linear 1171 """ 1172 tests = self._get_conv_linear_test_cases(is_reference=False) 1173 for (is_dynamic, ModuleClass, module_constructor_inputs, 1174 inputs, quantized_node, weight_prepack_node) in tests: 1175 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1176 node_occurrence = {} 1177 if weight_prepack_node: 1178 node_occurrence[weight_prepack_node] = 0 1179 self.checkGraphModeFxOp( 1180 ModuleClass(*module_constructor_inputs), 1181 inputs, quant_type, 1182 expected_node=quantized_node, 1183 expected_node_occurrence=node_occurrence, 1184 is_reference=False) 1185 1186 @skipIfNoFBGEMM 1187 def test_conv_linear_reference(self): 1188 """ Test quantizing functional conv and linear with reference option 1189 """ 1190 tests = self._get_conv_linear_test_cases(is_reference=True) 1191 1192 def _get_keys(prefix, is_dynamic): 1193 all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] 1194 if not is_dynamic: 1195 all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) 1196 return all_keys 1197 1198 for (is_dynamic, ModuleClass, module_constructor_inputs, 1199 inputs, quantized_node, weight_prepack_node) in tests: 1200 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1201 node_occurrence = {} 1202 if weight_prepack_node: 1203 node_occurrence[weight_prepack_node] = 0 1204 result_dict = self.checkGraphModeFxOp( 1205 ModuleClass(*module_constructor_inputs), 1206 inputs, quant_type, 1207 expected_node=quantized_node, 1208 expected_node_occurrence=node_occurrence, 1209 is_reference=True) 1210 qr = result_dict["quantized_reference"] 1211 1212 def checkWeightQParams(model): 1213 for module_name in ("linear", "conv"): 1214 if hasattr(model, module_name): 1215 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) 1216 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) 1217 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) 1218 self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) 1219 1220 def checkSerDeser(model, is_dynamic): 1221 for module_name in ("linear", "conv"): 1222 if hasattr(model, module_name): 1223 # make sure seralization works 1224 state_dict = copy.deepcopy(model.state_dict()) 1225 all_keys = _get_keys(module_name, is_dynamic) 1226 for key in all_keys: 1227 self.assertTrue(key in state_dict) 1228 # check load_state_dict restores states 1229 module = getattr(model, module_name) 1230 prev_scale = module.weight_scale 1231 module.weight_scale = None 1232 model.load_state_dict(state_dict) 1233 module = getattr(model, module_name) 1234 self.assertTrue(torch.equal(prev_scale, module.weight_scale)) 1235 1236 1237 checkWeightQParams(qr) 1238 qr = copy.deepcopy(qr) 1239 # make sure the qparams are preserved after copy 1240 checkWeightQParams(qr) 1241 1242 checkSerDeser(qr, is_dynamic) 1243 1244 def _get_conv_transpose_test_cases(self, use_relu, is_reference): 1245 """ Returns a list of test cases, with format: 1246 is_dynamic, ModuleClass, module_constructor_inputs, 1247 inputs, quantized_node, weight_prepack_op 1248 """ 1249 class FunctionalConvTranspose1d(torch.nn.Module): 1250 def __init__(self, weight): 1251 super().__init__() 1252 self.weight = torch.nn.Parameter(weight) 1253 self.stride = 1 1254 self.padding = 0 1255 self.output_padding = 0 1256 self.dilation = 1 1257 self.groups = 1 1258 1259 def forward(self, x): 1260 y = F.conv_transpose1d( 1261 x, 1262 self.weight, 1263 None, 1264 self.stride, 1265 self.padding, 1266 self.output_padding, 1267 self.groups, 1268 self.dilation 1269 ) 1270 if use_relu: 1271 y = F.relu(y) 1272 return y 1273 1274 class ConvTranspose1d(torch.nn.Module): 1275 def __init__(self, *args): 1276 super().__init__() 1277 self.deconv = torch.nn.ConvTranspose1d(*args) 1278 self.relu = torch.nn.ReLU() 1279 1280 def forward(self, x): 1281 y = self.deconv(x) 1282 if use_relu: 1283 y = self.relu(y) 1284 return y 1285 1286 conv_transpose1d_input = torch.rand(1, 3, 224) 1287 conv_transpose1d_weight = torch.rand(3, 3, 3) 1288 conv_transpose1d_module_args = (3, 3, 3) 1289 1290 class FunctionalConvTranspose2d(torch.nn.Module): 1291 def __init__(self, weight): 1292 super().__init__() 1293 self.weight = torch.nn.Parameter(weight) 1294 self.stride = (1, 1) 1295 self.padding = (0, 0) 1296 self.output_padding = (0, 0) 1297 self.dilation = (1, 1) 1298 self.groups = 1 1299 1300 def forward(self, x): 1301 y = F.conv_transpose2d( 1302 x, 1303 self.weight, 1304 None, 1305 self.stride, 1306 self.padding, 1307 self.output_padding, 1308 self.groups, 1309 self.dilation 1310 ) 1311 if use_relu: 1312 y = F.relu(y) 1313 return y 1314 1315 class ConvTranspose2d(torch.nn.Module): 1316 def __init__(self, *args): 1317 super().__init__() 1318 self.deconv = torch.nn.ConvTranspose2d(*args) 1319 self.relu = torch.nn.ReLU() 1320 1321 def forward(self, x): 1322 y = self.deconv(x) 1323 if use_relu: 1324 y = self.relu(y) 1325 return y 1326 1327 conv_transpose2d_input = torch.rand(1, 3, 224, 224) 1328 conv_transpose2d_weight = torch.rand(3, 3, 3, 3) 1329 conv_transpose2d_module_args = (3, 3, 3) 1330 1331 class FunctionalConvTranspose3d(torch.nn.Module): 1332 def __init__(self, weight): 1333 super().__init__() 1334 self.weight = torch.nn.Parameter(weight) 1335 self.stride = (1, 1, 1) 1336 self.padding = (0, 0, 0) 1337 self.output_padding = (0, 0, 0) 1338 self.dilation = (1, 1, 1) 1339 self.groups = 1 1340 1341 def forward(self, x): 1342 y = F.conv_transpose3d( 1343 x, 1344 self.weight, 1345 None, 1346 self.stride, 1347 self.padding, 1348 self.output_padding, 1349 self.groups, 1350 self.dilation 1351 ) 1352 if use_relu: 1353 y = F.relu(y) 1354 return y 1355 1356 class ConvTranspose3d(torch.nn.Module): 1357 def __init__(self, *args): 1358 super().__init__() 1359 self.deconv = torch.nn.ConvTranspose3d(*args) 1360 self.relu = torch.nn.ReLU() 1361 1362 def forward(self, x): 1363 y = self.deconv(x) 1364 if use_relu: 1365 y = self.relu(y) 1366 return y 1367 1368 conv_transpose3d_input = torch.rand(1, 3, 32, 224, 224) 1369 conv_transpose3d_weight = torch.rand(3, 3, 3, 3, 3) 1370 conv_transpose3d_module_args = (3, 3, 3) 1371 1372 # is_dynamic, ModuleClass, module_constructor_inputs, 1373 # inputs, quantized_node, weight_prepack_node 1374 tests = [ 1375 ( 1376 False, 1377 FunctionalConvTranspose1d, 1378 (conv_transpose1d_weight,), 1379 (conv_transpose1d_input,), 1380 ns.call_function( 1381 torch.nn.functional.conv_transpose1d if is_reference else torch.ops.quantized.conv_transpose1d 1382 ), 1383 ns.call_function(torch.ops.quantized.conv_transpose1d_prepack), 1384 ), 1385 ( 1386 False, 1387 FunctionalConvTranspose2d, 1388 (conv_transpose2d_weight,), 1389 (conv_transpose2d_input,), 1390 ns.call_function( 1391 torch.nn.functional.conv_transpose2d if is_reference else torch.ops.quantized.conv_transpose2d 1392 ), 1393 ns.call_function(torch.ops.quantized.conv_transpose2d_prepack), 1394 ), 1395 ( 1396 False, 1397 FunctionalConvTranspose3d, 1398 (conv_transpose3d_weight,), 1399 (conv_transpose3d_input,), 1400 ns.call_function( 1401 torch.nn.functional.conv_transpose3d if is_reference else torch.ops.quantized.conv_transpose3d), 1402 ns.call_function(torch.ops.quantized.conv_transpose3d_prepack), 1403 ), 1404 ( 1405 False, 1406 ConvTranspose1d, 1407 conv_transpose1d_module_args, 1408 (conv_transpose1d_input,), 1409 ns.call_module(nnqr.ConvTranspose1d if is_reference else nnq.ConvTranspose1d), 1410 None 1411 ), 1412 ( 1413 False, 1414 ConvTranspose2d, 1415 conv_transpose2d_module_args, 1416 (conv_transpose2d_input,), 1417 ns.call_module(nnqr.ConvTranspose2d if is_reference else nnq.ConvTranspose2d), 1418 None 1419 ), 1420 ( 1421 False, 1422 ConvTranspose3d, 1423 conv_transpose3d_module_args, 1424 (conv_transpose3d_input,), 1425 ns.call_module(nnqr.ConvTranspose3d if is_reference else nnq.ConvTranspose3d), 1426 None 1427 ), 1428 ] 1429 return tests 1430 1431 @skipIfNoFBGEMM 1432 def test_conv_transpose_not_reference(self): 1433 """ Test quantizing transposed conv 1434 """ 1435 tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=False) 1436 for (is_dynamic, ModuleClass, module_constructor_inputs, 1437 inputs, quantized_node, weight_prepack_node) in tests: 1438 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1439 node_occurrence = {} 1440 if weight_prepack_node: 1441 node_occurrence[weight_prepack_node] = 0 1442 self.checkGraphModeFxOp( 1443 ModuleClass(*module_constructor_inputs), 1444 inputs, quant_type, 1445 expected_node=quantized_node, 1446 expected_node_occurrence=node_occurrence, 1447 is_reference=False) 1448 1449 @skipIfNoFBGEMM 1450 def test_conv_transpose_reference(self): 1451 """ Test quantizing transposed conv with reference option 1452 """ 1453 tests = self._get_conv_transpose_test_cases(use_relu=False, is_reference=True) 1454 1455 def _get_keys(prefix, is_dynamic): 1456 all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] 1457 if not is_dynamic: 1458 all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) 1459 return all_keys 1460 1461 for (is_dynamic, ModuleClass, module_constructor_inputs, 1462 inputs, quantized_node, weight_prepack_node) in tests: 1463 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1464 node_occurrence = {} 1465 if weight_prepack_node: 1466 node_occurrence[weight_prepack_node] = 0 1467 result_dict = self.checkGraphModeFxOp( 1468 ModuleClass(*module_constructor_inputs), 1469 inputs, quant_type, 1470 expected_node=quantized_node, 1471 expected_node_occurrence=node_occurrence, 1472 is_reference=True) 1473 qr = result_dict["quantized_reference"] 1474 1475 def checkWeightQParams(model): 1476 module_name = "deconv" 1477 if hasattr(model, module_name): 1478 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) 1479 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) 1480 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) 1481 self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) 1482 1483 def checkSerDeser(model, is_dynamic): 1484 module_name = "deconv" 1485 if hasattr(model, module_name): 1486 # make sure seralization works 1487 state_dict = copy.deepcopy(model.state_dict()) 1488 all_keys = _get_keys(module_name, is_dynamic) 1489 for key in all_keys: 1490 self.assertTrue(key in state_dict) 1491 # check load_state_dict restores states 1492 module = getattr(model, module_name) 1493 prev_scale = module.weight_scale 1494 module.weight_scale = None 1495 model.load_state_dict(state_dict) 1496 module = getattr(model, module_name) 1497 self.assertTrue(torch.equal(prev_scale, module.weight_scale)) 1498 1499 1500 checkWeightQParams(qr) 1501 qr = copy.deepcopy(qr) 1502 # make sure the qparams are preserved after copy 1503 checkWeightQParams(qr) 1504 1505 checkSerDeser(qr, is_dynamic) 1506 1507 def test_conv_transpose_relu_not_reference(self): 1508 """ Test quantizing transposed conv + relu 1509 Fusion with relu is not supported. 1510 """ 1511 tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=False) 1512 for (is_dynamic, ModuleClass, module_constructor_inputs, 1513 inputs, quantized_node, weight_prepack_node) in tests: 1514 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1515 node_occurrence = {} 1516 if weight_prepack_node: 1517 node_occurrence[weight_prepack_node] = 0 1518 if quantized_node.op == 'call_module': 1519 node_occurrence[ns.call_module(nn.ReLU)] = 1 1520 else: 1521 node_occurrence[ns.call_function(F.relu)] = 1 1522 self.checkGraphModeFxOp( 1523 ModuleClass(*module_constructor_inputs), 1524 inputs, quant_type, 1525 expected_node=quantized_node, 1526 expected_node_occurrence=node_occurrence, 1527 is_reference=False) 1528 1529 @skipIfNoFBGEMM 1530 def test_conv_transpose_relu_reference(self): 1531 """ Test quantizing transposed conv with reference option 1532 Fusion with relu is not supported. 1533 """ 1534 tests = self._get_conv_transpose_test_cases(use_relu=True, is_reference=True) 1535 1536 def _get_keys(prefix, is_dynamic): 1537 all_keys = [prefix + "." + k for k in ["weight_qscheme", "weight_dtype"]] 1538 if not is_dynamic: 1539 all_keys.extend([prefix + "." + k for k in ["weight_scale", "weight_zero_point"]]) 1540 return all_keys 1541 1542 for (is_dynamic, ModuleClass, module_constructor_inputs, 1543 inputs, quantized_node, weight_prepack_node) in tests: 1544 quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC 1545 node_occurrence = {} 1546 if weight_prepack_node: 1547 node_occurrence[weight_prepack_node] = 0 1548 if quantized_node.op == 'call_module': 1549 node_occurrence[ns.call_module(nn.ReLU)] = 1 1550 else: 1551 node_occurrence[ns.call_function(F.relu)] = 1 1552 result_dict = self.checkGraphModeFxOp( 1553 ModuleClass(*module_constructor_inputs), 1554 inputs, quant_type, 1555 expected_node=quantized_node, 1556 expected_node_occurrence=node_occurrence, 1557 is_reference=True) 1558 qr = result_dict["quantized_reference"] 1559 1560 def checkWeightQParams(model): 1561 module_name = "deconv" 1562 if hasattr(model, module_name): 1563 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_qscheme")) 1564 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_scale")) 1565 self.assertTrue(hasattr(qr.get_submodule(module_name), "weight_zero_point")) 1566 self.assertTrue("Reference" in qr.get_submodule(module_name)._get_name()) 1567 1568 def checkSerDeser(model, is_dynamic): 1569 module_name = "deconv" 1570 if hasattr(model, module_name): 1571 # make sure seralization works 1572 state_dict = copy.deepcopy(model.state_dict()) 1573 all_keys = _get_keys(module_name, is_dynamic) 1574 for key in all_keys: 1575 self.assertTrue(key in state_dict) 1576 # check load_state_dict restores states 1577 module = getattr(model, module_name) 1578 prev_scale = module.weight_scale 1579 module.weight_scale = None 1580 model.load_state_dict(state_dict) 1581 module = getattr(model, module_name) 1582 self.assertTrue(torch.equal(prev_scale, module.weight_scale)) 1583 1584 1585 checkWeightQParams(qr) 1586 qr = copy.deepcopy(qr) 1587 # make sure the qparams are preserved after copy 1588 checkWeightQParams(qr) 1589 1590 checkSerDeser(qr, is_dynamic) 1591 1592 @skipIfNoFBGEMM 1593 def test_dynamic_quant_weight_observer(self): 1594 ''' Test that weight observer is run in convert step 1595 ''' 1596 1597 class M(torch.nn.Module): 1598 def __init__(self, weight): 1599 super().__init__() 1600 self.weight = torch.nn.Parameter(weight) 1601 1602 def forward(self, x): 1603 return F.linear(x, self.weight) 1604 1605 m = M(torch.rand(1, 1)).eval() 1606 qconfig = default_dynamic_qconfig 1607 qconfig_dict = {'': qconfig} 1608 example_inputs = (torch.rand(1, 1),) 1609 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1610 quantized = convert_to_reference_fx(prepared) 1611 qparams = (quantized._scale_0, quantized._zero_point_0) 1612 weight_obs = qconfig.weight() 1613 weight_obs(quantized.weight) 1614 # Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1]) 1615 ref_qparams = (weight_obs.calculate_qparams()[0].item(), weight_obs.calculate_qparams()[1].item()) 1616 self.assertEqual(qparams, ref_qparams) 1617 1618 def test_conv_bn_relu(self): 1619 """ Tests fusion and quantization for "Conv - Bn" and "Conv - Bn - ReLU" 1620 """ 1621 convs = { 1622 1: nn.Conv1d, 1623 2: nn.Conv2d, 1624 3: nn.Conv3d, 1625 } 1626 bns = { 1627 1: nn.BatchNorm1d, 1628 2: nn.BatchNorm2d, 1629 3: nn.BatchNorm3d, 1630 } 1631 quantized_convs = { 1632 1: nnq.Conv1d, 1633 2: nnq.Conv2d, 1634 3: nnq.Conv3d, 1635 } 1636 quantized_conv_relus = { 1637 1: nniq.ConvReLU1d, 1638 2: nniq.ConvReLU2d, 1639 3: nniq.ConvReLU3d, 1640 } 1641 1642 class M(torch.nn.Module): 1643 def __init__(self, dim, has_relu): 1644 super().__init__() 1645 self.conv = convs[dim](3, 3, 3) 1646 self.bn = bns[dim](3) 1647 self.relu = nn.ReLU() if has_relu else nn.Identity() 1648 self.has_relu = has_relu 1649 self.quant = QuantStub() 1650 self.dequant = DeQuantStub() 1651 1652 def forward(self, x): 1653 x = self.quant(x) 1654 x = self.conv(x) 1655 x = self.bn(x) 1656 if self.has_relu: 1657 x = self.relu(x) 1658 x = self.dequant(x) 1659 return x 1660 1661 options = itertools.product([1, 2, 3], [True, False], self.static_quant_types) 1662 for dim, has_relu, quant_type in options: 1663 expected_node = ns.call_module( 1664 quantized_conv_relus[dim] if has_relu 1665 else quantized_convs[dim]) 1666 m = M(dim, has_relu) 1667 m_eager = copy.deepcopy(m) 1668 result_dict = self.checkGraphModeFxOp( 1669 m, 1670 self.img_data_dict[dim], 1671 quant_type, 1672 expected_node=expected_node, 1673 ) 1674 result = result_dict["quantized_output"] 1675 1676 # check numerics 1677 qengine = torch.backends.quantized.engine 1678 if quant_type == QuantType.STATIC: 1679 m_eager.eval() 1680 qconfig = get_default_qconfig(qengine) 1681 prepare_fn = prepare 1682 is_qat = False 1683 else: 1684 m_eager.train() 1685 qconfig = get_default_qat_qconfig(qengine) 1686 prepare_fn = prepare_qat 1687 is_qat = True 1688 1689 fuse_list = ["conv", "bn"] 1690 if has_relu: 1691 fuse_list.append("relu") 1692 if is_qat: 1693 fuse_modules_qat(m_eager, fuse_list, inplace=True) 1694 else: 1695 fuse_modules(m_eager, fuse_list, inplace=True) 1696 m_eager.qconfig = qconfig 1697 m_eager = prepare_fn(m_eager) 1698 prepared_fx = result_dict["prepared"] 1699 1700 m_eager(*self.img_data_dict[dim][0]) 1701 m_eager = convert(m_eager) 1702 result_eager = m_eager(*self.img_data_dict[dim][0]) 1703 self.assertEqual(result, result_eager) 1704 1705 def test_linear_bn(self): 1706 class M(torch.nn.Module): 1707 def __init__(self) -> None: 1708 super().__init__() 1709 self.linear = nn.Linear(4, 4) 1710 self.bn = nn.BatchNorm1d(4) 1711 self.quant = QuantStub() 1712 self.dequant = DeQuantStub() 1713 1714 def forward(self, x): 1715 x = self.quant(x) 1716 x = self.linear(x) 1717 x = self.bn(x) 1718 x = self.dequant(x) 1719 return x 1720 1721 data = (torch.randn(4, 4),) 1722 for quant_type in self.static_quant_types: 1723 expected_node = ns.call_module(nnq.Linear) 1724 m = M() 1725 m_eager = copy.deepcopy(m) 1726 result_dict = self.checkGraphModeFxOp(m, data, quant_type, expected_node=expected_node) 1727 result = result_dict["quantized_output"] 1728 1729 # check numerics vs eager mode 1730 fuse_list = ["linear", "bn"] 1731 qengine = torch.backends.quantized.engine 1732 if quant_type == QuantType.STATIC: 1733 m_eager.eval() 1734 qconfig = get_default_qconfig(qengine) 1735 prepare_fn = prepare 1736 fuse_modules(m_eager, fuse_list, inplace=True) 1737 else: 1738 m_eager.train() 1739 qconfig = get_default_qat_qconfig(qengine) 1740 prepare_fn = prepare_qat 1741 fuse_modules_qat(m_eager, fuse_list, inplace=True) 1742 m_eager.qconfig = qconfig 1743 m_eager = prepare_fn(m_eager) 1744 m_eager(*data) 1745 m_eager = convert(m_eager) 1746 result_eager = m_eager(*data) 1747 self.assertEqual(result, result_eager) 1748 1749 @skipIfNoFBGEMM 1750 def test_dynamic_quant_fp16(self): 1751 with override_quantized_engine('fbgemm'): 1752 class Linear(torch.nn.Module): 1753 def __init__(self, weight): 1754 super().__init__() 1755 self.weight = torch.nn.Parameter(weight) 1756 1757 def forward(self, x): 1758 return F.linear(x, self.weight) 1759 1760 linear_input = torch.rand(8, 5) 1761 linear_weight = torch.rand(10, 5) 1762 1763 class LinearModule(torch.nn.Module): 1764 def __init__(self) -> None: 1765 super().__init__() 1766 self.linear = torch.nn.Linear(5, 10) 1767 1768 def forward(self, x): 1769 return self.linear(x) 1770 1771 linear_module_input = torch.rand(8, 5) 1772 1773 tests = [ 1774 (Linear, (linear_weight,), (linear_input,), 1775 ns.call_function(torch.ops.quantized.linear_dynamic_fp16), 1776 ns.call_function(torch.ops.quantized.linear_prepack_fp16)), 1777 (LinearModule, (), (linear_module_input,), 1778 ns.call_module(nnqd.Linear), 1779 None), 1780 ] 1781 for (ModuleClass, module_constructor_inputs, 1782 inputs, quantized_node, weight_prepack_node) in tests: 1783 for is_reference in [True, False]: 1784 node_occurrence = {} 1785 if weight_prepack_node: 1786 node_occurrence[weight_prepack_node] = 0 1787 m = ModuleClass(*module_constructor_inputs).eval() 1788 qconfig_dict = {"": float16_dynamic_qconfig} 1789 m = prepare_fx(m, qconfig_dict, example_inputs=inputs) 1790 convert_fn = convert_to_reference_fx if is_reference else convert_fx 1791 m = convert_fn(m) 1792 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 1793 1794 1795 1796 @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") 1797 @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 1798 @override_qengines 1799 def test_qat_prepare_device_affinity(self): 1800 """ 1801 Tests that FX QAT prepare pass respects device affinity 1802 """ 1803 class Model(nn.Module): 1804 1805 def __init__(self) -> None: 1806 super().__init__() 1807 self.conv = nn.Conv2d(1, 1, 1) 1808 self.bn = nn.BatchNorm2d(1) 1809 self.relu = nn.ReLU() 1810 1811 def forward(self, x): 1812 x = self.conv(x) 1813 x = self.bn(x) 1814 x = self.relu(x) 1815 return x 1816 1817 model = Model() 1818 qengine = torch.backends.quantized.engine 1819 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} 1820 device = torch.device('cuda:0') 1821 model.to(device) 1822 1823 example_inputs = (torch.randn(4, 1, 4, 4, device=device),) 1824 # QAT prepare 1825 model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 1826 1827 # ensure that running an input on CUDA works without any needed changes 1828 model(*example_inputs) 1829 1830 # ensure all buffers and parameters are on the device we expect 1831 model_devices = {p.device for p in model.parameters()} | \ 1832 {p.device for p in model.buffers()} 1833 self.assertEqual(len(model_devices), 1) 1834 model_device = next(iter(model_devices)) 1835 self.assertEqual(model_device, device) 1836 1837 @skipIfNoFBGEMM 1838 def test_dict_output(self): 1839 """ Make sure quantization runs for models with dictionary output 1840 """ 1841 class M(torch.nn.Module): 1842 def __init__(self) -> None: 1843 super().__init__() 1844 self.conv = torch.nn.Conv2d(1, 1, 1) 1845 1846 def forward(self, x): 1847 return {"output": self.conv(x["input"])} 1848 1849 example_inputs = ({"input": torch.randn(1, 1, 1, 1)},) 1850 m = M().eval() 1851 qconfig_dict = {"": default_qconfig} 1852 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1853 m(*example_inputs) 1854 m = convert_fx(m) 1855 m(*example_inputs) 1856 1857 @override_qengines 1858 def test_attention(self): 1859 """ Make sure quantization runs for a corner case in attention module 1860 """ 1861 class M(torch.nn.Module): 1862 def __init__(self) -> None: 1863 super().__init__() 1864 self.conv = torch.nn.Conv2d(1, 1, 1) 1865 1866 def forward(self, x): 1867 x = self.conv(x) 1868 q, k, v = x.chunk(3, dim=0) 1869 q = q.contiguous().view(-1, 1).transpose(0, 1) 1870 k = k.contiguous().view(-1, 1).transpose(0, 1) 1871 v = v.contiguous().view(-1, 1).transpose(0, 1) 1872 torch._assert( 1873 k.size(1) == 1, "key size should be equal to 1" 1874 ) 1875 r = torch.mm(k, v) 1876 return q * k + r 1877 1878 example_inputs = (torch.randn(3, 1, 1, 1),) 1879 m = M().eval() 1880 qconfig_dict = { 1881 "": None, 1882 "object_type": [ 1883 (nn.Conv2d, default_qconfig), 1884 ] 1885 } 1886 # make sure it runs 1887 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 1888 m(*example_inputs) 1889 m = convert_fx(m) 1890 m(*example_inputs) 1891 1892 def _test_standalone_module( 1893 self, 1894 interface_config, 1895 prepare_count_check, 1896 standalone_prepare_count_check, 1897 convert_count_check, 1898 standalone_convert_count_check): 1899 """ Test standalone module with different quantized input/quantized output 1900 configurations 1901 """ 1902 class StandaloneModule(torch.nn.Module): 1903 def __init__(self) -> None: 1904 super().__init__() 1905 self.conv = torch.nn.Conv2d(1, 1, 1) 1906 1907 def forward(self, x): 1908 return self.conv(x) 1909 1910 class M(torch.nn.Module): 1911 def __init__(self) -> None: 1912 super().__init__() 1913 self.conv = torch.nn.Conv2d(1, 1, 1) 1914 self.standalone = StandaloneModule() 1915 1916 def forward(self, x): 1917 x = self.conv(x) 1918 x = self.standalone(x) 1919 return x 1920 1921 class RefM(torch.nn.Module): 1922 def __init__(self) -> None: 1923 super().__init__() 1924 self.conv1 = torch.nn.Conv2d(1, 1, 1) 1925 self.conv2 = torch.nn.Conv2d(1, 1, 1) 1926 1927 def forward(self, x): 1928 x = self.conv1(x) 1929 x = self.conv2(x) 1930 return x 1931 1932 example_inputs = (torch.randn(1, 1, 1, 1),) 1933 # instantiate M and RefM and align the parameters 1934 original_m = M().eval() 1935 original_ref_m = RefM().eval() 1936 original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) 1937 original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) 1938 original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach()) 1939 original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach()) 1940 1941 for is_name in [True, False]: 1942 sm_example_inputs = example_inputs 1943 if is_name: 1944 prepare_config = { 1945 "standalone_module_name": [("standalone", None, sm_example_inputs, interface_config, None)] 1946 } 1947 else: 1948 prepare_config = { 1949 "standalone_module_class": [(StandaloneModule, None, sm_example_inputs, interface_config, None)] 1950 } 1951 1952 original_m_copy = copy.deepcopy(original_m) 1953 original_ref_m_copy = copy.deepcopy(original_ref_m) 1954 1955 qconfig_dict = {"": default_qconfig} 1956 # check prepared model 1957 m = prepare_fx( 1958 original_m_copy, 1959 qconfig_dict, 1960 example_inputs=example_inputs, 1961 prepare_custom_config=prepare_config) 1962 # calibration 1963 m(*example_inputs) 1964 self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check) 1965 self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check) 1966 1967 # check converted/quantized model 1968 m = convert_fx(m) 1969 self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check) 1970 self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check) 1971 res = m(*example_inputs) 1972 1973 # quantize the reference model 1974 ref_m = prepare_fx( 1975 original_ref_m_copy, 1976 qconfig_dict, 1977 example_inputs=example_inputs, 1978 ) 1979 ref_m(*example_inputs) 1980 ref_m = convert_fx(ref_m) 1981 ref_res = ref_m(*example_inputs) 1982 self.assertEqual(res, ref_res) 1983 1984 def test_standalone_module_float_interface(self): 1985 float_interface_config = { 1986 "input_quantized_idxs": [], # float input 1987 "output_quantized_idxs": [], # float output 1988 } 1989 interface_config = float_interface_config 1990 # input and output of first conv, observer for standalone module 1991 # will be inserted in the standalone module itself 1992 prepare_count_check = { 1993 ns.call_module(torch.ao.quantization.MinMaxObserver): 2 1994 } 1995 # for input and output of conv in the standalone module 1996 standalone_prepare_count_check = { 1997 ns.call_module(torch.ao.quantization.MinMaxObserver): 2 1998 } 1999 convert_count_check = { 2000 ns.call_function(torch.quantize_per_tensor) : 1, 2001 ns.call_module(nnq.Conv2d) : 1, 2002 ns.call_method("dequantize") : 1, 2003 } 2004 standalone_convert_count_check = { 2005 # standalone module will take float as input and output 2006 # so we'll see quantize and dequantize in the modoule 2007 ns.call_function(torch.quantize_per_tensor) : 1, 2008 ns.call_module(nnq.Conv2d): 1, 2009 ns.call_method("dequantize") : 1, 2010 } 2011 self._test_standalone_module( 2012 interface_config, 2013 prepare_count_check, 2014 standalone_prepare_count_check, 2015 convert_count_check, 2016 standalone_convert_count_check) 2017 2018 def test_standalone_module_quantized_interface(self): 2019 quantized_interface_config = { 2020 "input_quantized_idxs": [0], # quantized input 2021 "output_quantized_idxs": [0], # quantized output 2022 } 2023 interface_config = quantized_interface_config 2024 # observer for input and output of first conv 2025 prepare_count_check = { 2026 ns.call_module(torch.ao.quantization.MinMaxObserver): 2 2027 } 2028 # for output of conv in the standalone module 2029 standalone_prepare_count_check = { 2030 ns.call_module(torch.ao.quantization.MinMaxObserver): 1 2031 } 2032 convert_count_check = { 2033 # quantizing input for conv 2034 ns.call_function(torch.quantize_per_tensor) : 1, 2035 ns.call_module(nnq.Conv2d) : 1, 2036 # dequantizing output of standalone module 2037 ns.call_method("dequantize") : 1, 2038 } 2039 standalone_convert_count_check = { 2040 # quantization of input happens in parent module 2041 # quantization of output happens in the quantized conv module 2042 ns.call_function(torch.quantize_per_tensor) : 0, 2043 ns.call_module(nnq.Conv2d): 1, 2044 # dequantization for output happens in parent module 2045 ns.call_method("dequantize") : 0, 2046 } 2047 self._test_standalone_module( 2048 interface_config, 2049 prepare_count_check, 2050 standalone_prepare_count_check, 2051 convert_count_check, 2052 standalone_convert_count_check) 2053 2054 @skipIfNoFBGEMM 2055 def test_qconfig_none(self): 2056 class M(torch.nn.Module): 2057 def __init__(self) -> None: 2058 super().__init__() 2059 self.conv1 = nn.Conv2d(1, 1, 1) 2060 self.conv2 = nn.Conv2d(1, 1, 1) 2061 2062 def forward(self, x): 2063 x = self.conv1(x) 2064 x = self.conv2(x) 2065 return x 2066 2067 m = M().eval() 2068 qconfig_dict = {"": default_qconfig, 2069 "module_name": [("conv2", None)]} 2070 example_inputs = (torch.randn(1, 1, 1, 1),) 2071 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2072 m(*example_inputs) 2073 m = convert_fx(m) 2074 m(*example_inputs) 2075 # first conv is quantized, second conv is not quantized 2076 node_list = [ 2077 ns.call_function(torch.quantize_per_tensor), 2078 ns.call_module(nnq.Conv2d), 2079 ns.call_method("dequantize"), 2080 ns.call_module(nn.Conv2d), 2081 ] 2082 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2083 2084 def test_qconfig_module_type(self): 2085 class M(torch.nn.Module): 2086 def __init__(self) -> None: 2087 super().__init__() 2088 self.conv = nn.Conv2d(1, 1, 1) 2089 self.linear = nn.Linear(9, 3) 2090 2091 def forward(self, x): 2092 x = self.conv(x) 2093 x = x.reshape((1, -1)) 2094 x = self.linear(x) 2095 return x 2096 2097 m = M().eval() 2098 qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]} 2099 example_inputs = (torch.randn(1, 1, 3, 3),) 2100 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2101 m(*example_inputs) 2102 m = convert_fx(m) 2103 m(*example_inputs) 2104 # conv is quantized, linear is not quantized 2105 node_list = [ 2106 ns.call_function(torch.quantize_per_tensor), 2107 ns.call_module(nnq.Conv2d), 2108 ns.call_method("dequantize"), 2109 ns.call_module(nn.Linear), 2110 ] 2111 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2112 2113 def test_qconfig_qat_module_type(self): 2114 class LinearRelu(nn.Sequential): 2115 def __init__(self) -> None: 2116 super().__init__( 2117 nn.Linear(5, 5), 2118 nn.ReLU(), 2119 ) 2120 2121 class M(torch.nn.Module): 2122 def __init__(self) -> None: 2123 super().__init__() 2124 self.lin_relu = LinearRelu() 2125 self.linear = nn.Linear(5, 5) 2126 2127 def forward(self, x): 2128 x = self.lin_relu(x) 2129 x = self.linear(x) 2130 return x 2131 2132 model = M().train() 2133 2134 qconfig_dict = { 2135 "": None, 2136 "object_type": [ 2137 (torch.nn.Linear, default_qat_qconfig), 2138 (torch.nn.ReLU, default_qat_qconfig), 2139 ], 2140 } 2141 example_inputs = (torch.rand(5, 5),) 2142 m = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 2143 m(*example_inputs) 2144 m = convert_fx(m) 2145 m(*example_inputs) 2146 node_list = [ 2147 ns.call_function(torch.quantize_per_tensor), 2148 ns.call_module(nniq.LinearReLU), 2149 ns.call_module(nnq.Linear), 2150 ns.call_method("dequantize"), 2151 ] 2152 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2153 2154 def test_qconfig_function(self): 2155 class M(torch.nn.Module): 2156 def forward(self, x, y): 2157 return x + y 2158 2159 m = M().eval() 2160 qconfig_dict = {"object_type": [(operator.add, default_qconfig)]} 2161 data = torch.randn(1, 1, 1, 1) 2162 example_inputs = (data, data) 2163 m = prepare_fx(m, qconfig_dict, example_inputs) 2164 m(*example_inputs) 2165 m = convert_fx(m) 2166 m(*example_inputs) 2167 # first conv is quantized, second conv is not quantized 2168 node_list = [ 2169 ns.call_function(torch.quantize_per_tensor), 2170 ns.call_function(torch.ops.quantized.add), 2171 ns.call_method("dequantize"), 2172 ] 2173 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2174 2175 def test_qconfig_module_name_regex(self): 2176 class M(torch.nn.Module): 2177 def __init__(self) -> None: 2178 super().__init__() 2179 self.conv1 = nn.Conv2d(1, 1, 1) 2180 self.conv2 = nn.Conv2d(1, 1, 1) 2181 2182 def forward(self, x): 2183 x = self.conv1(x) 2184 x = self.conv2(x) 2185 return x 2186 2187 m = M().eval() 2188 qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]} 2189 example_inputs = (torch.randn(1, 1, 1, 1),) 2190 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2191 m(*example_inputs) 2192 m = convert_fx(m) 2193 m(*example_inputs) 2194 # first conv is quantized, second conv is not quantized 2195 node_list = [ 2196 ns.call_function(torch.quantize_per_tensor), 2197 ns.call_module(nnq.Conv2d), 2198 ns.call_module(nnq.Conv2d), 2199 ns.call_method("dequantize"), 2200 ] 2201 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2202 2203 def test_qconfig_precedence(self): 2204 for device in get_supported_device_types(): 2205 class M(torch.nn.Module): 2206 def __init__(self) -> None: 2207 super().__init__() 2208 self.linear = nn.Linear(1, 1) 2209 self.conv = nn.Conv2d(1, 1, 1) 2210 self.module_conv1 = nn.Conv2d(1, 1, 1) 2211 self.module_conv2 = nn.Conv2d(1, 1, 1) 2212 2213 def forward(self, x): 2214 # global 2215 x = self.linear(x) 2216 # global + object_type --> object_type 2217 x = self.conv(x) 2218 # global + object_type + module_name_regex --> module_name_regex 2219 x = self.module_conv1(x) 2220 # global + object_type + module_name_regex + module_name --> module_name 2221 x = self.module_conv2(x) 2222 return x 2223 2224 m = M().to(device).eval() 2225 2226 global_qconfig = default_qconfig 2227 object_type_qconfig = default_dynamic_qconfig 2228 module_name_regex_qconfig = float16_dynamic_qconfig 2229 module_name_qconfig = default_qat_qconfig 2230 qconfig_dict = { 2231 "": global_qconfig, 2232 "object_type": [(nn.Conv2d, object_type_qconfig)], 2233 "module_name_regex": [("module_conv*", module_name_regex_qconfig)], 2234 "module_name": [("module_conv2", module_name_qconfig)]} 2235 m_prep = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1),)) 2236 self.assertEqual(m_prep.linear.qconfig.activation.p.func, global_qconfig.activation.p.func) 2237 self.assertEqual(m_prep.linear.qconfig.weight.p.func, global_qconfig.weight.p.func) 2238 self.assertEqual(m_prep.conv.qconfig.activation.p.func, object_type_qconfig.activation.p.func) 2239 self.assertEqual(m_prep.conv.qconfig.weight.p.func, object_type_qconfig.weight.p.func) 2240 self.assertEqual(m_prep.module_conv1.qconfig.activation.p.func, module_name_regex_qconfig.activation.p.func) 2241 self.assertEqual(m_prep.module_conv1.qconfig.weight.p.func, module_name_regex_qconfig.weight.p.func) 2242 self.assertEqual(m_prep.module_conv2.qconfig.activation.p.func, module_name_qconfig.activation.p.func) 2243 self.assertEqual(m_prep.module_conv2.qconfig.weight.p.func, module_name_qconfig.weight.p.func) 2244 2245 def test_qconfig_module_name_object_type_order(self): 2246 class M1(torch.nn.Module): 2247 def __init__(self) -> None: 2248 super().__init__() 2249 self.fc1 = nn.Linear(1, 1) 2250 self.fc2 = nn.Linear(1, 1) 2251 2252 def forward(self, x): 2253 x = self.fc1(x) 2254 x = self.fc2(x) 2255 x = torch.add(x, x) 2256 x = torch.add(x, x) 2257 return x 2258 2259 class M2(torch.nn.Module): 2260 def __init__(self) -> None: 2261 super().__init__() 2262 self.fc1 = nn.Linear(1, 1) 2263 self.fc2 = nn.Linear(1, 1) 2264 self.m1 = M1() 2265 2266 def forward(self, x): 2267 x = self.fc1(x) 2268 x = self.fc2(x) 2269 x = torch.add(x, x) 2270 x = torch.add(x, x) 2271 x = self.m1(x) 2272 return x 2273 2274 class M3(torch.nn.Module): 2275 def __init__(self) -> None: 2276 super().__init__() 2277 self.fc1 = nn.Linear(1, 1) 2278 self.fc2 = nn.Linear(1, 1) 2279 self.m2 = M2() 2280 2281 def forward(self, x): 2282 x = self.fc1(x) 2283 x = self.fc2(x) 2284 x = torch.add(x, x) 2285 x = torch.add(x, x) 2286 x = self.m2(x) 2287 return x 2288 2289 m = M3().eval() 2290 qconfig_dict = { 2291 "module_name_object_type_order": [ 2292 # test various FQNs: global, single child, multiple children 2293 ("", nn.Linear, 0, torch.ao.quantization.default_qconfig), 2294 ("", torch.add, 0, torch.ao.quantization.default_qconfig), 2295 ("m2", nn.Linear, 1, torch.ao.quantization.default_qconfig), 2296 ("m2", torch.add, 1, torch.ao.quantization.default_qconfig), 2297 ("m2.m1", nn.Linear, 0, torch.ao.quantization.default_qconfig), 2298 ("m2.m1", torch.add, 0, torch.ao.quantization.default_qconfig), 2299 ], 2300 } 2301 example_inputs = (torch.randn(1, 1, 1, 1),) 2302 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2303 m(*example_inputs) 2304 m = convert_fx(m) 2305 m(*example_inputs) 2306 2307 node_list = [ 2308 # m3 2309 ns.call_function(torch.quantize_per_tensor), 2310 ns.call_module(nnq.Linear), 2311 ns.call_method("dequantize"), 2312 ns.call_module(nn.Linear), 2313 ns.call_function(torch.quantize_per_tensor), 2314 ns.call_function(torch.ops.quantized.add), 2315 ns.call_method("dequantize"), 2316 ns.call_function(torch.add), 2317 # m2 2318 ns.call_module(nn.Linear), 2319 ns.call_function(torch.quantize_per_tensor), 2320 ns.call_module(nnq.Linear), 2321 ns.call_method("dequantize"), 2322 ns.call_function(torch.add), 2323 ns.call_function(torch.quantize_per_tensor), 2324 ns.call_function(torch.ops.quantized.add), 2325 # m1 2326 ns.call_module(nnq.Linear), 2327 ns.call_method("dequantize"), 2328 ns.call_module(nn.Linear), 2329 ns.call_function(torch.quantize_per_tensor), 2330 ns.call_function(torch.ops.quantized.add), 2331 ns.call_method("dequantize"), 2332 ns.call_function(torch.add), 2333 ] 2334 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2335 2336 # test that function order overrides global qconfig 2337 class M4(torch.nn.Module): 2338 def __init__(self) -> None: 2339 super().__init__() 2340 self.fc1 = nn.Linear(1, 1) 2341 self.fc2 = nn.Linear(1, 1) 2342 2343 def forward(self, x): 2344 x = self.fc1(x) 2345 x = self.fc2(x) 2346 x = torch.add(x, x) 2347 x = torch.add(x, x) 2348 return x 2349 2350 m = M4().eval() 2351 qconfig_dict = { 2352 "": torch.ao.quantization.default_qconfig, 2353 "module_name_object_type_order": [ 2354 ("", nn.Linear, 1, None), 2355 ("", torch.add, 1, None), 2356 ], 2357 } 2358 example_inputs = (torch.randn(1, 1, 1, 1),) 2359 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2360 m(*example_inputs) 2361 m = convert_fx(m) 2362 m(*example_inputs) 2363 2364 node_list = [ 2365 ns.call_function(torch.quantize_per_tensor), 2366 ns.call_module(nnq.Linear), 2367 ns.call_method("dequantize"), 2368 ns.call_module(nn.Linear), 2369 ns.call_function(torch.quantize_per_tensor), 2370 ns.call_function(torch.ops.quantized.add), 2371 ns.call_method("dequantize"), 2372 ns.call_function(torch.add), 2373 ] 2374 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2375 2376 2377 @override_qengines 2378 def test_qconfig_dict_with_fused_modules(self): 2379 class LinearReLUModel(torch.nn.Module): 2380 def __init__(self, relu): 2381 super().__init__() 2382 self.linear = torch.nn.Linear(3, 3) 2383 self.relu = relu 2384 2385 def forward(self, x): 2386 x = self.linear(x) 2387 x = self.relu(x) 2388 return x 2389 2390 class ConvReLUModel(torch.nn.Module): 2391 def __init__(self, relu): 2392 super().__init__() 2393 self.conv = torch.nn.Conv1d(3, 3, 3) 2394 self.relu = relu 2395 2396 def forward(self, x): 2397 x = self.conv(x) 2398 x = self.relu(x) 2399 return x 2400 2401 class ConvBnReLUModel(torch.nn.Module): 2402 def __init__(self, relu): 2403 super().__init__() 2404 self.conv = torch.nn.Conv1d(3, 3, 3) 2405 self.bn = torch.nn.BatchNorm1d(3) 2406 self.relu = relu 2407 2408 def forward(self, x): 2409 x = self.conv(x) 2410 x = self.bn(x) 2411 x = self.relu(x) 2412 return x 2413 2414 for model in [LinearReLUModel, ConvReLUModel, ConvBnReLUModel]: 2415 for relu in [torch.nn.ReLU(), torch.nn.functional.relu, torch.relu]: 2416 m = model(relu).eval() 2417 qengine = torch.backends.quantized.engine 2418 qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping(qengine) 2419 # should not crash as in https://github.com/pytorch/pytorch/issues/75825 2420 prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) 2421 2422 # TODO: move QConfigMapping tests to test/quantization/core 2423 def test_qconfig_mapping_set_global(self): 2424 qconfig = get_default_qconfig() 2425 qconfig_mapping = QConfigMapping() 2426 self.assertEqual(qconfig_mapping.global_qconfig, None) 2427 qconfig_mapping.set_global(qconfig) 2428 self.assertEqual(qconfig_mapping.global_qconfig, qconfig) 2429 2430 def test_qconfig_mapping_set_object_type(self): 2431 qconfig1 = get_default_qconfig() 2432 qconfig2 = get_default_qconfig() 2433 qconfig3 = get_default_qconfig() 2434 self.assertNotEqual(qconfig1, qconfig2) 2435 self.assertNotEqual(qconfig1, qconfig3) 2436 qconfig_mapping = QConfigMapping() 2437 self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 0) 2438 # Insert some entries 2439 qconfig_mapping.set_object_type(torch.nn.Linear, qconfig1) 2440 qconfig_mapping.set_object_type(torch.nn.ReLU, qconfig2) 2441 self.assertEqual(len(qconfig_mapping.object_type_qconfigs), 2) 2442 self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig1) 2443 self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) 2444 # Override existing key 2445 qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3) 2446 self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3) 2447 self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) 2448 self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) 2449 self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) 2450 self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) 2451 2452 def test_qconfig_mapping_set_module_name_regex(self): 2453 qconfig1 = get_default_qconfig() 2454 qconfig2 = get_default_qconfig() 2455 qconfig3 = get_default_qconfig() 2456 self.assertNotEqual(qconfig1, qconfig2) 2457 self.assertNotEqual(qconfig1, qconfig3) 2458 qconfig_mapping = QConfigMapping() 2459 self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 0) 2460 # Insert some entries 2461 qconfig_mapping.set_module_name_regex("foo.*bar", qconfig1) 2462 qconfig_mapping.set_module_name_regex("foo.*", qconfig2) 2463 self.assertEqual(len(qconfig_mapping.module_name_regex_qconfigs), 2) 2464 self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig1) 2465 self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) 2466 # Override existing key 2467 qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3) 2468 self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3) 2469 self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) 2470 self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) 2471 self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) 2472 self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) 2473 self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) 2474 self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) 2475 2476 def test_qconfig_mapping_set_module_name(self): 2477 qconfig1 = get_default_qconfig() 2478 qconfig2 = get_default_qconfig() 2479 qconfig3 = get_default_qconfig() 2480 self.assertNotEqual(qconfig1, qconfig2) 2481 self.assertNotEqual(qconfig1, qconfig3) 2482 qconfig_mapping = QConfigMapping() 2483 self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 0) 2484 # Insert some entries 2485 qconfig_mapping.set_module_name("mod1", qconfig1) 2486 qconfig_mapping.set_module_name("mod2", qconfig2) 2487 self.assertEqual(len(qconfig_mapping.module_name_qconfigs), 2) 2488 self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig1) 2489 self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) 2490 # Override existing key 2491 qconfig_mapping.set_module_name("mod1", qconfig3) 2492 self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3) 2493 self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) 2494 self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) 2495 self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) 2496 self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) 2497 2498 def test_qconfig_mapping_set_module_name_object_type_order(self): 2499 qconfig1 = get_default_qconfig() 2500 qconfig2 = get_default_qconfig() 2501 qconfig3 = get_default_qconfig() 2502 self.assertNotEqual(qconfig1, qconfig2) 2503 self.assertNotEqual(qconfig1, qconfig3) 2504 qconfig_mapping = QConfigMapping() 2505 self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 0) 2506 # Insert some entries 2507 qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig1) 2508 qconfig_mapping.set_module_name_object_type_order("mod2", torch.nn.ReLU, 1, qconfig2) 2509 self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) 2510 key1 = ("mod1", torch.nn.Linear, 0) 2511 key2 = ("mod2", torch.nn.ReLU, 1) 2512 self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1) 2513 self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) 2514 self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig1) 2515 self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) 2516 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2517 qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig1) 2518 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2519 qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) 2520 # Override existing key 2521 qconfig_mapping.set_module_name_object_type_order("mod1", torch.nn.Linear, 0, qconfig3) 2522 self.assertEqual(len(qconfig_mapping.module_name_object_type_order_qconfigs), 2) 2523 self.assertEqual(next(iter(qconfig_mapping.module_name_object_type_order_qconfigs)), key1) 2524 self.assertEqual(list(qconfig_mapping.module_name_object_type_order_qconfigs)[1], key2) 2525 self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key1], qconfig3) 2526 self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs[key2], qconfig2) 2527 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2528 qconfig_mapping, "mod1", torch.nn.Linear, 0, None), qconfig3) 2529 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2530 qconfig_mapping, "mod2", torch.nn.ReLU, 1, None), qconfig2) 2531 # No match 2532 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2533 qconfig_mapping, "mod123", torch.nn.Linear, 0, None), None) 2534 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2535 qconfig_mapping, "mod1", torch.nn.Linear, 35, None), None) 2536 self.assertEqual(_maybe_adjust_qconfig_for_module_name_object_type_order( 2537 qconfig_mapping, "mod2", torch.nn.Conv2d, 1, None), None) 2538 2539 def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, qconfig2): 2540 """ 2541 Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. 2542 """ 2543 return { 2544 _GLOBAL_DICT_KEY: global_qconfig, 2545 _OBJECT_TYPE_DICT_KEY: [ 2546 (torch.nn.Linear, qconfig1), 2547 (torch.nn.ReLU, qconfig2), 2548 ], 2549 _MODULE_NAME_REGEX_DICT_KEY: [ 2550 ("foo.*bar", qconfig1), 2551 ("foo.*", qconfig2), 2552 ], 2553 _MODULE_NAME_DICT_KEY: [ 2554 ("bazbaz", qconfig1), 2555 ("borbor", qconfig2), 2556 ], 2557 _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ 2558 ("bazbaz", torch.nn.Linear, 0, qconfig1), 2559 ("foofoo", torch.nn.ReLU, 1, qconfig2), 2560 ], 2561 } 2562 2563 with self.assertRaises(ValueError) as context: 2564 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) # noqa: F821 2565 self.assertTrue( 2566 'Expected qconfig_dict to have the following keys:' in str(context.exception) 2567 ) 2568 self.assertTrue('But found \'object_typo\' instead.' in str(context.exception)) 2569 2570 def test_qconfig_mapping_from_dict(self): 2571 global_qconfig = QConfig(123, "global") 2572 qconfig1 = QConfig(1, "one") 2573 qconfig2 = QConfig(2, "two") 2574 qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) 2575 qconfig_dict["undefined_dict_key"] = [(123, qconfig1), (234, qconfig2)] 2576 qconfig_mapping = QConfigMapping.from_dict(qconfig_dict) 2577 self.assertEqual(qconfig_mapping.global_qconfig, global_qconfig) 2578 self.assertEqual(qconfig_mapping.object_type_qconfigs, OrderedDict({ 2579 torch.nn.Linear: qconfig1, 2580 torch.nn.ReLU: qconfig2, 2581 })) 2582 self.assertEqual(qconfig_mapping.module_name_regex_qconfigs, OrderedDict({ 2583 "foo.*bar": qconfig1, 2584 "foo.*": qconfig2, 2585 })) 2586 self.assertEqual(qconfig_mapping.module_name_qconfigs, OrderedDict({ 2587 "bazbaz": qconfig1, 2588 "borbor": qconfig2, 2589 })) 2590 self.assertEqual(qconfig_mapping.module_name_object_type_order_qconfigs, OrderedDict({ 2591 ("bazbaz", torch.nn.Linear, 0): qconfig1, 2592 ("foofoo", torch.nn.ReLU, 1): qconfig2, 2593 })) 2594 2595 def test_qconfig_mapping_to_dict(self): 2596 global_qconfig = QConfig(123, "global") 2597 qconfig1 = QConfig(1, "one") 2598 qconfig2 = QConfig(2, "two") 2599 qconfig_mapping = QConfigMapping().set_global(global_qconfig) \ 2600 .set_object_type(torch.nn.Linear, qconfig1) \ 2601 .set_object_type(torch.nn.ReLU, qconfig2) \ 2602 .set_module_name_regex("foo.*bar", qconfig1) \ 2603 .set_module_name_regex("foo.*", qconfig2) \ 2604 .set_module_name("bazbaz", qconfig1) \ 2605 .set_module_name("borbor", qconfig2) \ 2606 .set_module_name_object_type_order("bazbaz", torch.nn.Linear, 0, qconfig1) \ 2607 .set_module_name_object_type_order("foofoo", torch.nn.ReLU, 1, qconfig2) 2608 qconfig_dict = self._get_qconfig_dict_for_qconfig_mapping_test(global_qconfig, qconfig1, qconfig2) 2609 self.assertEqual(qconfig_mapping.to_dict(), qconfig_dict) 2610 2611 def test_qconfig_mapping_repr(self): 2612 self.assertTrue(isinstance(get_default_qconfig_mapping().__repr__(), str)) 2613 2614 def test_default_qconfig_mapping_override_global(self): 2615 class M(torch.nn.Module): 2616 def __init__(self) -> None: 2617 super().__init__() 2618 self.conv = torch.nn.Conv2d(1, 1, 1) 2619 2620 def forward(self, x): 2621 return self.conv(x) 2622 2623 m = M().eval() 2624 my_qconfig = QConfig(activation=MinMaxObserver, weight=default_weight_observer) 2625 qconfig_mapping = get_default_qconfig_mapping() 2626 # Override global qconfig 2627 old_global_qconfig = qconfig_mapping.global_qconfig 2628 qconfig_mapping.set_global(my_qconfig) 2629 # Verify the correct qconfig was used 2630 example_inputs = (torch.randn(1, 1, 1, 1),) 2631 m = prepare_fx(m, qconfig_mapping, example_inputs) 2632 self.assertTrue(isinstance(old_global_qconfig.activation(), HistogramObserver)) 2633 self.assertTrue(isinstance(my_qconfig.activation(), MinMaxObserver)) 2634 self.assertTrue(hasattr(m, "activation_post_process_0")) 2635 self.assertTrue(hasattr(m, "activation_post_process_1")) 2636 self.assertTrue(isinstance(m.activation_post_process_0, MinMaxObserver)) 2637 self.assertTrue(isinstance(m.activation_post_process_1, MinMaxObserver)) 2638 2639 # Dummy classes for PrepareCustomConfig testing 2640 2641 class _DummyStandaloneModule: 2642 pass 2643 2644 class _DummyFloatModule: 2645 pass 2646 2647 class _DummyObservedModule: 2648 pass 2649 2650 class _DummyQuantizedModule: 2651 pass 2652 2653 class _DummyNonTraceableModule1: 2654 pass 2655 2656 class _DummyNonTraceableModule2: 2657 pass 2658 2659 def test_prepare_custom_config_set_standalone_module_name(self): 2660 qconfig_mapping = QConfigMapping() 2661 example_inputs = (torch.randn(3),) 2662 child_prepare_custom_config = PrepareCustomConfig() 2663 backend_config = BackendConfig("my_backend") 2664 config_entry = StandaloneModuleConfigEntry( 2665 qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) 2666 prepare_custom_config = PrepareCustomConfig() 2667 self.assertEqual(len(prepare_custom_config.standalone_module_names), 0) 2668 prepare_custom_config.set_standalone_module_name( 2669 "module1", qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) 2670 self.assertEqual(list(prepare_custom_config.standalone_module_names.keys()), ["module1"]) 2671 self.assertEqual(prepare_custom_config.standalone_module_names["module1"], config_entry) 2672 2673 def test_prepare_custom_config_set_standalone_module_class(self): 2674 qconfig_mapping = QConfigMapping() 2675 example_inputs = (torch.randn(3),) 2676 child_prepare_custom_config = PrepareCustomConfig() 2677 backend_config = BackendConfig("my_backend") 2678 config_entry = StandaloneModuleConfigEntry( 2679 qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) 2680 prepare_custom_config = PrepareCustomConfig() 2681 self.assertEqual(len(prepare_custom_config.standalone_module_classes), 0) 2682 prepare_custom_config.set_standalone_module_class( 2683 self._DummyStandaloneModule, qconfig_mapping, example_inputs, child_prepare_custom_config, backend_config) 2684 self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) 2685 self.assertTrue(self._DummyStandaloneModule in prepare_custom_config.standalone_module_classes) 2686 self.assertEqual(prepare_custom_config.standalone_module_classes[self._DummyStandaloneModule], config_entry) 2687 2688 def test_prepare_custom_config_set_float_to_observed_mapping(self): 2689 prepare_custom_config = PrepareCustomConfig() 2690 self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 0) 2691 prepare_custom_config.set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule, QuantType.STATIC) 2692 self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) 2693 self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) 2694 self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) 2695 self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) 2696 self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], 2697 self._DummyObservedModule) 2698 2699 def test_prepare_custom_config_set_non_traceable_module_names(self): 2700 prepare_custom_config = PrepareCustomConfig() 2701 self.assertEqual(len(prepare_custom_config.non_traceable_module_names), 0) 2702 prepare_custom_config.set_non_traceable_module_names(["module1", "module2"]) 2703 self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module1", "module2"]) 2704 2705 def test_prepare_custom_config_set_non_traceable_module_classes(self): 2706 prepare_custom_config = PrepareCustomConfig() 2707 self.assertEqual(len(prepare_custom_config.non_traceable_module_classes), 0) 2708 prepare_custom_config.set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) 2709 self.assertEqual(prepare_custom_config.non_traceable_module_classes, 2710 [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) 2711 2712 def test_prepare_custom_config_set_input_quantized_indexes(self): 2713 prepare_custom_config = PrepareCustomConfig() 2714 self.assertEqual(len(prepare_custom_config.input_quantized_indexes), 0) 2715 prepare_custom_config.set_input_quantized_indexes([0, 1]) 2716 self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) 2717 2718 def test_prepare_custom_config_set_output_quantized_indexes(self): 2719 prepare_custom_config = PrepareCustomConfig() 2720 self.assertEqual(len(prepare_custom_config.output_quantized_indexes), 0) 2721 prepare_custom_config.set_output_quantized_indexes([0, 1]) 2722 self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) 2723 2724 def test_prepare_custom_config_set_preserved_attributes(self): 2725 prepare_custom_config = PrepareCustomConfig() 2726 self.assertEqual(len(prepare_custom_config.preserved_attributes), 0) 2727 prepare_custom_config.set_preserved_attributes(["attr1", "attr2"]) 2728 self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) 2729 2730 def _get_dummy_prepare_custom_config_dict(self): 2731 """ 2732 Return a dummy prepare_custom_config_dict to test PrepareCustomConfig's to_dict and from_dict methods. 2733 """ 2734 return { 2735 STANDALONE_MODULE_NAME_DICT_KEY: [( 2736 "module1", 2737 QConfigMapping(), 2738 (torch.randn(3),), 2739 PrepareCustomConfig(), 2740 BackendConfig("my_backend"), 2741 )], 2742 STANDALONE_MODULE_CLASS_DICT_KEY: [( 2743 self._DummyStandaloneModule, 2744 QConfigMapping(), 2745 (torch.randn(10),), 2746 PrepareCustomConfig(), 2747 BackendConfig("my_backend"), 2748 )], 2749 FLOAT_TO_OBSERVED_DICT_KEY: { 2750 "static": { 2751 self._DummyFloatModule: self._DummyObservedModule 2752 }, 2753 }, 2754 NON_TRACEABLE_MODULE_NAME_DICT_KEY: ["module2", "module3"], 2755 NON_TRACEABLE_MODULE_CLASS_DICT_KEY: [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2], 2756 INPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], 2757 OUTPUT_QUANTIZED_INDEXES_DICT_KEY: [0, 1], 2758 PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] 2759 } 2760 2761 def test_prepare_custom_config_from_dict(self): 2762 prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() 2763 (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] 2764 (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] 2765 sm_config_entry1 = StandaloneModuleConfigEntry(qm1, ei1, pcc1, bcd1) 2766 sm_config_entry2 = StandaloneModuleConfigEntry(qm2, ei2, pcc2, bcd2) 2767 prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config_dict) 2768 2769 # Standalone modules 2770 self.assertEqual(len(prepare_custom_config.standalone_module_names), 1) 2771 self.assertTrue(sm_name in prepare_custom_config.standalone_module_names) 2772 self.assertEqual(prepare_custom_config.standalone_module_names[sm_name], sm_config_entry1) 2773 self.assertEqual(len(prepare_custom_config.standalone_module_classes), 1) 2774 self.assertTrue(sm_class in prepare_custom_config.standalone_module_classes) 2775 self.assertEqual(prepare_custom_config.standalone_module_classes[sm_class], sm_config_entry2) 2776 2777 # Float to observed mapping 2778 self.assertEqual(len(prepare_custom_config.float_to_observed_mapping), 1) 2779 self.assertEqual(list(prepare_custom_config.float_to_observed_mapping.keys()), [QuantType.STATIC]) 2780 self.assertEqual(len(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]), 1) 2781 self.assertTrue(self._DummyFloatModule in prepare_custom_config.float_to_observed_mapping[QuantType.STATIC]) 2782 self.assertEqual(prepare_custom_config.float_to_observed_mapping[QuantType.STATIC][self._DummyFloatModule], 2783 self._DummyObservedModule) 2784 2785 # Other 2786 self.assertEqual(prepare_custom_config.non_traceable_module_names, ["module2", "module3"]) 2787 self.assertEqual(prepare_custom_config.non_traceable_module_classes, 2788 [self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) 2789 self.assertEqual(prepare_custom_config.input_quantized_indexes, [0, 1]) 2790 self.assertEqual(prepare_custom_config.output_quantized_indexes, [0, 1]) 2791 self.assertEqual(prepare_custom_config.preserved_attributes, ["attr1", "attr2"]) 2792 2793 def test_prepare_custom_config_to_dict(self): 2794 prepare_custom_config_dict = self._get_dummy_prepare_custom_config_dict() 2795 (sm_name, qm1, ei1, pcc1, bcd1) = prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] 2796 (sm_class, qm2, ei2, pcc2, bcd2) = prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] 2797 prepare_custom_config = PrepareCustomConfig() \ 2798 .set_standalone_module_name(sm_name, qm1, ei1, pcc1, bcd1) \ 2799 .set_standalone_module_class(sm_class, qm2, ei2, pcc2, bcd2) \ 2800 .set_float_to_observed_mapping(self._DummyFloatModule, self._DummyObservedModule) \ 2801 .set_non_traceable_module_names(["module2", "module3"]) \ 2802 .set_non_traceable_module_classes([self._DummyNonTraceableModule1, self._DummyNonTraceableModule2]) \ 2803 .set_input_quantized_indexes([0, 1]) \ 2804 .set_output_quantized_indexes([0, 1]) \ 2805 .set_preserved_attributes(["attr1", "attr2"]) 2806 # PrepareCustomConfig.to_dict also converts internal QConfigMappings and PrepareCustomConfigs to dicts 2807 prepare_custom_config_dict[STANDALONE_MODULE_NAME_DICT_KEY][0] = (sm_name, qm1.to_dict(), ei1, pcc1.to_dict(), bcd1) 2808 prepare_custom_config_dict[STANDALONE_MODULE_CLASS_DICT_KEY][0] = (sm_class, qm2.to_dict(), ei2, pcc2.to_dict(), bcd2) 2809 self.assertEqual(prepare_custom_config.to_dict(), prepare_custom_config_dict) 2810 2811 def test_convert_custom_config_set_observed_to_quantized_mapping(self): 2812 convert_custom_config = ConvertCustomConfig() 2813 self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 0) 2814 convert_custom_config.set_observed_to_quantized_mapping( 2815 self._DummyObservedModule, self._DummyQuantizedModule, QuantType.STATIC) 2816 self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) 2817 self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) 2818 self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) 2819 self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], 2820 self._DummyQuantizedModule) 2821 2822 def test_convert_custom_config_set_preserved_attributes(self): 2823 convert_custom_config = ConvertCustomConfig() 2824 self.assertEqual(len(convert_custom_config.preserved_attributes), 0) 2825 convert_custom_config.set_preserved_attributes(["attr1", "attr2"]) 2826 self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) 2827 2828 def _get_dummy_convert_custom_config_dict(self): 2829 """ 2830 Return a dummy convert_custom_config_dict to test ConvertCustomConfig's to_dict and from_dict methods. 2831 """ 2832 return { 2833 OBSERVED_TO_QUANTIZED_DICT_KEY: { 2834 "static": { 2835 self._DummyObservedModule: self._DummyQuantizedModule 2836 }, 2837 }, 2838 PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"] 2839 } 2840 2841 def test_convert_custom_config_from_dict(self): 2842 convert_custom_config_dict = self._get_dummy_convert_custom_config_dict() 2843 convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config_dict) 2844 self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping), 1) 2845 self.assertEqual(list(convert_custom_config.observed_to_quantized_mapping.keys()), [QuantType.STATIC]) 2846 self.assertEqual(len(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]), 1) 2847 self.assertTrue(self._DummyObservedModule in convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC]) 2848 self.assertEqual(convert_custom_config.observed_to_quantized_mapping[QuantType.STATIC][self._DummyObservedModule], 2849 self._DummyQuantizedModule) 2850 self.assertEqual(convert_custom_config.preserved_attributes, ["attr1", "attr2"]) 2851 2852 def test_convert_custom_config_to_dict(self): 2853 convert_custom_config = ConvertCustomConfig() \ 2854 .set_observed_to_quantized_mapping(self._DummyObservedModule, self._DummyQuantizedModule) \ 2855 .set_preserved_attributes(["attr1", "attr2"]) 2856 self.assertEqual(convert_custom_config.to_dict(), self._get_dummy_convert_custom_config_dict()) 2857 2858 def test_fuse_custom_config_set_preserved_attributes(self): 2859 fuse_custom_config = FuseCustomConfig() 2860 self.assertEqual(len(fuse_custom_config.preserved_attributes), 0) 2861 fuse_custom_config.set_preserved_attributes(["attr1", "attr2"]) 2862 self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) 2863 2864 def test_fuse_custom_config_from_dict(self): 2865 fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} 2866 fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config_dict) 2867 self.assertEqual(fuse_custom_config.preserved_attributes, ["attr1", "attr2"]) 2868 2869 def test_fuse_custom_config_to_dict(self): 2870 fuse_custom_config_dict = {PRESERVED_ATTRIBUTES_DICT_KEY: ["attr1", "attr2"]} 2871 fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) 2872 self.assertEqual(fuse_custom_config.to_dict(), fuse_custom_config_dict) 2873 2874 def test_remove_qconfig(self): 2875 class M(torch.nn.Module): 2876 def __init__(self) -> None: 2877 super().__init__() 2878 self.avg_pool = torch.nn.AvgPool2d(1) 2879 2880 def forward(self, x): 2881 return self.avg_pool(x) 2882 2883 m = M().eval() 2884 qconfig_dict = {'': default_qconfig} 2885 example_inputs = (torch.randn(1, 1, 1, 1),) 2886 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2887 m(*example_inputs) 2888 m = convert_fx(m) 2889 m(*example_inputs) 2890 for name, module in m.named_modules(): 2891 self.assertFalse(hasattr(module, 'qconfig'), 2892 'qconfig is not removed for ' + name) 2893 2894 def test_return_none(self): 2895 class M(torch.nn.Module): 2896 def forward(self, x): 2897 pass 2898 2899 m = M().eval() 2900 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 2901 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1),)) 2902 m = convert_fx(m) 2903 2904 def test_default_quant_after_none_qconfig(self): 2905 """ Make sure default quant is inserted properly""" 2906 class M(torch.nn.Module): 2907 def __init__(self) -> None: 2908 super().__init__() 2909 self.conv1 = torch.nn.Conv2d(1, 1, 1) 2910 self.conv2 = torch.nn.Conv2d(1, 1, 1) 2911 2912 def forward(self, x): 2913 x = self.conv1(x) 2914 x = x.transpose(1, 2) 2915 x = self.conv2(x) 2916 2917 m = M().eval() 2918 qconfig_dict = { 2919 "": default_qconfig, 2920 "module_name": [ 2921 ("conv1", None) 2922 ] 2923 } 2924 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) 2925 m = convert_fx(m) 2926 2927 def test_qconfig_for_call_method(self): 2928 class Sub(torch.nn.Module): 2929 def __init__(self) -> None: 2930 super().__init__() 2931 self.conv = torch.nn.Conv2d(1, 1, 1) 2932 2933 def forward(self, x): 2934 x = x.transpose(2, 3) 2935 x = self.conv(x) 2936 return x.transpose(2, 3) 2937 2938 class M(torch.nn.Module): 2939 def __init__(self) -> None: 2940 super().__init__() 2941 self.sub = Sub() 2942 self.conv1 = torch.nn.Conv2d(1, 1, 1) 2943 self.conv2 = torch.nn.Conv2d(1, 1, 1) 2944 2945 def forward(self, x): 2946 x = self.conv1(x) 2947 x = self.sub(x) 2948 x = self.conv2(x) 2949 return x.transpose(2, 3) 2950 2951 qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]} 2952 # since sub is configured to have qconfig None, we should dequantize the output 2953 # of self.conv1 and quantize the input of self.conv2 2954 # dequantize after conv2 should happen after transpose since 2955 # it is configured with default_qconfig 2956 # nodes in Sub module instance is not quantized 2957 node_list1 = [ 2958 ns.call_function(torch.quantize_per_tensor), 2959 ns.call_module(nnq.Conv2d), 2960 ns.call_method("dequantize"), 2961 ns.call_method("transpose"), 2962 ns.call_module(nn.Conv2d), 2963 ns.call_method("transpose"), 2964 ns.call_function(torch.quantize_per_tensor), 2965 ns.call_module(nnq.Conv2d), 2966 ns.call_method("transpose"), 2967 ns.call_method("dequantize") 2968 ] 2969 2970 qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]} 2971 # Only nodes in Sub module instance are quantized 2972 # the first transpose is not quantized because the input is not quantized 2973 node_list2 = [ 2974 ns.call_module(nn.Conv2d), 2975 ns.call_function(torch.quantize_per_tensor), 2976 ns.call_method("transpose"), 2977 ns.call_module(nnq.Conv2d), 2978 ns.call_method("transpose"), 2979 ns.call_method("dequantize"), 2980 ns.call_module(nn.Conv2d), 2981 ns.call_method("transpose"), 2982 ] 2983 2984 for qconfig_dict, node_list in [ 2985 (qconfig_dict1, node_list1), 2986 (qconfig_dict2, node_list2) 2987 ]: 2988 example_inputs = (torch.randn(2, 1, 3, 3),) 2989 m = M().eval() 2990 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 2991 m(torch.randn(2, 1, 3, 3)) 2992 m = convert_fx(m) 2993 self.checkGraphModuleNodes(m, expected_node_list=node_list) 2994 # make sure it runs 2995 m(*example_inputs) 2996 2997 def test_qconfig_for_call_func(self): 2998 class Linear(torch.nn.Module): 2999 def __init__(self) -> None: 3000 super().__init__() 3001 self.w = torch.ones(5, 5) 3002 self.b = torch.zeros(5) 3003 3004 def forward(self, x): 3005 return torch.nn.functional.linear(x, self.w, self.b) 3006 3007 class M(torch.nn.Module): 3008 def __init__(self) -> None: 3009 super().__init__() 3010 self.mods1 = torch.nn.Sequential( 3011 Linear(), 3012 Linear() 3013 ) 3014 self.mods2 = Linear() 3015 3016 def forward(self, x): 3017 x = self.mods1(x) 3018 x = self.mods2(x) 3019 return x 3020 3021 model = M().eval() 3022 example_inputs = (torch.rand(5, 5),) 3023 qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]} 3024 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 3025 m(*example_inputs) 3026 3027 m = convert_fx(m) 3028 node_list = [ 3029 ns.call_function(torch.quantize_per_tensor), 3030 ns.call_function(torch.ops.quantized.linear), 3031 ns.call_function(torch.ops.quantized.linear), 3032 ns.call_method('dequantize'), 3033 ns.call_function(torch.nn.functional.linear) 3034 ] 3035 self.checkGraphModuleNodes(m, expected_node_list=node_list) 3036 m(torch.rand(5, 5)) 3037 3038 def test_preserve_attributes(self): 3039 class M(torch.nn.Module): 3040 def __init__(self) -> None: 3041 super().__init__() 3042 self.conv = torch.nn.Conv2d(1, 1, 1) 3043 3044 def forward(self, x): 3045 return self.conv(x) 3046 3047 m = M() 3048 m.eval() 3049 m.preserved_attr = 3 3050 prepare_custom_config_dict = { 3051 "preserved_attributes": ["preserved_attr"] 3052 } 3053 example_inputs = (torch.randn(1, 1, 1, 1),) 3054 m = prepare_fx( 3055 m, 3056 {"": default_qconfig}, 3057 example_inputs=example_inputs, 3058 prepare_custom_config=prepare_custom_config_dict) 3059 3060 def assertAttrPreserved(m): 3061 self.assertTrue(hasattr(m, "preserved_attr")) 3062 self.assertEqual(m.preserved_attr, 3) 3063 3064 assertAttrPreserved(m) 3065 convert_custom_config_dict = { 3066 "preserved_attributes": ["preserved_attr"] 3067 } 3068 m = convert_fx(m, convert_custom_config=convert_custom_config_dict) 3069 assertAttrPreserved(m) 3070 3071 @skipIfNoFBGEMM 3072 def test_qat_and_script(self): 3073 model = LinearModelWithSubmodule().train() 3074 qengine = torch.backends.quantized.engine 3075 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig(qengine)} 3076 x = torch.randn(5, 5) 3077 example_inputs = (x,) 3078 model = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 3079 3080 # ensure scripting works 3081 scripted = torch.jit.script(model) 3082 # run one round to make sure model runs 3083 scripted(x) 3084 FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \ 3085 .run(scripted.graph) 3086 3087 # disable fake_quant and observer 3088 for epoch in range(3): 3089 if epoch == 1: 3090 scripted.apply(torch.ao.quantization.disable_observer) 3091 if epoch == 2: 3092 scripted.apply(torch.ao.quantization.disable_fake_quant) 3093 3094 # ensure the fake_quant and observer have been disabled. 3095 matches = ['.fake_quant_enabled', '.observer_enabled'] 3096 for key, v in scripted.state_dict().items(): 3097 if any(x in key for x in matches): 3098 self.assertEqual(v, torch.tensor([0], dtype=torch.int64)) 3099 3100 # enable them back 3101 scripted.apply(torch.ao.quantization.enable_fake_quant) 3102 scripted.apply(torch.ao.quantization.enable_observer) 3103 for key, v in scripted.state_dict().items(): 3104 if any(x in key for x in matches): 3105 self.assertEqual(v, torch.tensor([1], dtype=torch.int64)) 3106 3107 @skipIfNoFBGEMM 3108 def test_save_observer_state_dict(self): 3109 orig = LinearModelWithSubmodule().eval() 3110 model = orig 3111 qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} 3112 x = torch.randn(5, 5) 3113 model = prepare_fx(model, qconfig_dict, example_inputs=(x,)) 3114 3115 # run it through input 3116 model(x) 3117 # save state_dict of model 3118 obs_dict = torch.ao.quantization.get_observer_state_dict(model) 3119 3120 quant = convert_fx(model) 3121 3122 b = io.BytesIO() 3123 torch.save(obs_dict, b) 3124 3125 # Load the stats into new model 3126 for weights_only in [True, False]: 3127 b.seek(0) 3128 model_2 = orig 3129 model_2 = prepare_fx(model_2, qconfig_dict, example_inputs=(x,)) 3130 3131 loaded_dict = torch.load(b, weights_only=weights_only) 3132 torch.ao.quantization.load_observer_state_dict(model_2, loaded_dict) 3133 3134 quant_2 = convert_fx(model_2) 3135 3136 # Verify that loaded state dict produces same results. 3137 self.assertEqual(quant(x), quant_2(x)) 3138 3139 @skipIfNoFBGEMM 3140 def test_custom_module_class(self): 3141 class CustomModule(torch.nn.Module): 3142 def __init__(self) -> None: 3143 super().__init__() 3144 self.linear = torch.nn.Linear(3, 3) 3145 3146 def forward(self, x): 3147 return self.linear(x) 3148 3149 class ObservedCustomModule(torch.nn.Module): 3150 def __init__(self, linear): 3151 super().__init__() 3152 self.linear = linear 3153 3154 def forward(self, x): 3155 return self.linear(x) 3156 3157 @classmethod 3158 def from_float(cls, float_module): 3159 assert hasattr(float_module, 'qconfig') 3160 observed = cls(float_module.linear) 3161 observed.qconfig = float_module.qconfig 3162 return observed 3163 3164 class StaticQuantCustomModule(torch.nn.Module): 3165 def __init__(self, linear): 3166 super().__init__() 3167 self.linear = linear 3168 3169 def forward(self, x): 3170 return self.linear(x) 3171 3172 @classmethod 3173 def from_observed(cls, observed_module): 3174 assert hasattr(observed_module, 'qconfig') 3175 assert hasattr(observed_module, 'activation_post_process') 3176 observed_module.linear.activation_post_process = \ 3177 observed_module.activation_post_process 3178 quantized = cls(nnq.Linear.from_float(observed_module.linear)) 3179 return quantized 3180 3181 class DynamicQuantCustomModule(torch.nn.Module): 3182 def __init__(self, linear): 3183 super().__init__() 3184 self.linear = linear 3185 3186 def forward(self, x): 3187 return self.linear(x) 3188 3189 @classmethod 3190 def from_observed(cls, observed_module): 3191 assert hasattr(observed_module, 'qconfig') 3192 observed_module.linear.qconfig = observed_module.qconfig 3193 quantized = cls(nnqd.Linear.from_float(observed_module.linear)) 3194 return quantized 3195 3196 class M(torch.nn.Module): 3197 def __init__(self) -> None: 3198 super().__init__() 3199 self.linear = torch.nn.Linear(3, 3) 3200 self.custom = CustomModule() 3201 3202 def forward(self, x): 3203 x = self.linear(x) 3204 x = self.custom(x) 3205 return x 3206 3207 class RefM(torch.nn.Module): 3208 def __init__(self) -> None: 3209 super().__init__() 3210 self.linear1 = torch.nn.Linear(3, 3) 3211 self.linear2 = torch.nn.Linear(3, 3) 3212 3213 def forward(self, x): 3214 x = self.linear1(x) 3215 x = self.linear2(x) 3216 return x 3217 3218 # instantiate M and RefM and align the parameters 3219 original_m = M().eval() 3220 original_ref_m = RefM().eval() 3221 original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach()) 3222 original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach()) 3223 original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach()) 3224 original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach()) 3225 3226 a16_qconfig = QConfig( 3227 activation=MinMaxObserver.with_args(dtype=torch.qint32, quant_min=0, quant_max=65536), 3228 weight=default_weight_observer, 3229 ) 3230 test_configs = { 3231 "static": (default_qconfig, StaticQuantCustomModule, 3), 3232 "static_a16": (a16_qconfig, StaticQuantCustomModule, 3), 3233 "dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0) 3234 } 3235 3236 for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]: 3237 key = _get_quant_type_to_str(quant_type) 3238 qconfig, quantized_module_class, num_observers = test_configs[key] 3239 qconfig_dict = {"": qconfig} 3240 if key == "static": 3241 prepare_custom_config_dict = { 3242 "float_to_observed_custom_module_class": { 3243 "static": { 3244 CustomModule: ObservedCustomModule 3245 } 3246 } 3247 } 3248 convert_custom_config_dict = { 3249 "observed_to_quantized_custom_module_class": { 3250 "static": { 3251 ObservedCustomModule: quantized_module_class 3252 } 3253 } 3254 } 3255 else: 3256 prepare_custom_config_dict = { 3257 "non_traceable_module_class": [ 3258 CustomModule 3259 ] 3260 } 3261 convert_custom_config_dict = { 3262 "observed_to_quantized_custom_module_class": { 3263 "dynamic": { 3264 CustomModule: quantized_module_class 3265 } 3266 } 3267 } 3268 3269 example_inputs = (torch.randn(3, 3),) 3270 # check prepared model 3271 m = prepare_fx( 3272 copy.deepcopy(original_m), 3273 qconfig_dict, 3274 example_inputs=example_inputs, 3275 prepare_custom_config=prepare_custom_config_dict) 3276 # calibration 3277 m(*example_inputs) 3278 # all activation observers are inserted in the top level module 3279 count_check = { 3280 ns.call_module(torch.ao.quantization.MinMaxObserver): num_observers 3281 } 3282 self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) 3283 3284 # check converted/quantized model 3285 m = convert_fx( 3286 m, 3287 convert_custom_config=convert_custom_config_dict) 3288 if quant_type == QuantType.STATIC: 3289 count_check = { 3290 ns.call_function(torch.quantize_per_tensor) : 1, 3291 ns.call_module(nnq.Linear) : 1, 3292 ns.call_method('dequantize') : 1, 3293 } 3294 self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) 3295 self.assertEqual(type(m.custom), quantized_module_class) 3296 res = m(*example_inputs) 3297 3298 # quantize the reference model 3299 ref_m = prepare_fx( 3300 copy.deepcopy(original_ref_m), qconfig_dict, example_inputs=example_inputs) 3301 ref_m(*example_inputs) 3302 ref_m = convert_fx(ref_m) 3303 ref_res = ref_m(*example_inputs) 3304 self.assertEqual(res, ref_res) 3305 3306 @skipIfNoFBGEMM 3307 def test_custom_module_class_input_has_multiple_users(self): 3308 """ Tests that the flow still works when the input of custom module 3309 has multiple users 3310 """ 3311 class CustomModule(torch.nn.Module): 3312 def __init__(self) -> None: 3313 super().__init__() 3314 self.linear = torch.nn.Linear(3, 3) 3315 3316 def forward(self, x): 3317 return self.linear(x) 3318 3319 class ObservedCustomModule(torch.nn.Module): 3320 def __init__(self, linear): 3321 super().__init__() 3322 self.linear = linear 3323 3324 def forward(self, x): 3325 return self.linear(x) 3326 3327 @classmethod 3328 def from_float(cls, float_module): 3329 assert hasattr(float_module, 'qconfig') 3330 observed = cls(float_module.linear) 3331 observed.qconfig = float_module.qconfig 3332 return observed 3333 3334 class StaticQuantCustomModule(torch.nn.Module): 3335 def __init__(self, linear): 3336 super().__init__() 3337 self.linear = linear 3338 3339 def forward(self, x): 3340 return self.linear(x) 3341 3342 @classmethod 3343 def from_observed(cls, observed_module): 3344 assert hasattr(observed_module, 'qconfig') 3345 assert hasattr(observed_module, 'activation_post_process') 3346 observed_module.linear.activation_post_process = \ 3347 observed_module.activation_post_process 3348 quantized = cls(nnq.Linear.from_float(observed_module.linear)) 3349 return quantized 3350 3351 class M(torch.nn.Module): 3352 def __init__(self) -> None: 3353 super().__init__() 3354 self.linear = torch.nn.Linear(3, 3) 3355 self.custom = CustomModule() 3356 3357 def forward(self, x0): 3358 x1 = self.custom(x0) 3359 x2 = self.linear(x0) 3360 return x1 + x2 3361 3362 prepare_custom_config_dict = { 3363 "float_to_observed_custom_module_class": { 3364 "static": { 3365 CustomModule: ObservedCustomModule 3366 } 3367 } 3368 } 3369 convert_custom_config_dict = { 3370 "observed_to_quantized_custom_module_class": { 3371 "static": { 3372 ObservedCustomModule: StaticQuantCustomModule 3373 } 3374 } 3375 } 3376 m = M().eval() 3377 example_inputs = (torch.randn(3, 3),) 3378 m = prepare_fx( 3379 m, 3380 {"": default_qconfig}, 3381 example_inputs=example_inputs, 3382 prepare_custom_config=prepare_custom_config_dict) 3383 # make sure it works 3384 m = convert_fx( 3385 m, 3386 convert_custom_config=convert_custom_config_dict) 3387 # make sure it runs 3388 m(*example_inputs) 3389 3390 @skipIfNoFBGEMM 3391 def test_custom_module_class_input_has_duplicate_nodes(self): 3392 """ Tests that the flow still works when the graph has 3393 multiple nodes with the same custom module target. 3394 """ 3395 class CustomModule(torch.nn.Module): 3396 def __init__(self) -> None: 3397 super().__init__() 3398 self.linear = torch.nn.Linear(3, 3) 3399 3400 def forward(self, x): 3401 return self.linear(x) 3402 3403 class ObservedCustomModule(torch.nn.Module): 3404 def __init__(self, linear): 3405 super().__init__() 3406 self.linear = linear 3407 3408 def forward(self, x): 3409 return self.linear(x) 3410 3411 @classmethod 3412 def from_float(cls, float_module): 3413 assert hasattr(float_module, 'qconfig') 3414 observed = cls(float_module.linear) 3415 observed.qconfig = float_module.qconfig 3416 return observed 3417 3418 class StaticQuantCustomModule(torch.nn.Module): 3419 def __init__(self, linear): 3420 super().__init__() 3421 self.linear = linear 3422 3423 def forward(self, x): 3424 return self.linear(x) 3425 3426 @classmethod 3427 def from_observed(cls, observed_module): 3428 assert hasattr(observed_module, 'qconfig') 3429 assert hasattr(observed_module, 'activation_post_process') 3430 observed_module.linear.activation_post_process = \ 3431 observed_module.activation_post_process 3432 quantized = cls(nnq.Linear.from_float(observed_module.linear)) 3433 return quantized 3434 3435 class M(torch.nn.Module): 3436 def __init__(self) -> None: 3437 super().__init__() 3438 self.custom = CustomModule() 3439 3440 def forward(self, x0): 3441 x1 = self.custom(x0) 3442 x2 = self.custom(x0) 3443 return x1 + x2 3444 3445 prepare_custom_config_dict = { 3446 "float_to_observed_custom_module_class": { 3447 "static": { 3448 CustomModule: ObservedCustomModule 3449 } 3450 } 3451 } 3452 convert_custom_config_dict = { 3453 "observed_to_quantized_custom_module_class": { 3454 "static": { 3455 ObservedCustomModule: StaticQuantCustomModule 3456 } 3457 } 3458 } 3459 m = M().eval() 3460 example_inputs = (torch.randn(3, 3),) 3461 m = prepare_fx( 3462 m, 3463 {"": default_qconfig}, 3464 example_inputs=example_inputs, 3465 prepare_custom_config=prepare_custom_config_dict) 3466 # make sure it works 3467 m = convert_fx( 3468 m, 3469 convert_custom_config=convert_custom_config_dict) 3470 # make sure it runs 3471 m(*example_inputs) 3472 3473 @skipIfNoFBGEMM 3474 def test_non_traceable_module(self): 3475 class NonTraceable(torch.nn.Module): 3476 def forward(self, x): 3477 for k in x.keys(): 3478 print(x[k]) 3479 return x 3480 3481 class NonTraceable2(torch.nn.Module): 3482 def forward(self, x): 3483 # data dependent control flow is not traceable 3484 for i in x: 3485 print(i) 3486 return x 3487 3488 class M(torch.nn.Module): 3489 def __init__(self) -> None: 3490 super().__init__() 3491 self.m1 = NonTraceable() 3492 self.m2 = NonTraceable2() 3493 3494 def forward(self, x): 3495 x = self.m1(x) 3496 x = self.m2(x) 3497 return x 3498 3499 m = M().eval() 3500 qconfig_dict = {"": default_qconfig} 3501 prepare_custom_config_dict = { 3502 "non_traceable_module_name": [ 3503 "m1" 3504 ], 3505 "non_traceable_module_class": [ 3506 NonTraceable2 3507 ] 3508 } 3509 m = prepare_fx( 3510 m, qconfig_dict, 3511 example_inputs=({"key": torch.randn(1)},), 3512 prepare_custom_config=prepare_custom_config_dict) 3513 3514 node_occurrence = { 3515 ns.call_module(NonTraceable) : 1, 3516 ns.call_module(NonTraceable2) : 1, 3517 } 3518 # make sure these modules are not traced 3519 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 3520 3521 def test_prepared_model_deepcopy(self): 3522 """Ensures that copy.deepcopy works correctly on a prepared model. 3523 """ 3524 class M(torch.nn.Module): 3525 def __init__(self) -> None: 3526 super().__init__() 3527 self.conv = torch.nn.Conv2d(1, 1, 1) 3528 self._foobar = 'foobar' 3529 self.foobar2 = 'foobar2' 3530 3531 def forward(self, x): 3532 x = self.conv(x) 3533 return x 3534 3535 m = M() 3536 m.eval() 3537 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 3538 example_inputs = (torch.randn(4, 1, 4, 4),) 3539 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 3540 # calibrate 3541 prepared(*example_inputs) 3542 # copy 3543 prepared_copy = copy.deepcopy(prepared) 3544 # quantize, should run with no errors 3545 quantized = convert_fx(prepared_copy) 3546 3547 def test_quantized_model_type(self): 3548 """ Test state_dict and deepcopy works properly in the quantized model 3549 """ 3550 class M(torch.nn.Module): 3551 def __init__(self) -> None: 3552 super().__init__() 3553 self.linear = torch.nn.Linear(5, 5) 3554 3555 def forward(self, x): 3556 return self.linear(x) 3557 3558 example_inputs = (torch.rand(8, 5),) 3559 m = M().eval() 3560 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 3561 m = convert_fx(m) 3562 # test deepcopy 3563 m_copy = copy.deepcopy(m) 3564 self.assertEqual(m_copy(*example_inputs), m(*example_inputs)) 3565 3566 # test state_dict 3567 state_dict = m.state_dict() 3568 m_new = M().eval() 3569 m_new = prepare_fx(m_new, {"": default_qconfig}, example_inputs=example_inputs) 3570 m_new = convert_fx(m_new) 3571 m_new.load_state_dict(state_dict) 3572 self.assertEqual(m_new(*example_inputs), m(*example_inputs)) 3573 3574 def test_dequantize(self): 3575 r""" Test to make sure dequantize node are placed before 3576 non-quantizable node 3577 """ 3578 class M(torch.nn.Module): 3579 def __init__(self) -> None: 3580 super().__init__() 3581 self.conv = torch.nn.Conv2d(1, 1, 1) 3582 self.act = torch.nn.GELU() 3583 3584 def forward(self, x): 3585 x = self.conv(x) 3586 return self.act(x) 3587 3588 data = torch.rand(5, 1, 3, 3, dtype=torch.float) 3589 for quant_type in self.static_quant_types: 3590 node_list = [ 3591 ns.call_module(nnq.Conv2d), 3592 ns.call_method("dequantize"), 3593 ns.call_module(nn.GELU), 3594 ] 3595 self.checkGraphModeFxOp( 3596 M().eval(), (data,), quant_type, expected_node_list=node_list) 3597 3598 def test_sequential(self): 3599 class M(torch.nn.Module): 3600 def __init__(self) -> None: 3601 super().__init__() 3602 self.convs = torch.nn.Sequential( 3603 torch.nn.Conv2d(1, 1, 1), 3604 torch.nn.Conv2d(1, 1, 1) 3605 ) 3606 3607 def forward(self, x): 3608 x = self.convs(x) 3609 return x 3610 3611 data = torch.rand(5, 1, 3, 3, dtype=torch.float) 3612 for quant_type in self.static_quant_types: 3613 node_list = [ 3614 ns.call_module(nnq.Conv2d), 3615 ns.call_module(nnq.Conv2d), 3616 ] 3617 self.checkGraphModeFxOp( 3618 M().eval(), (data,), quant_type, expected_node_list=node_list) 3619 3620 def _test_quantized_inputs_outputs( 3621 self, prepare_custom_config_dict, prepare_count_check, 3622 convert_count_check): 3623 """ 3624 Test the option to have inputs and outputs of the graph quantized 3625 """ 3626 class M(torch.nn.Module): 3627 def __init__(self) -> None: 3628 super().__init__() 3629 self.conv1 = torch.nn.Conv2d(1, 1, 1) 3630 self.conv2 = torch.nn.Conv2d(1, 1, 1) 3631 3632 def forward(self, x): 3633 x = self.conv1(x) 3634 x = self.conv2(x) 3635 return x 3636 3637 # quantized input, quantized output 3638 m = M() 3639 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 3640 example_inputs = (torch.randn(1, 1, 4, 4),) 3641 m.eval() 3642 mp = torch.ao.quantization.quantize_fx.prepare_fx( 3643 m, qconfig_dict, 3644 example_inputs=example_inputs, 3645 prepare_custom_config=prepare_custom_config_dict) 3646 self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check) 3647 mp(*example_inputs) 3648 mq = torch.ao.quantization.quantize_fx.convert_fx(mp) 3649 self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check) 3650 3651 def test_quantized_input_quantized_output(self): 3652 prepare_custom_config_dict = { 3653 'input_quantized_idxs': [0], 'output_quantized_idxs': [0]} 3654 prepare_count_check = { 3655 ns.call_module(torch.ao.quantization.MinMaxObserver): 2, 3656 } 3657 convert_count_check = { 3658 ns.call_function(torch.quantize_per_tensor): 0, 3659 ns.call_method('dequantize'): 0, 3660 } 3661 self._test_quantized_inputs_outputs( 3662 prepare_custom_config_dict, prepare_count_check, convert_count_check) 3663 3664 def test_fp32_input_quantized_output(self): 3665 prepare_custom_config_dict = { 3666 'output_quantized_idxs': [0]} 3667 prepare_count_check = { 3668 ns.call_module(torch.ao.quantization.MinMaxObserver): 3, 3669 } 3670 convert_count_check = { 3671 ns.call_function(torch.quantize_per_tensor): 1, 3672 ns.call_method('dequantize'): 0, 3673 } 3674 self._test_quantized_inputs_outputs( 3675 prepare_custom_config_dict, prepare_count_check, convert_count_check) 3676 3677 def test_quantized_input_fp32_output(self): 3678 prepare_custom_config_dict = { 3679 'input_quantized_idxs': [0]} 3680 prepare_count_check = { 3681 ns.call_module(torch.ao.quantization.MinMaxObserver): 2, 3682 } 3683 convert_count_check = { 3684 ns.call_function(torch.quantize_per_tensor): 0, 3685 ns.call_method('dequantize'): 1, 3686 } 3687 self._test_quantized_inputs_outputs( 3688 prepare_custom_config_dict, prepare_count_check, convert_count_check) 3689 3690 def test_fp32_input_fp32_output(self): 3691 prepare_custom_config_dict = {} 3692 prepare_count_check = { 3693 ns.call_module(torch.ao.quantization.MinMaxObserver): 3, 3694 } 3695 convert_count_check = { 3696 ns.call_function(torch.quantize_per_tensor): 1, 3697 ns.call_method('dequantize'): 1, 3698 } 3699 self._test_quantized_inputs_outputs( 3700 prepare_custom_config_dict, prepare_count_check, convert_count_check) 3701 3702 @skipIfNoFBGEMM 3703 def test_convtranspose_per_channel_fails_early(self): 3704 r""" 3705 Verifies that attempting to quantize a ConvTranspose module with per-Channel 3706 weight observers fails in the prepare step, as opposed to the convert step. 3707 """ 3708 m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1)) 3709 m.eval() 3710 qconfig_dict = {'': torch.ao.quantization.get_default_qconfig('fbgemm')} 3711 with self.assertRaises(AssertionError) as context: 3712 mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) 3713 self.assertTrue( 3714 str(context.exception) == 3715 'Per channel weight observer is not supported yet for ConvTranspose{n}d.') 3716 3717 @skipIfNoFBGEMM 3718 def test_qparams_buffers(self): 3719 class Linear(torch.nn.Module): 3720 def __init__(self) -> None: 3721 super().__init__() 3722 self.w = torch.ones(5, 5) 3723 self.b = torch.zeros(5) 3724 3725 def forward(self, x): 3726 return torch.nn.functional.linear(x, self.w, self.b) 3727 3728 class M(torch.nn.Module): 3729 def __init__(self) -> None: 3730 super().__init__() 3731 self.mods1 = torch.nn.Sequential( 3732 Linear(), 3733 Linear() 3734 ) 3735 self.mods2 = Linear() 3736 3737 def forward(self, x): 3738 x = self.mods1(x) 3739 x = self.mods2(x) 3740 return x 3741 3742 model = M().eval() 3743 qconfig_dict = {"": default_qconfig} 3744 example_inputs = (torch.rand(5, 5),) 3745 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 3746 m(*example_inputs) 3747 m = convert_fx(m) 3748 keys = m.state_dict().keys() 3749 quant_scale_count = quant_zero_point = scale_count = zero_point_count = 0 3750 for k in keys: 3751 if 'input_scale' in k: 3752 quant_scale_count = quant_scale_count + 1 3753 elif 'input_zero_point' in k: 3754 quant_zero_point = quant_zero_point + 1 3755 elif 'scale' in k: 3756 scale_count = scale_count + 1 3757 elif 'zero_point' in k: 3758 zero_point_count = zero_point_count + 1 3759 3760 # Expect each quantized linear op to have a scale and zero point 3761 self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict") 3762 self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict") 3763 m(*example_inputs) 3764 # ensure it is scriptable 3765 scripted = torch.jit.script(m) 3766 scripted_keys = scripted.state_dict().keys() 3767 scripted.mods1_0_packed_weight_0 = m.state_dict()["mods1_0_packed_weight_0"] 3768 non_packed_weight_keys = [key for key in keys if "_packed_weight" not in key] 3769 self.assertTrue( 3770 set(scripted_keys) == set(non_packed_weight_keys), 3771 "Expected the scripted model to preserve the state_dict for non-packed weight attributes") 3772 # TODO: probably don't want to hardcode the attribute names, since they are generated 3773 for attr_name in [ 3774 "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", 3775 "mods1_0_scale_1", "mods1_0_zero_point_1", 3776 "mods1_1_scale_1", "mods1_1_zero_point_1", 3777 "mods2_scale_1", "mods2_zero_point_1"]: 3778 self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") 3779 3780 @skipIfNoFBGEMM 3781 def test_packed_weight_fused_op(self): 3782 class Linear(torch.nn.Module): 3783 def __init__(self) -> None: 3784 super().__init__() 3785 self.w = torch.ones(5, 5) 3786 self.b = torch.zeros(5) 3787 3788 def forward(self, x): 3789 return F.linear(x, self.w, self.b) 3790 3791 class M(torch.nn.Module): 3792 def __init__(self) -> None: 3793 super().__init__() 3794 self.mods1 = torch.nn.Sequential( 3795 Linear(), 3796 Linear() 3797 ) 3798 self.mods2 = Linear() 3799 self.relu = F.relu 3800 3801 def forward(self, x): 3802 x = self.mods1(x) 3803 x = self.mods2(x) 3804 x = self.relu(x) 3805 return x 3806 3807 model = M().eval() 3808 example_inputs = (torch.rand(5, 5),) 3809 qconfig_dict = {"": default_qconfig} 3810 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 3811 m(*example_inputs) 3812 m = convert_fx(m) 3813 assert hasattr(m, "mods1_0_packed_weight_0") 3814 assert hasattr(m, "mods1_1_packed_weight_0") 3815 assert hasattr(m, "mods2_packed_weight_0") 3816 3817 @skipIfNoFBGEMM 3818 def test_mul_add_fp16_config(self): 3819 with override_quantized_engine('fbgemm'): 3820 class Linear(torch.nn.Module): 3821 def __init__(self) -> None: 3822 super().__init__() 3823 self.w = torch.ones(5, 5) 3824 self.b = torch.zeros(5) 3825 3826 def forward(self, x): 3827 return torch.nn.functional.linear(x, self.w, self.b) 3828 3829 class M(torch.nn.Module): 3830 def __init__(self) -> None: 3831 super().__init__() 3832 self.mods1 = torch.nn.Sequential( 3833 Linear(), 3834 Linear() 3835 ) 3836 self.mods2 = Linear() 3837 3838 def forward(self, x): 3839 x = x * 5 3840 x = x + 5 3841 x = self.mods1(x) 3842 x = self.mods2(x) 3843 return x 3844 model = M().eval() 3845 qconfig_dict = {"": float16_dynamic_qconfig} 3846 example_inputs = (torch.rand(5, 5),) 3847 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 3848 m = convert_fx(m) 3849 # make sure it runs 3850 m(*example_inputs) 3851 3852 def test_getattr_with_nontensor_result(self): 3853 """ 3854 Verifies that binary ops get quantized correctly if some 3855 of the args are nodes but not Tensors, such as an `x.ndim` 3856 pattern. 3857 """ 3858 class M1(torch.nn.Module): 3859 def forward(self, x): 3860 dims = x.ndim 3861 dims_sub = dims - 1 3862 dims_sub2 = dims_sub - 1 3863 x = torch.add(x, dims_sub2) 3864 return x 3865 3866 class M2(torch.nn.Module): 3867 def forward(self, x): 3868 dims = x.ndim 3869 dims_sub = dims - 2 3870 mul = [1] * dims_sub 3871 dims_list = [-1, x.size(1)] + mul 3872 x = x.view(dims_list) 3873 return x 3874 3875 class M3(torch.nn.Module): 3876 def forward(self, x): 3877 shape = x.shape 3878 x = x.view(shape) 3879 return x 3880 3881 for cls in (M1, M2, M3): 3882 m = cls().eval() 3883 example_inputs = (torch.rand(4, 4, 4, 4),) 3884 m(*example_inputs) 3885 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 3886 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 3887 mp(torch.rand(4, 4, 4, 4)) 3888 mc = convert_fx(mp) 3889 3890 class _NonReferenceTestModel(nn.Module): 3891 def __init__(self, func, lin_in, lin_out): 3892 super().__init__() 3893 self.conv1 = nn.Conv2d(3, 6, 5) 3894 self.pool = nn.MaxPool2d(2, 2) 3895 self.lin = nn.Linear(lin_in, lin_out) 3896 self.func = func 3897 3898 def forward(self, x, y, z): 3899 x = self.pool(F.relu(self.conv1(x))) 3900 x = torch.flatten(x, 1) 3901 x = self.func(x, y, z) 3902 x = self.lin(x) 3903 return x 3904 3905 # This function looks at the node specified by the NodeInfo in the key of 3906 # node_info_to_non_tensor_args and checks that the args at specified indices 3907 # are not observed (since they are non tensors). If the args at those indices 3908 # are a tuple/list (which do not show up as nodes) the function checks the 3909 # individual elements of the tuple/list recursively. 3910 def _check_not_observed(self, model, node_info_to_non_tensor_args): 3911 3912 # this is a helper function (for easier recursion) that checks whether 3913 # arg_node is observed 3914 def _check_node_not_observed(model, arg_node, node): 3915 if isinstance(arg_node, (tuple, list)): 3916 for new_node in arg_node: 3917 _check_node_not_observed(model, new_node, node) 3918 elif arg_node.op == "call_module": 3919 self.assertTrue( 3920 not _is_activation_post_process(getattr(model, arg_node.target)), 3921 f"Arg: {arg_node} of node: {node} is observed but is not a float tensor", 3922 ) 3923 3924 for node in model.graph.nodes: 3925 indices = node_info_to_non_tensor_args.get( 3926 NodeInfo(node.op, node.target), [] 3927 ) 3928 for index in indices: 3929 if index < len(node.args): 3930 arg_node = node.args[index] 3931 _check_node_not_observed(model, arg_node, node) 3932 3933 # This test checks that the model gets prepared correct, doesn't have observers 3934 # on specific ops (see _check_not_observed) and that the prepared model runs 3935 def _test_dtype_propagation(self, model, node_info_to_non_tensor_args, *args): 3936 model.eval() 3937 qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} 3938 prepared_model = prepare_fx(model, qconfig_dict, example_inputs=tuple(args)) 3939 self._check_not_observed(prepared_model, node_info_to_non_tensor_args) 3940 prepared_model(*args) 3941 3942 def test_masked_fill_nontensor_args_not_observed(self): 3943 def func(x, y, z): 3944 return x.masked_fill(y, z) 3945 3946 model = self._NonReferenceTestModel(func, 1176, 1) 3947 args = [torch.randn(5, 3, 32, 32), torch.randn(1176) > 0, 0.1] 3948 node_info_to_non_tensor_args = {NodeInfo("call_method", "masked_fill"): [1, 2]} 3949 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3950 3951 def test_permute_nontensor_args_not_observed(self): 3952 def func(x, y, z): 3953 return x.permute(y, z) 3954 3955 model = self._NonReferenceTestModel(func, 1176, 1) 3956 args = [torch.randn(5, 3, 32, 32), 0, 1] 3957 node_info_to_non_tensor_args = {NodeInfo("call_method", "permute"): [1, 2]} 3958 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3959 3960 def test_repeat_nontensor_args_not_observed(self): 3961 def func(x, y, z): 3962 return x.repeat(y, z) 3963 3964 model = self._NonReferenceTestModel(func, 1176, 1) 3965 args = [torch.randn(5, 3, 32, 32), 2, 1] 3966 node_info_to_non_tensor_args = {NodeInfo("call_method", "repeat"): [1, 2]} 3967 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3968 3969 def test_reshape_nontensor_args_not_observed(self): 3970 def func(x, y, z): 3971 return x.reshape(-1, y) 3972 3973 model = self._NonReferenceTestModel(func, 5, 1) 3974 args = [torch.randn(5, 3, 32, 32), 5, None] 3975 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [2]} 3976 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3977 3978 def test_size_nontensor_args_not_observed(self): 3979 def func(x, y, z): 3980 return x.reshape((-1, x.size(y))) 3981 3982 model = self._NonReferenceTestModel(func, 5, 1) 3983 args = [torch.randn(5, 3, 32, 32), 0, None] 3984 node_info_to_non_tensor_args = {NodeInfo("call_method", "size"): [1]} 3985 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3986 3987 def test_transpose_nontensor_args_not_observed(self): 3988 def func(x, y, z): 3989 return x.transpose(y, z) 3990 3991 model = self._NonReferenceTestModel(func, 5, 1) 3992 args = [torch.randn(5, 3, 32, 32), 0, 1] 3993 node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} 3994 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 3995 3996 def test_torch_transpose_nontensor_args_not_observed(self): 3997 # TODO: make torch.transpose traceable by fx when using 3998 # variable nontensor arguments 3999 # func = lambda x, y, z: torch.transpose(x, y, z) # error 4000 def func(x, y, z): 4001 return torch.transpose(x, 0, 1) 4002 4003 model = self._NonReferenceTestModel(func, 5, 1) 4004 node_info_to_non_tensor_args = { 4005 NodeInfo("call_method", torch.transpose): [1, 2] 4006 } 4007 args = [torch.randn(5, 3, 32, 32), 0, 1] 4008 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4009 4010 def test_unsqueeze_nontensor_args_not_observed(self): 4011 def func(x, y, z): 4012 return x.unsqueeze(y) 4013 4014 model = self._NonReferenceTestModel(func, 1176, 1) 4015 args = [torch.randn(5, 3, 32, 32), 1, None] 4016 node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze"): [1]} 4017 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4018 4019 def test_unsqueeze__nontensor_args_not_observed(self): 4020 def func(x, y, z): 4021 return x.unsqueeze_(y) 4022 4023 model = self._NonReferenceTestModel(func, 1176, 1) 4024 args = [torch.randn(5, 3, 32, 32), 1, None] 4025 node_info_to_non_tensor_args = {NodeInfo("call_method", "unsqueeze_"): [1]} 4026 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4027 4028 def test_torch_unsqueeze_nontensor_args_not_observed(self): 4029 # TODO: make torch.unsqueeze scriptable by fx when using 4030 # variable nontensor arguments 4031 # func = lambda x, y, z: torch.unsqueeze(x, y) # error 4032 def func(x, y, z): 4033 return torch.unsqueeze(x, 1) 4034 4035 model = self._NonReferenceTestModel(func, 1176, 1) 4036 args = [torch.randn(5, 3, 32, 32), 1, None] 4037 node_info_to_non_tensor_args = {NodeInfo("call_method", torch.unsqueeze): [1]} 4038 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4039 4040 def test_view_nontensor_args_not_observed(self): 4041 def func(x, y, z): 4042 return x.view(-1, y) 4043 4044 model = self._NonReferenceTestModel(func, 5, 1) 4045 args = [torch.randn(5, 3, 32, 32), 5, None] 4046 node_info_to_non_tensor_args = {NodeInfo("call_method", "view"): [2]} 4047 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4048 4049 def test_propagate_dtypes_for_known_nodes_list_args(self): 4050 def func(x, y, z): 4051 return x.reshape(y) 4052 4053 model = self._NonReferenceTestModel(func, 5, 1) 4054 args = [torch.randn(5, 3, 32, 32), [-1, 5], None] 4055 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} 4056 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4057 4058 def test_propagate_dtypes_for_known_nodes_split_list_args(self): 4059 def func(x, y, z): 4060 return x.reshape([y, z]) 4061 4062 model = self._NonReferenceTestModel(func, 5, 1) 4063 args = [torch.randn(5, 3, 32, 32), -1, 5] 4064 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} 4065 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4066 4067 def test_propagate_dtypes_for_known_nodes_tuple_args(self): 4068 def func(x, y, z): 4069 return x.reshape(y) 4070 4071 model = self._NonReferenceTestModel(func, 5, 1) 4072 args = [torch.randn(5, 3, 32, 32), (-1, 5), None] 4073 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} 4074 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4075 4076 def test_propagate_dtypes_for_known_nodes_split_tuple_args(self): 4077 def func(x, y, z): 4078 return x.reshape((y, z)) 4079 4080 model = self._NonReferenceTestModel(func, 5, 1) 4081 args = [torch.randn(5, 3, 32, 32), -1, 5] 4082 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} 4083 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4084 4085 def test_propagate_dtypes_for_known_nodes_dict_args(self): 4086 def func(x, y, z): 4087 return x.transpose(y["first"], y["second"]) 4088 4089 model = self._NonReferenceTestModel(func, 5, 1) 4090 args = [torch.randn(5, 3, 32, 32), {"first": 0, "second": 1}, None] 4091 node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1, 2]} 4092 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4093 4094 def test_propagate_dtypes_for_known_nodes_dict_tuple_args(self): 4095 class reshape_module(nn.Module): 4096 def forward(self, x, y, z): 4097 return x.reshape(y["shape"]) 4098 4099 model = self._NonReferenceTestModel(reshape_module(), 5, 1) 4100 args = [torch.randn(5, 3, 32, 32), {"shape": (-1, 5)}, None] 4101 node_info_to_non_tensor_args = {NodeInfo("call_method", "reshape"): [1]} 4102 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4103 4104 def test_propagate_dtypes_for_known_nodes_dict_split_tuple_args(self): 4105 def func(x, y, z): 4106 return x.reshape((y["first"], y["second"])) 4107 4108 model = self._NonReferenceTestModel(func, 5, 1) 4109 args = [torch.randn(5, 3, 32, 32), {"first": -1, "second": 5}, None] 4110 node_info_to_non_tensor_args = {NodeInfo("call_method", "transpose"): [1]} 4111 self._test_dtype_propagation(model, node_info_to_non_tensor_args, *args) 4112 4113 def test_assert_on_size_after_quant_layer(self): 4114 """ 4115 Verifies that calculating a size of a quantized tensor works 4116 correctly in quantization passes. 4117 """ 4118 class M(torch.nn.Module): 4119 def __init__(self) -> None: 4120 super().__init__() 4121 self.conv1 = nn.Conv2d(1, 1, 1) 4122 4123 def forward(self, x): 4124 x = self.conv1(x) 4125 torch._assert(x.size(1) == 1, 'foobar') 4126 return x 4127 4128 m = M().eval() 4129 example_inputs = (torch.rand(4, 1, 4, 4),) 4130 m(*example_inputs) 4131 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 4132 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 4133 mp(*example_inputs) 4134 mc = convert_fx(mp) 4135 mc(*example_inputs) 4136 4137 def test_fp32_sum(self): 4138 """ 4139 Verifies that fp32 sum works correctly if it's before or after 4140 quantized layers. 4141 """ 4142 class M1(torch.nn.Module): 4143 def __init__(self) -> None: 4144 super().__init__() 4145 self.conv1 = nn.Conv2d(1, 1, 1) 4146 4147 def forward(self, x): 4148 x = self.conv1(x) 4149 x = torch.stack([x]) 4150 x = torch.sum(x) 4151 return x 4152 4153 class M2(torch.nn.Module): 4154 def __init__(self) -> None: 4155 super().__init__() 4156 self.conv1 = nn.Conv2d(1, 1, 1) 4157 self.conv2 = nn.Conv2d(1, 1, 1) 4158 4159 def forward(self, x): 4160 x = self.conv1(x) 4161 x1 = torch.stack([x]) 4162 x1 = torch.sum(x1, dim=0) 4163 x2 = self.conv2(x1) 4164 return x2 4165 4166 for cls in (M1, M2): 4167 m = cls().eval() 4168 example_inputs = (torch.rand(4, 1, 4, 4),) 4169 m(*example_inputs) 4170 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 4171 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 4172 mp(*example_inputs) 4173 mc = convert_fx(mp) 4174 mc(*example_inputs) 4175 4176 def test_fusion_pattern_unquantized(self): 4177 """ 4178 Ensure that leaving a possible fusion pattern of multiple nodes 4179 unquantized runs through the APIs without errors. 4180 """ 4181 class Child(torch.nn.Module): 4182 def __init__(self) -> None: 4183 super().__init__() 4184 self.relu = nn.ReLU() 4185 4186 def forward(self, x): 4187 x = torch.add(x, 1.0) 4188 x = torch.nn.functional.relu(x) 4189 return x 4190 4191 class Parent(torch.nn.Module): 4192 def __init__(self) -> None: 4193 super().__init__() 4194 self.child = Child() 4195 self.conv = nn.Conv2d(1, 1, 1) 4196 4197 def forward(self, x): 4198 x = self.child(x) 4199 x = self.conv(x) 4200 return x 4201 4202 m = Parent().eval() 4203 qconfig_dict = { 4204 '': torch.ao.quantization.default_qconfig, 4205 'module_name': [ 4206 ('child', None), 4207 ], 4208 } 4209 example_inputs = (torch.rand(1, 1, 1, 1),) 4210 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 4211 mp(*example_inputs) 4212 mc = convert_fx(mp) 4213 4214 def test_state_dict(self): 4215 """ Make sure packed params appear in state_dict 4216 """ 4217 4218 # test linear packed weight 4219 class M1(torch.nn.Module): 4220 def __init__(self) -> None: 4221 super().__init__() 4222 self.w = torch.rand(4, 30) 4223 self.b = torch.rand(4) 4224 4225 def forward(self, x): 4226 return F.linear(x, self.w, self.b) 4227 4228 m = M1().eval() 4229 qconfig_dict = {"": default_qconfig} 4230 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 30),)) 4231 m = convert_fx(m) 4232 state_dict = m.state_dict() 4233 self.assertTrue("_packed_weight_0" in state_dict) 4234 4235 # test conv packed weight 4236 class M2(torch.nn.Module): 4237 def __init__(self) -> None: 4238 super().__init__() 4239 self.w = torch.rand(3, 3, 3, 3) 4240 self.b = torch.rand(3) 4241 self.stride = (1, 1) 4242 self.padding = (0, 0) 4243 self.dilation = (1, 1) 4244 self.groups = 1 4245 4246 def forward(self, x): 4247 return F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) 4248 4249 m = M2().eval() 4250 qconfig_dict = {"": default_qconfig} 4251 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 3, 3),)) 4252 m = convert_fx(m) 4253 state_dict = m.state_dict() 4254 self.assertTrue("_packed_weight_0" in state_dict) 4255 4256 # test load 4257 ref_weight, ref_bias = torch.ops.quantized.conv2d_unpack(state_dict["_packed_weight_0"]) 4258 data = torch.rand(1, 3, 5, 5) 4259 ref_res = m(data) 4260 m = M2().eval() 4261 m = prepare_fx(m, qconfig_dict, (data,)) 4262 m = convert_fx(m) 4263 res = m(data) 4264 weight, bias = m._packed_weight_0.unpack() 4265 # check that random model weight/bias does not match ref weight/bias 4266 self.assertNotEqual(weight, ref_weight) 4267 self.assertNotEqual(bias, ref_bias) 4268 self.assertNotEqual(res, ref_res) 4269 m.load_state_dict(state_dict) 4270 4271 def checkModel(m, data, ref_weight, ref_bias, ref_res): 4272 res = m(data) 4273 weight, bias = m._packed_weight_0.unpack() 4274 # check that weight/bias matches after load the state_dict 4275 self.assertEqual(weight, ref_weight) 4276 self.assertEqual(bias, ref_bias) 4277 self.assertEqual(res, ref_res) 4278 4279 checkModel(m, data, ref_weight, ref_bias, ref_res) 4280 4281 # Test save to disk and load back 4282 m = M2().eval() 4283 m = prepare_fx(m, qconfig_dict, example_inputs=(data,)) 4284 m = convert_fx(m) 4285 m.load_state_dict(state_dict) 4286 with TemporaryFileName() as fname: 4287 torch.save(m.state_dict(), fname) 4288 # weights_only=False as this is loading a ScriptModule 4289 m.load_state_dict(torch.load(fname, weights_only=False)) 4290 4291 checkModel(m, data, ref_weight, ref_bias, ref_res) 4292 4293 @skipIfNoFBGEMM 4294 def test_preserve_qconfig(self): 4295 """ 4296 Test to make sure the temporary config option to preserve qconfig attributes 4297 in the model works 4298 """ 4299 with override_quantized_engine('fbgemm'): 4300 class Linear(torch.nn.Module): 4301 def __init__(self) -> None: 4302 super().__init__() 4303 self.w = torch.ones(5, 5) 4304 self.b = torch.zeros(5) 4305 4306 def forward(self, x): 4307 return torch.nn.functional.linear(x, self.w, self.b) 4308 4309 class M(torch.nn.Module): 4310 def __init__(self) -> None: 4311 super().__init__() 4312 self.mods1 = torch.nn.Sequential( 4313 Linear(), 4314 Linear() 4315 ) 4316 self.mods2 = torch.nn.Sigmoid() 4317 4318 def forward(self, x): 4319 x = self.mods1(x) 4320 x = self.mods2(x) 4321 return x 4322 4323 model = M().eval() 4324 qconfig_dict = { 4325 "object_type": [ 4326 (torch.nn.functional.linear, float16_dynamic_qconfig), 4327 ], 4328 } 4329 example_inputs = (torch.rand(5, 5),) 4330 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 4331 m(*example_inputs) 4332 m = convert_fx(m, _remove_qconfig=False) 4333 4334 self.assertTrue(hasattr(m.mods2, 'qconfig')) 4335 4336 def test_not_used(self): 4337 """ Test quantizing a not used value""" 4338 4339 class M(torch.nn.Module): 4340 def forward(self, x): 4341 x = x + x 4342 x.sigmoid_() 4343 return x 4344 4345 m = M().eval() 4346 qconfig_mapping = get_default_qconfig_mapping().set_global(float16_static_qconfig) 4347 # make sure quantization runs 4348 m = prepare_fx(m, qconfig_mapping, example_inputs=(torch.randn(1),)) 4349 m = convert_fx(m) 4350 4351 def test_qparams_fqn(self): 4352 """ Test that the FQN of input_scale/zero_point is set 4353 to that of first linear use. """ 4354 class Linear(torch.nn.Module): 4355 def __init__(self) -> None: 4356 super().__init__() 4357 self.w = torch.ones(5, 5) 4358 self.b = torch.zeros(5) 4359 4360 def forward(self, x): 4361 return torch.nn.functional.linear(x, self.w, self.b) 4362 4363 class M(torch.nn.Module): 4364 def __init__(self) -> None: 4365 super().__init__() 4366 self.mods1 = torch.nn.Sequential( 4367 Linear(), 4368 Linear() 4369 ) 4370 4371 def forward(self, x): 4372 x = torch.cat((x,), 1) 4373 tmp = x.size() 4374 x = self.mods1(x) 4375 y = x * tmp[0] 4376 return y 4377 4378 model = M().eval() 4379 qconfig_dict = { 4380 "": None, 4381 "object_type": [ 4382 (torch.nn.functional.linear, default_qconfig), 4383 (torch.nn.functional.relu, default_qconfig), 4384 ], 4385 } 4386 example_inputs = (torch.rand(5, 5),) 4387 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 4388 m(*example_inputs) 4389 m = convert_fx(m) 4390 keys = m.state_dict().keys() 4391 m(torch.randn(5, 5)) 4392 # TODO: probably don't want to hardcode the attribute names, since they are generated 4393 for attr_name in [ 4394 "mods1_0_input_scale_0", "mods1_0_input_zero_point_0", 4395 "mods1_0_scale_0", "mods1_0_zero_point_0", 4396 "mods1_1_scale_0", "mods1_1_zero_point_0"]: 4397 self.assertTrue(hasattr(m, attr_name), attr_name + " not found.") 4398 4399 def test_no_obs_between_unmatched_node_and_copy_node(self): 4400 """ 4401 Verifies that an observer is not inserted between an unmatched 4402 node and a node matched to CopyNodeQuantizeHandler. This is done 4403 because observers require activations to be Tensors, and there is 4404 no guarantee that an output of an unmatched node is a Tensor. 4405 """ 4406 4407 class M(nn.Module): 4408 def __init__(self) -> None: 4409 super().__init__() 4410 self.relu = nn.ReLU() 4411 4412 def forward(self, x): 4413 x = _user_func_with_complex_return_type(x) 4414 x1 = x[0] + 1 4415 return x1, x[1] 4416 4417 m = M().eval() 4418 4419 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 4420 example_inputs = (torch.randn(4, 4, 4, 4),) 4421 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 4422 # if an observer is inserted after _user_func_with_complex_return_type, 4423 # the following call will fail 4424 mp(*example_inputs) 4425 mc = convert_fx(mp) 4426 mc(*example_inputs) 4427 4428 def test_fold_quant_dequant(self): 4429 """ Test that the sequence of quant-dequant nodes in the 4430 graph, get folded and we erase the extra dequant nodes. 4431 """ 4432 class M(torch.nn.Module): 4433 def __init__(self) -> None: 4434 super().__init__() 4435 self.w = torch.ones(5, 5) 4436 self.b = torch.zeros(5) 4437 4438 def forward(self, x): 4439 x = torch.cat((x,), 1) 4440 tmp = x.size() 4441 x = torch.nn.functional.linear(x, self.w, self.b) 4442 y = x * tmp[0] 4443 return y 4444 4445 model = M().eval() 4446 qconfig_dict = { 4447 "": None, 4448 "object_type": [ 4449 (torch.nn.functional.linear, default_qconfig), 4450 ], 4451 } 4452 example_inputs = (torch.rand(5, 5),) 4453 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 4454 m(*example_inputs) 4455 m = convert_fx(m) 4456 keys = m.state_dict().keys() 4457 m(*example_inputs) 4458 dequant = 0 4459 quant = 0 4460 for n in m.graph.nodes: 4461 if n.op == "call_method" and n.target == "dequantize": 4462 dequant = dequant + 1 4463 if n.op == "call_function" and n.target == torch.quantize_per_tensor: 4464 quant = quant + 1 4465 self.assertEqual(dequant, 1) 4466 self.assertEqual(quant, 1) 4467 4468 def test_quant_output_always_observed(self): 4469 """ 4470 If the output is hardcoded to be quantized, ensure that 4471 there is always an observer, even if the last non-output node is not 4472 quantizeable. 4473 """ 4474 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 4475 prepare_custom_config_dict = {'output_quantized_idxs': [0]} 4476 example_inputs = (torch.randn(4, 1, 4, 4),) 4477 4478 # non-quantizeable node, quantized output 4479 class M1(torch.nn.Module): 4480 def __init__(self) -> None: 4481 super().__init__() 4482 self.identity = torch.nn.Identity() 4483 4484 def forward(self, x): 4485 x = self.identity(x) 4486 return x 4487 4488 m1 = M1() 4489 self.checkGraphModeFxOp( 4490 m1, example_inputs, QuantType.QAT, 4491 prepare_expected_node_occurrence={ 4492 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, 4493 }, 4494 expected_node_occurrence={ 4495 ns.call_function(torch.quantize_per_tensor): 1, 4496 }, 4497 prepare_custom_config=prepare_custom_config_dict) 4498 4499 # quantizeable node, quantized output 4500 class M2(torch.nn.Module): 4501 def __init__(self) -> None: 4502 super().__init__() 4503 self.conv = torch.nn.Conv2d(1, 1, 1) 4504 4505 def forward(self, x): 4506 x = self.conv(x) 4507 return x 4508 4509 m2 = M2() 4510 self.checkGraphModeFxOp( 4511 m2, example_inputs, QuantType.QAT, 4512 prepare_expected_node_occurrence={ 4513 # one for weights, one for activations 4514 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, 4515 }, 4516 expected_node_occurrence={ 4517 ns.call_function(torch.quantize_per_tensor): 1, 4518 }, 4519 prepare_custom_config=prepare_custom_config_dict) 4520 4521 # quantizeable node, quantized dictionary output 4522 class M3(torch.nn.Module): 4523 def __init__(self) -> None: 4524 super().__init__() 4525 self.conv = torch.nn.Conv2d(1, 1, 1) 4526 4527 def forward(self, x): 4528 x = self.conv(x) 4529 return {"output": x} 4530 4531 m3 = M3() 4532 self.checkGraphModeFxOp( 4533 m3, example_inputs, QuantType.QAT, 4534 prepare_expected_node_occurrence={ 4535 # one for weights, one for activations 4536 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 2, 4537 }, 4538 expected_node_occurrence={ 4539 ns.call_function(torch.quantize_per_tensor): 1, 4540 }, 4541 prepare_custom_config=prepare_custom_config_dict) 4542 4543 def test_deepcopy_preserve_attributes(self): 4544 class M(torch.nn.Module): 4545 def __init__(self) -> None: 4546 super().__init__() 4547 self.attr = 3 4548 4549 def forward(self, x): 4550 return x 4551 4552 m = M().eval() 4553 m = prepare_fx( 4554 m, 4555 {"": default_qconfig}, 4556 example_inputs=(torch.randn(1),), 4557 prepare_custom_config={"preserved_attributes": ["attr"]}) 4558 # preserved attributes are also stored in meta so that it doesn't get lost 4559 # during deepcopy 4560 self.assertTrue(hasattr(m, "attr")) 4561 self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) 4562 m2 = copy.deepcopy(m) 4563 self.assertTrue(hasattr(m2, "attr")) 4564 self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) 4565 m = convert_fx(m, convert_custom_config={"preserved_attributes": ["attr"]}) 4566 self.assertTrue(hasattr(m, "attr")) 4567 self.assertTrue("attr" in m.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) 4568 m2 = copy.deepcopy(m) 4569 self.assertTrue(hasattr(m2, "attr")) 4570 self.assertTrue("attr" in m2.meta[_USER_PRESERVED_ATTRIBUTES_KEY]) 4571 4572 def test_output_lists_and_dicts(self): 4573 """Verify that specifying complicated output types does not crash. 4574 """ 4575 class M(torch.nn.Module): 4576 def __init__(self) -> None: 4577 super().__init__() 4578 self.conv = nn.Conv2d(1, 1, 1) 4579 4580 def forward(self, x): 4581 x = self.conv(x) 4582 return {'foo': [x]}, [{'foo': [[x]]}] 4583 4584 m = M().eval() 4585 qconfig_dict = {'': torch.ao.quantization.default_qconfig} 4586 mp = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1),)) 4587 mc = convert_fx(mp) 4588 4589 def test_shape_followed_by_quantized_op(self): 4590 """ Make sure that shape does not dequantize 4591 the Tensor before the next operator 4592 """ 4593 class M(torch.nn.Module): 4594 def __init__(self) -> None: 4595 super().__init__() 4596 self.conv1 = torch.nn.Conv2d(2, 2, 2) 4597 self.conv2 = torch.nn.Conv2d(2, 2, 2) 4598 4599 def forward(self, x): 4600 x = self.conv1(x) 4601 s = x.shape 4602 torch._assert(s == x.shape, "") 4603 x = self.conv2(x) 4604 return x 4605 4606 # make sure quantization runs 4607 m = M().eval() 4608 example_inputs = (torch.randn(2, 2, 4, 4),) 4609 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 4610 m = convert_fx(m) 4611 m(*example_inputs) 4612 node_occurrence = { 4613 ns.call_function(torch.quantize_per_tensor): 1, 4614 ns.call_method("dequantize"): 1 4615 } 4616 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 4617 4618 def test_trace_quantize_per_tensor(self): 4619 class M(torch.nn.Module): 4620 def __init__(self) -> None: 4621 super().__init__() 4622 self.conv = torch.nn.Conv2d(1, 1, 1) 4623 4624 def forward(self, x): 4625 x = self.conv(x) 4626 return x 4627 4628 m = M().eval() 4629 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1, 1, 3, 3),)) 4630 m = convert_fx(m) 4631 # Make sure this runs without error 4632 m = torch.fx.Transformer(m).transform() 4633 4634 def test_copy_node_has_shared_actpp_instance(self): 4635 """ Test the output of CopyNode to have the same 4636 observer/fake_quant instance as the input 4637 """ 4638 4639 class M(torch.nn.Module): 4640 def __init__(self) -> None: 4641 super().__init__() 4642 self.avgpool2d = torch.nn.AvgPool2d(kernel_size=3) 4643 4644 def forward(self, x): 4645 x = self.avgpool2d(x) 4646 return x 4647 4648 for quant_type in self.static_quant_types: 4649 m = M() 4650 # Checks that we have an observer for both input and output 4651 occurrence_map = { 4652 QuantType.STATIC: { 4653 ns.call_module(torch.ao.quantization.MinMaxObserver): 2 4654 }, 4655 QuantType.QAT: { 4656 ns.call_module(torch.ao.quantization.FakeQuantize): 2 4657 } 4658 } 4659 if quant_type == QuantType.QAT: 4660 m.train() 4661 prepare = prepare_qat_fx 4662 qconfig = default_qat_qconfig 4663 actpp_module_class = torch.ao.quantization.FakeQuantize 4664 else: 4665 m.eval() 4666 prepare = prepare_fx 4667 qconfig = default_qconfig 4668 actpp_module_class = torch.ao.quantization.MinMaxObserver 4669 4670 example_inputs = (torch.randn(1, 3, 3, 3),) 4671 m = prepare(m, {"": qconfig}, example_inputs=example_inputs) 4672 # check that there is a duplicated observer instance 4673 actpp_module_count = 0 4674 for name, module in m.named_modules(remove_duplicate=False): 4675 if isinstance(module, actpp_module_class): 4676 actpp_module_count += 1 4677 self.assertEqual(actpp_module_count, 2) 4678 4679 actpp_module_count = 0 4680 for name, module in m.named_modules(): 4681 if isinstance(module, actpp_module_class): 4682 actpp_module_count += 1 4683 self.assertEqual(actpp_module_count, 1) 4684 4685 m_copy = copy.deepcopy(m) 4686 m = convert_fx(m) 4687 m_reference = convert_to_reference_fx(m_copy) 4688 4689 # checks for non-reference quantized model 4690 node_occurrence = { 4691 ns.call_function(torch.quantize_per_tensor): 1, 4692 ns.call_method("dequantize"): 1 4693 } 4694 node_list = [ 4695 ns.call_function(torch.quantize_per_tensor), 4696 ns.call_module(torch.nn.AvgPool2d), 4697 ns.call_method("dequantize"), 4698 ] 4699 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence, expected_node_list=node_list) 4700 4701 # checks for reference quantized model, for copy nodes we'll have 4702 # dequant - copy_node - quant patterns which will be fused later 4703 # in the backend lowering step 4704 node_occurrence = { 4705 ns.call_function(torch.quantize_per_tensor): 2, 4706 ns.call_method("dequantize"): 2 4707 } 4708 node_list = [ 4709 ns.call_function(torch.quantize_per_tensor), 4710 ns.call_method("dequantize"), 4711 ns.call_module(torch.nn.AvgPool2d), 4712 ns.call_function(torch.quantize_per_tensor), 4713 ns.call_method("dequantize"), 4714 ] 4715 self.checkGraphModuleNodes(m_reference, expected_node_occurrence=node_occurrence, expected_node_list=node_list) 4716 4717 def test_linear_qint8_activation(self): 4718 """Test support for qint8 activation in reference pattern 4719 """ 4720 class M(torch.nn.Module): 4721 def __init__(self) -> None: 4722 super().__init__() 4723 self.conv = torch.nn.Conv2d(1, 2, 2, 2) 4724 self.linear = torch.nn.Linear(8, 5) 4725 4726 def forward(self, x): 4727 x = self.conv(x) 4728 x = torch.flatten(x, 1) 4729 x = self.linear(x) 4730 return x 4731 4732 m = M().eval() 4733 example_inputs = (torch.rand(2, 1, 5, 5),) 4734 m = prepare_fx( 4735 m, 4736 {"": torch.ao.quantization.QConfig( 4737 activation=torch.ao.quantization.HistogramObserver.with_args( 4738 qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 4739 ), weight=torch.ao.quantization.default_per_channel_weight_observer)}, 4740 example_inputs=example_inputs) 4741 m = convert_to_reference_fx(m) 4742 m(*example_inputs) 4743 4744 def test_preserve_tuple(self): 4745 """ Test tuple input type is preserved 4746 """ 4747 4748 class LSTM(nn.Module): 4749 def __init__(self) -> None: 4750 super().__init__() 4751 self.lstm = nn.LSTM(50, 50, 1) 4752 4753 def forward(self, inputs: torch.Tensor, state: List[torch.Tensor]): 4754 h = state[0] 4755 c = state[1] 4756 return self.lstm(inputs, (h, c)) 4757 4758 m = LSTM().eval() 4759 example_inputs = (torch.randn(5, 3, 50), torch.randn(2, 3, 50), torch.randn(2, 3, 50)) 4760 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 4761 # make sure the arg[1] of lstm module is a tuple 4762 for n in m.graph.nodes: 4763 if n.target == "lstm": 4764 self.assertEqual(type(n.args[1]), tuple) 4765 4766 def _test_static_lstm_helper(self, model, prepare_node_occurrence, convert_node_occurrence): 4767 """ 4768 Helper method to validate the graph of a model with static LSTM. 4769 """ 4770 qconfig_mapping = get_default_qconfig_mapping() 4771 prepare_custom_config = PrepareCustomConfig() \ 4772 .set_float_to_observed_mapping(torch.nn.LSTM, torch.ao.nn.quantizable.LSTM) 4773 convert_custom_config = ConvertCustomConfig() \ 4774 .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, torch.ao.nn.quantized.LSTM) 4775 example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) 4776 4777 model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) 4778 self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence) 4779 model(*example_inputs) 4780 4781 model = convert_fx(model, convert_custom_config=convert_custom_config) 4782 self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence) 4783 model(*example_inputs) 4784 4785 def test_static_lstm(self): 4786 """ 4787 Test statically quantized custom module LSTM followed by ops that consume individual 4788 tensors of the output tuple. 4789 """ 4790 class MyModel(nn.Module): 4791 def __init__(self) -> None: 4792 super().__init__() 4793 self.lstm = nn.LSTM(50, 50, 1) 4794 self.linear1 = nn.Linear(50, 10) 4795 self.linear2 = nn.Linear(50, 10) 4796 self.linear3 = nn.Linear(50, 10) 4797 4798 def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): 4799 (out, (h0_out, c0_out)) = self.lstm(inputs, (h0, c0)) 4800 out = self.linear1(out) 4801 h0_out = self.linear2(h0_out) 4802 c0_out = self.linear3(c0_out) 4803 return (out, (h0_out, c0_out)) 4804 4805 m = MyModel() 4806 prepare_node_occurrence = { 4807 ns.call_module(torch.ao.nn.quantizable.LSTM): 1, 4808 } 4809 convert_node_occurrence = { 4810 ns.call_module(torch.ao.nn.quantized.LSTM): 1, 4811 ns.call_function(torch.quantize_per_tensor): 3, 4812 # lstm[0].dequantize() 4813 # lstm[1][0].dequantize() 4814 # lstm[1][1].dequantize() 4815 ns.call_method("dequantize"): 3, 4816 # lstm[0], lstm[1], lstm[1][0], lstm[1][1] 4817 ns.call_function(operator.getitem): 4, 4818 # No tuples are consumed 4819 ns.call_function(tuple): 0, 4820 } 4821 self._test_static_lstm_helper(m, prepare_node_occurrence, convert_node_occurrence) 4822 4823 def test_static_lstm_consume_tuple(self): 4824 """ 4825 Test statically quantized custom module LSTM followed by a module that consumes the 4826 output tuple, either as a whole or part of it. 4827 """ 4828 class ModuleAfterLSTM(nn.Module): 4829 def __init__(self) -> None: 4830 super().__init__() 4831 self.identity = torch.nn.Identity() 4832 4833 def forward(self, x): 4834 return self.identity(x) 4835 4836 class ConsumeWholeTuple(nn.Module): 4837 def __init__(self) -> None: 4838 super().__init__() 4839 self.lstm = nn.LSTM(50, 50, 1) 4840 self.module_after_lstm = ModuleAfterLSTM() 4841 4842 def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): 4843 x = self.lstm(inputs, (h0, c0)) 4844 x = self.module_after_lstm(x) # consume tuple (output, (hidden0, hidden1)) 4845 return x 4846 4847 class ConsumeHiddenTuple(ConsumeWholeTuple): 4848 def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): 4849 x = self.lstm(inputs, (h0, c0)) 4850 x = self.module_after_lstm(x[1]) # consume tuple (hidden0, hidden1) 4851 return x 4852 4853 # Test consuming the whole tuple (output, (hidden0, hidden1)) 4854 m1 = ConsumeWholeTuple() 4855 prepare_node_occurrence = { 4856 ns.call_module(torch.ao.nn.quantizable.LSTM): 1, 4857 } 4858 convert_node_occurrence1 = { 4859 ns.call_module(torch.ao.nn.quantized.LSTM): 1, 4860 ns.call_function(torch.quantize_per_tensor): 3, 4861 # lstm[0].dequantize() 4862 # lstm[1][0].dequantize() 4863 # lstm[1][1].dequantize() 4864 ns.call_method("dequantize"): 3, 4865 # lstm[0], lstm[1], lstm[1][0], lstm[1][1] 4866 ns.call_function(operator.getitem): 4, 4867 # tuple(output_dq, tuple(hidden0_dq, hidden1_dq)) 4868 ns.call_function(tuple): 2, 4869 } 4870 self._test_static_lstm_helper(m1, prepare_node_occurrence, convert_node_occurrence1) 4871 4872 # Test consuming just the hidden tuple (hidden0, hidden1) 4873 m2 = ConsumeHiddenTuple() 4874 convert_node_occurrence2 = { 4875 ns.call_module(torch.ao.nn.quantized.LSTM): 1, 4876 ns.call_function(torch.quantize_per_tensor): 3, 4877 # lstm[1][0].dequantize() 4878 # lstm[1][1].dequantize() 4879 ns.call_method("dequantize"): 2, 4880 # lstm[1], lstm[1][0], lstm[1][1] 4881 ns.call_function(operator.getitem): 3, 4882 # tuple(hidden0_dq, hidden1_dq) 4883 ns.call_function(tuple): 1, 4884 } 4885 self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2) 4886 4887 def test_static_lstm_with_custom_fixed_qparams(self): 4888 """ 4889 Test statically quantized LSTM with custom fixed qparams assigned to each of the 4890 inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM` 4891 and use the child class in the custom module mapping. 4892 """ 4893 class MyModel(torch.nn.Module): 4894 def __init__(self) -> None: 4895 super().__init__() 4896 self.my_lstm = torch.nn.LSTM(50, 50, 1) 4897 4898 def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): 4899 x = self.my_lstm(inputs, (h0, c0)) 4900 return x 4901 4902 # Construct a BackendConfig that supports qint32 for certain ops 4903 # TODO: build a BackendConfig from scratch instead of modifying an existing one 4904 qint32_dtype_config = DTypeConfig(input_dtype=torch.qint32, output_dtype=torch.qint32) 4905 my_backend_config = get_qnnpack_backend_config() 4906 for config in my_backend_config.configs: 4907 if config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh, torch.add, torch.mul]: 4908 config.add_dtype_config(qint32_dtype_config) 4909 4910 class UserObservedLSTM(torch.ao.nn.quantizable.LSTM): 4911 """ 4912 Example of user provided LSTM implementation that assigns fixed qparams 4913 to the inner ops. 4914 """ 4915 @classmethod 4916 def from_float(cls, float_lstm): 4917 assert isinstance(float_lstm, cls._FLOAT_MODULE) 4918 # uint16, [-16, 16) 4919 linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) 4920 # uint16, [0, 1) 4921 sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32) 4922 # uint16, [-1, 1) 4923 tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) 4924 # int16, [-16, 16) 4925 cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32) 4926 # uint8, [-1, 1) 4927 hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) 4928 example_inputs = (torch.rand(5, 3, 50), (torch.rand(1, 3, 50), torch.randn(1, 3, 50))) 4929 return torch.ao.quantization.fx.lstm_utils._get_lstm_with_individually_observed_parts( 4930 float_lstm=float_lstm, 4931 example_inputs=example_inputs, 4932 backend_config=my_backend_config, 4933 linear_output_obs_ctr=linear_output_obs_ctr, 4934 sigmoid_obs_ctr=sigmoid_obs_ctr, 4935 tanh_obs_ctr=tanh_obs_ctr, 4936 cell_state_obs_ctr=cell_state_obs_ctr, 4937 hidden_state_obs_ctr=hidden_state_obs_ctr, 4938 ) 4939 4940 class UserQuantizedLSTM(torch.ao.nn.quantized.LSTM): 4941 """ 4942 Example of user provided LSTM implementation that produces a reference 4943 quantized module from a `UserObservedLSTM`. 4944 """ 4945 @classmethod 4946 def from_observed(cls, observed_lstm): 4947 assert isinstance(observed_lstm, cls._FLOAT_MODULE) 4948 return torch.ao.quantization.fx.lstm_utils._get_reference_quantized_lstm_module( 4949 observed_lstm=observed_lstm, 4950 backend_config=my_backend_config, 4951 ) 4952 4953 # FX graph mode quantization 4954 m = MyModel() 4955 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 4956 example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) 4957 prepare_custom_config = PrepareCustomConfig() \ 4958 .set_float_to_observed_mapping(torch.nn.LSTM, UserObservedLSTM) 4959 convert_custom_config = ConvertCustomConfig() \ 4960 .set_observed_to_quantized_mapping(torch.ao.nn.quantizable.LSTM, UserQuantizedLSTM) 4961 prepared = prepare_fx( 4962 m, 4963 qconfig_mapping, 4964 example_inputs, 4965 prepare_custom_config, 4966 backend_config=my_backend_config, 4967 ) 4968 prepared(*example_inputs) 4969 converted = convert_fx( 4970 prepared, 4971 convert_custom_config, 4972 backend_config=my_backend_config, 4973 ) 4974 converted(*example_inputs) 4975 4976 # Find the patterns [dq - op - q_to_specific_dtype] in the graph and 4977 # verify that qparams and dtypes are set correctly in the quantize ops 4978 node_name_to_expected_quantize_args = { 4979 "igates": (None, None, torch.quint8), 4980 "hgates": (None, None, torch.quint8), 4981 "add": (2 ** -11, 2 ** 15, torch.qint32), # gates.add 4982 "input_gate": (2 ** -16, 0, torch.qint32), 4983 "forget_gate": (2 ** -16, 0, torch.qint32), 4984 "cell_gate": (2 ** -15, 2 ** 15, torch.qint32), 4985 "output_gate": (2 ** -16, 0, torch.qint32), 4986 "mul": (2 ** -11, 0, torch.qint32), # fgate_cx.mul 4987 "mul_1": (2 ** -11, 0, torch.qint32), # igate_cgate.mul 4988 "add_1": (2 ** -11, 0, torch.qint32), # fgate_cx_igate_cgate.add 4989 "mul_2": (2 ** -7, 2 ** 7, torch.quint8), # ogate_cy.mul 4990 } 4991 cell = converted.my_lstm.layers.get_submodule("0").layer_fw.cell 4992 matched_names = set() 4993 for node in cell.graph.nodes: 4994 if node.name not in node_name_to_expected_quantize_args: 4995 continue 4996 matched_names.add(node.name) 4997 # Match preceding dequantize 4998 self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) 4999 # Match following quantize with the specific qparams and dtypes 5000 expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name] 5001 for user in node.users.keys(): 5002 self.assertEqual(user.target, torch.quantize_per_tensor) 5003 if expected_scale is not None: 5004 self.assertEqual(getattr(cell, user.args[1].target), expected_scale) 5005 if expected_zp is not None: 5006 self.assertEqual(getattr(cell, user.args[2].target), expected_zp) 5007 self.assertEqual(user.args[-1], expected_dtype) 5008 # Ensure all patterns were matched 5009 self.assertEqual(matched_names, set(node_name_to_expected_quantize_args.keys())) 5010 5011 def test_reroute_tuple_getitem_patterns(self): 5012 """ 5013 The following graph should redirect the output to `b`. After the transformation, 5014 all other nodes, including the inputs `a` and `c`, are no longer needed. 5015 5016 a b c 5017 | \\ / 5018 \\ tuple 5019 \\ / 5020 tuple 5021 / \\ 5022 / \\ 5023 | \\ 5024 | \\ 5025 | \\ 5026 getitem0 getitem1 5027 | / \\ 5028 | getitem0 getitem1 5029 | \\ / 5030 \\ tuple 5031 \\ / 5032 \\ / 5033 tuple 5034 | 5035 getitem1 5036 | 5037 getitem0 5038 | 5039 output 5040 """ 5041 # Construct graph manually because symbolic_trace does not insert tuple and getitem nodes 5042 graph = torch.fx.Graph() 5043 a = graph.create_node("placeholder", "a") 5044 b = graph.create_node("placeholder", "b") 5045 c = graph.create_node("placeholder", "c") 5046 bc = graph.call_function(tuple, args=([b, c],)) 5047 abc = graph.call_function(tuple, args=([a, bc],)) 5048 5049 # Break down tuple and reconstruct it again 5050 a2 = graph.call_function(operator.getitem, args=(abc, 0)) 5051 bc2 = graph.call_function(operator.getitem, args=(abc, 1)) 5052 b2 = graph.call_function(operator.getitem, args=(bc2, 0)) 5053 c2 = graph.call_function(operator.getitem, args=(bc2, 1)) 5054 bc3 = graph.call_function(tuple, args=([b2, c2],)) 5055 abc2 = graph.call_function(tuple, args=([a2, bc3],)) 5056 5057 # Output tuple[1][0] 5058 bc4 = graph.call_function(operator.getitem, args=(abc2, 1)) 5059 b3 = graph.call_function(operator.getitem, args=(bc4, 0)) 5060 output = graph.output(b3) 5061 5062 # Do reroute 5063 _reroute_tuple_getitem_pattern(graph) 5064 5065 # Assert that output reroutes to `b` directly, and all other nodes can be removed 5066 output_ancestors = [] 5067 def gather_ancestors(current_node): # noqa: E306 5068 for arg in current_node.args: 5069 output_ancestors.append(arg) 5070 gather_ancestors(arg) 5071 gather_ancestors(output) 5072 self.assertEqual(output_ancestors, [b]) 5073 self.assertEqual(output.args[0], b) 5074 5075 def test_relu_lowering(self): 5076 class M(torch.nn.Module): 5077 def forward(self, x): 5078 return torch.nn.functional.relu(x) 5079 5080 m = M().eval() 5081 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.randn(1),)) 5082 m_copy = copy.deepcopy(m) 5083 m = convert_fx(m) 5084 m_ref = convert_to_reference_fx(m_copy) 5085 node_occurrence = { 5086 ns.call_function(torch.quantize_per_tensor): 1, 5087 ns.call_method("dequantize"): 1 5088 } 5089 node_occurrence_ref = { 5090 ns.call_function(torch.quantize_per_tensor): 2, 5091 ns.call_method("dequantize"): 2 5092 } 5093 5094 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 5095 self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) 5096 5097 @skipIfNoFBGEMM 5098 def test_dynamic_with_fusion(self): 5099 """ 5100 Tests that dynamic quantization APIs work with Linear + Relu fusion 5101 """ 5102 with override_quantized_engine('fbgemm'): 5103 class LinearRelu(torch.nn.Module): 5104 def __init__(self) -> None: 5105 super().__init__() 5106 self.linear = torch.nn.Linear(5, 5) 5107 self.relu = torch.nn.ReLU() 5108 5109 def forward(self, x): 5110 x = self.linear(x) 5111 return self.relu(x) 5112 5113 class Linear(torch.nn.Module): 5114 def __init__(self) -> None: 5115 super().__init__() 5116 self.w = torch.ones(5, 5) 5117 self.b = torch.zeros(5) 5118 5119 def forward(self, x): 5120 return torch.nn.functional.linear(x, self.w, self.b) 5121 5122 class M(torch.nn.Module): 5123 def __init__(self) -> None: 5124 super().__init__() 5125 self.mods1 = torch.nn.Sequential(LinearRelu(), LinearRelu()) 5126 self.mods2 = Linear() 5127 self.relu = F.relu 5128 5129 def forward(self, x): 5130 x = self.mods1(x) 5131 x = self.mods2(x) 5132 x = self.relu(x) 5133 return x 5134 5135 dynamic_quantized_ops = { 5136 float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16, 5137 default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic 5138 } 5139 for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: 5140 model = M().eval() 5141 qconfig_dict = { 5142 "": qconfig 5143 } 5144 example_inputs = (torch.rand(5, 5),) 5145 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 5146 m = convert_fx(m) 5147 m(*example_inputs) 5148 node_list = [ 5149 ns.call_module(nniqd.LinearReLU), 5150 ns.call_module(nniqd.LinearReLU), 5151 ns.call_function(dynamic_quantized_ops[qconfig]), 5152 ] 5153 self.checkGraphModuleNodes(m, expected_node_list=node_list) 5154 5155 @skipIfNoFBGEMM 5156 def test_dynamic_with_fusion_multiple_uses(self): 5157 """ 5158 Tests that dynamic quantization APIs work with Linear + Relu fusion 5159 """ 5160 with override_quantized_engine('fbgemm'): 5161 class LinearRelu(torch.nn.Module): 5162 def __init__(self) -> None: 5163 super().__init__() 5164 self.linear = torch.nn.Linear(5, 5) 5165 self.relu = torch.nn.ReLU() 5166 5167 def forward(self, x): 5168 x = self.linear(x) 5169 return self.relu(x) 5170 5171 class M(torch.nn.Module): 5172 def __init__(self) -> None: 5173 super().__init__() 5174 self.linear_relu = LinearRelu() 5175 5176 def forward(self, x): 5177 x = self.linear_relu(x) 5178 x = self.linear_relu(x) 5179 return x 5180 5181 for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: 5182 model = M().eval() 5183 qconfig_dict = { 5184 "": qconfig 5185 } 5186 example_inputs = (torch.randn(5, 5),) 5187 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 5188 m = convert_fx(m) 5189 m(*example_inputs) 5190 node_list = [ 5191 ns.call_module(nniqd.LinearReLU), 5192 ns.call_module(nniqd.LinearReLU), 5193 ] 5194 self.checkGraphModuleNodes(m, expected_node_list=node_list) 5195 5196 @skipIfNoFBGEMM 5197 def test_dynamic_linear_input_multiple_use(self): 5198 """ 5199 Tests input for dynamic linear being used by multiple ops 5200 """ 5201 with override_quantized_engine('fbgemm'): 5202 class LinearRelu(torch.nn.Module): 5203 def __init__(self) -> None: 5204 super().__init__() 5205 self.linear = torch.nn.Linear(5, 5) 5206 self.relu = torch.nn.ReLU() 5207 5208 def forward(self, x): 5209 x = self.linear(x) 5210 return self.relu(x) 5211 5212 class M(torch.nn.Module): 5213 def __init__(self) -> None: 5214 super().__init__() 5215 self.mod1 = LinearRelu() 5216 self.mod2 = LinearRelu() 5217 5218 def forward(self, x): 5219 y1 = self.mod1(x) 5220 y2 = self.mod2(x) 5221 return y1 + y2 5222 5223 for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: 5224 model = M().eval() 5225 qconfig_dict = { 5226 "": qconfig 5227 } 5228 example_inputs = (torch.rand(5, 5, 5),) 5229 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 5230 m = convert_fx(m) 5231 m(*example_inputs) 5232 node_list = [ 5233 ns.call_module(nniqd.LinearReLU), 5234 ns.call_module(nniqd.LinearReLU), 5235 ] 5236 self.checkGraphModuleNodes(m, expected_node_list=node_list) 5237 5238 def test_ref_linear_module(self): 5239 """ Make sure the numerics for models with ref linear module 5240 matches models with fbgemm/qnnpack module 5241 """ 5242 class M1(torch.nn.Module): 5243 def __init__(self) -> None: 5244 super().__init__() 5245 self.linear = torch.nn.Linear(10, 5) 5246 5247 def forward(self, x): 5248 return self.linear(x) 5249 5250 class M2(torch.nn.Module): 5251 def __init__(self) -> None: 5252 super().__init__() 5253 self.linear = torch.nn.Linear(10, 5) 5254 self.relu = torch.nn.ReLU() 5255 5256 def forward(self, x): 5257 return self.relu(self.linear(x)) 5258 5259 for M in [M1, M2]: 5260 m = M().eval() 5261 example_inputs = (torch.randn(5, 10),) 5262 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 5263 m_copy = copy.deepcopy(m) 5264 m = convert_fx(m) 5265 m_ref = convert_to_reference_fx(m_copy) 5266 result = m(*example_inputs) 5267 result_ref = m_ref(*example_inputs) 5268 self.assertTrue(torch.equal(result, result_ref)) 5269 5270 def test_ref_conv_module(self): 5271 """ Make sure the numerics for models with ref conv module 5272 matches models with fbgemm/qnnpack module 5273 """ 5274 convs = { 5275 1: nn.Conv1d, 5276 2: nn.Conv2d, 5277 3: nn.Conv3d, 5278 } 5279 5280 class M1(torch.nn.Module): 5281 def __init__(self, dim): 5282 super().__init__() 5283 self.conv = convs[dim](3, 3, 3) 5284 5285 def forward(self, x): 5286 return self.conv(x) 5287 5288 class M2(torch.nn.Module): 5289 def __init__(self, dim): 5290 super().__init__() 5291 self.conv = convs[dim](3, 3, 3) 5292 self.relu = torch.nn.ReLU() 5293 5294 def forward(self, x): 5295 return self.relu(self.conv(x)) 5296 5297 for dim, M in itertools.product([1, 2, 3], [M1, M2]): 5298 m = M(dim).eval() 5299 data = self.img_data_dict[dim][0][0] 5300 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) 5301 m_copy = copy.deepcopy(m) 5302 m = convert_fx(m) 5303 m_ref = convert_to_reference_fx(m_copy) 5304 result = m(data) 5305 result_ref = m_ref(data) 5306 self.assertTrue(torch.equal(result, result_ref)) 5307 5308 def test_sub_scalar(self): 5309 class M(torch.nn.Module): 5310 def forward(self, x): 5311 x = x + 1 5312 x = x - 1 5313 x = x + 3 5314 x = x - 4 5315 return x 5316 5317 m = M().eval() 5318 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(torch.rand(3),)) 5319 m = convert_fx(m) 5320 occurrence = { 5321 ns.call_function(torch.quantize_per_tensor): 2, 5322 ns.call_method("dequantize"): 2 5323 } 5324 self.checkGraphModuleNodes(m, expected_node_occurrence=occurrence) 5325 5326 def test_observer_fqn(self): 5327 """ 5328 Test to make sure the observer FQN is based on the quantizable op/module that it is observing 5329 and uses the modules FQN to determine the observer name. 5330 """ 5331 class Linear(torch.nn.Module): 5332 def __init__(self) -> None: 5333 super().__init__() 5334 self.w = torch.ones(5, 5) 5335 self.b = torch.zeros(5) 5336 5337 5338 def forward(self, x): 5339 return torch.nn.functional.linear(x, self.w, self.b) 5340 5341 5342 class M(torch.nn.Module): 5343 def __init__(self) -> None: 5344 super().__init__() 5345 self.mods1 = torch.nn.Sequential( 5346 Linear(), 5347 Linear() 5348 ) 5349 self.mods2 = Linear() 5350 self.mods3 = torch.nn.Linear(5, 5) 5351 5352 def forward(self, x): 5353 x = self.mods1(x) 5354 x = torch.add(x, 4) 5355 x = self.mods2(x) 5356 y = torch.add(x, 2) 5357 z = torch.mul(x, 5) 5358 a = self.mods3(y) 5359 return a, z 5360 5361 model = M().eval() 5362 5363 prepared = prepare_fx(model, {"": default_qconfig}, example_inputs=(torch.randn(1, 5))) 5364 name_list = [] 5365 for name, mod in prepared.named_modules(): 5366 if isinstance(mod, torch.ao.quantization.observer.MinMaxObserver): 5367 name_list.append(name) 5368 expected_name_list = ['activation_post_process_0', 5369 'activation_post_process_1', 5370 'activation_post_process_2', 5371 'activation_post_process_3', 5372 'activation_post_process_4', 5373 'activation_post_process_6', 5374 'activation_post_process_7', 5375 'activation_post_process_10'] 5376 assert name_list == expected_name_list 5377 5378 def test_conv_lowering(self): 5379 convs = {1: nn.Conv1d, 2: nn.Conv2d, 3: nn.Conv3d} 5380 qconvs = {1: nn.quantized.Conv1d, 2: nn.quantized.Conv2d, 3: nn.quantized.Conv3d} 5381 5382 class M(torch.nn.Module): 5383 def __init__(self, dim): 5384 super().__init__() 5385 self.conv = convs[dim](3, 3, 3) 5386 5387 def forward(self, x): 5388 return self.conv(x) 5389 5390 for dim in range(1, len(convs) + 1): 5391 m = M(dim).eval() 5392 data = self.img_data_dict[dim][0][0] 5393 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(data,)) 5394 m_ref = copy.deepcopy(m) 5395 m_ref = convert_to_reference_fx(m_ref) 5396 m = convert_fx(m) 5397 out_ref = m_ref(data) 5398 out = m(data) 5399 # check that reference pattern for quantized conv module is fused 5400 expected_node_occurrence = { 5401 ns.call_function(torch.quantize_per_tensor): 1, 5402 ns.call_module(qconvs[dim]): 1, 5403 ns.call_method("dequantize"): 1 5404 } 5405 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_node_occurrence) 5406 # checking result match 5407 self.assertTrue(torch.equal(out_ref, out)) 5408 5409 def test_convert_qconfig_mapping(self): 5410 class Linear(torch.nn.Module): 5411 def __init__(self) -> None: 5412 super().__init__() 5413 self.w = torch.ones(5, 5) 5414 self.b = torch.zeros(5) 5415 5416 def forward(self, x): 5417 return torch.nn.functional.linear(x, self.w, self.b) 5418 5419 5420 class M(torch.nn.Module): 5421 def __init__(self) -> None: 5422 super().__init__() 5423 self.mods1 = torch.nn.Sequential( 5424 Linear(), 5425 Linear() 5426 ) 5427 self.mods3 = torch.nn.Linear(5, 5) 5428 5429 def forward(self, x): 5430 x = self.mods1(x) 5431 x = torch.add(x, 4) 5432 z = torch.mul(x, 5) 5433 x = self.mods3(z) 5434 return x 5435 5436 model = M().train() 5437 5438 for check in ["module_name", "object_type"]: 5439 qconfig_dict = {"": None, 5440 "object_type": [ 5441 (nn.functional.linear, get_default_qat_qconfig("fbgemm")), 5442 (torch.add, get_default_qat_qconfig("fbgemm")), 5443 (nn.Linear, get_default_qat_qconfig("fbgemm")), 5444 ], 5445 } 5446 example_inputs = (torch.rand(5, 5),) 5447 prepared = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 5448 prepared(*example_inputs) 5449 if check == "module_name": 5450 convert_qconfig_dict = {"": None, 5451 "object_type": [ 5452 (nn.functional.linear, get_default_qat_qconfig("fbgemm")), 5453 (torch.add, get_default_qat_qconfig("fbgemm")), 5454 (nn.Linear, get_default_qat_qconfig("fbgemm")), 5455 ], 5456 "module_name": [("mods1.0", None)]} 5457 5458 node_occurrence = { 5459 ns.call_function(torch.quantize_per_tensor): 2, 5460 ns.call_function(torch.nn.functional.linear): 1, 5461 ns.call_function(torch.ops.quantized.linear): 1, 5462 ns.call_function(torch.ops.quantized.add): 1, 5463 ns.call_method("dequantize"): 2 5464 } 5465 order_check = [ 5466 ns.call_function(torch.nn.functional.linear), 5467 ns.call_function(torch.quantize_per_tensor), 5468 ns.call_function(torch.ops.quantized.linear), 5469 ns.call_function(torch.ops.quantized.add), 5470 ns.call_method("dequantize"), 5471 ns.call_function(torch.quantize_per_tensor), 5472 ns.call_module(nnq.Linear), 5473 ns.call_method("dequantize"), 5474 ] 5475 elif check == "object_type": 5476 convert_qconfig_dict = {"": None, 5477 "object_type": [ 5478 (nn.functional.linear, get_default_qat_qconfig("fbgemm")), 5479 (torch.add, get_default_qat_qconfig("fbgemm")), 5480 (nn.Linear, None), 5481 ]} 5482 5483 node_occurrence = { 5484 ns.call_function(torch.quantize_per_tensor): 1, 5485 ns.call_function(torch.ops.quantized.linear): 2, 5486 ns.call_function(torch.ops.quantized.add): 1, 5487 ns.call_function(torch.mul): 1, 5488 ns.call_method("dequantize"): 1 5489 } 5490 order_check = [ 5491 ns.call_function(torch.quantize_per_tensor), 5492 ns.call_function(torch.ops.quantized.linear), 5493 ns.call_function(torch.ops.quantized.linear), 5494 ns.call_function(torch.ops.quantized.add), 5495 ns.call_method("dequantize"), 5496 ns.call_function(torch.mul), 5497 ns.call_module(nn.Linear), 5498 ] 5499 5500 converted = convert_fx(prepared, qconfig_mapping=convert_qconfig_dict) 5501 converted(torch.rand(5, 5)) 5502 self.checkGraphModuleNodes( 5503 converted, 5504 expected_node_occurrence=node_occurrence, 5505 expected_node_list=order_check) 5506 5507 def _assertFixedQParamsFakeQuantizeEqual(self, fq1, fq2): 5508 self.assertEqual(fq1()._observer_ctr, fq2()._observer_ctr) 5509 5510 def test_register_patterns(self): 5511 def cleanUp(): 5512 del _DEFAULT_FUSION_PATTERNS["dummy_fusion"] 5513 del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"] 5514 del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"] 5515 del _DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"] 5516 del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"] 5517 del _DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"] 5518 del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"] 5519 del _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"] 5520 self.addCleanup(cleanUp) 5521 5522 @_register_fusion_pattern("dummy_fusion") 5523 class DummyFusion: 5524 pass 5525 5526 @_register_quant_pattern("dummy_quant") 5527 class DummyQuant: 5528 pass 5529 5530 @_register_quant_pattern("dummy_quant2", default_fixed_qparams_range_0to1_observer) 5531 class DummyQuant2: 5532 pass 5533 5534 @_register_quant_pattern("dummy_quant3", default_fixed_qparams_range_neg1to1_observer) 5535 class DummyQuant3: 5536 pass 5537 5538 self.assertEqual(_DEFAULT_FUSION_PATTERNS["dummy_fusion"], DummyFusion) 5539 self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant"], DummyQuant) 5540 self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant2"], DummyQuant2) 5541 self.assertEqual(_DEFAULT_QUANTIZATION_PATTERNS["dummy_quant3"], DummyQuant3) 5542 self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant2"], default_fixed_qparams_range_0to1_observer) 5543 self.assertEqual(_DEFAULT_OUTPUT_OBSERVER_MAP["dummy_quant3"], default_fixed_qparams_range_neg1to1_observer) 5544 self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant2"], 5545 default_fixed_qparams_range_0to1_fake_quant) 5546 self._assertFixedQParamsFakeQuantizeEqual(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP["dummy_quant3"], 5547 default_fixed_qparams_range_neg1to1_fake_quant) 5548 output_fake_quantize_map = get_default_output_activation_post_process_map(is_training=True) 5549 output_observer_map = get_default_output_activation_post_process_map(is_training=False) 5550 self.assertEqual(output_observer_map.get("dummy_quant3"), default_fixed_qparams_range_neg1to1_observer) 5551 self._assertFixedQParamsFakeQuantizeEqual(output_fake_quantize_map.get("dummy_quant3"), 5552 default_fixed_qparams_range_neg1to1_fake_quant) 5553 5554 5555 5556 def test_reuse_input_qconfig(self): 5557 class M1(torch.nn.Module): 5558 def __init__(self) -> None: 5559 super().__init__() 5560 self.conv = torch.nn.Conv2d(3, 3, 3) 5561 5562 def forward(self, x): 5563 x = self.conv(x) 5564 x = x.reshape() 5565 return x 5566 5567 class M2(torch.nn.Module): 5568 def forward(self, x): 5569 x = x.reshape() 5570 return x 5571 5572 options = itertools.product([M1, M2], [True, False]) 5573 for M, is_qat in options: 5574 m = M1().eval() 5575 example_inputs = (torch.randn(1, 3, 3, 3),) 5576 m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) 5577 m = convert_fx(m) 5578 node_list = [ 5579 ns.call_function(torch.quantize_per_tensor), 5580 ns.call_module(nnq.Conv2d), 5581 ns.call_method("reshape"), 5582 ns.call_method("dequantize"), 5583 ] 5584 self.checkGraphModuleNodes( 5585 m, 5586 expected_node_list=node_list) 5587 5588 m = M2().eval() 5589 m = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=example_inputs) 5590 m = convert_fx(m) 5591 node_occurrence = { 5592 ns.call_function(torch.quantize_per_tensor): 0, 5593 ns.call_method("dequnatize"): 0, 5594 } 5595 node_list = [ 5596 ns.call_method("reshape"), 5597 ] 5598 self.checkGraphModuleNodes( 5599 m, 5600 expected_node_occurrence=node_occurrence, 5601 expected_node_list=node_list) 5602 5603 def test_stack_trace_preserved_linear(self): 5604 class M(nn.Module): 5605 def __init__(self) -> None: 5606 super().__init__() 5607 self.linear = nn.Linear(1, 1) 5608 5609 def forward(self, x): 5610 x = self.linear(x) 5611 return x 5612 5613 m = M().eval() 5614 mp = prepare_fx(m, get_default_qconfig_mapping(), example_inputs=(torch.randn(1, 1),)) 5615 5616 found_stack_trace = False 5617 for n in mp.graph.nodes: 5618 if n.op == 'call_module' and n.target == 'linear': 5619 found_stack_trace = n.stack_trace is not None 5620 break 5621 self.assertTrue(found_stack_trace) 5622 5623 # test reference model 5624 mq = convert_to_reference_fx(copy.deepcopy(mp)) 5625 found_stack_trace = False 5626 for n in mq.graph.nodes: 5627 if n.op == 'call_module' and n.target == 'linear': 5628 found_stack_trace = n.stack_trace is not None 5629 break 5630 self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: True") 5631 5632 # test quantized model 5633 mq = convert_fx(mp) 5634 found_stack_trace = False 5635 for n in mq.graph.nodes: 5636 if n.op == 'call_module' and n.target == 'linear': 5637 found_stack_trace = n.stack_trace is not None 5638 break 5639 self.assertTrue(found_stack_trace, f"stack trace not found, node: {n.format_node()}, is_reference: False") 5640 5641 def test_qat_skip_untraced(self): 5642 class UnTraceableModuleClass(nn.Module): 5643 def __init__(self) -> None: 5644 super().__init__() 5645 self.linear = nn.Linear(2, 2) 5646 5647 def forward(self, x): 5648 return self.linear(x) 5649 5650 class UnTraceableModuleName(nn.Module): 5651 def __init__(self) -> None: 5652 super().__init__() 5653 self.linear = nn.Linear(2, 2) 5654 5655 def forward(self, x): 5656 return self.linear(x) 5657 5658 class M(nn.Module): 5659 def __init__(self) -> None: 5660 super().__init__() 5661 self.untraceable_module_class = UnTraceableModuleClass() 5662 self.untraceable_module_name = UnTraceableModuleClass() 5663 5664 def forward(self, x): 5665 x = self.untraceable_module_class(x) 5666 x = self.untraceable_module_name(x) 5667 return x 5668 5669 mod = M() 5670 5671 qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig()} 5672 prepare_custom_config_dict = { 5673 "non_traceable_module_class": [UnTraceableModuleClass], 5674 "non_traceable_module_name": ["untraceable_module_name"], 5675 } 5676 example_inputs = (torch.randn(2, 2),) 5677 mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( 5678 mod.train(), qconfig_dict, example_inputs=example_inputs, 5679 prepare_custom_config=prepare_custom_config_dict 5680 ) 5681 mod_prep = torch.ao.quantization.quantize_fx.prepare_qat_fx( 5682 mod.train(), qconfig_dict, example_inputs=example_inputs, 5683 prepare_custom_config=prepare_custom_config_dict 5684 ) 5685 self.assertTrue( 5686 isinstance(mod_prep.untraceable_module_class.linear, torch.nn.Linear) 5687 ) 5688 self.assertTrue( 5689 isinstance(mod_prep.untraceable_module_name.linear, torch.nn.Linear) 5690 ) 5691 self.assertTrue( 5692 type(mod_prep.untraceable_module_class.linear) 5693 is not torch.ao.nn.qat.modules.linear.Linear, 5694 "prepare_qat_fx shold not convert anything inside untraced module classes", 5695 ) 5696 self.assertTrue( 5697 type(mod_prep.untraceable_module_name.linear) 5698 is not torch.ao.nn.qat.modules.linear.Linear, 5699 "prepare_qat_fx shold not convert anything inside modules named in untraced_module_names", 5700 ) 5701 5702 def test_qconfig_dict_setup(self): 5703 class M(torch.nn.Module): 5704 def __init__(self) -> None: 5705 super().__init__() 5706 self.Conv1d = torch.nn.Conv1d(1, 1, 1) 5707 self.Conv2d = torch.nn.Conv2d(1, 1, 1) 5708 self.Conv3d = torch.nn.Conv3d(1, 1, 1) 5709 self.ConvTranspose1d = torch.nn.ConvTranspose1d(1, 1, 1) 5710 self.ConvTranspose2d = torch.nn.ConvTranspose2d(1, 1, 1) 5711 self.ConvTranspose3d = torch.nn.ConvTranspose3d(1, 1, 1) 5712 self.Linear = torch.nn.Linear(1, 1, 1) 5713 5714 def forward(self, x): 5715 x = self.Conv1d(x) 5716 x = self.Conv2d(x) 5717 x = self.Conv3d(x) 5718 x = self.ConvTranspose1d(x) 5719 x = self.ConvTranspose2d(x) 5720 x = self.ConvTranspose3d(x) 5721 x = self.Linear(x) 5722 x = torch.nn.functional.conv1d(x, torch.rand(2, 2)) 5723 x = torch.nn.functional.conv2d(x, torch.rand(2, 2)) 5724 x = torch.nn.functional.conv3d(x, torch.rand(2, 2)) 5725 x = torch.nn.functional.linear(x, torch.rand(2, 2)) 5726 return x 5727 5728 backends = ["qnnpack", "fbgemm"] 5729 for func in [get_default_qconfig_mapping, get_default_qat_qconfig_mapping]: 5730 for backend in backends: 5731 m = M().eval() 5732 qconfig_dict = func(backend) 5733 m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) 5734 for name, mod in m.named_modules(): 5735 if _is_activation_post_process(mod) and mod.dtype == torch.quint8: 5736 if backend == "fbgemm": 5737 lower_bnd = 0 5738 upper_bnd = 127 5739 else: 5740 lower_bnd = 0 5741 upper_bnd = 255 5742 if issubclass(type(mod), FakeQuantize): 5743 self.assertEqual(mod.activation_post_process.quant_min, lower_bnd) 5744 self.assertEqual(mod.activation_post_process.quant_max, upper_bnd) 5745 else: 5746 self.assertEqual(mod.quant_min, lower_bnd) 5747 self.assertEqual(mod.quant_max, upper_bnd) 5748 5749 def test_prepare_mode(self): 5750 class LinearModel(torch.nn.Module): 5751 def __init__(self) -> None: 5752 super().__init__() 5753 self.linear = torch.nn.Linear(5, 10) 5754 5755 def forward(self, x): 5756 return self.linear(x) 5757 5758 def _test(prepare_fn, qconfig_dict): 5759 m = LinearModel() 5760 m1 = copy.deepcopy(m) 5761 m1.train() 5762 example_inputs = (torch.randn(1, 5),) 5763 prepare_fn(m1, qconfig_dict, example_inputs=example_inputs) 5764 m2 = copy.deepcopy(m) 5765 m2.eval() 5766 prepare_fn(m2, qconfig_dict, example_inputs=example_inputs) 5767 5768 # Ensure prepare_fx and prepare_qat_fx work in both training and eval modes 5769 _test(prepare_fx, get_default_qconfig_mapping()) 5770 _test(prepare_qat_fx, get_default_qat_qconfig_mapping()) 5771 5772 def _validate_qconfig_against_backend_config_constraints( 5773 self, 5774 model: torch.nn.Module, 5775 qconfig: QConfig, 5776 backend_config: BackendConfig, 5777 satisfies_constraints: bool, 5778 qconfig_name: Optional[str] = None): 5779 """ 5780 Helper method to validate whether `qconfig` satisfies the constraints specified in `backend_config`. 5781 """ 5782 qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) 5783 example_inputs = (torch.rand((1, 30), dtype=torch.float),) 5784 model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) 5785 model(*example_inputs) 5786 model = convert_fx(model, backend_config=backend_config) 5787 if satisfies_constraints: 5788 expected_node_occurrence = { 5789 ns.call_module(torch.ao.nn.quantized.Linear) : 1, 5790 ns.call_module(torch.nn.Linear) : 0, 5791 } 5792 else: 5793 expected_node_occurrence = { 5794 ns.call_module(torch.ao.nn.quantized.Linear) : 0, 5795 ns.call_module(torch.nn.Linear) : 1, 5796 } 5797 try: 5798 self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) 5799 except AssertionError as e: 5800 if qconfig_name is not None: 5801 print(f"ERROR: Validation for QConfig '{qconfig_name}' failed") 5802 raise e 5803 5804 def test_backend_config_quantization_range(self): 5805 """ 5806 Check that quantization ranges specified through the BackendConfig are reflected in 5807 the observers inserted into the model. 5808 """ 5809 class MyModel(torch.nn.Module): 5810 def __init__(self) -> None: 5811 super().__init__() 5812 self.linear = torch.nn.Linear(30, 4).float() 5813 5814 def forward(self, x): 5815 return self.linear(x) 5816 5817 dtype_config = DTypeConfig( 5818 input_dtype=DTypeWithConstraints( 5819 dtype=torch.quint8, 5820 quant_min_lower_bound=0, 5821 quant_max_upper_bound=31, 5822 ), 5823 output_dtype=DTypeWithConstraints( 5824 dtype=torch.quint8, 5825 quant_min_lower_bound=0, 5826 quant_max_upper_bound=31, 5827 ), 5828 weight_dtype=DTypeWithConstraints( 5829 dtype=torch.qint8, 5830 quant_min_lower_bound=-64, 5831 quant_max_upper_bound=63, 5832 ), 5833 bias_dtype=torch.float, 5834 ) 5835 backend_config = BackendConfig() \ 5836 .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) 5837 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 5838 .add_dtype_config(dtype_config) 5839 .set_root_module(torch.nn.Linear) 5840 .set_reference_quantized_module(nnqr.Linear)) 5841 5842 def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): 5843 self._validate_qconfig_against_backend_config_constraints( 5844 MyModel(), qconfig, backend_config, satisfies_constraints) 5845 5846 # Case 1: QConfig ranges fit within backend ranges, OK 5847 qconfig1 = QConfig( 5848 activation=MinMaxObserver.with_args(quant_min=0, quant_max=15, dtype=torch.quint8), 5849 weight=MinMaxObserver.with_args(quant_min=-32, quant_max=31, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) 5850 validate_qconfig(qconfig1, satisfies_constraints=True) 5851 5852 # Case 2: QConfig activation range falls outside backend range, should fail 5853 qconfig2 = QConfig( 5854 activation=MinMaxObserver.with_args(quant_min=0, quant_max=63, dtype=torch.quint8), 5855 weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) 5856 validate_qconfig(qconfig2, satisfies_constraints=False) 5857 5858 # Case 3: QConfig weight range falls outside backend range, should fail 5859 qconfig3 = QConfig( 5860 activation=MinMaxObserver.with_args(dtype=torch.quint8), 5861 weight=MinMaxObserver.with_args(quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) 5862 validate_qconfig(qconfig3, satisfies_constraints=False) 5863 5864 # Case 4: QConfig doesn't specify range, should fail 5865 qconfig4 = QConfig(activation=ReuseInputObserver, weight=ReuseInputObserver) 5866 validate_qconfig(qconfig4, satisfies_constraints=False) 5867 5868 def test_backend_config_scale_min(self): 5869 """ 5870 Test QConfig eps validation against the BackendConfig's min scale value. 5871 """ 5872 class MyModel(torch.nn.Module): 5873 def __init__(self) -> None: 5874 super().__init__() 5875 self.linear = torch.nn.Linear(30, 4).float() 5876 5877 def forward(self, x): 5878 return self.linear(x) 5879 5880 dtype_config = DTypeConfig( 5881 input_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), 5882 output_dtype=DTypeWithConstraints(dtype=torch.quint8, scale_min_lower_bound=2 ** -12), 5883 weight_dtype=DTypeWithConstraints(dtype=torch.qint8, scale_min_lower_bound=2 ** -12), 5884 bias_dtype=torch.float, 5885 ) 5886 5887 backend_config = BackendConfig() \ 5888 .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) 5889 .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 5890 .add_dtype_config(dtype_config) 5891 .set_root_module(torch.nn.Linear) 5892 .set_reference_quantized_module(nnqr.Linear)) 5893 5894 def validate_qconfig(qconfig: QConfig, satisfies_constraints: bool): 5895 self._validate_qconfig_against_backend_config_constraints( 5896 MyModel(), qconfig, backend_config, satisfies_constraints) 5897 5898 # Case 1: QConfig min scale value == backend min scale value, OK 5899 qconfig1 = QConfig( 5900 activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -12), 5901 weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -12)) 5902 validate_qconfig(qconfig1, satisfies_constraints=True) 5903 5904 # Case 2: QConfig min scale value > backend min scale value, OK 5905 qconfig2 = QConfig( 5906 activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -10), 5907 weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -10)) 5908 validate_qconfig(qconfig2, satisfies_constraints=True) 5909 5910 # Case 3: QConfig activation min scale value < backend min scale value, should fail 5911 qconfig3 = QConfig( 5912 activation=MinMaxObserver.with_args(dtype=torch.quint8, eps=2 ** -14), 5913 weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)) 5914 validate_qconfig(qconfig3, satisfies_constraints=False) 5915 5916 # Case 3: QConfig weight min scale value < backend min scale value, should fail 5917 qconfig4 = QConfig( 5918 activation=MinMaxObserver.with_args(dtype=torch.quint8), 5919 weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, eps=2 ** -14)) 5920 validate_qconfig(qconfig4, satisfies_constraints=False) 5921 5922 # Case 5: QConfig doesn't specify eps, should fail 5923 qconfig5 = QConfig( 5924 activation=FixedQParamsObserver.with_args(scale=1.0, zero_point=0), 5925 weight=FixedQParamsObserver.with_args(scale=1.0, zero_point=0)) 5926 validate_qconfig(qconfig5, satisfies_constraints=False) 5927 5928 def test_qnnpack_backend_config(self): 5929 """ 5930 Test whether default QNNPACK QConfigs are compatible with the QNNPACK BackendConfig. 5931 """ 5932 class MyModel(torch.nn.Module): 5933 def __init__(self) -> None: 5934 super().__init__() 5935 self.linear = torch.nn.Linear(30, 4).float() 5936 5937 def forward(self, x): 5938 return self.linear(x) 5939 5940 all_qconfigs: List[Tuple[QConfig, str]] = [ 5941 (get_default_qconfig("qnnpack", version=0), "default_qnnpack_qconfig_v0"), 5942 (get_default_qat_qconfig("qnnpack", version=0), "default_qat_qnnpack_qconfig_v0"), 5943 (get_default_qat_qconfig("qnnpack", version=1), "default_qat_qnnpack_qconfig_v1"), 5944 (default_symmetric_qnnpack_qconfig, "default_symmetric_qnnpack_qconfig"), 5945 (default_symmetric_qnnpack_qat_qconfig, "default_symmetric_qnnpack_qat_qconfig"), 5946 # TODO: Test these QConfigs once they are fixed, see https://github.com/pytorch/pytorch/issues/85862 5947 # (default_per_channel_symmetric_qnnpack_qconfig, "default_per_channel_symmetric_qnnpack_qconfig"), 5948 # (default_per_channel_symmetric_qnnpack_qat_qconfig, "default_per_channel_symmetric_qnnpack_qat_qconfig"), 5949 ] 5950 backend_config = get_qnnpack_backend_config() 5951 for qconfig, qconfig_name in all_qconfigs: 5952 self._validate_qconfig_against_backend_config_constraints( 5953 MyModel(), qconfig, backend_config, satisfies_constraints=True, qconfig_name=qconfig_name) 5954 5955 def test_symmetric_qnnpack_qconfig_mapping(self): 5956 """ 5957 Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qconfig_mapping` 5958 works with the QNNPACK BackendConfig. 5959 """ 5960 if "qnnpack" not in supported_qengines: 5961 return 5962 5963 class MyModel(torch.nn.Module): 5964 def __init__(self) -> None: 5965 super().__init__() 5966 self.linear = torch.nn.Linear(30, 4).float() 5967 5968 def forward(self, x): 5969 return self.linear(x) 5970 5971 with override_quantized_engine("qnnpack"): 5972 qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping() 5973 example_inputs = (torch.rand((1, 30), dtype=torch.float),) 5974 backend_config = get_qnnpack_backend_config() 5975 model = MyModel() 5976 model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) 5977 model(*example_inputs) 5978 model = convert_fx(model, backend_config=backend_config) 5979 expected_node_occurrence = { 5980 ns.call_module(torch.ao.nn.quantized.Linear) : 1, 5981 ns.call_module(torch.nn.Linear) : 0, 5982 } 5983 self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) 5984 model(*example_inputs) 5985 5986 def test_symmetric_qnnpack_qat_qconfig_mapping(self): 5987 """ 5988 Test whether `torch.ao.quantization.qconfig_mapping._get_symmetric_qnnpack_qat_qconfig_mapping` 5989 works with the QNNPACK BackendConfig. 5990 """ 5991 if "qnnpack" not in supported_qengines: 5992 return 5993 5994 class MyModel(torch.nn.Module): 5995 def __init__(self) -> None: 5996 super().__init__() 5997 self.linear = torch.nn.Linear(30, 4).float() 5998 5999 def forward(self, x): 6000 return self.linear(x) 6001 6002 with override_quantized_engine("qnnpack"): 6003 qconfig_mapping = _get_symmetric_qnnpack_qat_qconfig_mapping() 6004 example_inputs = (torch.rand((1, 30), dtype=torch.float),) 6005 backend_config = get_qnnpack_backend_config() 6006 model = MyModel() 6007 model = prepare_fx(model, qconfig_mapping, example_inputs, backend_config=backend_config) 6008 model(*example_inputs) 6009 model = convert_fx(model, backend_config=backend_config) 6010 expected_node_occurrence = { 6011 ns.call_module(torch.ao.nn.quantized.Linear) : 1, 6012 ns.call_module(torch.nn.Linear) : 0, 6013 } 6014 self.checkGraphModuleNodes(model, expected_node_occurrence=expected_node_occurrence) 6015 model(*example_inputs) 6016 6017 6018 def test_get_executorch_backend_config(self): 6019 from torch.ao.quantization.backend_config import get_executorch_backend_config 6020 # make sure this runs 6021 executorch_backend_config = get_executorch_backend_config() 6022 6023 def test_backend_config_check_for_weight_and_bias(self): 6024 """ Test to make sure the backend_config check for weight and bias 6025 runs when the qconfig is None for the ops with weight and bias 6026 previously the error was not hit because we first check input, and 6027 the check for weight and bias are skipped. 6028 """ 6029 6030 class M(torch.nn.Module): 6031 def __init__(self) -> None: 6032 super().__init__() 6033 self.weight = torch.tensor((5, 5)) 6034 self.bias = torch.tensor((5,)) 6035 6036 def forward(self, x): 6037 return torch.addmm(self.bias, x, self.weight) 6038 6039 m = M().eval() 6040 qconfig_mapping = QConfigMapping() 6041 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 6042 weighted_op_quint8_dtype_config = DTypeConfig( 6043 input_dtype=torch.quint8, 6044 output_dtype=torch.quint8, 6045 weight_dtype=torch.qint8, 6046 bias_dtype=torch.float, 6047 ) 6048 dtype_configs = [weighted_op_quint8_dtype_config] 6049 backend_pattern_config = BackendPatternConfig(torch.addmm) \ 6050 .set_observation_type(observation_type) \ 6051 .set_dtype_configs(dtype_configs) \ 6052 ._set_input_type_to_index({"weight": 2, "bias": 0}) 6053 backend_config = BackendConfig() \ 6054 .set_backend_pattern_config(backend_pattern_config) 6055 example_inputs = (torch.rand(1, 5),) 6056 # make sure this runs 6057 m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) 6058 6059 def test_get_default_qconfig_valid_backend(self): 6060 """ Checks that AssertionError is raised when non expected backend input is specified 6061 """ 6062 invalid_backends = ["imaginary_backend", 3] 6063 for invalid_backend in invalid_backends: 6064 with self.assertRaisesRegex(AssertionError, "not supported"): 6065 qconfig = get_default_qconfig(invalid_backend) 6066 with self.assertRaisesRegex(AssertionError, "not supported"): 6067 qconfig = get_default_qat_qconfig(invalid_backend) 6068 with self.assertRaisesRegex(AssertionError, "not supported"): 6069 qconfig_mapping = get_default_qconfig_mapping(invalid_backend) 6070 with self.assertRaisesRegex(AssertionError, "not supported"): 6071 qconfig_mapping = get_default_qat_qconfig_mapping(invalid_backend) 6072 6073 def test__convert_to_reference_decomposed_fx(self): 6074 class M(torch.nn.Module): 6075 def __init__(self) -> None: 6076 super().__init__() 6077 self.linear = torch.nn.Linear(5, 10) 6078 6079 def forward(self, x): 6080 return self.linear(x) 6081 6082 m = M().eval() 6083 qconfig_mapping = get_default_qconfig_mapping("fbgemm") 6084 example_inputs = (torch.randn(1, 5),) 6085 m = prepare_fx(m, qconfig_mapping, example_inputs) 6086 m_ref = copy.deepcopy(m) 6087 m_ref = convert_to_reference_fx(m_ref) 6088 m = _convert_to_reference_decomposed_fx(m) 6089 expected_occurrence = { 6090 ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, 6091 ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, 6092 } 6093 self.checkGraphModuleNodes( 6094 m, 6095 expected_node_occurrence=expected_occurrence) 6096 # make sure it runs 6097 res_ref = m_ref(*example_inputs) 6098 res = m(*example_inputs) 6099 self.assertEqual(res, res_ref) 6100 6101 @skipIfNoQNNPACK 6102 def test__convert_to_reference_decomposed_fx_dynamic_quant(self): 6103 class M(torch.nn.Module): 6104 def __init__(self) -> None: 6105 super().__init__() 6106 self.linear = torch.nn.Linear(5, 10) 6107 6108 def forward(self, x): 6109 return self.linear(x) 6110 6111 # to avoid reduce_range 6112 with override_quantized_engine("qnnpack"): 6113 m = M().eval() 6114 qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ 6115 .set_object_type(torch.nn.Linear, default_dynamic_qconfig) 6116 example_inputs = (torch.randn(1, 5),) 6117 m = prepare_fx(m, qconfig_mapping, example_inputs) 6118 m(*example_inputs) 6119 m_ref = copy.deepcopy(m) 6120 m_ref = convert_to_reference_fx(m_ref) 6121 m = _convert_to_reference_decomposed_fx(m) 6122 expected_occurrence = { 6123 ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1, 6124 ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1, 6125 ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1, 6126 } 6127 self.checkGraphModuleNodes( 6128 m, 6129 expected_node_occurrence=expected_occurrence) 6130 # make sure it runs 6131 res_ref = m_ref(*example_inputs) 6132 res = m(*example_inputs) 6133 self.assertEqual(res, res_ref) 6134 6135 def test__convert_to_reference_decomposed_fx_per_channel_quant(self): 6136 class M(torch.nn.Module): 6137 def forward(self, x, weight, bias): 6138 return F.linear(x, weight, bias) 6139 6140 m = M().eval() 6141 qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ 6142 .set_object_type(F.linear, default_per_channel_qconfig) 6143 example_inputs = (torch.randn(1, 5), torch.randn(10, 5), torch.randn(10,)) 6144 m = prepare_fx(m, qconfig_mapping, example_inputs) 6145 m(*example_inputs) 6146 m_ref = copy.deepcopy(m) 6147 m_ref = convert_to_reference_fx(m_ref) 6148 m = _convert_to_reference_decomposed_fx(m) 6149 expected_occurrence = { 6150 # for input and output activations 6151 ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, 6152 ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, 6153 # for weight 6154 ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1, 6155 ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, 6156 } 6157 self.checkGraphModuleNodes( 6158 m, 6159 expected_node_occurrence=expected_occurrence) 6160 # make sure it runs 6161 res_ref = m_ref(*example_inputs) 6162 res = m(*example_inputs) 6163 self.assertEqual(res, res_ref) 6164 6165 def test_change_backend_config_for_fixed_qparam_ops(self): 6166 """ Making sure we can skip validation of qconfigs for fixedqparam ops based 6167 on BackendConfig 6168 """ 6169 class M(nn.Module): 6170 def __init__(self) -> None: 6171 super().__init__() 6172 self.tanh = torch.nn.Tanh() 6173 6174 def forward(self, x: torch.Tensor): 6175 x = self.tanh(x) 6176 return x 6177 6178 model = M().eval() 6179 # we set a global default_qconfig, which will be ignored since the backend 6180 # we defined doesn't support anything 6181 # this is to make sure we don't validate the qconfig when BackendConfig does not 6182 # have fixed qparam op related configurations 6183 qconfig_mapping = QConfigMapping().set_global(default_qconfig) 6184 backend_config = BackendConfig() 6185 # make sure this runs 6186 model = prepare_fx( 6187 model, 6188 qconfig_mapping=qconfig_mapping, 6189 example_inputs=(torch.randn(1, 2, 3, 4),), 6190 backend_config=backend_config 6191 ) 6192 6193 def test_channel_shuffle_lowering(self): 6194 # Three versions of channel shuffle 6195 class M1(torch.nn.Module): 6196 def __init__(self) -> None: 6197 super().__init__() 6198 self.op = torch.nn.ChannelShuffle(2) 6199 6200 def forward(self, x): 6201 return self.op(x + x) + x 6202 6203 class M2(torch.nn.Module): 6204 def forward(self, x): 6205 return torch.channel_shuffle(x + x, 2) + x 6206 6207 class M3(torch.nn.Module): 6208 def forward(self, x): 6209 return torch.nn.functional.channel_shuffle(x + x, 2) + x 6210 6211 x = torch.randn(4, 4, 4, 4) 6212 # torch.channel_shuffle is equivalent to torch.nn.functional.channel_shuffle 6213 model_node_pairs = [ 6214 (M1().eval(), ns.call_module(torch.nn.ChannelShuffle)), 6215 (M2().eval(), ns.call_function(torch.channel_shuffle)), 6216 (M3().eval(), ns.call_function(torch.channel_shuffle)) 6217 ] 6218 for m, node in model_node_pairs: 6219 m = prepare_fx(m, {"": default_qconfig}, example_inputs=(x,)) 6220 m_copy = copy.deepcopy(m) 6221 m = convert_fx(m) 6222 m_ref = convert_to_reference_fx(m_copy) 6223 node_occurrence = { 6224 node: 1, 6225 ns.call_function(torch.quantize_per_tensor): 1, 6226 ns.call_method("dequantize"): 1 6227 } 6228 node_occurrence_ref = { 6229 node: 1, 6230 ns.call_function(torch.quantize_per_tensor): 4, 6231 ns.call_method("dequantize"): 4 6232 } 6233 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 6234 self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) 6235 6236 def test_match_pattern_with_multiple_args(self): 6237 """ Test that we can match a pattern that has multiple arguments 6238 Pattern: 6239 shape \ 6240 transpose (observed) -> reshape -> output (observed) -> 6241 6242 where `reshape` has two arguments 6243 """ 6244 6245 def _get_pattern_configs(): 6246 backend_pattern_configs = [] 6247 observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT 6248 weighted_op_quint8_dtype_config = DTypeConfig( 6249 input_dtype=torch.quint8, 6250 output_dtype=torch.quint8, 6251 weight_dtype=torch.qint8, 6252 bias_dtype=torch.float, 6253 ) 6254 dtype_configs = [weighted_op_quint8_dtype_config] 6255 6256 def root_node_getter(node_pattern): 6257 reshape, transpose, shape = node_pattern 6258 return transpose 6259 6260 backend_pattern_configs.append( 6261 BackendPatternConfig() 6262 ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131 6263 .set_observation_type(observation_type) 6264 .set_dtype_configs(dtype_configs) 6265 ._set_root_node_getter(root_node_getter) 6266 ) 6267 return backend_pattern_configs 6268 6269 backend_config = BackendConfig().set_backend_pattern_configs(_get_pattern_configs()) 6270 6271 class M(torch.nn.Module): 6272 def forward(self, x): 6273 x = torch.transpose(x, 0, 1) 6274 x = torch.reshape(x, (-1,)) 6275 return x 6276 6277 m = M().eval() 6278 qconfig_mapping = QConfigMapping().set_global(default_qconfig) 6279 example_inputs = (torch.randn(1, 3, 3, 3),) 6280 m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) 6281 node_occurrence = { 6282 # one for input of the pattern and one for output of the pattern 6283 ns.call_module(MinMaxObserver): 2 6284 } 6285 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 6286 6287 def _test_linear_activation_fusion_lowering_helper( 6288 self, module, example_inputs, qconfig_mapping, 6289 backend_config, fused_module, root_module, activation_module): 6290 node_occurrence = { 6291 ns.call_function(torch.quantize_per_tensor): 1, 6292 ns.call_method("dequantize"): 1, 6293 ns.call_module(fused_module): 1, 6294 ns.call_module(root_module): 0, 6295 ns.call_module(activation_module): 0, 6296 } 6297 node_occurrence_ref = { 6298 ns.call_function(torch.quantize_per_tensor): 2, 6299 ns.call_method("dequantize"): 2, 6300 } 6301 m = module.eval() 6302 m = prepare_fx(m, qconfig_mapping, 6303 example_inputs=example_inputs, 6304 backend_config=backend_config) 6305 m_copy = copy.deepcopy(m) 6306 m = convert_fx(m, backend_config=backend_config) 6307 m_ref = convert_to_reference_fx(m_copy) 6308 6309 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 6310 self.checkGraphModuleNodes(m_ref, expected_node_occurrence=node_occurrence_ref) 6311 m(*example_inputs) 6312 6313 @skipIfNoONEDNN 6314 def test_linear_leaky_relu_lowering(self): 6315 """ Test fusion and lowering of Linear - (bn -) LeakyReLU 6316 by FX. For onednn backedn only. 6317 """ 6318 from torch.ao.quantization.backend_config import get_onednn_backend_config 6319 qconfig_mapping = get_default_qconfig_mapping('onednn') 6320 with override_quantized_engine('onednn'): 6321 for with_bn in [True, False]: 6322 m = LinearBnLeakyReluModel(with_bn) 6323 self._test_linear_activation_fusion_lowering_helper( 6324 m, 6325 m.get_example_inputs(), 6326 qconfig_mapping, 6327 get_onednn_backend_config(), 6328 nniq.LinearLeakyReLU, 6329 nn.Linear, 6330 nn.LeakyReLU) 6331 6332 @skipIfNoONEDNN 6333 def test_linear_tanh_lowering(self): 6334 """ Test fusion and lowering of Linear - Tanh 6335 by FX. For onednn backedn only. 6336 """ 6337 from torch.ao.quantization.backend_config import get_onednn_backend_config 6338 qconfig_mapping = get_default_qconfig_mapping('onednn') 6339 # TODO Currently it's required that separate ops in a fused op/module have the same qconfig. 6340 # Need to be able to support fusion of ops with different qconfigs 6341 # Since tanh must have 'fixed_qparams_qconfig' while linear should use 6342 # the global qconfig, we need to set qconfigs for them manually here for 6343 # fusion and cannot put such configs in onednn's default qconfig_mapping. 6344 # Known issue: 6345 # Cannot fuse linear - tanh and quantize standalone tanh at the same time. 6346 qconfig = get_default_qconfig('onednn') 6347 qconfig_mapping.set_object_type(torch.nn.Linear, qconfig) 6348 qconfig_mapping.set_object_type(torch.nn.Tanh, qconfig) 6349 with override_quantized_engine('onednn'): 6350 m = LinearTanhModel() 6351 self._test_linear_activation_fusion_lowering_helper( 6352 m, 6353 m.get_example_inputs(), 6354 qconfig_mapping, 6355 get_onednn_backend_config(), 6356 nniq.LinearTanh, 6357 nn.Linear, 6358 nn.Tanh) 6359 6360 @override_qengines 6361 def test_linear_size_view(self): 6362 class M(torch.nn.Module): 6363 def __init__(self, use_relu=False): 6364 super().__init__() 6365 self.linear = torch.nn.Linear(16, 32) 6366 self.relu = torch.nn.ReLU() 6367 self.use_relu = use_relu 6368 6369 def forward(self, x): 6370 x = self.linear(x) 6371 if self.use_relu: 6372 x = self.relu(x) 6373 return x.view(x.size(0), 1, 4, 8) 6374 6375 for use_relu in [False, True]: 6376 model_fp32 = M(use_relu).eval() 6377 qengine = torch.backends.quantized.engine 6378 qconfig_mapping = get_default_qconfig_mapping(qengine) 6379 x = torch.randn((5, 16)) 6380 model_fp32(x) 6381 prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) 6382 prepared_model(x) 6383 quantized_model = convert_fx(prepared_model) 6384 node_occurrence = { 6385 ns.call_module(nnq.Linear): 0 if use_relu else 1, 6386 ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, 6387 ns.call_function(torch.quantize_per_tensor): 1, 6388 ns.call_method("dequantize"): 1 6389 } 6390 self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) 6391 6392 @override_qengines 6393 def test_linear_shape_view(self): 6394 class M(torch.nn.Module): 6395 def __init__(self, use_relu=False): 6396 super().__init__() 6397 self.linear = torch.nn.Linear(16, 32) 6398 self.relu = torch.nn.ReLU() 6399 self.use_relu = use_relu 6400 6401 def forward(self, x): 6402 x = self.linear(x) 6403 if self.use_relu: 6404 x = self.relu(x) 6405 return x.view(x.shape[0], 1, 4, 8) 6406 6407 for use_relu in [False, True]: 6408 model_fp32 = M(use_relu).eval() 6409 qengine = torch.backends.quantized.engine 6410 qconfig_mapping = get_default_qconfig_mapping(qengine) 6411 x = torch.randn((5, 16)) 6412 model_fp32(x) 6413 prepared_model = prepare_fx(model_fp32, qconfig_mapping, x) 6414 prepared_model(x) 6415 quantized_model = convert_fx(prepared_model) 6416 node_occurrence = { 6417 ns.call_module(nnq.Linear): 0 if use_relu else 1, 6418 ns.call_module(nniq.LinearReLU): 1 if use_relu else 0, 6419 ns.call_function(torch.quantize_per_tensor): 1, 6420 ns.call_method("dequantize"): 1 6421 } 6422 self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) 6423 6424 def test_mixed_dtypes(self): 6425 """ 6426 Test that multiple dtypes can be used in the same model for different layers, 6427 and the dtypes will be converted correctly between the layers. 6428 """ 6429 class MyModule(torch.nn.Module): 6430 def __init__(self) -> None: 6431 super().__init__() 6432 self.linear1 = torch.nn.Linear(5, 5) 6433 self.linear2 = torch.nn.Linear(5, 5) 6434 self.sigmoid = torch.nn.Sigmoid() 6435 self.tanh = torch.nn.Tanh() 6436 self.float_functional = torch.ao.nn.quantized.FloatFunctional() 6437 6438 def forward(self, x: torch.Tensor): 6439 x = self.linear1(x) # qint32 6440 x = self.linear2(x) # quint8 6441 linear2 = x 6442 x = self.sigmoid(x) # back to qint32 6443 x = self.tanh(x) # back to quint8 6444 x = self.float_functional.add(linear2, x) # adding two quint8's together 6445 return x 6446 6447 def make_qconfig(scale, zp, dtype): 6448 return QConfig( 6449 activation=FixedQParamsObserver.with_args(scale=scale, zero_point=zp, dtype=dtype), 6450 weight=torch.ao.quantization.default_weight_observer) 6451 6452 # Set up a QConfigMapping that specifies different qparams and dtypes for different layers 6453 qconfig_mapping = QConfigMapping() \ 6454 .set_global(get_default_qconfig("qnnpack")) \ 6455 .set_module_name("linear1", make_qconfig(1234, 11, torch.qint32)) \ 6456 .set_module_name("linear2", make_qconfig(2345, 22, torch.quint8)) \ 6457 .set_object_type(torch.nn.Sigmoid, make_qconfig(3456, 33, torch.qint32)) \ 6458 .set_object_type(torch.nn.Tanh, make_qconfig(4567, 44, torch.quint8)) 6459 6460 # Set up BackendConfig that supports the dtypes configured in the above QConfigMapping 6461 weighted_op_qint32_dtype_config = DTypeConfig( 6462 input_dtype=torch.qint32, 6463 output_dtype=torch.qint32, 6464 weight_dtype=torch.qint8, 6465 bias_dtype=torch.float, 6466 ) 6467 fixed_qparams_op_quint8_dtype_config = DTypeConfig( 6468 input_dtype=torch.quint8, 6469 output_dtype=torch.quint8, 6470 ) 6471 fixed_qparams_op_qint32_dtype_config = DTypeConfig( 6472 input_dtype=torch.qint32, 6473 output_dtype=torch.qint32, 6474 ) 6475 backend_config = get_qnnpack_backend_config() 6476 for config in backend_config.configs: 6477 if config.pattern == torch.nn.Linear: 6478 config.add_dtype_config(weighted_op_qint32_dtype_config) 6479 elif config.pattern in [torch.nn.Sigmoid, torch.nn.Tanh]: 6480 config.add_dtype_config(fixed_qparams_op_quint8_dtype_config) 6481 config.add_dtype_config(fixed_qparams_op_qint32_dtype_config) 6482 6483 # Produce the reference quantized model 6484 m = MyModule() 6485 example_inputs = (torch.rand(5, 5),) 6486 prepared = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=backend_config) 6487 prepared(*example_inputs) # calibrate 6488 converted = convert_to_reference_fx(prepared, backend_config=backend_config) 6489 converted(*example_inputs) 6490 6491 # Verify that the reference model is correct 6492 # 6493 # Reference model until add should be: 6494 # fp32_input -> q_to_int32 -> [dq -> linear1_fp32 -> q_to_int32] -> dq -> 6495 # q_to_uint8 -> [dq -> linear2_fp32 -> q_to_uint8] -> dq (linear2_dq) -> 6496 # q_to_int32 -> [dq -> sigmoid_fp32 -> q_to_int32] -> dq -> 6497 # q_to_uint8 -> [dq -> tanh_fp32 -> q_to_uint8] -> dq (tanh_dq) 6498 # 6499 # Complete reference model with add should be: 6500 # [(linear2_dq, tanh_dq) -> add_fp32 -> q_to_uint8] -> dq -> fp32_output 6501 6502 target_to_expected_dtypes = { 6503 "linear1": torch.qint32, 6504 "linear2": torch.quint8, 6505 "sigmoid": torch.qint32, 6506 "tanh": torch.quint8, 6507 torch.add: torch.quint8, 6508 } 6509 # Find the patterns [dq - op_fp32 - q_to_specific_dtype] in the graph 6510 linear2_node = tanh_node = None 6511 for node in converted.graph.nodes: 6512 if node.target not in target_to_expected_dtypes: 6513 continue 6514 6515 # Match preceding dequantize 6516 self.assertTrue(len(node.args) == 1 or len(node.args) == 2) 6517 self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) 6518 6519 # Match following quantize with the specific dtypes 6520 self.assertEqual(len(node.users), 1) 6521 user = next(iter(node.users.keys())) 6522 self.assertEqual(user.target, torch.quantize_per_tensor) 6523 self.assertEqual(user.args[-1], target_to_expected_dtypes[node.target]) 6524 6525 # Match [dq - torch.add(linear2_dq, tanh_dq) - q] 6526 if node.target == "linear2": 6527 linear2_node = node 6528 elif node.target == "tanh": 6529 tanh_node = node 6530 elif node.target == torch.add: 6531 linear2_dq, tanh_dq = node.args 6532 self.assertEqual(tanh_dq.args[0].args[0], tanh_node) 6533 self.assertEqual(linear2_dq.args[0].args[0], linear2_node) 6534 6535 def test_lowering_functional_conv_with_kwargs(self): 6536 dim_to_op = { 6537 1: F.conv1d, 6538 2: F.conv2d, 6539 3: F.conv3d, 6540 } 6541 dim_to_qop = { 6542 1: torch.ops.quantized.conv1d, 6543 2: torch.ops.quantized.conv2d, 6544 3: torch.ops.quantized.conv3d, 6545 } 6546 6547 class Mod(nn.Module): 6548 def __init__(self, in_channels, out_channels, kernel_size, dimension): 6549 super().__init__() 6550 self.dim = dimension 6551 self.op = dim_to_op[dimension] 6552 kernel_sizes = [kernel_size] * self.dim 6553 self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_sizes)) 6554 6555 def forward(self, input): 6556 return self.op(input, self.weight, bias=None, stride=[1] * self.dim, 6557 padding=[0] * self.dim, dilation=[1] * self.dim, groups=1) 6558 6559 for dimension in [1, 2, 3]: 6560 model = Mod(3, 16, 3, dimension) 6561 model.eval() 6562 qconfig_mapping = get_default_qconfig_mapping() 6563 input_shape = (1, 3, *([8] * dimension)) 6564 example_inputs = torch.randn(input_shape) 6565 prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) 6566 prepared_model(example_inputs) 6567 quantized_model = convert_fx(prepared_model) 6568 # This should pass 6569 quantized_model(example_inputs) 6570 # Ensure the quantized model has the expected op 6571 node_occurrence = { 6572 ns.call_function(dim_to_qop[dimension]): 1, 6573 } 6574 self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) 6575 6576 def test_lowering_functional_conv_transpose_with_kwargs(self): 6577 dim_to_op = { 6578 1: F.conv_transpose1d, 6579 2: F.conv_transpose2d, 6580 3: F.conv_transpose3d, 6581 } 6582 dim_to_qop = { 6583 1: torch.ops.quantized.conv_transpose1d, 6584 2: torch.ops.quantized.conv_transpose2d, 6585 3: torch.ops.quantized.conv_transpose3d, 6586 } 6587 6588 class Mod(nn.Module): 6589 def __init__(self, in_channels, out_channels, kernel_size, dimension): 6590 super().__init__() 6591 self.dim = dimension 6592 self.op = dim_to_op[dimension] 6593 kernel_sizes = [kernel_size] * self.dim 6594 self.weight = nn.Parameter(torch.randn(in_channels, out_channels, *kernel_sizes)) 6595 6596 def forward(self, input): 6597 return self.op(input, self.weight, bias=None, stride=[1] * self.dim, 6598 padding=[0] * self.dim, output_padding=[0] * self.dim, 6599 dilation=[1] * self.dim, groups=1) 6600 6601 for dimension in [1, 2, 3]: 6602 model = Mod(3, 16, 3, dimension) 6603 model.eval() 6604 qconfig_mapping = get_default_qconfig_mapping() 6605 input_shape = (1, 3, *([8] * dimension)) 6606 example_inputs = torch.randn(input_shape) 6607 prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) 6608 prepared_model(example_inputs) 6609 quantized_model = convert_fx(prepared_model) 6610 # This should pass 6611 quantized_model(example_inputs) 6612 # Ensure the quantized model has the expected op 6613 node_occurrence = { 6614 ns.call_function(dim_to_qop[dimension]): 1, 6615 } 6616 self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) 6617 6618 def test_lowering_functional_linear_with_kwargs(self): 6619 class Mod(nn.Module): 6620 def __init__(self, in_channels, out_channels): 6621 super().__init__() 6622 self.weight = nn.Parameter(torch.randn(out_channels, in_channels)) 6623 6624 def forward(self, input): 6625 return F.linear(input, self.weight, bias=None) 6626 6627 model = Mod(8, 4) 6628 model.eval() 6629 qconfig_mapping = get_default_qconfig_mapping() 6630 example_inputs = torch.randn(1, 8) 6631 prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) 6632 prepared_model(example_inputs) 6633 quantized_model = convert_fx(prepared_model) 6634 # This should pass 6635 quantized_model(example_inputs) 6636 # Ensure the quantized model has the expected op 6637 node_occurrence = { 6638 ns.call_function(torch.ops.quantized.linear): 1, 6639 } 6640 self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) 6641 6642@skipIfNoFBGEMM 6643class TestQuantizeFxOps(QuantizationTestCase): 6644 def setUp(self): 6645 super().setUp() 6646 self.custom_qconfig = torch.ao.quantization.QConfig( 6647 activation=torch.ao.quantization.observer.HistogramObserver.with_args( 6648 qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 6649 ), 6650 weight=torch.ao.quantization.default_per_channel_weight_observer 6651 ) 6652 self.common_quant_patterns = { 6653 torch.nn.ConvTranspose1d: DefaultNodeQuantizeHandler, 6654 torch.nn.ConvTranspose2d: DefaultNodeQuantizeHandler, 6655 torch.nn.ELU: DefaultNodeQuantizeHandler, 6656 torch.nn.LeakyReLU: DefaultNodeQuantizeHandler, 6657 torch.nn.Hardswish: DefaultNodeQuantizeHandler, 6658 torch.nn.InstanceNorm1d: DefaultNodeQuantizeHandler, 6659 torch.nn.InstanceNorm2d: DefaultNodeQuantizeHandler, 6660 torch.nn.InstanceNorm3d: DefaultNodeQuantizeHandler, 6661 torch.nn.LayerNorm: DefaultNodeQuantizeHandler, 6662 torch.nn.SiLU: DefaultNodeQuantizeHandler, 6663 torch.nn.Mish: DefaultNodeQuantizeHandler, 6664 torch.nn.GELU: DefaultNodeQuantizeHandler, 6665 torch.nn.Softmax: DefaultNodeQuantizeHandler, 6666 torch.nn.functional.elu: DefaultNodeQuantizeHandler, 6667 torch.nn.functional.hardswish: DefaultNodeQuantizeHandler, 6668 torch.nn.functional.instance_norm: DefaultNodeQuantizeHandler, 6669 torch.nn.functional.layer_norm: DefaultNodeQuantizeHandler, 6670 torch.nn.functional.leaky_relu: DefaultNodeQuantizeHandler, 6671 torch.nn.functional.silu: DefaultNodeQuantizeHandler, 6672 torch.nn.functional.mish: DefaultNodeQuantizeHandler, 6673 torch.nn.functional.gelu: DefaultNodeQuantizeHandler, 6674 torch.nn.functional.softmax: DefaultNodeQuantizeHandler, 6675 torch.sum: DefaultNodeQuantizeHandler 6676 } 6677 6678 """Unit tests for individual ops 6679 """ 6680 @skipIfNoFBGEMM 6681 def test_linear_module(self): 6682 with override_quantized_engine('fbgemm'): 6683 class LinearModel(torch.nn.Module): 6684 def __init__(self) -> None: 6685 super().__init__() 6686 self.linear = torch.nn.Linear(30, 4).float() 6687 6688 def forward(self, x): 6689 return self.linear(x) 6690 6691 class LinearReLUModel(torch.nn.Module): 6692 def __init__(self, f_relu=False): 6693 super().__init__() 6694 self.linear = torch.nn.Linear(30, 4).float() 6695 if f_relu: 6696 self.relu = F.relu 6697 else: 6698 self.relu = torch.nn.ReLU() 6699 6700 def forward(self, x): 6701 x = self.linear(x) 6702 x = self.relu(x) 6703 return x 6704 6705 class LinearBnModel(torch.nn.Module): 6706 def __init__(self) -> None: 6707 super().__init__() 6708 self.linear = torch.nn.Linear(4, 4).float() 6709 self.bn = torch.nn.BatchNorm1d(4) 6710 6711 def forward(self, x): 6712 x = self.linear(x) 6713 x = self.bn(x) 6714 return x 6715 6716 # Test linear 6717 data = (torch.rand((1, 30), dtype=torch.float),) 6718 for quant_type in self.all_quant_types: 6719 model = LinearModel() 6720 quantized_module = nnqd.Linear if quant_type == QuantType.DYNAMIC else nnq.Linear 6721 quantized_node = ns.call_module(quantized_module) 6722 result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) 6723 if quant_type in self.static_quant_types: 6724 self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) 6725 6726 # TODO: enable test for dynamic quant 6727 # Test linear-relu 6728 for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]): 6729 model = LinearReLUModel(f_relu) 6730 quantized_node = ns.call_module(nniq.LinearReLU) 6731 result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) 6732 self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) 6733 6734 # Test linear-bn 6735 data = (torch.rand((4, 4), dtype=torch.float),) 6736 for quant_type in self.static_quant_types: 6737 model = LinearBnModel() 6738 quantized_node = ns.call_module(nnq.Linear) 6739 result_dict = self.checkGraphModeFxOp(model, data, quant_type, quantized_node) 6740 self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) 6741 6742 @skipIfNoFBGEMM 6743 def test_functional_linear(self): 6744 with override_quantized_engine('fbgemm'): 6745 class FuncLinear(torch.nn.Module): 6746 def __init__(self, use_bias, has_relu, f_relu): 6747 super().__init__() 6748 self.w = torch.randn(4, 30) 6749 self.b = torch.randn(4) 6750 self.use_bias = use_bias 6751 if has_relu: 6752 if f_relu: 6753 self.relu_or_id = F.relu 6754 else: 6755 self.relu_or_id = torch.nn.ReLU() 6756 else: 6757 self.relu_or_id = torch.nn.Identity() 6758 6759 def forward(self, x): 6760 if self.use_bias: 6761 x = F.linear(x, self.w, self.b) 6762 else: 6763 x = F.linear(x, self.w) 6764 x = self.relu_or_id(x) 6765 return x 6766 6767 data = (torch.rand((1, 30), dtype=torch.float),) 6768 quant_type_to_qlinear_fun = { 6769 QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic), 6770 QuantType.STATIC: ns.call_function(torch.ops.quantized.linear), 6771 QuantType.QAT: ns.call_function(torch.ops.quantized.linear), 6772 } 6773 quant_type_to_qlinear_relu_fun = { 6774 # we don't have linear_relu_dynamic 6775 QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_relu_dynamic), 6776 QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu), 6777 QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu), 6778 } 6779 6780 options = itertools.product( 6781 self.all_quant_types, 6782 (True, False), # use_bias 6783 (True, False), # has_relu 6784 (True, False), # functional relu 6785 ) 6786 for quant_type, use_bias, has_relu, f_relu in options: 6787 # when has_relu is False, we are using an nn.Identity and 6788 # we will insert observer/fake_quant for the output of nn.Identity since 6789 # it is a copy node, that's why we have extra observer/fake_quant 6790 # when has_relu is False 6791 quant_type_to_prepare_expected_node_occurrence = { 6792 QuantType.DYNAMIC: { 6793 ns.call_module(torch.ao.quantization.PlaceholderObserver): 1, 6794 ns.call_module(torch.ao.quantization.MinMaxObserver): 1, 6795 }, 6796 # There should be 3 observers: after input, weight and activation. 6797 # one more observer for torch.nn.Identity when there is no relu 6798 QuantType.STATIC: { 6799 ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, 6800 ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, 6801 }, 6802 # There should be 3 observers: after input, weight and activation. 6803 QuantType.QAT: { 6804 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, 6805 }, 6806 } 6807 model = FuncLinear(use_bias, has_relu, f_relu) 6808 if has_relu: 6809 qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type] 6810 else: 6811 qlinear_fun = quant_type_to_qlinear_fun[quant_type] 6812 6813 if quant_type != QuantType.DYNAMIC: 6814 num_dequantize = 1 6815 else: 6816 # we will have an extra quantize_per_tensor_dynamic + dequantize for 6817 # nn.Identity right now, but it will be fixed after we use 6818 # backend_config to configure the default pt backend 6819 num_dequantize = int(not has_relu) 6820 6821 convert_node_occurrence = { 6822 ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0, 6823 qlinear_fun: 1, 6824 ns.call_method("dequantize"): num_dequantize if quant_type != QuantType.DYNAMIC else 0, 6825 } 6826 prepare_expected_node_occurrence = \ 6827 quant_type_to_prepare_expected_node_occurrence[quant_type] 6828 result_dict = self.checkGraphModeFxOp( 6829 model, data, quant_type, qlinear_fun, 6830 prepare_expected_node_occurrence=prepare_expected_node_occurrence, 6831 expected_node_occurrence=convert_node_occurrence) 6832 if quant_type != QuantType.DYNAMIC: 6833 self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) 6834 # Ensure packed weights in lowered models are folded 6835 self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) 6836 6837 @skipIfNoFBGEMM 6838 def test_linear_dynamic_fp16(self): 6839 with override_quantized_engine('fbgemm'): 6840 class FuncLinear(torch.nn.Module): 6841 def __init__(self, use_bias, has_relu, f_relu): 6842 super().__init__() 6843 self.w = torch.randn(4, 30) 6844 self.b = torch.randn(4) 6845 self.use_bias = use_bias 6846 if has_relu: 6847 if f_relu: 6848 self.relu = F.relu 6849 else: 6850 self.relu = torch.nn.ReLU() 6851 else: 6852 self.relu = torch.nn.Identity() 6853 6854 def forward(self, x): 6855 if self.use_bias: 6856 x = F.linear(x, self.w, self.b) 6857 else: 6858 x = F.linear(x, self.w) 6859 x = self.relu(x) 6860 return x 6861 6862 data = (torch.rand((1, 30), dtype=torch.float),) 6863 options = itertools.product( 6864 (True, False), # use_bias 6865 (True, False), # has_relu 6866 (True, False), # functional relu 6867 (True, False), # is_reference 6868 ) 6869 for use_bias, has_relu, f_relu, is_reference in options: 6870 model = FuncLinear(use_bias, has_relu, f_relu) 6871 if is_reference: 6872 qlinear_fun = ns.call_function(torch.nn.functional.linear) 6873 else: 6874 if has_relu: 6875 qlinear_fun = ns.call_function(torch.ops.quantized.linear_relu_dynamic_fp16) 6876 else: 6877 qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16) 6878 prepare_node_occurrence = { 6879 # activation and weight 6880 ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 6881 } 6882 convert_node_occurrence = { 6883 qlinear_fun: 1, 6884 # weight 6885 ns.call_method("to"): 1 if is_reference else 0 6886 } 6887 self.checkGraphModeFxOp( 6888 model, data, QuantType.DYNAMIC, qlinear_fun, 6889 is_reference=is_reference, 6890 custom_qconfig_dict={"": float16_dynamic_qconfig}, 6891 prepare_expected_node_occurrence=prepare_node_occurrence, 6892 expected_node_occurrence=convert_node_occurrence) 6893 6894 def test_linear_static_fp16(self): 6895 class FuncLinear(torch.nn.Module): 6896 def __init__(self, use_bias, has_relu, f_relu): 6897 super().__init__() 6898 self.w = torch.randn(4, 30) 6899 self.b = torch.randn(4) 6900 self.use_bias = use_bias 6901 if has_relu: 6902 if f_relu: 6903 self.relu = F.relu 6904 else: 6905 self.relu = torch.nn.ReLU() 6906 else: 6907 self.relu = torch.nn.Identity() 6908 6909 def forward(self, x): 6910 if self.use_bias: 6911 x = F.linear(x, self.w, self.b) 6912 else: 6913 x = F.linear(x, self.w) 6914 x = self.relu(x) 6915 return x 6916 6917 data = (torch.rand((1, 30), dtype=torch.float),) 6918 options = itertools.product( 6919 (True, False), # use_bias 6920 (True, False), # has_relu 6921 (True, False), # functional relu 6922 (True, False), # is_reference 6923 ) 6924 backend_config = get_test_only_legacy_native_backend_config() 6925 for use_bias, has_relu, f_relu, is_reference in options: 6926 model = FuncLinear(use_bias, has_relu, f_relu) 6927 linear_fun = ns.call_function(torch.nn.functional.linear) 6928 # when has_relu is False, we are using an nn.Identity and 6929 # we will insert observer/fake_quant for the output of nn.Identity since 6930 # it is a copy node, that's why we have extra observer/fake_quant 6931 # when has_relu is False 6932 prepare_node_occurrence = { 6933 # activation, weight, bias and output 6934 ns.call_module(torch.ao.quantization.PlaceholderObserver): 3 + int(use_bias) + int(not has_relu), 6935 } 6936 # We have extra to and dequantize when is_reference is True 6937 # and has_relu is False since when has_relu is False, we 6938 # have an nn.Identity in the model, which is a CopyNode 6939 # and we would add extra quant - dequant for CopyNode in 6940 # reference patterns 6941 convert_node_occurrence = { 6942 # we don't support static fp16 ops, so the linear function 6943 # is unfused 6944 linear_fun: 1, 6945 # activation, weight, bias and output 6946 ns.call_method("to"): 3 + int(use_bias) + int(not has_relu and is_reference), 6947 ns.call_method("dequantize"): 3 + int(use_bias) + int(not has_relu and is_reference) 6948 } 6949 self.checkGraphModeFxOp( 6950 model, data, QuantType.DYNAMIC, linear_fun, 6951 is_reference=is_reference, 6952 custom_qconfig_dict={"": float16_static_qconfig}, 6953 prepare_expected_node_occurrence=prepare_node_occurrence, 6954 expected_node_occurrence=convert_node_occurrence, 6955 backend_config=backend_config) 6956 6957 @skipIfNoFBGEMM 6958 def test_conv_module(self): 6959 conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} 6960 6961 class ConvWrapper(torch.nn.Module): 6962 def __init__(self, dim): 6963 super().__init__() 6964 self.conv = conv_module[dim](3, 3, 3).float() 6965 6966 def forward(self, x): 6967 return self.conv(x) 6968 6969 options = itertools.product([1, 2, 3], self.static_quant_types) 6970 quantized_nodes = { 6971 # dim 6972 1: ns.call_module(nnq.Conv1d), 6973 2: ns.call_module(nnq.Conv2d), 6974 3: ns.call_module(nnq.Conv3d), 6975 } 6976 for dim, quant_type in options: 6977 self.checkGraphModeFxOp( 6978 ConvWrapper(dim), self.img_data_dict[dim], quant_type, 6979 quantized_nodes[dim]) 6980 6981 @skipIfNoFBGEMM 6982 def test_functional_conv(self): 6983 with override_quantized_engine('fbgemm'): 6984 """ Test for function conv and functional conv + relu 6985 """ 6986 convs = { 6987 1: torch.nn.functional.conv1d, 6988 2: torch.nn.functional.conv2d, 6989 3: torch.nn.functional.conv3d, 6990 } 6991 6992 class FuncConv(torch.nn.Module): 6993 def __init__(self, dim, use_bias, has_relu, f_relu): 6994 super().__init__() 6995 self.dim = dim 6996 self.w = torch.randn(tuple([3] * (dim + 2))) 6997 self.b = torch.randn(3) if use_bias else None 6998 self.stride = tuple([1] * dim) 6999 self.padding = tuple([0] * dim) 7000 self.dilation = tuple([1] * dim) 7001 self.groups = 1 7002 self.use_bias = use_bias 7003 if has_relu: 7004 if f_relu: 7005 self.relu = F.relu 7006 else: 7007 self.relu = torch.nn.ReLU() 7008 else: 7009 self.relu = torch.nn.Identity() 7010 7011 def forward(self, x): 7012 x = convs[self.dim](x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups) 7013 x = self.relu(x) 7014 return x 7015 7016 quant_type_to_qconv_fun = { 7017 QuantType.STATIC: { 7018 1: ns.call_function(torch.ops.quantized.conv1d), 7019 2: ns.call_function(torch.ops.quantized.conv2d), 7020 3: ns.call_function(torch.ops.quantized.conv3d) 7021 }, 7022 QuantType.QAT: { 7023 1: ns.call_function(torch.ops.quantized.conv1d), 7024 2: ns.call_function(torch.ops.quantized.conv2d), 7025 3: ns.call_function(torch.ops.quantized.conv3d) 7026 }, 7027 } 7028 quant_type_to_qconv_relu_fun = { 7029 QuantType.STATIC: { 7030 1: ns.call_function(torch.ops.quantized.conv1d_relu), 7031 2: ns.call_function(torch.ops.quantized.conv2d_relu), 7032 3: ns.call_function(torch.ops.quantized.conv3d_relu) 7033 }, 7034 QuantType.QAT: { 7035 1: ns.call_function(torch.ops.quantized.conv1d_relu), 7036 2: ns.call_function(torch.ops.quantized.conv2d_relu), 7037 3: ns.call_function(torch.ops.quantized.conv3d_relu) 7038 }, 7039 } 7040 7041 options = itertools.product( 7042 [1, 2, 3], # dims 7043 self.static_quant_types, 7044 (True, False), # use_bias 7045 (True, False), # has_relu 7046 (True, False), # functional relu 7047 ) 7048 for dim, quant_type, use_bias, has_relu, f_relu in options: 7049 # when has_relu is False, we are using an nn.Identity and 7050 # we will insert observer/fake_quant for the output of nn.Identity since 7051 # it is a copy node, that's why we have extra observer/fake_quant 7052 # when has_relu is False 7053 quant_type_to_prepare_expected_node_occurrence = { 7054 QuantType.DYNAMIC: {}, 7055 # There should be 3 observers: after input, weight and activation. 7056 QuantType.STATIC: { 7057 ns.call_module(torch.ao.quantization.HistogramObserver): 2 if has_relu else 3, 7058 ns.call_module(torch.ao.quantization.PerChannelMinMaxObserver): 1, 7059 }, 7060 # There should be 3 observers: after input, weight and activation. 7061 QuantType.QAT: { 7062 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 3 if has_relu else 4, 7063 }, 7064 } 7065 data_dims = [2, 3] + [4] * dim 7066 data = (torch.randn(tuple(data_dims), dtype=torch.float),) 7067 model = FuncConv(dim, use_bias, has_relu, f_relu) 7068 if has_relu: 7069 qconv_fun = quant_type_to_qconv_relu_fun[quant_type][dim] 7070 else: 7071 qconv_fun = quant_type_to_qconv_fun[quant_type][dim] 7072 7073 convert_node_occurrence = { 7074 ns.call_function(torch.quantize_per_tensor): 1, 7075 qconv_fun: 1, 7076 ns.call_method("dequantize"): 1 7077 } 7078 prepare_expected_node_occurrence = \ 7079 quant_type_to_prepare_expected_node_occurrence[quant_type] 7080 result_dict = self.checkGraphModeFxOp( 7081 model, data, quant_type, qconv_fun, 7082 prepare_expected_node_occurrence=prepare_expected_node_occurrence, 7083 expected_node_occurrence=convert_node_occurrence) 7084 if quant_type != QuantType.DYNAMIC: 7085 self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"]) 7086 # Ensure packed weights in lowered models are folded 7087 self.assertIn("_packed_weight_0", result_dict["quantized"].state_dict().keys()) 7088 7089 @skipIfNoFBGEMM 7090 def test_quantized_conv_relu(self): 7091 """tests for conv1d_relu/conv2d_relu/conv3d_relu""" 7092 conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d} 7093 7094 class ConvNdRelu(torch.nn.Module): 7095 def __init__(self, dim, inplace): 7096 super().__init__() 7097 self.conv = conv_module[dim](3, 3, 3).float() 7098 self.relu = torch.nn.ReLU(inplace) 7099 7100 def forward(self, x): 7101 return self.relu(self.conv(x)) 7102 7103 class ConvNdFunctionalRelu(torch.nn.Module): 7104 def __init__(self, dim): 7105 super().__init__() 7106 self.conv = conv_module[dim](3, 3, 3).float() 7107 7108 def forward(self, x): 7109 return F.relu(self.conv(x)) 7110 7111 class ConvNdInplaceFunctionalRelu(torch.nn.Module): 7112 def __init__(self, dim): 7113 super().__init__() 7114 self.conv = conv_module[dim](3, 3, 3).float() 7115 7116 def forward(self, x): 7117 return F.relu(self.conv(x), True) 7118 7119 options = itertools.product([1, 2, 3], self.static_quant_types) 7120 quantized_nodes = { 7121 # dim 7122 1: ns.call_module(nniq.ConvReLU1d), 7123 2: ns.call_module(nniq.ConvReLU2d), 7124 3: ns.call_module(nniq.ConvReLU3d), 7125 } 7126 for dim, quant_type in options: 7127 for m in [ConvNdRelu(dim, True), 7128 ConvNdRelu(dim, False), 7129 ConvNdFunctionalRelu(dim), 7130 ConvNdInplaceFunctionalRelu(dim)]: 7131 self.checkGraphModeFxOp( 7132 m, self.img_data_dict[dim], quant_type, 7133 quantized_nodes[dim]) 7134 7135 7136 def _test_binary_op_int8_impl(self, binary_op, ibinary_op, quantized_op): 7137 data = (torch.randn(1, 1, 1, 1, dtype=torch.float), 7138 torch.randn(1, 1, 1, 1, dtype=torch.float)) 7139 options = itertools.product([True, False], [True, False], [True, False]) 7140 quant_type = QuantType.STATIC 7141 # testing for default int8 static quant 7142 for is_inplace, is_scalar, is_reference in options: 7143 if is_reference: 7144 node_list = [ 7145 ns.call_method("dequantize"), 7146 ns.call_function(binary_op), 7147 ns.call_function(torch.quantize_per_tensor) 7148 ] 7149 quantized_node = None 7150 else: 7151 node_list = None 7152 quantized_node = ns.call_function(quantized_op) 7153 7154 self.checkGraphModeFxOp( 7155 BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, 7156 quantized_node, expected_node_list=node_list, is_reference=is_reference) 7157 # This tests the binary op should be quantized even when it is not feed with a 7158 # quantized input 7159 self.checkGraphModeFxOp( 7160 BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), 7161 data, quant_type, quantized_node, 7162 expected_node_list=node_list, is_reference=is_reference) 7163 7164 7165 def _test_binary_op_float16_impl(self, binary_op, ibinary_op): 7166 data = (torch.randn(1, 1, 1, 1, dtype=torch.float), 7167 torch.randn(1, 1, 1, 1, dtype=torch.float)) 7168 quant_type = QuantType.STATIC 7169 # testing for fp16 static quant 7170 # we are producing fp16 patterns 7171 options = itertools.product([True, False], [True, False]) 7172 custom_qconfig_dict = { 7173 "object_type": [(binary_op, float16_static_qconfig)] 7174 } 7175 backend_config = get_test_only_legacy_native_backend_config() 7176 for is_inplace, is_scalar in options: 7177 node_occurrence = { 7178 # output_conv1, output_add1, output_add2 for scalar 7179 # output_conv1, output_conv2, output_add1, output_add2 for non-scalar 7180 ns.call_method("to"): 3 if is_scalar else 4 7181 } 7182 self.checkGraphModeFxOp( 7183 BinaryOp(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, 7184 expected_node_occurrence=node_occurrence, 7185 custom_qconfig_dict=custom_qconfig_dict, 7186 backend_config=backend_config) 7187 7188 node_occurrence = { 7189 # input_add, output_add for scalar 7190 # input_add1, input_add2, output_add for non-scalar 7191 ns.call_method("to"): 2 if is_scalar else 3 7192 } 7193 self.checkGraphModeFxOp( 7194 BinaryOpNonQuantizedInput(binary_op, ibinary_op, is_inplace, is_scalar), data, quant_type, 7195 expected_node_occurrence=node_occurrence, 7196 custom_qconfig_dict=custom_qconfig_dict, 7197 backend_config=backend_config) 7198 7199 def _test_binary_op_relu_int8_impl(self, binary_op, ibinary_op, quantized_op): 7200 data = (torch.rand((1, 1, 1, 1), dtype=torch.float), 7201 torch.rand((1, 1, 1, 1), dtype=torch.float)) 7202 quant_type = QuantType.STATIC 7203 quantized_node = ns.call_function(quantized_op) 7204 options = itertools.product( 7205 [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) 7206 for is_inplace_op, relu_callable, is_scalar in options: 7207 model = BinaryOpRelu( 7208 binary_op, ibinary_op, is_inplace_op, relu_callable, is_scalar) 7209 self.checkGraphModeFxOp( 7210 model, data, quant_type, quantized_node) 7211 7212 def _test_binary_op_relu_float16_impl(self, binary_op, ibinary_op): 7213 data = (torch.rand((1, 1, 1, 1), dtype=torch.float), 7214 torch.rand((1, 1, 1, 1), dtype=torch.float)) 7215 quant_type = QuantType.STATIC 7216 options = itertools.product( 7217 [True, False], [nn.ReLU, F.relu, torch.relu], [True, False]) 7218 custom_qconfig_dict = { 7219 "": float16_static_qconfig, 7220 "object_type": [(torch.nn.Conv2d, None)] 7221 } 7222 backend_config = get_test_only_legacy_native_backend_config() 7223 for is_inplace_op, is_functional_relu, is_scalar in options: 7224 node_occurrence = { 7225 ns.call_method("to"): 3 if is_scalar else 4 7226 } 7227 model = BinaryOpRelu( 7228 binary_op, ibinary_op, is_inplace_op, is_functional_relu, is_scalar) 7229 self.checkGraphModeFxOp( 7230 model, data, quant_type, custom_qconfig_dict=custom_qconfig_dict, 7231 expected_node_occurrence=node_occurrence, 7232 backend_config=backend_config) 7233 7234 7235 @skipIfNoFBGEMM 7236 def test_add(self): 7237 self._test_binary_op_int8_impl( 7238 operator.add, operator.iadd, torch.ops.quantized.add) 7239 self._test_binary_op_float16_impl( 7240 operator.add, operator.iadd) 7241 7242 @unittest.skip("This is no longer needed right now, can enable later with new api") 7243 def test_sub(self): 7244 self._test_binary_op_float16_impl(operator.sub, operator.isub) 7245 self._test_binary_op_float16_impl(torch.sub, None) 7246 7247 @unittest.skip("This is no longer needed right now, can enable later with new api") 7248 def test_div(self): 7249 self._test_binary_op_float16_impl(operator.truediv, operator.itruediv) 7250 self._test_binary_op_float16_impl(torch.div, None) 7251 7252 @skipIfNoFBGEMM 7253 def test_mul(self): 7254 self._test_binary_op_int8_impl( 7255 operator.mul, operator.imul, torch.ops.quantized.mul) 7256 self._test_binary_op_float16_impl(operator.mul, operator.imul) 7257 7258 @unittest.skip("This is no longer needed right now, can enable later with new api") 7259 def test_sum(self): 7260 class Sum(torch.nn.Module): 7261 def forward(self, x): 7262 x = torch.sum(x, [1], keepdim=True) 7263 x = torch.sum(x, [1]) 7264 return x 7265 7266 data = torch.randn(1, 2, 3, 4, dtype=torch.float) 7267 quant_type = QuantType.STATIC 7268 # testing for fp16 static quant 7269 # we are producing fp16 patterns 7270 custom_qconfig_dict = { 7271 "object_type": [(torch.sum, float16_static_qconfig)] 7272 } 7273 node_occurrence = { 7274 # input_sum1, output_sum1, output_sum2 7275 ns.call_method("to"): 3 7276 } 7277 self.checkGraphModeFxOp( 7278 Sum(), data, quant_type, 7279 expected_node_occurrence=node_occurrence, 7280 custom_qconfig_dict=custom_qconfig_dict) 7281 7282 @unittest.skip("This is no longer needed right now, can enable later with new api") 7283 def test_bmm(self): 7284 class BMMMethod(torch.nn.Module): 7285 def forward(self, x, y): 7286 return x.bmm(y) 7287 7288 data = (torch.randn(1, 1, 1, dtype=torch.float), 7289 torch.randn(1, 1, 1, dtype=torch.float)) 7290 quant_type = QuantType.STATIC 7291 # testing for fp16 static quant 7292 # we are producing fp16 patterns 7293 custom_qconfig_dict = { 7294 "object_type": [(torch.bmm, float16_static_qconfig), 7295 ("bmm", float16_static_qconfig)] 7296 } 7297 node_occurrence = { 7298 # input_bmm1, input_bmm2, output_bmm 7299 ns.call_method("to"): 3 7300 } 7301 self.checkGraphModeFxOp( 7302 BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type, 7303 expected_node_occurrence=node_occurrence, 7304 custom_qconfig_dict=custom_qconfig_dict) 7305 7306 # TODO: support call_method("bmm") 7307 # we can transform call_method("bmm") to call_function(torch.bmm) 7308 # self.checkGraphModeFxOp( 7309 # BMMMethod(), data, quant_type, 7310 # expected_node_occurrence=node_occurrence, 7311 # custom_qconfig_dict=custom_qconfig_dict, 7312 # print_debug_info=True) 7313 7314 @skipIfNoFBGEMM 7315 def test_add_relu(self): 7316 self._test_binary_op_relu_int8_impl( 7317 operator.add, operator.iadd, torch.ops.quantized.add_relu) 7318 self._test_binary_op_relu_float16_impl( 7319 operator.add, operator.iadd) 7320 7321 @skipIfNoFBGEMM 7322 def test_add_relu_multiple_uses_of_relu(self): 7323 class Sub(torch.nn.Module): 7324 def __init__(self) -> None: 7325 super().__init__() 7326 self.relu = torch.nn.ReLU(inplace=True) 7327 7328 class M(torch.nn.Module): 7329 def __init__(self) -> None: 7330 super().__init__() 7331 self.sub = Sub() 7332 7333 def forward(self, x, y): 7334 x = x + y 7335 x = self.sub.relu(x) 7336 x = x + y 7337 x = self.sub.relu(x) 7338 return x 7339 7340 m = M().eval() 7341 example_inputs = (torch.randn(3), torch.randn(3)) 7342 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 7343 m = convert_fx(m) 7344 node_occurrence = { 7345 ns.call_function(torch.quantize_per_tensor): 2, 7346 ns.call_function(torch.ops.quantized.add_relu): 2, 7347 ns.call_method("dequantize"): 1, 7348 } 7349 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 7350 # check the model is scriptable 7351 m = torch.jit.script(m) 7352 # check the model is runnable 7353 m(*example_inputs) 7354 7355 @skipIfNoFBGEMM 7356 def test_mul_relu(self): 7357 self._test_binary_op_relu_int8_impl( 7358 operator.mul, operator.imul, torch.ops.quantized.mul_relu) 7359 self._test_binary_op_relu_float16_impl( 7360 operator.mul, operator.imul) 7361 7362 # TODO(future PR): make more generic 7363 def _test_quantized_add_mul_qat(self, model, example_inputs, expected_node_occurrence): 7364 qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} 7365 mp = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 7366 self.checkGraphModuleNodes( 7367 mp, expected_node_occurrence=expected_node_occurrence) 7368 7369 @skipIfNoFBGEMM 7370 def test_quantized_add_qat(self): 7371 class M(torch.nn.Module): 7372 def __init__(self) -> None: 7373 super().__init__() 7374 self.conv1 = torch.nn.Conv2d(1, 1, 1) 7375 self.conv2 = torch.nn.Conv2d(1, 1, 1) 7376 7377 def forward(self, x): 7378 x = torch.add(x, 1.0) 7379 x = self.conv1(x) 7380 x = torch.add(x, 1.0) 7381 x = torch.relu(x) 7382 x = self.conv2(x) 7383 return x 7384 7385 m = M() 7386 example_inputs = (torch.randn(1, 1, 1, 1),) 7387 expected_node_occurrence = { 7388 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, 7389 } 7390 self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) 7391 7392 @skipIfNoFBGEMM 7393 def test_quantized_mul_qat(self): 7394 class M(torch.nn.Module): 7395 def __init__(self) -> None: 7396 super().__init__() 7397 self.conv1 = torch.nn.Conv2d(1, 1, 1) 7398 self.conv2 = torch.nn.Conv2d(1, 1, 1) 7399 7400 def forward(self, x): 7401 x = torch.mul(x, 1.0) 7402 x = self.conv1(x) 7403 x = torch.mul(x, 1.0) 7404 x = torch.relu(x) 7405 x = self.conv2(x) 7406 return x 7407 7408 m = M() 7409 example_inputs = (torch.randn(1, 1, 1, 1),) 7410 expected_node_occurrence = { 7411 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 5, 7412 } 7413 self._test_quantized_add_mul_qat(m, example_inputs, expected_node_occurrence) 7414 7415 def test_int8_input_no_unnecessary_fq(self): 7416 """ 7417 If the inputs to the graph are quantized and the only node 7418 does not need an activation observer, verifies that the 7419 activation observer is not inserted. 7420 """ 7421 class M(nn.Module): 7422 def __init__(self, scalar): 7423 super().__init__() 7424 self.scalar = scalar 7425 self.add_func = torch.ao.nn.quantized.FloatFunctional() 7426 7427 def forward(self, x): 7428 return self.add_func.add_scalar(x, self.scalar) 7429 7430 m = M(0.5) 7431 mp = torch.ao.quantization.quantize_fx.prepare_qat_fx( 7432 m, {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')}, 7433 example_inputs=(torch.randn(1),), 7434 prepare_custom_config={"input_quantized_idxs": [0]}) 7435 expected_node_occurrence = { 7436 ns.call_module(torch.ao.quantization.FusedMovingAvgObsFakeQuantize): 1, 7437 } 7438 self.checkGraphModuleNodes( 7439 mp, expected_node_occurrence=expected_node_occurrence) 7440 7441 @skipIfNoFBGEMM 7442 def test_cat(self): 7443 """ quantization of the output of cat will depend on the 7444 input of cat. we only quantize the output of cat when its inputs are quantized. 7445 """ 7446 class M(torch.nn.Module): 7447 def __init__(self) -> None: 7448 super().__init__() 7449 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 7450 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 7451 7452 def forward(self, x, y): 7453 x = self.conv1(x) 7454 y = self.conv2(y) 7455 return torch.cat([x, y], 1) 7456 7457 example_inputs = (torch.randn(1, 2, 5, 5, dtype=torch.float), 7458 torch.randn(1, 2, 5, 5, dtype=torch.float)) 7459 quantized_node = ns.call_function(torch.cat) 7460 options = itertools.product(self.static_quant_types, [True, False]) 7461 for quant_type, is_reference in options: 7462 if is_reference: 7463 converted_node_list = [ 7464 ns.call_method("dequantize"), 7465 ns.call_function(torch.cat), 7466 ns.call_function(torch.quantize_per_tensor) 7467 ] 7468 converted_node_occurrence = { 7469 # inputs and outputs of the two conv, and output of cat 7470 ns.call_method("dequantize"): 5, 7471 ns.call_function(torch.cat): 1, 7472 # inputs and outputs of the two conv, and output of cat 7473 ns.call_function(torch.quantize_per_tensor): 5, 7474 } 7475 else: 7476 converted_node_list = None 7477 converted_node_occurrence = { 7478 # output of cat 7479 ns.call_method("dequantize"): 1, 7480 ns.call_function(torch.cat): 1, 7481 # for two inputs 7482 ns.call_function(torch.quantize_per_tensor): 2, 7483 } 7484 7485 self.checkGraphModeFxOp( 7486 M(), 7487 example_inputs, 7488 quant_type, 7489 quantized_node, 7490 expected_node_list=converted_node_list, 7491 expected_node_occurrence=converted_node_occurrence, 7492 is_reference=is_reference) 7493 7494 # check cat is using the same observer for input and output 7495 m = M().eval() 7496 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 7497 # two inputs and one output of torch.cat are using same observer, so we have 7498 # 2 observers that's replicated 7499 all_observers = len(dict(m.named_modules(remove_duplicate=False))) 7500 distinct_observers = len(dict(m.named_modules())) 7501 self.assertEqual(all_observers, distinct_observers + 2) 7502 # make sure the converted model runs 7503 m = convert_fx(m) 7504 m(*example_inputs) 7505 7506 @skipIfNoFBGEMM 7507 def test_qbatch_norm(self): 7508 bn_module = { 7509 # TODO: quantized batchnorm 1d module is missing 7510 # 1 : torch.nn.BatchNorm1d, 7511 2 : torch.nn.BatchNorm2d, 7512 3 : torch.nn.BatchNorm3d, 7513 } 7514 7515 class M(torch.nn.Module): 7516 def __init__(self, dim): 7517 super().__init__() 7518 self.bn = bn_module[dim](3).to(torch.float) 7519 7520 def forward(self, x): 7521 return self.bn(x) 7522 7523 options = itertools.product(self.static_quant_types, [2, 3], [True, False]) 7524 quantized_nodes = { 7525 False: { 7526 # 1: ns.call_module(nnq.BatchNorm1d), 7527 2: ns.call_module(nnq.BatchNorm2d), 7528 3: ns.call_module(nnq.BatchNorm3d), 7529 }, 7530 True: { 7531 # 1: ns.call_module(nn.BatchNorm1d), 7532 2: ns.call_module(nn.BatchNorm2d), 7533 3: ns.call_module(nn.BatchNorm3d), 7534 } 7535 } 7536 for quant_type, dim, is_reference in options: 7537 self.checkGraphModeFxOp( 7538 M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[is_reference][dim], is_reference=is_reference) 7539 7540 @skipIfNoFBGEMM 7541 def test_qbatch_norm_relu(self): 7542 bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} 7543 7544 class BNRelu(torch.nn.Module): 7545 def __init__(self, dim, inplace): 7546 super().__init__() 7547 self.bn = bn_module[dim](3).to(torch.float) 7548 self.relu = torch.nn.ReLU(inplace=inplace) 7549 7550 def forward(self, x): 7551 return self.relu(self.bn(x)) 7552 7553 class BNFuncRelu(torch.nn.Module): 7554 def __init__(self, dim): 7555 super().__init__() 7556 self.bn = bn_module[dim](3).to(torch.float) 7557 7558 def forward(self, x): 7559 return F.relu(self.bn(x), False) 7560 7561 class BNFuncInplaceRelu(torch.nn.Module): 7562 def __init__(self, dim): 7563 super().__init__() 7564 self.bn = bn_module[dim](3).to(torch.float) 7565 7566 def forward(self, x): 7567 return F.relu(self.bn(x), True) 7568 7569 options = itertools.product(self.static_quant_types, [2, 3], [True, False]) 7570 quantized_nodes = { 7571 True: { 7572 2: ns.call_module(nni.BNReLU2d), 7573 3: ns.call_module(nni.BNReLU3d), 7574 }, 7575 False: { 7576 2: ns.call_module(nniq.BNReLU2d), 7577 3: ns.call_module(nniq.BNReLU3d), 7578 } 7579 } 7580 for quant_type, dim, is_reference in options: 7581 for instance in [BNRelu(dim, True), BNRelu(dim, False), 7582 BNFuncRelu(dim), BNFuncInplaceRelu(dim)]: 7583 self.checkGraphModeFxOp( 7584 instance, self.img_data_dict[dim], quant_type, 7585 quantized_nodes[is_reference][dim], is_reference=is_reference) 7586 7587 def _test_activation_impl( 7588 self, float_module, float_op, quantized_module, quantized_op): 7589 ''' Test for activation op(with inplace options), float_op can be 7590 torch op or functional op 7591 ''' 7592 class M(torch.nn.Module): 7593 def __init__(self, is_module, inplace): 7594 super().__init__() 7595 self.is_module = is_module 7596 self.inplace = inplace 7597 if self.is_module: 7598 self.op = float_module(self.inplace) 7599 else: 7600 self.op = float_op 7601 7602 def forward(self, input): 7603 if self.is_module: 7604 return self.op(input) 7605 else: 7606 return self.op(input, self.inplace) 7607 7608 options = itertools.product([True, False], [True, False], self.static_quant_types, [True, False]) 7609 quantized_nodes = { 7610 # is_module 7611 True: { 7612 # is_reference 7613 True: ns.call_module(float_module), 7614 False: ns.call_module(quantized_module), 7615 }, 7616 False: { 7617 True: ns.call_function(float_op), 7618 False: ns.call_function(quantized_op), 7619 } 7620 } 7621 7622 for is_module, is_inplace, quant_type, is_reference in options: 7623 self.checkGraphModeFxOp( 7624 M(is_module, is_inplace), self.img_data_2d, 7625 quant_type, quantized_nodes[is_module][is_reference], is_reference=is_reference) 7626 7627 def test_hardswish(self): 7628 self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish) 7629 7630 def test_elu(self): 7631 self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu) 7632 7633 def test_leaky_relu(self): 7634 self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu) 7635 7636 def test_prelu(self): 7637 class M(torch.nn.Module): 7638 def __init__(self, num_param: int): 7639 super().__init__() 7640 self.op = torch.nn.PReLU(num_parameters=num_param) 7641 7642 def forward(self, input): 7643 return self.op(input) 7644 7645 X = [[torch.randn(4, 4, 4, 4, dtype=torch.float)]] 7646 options = itertools.product([1, 4], self.static_quant_types, [True, False]) 7647 quantized_nodes = { 7648 # is_reference 7649 True: ns.call_module(torch.nn.PReLU), 7650 False: ns.call_module(torch.ao.nn.quantized.PReLU), 7651 } 7652 7653 for num_parameter, quant_type, is_reference in options: 7654 self.checkGraphModeFxOp( 7655 M(num_parameter), X, quant_type, quantized_nodes[is_reference], 7656 is_reference=is_reference) 7657 7658 def _test_norm_impl( 7659 self, float_module, float_op, op_args, data, quantized_module, quantized_op, 7660 skip_op_arg_for_functional=False): 7661 ''' Test for normalization op, float_op can be torch op or functional op, 7662 op_args is a list of positional argument for the module/op 7663 ''' 7664 class M(torch.nn.Module): 7665 def __init__(self, is_module): 7666 super().__init__() 7667 self.is_module = is_module 7668 if self.is_module: 7669 self.op = float_module(*op_args) 7670 else: 7671 self.op = float_op 7672 7673 def forward(self, input): 7674 if self.is_module: 7675 return self.op(input) 7676 else: 7677 args = [input] 7678 if not skip_op_arg_for_functional: 7679 args += op_args 7680 return self.op(*args) 7681 7682 options = itertools.product([True, False], self.static_quant_types) 7683 quantized_nodes = { 7684 # is_module 7685 True: ns.call_module(quantized_module), 7686 False: ns.call_function(quantized_op), 7687 } 7688 7689 for is_module, quant_type in options: 7690 self.checkGraphModeFxOp( 7691 M(is_module), data, quant_type, quantized_nodes[is_module]) 7692 7693 def _test_norm_float16_impl( 7694 self, float_module, float_op, op_args, data, 7695 skip_op_arg_for_functional=False): 7696 ''' Test for normalization op, float_op can be torch op or functional op, 7697 op_args is a list of positional argument for the module/op 7698 ''' 7699 class M(torch.nn.Module): 7700 def __init__(self, is_module): 7701 super().__init__() 7702 self.is_module = is_module 7703 if self.is_module: 7704 self.op = float_module(*op_args) 7705 else: 7706 self.op = float_op 7707 7708 def forward(self, input): 7709 if self.is_module: 7710 return self.op(input) 7711 else: 7712 args = [input] 7713 if not skip_op_arg_for_functional: 7714 args += op_args 7715 return self.op(*args) 7716 7717 options = itertools.product([True, False], self.static_quant_types) 7718 qconfig_dict = { 7719 "object_type": [ 7720 (float_module, float16_static_qconfig), 7721 (float_op, float16_static_qconfig) 7722 ] 7723 } 7724 node_occurrence = { 7725 ns.call_method("to"): 2 7726 } 7727 for is_module, quant_type in options: 7728 self.checkGraphModeFxOp( 7729 M(is_module), data, quant_type, custom_qconfig_dict=qconfig_dict, expected_node_occurrence=node_occurrence) 7730 7731 def test_layer_norm(self): 7732 data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) 7733 self._test_norm_impl( 7734 nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm) 7735 7736 def test_instance_norm(self): 7737 data_1d = (torch.rand((1, 4, 5), dtype=torch.float),) 7738 data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),) 7739 data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),) 7740 data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d} 7741 instance_norm_modules = {1 : nn.InstanceNorm1d, 7742 2 : nn.InstanceNorm2d, 7743 3 : nn.InstanceNorm3d} 7744 quantized_instance_norm_modules = { 7745 1 : nnq.InstanceNorm1d, 7746 2 : nnq.InstanceNorm2d, 7747 3 : nnq.InstanceNorm3d 7748 } 7749 for dim in [1, 2, 3]: 7750 data = data_dict[dim] 7751 module = instance_norm_modules[dim] 7752 quantized_module = quantized_instance_norm_modules[dim] 7753 self._test_norm_impl( 7754 module, F.instance_norm, [4], data, 7755 quantized_module, torch.ops.quantized.instance_norm, 7756 skip_op_arg_for_functional=True) 7757 7758 def test_norm_weight_bias(self): 7759 class Linear(torch.nn.Module): 7760 def __init__(self) -> None: 7761 super().__init__() 7762 self.w = torch.ones(5, 5) 7763 self.b = torch.zeros(5) 7764 7765 def forward(self, x): 7766 return torch.nn.functional.linear(x, self.w, self.b) 7767 7768 class M(torch.nn.Module): 7769 def __init__(self) -> None: 7770 super().__init__() 7771 self.mods1 = Linear() 7772 self.scale = torch.randn(5, 5) 7773 self.bias = torch.randn(5, 5) 7774 7775 def forward(self, x): 7776 x1 = self.mods1(x) 7777 y = F.layer_norm(x1, [5, 5], weight=self.scale, bias=self.bias) 7778 return y 7779 7780 model = M() 7781 expected_occurrence = { 7782 ns.call_function(torch.quantize_per_tensor): 1, 7783 ns.call_function(torch.ops.quantized.linear): 1, 7784 ns.call_function(torch.ops.quantized.layer_norm): 1, 7785 ns.call_method("dequantize"): 1, 7786 } 7787 7788 self.checkGraphModeFxOp( 7789 model, 7790 (torch.rand(5, 5),), 7791 QuantType.STATIC, 7792 expected_node_occurrence=expected_occurrence, 7793 custom_qconfig_dict=get_default_qconfig_mapping().to_dict() 7794 ) 7795 7796 def _test_default_node_quant_handler_ops( 7797 self, module, functional, qconfig, is_reference=True, node_list=None, additional_quant_pattern_dict=None 7798 ): 7799 class M(torch.nn.Module): 7800 def __init__(self, mod, func): 7801 super().__init__() 7802 self.module = mod() 7803 self.functional = func 7804 7805 def forward(self, x): 7806 x = self.module(x) 7807 x = self.functional(x) 7808 return x 7809 7810 if node_list is None: 7811 node_list = [] 7812 if additional_quant_pattern_dict is None: 7813 additional_quant_pattern_dict = {} 7814 7815 data = torch.randn((2, 2, 2, 2)) 7816 quant_type = QuantType.STATIC 7817 prepare_custom_qconfig_dict = {"additional_quant_pattern": additional_quant_pattern_dict} 7818 qconfig_dict = {"": qconfig} 7819 7820 m = M(module, functional).eval() 7821 m_prep = prepare_fx(m, qconfig_dict, prepare_custom_qconfig_dict) 7822 m_prep(data) 7823 convert_fn = convert_to_reference_fx if is_reference else convert_fx 7824 m_quant = convert_fn(m_prep, is_reference=is_reference) 7825 m_quant(data) 7826 7827 self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) 7828 7829 @unittest.skip("TODO: reenable with backend_config api") 7830 def test_gelu_normal(self): 7831 module = torch.nn.GELU 7832 functional = torch.nn.functional.gelu 7833 qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 7834 is_reference = False 7835 node_list = [ 7836 ns.call_module(module), 7837 ns.call_function(functional), 7838 ] 7839 self._test_default_node_quant_handler_ops( 7840 module, functional, qconfig, is_reference, node_list) 7841 7842 @unittest.skip("TODO: reenable with backend_config api") 7843 def test_softmax_normal(self): 7844 module = torch.nn.Softmax 7845 functional = torch.nn.functional.softmax 7846 qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 7847 is_reference = False 7848 node_list = [ 7849 ns.call_module(torch.ao.nn.quantized.Softmax), 7850 ns.call_function(functional), 7851 ] 7852 self._test_default_node_quant_handler_ops( 7853 module, functional, qconfig, is_reference, node_list) 7854 7855 @unittest.skip("This is no longer needed right now, can enable later with new api") 7856 def test_gelu_reference(self): 7857 module = torch.nn.GELU 7858 functional = torch.nn.functional.gelu 7859 qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 7860 is_reference = True 7861 node_list = [ 7862 ns.call_function(torch.quantize_per_tensor), 7863 ns.call_method("dequantize"), 7864 ns.call_module(module), 7865 ns.call_function(torch.quantize_per_tensor), 7866 ns.call_method('dequantize'), 7867 ns.call_function(functional), 7868 ns.call_function(torch.quantize_per_tensor), 7869 ns.call_method('dequantize') 7870 ] 7871 # TODO: change these to use backend_config 7872 additional_patterns = {torch.nn.GELU: DefaultNodeQuantizeHandler, 7873 torch.nn.functional.gelu: DefaultNodeQuantizeHandler} 7874 self._test_default_node_quant_handler_ops( 7875 module, functional, qconfig, is_reference, node_list, additional_patterns) 7876 7877 self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, 7878 additional_quant_pattern_dict=self.common_quant_patterns) 7879 7880 @unittest.skip("This is no longer needed right now, can enable later with new api") 7881 def test_softmax_reference(self): 7882 module = torch.nn.Softmax 7883 functional = torch.nn.functional.softmax 7884 qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 7885 is_reference = True 7886 node_list = [ 7887 ns.call_function(torch.quantize_per_tensor), 7888 ns.call_method("dequantize"), 7889 ns.call_module(module), 7890 ns.call_function(torch.quantize_per_tensor), 7891 ns.call_method('dequantize'), 7892 ns.call_function(functional), 7893 ns.call_function(torch.quantize_per_tensor), 7894 ns.call_method('dequantize') 7895 ] 7896 additional_patterns = {torch.nn.Softmax: DefaultNodeQuantizeHandler, 7897 torch.nn.functional.softmax: DefaultNodeQuantizeHandler} 7898 self._test_default_node_quant_handler_ops( 7899 module, functional, qconfig, is_reference, node_list, additional_patterns) 7900 7901 self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, 7902 additional_quant_pattern_dict=self.common_quant_patterns) 7903 7904 @unittest.skip("This is no longer needed right now, can enable later with new api") 7905 def test_silu_reference(self): 7906 module = torch.nn.SiLU 7907 functional = torch.nn.functional.silu 7908 qconfig = float16_static_qconfig 7909 is_reference = True 7910 node_list = [ 7911 ns.call_method("to"), 7912 ns.call_method("dequantize"), 7913 ns.call_module(module), 7914 ns.call_method("to"), 7915 ns.call_method('dequantize'), 7916 ns.call_function(functional), 7917 ns.call_method("to"), 7918 ns.call_method('dequantize') 7919 ] 7920 self._test_default_node_quant_handler_ops( 7921 module, functional, qconfig, is_reference, node_list) 7922 7923 node_list = [ 7924 ns.call_function(torch.quantize_per_tensor), 7925 ns.call_method("dequantize"), 7926 ns.call_module(module), 7927 ns.call_function(torch.quantize_per_tensor), 7928 ns.call_method("dequantize"), 7929 ns.call_function(functional), 7930 ns.call_function(torch.quantize_per_tensor), 7931 ns.call_method("dequantize") 7932 ] 7933 self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, 7934 additional_quant_pattern_dict=self.common_quant_patterns) 7935 7936 @unittest.skip("This is no longer needed right now, can enable later with new api") 7937 def test_mish_reference(self): 7938 module = torch.nn.Mish 7939 functional = torch.nn.functional.mish 7940 qconfig = float16_static_qconfig 7941 is_reference = True 7942 node_list = [ 7943 ns.call_method("to"), 7944 ns.call_method("dequantize"), 7945 ns.call_module(module), 7946 ns.call_method("to"), 7947 ns.call_method('dequantize'), 7948 ns.call_function(functional), 7949 ns.call_method("to"), 7950 ns.call_method('dequantize') 7951 ] 7952 self._test_default_node_quant_handler_ops( 7953 module, functional, qconfig, is_reference, node_list) 7954 7955 node_list = [ 7956 ns.call_function(torch.quantize_per_tensor), 7957 ns.call_method("dequantize"), 7958 ns.call_module(module), 7959 ns.call_function(torch.quantize_per_tensor), 7960 ns.call_method("dequantize"), 7961 ns.call_function(functional), 7962 ns.call_function(torch.quantize_per_tensor), 7963 ns.call_method("dequantize") 7964 ] 7965 self._test_default_node_quant_handler_ops(module, functional, self.custom_qconfig, is_reference, node_list, 7966 additional_quant_pattern_dict=self.common_quant_patterns) 7967 7968 def test_bmm_int_reference(self): 7969 """ int8 is not supported for bmm so we won't produce reference 7970 pattern for it 7971 """ 7972 class M(torch.nn.Module): 7973 def __init__(self) -> None: 7974 super().__init__() 7975 self.bmm = torch.bmm 7976 7977 def forward(self, x, y): 7978 out = self.bmm(x, y) 7979 return out 7980 7981 data_x = torch.randn((2, 2, 2,)) 7982 data_y = torch.randn((2, 2, 2,)) 7983 example_inputs = (data_x, data_y) 7984 qconfig_dict = {"": torch.ao.quantization.get_default_qconfig("fbgemm")} 7985 is_reference = True 7986 node_list = [ 7987 ns.call_function(torch.bmm), 7988 ] 7989 7990 m = M().eval() 7991 m_prep = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 7992 m_prep(*example_inputs) 7993 convert_fn = convert_to_reference_fx if is_reference else convert_fx 7994 m_quant = convert_fn(m_prep) 7995 m_quant(*example_inputs) 7996 7997 self.checkGraphModuleNodes(m_quant, expected_node_list=node_list) 7998 7999 @skipIfNoFBGEMM 8000 def test_clamp(self): 8001 class M(torch.nn.Module): 8002 def __init__(self) -> None: 8003 super().__init__() 8004 self.conv = torch.nn.Conv2d(2, 2, 2).float() 8005 self.relu6 = torch.nn.ReLU6() 8006 self.relu6_ = torch.nn.ReLU6(True) 8007 self.hardtanh = torch.nn.Hardtanh() 8008 self.hardtanh_ = torch.nn.Hardtanh(inplace=True) 8009 8010 def forward(self, x): 8011 x = self.conv(x) 8012 x = self.relu6(x) 8013 self.relu6_(x) 8014 x = F.relu6(x) 8015 x = torch.clamp(x, -3, 3) 8016 x = x.clamp(-2.5, 2.5) 8017 # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready 8018 x = self.hardtanh(x) 8019 self.hardtanh_(x) 8020 x = F.hardtanh(x) 8021 return x 8022 8023 data = (torch.rand((1, 2, 5, 5), dtype=torch.float),) 8024 # list of node that should occur in order 8025 node_list = [ 8026 ns.call_function(torch.quantize_per_tensor), 8027 ns.call_module(nnq.Conv2d), 8028 ns.call_method('dequantize') 8029 ] 8030 for quant_type in self.static_quant_types: 8031 self.checkGraphModeFxOp( 8032 M(), data, quant_type, expected_node_list=node_list) 8033 8034 def test_fixed_qparams_ops_fp16(self): 8035 class M(torch.nn.Module): 8036 def __init__(self) -> None: 8037 super().__init__() 8038 self.sigmoid = torch.nn.Sigmoid() 8039 self.tanh = torch.nn.Tanh() 8040 8041 def forward(self, x): 8042 x = self.sigmoid(x) 8043 x = torch.sigmoid(x) 8044 x = x.sigmoid() 8045 x = self.tanh(x) 8046 x = torch.tanh(x) 8047 x = x.tanh() 8048 return x 8049 8050 data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) 8051 quant_type = QuantType.STATIC 8052 # TODO: use get_default_qconfig_mapping once it handles fp16 8053 qconfig_mapping = QConfigMapping().set_global(float16_static_qconfig) 8054 backend_config = get_test_only_legacy_native_backend_config() 8055 node_occurrence = { 8056 ns.call_method("to"): 7 8057 } 8058 self.checkGraphModeFxOp( 8059 M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, 8060 expected_node_occurrence=node_occurrence, 8061 backend_config=backend_config) 8062 8063 def test_fixed_qparams_ops_qint8(self): 8064 class M(torch.nn.Module): 8065 def __init__(self) -> None: 8066 super().__init__() 8067 self.sigmoid = torch.nn.Sigmoid() 8068 self.tanh = torch.nn.Tanh() 8069 8070 def forward(self, x): 8071 x = self.sigmoid(x) 8072 x = torch.sigmoid(x) 8073 x = x.sigmoid() 8074 x = self.tanh(x) 8075 x = torch.tanh(x) 8076 x = x.tanh() 8077 return x 8078 8079 data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) 8080 quant_type = QuantType.STATIC 8081 qconfig = torch.ao.quantization.QConfig( 8082 activation=HistogramObserver.with_args(qscheme=torch.per_tensor_symmetric, dtype=torch.quint8), 8083 weight=default_weight_observer) 8084 qconfig_mapping = get_default_qconfig_mapping().set_global(qconfig) 8085 node_occurrence = { 8086 ns.call_function(torch.quantize_per_tensor): 7, 8087 ns.call_method("dequantize"): 7 8088 } 8089 self.checkGraphModeFxOp( 8090 M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, 8091 expected_node_occurrence=node_occurrence, is_reference=True) 8092 8093 def test_fixed_qparams_ops_wrong_qconfig(self): 8094 """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized. 8095 """ 8096 class M(torch.nn.Module): 8097 def __init__(self) -> None: 8098 super().__init__() 8099 self.sigmoid = torch.nn.Sigmoid() 8100 self.tanh = torch.nn.Tanh() 8101 8102 def forward(self, x): 8103 x = self.sigmoid(x) 8104 x = torch.sigmoid(x) 8105 x = x.sigmoid() 8106 x = self.tanh(x) 8107 x = torch.tanh(x) 8108 x = x.tanh() 8109 return x 8110 8111 data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) 8112 qconfig_mapping = QConfigMapping().set_global(default_qconfig) 8113 m = M().eval() 8114 node_occurrence = { 8115 ns.call_function(torch.quantize_per_tensor): 0, 8116 ns.call_method("dequantize"): 0, 8117 } 8118 self.checkGraphModeFxOp( 8119 m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping, 8120 expected_node_occurrence=node_occurrence, is_reference=True) 8121 self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid)) 8122 self.assertTrue(isinstance(m.tanh, torch.nn.Tanh)) 8123 8124 @skipIfNoFBGEMM 8125 def test_general_shape_ops(self): 8126 """ A test that checks dequantize will be swapped for 8127 all supported general shape ops like aten::flatten 8128 without actually checking for execution of these ops 8129 """ 8130 class M(torch.nn.Module): 8131 def __init__(self) -> None: 8132 super().__init__() 8133 self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) 8134 self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) 8135 self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) 8136 self.dropout = torch.nn.Dropout() 8137 self.conv1 = torch.nn.Conv2d(3, 3, 3) 8138 self.conv2 = torch.nn.Conv2d(3, 3, 3) 8139 self.relu = torch.nn.ReLU() 8140 8141 def forward(self, x): 8142 x = self.conv1(x) 8143 # add_scalar 8144 x = x + 3 8145 # mul_scalar 8146 x = x * 3 8147 # add_scalar_out 8148 x += 3 8149 # mul_scalar_out 8150 x *= 3 8151 # add_scalar_relu 8152 x = x + 3 8153 x = F.relu(x) 8154 # add_scalar_relu_out 8155 x += 3 8156 x = F.relu(x) 8157 # mul_scalar_relu 8158 x = x * 3 8159 x = F.relu(x) 8160 # mul_scalar_relu_out 8161 x *= 3 8162 x = F.relu(x) 8163 x = self.maxpool1d(x) 8164 x = self.maxpool2d(x) 8165 x = self.maxpool3d(x) 8166 x = torch.flatten(x) 8167 x = x.reshape([-1]) 8168 x = x.resize_(1, 1, x) 8169 x = x.view(-1) 8170 # prim::ListConstruct 8171 xs = [x, x] 8172 # prim::ListUnpack 8173 x, y = xs 8174 # prim::TupleConstruct 8175 xs = (x, x) 8176 # prim::TupleUnpack 8177 x, y = xs 8178 x = x.transpose(1, 2) 8179 x = x.contiguous() 8180 # chunk is not supported since observer only supports 8181 # observing single Tensor currently 8182 x, y = torch.chunk(x, 2) 8183 x = F.dropout(x) 8184 x = self.dropout(x) 8185 x = x.permute(0, 2, 3, 1) 8186 x = x.repeat_interleave(3, 1) 8187 x = torch.repeat_interleave(x, 3, 1) 8188 x = self.relu(x) 8189 x = F.relu(x) 8190 x = F.relu(x, inplace=True) 8191 x = x.relu() 8192 x.relu_() 8193 x = x.squeeze(0) 8194 x.squeeze_(0) 8195 x = torch.squeeze(x, 0) 8196 x = x.unsqueeze(0) 8197 x.unsqueeze_(0) 8198 x = torch.unsqueeze(x, 0) 8199 x = x.detach() 8200 x.detach_() 8201 x = x.repeat(4, 2) 8202 y = [] 8203 y.append(x) 8204 z = torch.stack(y, 0) 8205 z = [z, z] 8206 x, _ = z 8207 x = self.conv2(x) 8208 return x 8209 8210 example_inputs = (torch.rand(1, 3, 10, 10),) 8211 # This model is not executable since we just put all ops 8212 # in the same forward 8213 m = M().eval() 8214 qconfig_dict = {'': default_qconfig} 8215 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 8216 # not runnable 8217 quantized = convert_fx(prepared) 8218 8219 # This checks that the dequantize from the output of first conv 8220 # is being propagated to the end, so that we don't insert extra 8221 # observers and also successfully fused two quantized::conv2d 8222 # patterns 8223 # one quantize_per_tensor for input 8224 # check exact counts of quantize and dequantize 8225 count_check = { 8226 # input of conv and two outputs of getitem 8227 ns.call_function(torch.quantize_per_tensor) : 2, 8228 # output of the model and two outputs of getitem 8229 ns.call_method('dequantize') : 2 8230 } 8231 order_check = [ 8232 ns.call_function(torch.quantize_per_tensor), 8233 ns.call_module(nnq.Conv2d), 8234 ns.call_module(nnq.Conv2d), 8235 ns.call_method('dequantize'), 8236 ] 8237 self.checkGraphModuleNodes( 8238 quantized, 8239 expected_node_occurrence=count_check, 8240 expected_node_list=order_check) 8241 8242 8243 # Checking the is_reference output 8244 m = M().eval() 8245 qconfig_dict = {'': default_qconfig} 8246 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 8247 # not runnable 8248 quantized = convert_to_reference_fx(prepared) 8249 8250 8251 @skipIfNoFBGEMM 8252 def test_ave_pool_with_custom_cfg(self): 8253 """ A test that checks correct patterns are produced for 8254 avg_pool2d with customized config 8255 """ 8256 class M(torch.nn.Module): 8257 def __init__(self) -> None: 8258 super().__init__() 8259 self.avg_pool2d = torch.nn.AvgPool2d(3) 8260 8261 8262 def forward(self, x): 8263 x = self.avg_pool2d(x) 8264 return x 8265 8266 # This model is not executable since we just put all ops 8267 # in the same forward 8268 m = M().eval() 8269 # nothing to fuse so skipping the fuse step 8270 qconfig_dict = {'': default_qconfig} 8271 example_inputs = (torch.randn(1, 3, 3, 3),) 8272 prepared = prepare_fx( 8273 m, qconfig_dict, example_inputs=example_inputs, 8274 prepare_custom_config={"input_quantized_idxs": [0]}) 8275 8276 # not runnable 8277 quantized = convert_fx(prepared) 8278 8279 # This checks that the dequantize from the output of first conv 8280 # is being propagated to the end, so that we don't insert extra 8281 # observers 8282 # check exact counts of quantize and dequantize 8283 count_check = { 8284 ns.call_method('dequantize') : 1 8285 } 8286 order_check = [ 8287 ns.call_module(nn.AvgPool2d), 8288 ns.call_method('dequantize'), 8289 ] 8290 self.checkGraphModuleNodes( 8291 quantized, 8292 expected_node_occurrence=count_check, 8293 expected_node_list=order_check) 8294 8295 @skipIfNoFBGEMM 8296 def test_general_value_ops(self): 8297 """ A test that checks correct patterns are produced for 8298 all supported general value ops like aten::avg_pool2d \ 8299 without actually checking for execution of these ops 8300 """ 8301 class M(torch.nn.Module): 8302 def __init__(self) -> None: 8303 super().__init__() 8304 self.conv = torch.nn.Conv2d(3, 3, 3) 8305 self.avg_pool1d = torch.nn.AvgPool1d(3) 8306 self.avg_pool2d = torch.nn.AvgPool2d(3) 8307 self.avg_pool3d = torch.nn.AvgPool3d(3) 8308 self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1) 8309 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 8310 self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) 8311 8312 def forward(self, x): 8313 x = self.conv(x) 8314 x = self.avg_pool1d(x) 8315 x = self.avg_pool2d(x) 8316 x = self.avg_pool3d(x) 8317 x = self.adaptive_avg_pool1d(x) 8318 x = self.adaptive_avg_pool2d(x) 8319 x = self.adaptive_avg_pool3d(x) 8320 x = F.avg_pool1d(x, 3) 8321 x = F.avg_pool2d(x, 3) 8322 x = F.avg_pool3d(x, 3) 8323 x = F.adaptive_avg_pool1d(x, (1)) 8324 x = F.adaptive_avg_pool2d(x, (1, 1)) 8325 x = F.adaptive_avg_pool3d(x, (1, 1, 1)) 8326 x = torch.mean(x) 8327 x = torch.mean(x, [2, 3], False) 8328 x = x.mean() 8329 x = x.mean([2, 3], True) 8330 x = F.interpolate(x, 4, mode='nearest') 8331 x = F.interpolate(x, 4, mode='linear') 8332 x = self.conv(x) 8333 return x 8334 8335 # This model is not executable since we just put all ops 8336 # in the same forward 8337 m = M().eval() 8338 # nothing to fuse so skipping the fuse step 8339 qconfig_dict = {'': default_qconfig} 8340 example_inputs = (torch.randn(1, 3, 3, 3),) 8341 prepared = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 8342 # not runnable 8343 quantized = convert_fx(prepared) 8344 8345 # This checks that the dequantize from the output of first conv 8346 # is being propagated to the end, so that we don't insert extra 8347 # observers 8348 # check exact counts of quantize and dequantize 8349 count_check = { 8350 ns.call_function(torch.quantize_per_tensor) : 1, 8351 ns.call_method('dequantize') : 1 8352 } 8353 order_check = [ 8354 ns.call_function(torch.quantize_per_tensor), 8355 ns.call_module(nnq.Conv2d), 8356 ns.call_module(nnq.Conv2d), 8357 ns.call_method('dequantize'), 8358 ] 8359 self.checkGraphModuleNodes( 8360 quantized, 8361 expected_node_occurrence=count_check, 8362 expected_node_list=order_check) 8363 8364 def test_copy_node_fp32_input(self): 8365 """ CopyNode works for both fp32 and int8 inputs, this is a test to make 8366 sure that a CopyNode can be successfully quantized in both cases 8367 """ 8368 class M(torch.nn.Module): 8369 def forward(self, x): 8370 x = x.relu() 8371 return x 8372 8373 m = M().eval() 8374 m = prepare_fx(m, {"": default_reuse_input_qconfig}, example_inputs=(torch.randn(1),)) 8375 m = convert_fx(m) 8376 # make sure it runs 8377 m(torch.rand(1)) 8378 8379 def test_getitem(self): 8380 """ Make sure we only insert observer for getitem if the following node is matched 8381 or needs to be quantized 8382 """ 8383 class M(torch.nn.Module): 8384 def forward(self, xs): 8385 x = xs[0] 8386 return x 8387 8388 m = M().eval() 8389 example_inputs = (torch.rand(1, 2),) 8390 qconfig_mapping = get_default_qconfig_mapping() 8391 m = prepare_fx(m, qconfig_mapping, example_inputs=example_inputs) 8392 self.checkGraphModuleNodes(m, expected_node_occurrence={ 8393 ns.call_module(torch.ao.quantization.MinMaxObserver): 0 8394 }) 8395 m = convert_fx(m) 8396 m(*example_inputs) 8397 8398 class M2(torch.nn.Module): 8399 def forward(self, xs): 8400 x = xs[0] 8401 x = torch.sigmoid(x) 8402 return x 8403 8404 m2 = M2().eval() 8405 example_inputs = ([torch.rand(1, 2)],) 8406 qconfig_mapping = get_default_qconfig_mapping() 8407 m2 = prepare_fx(m2, qconfig_mapping, example_inputs=example_inputs) 8408 self.checkGraphModuleNodes(m2, expected_node_occurrence={ 8409 ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 8410 }) 8411 m2 = convert_fx(m2) 8412 self.checkGraphModuleNodes(m2, expected_node_list=[ 8413 ns.call_function(torch.quantize_per_tensor), 8414 ns.call_method("dequantize") 8415 ]) 8416 m2(*example_inputs) 8417 8418 # testing prepare recognizes non-Tensor input for getitem 8419 class M3(torch.nn.Module): 8420 def forward(self, x): 8421 s = x.shape 8422 n, c = s[:2] 8423 x = torch.sigmoid(x) 8424 return x 8425 8426 m3 = M3().eval() 8427 example_inputs = (torch.rand(1, 2, 3, 4),) 8428 qconfig_mapping = get_default_qconfig_mapping() 8429 m3 = prepare_fx(m3, qconfig_mapping, example_inputs=example_inputs) 8430 self.checkGraphModuleNodes(m3, expected_node_occurrence={ 8431 ns.call_module(torch.ao.quantization.FixedQParamsObserver): 2 8432 }) 8433 m3 = convert_fx(m3) 8434 self.checkGraphModuleNodes(m3, expected_node_list=[ 8435 ns.call_function(torch.quantize_per_tensor), 8436 ns.call_method("dequantize") 8437 ]) 8438 m3(*example_inputs) 8439 8440 8441 @skipIfNoFBGEMM 8442 def test_fixed_qparams_ops(self): 8443 class M(torch.nn.Module): 8444 def __init__(self) -> None: 8445 super().__init__() 8446 self.conv = torch.nn.Conv2d(3, 3, 3) 8447 self.sigmoid = torch.nn.Sigmoid() 8448 self.hardsigmoid = torch.nn.Hardsigmoid() 8449 self.tanh = torch.nn.Tanh() 8450 self.softmax = torch.nn.Softmax(dim=0) 8451 8452 def forward(self, x): 8453 x = self.conv(x) 8454 # F.sigmoid is deprecated 8455 x = self.sigmoid(x) 8456 x = torch.sigmoid(x) 8457 x = x.sigmoid() 8458 x = self.hardsigmoid(x) 8459 x = F.hardsigmoid(x) 8460 x = F.hardsigmoid(x, inplace=True) 8461 x = self.tanh(x) 8462 # F.tanh is deprecated 8463 x = torch.tanh(x) 8464 x = x.tanh() 8465 # TODO(future PR): handle F.softmax 8466 x = self.softmax(x) 8467 return x 8468 8469 for eval_mode in [True, False]: 8470 # This model is not executable since we just put all ops 8471 # in the same forward 8472 m = M() 8473 if eval_mode: 8474 m.eval() 8475 qconfig_mapping = get_default_qconfig_mapping() 8476 prepare = prepare_fx 8477 fq_count = 10 8478 else: 8479 m.train() 8480 qconfig_mapping = get_default_qat_qconfig_mapping() 8481 prepare = prepare_qat_fx 8482 fq_count = 10 8483 # nothing to fuse so skipping the fuse step 8484 m_copy = copy.deepcopy(m) 8485 example_inputs = (torch.rand(3, 3, 3, 3),) 8486 prepared = prepare(m, qconfig_mapping, example_inputs=example_inputs) 8487 prepared_copy = copy.deepcopy(prepared) 8488 # check that prepare does not change model result 8489 if eval_mode: 8490 self.assertEqual(m_copy(*example_inputs), prepared_copy(*example_inputs)) 8491 # check the correct number of activation_post_process is inserted 8492 expected_activation_post_process = FixedQParamsObserver if eval_mode else FixedQParamsFakeQuantize 8493 count_check = { 8494 ns.call_module(expected_activation_post_process) : fq_count, 8495 } 8496 self.checkGraphModuleNodes( 8497 prepared, 8498 expected_node_occurrence=count_check) 8499 # not runnable 8500 quantized = convert_fx(prepared) 8501 quantized_reference = convert_to_reference_fx(prepared_copy) 8502 8503 # This checks that the dequantize from the output of first conv 8504 # is being propagated to the end, so that we don't insert extra 8505 # observers 8506 # check exact counts of quantize and dequantize 8507 count_check = { 8508 ns.call_function(torch.quantize_per_tensor) : 1, 8509 ns.call_method('dequantize') : 1 8510 } 8511 order_check = [ 8512 ns.call_function(torch.quantize_per_tensor), 8513 ns.call_module(nnq.Conv2d), 8514 ns.call_module(nn.Sigmoid), 8515 ns.call_module(nnq.Softmax), 8516 ns.call_method('dequantize'), 8517 ] 8518 self.checkGraphModuleNodes( 8519 quantized, 8520 expected_node_occurrence=count_check, 8521 expected_node_list=order_check) 8522 8523 reference_count_check = { 8524 ns.call_function(torch.quantize_per_tensor) : 12, 8525 ns.call_method('dequantize') : 12 8526 } 8527 reference_order_check = [ 8528 ns.call_function(torch.quantize_per_tensor), 8529 ns.call_method('dequantize'), 8530 ns.call_module(nnqr.Conv2d), 8531 ns.call_function(torch.quantize_per_tensor), 8532 ns.call_method('dequantize'), 8533 ns.call_module(nn.Sigmoid), 8534 ns.call_function(torch.quantize_per_tensor), 8535 ns.call_method('dequantize'), 8536 ns.call_module(nn.Softmax), 8537 ns.call_function(torch.quantize_per_tensor), 8538 ns.call_method('dequantize'), 8539 ] 8540 self.checkGraphModuleNodes( 8541 quantized_reference, 8542 expected_node_occurrence=reference_count_check, 8543 expected_node_list=reference_order_check) 8544 8545 # Verify that softmax scale and zero_point are correct 8546 self.assertTrue(quantized.softmax.scale - (1.0 / 256) <= 1e-8) 8547 self.assertTrue(quantized.softmax.zero_point == 0) 8548 8549 def test_float_functional(self): 8550 class TorchAdd(nn.Module): 8551 """Wrapper around torch.add so that all ops can be found at build""" 8552 def __init__(self) -> None: 8553 super().__init__() 8554 self.add_func = nnq.FloatFunctional() 8555 8556 def forward(self, x, y): 8557 return self.add_func.add(x, y) 8558 8559 class M(torch.nn.Module): 8560 def __init__(self) -> None: 8561 super().__init__() 8562 self.ff1 = TorchAdd() 8563 self.ff2 = nnq.FloatFunctional() 8564 self.ff3 = nnq.FloatFunctional() 8565 self.ff4 = nnq.FloatFunctional() 8566 self.ff5 = nnq.FloatFunctional() 8567 self.ff6 = nnq.FloatFunctional() 8568 8569 def forward(self, x): 8570 x = self.ff1(x, x) 8571 x = self.ff2.add_scalar(x, 3) 8572 x = self.ff3.mul(x, x) 8573 x = self.ff4.mul_scalar(x, 3) 8574 x = self.ff5.add_relu(x, x) 8575 x = self.ff6.cat([x]) 8576 return x 8577 8578 example_inputs = (torch.rand(3, 3),) 8579 # Note: QAT test succeeded by chance, to make it actually work 8580 # we need to fix eager mode FloatFunctional by removing 8581 # activation_post_process in add_scalar and mul_scalar 8582 for quant_type in self.static_quant_types: 8583 m = M() 8584 ref_m = torch.ao.quantization.QuantWrapper(M()) 8585 is_qat = quant_type == QuantType.QAT 8586 if is_qat: 8587 m.train() 8588 ref_m.train() 8589 qconfig = default_qat_qconfig 8590 expected_act_post_process = torch.ao.quantization.FakeQuantize 8591 else: 8592 m.eval() 8593 ref_m.eval() 8594 qconfig = default_qconfig 8595 expected_act_post_process = torch.ao.quantization.MinMaxObserver 8596 8597 prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx 8598 qconfig_dict = {"": qconfig} 8599 m = prepare_fx_function(m, qconfig_dict, example_inputs=example_inputs) 8600 node_occurrence = { 8601 ns.call_module(expected_act_post_process): 7, 8602 ns.call_module(torch.ao.nn.quantized.FloatFunctional): 0 8603 } 8604 self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) 8605 m(*example_inputs) 8606 node_list = [ 8607 ns.call_function(torch.quantize_per_tensor), 8608 ns.call_function(torch.ops.quantized.add), 8609 ns.call_function(torch.ops.quantized.add), 8610 ns.call_function(torch.ops.quantized.mul), 8611 ns.call_function(torch.ops.quantized.mul), 8612 ns.call_function(torch.ops.quantized.add_relu), 8613 ns.call_function(torch.cat), 8614 ns.call_method('dequantize') 8615 ] 8616 m = convert_fx(m) 8617 self.checkGraphModuleNodes(m, expected_node_list=node_list) 8618 8619 # make sure numerics match with eager mode 8620 ref_m.qconfig = qconfig 8621 prepare_function = prepare_qat if is_qat else prepare 8622 ref_m = prepare_function(ref_m) 8623 ref_m(*example_inputs) 8624 ref_m = convert(ref_m) 8625 # FX Graph Mode and Eager Mode now diverages in numerics of add_scalar and mul_scalar 8626 # self.assertEqual(m(data), ref_m(data)) 8627 8628 def test_embedding(self): 8629 class M(torch.nn.Module): 8630 def __init__(self) -> None: 8631 super().__init__() 8632 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 8633 8634 def forward(self, indices): 8635 return self.emb(indices) 8636 8637 for qconfig_type in [float_qparams_weight_only_qconfig, float_qparams_weight_only_qconfig_4bit]: 8638 model = M().eval() 8639 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 8640 example_inputs = (indices,) 8641 quantized_node = ns.call_module(nnq.Embedding) 8642 8643 # check dynamic quant 8644 self.checkGraphModeFxOp( 8645 model, 8646 example_inputs, 8647 QuantType.DYNAMIC, 8648 quantized_node, 8649 custom_qconfig_dict={"": qconfig_type} 8650 ) 8651 model = M().eval() 8652 8653 configs = [ 8654 (qconfig_type, ns.call_module(nnq.Embedding)), 8655 (None, ns.call_module(nn.Embedding)), 8656 (default_qconfig, ns.call_module(nn.Embedding)), 8657 ] 8658 8659 # check static quantization 8660 for qconfig, node in configs: 8661 qconfig_dict = {"": qconfig} 8662 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 8663 self.checkGraphModuleNodes(m, expected_node_occurrence={ 8664 ns.call_module(torch.ao.quantization.MinMaxObserver): 0 8665 }) 8666 m = convert_fx(m) 8667 self.checkGraphModuleNodes(m, expected_node=node) 8668 # make sure it runs 8669 m(*example_inputs) 8670 8671 def test_embedding_bag(self): 8672 class M(torch.nn.Module): 8673 def __init__(self) -> None: 8674 super().__init__() 8675 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True) 8676 8677 def forward(self, indices, offsets): 8678 return self.emb(indices, offsets) 8679 8680 indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) 8681 offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 8682 quantized_node = ns.call_module(nnq.EmbeddingBag) 8683 example_inputs = (indices, offsets) 8684 8685 for dtype in [torch.quint8, torch.quint4x2]: 8686 model = M().eval() 8687 float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, 8688 qscheme=torch.per_channel_affine_float_qparams, 8689 ch_axis=0) 8690 float_qparams_qconfig = QConfig(activation=default_placeholder_observer, 8691 weight=float_qparams_observer) 8692 self.checkGraphModeFxOp( 8693 model, 8694 example_inputs, 8695 QuantType.DYNAMIC, 8696 quantized_node, 8697 custom_qconfig_dict={"": float_qparams_qconfig} 8698 ) 8699 8700 # check it works in None and static qconfig 8701 for qconfig in [None, default_qconfig]: 8702 qconfig_dict = {"": default_qconfig} 8703 m = M().eval() 8704 m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 8705 self.checkGraphModuleNodes(m, expected_node_occurrence={ 8706 ns.call_module(torch.ao.quantization.MinMaxObserver): 0 8707 }) 8708 m = convert_fx(m) 8709 self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag)) 8710 # make sure it runs 8711 m(*example_inputs) 8712 8713 def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input): 8714 options = itertools.product(qconfigs, module_type_strs) 8715 for qconfig, module_type_str in options: 8716 model_eager = M(module_type_str).eval() 8717 model_graph = copy.deepcopy(model_eager) 8718 if torch.backends.quantized.engine == 'qnnpack' and \ 8719 qconfig is float16_dynamic_qconfig: 8720 continue 8721 # fp16 dynamic quant is not supported for qnnpack 8722 8723 eager_qconfig_dict = dict.fromkeys(module_types, qconfig) 8724 model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict) 8725 8726 graph_qconfig_dict = { 8727 "object_type": [ 8728 (x, qconfig) for x in module_types 8729 ] 8730 } 8731 model_graph = prepare_fx(model_graph, graph_qconfig_dict, example_inputs=(sample_input,)) 8732 model_graph = convert_fx(model_graph) 8733 self.assertEqual(model_eager(sample_input), model_graph(sample_input)) 8734 self.checkScriptable(model_graph, [[sample_input]], True) 8735 8736 @override_qengines 8737 def test_rnn_cell(self): 8738 if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): 8739 return 8740 qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] 8741 module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU'] 8742 module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell] 8743 sample_input = torch.tensor([[100, -155], 8744 [-155, 100], 8745 [100, -155]], dtype=torch.float) 8746 self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input) 8747 8748 @override_qengines 8749 def test_rnn(self): 8750 if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): 8751 return 8752 qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig] 8753 module_type_strs = ['LSTM', 'GRU'] 8754 module_types = [torch.nn.LSTM, torch.nn.GRU] 8755 niter = 10 8756 sample_input = torch.tensor([[100, -155], 8757 [-155, 100], 8758 [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) 8759 self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input) 8760 8761 def _test_conv_transpose_impl( 8762 self, float_cls: Callable, q_cls: Callable, data: torch.Tensor): 8763 with override_quantized_engine('qnnpack'): 8764 # Create fp32 versions of FX and Eager models 8765 m1 = torch.nn.Sequential(float_cls(1, 1, 1)) 8766 m2 = torch.nn.Sequential(float_cls(1, 1, 1)) 8767 m2.load_state_dict(m1.state_dict()) 8768 m2 = torch.ao.quantization.QuantWrapper(m2) 8769 # FX graph 8770 result_dict = self.checkGraphModeFxOp( 8771 m1, (data,), QuantType.STATIC, 8772 expected_node_occurrence={ 8773 ns.call_module(q_cls): 1, 8774 }) 8775 q_result1 = result_dict["quantized_output"] 8776 # Eager 8777 m2.qconfig = get_default_qconfig(torch.backends.quantized.engine) 8778 m2.eval() 8779 m2p = torch.ao.quantization.prepare(m2) 8780 m2p(data) 8781 m2q = torch.ao.quantization.convert(m2p) 8782 q_result2 = m2q(data) 8783 # verify results match 8784 self.assertEqual(q_result1, q_result2) 8785 8786 @unittest.skipUnless('qnnpack' in supported_qengines, 8787 "This Pytorch Build has not been built with or does not support QNNPACK") 8788 def test_conv_transpose_1d(self): 8789 self._test_conv_transpose_impl( 8790 torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4)) 8791 8792 @unittest.skipUnless('qnnpack' in supported_qengines, 8793 "This Pytorch Build has not been built with or does not support QNNPACK") 8794 def test_conv_transpose_2d(self): 8795 self._test_conv_transpose_impl( 8796 torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4)) 8797 8798 def test_reshape_fp16(self): 8799 class M(torch.nn.Module): 8800 def __init__(self, w, b): 8801 super().__init__() 8802 self.w = w 8803 self.b = b 8804 8805 def forward(self, x): 8806 x = torch.nn.functional.linear(x, self.w) 8807 x = x.reshape(-1, 4) 8808 x = torch.nn.functional.linear(x, self.w) 8809 return x 8810 8811 w = torch.randn(4, 4) 8812 b = torch.randn(4) 8813 m = M(w, b).eval() 8814 qconfig_dict = { 8815 # reshape will be quantized to fp16 as requested by this qconfig 8816 "": float16_static_qconfig, 8817 "object_type": [ 8818 (torch.nn.functional.linear, default_qconfig) 8819 ] 8820 } 8821 backend_config = get_test_only_legacy_native_backend_config() 8822 example_inputs = (torch.randn(1, 4),) 8823 m = prepare_fx( 8824 m, qconfig_dict, example_inputs=example_inputs, 8825 backend_config=backend_config) 8826 expected_occurrence = { 8827 # input and weight of first and second linear, output of first and second linear 8828 ns.call_module(torch.ao.quantization.MinMaxObserver): 6, 8829 # we insert placeholder observer for both input and output of reshape 8830 ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 8831 } 8832 self.checkGraphModuleNodes( 8833 m, 8834 expected_node_occurrence=expected_occurrence 8835 ) 8836 m = convert_fx(m, backend_config=backend_config) 8837 expected_occurrence = { 8838 ns.call_function(torch.quantize_per_tensor): 2, 8839 # dequantize after first linear, before reshape and before output 8840 ns.call_method("dequantize"): 3, 8841 # before reshape, to(fp16) 8842 ns.call_method("to"): 1, 8843 ns.call_function(torch.ops.quantized.linear): 2 8844 } 8845 self.checkGraphModuleNodes( 8846 m, 8847 expected_node_occurrence=expected_occurrence 8848 ) 8849 # make sure it runs 8850 m(torch.randn(2, 4)) 8851 8852 def test_multiple_qconfigs_for_single_value(self): 8853 """ Test multiple qconfigs for a single value""" 8854 class M(torch.nn.Module): 8855 def __init__(self, w, b): 8856 super().__init__() 8857 self.w = w 8858 self.b = b 8859 8860 def forward(self, x): 8861 x = torch.nn.functional.linear(x, self.w) 8862 x = torch.sigmoid(x) 8863 return x 8864 8865 w = torch.randn(4, 4) 8866 b = torch.randn(4) 8867 m = M(w, b).eval() 8868 # TODO: use get_default_qconfig_mapping once it handles fp16 8869 qconfig_mapping = QConfigMapping() \ 8870 .set_global(float16_static_qconfig) \ 8871 .set_object_type(torch.nn.functional.linear, default_qconfig) 8872 example_inputs = (torch.randn(1, 4),) 8873 backend_config = get_test_only_legacy_native_backend_config() 8874 m = prepare_fx( 8875 m, qconfig_mapping, example_inputs=example_inputs, 8876 backend_config=backend_config) 8877 expected_occurrence = { 8878 # input and weight of linear, output of linear 8879 ns.call_module(torch.ao.quantization.MinMaxObserver): 3, 8880 # input and output of sigmoid 8881 ns.call_module(torch.ao.quantization.PlaceholderObserver): 2, 8882 } 8883 self.checkGraphModuleNodes( 8884 m, 8885 expected_node_occurrence=expected_occurrence 8886 ) 8887 # make sure it runs 8888 m = convert_fx(m) 8889 expected_occurrence = { 8890 ns.call_function(torch.quantize_per_tensor): 1, 8891 ns.call_method("dequantize"): 3, 8892 ns.call_method("to"): 2 8893 } 8894 self.checkGraphModuleNodes( 8895 m, 8896 expected_node_occurrence=expected_occurrence 8897 ) 8898 8899 def test_boolean_tensor(self): 8900 """ Make sure we don't insert observer for boolean Tensors """ 8901 class M(torch.nn.Module): 8902 def forward(self, x, mask): 8903 mask = mask.unsqueeze(0) 8904 mask = mask.unsqueeze(1) 8905 x = x.masked_fill(mask, 1) 8906 return x 8907 8908 m = M().eval() 8909 example_inputs = (torch.rand(1, 2, 3, 4), torch.rand(3, 4).bool()) 8910 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 8911 expected_occurrence = { 8912 ns.call_module(torch.ao.quantization.MinMaxObserver): 0 8913 } 8914 self.checkGraphModuleNodes( 8915 m, 8916 expected_node_occurrence=expected_occurrence) 8917 m = convert_fx(m) 8918 m(*example_inputs) 8919 8920 def test_chunk(self): 8921 class M(torch.nn.Module): 8922 def forward(self, x): 8923 x, y = torch.chunk(x, 2) 8924 x = x + y 8925 return x 8926 m = M().eval() 8927 example_inputs = (torch.rand(2, 2, 2, 2),) 8928 m = prepare_fx(m, {"": default_qconfig}, example_inputs=example_inputs) 8929 m(*example_inputs) 8930 m = convert_fx(m) 8931 m(*example_inputs) 8932 # make sure everything runs 8933 8934 def test_ref_pattern_multi_use(self): 8935 class M(torch.nn.Module): 8936 def __init__(self) -> None: 8937 super().__init__() 8938 self.linear = torch.nn.Linear(5, 5) 8939 self.linear1 = torch.nn.Linear(5, 5) 8940 8941 def forward(self, x): 8942 y = self.linear(x) 8943 z = self.linear1(x) 8944 a = torch.mul(z, 5) 8945 b = torch.add(z, 5) 8946 return (y, a, b) 8947 8948 m = M().eval() 8949 qconfig_dict = { 8950 "": None, 8951 "object_type": [ 8952 (torch.nn.Linear, get_default_qconfig("fbgemm")), 8953 (torch.nn.ReLU, get_default_qconfig("fbgemm")), 8954 ], 8955 } 8956 example_inputs = (torch.randn(1, 5),) 8957 m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 8958 m = convert_fx(m) 8959 expected_occurrence = { 8960 ns.call_function(torch.quantize_per_tensor): 1, 8961 ns.call_module(nnq.Linear): 2, 8962 ns.call_method("dequantize"): 2, 8963 ns.call_function(torch.add): 1, 8964 ns.call_function(torch.mul): 1, 8965 } 8966 self.checkGraphModuleNodes( 8967 m, 8968 expected_node_occurrence=expected_occurrence) 8969 8970 def test_qmatmul(self): 8971 class M(torch.nn.Module): 8972 def forward(self, x, y): 8973 z = torch.matmul(x, y) 8974 return z 8975 8976 m = M().eval() 8977 example_inputs = (torch.randn(2, 2), torch.randn(2, 2)) 8978 qconfig_dict = get_default_qconfig_mapping("fbgemm") 8979 mp = prepare_fx(m, qconfig_dict, example_inputs=example_inputs) 8980 mp(*example_inputs) 8981 mq = convert_fx(mp) 8982 expected_occurrence = { 8983 ns.call_function(torch.matmul): 0, 8984 ns.call_function(torch.ops.quantized.matmul): 1, 8985 } 8986 self.checkGraphModuleNodes( 8987 mq, 8988 expected_node_occurrence=expected_occurrence) 8989 # verify no crash 8990 res = mq(*example_inputs) 8991 8992 def test_pixel_shuffle(self): 8993 class MyBias(nn.Module): 8994 def __init__(self) -> None: 8995 super().__init__() 8996 self.bias = nn.Parameter(torch.randn(8)) 8997 8998 class MyModel(nn.Module): 8999 def __init__(self) -> None: 9000 super().__init__() 9001 self.conv = nn.Conv2d(8, 8, 1, bias=False) 9002 self.bias = MyBias() 9003 9004 def forward(self, x): 9005 x = self.conv(x) 9006 x = nn.functional.pixel_shuffle(x, 2) 9007 x = x.view(-1, 8, 2, 2) 9008 bias = self.bias.bias 9009 return x + bias 9010 9011 backend_config = get_qnnpack_backend_config() 9012 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 9013 model = MyModel() 9014 m = prepare_fx( 9015 model, 9016 qconfig_mapping=qconfig_mapping, 9017 example_inputs=(torch.randn(1, 8, 3, 3),), 9018 backend_config=backend_config 9019 ) 9020 m = convert_fx(m) 9021 expected_occurrence = { 9022 ns.call_function(torch.quantize_per_tensor): 2, 9023 ns.call_method("dequantize"): 1, 9024 } 9025 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) 9026 9027 def test_pixel_shuffle_module(self) -> None: 9028 class MyBias(nn.Module): 9029 def __init__(self) -> None: 9030 super().__init__() 9031 self.bias = nn.Parameter(torch.randn(8)) 9032 9033 class MyModel(nn.Module): 9034 def __init__(self) -> None: 9035 super().__init__() 9036 self.conv = nn.Conv2d(8, 8, 1, bias=False) 9037 self.ps = nn.PixelShuffle(upscale_factor=2) 9038 self.bias = MyBias() 9039 9040 def forward(self, x): 9041 x = self.conv(x) 9042 x = self.ps(x) 9043 x = x.view(-1, 8, 2, 2) 9044 bias = self.bias.bias 9045 return x + bias 9046 9047 backend_config = get_qnnpack_backend_config() 9048 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 9049 model = MyModel() 9050 m = prepare_fx( 9051 model, 9052 qconfig_mapping=qconfig_mapping, 9053 example_inputs=(torch.randn(1, 8, 3, 3),), 9054 backend_config=backend_config 9055 ) 9056 m = convert_fx(m) 9057 expected_occurrence = { 9058 ns.call_function(torch.quantize_per_tensor): 2, 9059 ns.call_method("dequantize"): 1, 9060 ns.call_module(nn.PixelShuffle): 1, 9061 } 9062 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) 9063 9064 def test_pixel_unshuffle(self): 9065 class MyBias(nn.Module): 9066 def __init__(self) -> None: 9067 super().__init__() 9068 self.bias = nn.Parameter(torch.randn(64)) 9069 9070 class MyModel(nn.Module): 9071 def __init__(self) -> None: 9072 super().__init__() 9073 self.conv = nn.Conv2d(8, 8, 1, bias=False) 9074 self.bias = MyBias() 9075 9076 def forward(self, x): 9077 x = self.conv(x) 9078 x = nn.functional.pixel_unshuffle(x, 2) 9079 bias = self.bias.bias 9080 return x + bias 9081 9082 for backend in ["fbgemm", "qnnpack"]: 9083 if backend == "fbgemm": 9084 backend_config = get_fbgemm_backend_config() 9085 else: 9086 backend_config = get_qnnpack_backend_config() 9087 qconfig_mapping = get_default_qconfig_mapping(backend) 9088 model = MyModel() 9089 m = prepare_fx( 9090 model, 9091 qconfig_mapping=qconfig_mapping, 9092 example_inputs=(torch.randn(1, 8, 6, 6),), 9093 backend_config=backend_config 9094 ) 9095 m = convert_fx(m) 9096 expected_occurrence = { 9097 ns.call_function(torch.quantize_per_tensor): 2, 9098 ns.call_method("dequantize"): 1, 9099 } 9100 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) 9101 9102 def test_pixel_unshuffle_module(self) -> None: 9103 class MyBias(nn.Module): 9104 def __init__(self) -> None: 9105 super().__init__() 9106 self.bias = nn.Parameter(torch.randn(64)) 9107 9108 class MyModel(nn.Module): 9109 def __init__(self) -> None: 9110 super().__init__() 9111 self.conv = nn.Conv2d(8, 8, 1, bias=False) 9112 self.unshuffle = nn.PixelUnshuffle(downscale_factor=2) 9113 self.bias = MyBias() 9114 9115 def forward(self, x): 9116 x = self.conv(x) 9117 x = self.unshuffle(x) 9118 bias = self.bias.bias 9119 return x + bias 9120 9121 for backend in ["fbgemm", "qnnpack"]: 9122 if backend == "fbgemm": 9123 backend_config = get_fbgemm_backend_config() 9124 else: 9125 backend_config = get_qnnpack_backend_config() 9126 qconfig_mapping = get_default_qconfig_mapping(backend) 9127 model = MyModel() 9128 m = prepare_fx( 9129 model, 9130 qconfig_mapping=qconfig_mapping, 9131 example_inputs=(torch.randn(1, 8, 6, 6),), 9132 backend_config=backend_config 9133 ) 9134 m = convert_fx(m) 9135 expected_occurrence = { 9136 ns.call_function(torch.quantize_per_tensor): 2, 9137 ns.call_method("dequantize"): 1, 9138 ns.call_module(nn.PixelUnshuffle): 1, 9139 } 9140 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) 9141 9142 9143 9144 def test_narrow(self): 9145 class MyBias(nn.Module): 9146 def __init__(self) -> None: 9147 super().__init__() 9148 self.bias = nn.Parameter(torch.randn(4)) 9149 9150 class MyModel(nn.Module): 9151 def __init__(self) -> None: 9152 super().__init__() 9153 self.conv = nn.Conv2d(8, 8, 1, bias=False) 9154 self.bias = MyBias() 9155 9156 def forward(self, x): 9157 x = self.conv(x) 9158 x = torch.narrow(x, 1, 0, 4) 9159 bias = self.bias.bias 9160 return x + bias 9161 9162 for backend in ["fbgemm", "qnnpack"]: 9163 if backend == "fbgemm": 9164 backend_config = get_fbgemm_backend_config() 9165 else: 9166 backend_config = get_qnnpack_backend_config() 9167 qconfig_mapping = get_default_qconfig_mapping(backend) 9168 model = MyModel() 9169 m = prepare_fx( 9170 model, 9171 qconfig_mapping=qconfig_mapping, 9172 example_inputs=(torch.randn(1, 8, 3, 3),), 9173 backend_config=backend_config 9174 ) 9175 m = convert_fx(m) 9176 expected_occurrence = { 9177 ns.call_function(torch.quantize_per_tensor): 2, 9178 ns.call_method("dequantize"): 1, 9179 } 9180 self.checkGraphModuleNodes(m, expected_node_occurrence=expected_occurrence) 9181 9182class TestQuantizeFxModels(QuantizationTestCase): 9183 @skipIfNoFBGEMM 9184 @unittest.skipIf(not TEST_CUDA, "gpu is not available.") 9185 def test_static_gpu_convert_basic(self): 9186 9187 class Net(nn.Module): 9188 def __init__(self) -> None: 9189 super().__init__() 9190 self.relu1 = nn.ReLU() 9191 self.conv1 = nn.Conv2d(1, 6, 5) 9192 self.linear1 = nn.Linear(120, 1) 9193 9194 def forward(self, x): 9195 x = self.relu1(self.conv1(x)) 9196 y = self.linear1(x.view(-1)) 9197 return y 9198 9199 input = torch.randn((5, 1, 6, 6)).to('cuda') 9200 example_inputs = (input,) 9201 model = Net().to('cuda').eval() 9202 qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} 9203 model_prepared = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) 9204 model_prepared(*example_inputs) 9205 model_quantized = convert_to_reference_fx(model_prepared) 9206 out = model_quantized(*example_inputs) 9207 self.assertEqual(out.device.type, 'cuda') 9208 9209 @skipIfNoFBGEMM 9210 @unittest.skipIf(not TEST_CUDA, "gpu is not available.") 9211 def test_switch_device_prepare_convert(self): 9212 9213 class Net(nn.Module): 9214 def __init__(self) -> None: 9215 super().__init__() 9216 self.relu1 = nn.ReLU() 9217 self.conv1 = nn.Conv2d(1, 6, 5) 9218 self.linear1 = nn.Linear(120, 1) 9219 9220 def forward(self, x): 9221 x = self.relu1(self.conv1(x)) 9222 y = self.linear1(x.view(-1)) 9223 return y 9224 9225 for device in ['cuda', 'cpu']: 9226 device_after = 'cuda' if device == 'cpu' else 'cpu' 9227 input = torch.randn((5, 1, 6, 6)).to(device) 9228 model = Net().to(device).eval() 9229 qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} 9230 model_prepared = prepare_fx(model, qconfig_dict, example_inputs=(input,)) 9231 model_prepared(input) 9232 model_prepared.to(device_after) 9233 model_quantized = convert_to_reference_fx(model_prepared) 9234 out = model_quantized(input.to(device_after)) 9235 self.assertEqual(out.device.type, device_after) 9236 9237 @skipIfNoFBGEMM 9238 @unittest.skipIf(not TEST_CUDA, "gpu is not available.") 9239 def test_prepare_serialize_switch_device_convert(self): 9240 class Net(nn.Module): 9241 def __init__(self) -> None: 9242 super().__init__() 9243 self.conv1 = nn.Conv2d(1, 6, 5) 9244 self.linear1 = nn.Linear(120, 1) 9245 9246 def forward(self, x): 9247 x = self.conv1(x) 9248 y = self.linear1(x.view(-1)) 9249 return y 9250 9251 for device in ['cuda', 'cpu']: 9252 for device_after in ['cuda', 'cpu']: 9253 input = torch.randn((5, 1, 6, 6)).to(device) 9254 model = Net().to(device).eval() 9255 qconfig_dict = {"": torch.ao.quantization.get_default_qconfig('fbgemm')} 9256 model_prepared_first = prepare_fx(model, qconfig_dict, example_inputs=(input,)) 9257 model_prepared_second = prepare_fx(model, qconfig_dict, example_inputs=(input,)) 9258 model_prepared_first(input) 9259 state_dict = model_prepared_first.state_dict() 9260 del model_prepared_first 9261 model_prepared_second.load_state_dict(state_dict) 9262 model_prepared_second.to(device_after) 9263 model_quantized = convert_to_reference_fx(model_prepared_second) 9264 out = model_quantized(input.to(device_after)) 9265 self.assertEqual(out.device.type, device_after) 9266 9267 @skipIfTorchDynamo("too slow") 9268 @skip_if_no_torchvision 9269 def test_model_dropout(self): 9270 from torchvision import models 9271 m = models.mobilenet_v3_small() 9272 qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping('fbgemm') 9273 example_inputs = (torch.randn(1, 3, 224, 224),) 9274 mp = prepare_qat_fx(m, qconfig_mapping, example_inputs=example_inputs) 9275 mp(*example_inputs) 9276 with override_quantized_engine("qnnpack") if IS_ARM64 else contextlib.nullcontext(): 9277 mq = convert_fx(mp) 9278 mq(*example_inputs) 9279 9280 def _test_model_impl( 9281 self, mode, name, model, eager_quantizable_model, 9282 check_with_eager=True, 9283 diff_of_quant=None, 9284 diff_from_eager=None): 9285 if diff_of_quant is None or diff_from_eager is None: 9286 diff_of_quant = {} 9287 diff_from_eager = {} 9288 9289 if mode not in diff_of_quant or mode not in diff_from_eager: 9290 diff_of_quant[mode] = {} 9291 diff_from_eager[mode] = {} 9292 9293 input_tensor = torch.rand(1, 3, 224, 224) 9294 input_tensor_inception = torch.rand(1, 3, 299, 299) 9295 output_value = torch.randint(0, 1, (1,)) 9296 9297 # print('quantizing:', name, ' mode:', mode) 9298 if name == 'inception_v3': 9299 input_value = input_tensor_inception 9300 else: 9301 input_value = input_tensor 9302 9303 qconfig = default_qconfig if mode == 'static' else default_qat_qconfig 9304 qconfig_dict = {'': qconfig} 9305 script = torch.jit.script(model) 9306 9307 # make sure graph module and script module are both runanble 9308 original_out = model(input_value) 9309 is_not_tuple_out = not isinstance(original_out, tuple) 9310 script_out = script(input_value) 9311 9312 # set to train just before quantization 9313 prepare_fx_fn = prepare_fx 9314 if mode != 'static': 9315 model.train() 9316 prepare_fx_fn = prepare_qat_fx 9317 9318 prepared = prepare_fx_fn(model, qconfig_dict) 9319 9320 if mode == 'ddp': 9321 mp.spawn(run_ddp, 9322 args=(world_size, prepared), # noqa: F821 9323 nprocs=world_size, # noqa: F821 9324 join=True) 9325 elif mode == 'qat': 9326 assert prepared.training, 'prepared must be in training mode for qat' 9327 optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) 9328 criterion = nn.CrossEntropyLoss() 9329 train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) 9330 else: 9331 for i in range(10): 9332 prepared(input_value) 9333 9334 # print('after observation root:', prepared.root) 9335 9336 qgraph = convert_fx(prepared) 9337 # print('after quantization root:', qgraph.root) 9338 # print('after quantization code:', qgraph.src) 9339 qgraph.eval() 9340 qgraph_script = torch.jit.script(qgraph) 9341 # print('quantized and scripted:', qgraph_script.graph) 9342 9343 qgraph_out = qgraph(input_value) 9344 qgraph_script = qgraph_script(input_value) 9345 9346 if is_not_tuple_out: 9347 diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() 9348 assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' 9349 else: 9350 print('tuple output') 9351 9352 if eager_quantizable_model is not None: 9353 # comparing to eager mode quantization 9354 qeager = eager_quantizable_model 9355 ref_out = qeager(input_value) 9356 qeager.qconfig = qconfig 9357 if mode == 'static': 9358 qeager.fuse_model() 9359 prepare(qeager, inplace=True) 9360 else: 9361 qeager.train() 9362 qeager.fuse_model() 9363 prepare_qat(qeager, inplace=True) 9364 9365 # calibration 9366 if mode == 'ddp': 9367 mp.spawn(run_ddp, 9368 args=(world_size, qeager), # noqa: F821 9369 nprocs=world_size, # noqa: F821 9370 join=True) 9371 elif mode == 'qat': 9372 assert qeager.training, 'qeager should be in training mode for qat' 9373 optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) 9374 train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) 9375 else: 9376 for i in range(10): 9377 qeager(input_value) 9378 9379 # print('ref after observation:', qeager) 9380 9381 convert(qeager, inplace=True) 9382 qeager.eval() 9383 9384 # print('ref after quantization:', qeager) 9385 qeager_out = qeager(input_value) 9386 qeager_script = torch.jit.script(qeager) 9387 qscript_out = qeager_script(input_value) 9388 if is_not_tuple_out: 9389 diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() 9390 if check_with_eager: 9391 self.assertEqual(diff_from_eager[mode][name], 0, 9392 'Result of graph mode quantization and ' + 9393 'eager mode quantization on model: ' + name + 9394 ' should match. Mode: ' + mode + 9395 ' diff:' + str(diff_from_eager[mode][name])) 9396 9397 def _test_building_block(self, quant_type, BB): 9398 eager = BB().float() 9399 graph = copy.deepcopy(eager) 9400 9401 if quant_type == QuantType.STATIC: 9402 qconfig = default_qconfig 9403 eager_prepare = prepare 9404 graph_prepare = prepare_fx 9405 eager.eval() 9406 graph.eval() 9407 calibrate_or_train = test_only_eval_fn 9408 data = self.img_data_2d 9409 is_qat = False 9410 else: 9411 assert quant_type == QuantType.QAT 9412 qconfig = default_qat_qconfig 9413 eager_prepare = prepare_qat 9414 graph_prepare = prepare_qat_fx 9415 eager.train() 9416 graph.train() 9417 calibrate_or_train = test_only_train_fn 9418 data = self.img_data_2d_train 9419 is_qat = True 9420 9421 if hasattr(eager, "fuse_model"): 9422 eager.fuse_model() 9423 eager = QuantWrapper(eager) 9424 eager.qconfig = qconfig 9425 eager = eager_prepare(eager) 9426 9427 qconfig_dict = {"": qconfig} 9428 graph = graph_prepare(graph, qconfig_dict, example_inputs=(data[0][0],)) 9429 9430 eager_out = eager(data[0][0]) 9431 graph_out = graph(data[0][0]) 9432 # Eager Mode and FX Graph Mode QAT now differ in numerics both 9433 # in Post Training and QAT because FX Graph Mode uses same fake_quant instances 9434 # for input and output of CopyNode 9435 # self.assertEqual(eager_out, graph_out) 9436 9437 calibrate_or_train(eager, data) 9438 calibrate_or_train(graph, data) 9439 9440 eager = convert(eager) 9441 graph = convert_fx(graph) 9442 9443 eager_out = eager(data[0][0]) 9444 graph_out = graph(data[0][0]) 9445 9446 @override_qengines 9447 def test_resnet_base(self): 9448 models = [ResNetBase] 9449 options = itertools.product(self.static_quant_types, models) 9450 for quant_type, M in options: 9451 self._test_building_block(quant_type, M) 9452 9453 @skip_if_no_torchvision 9454 @skipIfNoFBGEMM 9455 @unittest.skip("skip for now since tbb failed") 9456 def test_torchvision(self): 9457 from torchvision import models 9458 from torchvision.models import quantization as quantized_models 9459 from torchvision.models.quantization.utils import _replace_relu 9460 9461 def get_available_classification_models(models): 9462 return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] 9463 9464 model_list = get_available_classification_models(models) 9465 quantized_model_list = get_available_classification_models(quantized_models) 9466 9467 quantized_model_list = set(quantized_model_list) 9468 # test eager and graph consistency 9469 model_list = quantized_model_list 9470 # mobilenet/inception_v3/googlenet qat is not working due to AdaptiveAveragePool qat 9471 # we might observe the output of AdaptiveAveragePool in the future 9472 # and re-enable the test 9473 fx_eager_not_matching = [ 9474 ("mobilenet_v2", "qat"), 9475 ("inception_v3", "qat"), 9476 ("googlenet", "qat") 9477 ] # because relu6 is replaced as relu in mobilenetv2 9478 9479 diff_of_quant = {} 9480 diff_from_eager = {} 9481 modes = ['static', 'qat'] 9482 options = itertools.product(modes, model_list) 9483 for mode, name in options: 9484 pretrained = name in quantized_model_list # load pretrained model to compare with quantized model 9485 kwargs = {} 9486 # turn off transform input for inception_v3 since 9487 # it's not quantized in eager mode and in fx graph 9488 # mode we can't skip quantizing a method right now 9489 # (might be supported in the future) 9490 if name in ["inception_v3", "googlenet"]: 9491 kwargs["transform_input"] = False 9492 eager_quantizable_model = None 9493 if name in quantized_model_list: 9494 eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False, **kwargs).eval().float() 9495 # compare with eager mode quantized model when it is available 9496 pretrained = eager_quantizable_model is not None 9497 model = models.__dict__[name](pretrained=pretrained, **kwargs).eval().float() 9498 if name == "mobilenet_v2": 9499 _replace_relu(model) 9500 # disable aux logits 9501 if hasattr(model, "aux_logits"): 9502 model.aux_logits = False 9503 model.AuxLogits = None 9504 if eager_quantizable_model: 9505 eager_quantizable_model.aux_logits = False 9506 eager_quantizable_model.AuxLogits = None 9507 9508 check_with_eager = (name, mode) not in fx_eager_not_matching 9509 self._test_model_impl( 9510 mode, name, model, eager_quantizable_model, 9511 check_with_eager, 9512 diff_of_quant, diff_from_eager) 9513 9514 def print_diffs(diffs): 9515 for mode, diffs_for_mode in diffs.items(): 9516 print('mode:', mode) 9517 for name, diff in diffs_for_mode.items(): 9518 print(name, ':', diff) 9519 9520 # print('differences between float and quantized') 9521 # print_diffs(diff_of_quant) 9522 # print('----------------------') 9523 # print('differences between graph mode and eager mode') 9524 # print_diffs(diff_from_eager) 9525 # print('----------------------') 9526 9527 @skip_if_no_torchvision 9528 @skipIfNoFBGEMM 9529 @unittest.skip("TODO: Test is always failing - https://github.com/pytorch/pytorch/issues/54979") 9530 def test_resnet18_ddp(self): 9531 from torchvision import models 9532 from torchvision.models import quantization as quantized_models 9533 eager_quantizable_model = quantized_models.__dict__[name](pretrained=False, quantize=False).eval().float() # noqa: F821 9534 model = models.__dict__[name](pretrained=False).eval().float() # noqa: F821 9535 self._test_model_impl( 9536 'ddp', 'resnet18', model, eager_quantizable_model) 9537 9538 @override_qengines 9539 def test_qat_embeddingbag_linear(self): 9540 for device in get_supported_device_types(): 9541 class EmbeddingBagLinear(torch.nn.Module): 9542 def __init__(self) -> None: 9543 super().__init__() 9544 self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') 9545 self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) 9546 9547 def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, 9548 per_sample_weights: Optional[torch.Tensor] = None): 9549 x = self.emb(input, offsets, per_sample_weights) 9550 x = self.linear(x) 9551 return x 9552 9553 qengine = torch.backends.quantized.engine 9554 qconfig_dict = QConfigMapping() \ 9555 .set_global(get_default_qat_qconfig(qengine)) \ 9556 .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) 9557 9558 train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] 9559 eval_output = [[torch.randint(0, 10, (12, 1))]] 9560 9561 model = EmbeddingBagLinear().train() 9562 prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) 9563 test_only_train_fn(prepared_fx_model, train_indices) 9564 quant_model = convert_fx(prepared_fx_model, 9565 qconfig_mapping=qconfig_dict) 9566 9567 def checkQuantized(model): 9568 # Make sure EmbeddingBag is now a quantized EmbeddingBag. 9569 self.assertTrue(type(model.emb), nn.quantized.EmbeddingBag) 9570 # Also test that Linear has been quantized. 9571 self.assertTrue(type(model.linear), nnq.Linear) 9572 9573 test_only_eval_fn(model, eval_output) 9574 self.checkScriptable(model, eval_output) 9575 self.checkNoQconfig(model) 9576 checkQuantized(quant_model) 9577 9578 9579 @override_qengines 9580 def test_qat_embedding_linear(self): 9581 for device in get_supported_device_types(): 9582 class EmbeddingLinear(torch.nn.Module): 9583 def __init__(self) -> None: 9584 super().__init__() 9585 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 9586 self.linear = torch.nn.Linear(12, 1).to(dtype=torch.float) 9587 9588 def forward(self, input: torch.Tensor): 9589 x = torch.sum(self.emb(input), dim=1) 9590 x = self.linear(x) 9591 return x 9592 9593 qengine = torch.backends.quantized.engine 9594 qconfig_dict = {"": get_default_qat_qconfig(qengine), 9595 "object_type": [(torch.nn.Embedding, default_embedding_qat_qconfig)]} 9596 9597 9598 train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] 9599 eval_output = [[torch.randint(0, 10, (12, 1))]] 9600 9601 model = EmbeddingLinear().train() 9602 prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) 9603 test_only_train_fn(prepared_fx_model, train_indices) 9604 quant_model = convert_fx(prepared_fx_model, 9605 qconfig_mapping=qconfig_dict) 9606 9607 def checkQuantized(model): 9608 # Make sure EmbeddingBag is now a quantized EmbeddingBag. 9609 self.assertTrue(type(model.emb), nn.quantized.Embedding) 9610 # Also test that Linear has been quantized. 9611 self.assertTrue(type(model.linear), nnq.Linear) 9612 9613 test_only_eval_fn(model, eval_output) 9614 self.checkScriptable(model, eval_output) 9615 self.checkNoQconfig(model) 9616 checkQuantized(quant_model) 9617 9618 @given( 9619 device=st.sampled_from( 9620 ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] 9621 ) 9622 ) 9623 @settings(deadline=None) 9624 @override_qengines 9625 def test_qat_functional_linear(self, device): 9626 if torch.backends.quantized.engine not in ('fbgemm', 'qnnpack'): 9627 return 9628 9629 class Linear(torch.nn.Module): 9630 def __init__(self) -> None: 9631 super().__init__() 9632 self.w = torch.ones(5, 5) 9633 self.b = torch.zeros(5) 9634 9635 def forward(self, x): 9636 return torch.nn.functional.linear(x, self.w, self.b) 9637 9638 class M(torch.nn.Module): 9639 def __init__(self) -> None: 9640 super().__init__() 9641 self.mods1 = torch.nn.Sequential(Linear(), Linear()) 9642 self.mods2 = Linear() 9643 9644 def forward(self, x): 9645 x = self.mods1(x) 9646 x = self.mods2(x) 9647 return x 9648 9649 model = M().train() 9650 ref_fake_quant = FakeQuantize.with_args( 9651 observer=MovingAverageMinMaxObserver, 9652 quant_min=0, 9653 quant_max=255, 9654 dtype=torch.quint8, 9655 reduce_range=False, 9656 ) 9657 ref_weight_fake_quant = FakeQuantize.with_args( 9658 observer=MovingAverageMinMaxObserver, 9659 quant_min=-128, 9660 quant_max=127, 9661 dtype=torch.qint8, 9662 reduce_range=False, 9663 ) 9664 ref_qat_qconfig = QConfig( 9665 activation=ref_fake_quant, weight=ref_weight_fake_quant 9666 ) 9667 qconfig_dict = {"": ref_qat_qconfig} 9668 example_inputs = (torch.randn(1, 5),) 9669 prepared_ref = prepare_qat_fx(model, qconfig_dict, example_inputs=example_inputs) 9670 9671 custom_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( 9672 observer=MovingAverageMinMaxObserver, 9673 quant_min=0, 9674 quant_max=255, 9675 dtype=torch.quint8, 9676 reduce_range=False, 9677 ) 9678 custom_weight_fake_quant = FusedMovingAvgObsFakeQuantize.with_args( 9679 observer=MovingAverageMinMaxObserver, 9680 quant_min=-128, 9681 quant_max=127, 9682 dtype=torch.qint8, 9683 reduce_range=False, 9684 ) 9685 custom_qconfig = QConfig( 9686 activation=custom_fake_quant, weight=custom_weight_fake_quant 9687 ) 9688 custom_qconfig_dict = {"": custom_qconfig} 9689 prepared = prepare_qat_fx(model, custom_qconfig_dict, example_inputs=example_inputs) 9690 9691 prepared.to(device) 9692 prepared_ref.to(device) 9693 9694 prepared.apply(torch.ao.quantization.disable_fake_quant) 9695 prepared.apply(torch.ao.quantization.disable_observer) 9696 prepared_ref.apply(torch.ao.quantization.disable_fake_quant) 9697 prepared_ref.apply(torch.ao.quantization.disable_observer) 9698 9699 inp = torch.randn(5, 5, device=device, requires_grad=True) 9700 for i in range(10): 9701 if i == 2: 9702 prepared.apply(torch.ao.quantization.enable_observer) 9703 prepared_ref.apply(torch.ao.quantization.enable_observer) 9704 if i == 4: 9705 prepared.apply(torch.ao.quantization.enable_fake_quant) 9706 prepared_ref.apply(torch.ao.quantization.enable_fake_quant) 9707 9708 inp = torch.randn(5, 5, device=device, requires_grad=True) 9709 out_ref = prepared_ref(inp) 9710 out = prepared(inp) 9711 torch.testing.assert_close(out, out_ref) 9712 9713 # try backward pass 9714 labels = torch.randn(5, 5, device=device) 9715 loss = (out - labels).sum() 9716 grad = torch.autograd.grad(loss, [inp]) 9717 loss_ref = (out_ref - labels).sum() 9718 grad_ref = torch.autograd.grad(loss_ref, [inp]) 9719 torch.testing.assert_close(grad[0], grad_ref[0]) 9720 9721 if 'fbgemm' in torch.backends.quantized.supported_engines: 9722 # During the lowering step in convert, fold_weight calls quantized::linear_prepack 9723 # which doesn't support QuantizedCuda backend 9724 prepared.cpu() 9725 prepared_ref.cpu() 9726 converted = convert_fx(prepared) 9727 converted_ref = convert_fx(prepared_ref) 9728 inp = torch.rand(5, 5) 9729 out = converted(inp) 9730 out_ref = converted_ref(inp) 9731 9732 torch.testing.assert_close(out, out_ref) 9733if __name__ == '__main__': 9734 raise RuntimeError("This test file is not meant to be run directly, use:\n\n" 9735 "\tpython test/test_quantization.py TESTNAME\n\n" 9736 "instead.") 9737