xref: /aosp_15_r20/external/pytorch/torch/_inductor/kernel/mm_common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import logging
5from typing import cast, List, Tuple
6
7import sympy
8
9import torch
10from torch._inductor.select_algorithm import realize_inputs
11from torch._inductor.virtualized import V
12
13from .. import config as inductor_config
14from ..runtime.runtime_utils import next_power_of_2
15from ..utils import ceildiv as cdiv
16
17
18log = logging.getLogger(__name__)
19
20
21def triton_config(num_stages, num_warps, **kwargs):
22    from triton import Config
23
24    return Config(kwargs, num_stages=num_stages, num_warps=num_warps)
25
26
27def filtered_configs(
28    m: int,
29    n: int,
30    k: int,
31    configs: List[Tuple[int, int, int, int, int]],
32    has_int8_tensor=False,
33):
34    """Heuristic to shrink configs when they are bigger than the input size"""
35
36    min_block_size = 16
37    # block_k=16 seems to be causing issues
38    # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
39    min_block_size_k = 32 if has_int8_tensor else 16
40    m = max(
41        next_power_of_2(
42            V.graph.sizevars.size_hint(
43                m, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
44            )
45        ),
46        min_block_size,
47    )
48    n = max(
49        next_power_of_2(
50            V.graph.sizevars.size_hint(
51                n, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
52            )
53        ),
54        min_block_size,
55    )
56    k = max(
57        next_power_of_2(
58            V.graph.sizevars.size_hint(
59                k, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
60            )
61        ),
62        min_block_size_k,
63    )
64    used = set()
65    for block_m, block_n, block_k, num_stages, num_warps in configs:
66        # shrink configs for small sizes
67        block_m = max(min(block_m, m), min_block_size)
68        block_n = max(min(block_n, n), min_block_size)
69        block_k = max(min(block_k, k), min_block_size_k)
70        # each warp computes 16x16 tile = 256
71        num_warps = min(num_warps, block_m * block_n // 256)
72        if torch.version.hip:
73            for matrix_instr_nonkdim in [0, 16]:
74                if matrix_instr_nonkdim != 0 and (
75                    block_m % matrix_instr_nonkdim != 0
76                    or block_n % matrix_instr_nonkdim != 0
77                ):
78                    #  block_m and block_n must be a multiple of matrix_instr_nonkdim
79                    continue
80                if (
81                    block_m,
82                    block_n,
83                    block_k,
84                    num_stages,
85                    num_warps,
86                    matrix_instr_nonkdim,
87                ) not in used:
88                    used.add(
89                        (
90                            block_m,
91                            block_n,
92                            block_k,
93                            num_stages,
94                            num_warps,
95                            matrix_instr_nonkdim,
96                        )
97                    )
98                    yield triton_config(
99                        BLOCK_M=block_m,
100                        BLOCK_N=block_n,
101                        BLOCK_K=block_k,
102                        num_stages=num_stages,
103                        num_warps=num_warps,
104                        matrix_instr_nonkdim=matrix_instr_nonkdim,
105                    )
106        else:
107            if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
108                used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
109                yield triton_config(
110                    BLOCK_M=block_m,
111                    BLOCK_N=block_n,
112                    BLOCK_K=block_k,
113                    num_stages=num_stages,
114                    num_warps=num_warps,
115                )
116
117
118# List of dictionaries to store the kernel configs. Configs that evaluate to true
119# will be utilised on the target platform. The configs are as follows:
120# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
121mm_kernel_configs = (
122    [
123        {"config": (32, 32, 16, 1, 2), "cond": True},
124        {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
125        {"config": (32, 64, 32, 5, 8), "cond": True},
126        {"config": (64, 32, 32, 5, 8), "cond": True},
127        {"config": (64, 32, 128, 5, 4), "cond": True},
128        {"config": (64, 64, 16, 2, 4), "cond": True},
129        {"config": (64, 64, 32, 2, 4), "cond": True},
130        {"config": (64, 64, 64, 3, 8), "cond": True},
131        {"config": (64, 64, 128, 5, 4), "cond": True},
132        {"config": (64, 128, 32, 3, 4), "cond": True},
133        {"config": (64, 128, 32, 4, 8), "cond": True},
134        {"config": (64, 128, 64, 3, 4), "cond": True},
135        {"config": (64, 128, 128, 4, 4), "cond": True},
136        {"config": (128, 64, 32, 3, 4), "cond": True},
137        {"config": (128, 64, 32, 4, 8), "cond": True},
138        {"config": (128, 128, 32, 2, 8), "cond": True},
139        {"config": (128, 128, 32, 3, 4), "cond": True},
140        {"config": (128, 128, 64, 3, 4), "cond": True},
141        {"config": (128, 128, 64, 5, 8), "cond": True},
142    ]
143    if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
144    else [
145        {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
146        for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
147            [16, 32, 64, 128, 256], repeat=3
148        )
149        for num_stages in [1, 2, 3, 4, 5]
150        for num_warps in [2, 4, 8]
151    ]
152)
153
154# these are only used in tuned_mm when AutoHeuristic is enabled
155# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
156# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
157# which saves compilation time (since less configs are autotuned) and potentially increase performance
158# because the learned heuristic might predict a config that is not part mm_configs
159extra_mm_kernel_configs = [
160    {"config": (16, 32, 16, 3, 2), "cond": True},
161    {"config": (16, 32, 32, 4, 2), "cond": True},
162    {"config": (16, 32, 32, 5, 2), "cond": True},
163    {"config": (64, 64, 128, 3, 4), "cond": True},
164    {"config": (128, 64, 32, 2, 2), "cond": True},
165    {"config": (128, 64, 64, 3, 8), "cond": True},
166    {"config": (128, 64, 128, 4, 8), "cond": True},
167    {"config": (128, 128, 32, 4, 4), "cond": True},
168    {"config": (128, 128, 64, 3, 8), "cond": True},
169    {"config": (128, 128, 64, 5, 4), "cond": True},
170]
171
172int8_mm_kernel_configs = [
173    {"config": (64, 64, 32, 2, 4), "cond": True},
174    {"config": (64, 128, 32, 3, 4), "cond": True},
175    {"config": (128, 64, 32, 3, 4), "cond": True},
176    {"config": (64, 128, 32, 4, 8), "cond": True},
177    {"config": (128, 64, 32, 4, 8), "cond": True},
178    {"config": (64, 32, 32, 5, 8), "cond": True},
179    {"config": (32, 64, 32, 5, 8), "cond": True},
180    {"config": (128, 128, 32, 2, 8), "cond": True},
181    {"config": (64, 64, 64, 3, 8), "cond": True},
182    # {"config": (32, 32, 128, 2, 4), "cond": True},
183    # {"config": (64, 64, 16, 2, 4), "cond": True},
184    # {"config": (32, 32, 16, 1, 2), "cond": True},
185    {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
186    {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
187]
188
189# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
190mixed_mm_kernel_configs_small_m = [
191    {"config": (16, 128, 256, 3, 4), "cond": True},
192    {"config": (16, 128, 256, 5, 8), "cond": True},
193]
194
195mixed_mm_kernel_configs = (
196    mm_kernel_configs + mixed_mm_kernel_configs_small_m
197    if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
198    else mm_kernel_configs
199)
200
201scaled_mm_kernel_configs = [
202    {"config": (128, 256, 32, 3, 8), "cond": True},
203    {"config": (256, 128, 32, 3, 8), "cond": True},
204    {"config": (256, 64, 32, 4, 4), "cond": True},
205    {"config": (64, 256, 32, 4, 4), "cond": True},
206    {"config": (128, 128, 32, 4, 4), "cond": True},
207    {"config": (128, 64, 32, 4, 4), "cond": True},
208    {"config": (64, 128, 32, 4, 4), "cond": True},
209    {"config": (128, 32, 32, 4, 4), "cond": True},
210    {"config": (64, 32, 32, 5, 2), "cond": True},
211    {"config": (256, 128, 128, 3, 8), "cond": True},
212    {"config": (256, 64, 128, 4, 4), "cond": True},
213    {"config": (64, 256, 128, 4, 4), "cond": True},
214    {"config": (128, 128, 128, 4, 4), "cond": True},
215    {"config": (128, 64, 64, 4, 4), "cond": True},
216    {"config": (64, 128, 64, 4, 4), "cond": True},
217    {"config": (128, 32, 64, 4, 4), "cond": True},
218    {"config": (64, 32, 64, 5, 2), "cond": True},
219    {"config": (16, 32, 32, 2, 2), "cond": True},
220    {"config": (16, 64, 32, 2, 2), "cond": True},
221    {"config": (16, 128, 32, 2, 4), "cond": True},
222    {"config": (16, 256, 32, 2, 4), "cond": True},
223    {"config": (16, 32, 64, 2, 2), "cond": True},
224    {"config": (16, 64, 64, 2, 2), "cond": True},
225    {"config": (16, 128, 64, 2, 4), "cond": True},
226    {"config": (16, 256, 64, 2, 4), "cond": True},
227    {"config": (32, 32, 32, 2, 2), "cond": True},
228    {"config": (32, 64, 32, 2, 2), "cond": True},
229    {"config": (32, 128, 32, 2, 4), "cond": True},
230    {"config": (32, 256, 32, 2, 4), "cond": True},
231    {"config": (32, 32, 64, 2, 2), "cond": True},
232    {"config": (32, 64, 64, 2, 2), "cond": True},
233    {"config": (32, 128, 64, 2, 4), "cond": True},
234    {"config": (32, 256, 64, 2, 4), "cond": True},
235    {"config": (16, 32, 32, 3, 2), "cond": True},
236    {"config": (16, 64, 32, 3, 2), "cond": True},
237    {"config": (16, 128, 32, 3, 4), "cond": True},
238    {"config": (16, 256, 32, 3, 4), "cond": True},
239    {"config": (16, 32, 64, 3, 2), "cond": True},
240    {"config": (16, 64, 64, 3, 2), "cond": True},
241    {"config": (16, 128, 64, 3, 4), "cond": True},
242    {"config": (16, 256, 64, 3, 4), "cond": True},
243    {"config": (32, 32, 32, 3, 2), "cond": True},
244    {"config": (32, 64, 32, 3, 2), "cond": True},
245    {"config": (32, 128, 32, 3, 4), "cond": True},
246    {"config": (32, 256, 32, 3, 4), "cond": True},
247    {"config": (32, 32, 64, 3, 2), "cond": True},
248    {"config": (32, 64, 64, 3, 2), "cond": True},
249    {"config": (32, 128, 64, 3, 4), "cond": True},
250    {"config": (32, 256, 64, 3, 4), "cond": True},
251    {"config": (16, 32, 32, 4, 2), "cond": True},
252    {"config": (16, 64, 32, 4, 2), "cond": True},
253    {"config": (16, 128, 32, 4, 4), "cond": True},
254    {"config": (16, 256, 32, 4, 4), "cond": True},
255    {"config": (16, 32, 64, 4, 2), "cond": True},
256    {"config": (16, 64, 64, 4, 2), "cond": True},
257    {"config": (16, 128, 64, 4, 4), "cond": True},
258    {"config": (16, 256, 64, 4, 4), "cond": True},
259    {"config": (32, 32, 32, 4, 2), "cond": True},
260    {"config": (32, 64, 32, 4, 2), "cond": True},
261    {"config": (32, 128, 32, 4, 4), "cond": True},
262    {"config": (32, 256, 32, 4, 4), "cond": True},
263    {"config": (32, 32, 64, 4, 2), "cond": True},
264    {"config": (32, 64, 64, 4, 2), "cond": True},
265    {"config": (32, 128, 64, 4, 4), "cond": True},
266    {"config": (32, 256, 64, 4, 4), "cond": True},
267    {"config": (16, 32, 32, 5, 2), "cond": True},
268    {"config": (16, 64, 32, 5, 2), "cond": True},
269    {"config": (16, 128, 32, 5, 4), "cond": True},
270    {"config": (16, 256, 32, 5, 4), "cond": True},
271    {"config": (16, 32, 64, 5, 2), "cond": True},
272    {"config": (16, 64, 64, 5, 2), "cond": True},
273    {"config": (16, 128, 64, 5, 4), "cond": True},
274    {"config": (16, 256, 64, 5, 4), "cond": True},
275    {"config": (32, 32, 32, 5, 2), "cond": True},
276    {"config": (32, 64, 32, 5, 2), "cond": True},
277    {"config": (32, 128, 32, 5, 4), "cond": True},
278    {"config": (32, 256, 32, 5, 4), "cond": True},
279    {"config": (32, 32, 64, 5, 2), "cond": True},
280    {"config": (32, 64, 64, 5, 2), "cond": True},
281    {"config": (32, 128, 64, 5, 4), "cond": True},
282    {"config": (32, 256, 64, 5, 4), "cond": True},
283    {"config": (16, 32, 32, 6, 2), "cond": True},
284    {"config": (16, 64, 32, 6, 2), "cond": True},
285    {"config": (16, 128, 32, 6, 4), "cond": True},
286    {"config": (16, 256, 32, 6, 4), "cond": True},
287    {"config": (16, 32, 64, 6, 2), "cond": True},
288    {"config": (16, 64, 64, 6, 2), "cond": True},
289    {"config": (16, 128, 64, 6, 4), "cond": True},
290    {"config": (16, 256, 64, 6, 4), "cond": True},
291    {"config": (32, 32, 32, 6, 2), "cond": True},
292    {"config": (32, 64, 32, 6, 2), "cond": True},
293    {"config": (32, 128, 32, 6, 4), "cond": True},
294    {"config": (32, 256, 32, 6, 4), "cond": True},
295    {"config": (32, 32, 64, 6, 2), "cond": True},
296    {"config": (32, 64, 64, 6, 2), "cond": True},
297    {"config": (32, 128, 64, 6, 4), "cond": True},
298    {"config": (32, 256, 64, 6, 4), "cond": True},
299]
300
301
302# Create filtered list of configs based on cond evaluation
303mm_platform_configs = tuple(
304    cast(Tuple[int, int, int, int, int], config["config"])
305    for config in mm_kernel_configs
306    if config["cond"]
307)
308extra_mm_platform_configs = tuple(
309    cast(Tuple[int, int, int, int, int], config["config"])
310    for config in extra_mm_kernel_configs
311    if config["cond"]
312)
313int8_platform_configs = tuple(
314    cast(Tuple[int, int, int, int, int], config["config"])
315    for config in int8_mm_kernel_configs
316    if config["cond"]
317)
318mixed_mm_platform_configs = tuple(
319    cast(Tuple[int, int, int, int, int], config["config"])
320    for config in mixed_mm_kernel_configs
321    if config["cond"]
322)
323scaled_mm_platform_configs = tuple(
324    cast(Tuple[int, int, int, int, int], config["config"])
325    for config in scaled_mm_kernel_configs
326    if config["cond"]
327)
328
329# On ROCm convert num_stages to 0 to enable software pipelining
330if torch.version.hip:
331    mm_platform_configs = tuple(
332        (config[0], config[1], config[2], 0, config[4])
333        for config in mm_platform_configs
334    )
335    extra_mm_platform_configs = tuple(
336        (config[0], config[1], config[2], 0, config[4])
337        for config in extra_mm_platform_configs
338    )
339    int8_platform_configs = tuple(
340        (config[0], config[1], config[2], 0, config[4])
341        for config in mm_platform_configs
342    )
343    mixed_mm_platform_configs = tuple(
344        (config[0], config[1], config[2], 0, config[4])
345        for config in mixed_mm_platform_configs
346    )
347    scaled_mm_platform_configs = tuple(
348        (config[0], config[1], config[2], 0, config[4])
349        for config in scaled_mm_platform_configs
350    )
351
352mm_configs = functools.partial(
353    filtered_configs,
354    configs=mm_platform_configs,
355)
356
357extra_mm_configs = functools.partial(
358    filtered_configs,
359    configs=extra_mm_platform_configs,
360)
361
362int8_mm_configs = functools.partial(
363    filtered_configs,
364    configs=int8_platform_configs,
365)
366
367mixed_mm_configs = functools.partial(
368    filtered_configs,
369    configs=mixed_mm_platform_configs,
370)
371
372scaled_mm_configs = functools.partial(
373    filtered_configs,
374    configs=scaled_mm_platform_configs,
375)
376
377
378def mm_grid(m, n, meta):
379    """
380    The CUDA grid size for matmul triton templates.
381    """
382    return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
383
384
385def acc_type(dtype):
386    if dtype in (torch.float16, torch.bfloat16):
387        return "tl.float32"
388    return f"tl.{dtype}".replace("torch.", "")
389
390
391def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
392    """
393    Common options to matmul triton templates.
394    """
395    even_k_symbolic = (
396        # it isn't worth guarding on this
397        sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
398        == config.kwargs["BLOCK_K"]
399    )
400    allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
401        not inductor_config.force_same_precision
402        or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
403    )
404    return dict(
405        GROUP_M=8,
406        EVEN_K=even_k_symbolic,
407        ALLOW_TF32=allow_tf32,
408        ACC_TYPE=acc_type(layout.dtype),
409        B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
410        num_stages=config.num_stages,
411        num_warps=config.num_warps,
412        **config.kwargs,
413    )
414
415
416def mm_args(
417    mat1,
418    mat2,
419    *others,
420    layout=None,
421    out_dtype=None,
422    use_4x2_dim=False,
423    mat2_transposed=False,
424):
425    """
426    Common arg processing for mm,bmm,addmm,etc
427    """
428    mat1, mat2 = realize_inputs(mat1, mat2)
429    *b1, m, k1 = mat1.get_size()
430    if mat2_transposed:
431        *b2, n, k2 = mat2.get_size()
432    else:
433        *b2, k2, n = mat2.get_size()
434    b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
435    if use_4x2_dim:
436        k2 = k2 * 2
437    k = V.graph.sizevars.guard_equals(k1, k2)
438    if layout is None:
439        from torch._inductor.ir import FixedLayout
440
441        if out_dtype is None:
442            out_dtype = mat1.get_dtype()
443
444        layout = FixedLayout(
445            mat1.get_device(),
446            out_dtype,
447            [*b, m, n],
448        )
449    else:
450        assert out_dtype is None, "out_dtype is ignored if layout is specified."
451    from ..lowering import expand
452
453    others = [realize_inputs(expand(x, layout.size)) for x in others]
454
455    return [m, n, k, layout, mat1, mat2, *others]
456
457
458def addmm_epilogue(dtype, alpha, beta):
459    def epilogue(acc, bias):
460        if alpha != 1:
461            acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
462        if beta != 1:
463            bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
464        return V.ops.add(acc, bias)
465
466    return epilogue
467