1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from __future__ import annotations 4 5import functools 6import logging 7from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict 8 9import torch 10 11from .. import config, ir 12from ..lowering import ( 13 add_layout_constraint, 14 constrain_to_fx_strides, 15 lowerings as L, 16 register_lowering, 17) 18from ..select_algorithm import ( 19 autotune_select_algorithm, 20 ExternKernelChoice, 21 TritonTemplate, 22) 23from ..utils import ( 24 ceildiv, 25 is_ones, 26 is_zeros, 27 pad_listlike, 28 sympy_product, 29 use_triton_template, 30) 31from ..virtualized import V 32from .mm_common import filtered_configs 33 34 35if TYPE_CHECKING: 36 from ..ir import TensorBox 37 38log = logging.getLogger(__name__) 39 40 41aten = torch.ops.aten 42 43 44def conv2d_grid(n, c, h, w, meta): 45 return ( 46 ceildiv(n * h * w, meta["BLOCK_M"]), 47 ceildiv(c, meta["BLOCK_N"]), 48 meta["GROUPS"], 49 ) 50 51 52def conv3d_grid(n, c, d, h, w, meta): 53 return ( 54 ceildiv(n * d * h * w, meta["BLOCK_M"]), 55 ceildiv(c, meta["BLOCK_N"]), 56 meta["GROUPS"], 57 ) 58 59 60# List of dictionaries to store the kernel configs. Configs that evaluate to true 61# will be utilised on the target platform 62kernel_configs = [ 63 # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" 64 {"config": (64, 256, 16, 2, 4), "cond": True}, 65 {"config": (256, 64, 16, 2, 4), "cond": True}, 66 {"config": (1024, 16, 16, 1, 8), "cond": True}, 67 {"config": (128, 128, 32, 2, 8), "cond": True}, 68 {"config": (64, 64, 32, 2, 4), "cond": True}, 69 {"config": (64, 256, 32, 2, 8), "cond": True}, 70 {"config": (256, 64, 32, 2, 8), "cond": True}, 71] 72 73# Create filtered list of configs based on conv 74platform_configs = tuple( 75 cast(Tuple[int, int, int, int, int], config["config"]) 76 for config in kernel_configs 77 if config["cond"] 78) 79 80# On ROCm convert num_stages to 1 as pipelining provides no benefit 81if torch.version.hip: 82 platform_configs = tuple( 83 (config[0], config[1], config[2], 1, config[4]) for config in platform_configs 84 ) 85 86conv_configs = functools.partial( 87 filtered_configs, 88 configs=platform_configs, 89) 90 91LOOP_BODY_2D = """ 92 idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H 93 idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W 94 idx_x_c = tl.arange(0, BLOCK_K) + k 95 96 x_ptrs = x_base + ( 97 (idx_x_h * stride_xh)[:, None] 98 + (idx_x_w * stride_xw)[:, None] 99 + (idx_x_c * stride_xc)[None, :] 100 ) 101 mask_x = ( 102 (idx_n < BATCH)[:, None] 103 & (idx_x_h >= 0)[:, None] 104 & (idx_x_h < IN_H)[:, None] 105 & (idx_x_w >= 0)[:, None] 106 & (idx_x_w < IN_W)[:, None] 107 & (idx_x_c < GROUP_IN_C)[None, :] 108 ) 109 matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) 110 111 w_ptrs = w_base + ( 112 (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) 113 ) 114 mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) 115 matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) 116 acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) 117""" 118 119""" 120This is a relatively simple conv implementation that can likely be 121improved. Many alternate conv versions can be found here: 122https://github.com/pytorch/torchdynamo/pull/971 123""" 124conv2d_template = TritonTemplate( 125 name="convolution2d", 126 grid=conv2d_grid, 127 source=r""" 128{{def_kernel("X", "W")}} 129 # Tensor dimensions 130 BATCH = {{size("X", 0)}} 131 IN_C = {{size("X", 1)}} 132 IN_H = {{size("X", 2)}} 133 IN_W = {{size("X", 3)}} 134 OUT_C = {{size(None, 1)}} 135 OUT_H = {{size(None, 2)}} 136 OUT_W = {{size(None, 3)}} 137 138 # Strides: 139 stride_xn = {{stride("X", 0)}} 140 stride_xc = {{stride("X", 1)}} 141 stride_xh = {{stride("X", 2)}} 142 stride_xw = {{stride("X", 3)}} 143 stride_wc_out = {{stride("W", 0)}} 144 stride_wc_in = {{stride("W", 1)}} 145 stride_wh = {{stride("W", 2)}} 146 stride_ww = {{stride("W", 3)}} 147 148 nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 149 idx_y_w = nhw % OUT_W 150 nh = nhw // OUT_W 151 idx_y_h = nh % OUT_H 152 idx_n = nh // OUT_H 153 idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) 154 155{% if GROUPS == 1 %} 156 group = 0 157 GROUP_IN_C = IN_C 158 GROUP_OUT_C = OUT_C 159{% else %} 160 group = tl.program_id(2) 161 GROUP_IN_C = IN_C // GROUPS 162 GROUP_OUT_C = OUT_C // GROUPS 163{% endif %} 164 165 x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] 166 w_base = ( 167 W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] 168 ) 169 170 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 171 172{% if UNROLL %} 173{% for i in range(KERNEL_H) %} 174{% for j in range(KERNEL_W) %} 175 i = {{i}} 176 j = {{j}} 177 for k in range(0, GROUP_IN_C, BLOCK_K): 178 """ 179 + LOOP_BODY_2D 180 + """ 181{% endfor %} 182{% endfor %} 183{% else %} 184 # Could be simplified, but slightly slower: 185 # for i in range(KERNEL_H): 186 # for j in range(KERNEL_W): 187 # for k in range(0, GROUP_IN_C, BLOCK_K): 188 BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K 189 for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): 190 k = (ijk % BLOCK_K_COUNT) * BLOCK_K 191 ij = ijk // BLOCK_K_COUNT 192 i = ij // KERNEL_W 193 j = ij % KERNEL_W 194 """ 195 + LOOP_BODY_2D 196 + """ 197{% endif %} 198 199 mask = ( 200 (idx_n < BATCH)[:, None] 201 & (idx_y_h < OUT_H)[:, None] 202 & (idx_y_w < OUT_W)[:, None] 203 & (idx_y_c < GROUP_OUT_C)[None, :] 204 ) 205 idx_n = idx_n[:, None] 206 idx_c = idx_y_c[None, :] + group * GROUP_OUT_C 207 idx_h = idx_y_h[:, None] 208 idx_w = idx_y_w[:, None] 209 210 # inductor generates a suffix 211 {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}} 212""", 213) 214 215LOOP_BODY_3D = """ 216 idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D 217 idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H 218 idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W 219 idx_x_c = tl.arange(0, BLOCK_K) + k 220 221 x_ptrs = x_base + ( 222 (idx_x_d * stride_xd)[:, None] 223 + (idx_x_h * stride_xh)[:, None] 224 + (idx_x_w * stride_xw)[:, None] 225 + (idx_x_c * stride_xc)[None, :] 226 ) 227 mask_x = ( 228 (idx_n < BATCH)[:, None] 229 & (idx_x_d >= 0)[:, None] 230 & (idx_x_d < IN_D)[:, None] 231 & (idx_x_h >= 0)[:, None] 232 & (idx_x_h < IN_H)[:, None] 233 & (idx_x_w >= 0)[:, None] 234 & (idx_x_w < IN_W)[:, None] 235 & (idx_x_c < GROUP_IN_C)[None, :] 236 ) 237 matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) 238 239 w_ptrs = w_base + ( 240 (idx_x_c * stride_wc_in)[:, None] + 241 (d * stride_wd) + (i * stride_wh) + (j * stride_ww) 242 ) 243 mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) 244 matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) 245 acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) 246""" 247 248conv3d_template = TritonTemplate( 249 name="convolution3d", 250 grid=conv3d_grid, 251 source=r""" 252{{def_kernel("X", "W")}} 253 # Tensor dimensions 254 BATCH = {{size("X", 0)}} 255 IN_C = {{size("X", 1)}} 256 IN_D = {{size("X", 2)}} 257 IN_H = {{size("X", 3)}} 258 IN_W = {{size("X", 4)}} 259 OUT_C = {{size(None, 1)}} 260 OUT_D = {{size(None, 2)}} 261 OUT_H = {{size(None, 3)}} 262 OUT_W = {{size(None, 4)}} 263 264 # Strides: 265 stride_xn = {{stride("X", 0)}} 266 stride_xc = {{stride("X", 1)}} 267 stride_xd = {{stride("X", 2)}} 268 stride_xh = {{stride("X", 3)}} 269 stride_xw = {{stride("X", 4)}} 270 stride_wc_out = {{stride("W", 0)}} 271 stride_wc_in = {{stride("W", 1)}} 272 stride_wd = {{stride("W", 2)}} 273 stride_wh = {{stride("W", 3)}} 274 stride_ww = {{stride("W", 4)}} 275 276 ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 277 idx_y_w = ndhw % OUT_W 278 ndh = ndhw // OUT_W 279 idx_y_h = ndh % OUT_H 280 nd = ndh // OUT_H 281 idx_y_d = nd % OUT_D 282 idx_n = nd // OUT_D 283 idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) 284 285{% if GROUPS == 1 %} 286 group = 0 287 GROUP_IN_C = IN_C 288 GROUP_OUT_C = OUT_C 289{% else %} 290 group = tl.program_id(2) 291 GROUP_IN_C = IN_C // GROUPS 292 GROUP_OUT_C = OUT_C // GROUPS 293{% endif %} 294 295 x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] 296 w_base = ( 297 W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] 298 ) 299 300 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 301 302{% if UNROLL %} 303{% for d in range(KERNEL_D) %} 304{% for i in range(KERNEL_H) %} 305{% for j in range(KERNEL_W) %} 306 d = {{d}} 307 i = {{i}} 308 j = {{j}} 309 for k in range(0, GROUP_IN_C, BLOCK_K): 310 """ 311 + LOOP_BODY_3D 312 + """ 313{% endfor %} 314{% endfor %} 315{% endfor %} 316{% else %} 317 # Could be simplified, but slightly slower: 318 # for d in range(KERNEL_D): 319 # for i in range(KERNEL_H): 320 # for j in range(KERNEL_W): 321 # for k in range(0, GROUP_IN_C, BLOCK_K): 322 BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K 323 for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT): 324 k = (dijk % BLOCK_K_COUNT) * BLOCK_K 325 dij = dijk // BLOCK_K_COUNT 326 j = dij % KERNEL_W 327 di = dij // KERNEL_W 328 i = di % KERNEL_H 329 d = di // KERNEL_H 330 """ 331 + LOOP_BODY_3D 332 + """ 333{% endif %} 334 335 mask = ( 336 (idx_n < BATCH)[:, None] 337 & (idx_y_d < OUT_D)[:, None] 338 & (idx_y_h < OUT_H)[:, None] 339 & (idx_y_w < OUT_W)[:, None] 340 & (idx_y_c < GROUP_OUT_C)[None, :] 341 ) 342 idx_n = idx_n[:, None] 343 idx_c = idx_y_c[None, :] + group * GROUP_OUT_C 344 idx_d = idx_y_d[:, None] 345 idx_h = idx_y_h[:, None] 346 idx_w = idx_y_w[:, None] 347 348 # inductor generates a suffix 349 {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}} 350""", 351) 352 353aten_convolution = ExternKernelChoice( 354 torch.convolution, 355 "at::convolution", 356 has_out_variant=False, 357 op_overload=aten.convolution.default, 358) 359 360 361def conv1x1_via_mm(x, w, *, out): 362 w = torch.squeeze(torch.squeeze(w, -1), -1) 363 return torch.matmul( 364 x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) 365 ) 366 367 368aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) 369 370 371class ConvLayoutParams(TypedDict): 372 stride: tuple[int, ...] 373 padding: tuple[int, ...] 374 dilation: tuple[int, ...] 375 transposed: bool 376 output_padding: tuple[int, ...] 377 groups: int 378 379 380def conv_layout( 381 x: TensorBox, 382 weight: TensorBox, 383 bias: Optional[TensorBox], 384 stride: Sequence[int], 385 padding: tuple[int, ...], 386 dilation: tuple[int, ...], 387 transposed: bool, 388 output_padding: tuple[int, ...], 389 groups: int, 390) -> ir.Layout: 391 """Determine output layout for a convolution""" 392 with V.graph.fake_mode: 393 output = torch.ops.aten.convolution( 394 ir.ir_node_to_tensor(x, guard_shape=True), 395 ir.ir_node_to_tensor(weight, guard_shape=True), 396 ir.ir_node_to_tensor(bias, guard_shape=True), 397 V.graph.sizevars.size_hints(stride), # type: ignore[arg-type] 398 V.graph.sizevars.size_hints(padding), # type: ignore[arg-type] 399 V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type] 400 transposed, 401 V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type] 402 groups, 403 ) 404 sizes = ir.convert_shape_to_inductor(output.size()) 405 stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] 406 407 return ir.FixedLayout( 408 x.get_device(), 409 x.get_dtype(), 410 sizes, 411 stride, 412 ) 413 414 415def channels_last_order(rank): 416 order = list(reversed(range(rank))) 417 order.insert(1, order.pop(-1)) 418 return order 419 420 421def convert_1x1_conv_to_mm(x, weight, bias): 422 # special case for 1x1 convolution, which is actually just a matmul 423 rank = len(weight.get_size()) 424 for _ in range(rank - 2): 425 weight = L[aten.squeeze](weight, dim=-1) 426 weight = L[aten.permute](weight, [1, 0]) 427 428 x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) 429 x_permute = list(range(rank)) 430 x_permute.append(x_permute.pop(1)) 431 x = L[aten.permute](x, x_permute) 432 *sizes, in_chan = x.get_size() 433 x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) 434 if bias is None: 435 result = L[aten.mm](x, weight) 436 else: 437 result = L[aten.addmm](bias, x, weight) 438 result = L[aten.reshape](result, [*sizes, -1]) 439 result_permute = list(range(rank)) 440 result_permute.insert(1, result_permute.pop(-1)) 441 return L[aten.permute](result, result_permute) 442 443 444@register_lowering(aten.convolution) 445def convolution( 446 x: TensorBox, 447 weight: TensorBox, 448 bias: TensorBox, 449 stride: List[int], 450 padding: List[int], 451 dilation: List[int], 452 transposed: bool, 453 output_padding: List[int], 454 groups: int, 455): 456 stride = tuple(stride) 457 padding = tuple(padding) 458 dilation = tuple(dilation) 459 output_padding = tuple(output_padding) 460 if not isinstance(groups, int): 461 groups = V.graph.sizevars.evaluate_static_shape(groups) 462 assert isinstance(groups, int) 463 464 # Need use hint for triton template since the template does not 465 # work with a dynamic shape. 466 # 467 # No need to evaluate_static_shape for dilation and output_padding 468 # since the template is only used when dilation is 1 and output_padding 469 # is 0. 470 stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) 471 padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) 472 473 kwargs: ConvLayoutParams = { 474 "stride": stride, 475 "padding": padding, 476 "dilation": dilation, 477 "transposed": transposed, 478 "output_padding": output_padding, 479 "groups": groups, 480 } 481 482 if len(x.get_size()) == len(weight.get_size()) - 1: 483 # add batch dimension to simplify rest of function 484 return L[aten.squeeze]( 485 convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), 486 dim=0, 487 ) 488 489 out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( 490 weight.get_size() 491 ) 492 ndim = len(kernel_shape) 493 stride = pad_listlike(stride, ndim) 494 padding = pad_listlike(padding, ndim) 495 dilation = pad_listlike(dilation, ndim) 496 output_padding = pad_listlike(output_padding, ndim) 497 498 def channels_last_conv(): 499 if V.graph.layout_opt and ndim == 2: 500 return True 501 502 layout = conv_layout(x, weight, None, **kwargs) 503 req_stride_order = ir.get_stride_order( 504 V.graph.sizevars.size_hints(layout.stride) 505 ) 506 return req_stride_order == ir.NHWC_STRIDE_ORDER 507 508 autotuning_gemm = config.max_autotune or config.max_autotune_gemm 509 510 if ( 511 (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) 512 and is_ones(kernel_shape) 513 and is_ones(stride) 514 and is_zeros(padding) 515 and is_ones(dilation) 516 and not transposed 517 and is_zeros(output_padding) 518 and groups == 1 519 and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0) 520 ): 521 return convert_1x1_conv_to_mm(x, weight, bias) 522 523 if bias is not None and ir.get_device_type(x) != "cpu": 524 # peel off the bias, cudnn is slower with it 525 result = convolution(x, weight, None, **kwargs) 526 return L[aten.add]( 527 result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) 528 ) 529 530 x.realize() 531 weight.realize() 532 533 # ndim can be 1 for convolution in models such as demucs 534 # TODO: check if it's beneficial to convert Conv1d to Conv2d and then 535 # apply channels last. 536 if V.graph.layout_opt and ndim == 2: 537 V.graph.num_channels_last_conv += 1 538 x = ir.ExternKernel.require_channels_last(x) 539 # TODO maybe we can convert weights to channels last just once before 540 # running the model. 541 weight = ir.ExternKernel.require_channels_last(weight) 542 layout = conv_layout(x, weight, None, **kwargs) 543 else: 544 layout = conv_layout(x, weight, None, **kwargs) 545 req_stride_order = ir.get_stride_order( 546 V.graph.sizevars.size_hints(layout.stride) 547 ) 548 x = ir.ExternKernel.require_stride_order(x, req_stride_order) 549 weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) 550 551 ordered_kwargs_for_cpp_kernel = [ 552 "stride", 553 "padding", 554 "dilation", 555 "transposed", 556 "output_padding", 557 "groups", 558 ] 559 if bias is None: 560 args = [x, weight] 561 kwargs["bias"] = None # type: ignore[typeddict-unknown-key] 562 ordered_kwargs_for_cpp_kernel.insert(0, "bias") 563 else: 564 args = [x, weight, bias] 565 bias.realize() 566 bias.freeze_layout() 567 V.graph.sizevars.evaluate_static_shapes(bias.get_size()) 568 569 choices = [] 570 if torch._inductor.utils._use_conv_autotune_backend("ATEN"): 571 choices = [ 572 aten_convolution.bind( 573 args, 574 layout, 575 ordered_kwargs_for_cpp_kernel, 576 **kwargs, 577 ) 578 ] 579 580 if ( 581 torch._inductor.utils._use_conv_autotune_backend("TRITON") 582 and use_triton_template(layout) 583 # templates only support these: 584 and is_ones(dilation) 585 and not transposed 586 and is_zeros(output_padding) 587 # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) 588 and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] 589 ): 590 if ( 591 is_ones(kernel_shape) 592 and is_ones(stride) 593 and is_zeros(padding) 594 and groups == 1 595 ): 596 choices.append(aten_conv1x1_via_mm.bind(args, layout)) 597 598 for cfg in conv_configs( 599 sympy_product([x.get_size()[0], *x.get_size()[2:]]), 600 out_chan, 601 in_chan, 602 ): 603 if ndim == 2: 604 conv2d_template.maybe_append_choice( 605 choices, 606 input_nodes=(x, weight), 607 layout=layout, 608 KERNEL_H=kernel_shape[0], 609 KERNEL_W=kernel_shape[1], 610 STRIDE_H=stride[0], 611 STRIDE_W=stride[1], 612 PADDING_H=padding[0], 613 PADDING_W=padding[1], 614 GROUPS=groups, 615 # TODO(jansel): try unroll for bigger kernels once fixed: 616 # https://github.com/openai/triton/issues/1254 617 UNROLL=is_ones(kernel_shape), 618 ALLOW_TF32=torch.backends.cudnn.allow_tf32, 619 num_stages=cfg.num_stages, 620 num_warps=cfg.num_warps, 621 **cfg.kwargs, 622 ) 623 elif ndim == 3: 624 conv3d_template.maybe_append_choice( 625 choices, 626 input_nodes=(x, weight), 627 layout=layout, 628 KERNEL_D=kernel_shape[0], 629 KERNEL_H=kernel_shape[1], 630 KERNEL_W=kernel_shape[2], 631 STRIDE_D=stride[0], 632 STRIDE_H=stride[1], 633 STRIDE_W=stride[2], 634 PADDING_D=padding[0], 635 PADDING_H=padding[1], 636 PADDING_W=padding[2], 637 GROUPS=groups, 638 # TODO(jansel): try unroll for bigger kernels once fixed: 639 # https://github.com/openai/triton/issues/1254 640 UNROLL=is_ones(kernel_shape), 641 ALLOW_TF32=torch.backends.cudnn.allow_tf32, 642 num_stages=cfg.num_stages, 643 num_warps=cfg.num_warps, 644 **cfg.kwargs, 645 ) 646 647 return autotune_select_algorithm("convolution", choices, args, layout) 648 649 650@register_lowering(aten._convolution) 651def _convolution( 652 x, 653 weight, 654 bias, 655 stride, 656 padding, 657 dilation, 658 transposed, 659 output_padding, 660 groups, 661 benchmark, 662 deterministic, 663 cudnn_enabled, 664 allow_tf32, 665): 666 return convolution( 667 x, weight, bias, stride, padding, dilation, transposed, output_padding, groups 668 ) 669 670 671def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): 672 assert fx_node.target == torch.ops.aten.convolution.default 673 if V.graph.layout_opt: 674 return args, kwargs 675 else: 676 return constrain_to_fx_strides(fx_node, *args, **kwargs) 677 678 679add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) 680