1# mypy: allow-untyped-defs 2import copy 3import operator 4from collections import namedtuple 5from typing import Callable, Dict, List, Union 6 7import torch 8import torch.ao.nn.intrinsic as nni 9import torch.ao.nn.intrinsic.qat as nniqat 10import torch.ao.nn.qat as nnqat 11import torch.ao.nn.quantized.reference as nnqr 12import torch.nn as nn 13import torch.nn.functional as F 14from torch.ao.quantization.fuser_method_mappings import ( 15 _sequential_wrapper2, 16 fuse_conv_bn, 17 fuse_conv_bn_relu, 18 fuse_convtranspose_bn, 19 fuse_linear_bn, 20) 21 22from .backend_config import ( 23 BackendPatternConfig, 24 DTypeConfig, 25 DTypeWithConstraints, 26 ObservationType, 27) 28 29 30__all__: List[str] = [] 31 32# TODO: rename to be more explicit, e.g. qat_conv_relu 33_ConvMetadata = namedtuple( 34 "_ConvMetadata", 35 [ 36 "root", 37 "transpose", 38 "bn", 39 "reference", 40 "transpose_reference", 41 "fused_conv_relu", 42 "fused_conv_bn", 43 "fused_conv_bn_relu", 44 "qat", 45 "relu_qat", 46 "bn_qat", 47 "bn_relu_qat", 48 "func", 49 "func_transpose", 50 ], 51) 52_Conv1dMetadata = _ConvMetadata( 53 nn.Conv1d, 54 nn.ConvTranspose1d, 55 nn.BatchNorm1d, 56 nnqr.Conv1d, 57 nnqr.ConvTranspose1d, 58 nni.ConvReLU1d, 59 nni.ConvBn1d, 60 nni.ConvBnReLU1d, 61 nnqat.Conv1d, 62 nniqat.ConvReLU1d, 63 nniqat.ConvBn1d, 64 nniqat.ConvBnReLU1d, 65 F.conv1d, 66 F.conv_transpose1d, 67) 68_Conv2dMetadata = _ConvMetadata( 69 nn.Conv2d, 70 nn.ConvTranspose2d, 71 nn.BatchNorm2d, 72 nnqr.Conv2d, 73 nnqr.ConvTranspose2d, 74 nni.ConvReLU2d, 75 nni.ConvBn2d, 76 nni.ConvBnReLU2d, 77 nnqat.Conv2d, 78 nniqat.ConvReLU2d, 79 nniqat.ConvBn2d, 80 nniqat.ConvBnReLU2d, 81 F.conv2d, 82 F.conv_transpose2d, 83) 84_Conv3dMetadata = _ConvMetadata( 85 nn.Conv3d, 86 nn.ConvTranspose3d, 87 nn.BatchNorm3d, 88 nnqr.Conv3d, 89 nnqr.ConvTranspose3d, 90 nni.ConvReLU3d, 91 nni.ConvBn3d, 92 nni.ConvBnReLU3d, 93 nnqat.Conv3d, 94 nniqat.ConvReLU3d, 95 nniqat.ConvBn3d, 96 nniqat.ConvBnReLU3d, 97 F.conv3d, 98 F.conv_transpose3d, 99) 100 101# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values 102# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh 103_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( 104 dtype=torch.quint8, 105 quant_min_lower_bound=0, 106 quant_max_upper_bound=255, 107 scale_exact_match=1.0 / 256.0, 108 zero_point_exact_match=0, 109) 110_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( 111 dtype=torch.quint8, 112 quant_min_lower_bound=0, 113 quant_max_upper_bound=255, 114 scale_exact_match=2.0 / 256.0, 115 zero_point_exact_match=128, 116) 117_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = { 118 torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 119 torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 120 "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 121 "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 122 torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 123 torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 124 "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 125 "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 126 torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, 127 torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, 128 torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, 129 "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, 130 "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, 131} 132 133 134def _get_binary_op_configs( 135 dtype_configs: List[DTypeConfig], 136) -> List[BackendPatternConfig]: 137 binary_op_configs: List[BackendPatternConfig] = [] 138 num_tensor_args_to_observation_type_mapping = { 139 # TODO: this is not used right now since we have extra check in prepare 140 # will need to change this to NO_OBSERVER later after we implemented 141 # Tensor dtype inference properly 142 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 143 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, 144 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, 145 } 146 for op_with_quantized_bop_scalar_variant in [ 147 operator.add, 148 torch.add, 149 operator.mul, 150 torch.mul, 151 ]: 152 bop_patterns = [ 153 (op_with_quantized_bop_scalar_variant, nn.ReLU), 154 (op_with_quantized_bop_scalar_variant, F.relu), 155 (op_with_quantized_bop_scalar_variant, torch.relu), 156 op_with_quantized_bop_scalar_variant, 157 ] 158 for bop_pattern in bop_patterns: 159 binary_op_configs.append( 160 BackendPatternConfig(bop_pattern) 161 .set_dtype_configs(dtype_configs) # noqa: E131 162 ._set_num_tensor_args_to_observation_type( 163 num_tensor_args_to_observation_type_mapping 164 ) 165 ) 166 # matmul 167 binary_op_configs.append( 168 BackendPatternConfig(torch.matmul).set_dtype_configs( 169 dtype_configs 170 ) # noqa: E131 171 ) 172 return binary_op_configs 173 174 175def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: 176 """ 177 Return all configs related to linear modules and ops. 178 """ 179 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 180 linear_configs: List[BackendPatternConfig] = [] 181 182 # (1) Single linear modules/functions 183 # ------------------------------------- 184 # linear module 185 linear_configs.append( 186 BackendPatternConfig(torch.nn.Linear) 187 .set_observation_type(observation_type) # noqa: E131 188 .set_dtype_configs(dtype_configs) 189 .set_root_module(torch.nn.Linear) 190 .set_reference_quantized_module(nnqr.Linear) 191 .set_qat_module(nnqat.Linear) 192 ) 193 # linear qat module 194 linear_configs.append( 195 BackendPatternConfig(nnqat.Linear) 196 .set_observation_type(observation_type) # noqa: E131 197 .set_dtype_configs(dtype_configs) 198 .set_root_module(torch.nn.Linear) 199 .set_reference_quantized_module(nnqr.Linear) 200 ) 201 # functional linear 202 linear_configs.append( 203 BackendPatternConfig(torch.nn.functional.linear) 204 .set_observation_type(observation_type) # noqa: E131 205 .set_dtype_configs(dtype_configs) 206 ._set_input_type_to_index({"weight": 1, "bias": 2}) 207 ) 208 209 # (2) Linear + relu 210 # ------------------- 211 # 2.1 linear module + relu fusion config 212 # linear relu, linear module + relu module 213 linear_configs.append( 214 BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) 215 .set_dtype_configs(dtype_configs) # noqa: E131 216 .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) 217 .set_fused_module(nni.LinearReLU) 218 ) 219 # linear relu, linear module + functional relu 220 linear_configs.append( 221 BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) 222 .set_dtype_configs(dtype_configs) # noqa: E131 223 .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) 224 .set_fused_module(nni.LinearReLU) 225 ) 226 227 # 2.2 linear module + relu, fused module configs 228 # linear relu, fused module 229 linear_configs.append( 230 BackendPatternConfig(nni.LinearReLU) 231 .set_observation_type(observation_type) # noqa: E131 232 .set_dtype_configs(dtype_configs) 233 .set_root_module(torch.nn.Linear) 234 .set_reference_quantized_module(nnqr.Linear) 235 .set_qat_module(nniqat.LinearReLU) 236 ) 237 # linear relu, qat fused module 238 linear_configs.append( 239 BackendPatternConfig(nniqat.LinearReLU) 240 .set_observation_type(observation_type) # noqa: E131 241 .set_dtype_configs(dtype_configs) 242 .set_root_module(torch.nn.Linear) 243 .set_reference_quantized_module(nnqr.Linear) 244 ) 245 # 2.3 functional linear + relu configs 246 # linear relu, functional linear + relu module 247 linear_configs.append( 248 BackendPatternConfig((F.linear, torch.nn.ReLU)) 249 .set_observation_type(observation_type) # noqa: E131 250 .set_dtype_configs(dtype_configs) 251 ) 252 # linear relu, functional linear + functional relu 253 linear_configs.append( 254 BackendPatternConfig((F.linear, F.relu)) 255 .set_observation_type(observation_type) # noqa: E131 256 .set_dtype_configs(dtype_configs) 257 ) 258 259 # (3) Linear + batchnorm 260 # ------------------------ 261 # 3.1 linear bn fusion 262 linear_configs.append( 263 BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) 264 .set_dtype_configs(dtype_configs) # noqa: E131 265 .set_fuser_method(fuse_linear_bn) 266 .set_fused_module(nni.LinearBn1d) 267 ) 268 269 # 3.2 linear bn fused 270 # linear bn, fused module 271 linear_configs.append( 272 BackendPatternConfig(nni.LinearBn1d) 273 .set_observation_type(observation_type) # noqa: E131 274 .set_dtype_configs(dtype_configs) 275 .set_root_module(torch.nn.Linear) 276 .set_reference_quantized_module(nnqr.Linear) 277 .set_qat_module(nniqat.LinearBn1d) 278 ) 279 # linear bn, qat fused module 280 linear_configs.append( 281 BackendPatternConfig(nniqat.LinearBn1d) 282 .set_observation_type(observation_type) # noqa: E131 283 .set_dtype_configs(dtype_configs) 284 .set_root_module(torch.nn.Linear) 285 .set_reference_quantized_module(nnqr.Linear) 286 ) 287 return linear_configs 288 289 290def _get_conv_configs(dtype_configs): 291 """ 292 Return all configs related to conv modules and ops. 293 """ 294 conv_configs = [] 295 observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 296 for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: 297 # (1) Single conv modules/functions 298 # ----------------------------------- 299 # conv module 300 conv_configs.append( 301 BackendPatternConfig(convs.root) 302 .set_observation_type(observation_type) # noqa: E131 303 .set_dtype_configs(dtype_configs) 304 .set_root_module(convs.root) 305 .set_reference_quantized_module(convs.reference) 306 .set_qat_module(convs.qat) 307 ) 308 # conv qat module 309 conv_configs.append( 310 BackendPatternConfig(convs.qat) 311 .set_observation_type(observation_type) # noqa: E131 312 .set_dtype_configs(dtype_configs) 313 .set_root_module(convs.root) 314 .set_reference_quantized_module(convs.reference) 315 ) 316 # functional conv 317 conv_configs.append( 318 BackendPatternConfig(convs.func) 319 .set_observation_type(observation_type) # noqa: E131 320 .set_dtype_configs(dtype_configs) 321 ._set_input_type_to_index({"weight": 1, "bias": 2}) 322 ) 323 324 # (2) Conv + relu 325 # ----------------- 326 # 2.1 conv module + relu fusion configs 327 # conv relu fusion, conv module + relu module 328 conv_configs.append( 329 BackendPatternConfig((convs.root, torch.nn.ReLU)) 330 .set_dtype_configs(dtype_configs) # noqa: E131 331 .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) 332 .set_fused_module(convs.fused_conv_relu) 333 ) 334 # conv relu fusion, conv module + functional relu 335 conv_configs.append( 336 BackendPatternConfig((convs.root, F.relu)) 337 .set_dtype_configs(dtype_configs) # noqa: E131 338 .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) 339 .set_fused_module(convs.fused_conv_relu) 340 ) 341 # 2.2 conv module + relu fused module configs 342 # conv relu, fused module 343 conv_configs.append( 344 BackendPatternConfig(convs.fused_conv_relu) 345 .set_observation_type(observation_type) # noqa: E131 346 .set_dtype_configs(dtype_configs) 347 .set_root_module(convs.root) 348 .set_reference_quantized_module(convs.reference) 349 .set_qat_module(convs.relu_qat) 350 ) 351 # conv relu, qat fused module 352 conv_configs.append( 353 BackendPatternConfig(convs.relu_qat) 354 .set_observation_type(observation_type) # noqa: E131 355 .set_dtype_configs(dtype_configs) 356 .set_root_module(convs.root) 357 .set_reference_quantized_module(convs.reference) 358 ) 359 # 2.3 functional conv + relu configs 360 # conv relu, functional conv + relu module 361 conv_configs.append( 362 BackendPatternConfig((convs.func, torch.nn.ReLU)) 363 .set_observation_type(observation_type) # noqa: E131 364 .set_dtype_configs(dtype_configs) 365 ) 366 # conv relu, functional conv + functional relu 367 conv_configs.append( 368 BackendPatternConfig((convs.func, F.relu)) 369 .set_observation_type(observation_type) # noqa: E131 370 .set_dtype_configs(dtype_configs) 371 ) 372 373 # fused conv relu 374 conv_configs.append( 375 BackendPatternConfig(convs.fused_conv_relu) 376 .set_dtype_configs(dtype_configs) # noqa: E131 377 .set_qat_module(convs.relu_qat) 378 ) 379 380 conv_configs.append( 381 BackendPatternConfig(convs.relu_qat) 382 .set_dtype_configs(dtype_configs) # noqa: E131 383 .set_root_module(convs.root) 384 .set_reference_quantized_module(convs.reference) 385 ) 386 387 # (3) Conv + batchnorm (+ relu) 388 # ------------------------------- 389 # 3.1 conv bn fusion configs 390 # conv + bn fusion 391 conv_configs.append( 392 BackendPatternConfig((convs.root, convs.bn)) 393 .set_dtype_configs(dtype_configs) # noqa: E131 394 .set_fuser_method(fuse_conv_bn) 395 .set_fused_module(convs.fused_conv_bn) 396 ) 397 # conv + bn + relu module fusion 398 conv_configs.append( 399 BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) 400 .set_dtype_configs(dtype_configs) # noqa: E131 401 .set_fuser_method(fuse_conv_bn_relu) 402 .set_fused_module(convs.fused_conv_bn_relu) 403 ) 404 # conv + bn + relu functional fusion 405 conv_configs.append( 406 BackendPatternConfig((convs.root, convs.bn, F.relu)) 407 .set_dtype_configs(dtype_configs) # noqa: E131 408 .set_root_module(convs.root) 409 .set_fuser_method(fuse_conv_bn_relu) 410 .set_fused_module(convs.fused_conv_bn_relu) 411 ) 412 # TODO: we can add fusion for torch.relu as well 413 414 # 3.2 conv + bn (+ relu) fused module configs 415 # fused conv bn 416 conv_configs.append( 417 BackendPatternConfig(convs.fused_conv_bn) 418 .set_dtype_configs(dtype_configs) # noqa: E131 419 .set_qat_module(convs.bn_qat) 420 ) 421 422 # fused conv bn relu 423 conv_configs.append( 424 BackendPatternConfig(convs.fused_conv_bn_relu) 425 .set_dtype_configs(dtype_configs) # noqa: E131 426 .set_qat_module(convs.bn_relu_qat) 427 ) 428 429 # conv bn, qat fused module 430 conv_configs.append( 431 BackendPatternConfig(convs.bn_qat) 432 .set_observation_type(observation_type) # noqa: E131 433 .set_dtype_configs(dtype_configs) 434 .set_root_module(convs.root) 435 .set_reference_quantized_module(convs.reference) 436 ) 437 # conv bn relu, qat fused module 438 conv_configs.append( 439 BackendPatternConfig(convs.bn_relu_qat) 440 .set_observation_type(observation_type) # noqa: E131 441 .set_dtype_configs(dtype_configs) 442 .set_root_module(convs.root) 443 .set_reference_quantized_module(convs.reference) 444 ) 445 446 # (4) conv transpose and its fusion 447 # 4.1 conv transpose config 448 conv_configs.append( 449 BackendPatternConfig(convs.transpose) 450 .set_dtype_configs(dtype_configs) # noqa: E131 451 .set_root_module(convs.transpose) 452 .set_reference_quantized_module(convs.transpose_reference) 453 ) 454 455 # 4.2 conv transpose + bn fusion 456 conv_configs.append( 457 BackendPatternConfig((convs.transpose, convs.bn)) 458 .set_dtype_configs(dtype_configs) # noqa: E131 459 .set_fuser_method(fuse_convtranspose_bn) 460 .set_root_module(convs.transpose) 461 .set_reference_quantized_module(convs.transpose_reference) 462 ) 463 464 # 4.3 functional conv transpose 465 conv_configs.append( 466 BackendPatternConfig(convs.func_transpose) 467 .set_dtype_configs(dtype_configs) # noqa: E131 468 ._set_input_type_to_index({"weight": 1, "bias": 2}) 469 ) 470 471 return conv_configs 472 473 474def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: 475 return ( 476 BackendPatternConfig(torch.cat) 477 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 478 .set_dtype_configs(dtype_configs) 479 ) 480 481 482def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: 483 ln_configs = [] 484 ln_configs.append( 485 BackendPatternConfig(torch.nn.LayerNorm) 486 .set_observation_type( 487 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 488 ) # noqa: E131 489 .set_dtype_configs(dtype_configs) 490 ) 491 ln_configs.append( 492 BackendPatternConfig(torch.nn.functional.layer_norm) 493 .set_observation_type( 494 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 495 ) # noqa: E131 496 .set_dtype_configs(dtype_configs) 497 ._set_input_type_to_index({"weight": 2, "bias": 3}) 498 ) 499 return ln_configs 500 501 502def _get_default_op_configs( 503 dtype_configs: List[DTypeConfig], 504) -> List[BackendPatternConfig]: 505 configs = [] 506 default_ops = [ 507 torch.nn.ELU, 508 torch.nn.LeakyReLU, 509 torch.nn.Hardswish, 510 torch.nn.InstanceNorm1d, 511 torch.nn.InstanceNorm2d, 512 torch.nn.InstanceNorm3d, 513 torch.nn.Dropout, 514 torch.nn.PReLU, 515 torch.nn.functional.elu, 516 torch.nn.functional.hardswish, 517 torch.nn.functional.leaky_relu, 518 torch.nn.functional.dropout, 519 ] 520 for op in default_ops: 521 configs.append( 522 BackendPatternConfig(op) 523 .set_observation_type( 524 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 525 ) # noqa: E131 526 .set_dtype_configs(dtype_configs) 527 ) 528 529 configs.append( 530 BackendPatternConfig(torch.nn.functional.group_norm) 531 .set_observation_type( 532 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 533 ) # noqa: E131 534 .set_dtype_configs(dtype_configs) 535 ._set_input_type_to_index({"weight": 2, "bias": 3}) 536 ) 537 538 configs.append( 539 BackendPatternConfig(torch.nn.functional.instance_norm) 540 .set_observation_type( 541 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 542 ) # noqa: E131 543 .set_dtype_configs(dtype_configs) 544 ._set_input_type_to_index({"weight": 3, "bias": 4}) 545 ) 546 return configs 547 548 549def _add_fixed_qparams_to_dtype_configs( 550 dtype_configs: List[DTypeConfig], 551 constraints: DTypeWithConstraints, 552) -> List[DTypeConfig]: 553 """ 554 Return a copy of the list of DTypeConfigs where activations are subject to the specified 555 constraints required for fixed qparams ops. 556 557 If the data type doesn't match the one in the constraints, simply leave the corresponding 558 DTypeConfig unchanged. 559 560 If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, 561 throw an exception since these settings are incompatible with fixed qparams ops. 562 """ 563 new_dtype_configs = [] 564 for dtype_config in dtype_configs: 565 dc = copy.deepcopy(dtype_config) 566 for orig_constraints in [ 567 dc.input_dtype_with_constraints, 568 dc.output_dtype_with_constraints, 569 ]: 570 if orig_constraints.dtype != constraints.dtype: 571 continue 572 if orig_constraints.scale_min_lower_bound is not None: 573 raise ValueError( 574 f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}" 575 ) 576 if orig_constraints.scale_max_upper_bound is not None: 577 raise ValueError( 578 f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}" 579 ) 580 orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound 581 orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound 582 orig_constraints.scale_exact_match = constraints.scale_exact_match 583 orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match 584 new_dtype_configs.append(dc) 585 return new_dtype_configs 586 587 588def _get_fixed_qparams_op_configs( 589 dtype_configs: List[DTypeConfig], 590) -> List[BackendPatternConfig]: 591 fixed_qparams_op_configs = [] 592 for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): 593 new_dtype_configs = _add_fixed_qparams_to_dtype_configs( 594 dtype_configs, constraints 595 ) 596 fixed_qparams_op_configs.append( 597 BackendPatternConfig(fixed_qparam_op) 598 .set_observation_type( 599 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 600 ) # noqa: E131 601 .set_dtype_configs(new_dtype_configs) 602 ) 603 return fixed_qparams_op_configs 604 605 606def _get_share_qparams_op_configs(dtype_configs): 607 """Get the operator config for the operators that works for both float and quantized input 608 if input is quantized, the output Tensor shares the same quantization parameter 609 with input. 610 Example operator: avgpool2d, reshape, transpose, maxpool2d 611 Example observed operator: 612 observer_0 - avgpool2d - observer_0 (same observer instance as input) 613 """ 614 615 def _get_share_qprams_op_backend_config(op): 616 return ( 617 BackendPatternConfig(op) 618 .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) 619 .set_dtype_configs(dtype_configs) 620 ) 621 622 share_qparams_ops = [ 623 torch.nn.AdaptiveAvgPool1d, 624 torch.nn.AdaptiveAvgPool2d, 625 torch.nn.AdaptiveAvgPool3d, 626 torch.nn.AvgPool1d, 627 torch.nn.AvgPool2d, 628 torch.nn.AvgPool3d, 629 torch.nn.Hardtanh, 630 torch.nn.Identity, 631 torch.nn.MaxPool1d, 632 torch.nn.MaxPool2d, 633 torch.nn.MaxPool3d, 634 torch.nn.PixelShuffle, 635 torch.nn.PixelUnshuffle, 636 torch.nn.ReLU, 637 torch.nn.ReLU6, 638 torch.adaptive_avg_pool1d, 639 torch.nn.functional.adaptive_avg_pool2d, 640 torch.nn.functional.adaptive_avg_pool3d, 641 torch.nn.functional.hardtanh, 642 torch.nn.functional.hardtanh_, 643 torch.nn.functional.interpolate, 644 torch.nn.functional.max_pool1d, 645 torch.nn.functional.max_pool2d, 646 torch.nn.functional.max_pool3d, 647 torch.nn.functional.pixel_shuffle, 648 torch.nn.functional.pixel_unshuffle, 649 torch.nn.functional.relu, 650 torch.nn.functional.relu6, 651 torch.avg_pool1d, 652 torch._C._nn.avg_pool2d, 653 torch._C._nn.avg_pool3d, 654 torch.clamp, 655 torch.flatten, 656 torch.mean, 657 torch.narrow, 658 torch.repeat_interleave, 659 torch.transpose, 660 torch.squeeze, 661 torch.stack, 662 torch.unsqueeze, 663 operator.floordiv, 664 "contiguous", 665 "clamp", 666 "detach", 667 "detach_", 668 "mean", 669 "permute", 670 "repeat", 671 "repeat_interleave", 672 "reshape", 673 "resize_", 674 "relu", 675 "relu_", 676 "squeeze", 677 "squeeze_", 678 "transpose", 679 "unsqueeze", 680 "unsqueeze_", 681 "view", 682 ] 683 return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops] 684 685 686def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: 687 """Get configs related to batchnorm.""" 688 bn_configs = [] 689 bn_to_fused_bn = { 690 torch.nn.BatchNorm2d: nni.BNReLU2d, 691 torch.nn.BatchNorm3d: nni.BNReLU3d, 692 } 693 for bn in bn_to_fused_bn.keys(): 694 fused_bn = bn_to_fused_bn[bn] 695 # bn module + relu module fusion config 696 bn_configs.append( 697 BackendPatternConfig((bn, nn.ReLU)) 698 .set_dtype_configs(dtype_configs) # noqa: E131 699 .set_fuser_method(_sequential_wrapper2(fused_bn)) 700 .set_fused_module(fused_bn) 701 ) 702 # bn module + F.relu fusion config 703 bn_configs.append( 704 BackendPatternConfig((bn, F.relu)) 705 .set_dtype_configs(dtype_configs) # noqa: E131 706 .set_fuser_method(_sequential_wrapper2(fused_bn)) 707 .set_fused_module(fused_bn) 708 ) 709 bn_configs.append( 710 BackendPatternConfig(bn) 711 .set_observation_type( 712 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 713 ) # noqa: E131 714 .set_dtype_configs(dtype_configs) 715 ) 716 717 # fused bn configs 718 for fused_bn in bn_to_fused_bn.values(): 719 bn_configs.append( 720 BackendPatternConfig(fused_bn) 721 .set_observation_type( 722 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 723 ) # noqa: E131 724 .set_dtype_configs(dtype_configs) 725 ) 726 return bn_configs 727 728 729def _get_rnn_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: 730 rnn_op_configs = [] 731 for rnn_op, ref_rnn_op in [ 732 (nn.GRUCell, nnqr.GRUCell), 733 (nn.LSTMCell, nnqr.LSTMCell), 734 (nn.RNNCell, nnqr.RNNCell), 735 (nn.LSTM, nnqr.LSTM), 736 (nn.GRU, nnqr.GRU), 737 ]: 738 rnn_op_configs.append( 739 BackendPatternConfig(rnn_op) 740 .set_observation_type( 741 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 742 ) # noqa: E131 743 .set_dtype_configs(dtype_configs) 744 .set_root_module(rnn_op) 745 .set_reference_quantized_module(ref_rnn_op) 746 ) 747 return rnn_op_configs 748 749 750def _get_embedding_op_configs( 751 dtype_configs: List[DTypeConfig], 752) -> List[BackendPatternConfig]: 753 embedding_op_configs = [] 754 for embedding_op, qat_embedding_op, ref_embedding_op in [ 755 (nn.Embedding, nnqat.Embedding, nnqr.Embedding), 756 (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), 757 ]: 758 embedding_op_configs.append( 759 BackendPatternConfig(embedding_op) 760 .set_observation_type( 761 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 762 ) # noqa: E131 763 .set_dtype_configs(dtype_configs) 764 .set_qat_module(qat_embedding_op) 765 .set_root_module(embedding_op) 766 .set_reference_quantized_module(ref_embedding_op) 767 ) 768 769 # config for qat op 770 embedding_op_configs.append( 771 BackendPatternConfig(qat_embedding_op) 772 .set_observation_type( 773 ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 774 ) # noqa: E131 775 .set_dtype_configs(dtype_configs) 776 .set_root_module(embedding_op) 777 .set_reference_quantized_module(ref_embedding_op) 778 ) 779 return embedding_op_configs 780 781 782def _get_tensor_info_op_configs(dtype_configs): 783 """ 784 These ops work on tensors of different dtypes but return non-tensors 785 containing information about the input tensor. 786 """ 787 788 def _get_config(op): 789 return ( 790 BackendPatternConfig(op) 791 .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED) 792 .set_dtype_configs(dtype_configs) 793 ) 794 795 return [_get_config(op) for op in ("shape", "size")] 796