1# mypy: allow-untyped-defs 2import itertools 3import operator 4 5import torch 6import torch.ao.nn.intrinsic as nni 7import torch.ao.nn.quantized.reference as nnqr 8import torch.nn as nn 9import torch.nn.functional as F 10from torch.ao.quantization.fuser_method_mappings import _sequential_wrapper2 11from torch.ao.quantization.utils import MatchAllNode 12 13from ._common_operator_config_utils import ( 14 _get_binary_op_configs, 15 _get_bn_configs, 16 _get_cat_config, 17 _get_conv_configs, 18 _get_default_op_configs, 19 _get_embedding_op_configs, 20 _get_fixed_qparams_op_configs, 21 _get_linear_configs, 22 _get_ln_configs, 23 _get_rnn_op_configs, 24 _get_share_qparams_op_configs, 25) 26from .backend_config import ( 27 BackendConfig, 28 BackendPatternConfig, 29 DTypeConfig, 30 ObservationType, 31) 32 33 34# =================== 35# | DTYPE CONFIGS | 36# =================== 37 38onednn_weighted_op_int8_dtype_config = DTypeConfig( 39 input_dtype=torch.quint8, 40 output_dtype=torch.quint8, 41 weight_dtype=torch.qint8, 42 bias_dtype=torch.float, 43) 44 45onednn_op_quint8_dtype_config = DTypeConfig( 46 input_dtype=torch.quint8, 47 output_dtype=torch.quint8, 48) 49 50onednn_dynamic_int8_dtype_config = DTypeConfig( 51 input_dtype=torch.quint8, 52 output_dtype=torch.float, 53 weight_dtype=torch.qint8, 54 bias_dtype=torch.float, 55 is_dynamic=True, 56) 57 58onednn_weight_only_qint8_dtype_config = DTypeConfig( 59 input_dtype=torch.float, 60 output_dtype=torch.float, 61 weight_dtype=torch.qint8, 62) 63 64onednn_input_output_only_quint8_dtype_config = DTypeConfig( 65 input_dtype=torch.quint8, 66 output_dtype=torch.quint8, 67 weight_dtype=torch.float, 68 bias_dtype=torch.float, 69) 70 71# =================== 72# | FUSER METHODS | 73# =================== 74 75 76def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): 77 r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module 78 Args: 79 is_qat: a flag for whether we are using quantization aware training fusion 80 or post training quantization fusion 81 linear: Module instance of type Linear 82 bn: BatchNorm1d instance that needs to be fused with the linear layer 83 leaky_relu: LeakyReLU instance that needs to be fused with the linear layer 84 Examples:: 85 >>> # xdoctest: +SKIP(failing) 86 >>> m1 = nn.Linear(20, 10) 87 >>> b1 = nn.BatchNorm1d(10) 88 >>> lr = nn.LeakyReLU(0.01) 89 >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) 90 """ 91 assert ( 92 linear.training == bn.training and bn.training == leaky_relu.training 93 ), "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." 94 95 if is_qat: 96 raise NotImplementedError( 97 f"Cannot fuse train modules: {(linear, bn, leaky_relu)}" 98 ) 99 else: 100 map_to_fused_module_eval = { 101 nn.Linear: nni.LinearLeakyReLU, 102 } 103 fused_module = map_to_fused_module_eval.get(type(linear), None) 104 if fused_module is not None: 105 fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn) 106 fm = fused_module(fused_linear, leaky_relu) 107 return fm 108 else: 109 raise NotImplementedError( 110 f"Cannot fuse eval modules: {(linear, bn, leaky_relu)}" 111 ) 112 113 114# ====================== 115# | CONFIGS FOR CONV | 116# ====================== 117observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT 118 119conv_dtype_configs = [onednn_weighted_op_int8_dtype_config] 120conv_configs = _get_conv_configs(conv_dtype_configs) 121 122# (1) Conv2d + Add 123 124# conv2d Y 125# \ / 126# add 127 128# include: 129# conv2d conv2d 130# \ / 131# add 132 133 134def _fuse_conv_add_left(is_qat, add, conv, _): 135 return nni.ConvAdd2d(conv, add) 136 137 138def _conv_add_root_node_getter_left(pattern): 139 _, conv, _ = pattern 140 return conv 141 142 143def _conv_add_extra_inputs_getter_left(pattern): 144 """get inputs pattern for extra inputs, inputs for root node 145 are assumed to be copied over from root node to the fused node 146 """ 147 _, conv, extra_input = pattern 148 return [extra_input] 149 150 151# conv2d 152# \ 153# bn Y 154# \ / 155# add 156 157 158def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _): 159 bn, conv = bn_conv 160 if is_qat: 161 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") 162 else: 163 fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) 164 return nni.ConvAdd2d(fused_conv, add) 165 166 167def _conv_bn_add_root_node_getter_left(add_pattern): 168 _, bn_conv, _ = add_pattern 169 bn, conv = bn_conv 170 return conv 171 172 173def _conv_bn_add_extra_inputs_getter_left(add_pattern): 174 """get inputs pattern for extra inputs, inputs for root node 175 are assumed to be copied over from root node to the fused node 176 """ 177 _, bn_conv, extra_input = add_pattern 178 bn, conv = bn_conv 179 return [extra_input] 180 181 182conv_add_left_optioins = itertools.product( 183 [True, False], # with_bn 184 [torch.add, operator.add], # add_op 185) 186 187for with_bn, add_op in conv_add_left_optioins: 188 if with_bn: 189 conv_configs.append( 190 BackendPatternConfig() 191 ._set_pattern_complex_format( 192 (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode) 193 ) # noqa: E131 194 .set_observation_type(observation_type) 195 .set_dtype_configs(conv_dtype_configs) 196 .set_fuser_method(_fuse_conv_bn_add_left) 197 ._set_root_node_getter(_conv_bn_add_root_node_getter_left) 198 ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left) 199 .set_fused_module(nni.ConvAdd2d) 200 ) 201 else: 202 conv_configs.append( 203 BackendPatternConfig() 204 ._set_pattern_complex_format( 205 (add_op, nn.Conv2d, MatchAllNode) 206 ) # noqa: E131 207 .set_observation_type(observation_type) 208 .set_dtype_configs(conv_dtype_configs) 209 .set_fuser_method(_fuse_conv_add_left) 210 ._set_root_node_getter(_conv_add_root_node_getter_left) 211 ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left) 212 .set_fused_module(nni.ConvAdd2d) 213 ) 214 215# Y conv2d 216# \ / 217# add 218 219 220def _fuse_conv_add_right(is_qat, add, _, conv): 221 return nni.ConvAdd2d(conv, add) 222 223 224def _conv_add_root_node_getter_right(pattern): 225 add, _, conv = pattern 226 return conv 227 228 229def _conv_add_extra_inputs_getter_right(pattern): 230 """get inputs pattern for extra inputs, inputs for root node 231 are assumed to be copied over from root node to the fused node 232 """ 233 _, extra_input, conv = pattern 234 return [extra_input] 235 236 237# conv2d 238# / 239# Y bn 240# \ / 241# add 242 243 244def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv): 245 bn, conv = bn_conv 246 if is_qat: 247 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add)}") 248 else: 249 fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) 250 return nni.ConvAdd2d(fused_conv, add) 251 252 253def _conv_bn_add_root_node_getter_right(pattern): 254 add, _, bn_conv = pattern 255 bn, conv = bn_conv 256 return conv 257 258 259def _conv_bn_add_extra_inputs_getter_right(pattern): 260 """get inputs pattern for extra inputs, inputs for root node 261 are assumed to be copied over from root node to the fused node 262 """ 263 _, extra_input, bn_conv = pattern 264 bn, conv = bn_conv 265 return [extra_input] 266 267 268conv_add_optioins = itertools.product( 269 [True, False], # with_bn 270 [torch.add, operator.add], # add_op 271) 272 273for with_bn, add_op in conv_add_optioins: 274 if with_bn: 275 conv_configs.append( 276 BackendPatternConfig() 277 ._set_pattern_complex_format( 278 (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)) 279 ) # noqa: E131 280 .set_observation_type(observation_type) 281 .set_dtype_configs(conv_dtype_configs) 282 .set_fuser_method(_fuse_conv_bn_add_right) 283 ._set_root_node_getter(_conv_bn_add_root_node_getter_right) 284 ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right) 285 .set_fused_module(nni.ConvAdd2d) 286 ) 287 else: 288 conv_configs.append( 289 BackendPatternConfig() 290 ._set_pattern_complex_format( 291 (add_op, MatchAllNode, nn.Conv2d) 292 ) # noqa: E131 293 .set_observation_type(observation_type) 294 .set_dtype_configs(conv_dtype_configs) 295 .set_fuser_method(_fuse_conv_add_right) 296 ._set_root_node_getter(_conv_add_root_node_getter_right) 297 ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right) 298 .set_fused_module(nni.ConvAdd2d) 299 ) 300 301conv_configs.append( 302 BackendPatternConfig(nni.ConvAdd2d) 303 .set_observation_type(observation_type) # noqa: E131 304 .set_dtype_configs(conv_dtype_configs) 305 .set_root_module(nn.Conv2d) 306 .set_reference_quantized_module(nnqr.Conv2d) 307) 308 309# (2) Conv2d + Add + Relu 310 311# conv2d Y 312# \ / 313# add 314# \ 315# relu 316 317 318def _fuse_conv_add_relu_left(is_qat, relu, add_pattern): 319 add, conv, _ = add_pattern 320 return nni.ConvAddReLU2d(conv, add, relu) 321 322 323def _conv_add_relu_root_node_getter_left(pattern): 324 relu, add_pattern = pattern 325 _, conv, _ = add_pattern 326 return conv 327 328 329def _conv_add_relu_extra_inputs_getter_left(pattern): 330 """get inputs pattern for extra inputs, inputs for root node 331 are assumed to be copied over from root node to the fused node 332 """ 333 relu, add_pattern = pattern 334 _, conv, extra_input = add_pattern 335 return [extra_input] 336 337 338# conv2d 339# \ 340# bn Y 341# \ / 342# add 343# \ 344# relu 345 346 347def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern): 348 add, bn_conv, _ = add_pattern 349 bn, conv = bn_conv 350 if is_qat: 351 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") 352 else: 353 fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) 354 return nni.ConvAddReLU2d(fused_conv, add, relu) 355 356 357def _conv_bn_add_relu_root_node_getter_left(pattern): 358 relu, add_pattern = pattern 359 _, bn_conv, _ = add_pattern 360 bn, conv = bn_conv 361 return conv 362 363 364def _conv_bn_add_relu_extra_inputs_getter_left(pattern): 365 """get inputs pattern for extra inputs, inputs for root node 366 are assumed to be copied over from root node to the fused node 367 """ 368 relu, add_pattern = pattern 369 _, bn_conv, extra_input = add_pattern 370 bn, conv = bn_conv 371 return [extra_input] 372 373 374conv_add_relu_left_optioins = itertools.product( 375 [True, False], # with_bn 376 [torch.add, operator.add], # add_op 377) 378 379for with_bn, add_op in conv_add_relu_left_optioins: 380 if with_bn: 381 conv_configs.append( 382 BackendPatternConfig() 383 ._set_pattern_complex_format( 384 (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) 385 ) # noqa: E131 386 .set_observation_type(observation_type) 387 .set_dtype_configs(conv_dtype_configs) 388 .set_fuser_method(_fuse_conv_bn_add_relu_left) 389 ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left) 390 ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left) 391 .set_fused_module(nni.ConvAddReLU2d) 392 ) 393 else: 394 conv_configs.append( 395 BackendPatternConfig() 396 ._set_pattern_complex_format( 397 (nn.ReLU, (add_op, nn.Conv2d, MatchAllNode)) 398 ) # noqa: E131 399 .set_observation_type(observation_type) 400 .set_dtype_configs(conv_dtype_configs) 401 .set_fuser_method(_fuse_conv_add_relu_left) 402 ._set_root_node_getter(_conv_add_relu_root_node_getter_left) 403 ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left) 404 .set_fused_module(nni.ConvAddReLU2d) 405 ) 406 407# Y conv2d 408# \ / 409# add 410# \ 411# relu 412 413 414def _fuse_conv_add_relu_right(is_qat, relu, add_pattern): 415 add, _, conv = add_pattern 416 return nni.ConvAddReLU2d(conv, add, relu) 417 418 419def _conv_add_relu_root_node_getter_right(pattern): 420 relu, add_pattern = pattern 421 _, _, conv = add_pattern 422 return conv 423 424 425def _conv_add_relu_extra_inputs_getter_right(pattern): 426 """get inputs pattern for extra inputs, inputs for root node 427 are assumed to be copied over from root node to the fused node 428 """ 429 relu, add_pattern = pattern 430 _, extra_input, conv = add_pattern 431 return [extra_input] 432 433 434# conv2d 435# / 436# Y bn 437# \ / 438# add 439# \ 440# relu 441 442 443def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern): 444 add, _, bn_conv = add_pattern 445 bn, conv = bn_conv 446 if is_qat: 447 raise NotImplementedError(f"Cannot fuse train modules: {(conv, bn, add, relu)}") 448 else: 449 fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) 450 return nni.ConvAddReLU2d(fused_conv, add, relu) 451 452 453def _conv_bn_add_relu_root_node_getter_right(pattern): 454 relu, add_pattern = pattern 455 _, _, bn_conv = add_pattern 456 bn, conv = bn_conv 457 return conv 458 459 460def _conv_bn_add_relu_extra_inputs_getter_right(pattern): 461 """get inputs pattern for extra inputs, inputs for root node 462 are assumed to be copied over from root node to the fused node 463 """ 464 relu, add_pattern = pattern 465 _, extra_input, bn_conv = add_pattern 466 bn, conv = bn_conv 467 return [extra_input] 468 469 470conv_add_relu_optioins = itertools.product( 471 [True, False], # with_bn 472 [torch.add, operator.add], # add_op 473) 474 475for with_bn, add_op in conv_add_relu_optioins: 476 if with_bn: 477 conv_configs.append( 478 BackendPatternConfig() 479 ._set_pattern_complex_format( 480 (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) 481 ) # noqa: E131 482 .set_observation_type(observation_type) 483 .set_dtype_configs(conv_dtype_configs) 484 .set_fuser_method(_fuse_conv_bn_add_relu_right) 485 ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right) 486 ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right) 487 .set_fused_module(nni.ConvAddReLU2d) 488 ) 489 else: 490 conv_configs.append( 491 BackendPatternConfig() 492 ._set_pattern_complex_format( 493 (nn.ReLU, (add_op, MatchAllNode, nn.Conv2d)) 494 ) # noqa: E131 495 .set_observation_type(observation_type) 496 .set_dtype_configs(conv_dtype_configs) 497 .set_fuser_method(_fuse_conv_add_relu_right) 498 ._set_root_node_getter(_conv_add_relu_root_node_getter_right) 499 ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right) 500 .set_fused_module(nni.ConvAddReLU2d) 501 ) 502 503conv_configs.append( 504 BackendPatternConfig(nni.ConvAddReLU2d) 505 .set_observation_type(observation_type) # noqa: E131 506 .set_dtype_configs(conv_dtype_configs) 507 .set_root_module(nn.Conv2d) 508 .set_reference_quantized_module(nnqr.Conv2d) 509) 510 511# ======================== 512# | CONFIGS FOR LINEAR | 513# ======================== 514 515linear_dtype_configs = [ 516 onednn_weighted_op_int8_dtype_config, 517 onednn_dynamic_int8_dtype_config, 518] 519linear_configs = _get_linear_configs(linear_dtype_configs) 520 521 522def _add_eltwise_fusion_configs( 523 configs, 524 root_module, 525 root_op, 526 post_module, 527 post_op, 528 dtype_configs, 529 fuser_method, 530 fused_module, 531 observation_type, 532 ref_quant_module, 533): 534 # 1 base module + op module fusion config 535 configs.append( 536 BackendPatternConfig((root_module, post_module)) 537 .set_dtype_configs(dtype_configs) # noqa: E131 538 .set_fuser_method(fuser_method) 539 .set_fused_module(fused_module) 540 ) 541 # base module + functional post op 542 configs.append( 543 BackendPatternConfig((root_module, post_op)) 544 .set_dtype_configs(dtype_configs) # noqa: E131 545 .set_fuser_method(fuser_method) 546 .set_fused_module(fused_module) 547 ) 548 549 # 2 fused module configs 550 configs.append( 551 BackendPatternConfig(fused_module) 552 .set_observation_type(observation_type) # noqa: E131 553 .set_dtype_configs(dtype_configs) 554 .set_root_module(root_module) 555 .set_reference_quantized_module(ref_quant_module) 556 ) 557 558 # 3 functional base op + post op configs 559 configs.append( 560 BackendPatternConfig((root_op, post_module)) 561 .set_observation_type(observation_type) # noqa: E131 562 .set_dtype_configs(dtype_configs) 563 ) 564 configs.append( 565 BackendPatternConfig((root_op, post_op)) 566 .set_observation_type(observation_type) # noqa: E131 567 .set_dtype_configs(dtype_configs) 568 ) 569 570 571# Configs for linear + leaky_relu fusion 572_add_eltwise_fusion_configs( 573 linear_configs, 574 nn.Linear, 575 F.linear, 576 nn.LeakyReLU, 577 F.leaky_relu, 578 linear_dtype_configs, 579 _sequential_wrapper2(nni.LinearLeakyReLU), 580 nni.LinearLeakyReLU, 581 observation_type, 582 nnqr.Linear, 583) 584 585# Configs for linear module + batchnorm + leaky_relu 586linear_configs.append( 587 BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) 588 .set_dtype_configs(linear_dtype_configs) # noqa: E131 589 .set_fuser_method(_fuse_linear_bn_leaky_relu) 590 .set_fused_module(nni.LinearLeakyReLU) 591) 592 593# Configs for linear + tanh fusion 594_add_eltwise_fusion_configs( 595 linear_configs, 596 nn.Linear, 597 F.linear, 598 nn.Tanh, 599 torch.tanh, 600 linear_dtype_configs, 601 _sequential_wrapper2(nni.LinearTanh), 602 nni.LinearTanh, 603 observation_type, 604 nnqr.Linear, 605) 606 607# =========================== 608# | CONFIGS FOR OTHER OPS | 609# =========================== 610 611binary_op_dtype_configs = [onednn_op_quint8_dtype_config] 612default_op_dtype_configs = [onednn_op_quint8_dtype_config] 613fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] 614share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config] 615rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config] 616embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config] 617layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config] 618 619# ===================== 620# | BACKEND CONFIGS | 621# ===================== 622 623 624def get_onednn_backend_config() -> BackendConfig: 625 """ 626 Return the `BackendConfig` for PyTorch's native ONEDNN backend. 627 """ 628 return ( 629 BackendConfig("onednn") 630 .set_backend_pattern_configs(conv_configs) 631 .set_backend_pattern_configs(linear_configs) 632 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 633 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 634 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 635 .set_backend_pattern_configs( 636 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 637 ) 638 .set_backend_pattern_configs( 639 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 640 ) 641 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 642 .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) 643 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 644 .set_backend_pattern_configs( 645 _get_embedding_op_configs(embedding_op_dtype_configs) 646 ) 647 ) 648 649 650__all__ = [ 651 "get_onednn_backend_config", 652] 653