xref: /aosp_15_r20/external/pytorch/torch/_meta_registrations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4from enum import Enum
5from typing import List, Optional, Sequence, Tuple, Union
6
7import torch
8import torch._prims_common as utils
9from torch import SymBool, SymFloat, Tensor
10from torch._decomp import (
11    _add_op_to_registry,
12    _convert_out_params,
13    global_decomposition_table,
14    meta_table,
15)
16from torch._ops import OpOverload
17from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
18from torch._prims_common import (
19    corresponding_complex_dtype,
20    corresponding_real_dtype,
21    elementwise_dtypes,
22    ELEMENTWISE_TYPE_PROMOTION_KIND,
23    IntLike,
24    make_contiguous_strides_for,
25    Number,
26    TensorLike,
27)
28from torch._prims_common.wrappers import (
29    _maybe_convert_to_dtype,
30    _maybe_resize_out,
31    _resize_output_check,
32    _safe_copy_out,
33    out_wrapper,
34)
35from torch._refs import _broadcast_shapes, _maybe_broadcast
36from torch.utils import _pytree as pytree
37
38
39aten = torch.ops.aten
40
41_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
42
43
44def register_meta(op):
45    def wrapper(fn):
46        fn = _convert_out_params(fn)
47
48        def register(op):
49            _add_op_to_registry(meta_table, op, fn)
50
51        pytree.tree_map_(register, op)
52        return fn
53
54    return wrapper
55
56
57def elementwise_meta(
58    *args,
59    type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
60):
61    # Perform type promotion, as this is expected from prim_metafunction
62    _, result_dtype = utils.elementwise_dtypes(
63        *args,
64        type_promotion_kind=type_promotion,
65    )
66    args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
67
68    # Broadcast
69    args = _maybe_broadcast(*args)
70
71    # Perform prim checks
72    return _prim_elementwise_meta(
73        *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
74    )
75
76
77def toRealValueType(dtype):
78    from_complex = {
79        torch.complex32: torch.half,
80        torch.cfloat: torch.float,
81        torch.cdouble: torch.double,
82    }
83    return from_complex.get(dtype, dtype)
84
85
86def check_inplace_broadcast(self_shape, *args_shape):
87    broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
88    torch._check(
89        broadcasted_shape == self_shape,
90        lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
91    )
92
93
94@register_meta([aten.linspace, aten.logspace])
95@out_wrapper()
96def meta_linspace_logspace(
97    start,
98    end,
99    steps,
100    base=None,
101    dtype=None,
102    device=None,
103    layout=torch.strided,
104    pin_memory=False,
105    requires_grad=False,
106):
107    if isinstance(start, torch.Tensor):
108        torch._check(
109            start.dim() == 0,
110            lambda: "linspace only supports 0-dimensional start and end tensors",
111        )
112    if isinstance(end, torch.Tensor):
113        torch._check(
114            end.dim() == 0,
115            lambda: "linspace only supports 0-dimensional start and end tensors",
116        )
117
118    if any(isinstance(arg, complex) for arg in (start, end, steps)):
119        default_complex_dtype = utils.corresponding_complex_dtype(
120            torch.get_default_dtype()
121        )
122        if dtype is None:
123            dtype = default_complex_dtype
124        else:
125            torch._check(
126                utils.is_complex_dtype(dtype),
127                lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
128            )
129    else:
130        dtype = dtype or torch.get_default_dtype()
131    assert isinstance(dtype, torch.dtype)
132
133    # steps does not participate in the computation of the dtype
134    torch._check_type(
135        isinstance(steps, IntLike),
136        lambda: f"received an invalid combination of arguments - got \
137({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
138    )
139    assert isinstance(steps, IntLike)  # for mypy
140    torch._check(steps >= 0, lambda: "number of steps must be non-negative")
141
142    return torch.empty(
143        (steps,),  # type: ignore[arg-type]
144        dtype=dtype,
145        layout=layout,
146        device="meta",
147        pin_memory=pin_memory,
148        requires_grad=requires_grad,
149    )
150
151
152@register_meta([aten.take.default, aten.take.out])
153@out_wrapper()
154def meta_take(self, index):
155    # Type and device checks
156    torch._check(
157        index.dtype == torch.long,
158        lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
159    )
160    # Index checks
161    torch._check_index(
162        not (self.numel() == 0 and index.numel() != 0),
163        lambda: "take(): tried to take from an empty tensor",
164    )
165    return self.new_empty(index.shape)
166
167
168@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
169@out_wrapper()
170def linalg_cross(self, other, *, dim=-1):
171    x_d = self.ndim
172    y_d = other.ndim
173    torch._check(
174        x_d == y_d,
175        lambda: "linalg.cross: inputs must have the same number of dimensions.",
176    )
177    torch._check(
178        self.size(dim) == 3 and other.size(dim) == 3,
179        lambda: (
180            f"linalg.cross: inputs dimension {dim} must have length 3. "
181            f"Got {self.size(dim)} and {other.size(dim)}"
182        ),
183    )
184    out_shape = _broadcast_shapes(self.shape, other.shape)
185    return self.new_empty(out_shape)
186
187
188@register_meta(aten.linalg_matrix_exp)
189@out_wrapper()
190def linalg_matrix_exp(self):
191    squareCheckInputs(self, "linalg.matrix_exp")
192    checkFloatingOrComplex(self, "linalg.matrix_exp")
193    return torch.empty_like(self, memory_format=torch.contiguous_format)
194
195
196@register_meta(
197    [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
198)
199@out_wrapper("values", "indices")
200def cummaxmin(self, dim):
201    values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
202    indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
203    if self.numel() != 0 and self.ndim != 0:
204        # Checks that dim is within bounds
205        maybe_wrap_dim(dim, self.ndim)
206    return values, indices
207
208
209@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
210@out_wrapper()
211def logcumsumexp(self, dim):
212    # Checks that dim is within bounds
213    maybe_wrap_dim(dim, self.ndim)
214    return torch.empty_like(self).contiguous()
215
216
217# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
218def _exec_fft(out, self, out_sizes, dim, forward):
219    ndim = self.ndim
220    signal_ndim = len(dim)
221    batch_dims = ndim - signal_ndim
222
223    # Permute dimensions so batch dimensions come first, and in stride order
224    dim_permute = list(range(ndim))
225
226    is_transformed_dim = [False for _ in range(ndim)]
227    for d in dim:
228        is_transformed_dim[d] = True
229
230    # std::partition
231    left, right = [], []
232    for d in dim_permute:
233        if not is_transformed_dim[d]:
234            left.append(d)
235        else:
236            right.append(d)
237    dim_permute = left + right
238    batch_end = len(left)
239
240    self_strides = self.stride()
241    tmp = dim_permute[:batch_end]
242    tmp.sort(key=lambda x: self_strides[x], reverse=True)
243    dim_permute = tmp + dim_permute[batch_end:]
244    input = self.permute(dim_permute)
245
246    # Collapse batch dimensions into a single dimension
247    batched_sizes = [-1] + list(input.shape[batch_dims:])
248    input = input.reshape(batched_sizes)
249
250    batch_size = input.size(0)
251    batched_sizes[0] = batch_size
252    batched_out_sizes = batched_sizes
253    for i in range(len(dim)):
254        batched_out_sizes[i + 1] = out_sizes[dim[i]]
255    out = out.reshape(batched_out_sizes)
256
257    # Reshaping to original batch shape and inverting the dimension permutation
258    out_strides = [0 for _ in range(ndim)]
259    batch_numel = 1
260    i = batch_dims - 1
261    while i >= 0:
262        out_strides[dim_permute[i]] = batch_numel * out.stride(0)
263        batch_numel *= out_sizes[dim_permute[i]]
264        i -= 1
265    for i in range(batch_dims, ndim):
266        out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
267    return out.as_strided(out_sizes, out_strides, out.storage_offset())
268
269
270# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
271# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
272@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
273@out_wrapper()
274def meta_fft_c2c(self, dim, normalization, forward):
275    assert self.dtype.is_complex
276
277    out_sizes = self.shape
278    output = self.new_empty(out_sizes)
279
280    if not dim:
281        return output
282
283    sorted_dims = dim[:]
284    self_strides = self.stride()
285    sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
286    output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
287
288    return output
289
290
291@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
292@out_wrapper()
293def meta_fft_r2c(self, dim, normalization, onesided):
294    assert self.dtype.is_floating_point
295    output_sizes = list(self.size())
296
297    if onesided:
298        last_dim = dim[-1]
299        last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
300        output_sizes[last_dim] = last_dim_halfsize
301
302    return self.new_empty(
303        output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
304    )
305
306
307@register_meta(aten.randperm.generator_out)
308def meta_randperm(n, *, generator=None, out):
309    return _maybe_resize_out(out, torch.Size([n]))
310
311
312@register_meta(aten.randperm.default)
313def meta_randperm_default(
314    n,
315    *,
316    dtype=torch.long,
317    layout=None,
318    device=None,
319    pin_memory=None,
320):
321    return torch.empty(
322        n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
323    )
324
325
326@register_meta([aten.randint.default, aten.randint.out])
327@out_wrapper()
328def meta_randint(
329    high,
330    size,
331    *,
332    dtype=torch.long,
333    layout=None,
334    device=None,
335    pin_memory=None,
336):
337    return torch.empty(
338        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
339    )
340
341
342@register_meta([aten.randint.low, aten.randint.low_out])
343@out_wrapper()
344def meta_randint_low(
345    low,
346    high,
347    size,
348    *,
349    dtype=torch.long,
350    layout=None,
351    device=None,
352    pin_memory=None,
353):
354    return torch.empty(
355        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
356    )
357
358
359@register_meta([aten.rand.default, aten.rand.out])
360@out_wrapper()
361def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
362    return torch.empty(
363        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
364    )
365
366
367@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
368@out_wrapper()
369def meta_fft_c2r(self, dim, normalization, lastdim):
370    assert self.dtype.is_complex
371    output_sizes = list(self.size())
372    output_sizes[dim[-1]] = lastdim
373    return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
374
375
376@register_meta(aten.copy_.default)
377def meta_copy_(self, src, non_blocking=False):
378    # This code simulates the original decomp from inductor,
379    # which runs most of the meta checks that we care about.
380    # In theory, we should make this more robust by carefully
381    # auditing our C++ copy_() kernel and copying the checks here.
382    from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
383
384    # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are
385    # calling an actual copy_, you'll get that automatically
386    # https://github.com/pytorch/pytorch/issues/122477
387    if (
388        not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1
389    ):  # 1 == MemOverlap::Yes
390        raise RuntimeError(
391            "more than one element of the written-to tensor refers to a single memory location"
392        )
393
394    if isinstance(src, Tensor):
395        intermediate = src.to(self, non_blocking)
396        if self.size() != intermediate.size():
397            aten.expand_copy.default(intermediate, self.size())
398    return self
399
400
401def inferUnsqueezeGeometry(tensor, dim):
402    result_sizes = list(tensor.size())
403    result_strides = list(tensor.stride())
404    new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
405    result_sizes.insert(dim, 1)
406    result_strides.insert(dim, new_stride)
407    return result_sizes, result_strides
408
409
410@register_meta(aten.unsqueeze_.default)
411def meta_unsqueeze_(self, dim):
412    dim = maybe_wrap_dim(dim, self.dim() + 1)
413    g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
414    self.as_strided_(g_sizes, g_strides)
415    return self
416
417
418@register_meta(aten._sparse_semi_structured_linear)
419def meta_sparse_structured_linear(
420    input: Tensor,
421    weight: Tensor,
422    _meta: Tensor,
423    bias: Optional[Tensor] = None,
424    _activation_opt: Optional[str] = None,
425    out_dtype: Optional[torch.dtype] = None,
426):
427    output_sizes = list(input.shape)
428    if bias is not None:
429        assert weight.size(0) == bias.size(0), "output size mismatch"
430    assert weight.size(1) == input.size(-1) / 2
431    output_sizes[-1] = weight.size(0)
432
433    # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
434    # We assume that we have already squashed the inputs into a 2-D tensor
435    # Then, as the output is transposed, we need to propagate the transposed
436    # stride information to the output tensor
437    assert len(input.shape) == 2, "we can only handle the squashed input case"
438    transposed_strides = (1, input.size(0))
439
440    if out_dtype is not None:
441        assert (
442            input.dtype == torch.int8 and out_dtype == torch.int32
443        ), "out_dtype is only supported for i8i8->i32 linear operator"
444    output = input.new_empty(
445        output_sizes,
446        dtype=input.dtype if out_dtype is None else out_dtype,
447    ).as_strided(output_sizes, transposed_strides)
448
449    return output
450
451
452@register_meta(aten._sparse_semi_structured_mm)
453def meta_sparse_structured_mm(
454    mat1: Tensor,
455    mat1_meta: Tensor,
456    mat2: Tensor,
457    out_dtype: Optional[torch.dtype] = None,
458):
459    assert len(mat1.shape) == 2
460    assert len(mat1_meta.shape) == 2
461    assert len(mat2.shape) == 2
462    assert mat1.size(1) == mat2.size(0) / 2
463    output_sizes = [mat1.size(0), mat2.size(1)]
464
465    if out_dtype is not None:
466        assert (
467            mat2.dtype == torch.int8 and out_dtype == torch.int32
468        ), "out_dtype is only supported for i8i8->i32 linear operator"
469    output = mat2.new_empty(
470        output_sizes,
471        dtype=mat2.dtype if out_dtype is None else out_dtype,
472    )
473
474    return output
475
476
477@register_meta(aten._sparse_semi_structured_addmm)
478def meta_sparse_structured_addmm(
479    input: Tensor,
480    mat1: Tensor,
481    mat1_meta: Tensor,
482    mat2: Tensor,
483    *,
484    alpha=1,
485    beta=1,
486    out_dtype: Optional[torch.dtype] = None,
487):
488    assert (
489        len(input.shape) == 1
490    ), "only input broadcasted to columns of mat1 * mat2 product is supported"
491    assert len(mat1.shape) == 2
492    assert len(mat1_meta.shape) == 2
493    assert len(mat2.shape) == 2
494    assert input.size(0) == mat1.size(
495        0
496    ), "only input broadcasted to columns of mat1 * mat2 product is supported"
497    assert mat1.size(1) == mat2.size(0) / 2
498    output_sizes = [mat1.size(0), mat2.size(1)]
499
500    if out_dtype is not None:
501        assert (
502            mat2.dtype == torch.int8 and out_dtype == torch.int32
503        ), "out_dtype is only supported for i8i8->i32 linear operator"
504    output = mat2.new_empty(
505        output_sizes,
506        dtype=mat2.dtype if out_dtype is None else out_dtype,
507    )
508
509    return output
510
511
512@register_meta(aten._cslt_sparse_mm)
513def meta__cslt_sparse_mm(
514    compressed_A: torch.Tensor,
515    dense_B: torch.Tensor,
516    bias: Optional[Tensor] = None,
517    alpha: Optional[Tensor] = None,
518    out_dtype: Optional[torch.dtype] = None,
519    transpose_result: bool = False,
520):
521    assert dense_B.dtype in {
522        torch.float32,
523        torch.float16,
524        torch.bfloat16,
525        torch.int8,
526    }, "_cslt_sparse_mm only supports fp16, bf16, and int8"
527    assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
528    assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
529
530    is_int8_input_type = compressed_A.dtype == torch.int8
531    compression_factor = 10 if is_int8_input_type else 9
532    k = dense_B.size(0)
533    n = dense_B.size(1)
534    m = (compressed_A.numel() * 16) // (compression_factor * k)
535    if bias is not None:
536        assert m == bias.size(0)
537
538    if out_dtype is not None:
539        assert is_int8_input_type and out_dtype in {
540            torch.float16,
541            torch.bfloat16,
542            torch.int32,
543        }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
544    output_shape = (n, m) if transpose_result else (m, n)
545    result = dense_B.new_empty(output_shape, dtype=out_dtype)
546    return result
547
548
549@register_meta(aten.index_reduce.default)
550def meta_index_reduce(
551    self: Tensor,
552    dim: int,
553    index: Tensor,
554    source: torch.Tensor,
555    reduce: str,
556    *,
557    include_self: bool = True,
558) -> Tensor:
559    return torch.empty_like(self, memory_format=torch.contiguous_format)
560
561
562@register_meta(aten.index_reduce_.default)
563def meta_index_reduce_(
564    self: Tensor,
565    dim: int,
566    index: Tensor,
567    source: torch.Tensor,
568    reduce: str,
569    *,
570    include_self: bool = True,
571) -> Tensor:
572    return self
573
574
575# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
576@out_wrapper()
577@register_meta(aten.index_select.default)
578def meta_index_select(self, dim, index):
579    result_size = list(self.size())
580    if self.dim() > 0:
581        result_size[dim] = index.numel()
582    return self.new_empty(result_size)
583
584
585@register_meta(aten.segment_reduce.default)
586def meta_segment_reduce(
587    data: Tensor,
588    reduce: str,
589    *,
590    lengths: Optional[Tensor] = None,
591    indices: Optional[Tensor] = None,
592    offsets: Optional[Tensor] = None,
593    axis: int = 0,
594    unsafe: bool = False,
595    initial=None,
596) -> Tensor:
597    if indices is not None:
598        raise NotImplementedError(
599            "segment_reduce(): indices based reduction is not supported yet."
600        )
601
602    def segment_reduce_lengths_tensor(lengths_shape):
603        return torch.empty(
604            lengths_shape + data.shape[axis + 1 :],
605            dtype=data.dtype,
606            device="meta",
607            memory_format=torch.contiguous_format,
608        )
609
610    if lengths is not None:
611        return segment_reduce_lengths_tensor(lengths.shape)
612    # FIXME should probably check that lengths and offset aren't both set, but
613    # the ATen implementation neglects this too
614    if offsets is not None:
615        # lengths == torch.diff(offsets)
616        lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
617        return segment_reduce_lengths_tensor(lengths_shape)
618    raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
619
620
621@register_meta([aten.max.default, aten.max.unary_out])
622@out_wrapper()
623def meta_max(self):
624    return self.new_empty(())
625
626
627@register_meta(aten.max.dim)
628def meta_max_dim(self, dim, keepdim=False):
629    dim = utils.reduction_dims(self.shape, (dim,))
630    output_shape = _compute_reduction_shape(self, dim, keepdim)
631    return (
632        self.new_empty(output_shape),
633        self.new_empty(output_shape, dtype=torch.long),
634    )
635
636
637@register_meta([aten.min.default, aten.min.unary_out])
638@out_wrapper()
639def meta_min(self):
640    return self.new_empty(())
641
642
643@register_meta(aten.min.dim)
644def meta_min_dim(self, dim, keepdim=False):
645    dim = utils.reduction_dims(self.shape, (dim,))
646    output_shape = _compute_reduction_shape(self, dim, keepdim)
647    return (
648        self.new_empty(output_shape),
649        self.new_empty(output_shape, dtype=torch.long),
650    )
651
652
653@register_meta(aten.angle.default)
654def meta_angle(self):
655    if self.is_complex():
656        result_dtype = corresponding_real_dtype(self.dtype)
657    else:
658        _, result_dtype = elementwise_dtypes(
659            self,
660            type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
661        )
662    return torch.empty_like(self, dtype=result_dtype)
663
664
665@register_meta(aten.angle.out)
666def meta_angle_out(self, out):
667    torch._resize_output_(out, self.size(), self.device)
668    return out.copy_(torch.angle(self))
669
670
671@register_meta(aten._assert_async.default)
672def assert_async(val):
673    return
674
675
676@register_meta(aten._assert_async.msg)
677def assert_async_meta(val, assert_msg):
678    return
679
680
681@register_meta(aten._print.default)
682def print_meta(s):
683    return
684
685
686@register_meta(aten._make_dep_token.default)
687def make_dep_token(
688    *,
689    dtype=None,
690    layout=None,
691    device=None,
692    pin_memory=None,
693    memory_format=None,
694):
695    return torch.empty(0, device="meta")
696
697
698@register_meta(aten.sym_constrain_range.default)
699def sym_constrain_range(size, min=None, max=None):
700    # Avoid importing sympy at a module level
701    from torch.fx.experimental.symbolic_shapes import constrain_range
702
703    if isinstance(size, (SymFloat, SymBool)):
704        raise ValueError("Constraining SymFloat or Symbool is nyi")
705    constrain_range(size, min=min, max=max)
706
707
708@register_meta(aten._functional_sym_constrain_range.default)
709def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
710    aten.sym_constrain_range(size, min=min, max=max)
711    return dep_token
712
713
714@register_meta(aten.sym_constrain_range_for_size.default)
715def sym_constrain_range_for_size(size, min=None, max=None):
716    # Avoid importing sympy at a module level
717    from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
718
719    if isinstance(size, (SymFloat, SymBool)):
720        raise ValueError("Constraining SymFloat or Symbool is nyi")
721    _constrain_range_for_size(size, min=min, max=max)
722
723
724@register_meta(aten._functional_sym_constrain_range_for_size.default)
725def functional_sym_constrain_range_for_size(size, min, max, dep_token):
726    aten.sym_constrain_range_for_size(size, min=min, max=max)
727    return dep_token
728
729
730@register_meta(aten._functional_assert_async.msg)
731def functional_assert_async_meta(val, assert_msg, dep_token):
732    return dep_token
733
734
735# From aten/src/ATen/native/LinearAlgebraUtils.h
736def squareCheckInputs(self: Tensor, f_name: str):
737    assert (
738        self.dim() >= 2
739    ), f"{f_name}: The input tensor must have at least 2 dimensions."
740    assert (
741        self.size(-1) == self.size(-2)
742    ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
743
744
745# Validates input shapes and devices
746# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
747# From aten/src/ATen/native/LinearAlgebraUtils.h
748def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str):
749    torch._check(
750        self.device == A.device,
751        lambda: (
752            f"Expected b and A to be on the same device, but found b on "
753            f"{self.device} and A on {A.device} instead."
754        ),
755    )
756
757    torch._check(
758        self.dtype == A.dtype,
759        lambda: (
760            f"Expected b and A to have the same dtype, but found b of type "
761            f"{self.dtype} and A of type {A.dtype} instead."
762        ),
763    )
764
765    torch._check(
766        A.size(-1) == A.size(-2),
767        lambda: (
768            f"A must be batches of square matrices, "
769            f"but they are {A.size(-2)} by {A.size(-1)} matrices"
770        ),
771    )
772
773    torch._check(
774        A.size(-1) == self.size(-2),
775        lambda: (
776            f"Incompatible matrix sizes for {name}: each A "
777            f"matrix is {A.size(-1)} by {A.size(-1)}"
778            f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
779        ),
780    )
781
782
783# From aten/src/ATen/native/LinearAlgebraUtils.h
784def checkFloatingOrComplex(
785    t: Tensor,
786    f_name: str,
787    allow_low_precision_dtypes: bool = True,
788):
789    dtype = t.dtype
790    torch._check(
791        t.is_floating_point() or t.is_complex(),
792        lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
793    )
794    if not allow_low_precision_dtypes:
795        torch._check(
796            dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
797            lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
798        )
799
800
801# From aten/src/ATen/native/LinearAlgebraUtils.h
802def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
803    torch._check(
804        A.dim() >= 2,
805        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
806    )
807
808
809def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str):
810    squareCheckInputs(A, f_name)
811    checkIsMatrix(B, f_name)
812    torch._check(
813        A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
814        lambda: (
815            f"{f_name}: Incompatible shapes of A and B for the equation "
816            f"{'AX = B' if left else 'XA = B'}"
817            f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
818        ),
819    )
820
821
822def checkSameDevice(
823    fn_name: str,
824    result: Tensor,
825    input: Tensor,
826    result_name: str = "result",
827):
828    torch._check(
829        result.device == input.device,
830        lambda: (
831            f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
832            f"{result_name} on {result.device} and input on {input.device}"
833        ),
834    )
835
836
837def checkUplo(UPLO: str):
838    UPLO_uppercase = UPLO.upper()
839    torch._check(
840        len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
841        lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
842    )
843
844
845@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
846@out_wrapper("eigenvalues", "eigenvectors")
847def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True):
848    squareCheckInputs(A, "linalg.eigh")
849    checkUplo(UPLO)
850
851    shape = list(A.shape)
852    if compute_v:
853        vecs = A.new_empty(shape)
854        vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
855    else:
856        vecs = A.new_empty([0])
857
858    shape.pop()
859    vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
860
861    return vals, vecs
862
863
864@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
865@out_wrapper()
866def meta__linalg_eigvals(input: Tensor) -> Tensor:
867    squareCheckInputs(input, "linalg.eigvals")
868    complex_dtype = (
869        input.dtype
870        if utils.is_complex_dtype(input.dtype)
871        else utils.corresponding_complex_dtype(input.dtype)
872    )
873    return input.new_empty(input.shape[:-1], dtype=complex_dtype)
874
875
876@register_meta([aten.linalg_eig])
877@out_wrapper("eigenvalues", "eigenvectors")
878def meta_linalg_eig(input: Tensor):
879    squareCheckInputs(input, "linalg.eig")
880    complex_dtype = (
881        input.dtype
882        if utils.is_complex_dtype(input.dtype)
883        else utils.corresponding_complex_dtype(input.dtype)
884    )
885    values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
886    vectors = input.new_empty(input.shape, dtype=complex_dtype)
887    return values, vectors
888
889
890def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
891    return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
892
893
894@register_meta(aten._cholesky_solve_helper)
895@out_wrapper()
896def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
897    return cloneBatchedColumnMajor(self)
898
899
900@register_meta(aten.cholesky_solve)
901@out_wrapper()
902def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
903    torch._check(
904        self.ndim >= 2,
905        lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
906    )
907    torch._check(
908        A.ndim >= 2,
909        lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
910    )
911    self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
912        self, A, "cholesky_solve"
913    )
914    return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
915
916
917@register_meta(aten.cholesky)
918@out_wrapper()
919def cholesky(self: Tensor, upper: bool = False) -> Tensor:
920    if self.numel() == 0:
921        return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
922    squareCheckInputs(self, "cholesky")
923    return cloneBatchedColumnMajor(self)
924
925
926@register_meta(aten.cholesky_inverse)
927@out_wrapper()
928def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
929    squareCheckInputs(self, "cholesky_inverse")
930    return cloneBatchedColumnMajor(self)
931
932
933# From aten/src/ATen/native/BatchLinearAlgebra.cpp
934@register_meta(aten.linalg_cholesky_ex.default)
935def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
936    squareCheckInputs(A, "linalg.cholesky")
937    checkFloatingOrComplex(A, "linalg.cholesky")
938
939    A_shape = A.shape
940    ndim = len(A_shape)
941
942    # L
943    L_strides = make_contiguous_strides_for(A_shape, False)
944    L = A.new_empty(A_shape)
945    L.as_strided_(A_shape, L_strides)
946
947    # infos
948    infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
949    return L, infos
950
951
952@register_meta(
953    [aten.linalg_householder_product.default, aten.linalg_householder_product.out]
954)
955@out_wrapper()
956def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
957    torch._check(
958        input.ndim >= 2,
959        lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
960    )
961    torch._check(
962        input.size(-2) >= input.size(-1),
963        lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
964    )
965    torch._check(
966        input.size(-1) >= tau.size(-1),
967        lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
968    )
969
970    torch._check(
971        input.ndim - tau.ndim == 1,
972        lambda: (
973            f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
974            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
975        ),
976    )
977    if input.ndim > 2:
978        expected_batch_tau_shape = input.shape[:-2]
979        actual_batch_tau_shape = tau.shape[:-1]
980        torch._check(
981            actual_batch_tau_shape == expected_batch_tau_shape,
982            lambda: (
983                f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
984                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
985            ),
986        )
987
988    torch._check(
989        tau.dtype == input.dtype,
990        lambda: (
991            f"torch.linalg.householder_product: tau dtype {tau.dtype}"
992            f" does not match input dtype {input.dtype}"
993        ),
994    )
995    checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
996
997    return torch.empty_strided(
998        size=input.shape,
999        stride=make_contiguous_strides_for(input.shape, row_major=False),
1000        dtype=input.dtype,
1001        device=input.device,
1002    )
1003
1004
1005# From aten/src/ATen/native/BatchLinearAlgebra.cpp
1006@register_meta(aten.linalg_inv_ex.default)
1007def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
1008    squareCheckInputs(A, "linalg.inv_ex")
1009    checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
1010
1011    L = A.new_empty(A.shape)
1012    L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1013
1014    infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
1015    return L, infos
1016
1017
1018@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
1019@out_wrapper("LD", "pivots", "info")
1020def linalg_ldl_factor_ex_meta(
1021    self: Tensor,
1022    *,
1023    hermitian: bool = False,
1024    check_errors: bool = False,
1025) -> Tuple[Tensor, Tensor, Tensor]:
1026    squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
1027    checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
1028    LD = torch.empty_strided(
1029        size=self.shape,
1030        stride=make_contiguous_strides_for(self.shape, row_major=False),
1031        dtype=self.dtype,
1032        device=self.device,
1033    )
1034    pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
1035    info = self.new_empty(self.shape[:-2], dtype=torch.int)
1036    return LD, pivots, info
1037
1038
1039@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
1040@out_wrapper()
1041def linalg_ldl_solve_meta(
1042    LD: Tensor,
1043    pivots: Tensor,
1044    B: Tensor,
1045    *,
1046    hermitian: bool = False,
1047) -> Tensor:
1048    squareCheckInputs(LD, "torch.linalg.ldl_solve")
1049    checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
1050    linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
1051    torch._check(
1052        B.ndim >= 2,
1053        lambda: (
1054            f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
1055            f"but it has {B.ndim} dimensions instead"
1056        ),
1057    )
1058    expected_pivots_shape = LD.shape[:-1]
1059    torch._check(
1060        expected_pivots_shape == pivots.shape,
1061        lambda: (
1062            f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
1063            f"but got pivots with shape {pivots.shape} instead"
1064        ),
1065    )
1066    torch._check(
1067        utils.is_integer_dtype(pivots.dtype),
1068        lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
1069    )
1070    torch._check(
1071        LD.dtype == B.dtype,
1072        lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
1073    )
1074    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
1075    return torch.empty_strided(
1076        size=B_broadcast_size,
1077        stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
1078        dtype=B.dtype,
1079        device=B.device,
1080    )
1081
1082
1083@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
1084@out_wrapper("P", "L", "U")
1085def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
1086    torch._check(
1087        A.ndim >= 2,
1088        lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1089    )
1090
1091    sizes = list(A.shape)
1092    m = sizes[-2]
1093    n = sizes[-1]
1094    k = min(m, n)
1095
1096    sizes[-1] = m
1097    if pivot:
1098        P = A.new_empty(sizes)
1099    else:
1100        P = A.new_empty([0])
1101
1102    sizes[-1] = k
1103    L = A.new_empty(sizes)
1104
1105    sizes[-2] = k
1106    sizes[-1] = n
1107    U = A.new_empty(sizes)
1108    return P, L, U
1109
1110
1111@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
1112@out_wrapper("LU", "pivots", "info")
1113def linalg_lu_factor_ex_meta(
1114    A: Tensor,
1115    *,
1116    pivot: bool = True,
1117    check_errors: bool = False,
1118) -> Tuple[Tensor, Tensor, Tensor]:
1119    torch._check(
1120        A.ndim >= 2,
1121        lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1122    )
1123
1124    sizes = list(A.shape)
1125    m = sizes[-2]
1126    n = sizes[-1]
1127
1128    LU = torch.empty_strided(
1129        size=sizes,
1130        stride=make_contiguous_strides_for(sizes, row_major=False),
1131        dtype=A.dtype,
1132        device=A.device,
1133    )
1134
1135    # Sets sizes to the size of pivots
1136    sizes.pop()
1137    sizes[-1] = min(m, n)
1138    pivots = A.new_empty(sizes, dtype=torch.int)
1139
1140    # Sets sizes to the size of info
1141    sizes.pop()
1142    info = A.new_empty(sizes, dtype=torch.int)
1143
1144    return LU, pivots, info
1145
1146
1147@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
1148@out_wrapper()
1149def linalg_lu_solve_meta(
1150    LU: Tensor,
1151    pivots: Tensor,
1152    B: Tensor,
1153    *,
1154    left: bool = True,
1155    adjoint: bool = False,
1156) -> Tensor:
1157    # dtype
1158    checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
1159    torch._check(
1160        LU.dtype == B.dtype,
1161        lambda: (
1162            f"linalg.lu_solve: Expected LU and B to have the same dtype, "
1163            f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
1164        ),
1165    )
1166    torch._check(
1167        pivots.dtype == torch.int,
1168        lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
1169    )
1170
1171    # matrix shapes
1172    squareCheckInputs(LU, "torch.linalg.lu_solve")
1173    checkInputsSolver(LU, B, left, "linalg.lu_solve")
1174    torch._check(
1175        LU.size(-1) == pivots.size(-1),
1176        lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
1177    )
1178
1179    # batches
1180    torch._check(
1181        LU.shape[:-1] == pivots.shape,
1182        lambda: (
1183            f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
1184            f"but got pivots with shape {pivots.shape} instead"
1185        ),
1186    )
1187
1188    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
1189
1190    result = torch.empty_strided(
1191        size=B_broadcast_size,
1192        stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
1193        dtype=B.dtype,
1194        device=B.device,
1195    )
1196
1197    if result.numel() != 0 and not left:
1198        if result.is_complex():
1199            result = result.conj()
1200
1201    return result
1202
1203
1204@register_meta(aten.lu_unpack)
1205@out_wrapper("P", "L", "U")
1206def lu_unpack_meta(
1207    LU: Tensor,
1208    pivots: Tensor,
1209    unpack_data: bool = True,
1210    unpack_pivots: bool = True,
1211) -> Tuple[Tensor, Tensor, Tensor]:
1212    torch._check(
1213        LU.ndim >= 2,
1214        lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
1215    )
1216    if unpack_pivots:
1217        torch._check(
1218            pivots.dtype == torch.int32,
1219            lambda: (
1220                "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
1221                "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
1222            ),
1223        )
1224    sizes = list(LU.shape)
1225    m = sizes[-2]
1226    n = sizes[-1]
1227    k = min(m, n)
1228    sizes[-1] = m
1229    if unpack_pivots:
1230        P = LU.new_empty(sizes)
1231    else:
1232        P = LU.new_empty([0])
1233    if unpack_data:
1234        sizes[-1] = k
1235        L = LU.new_empty(sizes)
1236        sizes[-2] = k
1237        sizes[-1] = n
1238        U = LU.new_empty(sizes)
1239    else:
1240        L = LU.new_empty([0])
1241        U = LU.new_empty([0])
1242    return P, L, U
1243
1244
1245# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
1246def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
1247    if mode == "reduced":
1248        compute_q = True
1249        reduced = True
1250    elif mode == "complete":
1251        compute_q = True
1252        reduced = False
1253    elif mode == "r":
1254        compute_q = False
1255        reduced = True  # this is actually irrelevant in this mode
1256    else:
1257        torch._check(
1258            False,
1259            lambda: (
1260                f"qr received unrecognized mode '{mode}' "
1261                f"but expected one of 'reduced' (default), 'r', or 'complete'"
1262            ),
1263        )
1264    return compute_q, reduced  # type: ignore[possibly-undefined]
1265
1266
1267@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
1268@out_wrapper("Q", "R")
1269def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
1270    checkIsMatrix(A, "linalg.qr")
1271    checkFloatingOrComplex(A, "linalg.qr")
1272
1273    compute_q, reduced_mode = _parse_qr_mode(mode)
1274
1275    m = A.shape[-2]
1276    n = A.shape[-1]
1277    k = min(m, n)
1278
1279    if compute_q:
1280        Q_shape = list(A.shape)
1281        Q_shape[-1] = k if reduced_mode else m
1282        Q = A.new_empty(Q_shape)
1283        Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
1284    else:
1285        Q = A.new_empty([0])
1286
1287    # For readability
1288    R_shape = list(A.shape)
1289    R_shape[-2] = k if reduced_mode or not compute_q else m
1290    R = A.new_empty(R_shape)
1291    R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
1292    return Q, R
1293
1294
1295@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
1296@out_wrapper("sign", "logabsdet", "LU", "pivots")
1297def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1298    squareCheckInputs(A, "linalg.slogdet")
1299    checkFloatingOrComplex(A, "linalg.slogdet", False)
1300    shape = A.shape
1301    sign = A.new_empty(shape[:-2])
1302    logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
1303    LU = torch.empty_strided(
1304        size=shape,
1305        stride=make_contiguous_strides_for(shape, False),
1306        dtype=A.dtype,
1307        device=A.device,
1308    )
1309    pivots = A.new_empty(shape[:-1], dtype=torch.int32)
1310    return sign, logabsdet, LU, pivots
1311
1312
1313# From aten/src/ATen/native/BatchLinearAlgebra.cpp
1314# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
1315@register_meta(aten._linalg_svd.default)
1316def _linalg_svd_meta(
1317    A: Tensor,
1318    full_matrices: bool = False,
1319    compute_uv: bool = True,
1320    driver: Optional[str] = None,
1321):
1322    checkIsMatrix(A, "linalg.svd")
1323    checkFloatingOrComplex(A, "linalg.svd")
1324
1325    batch_dims = list(A.shape[:-2])
1326    m = A.shape[-2]
1327    n = A.shape[-1]
1328    k = min(m, n)
1329
1330    if compute_uv:
1331        U_shape = batch_dims + [m, m if full_matrices else k]
1332        U = A.new_empty(U_shape)
1333        U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
1334
1335        V_shape = batch_dims + [n if full_matrices else k, n]
1336        V = A.new_empty(V_shape)
1337        # NB: This checks for CUDA since there is no way to check for cuSolver.
1338        # Also, this might not work correctly on CPU when fake_device is not
1339        # available as device_hint just defaults to CUDA in that case. See
1340        # _linalg_svd meta in core.
1341        is_cuda = device_hint(A) == "cuda"
1342        V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
1343    else:
1344        # doesn't matter
1345        U = A.new_empty([0])
1346        V = A.new_empty([0])
1347
1348    # S is always real, even when A is complex.
1349    S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
1350    return U, S, V
1351
1352
1353def _linalg_broadcast_batch_dims(
1354    arg1: Tensor,
1355    arg2: Tensor,
1356) -> Tuple[List[int], List[int]]:
1357    # broadcast the batch dimensions of arg1 and arg2.
1358    arg1_batch_sizes = arg1.shape[:-2]
1359    arg2_batch_sizes = arg2.shape[:-2]
1360    expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
1361
1362    arg1_expand_size = list(expand_batch_portion)
1363    arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
1364
1365    arg2_expand_size = list(expand_batch_portion)
1366    arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
1367    return arg1_expand_size, arg2_expand_size
1368
1369
1370def _linalg_broadcast_batch_dims_name(
1371    arg1: Tensor,
1372    arg2: Tensor,
1373    name: Optional[str],
1374) -> Tuple[Tensor, Tensor]:
1375    # If there's no name we assume we don't want to check the errors
1376    if name:
1377        linearSolveCheckInputs(arg1, arg2, name)
1378
1379    arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
1380
1381    arg1_broadcasted = (
1382        arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
1383    )
1384    arg2_broadcasted = (
1385        arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
1386    )
1387    return arg1_broadcasted, arg2_broadcasted
1388
1389
1390def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
1391    expected_batched_rhs_shape = input.shape[:-1]
1392    vector_case = other.ndim == 1 or (
1393        input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
1394    )
1395    return vector_case
1396
1397
1398@register_meta(aten._linalg_solve_ex)
1399def _linalg_solve_ex(
1400    A: Tensor,
1401    B: Tensor,
1402    *,
1403    left: bool = True,
1404    check_errors: bool = False,
1405    result: Optional[Tensor] = None,
1406    LU: Optional[Tensor] = None,
1407    pivots: Optional[Tensor] = None,
1408    info: Optional[Tensor] = None,
1409) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1410    checkFloatingOrComplex(A, "linalg.solve")
1411    torch._check(
1412        A.dtype == B.dtype,
1413        lambda: (
1414            f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
1415            f"{A.dtype} and B of type {B.dtype} instead"
1416        ),
1417    )
1418    vector_case = linalg_solve_is_vector_rhs(A, B)
1419    B_ = B.unsqueeze(-1) if vector_case else B
1420    checkInputsSolver(A, B_, left, "linalg.solve")
1421    B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
1422    torch._check(
1423        left or not vector_case,
1424        lambda: (
1425            "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
1426            "In this case linalg.solve is equivalent to B / A.squeeze(-1)"
1427        ),
1428    )
1429    result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
1430    result_ = torch.empty_strided(
1431        size=result_shape,
1432        stride=make_contiguous_strides_for(result_shape, not left),
1433        dtype=B.dtype,
1434        device=B.device,
1435    )
1436    shape = A.shape
1437    ndim = A.ndim
1438    LU_ = torch.empty_strided(
1439        size=shape,
1440        stride=make_contiguous_strides_for(shape, False),
1441        dtype=A.dtype,
1442        device=A.device,
1443    )
1444    pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
1445    info_ = A.new_empty(shape[:-2], dtype=torch.int32)
1446    out = (result, LU, pivots, info)
1447    res = (result_, LU_, pivots_, info_)
1448    if all(x is not None for x in out):
1449        for r, o in zip(res, out):
1450            # resize and copy operations are done in-place
1451            _maybe_resize_out(o, r.shape)  # type: ignore[arg-type]
1452            # strides are not copied in out_wrapper
1453            o.as_strided_(r.shape, r.stride())  # type: ignore[union-attr]
1454            _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False)  # type: ignore[arg-type]
1455    return res
1456
1457
1458@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
1459def linalg_solve_triangular_meta(
1460    A: Tensor,
1461    B: Tensor,
1462    *,
1463    upper: bool,
1464    left: bool = True,
1465    unitriangular: bool = False,
1466    out: Optional[Tensor] = None,
1467) -> Tensor:
1468    if out is None:
1469        out = A.new_empty([0])
1470    assert isinstance(out, TensorLike)
1471    checkInputsSolver(A, B, left, "linalg.solve_triangular")
1472    B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
1473    avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
1474    if avoid_copy_A:
1475        out = _maybe_resize_out(out, B_.shape)
1476    else:
1477        # reimplementation of resize_output with result F-contig
1478        if _resize_output_check(out, B_.shape):
1479            out.resize_(B_.transpose(-2, -1).shape)
1480            out.transpose_(-2, -1)
1481    return out  # type: ignore[return-value]
1482
1483
1484@register_meta(aten.triangular_solve)
1485@out_wrapper("solution", "cloned_coefficient")
1486def triangular_solve_meta(
1487    self: Tensor,
1488    A: Tensor,
1489    upper: bool = True,
1490    transpose: bool = False,
1491    unitriangular: bool = False,
1492) -> Tuple[Tensor, Tensor]:
1493    torch._check(
1494        self.ndim >= 2,
1495        lambda: (
1496            f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
1497            f"but it has {self.ndim} dimensions instead"
1498        ),
1499    )
1500    torch._check(
1501        A.ndim >= 2,
1502        lambda: (
1503            f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
1504            f"but it has {A.ndim} dimensions instead"
1505        ),
1506    )
1507
1508    linearSolveCheckInputs(self, A, "triangular_solve")
1509
1510    if A.layout == torch.strided:
1511        self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
1512        solution = torch.empty_strided(
1513            size=self_broadcast_size,
1514            stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
1515            dtype=self.dtype,
1516            device=self.device,
1517        )
1518        cloned_coefficient = torch.empty_strided(
1519            size=A_broadcast_size,
1520            stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
1521            dtype=A.dtype,
1522            device=A.device,
1523        )
1524    elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
1525        solution = torch.empty_like(self)
1526        cloned_coefficient = self.new_empty([0])
1527    else:
1528        torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
1529    return solution, cloned_coefficient  # type: ignore[possibly-undefined]
1530
1531
1532# From aten/src/ATen/native/LinearAlgebra.cpp
1533@register_meta(aten._linalg_det.default)
1534def _linalg_det_meta(A):
1535    squareCheckInputs(A, "linalg.det")
1536    checkFloatingOrComplex(A, "linalg.det")
1537
1538    det = A.new_empty(A.shape[:-2])
1539
1540    LU = A.new_empty(A.shape)
1541    LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1542
1543    pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
1544    return det, LU, pivots
1545
1546
1547@register_meta(aten.ormqr)
1548@out_wrapper()
1549def ormqr(
1550    input: Tensor,
1551    tau: Tensor,
1552    other: Tensor,
1553    left: bool = True,
1554    transpose: bool = False,
1555) -> Tensor:
1556    torch._check(
1557        input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
1558    )
1559    torch._check(
1560        other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
1561    )
1562
1563    left_size_condition = -2 if left else -1
1564    torch._check(
1565        other.shape[left_size_condition] >= tau.shape[-1],
1566        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
1567    )
1568    torch._check(
1569        other.shape[left_size_condition] == input.shape[-2],
1570        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
1571    )
1572
1573    torch._check(
1574        tau.shape[-1] <= input.shape[-1],
1575        lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
1576    )
1577
1578    torch._check(
1579        input.ndim - tau.ndim == 1,
1580        lambda: (
1581            f"torch.ormqr: Expected tau to have one dimension less than input, "
1582            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
1583        ),
1584    )
1585    torch._check(
1586        input.ndim == other.ndim,
1587        lambda: (
1588            f"torch.ormqr: Expected other to have the same number of dimensions as input, "
1589            f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
1590        ),
1591    )
1592
1593    if input.ndim > 2:
1594        expected_batch_shape = input.shape[:-2]
1595        actual_batch_tau_shape = tau.shape[:-1]
1596        torch._check(
1597            actual_batch_tau_shape == expected_batch_shape,
1598            lambda: (
1599                f"torch.ormqr: Expected batch dimensions of tau to be "
1600                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
1601            ),
1602        )
1603
1604        actual_batch_other_shape = other.shape[:-2]
1605        torch._check(
1606            actual_batch_other_shape == expected_batch_shape,
1607            lambda: (
1608                f"torch.ormqr: Expected batch dimensions of other to be "
1609                f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
1610            ),
1611        )
1612
1613    torch._check(
1614        tau.dtype == input.dtype,
1615        lambda: (
1616            f"torch.ormqr: Expected input and tau to have the same dtype, "
1617            f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
1618        ),
1619    )
1620    torch._check(
1621        other.dtype == input.dtype,
1622        lambda: (
1623            f"torch.ormqr: Expected input and other to have the same dtype, "
1624            f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
1625        ),
1626    )
1627
1628    checkSameDevice("torch.ormqr", tau, input, "tau")
1629    checkSameDevice("torch.ormqr", other, input, "other")
1630
1631    return torch.empty_strided(
1632        size=other.shape,
1633        stride=make_contiguous_strides_for(other.shape, row_major=False),
1634        dtype=other.dtype,
1635        device=other.device,
1636    )
1637
1638
1639def _padding_check_valid_input(input, padding, *, dim):
1640    torch._check(
1641        len(padding) == 2 * dim,
1642        lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
1643    )
1644
1645    input_dim = input.ndim
1646
1647    is_batch_mode = input_dim == (dim + 2)
1648
1649    valid_batch_mode = is_batch_mode
1650    valid_non_batch_mode = not is_batch_mode
1651
1652    if is_batch_mode:
1653        # allow batch size of 0-dim.
1654        for d in range(1, input_dim):
1655            valid_batch_mode = valid_batch_mode and input.size(d) != 0
1656    else:
1657        for d in range(0, input_dim):
1658            valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
1659
1660    # allow empty batch size but not other dimensions.
1661    torch._check(
1662        valid_batch_mode or valid_non_batch_mode,
1663        lambda: (
1664            f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
1665            f"and other non-zero dimensions for input, but got: {input.shape}"
1666        ),
1667    )
1668
1669
1670def _pad1d_common(input, padding, *, is_reflection):
1671    dim_plane = 0
1672    dim_w = 1
1673    nbatch = 1
1674
1675    if input.ndim == 3:
1676        nbatch = input.size(0)
1677        dim_w += 1
1678        dim_plane += 1
1679
1680    _padding_check_valid_input(input, padding, dim=1)
1681
1682    pad_l, pad_r = padding
1683
1684    nplane = input.size(dim_plane)
1685    input_w = input.size(dim_w)
1686    output_w = input_w + pad_l + pad_r
1687
1688    if is_reflection:
1689        torch._check(
1690            pad_l < input_w and pad_r < input_w,
1691            lambda: (
1692                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1693                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1694            ),
1695        )
1696
1697    torch._check(
1698        output_w >= 1,
1699        lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
1700    )
1701
1702    if input.ndim == 2:
1703        return input.new_empty((nplane, output_w))
1704    else:
1705        return input.new_empty((nbatch, nplane, output_w))
1706
1707
1708@register_meta(aten.reflection_pad1d)
1709@out_wrapper()
1710def meta_reflection_pad1d(input, padding):
1711    return _pad1d_common(input, padding, is_reflection=True)
1712
1713
1714@register_meta(aten.replication_pad1d)
1715@out_wrapper()
1716def meta_replication_pad1d(input, padding):
1717    return _pad1d_common(input, padding, is_reflection=False)
1718
1719
1720def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
1721    dim_w = 1
1722    if not is_reflection:
1723        torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
1724
1725    if input.ndim == 3:
1726        dim_w += 1
1727
1728    pad_l, pad_r = padding
1729
1730    input_w = input.size(dim_w)
1731    output_w = input_w + pad_l + pad_r
1732
1733    if is_reflection:
1734        torch._check(
1735            pad_l < input_w and pad_r < input_w,
1736            lambda: (
1737                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1738                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1739            ),
1740        )
1741
1742    torch._check(
1743        output_w == grad_output.size(dim_w),
1744        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1745    )
1746
1747    return input.new_empty(input.shape)
1748
1749
1750@register_meta(aten.reflection_pad1d_backward)
1751@out_wrapper("grad_input")
1752def meta_reflection_pad1d_backward(grad_output, input, padding):
1753    return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
1754
1755
1756@register_meta(aten.replication_pad1d_backward)
1757@out_wrapper("grad_input")
1758def meta_replication_pad1d_backward(grad_output, input, padding):
1759    return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
1760
1761
1762def _pad2d_common(input, padding, *, is_reflection):
1763    dim_w = 2
1764    dim_h = 1
1765    dim_slices = 0
1766    nbatch = 1
1767
1768    _padding_check_valid_input(input, padding, dim=2)
1769
1770    ndim = input.ndim
1771    if ndim == 4:
1772        nbatch = input.size(0)
1773        dim_w += 1
1774        dim_h += 1
1775        dim_slices += 1
1776
1777    pad_l, pad_r, pad_t, pad_b = padding
1778
1779    nplane = input.size(dim_slices)
1780    input_h = input.size(dim_h)
1781    input_w = input.size(dim_w)
1782    output_h = input_h + pad_t + pad_b
1783    output_w = input_w + pad_l + pad_r
1784
1785    if is_reflection:
1786        torch._check(
1787            pad_l < input_w and pad_r < input_w,
1788            lambda: (
1789                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1790                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1791            ),
1792        )
1793        torch._check(
1794            pad_t < input_h and pad_b < input_h,
1795            lambda: (
1796                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1797                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1798            ),
1799        )
1800
1801    torch._check(
1802        output_w >= 1 or output_h >= 1,
1803        lambda: (
1804            f"input (H: {input_h} W: {input_w}) is too small. "
1805            f"Calculated output H: {output_h} W: {output_w}"
1806        ),
1807    )
1808
1809    if input.ndim == 3:
1810        return input.new_empty((nplane, output_h, output_w))
1811    else:
1812        return input.new_empty((nbatch, nplane, output_h, output_w))
1813
1814
1815@register_meta(aten.reflection_pad2d)
1816@out_wrapper()
1817def meta_reflection_pad2d(input, padding):
1818    return _pad2d_common(input, padding, is_reflection=True)
1819
1820
1821@register_meta(aten.replication_pad2d)
1822@out_wrapper()
1823def meta_replication_pad2d(input, padding):
1824    return _pad2d_common(input, padding, is_reflection=False)
1825
1826
1827@register_meta(
1828    [
1829        aten.reflection_pad2d_backward.default,
1830        aten.reflection_pad2d_backward.grad_input,
1831        aten.replication_pad2d_backward.default,
1832        aten.replication_pad2d_backward.grad_input,
1833    ]
1834)
1835@out_wrapper("grad_input")
1836def meta_pad2d_backward(grad_output, self, padding):
1837    dim_w = 2
1838    dim_h = 1
1839    dim_plane = 0
1840    nbatch = 1
1841
1842    self_shape = self.shape
1843    if self.dim() == 4:
1844        nbatch = self_shape[0]
1845        dim_w += 1
1846        dim_h += 1
1847        dim_plane += 1
1848
1849    pad_l, pad_r, pad_t, pad_b = padding
1850
1851    nplane = self_shape[dim_plane]
1852    input_h = self_shape[dim_h]
1853    input_w = self_shape[dim_w]
1854    output_h = input_h + pad_t + pad_b
1855    output_w = input_w + pad_l + pad_r
1856
1857    torch._check(
1858        output_w == grad_output.size(dim_w),
1859        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1860    )
1861    torch._check(
1862        output_h == grad_output.size(dim_h),
1863        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1864    )
1865    return self.new_empty(self.shape)
1866
1867
1868def _pad3d_common(input, padding, *, is_reflection):
1869    dim_w = 3
1870    dim_h = 2
1871    dim_d = 1
1872    dim_plane = 0
1873
1874    _padding_check_valid_input(input, padding, dim=3)
1875
1876    batch_mode = input.ndim == 5
1877    if batch_mode:
1878        nbatch = input.size(0)
1879        dim_w += 1
1880        dim_h += 1
1881        dim_d += 1
1882        dim_plane += 1
1883
1884    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1885
1886    nplane = input.size(dim_plane)
1887    input_d = input.size(dim_d)
1888    input_h = input.size(dim_h)
1889    input_w = input.size(dim_w)
1890    output_d = input_d + pad_f + pad_bk
1891    output_h = input_h + pad_t + pad_b
1892    output_w = input_w + pad_l + pad_r
1893
1894    if is_reflection:
1895        torch._check(
1896            pad_l < input_w and pad_r < input_w,
1897            lambda: (
1898                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1899                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1900            ),
1901        )
1902        torch._check(
1903            pad_t < input_h and pad_b < input_h,
1904            lambda: (
1905                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1906                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1907            ),
1908        )
1909        torch._check(
1910            pad_f < input_d and pad_bk < input_d,
1911            lambda: (
1912                f"Argument #8: Padding size should be less than the corresponding input dimension, "
1913                f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
1914            ),
1915        )
1916
1917    torch._check(
1918        output_w >= 1 or output_h >= 1 or output_d >= 1,
1919        lambda: (
1920            f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
1921            f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
1922        ),
1923    )
1924
1925    if batch_mode:
1926        return input.new_empty((nbatch, nplane, output_d, output_h, output_w))  # type: ignore[possibly-undefined]
1927    else:
1928        return input.new_empty((nplane, output_d, output_h, output_w))
1929
1930
1931@register_meta(aten.reflection_pad3d)
1932@out_wrapper()
1933def meta_reflection_pad3d(input, padding):
1934    return _pad3d_common(input, padding, is_reflection=True)
1935
1936
1937@register_meta(aten.replication_pad3d)
1938@out_wrapper()
1939def meta_replication_pad3d(input, padding):
1940    return _pad3d_common(input, padding, is_reflection=False)
1941
1942
1943@register_meta(
1944    [
1945        aten.reflection_pad3d_backward.default,
1946        aten.reflection_pad3d_backward.grad_input,
1947        aten.replication_pad3d_backward.default,
1948        aten.replication_pad3d_backward.grad_input,
1949    ]
1950)
1951@out_wrapper("grad_input")
1952def meta_pad3d_backward(grad_output, input, padding):
1953    torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
1954    assert input.ndim > 3
1955    assert grad_output.ndim == input.ndim
1956
1957    dim_w = 3
1958    dim_h = 2
1959    dim_d = 1
1960
1961    if input.ndim == 5:
1962        dim_w += 1
1963        dim_h += 1
1964        dim_d += 1
1965
1966    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1967
1968    input_d = input.size(dim_d)
1969    input_h = input.size(dim_h)
1970    input_w = input.size(dim_w)
1971    output_d = input_d + pad_f + pad_bk
1972    output_h = input_h + pad_t + pad_b
1973    output_w = input_w + pad_l + pad_r
1974
1975    torch._check(
1976        output_w == grad_output.size(dim_w),
1977        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1978    )
1979    torch._check(
1980        output_h == grad_output.size(dim_h),
1981        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1982    )
1983    torch._check(
1984        output_d == grad_output.size(dim_d),
1985        lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
1986    )
1987
1988    return input.new_empty(input.shape)
1989
1990
1991@register_meta(aten._pdist_forward)
1992@out_wrapper()
1993def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
1994    torch._check(
1995        self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
1996    )
1997    n = self.size(0)
1998    if n <= 1:
1999        return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format)  # type: ignore[call-overload]
2000    else:
2001        return self.new_empty((n * (n - 1) // 2,)).to(
2002            memory_format=torch.legacy_contiguous_format
2003        )  # type: ignore[call-overload]
2004
2005
2006@register_meta(aten._pdist_backward)
2007@out_wrapper()
2008def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
2009    torch._check(
2010        self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
2011    )
2012    torch._check(
2013        pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
2014    )
2015    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2016
2017
2018@register_meta([aten.baddbmm.default, aten.baddbmm.out])
2019@out_wrapper()
2020def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
2021    dim1 = batch1.size(0)
2022    dim2 = batch1.size(1)
2023    dim3 = batch2.size(2)
2024    self = self.expand((dim1, dim2, dim3))
2025    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
2026    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
2027    torch._check(
2028        self.dtype == batch1.dtype == batch2.dtype,
2029        lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
2030    )
2031    batch1_sizes = batch1.shape
2032    batch2_sizes = batch2.shape
2033    bs = batch1_sizes[0]
2034    contraction_size = batch1_sizes[2]
2035    torch._check(
2036        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
2037        lambda: (
2038            f"Expected size for first two dimensions of batch2 tensor to be: "
2039            f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
2040        ),
2041    )
2042    return self.new_empty(self.size())
2043
2044
2045@register_meta([aten.bernoulli.default, aten.bernoulli.out])
2046@out_wrapper()
2047def meta_bernoulli(self, *, generator=None):
2048    # https://github.com/pytorch/pytorch/issues/88612
2049    return torch.empty_like(self).contiguous()
2050
2051
2052@register_meta(aten.bernoulli_.float)
2053def meta_bernoulli_(self, p=0.5, generator=None):
2054    return self
2055
2056
2057@register_meta(aten.bernoulli.p)
2058def meta_bernoulli_p(self, p=0.5, generator=None):
2059    # https://github.com/pytorch/pytorch/issues/88612
2060    return torch.empty_like(self).contiguous()
2061
2062
2063@register_meta([aten.poisson.default, aten.poisson.out])
2064@out_wrapper()
2065def meta_poisson(self, generator=None):
2066    return torch.empty_like(self)
2067
2068
2069@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
2070def meta__fused_moving_avg_obs_fq_helper(
2071    self,
2072    observer_on,
2073    fake_quant_on,
2074    running_min,
2075    running_max,
2076    scale,
2077    zero_point,
2078    averaging_const,
2079    quant_min,
2080    quant_max,
2081    ch_axis,
2082    per_row_fake_quant=False,
2083    symmetric_quant=False,
2084):
2085    torch._check(
2086        ch_axis < self.dim(),
2087        lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
2088    )
2089    mask = torch.empty_like(self, dtype=torch.bool)
2090    return (torch.empty_like(self), mask)
2091
2092
2093@register_meta(aten.mm)
2094@out_wrapper()
2095def meta_mm(a, b):
2096    torch._check(a.dim() == 2, lambda: "a must be 2D")
2097    torch._check(b.dim() == 2, lambda: "b must be 2D")
2098    N, M1 = a.shape
2099    M2, P = b.shape
2100    torch._check(
2101        M1 == M2,
2102        lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
2103    )
2104    return a.new_empty(N, P)
2105
2106
2107def _compute_reduction_shape(self, dims, keepdim):
2108    if keepdim:
2109        return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
2110
2111    return utils.compute_reduction_output_shape(self.shape, dims)
2112
2113
2114# FakeTensors (meta tensors with a device) will report device as meta
2115# when running meta kernels. Here, access the "fake device" of FakeTensor if it
2116# exists so meta kernels which have diverge per device will be more
2117# accurate when run with FakeTensors
2118def device_hint(tensor) -> "str":
2119    if isinstance(tensor, torch._subclasses.FakeTensor):
2120        return tensor.fake_device.type
2121    else:
2122        return "cuda"  # default to cuda
2123
2124
2125def calc_conv_nd_return_shape(
2126    input_tensor: torch.Tensor,
2127    weight: torch.Tensor,
2128    stride: Union[List[int], int],
2129    padding: Union[List[int], int],
2130    dilation: Union[List[int], int],
2131    is_transposed: bool,
2132    groups: int,
2133    output_padding: Optional[Union[List[int], int]] = None,
2134):
2135    def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
2136        """
2137        Formula to apply to calculate the length of some dimension of the output
2138
2139        See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
2140
2141        Args:
2142            ln: length of the dimension
2143            p: padding in that dim
2144            d: dilation in that dim
2145            k: kernel size in that dim
2146            s: stride in that dim
2147        Returns:
2148            The output length
2149        """
2150        return (ln + 2 * p - d * (k - 1) - 1) // s + 1
2151
2152    def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
2153        """
2154        Formula to apply to calculate the length of some dimension of the output
2155        if transposed convolution is used.
2156        See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
2157
2158        Args:
2159            ln: length of the dimension
2160            p: padding in that dim
2161            d: dilation in that dim
2162            k: kernel size in that dim
2163            s: stride in that dim
2164            op: output padding in that dim
2165
2166        Returns:
2167            The output length
2168        """
2169        return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
2170
2171    kernel_size = weight.shape[2:]
2172    dims = input_tensor.shape[2:]
2173    if is_transposed:
2174        out_channels = groups * weight.shape[1]
2175    else:
2176        out_channels = weight.shape[0]
2177        if weight.shape[1] * groups != input_tensor.shape[1]:
2178            raise RuntimeError("Invalid channel dimensions")
2179
2180    ret_shape = [input_tensor.shape[0], out_channels]
2181    if isinstance(stride, IntLike):
2182        stride = [stride] * len(dims)
2183    elif len(stride) == 1:
2184        stride = [stride[0]] * len(dims)
2185
2186    if isinstance(padding, IntLike):
2187        padding = [padding] * len(dims)
2188    elif len(padding) == 1:
2189        padding = [padding[0]] * len(dims)
2190
2191    if isinstance(dilation, IntLike):
2192        dilation = [dilation] * len(dims)
2193    elif len(dilation) == 1:
2194        dilation = [dilation[0]] * len(dims)
2195
2196    output_padding_list: Optional[List[int]] = None
2197    if output_padding:
2198        if isinstance(output_padding, IntLike):
2199            output_padding_list = [output_padding] * len(dims)
2200        elif len(output_padding) == 1:
2201            output_padding_list = [output_padding[0]] * len(dims)
2202        else:
2203            output_padding_list = output_padding
2204
2205    for i in range(len(dims)):
2206        # If output_padding is present, we are dealing with a transposed convolution
2207        if output_padding_list:
2208            ret_shape.append(
2209                _formula_transposed(
2210                    dims[i],
2211                    padding[i],
2212                    dilation[i],
2213                    kernel_size[i],
2214                    stride[i],
2215                    output_padding_list[i],
2216                )
2217            )
2218        else:
2219            ret_shape.append(
2220                _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
2221            )
2222
2223    return ret_shape
2224
2225
2226def is_channels_last(ten):
2227    return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
2228
2229
2230@register_meta(aten.convolution.default)
2231def meta_conv(
2232    input_tensor: torch.Tensor,
2233    weight: torch.Tensor,
2234    bias: torch.Tensor,
2235    stride: List[int],
2236    padding: List[int],
2237    dilation: List[int],
2238    is_transposed: bool,
2239    output_padding: List[int],
2240    groups: int,
2241):
2242    def pick_memory_format():
2243        if device_hint(input_tensor) == "cuda":
2244            if is_channels_last(input_tensor) or is_channels_last(weight):
2245                return torch.channels_last
2246        else:
2247            if is_channels_last(input_tensor):
2248                return torch.channels_last
2249        if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
2250            return torch.contiguous_format
2251        elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
2252            return torch.preserve_format
2253
2254    shape_out = calc_conv_nd_return_shape(
2255        input_tensor,
2256        weight,
2257        stride,
2258        padding,
2259        dilation,
2260        is_transposed,
2261        groups,
2262        output_padding if is_transposed else None,
2263    )
2264
2265    input_channels_dim = 1
2266    output_channels_dim = 1
2267    if input_tensor.size(input_channels_dim) == 0:
2268        shape_out[output_channels_dim] = 0
2269
2270    out = input_tensor.new_empty(shape_out)
2271    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
2272    return out
2273
2274
2275if torch._C._has_mkldnn:
2276    _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
2277        "mkldnn", "IMPL", "Meta"
2278    )
2279
2280    @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
2281    def meta_mkldnn_convolution_default(
2282        input_tensor,
2283        weight,
2284        bias,
2285        padding,
2286        stride,
2287        dilation,
2288        groups,
2289        attr,
2290        scalars,
2291        algorithm,
2292    ):
2293        shape_out = calc_conv_nd_return_shape(
2294            input_tensor, weight, stride, padding, dilation, False, groups, []
2295        )
2296        out = input_tensor.new_empty(shape_out)
2297        out_memory_format = torch.channels_last
2298        if input_tensor.dim() == 5:
2299            out_memory_format = torch.channels_last_3d
2300        out = out.to(memory_format=out_memory_format)  # type: ignore[call-overload]
2301        return out
2302
2303    @register_meta(torch.ops.mkldnn._linear_pointwise.default)
2304    def meta_linear_pointwise_default(
2305        input_tensor, weight, bias, attr, scalars, algorithm
2306    ):
2307        return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2308
2309    if torch._C.has_mkl:
2310        _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
2311            "mkl", "IMPL", "Meta"
2312        )
2313
2314        @register_meta(torch.ops.mkl._mkl_linear)
2315        def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size):
2316            return input_tensor.new_empty(
2317                (*input_tensor.shape[:-1], orig_weight.shape[0])
2318            )
2319
2320    _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
2321        "onednn", "IMPL", "Meta"
2322    )
2323
2324    @register_meta(torch.ops.onednn.qconv2d_pointwise.default)
2325    def meta_qconv2d_pointwise(
2326        x,
2327        x_scale,
2328        x_zp,
2329        w,  # prepacked_weight
2330        w_scale,
2331        w_zp,
2332        bias,
2333        stride,
2334        padding,
2335        dilation,
2336        groups,
2337        output_scale,
2338        output_zero_point,
2339        output_dtype,
2340        attr,
2341        scalars,
2342        algorithm,
2343    ):
2344        shape_out = calc_conv_nd_return_shape(
2345            x,
2346            w,
2347            stride,
2348            padding,
2349            dilation,
2350            False,
2351            groups,
2352            None,
2353        )
2354        assert output_dtype in [torch.float32, torch.bfloat16]
2355        out = x.new_empty(shape_out, dtype=output_dtype)
2356        out = out.to(memory_format=torch.channels_last)
2357        return out
2358
2359    @register_meta(torch.ops.onednn.qlinear_pointwise.default)
2360    @register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
2361    def meta_qlinear_pointwise(
2362        x,
2363        x_scale,
2364        x_zp,
2365        w,
2366        w_scale,
2367        w_zp,
2368        bias,
2369        output_scale,
2370        output_zero_point,
2371        output_dtype,
2372        post_op_name,
2373        post_op_args,
2374        post_op_algorithm,
2375    ):
2376        output_shape = list(x.shape)
2377        # The weight has been transposed during the qlinear weight prepack process.
2378        output_shape[-1] = w.shape[1]
2379        assert output_dtype in [torch.float32, torch.bfloat16]
2380        out = x.new_empty(output_shape, dtype=output_dtype)
2381        return out
2382
2383    _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
2384        "quantized", "IMPL", "Meta"
2385    )
2386
2387    @register_meta(torch.ops.quantized.max_pool2d)
2388    def meta_quantized_max_pool2d(
2389        input,
2390        kernel_size,
2391        stride=(),
2392        padding=(0,),
2393        dilation=(1,),
2394        ceil_mode=False,
2395    ):
2396        (
2397            nInputPlane,
2398            outputHeight,
2399            outputWidth,
2400        ) = max_pool2d_checks_and_compute_shape(
2401            input, kernel_size, stride, padding, dilation, ceil_mode
2402        )
2403        nbatch = input.size(-4) if input.dim() == 4 else 1
2404        memory_format = torch.channels_last
2405        if input.dim() == 3:
2406            size = [nInputPlane, outputHeight, outputWidth]
2407        else:
2408            size = [nbatch, nInputPlane, outputHeight, outputWidth]
2409        return torch.empty(
2410            size,
2411            dtype=input.dtype,
2412            device=input.device,
2413            memory_format=memory_format,
2414        )
2415
2416
2417# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
2418def check_dim_size(tensor, dim, dim_size, size):
2419    torch._check(
2420        tensor.dim() == dim and tensor.shape[dim_size] == size,
2421        lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
2422        + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
2423    )
2424
2425
2426@register_meta(aten.avg_pool2d.default)
2427def meta_avg_pool2d(
2428    input,
2429    kernel_size,
2430    stride=(),
2431    padding=(0,),
2432    ceil_mode=False,
2433    count_include_pad=True,
2434    divisor_override=None,
2435):
2436    def unpack(name, val):
2437        torch._check(
2438            len(val) in [1, 2],
2439            lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
2440        )
2441        H = val[0]
2442        W = H if len(val) == 1 else val[1]
2443        return H, W
2444
2445    kH, kW = unpack("kernel_size", kernel_size)
2446    torch._check(
2447        len(stride) in [0, 1, 2],
2448        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2449    )
2450    if len(stride) == 0:
2451        dH, dW = kH, kW
2452    elif len(stride) == 1:
2453        dH, dW = stride[0], stride[0]
2454    else:
2455        dH, dW = unpack("stride", stride)
2456
2457    padH, padW = unpack("padding", padding)
2458
2459    torch._check(
2460        divisor_override is None or divisor_override != 0,
2461        lambda: "divisor must be not zero",
2462    )
2463
2464    nbatch = input.size(-4) if input.dim() == 4 else 1
2465    nInputPlane = input.size(-3)
2466    inputHeight = input.size(-2)
2467    inputWidth = input.size(-1)
2468
2469    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2470    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2471
2472    memory_format = utils.suggest_memory_format(input)
2473    pool2d_shape_check(
2474        input,
2475        kH,
2476        kW,
2477        dH,
2478        dW,
2479        padH,
2480        padW,
2481        1,
2482        1,
2483        nInputPlane,
2484        inputHeight,
2485        inputWidth,
2486        outputHeight,
2487        outputWidth,
2488        memory_format,
2489    )
2490
2491    if input.dim() == 3:
2492        size = [nInputPlane, outputHeight, outputWidth]
2493    else:
2494        size = [nbatch, nInputPlane, outputHeight, outputWidth]
2495    return torch.empty(
2496        size,
2497        dtype=input.dtype,
2498        device=input.device,
2499        memory_format=memory_format,
2500    )
2501
2502
2503# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
2504def avg_pool2d_backward_shape_check(
2505    input,
2506    gradOutput,
2507    nbatch,
2508    kH,
2509    kW,
2510    dH,
2511    dW,
2512    padH,
2513    padW,
2514    nInputPlane,
2515    inputHeight,
2516    inputWidth,
2517    outputHeight,
2518    outputWidth,
2519    mem_format,
2520):
2521    pool2d_shape_check(
2522        input,
2523        kH,
2524        kW,
2525        dH,
2526        dW,
2527        padH,
2528        padW,
2529        1,
2530        1,
2531        nInputPlane,
2532        inputHeight,
2533        inputWidth,
2534        outputHeight,
2535        outputWidth,
2536        mem_format,
2537    )
2538
2539    ndim = input.dim()
2540    nOutputPlane = nInputPlane
2541
2542    check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
2543    check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
2544    check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
2545
2546
2547# Don't override the C++ registration.
2548@register_meta(aten.avg_pool2d_backward.default)
2549def meta_avg_pool2d_backward(
2550    gradOutput_,
2551    input,
2552    kernel_size,
2553    stride,
2554    padding,
2555    ceil_mode,
2556    count_include_pad,
2557    divisor_override,
2558):
2559    # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
2560    torch._check(
2561        len(kernel_size) == 1 or len(kernel_size) == 2,
2562        lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
2563    )
2564    kH = kernel_size[0]
2565    kW = kH if len(kernel_size) == 1 else kernel_size[1]
2566    torch._check(
2567        len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
2568        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2569    )
2570    dH = kH if len(stride) == 0 else stride[0]
2571    dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
2572    torch._check(
2573        len(padding) == 1 or len(padding) == 2,
2574        lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
2575    )
2576    padH = padding[0]
2577    padW = padH if len(padding) == 1 else padding[1]
2578
2579    torch._check(
2580        divisor_override is None or divisor_override != 0,
2581        lambda: "divisor must be not zero",
2582    )
2583
2584    input_size = input.shape
2585    nbatch = input_size[-4] if input.dim() == 4 else 1
2586    nInputPlane = input_size[-3]
2587    inputHeight = input_size[-2]
2588    inputWidth = input_size[-1]
2589
2590    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2591    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2592
2593    mem_format = utils.suggest_memory_format(input)
2594
2595    avg_pool2d_backward_shape_check(
2596        input,
2597        gradOutput_,
2598        nbatch,
2599        kH,
2600        kW,
2601        dH,
2602        dW,
2603        padH,
2604        padW,
2605        nInputPlane,
2606        inputHeight,
2607        inputWidth,
2608        outputHeight,
2609        outputWidth,
2610        mem_format,
2611    )
2612
2613    return torch.empty(
2614        input_size,
2615        dtype=input.dtype,
2616        device=input.device,
2617        memory_format=mem_format,
2618    )
2619
2620
2621@register_meta(aten.avg_pool3d)
2622@out_wrapper()
2623def meta_avg_pool3d(
2624    input,
2625    kernel_size,
2626    stride=(),
2627    padding=(0,),
2628    ceil_mode=False,
2629    count_include_pad=True,
2630    divisor_override=None,
2631):
2632    torch._check(
2633        len(kernel_size) in (1, 3),
2634        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2635    )
2636    kT = kernel_size[0]
2637    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2638    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2639
2640    torch._check(
2641        not stride or len(stride) in (1, 3),
2642        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2643    )
2644    dT = kT if not stride else stride[0]
2645    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2646    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2647
2648    torch._check(
2649        len(padding) in (1, 3),
2650        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2651    )
2652    padT = padding[0]
2653    padH = padT if len(padding) == 1 else padding[1]
2654    padW = padT if len(padding) == 1 else padding[2]
2655
2656    torch._check(
2657        input.ndim in (4, 5),
2658        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2659    )
2660
2661    torch._check(
2662        not divisor_override or divisor_override != 0,
2663        lambda: "divisor must be not zero",
2664    )
2665
2666    nbatch = input.size(0)
2667    nslices = input.size(-4)
2668    itime = input.size(-3)
2669    iheight = input.size(-2)
2670    iwidth = input.size(-1)
2671
2672    otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2673    oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2674    owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2675
2676    pool3d_shape_check(
2677        input,
2678        nslices,
2679        kT,
2680        kH,
2681        kW,
2682        dT,
2683        dH,
2684        dW,
2685        padT,
2686        padH,
2687        padW,
2688        1,
2689        1,
2690        1,
2691        itime,
2692        iheight,
2693        iwidth,
2694        otime,
2695        oheight,
2696        owidth,
2697        "avg_pool3d()",
2698        check_input_size=True,
2699    )
2700
2701    if input.ndim == 4:
2702        return input.new_empty((nslices, otime, oheight, owidth))
2703    else:
2704        return input.new_empty((nbatch, nslices, otime, oheight, owidth))
2705
2706
2707@register_meta(aten.avg_pool3d_backward)
2708@out_wrapper("grad_input")
2709def meta_avg_pool3d_backward(
2710    grad_output,
2711    input,
2712    kernel_size,
2713    stride,
2714    padding,
2715    ceil_mode,
2716    count_include_pad,
2717    divisor_override,
2718):
2719    torch._check(
2720        len(kernel_size) in (1, 3),
2721        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2722    )
2723    kT = kernel_size[0]
2724    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2725    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2726
2727    torch._check(
2728        not stride or len(stride) in (1, 3),
2729        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2730    )
2731    dT = kT if not stride else stride[0]
2732    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2733    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2734
2735    torch._check(
2736        len(padding) in (1, 3),
2737        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2738    )
2739    padT = padding[0]
2740    padH = padT if len(padding) == 1 else padding[1]
2741    padW = padT if len(padding) == 1 else padding[2]
2742
2743    torch._check(
2744        input.ndim in (4, 5),
2745        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2746    )
2747
2748    torch._check(
2749        not divisor_override or divisor_override != 0,
2750        lambda: "divisor must be not zero",
2751    )
2752
2753    nslices = input.size(-4)
2754    itime = input.size(-3)
2755    iheight = input.size(-2)
2756    iwidth = input.size(-1)
2757
2758    otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2759    oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2760    owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2761
2762    avg_pool3d_backward_shape_check(
2763        input,
2764        grad_output,
2765        nslices,
2766        kT,
2767        kH,
2768        kW,
2769        dT,
2770        dH,
2771        dW,
2772        padT,
2773        padH,
2774        padW,
2775        itime,
2776        iheight,
2777        iwidth,
2778        otime_for_shape_check,
2779        oheight_for_shape_check,
2780        owidth_for_shape_check,
2781        "avg_pool3d_backward()",
2782    )
2783
2784    return input.new_empty(input.shape)
2785
2786
2787@register_meta(aten._adaptive_avg_pool2d.default)
2788def meta_adaptive_avg_pool2d(self, output_size):
2789    torch._check(
2790        self.ndim == 3 or self.ndim == 4,
2791        lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
2792    )
2793    output_shape = self.shape[:-2] + tuple(output_size)
2794    memory_format = utils.suggest_memory_format(self)
2795    # need to set memory_format to preserve the memory format of the input
2796    # channel last input should have channel last output
2797    return torch.empty(
2798        output_shape,
2799        dtype=self.dtype,
2800        device=self.device,
2801        memory_format=memory_format,
2802    )
2803
2804
2805@register_meta(aten._adaptive_avg_pool3d.default)
2806def meta_adaptive_avg_pool3d(self, output_size):
2807    torch._check(
2808        self.ndim == 4 or self.ndim == 5,
2809        lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
2810    )
2811    return self.new_empty(self.shape[:-3] + tuple(output_size))
2812
2813
2814@register_meta(aten._adaptive_avg_pool2d_backward.default)
2815def meta__adaptive_avg_pool2d_backward(grad_out, self):
2816    ndim = grad_out.ndim
2817    for i in range(1, ndim):
2818        torch._check(
2819            grad_out.size(i) > 0,
2820            lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
2821                      size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
2822        )
2823    torch._check(
2824        ndim == 3 or ndim == 4,
2825        lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
2826    )
2827    torch._check(
2828        self.dtype == grad_out.dtype,
2829        lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
2830    )
2831    memory_format = torch.contiguous_format
2832    if is_channels_last(self):
2833        memory_format = torch.channels_last
2834    return self.new_empty(self.shape).to(memory_format=memory_format)
2835
2836
2837@register_meta(aten._adaptive_avg_pool3d_backward)
2838@out_wrapper("grad_input")
2839def meta__adaptive_avg_pool3d_backward(grad_output, self):
2840    _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
2841    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2842
2843
2844def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
2845    ndim = grad_output.ndim
2846    for i in range(1, ndim):
2847        torch._check(
2848            grad_output.size(i) > 0,
2849            lambda: (
2850                f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
2851                f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
2852            ),
2853        )
2854
2855
2856@register_meta(aten.adaptive_max_pool2d)
2857@out_wrapper("out", "indices")
2858def meta_adaptive_max_pool2d(input, output_size):
2859    ndim = input.ndim
2860    torch._check(
2861        ndim in (3, 4),
2862        lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
2863    )
2864    for i in range(1, ndim):
2865        torch._check(
2866            input.size(i) > 0,
2867            lambda: (
2868                f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
2869                f"but input has sizes {input.shape} with dimension {i} being empty"
2870            ),
2871        )
2872
2873    torch._check(
2874        len(output_size) == 2,
2875        lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
2876    )
2877
2878    dimH = 1
2879    sizeB = 1
2880    sizeD = 0
2881
2882    if input.ndim == 4:
2883        sizeB = input.size(0)
2884        dimH += 1
2885
2886    sizeD = input.size(dimH - 1)
2887    osizeH, osizeW = output_size
2888
2889    if input.ndim == 3:
2890        out_shape = (sizeD, osizeH, osizeW)
2891        out = input.new_empty(out_shape)
2892        indices = input.new_empty(out_shape, dtype=torch.int64)
2893        return out, indices
2894    else:
2895        out_shape = (sizeB, sizeD, osizeH, osizeW)  # type: ignore[assignment]
2896        memory_format = utils.suggest_memory_format(input)
2897        out = input.new_empty(out_shape).to(memory_format=memory_format)
2898        indices = input.new_empty(out_shape, dtype=torch.int64).to(
2899            memory_format=memory_format
2900        )
2901        return out, indices
2902
2903
2904@register_meta(aten.adaptive_max_pool2d_backward)
2905@out_wrapper("grad_input")
2906def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
2907    ndim = grad_output.ndim
2908    torch._check(
2909        ndim in (3, 4),
2910        lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
2911    )
2912
2913    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
2914
2915    torch._check(
2916        input.dtype == grad_output.dtype,
2917        lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
2918    )
2919
2920    memory_format = utils.suggest_memory_format(input)
2921    return input.new_empty(input.shape).to(memory_format=memory_format)
2922
2923
2924@register_meta(aten.adaptive_max_pool3d)
2925@out_wrapper("out", "indices")
2926def meta_adaptive_max_pool3d(input, output_size):
2927    ndim = input.ndim
2928    torch._check(
2929        ndim in (4, 5),
2930        lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
2931    )
2932    for i in range(1, ndim):
2933        torch._check(
2934            input.size(i) > 0,
2935            lambda: (
2936                f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
2937                f"but input has sizes {input.shape} with dimension {i} being empty"
2938            ),
2939        )
2940
2941    torch._check(
2942        len(output_size) == 3,
2943        lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
2944    )
2945
2946    dimD = 0
2947    sizeB = 1
2948    sizeD = 0
2949
2950    if ndim == 5:
2951        sizeB = input.size(0)
2952        dimD += 1
2953
2954    sizeD = input.size(dimD)
2955    osizeT, osizeH, osizeW = output_size
2956
2957    if ndim == 4:
2958        out_shape = (sizeD, osizeT, osizeH, osizeW)
2959    else:
2960        out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW)  # type: ignore[assignment]
2961
2962    out = input.new_empty(out_shape)
2963    indices = input.new_empty(out_shape, dtype=torch.int64)
2964
2965    return out, indices
2966
2967
2968@register_meta(aten.adaptive_max_pool3d_backward)
2969@out_wrapper("grad_input")
2970def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
2971    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
2972    return input.new_empty(input.shape)
2973
2974
2975@register_meta(aten.repeat_interleave.Tensor)
2976def meta_repeat_interleave_Tensor(repeats, output_size=None):
2977    if output_size is None:
2978        raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
2979    return repeats.new_empty(output_size)
2980
2981
2982@register_meta([aten.complex.default, aten.complex.out])
2983@out_wrapper()
2984def meta_complex(real, imag):
2985    assert real.dtype.is_floating_point
2986    assert imag.dtype.is_floating_point
2987    out_shape = _broadcast_shapes(real.shape, imag.shape)
2988    return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
2989
2990
2991@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
2992@out_wrapper()
2993def nonzero_static(self, *, size: int, fill_value: int = -1):
2994    return self.new_empty((size, self.dim()), dtype=torch.long)
2995
2996
2997@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
2998def meta_index_Tensor(self, indices):
2999    torch._check(bool(indices), lambda: "at least one index must be provided")
3000    # aten::index is the internal advanced indexing implementation
3001    # checkIndexTensorTypes and expandTensors
3002    result: List[Optional[Tensor]] = []
3003    for i, index in enumerate(indices):
3004        if index is not None:
3005            torch._check(
3006                index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
3007                lambda: "tensors used as indices must be long, int, byte or bool tensors",
3008            )
3009            if index.dtype in [torch.int8, torch.bool]:
3010                nonzero = index.nonzero()
3011                k = len(result)
3012                torch._check_index(
3013                    k + index.ndim <= self.ndim,
3014                    lambda: f"too many indices for tensor of dimension {self.ndim}",
3015                )
3016                for j in range(index.ndim):
3017                    torch._check_index(
3018                        index.shape[j] == self.shape[k + j],
3019                        lambda: f"The shape of the mask {index.shape} at index {i} "
3020                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
3021                    )
3022                    result.append(nonzero.select(1, j))
3023            else:
3024                result.append(index)
3025        else:
3026            result.append(index)
3027    indices = result
3028    torch._check(
3029        len(indices) <= self.ndim,
3030        lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
3031    )
3032    # expand_outplace
3033    import torch._refs as refs  # avoid import cycle in mypy
3034
3035    indices = list(refs._maybe_broadcast(*indices))
3036    # add missing null tensors
3037    while len(indices) < self.ndim:
3038        indices.append(None)
3039
3040    # hasContiguousSubspace
3041    #   true if all non-null tensors are adjacent
3042    # See:
3043    # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
3044    # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
3045    state = 0
3046    has_contiguous_subspace = False
3047    for index in indices:
3048        if state == 0:
3049            if index is not None:
3050                state = 1
3051        elif state == 1:
3052            if index is None:
3053                state = 2
3054        else:
3055            if index is not None:
3056                break
3057    else:
3058        has_contiguous_subspace = True
3059
3060    # transposeToFront
3061    # This is the logic that causes the newly inserted dimensions to show up
3062    # at the beginning of the tensor, if they're not contiguous
3063    if not has_contiguous_subspace:
3064        dims = []
3065        transposed_indices = []
3066        for i, index in enumerate(indices):
3067            if index is not None:
3068                dims.append(i)
3069                transposed_indices.append(index)
3070        for i, index in enumerate(indices):
3071            if index is None:
3072                dims.append(i)
3073                transposed_indices.append(index)
3074        self = self.permute(dims)
3075        indices = transposed_indices
3076
3077    # AdvancedIndex::AdvancedIndex
3078    # Now we can assume the indices have contiguous subspace
3079    # This is simplified from AdvancedIndex which goes to more effort
3080    # to put the input and indices in a form so that TensorIterator can
3081    # take them.  If we write a ref for this, probably that logic should
3082    # get implemented
3083    before_shape: List[int] = []
3084    after_shape: List[int] = []
3085    replacement_shape: List[int] = []
3086    for dim, index in enumerate(indices):
3087        if index is None:
3088            if replacement_shape:
3089                after_shape.append(self.shape[dim])
3090            else:
3091                before_shape.append(self.shape[dim])
3092        else:
3093            replacement_shape = list(index.shape)
3094    return self.new_empty(before_shape + replacement_shape + after_shape)
3095
3096
3097@register_meta([aten.convolution_backward.default])
3098def meta_convolution_backward(
3099    grad_output_,
3100    input_,
3101    weight_,
3102    bias_sizes_opt,
3103    stride,
3104    padding,
3105    dilation,
3106    transposed,
3107    output_padding,
3108    groups,
3109    output_mask,
3110):
3111    # High level logic taken from slow_conv3d_backward_cpu which should
3112    # be representative of all convolution_backward impls
3113    backend_grad_input = None
3114    backend_grad_weight = None
3115    backend_grad_bias = None
3116
3117    if output_mask[0]:
3118        backend_grad_input = grad_output_.new_empty(input_.size())
3119    if output_mask[1]:
3120        backend_grad_weight = grad_output_.new_empty(weight_.size())
3121    if output_mask[2]:
3122        backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3123
3124    return (backend_grad_input, backend_grad_weight, backend_grad_bias)
3125
3126
3127@register_meta([aten.addbmm.default, aten.addbmm.out])
3128@out_wrapper()
3129def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
3130    dim1 = batch1.size(1)
3131    dim2 = batch2.size(2)
3132    self = self.expand((dim1, dim2))
3133    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3134    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3135    torch._check(
3136        batch1.size(0) == batch2.size(0),
3137        lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
3138    )
3139    torch._check(
3140        batch1.size(2) == batch2.size(1),
3141        lambda: (
3142            f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
3143            f"and {batch2.size(1)}x{batch2.size(2)})"
3144        ),
3145    )
3146    torch._check(
3147        self.size(0) == dim1 and self.size(1) == dim2,
3148        lambda: "self tensor does not match matmul output shape",
3149    )
3150    return self.new_empty(self.size())
3151
3152
3153@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
3154def meta__fused_adam_(
3155    self,
3156    grads,
3157    exp_avgs,
3158    exp_avg_sqs,
3159    max_exp_avg_sqs,
3160    state_steps,
3161    *,
3162    lr,
3163    beta1,
3164    beta2,
3165    weight_decay,
3166    eps,
3167    amsgrad,
3168    maximize,
3169    grad_scale=None,
3170    found_inf=None,
3171):
3172    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3173        torch._check(
3174            isinstance(l, List),
3175            lambda: f"exponent must be a tensor list but got {type(l)}",
3176        )
3177
3178
3179@register_meta([aten._fused_adam.default])
3180def meta__fused_adam(
3181    self,
3182    grads,
3183    exp_avgs,
3184    exp_avg_sqs,
3185    max_exp_avg_sqs,
3186    state_steps,
3187    *,
3188    lr,
3189    beta1,
3190    beta2,
3191    weight_decay,
3192    eps,
3193    amsgrad,
3194    maximize,
3195    grad_scale=None,
3196    found_inf=None,
3197):
3198    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3199        torch._check(
3200            isinstance(l, List),
3201            lambda: f"exponent must be a tensor list but got {type(l)}",
3202        )
3203
3204    def empty_like_list(tensor_list):
3205        return [torch.empty_like(t) for t in tensor_list]
3206
3207    return (
3208        empty_like_list(self),
3209        empty_like_list(grads),
3210        empty_like_list(exp_avgs),
3211        empty_like_list(exp_avg_sqs),
3212        empty_like_list(max_exp_avg_sqs),
3213    )
3214
3215
3216@register_meta([aten._int_mm])
3217@out_wrapper()
3218def meta__int_mm(a, b):
3219    torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
3220    torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
3221    torch._check(
3222        a.dtype is torch.int8,
3223        lambda: f"expected self to be int8, got {a.dtype}",
3224    )
3225    torch._check(
3226        b.dtype is torch.int8,
3227        lambda: f"expected mat2 to be int8, got {b.dtype}",
3228    )
3229    torch._check(
3230        a.size(1) == b.size(0),
3231        lambda: (
3232            f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
3233            f"and {b.size(0)}x{b.size(1)})"
3234        ),
3235    )
3236    return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
3237
3238
3239@register_meta([aten._convert_weight_to_int4pack])
3240def meta__convert_weight_to_int4pack(w, inner_k_tiles):
3241    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3242    torch._check(
3243        w.dtype is torch.uint8,
3244        lambda: f"expected w to be uint8, got {w.dtype}",
3245    )
3246    n = w.size(0)
3247    k = w.size(1) * 2  # w is [n][k / 2] uint8
3248    return w.new_empty(
3249        (
3250            n // 8,
3251            k // (inner_k_tiles * 16),
3252            32,
3253            inner_k_tiles // 2,
3254        ),
3255        dtype=torch.int32,
3256    )
3257
3258
3259@register_meta([aten._weight_int4pack_mm])
3260def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
3261    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3262    torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
3263    torch._check(
3264        x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3265        lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3266    )
3267    torch._check(
3268        w.dtype is torch.int32,
3269        lambda: f"expected w to be int32, got {w.dtype}",
3270    )
3271    return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
3272
3273
3274@register_meta([aten._weight_int8pack_mm])
3275def meta__weight_int8pack_mm(x, w, q_scales):
3276    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3277    torch._check(
3278        x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3279        lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3280    )
3281    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3282    torch._check(
3283        w.dtype is torch.int8,
3284        lambda: f"expected w to be int8, got {w.dtype}",
3285    )
3286    return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3287
3288
3289@register_meta(aten._cdist_forward.default)
3290def meta_cdist_forward(x1, x2, p, compute_mode):
3291    torch._check(
3292        x1.dim() >= 2,
3293        lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
3294    )
3295    torch._check(
3296        x2.dim() >= 2,
3297        lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
3298    )
3299    torch._check(
3300        x1.size(-1) == x2.size(-1),
3301        lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
3302    )
3303    torch._check(
3304        utils.is_float_dtype(x1.dtype),
3305        lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
3306    )
3307    torch._check(
3308        utils.is_float_dtype(x2.dtype),
3309        lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
3310    )
3311    torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
3312    torch._check(
3313        compute_mode in (None, 1, 2),
3314        lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
3315    )
3316    r1 = x1.size(-2)
3317    r2 = x2.size(-2)
3318    batch_tensor1 = x1.shape[:-2]
3319    batch_tensor2 = x2.shape[:-2]
3320    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3321    output_shape.extend([r1, r2])
3322    return x1.new_empty(output_shape)
3323
3324
3325@register_meta(aten._cdist_backward)
3326@out_wrapper()
3327def meta_cdist_backward(grad, x1, x2, p, cdist):
3328    c1 = x1.shape[-1]
3329    r1 = x1.shape[-2]
3330    r2 = x2.shape[-2]
3331    batch_tensor1 = x1.shape[:-2]
3332    batch_tensor2 = x2.shape[:-2]
3333    expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3334    tensor1_expand_size = expand_batch_portion.copy()
3335    tensor1_expand_size.extend([r1, c1])
3336    batch_product = math.prod(expand_batch_portion)
3337    if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
3338        return torch.zeros_like(x1)
3339    if tensor1_expand_size != list(x1.shape):
3340        x1 = x1.expand(tensor1_expand_size)
3341    return torch.empty_like(x1, memory_format=torch.contiguous_format)
3342
3343
3344# NB: This meta function accepts non-meta arguments!  When this behavior
3345# was originally introduced this was accidental, but it is now load bearing
3346# as people are using this so that they can conveniently test code involving
3347# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module)
3348@register_meta(aten._embedding_bag.default)
3349def meta_embedding_bag(
3350    weight,
3351    indices,
3352    offsets,
3353    scale_grad_by_freq=False,
3354    mode=0,
3355    sparse=False,
3356    per_sample_weights=None,
3357    include_last_offset=False,
3358    padding_idx=-1,
3359):
3360    torch._check(
3361        indices.dtype in (torch.long, torch.int),
3362        lambda: f"expected indices to be long or int, got {indices.dtype}",
3363    )
3364    torch._check(
3365        offsets.dtype in (torch.long, torch.int),
3366        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
3367    )
3368    torch._check(
3369        utils.is_float_dtype(weight.dtype),
3370        lambda: f"expected weight to be floating point type, got {weight.dtype}",
3371    )
3372
3373    num_bags = offsets.size(0)
3374    if include_last_offset:
3375        torch._check(
3376            num_bags >= 1,
3377            lambda: "include_last_offset: numBags should be at least 1",
3378        )
3379        num_bags -= 1
3380
3381    output = weight.new_empty(num_bags, weight.size(1))
3382    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
3383
3384    if per_sample_weights is not None:
3385        torch._check(
3386            mode == MODE_SUM,
3387            lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
3388        )
3389        torch._check(
3390            per_sample_weights.dtype == weight.dtype,
3391            lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
3392        )
3393        torch._check(
3394            per_sample_weights.ndim == 1,
3395            lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
3396        )
3397        torch._check(
3398            per_sample_weights.numel() == indices.numel(),
3399            lambda: (
3400                f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
3401                f"to be the same as indices.numel() ({indices.numel()})"
3402            ),
3403        )
3404
3405    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
3406        return (
3407            is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
3408        )
3409
3410    def is_fast_path_index_select(src, output, padding_idx):
3411        return (
3412            (src.dtype == torch.float or src.dtype == torch.half)
3413            and src.stride(1) == 1
3414            and output.stride(1) == 1
3415            and padding_idx < 0
3416        )
3417
3418    def is_fast_path(src, scale, output, padding_idx):
3419        if scale is not None:
3420            return is_fast_path_index_select_scale(src, scale, output, padding_idx)
3421        else:
3422            return is_fast_path_index_select(src, output, padding_idx)
3423
3424    if device_hint(offsets) != "cpu":
3425        offset2bag = indices.new_empty(indices.size(0))
3426        bag_size = indices.new_empty(offsets.size())
3427        if mode == MODE_MAX:
3428            max_indices = indices.new_empty(num_bags, weight.size(1))
3429        else:
3430            max_indices = indices.new_empty(0)
3431    else:
3432        fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
3433        if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
3434            offset2bag = offsets.new_empty(indices.size(0))
3435        else:
3436            offset2bag = offsets.new_empty(0)
3437        bag_size = offsets.new_empty(num_bags)
3438        # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
3439        numBags = offsets.shape[0]
3440        if mode == MODE_MAX:
3441            if include_last_offset:
3442                torch._check(
3443                    numBags >= 1,
3444                    lambda: "include_last_offset: numBags should be at least 1",
3445                )
3446                numBags -= 1
3447            max_indices = offsets.new_empty(numBags, weight.shape[1])
3448        else:
3449            max_indices = offsets.new_empty(bag_size.size())
3450    return output, offset2bag, bag_size, max_indices
3451
3452
3453@register_meta(aten._embedding_bag_forward_only.default)
3454def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
3455    output, offset2bag, bag_size, max_indices = meta_embedding_bag(
3456        weight, indices, offsets, *args
3457    )
3458    if device_hint(offsets) == "cpu":
3459        bag_size = offsets.new_empty(offsets.size())
3460    return output, offset2bag, bag_size, max_indices
3461
3462
3463def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
3464    # if specified, dtype takes precedence
3465    if dtype:
3466        return dtype
3467
3468    if input.dtype.is_floating_point or input.dtype.is_complex:
3469        return input.dtype
3470    elif promote_int_to_long:
3471        return torch.long
3472
3473    return input.dtype
3474
3475
3476@register_meta([aten.nansum.default, aten.nansum.out])
3477@out_wrapper()
3478def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
3479    output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
3480    dims = utils.reduction_dims(input.shape, dims)
3481    output_shape = _compute_reduction_shape(input, dims, keepdim)
3482    return input.new_empty(output_shape, dtype=output_dtype)
3483
3484
3485@register_meta([aten.median.default, aten.nanmedian.default])
3486def meta_median(input):
3487    output_shape = utils.compute_reduction_output_shape(
3488        input.shape, tuple(range(input.dim()))
3489    )
3490    return input.new_empty(output_shape)
3491
3492
3493@register_meta(
3494    [
3495        aten.median.dim,
3496        aten.median.dim_values,
3497        aten.nanmedian.dim,
3498        aten.nanmedian.dim_values,
3499        aten.mode.default,
3500        aten.mode.values,
3501    ]
3502)
3503@out_wrapper("values", "indices")
3504def meta_median_mode_dim(input, dim=-1, keepdim=False):
3505    if device_hint(input) == "cuda":
3506        utils.alert_not_deterministic("median CUDA with indices output")
3507    dim = utils.reduction_dims(input.shape, (dim,))
3508    output_shape = _compute_reduction_shape(input, dim, keepdim)
3509    return (
3510        input.new_empty(output_shape),
3511        input.new_empty(output_shape, dtype=torch.long),
3512    )
3513
3514
3515@register_meta(aten.logical_not_.default)
3516def meta_logical_not_(self):
3517    return self
3518
3519
3520@register_meta(aten.repeat.default)
3521def meta_repeat(self, repeats):
3522    torch._check(
3523        len(repeats) >= self.dim(),
3524        lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3525    )
3526    # Add new leading dimensions to the tensor if the
3527    # number of target dimensions is larger than the
3528    # number of source dimensions.
3529    num_new_dimensions = len(repeats) - self.dim()
3530    padded_size = (1,) * num_new_dimensions + tuple(self.shape)
3531    target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
3532    return self.new_empty(target_size)
3533
3534
3535@register_meta(aten.zero_.default)
3536def meta_zero_(self):
3537    return self
3538
3539
3540@register_meta(
3541    [
3542        aten.mul_.Scalar,
3543        aten.div_.Scalar,
3544        aten.mul_.Tensor,
3545        aten.div_.Tensor,
3546        aten.logical_and_.default,
3547        aten.logical_or_.default,
3548        aten.logical_xor_.default,
3549    ],
3550)
3551def meta_binop_inplace(self, other):
3552    if isinstance(other, torch.Tensor):
3553        check_inplace_broadcast(self.shape, other.shape)
3554    return self
3555
3556
3557@register_meta(
3558    [
3559        aten.add_.Scalar,
3560        aten.sub_.Scalar,
3561        aten.add_.Tensor,
3562        aten.sub_.Tensor,
3563    ],
3564)
3565def meta_binop_inplace_alpha(self, other, alpha=1):
3566    if isinstance(other, torch.Tensor):
3567        check_inplace_broadcast(self.shape, other.shape)
3568    return self
3569
3570
3571@register_meta([aten.round.default, aten.round.decimals])
3572def meta_round(self, **kwargs):
3573    return elementwise_meta(
3574        self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3575    )
3576
3577
3578def shift_dtype_check(fn_name, self, val):
3579    torch._check(
3580        utils.is_integer_dtype(self.dtype),
3581        lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
3582    )
3583    if isinstance(val, torch.Tensor):
3584        torch._check(
3585            utils.is_integer_dtype(val.dtype),
3586            lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
3587        )
3588    else:
3589        torch._check(
3590            isinstance(val, IntLike),
3591            lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
3592        )
3593
3594
3595@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
3596def meta_rshifts(self, other):
3597    shift_dtype_check("rshift", self, other)
3598    return elementwise_meta(
3599        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3600    )
3601
3602
3603@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
3604def meta_lshifts(self, other):
3605    shift_dtype_check("lshift", self, other)
3606    return elementwise_meta(
3607        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3608    )
3609
3610
3611@register_meta(aten.zero.default)
3612def meta_zero(self):
3613    return self.new_empty(self.shape)
3614
3615
3616@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
3617def meta_fill_(self, val):
3618    return self
3619
3620
3621@register_meta([aten.fill.Tensor, aten.fill.Scalar])
3622def meta_fill(self, val):
3623    return torch.empty_like(self)
3624
3625
3626@register_meta(aten.relu_.default)
3627def meta_relu_(self):
3628    return self
3629
3630
3631@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
3632def meta_index_put(self, indices, values, accumulate=False):
3633    return torch.empty_like(self)
3634
3635
3636@register_meta(aten.masked_fill_.Scalar)
3637def meta_masked_fill_(self, mask, value):
3638    check_inplace_broadcast(self.shape, mask.shape)
3639    return self
3640
3641
3642@register_meta(aten._masked_scale.default)
3643def meta__masked_scale(self, mask, scale):
3644    masked_scale = self.new_empty(self.size()).to(
3645        memory_format=utils.suggest_memory_format(self)
3646    )
3647    return masked_scale
3648
3649
3650@register_meta(aten.masked_scatter_)
3651def meta_masked_scatter_(self, mask, source):
3652    torch._check(
3653        mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
3654    )
3655    torch._check(
3656        self.dtype == source.dtype,
3657        lambda: "masked_scatter: expected self and source to have same "
3658        "dtypes but got {self.dtype} and {source.dtype}",
3659    )
3660    return self
3661
3662
3663@register_meta(aten.masked_scatter)
3664@out_wrapper()
3665def meta_masked_scatter(self, mask, source):
3666    self, mask = _maybe_broadcast(self, mask)
3667    output = torch.empty_like(self, memory_format=torch.contiguous_format)
3668    return meta_masked_scatter_(output, mask, source)
3669
3670
3671@register_meta(aten.masked_scatter_backward)
3672def meta_masked_scatter_backward(self, mask, sizes):
3673    return self.new_empty(sizes)
3674
3675
3676@register_meta(aten.index_put_.default)
3677def meta_index_put_(self, indices, values, accumulate=False):
3678    return self
3679
3680
3681@register_meta(aten.alias.default)
3682def meta_alias(self):
3683    return self.view(self.shape)
3684
3685
3686def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
3687    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3688    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3689
3690    batch1_sizes = batch1.size()
3691    batch2_sizes = batch2.size()
3692
3693    bs = batch1_sizes[0]
3694    contraction_size = batch1_sizes[2]
3695    res_rows = batch1_sizes[1]
3696    res_cols = batch2_sizes[2]
3697    output_size = (bs, res_rows, res_cols)
3698
3699    torch._check(
3700        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
3701        lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
3702        f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
3703    )
3704
3705    # TODO: handle out
3706
3707    output = batch2.new_empty(output_size)
3708
3709    if not is_bmm and self_baddbmm is not None:
3710        torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
3711        torch._check(
3712            self_baddbmm.size() == output_size,
3713            lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
3714        )
3715
3716    return output
3717
3718
3719@register_meta(aten.bmm.default)
3720def meta_bmm(self, mat2):
3721    return common_meta_baddbmm_bmm(self, mat2, True)
3722
3723
3724def div_rtn(x, y):
3725    q = x // y
3726    r = x % y
3727    # WARNING: explicit bool conversion here is necessary;
3728    # would be fixed by SymBool
3729    if r != 0 and (bool(r < 0) != bool(y < 0)):
3730        q -= 1
3731    return q
3732
3733
3734def pooling_output_shape_pad_lr(
3735    inputSize,
3736    kernelSize,
3737    pad_l,
3738    pad_r,
3739    stride,
3740    dilation,
3741    ceil_mode,
3742):
3743    outputSize = (
3744        div_rtn(
3745            inputSize
3746            + pad_l
3747            + pad_r
3748            - dilation * (kernelSize - 1)
3749            - 1
3750            + (stride - 1 if ceil_mode else 0),
3751            stride,
3752        )
3753        + 1
3754    )
3755    if ceil_mode:
3756        if (outputSize - 1) * stride >= inputSize + pad_l:
3757            outputSize -= 1
3758    return outputSize
3759
3760
3761def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
3762    torch._check(stride != 0, lambda: "stride should not be zero")
3763    torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
3764    torch._check(
3765        pad <= ((kernelSize - 1) * dilation + 1) // 2,
3766        lambda: (
3767            f"pad should be at most half of effective kernel size, but got pad={pad}, "
3768            f"kernel_size={kernelSize} and dilation={dilation}"
3769        ),
3770    )
3771    return pooling_output_shape_pad_lr(
3772        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
3773    )
3774
3775
3776def pool2d_shape_check(
3777    input,
3778    kH,
3779    kW,
3780    dH,
3781    dW,
3782    padH,
3783    padW,
3784    dilationH,
3785    dilationW,
3786    nInputPlane,
3787    inputHeight,
3788    inputWidth,
3789    outputHeight,
3790    outputWidth,
3791    memory_format,
3792):
3793    ndim = input.dim()
3794    nOutputPlane = nInputPlane
3795
3796    torch._check(
3797        kW > 0 and kH > 0,
3798        lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
3799    )
3800    torch._check(
3801        dW > 0 and dH > 0,
3802        lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
3803    )
3804    torch._check(
3805        dilationH > 0 and dilationW > 0,
3806        lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
3807    )
3808
3809    valid_dims = input.size(1) != 0 and input.size(2) != 0
3810
3811    if memory_format == torch.channels_last:
3812        torch._check(
3813            ndim == 4 and valid_dims and input.size(3) != 0,
3814            lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
3815            " with optional 0 dim batch size for input, but got: {input.size()}",
3816        )
3817    else:
3818        torch._check(
3819            (ndim == 3 and input.size(0) != 0 and valid_dims)
3820            or (ndim == 4 and valid_dims and input.size(3) != 0),
3821            lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
3822        )
3823
3824    torch._check(
3825        kW // 2 >= padW and kH // 2 >= padH,
3826        lambda: "pad should be smaller than or equal to half of kernel size, but got "
3827        f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
3828    )
3829
3830    torch._check(
3831        outputWidth >= 1 and outputHeight >= 1,
3832        lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
3833        f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
3834        "Output size is too small",
3835    )
3836
3837
3838def pool3d_shape_check(
3839    input: Tensor,
3840    nslices: int,
3841    kT: int,
3842    kH: int,
3843    kW: int,
3844    dT: int,
3845    dH: int,
3846    dW: int,
3847    pT: int,
3848    pH: int,
3849    pW: int,
3850    dilationT: int,
3851    dilationH: int,
3852    dilationW: int,
3853    itime: int,
3854    iheight: int,
3855    iwidth: int,
3856    otime: int,
3857    oheight: int,
3858    owidth: int,
3859    fn_name: str,
3860    check_input_size: bool = False,
3861):
3862    ndim = input.ndim
3863
3864    torch._check(
3865        kT > 0 and kW > 0 and kH > 0,
3866        lambda: (
3867            f"kernel size should be greater than zero, but got "
3868            f"kT: {kT}, kH: {kH}, kW: {kW}"
3869        ),
3870    )
3871    torch._check(
3872        dT > 0 and dW > 0 and dH > 0,
3873        lambda: (
3874            f"stride should be greater than zero, but got "
3875            f"dT: {dT}, dH: {dH}, dW: {dW}"
3876        ),
3877    )
3878    torch._check(
3879        dilationT > 0 and dilationW > 0 and dilationH > 0,
3880        lambda: (
3881            f"dilation should be greater than zero, but got "
3882            f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
3883        ),
3884    )
3885
3886    torch._check(
3887        ndim in (4, 5),
3888        lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
3889    )
3890
3891    for i in range(ndim):
3892        if ndim == 5 and i == 0:
3893            # size of batch-dim can be 0.
3894            continue
3895        torch._check(
3896            input.size(i) > 0,
3897            lambda: (
3898                f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
3899                f" but input has a shape of {input.shape}"
3900                f" and non-batch dimension {input.size(i)} has length zero!"
3901            ),
3902        )
3903
3904    if check_input_size:  # AveragePool3d
3905        torch._check(
3906            itime >= kT and iheight >= kH and iwidth >= kW,
3907            lambda: (
3908                f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
3909                f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
3910            ),
3911        )
3912
3913    torch._check(
3914        kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
3915        lambda: (
3916            f"pad should be smaller than or equal to half of kernel size, but got "
3917            f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
3918        ),
3919    )
3920
3921    torch._check(
3922        otime >= 1 and owidth >= 1 and oheight >= 1,
3923        lambda: (
3924            f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
3925            f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
3926            f"Output size is too small"
3927        ),
3928    )
3929
3930
3931def max_pool3d_backward_shape_check(
3932    input,
3933    grad_output,
3934    indices,
3935    nslices,
3936    kT,
3937    kH,
3938    kW,
3939    dT,
3940    dH,
3941    dW,
3942    pT,
3943    pH,
3944    pW,
3945    dilationT,
3946    dilationH,
3947    dilationW,
3948    itime,
3949    iheight,
3950    iwidth,
3951    otime,
3952    oheight,
3953    owidth,
3954    fn_name,
3955):
3956    ndim = input.ndim
3957
3958    pool3d_shape_check(
3959        input,
3960        nslices,
3961        kT,
3962        kH,
3963        kW,
3964        dT,
3965        dH,
3966        dW,
3967        pT,
3968        pH,
3969        pW,
3970        dilationT,
3971        dilationH,
3972        dilationW,
3973        itime,
3974        iheight,
3975        iwidth,
3976        otime,
3977        oheight,
3978        owidth,
3979        fn_name,
3980    )
3981
3982    check_dim_size(grad_output, ndim, ndim - 4, nslices)
3983    check_dim_size(grad_output, ndim, ndim - 3, otime)
3984    check_dim_size(grad_output, ndim, ndim - 2, oheight)
3985    check_dim_size(grad_output, ndim, ndim - 1, owidth)
3986
3987    check_dim_size(indices, ndim, ndim - 4, nslices)
3988    check_dim_size(indices, ndim, ndim - 3, otime)
3989    check_dim_size(indices, ndim, ndim - 2, oheight)
3990    check_dim_size(indices, ndim, ndim - 1, owidth)
3991
3992
3993def avg_pool3d_backward_shape_check(
3994    input: Tensor,
3995    grad_output: Tensor,
3996    nslices: int,
3997    kT: int,
3998    kH: int,
3999    kW: int,
4000    dT: int,
4001    dH: int,
4002    dW: int,
4003    pT: int,
4004    pH: int,
4005    pW: int,
4006    itime: int,
4007    iheight: int,
4008    iwidth: int,
4009    otime: int,
4010    oheight: int,
4011    owidth: int,
4012    fn_name: str,
4013):
4014    ndim = input.ndim
4015
4016    pool3d_shape_check(
4017        input,
4018        nslices,
4019        kT,
4020        kH,
4021        kW,
4022        dT,
4023        dH,
4024        dW,
4025        pT,
4026        pH,
4027        pW,
4028        1,
4029        1,
4030        1,
4031        itime,
4032        iheight,
4033        iwidth,
4034        otime,
4035        oheight,
4036        owidth,
4037        fn_name,
4038        True,
4039    )
4040
4041    check_dim_size(grad_output, ndim, ndim - 4, nslices)
4042    check_dim_size(grad_output, ndim, ndim - 3, otime)
4043    check_dim_size(grad_output, ndim, ndim - 2, oheight)
4044    check_dim_size(grad_output, ndim, ndim - 1, owidth)
4045
4046
4047def max_pool2d_checks_and_compute_shape(
4048    input,
4049    kernel_size,
4050    stride,
4051    padding,
4052    dilation,
4053    ceil_mode,
4054):
4055    # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
4056    def unpack(name, val):
4057        torch._check(
4058            len(val) in [1, 2],
4059            lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
4060        )
4061        H = val[0]
4062        W = H if len(val) == 1 else val[1]
4063        return H, W
4064
4065    kH, kW = unpack("kernel_size", kernel_size)
4066
4067    torch._check(
4068        len(stride) in [0, 1, 2],
4069        lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
4070    )
4071    if len(stride) == 0:
4072        dH, dW = kH, kW
4073    else:
4074        dH, dW = unpack("stride", stride)
4075
4076    padH, padW = unpack("padding", padding)
4077    dilationH, dilationW = unpack("dilation", dilation)
4078    nInputPlane = input.size(-3)
4079    inputHeight = input.size(-2)
4080    inputWidth = input.size(-1)
4081
4082    memory_format = utils.suggest_memory_format(input)
4083    if memory_format == torch.channels_last:
4084        torch._check(
4085            input.dim() == 4,
4086            lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
4087        )
4088    elif memory_format == torch.contiguous_format:
4089        torch._check(
4090            input.dim() in [3, 4],
4091            lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
4092        )
4093    else:
4094        torch._check(
4095            False,
4096            lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
4097        )
4098
4099    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
4100    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
4101
4102    pool2d_shape_check(
4103        input,
4104        kH,
4105        kW,
4106        dH,
4107        dW,
4108        padH,
4109        padW,
4110        dilationH,
4111        dilationW,
4112        nInputPlane,
4113        inputHeight,
4114        inputWidth,
4115        outputHeight,
4116        outputWidth,
4117        memory_format,
4118    )
4119
4120    return nInputPlane, outputHeight, outputWidth
4121
4122
4123@register_meta(aten.max_pool2d_with_indices_backward.default)
4124def meta_max_pool2d_with_indices_backward(
4125    grad_output,
4126    self,
4127    kernel_size,
4128    stride,
4129    padding,
4130    dilation,
4131    ceil_mode,
4132    indices,
4133):
4134    (
4135        nInputPlane,
4136        outputHeight,
4137        outputWidth,
4138    ) = max_pool2d_checks_and_compute_shape(
4139        self, kernel_size, stride, padding, dilation, ceil_mode
4140    )
4141
4142    torch._check(
4143        self.dtype == grad_output.dtype,
4144        lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
4145    )
4146
4147    nOutputPlane = nInputPlane
4148    ndim = self.ndim
4149
4150    def _check_dim_size(t):
4151        check_dim_size(t, ndim, ndim - 3, nOutputPlane)
4152        check_dim_size(t, ndim, ndim - 2, outputHeight)
4153        check_dim_size(t, ndim, ndim - 1, outputWidth)
4154
4155    _check_dim_size(grad_output)
4156    _check_dim_size(indices)
4157
4158    memory_format = utils.suggest_memory_format(self)
4159    return torch.empty(
4160        self.shape,
4161        dtype=self.dtype,
4162        device=self.device,
4163        memory_format=memory_format,
4164    )
4165
4166
4167@register_meta(aten.max_pool2d_with_indices.default)
4168def meta_max_pool2d_with_indices(
4169    input,
4170    kernel_size,
4171    stride=(),
4172    padding=(0,),
4173    dilation=(1,),
4174    ceil_mode=False,
4175):
4176    (
4177        nInputPlane,
4178        outputHeight,
4179        outputWidth,
4180    ) = max_pool2d_checks_and_compute_shape(
4181        input, kernel_size, stride, padding, dilation, ceil_mode
4182    )
4183
4184    nbatch = input.size(-4) if input.dim() == 4 else 1
4185    memory_format = utils.suggest_memory_format(input)
4186    if input.dim() == 3:
4187        size = [nInputPlane, outputHeight, outputWidth]
4188    else:
4189        size = [nbatch, nInputPlane, outputHeight, outputWidth]
4190    return (
4191        torch.empty(
4192            size,
4193            dtype=input.dtype,
4194            device=input.device,
4195            memory_format=memory_format,
4196        ),
4197        torch.empty(
4198            size,
4199            dtype=torch.int64,
4200            device=input.device,
4201            memory_format=memory_format,
4202        ),
4203    )
4204
4205
4206@register_meta(aten.fractional_max_pool2d.default)
4207def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
4208    torch._check(
4209        self.ndim in (3, 4),
4210        lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
4211    )
4212    ndim = self.ndim
4213
4214    for d in range(ndim - 3, ndim):
4215        torch._check(
4216            self.size(d) > 0,
4217            f"fractional_max_pool2d: Expected input to have non-zero "
4218            f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
4219        )
4220
4221    # the check and message are out of sync, but this matches the structured meta
4222    torch._check(
4223        len(kernel_size) == 2,
4224        lambda: "fractional_max_pool2d: kernel_size must"
4225        "either be a single int or tuple of Ints",
4226    )
4227    torch._check(
4228        len(output_size) == 2,
4229        lambda: "fractional_max_pool2d: output_size must "
4230        "either be a single int or tuple of Ints",
4231    )
4232
4233    input_channels = self.size(-3)
4234    input_height = self.size(-2)
4235    input_width = self.size(-1)
4236    if ndim == 4:
4237        input_batch = self.size(0)
4238    else:
4239        input_batch = 1
4240
4241    torch._check(
4242        self.dtype == random_samples.dtype,
4243        lambda: "Expect _random_samples to have the same dtype as input",
4244    )
4245    torch._check(
4246        random_samples.ndim == 3,
4247        lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
4248    )
4249
4250    n = random_samples.size(0)
4251    c = random_samples.size(1)
4252    d = random_samples.size(2)
4253    torch._check(
4254        n >= input_batch,
4255        "Expect _random_samples.size(0) no less then input batch size.",
4256    )
4257    torch._check(
4258        c == input_channels,
4259        lambda: "Expect _random_samples.size(1) equals to input channel size.",
4260    )
4261    torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
4262
4263    torch._check(
4264        output_size[0] + kernel_size[0] - 1 <= input_height,
4265        lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
4266    )
4267    torch._check(
4268        output_size[1] + kernel_size[1] - 1 <= input_width,
4269        lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
4270    )
4271
4272    if self.dim() == 4:
4273        size = [input_batch, input_channels, output_size[0], output_size[1]]
4274    else:
4275        size = [input_channels, output_size[0], output_size[1]]
4276
4277    return (
4278        torch.empty(
4279            size,
4280            dtype=self.dtype,
4281            device=self.device,
4282        ),
4283        torch.empty(
4284            size,
4285            dtype=torch.int64,
4286            device=self.device,
4287        ),
4288    )
4289
4290
4291@register_meta(aten.max_unpool2d)
4292@out_wrapper()
4293def meta_max_unpool2d(self, indices, output_size):
4294    utils.alert_not_deterministic("max_unpooling2d_forward_out")
4295
4296    torch._check(
4297        indices.dtype == torch.int64,
4298        lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4299    )
4300    torch._check(
4301        len(output_size) == 2,
4302        lambda: (
4303            f"There should be exactly two elements (height, width) in output_size, "
4304            f"but got {len(output_size)} elements."
4305        ),
4306    )
4307
4308    oheight, owidth = output_size
4309
4310    torch._check(
4311        self.ndim in (3, 4),
4312        lambda: (
4313            f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4314            f"but got a tensor with {self.ndim} dimensions."
4315        ),
4316    )
4317    torch._check(
4318        self.shape == indices.shape,
4319        lambda: (
4320            f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
4321            f"but got indices tensor with shape: {indices.shape}"
4322        ),
4323    )
4324
4325    for i in range(1, self.ndim):
4326        torch._check(
4327            self.size(i) > 0,
4328            lambda: (
4329                f"max_unpooling2d(): "
4330                f"Expected input to have non-zero size for non-batch dimensions, "
4331                f"but got {self.shape} with dimension {i} being empty."
4332            ),
4333        )
4334
4335    self = self.contiguous()
4336
4337    if self.ndim == 3:
4338        nchannels = self.size(0)
4339        result = self.new_empty((nchannels, oheight, owidth))
4340    else:
4341        nbatch = self.size(0)
4342        nchannels = self.size(1)
4343        result = self.new_empty((nbatch, nchannels, oheight, owidth))
4344
4345    return result
4346
4347
4348def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4349    torch._check(
4350        indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4351    )
4352    torch._check(
4353        input.ndim in (4, 5),
4354        lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4355    )
4356    torch._check(
4357        len(output_size) == 3,
4358        lambda: (
4359            f"There should be exactly three elements (depth, height, width) in output_size, "
4360            f"but got {len(output_size)} elements."
4361        ),
4362    )
4363    torch._check(
4364        len(stride) == 3,
4365        lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4366    )
4367    torch._check(
4368        len(padding) == 3,
4369        lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4370    )
4371    torch._check(
4372        input.shape == indices.shape,
4373        lambda: (
4374            f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4375            f"but got indices tensor with shape: {indices.shape}"
4376        ),
4377    )
4378
4379    for i in range(1, input.ndim):
4380        torch._check(
4381            input.size(i) > 0,
4382            lambda: (
4383                f"{fn_name}: "
4384                f"Expected input to have non-zero size for non-batch dimensions, "
4385                f"but got {input.shape} with dimension {i} being empty."
4386            ),
4387        )
4388
4389    torch._check(
4390        stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4391        lambda: f"strides should be greater than zero, but got stride: {stride}",
4392    )
4393
4394
4395@register_meta(aten.max_unpool3d)
4396@out_wrapper()
4397def meta_max_unpool3d(self, indices, output_size, stride, padding):
4398    utils.alert_not_deterministic("max_unpooling3d_forward_out")
4399
4400    _max_unpooling3d_shape_check(
4401        self, indices, output_size, stride, padding, "max_unpooling3d()"
4402    )
4403
4404    self = self.contiguous()
4405
4406    odepth, oheight, owidth = output_size
4407
4408    if self.ndim == 4:
4409        nchannels = self.size(0)
4410        result = self.new_empty((nchannels, odepth, oheight, owidth))
4411    else:
4412        nbatch = self.size(0)
4413        nchannels = self.size(1)
4414        result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4415
4416    return result
4417
4418
4419@register_meta(aten.max_pool3d_with_indices)
4420@out_wrapper("out", "indices")
4421def meta_max_pool3d_with_indices(
4422    input,
4423    kernel_size,
4424    stride=(),
4425    padding=(0,),
4426    dilation=(1,),
4427    ceil_mode=False,
4428):
4429    torch._check(
4430        len(kernel_size) in (1, 3),
4431        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4432    )
4433    kT = kernel_size[0]
4434    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4435    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4436
4437    torch._check(
4438        not stride or len(stride) in (1, 3),
4439        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4440    )
4441    dT = kT if not stride else stride[0]
4442    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4443    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4444
4445    torch._check(
4446        len(padding) in (1, 3),
4447        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4448    )
4449    pT = padding[0]
4450    pH = pT if len(padding) == 1 else padding[1]
4451    pW = pT if len(padding) == 1 else padding[2]
4452
4453    torch._check(
4454        len(dilation) in (1, 3),
4455        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4456    )
4457    dilationT = dilation[0]
4458    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4459    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4460
4461    torch._check(
4462        input.ndim in (4, 5),
4463        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4464    )
4465
4466    nbatch = input.size(-5) if input.ndim == 5 else 1
4467    nslices = input.size(-4)
4468    itime = input.size(-3)
4469    iheight = input.size(-2)
4470    iwidth = input.size(-1)
4471
4472    otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
4473    oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
4474    owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
4475
4476    pool3d_shape_check(
4477        input,
4478        nslices,
4479        kT,
4480        kH,
4481        kW,
4482        dT,
4483        dH,
4484        dW,
4485        pT,
4486        pH,
4487        pW,
4488        dilationT,
4489        dilationH,
4490        dilationW,
4491        itime,
4492        iheight,
4493        iwidth,
4494        otime,
4495        oheight,
4496        owidth,
4497        "max_pool3d_with_indices()",
4498    )
4499
4500    channels_last = (
4501        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4502    )
4503    if input.ndim == 4:
4504        input_channels_last_check = input.unsqueeze(0)
4505        channels_last = (
4506            not input_channels_last_check.is_contiguous()
4507        ) and input_channels_last_check.is_contiguous(
4508            memory_format=torch.channels_last_3d
4509        )
4510        out_shape = (nslices, otime, oheight, owidth)
4511    else:
4512        out_shape = (nbatch, nslices, otime, oheight, owidth)  # type: ignore[assignment]
4513
4514    out = input.new_empty(out_shape)
4515    indices = input.new_empty(out_shape, dtype=torch.int64)
4516
4517    if channels_last:
4518        out = out.to(memory_format=torch.channels_last_3d)
4519        indices = indices.to(memory_format=torch.channels_last_3d)
4520
4521    return out, indices
4522
4523
4524@register_meta(aten.max_pool3d_with_indices_backward)
4525@out_wrapper("grad_input")
4526def meta_max_pool3d_with_indices_backward(
4527    grad_output,
4528    input,
4529    kernel_size,
4530    stride,
4531    padding,
4532    dilation,
4533    ceil_mode,
4534    indices,
4535):
4536    torch._check(
4537        len(kernel_size) in (1, 3),
4538        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4539    )
4540    kT = kernel_size[0]
4541    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4542    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4543
4544    torch._check(
4545        not stride or len(stride) in (1, 3),
4546        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4547    )
4548    dT = kT if not stride else stride[0]
4549    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4550    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4551
4552    torch._check(
4553        len(padding) in (1, 3),
4554        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4555    )
4556    pT = padding[0]
4557    pH = pT if len(padding) == 1 else padding[1]
4558    pW = pT if len(padding) == 1 else padding[2]
4559
4560    torch._check(
4561        len(dilation) in (1, 3),
4562        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4563    )
4564    dilationT = dilation[0]
4565    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4566    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4567
4568    torch._check(
4569        input.ndim in (4, 5),
4570        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4571    )
4572
4573    nslices = input.size(-4)
4574    itime = input.size(-3)
4575    iheight = input.size(-2)
4576    iwidth = input.size(-1)
4577
4578    otime = grad_output.size(-3)
4579    oheight = grad_output.size(-2)
4580    owidth = grad_output.size(-1)
4581
4582    max_pool3d_backward_shape_check(
4583        input,
4584        grad_output,
4585        indices,
4586        nslices,
4587        kT,
4588        kH,
4589        kW,
4590        dT,
4591        dH,
4592        dW,
4593        pT,
4594        pH,
4595        pW,
4596        dilationT,
4597        dilationH,
4598        dilationW,
4599        itime,
4600        iheight,
4601        iwidth,
4602        otime,
4603        oheight,
4604        owidth,
4605        "max_pool3d_with_indices_backward()",
4606    )
4607
4608    channels_last = (
4609        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4610    )
4611    if input.ndim == 4:
4612        input_channels_last_check = input.unsqueeze(0)
4613        channels_last = (
4614            not input_channels_last_check.is_contiguous()
4615        ) and input_channels_last_check.is_contiguous(
4616            memory_format=torch.channels_last_3d
4617        )
4618
4619    grad_input = input.new_empty(input.shape)
4620
4621    if channels_last:
4622        grad_input = grad_input.to(memory_format=torch.channels_last_3d)
4623
4624    return grad_input
4625
4626
4627def check_grid_sampler_common(input: Tensor, grid: Tensor):
4628    torch._check(
4629        input.device == grid.device,
4630        lambda: (
4631            f"grid_sampler(): expected input and grid to be on same device, but input "
4632            f"is on {input.device} and grid is on {grid.device}"
4633        ),
4634    )
4635    torch._check(
4636        input.layout == torch.strided and grid.layout == torch.strided,
4637        lambda: (
4638            f"grid_sampler(): expected input and grid to have torch.strided layout, but "
4639            f"input has {input.layout} and grid has {grid.layout}"
4640        ),
4641    )
4642    torch._check(
4643        input.shape[0] == grid.shape[0],
4644        lambda: (
4645            f"grid_sampler(): expected grid and input to have same batch size, but got "
4646            f"input with sizes {input.shape} and grid with sizes {grid.shape}"
4647        ),
4648    )
4649    torch._check(
4650        grid.shape[-1] == input.ndim - 2,
4651        lambda: (
4652            f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
4653            f"dimension, but got grid with sizes {grid.shape}"
4654        ),
4655    )
4656
4657    for i in range(2, input.ndim):
4658        torch._check(
4659            input.shape[i] > 0,
4660            lambda: (
4661                f"grid_sampler(): expected input to have non-empty spatial dimensions, "
4662                f"but input has sizes {input.shape} with dimension {i} being empty"
4663            ),
4664        )
4665
4666
4667class GridSamplerInterpolation(Enum):
4668    BILINEAR = 0
4669    NEAREST = 1
4670    BICUBIC = 2
4671
4672
4673def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
4674    torch._check(
4675        input.ndim == 5 and input.ndim == grid.ndim,
4676        lambda: (
4677            f"grid_sampler(): expected 5D input and grid with same number of "
4678            f"dimensions, but got input with sizes {input.shape}"
4679            f" and grid with sizes {grid.shape}"
4680        ),
4681    )
4682    torch._check(
4683        not (
4684            input.ndim == 5
4685            and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
4686        ),
4687        lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
4688    )
4689
4690
4691@register_meta(aten.grid_sampler_2d_backward.default)
4692def grid_sampler_2d_backward_meta(
4693    grad_output,
4694    input,
4695    grid,
4696    interpolation_mode,
4697    padding_mode,
4698    align_corners,
4699    output_mask,
4700):
4701    input_requires_grad = output_mask[0]
4702    if input_requires_grad:
4703        grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
4704    else:
4705        grad_input = None
4706    grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
4707    return (grad_input, grad_grid)
4708
4709
4710@register_meta(aten.grid_sampler_3d)
4711@out_wrapper()
4712def grid_sampler_3d(
4713    input,
4714    grid,
4715    interpolation_mode,
4716    padding_mode,
4717    align_corners,
4718):
4719    check_grid_sampler_common(input, grid)
4720    check_grid_sampler_3d(input, grid, interpolation_mode)
4721    N = input.shape[0]
4722    C = input.shape[1]
4723    out_D = grid.shape[1]
4724    out_H = grid.shape[2]
4725    out_W = grid.shape[3]
4726    return input.new_empty((N, C, out_D, out_H, out_W))
4727
4728
4729@register_meta(aten.grid_sampler_3d_backward)
4730@out_wrapper("grad_input", "grad_grid")
4731def grid_sampler_3d_backward(
4732    grad_output,
4733    input,
4734    grid,
4735    interpolation_mode,
4736    padding_mode,
4737    align_corners,
4738    output_mask,
4739):
4740    check_grid_sampler_common(input, grid)
4741    check_grid_sampler_3d(input, grid, interpolation_mode)
4742    input_requires_grad = output_mask[0]
4743    if input_requires_grad:
4744        grad_input = torch.zeros_like(
4745            input, memory_format=torch.legacy_contiguous_format
4746        )
4747    else:
4748        grad_input = None
4749    grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
4750    return grad_input, grad_grid
4751
4752
4753@register_meta([aten.full.default])
4754def full(size, fill_value, *args, **kwargs):
4755    dtype = kwargs.get("dtype", None)
4756    if not dtype:
4757        dtype = utils.get_dtype(fill_value)
4758    kwargs["dtype"] = dtype
4759    return torch.empty(size, *args, **kwargs)
4760
4761
4762# zeros_like is special cased to work for sparse
4763@register_meta(aten.zeros_like.default)
4764def zeros_like(
4765    self,
4766    dtype=None,
4767    layout=None,
4768    device=None,
4769    pin_memory=None,
4770    memory_format=None,
4771):
4772    if layout == torch.sparse_coo:
4773        torch._check(
4774            memory_format is None,
4775            lambda: "memory format option is only supported by strided tensors",
4776        )
4777
4778        res = torch.empty(
4779            0,
4780            dtype=self.dtype if dtype is None else dtype,
4781            layout=layout,
4782            device=self.device if device is None else device,
4783            pin_memory=pin_memory,
4784        )
4785
4786        if self.is_sparse:
4787            res.sparse_resize_and_clear_(
4788                self.size(), self.sparse_dim(), self.dense_dim()
4789            )
4790        else:
4791            res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
4792
4793        res._coalesced_(True)
4794        return res
4795    res = aten.empty_like.default(
4796        self,
4797        dtype=dtype,
4798        layout=layout,
4799        device=device,
4800        pin_memory=pin_memory,
4801        memory_format=memory_format,
4802    )
4803    # device can be not "meta"
4804    res.fill_(0)
4805    return res
4806
4807
4808@register_meta(aten.select.int)
4809def meta_select(self, dim, index):
4810    ndim = self.dim()
4811    torch._check_index(
4812        ndim != 0,
4813        lambda: "select() cannot be applied to a 0-dim tensor.",
4814    )
4815
4816    dim = dim if dim >= 0 else dim + ndim
4817    size = self.size(dim)
4818
4819    torch._check_index(
4820        not (-index > size or index >= size),
4821        lambda: f"select(): index {index} out of range for tensor of size "
4822        f"{self.size()} at dimension {dim}",
4823    )
4824
4825    index = index if index >= 0 else index + size
4826
4827    new_size = list(self.size())
4828    new_stride = list(self.stride())
4829
4830    new_storage_offset = self.storage_offset() + index * new_stride[dim]
4831    del new_size[dim]
4832    del new_stride[dim]
4833
4834    return self.as_strided(new_size, new_stride, new_storage_offset)
4835
4836
4837@register_meta(aten.select_scatter.default)
4838def meta_select_scatter(self, src, dim, index):
4839    return utils.clone_preserve_strides(self)
4840
4841
4842@register_meta(aten.slice_scatter.default)
4843def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
4844    return utils.clone_preserve_strides(self)
4845
4846
4847# TODO: Deduplicate this with canonicalize_dim
4848def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
4849    if dim_post_expr <= 0:
4850        assert wrap_scalar
4851        dim_post_expr = 1
4852    min = -dim_post_expr
4853    max = dim_post_expr - 1
4854    assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
4855    if dim < 0:
4856        dim += dim_post_expr
4857    return dim
4858
4859
4860def ensure_nonempty_size(t, dim):
4861    return 1 if t.dim() == 0 else t.shape[dim]
4862
4863
4864# From aten/src/ATen/native/ScatterGatherChecks.h
4865def gather_shape_check(self, dim, index):
4866    self_dims = max(self.dim(), 1)
4867    index_dims = max(index.dim(), 1)
4868    torch._check(
4869        self_dims == index_dims,
4870        lambda: "Index tensor must have the same number of dimensions as input tensor",
4871    )
4872    for i in range(self_dims):
4873        if i != dim:
4874            torch._check(
4875                ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
4876                lambda: f"Size does not match at dimension {i} expected index {index.shape}"
4877                + f" to be smaller than self {self.shape} apart from dimension {dim}",
4878            )
4879
4880
4881@register_meta(aten.gather.default)
4882def meta_gather(self, dim, index, sparse_grad=False):
4883    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4884
4885    wrapped_dim = maybe_wrap_dim(dim, self.dim())
4886    is_index_empty = guard_size_oblivious(index.numel() == 0)
4887    if not is_index_empty:
4888        torch._check(
4889            index.dtype == torch.long,
4890            lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
4891        )
4892        gather_shape_check(self, wrapped_dim, index)
4893    return self.new_empty(index.shape)
4894
4895
4896# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
4897def get_operator_enum(reduce_, use_new_options=False):
4898    if use_new_options:
4899        if reduce_ == "sum":
4900            return "REDUCE_ADD"
4901        elif reduce_ == "prod":
4902            return "REDUCE_MULTIPLY"
4903        elif reduce_ == "mean":
4904            return "REDUCE_MEAN"
4905        elif reduce_ == "amax":
4906            return "REDUCE_MAXIMUM"
4907        elif reduce_ == "amin":
4908            return "REDUCE_MINIMUM"
4909        torch._check(
4910            False,
4911            lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
4912        )
4913        return
4914    else:
4915        if reduce_ == "add":
4916            return "REDUCE_ADD"
4917        elif reduce_ == "multiply":
4918            return "REDUCE_MULTIPLY"
4919        torch._check(False, lambda: "reduce argument must be either add or multiply.")
4920        return
4921
4922
4923# From aten/src/ATen/native/ScatterGatherChecks.h
4924def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
4925    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4926
4927    if guard_size_oblivious(index.numel() != 0):
4928        torch._check(
4929            index.dtype == torch.long,
4930            lambda: f"{method_name}(): Expected dtype int64 for index",
4931        )
4932
4933    if src_opt is not None:
4934        torch._check(
4935            self.dtype == src_opt.dtype,
4936            lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
4937        )
4938
4939
4940def ensure_nonempty_dim(dim):
4941    return max(dim, 1)
4942
4943
4944# From aten/src/ATen/native/ScatterGatherChecks.h
4945def scatter_shape_check(self, dim, index, src_opt=None):
4946    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4947
4948    if guard_size_oblivious(index.numel() == 0):
4949        return
4950    torch._check(
4951        ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4952        lambda: "Index tensor must have the same number of dimensions as self tensor",
4953    )
4954
4955    is_wrong_shape = False
4956    self_dims = ensure_nonempty_dim(self.dim())
4957
4958    # Check: index.size(d) <= self.size(d) for all d != dim
4959    for d in range(self_dims):
4960        index_d_size = ensure_nonempty_size(index, d)
4961        if d == dim:
4962            continue
4963        if index_d_size > ensure_nonempty_size(self, d):
4964            is_wrong_shape = True
4965            break
4966
4967    # Check: index.size(d) <= src.size(d) for all d if src is Tensor
4968    if not is_wrong_shape and src_opt is not None:
4969        for d in range(self_dims):
4970            index_d_size = ensure_nonempty_size(index, d)
4971            if index_d_size > ensure_nonempty_size(src_opt, d):
4972                is_wrong_shape = True
4973                break
4974
4975    if src_opt is not None:
4976        torch._check(
4977            ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4978            lambda: "Index tensor must have the same number of dimensions as self tensor",
4979        )
4980        torch._check(
4981            not is_wrong_shape,
4982            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4983            + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
4984        )
4985    else:
4986        torch._check(
4987            not is_wrong_shape,
4988            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4989            + f" apart from dimension {dim}",
4990        )
4991
4992
4993# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
4994def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
4995    wrapped_dim = maybe_wrap_dim(dim, self.dim())
4996    scatter_gather_dtype_check("scatter", self, index, src)
4997    scatter_shape_check(self, wrapped_dim, index, src)
4998    if reduce_ is not None:
4999        # Check if we have a valid reduce operator.
5000        get_operator_enum(reduce_, use_new_options)
5001
5002
5003@register_meta(aten.scatter_add.default)
5004def meta_scatter_add(self, dim, index, src):
5005    scatter_meta_impl(self, dim, index, src, "add")
5006    return self.new_empty(self.shape)
5007
5008
5009@register_meta(aten.scatter_add_)
5010def meta_scatter_add_(self, dim, index, src):
5011    scatter_meta_impl(self, dim, index, src, "add")
5012    return self
5013
5014
5015@register_meta(
5016    [
5017        aten.scatter.src,
5018        aten.scatter.value,
5019        aten.scatter.reduce,
5020        aten.scatter.value_reduce,
5021    ]
5022)
5023@out_wrapper()
5024def meta_scatter(self, dim, index, src_or_value, reduce=None):
5025    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5026    scatter_meta_impl(self, dim, index, src, reduce)
5027    return self.new_empty(self.shape)
5028
5029
5030@register_meta(
5031    [
5032        aten.scatter_.src,
5033        aten.scatter_.value,
5034        aten.scatter_.reduce,
5035        aten.scatter_.value_reduce,
5036    ]
5037)
5038def meta_scatter_(self, dim, index, src_or_value, reduce=None):
5039    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5040    scatter_meta_impl(self, dim, index, src, reduce)
5041    return self
5042
5043
5044@register_meta([aten._scaled_dot_product_flash_attention])
5045def meta__scaled_dot_product_flash_attention(
5046    query: Tensor,
5047    key: Tensor,
5048    value: Tensor,
5049    dropout_p: float = 0.0,
5050    is_causal: bool = False,
5051    return_debug_mask: bool = False,
5052    scale: Optional[float] = None,
5053):
5054    batch_size = query.size(0)
5055    num_heads = query.size(1)
5056    max_seqlen_batch_q = query.size(2)
5057    head_dim = query.size(3)
5058    max_seqlen_batch_k = key.size(2)
5059
5060    query_t = query.transpose(1, 2)
5061    attention = torch.empty_like(query_t).transpose(1, 2)
5062    logsumexp = torch.empty(
5063        (batch_size, num_heads, max_seqlen_batch_q),
5064        dtype=torch.float,
5065        device=query.device,
5066    )
5067
5068    if return_debug_mask:
5069        blocksize_c = 128 if head_dim > 64 else 256
5070        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5071        if max_seqlen_batch_k <= 128:
5072            max_seqlen_k = 128
5073        elif max_seqlen_batch_k <= 256:
5074            max_seqlen_k = 256
5075        debug_mask = torch.empty(
5076            (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5077            dtype=query.dtype,
5078            device=query.device,
5079        )
5080    else:
5081        debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5082
5083    # Note [Seed and Offset]: device for seed and offset below depends on whether we are
5084    # capturing or not, but at the time of tracing we don't know if we
5085    # are going to use cudagraphs or not, so we return meta tensors here
5086    # it's possible we'll need to have some special handling in inductor for sdpa
5087
5088    return (
5089        attention,
5090        logsumexp,
5091        None,
5092        None,
5093        max_seqlen_batch_q,
5094        max_seqlen_batch_k,
5095        torch.empty((), dtype=torch.long, device="meta"),
5096        torch.empty((), dtype=torch.long, device="meta"),
5097        debug_mask,
5098    )
5099
5100
5101@register_meta([aten._scaled_dot_product_cudnn_attention])
5102def meta__scaled_dot_product_cudnn_attention(
5103    query: Tensor,
5104    key: Tensor,
5105    value: Tensor,
5106    attn_bias: Optional[Tensor],
5107    compute_log_sumexp: bool,
5108    dropout_p: float = 0.0,
5109    is_causal: bool = False,
5110    return_debug_mask: bool = False,
5111    scale: Optional[float] = None,
5112):
5113    B = query.size(0)
5114    H = query.size(1)
5115    S_Q = query.size(2)
5116    S_KV = key.size(2)
5117    D_QK = query.size(-1)
5118    D_V = value.size(-1)
5119
5120    res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
5121    logsum_exp = torch.empty(
5122        (B, H, S_Q),
5123        dtype=torch.float,
5124        device=query.device,
5125    )
5126
5127    # See Note [Seed and Offset]
5128    seed = torch.empty((), dtype=torch.long, device="meta")
5129    offset = torch.empty((), dtype=torch.long, device="meta")
5130
5131    return (
5132        res,
5133        logsum_exp,
5134        None,
5135        None,
5136        S_Q,
5137        S_KV,
5138        seed,
5139        offset,
5140        None,
5141    )
5142
5143
5144@register_meta(
5145    [
5146        aten._scaled_dot_product_flash_attention_backward,
5147    ]
5148)
5149def meta__scaled_dot_product_flash_backward(
5150    grad_out: Tensor,
5151    query: Tensor,
5152    key: Tensor,
5153    value: Tensor,
5154    out: Tensor,
5155    logsumexp: Tensor,
5156    cum_seq_q: Tensor,
5157    cum_seq_k: Tensor,
5158    max_q: int,
5159    max_k: int,
5160    dropout_p: float,
5161    is_causal: bool,
5162    philox_seed: Tensor,
5163    philox_offset: Tensor,
5164    scale: Optional[float] = None,
5165):
5166    grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
5167    grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
5168    grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
5169    return grad_q, grad_k, grad_v
5170
5171
5172@register_meta(
5173    [
5174        aten._scaled_dot_product_flash_attention_for_cpu,
5175    ]
5176)
5177def meta__scaled_dot_product_flash_attention_for_cpu(
5178    query: Tensor,
5179    key: Tensor,
5180    value: Tensor,
5181    dropout_p: float = 0.0,
5182    is_causal: bool = False,
5183    attn_mask: Optional[Tensor] = None,
5184    scale: Optional[float] = None,
5185):
5186    batch_size = query.size(0)
5187    num_heads = query.size(1)
5188    max_seqlen_batch_q = query.size(2)
5189    head_dim = query.size(3)
5190
5191    attention = torch.empty_like(query)
5192    logsumexp = torch.empty(
5193        (
5194            batch_size,
5195            max_seqlen_batch_q,
5196            num_heads,
5197        ),
5198        dtype=torch.float,
5199        device=query.device,
5200    ).transpose(1, 2)
5201    return (
5202        attention,
5203        logsumexp,
5204    )
5205
5206
5207@register_meta(
5208    [
5209        aten._scaled_dot_product_flash_attention_for_cpu_backward,
5210    ]
5211)
5212def meta__scaled_dot_product_flash_attention_for_cpu_backward(
5213    grad_out: Tensor,
5214    query: Tensor,
5215    key: Tensor,
5216    value: Tensor,
5217    out: Tensor,
5218    logsumexp: Tensor,
5219    dropout_p: float,
5220    is_causal: bool,
5221    attn_mask: Optional[Tensor] = None,
5222    scale: Optional[float] = None,
5223):
5224    # cpus's grad layout is different from cuda's,
5225    # i.e. (batch_size, seq_len,num_heads, head_dim)
5226    batch_size = query.size(0)
5227    num_heads = query.size(1)
5228    head_dim = query.size(3)
5229    len_q = query.size(2)
5230    len_k = key.size(2)
5231
5232    grad_q = torch.empty_permuted(
5233        (batch_size, num_heads, len_q, head_dim),
5234        (0, 2, 1, 3),
5235        dtype=query.dtype,
5236        device=query.device,
5237    )
5238    grad_k = torch.empty_permuted(
5239        (batch_size, num_heads, len_k, head_dim),
5240        (0, 2, 1, 3),
5241        dtype=key.dtype,
5242        device=key.device,
5243    )
5244    grad_v = torch.empty_permuted(
5245        (batch_size, num_heads, len_k, head_dim),
5246        (0, 2, 1, 3),
5247        dtype=value.dtype,
5248        device=value.device,
5249    )
5250
5251    return grad_q, grad_k, grad_v
5252
5253
5254@register_meta([aten._scaled_dot_product_efficient_attention])
5255def meta__scaled_dot_product_efficient_attention(
5256    query: Tensor,
5257    key: Tensor,
5258    value: Tensor,
5259    attn_bias: Optional[Tensor],
5260    compute_log_sumexp: bool,
5261    dropout_p=0.0,
5262    is_causal: bool = False,
5263    scale: Optional[float] = None,
5264):
5265    query = query.transpose(1, 2)
5266    key = key.transpose(1, 2)
5267    value = value.transpose(1, 2)
5268
5269    B = query.size(0)
5270    M = query.size(1)
5271    N = key.size(1)
5272    num_heads = query.size(-2)
5273    K = query.size(-1)
5274    Kv = value.size(-1)
5275
5276    res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5277
5278    logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5279    logsum_exp = torch.empty(
5280        (B, num_heads, logsumexp_dim),
5281        dtype=torch.float,
5282        device=query.device,
5283    )
5284
5285    res = res.transpose(1, 2)
5286
5287    # See Note [Seed and Offset]:
5288    seed = torch.empty((), dtype=torch.long, device="meta")
5289    offset = torch.empty((), dtype=torch.long, device="meta")
5290
5291    return res, logsum_exp, seed, offset
5292
5293
5294@register_meta(
5295    [
5296        aten._scaled_dot_product_efficient_attention_backward,
5297    ]
5298)
5299def meta__scaled_dot_product_efficient_backward(
5300    grad_out: Tensor,
5301    query: Tensor,
5302    key: Tensor,
5303    value: Tensor,
5304    attn_bias: Optional[Tensor],
5305    out: Tensor,
5306    logsumexp: Tensor,
5307    philox_seed: Tensor,
5308    philox_offset: Tensor,
5309    dropout_p: float,
5310    grad_input_mask: List[bool],
5311    is_causal: bool = False,
5312    scale: Optional[float] = None,
5313):
5314    batch_size = query.size(0)
5315    num_heads = query.size(1)
5316    max_q = query.size(2)
5317    head_dim = query.size(3)
5318    head_dim_v = value.size(3)
5319
5320    max_k = key.size(2)
5321
5322    grad_q = torch.empty_permuted(
5323        (batch_size, num_heads, max_q, head_dim),
5324        (0, 2, 1, 3),
5325        dtype=query.dtype,
5326        device=query.device,
5327    )
5328    grad_k = torch.empty_permuted(
5329        (batch_size, num_heads, max_k, head_dim),
5330        (0, 2, 1, 3),
5331        dtype=key.dtype,
5332        device=key.device,
5333    )
5334    grad_v = torch.empty_permuted(
5335        (batch_size, num_heads, max_k, head_dim_v),
5336        (0, 2, 1, 3),
5337        dtype=value.dtype,
5338        device=value.device,
5339    )
5340    grad_bias = None
5341    if attn_bias is not None and grad_input_mask[3]:
5342        lastDim = attn_bias.size(-1)
5343        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5344        new_sizes = list(attn_bias.size())
5345        new_sizes[-1] = lastDimAligned
5346        grad_bias = torch.empty(
5347            new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5348        )
5349        grad_bias = grad_bias[..., :lastDim]
5350
5351    return grad_q, grad_k, grad_v, grad_bias
5352
5353
5354@register_meta(
5355    [
5356        aten._scaled_dot_product_cudnn_attention_backward,
5357    ]
5358)
5359def meta__scaled_dot_product_cudnn_backward(
5360    grad_out: Tensor,
5361    query: Tensor,
5362    key: Tensor,
5363    value: Tensor,
5364    out: Tensor,
5365    logsumexp: Tensor,
5366    philox_seed: Tensor,
5367    philox_offset: Tensor,
5368    attn_bias: Tensor,
5369    cum_seq_q: Tensor,
5370    cum_seq_k: Tensor,
5371    max_q: int,
5372    max_k: int,
5373    dropout_p: float,
5374    is_causal: bool,
5375    scale: Optional[float] = None,
5376):
5377    grad_q = torch.empty_like(query)
5378    grad_k = torch.empty_like(key)
5379    grad_v = torch.empty_like(value)
5380    return grad_q, grad_k, grad_v
5381
5382
5383@register_meta(
5384    [
5385        aten._flash_attention_forward,
5386    ]
5387)
5388def meta__flash_attention_forward(
5389    query: Tensor,
5390    key: Tensor,
5391    value: Tensor,
5392    cum_seq_q: Optional[Tensor],
5393    cum_seq_k: Optional[Tensor],
5394    max_q: int,
5395    max_k: int,
5396    dropout_p: float,
5397    is_causal: bool,
5398    return_debug_mask: bool,
5399    scale: Optional[float] = None,
5400    window_size_left: Optional[int] = None,
5401    window_size_right: Optional[int] = None,
5402    seqused_k: Optional[Tensor] = None,
5403    alibi_slopes: Optional[Tensor] = None,
5404):
5405    # NB: there are two underlying paths:
5406    # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
5407    # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
5408    #    includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
5409    batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
5410    max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
5411    max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
5412    num_heads = query.size(-2)
5413    head_dim = query.size(-1)
5414
5415    # Cuda Path
5416    attention = torch.empty_like(query)
5417    logsumexp = torch.empty(
5418        (batch_size, num_heads, max_seqlen_batch_q),
5419        dtype=torch.float,
5420        device=query.device,
5421    )
5422
5423    if return_debug_mask:
5424        blocksize_c = 128 if head_dim > 64 else 256
5425        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5426        if max_seqlen_batch_k <= 128:
5427            max_seqlen_k = 128
5428        elif max_seqlen_batch_k <= 256:
5429            max_seqlen_k = 256
5430        debug_mask = torch.empty(
5431            (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5432            dtype=query.dtype,
5433            device=query.device,
5434        )
5435    else:
5436        debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5437
5438    # See Note [Seed and Offset]:
5439    return (
5440        attention,
5441        logsumexp,
5442        torch.empty((), dtype=torch.long, device="meta"),
5443        torch.empty((), dtype=torch.long, device="meta"),
5444        debug_mask,
5445    )
5446
5447
5448@register_meta(
5449    [
5450        aten._flash_attention_backward,
5451    ]
5452)
5453def meta__flash_attention_backward(
5454    grad_out: Tensor,
5455    query: Tensor,
5456    key: Tensor,
5457    value: Tensor,
5458    out: Tensor,
5459    logsumexp: Tensor,
5460    cum_seq_q: Tensor,
5461    cum_seq_k: Tensor,
5462    max_q: int,
5463    max_k: int,
5464    dropout_p: float,
5465    is_causal: bool,
5466    philox_seed: Tensor,
5467    philox_offset: Tensor,
5468    scale: Optional[float] = None,
5469    window_size_left: Optional[int] = None,
5470    window_size_right: Optional[int] = None,
5471):
5472    grad_query = torch.empty_like(query)
5473    grad_key = torch.empty_like(key)
5474    grad_value = torch.empty_like(value)
5475
5476    return grad_query, grad_key, grad_value
5477
5478
5479@register_meta(
5480    [
5481        aten._efficient_attention_forward,
5482    ]
5483)
5484def meta__efficient_attention_forward(
5485    query: Tensor,
5486    key: Tensor,
5487    value: Tensor,
5488    bias: Optional[Tensor],
5489    cu_seqlens_q: Optional[Tensor],
5490    cu_seqlens_k: Optional[Tensor],
5491    max_seqlen_q: Optional[int],
5492    max_seqlen_k: Optional[int],
5493    dropout_p: float,
5494    custom_mask_type: int,
5495    compute_log_sumexp: bool = False,
5496    scale: Optional[float] = None,
5497    causal_diagonal: Optional[Tensor] = None,
5498    seqlen_k: Optional[Tensor] = None,
5499    window_size: Optional[int] = None,
5500):
5501    B = query.size(0)
5502    M = query.size(1)
5503    N = key.size(1)
5504    num_heads = query.size(-2)
5505    K = query.size(-1)
5506    Kv = value.size(-1)
5507
5508    res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5509
5510    logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
5511    actual_max_seqlen_q = M
5512    if cu_seqlens_q is not None:
5513        assert max_seqlen_q is not None
5514        actual_max_seqlen_q = max_seqlen_q
5515    actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
5516    logsumexp_dim = (
5517        math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
5518    )
5519    logsum_exp = torch.empty(
5520        (logsumexp_batch_dim, num_heads, logsumexp_dim),
5521        dtype=torch.float,
5522        device=query.device,
5523    )
5524
5525    # See Note [Seed and Offset]:
5526    seed = torch.empty((), dtype=torch.long, device="meta")
5527    offset = torch.empty((), dtype=torch.long, device="meta")
5528
5529    return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
5530
5531
5532@register_meta(
5533    [
5534        aten._efficient_attention_backward,
5535    ]
5536)
5537def meta__efficient_attention_backward(
5538    grad_out: Tensor,
5539    query: Tensor,
5540    key: Tensor,
5541    value: Tensor,
5542    bias: Optional[Tensor],
5543    cu_seqlens_q: Optional[Tensor],
5544    cu_seqlens_k: Optional[Tensor],
5545    max_seqlen_q: torch.SymInt,
5546    max_seqlen_k: torch.SymInt,
5547    logsumexp: Tensor,
5548    dropout_p: float,
5549    philox_seed: Tensor,
5550    philox_offset: Tensor,
5551    custom_mask_type: int,
5552    bias_requires_grad: bool,
5553    scale: Optional[float] = None,
5554    num_splits_key: Optional[int] = None,
5555    shared_storage_dqdkdv: bool = False,
5556):
5557    if shared_storage_dqdkdv:
5558        torch._check(
5559            query.shape[1] == key.shape[1],
5560            lambda: "seqlen must match for `shared_storage_dqdkdv",
5561        )
5562        torch._check(
5563            query.shape[3] == key.shape[3],
5564            lambda: "embedding dim must match for `shared_storage_dqdkdv",
5565        )
5566        chunk = torch.empty(
5567            (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
5568            dtype=query.dtype,
5569            device=query.device,
5570        )
5571        grad_query = chunk.select(-3, 0)
5572        grad_key = chunk.select(-3, 1)
5573        grad_value = chunk.select(-3, 2)
5574    else:
5575        grad_query = torch.empty_like(query)
5576        grad_key = torch.empty_like(key)
5577        grad_value = torch.empty_like(value)
5578
5579    if bias is not None:
5580        lastDim = bias.size(-1)
5581        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5582        new_sizes = list(bias.size())
5583        new_sizes[-1] = lastDimAligned
5584        grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
5585        grad_bias = grad_bias[..., :lastDim]
5586    else:
5587        grad_bias = torch.empty((), device=query.device)
5588
5589    return grad_query, grad_key, grad_value, grad_bias
5590
5591
5592@register_meta([aten._scaled_mm.default])
5593def meta_scaled_mm(
5594    self: torch.Tensor,
5595    mat2: torch.Tensor,
5596    scale_a: torch.Tensor,
5597    scale_b: torch.Tensor,
5598    bias: Optional[torch.Tensor] = None,
5599    scale_result: Optional[torch.Tensor] = None,
5600    out_dtype: Optional[torch.dtype] = None,
5601    use_fast_accum: bool = False,
5602):
5603    def is_row_major(stride):
5604        return stride[0] > stride[1] and stride[1] == 1
5605
5606    def is_col_major(stride):
5607        return stride[0] == 1 and stride[1] > 1
5608
5609    def is_fp8_type(dtype):
5610        return dtype in (
5611            torch.float8_e4m3fn,
5612            torch.float8_e5m2,
5613            torch.float8_e4m3fnuz,
5614            torch.float8_e5m2fnuz,
5615        )
5616
5617    torch._check(
5618        self.dim() == 2 and mat2.dim() == 2,
5619        lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5620    )
5621    torch._check(
5622        is_row_major(self.stride()),
5623        lambda: "self must be row_major",
5624    )
5625    torch._check(
5626        is_col_major(mat2.stride()),
5627        lambda: "mat2 must be col_major",
5628    )
5629    torch._check(
5630        self.size(1) % 16 == 0,
5631        lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
5632    )
5633    torch._check(
5634        mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5635        lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5636    )
5637    torch._check(
5638        is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5639        lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5640    )
5641
5642    # determine scaling type and check input dimensions (refer to Blas.cpp op)
5643    torch._check(
5644        scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
5645        lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
5646    )
5647    m, k = self.shape
5648    n = mat2.size(1)
5649    if scale_a.numel() == 1 and scale_b.numel() == 1:
5650        # tensorwise scaling
5651        pass
5652    else:
5653        # for non-tensorwise scaling, enforce 2D input tensors
5654        torch._check(
5655            scale_a.dim() == 2 and scale_b.dim() == 2,
5656            lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}",
5657        )
5658
5659        if (
5660            scale_a.size(0) == m
5661            and scale_a.size(1) == 1
5662            and scale_b.size(0) == 1
5663            and scale_b.size(1) == n
5664        ):
5665            # rowwise scaling
5666            torch._check(
5667                scale_a.is_contiguous() and scale_b.is_contiguous(),
5668                lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
5669            )
5670        else:
5671            # does not match any valid scaling type
5672            torch._check(
5673                False,
5674                lambda: (
5675                    "Invalid scaling configuration. "
5676                    "For tensorwise scaling, both scales should be scalar. "
5677                    f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
5678                    f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
5679                    f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
5680                ),
5681            )
5682
5683    _out_dtype = out_dtype if out_dtype is not None else self.dtype
5684    return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)
5685
5686
5687@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
5688@out_wrapper()
5689def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
5690    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5691    return self.new_empty(self.shape)
5692
5693
5694@register_meta(aten.scatter_reduce_.two)
5695def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
5696    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5697    return self
5698
5699
5700@register_meta([aten.multinomial.default, aten.multinomial.out])
5701@out_wrapper()
5702def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
5703    torch._check(
5704        0 < input.dim() <= 2,
5705        lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
5706    )
5707    if input.dim() == 1:
5708        return torch.empty(num_samples, dtype=torch.long, device=input.device)
5709    return torch.empty(
5710        input.size(0), num_samples, dtype=torch.long, device=input.device
5711    )
5712
5713
5714def multiply_integers(vs):
5715    r = 1
5716    for v in vs:
5717        r *= v
5718    return r
5719
5720
5721def upsample_common_check(input_size, output_size, num_spatial_dims):
5722    torch._check(
5723        len(output_size) == num_spatial_dims,
5724        lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
5725    )
5726    expected_input_dims = num_spatial_dims + 2  # N, C, ...
5727    torch._check(
5728        len(input_size) == expected_input_dims,
5729        lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
5730    )
5731
5732    torch._check(
5733        all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
5734        lambda: f"Input and output sizes should be greater than 0, but got "
5735        f"input size {input_size} and output size {output_size}",
5736    )
5737
5738    nbatch, channels = input_size[:2]
5739    return (nbatch, channels, *output_size)
5740
5741
5742@register_meta(
5743    [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
5744)
5745def upsample_nearest1d(input, output_size, scales=None):
5746    torch._check(
5747        input.numel() != 0 or multiply_integers(input.size()[1:]),
5748        lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
5749    )
5750    full_output_size = upsample_common_check(
5751        input.size(), output_size, num_spatial_dims=1
5752    )
5753    return input.new_empty(full_output_size).to(
5754        memory_format=utils.suggest_memory_format(input)
5755    )
5756
5757
5758@register_meta(
5759    [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
5760)
5761def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
5762    torch._check(
5763        input.numel() != 0 or multiply_integers(input.size()[1:]),
5764        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5765    )
5766    full_output_size = upsample_common_check(
5767        input.size(), output_size, num_spatial_dims=2
5768    )
5769    output = input.new_empty(full_output_size)
5770
5771    # convert output to correct memory format, if necessary
5772    memory_format = utils.suggest_memory_format(input)
5773
5774    # following "heuristic: only use channels_last path when it's faster than the contiguous path"
5775    _, n_channels, _, _ = input.shape
5776    if input.device.type == "cuda" and n_channels < 4:
5777        memory_format = torch.contiguous_format
5778
5779    output = output.contiguous(memory_format=memory_format)
5780
5781    return output
5782
5783
5784@register_meta(
5785    [
5786        aten.upsample_nearest2d_backward.default,
5787        aten._upsample_nearest_exact2d_backward.default,
5788    ]
5789)
5790def upsample_nearest2d_backward(
5791    grad_output: Tensor,
5792    output_size: Sequence[Union[int, torch.SymInt]],
5793    input_size: Sequence[Union[int, torch.SymInt]],
5794    scales_h: Optional[float] = None,
5795    scales_w: Optional[float] = None,
5796):
5797    full_output_size = upsample_common_check(
5798        input_size, output_size, num_spatial_dims=2
5799    )
5800    torch._check(
5801        grad_output.ndim == 4,
5802        lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
5803    )
5804    for i in range(4):
5805        torch._check(
5806            grad_output.size(i) == full_output_size[i],
5807            lambda: (
5808                f"Expected grad_output to have the same shape as output;"
5809                f" output.size({i}) = {full_output_size[i]}"
5810                f" but got grad_output.size({i}) = {grad_output.size(i)}"
5811            ),
5812        )
5813
5814    return grad_output.new_empty(input_size).to(
5815        memory_format=utils.suggest_memory_format(grad_output)
5816    )  # type: ignore[call-overload]
5817
5818
5819@register_meta(
5820    [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
5821)
5822def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
5823    torch._check(
5824        input.numel() != 0 or multiply_integers(input.size()[1:]),
5825        lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
5826    )
5827    full_output_size = upsample_common_check(
5828        input.size(), output_size, num_spatial_dims=3
5829    )
5830    return input.new_empty(full_output_size).to(
5831        memory_format=utils.suggest_memory_format(input)
5832    )
5833
5834
5835@register_meta(
5836    [
5837        aten.sort.default,
5838        aten.sort.stable,
5839        aten.sort.values,
5840        aten.sort.values_stable,
5841    ]
5842)
5843def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
5844    v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
5845    if values is not None and indices is not None:
5846        assert isinstance(values, TensorLike)
5847        assert isinstance(indices, TensorLike)
5848        # Makes sure values and indices have the same strides. For cases where
5849        # these have different shapes, like (5, 10, 5) and (0) in msort.
5850        out_shape = v.shape
5851        out_stride = v.stride()
5852        values = _maybe_resize_out(values, out_shape)
5853        indices = _maybe_resize_out(indices, out_shape)
5854        values.as_strided_(out_shape, out_stride)
5855        indices.as_strided_(out_shape, out_stride)
5856        _safe_copy_out(copy_from=v, copy_to=values)  # type: ignore[arg-type]
5857        _safe_copy_out(copy_from=i, copy_to=indices)  # type: ignore[arg-type]
5858        return values, indices
5859    return v, i
5860
5861
5862def rnn_cell_checkSizes(
5863    input_gates,
5864    hidden_gates,
5865    input_bias,
5866    hidden_bias,
5867    factor,
5868    prev_hidden,
5869):
5870    torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
5871    torch._check(
5872        input_gates.shape == hidden_gates.shape,
5873        lambda: f"{input_gates.shape} != {hidden_gates.shape}",
5874    )
5875    gates_size = input_gates.size(1)
5876    if input_bias is not None:
5877        torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
5878        torch._check(
5879            input_bias.numel() == gates_size,
5880            lambda: f"{input_bias.numel()} != {gates_size}",
5881        )
5882        torch._check(
5883            input_bias.shape == hidden_bias.shape,
5884            lambda: f"{input_bias.shape} != {hidden_bias.shape}",
5885        )
5886    torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
5887    expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
5888    torch._check(
5889        prev_hidden.numel() == expected_prev_hidden_numel,
5890        lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
5891    )
5892    torch._check(
5893        all(
5894            x.device == input_gates.device
5895            for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
5896        ),
5897        lambda: "expected all inputs to be same device",
5898    )
5899
5900
5901@register_meta(aten._thnn_fused_lstm_cell.default)
5902def _thnn_fused_lstm_cell_meta(
5903    input_gates,
5904    hidden_gates,
5905    cx,
5906    input_bias=None,
5907    hidden_bias=None,
5908):
5909    rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
5910    workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
5911    hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5912    cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5913    return (hy, cy, workspace)
5914
5915
5916@register_meta(aten._cudnn_rnn.default)
5917def _cudnn_rnn(
5918    input,
5919    weight,
5920    weight_stride0,
5921    weight_buf,
5922    hx,
5923    cx,
5924    mode,
5925    hidden_size,
5926    proj_size,
5927    num_layers,
5928    batch_first,
5929    dropout,
5930    train,
5931    bidirectional,
5932    batch_sizes,
5933    dropout_state,
5934):
5935    is_input_packed = len(batch_sizes) != 0
5936    if is_input_packed:
5937        seq_length = len(batch_sizes)
5938        mini_batch = batch_sizes[0]
5939        batch_sizes_sum = input.shape[0]
5940    else:
5941        seq_length = input.shape[1] if batch_first else input.shape[0]
5942        mini_batch = input.shape[0] if batch_first else input.shape[1]
5943        batch_sizes_sum = -1
5944
5945    num_directions = 2 if bidirectional else 1
5946    out_size = proj_size if proj_size != 0 else hidden_size
5947    if is_input_packed:
5948        out_shape = [batch_sizes_sum, out_size * num_directions]
5949    else:
5950        out_shape = (
5951            [mini_batch, seq_length, out_size * num_directions]
5952            if batch_first
5953            else [seq_length, mini_batch, out_size * num_directions]
5954        )
5955    output = input.new_empty(out_shape)
5956
5957    cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
5958    if cx is None:
5959        cy = torch.empty(0, device=input.device)
5960    else:
5961        cy = cx.new_empty(cell_shape)
5962
5963    hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
5964
5965    # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
5966    reserve_shape = 0 if train else 0
5967    reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
5968
5969    return output, hy, cy, reserve, weight_buf
5970
5971
5972@register_meta(aten.mkldnn_rnn_layer.default)
5973def mkldnn_rnn_layer(
5974    input,
5975    w0,
5976    w1,
5977    w2,
5978    w3,
5979    hx_,
5980    cx_,
5981    reverse,
5982    batch_sizes,
5983    mode,
5984    hidden_size,
5985    num_layers,
5986    has_biases,
5987    bidirectional,
5988    batch_first,
5989    train,
5990):
5991    seq_length = input.shape[1] if batch_first else input.shape[0]
5992    mini_batch = input.shape[0] if batch_first else input.shape[1]
5993    output_chanels = hidden_size
5994    out_shape = (
5995        [mini_batch, seq_length, output_chanels]
5996        if batch_first
5997        else [seq_length, mini_batch, output_chanels]
5998    )
5999    output = input.new_empty(out_shape)
6000    if hx_ is None:
6001        hy = torch.empty(0, device=input.device)
6002    else:
6003        hy = hx_.new_empty(hx_.shape)
6004    if cx_ is None:
6005        cy = torch.empty(0, device=input.device)
6006    else:
6007        cy = cx_.new_empty(cx_.shape)
6008    workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
6009    return output, hy, cy, workspace
6010
6011
6012def zero_numel_check_dims(self, dim, fn_name):
6013    if self.ndim == 0:
6014        torch._check_index(
6015            dim == 0 or dim == -1,
6016            lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
6017        )
6018    else:
6019        torch._check_index(
6020            self.size(dim) != 0,
6021            lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
6022        )
6023
6024
6025# From aten/src/ATen/native/ReduceOps.cpp
6026def check_argmax_argmin(name, self, dim):
6027    if dim is not None:
6028        dim = maybe_wrap_dim(dim, self.dim())
6029        zero_numel_check_dims(self, dim, name)
6030    else:
6031        torch._check(
6032            self.numel() != 0,
6033            lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
6034        )
6035
6036
6037@register_meta([aten.argmax.default, aten.argmin.default])
6038def argmax_argmin_meta(self, dim=None, keepdim=False):
6039    check_argmax_argmin("argmax", self, dim)
6040    dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
6041    shape = _compute_reduction_shape(self, dims, keepdim)
6042    return self.new_empty(shape, dtype=torch.int64)
6043
6044
6045@register_meta(aten.scalar_tensor.default)
6046def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
6047    return torch.empty(
6048        (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
6049    )
6050
6051
6052@register_meta(aten.topk.default)
6053def topk_meta(self, k, dim=-1, largest=True, sorted=True):
6054    # From aten/src/ATen/native/Sorting.cpp
6055    dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6056    sliceSize = 1 if self.dim() == 0 else self.size(dim)
6057    torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
6058
6059    topKSize = list(self.shape)
6060    if len(topKSize) > 0:
6061        topKSize[dim] = k
6062    return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
6063
6064
6065@register_meta([aten.kthvalue.default, aten.kthvalue.values])
6066@out_wrapper("values", "indices")
6067def kthvalue_meta(self, k, dim=-1, keepdim=False):
6068    dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6069    dimSize = self.size(dim) if self.dim() > 0 else 1
6070    torch._check(
6071        k >= 1 and k <= dimSize,
6072        lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
6073    )
6074
6075    shape = list(self.shape[:dim] + self.shape[dim + 1 :])
6076    if keepdim and self.dim() > 0:
6077        shape.insert(dim, 1)
6078    return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
6079
6080
6081legacy_contiguous_memory_format = torch.contiguous_format
6082
6083
6084# From aten/src/ATen/native/cuda/RNN.cu
6085def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
6086    defined_grad = grad_hy if grad_hy is not None else grad_cy
6087    torch._check(defined_grad.dim() == 2, lambda: "")
6088    exp_size = defined_grad.size()
6089    if grad_hy is not None:
6090        torch._check(grad_hy.size() == exp_size, lambda: "")
6091    if grad_cy is not None:
6092        torch._check(grad_cy.size() == exp_size, lambda: "")
6093    torch._check(cx.size() == exp_size, lambda: "")
6094    torch._check(cy.size() == exp_size, lambda: "")
6095    torch._check(workspace.dim() == 2, lambda: "")
6096    torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
6097
6098
6099# From aten/src/ATen/native/cuda/RNN.cu
6100@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
6101def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
6102    if grad_hy is None and grad_cy is None:
6103        return None, None, None
6104    checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
6105    grad_gates = torch.empty_like(
6106        workspace, memory_format=legacy_contiguous_memory_format
6107    )
6108    grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
6109    grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
6110    return grad_gates, grad_cx, grad_bias
6111
6112
6113# From aten/src/ATen/native/mps/operations/Linear.mm
6114@register_meta(aten.linear_backward.default)
6115def linear_backward(input_, grad_output_, weight_, output_mask):
6116    grad_input = None
6117    grad_weight = None
6118    grad_bias = None
6119    if output_mask[0]:
6120        grad_input = grad_output_.new_empty(input_.size())
6121    if output_mask[1] or output_mask[2]:
6122        grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
6123        grad_bias = grad_output_.new_empty(grad_output_.size(-1))
6124    return (grad_input, grad_weight, grad_bias)
6125
6126
6127@register_meta(aten.pixel_shuffle.default)
6128def meta_pixel_shuffle(self, upscale_factor):
6129    assert (
6130        len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
6131    ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
6132
6133    def is_channels_last(ten):
6134        return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
6135
6136    def pick_memory_format():
6137        if is_channels_last(self):
6138            if device_hint(self) == "cuda":
6139                return torch.contiguous_format
6140            else:
6141                return torch.channels_last
6142        elif self.is_contiguous(memory_format=torch.contiguous_format):
6143            return torch.contiguous_format
6144        elif self.is_contiguous(memory_format=torch.preserve_format):
6145            return torch.preserve_format
6146
6147    C = self.shape[-3] // (upscale_factor * upscale_factor)
6148    Hr = self.shape[-2] * upscale_factor
6149    Wr = self.shape[-1] * upscale_factor
6150    out_shape = (*self.shape[:-3], C, Hr, Wr)
6151
6152    out = self.new_empty(out_shape)
6153    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
6154    return out
6155
6156
6157@register_meta(aten.mkldnn_rnn_layer_backward.default)
6158def mkldnn_rnn_layer_backward(
6159    input,
6160    weight0,
6161    weight1,
6162    weight2,
6163    weight3,
6164    hx_,
6165    cx_tmp,
6166    output,
6167    hy_,
6168    cy_,
6169    grad_output_r_opt,
6170    grad_hy_r_opt,
6171    grad_cy_r_opt,
6172    reverse,
6173    mode,
6174    hidden_size,
6175    num_layers,
6176    has_biases,
6177    train,
6178    bidirectional,
6179    batch_sizes,
6180    batch_first,
6181    workspace,
6182):
6183    diff_x = input.new_empty(input.shape)
6184    diff_hx = hx_.new_empty(hx_.shape)
6185    diff_cx = cx_tmp.new_empty(cx_tmp.shape)
6186    diff_w1 = weight0.new_empty(weight0.shape)
6187    diff_w2 = weight1.new_empty(weight1.shape)
6188    diff_b = weight2.new_empty(weight2.shape)
6189    return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
6190
6191
6192@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
6193@out_wrapper()
6194def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
6195    return torch.empty_like(
6196        self, dtype=torch.int32 if out_int32 else torch.int64
6197    ).contiguous()
6198
6199
6200@register_meta([aten.histc])
6201@out_wrapper()
6202def meta_histc(input, bins=100, min=0, max=0):
6203    fn_name = "histc()"
6204    if device_hint(input) == "cpu":
6205        torch._check(
6206            input.is_floating_point(),
6207            lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'",
6208        )
6209    torch._check(
6210        isinstance(bins, IntLike),
6211        lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}",
6212    )
6213    torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}")
6214    torch._check(
6215        isinstance(min, Number),
6216        lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}",
6217    )
6218    torch._check(
6219        isinstance(max, Number),
6220        lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
6221    )
6222    torch._check(max >= min, lambda: "{fn_name}: max must be larger than min")
6223    return torch.empty(bins, device=input.device, dtype=input.dtype)
6224
6225
6226@register_meta(
6227    [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
6228)
6229def meta_upsample_bimode2d_aa(
6230    input,
6231    output_size,
6232    align_corners,
6233    scales_h=None,
6234    scales_w=None,
6235):
6236    full_output_size = upsample_common_check(
6237        input.size(), output_size, num_spatial_dims=2
6238    )
6239    torch._check(
6240        input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
6241        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
6242    )
6243    return input.new_empty(full_output_size).to(
6244        memory_format=utils.suggest_memory_format(input)
6245    )
6246
6247
6248# From aten/src/ATen/native/cuda/AmpKernels.cu
6249@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
6250def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
6251    torch._check(
6252        found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
6253    )
6254    torch._check(
6255        inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
6256    )
6257    torch._check(
6258        found_inf.dtype.is_floating_point,
6259        lambda: "found_inf must be a float tensor.",
6260    )
6261    torch._check(
6262        inv_scale.dtype.is_floating_point,
6263        lambda: "inv_scale must be a float tensor.",
6264    )
6265
6266
6267# From aten/src/ATen/native/UnaryOps.cpp
6268@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
6269@out_wrapper()
6270def nan_to_num(self, nan=None, posinf=None, neginf=None):
6271    result_size = list(self.size())
6272    return self.new_empty(result_size)
6273
6274
6275@register_meta(torch.ops.aten.transpose_)
6276def transpose_(self, dim0, dim1):
6277    assert (
6278        self.layout
6279        not in {
6280            torch.sparse_csr,
6281            torch.sparse_csc,
6282            torch.sparse_bsr,
6283            torch.sparse_bsc,
6284        }
6285    ), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
6286
6287    ndims = self.ndim
6288
6289    dim0 = maybe_wrap_dim(dim0, ndims)
6290    dim1 = maybe_wrap_dim(dim1, ndims)
6291
6292    if dim0 == dim1:
6293        return self
6294
6295    size = list(self.size())
6296    stride = list(self.stride())
6297
6298    stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
6299    size[dim0], size[dim1] = size[dim1], size[dim0]
6300
6301    self.as_strided_(size, stride)
6302    return self
6303
6304
6305@register_meta(torch.ops.aten.t_)
6306def t_(self):
6307    ndims = self.ndim
6308
6309    if self.is_sparse:
6310        sparse_dim = self.sparse_dim()
6311        dense_dim = self.dense_dim()
6312        assert (
6313            sparse_dim <= 2 and dense_dim == 0
6314        ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions"  # noqa: B950
6315    else:
6316        assert (
6317            self.dim() <= 2
6318        ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
6319
6320    return transpose_(self, 0, 0 if ndims < 2 else 1)
6321
6322
6323@register_meta(aten.searchsorted)
6324@out_wrapper()
6325def meta_searchsorted(
6326    sorted_sequence,
6327    self,
6328    *,
6329    out_int32=False,
6330    right=False,
6331    side=None,
6332    sorter=None,
6333):
6334    dtype = torch.int32 if out_int32 else torch.int64
6335    if isinstance(self, torch.Tensor):
6336        return torch.empty_like(self, dtype=dtype).contiguous()
6337    else:  # Scalar
6338        return torch.empty((), dtype=dtype, device=sorted_sequence.device)
6339
6340
6341def _check_for_unsupported_isin_dtype(dtype):
6342    torch._check(
6343        dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
6344        lambda: f"Unsupported input type encountered for isin(): {dtype}",
6345    )
6346
6347
6348@register_meta(aten._embedding_bag_backward)
6349def meta_embedding_bag_backward(
6350    grad,
6351    indices,
6352    offsets,
6353    offset2bag,
6354    bag_size,
6355    maximum_indices,
6356    num_weights,
6357    scale_grad_by_freq,
6358    mode,
6359    sparse,
6360    per_sample_weights,
6361    padding_idx=-1,
6362):
6363    if sparse:
6364        return aten._embedding_bag_sparse_backward(
6365            grad,
6366            indices,
6367            offsets,
6368            offset2bag,
6369            bag_size,
6370            num_weights,
6371            scale_grad_by_freq,
6372            mode,
6373            per_sample_weights,
6374            padding_idx,
6375        )
6376    else:
6377        return meta_embedding_bag_dense_backward(
6378            grad,
6379            indices,
6380            offset2bag,
6381            bag_size,
6382            maximum_indices,
6383            num_weights,
6384            scale_grad_by_freq,
6385            mode,
6386            per_sample_weights,
6387            padding_idx,
6388        )
6389
6390
6391@register_meta(aten._embedding_bag_dense_backward)
6392def meta_embedding_bag_dense_backward(
6393    grad,
6394    indices,
6395    offset2bag,
6396    bag_size,
6397    maximum_indices,
6398    num_weights,
6399    scale_grad_by_freq,
6400    mode,
6401    per_sample_weights,
6402    padding_idx=-1,
6403):
6404    torch._check(
6405        grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64],
6406        lambda: f"Unsupported input type encountered: {grad.dtype}",
6407    )
6408    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6409    if mode == MODE_MAX:
6410        torch._check(maximum_indices is not None)
6411    index_grad_weight = grad.new_empty((num_weights, grad.size(1)))
6412    return index_grad_weight
6413
6414
6415@register_meta(aten._embedding_bag_per_sample_weights_backward)
6416def meta_embedding_bag_per_sample_weights_backward(
6417    grad,
6418    weight,
6419    indices,
6420    offsets,
6421    offset2bag,
6422    mode,
6423    padding_idx=-1,
6424):
6425    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6426    embedding_features = grad.size(1)
6427    torch._check(
6428        mode == MODE_SUM,
6429        "embedding_bag_backward: per_sample_weights only supported for mode='sum'",
6430    )
6431    torch._check(grad.dim() == 2)
6432    torch._check(indices.dim() == 1)
6433    num_samples = indices.size(0)
6434    torch._check(weight.dim() == 2)
6435    torch._check(weight.size(1) == embedding_features)
6436    output = grad.new_empty((num_samples,))
6437    return output
6438
6439
6440@register_meta(aten.isin)
6441@out_wrapper()
6442def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
6443    torch._check(
6444        isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
6445        lambda: "At least one of elements and test_elements must be a Tensor.",
6446    )
6447    if not isinstance(elements, Tensor):
6448        elements = torch.tensor(elements, device=test_elements.device)
6449
6450    if not isinstance(test_elements, Tensor):
6451        test_elements = torch.tensor(test_elements, device=elements.device)
6452
6453    _check_for_unsupported_isin_dtype(elements.dtype)
6454    _check_for_unsupported_isin_dtype(test_elements.dtype)
6455    return torch.empty_like(elements, dtype=torch.bool)
6456
6457
6458@register_meta(aten.polygamma)
6459@out_wrapper()
6460def meta_polygamma(n: int, self: Tensor) -> Tensor:
6461    torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
6462    _, result_dtype = elementwise_dtypes(
6463        self,
6464        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
6465    )
6466    return torch.empty_like(self, dtype=result_dtype)
6467
6468
6469@register_meta(aten._local_scalar_dense)
6470def meta_local_scalar_dense(self: Tensor):
6471    raise RuntimeError("Tensor.item() cannot be called on meta tensors")
6472
6473
6474@register_meta(aten._jagged_to_padded_dense_forward.default)
6475def meta__jagged_to_padded_dense_forward(
6476    values: Tensor,
6477    offsets: List[Tensor],
6478    max_lengths: List[int],
6479    padding_value: float = 0.0,
6480):
6481    # only one jagged dim is supported for now
6482    assert len(offsets) == 1
6483    assert len(max_lengths) == 1
6484
6485    B = offsets[0].shape[0] - 1
6486    S = max_lengths[0]
6487    output_shape = (B, S, *values.shape[1:])
6488    return values.new_empty(output_shape)
6489
6490
6491@register_meta(aten._padded_dense_to_jagged_forward.default)
6492def meta__padded_dense_to_jagged_forward(
6493    padded: Tensor,
6494    offsets: List[Tensor],
6495    total_L: Optional[int] = None,
6496):
6497    # only one jagged dim is supported for now
6498    assert len(offsets) == 1
6499
6500    if not total_L:
6501        assert isinstance(padded, torch._subclasses.FakeTensor)
6502        shape_env = padded.fake_mode.shape_env
6503        assert shape_env is not None
6504        total_L = shape_env.create_unbacked_symint()
6505        torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
6506            total_L, min=0, max=None
6507        )
6508
6509    output_shape = (total_L, *padded.shape[2:])
6510    return padded.new_empty(output_shape)
6511
6512
6513def _create_unary_float_meta_func(func):
6514    @register_meta(func)
6515    @out_wrapper()
6516    def _f(x):
6517        return elementwise_meta(
6518            x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6519        )
6520
6521    return _f
6522
6523
6524def _create_binary_float_meta_func(func):
6525    @register_meta(func)
6526    @out_wrapper()
6527    def _f(x, y):
6528        return elementwise_meta(
6529            x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6530        )
6531
6532    return _f
6533
6534
6535_create_unary_float_meta_func(aten.special_airy_ai)
6536_create_unary_float_meta_func(aten.special_bessel_y0)
6537_create_unary_float_meta_func(aten.special_bessel_y1)
6538_create_unary_float_meta_func(aten.special_modified_bessel_i0)
6539_create_unary_float_meta_func(aten.special_modified_bessel_i1)
6540_create_unary_float_meta_func(aten.special_modified_bessel_k0)
6541_create_unary_float_meta_func(aten.special_modified_bessel_k1)
6542_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
6543_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
6544
6545
6546_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
6547_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
6548_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
6549_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
6550_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
6551_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
6552_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
6553_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
6554_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
6555_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
6556_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
6557_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
6558
6559
6560# We must also trigger meta registrations from PrimTorch ref
6561# decompositions
6562import torch._refs
6563import torch._refs.nn.functional
6564import torch._refs.special
6565
6566
6567def activate_meta():
6568    activate_meta_table = {}
6569
6570    # For a given op, we pick the most specific decomp function from
6571    # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
6572    for type in ["meta", "post_autograd", "pre_autograd"]:
6573        registry = global_decomposition_table[type]
6574
6575        for opo in registry:
6576            if opo not in activate_meta_table:
6577                activate_meta_table[opo] = registry[opo]
6578
6579    for op_overload, fn in activate_meta_table.items():
6580        # Don't register meta for HigherOrderOp's decomp.
6581        # We can reconsider this in the future, but in general,
6582        # the way you do a meta for a HigherOrderOp is different from
6583        # OpOverload.
6584        if isinstance(op_overload, torch._ops.HigherOrderOperator):
6585            continue
6586        assert isinstance(op_overload, OpOverload)
6587
6588        op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6589
6590        if torch._C._dispatch_has_kernel_for_dispatch_key(
6591            op_overload.name(), "CompositeImplicitAutograd"
6592        ):
6593            # Internally, we shouldn't be registering meta kernels for any operators that
6594            # have CompositeImplicitAutograd kernels.
6595            # Instead, we should be letting those decompositions run, and writing meta kernels
6596            # only for the base operators.
6597            if op_overload in global_decomposition_table["meta"]:
6598                raise RuntimeError(
6599                    f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
6600                    "register meta function for it. Instead, we should let the decomposition run and write "
6601                    "meta kernels for the base operators."
6602                )
6603        elif op_overload.is_view:
6604            # Attempting to register a python meta kernel for a view operator.
6605            # We shouldn't do this, because the output will report as not having aliased storages.
6606            # All view ops have meta kernels in C++ today, so we should use those instead.
6607            pass
6608        elif (
6609            op_overload.name()
6610            in {
6611                "aten::empty_strided",  # causing infinite recursion, test_meta.py
6612                "aten::clone",  # causing infinite recursion
6613                "aten::_to_copy",  # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite  # noqa: B950
6614                "aten::copy_",  # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64  # noqa: B950
6615                "aten::constant_pad_nd",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32  # noqa: B950
6616                "aten::rot90",  # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32  # noqa: B950
6617                "aten::as_strided_scatter",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32  # noqa: B950
6618            }
6619        ):
6620            pass
6621        else:
6622            if "mkldnn::" in op_overload.name():
6623                _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
6624            elif "mkl::" in op_overload.name():
6625                _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
6626            elif "onednn::" in op_overload.name():
6627                _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
6628            elif "quantized::" in op_overload.name():
6629                _meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
6630                    op_overload, fn
6631                )
6632            else:
6633                _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
6634
6635
6636activate_meta()
6637