1import operator 2from typing import Callable, Dict, List, Optional, Set, Tuple 3 4import torch 5import torch.ao.nn.intrinsic as nni 6import torch.ao.nn.intrinsic.qat as nniqat 7import torch.ao.nn.intrinsic.quantized as nniq 8import torch.ao.nn.intrinsic.quantized.dynamic as nniqd 9import torch.ao.nn.qat as nnqat 10import torch.ao.nn.qat.dynamic as nnqatd 11import torch.ao.nn.quantized as nnq 12import torch.ao.nn.quantized.dynamic as nnqd 13import torch.ao.quantization.fx._lower_to_native_backend as _lower_to_native_backend 14import torch.ao.quantization.quantization_mappings as quantization_mappings 15import torch.nn as nn 16import torch.nn.functional as F 17from torch.ao.quantization.backend_config import get_native_backend_config 18 19from .ns_types import NSNodeTargetType 20 21 22toq = torch.ops.quantized 23 24 25def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: 26 # note: this set is modified below by items from backend_config 27 sets_of_related_ops: List[Set[NSNodeTargetType]] = [ 28 # conv modules 29 { 30 nn.Conv1d, 31 }, 32 { 33 nn.Conv2d, 34 }, 35 { 36 nn.Conv3d, 37 }, 38 # conv functionals 39 { 40 F.conv1d, 41 }, 42 { 43 F.conv2d, 44 }, 45 { 46 F.conv3d, 47 }, 48 # linear modules 49 { 50 nn.Linear, 51 }, 52 # linear functionals 53 { 54 F.linear, 55 }, 56 # average pool 57 { 58 nn.AvgPool1d, 59 torch.avg_pool1d, 60 }, 61 { 62 nn.AvgPool2d, 63 torch._C._nn.avg_pool2d, 64 }, 65 { 66 nn.AvgPool3d, 67 torch._C._nn.avg_pool3d, 68 }, 69 # adaptive average pool 70 { 71 nn.AdaptiveAvgPool1d, 72 F.adaptive_avg_pool1d, 73 }, 74 { 75 nn.AdaptiveAvgPool2d, 76 F.adaptive_avg_pool2d, 77 }, 78 { 79 nn.AdaptiveAvgPool3d, 80 F.adaptive_avg_pool3d, 81 }, 82 # LSTM 83 { 84 nn.LSTM, 85 }, 86 # add 87 { 88 torch.add, 89 operator.add, # x + y 90 }, 91 # cat 92 { 93 torch.cat, 94 }, 95 # mul 96 { 97 torch.mul, 98 operator.mul, 99 }, 100 # relu 101 { 102 F.relu, 103 nn.ReLU, 104 "relu", 105 "relu_", 106 torch.relu, 107 }, 108 # maxpool 109 { 110 nn.MaxPool1d, 111 F.max_pool1d, 112 }, 113 { 114 nn.MaxPool2d, 115 F.max_pool2d, 116 }, 117 { 118 nn.MaxPool3d, 119 F.max_pool3d, 120 }, 121 # sigmoid 122 { 123 torch.sigmoid, 124 "sigmoid", 125 "sigmoid_", 126 nn.Sigmoid, 127 F.sigmoid, 128 }, 129 # BatchNorm 130 { 131 nn.BatchNorm2d, 132 }, 133 { 134 nn.BatchNorm3d, 135 }, 136 # ConvTranspose 137 { 138 nn.ConvTranspose1d, 139 }, 140 { 141 nn.ConvTranspose2d, 142 }, 143 { 144 nn.ConvTranspose3d, 145 }, 146 # functional transposed conv 147 { 148 F.conv_transpose1d, 149 }, 150 { 151 F.conv_transpose2d, 152 }, 153 { 154 F.conv_transpose3d, 155 }, 156 # ELU 157 { 158 nn.ELU, 159 }, 160 # Embedding 161 { 162 nn.Embedding, 163 }, 164 # EmbeddingBag 165 { 166 nn.EmbeddingBag, 167 }, 168 # GroupNorm 169 { 170 nn.GroupNorm, 171 }, 172 # Hardswish 173 { 174 nn.Hardswish, 175 }, 176 # InstanceNorm 177 { 178 nn.InstanceNorm1d, 179 }, 180 { 181 nn.InstanceNorm2d, 182 }, 183 { 184 nn.InstanceNorm3d, 185 }, 186 # LayerNorm 187 { 188 nn.LayerNorm, 189 }, 190 # LeakyReLU 191 { 192 nn.LeakyReLU, 193 }, 194 # ReLU6 195 { 196 nn.ReLU6, 197 F.relu6, 198 }, 199 # F.elu 200 { 201 F.elu, 202 }, 203 # F.hardswish 204 { 205 F.hardswish, 206 }, 207 # F.group_norm 208 { 209 F.group_norm, 210 }, 211 # F.instance_norm 212 { 213 F.instance_norm, 214 }, 215 # F.layer_norm 216 { 217 F.layer_norm, 218 }, 219 # F.leaky_relu 220 { 221 F.leaky_relu, 222 }, 223 # F.silu 224 { 225 nn.SiLU, 226 F.silu, 227 }, 228 # F.mish 229 { 230 nn.Mish, 231 F.mish, 232 }, 233 # F.tanh 234 { 235 nn.Tanh, 236 F.tanh, 237 torch.tanh, 238 "tanh_", 239 "tanh", 240 }, 241 # F.hardsigmoid 242 { 243 "hardsigmoid_", 244 "hardsigmoid", 245 F.hardsigmoid, 246 nn.Hardsigmoid, 247 }, 248 # F.hardtanh 249 { 250 nn.Hardtanh, 251 F.hardtanh, 252 F.hardtanh_, 253 }, 254 # floordiv 255 { 256 operator.floordiv, 257 }, 258 # unsqueeze 259 { 260 torch.unsqueeze, 261 }, 262 # stack 263 { 264 torch.stack, 265 }, 266 # squeeze 267 { 268 torch.squeeze, 269 }, 270 # sort 271 { 272 torch.sort, 273 }, 274 # repeat_interleave 275 { 276 torch.repeat_interleave, 277 }, 278 # min 279 { 280 torch.min, 281 }, 282 # mean 283 { 284 torch.mean, 285 }, 286 # max 287 { 288 torch.max, 289 }, 290 # transpose 291 { 292 torch.transpose, 293 }, 294 # flatten 295 { 296 torch.flatten, 297 }, 298 # clamp 299 { 300 torch.clamp, 301 }, 302 # chunk 303 { 304 torch.chunk, 305 }, 306 # interpolate 307 { 308 torch.nn.functional.interpolate, 309 }, 310 # dropout 311 { 312 nn.Dropout, 313 }, 314 # F.dropout 315 { 316 F.dropout, 317 }, 318 # matmul 319 { 320 torch.matmul, 321 }, 322 # Softmax 323 { 324 nn.Softmax, 325 }, 326 # PReLU 327 { 328 nn.PReLU, 329 nnq.PReLU, 330 }, 331 # F.prelu 332 { 333 F.prelu, 334 toq.prelu, 335 }, 336 # pixel shuffle 337 { 338 nn.PixelShuffle, 339 }, 340 { 341 F.pixel_shuffle, 342 }, 343 # pixel unshuffle 344 { 345 nn.PixelUnshuffle, 346 }, 347 { 348 F.pixel_unshuffle, 349 }, 350 # narrow 351 { 352 torch.narrow, 353 }, 354 ] 355 356 # for each floating point op, add versions of the op added by 357 # backend_config 358 backend_config = get_native_backend_config() 359 360 new_connections: List[Tuple[Callable, Callable]] = [ 361 # technical debt edge case 362 (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear), 363 ] 364 365 for pattern, config in backend_config._pattern_complex_format_to_config.items(): 366 # pattern format: (c, (b, a)) 367 first_element = pattern 368 # look from the end, because pattern is in reverse order 369 while isinstance(first_element, (list, tuple)): 370 first_element = first_element[-1] 371 372 if config.fused_module is not None: 373 # case 1: pattern fuses a pattern of ops into an op 374 # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d 375 new_connections.append((first_element, config.fused_module)) 376 377 if config.qat_module is not None: 378 # case 2: pattern swaps a module into a QAT module 379 # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d 380 new_connections.append((first_element, config.qat_module)) 381 382 if config.reference_quantized_module is not None: 383 # case 3: reference version of floating point module, such as 384 # nn.Conv2d and nnqr.Conv2d 385 new_connections.append((first_element, config.reference_quantized_module)) 386 387 # 388 # Add reference module swaps from default lowering path 389 # 390 391 for source_to_target in ( 392 _lower_to_native_backend.STATIC_LOWER_MODULE_MAP, 393 _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP, 394 _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP, 395 _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP, 396 ): 397 for source, target in source_to_target.items(): # type: ignore[attr-defined] 398 new_connections.append((source, target)) 399 400 for source_to_double_target in ( 401 _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP, 402 _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP, 403 _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP, 404 ): 405 for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined] 406 new_connections.append((source, target1)) 407 new_connections.append((source, target2)) 408 409 # 410 # Add function swaps from default lowering path 411 # 412 413 for source, ( 414 target1, 415 target2, 416 ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): 417 new_connections.append((source, target1)) 418 new_connections.append((source, target2)) 419 420 for source_to_target in ( 421 _lower_to_native_backend.QBIN_OP_MAPPING, 422 _lower_to_native_backend.QBIN_RELU_OP_MAPPING, 423 quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, 424 ): 425 for source, target in source_to_target.items(): 426 new_connections.append((source, target)) 427 428 # 429 # Add other swaps, ideally in the future this could be removed 430 # after the lowering code stops using these. 431 # 432 for source_to_target in ( 433 quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, 434 ): 435 for source, target in source_to_target.items(): 436 new_connections.append((source, target)) 437 438 # add the new connections from backend_config 439 for item1, item2 in new_connections: 440 for set_of_related_ops in sets_of_related_ops: 441 if item1 in set_of_related_ops or item2 in set_of_related_ops: 442 set_of_related_ops.add(item1) 443 set_of_related_ops.add(item2) 444 break 445 446 base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {} 447 448 counter = 0 449 for set_of_related_ops in sets_of_related_ops: 450 base_name = str(counter) 451 counter += 1 452 base_name_to_sets_of_related_ops[base_name] = set_of_related_ops 453 454 return base_name_to_sets_of_related_ops 455 456 457def get_base_name_for_op( 458 base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], 459 op: NSNodeTargetType, 460) -> Optional[str]: 461 for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items(): 462 if op in set_of_related_ops: 463 return base_name 464 return None 465 466 467def add_op_to_sets_of_related_ops( 468 base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], 469 op: NSNodeTargetType, 470 related_op: Optional[NSNodeTargetType], 471) -> None: 472 if related_op is not None: 473 for set_of_related_ops in base_name_to_sets_of_related_ops.values(): 474 if related_op in set_of_related_ops: 475 set_of_related_ops.add(op) 476 return 477 # if we got here, related_op was not found 478 raise AssertionError(f"{related_op} was not found") 479 else: 480 counter = 0 481 while str(counter) in base_name_to_sets_of_related_ops: 482 counter += 1 483 base_name_to_sets_of_related_ops[str(counter)] = {op} 484 485 486# TODO(future PR): clean this up 487def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: 488 FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = { 489 F.linear, 490 F.conv1d, 491 F.conv2d, 492 F.conv3d, 493 torch.cat, 494 F.elu, 495 F.hardswish, 496 F.instance_norm, 497 F.layer_norm, 498 F.leaky_relu, 499 F.dropout, 500 F.silu, 501 F.mish, 502 operator.add, 503 torch.add, 504 operator.mul, 505 torch.mul, 506 torch.sum, 507 F.prelu, 508 } 509 510 FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set() 511 512 FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = { 513 toq.linear, 514 toq.linear_relu, 515 toq.conv1d, 516 toq.conv1d_relu, 517 toq.conv2d, 518 toq.conv2d_relu, 519 toq.conv3d, 520 toq.conv3d_relu, 521 toq.cat, 522 toq.elu, 523 toq.hardswish, 524 toq.instance_norm, 525 toq.layer_norm, 526 toq.leaky_relu, 527 toq.dropout, 528 toq.prelu, 529 # TODO(future PR): implement shadowing for binary ops and 530 # uncomment below 531 # toq.add, 532 # toq.mul, 533 } 534 535 FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { 536 F.relu, 537 F.tanh, 538 torch.tanh, 539 F.sigmoid, 540 torch.sigmoid, 541 F.hardsigmoid, 542 operator.floordiv, 543 torch.adaptive_avg_pool1d, 544 F.adaptive_avg_pool2d, 545 F.adaptive_avg_pool3d, 546 F.dropout, 547 F.hardtanh, 548 F.hardtanh_, 549 F.interpolate, 550 F.max_pool1d, 551 F.max_pool2d, 552 F.max_pool3d, 553 F.relu6, 554 F.pixel_shuffle, 555 F.pixel_unshuffle, 556 torch.avg_pool1d, 557 torch._C._nn.avg_pool2d, 558 torch._C._nn.avg_pool3d, 559 torch.cat, 560 torch.chunk, 561 torch.clamp, 562 torch.flatten, 563 torch.transpose, 564 torch.max, 565 torch.mean, 566 torch.min, 567 torch.narrow, 568 torch.repeat_interleave, 569 torch.sort, 570 torch.squeeze, 571 torch.stack, 572 torch.unsqueeze, 573 operator.add, 574 } 575 576 MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = { 577 nn.Linear, 578 nnqat.Linear, 579 nnqatd.Linear, 580 nnqd.Linear, 581 torch.nn.modules.linear.NonDynamicallyQuantizableLinear, 582 nn.Conv1d, 583 nn.Conv2d, 584 nn.Conv3d, 585 nnqat.Conv1d, 586 nnqat.Conv2d, 587 nnqat.Conv3d, 588 nnqat.Embedding, 589 nnqat.EmbeddingBag, 590 nn.LSTM, 591 # note: nnqd.Linear is an instance of nnq.Linear, so this 592 # check has to happen before the int8 module check 593 nnqd.LSTM, 594 nn.BatchNorm2d, 595 nn.BatchNorm3d, 596 nn.Dropout, 597 nn.ConvTranspose1d, 598 nn.ConvTranspose2d, 599 nn.ConvTranspose3d, 600 nn.ELU, 601 nn.GroupNorm, 602 nn.InstanceNorm1d, 603 nn.InstanceNorm2d, 604 nn.InstanceNorm3d, 605 nn.LayerNorm, 606 nn.Hardswish, 607 nn.LeakyReLU, 608 nn.ReLU6, 609 nn.SiLU, 610 nn.Mish, 611 nn.Softmax, 612 nn.PReLU, 613 nni.BNReLU2d, 614 nni.BNReLU3d, 615 nni.ConvReLU1d, 616 nni.ConvReLU2d, 617 nni.ConvReLU3d, 618 nni.LinearReLU, 619 nni.LinearBn1d, 620 nni.ConvBn1d, 621 nni.ConvBn2d, 622 nni.ConvBn3d, 623 nniqat.ConvBn1d, 624 nniqat.ConvBn2d, 625 nniqat.ConvBn3d, 626 nniqat.ConvBnReLU1d, 627 nniqat.ConvBnReLU2d, 628 nniqat.ConvBnReLU3d, 629 nniqat.ConvReLU1d, 630 nniqat.ConvReLU2d, 631 nniqat.ConvReLU3d, 632 nniqat.LinearReLU, 633 nniqat.LinearBn1d, 634 nniqd.LinearReLU, 635 nni.LinearLeakyReLU, 636 nni.LinearTanh, 637 nni.ConvAdd2d, 638 nni.ConvAddReLU2d, 639 } 640 641 MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = { 642 nnq.Linear, 643 nnq.Conv1d, 644 nnq.Conv2d, 645 nnq.Conv3d, 646 nnq.BatchNorm2d, 647 nnq.BatchNorm3d, 648 nnq.Dropout, 649 nnq.ConvTranspose1d, 650 nnq.ConvTranspose2d, 651 nnq.ELU, 652 nnq.InstanceNorm1d, 653 nnq.InstanceNorm2d, 654 nnq.InstanceNorm3d, 655 nnq.LayerNorm, 656 nnq.Hardswish, 657 nnq.LeakyReLU, 658 nnq.Embedding, 659 nnq.EmbeddingBag, 660 nnq.Dropout, 661 nnq.Softmax, 662 nnq.PReLU, 663 nniq.BNReLU2d, 664 nniq.BNReLU3d, 665 nniq.ConvReLU1d, 666 nniq.ConvReLU2d, 667 nniq.ConvReLU3d, 668 nniq.LinearReLU, 669 nniq.LinearLeakyReLU, 670 nniq.LinearTanh, 671 nniq.ConvAdd2d, 672 nniq.ConvAddReLU2d, 673 } 674 675 MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { 676 nn.ReLU, 677 nn.Tanh, 678 nn.Sigmoid, 679 nn.Hardsigmoid, 680 nn.AdaptiveAvgPool1d, 681 nn.AdaptiveAvgPool2d, 682 nn.AdaptiveAvgPool3d, 683 nn.AvgPool1d, 684 nn.AvgPool2d, 685 nn.AvgPool3d, 686 nn.Dropout, 687 nn.Hardtanh, 688 nn.Identity, 689 nn.MaxPool1d, 690 nn.MaxPool2d, 691 nn.MaxPool3d, 692 nn.PixelShuffle, 693 nn.PixelUnshuffle, 694 nn.ReLU6, 695 } 696 697 METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = { 698 "sigmoid_", 699 "sigmoid", 700 "tanh_", 701 "tanh", 702 "hardsigmoid_", 703 "hardsigmoid", 704 "relu_", 705 "relu", 706 } 707 708 return { 709 "funs_io_type_fp32": FUNS_IO_TYPE_FP32, 710 "funs_io_type_fp16": FUNS_IO_TYPE_FP16, 711 "funs_io_type_int8": FUNS_IO_TYPE_INT8, 712 "funs_io_type_fp32_or_int8": FUNS_IO_TYPE_FP32_OR_INT8, 713 "mods_io_type_fp32": MODS_IO_TYPE_FP32, 714 "mods_io_type_int8": MODS_IO_TYPE_INT8, 715 "mods_io_type_fp32_or_int8": MODS_IO_TYPE_FP32_OR_INT8, 716 "meths_io_type_fp32_or_int8": METHS_IO_TYPE_FP32_OR_INT8, 717 } 718 719 720def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]: 721 FUNS_UNMATCHABLE: Set[NSNodeTargetType] = { 722 torch.quantize_per_tensor, 723 operator.getitem, 724 } 725 726 MODS_UNMATCHABLE: Set[NSNodeTargetType] = { 727 nn.Identity, 728 } 729 730 METHS_UNMATCHABLE: Set[NSNodeTargetType] = { 731 "to", 732 "dequantize", 733 "reshape", 734 "view", 735 "unsqueeze_", 736 "unsqueeze", 737 "transpose", 738 "squeeze_", 739 "squeeze", 740 "size", 741 "shape", 742 "resize_", 743 "repeat_interleave", 744 "repeat", 745 "permute", 746 "numel", 747 "mean", 748 "detach_", 749 "detach", 750 "contiguous", 751 "clamp", 752 "chunk", 753 } 754 755 return { 756 "funs_unmatchable": FUNS_UNMATCHABLE, 757 "mods_unmatchable": MODS_UNMATCHABLE, 758 "meths_unmatchable": METHS_UNMATCHABLE, 759 } 760