1# mypy: allow-untyped-defs 2import functools 3import operator 4from functools import reduce 5from typing import Any, Tuple 6 7import torch 8from torch.fx.experimental.symbolic_shapes import has_free_symbols 9 10from .. import ir 11from ..lowering import lowerings as L 12from ..pattern_matcher import ( 13 Arg, 14 CallFunction, 15 filter_nodes, 16 get_arg_value, 17 KeywordArg, 18 MULTIPLE, 19) 20from ..virtualized import ops, V 21from .freezing_patterns import register_freezing_graph_pattern 22from .post_grad import register_lowering_pattern 23from .quantization import ( 24 _register_quantization_lowerings, 25 _register_quantization_weight_pack_pass, 26 _register_woq_lowerings, 27) 28 29 30if torch._C._has_mkldnn: 31 aten = torch.ops.aten 32 mkldnn = torch.ops.mkldnn 33 prims = torch.ops.prims 34 35 _conv_args = [Arg() for _ in range(10)] 36 _linear_args = [Arg() for _ in range(6)] 37 _conv_transpose_args = [Arg() for _ in range(11)] 38 39 def _conv_call(users=1): 40 return CallFunction( 41 mkldnn._convolution_pointwise.default, *_conv_args, _users=users 42 ) 43 44 def _linear_call(users=1): 45 return CallFunction( 46 mkldnn._linear_pointwise.default, *_linear_args, _users=users 47 ) 48 49 def _conv_transpose_call(users=1): 50 return CallFunction( 51 mkldnn._convolution_transpose_pointwise.default, 52 *_conv_transpose_args, 53 _users=users, 54 ) 55 56 def _to_float(input_call, users=1): 57 return CallFunction( 58 prims.convert_element_type.default, 59 input_call, 60 KeywordArg("to_float"), 61 _users=users, 62 ) 63 64 def _to_bf16(input_call): 65 return CallFunction( 66 prims.convert_element_type.default, 67 input_call, 68 KeywordArg("to_bf16"), 69 _users=1, 70 ) 71 72 def _to_fp16(input_call): 73 return CallFunction( 74 prims.convert_element_type.default, 75 input_call, 76 KeywordArg("to_fp16"), 77 _users=1, 78 ) 79 80 def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): 81 # only insert to_dtype if lowp_dtype is True 82 computation_call = ( 83 _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) 84 ) 85 out = unary_fusion(computation_call) 86 if lowp_dtype == torch.bfloat16: 87 return _to_bf16(out) 88 elif lowp_dtype == torch.float16: 89 return _to_fp16(out) 90 else: 91 return out 92 93 def _gelu_fusion_1(computation_call): 94 return CallFunction( 95 aten.mul, 96 CallFunction(aten.mul, computation_call, 0.5), 97 CallFunction( 98 aten.add, 99 CallFunction( 100 aten.erf, 101 CallFunction(aten.mul, computation_call, 0.7071067811865476), 102 ), 103 1, 104 ), 105 ) 106 107 def _gelu_fusion_2(computation_call): 108 return CallFunction( 109 aten.mul, 110 CallFunction(aten.mul, computation_call, 0.5), 111 CallFunction( 112 aten.add, 113 CallFunction( 114 aten.tanh, 115 CallFunction( 116 aten.mul, 117 CallFunction( 118 aten.add, 119 computation_call, 120 CallFunction( 121 aten.mul, 122 CallFunction( 123 aten.mul, 124 CallFunction( 125 aten.mul, computation_call, computation_call 126 ), 127 computation_call, 128 ), 129 0.044715, 130 ), 131 ), 132 0.7978845608028654, 133 ), 134 ), 135 1, 136 ), 137 ) 138 139 def _hardswish_fusion(computation_call): 140 return CallFunction( 141 aten.div, 142 CallFunction( 143 aten.mul, 144 computation_call, 145 CallFunction( 146 aten.clamp_max, 147 CallFunction( 148 aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 149 ), 150 6, 151 ), 152 ), 153 6, 154 ) 155 156 def _silu_fusion(computation_call): 157 return CallFunction( 158 aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) 159 ) 160 161 def _hardsigmoid_fusion(computation_call): 162 return CallFunction( 163 aten.div, 164 CallFunction( 165 aten.clamp_max, 166 CallFunction( 167 aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 168 ), 169 6, 170 ), 171 6, 172 ) 173 174 def _leaky_relu_fusion(computation_call): 175 return CallFunction( 176 aten.where, 177 CallFunction(aten.gt, computation_call, 0), 178 computation_call, 179 CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), 180 ) 181 182 def _hardtanh_fusion(computation_call): 183 return CallFunction( 184 aten.clamp_max, 185 CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), 186 KeywordArg("max_value"), 187 ) 188 189 def _combined_fusion(computation_call, elementwise_op): 190 return CallFunction(elementwise_op, computation_call) 191 192 # binary_op(other, computation_op) 193 def _binary_fusion_v1(computation_call, binary_fn): 194 return CallFunction(binary_fn, KeywordArg("other"), computation_call) 195 196 # binary_op(computation_op, other) 197 def _binary_fusion_v2(computation_call, binary_fn): 198 return CallFunction(binary_fn, computation_call, KeywordArg("other")) 199 200 def _is_single_computation_op(computation_op, lowp_dtype=None): 201 def fn(match): 202 computation_nodes = filter_nodes(match.nodes, computation_op) 203 204 if lowp_dtype: 205 output_node_meta = match.output_node().meta.get("val") 206 if output_node_meta.dtype != lowp_dtype: 207 return False 208 209 if len(computation_nodes) < 1: 210 return False 211 if any(n.args[-3] != "none" for n in computation_nodes): 212 return False 213 return True 214 215 return fn 216 217 def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): 218 def fn(match): 219 matched = _is_single_computation_op(computation_op, lowp_dtype)(match) 220 computation_node = filter_nodes(match.nodes, computation_op)[0] 221 if lowp_dtype: 222 conversion_dtype_nodes = filter_nodes( 223 match.nodes, prims.convert_element_type.default 224 ) 225 if len(conversion_dtype_nodes) != 2: 226 return False 227 # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 228 if computation_node == conversion_dtype_nodes[0].args[0]: 229 to_float = conversion_dtype_nodes[0].args[1] 230 to_lp = conversion_dtype_nodes[1].args[1] 231 else: 232 to_float = conversion_dtype_nodes[1].args[1] 233 to_lp = conversion_dtype_nodes[0].args[1] 234 matched = matched and to_float == torch.float and to_lp == lowp_dtype 235 return matched 236 237 return fn 238 239 def _register_unary_fusion_lowering( 240 pattern, unary_attr, computation_op, lowp_dtype=None 241 ): 242 @register_lowering_pattern( 243 pattern, 244 extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), 245 ) 246 def fn(match, *args, **kwargs): 247 computation_args = list(args)[:-3] + [ 248 unary_attr.op_name, 249 unary_attr.scalars_attr, 250 unary_attr.algorithm_attr, 251 ] 252 return L[computation_op](*computation_args) 253 254 return fn 255 256 def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): 257 @register_lowering_pattern( 258 pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) 259 ) 260 def fn(match, *args, **kwargs): 261 negative_slope = kwargs.get("negative_slope") 262 if isinstance(negative_slope, ir.TensorBox): 263 matched = False 264 else: # inp is a Number 265 matched = True 266 if lowp_dtype: 267 dtype1 = kwargs.get("to_float") 268 dtype2 = ( 269 kwargs.get("to_bf16") 270 if lowp_dtype == torch.bfloat16 271 else kwargs.get("to_fp16") 272 ) 273 matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype 274 computation_args = list(args) 275 if matched: 276 computation_args = computation_args[:-3] + [ 277 "leaky_relu", 278 [negative_slope], 279 "", 280 ] 281 return L[computation_op](*computation_args) 282 else: 283 # computation_args += ["none", [], ""] 284 out = L[computation_op](*computation_args) 285 if lowp_dtype: 286 out = L[prims.convert_element_type.default](out, dtype=torch.float) 287 out = L[aten.where]( 288 L[aten.gt](out, 0), 289 out, 290 L[aten.mul](out, negative_slope), 291 ) 292 if lowp_dtype: 293 out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] 294 return out 295 296 return fn 297 298 def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): 299 @register_lowering_pattern( 300 pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) 301 ) 302 def fn(match, *args, **kwargs): 303 min_value = kwargs.get("min_value") 304 max_value = kwargs.get("max_value") 305 if isinstance(min_value, ir.TensorBox) or isinstance( 306 max_value, ir.TensorBox 307 ): 308 matched = False 309 else: # inp is a Number 310 assert max_value is not None 311 matched = min_value <= max_value 312 if lowp_dtype: 313 dtype1 = kwargs.get("to_float") 314 dtype2 = ( 315 kwargs.get("to_bf16") 316 if lowp_dtype == torch.bfloat16 317 else kwargs.get("to_fp16") 318 ) 319 matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype 320 computation_args = list(args) 321 if matched: 322 computation_args = computation_args[:-3] + [ 323 "hardtanh", 324 [min_value, max_value], 325 "", 326 ] 327 return L[computation_op](*computation_args) 328 else: 329 out = L[computation_op](*computation_args) 330 if lowp_dtype: 331 out = L[prims.convert_element_type.default](out, dtype=torch.float) 332 out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) 333 if lowp_dtype: 334 out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] 335 return out 336 337 return fn 338 339 _binary_attr = { 340 aten.add: "add", 341 ops.add: "add", 342 aten.sub: "sub", 343 ops.sub: "sub", 344 } 345 346 def _is_valid_binary(match, fn): 347 binary_nodes = filter_nodes(match.nodes, fn) 348 if len(binary_nodes) < 1: 349 return False 350 351 def get_meta_value(argument: torch.fx.node.Argument): 352 # Only torch.fx.Node is expected to have meta. 353 if isinstance(argument, torch.fx.Node): 354 return argument.meta.get("val", None) 355 return None 356 357 if any( 358 not isinstance(get_meta_value(n.args[0]), torch.Tensor) 359 or not isinstance(get_meta_value(n.args[1]), torch.Tensor) 360 for n in binary_nodes 361 ): 362 return False 363 # check alpha is one. 364 if any( 365 get_arg_value(n, 2, kwarg_name="alpha") != 1.0 366 and get_arg_value(n, 2, kwarg_name="alpha") is not None 367 for n in binary_nodes 368 ): 369 return False 370 if any( 371 get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size() 372 or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device 373 or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype 374 for n in binary_nodes 375 ): 376 return False 377 # check args[0] and args[1] is not same 378 if any(n.args[0] == n.args[1] for n in binary_nodes): 379 return False 380 return True 381 382 def _is_valid_computation_binary(computation_op, binary_op, other_index=None): 383 def fn(match): 384 if not _is_single_computation_op(computation_op)(match): 385 return False 386 if not _is_valid_binary(match, binary_op): 387 return False 388 return True 389 390 return fn 391 392 def _get_remaining_users(extra_input_node, compute_node): 393 # Think about this pattern: 394 # ReLU 395 # / \ 396 # Conv1 397 # / \ 398 # Conv2 399 # \ / 400 # Add 401 # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. 402 # The Conv1 is the ancestor node of the current compute node (Conv2). 403 # This indicates that the buffer of ReLU has completed all its usage, 404 # So we can safely make changes to it now by doing Conv2->Add inplace fusion. 405 # Take above case as example: 406 # * extra_input_node: ReLU 407 # * compute_node: Conv2 408 # _get_remaining_users will return the users of extra_input_node which are not 409 # ancestor node of compute_node. 410 def _is_ancestor_node(_current_node, _ancestor_node): 411 # Check whether _ancestor_node is the ancestor node of _current_node 412 _node_list = [_current_node] 413 _visited_nodes = set() 414 while len(_node_list) != 0: 415 _current_node = _node_list.pop(0) 416 if _current_node not in _visited_nodes: 417 _visited_nodes.add(_current_node) 418 if _current_node == _ancestor_node: 419 return True 420 elif isinstance( 421 _current_node, torch.fx.Node 422 ) and _current_node.op not in ["placeholder", "output", "get_attr"]: 423 for input in _current_node.all_input_nodes: 424 _node_list.append(input) # noqa: PERF402 425 return False 426 427 return [ 428 user 429 for user in list(extra_input_node.users) 430 if not _is_ancestor_node(compute_node, user) 431 ] 432 433 def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): 434 def fn(match): 435 if not _is_valid_computation_binary(computation_op, binary_op)(match): 436 return False 437 binary_nodes = filter_nodes(match.nodes, binary_op) 438 439 def _get_compute_node(_binary_node, _other_index): 440 assert ( 441 len(_binary_node.all_input_nodes) == 2 442 ), "Binary node should have 2 input nodes." 443 _compute_index = 1 if (_other_index == 0) else 0 444 return _binary_node.args[_compute_index] 445 446 def _other_input_not_inplaceable(_binary_node, _other_index): 447 _compute_node = _get_compute_node(_binary_node, _other_index) 448 return ( 449 len( 450 _get_remaining_users( 451 _binary_node.args[_other_index], _compute_node 452 ) 453 ) 454 > 1 455 or _binary_node.args[_other_index] == _compute_node.args[0] 456 ) 457 458 if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): 459 return False 460 if any( 461 n.args[other_index].op in ["placeholder", "output"] 462 for n in binary_nodes 463 ): 464 return False 465 return True 466 467 return fn 468 469 def _register_binary_unary_fusion_lowering( 470 pattern, 471 computation_op, 472 binary_op, 473 fusion_op, 474 unary_attr=None, 475 ): 476 @register_lowering_pattern( 477 pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) 478 ) 479 def fn(match, *args, **kwargs): 480 other = kwargs.get("other") 481 assert isinstance(other, ir.TensorBox) 482 binary_attr = _binary_attr[binary_op] 483 args_list = list(args) 484 computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] 485 if len(args_list) > 6: 486 if unary_attr is not None: 487 computation_args += [ 488 1.0, 489 unary_attr.op_name, 490 unary_attr.scalars_attr, 491 unary_attr.algorithm_attr, 492 ] 493 else: 494 computation_args += [1.0, None, [], None] 495 return L[fusion_op](*computation_args) 496 497 return fn 498 499 def _can_be_inplace(_other): 500 if isinstance(_other.data, ir.View): 501 return _can_be_inplace(_other.data) 502 else: 503 return not ( 504 isinstance(_other.data, ir.ReinterpretView) 505 or len(_other.get_inputs_that_alias_output()) > 0 506 ) 507 508 def _register_binary_unary_maybe_inplace_fusion_lowering( 509 pattern, 510 computation_op, 511 binary_op, 512 inplace_fusion_op, 513 outplace_fusion_op, 514 unary_attr=None, 515 other_index=None, 516 ): 517 @register_lowering_pattern( 518 pattern, 519 extra_check=_is_valid_computation_binary_inplace( 520 computation_op, binary_op, other_index 521 ), 522 ) 523 def fn(match, *args, **kwargs): 524 other = kwargs.get("other") 525 assert isinstance(other, ir.TensorBox) 526 binary_attr = _binary_attr[binary_op] 527 args_list = list(args) 528 computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] 529 if len(args_list) > 6: 530 if unary_attr is not None: 531 computation_args += [ 532 1.0, 533 unary_attr.op_name, 534 unary_attr.scalars_attr, 535 unary_attr.algorithm_attr, 536 ] 537 else: 538 computation_args += [1.0, None, [], None] 539 # Make sure the other is not an alias or mutation(fx side doesn't has such info). 540 other.realize() 541 if not _can_be_inplace(other): 542 return L[outplace_fusion_op](*computation_args) 543 return L[inplace_fusion_op](*computation_args) 544 545 return fn 546 547 computation_ops = [ 548 mkldnn._convolution_pointwise.default, 549 mkldnn._linear_pointwise.default, 550 mkldnn._convolution_transpose_pointwise.default, 551 ] 552 553 class UnaryAttr: 554 def __init__( 555 self, op_name: str, scalars_attr=None, algorithm_attr=None 556 ) -> None: 557 self.op_name = op_name 558 self.scalars_attr = scalars_attr if scalars_attr else [] 559 self.algorithm_attr = algorithm_attr if algorithm_attr else "" 560 561 def _register_unary_fusion(): 562 computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] 563 564 def _unary_fusion_patterns(lowp_dtype): 565 replacement_unary_fusion_patterns = { 566 UnaryAttr("gelu", algorithm_attr="tanh"): [ 567 _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) 568 for call_fn in computation_call_fns 569 ], 570 UnaryAttr("gelu", algorithm_attr="none"): [ 571 _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) 572 for call_fn in computation_call_fns 573 ], 574 UnaryAttr("hardswish"): [ 575 _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) 576 for call_fn in computation_call_fns 577 ], 578 UnaryAttr("hardsigmoid"): [ 579 _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) 580 for call_fn in computation_call_fns 581 ], 582 UnaryAttr("swish"): [ 583 _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) 584 for call_fn in computation_call_fns 585 ], 586 } 587 if not lowp_dtype: 588 call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] 589 replacement_unary_fusion_patterns.update( 590 { 591 UnaryAttr("relu"): [ 592 _combined_fusion(u, aten.relu) for u in call_user1 593 ], 594 UnaryAttr("sigmoid"): [ 595 _combined_fusion(u, aten.sigmoid) for u in call_user1 596 ], 597 UnaryAttr("tanh"): [ 598 _combined_fusion(u, aten.tanh) for u in call_user1 599 ], 600 } 601 ) 602 603 return replacement_unary_fusion_patterns 604 605 for lowp_dtype in [torch.bfloat16, torch.float16, None]: 606 replace_patterns = _unary_fusion_patterns(lowp_dtype) 607 for unary_attr, patterns in replace_patterns.items(): 608 _register_unary_fusion_lowering( 609 patterns[0], unary_attr, computation_ops[0], lowp_dtype 610 ) 611 _register_unary_fusion_lowering( 612 patterns[1], unary_attr, computation_ops[1], lowp_dtype 613 ) 614 _register_unary_fusion_lowering( 615 patterns[2], unary_attr, computation_ops[2], lowp_dtype 616 ) 617 _leaky_relu_patterns = [ 618 _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) 619 for call_fn in computation_call_fns 620 ] 621 for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): 622 _register_leaky_relu_fusion_lowering( 623 pattern, computation_op, lowp_dtype 624 ) 625 hardtanh_patterns = [ 626 _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) 627 for call_fn in computation_call_fns 628 ] 629 for pattern, computation_op in zip(hardtanh_patterns, computation_ops): 630 _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) 631 632 def _register_inplace_fusion(): 633 binary_ops = [aten.add, ops.add] 634 inplace_fusion_op = mkldnn._convolution_pointwise_.binary 635 outplace_fusion_op = mkldnn._convolution_pointwise.binary 636 conv_call = _conv_call(users=1) 637 conv_op = computation_ops[0] 638 for binary_op in binary_ops: 639 binary_v1 = _binary_fusion_v1(conv_call, binary_op) 640 binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) 641 _register_binary_unary_maybe_inplace_fusion_lowering( 642 binary_unary_v1, 643 conv_op, 644 binary_op, 645 inplace_fusion_op, 646 outplace_fusion_op, 647 other_index=0, 648 unary_attr=UnaryAttr("relu"), 649 ) 650 _register_binary_unary_maybe_inplace_fusion_lowering( 651 binary_v1, 652 conv_op, 653 binary_op, 654 inplace_fusion_op, 655 outplace_fusion_op, 656 other_index=0, 657 ) 658 binary_v2 = _binary_fusion_v2(conv_call, binary_op) 659 binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) 660 _register_binary_unary_maybe_inplace_fusion_lowering( 661 binary_unary_v2, 662 conv_op, 663 binary_op, 664 inplace_fusion_op, 665 outplace_fusion_op, 666 other_index=1, 667 unary_attr=UnaryAttr("relu"), 668 ) 669 _register_binary_unary_maybe_inplace_fusion_lowering( 670 binary_v2, 671 conv_op, 672 binary_op, 673 inplace_fusion_op, 674 outplace_fusion_op, 675 other_index=1, 676 ) 677 678 def _register_binary_fusion(): 679 binary_ops = [aten.add, ops.add, aten.sub, ops.sub] 680 fusion_ops = [ 681 mkldnn._convolution_pointwise.binary, 682 mkldnn._linear_pointwise.binary, 683 ] 684 _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] 685 for computation_call, computation_op, fusion_op in zip( 686 _computation_user_1, computation_ops[:-1], fusion_ops 687 ): 688 for binary_op in binary_ops: 689 pattern = _binary_fusion_v2(computation_call, binary_op) 690 _register_binary_unary_fusion_lowering( 691 pattern, computation_op, binary_op, fusion_op 692 ) 693 694 for binary_op in [aten.add, ops.add]: 695 pattern = _binary_fusion_v1(computation_call, binary_op) 696 _register_binary_unary_fusion_lowering( 697 pattern, computation_op, binary_op, fusion_op 698 ) 699 700 def _register_binary_unary_fusion(): 701 binary_ops = [aten.add, ops.add, aten.sub, ops.sub] 702 fusion_ops = [mkldnn._convolution_pointwise.binary] 703 _computation_user_1 = [_conv_call(users=1)] 704 for computation_call, computation_op, fusion_op in zip( 705 _computation_user_1, computation_ops[:-1], fusion_ops 706 ): 707 for binary_op in binary_ops: 708 pattern_v1 = _combined_fusion( 709 _binary_fusion_v2(computation_call, binary_op), aten.relu 710 ) 711 _register_binary_unary_fusion_lowering( 712 pattern_v1, 713 computation_op, 714 binary_op, 715 fusion_op, 716 unary_attr=UnaryAttr("relu"), 717 ) 718 for binary_op in [aten.add, ops.add]: 719 pattern_v2 = _combined_fusion( 720 _binary_fusion_v1(computation_call, binary_op), aten.relu 721 ) 722 _register_binary_unary_fusion_lowering( 723 pattern_v2, 724 computation_op, 725 binary_op, 726 fusion_op, 727 unary_attr=UnaryAttr("relu"), 728 ) 729 730 def _recover_linear(): 731 # convert reshape+linear+reshape to a single linear for applying fusion path. 732 @register_freezing_graph_pattern( 733 CallFunction( 734 aten.reshape.default, 735 CallFunction( 736 mkldnn._linear_pointwise.default, 737 CallFunction( 738 aten.reshape.default, 739 Arg(), 740 KeywordArg("reshape_1"), 741 _users=MULTIPLE, 742 ), 743 Arg(), 744 Arg(), 745 Arg(), 746 Arg(), 747 Arg(), 748 ), 749 KeywordArg("reshape_2"), 750 ), 751 pass_number=1, 752 ) 753 def reshape_linear_reshape_pattern(match, *args, **kwargs): 754 def get_val(val): 755 return val if isinstance(val, int) else val.meta.get("val") 756 757 reshape_1 = kwargs.get("reshape_1") 758 reshape_2 = kwargs.get("reshape_2") 759 assert isinstance(reshape_1, list) 760 assert isinstance(reshape_2, list) 761 assert len(reshape_1) == 2 762 763 graph = match.graph 764 reshape_2_node = match.output_node() 765 linear_input_node = reshape_2_node.args[0].args[0].args[0] 766 # check linear's input's shape[:-1] == reshape_2[:-1] 767 # and check product(reshape_2[:-1]) == reshape_1[0] 768 can_remove_reshape = linear_input_node.meta.get("val").shape[ 769 :-1 770 ] == torch.Size([get_val(val) for val in reshape_2[:-1]]) 771 can_remove_reshape = can_remove_reshape and ( 772 reduce( 773 operator.mul, 774 [get_val(val) for val in reshape_2[:-1]], 775 ) 776 == get_val(reshape_1[0]) 777 ) 778 779 if can_remove_reshape: 780 repl = graph.call_function(mkldnn._linear_pointwise.default, args) 781 repl.meta.update(reshape_2_node.meta) 782 reshape_2_node.replace_all_uses_with(repl) 783 old_linear_node = reshape_2_node.args[0] 784 reshape_1_node = old_linear_node.args[0] 785 graph.erase_node(reshape_2_node) 786 graph.erase_node(old_linear_node) 787 if len(reshape_1_node.users) == 0: 788 graph.erase_node(reshape_1_node) 789 790 def is_linear_add_bias(match): 791 add_node = match.output_node() 792 linear_node = add_node.args[0] 793 packed_weight_node = linear_node.args[1] 794 assert packed_weight_node.target == mkldnn._reorder_linear_weight 795 transpose_weight_node = packed_weight_node.args[0] 796 assert transpose_weight_node.target == aten.permute.default 797 weight_meta = transpose_weight_node.args[0].meta.get("val") 798 bias_node = add_node.args[1] 799 if isinstance(bias_node, int): 800 # we only folding bias if it is a constant 801 return False 802 bias_meta = add_node.args[1].meta.get("val") 803 if weight_meta is None or bias_meta is None: 804 return False 805 assert weight_meta.dtype in ( 806 torch.bfloat16, 807 torch.float16, 808 ) 809 if bias_meta.dtype != weight_meta.dtype: 810 return False 811 return ( 812 linear_node.args[2] is None 813 and bias_meta.dim() == 1 814 and bias_meta.size(0) == weight_meta.size(1) 815 ) 816 817 # convert linear+bias to a single linear for applying fusion path. 818 @register_freezing_graph_pattern( 819 CallFunction( 820 aten.add.Tensor, 821 CallFunction(mkldnn._linear_pointwise.default, *_linear_args), 822 Arg(), 823 ), 824 pass_number=1, 825 extra_check=is_linear_add_bias, 826 ) 827 def linear_bias_pattern(match, *args): 828 graph = match.graph 829 add_node = match.output_node() 830 linear_node = add_node.args[0] 831 new_args = list(linear_node.args) 832 new_args[2] = add_node.args[1] 833 repl = graph.call_function( 834 mkldnn._linear_pointwise.default, tuple(new_args) 835 ) 836 repl.meta.update(add_node.meta) 837 add_node.replace_all_uses_with(repl) 838 match.erase_nodes() 839 840 def _is_packable_mkldnn_rnn_layer(match): 841 lstm_node = match.output_node() 842 POS_WEIGHTS = [1, 2] 843 POS_INPUTS = [0, 5, 6] 844 POS_ARGS = POS_WEIGHTS + POS_INPUTS 845 # Weights should be Constant 846 if any( 847 lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS 848 ): 849 return False 850 851 # Meta info for weights and inputs should be available 852 if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): 853 return False 854 855 # Check device 856 if any( 857 lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" 858 for POS_ARG in POS_ARGS 859 ): 860 return False 861 862 # Check dtype 863 if any( 864 lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 865 and not mkldnn._is_mkldnn_bf16_supported() 866 for POS_ARG in POS_ARGS 867 ): 868 return False 869 if any( 870 lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 871 and not mkldnn._is_mkldnn_fp16_supported() 872 for POS_ARG in POS_ARGS 873 ): 874 return False 875 876 return True 877 878 def _is_packable_convolution(match): 879 """ 880 Check if the node is supported for MKLDNN convolution. 881 """ 882 conv_node = match.output_node() 883 input_meta_value = conv_node.args[0].meta.get("val") 884 weight_meta_value = conv_node.args[1].meta.get("val") 885 if input_meta_value is None or weight_meta_value is None: 886 return False 887 input_size = input_meta_value.shape 888 if conv_node.args[1].op != "get_attr": 889 return False 890 for meta_value in [input_meta_value, weight_meta_value]: 891 if ( 892 meta_value is None 893 or meta_value.device.type != "cpu" 894 or (meta_value.dim() != 4 and meta_value.dim() != 5) 895 ): 896 return False 897 if ( 898 input_meta_value.dtype == torch.bfloat16 899 or weight_meta_value.dtype == torch.bfloat16 900 ): 901 if not mkldnn._is_mkldnn_bf16_supported(): 902 return False 903 if ( 904 input_meta_value.dtype == torch.float16 905 or weight_meta_value.dtype == torch.float16 906 ): 907 if not mkldnn._is_mkldnn_fp16_supported(): 908 return False 909 is_transposed = conv_node.args[-3] 910 if is_transposed: 911 # TODO: Support dynamic shape case for MKLDNN conv transpose. 912 if has_free_symbols(input_size): 913 return False 914 groups = conv_node.args[-1] 915 in_channels = weight_meta_value.size(0) 916 # doesn't support group_depthwise_conv_transpose. 917 if groups > 1 and groups == in_channels: 918 return False 919 # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big 920 output_paddings = conv_node.args[-2] 921 strides = conv_node.args[3] 922 if any( 923 output_padding >= stride 924 for output_padding, stride in zip(output_paddings, strides) 925 ): 926 return False 927 return True 928 929 def _is_packable_linear(match): 930 """ 931 Check if the node is supported for MKLDNN linear. 932 """ 933 linear_node = match.output_node() 934 # mkldnn linear only supports beta=1or0 and alpha=1 935 if linear_node.target == aten.addmm.default: 936 alpha = linear_node.kwargs.get("alpha", 1.0) 937 beta = linear_node.kwargs.get("beta", 1.0) 938 if (beta != 0.0 and beta != 1.0) or alpha != 1.0: 939 return False 940 # weight_idx is 1 for aten.mm and is 2 for aten.addmm 941 weight_idx = 2 if linear_node.target == aten.addmm.default else 1 942 if linear_node.args[weight_idx].op != "get_attr": 943 return False 944 input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") 945 weight_meta_value = linear_node.args[weight_idx].meta.get("val") 946 if input_meta_value is None or weight_meta_value is None: 947 return False 948 batch_size = input_meta_value.shape[0] 949 if ( 950 input_meta_value.dtype == torch.float64 951 or weight_meta_value.dtype == torch.float64 952 ): 953 return False 954 is_lp_weight = weight_meta_value.dtype in ( 955 torch.bfloat16, 956 torch.float16, 957 ) 958 # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. 959 # on aarch64, use mkldnn op for fp32 as well if acl is enabled 960 if ( 961 not is_lp_weight 962 and not mkldnn._is_mkldnn_acl_supported() 963 and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) 964 ): 965 return False 966 for meta_value in [input_meta_value, weight_meta_value]: 967 if ( 968 meta_value is None 969 or meta_value.device.type != "cpu" 970 or meta_value.dim() != 2 971 ): 972 return False 973 if weight_idx == 2: 974 bias_meta_value = linear_node.args[0].meta.get("val") 975 if ( 976 bias_meta_value is None 977 or meta_value.device.type != "cpu" 978 or bias_meta_value.dim() != 1 979 or bias_meta_value.size(0) != weight_meta_value.size(1) 980 ): 981 return False 982 983 if ( 984 input_meta_value.dtype == torch.bfloat16 985 or weight_meta_value.dtype == torch.bfloat16 986 ): 987 if not mkldnn._is_mkldnn_bf16_supported(): 988 return False 989 if ( 990 input_meta_value.dtype == torch.float16 991 or weight_meta_value.dtype == torch.float16 992 ): 993 if not mkldnn._is_mkldnn_fp16_supported(): 994 return False 995 return True 996 997 _aten_conv_args = ( 998 Arg(), 999 Arg(), 1000 Arg(), 1001 Arg(), 1002 Arg(), 1003 Arg(), 1004 KeywordArg("is_transposed"), 1005 Arg(), 1006 Arg(), 1007 ) 1008 1009 _aten_mkldnn_rnn_layer_args = ( 1010 Arg(), # input 1011 Arg(), # weight0 1012 Arg(), # weight1 1013 Arg(), # weight2 1014 Arg(), # weight3 1015 Arg(), # hx_ 1016 Arg(), # cx_ 1017 KeywordArg("reverse"), # reverse 1018 Arg(), # batch_sizes 1019 Arg(), # mode 1020 Arg(), # hidden_size 1021 Arg(), # num_layers 1022 Arg(), # has_biases 1023 Arg(), # bidirectional 1024 Arg(), # batch_first 1025 Arg(), # train 1026 ) 1027 1028 def _register_weight_pack_pass(): 1029 @register_freezing_graph_pattern( 1030 CallFunction(aten.convolution.default, *_aten_conv_args), 1031 extra_check=_is_packable_convolution, 1032 ) 1033 def convolution(match, *args, **kwargs): 1034 is_transposed = kwargs.get("is_transposed") 1035 assert isinstance(is_transposed, bool) 1036 graph = match.graph 1037 conv_node = match.output_node() 1038 input_size = conv_node.args[0].meta.get("val").shape 1039 with graph.inserting_before(conv_node): 1040 constant_args = [args[4], args[3], args[5], args[-1]] 1041 packed_weight_op = mkldnn._reorder_convolution_weight 1042 packed_conv_op = mkldnn._convolution_pointwise.default 1043 if is_transposed: 1044 constant_args.insert(1, args[-2]) # output_padding 1045 packed_weight_op = mkldnn._reorder_convolution_transpose_weight 1046 packed_conv_op = mkldnn._convolution_transpose_pointwise.default 1047 if not has_free_symbols(input_size): 1048 packed_weight_inputs = ( 1049 (args[1],) + tuple(constant_args) + (input_size,) 1050 ) 1051 packed_weight_node = graph.create_node( 1052 "call_function", packed_weight_op, args=packed_weight_inputs 1053 ) 1054 else: 1055 assert not is_transposed 1056 # For dynamic shape case, we need to pack weight in runtime. 1057 packed_weight_node = args[1] 1058 packed_conv_inputs = ( 1059 (args[0], packed_weight_node, args[2]) 1060 + tuple(constant_args) 1061 + ("none", [], "") 1062 ) 1063 packed_conv_node = graph.create_node( 1064 "call_function", packed_conv_op, tuple(packed_conv_inputs) 1065 ) 1066 conv_node.replace_all_uses_with(packed_conv_node) 1067 packed_conv_node.meta.update(conv_node.meta) 1068 graph.erase_node(conv_node) 1069 1070 @register_freezing_graph_pattern( 1071 CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), 1072 extra_check=_is_packable_mkldnn_rnn_layer, 1073 ) 1074 def mkldnn_rnn_layer(match, *args, **kwargs): 1075 def get_item(graph, node, index): 1076 return graph.call_function(operator.getitem, (node, index)) 1077 1078 graph = match.graph 1079 lstm_node = match.output_node() 1080 input = args[0] 1081 weight0, weight1 = args[1:3] 1082 reverse = kwargs.get("reverse") 1083 packed_lstm_op = aten.mkldnn_rnn_layer.default 1084 hidden_size = args[9] 1085 has_biases = args[11] 1086 batch_first = args[13] 1087 with graph.inserting_before(lstm_node): 1088 packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default 1089 packed_weight_inputs = ( 1090 weight0, 1091 weight1, 1092 hidden_size, 1093 reverse, 1094 has_biases, 1095 batch_first, 1096 ) 1097 packed_weight_node = graph.create_node( 1098 "call_function", packed_weight_op, packed_weight_inputs, {}, "name" 1099 ) 1100 packed_weight_items = [ 1101 get_item(graph, packed_weight_node, i) for i in range(2) 1102 ] 1103 pack_lstm_inputs = ( 1104 args[0], 1105 *packed_weight_items, 1106 args[3], 1107 args[4], 1108 args[5], 1109 args[6], 1110 reverse, 1111 *args[7:], 1112 ) 1113 1114 packed_lstm_node = graph.create_node( 1115 "call_function", packed_lstm_op, args=pack_lstm_inputs 1116 ) 1117 lstm_node.replace_all_uses_with(packed_lstm_node) 1118 packed_lstm_node.meta.update(lstm_node.meta) 1119 graph.erase_node(lstm_node) 1120 1121 @register_freezing_graph_pattern( 1122 CallFunction( 1123 aten.addmm.default, 1124 Arg(), 1125 Arg(), 1126 Arg(), 1127 beta=KeywordArg("beta"), 1128 alpha=KeywordArg("alpha"), 1129 ), 1130 extra_check=_is_packable_linear, 1131 ) 1132 @register_freezing_graph_pattern( 1133 CallFunction(aten.mm.default, Arg(), Arg()), 1134 extra_check=_is_packable_linear, 1135 ) 1136 def linear(match, *args, **kwargs): 1137 graph = match.graph 1138 linear_node = match.output_node() 1139 input = args[0] if linear_node.target == aten.mm.default else args[1] 1140 bias = ( 1141 None 1142 if linear_node.target == aten.mm.default 1143 or ( 1144 linear_node.target == aten.addmm.default 1145 and linear_node.kwargs.get("beta", 1.0) == 0.0 1146 ) 1147 else args[0] 1148 ) 1149 weight = args[1] if linear_node.target == aten.mm.default else args[2] 1150 with graph.inserting_before(linear_node): 1151 transpose_weight_node = graph.create_node( 1152 "call_function", aten.permute.default, (weight, (1, 0)) 1153 ) 1154 weight_dtype = weight.meta.get("val").dtype 1155 is_lp_weight = weight_dtype in ( 1156 torch.bfloat16, 1157 torch.float16, 1158 ) 1159 batch_size = input.meta.get("val").shape[0] 1160 if has_free_symbols(batch_size): 1161 assert ( 1162 is_lp_weight or mkldnn._is_mkldnn_acl_supported() 1163 ), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" 1164 # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. 1165 packed_weight_inputs = ( 1166 transpose_weight_node, 1167 batch_size.node.shape_env.size_hint(batch_size.node.expr) 1168 if has_free_symbols(batch_size) 1169 else batch_size, 1170 ) 1171 # MKL packed matrix can't be copied to a different address because the internal implementation 1172 # depends on the alignment of internally-stored metadata. 1173 # In aot mode, we need to firstly save the packed weight, when loading it, 1174 # it will be in a different address which doesn't work. 1175 # Disable MKL prepack linear in AOT mode 1176 packed_weight_op = ( 1177 mkldnn._reorder_linear_weight 1178 if ( 1179 is_lp_weight 1180 or mkldnn._is_mkldnn_acl_supported() 1181 or V.aot_compilation is True 1182 ) 1183 else torch.ops.mkl._mkl_reorder_linear_weight 1184 ) 1185 packed_weight_node = graph.create_node( 1186 "call_function", packed_weight_op, args=packed_weight_inputs 1187 ) 1188 1189 packed_linear_inputs: Tuple[Any, ...] = (input, packed_weight_node) 1190 if ( 1191 is_lp_weight 1192 or mkldnn._is_mkldnn_acl_supported() 1193 or V.aot_compilation is True 1194 ): 1195 packed_linear_inputs += (bias, "none", [], "") 1196 packed_linear_op = mkldnn._linear_pointwise.default 1197 else: 1198 packed_linear_inputs += (transpose_weight_node, bias, batch_size) 1199 packed_linear_op = torch.ops.mkl._mkl_linear 1200 packed_linear_node = graph.create_node( 1201 "call_function", packed_linear_op, packed_linear_inputs 1202 ) 1203 linear_node.replace_all_uses_with(packed_linear_node) 1204 packed_linear_node.meta.update(linear_node.meta) 1205 graph.erase_node(linear_node) 1206 1207 def _eliminate_duplicate_packed_nodes(gm): 1208 """ 1209 Combine packed weight nodes with the same inputs to reduce memory usage. 1210 for example: 1211 class Model(nn.Module): 1212 def __init__(self) -> None: 1213 super().__init__() 1214 self.linear = nn.Linear(32, 32, bias=True) 1215 1216 def forward(self, x): 1217 return self.linear(self.linear(x)) 1218 1219 the above's packed weight nodes are duplicate if two linear calls have same input size. 1220 """ 1221 if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): 1222 return gm 1223 1224 packed_weight_ops = [ 1225 torch._C._nn.mkldnn_reorder_conv2d_weight, 1226 torch._C._nn.mkldnn_reorder_conv3d_weight, 1227 mkldnn._reorder_convolution_transpose_weight, 1228 mkldnn._reorder_linear_weight, 1229 mkldnn._reorder_mkldnn_rnn_layer_weight, 1230 ] 1231 if torch._C.has_mkl: 1232 packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) 1233 1234 for node in gm.graph.nodes: 1235 if node.target in packed_weight_ops and len(node.args[0].users) > 1: 1236 for user_node in list(node.args[0].users.keys()): 1237 if ( 1238 user_node.target == node.target 1239 and user_node != node 1240 and user_node.args == node.args 1241 ): 1242 user_node.replace_all_uses_with(node) 1243 gm.graph.erase_node(user_node) 1244 1245 @functools.lru_cache(None) 1246 def _mkldnn_fusion_init(): 1247 # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. 1248 # Otherwise even the matmul or innerproduct can not be accelerated with acl 1249 if ( 1250 torch.backends.mkldnn.enabled 1251 and torch.backends.mkldnn.is_available() 1252 and not torch.ops.mkldnn._is_mkldnn_acl_supported() 1253 ): 1254 _register_unary_fusion() 1255 _register_inplace_fusion() 1256 _register_binary_unary_fusion() 1257 _register_binary_fusion() 1258 _register_quantization_lowerings() 1259 _register_woq_lowerings() 1260 1261 @functools.lru_cache(None) 1262 def _mkldnn_weight_pack_init(): 1263 if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): 1264 _register_weight_pack_pass() 1265 _recover_linear() 1266 _register_quantization_weight_pack_pass() 1267