1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4from typing import List, Optional 5 6import torch 7import torch.utils._pytree as pytree 8from torch._inductor.kernel.mm_common import mm_args 9 10from . import ir 11from .codegen.cpp_gemm_template import CppPackedGemmTemplate 12from .codegen.cpp_utils import create_epilogue_with_attr 13from .ir import TensorBox 14from .lowering import ( 15 add, 16 add_needs_realized_inputs, 17 aten, 18 permute, 19 register_lowering, 20 to_dtype, 21 view, 22) 23from .select_algorithm import ( 24 autotune_select_algorithm, 25 ChoiceCaller, 26 ExternKernelChoice, 27) 28from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune 29from .virtualized import ops, V 30 31 32def register_onednn_fusion_ops(): 33 if torch._C._has_mkldnn: 34 from . import mkldnn_ir 35 36 aten_mkldnn_linear_unary = ExternKernelChoice( 37 torch.ops.mkldnn._linear_pointwise, 38 "mkldnn::_linear_pointwise", 39 has_out_variant=False, 40 kernel_creator=mkldnn_ir.LinearUnary.create, 41 ) 42 aten_mkldnn_linear_binary = ExternKernelChoice( 43 torch.ops.mkldnn._linear_pointwise.binary, 44 "mkldnn::_linear_pointwise", 45 has_out_variant=False, 46 kernel_creator=mkldnn_ir.LinearBinary.create, 47 ) 48 aten_mkldnn_qlinear_unary = ExternKernelChoice( 49 torch.ops.onednn.qlinear_pointwise, 50 "onednn::qlinear_pointwise", 51 has_out_variant=False, 52 kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create, 53 ) 54 aten_mkldnn_qlinear_binary = ExternKernelChoice( 55 torch.ops.onednn.qlinear_pointwise.binary, 56 "onednn::qlinear_pointwise", 57 has_out_variant=False, 58 kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create, 59 ) 60 cpu_needs_realized_inputs = [ 61 torch.ops.mkldnn._convolution_pointwise, 62 torch.ops.mkldnn._convolution_pointwise_, 63 torch.ops.mkldnn._convolution_transpose_pointwise, 64 torch.ops.mkldnn._linear_pointwise, 65 aten.mkldnn_rnn_layer.default, 66 torch.ops.onednn.qconv2d_pointwise, 67 ] 68 69 @register_lowering(torch.ops.mkldnn._convolution_pointwise) 70 def convolution_unary( 71 x: TensorBox, 72 weight: TensorBox, 73 bias: TensorBox, 74 padding, 75 stride, 76 dilation, 77 groups, 78 attr, 79 scalars, 80 algorithm, 81 ): 82 return TensorBox.create( 83 mkldnn_ir.ConvolutionUnary.create( 84 x, 85 weight, 86 bias, 87 padding, 88 stride, 89 dilation, 90 groups, 91 attr, 92 scalars, 93 algorithm, 94 ) 95 ) 96 97 @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) 98 def convolution_binary( 99 x: TensorBox, 100 other: TensorBox, 101 weight: TensorBox, 102 bias: TensorBox, 103 padding, 104 stride, 105 dilation, 106 groups, 107 binary_attr, 108 binary_alpha, 109 unary_attr, 110 unary_scalars, 111 unary_algorithm, 112 ): 113 return TensorBox.create( 114 mkldnn_ir.ConvolutionBinary.create( 115 x, 116 other, 117 weight, 118 bias, 119 padding, 120 stride, 121 dilation, 122 groups, 123 binary_attr, 124 binary_alpha, 125 unary_attr, 126 unary_scalars, 127 unary_algorithm, 128 ) 129 ) 130 131 @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) 132 def convolution_binary_inplace( 133 x: TensorBox, 134 other: TensorBox, 135 weight: TensorBox, 136 bias: TensorBox, 137 padding, 138 stride, 139 dilation, 140 groups, 141 binary_attr, 142 binary_alpha, 143 unary_attr, 144 unary_scalars, 145 unary_algorithm, 146 ): 147 return TensorBox.create( 148 mkldnn_ir.ConvolutionBinaryInplace.create( 149 x, 150 other, 151 weight, 152 bias, 153 padding, 154 stride, 155 dilation, 156 groups, 157 binary_attr, 158 binary_alpha, 159 unary_attr, 160 unary_scalars, 161 unary_algorithm, 162 ) 163 ) 164 165 @register_lowering(torch.ops.mkldnn._linear_pointwise) 166 def linear_unary( 167 x: TensorBox, 168 w: TensorBox, 169 b: TensorBox, 170 attr, 171 scalars, 172 algorithm, 173 layout=None, 174 ): 175 x_size = x.get_size() 176 if len(x_size) > 2: 177 # GEMM template needs 2D input, normalize input shape here 178 x = view(x, [-1, x_size[-1]]) 179 if b is not None: 180 b = ir.ExternKernel.realize_input(b) 181 choices: List[ChoiceCaller] = [] 182 if use_max_autotune(): 183 transposed_w = permute(w, [1, 0]) 184 *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) 185 if use_cpp_packed_gemm_template(layout, x, transposed_w): 186 187 def epilogue_creator(buf): 188 return create_epilogue_with_attr( 189 buf, attr, scalars=scalars, algorithm=algorithm 190 ) 191 192 kwargs = dict( 193 has_bias=b is not None, 194 trans_w=True, 195 epilogue_creator=None if attr == "none" else epilogue_creator, 196 ) 197 if b is not None: 198 kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] 199 CppPackedGemmTemplate.add_choices( 200 choices, 201 layout, 202 [x, w] if b is None else [x, w, b], 203 **kwargs, # type: ignore[arg-type] 204 ) 205 if len(choices) == 0 or use_aten_gemm_kernels(): 206 kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm) 207 if b is None: 208 kwargs["B"] = None 209 choices.append( 210 aten_mkldnn_linear_unary.bind( 211 [x, w] if b is None else [x, w, b], 212 layout, 213 **kwargs, 214 ) 215 ) 216 assert w.get_name() in V.graph.constants 217 input_gen_fns = { 218 1: lambda x: V.graph.constants[x.get_name()], 219 } 220 result = autotune_select_algorithm( 221 "linear_unary", 222 choices, 223 [x, w] if b is None else [x, w, b], 224 layout, 225 input_gen_fns=input_gen_fns, 226 ) 227 if len(x_size) > 2: 228 result = view(result, (*x_size[:-1], result.get_size()[-1])) 229 return result 230 231 @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) 232 def linear_binary( 233 x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None 234 ): 235 x_size = x.get_size() 236 if len(x_size) > 2: 237 # GEMM template needs 2D input, normalize input shape here 238 x = view(x, [-1, x_size[-1]]) 239 y_size = y.get_size() 240 if len(y_size) > 2: 241 y = view(y, [-1, y_size[-1]]) 242 if b is not None: 243 b = ir.ExternKernel.realize_input(b) 244 choices: List[ChoiceCaller] = [] 245 if use_max_autotune(): 246 transposed_w = permute(w, [1, 0]) 247 *_, layout, x, transposed_w, y = mm_args( 248 x, transposed_w, y, layout=layout 249 ) 250 if use_cpp_packed_gemm_template(layout, x, transposed_w): 251 252 def epilogue_creator(buf): 253 return create_epilogue_with_attr(buf, attr, other=y) 254 255 kwargs = dict( 256 has_bias=b is not None, 257 trans_w=True, 258 epilogue_creator=epilogue_creator, 259 ) 260 kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] 261 CppPackedGemmTemplate.add_choices( 262 choices, 263 layout, 264 [x, y, w] if b is None else [x, y, w, b], 265 **kwargs, # type: ignore[arg-type] 266 ) 267 if len(choices) == 0 or use_aten_gemm_kernels(): 268 kwargs = dict(attr=attr) 269 if b is None: 270 kwargs["B"] = None 271 choices.append( 272 aten_mkldnn_linear_binary.bind( 273 [x, y, w] if b is None else [x, y, w, b], 274 layout, 275 **kwargs, 276 ) 277 ) 278 assert w.get_name() in V.graph.constants 279 input_gen_fns = { 280 2: lambda x: V.graph.constants[x.get_name()], 281 } 282 result = autotune_select_algorithm( 283 "linear_binary", 284 choices, 285 [x, y, w] if b is None else [x, y, w, b], 286 layout, 287 input_gen_fns=input_gen_fns, 288 ) 289 if len(x_size) > 2: 290 result = view(result, (*x_size[:-1], result.get_size()[-1])) 291 return result 292 293 @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) 294 def convolution_transpose_unary( 295 x: TensorBox, 296 weight: TensorBox, 297 bias: TensorBox, 298 padding, 299 output_padding, 300 stride, 301 dilation, 302 groups, 303 attr, 304 scalars, 305 algorithm, 306 ): 307 return TensorBox.create( 308 mkldnn_ir.ConvolutionTransposeUnary.create( 309 x, 310 weight, 311 bias, 312 padding, 313 output_padding, 314 stride, 315 dilation, 316 groups, 317 attr, 318 scalars, 319 algorithm, 320 ) 321 ) 322 323 @register_lowering(aten.mkldnn_rnn_layer.default) 324 def mkldnn_rnn_layer( 325 x: TensorBox, 326 w0: TensorBox, 327 w1: TensorBox, 328 w2: TensorBox, 329 w3: TensorBox, 330 hx: TensorBox, 331 cx: TensorBox, 332 reverse: bool, 333 batch_sizes: List[int], 334 mode: int, 335 hidden_size: int, 336 num_layers: int, 337 has_biases: bool, 338 bidirectional: bool, 339 batch_first: bool, 340 train: bool, 341 ): 342 return pytree.tree_map( 343 TensorBox.create, 344 mkldnn_ir.MkldnnRnnLayer.create( 345 x, 346 w0, 347 w1, 348 w2, 349 w3, 350 hx, 351 cx, 352 reverse, 353 batch_sizes, 354 mode, 355 hidden_size, 356 num_layers, 357 has_biases, 358 bidirectional, 359 batch_first, 360 train, 361 ), 362 ) 363 364 @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) 365 def qconvolution_unary( 366 x: TensorBox, 367 x_scale, 368 x_zp, 369 packed_weight: TensorBox, 370 w_scale: TensorBox, 371 w_zp: TensorBox, 372 bias: TensorBox, 373 stride, 374 padding, 375 dilation, 376 groups, 377 o_inv_scale, 378 o_zero_point, 379 output_dtype, 380 attr, 381 scalars, 382 algorithm, 383 ): 384 return TensorBox.create( 385 mkldnn_ir.QConvPointWisePT2E.create( 386 x, 387 x_scale, 388 x_zp, 389 packed_weight, 390 w_scale, 391 w_zp, 392 bias, 393 stride, 394 padding, 395 dilation, 396 groups, 397 o_inv_scale, 398 o_zero_point, 399 output_dtype, 400 attr, 401 scalars, 402 algorithm, 403 ) 404 ) 405 406 @register_lowering( 407 torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None 408 ) 409 def qconvolution_binary( 410 x: TensorBox, 411 x_scale, 412 x_zp, 413 accum: TensorBox, 414 accum_scale, 415 accum_zp, 416 packed_weight: TensorBox, 417 w_scale: TensorBox, 418 w_zp: TensorBox, 419 bias: TensorBox, 420 stride, 421 padding, 422 dilation, 423 groups, 424 o_inv_scale, 425 o_zero_point, 426 output_dtype, 427 binary_attr, 428 alpha, 429 unary_attr, 430 unary_scalars, 431 unary_algorithmm, 432 ): 433 if ( 434 binary_attr == "sum" 435 and output_dtype in [torch.float32, torch.bfloat16] 436 and accum.get_dtype() in [torch.float32, torch.bfloat16] 437 and accum.get_dtype() != output_dtype 438 ): 439 # For int8-mixed-bf16 quantization and inplace add, 440 # there is case when accum dtype is float32 but output dtype is bfloat16. 441 # Since the accum will be inplaced changed with post op sum, 442 # we will do accum dtype convertion here. 443 accum = to_dtype(accum, output_dtype) 444 return TensorBox.create( 445 mkldnn_ir.QConvPointWiseBinaryPT2E.create( 446 x, 447 x_scale, 448 x_zp, 449 accum, 450 accum_scale, 451 accum_zp, 452 packed_weight, 453 w_scale, 454 w_zp, 455 bias, 456 stride, 457 padding, 458 dilation, 459 groups, 460 o_inv_scale, 461 o_zero_point, 462 output_dtype, 463 binary_attr, 464 alpha, 465 unary_attr, 466 unary_scalars, 467 unary_algorithmm, 468 ) 469 ) 470 471 @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) 472 def qlinear_unary( 473 x: TensorBox, 474 x_scale, 475 x_zp, 476 packed_weight: TensorBox, 477 w_scale: TensorBox, 478 w_zp: TensorBox, 479 bias: TensorBox, 480 o_scale, 481 o_zero_point, 482 output_dtype, 483 attr, 484 scalars, 485 algorithm, 486 layout=None, 487 ): 488 x_size = x.get_size() 489 if len(x_size) > 2: 490 # GEMM template needs 2D input, normalize input shape here 491 x = view(x, [-1, x_size[-1]]) 492 if not isinstance(x_scale, ir.TensorBox): 493 assert type(x_scale) == float 494 x_scale = V.graph.add_tensor_constant( 495 torch.tensor(x_scale, dtype=torch.float32), name="x_scale" 496 ) 497 else: 498 x_scale.realize() 499 if not isinstance(x_zp, ir.TensorBox): 500 assert type(x_zp) == int 501 x_zp = V.graph.add_tensor_constant( 502 torch.tensor(x_zp, dtype=torch.int32), name="x_zp" 503 ) 504 else: 505 x_zp.realize() 506 507 # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer 508 # Refer to https://github.com/pytorch/pytorch/blob 509 # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 510 w_scale.realize() 511 w_zp.realize() 512 if w_zp.get_dtype() != torch.int32 and isinstance( 513 ir.InputsKernel.unwrap_storage_for_input(w_zp), 514 ir.ConstantBuffer, 515 ): 516 # W_zp might be a ConstantBuffer with int64, convert it to int32 517 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) 518 w_zp = V.graph.add_tensor_constant( 519 torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() 520 ) 521 522 bias_dtype = None if bias is None else bias.get_dtype() 523 524 choices: List[ChoiceCaller] = [] 525 if use_max_autotune(): 526 *_, layout, x, packed_weight = mm_args( 527 x, packed_weight, layout=layout, out_dtype=output_dtype 528 ) 529 if ( 530 isinstance( 531 ir.InputsKernel.unwrap_storage_for_input(x_zp), 532 ir.ConstantBuffer, 533 ) 534 and len(x_zp.get_layout().size) == 0 # Per tensor quant of act 535 and isinstance( 536 ir.InputsKernel.unwrap_storage_for_input(w_zp), 537 ir.ConstantBuffer, 538 ) 539 and torch.equal( 540 torch.zeros_like(V.graph.constants[w_zp.get_name()]), 541 V.graph.constants[w_zp.get_name()], 542 ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA 543 and use_cpp_packed_gemm_template(layout, x, packed_weight) 544 ): 545 W_tensor = V.graph.constants[packed_weight.get_name()].to_dense() 546 weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) 547 weight_compens = V.graph.add_tensor_constant( 548 weight_compens_tensor, 549 name=packed_weight.get_name() + "_BMatrixCompens", 550 ) 551 552 def epilogue_creator(input_buffer): 553 # Epilogue to convert from s32 to f32 for u8s8f32 554 assert output_dtype in [ 555 torch.float32, 556 torch.bfloat16, 557 torch.uint8, 558 ] 559 input_loader = input_buffer.make_loader() 560 weight_compens_loader = weight_compens.make_loader() 561 x_scale_loader = x_scale.make_loader() 562 w_scale_loader = w_scale.make_loader() 563 x_zp_loader = x_zp.make_loader() 564 nonlocal bias 565 bias_loader = None 566 if bias is not None: 567 bias_loader = bias.make_loader() 568 569 def inner_fn(index): 570 nonlocal bias 571 input = input_loader(index) 572 # MicroKernel Output is with int32 573 # cvt to FP32 before doing compensation 574 input = ops.to_dtype(input, torch.float32) 575 weight_compens_index = (index[-1],) 576 _x_scale = x_scale_loader(()) 577 _x_zp = x_zp_loader(()) 578 _w_scale = w_scale_loader(weight_compens_index) 579 _weight_compo = weight_compens_loader(weight_compens_index) 580 # Step 1: Doing compensation to cvt fp32 581 temp = ops.mul( 582 ops.mul( 583 input, 584 _x_scale, 585 ), 586 _w_scale, 587 ) 588 temp = ops.sub( 589 temp, 590 ops.mul( 591 ops.mul( 592 ops.mul( 593 _x_scale, 594 _w_scale, 595 ), 596 _x_zp, 597 ), 598 _weight_compo, 599 ), 600 ) 601 # Step 2: add Bias if applicable 602 if bias is not None: 603 _bias = bias_loader(weight_compens_index) 604 nonlocal bias_dtype 605 assert bias_dtype in [torch.float32, torch.bfloat16] 606 if bias_dtype == torch.bfloat16: 607 _bias = ops.to_dtype(_bias, torch.float32) 608 temp = ops.add(temp, _bias) 609 610 return temp 611 612 output_buf = ir.Pointwise( 613 device=input_buffer.get_device(), 614 dtype=torch.float32, # Hardcode to FP32 for u8s8f32 615 inner_fn=inner_fn, 616 ranges=input_buffer.get_size(), 617 ) 618 619 # Step 3: Doing the unary post op fusion 620 if attr != "none": 621 output_buf = create_epilogue_with_attr( 622 output_buf, attr, scalars=scalars, algorithm=algorithm 623 ) 624 625 # Step 4: Cast output to Target Dtype 626 if output_dtype == torch.bfloat16: 627 output_cast_loader = output_buf.make_loader() 628 629 def inner_fn_cast_output_to_bf16(index): 630 input = output_cast_loader(index) 631 return ops.to_dtype(input, output_dtype) 632 633 output_buf = ir.Pointwise( 634 device=output_buf.get_device(), 635 dtype=output_dtype, 636 inner_fn=inner_fn_cast_output_to_bf16, 637 ranges=output_buf.get_size(), 638 ) 639 elif output_dtype == torch.uint8: 640 from .lowering import _create_constants 641 642 requant_input_loader = output_buf.make_loader() 643 644 def inner_fn_requant(index, scale, zero_point): 645 input = requant_input_loader(index) 646 inv_scale, zero_point = _create_constants( 647 1.0 / scale, zero_point, dtype=torch.float32 648 ) 649 val = ops.round(input * inv_scale) + zero_point 650 qmin, qmax = _create_constants( 651 0, 255, dtype=torch.float32 652 ) 653 clamped = ops.minimum(ops.maximum(val, qmin), qmax) 654 return ops.to_dtype(clamped, torch.uint8) 655 656 output_buf = ir.Pointwise( 657 device=output_buf.get_device(), 658 dtype=output_dtype, 659 inner_fn=functools.partial( 660 inner_fn_requant, 661 scale=float(o_scale), 662 zero_point=int(o_zero_point), 663 ), 664 ranges=output_buf.get_size(), 665 ) 666 667 return output_buf 668 669 assert x.get_dtype() == torch.uint8 670 CppPackedGemmTemplate.add_choices( 671 choices, 672 layout, 673 [x, x_scale, x_zp, packed_weight, w_scale, w_zp] 674 if bias is None 675 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], 676 has_bias=bias is not None, 677 epilogue_creator=epilogue_creator, 678 input_indices=[0, 3, 1, 2, 4, 5] 679 if bias is None 680 else [6, 0, 3, 1, 2, 4, 5], 681 ) 682 if len(choices) == 0 or use_aten_gemm_kernels(): 683 kwargs = dict( 684 output_scale=o_scale, 685 output_zero_point=o_zero_point, 686 output_dtype=output_dtype, 687 post_op_name=attr, 688 post_op_args=scalars, 689 post_op_algorithm=algorithm, 690 ) 691 if bias is None: 692 kwargs["bias"] = None 693 choices.append( 694 aten_mkldnn_qlinear_unary.bind( 695 (x, x_scale, x_zp, packed_weight, w_scale, w_zp) 696 if bias is None 697 else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias), 698 layout, 699 **kwargs, 700 ) 701 ) 702 assert packed_weight.get_name() in V.graph.constants 703 input_gen_fns = { 704 3: lambda x: V.graph.constants[x.get_name()], 705 4: lambda x: V.graph.constants[x.get_name()], 706 5: lambda x: V.graph.constants[x.get_name()], 707 6: lambda x: V.graph.constants[x.get_name()], # For bias 708 } 709 result = autotune_select_algorithm( 710 "qlinear_unary", 711 choices, 712 [x, x_scale, x_zp, packed_weight, w_scale, w_zp] 713 if bias is None 714 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], 715 layout, 716 input_gen_fns=input_gen_fns, 717 ) 718 if len(x_size) > 2: 719 result = view(result, (*x_size[:-1], result.get_size()[-1])) 720 return result 721 722 @register_lowering( 723 torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None 724 ) 725 @register_lowering( 726 torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None 727 ) 728 def qlinear_binary( 729 x: TensorBox, 730 x_scale, 731 x_zp, 732 packed_weight: TensorBox, 733 w_scale: TensorBox, 734 w_zp: TensorBox, 735 x2: TensorBox, 736 bias: TensorBox, 737 o_scale, 738 o_zero_point, 739 output_dtype, 740 x2_scale, 741 x2_zp, 742 binary_attr, 743 alpha, 744 unary_attr, 745 unary_scalars, 746 unary_algorithmm, 747 layout=None, 748 ): 749 x_size = x.get_size() 750 x2_size = x2.get_size() 751 assert len(x_size) == len(x2_size) 752 if len(x_size) > 2 and binary_attr == "add": 753 # GEMM template needs 2D input, normalize input shape here 754 x = view(x, [-1, x_size[-1]]) 755 x2 = view(x2, [-1, x2_size[-1]]) 756 if not isinstance(x_scale, ir.TensorBox): 757 assert type(x_scale) == float 758 x_scale = V.graph.add_tensor_constant( 759 torch.tensor(x_scale, dtype=torch.float32), name="x_scale" 760 ) 761 else: 762 x_scale.realize() 763 if not isinstance(x_zp, ir.TensorBox): 764 assert type(x_zp) == int 765 x_zp = V.graph.add_tensor_constant( 766 torch.tensor(x_zp, dtype=torch.int32), name="x_zp" 767 ) 768 else: 769 x_zp.realize() 770 771 # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer 772 # Refer to https://github.com/pytorch/pytorch/blob 773 # /f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 774 w_scale.realize() 775 w_zp.realize() 776 if w_zp.get_dtype() != torch.int32 and isinstance( 777 ir.InputsKernel.unwrap_storage_for_input(w_zp), 778 ir.ConstantBuffer, 779 ): 780 w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) 781 w_zp = V.graph.add_tensor_constant( 782 torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() 783 ) 784 if binary_attr == "sum": 785 if output_dtype in [ 786 torch.float32, 787 torch.bfloat16, 788 ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: 789 if x2.get_dtype() != output_dtype: 790 # For int8-mixed-bf16 quantization and inplace add, 791 # there is case when accum dtype is float32 but output dtype is bfloat16. 792 # Since the accum will be inplaced changed with post op sum, 793 # we will do accum dtype convertion here. 794 x2 = to_dtype(x2, output_dtype) 795 else: 796 assert ( 797 x2.get_dtype() == output_dtype 798 ), "dtype of accum for qlinear post op sum should be the same as output" 799 x2_dtype = x2.get_dtype() 800 bias_dtype = bias.get_dtype() if bias is not None else None 801 choices: List[ChoiceCaller] = [] 802 if ( 803 use_max_autotune() and binary_attr == "add" 804 ): # <TODO> Support inplace sum fusion 805 *_, layout, x, packed_weight, x2 = mm_args( 806 x, packed_weight, x2, layout=layout, out_dtype=output_dtype 807 ) 808 if ( 809 isinstance( 810 ir.InputsKernel.unwrap_storage_for_input(x_zp), 811 ir.ConstantBuffer, 812 ) 813 and len(x_zp.get_layout().size) == 0 # Per tensor quant of act 814 and isinstance( 815 ir.InputsKernel.unwrap_storage_for_input(w_zp), 816 ir.ConstantBuffer, 817 ) 818 and torch.equal( 819 torch.zeros_like(V.graph.constants[w_zp.get_name()]), 820 V.graph.constants[w_zp.get_name()], 821 ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA 822 and use_cpp_packed_gemm_template(layout, x, packed_weight) 823 ): 824 W_tensor = V.graph.constants[packed_weight.get_name()] 825 W_tensor = W_tensor.to_dense() 826 weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) 827 weight_compens = V.graph.add_tensor_constant( 828 weight_compens_tensor, 829 name=packed_weight.get_name() + "_BMatrixCompens", 830 ) 831 832 def epilogue_creator(input_buffer): 833 # Epilogue to convert from s32 to f32 for u8s8f32 834 assert output_dtype in [ 835 torch.float32, 836 torch.bfloat16, 837 torch.uint8, 838 ] 839 840 input_loader = input_buffer.make_loader() 841 x2_loader = x2.make_loader() 842 weight_compens_loader = weight_compens.make_loader() 843 x_scale_loader = x_scale.make_loader() 844 w_scale_loader = w_scale.make_loader() 845 x_zp_loader = x_zp.make_loader() 846 nonlocal bias 847 bias_loader = None 848 if bias is not None: 849 bias_loader = bias.make_loader() 850 851 def inner_fn(index): 852 nonlocal bias 853 input = input_loader(index) 854 _x2 = x2_loader(index) 855 _x_scale = x_scale_loader(()) 856 _x_zp = x_zp_loader(()) 857 858 # MicroKernel Output is with int32 859 # cvt to FP32 before doing compensation 860 input = ops.to_dtype(input, torch.float32) 861 weight_compens_index = (index[-1],) 862 _w_scale = w_scale_loader(weight_compens_index) 863 _weight_compens = weight_compens_loader( 864 weight_compens_index 865 ) 866 # Step 1: Doing compensation to cvt fp32 867 temp = ops.mul( 868 ops.mul( 869 input, 870 _x_scale, 871 ), 872 _w_scale, 873 ) 874 temp = ops.sub( 875 temp, 876 ops.mul( 877 ops.mul( 878 ops.mul( 879 _x_scale, 880 _w_scale, 881 ), 882 _x_zp, 883 ), 884 _weight_compens, 885 ), 886 ) 887 888 # Step 2: add Bias if applicable 889 if bias is not None: 890 _bias = bias_loader(weight_compens_index) 891 nonlocal bias_dtype 892 assert bias_dtype in [torch.float32, torch.bfloat16] 893 if bias_dtype == torch.bfloat16: 894 _bias = ops.to_dtype(_bias, torch.float32) 895 temp = ops.add(temp, _bias) 896 897 # Step 3: Binary add 898 nonlocal x2_dtype 899 assert x2_dtype in [torch.float32, torch.bfloat16] 900 if x2_dtype == torch.bfloat16: 901 _x2 = ops.to_dtype(_x2, torch.float32) 902 temp = ops.add(temp, _x2) 903 904 return temp 905 906 output_buf = ir.Pointwise( 907 device=input_buffer.get_device(), 908 dtype=torch.float32, # Hardcode to FP32 for u8s8f32 909 inner_fn=inner_fn, 910 ranges=input_buffer.get_size(), 911 ) 912 913 # Step 4: Unary post op if has 914 if unary_attr != "none": 915 output_buf = create_epilogue_with_attr( 916 output_buf, 917 unary_attr, 918 scalars=unary_scalars, 919 algorithm=unary_algorithmm, 920 ) 921 922 # Step 5: Cast output to Target Dtype 923 if output_dtype == torch.bfloat16: 924 output_cast_loader = output_buf.make_loader() 925 926 def inner_fn_cast_output_to_bf16(index): 927 input = output_cast_loader(index) 928 return ops.to_dtype(input, output_dtype) 929 930 output_buf = ir.Pointwise( 931 device=output_buf.get_device(), 932 dtype=output_dtype, 933 inner_fn=inner_fn_cast_output_to_bf16, 934 ranges=output_buf.get_size(), 935 ) 936 elif output_dtype == torch.uint8: 937 from .lowering import _create_constants 938 939 requant_input_loader = output_buf.make_loader() 940 941 def inner_fn_requant(index, scale, zero_point): 942 input = requant_input_loader(index) 943 inv_scale, zero_point = _create_constants( 944 1.0 / scale, zero_point, dtype=torch.float32 945 ) 946 val = ops.round(input * inv_scale) + zero_point 947 qmin, qmax = _create_constants( 948 0, 255, dtype=torch.float32 949 ) 950 clamped = ops.minimum(ops.maximum(val, qmin), qmax) 951 return ops.to_dtype(clamped, torch.uint8) 952 953 output_buf = ir.Pointwise( 954 device=output_buf.get_device(), 955 dtype=torch.uint8, 956 inner_fn=functools.partial( 957 inner_fn_requant, 958 scale=float(o_scale), 959 zero_point=int(o_zero_point), 960 ), 961 ranges=output_buf.get_size(), 962 ) 963 964 return output_buf 965 966 CppPackedGemmTemplate.add_choices( 967 choices, 968 layout, 969 [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] 970 if bias is None 971 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], 972 has_bias=bias is not None, 973 epilogue_creator=epilogue_creator, 974 # Reorder bias and x2 975 input_indices=[0, 3, 1, 2, 4, 5, 6] 976 if bias is None 977 else [7, 0, 3, 1, 2, 4, 5, 6], 978 ) 979 980 if len(choices) == 0 or use_aten_gemm_kernels(): 981 kwargs = dict( 982 output_scale=o_scale, 983 output_zero_point=o_zero_point, 984 output_dtype=output_dtype, 985 other_scale=x2_scale, 986 other_zp=x2_zp, 987 binary_post_op=binary_attr, 988 binary_alpha=alpha, 989 unary_post_op=unary_attr, 990 unary_post_op_args=unary_scalars, 991 unary_post_op_algorithm=unary_algorithmm, 992 ) 993 if bias is None: 994 kwargs["bias"] = None 995 choices.append( 996 aten_mkldnn_qlinear_binary.bind( 997 (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2) 998 if bias is None 999 else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias), 1000 layout, 1001 **kwargs, 1002 ) 1003 ) 1004 assert packed_weight.get_name() in V.graph.constants 1005 input_gen_fns = { 1006 3: lambda x: V.graph.constants[x.get_name()], 1007 4: lambda x: V.graph.constants[x.get_name()], 1008 5: lambda x: V.graph.constants[x.get_name()], 1009 } 1010 if bias is not None: 1011 input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias 1012 result = autotune_select_algorithm( 1013 "qlinear_binary", 1014 choices, 1015 [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] 1016 if bias is None 1017 else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], 1018 layout, 1019 input_gen_fns=input_gen_fns, 1020 ) 1021 if len(x_size) > 2 and binary_attr == "add": 1022 result = view(result, (*x_size[:-1], result.get_size()[-1])) 1023 return result 1024 1025 if torch._C.has_mkl: 1026 aten_mkl_linear = ExternKernelChoice( 1027 torch.ops.mkl._mkl_linear, 1028 "mkl::_mkl_linear", 1029 has_out_variant=False, 1030 kernel_creator=mkldnn_ir.MKLPackedLinear.create, 1031 ) 1032 cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) 1033 1034 @register_lowering(torch.ops.mkl._mkl_linear) 1035 def mkl_packed_linear( 1036 x: TensorBox, 1037 packed_w: TensorBox, 1038 orig_w: TensorBox, 1039 b: Optional[TensorBox], 1040 batch_size, 1041 *, 1042 layout=None, 1043 ): 1044 choices: List[ChoiceCaller] = [] 1045 if use_max_autotune(): 1046 transposed_w = permute(orig_w, [1, 0]) 1047 *_, layout, x, transposed_w = mm_args( 1048 x, transposed_w, layout=layout 1049 ) 1050 if use_cpp_packed_gemm_template(layout, x, transposed_w): 1051 CppPackedGemmTemplate.add_choices( 1052 choices, 1053 layout, 1054 [x, packed_w, orig_w], 1055 trans_w=True, 1056 input_indices=[0, 2], 1057 ) 1058 1059 if len(choices) == 0 or use_aten_gemm_kernels(): 1060 choices.append( 1061 aten_mkl_linear.bind( 1062 (x, packed_w, orig_w), layout, B=None, batch_size=batch_size 1063 ) 1064 ) 1065 1066 assert packed_w.get_name() in V.graph.constants 1067 assert orig_w.get_name() in V.graph.constants 1068 # packed_w is a mkldnn tensor which we can't generate directly 1069 # so we use the weights from the original tensor in autotune. 1070 input_gen_fns = { 1071 1: lambda x: V.graph.constants[x.get_name()], 1072 2: lambda x: V.graph.constants[x.get_name()], 1073 } 1074 result: TensorBox = autotune_select_algorithm( 1075 "packed_linear", 1076 choices, 1077 [x, packed_w, orig_w], 1078 layout, 1079 input_gen_fns=input_gen_fns, 1080 ) 1081 if b is not None: 1082 result = add(result, b) 1083 return result 1084 1085 add_needs_realized_inputs(cpu_needs_realized_inputs) 1086 else: 1087 pass 1088