xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/pad_mm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import operator
5import typing
6from typing import Callable, List, Optional, Union
7
8import torch
9import torch._inductor.runtime.runtime_utils
10from torch import Tensor
11from torch._dynamo.utils import counters
12from torch._inductor import utils
13from torch._inductor.autoheuristic.autoheuristic import (
14    AHContext,
15    AutoHeuristic,
16    LocalFeedback,
17)
18from torch._inductor.autoheuristic.autoheuristic_utils import (
19    context_add_strides,
20    context_add_using_tf32,
21    pad_mm_operations,
22    pad_mm_precondition,
23)
24from torch._subclasses.fake_tensor import FakeTensor
25from torch.utils._mode_utils import no_dispatch
26
27from ...utils._triton import has_triton
28from ..pattern_matcher import (
29    fwd_only,
30    gen_register_replacement,
31    joint_fwd_bwd,
32    Match,
33    ReplaceFn,
34    SearchFn,
35)
36
37
38aten = torch.ops.aten
39
40
41# This flag is only used for testing purpose.
42# Changing it to True will ignore comparing do_bench times
43# between original pattern and padded one.
44_skip_do_bench_times = False
45
46
47def fetch_fake_tensors(match, kwarg_names) -> List[Tensor]:
48    kwargs = match.kwargs
49    return [kwargs[name].meta["val"] for name in kwarg_names]
50
51
52def unwrap_fake_args(*arg_names):
53    def decorator(func):
54        def wrapper(match):
55            fake_tensors = fetch_fake_tensors(match, arg_names)
56            return func(*fake_tensors)
57
58        return wrapper
59
60    return decorator
61
62
63def get_alignment_size(x: Tensor) -> int:
64    return get_alignment_size_dtype(x.dtype)
65
66
67def get_alignment_size_dtype(dtype: torch.dtype) -> int:
68    if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16:
69        return 8
70    elif dtype == torch.float32 or dtype == torch.float:
71        return 4
72    else:
73        return 0
74
75
76def check_device(a: Tensor, b: Tensor) -> bool:
77    return a.is_cuda and b.is_cuda
78
79
80def check_dtype(a: Tensor, b: Tensor) -> bool:
81    return a.is_floating_point() and b.is_floating_point()
82
83
84def should_pad_common(
85    mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None
86) -> bool:
87    # It's fine we have symbolic shapes or strides as long as they
88    # have hints. Later, we will make sure we only pad non-symbolic dimensions.
89    def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
90        if t is None:
91            return True
92
93        symbolic_cnt = 0
94        for x in t.size():
95            if isinstance(x, int):
96                continue
97            elif utils.is_symbolic(x):
98                if not x.node.has_hint():
99                    return False
100                symbolic_cnt += 1
101            else:
102                return False
103        # filter out cases where all dimentions are symbolic
104        if symbolic_cnt == len(t.size()):
105            return False
106        return all(
107            isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint())
108            for x in t.stride()
109        )
110
111    return (
112        torch._inductor.config.shape_padding
113        and check_device(mat1, mat2)
114        and check_dtype(mat1, mat2)
115        and all(valid_shape_and_stride(t) for t in (mat1, mat2, input))
116    )
117
118
119def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
120    # we don't pad x if it is symbolic
121    if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
122        return 0
123
124    # ignore dim that can be squeezed away
125    if x == 1:
126        return 0
127
128    return int((x // alignment_size + 1) * alignment_size) - x
129
130
131def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor:
132    if padded_length == 0:
133        return x
134    pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
135    return torch.cat([x, pad], dim=dim)
136
137
138def addmm_pattern(
139    input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float
140) -> Tensor:
141    return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
142
143
144def should_pad_addmm(match: Match) -> bool:
145    mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input"))
146    return should_pad_common(mat1, mat2, input) and should_pad_bench(
147        match, mat1, mat2, torch.ops.aten.addmm, input=input
148    )
149
150
151def pad_addmm(
152    input: Optional[Tensor],
153    mat1: Tensor,
154    mat2: Tensor,
155    m_padded_length: int,
156    k_padded_length: int,
157    n_padded_length: int,
158    beta=1.0,
159    alpha=1.0,
160    mat1_pre_padded: bool = False,
161    mat2_pre_padded: bool = False,
162):
163    # for paddings, dim order is reversed for some reasons
164    # and for every dim, we need to specify left and right padding
165    if not mat1_pre_padded:
166        mat1 = pad_mat1(
167            mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
168        )
169    if not mat2_pre_padded:
170        mat2 = pad_mat2(
171            mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
172        )
173
174    # the add broadcasts, so we only pad if the dimension != 1
175    if input is not None:
176        if n_padded_length != 0:
177            if input.dim() == 2 and input.shape[1] != 1:
178                input = pad_dim(input, n_padded_length, 1)
179            elif input.dim() == 1 and input.shape[0] != 1:
180                input = pad_dim(input, n_padded_length, 0)
181        if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1:
182            input = pad_dim(input, m_padded_length, 0)
183
184    res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
185
186    if m_padded_length != 0:
187        res = res[:-m_padded_length, :]
188    if n_padded_length != 0:
189        res = res[:, :-n_padded_length]
190    return res
191
192
193def addmm_replace(
194    input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
195) -> Tensor:
196    k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
197    n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
198    m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
199    return pad_addmm(
200        input,
201        mat1,
202        mat2,
203        m_padded_length,
204        k_padded_length,
205        n_padded_length,
206        beta,
207        alpha,
208    )
209
210
211def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
212    denominator = M * K + N * K + M * N
213    if denominator == 0:
214        return False
215    arithmetic_intensity = (M * N * K) / denominator
216
217    # we have experienced some large perf hits in this case, even in bandwidth bound regimes
218    if (
219        dtype is torch.bfloat16
220        and K > M
221        and K > N
222        and torch.cuda.get_device_capability() < (9, 0)
223    ):  # doesnt repro on h100s:
224        return True
225
226    # Fails with AMD
227    try:
228        machine_balance = (
229            1000 * utils.get_device_tflops(dtype)
230        ) / utils.get_gpu_dram_gbps()
231    except Exception:
232        return True
233
234    # dram_gbps might be underestimating bandwidth because of cache.
235    # if we estimate machine balance too low we might miss some speedups,
236    # if we extimate too high there will be unnecessary compilation time increase.
237    # TODO - finetune coefficient here. As a reference point, Triton mm model assumes
238    # 80% of reads are in cache and cache is 4x faster than dram_gbps
239    machine_balance = machine_balance * 0.5
240
241    return arithmetic_intensity > machine_balance
242
243
244@functools.lru_cache(None)
245def get_pad_cache():
246    return torch._inductor.codecache.LocalCache()
247
248
249def get_cached_should_pad(key: str) -> bool:
250    return get_pad_cache().lookup(key)
251
252
253def set_cached_should_pad(key: str, value: bool):
254    return get_pad_cache().set_value(key, value=value)
255
256
257def get_cached_base_mm_benchmark_time(key: str) -> float:
258    return get_pad_cache().lookup(key)
259
260
261def set_cached_base_mm_benchmark_time(key: str, value: float):
262    return get_pad_cache().set_value(key, value=value)
263
264
265def should_pad_bench_key(
266    match,
267    mat1: Tensor,
268    mat2: Tensor,
269    op,
270    input: Optional[Tensor] = None,
271    is_base_time_key=False,
272) -> str:
273    def tensor_key(t):
274        return (t.shape, t.stride(), t.dtype)
275
276    tf32_key = (
277        None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
278    )
279
280    def fmt_pad(name):
281        if is_base_time_key:
282            return None
283        return f"exclude_pad:{should_exclude_padding_time(match, name)}"
284
285    key = (
286        tensor_key(mat1),
287        tensor_key(mat2),
288        fmt_pad("mat1"),
289        fmt_pad("mat2"),
290        op,
291        input if input is None else tensor_key(input),
292        tf32_key,
293    )
294
295    key = str(key)
296    if is_base_time_key:
297        key = f"base mm time: {key}"
298    return key
299
300
301def get_non_view_def(node):
302    if node.op == operator.getitem:
303        return get_non_view_def(node.args[0])
304
305    if (
306        node.op == "call_function"
307        and isinstance(node.target, torch._ops.OpOverload)
308        and utils.is_view(node.target)
309    ):
310        return get_non_view_def(node.all_input_nodes[0])
311
312    return node
313
314
315def should_exclude_padding_time(match, arg_name):
316    node_def = get_non_view_def(match.kwargs[arg_name])
317
318    # constant padding converts tensors to contiguous so even if the input tensor
319    # can be planned layout transform is not free. TODO - way to pad and preserve layout ?
320    if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous():
321        return False
322
323    # TODO - see issue https://githpub.com/pytorch/pytorch/issues/128889
324    # We would only able to completely plan these out if we were only doing
325    # first dimension padding. non-first we would still need a copy
326    # because these outputs are fixed dense.
327    cannot_plan_output = [
328        aten.mm.default,
329        aten.convolution.default,
330        aten.convolution_backward.default,
331        aten.bmm.default,
332        aten.addmm.default,
333        aten._scaled_dot_product_flash_attention.default,
334        aten._scaled_dot_product_efficient_attention.default,
335    ]
336
337    if node_def.target in cannot_plan_output:
338        return False
339
340    if (
341        node_def.target == aten.cat.default
342        and len(node_def.all_input_nodes)
343        > torch._inductor.config.max_pointwise_cat_inputs
344    ):
345        return False
346
347    # optimistically assume we should be able to memory plan away
348    # all non inputs
349    return node_def.op != "placeholder"
350
351
352def should_pad(key: str, ori_time, pad_time) -> bool:
353    multiplier = 1.1
354    # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
355    # tradeoff between performance improvement from shape padding and overhead from additional memory ops
356    # TODO: Build a learned model which would be better than this heuristic
357    if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options:
358        multiplier = torch._inductor.config.post_grad_fusion_options[
359            "shape_padding_multiplier"
360        ].get("value", 1.1)
361        counters["inductor"]["shape_padding_multiplier"] += 1
362    should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier
363    set_cached_should_pad(key, should_pad)
364    return should_pad
365
366
367def should_pad_bench(
368    match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
369) -> bool:
370    do_bench = functools.partial(
371        torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
372        warmup=5,
373    )
374    m_padded_length = 0
375    n_padded_length = 0
376    batchsize = 1
377    with no_dispatch():
378        if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
379            m = mat1.shape[0]
380            k = mat1.shape[1]
381            n = mat2.shape[1]
382            k_padded_length = get_padded_length(k, get_alignment_size(mat1))
383            n_padded_length = get_padded_length(n, get_alignment_size(mat2))
384            m_padded_length = get_padded_length(m, get_alignment_size(mat1))
385        elif op is torch.ops.aten.bmm:
386            batchsize = mat1.shape[0]
387            m = mat1.shape[1]
388            k = mat1.shape[2]
389            n = mat2.shape[2]
390            k_padded_length = get_padded_length(k, get_alignment_size(mat1))
391            m_padded_length = get_padded_length(m, get_alignment_size(mat1))
392            n_padded_length = get_padded_length(n, get_alignment_size(mat2))
393        else:
394            return False
395
396        if m_padded_length == k_padded_length == n_padded_length == 0:
397            return False
398
399        def realize_symbols(ds):
400            return [d if isinstance(d, int) else d.node.hint for d in ds]
401
402        if any(
403            dim == 0
404            for dim in itertools.chain(
405                realize_symbols(mat1.shape), realize_symbols(mat2.shape)
406            )
407        ):
408            return False
409
410        if torch._inductor.config.force_shape_pad:
411            return True
412
413        if not has_triton():
414            return False
415
416        if not is_mm_compute_bound(m, k, n, mat1.dtype):
417            return False
418
419        # We don't want to look up the cache for cases that are trivially false
420        # since it does file io
421        key = should_pad_bench_key(match, mat1, mat2, op, input)
422
423        cached_pad = get_cached_should_pad(key)
424        if cached_pad is not None:
425            return cached_pad
426
427        def realize_tensor(t):
428            if isinstance(t, FakeTensor):
429                size_hints = realize_symbols(t.size())
430                stride_hint = realize_symbols(t.stride())
431                real_size = (
432                    sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1
433                )
434                real_t = torch.randn(real_size, dtype=t.dtype, device=t.device)
435                return torch.as_strided(real_t, size_hints, stride_hint)
436            else:
437                return torch.randn_like(t)
438
439        mat1 = realize_tensor(mat1)
440        mat2 = realize_tensor(mat2)
441
442        # since we key on whether or not the inputs can be memory planned, set cache for the
443        # original time which is unaffected by whether or not the input can be planned
444        ori_time_key = should_pad_bench_key(
445            match, mat1, mat2, op, input, is_base_time_key=True
446        )
447        ori_time = get_cached_base_mm_benchmark_time(ori_time_key)
448        if ori_time is None and op is torch.ops.aten.addmm and input is not None:
449            # realize bias for addmm
450            input = realize_tensor(input)
451
452        mat1_pad = mat1
453        mat2_pad = mat2
454
455        is_bmm = op is torch.ops.aten.bmm
456
457        mat1_pre_padded = should_exclude_padding_time(match, "mat1")
458        fns = []
459        if mat1_pre_padded and (m_padded_length or k_padded_length):
460            mat1_pad = pad_mat1(
461                mat1_pad,
462                m_padded_length=m_padded_length,
463                k_padded_length=k_padded_length,
464                is_bmm=is_bmm,
465            )
466
467            def write_pad():
468                if is_bmm:
469                    mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0)
470                else:
471                    mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0)
472
473            fns.append(write_pad)
474
475        mat2_pre_padded = should_exclude_padding_time(match, "mat2")
476        if mat2_pre_padded and (k_padded_length or n_padded_length):
477            mat2_pad = pad_mat2(
478                mat2_pad,
479                k_padded_length=k_padded_length,
480                n_padded_length=n_padded_length,
481                is_bmm=is_bmm,
482            )
483
484            def write_pad():
485                if is_bmm:
486                    mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0)
487                else:
488                    mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0)
489
490            fns.append(write_pad)
491
492        if op is torch.ops.aten.addmm:
493            input_pad = None
494            if input is not None and input.is_cuda:
495                input_pad = torch.randn_like(input)
496            fns.append(
497                lambda: pad_addmm(
498                    input_pad,
499                    mat1_pad,
500                    mat2_pad,
501                    m_padded_length,
502                    k_padded_length,
503                    n_padded_length,
504                    mat1_pre_padded=mat1_pre_padded,
505                    mat2_pre_padded=mat2_pre_padded,
506                )
507            )
508        elif op is torch.ops.aten.mm:
509            fns.append(
510                lambda: pad_mm(
511                    mat1_pad,
512                    mat2_pad,
513                    m_padded_length,
514                    k_padded_length,
515                    n_padded_length,
516                    mat1_pre_padded=mat1_pre_padded,
517                    mat2_pre_padded=mat2_pre_padded,
518                )
519            )
520        else:
521            fns.append(
522                lambda: pad_bmm(
523                    mat1_pad,
524                    mat2_pad,
525                    m_padded_length,
526                    k_padded_length,
527                    n_padded_length,
528                    mat1_pre_padded=mat1_pre_padded,
529                    mat2_pre_padded=mat2_pre_padded,
530                )
531            )
532
533        def orig_bench_fn():
534            if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
535                op(mat1, mat2)
536            else:
537                op(input, mat1, mat2)
538
539        def pad_bench_fn():
540            for fn in fns:
541                fn()
542
543        if (
544            torch._inductor.config.run_autoheuristic("pad_mm")
545            and op is torch.ops.aten.mm
546        ):
547            ah_should_pad = run_autoheuristic(
548                mat1,
549                mat2,
550                orig_bench_fn,
551                pad_bench_fn,
552                m_padded_length,
553                k_padded_length,
554                n_padded_length,
555                do_bench,
556                mat1_pre_padded,
557                mat2_pre_padded,
558                ori_time,
559                ori_time_key,
560                key,
561            )
562            if ah_should_pad is not None:
563                return ah_should_pad
564
565        if ori_time is None:
566            ori_time = do_bench(orig_bench_fn)
567            set_cached_base_mm_benchmark_time(ori_time_key, ori_time)
568
569        pad_time = do_bench(pad_bench_fn)
570        return should_pad(key, ori_time, pad_time)
571
572
573def get_context(
574    mat1: Tensor,
575    mat2: Tensor,
576    mat1_pre_padded: bool,
577    mat2_pre_padded: bool,
578    m_padded_length: int,
579    k_padded_length: int,
580    n_padded_length: int,
581):
582    context = AHContext()
583
584    context.add_feature("m", mat1.shape[0])
585    context.add_feature("k", mat1.shape[1])
586    context.add_feature("n", mat2.shape[1])
587
588    context_add_strides(context, "mat1", mat1.stride())
589    context_add_strides(context, "mat2", mat2.stride())
590
591    context.add_feature("m_padded_length", m_padded_length)
592    context.add_feature("k_padded_length", k_padded_length)
593    context.add_feature("n_padded_length", n_padded_length)
594
595    context.add_feature("mat1_align_size", get_alignment_size(mat1))
596    context.add_feature("mat2_align_size", get_alignment_size(mat2))
597
598    context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True)
599    context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True)
600
601    context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True)
602    context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True)
603
604    context_add_using_tf32(context, mat1.dtype)
605    return context
606
607
608def run_autoheuristic(
609    mat1: Tensor,
610    mat2: Tensor,
611    orig_bench_fn: Callable[[], None],
612    pad_bench_fn: Callable[[], None],
613    m_padded_length: int,
614    k_padded_length: int,
615    n_padded_length: int,
616    do_bench,
617    mat1_pre_padded: bool,
618    mat2_pre_padded: bool,
619    ori_time,
620    ori_time_key: str,
621    key: str,
622) -> Optional[bool]:
623    def feedback_fn(choice: str):
624        if choice == orig_choice:
625            return do_bench(orig_bench_fn)
626        elif choice == pad_choice:
627            return do_bench(pad_bench_fn)
628        return None
629
630    def fallback() -> str:
631        return "autotune"
632
633    orig_choice = "orig"
634    pad_choice = "pad"
635    choices = [orig_choice, pad_choice]
636    feedback = LocalFeedback(feedback_fn)
637    context = get_context(
638        mat1,
639        mat2,
640        mat1_pre_padded,
641        mat2_pre_padded,
642        m_padded_length,
643        k_padded_length,
644        n_padded_length,
645    )
646    name = "pad_mm"
647    autoheuristic = AutoHeuristic(
648        fallback=fallback,
649        choices=choices,
650        feedback=feedback,
651        context=context,
652        name=name,
653        augment_context=pad_mm_operations(),
654        precondition=pad_mm_precondition,
655    )
656    choice = autoheuristic.get_choice()
657    choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None}
658    ah_should_pad = choice2should_pad.get(choice, None)
659
660    if torch._inductor.config.collect_autoheuristic(name):
661        ah_ori_time = autoheuristic.get_collected_feedback(orig_choice)
662        ah_pad_time = autoheuristic.get_collected_feedback(pad_choice)
663
664        # if precondition is not satisifed, autoheuristic does not collect data
665        if ah_ori_time is not None and ah_pad_time is not None:
666            if ori_time is None:
667                set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time)
668            return should_pad(key, ah_ori_time, ah_pad_time)
669    if ah_should_pad is not None:
670        set_cached_should_pad(key, ah_should_pad)
671    return ah_should_pad
672
673
674def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
675    return aten.mm(mat1, mat2)
676
677
678def should_pad_mm(match: Match) -> bool:
679    mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
680    return should_pad_common(mat1, mat2) and should_pad_bench(
681        match, mat1, mat2, torch.ops.aten.mm
682    )
683
684
685def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
686    if m_padded_length == 0 and k_padded_length == 0:
687        return mat1
688    elif k_padded_length != 0 and m_padded_length != 0:
689        # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
690        pad_arg = [0, k_padded_length, 0, m_padded_length]
691        if is_bmm:
692            pad_arg.extend((0, 0))
693        return aten.constant_pad_nd(mat1, pad_arg)
694    elif m_padded_length != 0:
695        return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1)
696    else:
697        assert k_padded_length != 0
698        return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2)
699
700
701def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
702    if k_padded_length == 0 and n_padded_length == 0:
703        return mat2
704    elif k_padded_length != 0 and n_padded_length != 0:
705        # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
706        pad_arg = [0, n_padded_length, 0, k_padded_length]
707        if is_bmm:
708            pad_arg.extend((0, 0))
709        return aten.constant_pad_nd(mat2, pad_arg)
710    elif k_padded_length != 0:
711        return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1)
712    else:
713        assert n_padded_length != 0
714        return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2)
715
716
717def pad_mm(
718    mat1: Tensor,
719    mat2: Tensor,
720    m_padded_length: int,
721    k_padded_length: int,
722    n_padded_length: int,
723    mat1_pre_padded: bool = False,
724    mat2_pre_padded: bool = False,
725) -> Tensor:
726    if not mat1_pre_padded:
727        mat1 = pad_mat1(
728            mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length
729        )
730    if not mat2_pre_padded:
731        mat2 = pad_mat2(
732            mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length
733        )
734    res = aten.mm(mat1, mat2)
735    if m_padded_length != 0:
736        res = res[:-m_padded_length, :]
737    if n_padded_length != 0:
738        res = res[:, :-n_padded_length]
739    return res
740
741
742def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
743    k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
744    m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
745    n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
746    return pad_mm(
747        mat1,
748        mat2,
749        m_padded_length,
750        k_padded_length,
751        n_padded_length,
752    )
753
754
755def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor:
756    return aten.bmm(mat1, mat2)
757
758
759def should_pad_bmm(match: Match) -> bool:
760    mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2"))
761    return should_pad_common(mat1, mat2) and should_pad_bench(
762        match, mat1, mat2, torch.ops.aten.bmm
763    )
764
765
766def pad_bmm(
767    mat1: Tensor,
768    mat2: Tensor,
769    m_padded_length: int,
770    k_padded_length: int,
771    n_padded_length: int,
772    mat1_pre_padded: bool = False,
773    mat2_pre_padded: bool = False,
774) -> Tensor:
775    if not mat1_pre_padded:
776        mat1 = pad_mat1(
777            mat1,
778            m_padded_length=m_padded_length,
779            k_padded_length=k_padded_length,
780            is_bmm=True,
781        )
782    if not mat2_pre_padded:
783        mat2 = pad_mat2(
784            mat2,
785            k_padded_length=k_padded_length,
786            n_padded_length=n_padded_length,
787            is_bmm=True,
788        )
789    res = aten.bmm(mat1, mat2)
790    if m_padded_length != 0:
791        res = res[:, :-m_padded_length, :]
792    if n_padded_length != 0:
793        res = res[:, :, :-n_padded_length]
794    return res
795
796
797def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
798    k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
799    n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
800    m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
801    return pad_bmm(
802        mat1,
803        mat2,
804        m_padded_length,
805        k_padded_length,
806        n_padded_length,
807    )
808
809
810@functools.lru_cache(None)
811def _pad_mm_init():
812    from .joint_graph import patterns
813
814    if torch.cuda.is_available():
815        # workaround https://github.com/pytorch/pytorch/issues/97894
816        device = "cuda"
817    else:
818        device = "cpu"
819
820    # sizes/values dont actually matter for initial trace
821    # once we get a possible match we re-trace with the actual values and verify the match still holds
822
823    dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
824    dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True)
825
826    dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
827    dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True)
828
829    dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True)
830
831    # workaround https://github.com/pytorch/pytorch/issues/97894
832    # 0.113377 is a "magic" value that lets us recover the lost input arg relationship
833    rep = {"beta": 0.213377, "alpha": 0.113377}
834
835    for pattern, replacement, args, workaround, extra_check in [
836        (
837            typing.cast(SearchFn, mm_pattern),
838            typing.cast(ReplaceFn, mm_replace),
839            [dim2a(), dim2b()],
840            {},
841            should_pad_mm,
842        ),
843        (
844            typing.cast(SearchFn, bmm_pattern),
845            typing.cast(ReplaceFn, bmm_replace),
846            [dim3a(), dim3b()],
847            {},
848            should_pad_bmm,
849        ),
850        (
851            typing.cast(SearchFn, addmm_pattern),
852            typing.cast(ReplaceFn, addmm_replace),
853            [dim1a(), dim2a(), dim2b()],
854            rep,
855            should_pad_addmm,
856        ),
857    ]:
858        assert isinstance(workaround, dict)  # mypy is unable to infer the type properly
859        name = pattern.__name__
860
861        gen_register_replacement(
862            f"{name}_training",
863            pattern,
864            replacement,
865            args,
866            joint_fwd_bwd,
867            patterns,
868            extra_check=extra_check,
869            scalar_workaround=workaround,
870        )
871
872        gen_register_replacement(
873            f"{name}_inference",
874            pattern,
875            replacement,
876            args,
877            fwd_only,
878            patterns,
879            extra_check=extra_check,
880            scalar_workaround=workaround,
881        )
882