1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import copy 4import functools 5import itertools 6import math 7import operator 8from typing import Any, Tuple 9 10import torch 11from torch._dynamo.utils import counters 12from torch.fx.experimental.symbolic_shapes import has_free_symbols 13from torch.fx.node import map_arg 14 15from ..lowering import lowerings as L, require_channels_last 16from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match 17from ..utils import pad_listlike 18from .freezing_patterns import register_freezing_graph_pattern 19from .post_grad import register_lowering_pattern 20 21 22aten = torch.ops.aten 23prims = torch.ops.prims 24quantized_decomposed = torch.ops.quantized_decomposed 25quantized = torch.ops.quantized 26 27# Only for per tensor quant since permute may changes the channel idx 28_PER_TENSOR_QUANTIZE_OPS = [ 29 quantized_decomposed.quantize_per_tensor.default, 30 quantized_decomposed.quantize_per_tensor.tensor, 31] 32 33_VIEW_OPS = [ 34 aten.transpose.int, 35 aten.permute.default, 36 aten.view.default, 37] 38 39""" 40The quantization.py file primarily incorporates passes related to quantization fusion 41in inductor, includes: 421. Dequant Promotion; 432. Conv/GEMM weight prepack with oneDNN Library; 443. Conv/GEMM quantization fusion with output quant node (if have); 454. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; 46 47It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference 48of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is 491. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. 502. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. 51Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 52quantization. 53""" 54 55 56def _get_pattern_output_dtype(match: Match): 57 """ 58 Get the pattern's output dtype from node's meta 59 Assume only 1 output node in this matched pattern. 60 """ 61 pattern_output_nodes = match.output_nodes() 62 assert len(pattern_output_nodes) == 1 63 output_node = pattern_output_nodes[0] 64 assert isinstance(output_node, torch.fx.Node) 65 output_dtype = output_node.meta["val"].dtype 66 assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16] 67 return output_dtype 68 69 70def _may_generate_pattern_with_dtype_convert( 71 pattern, dtype=Arg(), with_dtype_convert=True, users=1 72): 73 if with_dtype_convert: 74 return CallFunction( 75 prims.convert_element_type.default, 76 pattern, 77 dtype, 78 _users=users, 79 ) 80 else: 81 return pattern 82 83 84def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): 85 if with_reshape: 86 return CallFunction( 87 torch.ops.aten.reshape.default, 88 pattern, 89 reshape_size, 90 ) 91 else: 92 return pattern 93 94 95def _generate_linear_t_pattern( 96 _dequant_per_channel_pattern, 97 dtype, 98): 99 assert dtype in [torch.float32, torch.bfloat16] 100 t_pattern = CallFunction( 101 aten.permute.default, 102 _may_generate_pattern_with_dtype_convert( 103 _dequant_per_channel_pattern, 104 KeywordArg("autocast_wgt_dtype"), 105 dtype == torch.bfloat16, 106 ), 107 KeywordArg("permute_axes"), 108 ) 109 return t_pattern 110 111 112def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): 113 # only insert to_dtype if is_bf16 is True 114 computation_call = _may_generate_pattern_with_dtype_convert( 115 call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users 116 ) 117 return unary_fusion(computation_call) 118 119 120def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): 121 dequantize_per_tensor_activation_pattern = CallFunction( 122 quantized_decomposed.dequantize_per_tensor.tensor 123 if is_tensor_overload 124 else quantized_decomposed.dequantize_per_tensor.default, 125 KeywordArg("x"), 126 KeywordArg("x_scale"), 127 KeywordArg("x_zp"), 128 KeywordArg("x_quant_min"), 129 KeywordArg("x_quant_max"), 130 KeywordArg("x_dq_dtype"), 131 ) 132 return dequantize_per_tensor_activation_pattern 133 134 135dequantize_per_channel_weight_pattern = CallFunction( 136 quantized_decomposed.dequantize_per_channel.default, 137 KeywordArg("q_weight"), 138 KeywordArg("w_scale"), 139 KeywordArg("w_zp"), 140 KeywordArg("w_axis"), 141 KeywordArg("w_quant_min"), 142 KeywordArg("w_quant_max"), 143 KeywordArg("w_dtype"), 144) 145 146dequantize_per_channel_to_bf16_weight_pattern = ( 147 _may_generate_pattern_with_dtype_convert( 148 dequantize_per_channel_weight_pattern, 149 KeywordArg("autocast_wgt_dtype"), 150 ) 151) 152 153dequantize_per_channel_clone_weight_pattern = CallFunction( 154 aten.clone.default, 155 dequantize_per_channel_weight_pattern, 156 memory_format=KeywordArg("memory_format"), 157) 158 159dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( 160 aten.clone.default, 161 dequantize_per_channel_to_bf16_weight_pattern, 162 memory_format=KeywordArg("memory_format"), 163) 164 165 166def get_dequantize_qconv_pt2e_pattern(users=1): 167 return CallFunction( 168 torch.ops.onednn.qconv2d_pointwise.default, 169 KeywordArg("x"), 170 KeywordArg("x_scale"), # x_scale 171 KeywordArg("x_zp"), # x_zp 172 KeywordArg("packed_weight"), # packed_weight 173 KeywordArg("w_scale"), # w_scale 174 KeywordArg("w_zp"), # w_zp 175 KeywordArg("b"), # bias 176 KeywordArg("stride"), 177 KeywordArg("padding"), 178 KeywordArg("dilation"), 179 KeywordArg("groups"), 180 KeywordArg("output_scale"), # output_scale = 1.0 181 KeywordArg("output_zero_point"), # output_zero_point = 0 182 KeywordArg("output_dtype"), # output_dtype = None 183 KeywordArg("attr"), # attr = "none" 184 Arg(), # scalars 185 Arg(), # algorithm 186 _users=users, 187 ) 188 189 190def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): 191 qlinear_op = ( 192 torch.ops.onednn.qlinear_pointwise.tensor 193 if x_scale_zp_are_tensors 194 else torch.ops.onednn.qlinear_pointwise.default 195 ) 196 return CallFunction( 197 qlinear_op, 198 KeywordArg("x"), 199 KeywordArg("x_scale"), 200 KeywordArg("x_zp"), 201 KeywordArg("packed_weight"), 202 KeywordArg("w_scale"), 203 KeywordArg("w_zp"), 204 KeywordArg("b"), 205 KeywordArg("output_scale"), 206 KeywordArg("output_zero_point"), 207 KeywordArg("output_dtype"), 208 KeywordArg("postop_name"), 209 KeywordArg("postop_args"), 210 KeywordArg("postop_algorithm"), 211 _users=users, 212 ) 213 214 215dequantize_accum_pattern = CallFunction( 216 quantized_decomposed.dequantize_per_tensor.default, 217 KeywordArg("accum"), 218 KeywordArg("accum_scale"), 219 KeywordArg("accum_zp"), 220 Arg(), 221 Arg(), 222 KeywordArg("accum_dq_dtype"), 223) 224 225 226def generate_pattern_with_binary( 227 binary_post_op, 228 computation_call, 229 extra_input_pattern, 230 dtype_convert=False, 231 swap_inputs=False, 232): 233 binary_pattern = ( 234 CallFunction( 235 binary_post_op, 236 extra_input_pattern, 237 computation_call, 238 ) 239 if swap_inputs 240 else CallFunction( 241 binary_post_op, 242 computation_call, 243 extra_input_pattern, 244 ) 245 ) 246 return _may_generate_pattern_with_dtype_convert( 247 binary_pattern, 248 KeywordArg("convert_dtype_after_inplace_add"), 249 dtype_convert, 250 ) 251 252 253def generate_pattern_with_unary(computation_call, unary_post_op): 254 if unary_post_op is not None: 255 return CallFunction( 256 unary_post_op, 257 computation_call, 258 ) 259 return computation_call 260 261 262def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): 263 quantized_op_output_pattern_pt2e = CallFunction( 264 quantized_decomposed.quantize_per_tensor.default, 265 _may_generate_pattern_with_dtype_convert( 266 computation_call, 267 Arg(), 268 with_dtype_convert, 269 ), 270 KeywordArg("o_inv_scale"), 271 KeywordArg("o_zp"), 272 KeywordArg("o_qmin"), 273 KeywordArg("o_qmax"), 274 KeywordArg("o_dtype"), 275 ) 276 return quantized_op_output_pattern_pt2e 277 278 279def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): 280 if kwarg_name in check_node.kwargs: 281 actual_value = check_node.kwargs[kwarg_name] 282 return actual_value == expected_value 283 else: 284 assert len(check_node.args) >= (args_index + 1) 285 actual_value = check_node.args[args_index] 286 return actual_value == expected_value 287 288 289def _is_valid_quantized_conv2d_optimization_pattern(): 290 def fn(match): 291 output_dtype = _get_pattern_output_dtype(match) 292 if output_dtype in [torch.float32, torch.bfloat16]: 293 # Only keep matched pattern with same output_dtype 294 qconv_node_after_weight_prepack = filter_nodes( 295 match.nodes, torch.ops.onednn.qconv2d_pointwise 296 )[0] 297 return _check_node_kwarg_arg_value( 298 qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype 299 ) 300 return True 301 302 return fn 303 304 305def _register_quantized_conv_lowering( 306 pattern, 307 pass_number, 308 computation_op, 309 unary_attr, 310): 311 @register_lowering_pattern( 312 pattern, 313 extra_check=_is_valid_quantized_conv2d_optimization_pattern(), 314 pass_number=pass_number, 315 ) 316 def qconv(match: Match, *args, **kwargs): 317 # Activation QParams 318 x, x_scale, x_zp = ( 319 kwargs["x"], 320 kwargs["x_scale"], 321 kwargs["x_zp"], 322 ) 323 # Weight QParams 324 packed_weight, w_scale, w_zp = ( 325 kwargs["packed_weight"], 326 kwargs["w_scale"], 327 kwargs["w_zp"], 328 ) 329 # Conv Params 330 b, stride, padding, dilation, groups = ( 331 kwargs["b"], 332 kwargs["stride"], 333 kwargs["padding"], 334 kwargs["dilation"], 335 kwargs["groups"], 336 ) 337 output_dtype = _get_pattern_output_dtype(match) 338 assert output_dtype in [torch.uint8, torch.float32, torch.bfloat16] 339 # Output QParams 340 o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 341 o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 342 assert ( 343 kwargs["attr"] == "none" 344 ) # Expected no post op fused in weight prepack phase 345 if unary_attr.op_name == "hardtanh": 346 min_value = kwargs.get("min_value") 347 max_value = kwargs.get("max_value") 348 unary_attr.scalars_attr = [min_value, max_value] 349 350 computation_args = ( 351 x, 352 x_scale, 353 x_zp, 354 packed_weight, 355 w_scale, 356 w_zp, 357 b, 358 stride, 359 padding, 360 dilation, 361 groups, 362 o_inv_scale, 363 o_zero_point, 364 output_dtype, 365 unary_attr.op_name, 366 unary_attr.scalars_attr, 367 unary_attr.algorithm_attr, 368 ) 369 counters["inductor"]["qconv2d_unary_matcher_count"] += 1 370 counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes) 371 return L[computation_op](*computation_args) 372 373 return qconv 374 375 376def _is_valid_quantized_linear_optimization_pattern(): 377 def fn(match): 378 output_dtype = _get_pattern_output_dtype(match) 379 if output_dtype in [torch.float32, torch.bfloat16]: 380 # Only keep matched pattern with same output_dtype 381 qlinear_node_after_weight_prepack = filter_nodes( 382 match.nodes, torch.ops.onednn.qlinear_pointwise 383 )[0] 384 return _check_node_kwarg_arg_value( 385 qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype 386 ) 387 return True 388 389 return fn 390 391 392def _register_quantized_linear_lowering( 393 pattern, 394 pass_number, 395 computation_op, 396 unary_attr, 397): 398 @register_lowering_pattern( 399 pattern, 400 extra_check=_is_valid_quantized_linear_optimization_pattern(), 401 pass_number=pass_number, 402 ) 403 def qlinear(match: Match, *args, **kwargs): 404 output_dtype = _get_pattern_output_dtype(match) 405 # Activation QParams 406 x, x_scale, x_zp = ( 407 kwargs["x"], 408 kwargs["x_scale"], 409 kwargs["x_zp"], 410 ) 411 # Weight QParams 412 packed_weight, w_scale, w_zp = ( 413 kwargs["packed_weight"], 414 kwargs["w_scale"], 415 kwargs["w_zp"], 416 ) 417 418 # bias 419 b = kwargs["b"] if "b" in kwargs else None 420 421 # Output QParams 422 o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 423 o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 424 assert ( 425 kwargs["postop_name"] == "none" 426 ) # Expected no post op fused in weight prepack phase 427 428 computation_args = ( 429 x, 430 x_scale, 431 x_zp, 432 packed_weight, 433 w_scale, 434 w_zp, 435 b, 436 o_inv_scale, 437 o_zero_point, 438 output_dtype, 439 unary_attr.op_name, 440 unary_attr.scalars_attr, 441 unary_attr.algorithm_attr, 442 ) 443 counters["inductor"]["qlinear_unary_matcher_count"] += 1 444 counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes) 445 return L[computation_op](*computation_args) 446 447 return qlinear 448 449 450def _register_quantized_linear_binary_lowering( 451 pattern, 452 pass_number, 453 computation_op, 454 binary_unary_attr, 455): 456 @register_lowering_pattern( 457 pattern, 458 extra_check=_is_valid_qlinear_binary_optimization_pattern(), 459 pass_number=pass_number, 460 ) 461 def qlinear_binary(match: Match, *args, **kwargs): 462 output_dtype = _get_pattern_output_dtype(match) 463 assert output_dtype is not None 464 # Activation QParams 465 x, x_scale, x_zp = ( 466 kwargs["x"], 467 kwargs["x_scale"], 468 kwargs["x_zp"], 469 ) 470 x2 = ( 471 kwargs["accum"] 472 if binary_unary_attr.binary_op_name == "sum" 473 else kwargs["other"] 474 ) 475 x2_scale = 1.0 476 x2_zp = 0 477 # Weight QParams 478 packed_weight, w_scale, w_zp = ( 479 kwargs["packed_weight"], 480 kwargs["w_scale"], 481 kwargs["w_zp"], 482 ) 483 # bias 484 b = kwargs["b"] if "b" in kwargs else None 485 # Output QParams 486 o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 487 o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 488 489 x2.realize() 490 from .mkldnn_fusion import _can_be_inplace 491 492 binary_op_name = binary_unary_attr.binary_op_name 493 494 if binary_op_name == "sum" and not _can_be_inplace(x2): 495 # When we enable the GEMM Template, the output of QLinear 496 # will be reshaped from 2D back to 3D if the input is 3D. 497 # This causes _can_be_inplace(x2) to return False if x2 happens 498 # to be the output of QLinear in this scenario. 499 # Change the post op from sum to binary add for this case. 500 # Refer to test case: 501 # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 502 binary_op_name = "add" 503 504 computation_args = ( 505 x, 506 x_scale, 507 x_zp, 508 packed_weight, 509 w_scale, 510 w_zp, 511 x2, 512 b, 513 o_inv_scale, 514 o_zero_point, 515 output_dtype, 516 x2_scale, 517 x2_zp, 518 binary_op_name, 519 binary_unary_attr.alpha, 520 binary_unary_attr.unary_op_name, 521 binary_unary_attr.scalars_attr, 522 binary_unary_attr.algorithm_attr, 523 ) 524 counters["inductor"]["qlinear_binary_matcher_count"] += 1 525 counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes) 526 return L[computation_op](*computation_args) 527 528 return qlinear_binary 529 530 531def _is_valid_qconv_binary_optimization_pattern(): 532 return _is_valid_quantized_op_binary_optimization_pattern( 533 torch.ops.onednn.qconv2d_pointwise 534 ) 535 536 537def _is_valid_qlinear_binary_optimization_pattern(): 538 return _is_valid_quantized_op_binary_optimization_pattern( 539 torch.ops.onednn.qlinear_pointwise, 540 # we don't insert q-dq for extra input due to accuracy issues 541 extra_input_from_dequant=False, 542 ) 543 544 545def _is_valid_quantized_op_binary_optimization_pattern( 546 qop, extra_input_from_dequant=True 547): 548 # Check if it's a valid Binary Pattern for qconv2d and qlinear: 549 # * qop_pointwise should only has one users 550 # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern 551 # * the two inputs of binary node should have attribute "meta" and should be tensors 552 # * the two inputs of binary node should have the same shape 553 # * All users of the extra input in this pattern should be 554 # ancestor nodes of the compute node, except for the binary node 555 # connected to the compute node. 556 def fn(match): 557 output_dtype = _get_pattern_output_dtype(match) 558 compute_node = filter_nodes(match.nodes, qop)[0] 559 # qop_pointwise should only have one user 560 if len(compute_node.users) != 1: 561 return False 562 binary_node_inputs = next(iter(compute_node.users)).args 563 assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" 564 if output_dtype in [torch.float32, torch.bfloat16]: 565 extra_input_of_binary_node = None 566 for arg in binary_node_inputs: 567 if arg != compute_node: 568 extra_input_of_binary_node = arg 569 break 570 assert extra_input_of_binary_node is not None 571 # Extra input of binary node comes from dequant pattern 572 if extra_input_from_dequant and ( 573 (not isinstance(extra_input_of_binary_node, torch.fx.Node)) 574 or ( 575 extra_input_of_binary_node.target 576 != quantized_decomposed.dequantize_per_tensor.default 577 ) 578 ): 579 return False 580 581 # the two inputs of binary node should have attribute "meta" and should be tensors 582 if not ( 583 hasattr(binary_node_inputs[0], "meta") 584 and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] 585 ) or not ( 586 hasattr(binary_node_inputs[1], "meta") 587 and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] 588 ): 589 return False 590 # the two inputs of binary node should have the same shape 591 if ( 592 binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] 593 != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] 594 ): 595 return False 596 597 # All users of the extra input in this pattern should be 598 # ancestor nodes of the compute node, except for the binary node 599 # connected to the compute node. 600 601 from .mkldnn_fusion import _get_remaining_users 602 603 extra_input_of_pattern = ( 604 match.kwargs["other"] 605 if "other" in match.kwargs 606 else ( 607 match.kwargs["accum"] 608 if output_dtype == torch.uint8 or (not extra_input_from_dequant) 609 else match.kwargs["accum_after_dequant"] 610 ) 611 ) 612 if ( 613 len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 614 or extra_input_of_pattern == compute_node.args[0] 615 ): 616 return False 617 return True 618 619 return fn 620 621 622def _register_quantized_conv_binary_lowering( 623 pattern, 624 pass_number, 625 computation_op, 626 binary_unary_attr, 627): 628 @register_lowering_pattern( 629 pattern, 630 extra_check=_is_valid_qconv_binary_optimization_pattern(), 631 pass_number=pass_number, 632 ) 633 def qconv_binary(match: Match, *args, **kwargs): 634 output_dtype = _get_pattern_output_dtype(match) 635 assert output_dtype is not None 636 x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] 637 accum = ( 638 kwargs["accum"] 639 if output_dtype == torch.uint8 640 else kwargs["accum_after_dequant"] 641 ) 642 accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0 643 accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0 644 packed_weight, w_scale, w_zp = ( 645 kwargs["packed_weight"], 646 kwargs["w_scale"], 647 kwargs["w_zp"], 648 ) 649 b, stride, padding, dilation, groups = ( 650 kwargs["b"], 651 kwargs["stride"], 652 kwargs["padding"], 653 kwargs["dilation"], 654 kwargs["groups"], 655 ) 656 # Output QParams 657 o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0 658 o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0 659 660 accum.realize() 661 from .mkldnn_fusion import _can_be_inplace 662 663 assert _can_be_inplace( 664 accum 665 ), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." 666 667 computation_args = ( 668 x, 669 x_scale, 670 x_zp, 671 accum, 672 accum_scale, 673 accum_zp, 674 packed_weight, 675 w_scale, 676 w_zp, 677 b, 678 stride, 679 padding, 680 dilation, 681 groups, 682 o_inv_scale, 683 o_zero_point, 684 output_dtype, 685 binary_unary_attr.binary_op_name, 686 binary_unary_attr.alpha, 687 binary_unary_attr.unary_op_name, 688 binary_unary_attr.scalars_attr, 689 binary_unary_attr.algorithm_attr, 690 ) 691 counters["inductor"]["qconv2d_binary_matcher_count"] += 1 692 counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes) 693 return L[computation_op](*computation_args) 694 695 return qconv_binary 696 697 698def _register_quantization_unary_fusion(): 699 from .mkldnn_fusion import ( 700 _gelu_fusion_1 as _gelu_fusion_erf, 701 _gelu_fusion_2 as _gelu_fusion_tanh, 702 _hardswish_fusion, 703 _hardtanh_fusion, 704 _silu_fusion, 705 ) 706 707 class UnaryAttr: 708 def __init__( 709 self, op_name: str, scalars_attr=None, algorithm_attr=None 710 ) -> None: 711 self.op_name = op_name 712 self.scalars_attr = scalars_attr if scalars_attr else [] 713 self.algorithm_attr = algorithm_attr if algorithm_attr else "" 714 715 for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: 716 # QConv2d 717 # Priority 1 to match: QConv2d Unary pattern with int8 output 718 # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. 719 # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant 720 is_bf16 = original_pattern_output_dtype == torch.bfloat16 721 conv_unary_replace_patterns = { 722 UnaryAttr("none", [], ""): generate_pattern_with_output_quant( 723 get_dequantize_qconv_pt2e_pattern(1), 724 ), 725 UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( 726 generate_pattern_with_unary( 727 get_dequantize_qconv_pt2e_pattern(1), aten.relu.default 728 ), 729 ), 730 UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant( 731 _unary_fusion_pattern( 732 _hardtanh_fusion, 733 get_dequantize_qconv_pt2e_pattern(1), 734 1, 735 is_bf16, 736 ), 737 with_dtype_convert=is_bf16, 738 ), 739 UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant( 740 _unary_fusion_pattern( 741 _hardswish_fusion, 742 get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), 743 2, 744 is_bf16, 745 ), 746 with_dtype_convert=is_bf16, 747 ), 748 UnaryAttr("swish", [], ""): generate_pattern_with_output_quant( 749 _unary_fusion_pattern( 750 _silu_fusion, 751 get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), 752 2, 753 is_bf16, 754 ), 755 with_dtype_convert=is_bf16, 756 ), 757 } 758 759 for unary_attr, patterns in conv_unary_replace_patterns.items(): 760 # Register qconv2d pattern for ExternKernel Lowering 761 _register_quantized_conv_lowering( 762 patterns, 763 1, # pass_number 764 torch.ops.onednn.qconv2d_pointwise, # computation_op 765 unary_attr, # unary_attr 766 ) 767 768 # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output 769 conv_unary_replace_float_out_patterns = { 770 UnaryAttr("relu", [], ""): generate_pattern_with_unary( 771 get_dequantize_qconv_pt2e_pattern(1), aten.relu.default 772 ), 773 UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert( 774 _unary_fusion_pattern( 775 _hardtanh_fusion, 776 get_dequantize_qconv_pt2e_pattern(1), 777 1, 778 is_bf16, 779 ), 780 Arg(), 781 is_bf16, 782 ), 783 UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert( 784 _unary_fusion_pattern( 785 _hardswish_fusion, 786 get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), 787 2, 788 is_bf16, 789 ), 790 Arg(), 791 is_bf16, 792 ), 793 UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert( 794 _unary_fusion_pattern( 795 _silu_fusion, 796 get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2), 797 2, 798 is_bf16, 799 ), 800 Arg(), 801 is_bf16, 802 ), 803 } 804 805 for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): 806 # Register qconv2d pattern for ExternKernel Lowering 807 _register_quantized_conv_lowering( 808 patterns, 809 2, # pass_number 810 torch.ops.onednn.qconv2d_pointwise, # computation_op 811 unary_attr, # unary_attr 812 ) 813 814 # QLinear 815 for x_scale_zp_are_tensors in (False, True): 816 qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) 817 # Priority 1 to match: QLinear Unary pattern with int8 output 818 linear_unary_replace_patterns = { 819 UnaryAttr("none", [], ""): generate_pattern_with_output_quant( 820 qlinear_pattern, 821 ), 822 UnaryAttr("relu", [], ""): generate_pattern_with_output_quant( 823 generate_pattern_with_unary(qlinear_pattern, aten.relu.default), 824 ), 825 UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant( 826 _unary_fusion_pattern( 827 _gelu_fusion_erf, 828 get_qlinear_pt2e_pattern( 829 x_scale_zp_are_tensors, 1 if is_bf16 else 2 830 ), 831 2, 832 is_bf16, 833 ), 834 with_dtype_convert=is_bf16, 835 ), 836 UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant( 837 _unary_fusion_pattern( 838 _gelu_fusion_tanh, 839 get_qlinear_pt2e_pattern( 840 x_scale_zp_are_tensors, 1 if is_bf16 else 4 841 ), 842 4, 843 is_bf16, 844 ), 845 with_dtype_convert=is_bf16, 846 ), 847 } 848 849 for unary_attr, patterns in linear_unary_replace_patterns.items(): 850 _register_quantized_linear_lowering( 851 patterns, 852 1, # pass_number 853 torch.ops.onednn.qlinear_pointwise, # computation_op 854 unary_attr, # unary_attr 855 ) 856 857 # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output 858 linear_unary_replace_float_out_patterns = { 859 UnaryAttr("relu", [], ""): generate_pattern_with_unary( 860 qlinear_pattern, aten.relu.default 861 ), 862 UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert( 863 _unary_fusion_pattern( 864 _gelu_fusion_erf, 865 get_qlinear_pt2e_pattern( 866 x_scale_zp_are_tensors, 1 if is_bf16 else 2 867 ), 868 2, 869 is_bf16, 870 ), 871 Arg(), 872 is_bf16, 873 ), 874 UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert( 875 _unary_fusion_pattern( 876 _gelu_fusion_tanh, 877 get_qlinear_pt2e_pattern( 878 x_scale_zp_are_tensors, 1 if is_bf16 else 4 879 ), 880 4, 881 is_bf16, 882 ), 883 Arg(), 884 is_bf16, 885 ), 886 } 887 888 for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): 889 _register_quantized_linear_lowering( 890 patterns, 891 2, # pass_number 892 torch.ops.onednn.qlinear_pointwise, # computation_op 893 unary_attr, # unary_attr 894 ) 895 896 897def _register_quantization_binary_fusion(): 898 class BinaryUnaryAttr: 899 def __init__( 900 self, 901 binary_op_name: str, 902 alpha=None, 903 unary_op_name: str = "none", 904 scalars_attr=None, 905 algorithm_attr=None, 906 ) -> None: 907 self.binary_op_name = binary_op_name 908 self.alpha = alpha if alpha else 1.0 909 self.unary_op_name = unary_op_name 910 self.scalars_attr = scalars_attr if scalars_attr else [] 911 self.algorithm_attr = algorithm_attr if algorithm_attr else "" 912 913 for int8_mixed_bf16_with_inplace_add in [False, True]: 914 # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output 915 binary_replace_patterns = { 916 BinaryUnaryAttr( 917 "sum", 1.0, "none", [], "" 918 ): generate_pattern_with_output_quant( 919 generate_pattern_with_binary( 920 aten.add.Tensor, 921 get_dequantize_qconv_pt2e_pattern(1), 922 dequantize_accum_pattern, 923 int8_mixed_bf16_with_inplace_add, 924 ), 925 ), 926 BinaryUnaryAttr( 927 "sum", 1.0, "relu", [], "" 928 ): generate_pattern_with_output_quant( 929 generate_pattern_with_unary( 930 generate_pattern_with_binary( 931 aten.add.Tensor, 932 get_dequantize_qconv_pt2e_pattern(1), 933 dequantize_accum_pattern, 934 int8_mixed_bf16_with_inplace_add, 935 ), 936 aten.relu.default, 937 ), 938 ), 939 } 940 941 for binary_unary_attr, patterns in binary_replace_patterns.items(): 942 _register_quantized_conv_binary_lowering( 943 patterns, 944 0, # pass_number 945 torch.ops.onednn.qconv2d_pointwise.binary, # computation_op 946 binary_unary_attr, # binary_unary_attr 947 ) 948 949 # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output 950 binary_replace_float_out_patterns = { 951 BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( 952 generate_pattern_with_binary( 953 aten.add.Tensor, 954 get_dequantize_qconv_pt2e_pattern(1), 955 KeywordArg("accum_after_dequant"), 956 int8_mixed_bf16_with_inplace_add, 957 ), 958 aten.relu.default, 959 ), 960 } 961 962 for ( 963 binary_unary_attr, 964 patterns, 965 ) in binary_replace_float_out_patterns.items(): 966 if int8_mixed_bf16_with_inplace_add: 967 _register_quantized_conv_binary_lowering( 968 patterns, 969 0, # pass_number 970 torch.ops.onednn.qconv2d_pointwise.binary, # computation_op 971 binary_unary_attr, # binary_unary_attr 972 ) 973 else: 974 _register_quantized_conv_binary_lowering( 975 patterns, 976 1, # pass_number 977 torch.ops.onednn.qconv2d_pointwise.binary, # computation_op 978 binary_unary_attr, # binary_unary_attr 979 ) 980 981 # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output 982 binary_replace_float_out_patterns = { 983 BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary( 984 aten.add.Tensor, 985 get_dequantize_qconv_pt2e_pattern(1), 986 KeywordArg("accum_after_dequant"), 987 int8_mixed_bf16_with_inplace_add, 988 ), 989 } 990 991 for ( 992 binary_unary_attr, 993 patterns, 994 ) in binary_replace_float_out_patterns.items(): 995 _register_quantized_conv_binary_lowering( 996 patterns, 997 1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number 998 torch.ops.onednn.qconv2d_pointwise.binary, # computation_op 999 binary_unary_attr, # binary_unary_attr 1000 ) 1001 1002 # QLinear 1003 r""" 1004 Supported linear-binary(-unary) patterns 1005 1006 linear(X) extra input 1007 \ / 1008 Add 1009 | 1010 Optional(relu) 1011 | 1012 Y 1013 1014 1. int8-mixed-fp32 1015 +---+---------------+-----------+------------------------------+---------+ 1016 | # | Add type | Quant out | Pattern | Post op | 1017 +---+---------------+-----------+------------------------------+---------+ 1018 | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | 1019 +---+---------------+-----------+------------------------------+---------+ 1020 | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | 1021 +---+---------------+-----------+------------------------------+---------+ 1022 1023 2. int8-mixed-bf16 1024 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1025 | # | X2 dtype | Add type | Quant out | Pattern | Post op | 1026 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1027 | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | 1028 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1029 | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | 1030 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1031 | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | 1032 | | | In-place right| | | | 1033 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1034 | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | 1035 | | | In-place right| | | | 1036 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1037 | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | 1038 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1039 | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | 1040 +---+----------+---------------+-----------+-----------------------------------------+---------+ 1041 1042 Note 1043 (1) The positions of linear and the extra input can be swapped. 1044 (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the 1045 extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. 1046 """ 1047 for x_scale_zp_are_tensors in (False, True): 1048 qlinear_binary_op = ( 1049 torch.ops.onednn.qlinear_pointwise.binary_tensor 1050 if x_scale_zp_are_tensors 1051 else torch.ops.onednn.qlinear_pointwise.binary 1052 ) 1053 unary_postop_list = ["none", "relu"] 1054 unary_postop_dict = { 1055 "none": None, 1056 "relu": aten.relu.default, 1057 } 1058 convert_dtype_after_binary_list = [False, True] 1059 1060 # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output 1061 # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, 1062 # totally 3 patterns (2 are identical) 1063 swap_binary_inputs_list = [False, True] 1064 int8_mixed_bf16_list = [False, True] 1065 combinations = itertools.product( 1066 unary_postop_list, 1067 int8_mixed_bf16_list, 1068 swap_binary_inputs_list, 1069 convert_dtype_after_binary_list, 1070 ) 1071 qlinear_binary_replace_patterns = {} 1072 for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: 1073 if not int8_mixed_bf16 and cvt_dtype_binary: 1074 # No convert node after binary node if dtypes are all fp32 1075 continue 1076 qlinear_binary_replace_patterns.update( 1077 { 1078 BinaryUnaryAttr( 1079 "add", 1.0, unary_op, [], "" 1080 ): generate_pattern_with_output_quant( 1081 generate_pattern_with_unary( 1082 generate_pattern_with_binary( 1083 aten.add.Tensor, 1084 get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), 1085 KeywordArg("other"), 1086 # If fp32 extra input is inplace added to bf16 linear output, 1087 # a to_bf16 node is inserted after binary 1088 dtype_convert=cvt_dtype_binary, 1089 swap_inputs=swap_inputs, 1090 ), 1091 unary_postop_dict[unary_op], 1092 ), 1093 ) 1094 } 1095 ) 1096 for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): 1097 _register_quantized_linear_binary_lowering( 1098 patterns, 1099 0, # pass_number 1100 qlinear_binary_op, # computation_op 1101 binary_unary_attr, # binary_unary_attr 1102 ) 1103 1104 # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output 1105 # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, 1106 # totally 2 patterns (2 are identical) 1107 binary_replace_float_out_patterns = {} 1108 for swap_binary_inputs in swap_binary_inputs_list: 1109 binary_replace_float_out_patterns.update( 1110 { 1111 BinaryUnaryAttr( 1112 "sum", 1.0, "relu", [], "" 1113 ): generate_pattern_with_unary( 1114 generate_pattern_with_binary( 1115 aten.add.Tensor, 1116 get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), 1117 KeywordArg("accum"), 1118 dtype_convert=False, 1119 swap_inputs=swap_binary_inputs, 1120 ), 1121 aten.relu.default, 1122 ), 1123 } 1124 ) 1125 for ( 1126 binary_unary_attr, 1127 patterns, 1128 ) in binary_replace_float_out_patterns.items(): 1129 _register_quantized_linear_binary_lowering( 1130 patterns, 1131 1, # pass_number 1132 qlinear_binary_op, # computation_op 1133 binary_unary_attr, 1134 ) 1135 # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output 1136 # Covers case (6) of int8-mixed-bf16 1137 binary_replace_float_out_patterns = {} 1138 for swap_binary_inputs in swap_binary_inputs_list: 1139 binary_replace_float_out_patterns.update( 1140 { 1141 BinaryUnaryAttr( 1142 "add", 1.0, "relu", [], "" 1143 ): generate_pattern_with_unary( 1144 generate_pattern_with_binary( 1145 aten.add.Tensor, 1146 get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), 1147 KeywordArg("other"), 1148 dtype_convert=True, 1149 swap_inputs=swap_binary_inputs, 1150 ), 1151 aten.relu.default, 1152 ), 1153 } 1154 ) 1155 for ( 1156 binary_unary_attr, 1157 patterns, 1158 ) in binary_replace_float_out_patterns.items(): 1159 _register_quantized_linear_binary_lowering( 1160 patterns, 1161 1, # pass_number 1162 qlinear_binary_op, # computation_op 1163 binary_unary_attr, 1164 ) 1165 1166 # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output 1167 # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, 1168 # totally 2 patterns (2 are identical) 1169 binary_replace_float_out_patterns = {} 1170 for swap_binary_inputs in swap_binary_inputs_list: 1171 binary_replace_float_out_patterns.update( 1172 { 1173 BinaryUnaryAttr( 1174 "sum", 1.0, "none", [], "" 1175 ): generate_pattern_with_binary( 1176 aten.add.Tensor, 1177 get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), 1178 KeywordArg("accum"), 1179 dtype_convert=False, 1180 swap_inputs=swap_binary_inputs, 1181 ), 1182 } 1183 ) 1184 for ( 1185 binary_unary_attr, 1186 patterns, 1187 ) in binary_replace_float_out_patterns.items(): 1188 _register_quantized_linear_binary_lowering( 1189 patterns, 1190 2, # pass_number 1191 qlinear_binary_op, # computation_op 1192 binary_unary_attr, 1193 ) 1194 # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output 1195 # Covers (6) of int8-mixed-bf16 1196 binary_replace_float_out_patterns = {} 1197 for swap_binary_inputs in swap_binary_inputs_list: 1198 binary_replace_float_out_patterns.update( 1199 { 1200 BinaryUnaryAttr( 1201 "add", 1.0, "none", [], "" 1202 ): generate_pattern_with_binary( 1203 aten.add.Tensor, 1204 get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), 1205 KeywordArg("other"), 1206 dtype_convert=True, 1207 swap_inputs=swap_binary_inputs, 1208 ), 1209 } 1210 ) 1211 for ( 1212 binary_unary_attr, 1213 patterns, 1214 ) in binary_replace_float_out_patterns.items(): 1215 _register_quantized_linear_binary_lowering( 1216 patterns, 1217 2, # pass_number 1218 qlinear_binary_op, # computation_op 1219 binary_unary_attr, 1220 ) 1221 1222 1223def _is_valid_quantized_maxpool2d_optimization_pattern(): 1224 def fn(match): 1225 # Only match the pattern which max_pool2d_with_indices returns value 1226 # instead of indices. 1227 get_item_node = filter_nodes(match.nodes, operator.getitem)[0] 1228 return get_item_node.args[1] == 0 1229 1230 return fn 1231 1232 1233def _register_quantized_maxpool2d_lowering( 1234 pattern, 1235 computation_op, 1236): 1237 @register_lowering_pattern( 1238 pattern, 1239 extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), 1240 ) 1241 def qmaxpool2d(match: Match, *args, **kwargs): 1242 x = kwargs["x"] 1243 kernel_size = kwargs["kernel_size"] 1244 stride = kwargs["stride"] if ("stride" in kwargs) else None 1245 padding = kwargs["padding"] if ("padding" in kwargs) else 0 1246 dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 1247 ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False 1248 1249 if padding == 0: 1250 padding = [0, 0] 1251 if dilation == 1: 1252 dilation = [1, 1] 1253 if not stride: 1254 stride = kernel_size 1255 kernel_size = pad_listlike(kernel_size, 2) 1256 stride = pad_listlike(stride, 2) 1257 padding = pad_listlike(padding, 2) 1258 dilation = pad_listlike(dilation, 2) 1259 1260 assert len(kernel_size) == 2 1261 assert len(stride) == 2 1262 assert len(padding) == 2 1263 assert len(dilation) == 2 1264 1265 computation_args = ( 1266 x, 1267 kernel_size, 1268 stride, 1269 padding, 1270 dilation, 1271 ceil_mode, 1272 ) 1273 computation_args, _ = require_channels_last(computation_op, *computation_args) 1274 counters["inductor"]["qmaxpool2d_matcher_count"] += 1 1275 counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) 1276 return L[computation_op](*computation_args) 1277 1278 return qmaxpool2d 1279 1280 1281def _register_quantization_maxpool2d(): 1282 # Currently, the default parameters are not in FX Graph generated by Dynamo export. 1283 # So, if user defines nn.MaxPool2d with different assignment of default parameter, 1284 # it will generate graph with different number of input nodes and hence 1285 # different pattern to be matched. 1286 # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 1287 max_pool2d_args_list = [ 1288 [ 1289 KeywordArg("stride"), 1290 ], 1291 [ 1292 KeywordArg("stride"), 1293 KeywordArg("padding"), 1294 ], 1295 [ 1296 KeywordArg("stride"), 1297 KeywordArg("padding"), 1298 KeywordArg("dilation"), 1299 ], 1300 [ 1301 KeywordArg("stride"), 1302 KeywordArg("padding"), 1303 KeywordArg("dilation"), 1304 KeywordArg("ceil_mode"), 1305 ], 1306 ] 1307 for max_pool2d_args in max_pool2d_args_list: 1308 dequantize_maxpool2d_pattern = CallFunction( 1309 aten.max_pool2d_with_indices.default, 1310 get_dequantize_per_tensor_activation_pattern(), 1311 KeywordArg("kernel_size"), 1312 *max_pool2d_args, 1313 ) 1314 dequantize_lowmem_maxpool2d_pattern = CallFunction( 1315 prims._low_memory_max_pool2d_with_offsets.default, 1316 get_dequantize_per_tensor_activation_pattern(), 1317 KeywordArg("kernel_size"), 1318 *max_pool2d_args, 1319 KeywordArg("offset_dtype"), 1320 ) 1321 dequantize_maxpool2d_get_item_pattern = CallFunction( 1322 operator.getitem, 1323 dequantize_maxpool2d_pattern, 1324 Arg(), 1325 ) 1326 dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( 1327 operator.getitem, 1328 dequantize_lowmem_maxpool2d_pattern, 1329 Arg(), 1330 ) 1331 _register_quantized_maxpool2d_lowering( 1332 generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), 1333 quantized.max_pool2d.default, 1334 ) 1335 _register_quantized_maxpool2d_lowering( 1336 generate_pattern_with_output_quant( 1337 dequantize_lowmem_maxpool2d_get_item_pattern 1338 ), 1339 quantized.max_pool2d.default, 1340 ) 1341 1342 1343def _is_input_output_same_scale_zp(check_node): 1344 def fn(match): 1345 # Ensure all the inputs and output has same scale and zero point 1346 # Step 1: Check inputs/output zero point 1347 # Get dequant nodes at input 1348 dequant_nodes = filter_nodes( 1349 match.nodes, quantized_decomposed.dequantize_per_tensor.default 1350 ) 1351 zero_points = [node.args[2] for node in dequant_nodes] 1352 # Get quant nodes at output 1353 quant_nodes = filter_nodes( 1354 match.nodes, quantized_decomposed.quantize_per_tensor.default 1355 ) 1356 assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" 1357 zero_points.append(quant_nodes[0].args[2]) 1358 if not all(zero_point == zero_points[0] for zero_point in zero_points): 1359 return False 1360 1361 # Step 2: Check inputs/output scale 1362 scales = [node.args[1] for node in dequant_nodes] 1363 scales.append(quant_nodes[0].args[1]) 1364 if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] 1365 return False 1366 1367 return True 1368 1369 return fn 1370 1371 1372def _register_quantized_cat_lowering( 1373 pattern, 1374 computation_op, 1375): 1376 @register_lowering_pattern( 1377 pattern, 1378 extra_check=_is_input_output_same_scale_zp(aten.cat.default), 1379 ) 1380 def qcat(match: Match, inputs, dim, **kwargs): 1381 # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] 1382 uint8_inputs = [input[0] for input in inputs] 1383 counters["inductor"]["qcat_matcher_count"] += 1 1384 counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) 1385 return L[computation_op](uint8_inputs, dim) 1386 1387 return qcat 1388 1389 1390_raw_dequantize_per_tensor_activation_pattern = CallFunction( 1391 quantized_decomposed.dequantize_per_tensor.default, 1392 Arg(), 1393 Arg(), 1394 Arg(), 1395 Arg(), 1396 Arg(), 1397 Arg(), 1398) 1399 1400 1401def _register_quantization_cat(): 1402 dequantize_cat_pattern = CallFunction( 1403 aten.cat.default, 1404 ListOf(_raw_dequantize_per_tensor_activation_pattern), 1405 KeywordArg("dim"), 1406 ) 1407 _register_quantized_cat_lowering( 1408 generate_pattern_with_output_quant(dequantize_cat_pattern), 1409 aten.cat, 1410 ) 1411 1412 1413def _register_quantized_reshape_lowering( 1414 pattern, 1415 computation_op, 1416): 1417 @register_lowering_pattern( 1418 pattern, 1419 extra_check=_is_input_output_same_scale_zp(aten.reshape.default), 1420 ) 1421 def qreshape(match: Match, *args, **kwargs): 1422 qx = kwargs["x"] 1423 shape = kwargs["shape"] 1424 counters["inductor"]["qreshape_matcher_count"] += 1 1425 counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) 1426 return L[computation_op](qx, shape) 1427 1428 return qreshape 1429 1430 1431def _register_quantization_reshape(): 1432 dequantize_reshape_pattern = CallFunction( 1433 torch.ops.aten.reshape.default, 1434 get_dequantize_per_tensor_activation_pattern(), 1435 KeywordArg("shape"), 1436 ) 1437 _register_quantized_reshape_lowering( 1438 generate_pattern_with_output_quant(dequantize_reshape_pattern), 1439 aten.reshape, 1440 ) 1441 1442 1443def _is_valid_woq_optimization_pattern(): 1444 def fn(match): 1445 assert all(k in match.kwargs for k in ("x", "weight", "scales")) 1446 x = match.kwargs["x"].meta["val"] 1447 weight = match.kwargs["weight"].meta["val"] 1448 scales = match.kwargs["scales"].meta["val"] 1449 return ( 1450 # For now, we only support woq mm kernels 1451 # with x.type=bfloat16 and w.type=int8 1452 x.dtype == torch.bfloat16 1453 and weight.dtype == torch.int8 1454 and scales.dtype == torch.bfloat16 1455 # _weight_int8pack_mm kernel only supports cpu now 1456 # TODO: add cuda kernel support instead of calling mul+sum 1457 and x.device.type == "cpu" 1458 and x.device == weight.device 1459 and x.device == scales.device 1460 ) 1461 1462 return fn 1463 1464 1465def _register_woq_lowering(pattern, computation_woq, computation_reshape): 1466 @register_lowering_pattern( 1467 pattern, 1468 extra_check=_is_valid_woq_optimization_pattern(), 1469 ) 1470 def woq(match: Match, *args, **kwargs): 1471 x = kwargs["x"] 1472 weight = kwargs["weight"] 1473 scales = kwargs["scales"] 1474 counters["inductor"]["woq_matcher_count"] += 1 1475 counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) 1476 out_features = weight.get_size()[0] 1477 origin_x_size = x.get_size() 1478 x_shape = [-1, origin_x_size[-1]] 1479 out_shape = origin_x_size[:-1] + [ 1480 out_features, 1481 ] 1482 func1 = L[computation_reshape](x, x_shape) 1483 func2 = L[computation_woq](func1, weight, scales) 1484 return L[computation_reshape](func2, out_shape) 1485 1486 return woq 1487 1488 1489def _register_woq_mm_int8_pattern1(): 1490 # F.linear(x, weight.to(dtype=x.dtype)) * scales 1491 # case of dispatching to mm, with x reshape 1492 _woq_pattern = CallFunction( 1493 aten.mul.Tensor, 1494 CallFunction( 1495 aten.reshape.default, 1496 CallFunction( 1497 aten.mm.default, 1498 CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), 1499 CallFunction( 1500 aten.permute.default, 1501 CallFunction( 1502 prims.convert_element_type.default, KeywordArg("weight"), Arg() 1503 ), 1504 Arg(), 1505 ), 1506 ), 1507 Arg(), 1508 ), 1509 KeywordArg("scales"), 1510 ) 1511 _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) 1512 1513 1514def _register_woq_mm_int8_pattern2(): 1515 # F.linear(x, weight.to(dtype=x.dtype)) * scales 1516 # case of dispatching to mm, w/o x reshape 1517 _woq_pattern = CallFunction( 1518 aten.mul.Tensor, 1519 CallFunction( 1520 aten.reshape.default, 1521 CallFunction( 1522 aten.mm.default, 1523 KeywordArg("x"), 1524 CallFunction( 1525 aten.permute.default, 1526 CallFunction( 1527 prims.convert_element_type.default, KeywordArg("weight"), Arg() 1528 ), 1529 Arg(), 1530 ), 1531 ), 1532 Arg(), 1533 ), 1534 KeywordArg("scales"), 1535 ) 1536 _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) 1537 1538 1539def _register_woq_mm_int8_pattern3(): 1540 # F.linear(x, weight.to(dtype=x.dtype)) * scales 1541 # case of dispatching to bmm 1542 _woq_pattern = CallFunction( 1543 aten.mul.Tensor, 1544 CallFunction( 1545 aten.bmm.default, 1546 CallFunction(aten.expand.default, KeywordArg("x"), Arg()), 1547 CallFunction( 1548 aten.expand.default, 1549 CallFunction( 1550 aten.permute.default, 1551 CallFunction( 1552 prims.convert_element_type.default, KeywordArg("weight"), Arg() 1553 ), 1554 Arg(), 1555 ), 1556 Arg(), 1557 ), 1558 ), 1559 KeywordArg("scales"), 1560 ) 1561 _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) 1562 1563 1564def _register_quantization_lowerings(): 1565 _register_quantization_unary_fusion() 1566 _register_quantization_binary_fusion() 1567 _register_quantization_maxpool2d() 1568 _register_quantization_cat() 1569 _register_quantization_reshape() 1570 1571 1572def _register_woq_lowerings(): 1573 _register_woq_mm_int8_pattern1() 1574 _register_woq_mm_int8_pattern2() 1575 _register_woq_mm_int8_pattern3() 1576 1577 1578def _is_valid_dequant_promotion_pattern(dtype=torch.float32): 1579 def _inner(match): 1580 assert dtype in [torch.float32, torch.bfloat16] 1581 dequant_pattern_end_node = match.output_node() 1582 if dequant_pattern_end_node.target not in [ 1583 quantized_decomposed.dequantize_per_tensor.default, 1584 quantized_decomposed.dequantize_per_tensor.tensor, 1585 prims.convert_element_type.default, 1586 aten.reshape.default, 1587 ]: 1588 return False 1589 1590 if dequant_pattern_end_node.target is aten.reshape.default: 1591 dequant_node = ( 1592 dequant_pattern_end_node.args[ 1593 0 1594 ] # pattern: linear <- reshape <- dequant 1595 if dtype == torch.float32 1596 else dequant_pattern_end_node.args[0].args[ 1597 0 1598 ] # pattern: linear <- reshape <- to_bf16 <- dequant 1599 ) 1600 else: 1601 dequant_node = ( 1602 dequant_pattern_end_node # pattern: linear <- dequant 1603 if dtype == torch.float32 1604 else dequant_pattern_end_node.args[ 1605 0 1606 ] # pattern: linear <- to_bf16 <- dequant 1607 ) 1608 1609 if ( 1610 dequant_node.target 1611 in [ 1612 quantized_decomposed.dequantize_per_tensor.default, 1613 quantized_decomposed.dequantize_per_tensor.tensor, 1614 ] 1615 and len(list(dequant_pattern_end_node.users)) > 1 1616 ): 1617 # If dequant pattern has more than 1 users, then do dequant promoted 1618 return True 1619 return False 1620 1621 return _inner 1622 1623 1624def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): 1625 @register_freezing_graph_pattern( 1626 pattern, 1627 extra_check=_is_valid_dequant_promotion_pattern(dtype), 1628 pass_number=pass_number, 1629 ) 1630 def dequant_promotion(match: Match, *args, **kwargs): 1631 # Dequant_promotion will transform 1632 # graph 1: 1633 # quant 1634 # + - - - | - - - + 1635 # | dequant | 1636 # | / \ | 1637 # | node1 node2 | 1638 # + - | - - - | - + 1639 # quant quant 1640 # into: 1641 # graph 2: 1642 # quant 1643 # + - - / - \ - - + 1644 # |dequant dequant| 1645 # | | | | 1646 # | node1 node2 | 1647 # + - | - - - | - + 1648 # quant quant 1649 # In graph 1, the dequant node is shared by node1 and node2, 1650 # as a result, neither node1 nor node2 could form an int8 1651 # fusion pattern. 1652 # After this transformation, the graph 2 could hit the int8 1653 # fusion pattern: dequant-node-quant, respectively for 1654 # node1 and node2. 1655 assert dtype in [torch.float32, torch.bfloat16] 1656 1657 def clone_to_new_node(graph, source_node, user_node): 1658 # Clone the source_node to a new node 1659 # Replace user_node's input from source_node to new_node 1660 assert ( 1661 source_node.op == "call_function" 1662 ), "clone_to_new_node only support node.op call_function" 1663 with graph.inserting_before(user_node): 1664 new_node = graph.call_function( 1665 source_node.target, 1666 args=source_node.args, 1667 kwargs=source_node.kwargs, 1668 ) 1669 new_node.meta = copy.copy(source_node.meta) 1670 user_node.replace_input_with(source_node, new_node) 1671 return new_node 1672 1673 # Find the start node and end node of a dequant pattern 1674 # * End node should be the match.output_node() 1675 # * Start node should be the node of dequantize_per_tensor 1676 dequant_pattern_end_node = match.output_node() 1677 assert dequant_pattern_end_node.target in [ 1678 quantized_decomposed.dequantize_per_tensor.default, 1679 quantized_decomposed.dequantize_per_tensor.tensor, 1680 prims.convert_element_type.default, 1681 aten.reshape.default, 1682 ] 1683 1684 # For a dequant pattern, we should expect see the node list as: 1685 # * OPT(aten.reshape.default) 1686 # * OPT(prims.convert_element_type.default) (to_bf16) 1687 # * dequantize_per_tensor 1688 def _find_first_node_in_dequant_pattern(_node): 1689 if _node.target in [ 1690 quantized_decomposed.dequantize_per_tensor.default, 1691 quantized_decomposed.dequantize_per_tensor.tensor, 1692 ]: 1693 # For a dequant pattern, we expect the start node is a dequantize_per_tensor node 1694 return _node 1695 else: 1696 assert ( 1697 len(_node.args) >= 1 1698 ), "In in dequant pattern, each node should have more than 1 arg." 1699 return _find_first_node_in_dequant_pattern(_node.args[0]) 1700 1701 dequant_pattern_start_node = _find_first_node_in_dequant_pattern( 1702 dequant_pattern_end_node 1703 ) 1704 1705 assert dequant_pattern_start_node.target in [ 1706 quantized_decomposed.dequantize_per_tensor.default, 1707 quantized_decomposed.dequantize_per_tensor.tensor, 1708 ] 1709 1710 # Clone the dequant pattern for each user node 1711 graph = match.graph 1712 user_node_list = list(dequant_pattern_end_node.users) 1713 for user_node in user_node_list[1:]: 1714 _source_node = dequant_pattern_end_node 1715 _user_node = user_node 1716 while _source_node != dequant_pattern_start_node.args[0]: 1717 _user_node = clone_to_new_node(graph, _source_node, _user_node) 1718 _source_node = _source_node.args[0] # type: ignore[assignment] 1719 1720 counters["inductor"]["dequant_promotion_matcher_count"] += 1 1721 counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) 1722 1723 1724def _is_valid_dequant_conv2d_pattern(dtype): 1725 def _inner(match): 1726 # Here we do some further check to ensure: 1727 # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. 1728 # 2. The dequant pattern has only 1 user of conv2d node. 1729 # If these conditions don't meet, we will not 1730 # insert weight prepack node into the matched pattern. 1731 conv_node = match.output_node() 1732 assert conv_node.target is aten.convolution.default 1733 input_meta_value = conv_node.args[0].meta.get("val") 1734 weight_meta_value = conv_node.args[1].meta.get("val") 1735 for meta_value in [input_meta_value, weight_meta_value]: 1736 if ( 1737 meta_value is None 1738 or meta_value.device.type != "cpu" 1739 or meta_value.dim() != 4 1740 ): 1741 # Only support conv2d now 1742 return False 1743 1744 assert dtype in [torch.float32, torch.bfloat16] 1745 1746 if dtype == torch.float32: 1747 dequant_node = conv_node.args[0] 1748 else: 1749 convert_to_bf16 = conv_node.args[0] 1750 dequant_node = convert_to_bf16.args[0] 1751 1752 if len(list(dequant_node.users)) != 1: 1753 # Ensure the dequant pattern only has 1 user 1754 # since we will delete the dequant pattern here 1755 return False 1756 return True 1757 1758 return _inner 1759 1760 1761def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): 1762 @register_freezing_graph_pattern( 1763 pattern, 1764 extra_check=_is_valid_dequant_conv2d_pattern(dtype), 1765 pass_number=pass_number, 1766 ) 1767 def qconv_weight_prepack(match: Match, *args, **kwargs): 1768 """ 1769 Match the pattern: 1770 int8 activation 1771 | 1772 dequant_per_tensor 1773 | 1774 Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight 1775 1776 Insert weight prepack node and change the pattern to: 1777 int8 activation 1778 | 1779 onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight 1780 """ 1781 assert dtype in [torch.float32, torch.bfloat16] 1782 conv_node = match.output_node() 1783 assert conv_node.target is aten.convolution.default 1784 if dtype == torch.float32: 1785 dequant_node = conv_node.args[0] 1786 else: 1787 convert_to_bf16 = conv_node.args[0] 1788 dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] 1789 has_clone_to_channel_last_node_in_pattern = ( 1790 conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] 1791 ) 1792 clone_node = ( 1793 conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None 1794 ) 1795 1796 if dtype == torch.float32: 1797 dequant_per_channel = ( 1798 clone_node.args[0] # type: ignore[union-attr] 1799 if has_clone_to_channel_last_node_in_pattern 1800 else conv_node.args[1] 1801 ) 1802 else: 1803 weight_to_bf16_node = ( 1804 clone_node.args[0] # type: ignore[union-attr] 1805 if has_clone_to_channel_last_node_in_pattern 1806 else conv_node.args[1] 1807 ) 1808 dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] 1809 1810 assert ( 1811 dequant_per_channel.target # type: ignore[union-attr] 1812 is quantized_decomposed.dequantize_per_channel.default 1813 ) 1814 1815 # Activation QParams 1816 qx, x_zp, x_scale = ( 1817 kwargs["x"], 1818 kwargs["x_zp"], 1819 kwargs["x_scale"], 1820 ) 1821 1822 # Weight QParams 1823 qw, w_scale, w_zp = ( 1824 kwargs["q_weight"], 1825 kwargs["w_scale"], 1826 kwargs["w_zp"], 1827 ) 1828 1829 # Conv Params 1830 bias, stride, padding, dilation, groups = ( 1831 kwargs["b"], 1832 kwargs["stride"], 1833 kwargs["padding"], 1834 kwargs["dilation"], 1835 kwargs["groups"], 1836 ) 1837 1838 x_shape = qx.meta.get("tensor_meta").shape 1839 if has_free_symbols(x_shape): 1840 # For dynamic shape case, we can't get activation shape ahead of runtime. 1841 x_shape = None 1842 graph = match.graph 1843 with graph.inserting_before(conv_node): 1844 # Insert weight prepack node and the QConv node 1845 packed_weight_inputs = ( 1846 qw, 1847 w_scale, 1848 x_scale, 1849 x_zp, 1850 stride, 1851 padding, 1852 dilation, 1853 groups, 1854 x_shape, 1855 ) 1856 packed_weight_op = torch.ops.onednn.qconv_prepack 1857 prepack_weight_node = graph.call_function( 1858 packed_weight_op, args=packed_weight_inputs 1859 ) 1860 1861 new_args: Tuple[Any, ...] = ( 1862 qx, 1863 x_scale, 1864 x_zp, 1865 prepack_weight_node, 1866 w_scale, 1867 w_zp, 1868 bias, 1869 stride, 1870 padding, 1871 dilation, 1872 groups, 1873 1.0, # output_scale 1874 0, # output_zero_point 1875 dtype, # output_dtype 1876 "none", # attr 1877 [], # scalars 1878 "", # algorithm 1879 ) 1880 new_conv_node = graph.call_function( 1881 torch.ops.onednn.qconv2d_pointwise.default, args=new_args 1882 ) 1883 conv_node.replace_all_uses_with(new_conv_node) 1884 new_conv_node.meta.update(conv_node.meta) 1885 1886 # Erase the original conv node 1887 graph.erase_node(conv_node) 1888 # Erase the dequant pattern 1889 if dtype == torch.bfloat16: 1890 graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] 1891 graph.erase_node(dequant_node) # type: ignore[arg-type] 1892 # Erase the dequant per channel pattern 1893 if clone_node is not None: 1894 graph.erase_node(clone_node) # type: ignore[arg-type] 1895 if dtype == torch.bfloat16: 1896 graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] 1897 graph.erase_node(dequant_per_channel) # type: ignore[arg-type] 1898 counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 1899 counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( 1900 match.nodes 1901 ) 1902 1903 1904def _generate_dequant_convolution_node_pattern( 1905 _dequant_per_channel_pattern, dtype=torch.float32 1906): 1907 assert dtype in [torch.float32, torch.bfloat16] 1908 dequant_convolution_node_pattern = CallFunction( 1909 aten.convolution.default, 1910 _may_generate_pattern_with_dtype_convert( 1911 get_dequantize_per_tensor_activation_pattern(), 1912 KeywordArg("autocast_act_dtype"), 1913 dtype == torch.bfloat16, 1914 ), 1915 _dequant_per_channel_pattern, 1916 KeywordArg("b"), 1917 KeywordArg("stride"), 1918 KeywordArg("padding"), 1919 KeywordArg("dilation"), 1920 KeywordArg("is_transposed"), 1921 KeywordArg("out_padding"), 1922 KeywordArg("groups"), 1923 ) 1924 return dequant_convolution_node_pattern 1925 1926 1927def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): 1928 assert dtype in [torch.float32, torch.bfloat16] 1929 return ( 1930 _generate_dequant_convolution_node_pattern( 1931 dequantize_per_channel_weight_pattern 1932 if dtype == torch.float32 1933 else dequantize_per_channel_to_bf16_weight_pattern, 1934 dtype, 1935 ), 1936 # There is another pattern due to the pass of convert_conv_weights_to_channels_last 1937 # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. 1938 # Depend on some heuristics, it may or may not insert to(channel_last) node 1939 # between convolution and dequant_per_channel node 1940 _generate_dequant_convolution_node_pattern( 1941 dequantize_per_channel_clone_weight_pattern 1942 if dtype == torch.float32 1943 else dequantize_per_channel_to_bf16_clone_weight_pattern, 1944 dtype, 1945 ), 1946 ) 1947 1948 1949def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): 1950 output_reshape_node = None 1951 if input_dim_exceeds_two: 1952 if input_contiguous: 1953 output_reshape_node = match.output_node() 1954 assert output_reshape_node.target is aten.reshape.default 1955 linear_node = output_reshape_node.args[0] 1956 else: 1957 linear_nodes = filter_nodes(match.nodes, aten.bmm.default) 1958 assert len(linear_nodes) == 1 1959 linear_node = linear_nodes[0] 1960 else: 1961 linear_node = match.output_node() 1962 1963 assert linear_node.target in ( 1964 aten.addmm.default, 1965 aten.mm.default, 1966 aten.bmm.default, 1967 ) 1968 return linear_node, output_reshape_node 1969 1970 1971def _get_linear_dq_node( 1972 linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous 1973): 1974 act_reshape_node = None 1975 activation_to_bf16_node = None 1976 act_expand_node = None 1977 if input_dim_exceeds_two: 1978 if input_contiguous: 1979 act_reshape_node = linear_node.args[input_index] 1980 assert act_reshape_node.target is aten.reshape.default 1981 if dtype == torch.float32: 1982 # pattern: linear -> reshape -> dequant 1983 dequant_node = act_reshape_node.args[0] 1984 else: 1985 # pattern: linear -> reshape -> to_bf16 -> dequant 1986 activation_to_bf16_node = act_reshape_node.args[0] 1987 dequant_node = activation_to_bf16_node.args[0] 1988 else: 1989 # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous 1990 act_expand_node = linear_node.args[input_index] 1991 assert act_expand_node.target is aten.expand.default 1992 if dtype == torch.float32: 1993 dequant_node = act_expand_node.args[0] 1994 else: 1995 activation_to_bf16_node = act_expand_node.args[0] 1996 dequant_node = activation_to_bf16_node.args[0] 1997 else: 1998 if dtype == torch.float32: 1999 # pattern: linear -> dequant 2000 dequant_node = linear_node.args[input_index] 2001 else: 2002 # pattern: linear -> to_bf16 -> dequant 2003 activation_to_bf16_node = linear_node.args[input_index] 2004 dequant_node = activation_to_bf16_node.args[0] 2005 return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node 2006 2007 2008def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): 2009 def _inner(match): 2010 # Check dequant pattern has only 1 user. 2011 ( 2012 linear_node, 2013 _, 2014 ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) 2015 2016 input_index = 1 if linear_node.target is aten.addmm.default else 0 2017 assert dtype in [torch.float32, torch.bfloat16] 2018 ( 2019 dequant_node, 2020 _, 2021 _, 2022 _, 2023 ) = _get_linear_dq_node( 2024 linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous 2025 ) 2026 2027 assert dequant_node.target in [ 2028 quantized_decomposed.dequantize_per_tensor.default, 2029 quantized_decomposed.dequantize_per_tensor.tensor, 2030 ] 2031 2032 if len(list(dequant_node.users)) != 1: 2033 # Ensure the dequant pattern only has 1 user 2034 # since we will delete the dequant pattern here 2035 return False 2036 2037 # Extra check for bmm pattern 2038 if input_dim_exceeds_two and not input_contiguous: 2039 # Check for act 2040 # Act expand size should be exactly same as act size 2041 act_expand_size = match.kwargs["act_expand_size"] 2042 act_node = match.kwargs["x"] 2043 if not ( 2044 hasattr(act_node, "meta") 2045 and isinstance(act_node.meta.get("val", None), torch.Tensor) 2046 and (act_node.meta["val"].size() == torch.Size(act_expand_size)) 2047 ): 2048 return False 2049 2050 # Check for wgt 2051 # wgt permute dims should be [1, 0] 2052 wgt_permute_dims = match.kwargs["permute_axes"] 2053 if wgt_permute_dims != [1, 0]: 2054 return False 2055 2056 # Check below wgt size items: 2057 # wgt before expand should with dim 2 2058 # Expand size should with dim 3 2059 # Expand size[0] should same as act size[0] 2060 # Expand size[1] should same as wgt size[1] 2061 # Expand size[2] should same as wgt size[0] 2062 qweight_node = match.kwargs["q_weight"] 2063 wgt_expand_size = match.kwargs["wgt_expand_size"] 2064 if not ( 2065 hasattr(qweight_node, "meta") 2066 and isinstance(qweight_node.meta.get("val", None), torch.Tensor) 2067 and len(qweight_node.meta["val"].size()) == 2 2068 and len(wgt_expand_size) == 3 2069 and wgt_expand_size[0] == act_node.meta["val"].size()[0] 2070 and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] 2071 and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] 2072 ): 2073 return False 2074 2075 return True 2076 2077 return _inner 2078 2079 2080def _register_qlinear_weight_prepack_pass( 2081 pattern, 2082 pass_number, 2083 dtype=torch.float32, 2084 input_dim_exceeds_two=False, 2085 input_contiguous=True, 2086): 2087 @register_freezing_graph_pattern( 2088 pattern, 2089 extra_check=_is_valid_dequant_linear_pattern( 2090 dtype, input_dim_exceeds_two, input_contiguous 2091 ), 2092 pass_number=pass_number, 2093 ) 2094 def qlinear_weight_prepack(match: Match, *args, **kwargs): 2095 """ 2096 Match the pattern: 2097 int8 activation 2098 | 2099 dequant_per_tensor 2100 | 2101 mm/addmm <- t <- dequant_per_channel <- int8_weight 2102 2103 Insert weight prepack node and change the pattern to: 2104 int8 activation 2105 | 2106 onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight 2107 """ 2108 assert dtype in [torch.float32, torch.bfloat16] 2109 ( 2110 linear_node, 2111 output_reshape_node, 2112 ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) 2113 input_index = 1 if linear_node.target is aten.addmm.default else 0 2114 weight_index = input_index + 1 2115 2116 ( 2117 dequant_node, 2118 act_reshape_node, 2119 activation_to_bf16_node, 2120 act_expand_node, 2121 ) = _get_linear_dq_node( 2122 linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous 2123 ) 2124 2125 if input_dim_exceeds_two and not input_contiguous: 2126 wgt_expand_node = linear_node.args[weight_index] 2127 assert wgt_expand_node.target is aten.expand.default 2128 t_node = wgt_expand_node.args[0] 2129 else: 2130 t_node = linear_node.args[weight_index] 2131 2132 if dtype == torch.float32: 2133 dequant_per_channel = t_node.args[0] 2134 else: 2135 weight_to_bf16_node = t_node.args[0] 2136 dequant_per_channel = weight_to_bf16_node.args[0] 2137 assert ( 2138 dequant_per_channel.target 2139 is quantized_decomposed.dequantize_per_channel.default 2140 ) 2141 2142 # Activation QParams 2143 qx, x_zp, x_scale = ( 2144 kwargs["x"], 2145 kwargs["x_zp"], 2146 kwargs["x_scale"], 2147 ) 2148 2149 # Weight QParams 2150 qw, w_scale, w_zp = ( 2151 kwargs["q_weight"], 2152 kwargs["w_scale"], 2153 kwargs["w_zp"], 2154 ) 2155 2156 # Params 2157 bias = kwargs["b"] if "b" in kwargs else None 2158 2159 x_shape = qx.meta.get("tensor_meta").shape 2160 if has_free_symbols(x_shape): 2161 # For dynamic shape case, we can't get activation shape ahead of runtime. 2162 x_shape = None 2163 graph = match.graph 2164 with graph.inserting_before(linear_node): 2165 # Insert weight prepack node and the qlinear node 2166 packed_weight_inputs = ( 2167 qw, 2168 x_shape, 2169 ) 2170 packed_weight_op = torch.ops.onednn.qlinear_prepack 2171 prepack_weight_node = graph.call_function( 2172 packed_weight_op, args=packed_weight_inputs 2173 ) 2174 2175 new_args: Tuple[Any, ...] = ( 2176 qx, 2177 x_scale, 2178 x_zp, 2179 prepack_weight_node, 2180 w_scale, 2181 w_zp, 2182 bias, 2183 1.0, # output_scale 2184 0, # output_zero_point 2185 dtype, # output_dtype 2186 "none", # post op name 2187 [], # post op args 2188 "", # post op algorithm 2189 ) 2190 Node = torch.fx.node.Node 2191 if isinstance(x_scale, Node) and isinstance(x_zp, Node): 2192 new_linear_node = graph.call_function( 2193 torch.ops.onednn.qlinear_pointwise.tensor, args=new_args 2194 ) 2195 else: 2196 new_linear_node = graph.call_function( 2197 torch.ops.onednn.qlinear_pointwise.default, args=new_args 2198 ) 2199 if input_dim_exceeds_two: 2200 if input_contiguous: 2201 output_reshape_node.replace_all_uses_with(new_linear_node) 2202 new_linear_node.meta.update(output_reshape_node.meta) 2203 else: 2204 if bias: 2205 output_add_node_for_bias = match.output_node() 2206 assert output_add_node_for_bias.target is aten.add.Tensor 2207 output_add_node_for_bias.replace_all_uses_with(new_linear_node) 2208 new_linear_node.meta.update(output_add_node_for_bias.meta) 2209 else: 2210 linear_node.replace_all_uses_with(new_linear_node) 2211 new_linear_node.meta.update(linear_node.meta) 2212 else: 2213 linear_node.replace_all_uses_with(new_linear_node) 2214 new_linear_node.meta.update(linear_node.meta) 2215 2216 # Erase the original linear node 2217 if input_dim_exceeds_two: 2218 if input_contiguous: 2219 graph.erase_node(output_reshape_node) 2220 elif not input_contiguous and bias: 2221 graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] 2222 graph.erase_node(linear_node) 2223 if input_dim_exceeds_two: 2224 if input_contiguous: 2225 graph.erase_node(act_reshape_node) 2226 else: 2227 graph.erase_node(act_expand_node) 2228 graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] 2229 if dtype == torch.bfloat16: 2230 graph.erase_node(activation_to_bf16_node) 2231 # Erase the dequant pattern 2232 graph.erase_node(dequant_node) 2233 # Erase the dequant per channel pattern 2234 graph.erase_node(t_node) 2235 if dtype == torch.bfloat16: 2236 graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] 2237 graph.erase_node(dequant_per_channel) 2238 2239 counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 2240 counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( 2241 match.nodes 2242 ) 2243 2244 2245def _generate_dequant_linear_node_pattern( 2246 _dequant_per_channel_pattern, 2247 dtype=torch.float32, 2248 input_dim_exceeds_two=False, 2249 is_tensor_overload=False, 2250): 2251 assert dtype in [torch.float32, torch.bfloat16] 2252 t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) 2253 dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( 2254 CallFunction( 2255 aten.addmm.default, 2256 KeywordArg("b"), 2257 _may_generate_pattern_with_reshape( 2258 _may_generate_pattern_with_dtype_convert( 2259 get_dequantize_per_tensor_activation_pattern(is_tensor_overload), 2260 KeywordArg("autocast_act_dtype"), 2261 dtype == torch.bfloat16, 2262 ), 2263 KeywordArg("act_reshape_size"), 2264 input_dim_exceeds_two, 2265 ), 2266 t_pattern, 2267 ), 2268 KeywordArg("output_reshape_size"), 2269 input_dim_exceeds_two, 2270 ) 2271 dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( 2272 CallFunction( 2273 aten.mm.default, 2274 _may_generate_pattern_with_reshape( 2275 _may_generate_pattern_with_dtype_convert( 2276 get_dequantize_per_tensor_activation_pattern(is_tensor_overload), 2277 KeywordArg("autocast_act_dtype"), 2278 dtype == torch.bfloat16, 2279 ), 2280 KeywordArg("act_reshape_size"), 2281 input_dim_exceeds_two, 2282 ), 2283 t_pattern, 2284 ), 2285 KeywordArg("output_reshape_size"), 2286 input_dim_exceeds_two, 2287 ) 2288 return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern 2289 2290 2291def _generate_dequant_bmm_node_pattern( 2292 _dequant_per_channel_pattern, 2293 dtype=torch.float32, 2294 with_bias=False, 2295 is_tensor_overload=False, 2296): 2297 # When activation of linear dim exceed 2 and not contiguous 2298 t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) 2299 2300 assert dtype in [torch.float32, torch.bfloat16] 2301 dequant_bmm_pattern = CallFunction( 2302 aten.bmm.default, 2303 CallFunction( 2304 aten.expand.default, 2305 _may_generate_pattern_with_dtype_convert( 2306 get_dequantize_per_tensor_activation_pattern(is_tensor_overload), 2307 KeywordArg("autocast_act_dtype"), 2308 dtype == torch.bfloat16, 2309 ), 2310 KeywordArg("act_expand_size"), 2311 ), 2312 CallFunction( 2313 aten.expand.default, 2314 t_pattern, 2315 KeywordArg("wgt_expand_size"), 2316 ), 2317 ) 2318 2319 def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): 2320 if _with_bias: 2321 return CallFunction( 2322 aten.add.Tensor, 2323 _dequant_bmm_pattern, 2324 KeywordArg("b"), 2325 ) 2326 else: 2327 return _dequant_bmm_pattern 2328 2329 return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) 2330 2331 2332def _generate_qlinear_weight_prepack_patterns( 2333 dtype=torch.float32, 2334 input_dim_exceeds_two=False, 2335 input_contiguous=True, 2336 with_bias=False, 2337 is_tensor_overload=False, 2338): 2339 if input_dim_exceeds_two and not input_contiguous: 2340 return _generate_dequant_bmm_node_pattern( 2341 dequantize_per_channel_weight_pattern, 2342 dtype, 2343 with_bias, 2344 is_tensor_overload, 2345 ) 2346 else: 2347 return _generate_dequant_linear_node_pattern( 2348 dequantize_per_channel_weight_pattern, 2349 dtype, 2350 input_dim_exceeds_two, 2351 is_tensor_overload, 2352 ) 2353 2354 2355def _register_dequant_promotion(): 2356 dequant_pattern_cases = itertools.product( 2357 [torch.float32, torch.bfloat16], [True, False], [True, False] 2358 ) 2359 for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: 2360 # 4 dequantization patterns will be matched based on the dtype and input dimension size. 2361 # Case 1: int8-mixed-fp32, input dim size is 2 2362 # Case 2: int8-mixed-fp32, input dim size exceeds 2 2363 # Case 3: int8-mixed-bf16, input dim size is 2 2364 # Case 4: int8-mixed-bf16, input dim size exceeds 2 2365 # quant 2366 # + - - - - | - - - - + 2367 # | dequant | 2368 # | | | 2369 # | OPT(to_bf16) | 2370 # | | | 2371 # | OPT(reshape) | 2372 # | / \ | 2373 # | node1 node2 | 2374 # + - - | - - - | - - + 2375 # OPT(reshape) OPT(reshape) 2376 # + - - | - - - | - - + 2377 # OPT(to_fp32) OPT(to_fp32) 2378 # + - - | - - - | - - + 2379 # quant quant 2380 _register_dequant_promotion_pass( 2381 _may_generate_pattern_with_reshape( 2382 _may_generate_pattern_with_dtype_convert( 2383 get_dequantize_per_tensor_activation_pattern( 2384 is_tensor_overload=is_tensor_overload 2385 ), 2386 KeywordArg("autocast_act_dtype"), 2387 dtype == torch.bfloat16, 2388 ), 2389 KeywordArg("act_reshape_size"), 2390 with_reshape=input_dim_exceeds_two, 2391 ), 2392 pass_number=0, 2393 dtype=dtype, 2394 ) # pass_number=0 to run before weight prepack 2395 2396 2397def _register_qconv_weight_prepack(): 2398 for dtype in [torch.float32, torch.bfloat16]: 2399 weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) 2400 for weight_prepack_pattern in weight_prepack_patterns: 2401 # Register to pass_number 1, so we can do dequant promotion in pass_number 0. 2402 _register_qconv_weight_prepack_pass( 2403 weight_prepack_pattern, pass_number=1, dtype=dtype 2404 ) 2405 2406 2407def _register_qlinear_weight_prepack(): 2408 # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. 2409 # Then convert the pattern into a QLinear node with int8_fp32/bf16. 2410 # Case 1: int8-mixed-fp32, input dim size is 2 2411 # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous 2412 # Case 3: int8-mixed-bf16, input dim size is 2 2413 # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous 2414 2415 # + - - - - | - - - - - - | - - - - - + 2416 # | dq_per_tensor dq_per_channel | 2417 # | | | | 2418 # | OPT(to_bf16) OPT(to_bf16) | 2419 # | | | | 2420 # | OPT(reshape) permute | 2421 # | \ / | 2422 # | addmm/mm | 2423 # | | | 2424 # | OPT(reshape) | 2425 2426 # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous 2427 # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous 2428 2429 # + - - - - | - - - - - - | - - - - - + 2430 # | dq_per_tensor dq_per_channel | 2431 # | | | | 2432 # | OPT(to_bf16) OPT(to_bf16) | 2433 # | | | | 2434 # | expand permute | 2435 # | \ | | 2436 # | expand | 2437 # | / | 2438 # | bmm | 2439 # | | | 2440 # | OPT(add) | 2441 2442 linear_weight_prepack_cases = itertools.product( 2443 [torch.float32, torch.bfloat16], [True, False], [True, False] 2444 ) 2445 2446 # Step 1: register patterns from mm and addmm 2447 for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: 2448 weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( 2449 dtype, 2450 input_dim_exceeds_two, 2451 is_tensor_overload=is_tensor_overload, 2452 ) 2453 for weight_prepack_pattern in weight_prepack_patterns: 2454 # Register to pass_number 1, so we can do dequant promotion in pass_number 0. 2455 _register_qlinear_weight_prepack_pass( 2456 weight_prepack_pattern, 2457 pass_number=1, 2458 dtype=dtype, 2459 input_dim_exceeds_two=input_dim_exceeds_two, 2460 ) 2461 2462 # Step 2: register patterns from bmm 2463 # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous 2464 # refer to: 2465 # https://github.com/pytorch/pytorch/blob/ 2466 # 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 2467 # in this case, we can convert it back to qlinear 2468 for dtype, with_bias, is_tensor_overload in itertools.product( 2469 [torch.float32, torch.bfloat16], [True, False], [True, False] 2470 ): 2471 bmm_pattern = _generate_qlinear_weight_prepack_patterns( 2472 dtype=dtype, 2473 input_dim_exceeds_two=True, 2474 input_contiguous=False, 2475 with_bias=with_bias, 2476 is_tensor_overload=is_tensor_overload, 2477 ) 2478 _register_qlinear_weight_prepack_pass( 2479 bmm_pattern, 2480 pass_number=1 2481 if with_bias 2482 else 2, # if with_bias, there is an output add, so we should try to match it firstly 2483 dtype=dtype, 2484 input_dim_exceeds_two=True, 2485 input_contiguous=False, 2486 ) 2487 2488 2489@functools.lru_cache(None) 2490def _register_quantization_weight_pack_pass(): 2491 # Step 1: Dequant promotion for int8-mixed-fp32/bf16 2492 _register_dequant_promotion() 2493 2494 # Step 2: QConv weight prepack 2495 _register_qconv_weight_prepack() 2496 2497 # Step 3: QLinear weight prepack 2498 _register_qlinear_weight_prepack() 2499 2500 2501def quant_lift_up(graph_module: torch.fx.GraphModule): 2502 """ 2503 Lift up the quant node before view like nodes. It can benefit performance 2504 of Attention like block. For example, we have the pattern as: 2505 2506 DQ 2507 DQ LINEAR 2508 LINEAR VIEW 2509 VIEW PERMUTE 2510 PERMUTE TRANSPOSE 2511 Q Q 2512 DQ DQ 2513 Matmul 2514 DIV 2515 ADD 2516 SOFTMAX 2517 2518 We want to lift up the the quant nodes from matmul before view like nodes 2519 as the output of Linear node. 2520 2521 DQ 2522 DQ LINEAR 2523 LINEAR Q 2524 Q VIEW 2525 VIEW PERMUTE 2526 PERMUTE TRANSPOSE 2527 DQ DQ 2528 Matmul 2529 DIV 2530 ADD 2531 SOFTMAX 2532 2533 It produces a DQ->LINEAR->Q pattern which can be fused by backend. 2534 """ 2535 2536 def is_view_op(node): 2537 return node.op == "call_function" and node.target in _VIEW_OPS 2538 2539 for node in graph_module.graph.nodes: 2540 # <TODO> Leslie: Here we verify that the quant node has exactly 2541 # one input FX node, with constant scalar value for scale and zero point. 2542 # For the case input of quant node has more than one input FX nodes, 2543 # extend the implementation to lift up all the connected nodes 2544 # before the view nodes to keep the topological order. 2545 if ( 2546 node.op == "call_function" 2547 and node.target in _PER_TENSOR_QUANTIZE_OPS 2548 and len(node.all_input_nodes) == 1 2549 and is_view_op(node.all_input_nodes[0]) 2550 ): 2551 quant_node = node 2552 input_node_of_quant = quant_node.args[0] 2553 2554 # Check the nodes along lift up path has only 1 user node 2555 # Propagate view like node to find where to insert the new quant node 2556 could_lift_up = True 2557 current_node = quant_node 2558 input_node = current_node.args[0] 2559 while is_view_op(input_node): 2560 if len(input_node.users) != 1: 2561 could_lift_up = False 2562 break 2563 current_node = input_node 2564 input_node = current_node.args[0] 2565 2566 # Further check the input node of the first view node has only 1 user node 2567 if could_lift_up and len(input_node.users) == 1: 2568 # Replace dequant's input from quant to quant's input 2569 quant_node.replace_all_uses_with(input_node_of_quant) 2570 # Insert the new quant node 2571 with graph_module.graph.inserting_before(current_node): 2572 new_quant_node = graph_module.graph.node_copy(quant_node) 2573 input_node.replace_all_uses_with(new_quant_node) 2574 2575 # Update inputs of new_quant_node 2576 def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: 2577 if n == input_node_of_quant: 2578 return input_node 2579 else: 2580 return n 2581 2582 new_args = map_arg(new_quant_node.args, maybe_replace_node) 2583 new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) 2584 new_quant_node.args = new_args # type: ignore[assignment] 2585 new_quant_node.kwargs = new_kwargs # type: ignore[assignment] 2586 graph_module.graph.erase_node(quant_node) 2587 2588 graph_module.graph.lint() 2589 graph_module.recompile() 2590