1# mypy: allow-untyped-defs 2import functools 3import itertools 4import operator 5import typing 6from typing import Callable, List, Optional, Union 7 8import torch 9import torch._inductor.runtime.runtime_utils 10from torch import Tensor 11from torch._dynamo.utils import counters 12from torch._inductor import utils 13from torch._inductor.autoheuristic.autoheuristic import ( 14 AHContext, 15 AutoHeuristic, 16 LocalFeedback, 17) 18from torch._inductor.autoheuristic.autoheuristic_utils import ( 19 context_add_strides, 20 context_add_using_tf32, 21 pad_mm_operations, 22 pad_mm_precondition, 23) 24from torch._subclasses.fake_tensor import FakeTensor 25from torch.utils._mode_utils import no_dispatch 26 27from ...utils._triton import has_triton 28from ..pattern_matcher import ( 29 fwd_only, 30 gen_register_replacement, 31 joint_fwd_bwd, 32 Match, 33 ReplaceFn, 34 SearchFn, 35) 36 37 38aten = torch.ops.aten 39 40 41# This flag is only used for testing purpose. 42# Changing it to True will ignore comparing do_bench times 43# between original pattern and padded one. 44_skip_do_bench_times = False 45 46 47def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]: 48 kwargs = match.kwargs 49 return [kwargs[name].meta["val"] for name in kwarg_names] 50 51 52def unwrap_fake_args(*arg_names): 53 def decorator(func): 54 def wrapper(match): 55 fake_tensors = fetch_fake_tensors(match, arg_names) 56 return func(*fake_tensors) 57 58 return wrapper 59 60 return decorator 61 62 63def get_alignment_size(x: Tensor) -> int: 64 return get_alignment_size_dtype(x.dtype) 65 66 67def get_alignment_size_dtype(dtype: torch.dtype) -> int: 68 if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16: 69 return 8 70 elif dtype == torch.float32 or dtype == torch.float: 71 return 4 72 else: 73 return 0 74 75 76def check_device(a: Tensor, b: Tensor) -> bool: 77 return a.is_cuda and b.is_cuda 78 79 80def check_dtype(a: Tensor, b: Tensor) -> bool: 81 return a.is_floating_point() and b.is_floating_point() 82 83 84def should_pad_common( 85 mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None 86) -> bool: 87 # It's fine we have symbolic shapes or strides as long as they 88 # have hints. Later, we will make sure we only pad non-symbolic dimensions. 89 def valid_shape_and_stride(t: Optional[Tensor]) -> bool: 90 if t is None: 91 return True 92 93 symbolic_cnt = 0 94 for x in t.size(): 95 if isinstance(x, int): 96 continue 97 elif utils.is_symbolic(x): 98 if not x.node.has_hint(): 99 return False 100 symbolic_cnt += 1 101 else: 102 return False 103 # filter out cases where all dimentions are symbolic 104 if symbolic_cnt == len(t.size()): 105 return False 106 return all( 107 isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) 108 for x in t.stride() 109 ) 110 111 return ( 112 torch._inductor.config.shape_padding 113 and check_device(mat1, mat2) 114 and check_dtype(mat1, mat2) 115 and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) 116 ) 117 118 119def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int: 120 # we don't pad x if it is symbolic 121 if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: 122 return 0 123 124 # ignore dim that can be squeezed away 125 if x == 1: 126 return 0 127 128 return int((x // alignment_size + 1) * alignment_size) - x 129 130 131def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: 132 if padded_length == 0: 133 return x 134 pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) 135 return torch.cat([x, pad], dim=dim) 136 137 138def addmm_pattern( 139 input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float 140) -> Tensor: 141 return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) 142 143 144def should_pad_addmm(match: Match) -> bool: 145 mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) 146 return should_pad_common(mat1, mat2, input) and should_pad_bench( 147 match, mat1, mat2, torch.ops.aten.addmm, input=input 148 ) 149 150 151def pad_addmm( 152 input: Optional[Tensor], 153 mat1: Tensor, 154 mat2: Tensor, 155 m_padded_length: int, 156 k_padded_length: int, 157 n_padded_length: int, 158 beta=1.0, 159 alpha=1.0, 160 mat1_pre_padded: bool = False, 161 mat2_pre_padded: bool = False, 162): 163 # for paddings, dim order is reversed for some reasons 164 # and for every dim, we need to specify left and right padding 165 if not mat1_pre_padded: 166 mat1 = pad_mat1( 167 mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length 168 ) 169 if not mat2_pre_padded: 170 mat2 = pad_mat2( 171 mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length 172 ) 173 174 # the add broadcasts, so we only pad if the dimension != 1 175 if input is not None: 176 if n_padded_length != 0: 177 if input.dim() == 2 and input.shape[1] != 1: 178 input = pad_dim(input, n_padded_length, 1) 179 elif input.dim() == 1 and input.shape[0] != 1: 180 input = pad_dim(input, n_padded_length, 0) 181 if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: 182 input = pad_dim(input, m_padded_length, 0) 183 184 res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) 185 186 if m_padded_length != 0: 187 res = res[:-m_padded_length, :] 188 if n_padded_length != 0: 189 res = res[:, :-n_padded_length] 190 return res 191 192 193def addmm_replace( 194 input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0 195) -> Tensor: 196 k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) 197 n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) 198 m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) 199 return pad_addmm( 200 input, 201 mat1, 202 mat2, 203 m_padded_length, 204 k_padded_length, 205 n_padded_length, 206 beta, 207 alpha, 208 ) 209 210 211def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: 212 denominator = M * K + N * K + M * N 213 if denominator == 0: 214 return False 215 arithmetic_intensity = (M * N * K) / denominator 216 217 # we have experienced some large perf hits in this case, even in bandwidth bound regimes 218 if ( 219 dtype is torch.bfloat16 220 and K > M 221 and K > N 222 and torch.cuda.get_device_capability() < (9, 0) 223 ): # doesnt repro on h100s: 224 return True 225 226 # Fails with AMD 227 try: 228 machine_balance = ( 229 1000 * utils.get_device_tflops(dtype) 230 ) / utils.get_gpu_dram_gbps() 231 except Exception: 232 return True 233 234 # dram_gbps might be underestimating bandwidth because of cache. 235 # if we estimate machine balance too low we might miss some speedups, 236 # if we extimate too high there will be unnecessary compilation time increase. 237 # TODO - finetune coefficient here. As a reference point, Triton mm model assumes 238 # 80% of reads are in cache and cache is 4x faster than dram_gbps 239 machine_balance = machine_balance * 0.5 240 241 return arithmetic_intensity > machine_balance 242 243 244@functools.lru_cache(None) 245def get_pad_cache(): 246 return torch._inductor.codecache.LocalCache() 247 248 249def get_cached_should_pad(key: str) -> bool: 250 return get_pad_cache().lookup(key) 251 252 253def set_cached_should_pad(key: str, value: bool): 254 return get_pad_cache().set_value(key, value=value) 255 256 257def get_cached_base_mm_benchmark_time(key: str) -> float: 258 return get_pad_cache().lookup(key) 259 260 261def set_cached_base_mm_benchmark_time(key: str, value: float): 262 return get_pad_cache().set_value(key, value=value) 263 264 265def should_pad_bench_key( 266 match, 267 mat1: Tensor, 268 mat2: Tensor, 269 op, 270 input: Optional[Tensor] = None, 271 is_base_time_key=False, 272) -> str: 273 def tensor_key(t): 274 return (t.shape, t.stride(), t.dtype) 275 276 tf32_key = ( 277 None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 278 ) 279 280 def fmt_pad(name): 281 if is_base_time_key: 282 return None 283 return f"exclude_pad:{should_exclude_padding_time(match, name)}" 284 285 key = ( 286 tensor_key(mat1), 287 tensor_key(mat2), 288 fmt_pad("mat1"), 289 fmt_pad("mat2"), 290 op, 291 input if input is None else tensor_key(input), 292 tf32_key, 293 ) 294 295 key = str(key) 296 if is_base_time_key: 297 key = f"base mm time: {key}" 298 return key 299 300 301def get_non_view_def(node): 302 if node.op == operator.getitem: 303 return get_non_view_def(node.args[0]) 304 305 if ( 306 node.op == "call_function" 307 and isinstance(node.target, torch._ops.OpOverload) 308 and utils.is_view(node.target) 309 ): 310 return get_non_view_def(node.all_input_nodes[0]) 311 312 return node 313 314 315def should_exclude_padding_time(match, arg_name): 316 node_def = get_non_view_def(match.kwargs[arg_name]) 317 318 # constant padding converts tensors to contiguous so even if the input tensor 319 # can be planned layout transform is not free. TODO - way to pad and preserve layout ? 320 if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): 321 return False 322 323 # TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889 324 # We would only able to completely plan these out if we were only doing 325 # first dimension padding. non-first we would still need a copy 326 # because these outputs are fixed dense. 327 cannot_plan_output = [ 328 aten.mm.default, 329 aten.convolution.default, 330 aten.convolution_backward.default, 331 aten.bmm.default, 332 aten.addmm.default, 333 aten._scaled_dot_product_flash_attention.default, 334 aten._scaled_dot_product_efficient_attention.default, 335 ] 336 337 if node_def.target in cannot_plan_output: 338 return False 339 340 if ( 341 node_def.target == aten.cat.default 342 and len(node_def.all_input_nodes) 343 > torch._inductor.config.max_pointwise_cat_inputs 344 ): 345 return False 346 347 # optimistically assume we should be able to memory plan away 348 # all non inputs 349 return node_def.op != "placeholder" 350 351 352def should_pad(key: str, ori_time, pad_time) -> bool: 353 multiplier = 1.1 354 # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable 355 # tradeoff between performance improvement from shape padding and overhead from additional memory ops 356 # TODO: Build a learned model which would be better than this heuristic 357 if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options: 358 multiplier = torch._inductor.config.post_grad_fusion_options[ 359 "shape_padding_multiplier" 360 ].get("value", 1.1) 361 counters["inductor"]["shape_padding_multiplier"] += 1 362 should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier 363 set_cached_should_pad(key, should_pad) 364 return should_pad 365 366 367def should_pad_bench( 368 match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None 369) -> bool: 370 do_bench = functools.partial( 371 torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, 372 warmup=5, 373 ) 374 m_padded_length = 0 375 n_padded_length = 0 376 batchsize = 1 377 with no_dispatch(): 378 if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: 379 m = mat1.shape[0] 380 k = mat1.shape[1] 381 n = mat2.shape[1] 382 k_padded_length = get_padded_length(k, get_alignment_size(mat1)) 383 n_padded_length = get_padded_length(n, get_alignment_size(mat2)) 384 m_padded_length = get_padded_length(m, get_alignment_size(mat1)) 385 elif op is torch.ops.aten.bmm: 386 batchsize = mat1.shape[0] 387 m = mat1.shape[1] 388 k = mat1.shape[2] 389 n = mat2.shape[2] 390 k_padded_length = get_padded_length(k, get_alignment_size(mat1)) 391 m_padded_length = get_padded_length(m, get_alignment_size(mat1)) 392 n_padded_length = get_padded_length(n, get_alignment_size(mat2)) 393 else: 394 return False 395 396 if m_padded_length == k_padded_length == n_padded_length == 0: 397 return False 398 399 def realize_symbols(ds): 400 return [d if isinstance(d, int) else d.node.hint for d in ds] 401 402 if any( 403 dim == 0 404 for dim in itertools.chain( 405 realize_symbols(mat1.shape), realize_symbols(mat2.shape) 406 ) 407 ): 408 return False 409 410 if torch._inductor.config.force_shape_pad: 411 return True 412 413 if not has_triton(): 414 return False 415 416 if not is_mm_compute_bound(m, k, n, mat1.dtype): 417 return False 418 419 # We don't want to look up the cache for cases that are trivially false 420 # since it does file io 421 key = should_pad_bench_key(match, mat1, mat2, op, input) 422 423 cached_pad = get_cached_should_pad(key) 424 if cached_pad is not None: 425 return cached_pad 426 427 def realize_tensor(t): 428 if isinstance(t, FakeTensor): 429 size_hints = realize_symbols(t.size()) 430 stride_hint = realize_symbols(t.stride()) 431 real_size = ( 432 sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 433 ) 434 real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) 435 return torch.as_strided(real_t, size_hints, stride_hint) 436 else: 437 return torch.randn_like(t) 438 439 mat1 = realize_tensor(mat1) 440 mat2 = realize_tensor(mat2) 441 442 # since we key on whether or not the inputs can be memory planned, set cache for the 443 # original time which is unaffected by whether or not the input can be planned 444 ori_time_key = should_pad_bench_key( 445 match, mat1, mat2, op, input, is_base_time_key=True 446 ) 447 ori_time = get_cached_base_mm_benchmark_time(ori_time_key) 448 if ori_time is None and op is torch.ops.aten.addmm and input is not None: 449 # realize bias for addmm 450 input = realize_tensor(input) 451 452 mat1_pad = mat1 453 mat2_pad = mat2 454 455 is_bmm = op is torch.ops.aten.bmm 456 457 mat1_pre_padded = should_exclude_padding_time(match, "mat1") 458 fns = [] 459 if mat1_pre_padded and (m_padded_length or k_padded_length): 460 mat1_pad = pad_mat1( 461 mat1_pad, 462 m_padded_length=m_padded_length, 463 k_padded_length=k_padded_length, 464 is_bmm=is_bmm, 465 ) 466 467 def write_pad(): 468 if is_bmm: 469 mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0) 470 else: 471 mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0) 472 473 fns.append(write_pad) 474 475 mat2_pre_padded = should_exclude_padding_time(match, "mat2") 476 if mat2_pre_padded and (k_padded_length or n_padded_length): 477 mat2_pad = pad_mat2( 478 mat2_pad, 479 k_padded_length=k_padded_length, 480 n_padded_length=n_padded_length, 481 is_bmm=is_bmm, 482 ) 483 484 def write_pad(): 485 if is_bmm: 486 mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0) 487 else: 488 mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0) 489 490 fns.append(write_pad) 491 492 if op is torch.ops.aten.addmm: 493 input_pad = None 494 if input is not None and input.is_cuda: 495 input_pad = torch.randn_like(input) 496 fns.append( 497 lambda: pad_addmm( 498 input_pad, 499 mat1_pad, 500 mat2_pad, 501 m_padded_length, 502 k_padded_length, 503 n_padded_length, 504 mat1_pre_padded=mat1_pre_padded, 505 mat2_pre_padded=mat2_pre_padded, 506 ) 507 ) 508 elif op is torch.ops.aten.mm: 509 fns.append( 510 lambda: pad_mm( 511 mat1_pad, 512 mat2_pad, 513 m_padded_length, 514 k_padded_length, 515 n_padded_length, 516 mat1_pre_padded=mat1_pre_padded, 517 mat2_pre_padded=mat2_pre_padded, 518 ) 519 ) 520 else: 521 fns.append( 522 lambda: pad_bmm( 523 mat1_pad, 524 mat2_pad, 525 m_padded_length, 526 k_padded_length, 527 n_padded_length, 528 mat1_pre_padded=mat1_pre_padded, 529 mat2_pre_padded=mat2_pre_padded, 530 ) 531 ) 532 533 def orig_bench_fn(): 534 if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: 535 op(mat1, mat2) 536 else: 537 op(input, mat1, mat2) 538 539 def pad_bench_fn(): 540 for fn in fns: 541 fn() 542 543 if ( 544 torch._inductor.config.run_autoheuristic("pad_mm") 545 and op is torch.ops.aten.mm 546 ): 547 ah_should_pad = run_autoheuristic( 548 mat1, 549 mat2, 550 orig_bench_fn, 551 pad_bench_fn, 552 m_padded_length, 553 k_padded_length, 554 n_padded_length, 555 do_bench, 556 mat1_pre_padded, 557 mat2_pre_padded, 558 ori_time, 559 ori_time_key, 560 key, 561 ) 562 if ah_should_pad is not None: 563 return ah_should_pad 564 565 if ori_time is None: 566 ori_time = do_bench(orig_bench_fn) 567 set_cached_base_mm_benchmark_time(ori_time_key, ori_time) 568 569 pad_time = do_bench(pad_bench_fn) 570 return should_pad(key, ori_time, pad_time) 571 572 573def get_context( 574 mat1: Tensor, 575 mat2: Tensor, 576 mat1_pre_padded: bool, 577 mat2_pre_padded: bool, 578 m_padded_length: int, 579 k_padded_length: int, 580 n_padded_length: int, 581): 582 context = AHContext() 583 584 context.add_feature("m", mat1.shape[0]) 585 context.add_feature("k", mat1.shape[1]) 586 context.add_feature("n", mat2.shape[1]) 587 588 context_add_strides(context, "mat1", mat1.stride()) 589 context_add_strides(context, "mat2", mat2.stride()) 590 591 context.add_feature("m_padded_length", m_padded_length) 592 context.add_feature("k_padded_length", k_padded_length) 593 context.add_feature("n_padded_length", n_padded_length) 594 595 context.add_feature("mat1_align_size", get_alignment_size(mat1)) 596 context.add_feature("mat2_align_size", get_alignment_size(mat2)) 597 598 context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) 599 context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True) 600 601 context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True) 602 context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True) 603 604 context_add_using_tf32(context, mat1.dtype) 605 return context 606 607 608def run_autoheuristic( 609 mat1: Tensor, 610 mat2: Tensor, 611 orig_bench_fn: Callable[[], None], 612 pad_bench_fn: Callable[[], None], 613 m_padded_length: int, 614 k_padded_length: int, 615 n_padded_length: int, 616 do_bench, 617 mat1_pre_padded: bool, 618 mat2_pre_padded: bool, 619 ori_time, 620 ori_time_key: str, 621 key: str, 622) -> Optional[bool]: 623 def feedback_fn(choice: str): 624 if choice == orig_choice: 625 return do_bench(orig_bench_fn) 626 elif choice == pad_choice: 627 return do_bench(pad_bench_fn) 628 return None 629 630 def fallback() -> str: 631 return "autotune" 632 633 orig_choice = "orig" 634 pad_choice = "pad" 635 choices = [orig_choice, pad_choice] 636 feedback = LocalFeedback(feedback_fn) 637 context = get_context( 638 mat1, 639 mat2, 640 mat1_pre_padded, 641 mat2_pre_padded, 642 m_padded_length, 643 k_padded_length, 644 n_padded_length, 645 ) 646 name = "pad_mm" 647 autoheuristic = AutoHeuristic( 648 fallback=fallback, 649 choices=choices, 650 feedback=feedback, 651 context=context, 652 name=name, 653 augment_context=pad_mm_operations(), 654 precondition=pad_mm_precondition, 655 ) 656 choice = autoheuristic.get_choice() 657 choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} 658 ah_should_pad = choice2should_pad.get(choice, None) 659 660 if torch._inductor.config.collect_autoheuristic(name): 661 ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) 662 ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) 663 664 # if precondition is not satisifed, autoheuristic does not collect data 665 if ah_ori_time is not None and ah_pad_time is not None: 666 if ori_time is None: 667 set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) 668 return should_pad(key, ah_ori_time, ah_pad_time) 669 if ah_should_pad is not None: 670 set_cached_should_pad(key, ah_should_pad) 671 return ah_should_pad 672 673 674def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: 675 return aten.mm(mat1, mat2) 676 677 678def should_pad_mm(match: Match) -> bool: 679 mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) 680 return should_pad_common(mat1, mat2) and should_pad_bench( 681 match, mat1, mat2, torch.ops.aten.mm 682 ) 683 684 685def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False): 686 if m_padded_length == 0 and k_padded_length == 0: 687 return mat1 688 elif k_padded_length != 0 and m_padded_length != 0: 689 # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding 690 pad_arg = [0, k_padded_length, 0, m_padded_length] 691 if is_bmm: 692 pad_arg.extend((0, 0)) 693 return aten.constant_pad_nd(mat1, pad_arg) 694 elif m_padded_length != 0: 695 return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1) 696 else: 697 assert k_padded_length != 0 698 return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2) 699 700 701def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False): 702 if k_padded_length == 0 and n_padded_length == 0: 703 return mat2 704 elif k_padded_length != 0 and n_padded_length != 0: 705 # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding 706 pad_arg = [0, n_padded_length, 0, k_padded_length] 707 if is_bmm: 708 pad_arg.extend((0, 0)) 709 return aten.constant_pad_nd(mat2, pad_arg) 710 elif k_padded_length != 0: 711 return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1) 712 else: 713 assert n_padded_length != 0 714 return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2) 715 716 717def pad_mm( 718 mat1: Tensor, 719 mat2: Tensor, 720 m_padded_length: int, 721 k_padded_length: int, 722 n_padded_length: int, 723 mat1_pre_padded: bool = False, 724 mat2_pre_padded: bool = False, 725) -> Tensor: 726 if not mat1_pre_padded: 727 mat1 = pad_mat1( 728 mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length 729 ) 730 if not mat2_pre_padded: 731 mat2 = pad_mat2( 732 mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length 733 ) 734 res = aten.mm(mat1, mat2) 735 if m_padded_length != 0: 736 res = res[:-m_padded_length, :] 737 if n_padded_length != 0: 738 res = res[:, :-n_padded_length] 739 return res 740 741 742def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: 743 k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) 744 m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) 745 n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) 746 return pad_mm( 747 mat1, 748 mat2, 749 m_padded_length, 750 k_padded_length, 751 n_padded_length, 752 ) 753 754 755def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: 756 return aten.bmm(mat1, mat2) 757 758 759def should_pad_bmm(match: Match) -> bool: 760 mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) 761 return should_pad_common(mat1, mat2) and should_pad_bench( 762 match, mat1, mat2, torch.ops.aten.bmm 763 ) 764 765 766def pad_bmm( 767 mat1: Tensor, 768 mat2: Tensor, 769 m_padded_length: int, 770 k_padded_length: int, 771 n_padded_length: int, 772 mat1_pre_padded: bool = False, 773 mat2_pre_padded: bool = False, 774) -> Tensor: 775 if not mat1_pre_padded: 776 mat1 = pad_mat1( 777 mat1, 778 m_padded_length=m_padded_length, 779 k_padded_length=k_padded_length, 780 is_bmm=True, 781 ) 782 if not mat2_pre_padded: 783 mat2 = pad_mat2( 784 mat2, 785 k_padded_length=k_padded_length, 786 n_padded_length=n_padded_length, 787 is_bmm=True, 788 ) 789 res = aten.bmm(mat1, mat2) 790 if m_padded_length != 0: 791 res = res[:, :-m_padded_length, :] 792 if n_padded_length != 0: 793 res = res[:, :, :-n_padded_length] 794 return res 795 796 797def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: 798 k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) 799 n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) 800 m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) 801 return pad_bmm( 802 mat1, 803 mat2, 804 m_padded_length, 805 k_padded_length, 806 n_padded_length, 807 ) 808 809 810@functools.lru_cache(None) 811def _pad_mm_init(): 812 from .joint_graph import patterns 813 814 if torch.cuda.is_available(): 815 # workaround https://github.com/pytorch/pytorch/issues/97894 816 device = "cuda" 817 else: 818 device = "cpu" 819 820 # sizes/values dont actually matter for initial trace 821 # once we get a possible match we re-trace with the actual values and verify the match still holds 822 823 dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) 824 dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) 825 826 dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) 827 dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) 828 829 dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) 830 831 # workaround https://github.com/pytorch/pytorch/issues/97894 832 # 0.113377 is a "magic" value that lets us recover the lost input arg relationship 833 rep = {"beta": 0.213377, "alpha": 0.113377} 834 835 for pattern, replacement, args, workaround, extra_check in [ 836 ( 837 typing.cast(SearchFn, mm_pattern), 838 typing.cast(ReplaceFn, mm_replace), 839 [dim2a(), dim2b()], 840 {}, 841 should_pad_mm, 842 ), 843 ( 844 typing.cast(SearchFn, bmm_pattern), 845 typing.cast(ReplaceFn, bmm_replace), 846 [dim3a(), dim3b()], 847 {}, 848 should_pad_bmm, 849 ), 850 ( 851 typing.cast(SearchFn, addmm_pattern), 852 typing.cast(ReplaceFn, addmm_replace), 853 [dim1a(), dim2a(), dim2b()], 854 rep, 855 should_pad_addmm, 856 ), 857 ]: 858 assert isinstance(workaround, dict) # mypy is unable to infer the type properly 859 name = pattern.__name__ 860 861 gen_register_replacement( 862 f"{name}_training", 863 pattern, 864 replacement, 865 args, 866 joint_fwd_bwd, 867 patterns, 868 extra_check=extra_check, 869 scalar_workaround=workaround, 870 ) 871 872 gen_register_replacement( 873 f"{name}_inference", 874 pattern, 875 replacement, 876 args, 877 fwd_only, 878 patterns, 879 extra_check=extra_check, 880 scalar_workaround=workaround, 881 ) 882