xref: /aosp_15_r20/external/pytorch/torch/_inductor/kernel/bmm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3
4import torch
5
6from .. import ir, lowering as L
7from ..select_algorithm import (
8    autotune_select_algorithm,
9    ExternKernelChoice,
10    TritonTemplate,
11)
12from ..utils import (
13    ceildiv as cdiv,
14    use_aten_gemm_kernels,
15    use_cutlass_template,
16    use_triton_template,
17)
18from ..virtualized import V
19from .mm import _is_static_problem
20from .mm_common import addmm_epilogue, mm_args, mm_configs, mm_options
21
22
23log = logging.getLogger(__name__)
24aten = torch.ops.aten
25
26
27def bmm_grid(b, m, n, meta):
28    return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
29
30
31bmm_template = TritonTemplate(
32    name="bmm",
33    grid=bmm_grid,
34    source=r"""
35{{def_kernel("A", "B")}}
36    M = {{size("A", -2)}}
37    N = {{size("B", -1)}}
38    K = {{size("A", -1)}}
39
40    stride_aq = {{stride("A", 0)}}
41    stride_am = {{stride("A", 1)}}
42    stride_ak = {{stride("A", 2)}}
43
44    stride_bq = {{stride("B", 0)}}
45    stride_bk = {{stride("B", 1)}}
46    stride_bn = {{stride("B", 2)}}
47
48    # based on triton.ops.matmul
49    pid = tl.program_id(0)
50    grid_m = (M + BLOCK_M - 1) // BLOCK_M
51    grid_n = (N + BLOCK_N - 1) // BLOCK_N
52
53    # re-order program ID for better L2 performance
54    width = GROUP_M * grid_n
55    group_id = pid // width
56    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
57    pid_m = group_id * GROUP_M + (pid % group_size)
58    pid_n = (pid % width) // (group_size)
59
60    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
61    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
62    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
63        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
64    else:
65        ram = rm % M
66    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
67        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
68    else:
69        rbn = rn % N
70
71    rk = tl.arange(0, BLOCK_K)
72
73    idx_q = tl.program_id(1)  # batch dimension for BMM
74    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
75    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
76
77    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
78    for k in range(K, 0, -BLOCK_K):
79        if EVEN_K:
80            a = tl.load(A)
81            b = tl.load(B)
82        else:
83            a = tl.load(A, mask=rk[None, :] < k, other=0.)
84            b = tl.load(B, mask=rk[:, None] < k, other=0.)
85        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
86        A += BLOCK_K * stride_ak
87        B += BLOCK_K * stride_bk
88
89    # rematerialize rm and rn to save registers
90    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
91    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
92    idx_q = tl.program_id(1)  # batch dimension for BMM
93    idx_m = rm[:, None]
94    idx_n = rn[None, :]
95    mask = (idx_m < M) & (idx_n < N)
96
97    # inductor generates a suffix
98    {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}}
99""",
100)
101
102aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out")
103aten_baddbmm = ExternKernelChoice(torch.baddbmm, "at::baddbmm_out")
104
105
106@L.register_lowering(aten.bmm)
107def tuned_bmm(mat1, mat2, *, layout=None):
108    if all(x.get_device().type == "cpu" for x in [mat1, mat2]):
109        # decompose to small ops when memory bound
110        if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1:
111            mat1 = L.unsqueeze(mat1, -1)
112            mat2 = L.unsqueeze(mat2, 1)
113            return L.sum_(L.mul(mat1, mat2), axis=2)
114
115        def is_valid_to_require_contiguous(t):
116            if not ir.is_storage_and_layout(t):
117                return True
118            _, layout = ir.as_storage_and_layout(t, freeze=False)
119            return isinstance(layout, ir.FlexibleLayout)
120
121        def is_preferred_layout_as_bmm_input(sizes, strides):
122            # contiguous on one of the last two dims
123            return (
124                strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1])
125            ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2]))
126
127        # Make the input of bmm contiguous
128        # if it is not contiguous on either of the last two dims,
129        # because bmm cpu implementation would do contiguous() if not.
130        # This is to avoid additional copies in bmm.
131        def may_require_contiguous(t, meta_t):
132            sizes = meta_t.meta["val"].size()
133            strides = meta_t.meta["val"].stride()
134            if not is_preferred_layout_as_bmm_input(sizes, strides):
135                t = ir.ExternKernel.require_contiguous(t)
136            return t
137
138        if is_valid_to_require_contiguous(mat1):
139            meta_mat1 = V.graph.current_node.args[0]
140            mat1 = may_require_contiguous(mat1, meta_mat1)
141        if is_valid_to_require_contiguous(mat2):
142            meta_mat2 = V.graph.current_node.args[1]
143            mat2 = may_require_contiguous(mat2, meta_mat2)
144
145    m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout)
146
147    # options to tune from
148    choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
149    if use_triton_template(layout):
150        for config in mm_configs(m, n, k):
151            bmm_template.maybe_append_choice(
152                choices,
153                input_nodes=(mat1, mat2),
154                layout=layout,
155                **mm_options(config, m, n, k, layout),
156            )
157    static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
158    if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
159        from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
160
161        CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
162
163    if len(choices) == 0:
164        log.warning("No choices for GEMM, using ATen backend as fallback")
165        choices.append(aten_bmm.bind((mat1, mat2), layout))
166
167    return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
168
169
170# Don't register this since it is slower than decomposing it
171# @L.register_lowering(aten.baddbmm)
172def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
173    m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
174
175    # options to tune from
176    choices = (
177        [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
178        if use_aten_gemm_kernels()
179        else []
180    )
181    if use_triton_template(layout):
182        for config in mm_configs(m, n, k):
183            bmm_template.maybe_append_choice(
184                choices,
185                input_nodes=(inp, mat1, mat2),
186                layout=layout,
187                **mm_options(config, m, n, k, layout),
188                prefix_args=1,
189                epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
190            )
191
192    return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)
193