xref: /aosp_15_r20/external/pytorch/torch/_inductor/kernel/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from __future__ import annotations
4
5import functools
6import logging
7from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict
8
9import torch
10
11from .. import config, ir
12from ..lowering import (
13    add_layout_constraint,
14    constrain_to_fx_strides,
15    lowerings as L,
16    register_lowering,
17)
18from ..select_algorithm import (
19    autotune_select_algorithm,
20    ExternKernelChoice,
21    TritonTemplate,
22)
23from ..utils import (
24    ceildiv,
25    is_ones,
26    is_zeros,
27    pad_listlike,
28    sympy_product,
29    use_triton_template,
30)
31from ..virtualized import V
32from .mm_common import filtered_configs
33
34
35if TYPE_CHECKING:
36    from ..ir import TensorBox
37
38log = logging.getLogger(__name__)
39
40
41aten = torch.ops.aten
42
43
44def conv2d_grid(n, c, h, w, meta):
45    return (
46        ceildiv(n * h * w, meta["BLOCK_M"]),
47        ceildiv(c, meta["BLOCK_N"]),
48        meta["GROUPS"],
49    )
50
51
52def conv3d_grid(n, c, d, h, w, meta):
53    return (
54        ceildiv(n * d * h * w, meta["BLOCK_M"]),
55        ceildiv(c, meta["BLOCK_N"]),
56        meta["GROUPS"],
57    )
58
59
60# List of dictionaries to store the kernel configs. Configs that evaluate to true
61# will be utilised on the target platform
62kernel_configs = [
63    # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
64    {"config": (64, 256, 16, 2, 4), "cond": True},
65    {"config": (256, 64, 16, 2, 4), "cond": True},
66    {"config": (1024, 16, 16, 1, 8), "cond": True},
67    {"config": (128, 128, 32, 2, 8), "cond": True},
68    {"config": (64, 64, 32, 2, 4), "cond": True},
69    {"config": (64, 256, 32, 2, 8), "cond": True},
70    {"config": (256, 64, 32, 2, 8), "cond": True},
71]
72
73# Create filtered list of configs based on conv
74platform_configs = tuple(
75    cast(Tuple[int, int, int, int, int], config["config"])
76    for config in kernel_configs
77    if config["cond"]
78)
79
80# On ROCm convert num_stages to 1 as pipelining provides no benefit
81if torch.version.hip:
82    platform_configs = tuple(
83        (config[0], config[1], config[2], 1, config[4]) for config in platform_configs
84    )
85
86conv_configs = functools.partial(
87    filtered_configs,
88    configs=platform_configs,
89)
90
91LOOP_BODY_2D = """
92        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
93        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
94        idx_x_c = tl.arange(0, BLOCK_K) + k
95
96        x_ptrs = x_base + (
97            (idx_x_h * stride_xh)[:, None]
98            + (idx_x_w * stride_xw)[:, None]
99            + (idx_x_c * stride_xc)[None, :]
100        )
101        mask_x = (
102            (idx_n < BATCH)[:, None]
103            & (idx_x_h >= 0)[:, None]
104            & (idx_x_h < IN_H)[:, None]
105            & (idx_x_w >= 0)[:, None]
106            & (idx_x_w < IN_W)[:, None]
107            & (idx_x_c < GROUP_IN_C)[None, :]
108        )
109        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
110
111        w_ptrs = w_base + (
112            (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww)
113        )
114        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
115        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
116        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
117"""
118
119"""
120This is a relatively simple conv implementation that can likely be
121improved.  Many alternate conv versions can be found here:
122https://github.com/pytorch/torchdynamo/pull/971
123"""
124conv2d_template = TritonTemplate(
125    name="convolution2d",
126    grid=conv2d_grid,
127    source=r"""
128{{def_kernel("X", "W")}}
129    # Tensor dimensions
130    BATCH = {{size("X", 0)}}
131    IN_C = {{size("X", 1)}}
132    IN_H = {{size("X", 2)}}
133    IN_W = {{size("X", 3)}}
134    OUT_C = {{size(None, 1)}}
135    OUT_H = {{size(None, 2)}}
136    OUT_W = {{size(None, 3)}}
137
138    # Strides:
139    stride_xn = {{stride("X", 0)}}
140    stride_xc = {{stride("X", 1)}}
141    stride_xh = {{stride("X", 2)}}
142    stride_xw = {{stride("X", 3)}}
143    stride_wc_out = {{stride("W", 0)}}
144    stride_wc_in = {{stride("W", 1)}}
145    stride_wh = {{stride("W", 2)}}
146    stride_ww = {{stride("W", 3)}}
147
148    nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
149    idx_y_w = nhw % OUT_W
150    nh = nhw // OUT_W
151    idx_y_h = nh % OUT_H
152    idx_n = nh // OUT_H
153    idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
154
155{% if GROUPS == 1 %}
156    group = 0
157    GROUP_IN_C = IN_C
158    GROUP_OUT_C = OUT_C
159{% else %}
160    group = tl.program_id(2)
161    GROUP_IN_C = IN_C // GROUPS
162    GROUP_OUT_C = OUT_C // GROUPS
163{% endif %}
164
165    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
166    w_base = (
167        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
168    )
169
170    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
171
172{% if UNROLL %}
173{% for i in range(KERNEL_H) %}
174{% for j in range(KERNEL_W) %}
175    i = {{i}}
176    j = {{j}}
177    for k in range(0, GROUP_IN_C, BLOCK_K):
178        """
179    + LOOP_BODY_2D
180    + """
181{% endfor %}
182{% endfor %}
183{% else %}
184    # Could be simplified, but slightly slower:
185    # for i in range(KERNEL_H):
186    #     for j in range(KERNEL_W):
187    #         for k in range(0, GROUP_IN_C, BLOCK_K):
188    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
189    for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
190        k = (ijk % BLOCK_K_COUNT) * BLOCK_K
191        ij = ijk // BLOCK_K_COUNT
192        i = ij // KERNEL_W
193        j = ij % KERNEL_W
194        """
195    + LOOP_BODY_2D
196    + """
197{% endif %}
198
199    mask = (
200        (idx_n < BATCH)[:, None]
201        & (idx_y_h < OUT_H)[:, None]
202        & (idx_y_w < OUT_W)[:, None]
203        & (idx_y_c < GROUP_OUT_C)[None, :]
204    )
205    idx_n = idx_n[:, None]
206    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
207    idx_h = idx_y_h[:, None]
208    idx_w = idx_y_w[:, None]
209
210    # inductor generates a suffix
211    {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}}
212""",
213)
214
215LOOP_BODY_3D = """
216        idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D
217        idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
218        idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
219        idx_x_c = tl.arange(0, BLOCK_K) + k
220
221        x_ptrs = x_base + (
222            (idx_x_d * stride_xd)[:, None]
223            + (idx_x_h * stride_xh)[:, None]
224            + (idx_x_w * stride_xw)[:, None]
225            + (idx_x_c * stride_xc)[None, :]
226        )
227        mask_x = (
228            (idx_n < BATCH)[:, None]
229            & (idx_x_d >= 0)[:, None]
230            & (idx_x_d < IN_D)[:, None]
231            & (idx_x_h >= 0)[:, None]
232            & (idx_x_h < IN_H)[:, None]
233            & (idx_x_w >= 0)[:, None]
234            & (idx_x_w < IN_W)[:, None]
235            & (idx_x_c < GROUP_IN_C)[None, :]
236        )
237        matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
238
239        w_ptrs = w_base + (
240            (idx_x_c * stride_wc_in)[:, None] +
241            (d * stride_wd) + (i * stride_wh) + (j * stride_ww)
242        )
243        mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C)
244        matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
245        acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32)
246"""
247
248conv3d_template = TritonTemplate(
249    name="convolution3d",
250    grid=conv3d_grid,
251    source=r"""
252{{def_kernel("X", "W")}}
253    # Tensor dimensions
254    BATCH = {{size("X", 0)}}
255    IN_C = {{size("X", 1)}}
256    IN_D = {{size("X", 2)}}
257    IN_H = {{size("X", 3)}}
258    IN_W = {{size("X", 4)}}
259    OUT_C = {{size(None, 1)}}
260    OUT_D = {{size(None, 2)}}
261    OUT_H = {{size(None, 3)}}
262    OUT_W = {{size(None, 4)}}
263
264    # Strides:
265    stride_xn = {{stride("X", 0)}}
266    stride_xc = {{stride("X", 1)}}
267    stride_xd = {{stride("X", 2)}}
268    stride_xh = {{stride("X", 3)}}
269    stride_xw = {{stride("X", 4)}}
270    stride_wc_out = {{stride("W", 0)}}
271    stride_wc_in = {{stride("W", 1)}}
272    stride_wd = {{stride("W", 2)}}
273    stride_wh = {{stride("W", 3)}}
274    stride_ww = {{stride("W", 4)}}
275
276    ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
277    idx_y_w = ndhw % OUT_W
278    ndh = ndhw // OUT_W
279    idx_y_h = ndh % OUT_H
280    nd = ndh // OUT_H
281    idx_y_d = nd % OUT_D
282    idx_n = nd // OUT_D
283    idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
284
285{% if GROUPS == 1 %}
286    group = 0
287    GROUP_IN_C = IN_C
288    GROUP_OUT_C = OUT_C
289{% else %}
290    group = tl.program_id(2)
291    GROUP_IN_C = IN_C // GROUPS
292    GROUP_OUT_C = OUT_C // GROUPS
293{% endif %}
294
295    x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None]
296    w_base = (
297        W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :]
298    )
299
300    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
301
302{% if UNROLL %}
303{% for d in range(KERNEL_D) %}
304{% for i in range(KERNEL_H) %}
305{% for j in range(KERNEL_W) %}
306    d = {{d}}
307    i = {{i}}
308    j = {{j}}
309    for k in range(0, GROUP_IN_C, BLOCK_K):
310        """
311    + LOOP_BODY_3D
312    + """
313{% endfor %}
314{% endfor %}
315{% endfor %}
316{% else %}
317    # Could be simplified, but slightly slower:
318    # for d in range(KERNEL_D):
319    #   for i in range(KERNEL_H):
320    #     for j in range(KERNEL_W):
321    #         for k in range(0, GROUP_IN_C, BLOCK_K):
322    BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K
323    for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT):
324        k = (dijk % BLOCK_K_COUNT) * BLOCK_K
325        dij = dijk // BLOCK_K_COUNT
326        j = dij % KERNEL_W
327        di = dij // KERNEL_W
328        i = di % KERNEL_H
329        d = di // KERNEL_H
330        """
331    + LOOP_BODY_3D
332    + """
333{% endif %}
334
335    mask = (
336        (idx_n < BATCH)[:, None]
337        & (idx_y_d < OUT_D)[:, None]
338        & (idx_y_h < OUT_H)[:, None]
339        & (idx_y_w < OUT_W)[:, None]
340        & (idx_y_c < GROUP_OUT_C)[None, :]
341    )
342    idx_n = idx_n[:, None]
343    idx_c = idx_y_c[None, :] + group * GROUP_OUT_C
344    idx_d = idx_y_d[:, None]
345    idx_h = idx_y_h[:, None]
346    idx_w = idx_y_w[:, None]
347
348    # inductor generates a suffix
349    {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}}
350""",
351)
352
353aten_convolution = ExternKernelChoice(
354    torch.convolution,
355    "at::convolution",
356    has_out_variant=False,
357    op_overload=aten.convolution.default,
358)
359
360
361def conv1x1_via_mm(x, w, *, out):
362    w = torch.squeeze(torch.squeeze(w, -1), -1)
363    return torch.matmul(
364        x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1)
365    )
366
367
368aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None)
369
370
371class ConvLayoutParams(TypedDict):
372    stride: tuple[int, ...]
373    padding: tuple[int, ...]
374    dilation: tuple[int, ...]
375    transposed: bool
376    output_padding: tuple[int, ...]
377    groups: int
378
379
380def conv_layout(
381    x: TensorBox,
382    weight: TensorBox,
383    bias: Optional[TensorBox],
384    stride: Sequence[int],
385    padding: tuple[int, ...],
386    dilation: tuple[int, ...],
387    transposed: bool,
388    output_padding: tuple[int, ...],
389    groups: int,
390) -> ir.Layout:
391    """Determine output layout for a convolution"""
392    with V.graph.fake_mode:
393        output = torch.ops.aten.convolution(
394            ir.ir_node_to_tensor(x, guard_shape=True),
395            ir.ir_node_to_tensor(weight, guard_shape=True),
396            ir.ir_node_to_tensor(bias, guard_shape=True),
397            V.graph.sizevars.size_hints(stride),  # type: ignore[arg-type]
398            V.graph.sizevars.size_hints(padding),  # type: ignore[arg-type]
399            V.graph.sizevars.size_hints(dilation),  # type: ignore[arg-type]
400            transposed,
401            V.graph.sizevars.size_hints(output_padding),  # type: ignore[arg-type]
402            groups,
403        )
404        sizes = ir.convert_shape_to_inductor(output.size())
405        stride = ir.convert_shape_to_inductor(output.stride())  # type: ignore[assignment]
406
407    return ir.FixedLayout(
408        x.get_device(),
409        x.get_dtype(),
410        sizes,
411        stride,
412    )
413
414
415def channels_last_order(rank):
416    order = list(reversed(range(rank)))
417    order.insert(1, order.pop(-1))
418    return order
419
420
421def convert_1x1_conv_to_mm(x, weight, bias):
422    # special case for 1x1 convolution, which is actually just a matmul
423    rank = len(weight.get_size())
424    for _ in range(rank - 2):
425        weight = L[aten.squeeze](weight, dim=-1)
426    weight = L[aten.permute](weight, [1, 0])
427
428    x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank))
429    x_permute = list(range(rank))
430    x_permute.append(x_permute.pop(1))
431    x = L[aten.permute](x, x_permute)
432    *sizes, in_chan = x.get_size()
433    x = L[aten.reshape](x, [sympy_product(sizes), in_chan])
434    if bias is None:
435        result = L[aten.mm](x, weight)
436    else:
437        result = L[aten.addmm](bias, x, weight)
438    result = L[aten.reshape](result, [*sizes, -1])
439    result_permute = list(range(rank))
440    result_permute.insert(1, result_permute.pop(-1))
441    return L[aten.permute](result, result_permute)
442
443
444@register_lowering(aten.convolution)
445def convolution(
446    x: TensorBox,
447    weight: TensorBox,
448    bias: TensorBox,
449    stride: List[int],
450    padding: List[int],
451    dilation: List[int],
452    transposed: bool,
453    output_padding: List[int],
454    groups: int,
455):
456    stride = tuple(stride)
457    padding = tuple(padding)
458    dilation = tuple(dilation)
459    output_padding = tuple(output_padding)
460    if not isinstance(groups, int):
461        groups = V.graph.sizevars.evaluate_static_shape(groups)
462    assert isinstance(groups, int)
463
464    # Need use hint for triton template since the template does not
465    # work with a dynamic shape.
466    #
467    # No need to evaluate_static_shape for dilation and output_padding
468    # since the template is only used when dilation is 1 and output_padding
469    # is 0.
470    stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride))
471    padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding))
472
473    kwargs: ConvLayoutParams = {
474        "stride": stride,
475        "padding": padding,
476        "dilation": dilation,
477        "transposed": transposed,
478        "output_padding": output_padding,
479        "groups": groups,
480    }
481
482    if len(x.get_size()) == len(weight.get_size()) - 1:
483        # add batch dimension to simplify rest of function
484        return L[aten.squeeze](
485            convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs),
486            dim=0,
487        )
488
489    out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes(
490        weight.get_size()
491    )
492    ndim = len(kernel_shape)
493    stride = pad_listlike(stride, ndim)
494    padding = pad_listlike(padding, ndim)
495    dilation = pad_listlike(dilation, ndim)
496    output_padding = pad_listlike(output_padding, ndim)
497
498    def channels_last_conv():
499        if V.graph.layout_opt and ndim == 2:
500            return True
501
502        layout = conv_layout(x, weight, None, **kwargs)
503        req_stride_order = ir.get_stride_order(
504            V.graph.sizevars.size_hints(layout.stride)
505        )
506        return req_stride_order == ir.NHWC_STRIDE_ORDER
507
508    autotuning_gemm = config.max_autotune or config.max_autotune_gemm
509
510    if (
511        (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv()))
512        and is_ones(kernel_shape)
513        and is_ones(stride)
514        and is_zeros(padding)
515        and is_ones(dilation)
516        and not transposed
517        and is_zeros(output_padding)
518        and groups == 1
519        and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0)
520    ):
521        return convert_1x1_conv_to_mm(x, weight, bias)
522
523    if bias is not None and ir.get_device_type(x) != "cpu":
524        # peel off the bias, cudnn is slower with it
525        result = convolution(x, weight, None, **kwargs)
526        return L[aten.add](
527            result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1])
528        )
529
530    x.realize()
531    weight.realize()
532
533    # ndim can be 1 for convolution in models such as demucs
534    # TODO: check if it's beneficial to convert Conv1d to Conv2d and then
535    # apply channels last.
536    if V.graph.layout_opt and ndim == 2:
537        V.graph.num_channels_last_conv += 1
538        x = ir.ExternKernel.require_channels_last(x)
539        # TODO maybe we can convert weights to channels last just once before
540        # running the model.
541        weight = ir.ExternKernel.require_channels_last(weight)
542        layout = conv_layout(x, weight, None, **kwargs)
543    else:
544        layout = conv_layout(x, weight, None, **kwargs)
545        req_stride_order = ir.get_stride_order(
546            V.graph.sizevars.size_hints(layout.stride)
547        )
548        x = ir.ExternKernel.require_stride_order(x, req_stride_order)
549        weight = ir.ExternKernel.require_stride_order(weight, req_stride_order)
550
551    ordered_kwargs_for_cpp_kernel = [
552        "stride",
553        "padding",
554        "dilation",
555        "transposed",
556        "output_padding",
557        "groups",
558    ]
559    if bias is None:
560        args = [x, weight]
561        kwargs["bias"] = None  # type: ignore[typeddict-unknown-key]
562        ordered_kwargs_for_cpp_kernel.insert(0, "bias")
563    else:
564        args = [x, weight, bias]
565        bias.realize()
566        bias.freeze_layout()
567        V.graph.sizevars.evaluate_static_shapes(bias.get_size())
568
569    choices = []
570    if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
571        choices = [
572            aten_convolution.bind(
573                args,
574                layout,
575                ordered_kwargs_for_cpp_kernel,
576                **kwargs,
577            )
578        ]
579
580    if (
581        torch._inductor.utils._use_conv_autotune_backend("TRITON")
582        and use_triton_template(layout)
583        # templates only support these:
584        and is_ones(dilation)
585        and not transposed
586        and is_zeros(output_padding)
587        # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0)
588        and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1])  # type: ignore[arg-type]
589    ):
590        if (
591            is_ones(kernel_shape)
592            and is_ones(stride)
593            and is_zeros(padding)
594            and groups == 1
595        ):
596            choices.append(aten_conv1x1_via_mm.bind(args, layout))
597
598        for cfg in conv_configs(
599            sympy_product([x.get_size()[0], *x.get_size()[2:]]),
600            out_chan,
601            in_chan,
602        ):
603            if ndim == 2:
604                conv2d_template.maybe_append_choice(
605                    choices,
606                    input_nodes=(x, weight),
607                    layout=layout,
608                    KERNEL_H=kernel_shape[0],
609                    KERNEL_W=kernel_shape[1],
610                    STRIDE_H=stride[0],
611                    STRIDE_W=stride[1],
612                    PADDING_H=padding[0],
613                    PADDING_W=padding[1],
614                    GROUPS=groups,
615                    # TODO(jansel): try unroll for bigger kernels once fixed:
616                    #               https://github.com/openai/triton/issues/1254
617                    UNROLL=is_ones(kernel_shape),
618                    ALLOW_TF32=torch.backends.cudnn.allow_tf32,
619                    num_stages=cfg.num_stages,
620                    num_warps=cfg.num_warps,
621                    **cfg.kwargs,
622                )
623            elif ndim == 3:
624                conv3d_template.maybe_append_choice(
625                    choices,
626                    input_nodes=(x, weight),
627                    layout=layout,
628                    KERNEL_D=kernel_shape[0],
629                    KERNEL_H=kernel_shape[1],
630                    KERNEL_W=kernel_shape[2],
631                    STRIDE_D=stride[0],
632                    STRIDE_H=stride[1],
633                    STRIDE_W=stride[2],
634                    PADDING_D=padding[0],
635                    PADDING_H=padding[1],
636                    PADDING_W=padding[2],
637                    GROUPS=groups,
638                    # TODO(jansel): try unroll for bigger kernels once fixed:
639                    #               https://github.com/openai/triton/issues/1254
640                    UNROLL=is_ones(kernel_shape),
641                    ALLOW_TF32=torch.backends.cudnn.allow_tf32,
642                    num_stages=cfg.num_stages,
643                    num_warps=cfg.num_warps,
644                    **cfg.kwargs,
645                )
646
647    return autotune_select_algorithm("convolution", choices, args, layout)
648
649
650@register_lowering(aten._convolution)
651def _convolution(
652    x,
653    weight,
654    bias,
655    stride,
656    padding,
657    dilation,
658    transposed,
659    output_padding,
660    groups,
661    benchmark,
662    deterministic,
663    cudnn_enabled,
664    allow_tf32,
665):
666    return convolution(
667        x, weight, bias, stride, padding, dilation, transposed, output_padding, groups
668    )
669
670
671def constrain_conv_to_fx_strides(fx_node, *args, **kwargs):
672    assert fx_node.target == torch.ops.aten.convolution.default
673    if V.graph.layout_opt:
674        return args, kwargs
675    else:
676        return constrain_to_fx_strides(fx_node, *args, **kwargs)
677
678
679add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides)
680