xref: /aosp_15_r20/external/pytorch/torch/_refs/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import builtins
4import collections
5import inspect
6import itertools
7import math
8import operator
9import warnings
10from collections.abc import Iterable
11from enum import Enum
12from functools import partial, reduce, singledispatch, wraps
13from typing import Any, Callable, Dict, List, Optional, overload, Sequence, Tuple, Union
14
15import torch
16import torch._prims as prims
17import torch._prims_common as utils
18import torch.utils._pytree as pytree
19from torch import sym_float, sym_int
20from torch._prims_common import (
21    BoolLike,
22    DeviceLikeType,
23    Dim,
24    DimsSequenceType,
25    DimsType,
26    dtype_to_type,
27    ELEMENTWISE_TYPE_PROMOTION_KIND,
28    FloatLike,
29    FloatWithoutSymFloat,
30    IntLike,
31    is_weakly_lesser_type,
32    Number,
33    NumberType,
34    RealNumberType,
35    REDUCTION_OUTPUT_TYPE_KIND,
36    ShapeType,
37    StrideType,
38    TensorLike,
39    TensorLikeType,
40    TensorOrNumberLikeType,
41    TensorSequenceType,
42)
43from torch._prims_common.wrappers import (
44    _maybe_convert_to_dtype,
45    _maybe_resize_out,
46    _safe_copy_out,
47    elementwise_type_promotion_wrapper,
48    elementwise_unary_scalar_wrapper,
49    out_wrapper,
50)
51
52
53# Experimental module containing prototype Python references for existing
54#   PyTorch operations.
55
56__all__ = [
57    #
58    # Elementwise Unary References
59    #
60    "abs",
61    "acos",
62    "acosh",
63    "asinh",
64    "asin",
65    "atan",
66    "atanh",
67    "bitwise_not",
68    # "cbrt",  # No corresponding torch operation
69    "ceil",
70    "conj_physical",
71    "cos",
72    "cosh",
73    "count_nonzero",
74    "deg2rad",
75    "digamma",
76    "erf",
77    "erfinv",
78    "erfc",
79    "exp",
80    "expm1",
81    "exponential",
82    "exp2",
83    "fill",
84    "fill_",
85    "floor",
86    "frac",
87    "geometric",
88    "index_add",
89    "index_copy",
90    "index_copy_",
91    "index_select",
92    "index_fill",
93    "index_fill_",
94    "isfinite",
95    "isinf",
96    "isposinf",
97    "isneginf",
98    "isnan",
99    "isreal",
100    "i0",
101    "lerp",
102    "lgamma",
103    "log",
104    "log1p",
105    "log2",
106    "log10",
107    "log_normal",
108    "log_softmax",
109    "mvlgamma",
110    "norm",
111    "normal",
112    "nan_to_num",
113    "neg",
114    "positive",
115    "rad2deg",
116    "reciprocal",
117    "round",  # TODO: model kwargs
118    "sigmoid",
119    "sgn",
120    "sign",
121    "signbit",
122    "sin",
123    "sinc",
124    "sinh",
125    "softmax",
126    "sqrt",
127    "square",
128    "tan",
129    "tanh",
130    "trace",
131    "trunc",
132    #
133    # Elementwise Binary References
134    #
135    "add",
136    "atan2",
137    "bitwise_and",
138    "bitwise_left_shift",
139    "bitwise_or",
140    "bitwise_right_shift",
141    "bitwise_xor",
142    "clamp_min",
143    "clamp_max",
144    "copysign",
145    "div",
146    "eq",
147    "float_power",
148    "floor_divide",
149    "fmax",
150    "fmin",
151    "fmod",
152    "gcd",
153    "ge",
154    "gt",
155    "heaviside",
156    "hypot",
157    "igamma",
158    "igammac",
159    "imag",
160    "isclose",
161    "lcm",
162    # 'ldexp',
163    "le",
164    "logaddexp",
165    "logaddexp2",
166    "logical_and",
167    "logical_not",
168    "logical_or",
169    "logical_xor",
170    "logsumexp",
171    "lt",
172    # 'max', # implement with reductions
173    "maximum",
174    # 'min', # implement with reductions
175    "minimum",
176    "mul",
177    "ne",
178    "nextafter",
179    # 'polar',  # abs, cos, sin
180    "pow",
181    "real",
182    "rpow",
183    "remainder",
184    "rsub",
185    "rtruediv",
186    "rfloordiv",
187    "sub",
188    "true_divide",
189    "trunc_divide",
190    "xlogy",
191    #
192    # Elementwise Ternary References
193    #
194    "addcdiv",
195    "addcmul",
196    "clamp",
197    #
198    # Conditional references
199    #
200    "masked_fill",
201    "masked_fill_",
202    "where",
203    #
204    # Data conversion and movement references
205    #
206    "clone",
207    "copy_to",  # TODO: add OpInfo (or implement .to)
208    "item",
209    "to",
210    #
211    # Reduction ops
212    #
213    "all",
214    "amax",
215    "amin",
216    "any",
217    "cumsum",
218    "cumprod",
219    "mean",
220    "dot",
221    "vdot",
222    "std",
223    "std_mean",
224    "sum",
225    "sum_to_size",
226    "prod",
227    "var",
228    "var_mean",
229    #
230    # Linear algebra ops
231    #
232    "addr",
233    #
234    # View & Shape Ops
235    #
236    "alias",
237    "alias_copy",
238    "atleast_1d",
239    "atleast_2d",
240    "atleast_3d",
241    "as_strided",
242    "as_strided_copy",
243    "as_strided_scatter",
244    "block_diag",
245    "broadcast_shapes",
246    "broadcast_tensors",
247    "broadcast_to",
248    "cat",
249    "chunk",
250    "column_stack",
251    "conj",
252    "constant_pad_nd",
253    "contiguous",
254    "diag_embed",
255    "diag",
256    "diagonal",
257    "diagonal_copy",
258    "diagonal_scatter",
259    "dsplit",
260    "dstack",
261    "expand",
262    "expand_as",
263    "expand_copy",
264    "flatten",
265    "flip",
266    "fliplr",
267    "flipud",
268    "hsplit",
269    "hstack",
270    "meshgrid",
271    "movedim",
272    "narrow",
273    "narrow_copy",
274    "native_group_norm",
275    "native_layer_norm",
276    "permute",
277    "ravel",
278    "repeat",
279    "reshape",
280    "reshape_as",
281    "roll",
282    "rot90",
283    "rsqrt",
284    "stack",
285    "swap_axes",  # alias for transpose
286    "squeeze",
287    "t",
288    "t_copy",
289    "T",
290    "take_along_dim",
291    "tensor_split",
292    "transpose",
293    "unfold",
294    "unfold_copy",
295    "unsqueeze",
296    "unsqueeze_copy",
297    "view",
298    "view_as",
299    "view_copy",
300    "vsplit",
301    "vstack",
302    "view_as_complex",
303    "unflatten",
304    "unbind",
305    "triu",
306    "tril",
307    "triu_indices",
308    "tril_indices",
309    #
310    # Tensor Creation
311    #
312    "arange",
313    "cauchy",
314    "empty",
315    "empty_like",
316    "empty_permuted",
317    "empty_strided",
318    "eye",
319    "full",
320    "full_like",
321    "linspace",
322    "logspace",
323    "new_empty",
324    "new_empty_strided",
325    "new_full",
326    "new_ones",
327    "new_zeros",
328    "ones",
329    "ones_like",
330    "randn",
331    "scalar_tensor",
332    "zero",
333    "zeros",
334    "zeros_like",
335    #
336    # Test-related functions
337    #
338    "allclose",
339    "equal",
340    #
341    # Statistical operations
342    #
343    "bucketize",
344    #
345    # Misc
346    #
347    "is_complex",
348    "renorm",
349    "stft",
350    "istft",
351]
352
353Tensor = torch.Tensor
354DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
355aten = torch._ops.ops.aten
356
357# Note that the docstrings for the public methods from this file are in
358# torch/_torch_docs.py
359
360
361def is_noncontiguous_supported(device):
362    return device is None or device.type != "hpu"
363
364
365def handle_noncontiguous_outputs(input_tlist, output):
366    device = None
367    from torch._subclasses.fake_tensor import FakeTensor
368
369    for t in input_tlist:
370        if isinstance(t, FakeTensor):
371            device = t.fake_device
372            break
373
374    if not is_noncontiguous_supported(device):
375        output = output.contiguous()
376
377    return output
378
379
380def _broadcast_shapes(*_shapes):
381    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
382
383    shapes = tuple(
384        (x,) if isinstance(x, IntLike) else x
385        for x in filter(lambda x: x is not None, _shapes)
386    )
387
388    # Short-circuits on no input
389    if len(shapes) == 0:
390        return None
391
392    # Type checking
393    # TODO: make common validations available as utils
394    for shape in shapes:
395        assert isinstance(shape, Sequence)
396
397    # Computes common shape
398    common_shape = [
399        1,
400    ] * reduce(max, (len(shape) for shape in shapes))
401    for arg_idx, shape in enumerate(shapes):
402        for idx in range(-1, -1 - len(shape), -1):
403            if guard_size_oblivious(common_shape[idx] == 1):
404                if shape[idx] < 0:
405                    raise ValueError(
406                        "Attempting to broadcast a dimension with negative length!"
407                    )
408                common_shape[idx] = shape[idx]
409            elif guard_size_oblivious(shape[idx] != 1):
410                if common_shape[idx] != shape[idx]:
411                    raise RuntimeError(
412                        f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
413                        f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
414                        f"should be broadcastable to {common_shape}"
415                    )
416
417    return common_shape
418
419
420def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
421    # Computes common shape
422    common_shape = _broadcast_shapes(
423        *(t.shape if isinstance(t, TensorLike) else None for t in args)
424    )
425
426    def __maybe_broadcast(x, shape):
427        if x is None:
428            return None
429        elif isinstance(x, Number):
430            return x
431        elif isinstance(x, TensorLike):
432            if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
433                return x
434
435            if not utils.same_shape(x.shape, common_shape):
436                return x.expand(common_shape)
437
438            return x
439        else:
440            raise RuntimeError(
441                "Unexpected type when broadcasting: " + str(type(x)) + "!"
442            )
443
444    return tuple(__maybe_broadcast(x, common_shape) for x in args)
445
446
447# Utilities should come BEFORE this import
448from torch._decomp import register_decomposition
449
450
451#
452# Elementwise unary references
453#
454
455infer_aten_op = object()
456
457
458# TODO: add type promotion support
459def _make_elementwise_unary_reference(
460    type_promotion_kind,
461    *,
462    aten_op=infer_aten_op,
463    extra_meta=None,
464) -> Callable:
465    def inner(prim: Callable):
466        nonlocal aten_op
467
468        @wraps(prim)
469        @out_wrapper()
470        @elementwise_unary_scalar_wrapper
471        @elementwise_type_promotion_wrapper(
472            type_promoting_args=("a",),
473            type_promotion_kind=type_promotion_kind,
474        )
475        def _ref(a: TensorLikeType) -> TensorLikeType:
476            if extra_meta is not None:
477                extra_meta(a)
478
479            output = prim(a)
480            return handle_noncontiguous_outputs([a], output)
481
482        if aten_op is infer_aten_op:
483            aten_op = utils.get_aten_op(prim, prim.__name__)
484        if aten_op is not None:
485            register_decomposition(aten_op)(_ref)
486
487        return _ref
488
489    return inner
490
491
492def _make_alias(fn, name):
493    """
494    This function defines an alias of another function and sets its __name__ argument.
495    It also sets its __module__ argument to the module of the caller.
496    Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and
497    `alias.__module__ == fn.__module__`.
498    """
499
500    def _fn(*args, **kwargs):
501        return fn(*args, **kwargs)
502
503    _fn.__name__ = name
504    _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"]  # type: ignore[union-attr]
505    return _fn
506
507
508def _make_inplace(fn):
509    """
510    Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
511    See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
512    """
513
514    # nb. We use the name of the first argument used in the unary references
515    @wraps(fn)
516    def _fn(a, *args, **kwargs):
517        return fn(a, *args, out=a, **kwargs)
518
519    inplace_name = f"{fn.__name__}_"
520    _fn.__name__ = inplace_name
521    _fn = register_decomposition(getattr(aten, inplace_name))(_fn)  # type: ignore[assignment]
522
523    # We access the __all__ attribute of the module where fn is defined
524    # There may be a cleaner way of doing this...
525    from inspect import getmodule
526
527    _all = getmodule(fn).__all__  # type: ignore[union-attr]
528    if inplace_name not in _all:
529        _all.append(inplace_name)
530    return _fn
531
532
533@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
534def abs(a):
535    return prims.abs(a)
536
537
538@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
539def acos(a):
540    return prims.acos(a)
541
542
543@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
544def acosh(a):
545    return prims.acosh(a)
546
547
548@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
549def asin(a):
550    return prims.asin(a)
551
552
553@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
554def asinh(a):
555    return prims.asinh(a)
556
557
558@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
559def atan(a):
560    return prims.atan(a)
561
562
563@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
564def atanh(a):
565    return prims.atanh(a)
566
567
568@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
569def bitwise_not(a):
570    return prims.bitwise_not(a)
571
572
573@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
574def ceil(a):
575    return prims.ceil(a)
576
577
578@register_decomposition(aten.is_complex)
579def is_complex(input: TensorLikeType):
580    return utils.is_complex_dtype(input.dtype)
581
582
583@register_decomposition(aten.conj_physical)
584@out_wrapper()
585def conj_physical(input: TensorLikeType):
586    if not utils.is_complex_dtype(input.dtype):
587        return input
588    return prims.conj_physical(input)
589
590
591@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
592def cos(a):
593    return prims.cos(a)
594
595
596@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
597def cosh(a):
598    return prims.cosh(a)
599
600
601@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
602def digamma(a):
603    return prims.digamma(a)
604
605
606@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
607def erf(a):
608    return prims.erf(a)
609
610
611@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
612def erfinv(a):
613    return prims.erf_inv(a)
614
615
616@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
617def erfc(a):
618    return prims.erfc(a)
619
620
621@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
622def exp(a):
623    return prims.exp(a)
624
625
626@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
627def expm1(a):
628    return prims.expm1(a)
629
630
631@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
632def exp2(a):
633    return prims.exp2(a)
634
635
636# Fill has its own implementation because it has a value parameter
637# CompositeImplicitAutograd - don't register decomp
638@out_wrapper()
639@elementwise_type_promotion_wrapper(
640    type_promoting_args=("a,"),
641    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
642)
643def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
644    assert isinstance(a, TensorLike)
645    assert isinstance(value, Number)
646
647    python_type = utils.dtype_to_type(a.dtype)
648    if not utils.is_weakly_lesser_type(type(value), python_type):
649        msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!"
650        raise ValueError(msg)
651
652    return prims.fill(a, value)
653
654
655def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
656    r = prims.fill(a, value)
657    prims.copy_to(a, r)
658    return a
659
660
661@register_decomposition(aten.zero)
662@out_wrapper()
663def zero(input: TensorLikeType) -> TensorLikeType:
664    return torch.zeros_like(input)
665
666
667@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
668def floor(a):
669    return prims.floor(a)
670
671
672@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
673def frac(x: TensorLikeType) -> TensorLikeType:
674    trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
675    return torch.sub(x, trunc_x)
676
677
678# imag does not use _make_elementwise_unary_reference because it does not support out
679def imag(a: TensorLikeType) -> TensorLikeType:
680    assert isinstance(a, TensorLike)
681    torch._check(
682        utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
683    )
684    return prims.imag(a)
685
686
687@_make_elementwise_unary_reference(
688    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
689    aten_op=None,  # CompositeImplicitAutograd
690)
691def isfinite(a: TensorLikeType) -> TensorLikeType:
692    if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
693        return prims.isfinite(a)
694
695    return ones_like(a, dtype=torch.bool)
696
697
698@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
699def isinf(a: TensorLikeType) -> TensorLikeType:
700    if utils.is_complex_dtype(a.dtype):
701        return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
702    if utils.is_float_dtype(a.dtype):
703        return torch.abs(a) == float("inf")
704    return torch.zeros_like(a, dtype=torch.bool)
705
706
707@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
708def isposinf(a: TensorLikeType) -> TensorLikeType:
709    torch._check(
710        not utils.is_complex_dtype(a.dtype),
711        lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
712    )
713    if utils.is_float_dtype(a.dtype):
714        return a == float("inf")
715    return torch.zeros_like(a, dtype=torch.bool)
716
717
718@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
719def isneginf(a: TensorLikeType) -> TensorLikeType:
720    torch._check(
721        not utils.is_complex_dtype(a.dtype),
722        lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
723    )
724    if utils.is_float_dtype(a.dtype):
725        return a == float("-inf")
726    return torch.zeros_like(a, dtype=torch.bool)
727
728
729@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
730def isnan(a: TensorLikeType) -> TensorLikeType:
731    return prims.ne(a, a)
732
733
734# alias
735mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma")  # type: ignore[has-type]
736
737
738@_make_elementwise_unary_reference(
739    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
740    aten_op=None,  # CompositeImplicitAutograd
741)
742def isreal(a: TensorLikeType) -> TensorLikeType:
743    if utils.is_complex_dtype(a.dtype):
744        return torch.imag(a) == 0
745    return torch.ones_like(a, dtype=torch.bool)
746
747
748# TODO: if this is special maybe it should be defined there and imported here?
749@_make_elementwise_unary_reference(
750    ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
751)
752def i0(a):
753    return prims.bessel_i0(a)
754
755
756@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
757def lgamma(a):
758    return prims.lgamma(a)
759
760
761@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
762def log(a):
763    return prims.log(a)
764
765
766@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
767def log1p(a):
768    return prims.log1p(a)
769
770
771@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
772def log2(a):
773    return prims.log2(a)
774
775
776@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
777def log10(a):
778    return prims.log10(a)
779
780
781# CompositeImplicitAutograd - don't register decomp
782@out_wrapper()
783def log_softmax(
784    a: TensorLikeType,
785    dim: int,
786    dtype: Optional[torch.dtype] = None,
787) -> TensorLikeType:
788    result_dtype = dtype or a.dtype
789    computation_dtype = utils.get_computation_dtype(result_dtype)
790    a_ = _maybe_convert_to_dtype(a, computation_dtype)
791    return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype)  # type: ignore[return-value]
792
793
794@register_decomposition(aten.logsumexp)
795@out_wrapper()
796@elementwise_type_promotion_wrapper(
797    type_promoting_args=("self",),
798    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
799)
800def logsumexp(
801    self: TensorLikeType, dim: DimsType, keepdim: bool = False
802) -> TensorLikeType:
803    if not isinstance(dim, Iterable):
804        dim = (dim,)
805    if self.numel() == 0:
806        return torch.sum(torch.exp(self), dim, keepdim).log()
807    maxes = torch.amax(torch.real(self), dim, keepdim=True)
808    maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
809    maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
810    result = torch.sum(torch.exp(self - maxes), dim, keepdim)
811    return result.log().add(maxes_squeezed)
812
813
814@register_decomposition(aten.nan_to_num)
815@out_wrapper()
816def nan_to_num(
817    a: TensorLikeType,
818    nan: Optional[NumberType] = 0.0,
819    posinf: Optional[NumberType] = None,
820    neginf: Optional[NumberType] = None,
821) -> TensorLikeType:
822    assert isinstance(a, TensorLike)
823
824    if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
825        return a.clone()
826
827    if nan is None:
828        nan = 0.0
829
830    if posinf is None:
831        posinf = torch.finfo(a.dtype).max
832
833    if neginf is None:
834        neginf = torch.finfo(a.dtype).min
835
836    result = torch.where(torch.isnan(a), nan, a)  # type: ignore[call-overload]
837    result = torch.where(torch.isneginf(a), neginf, result)  # type: ignore[call-overload]
838    result = torch.where(torch.isposinf(a), posinf, result)  # type: ignore[call-overload]
839    return result
840
841
842def _neg_meta(a: TensorLikeType):
843    torch._check(
844        a.dtype is not torch.bool,
845        lambda: (
846            "Negation, the `-` operator, on a bool tensor is not supported. "
847            "If you are trying to invert a mask, use the `~` or `logical_not()` "
848            "operator instead."
849        ),
850    )
851
852
853@_make_elementwise_unary_reference(
854    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
855)
856def neg(a):
857    return prims.neg(a)
858
859
860# positive does not use _make_elementwise_unary_reference because it does not support out
861# CompositeImplicitAutograd - don't register decomp
862def positive(a: TensorLikeType) -> TensorLikeType:
863    assert isinstance(a, TensorLike)
864    if a.dtype is torch.bool:
865        msg = "positive does not support bool tensors."
866        raise RuntimeError(msg)
867    return a
868
869
870# real does not use _make_elementwise_unary_reference because it does not support out
871def real(a: TensorLikeType) -> TensorLikeType:
872    assert isinstance(a, TensorLike)
873    if utils.is_complex_dtype(a.dtype):
874        return prims.real(a)
875    return a
876
877
878@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
879def reciprocal(a):
880    return prims.reciprocal(a)
881
882
883@register_decomposition(aten.round)
884@out_wrapper()
885@elementwise_type_promotion_wrapper(
886    type_promoting_args=("a",),
887    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
888)
889def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
890    if decimals == 0:
891        return prims.round(a)
892    else:
893        ten_pow = 10**decimals
894        ten_neg_pow = 10 ** (-decimals)
895        return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
896
897
898@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
899def rsqrt(a):
900    return prims.rsqrt(a)
901
902
903@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
904def sigmoid(a: TensorLikeType) -> TensorLikeType:
905    return true_divide(1, add(1, exp(neg(a))))
906
907
908@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
909def sgn(a):
910    if utils.is_complex_dtype(a.dtype):
911        a_abs = a.abs()
912        return torch.where(a_abs == 0, 0, a / a_abs)
913    else:
914        return a.sign()
915
916
917@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
918def sign(a):
919    return prims.sign(a)
920
921
922@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
923def signbit(a):
924    return prims.signbit(a)
925
926
927@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
928def sin(a):
929    return prims.sin(a)
930
931
932# Autograd note: This will give the right first derivative at zero (by chance),
933# but not the right second derivative
934@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
935def sinc(a):
936    a = math.pi * a
937    return torch.where(a == 0, 1, torch.sin(a) / a)
938
939
940@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
941def sinh(a):
942    return prims.sinh(a)
943
944
945@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
946def sqrt(a):
947    return prims.sqrt(a)
948
949
950@_make_elementwise_unary_reference(
951    ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
952    aten_op=None,  # CompositeImplicitAutograd,
953)
954def square(a: TensorLikeType) -> TensorLikeType:
955    return mul(a, a)
956
957
958@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
959def tan(a):
960    return prims.tan(a)
961
962
963@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
964def tanh(a):
965    return prims.tanh(a)
966
967
968@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
969def trunc(a):
970    return prims.trunc(a)
971
972
973# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
974def view_as_complex(self: TensorLikeType) -> TensorLikeType:
975    input_dtype = self.dtype
976    torch._check(
977        utils.is_float_dtype(input_dtype),
978        lambda: f"view_as_complex is only supported for floating point"
979        f"tensors, but got a tensor of scalar type: {input_dtype}",
980    )
981    sizes = self.size()
982    torch._check(
983        len(sizes) != 0,
984        lambda: "Input tensor must have one or more dimensions",
985    )
986    torch._check(
987        sizes[-1] == 2,
988        lambda: "Tensor must have a last dimension of size 2",
989    )
990
991    old_strides = self.stride()
992    torch._check(
993        old_strides[-1] == 1,
994        lambda: "Tensor must have a last dimension with stride 1",
995    )
996    dims = old_strides[:-1]
997    torch._check(
998        builtins.all(stride % 2 == 0 for stride in dims),
999        lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
1000    )
1001    torch._check(
1002        self.storage_offset() % 2 == 0,
1003        lambda: "Tensor must have a storage_offset divisible by 2",
1004    )
1005    return prims.view_element_type(
1006        self, utils.corresponding_complex_dtype(input_dtype)
1007    ).squeeze(-1)
1008
1009
1010def _make_elementwise_binary_reference(
1011    type_promotion_kind,
1012    aten_op=infer_aten_op,
1013    name=None,
1014    has_out=True,
1015    supports_lhs_python_scalar=True,
1016    supports_rhs_python_scalar=True,
1017    supports_two_python_scalars=False,
1018    should_register_decomposition=True,
1019) -> Callable:
1020    def inner(prim: Callable):
1021        nonlocal aten_op, name
1022        if name is None:
1023            name = prim.__name__
1024
1025        @wraps(prim)
1026        @elementwise_type_promotion_wrapper(
1027            type_promoting_args=("a", "b"),
1028            type_promotion_kind=type_promotion_kind,
1029        )
1030        def _ref(
1031            a: Union[Tensor, NumberType],
1032            b: Union[Tensor, NumberType],
1033        ) -> Tensor:
1034            torch._check_value(
1035                supports_lhs_python_scalar or not isinstance(a, Number),
1036                lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
1037                "operation that does not accept lhs scalars!",
1038            )
1039            torch._check_value(
1040                supports_rhs_python_scalar or not isinstance(b, Number),
1041                lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
1042                "operation that does not accept rhs scalars!",
1043            )
1044            torch._check_value(
1045                supports_two_python_scalars
1046                or not (isinstance(a, Number) and isinstance(b, Number)),
1047                lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
1048            )
1049            a, b = _maybe_broadcast(a, b)
1050            output = prim(a, b)
1051            return handle_noncontiguous_outputs([a, b], output)
1052
1053        if has_out:
1054            _ref = out_wrapper()(_ref)  # type: ignore[assignment]
1055
1056        _ref.__name__ = name
1057        if aten_op is infer_aten_op:
1058            aten_op = utils.get_aten_op(prim, name)
1059        if aten_op is not None and should_register_decomposition:
1060            register_decomposition(aten_op)(_ref)
1061
1062        return _ref
1063
1064    return inner
1065
1066
1067# Add has its own implementation because it has an alpha argument
1068@register_decomposition(aten.add)
1069@out_wrapper()
1070@elementwise_type_promotion_wrapper(
1071    type_promoting_args=("a", "b"),
1072    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1073)
1074def add(
1075    a: Union[TensorLikeType, NumberType],
1076    b: Union[TensorLikeType, NumberType],
1077    *,
1078    alpha: Optional[NumberType] = None,
1079):
1080    """
1081    Reference implementation of torch.add
1082    """
1083
1084    a, b = _maybe_broadcast(a, b)
1085
1086    if alpha is not None:
1087        dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
1088        python_type = utils.dtype_to_type(dtype)
1089        if python_type != bool and not utils.is_weakly_lesser_type(
1090            type(alpha), python_type
1091        ):
1092            msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
1093            raise ValueError(msg)
1094        if isinstance(b, TensorLike):
1095            b = prims.mul(b, alpha)
1096        else:
1097            b = b * alpha
1098
1099    output = prims.add(a, b)
1100    return handle_noncontiguous_outputs([a, b], output)
1101
1102
1103@_make_elementwise_binary_reference(
1104    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1105    supports_lhs_python_scalar=False,
1106    supports_rhs_python_scalar=False,
1107)
1108def atan2(a, b):
1109    return prims.atan2(a, b)
1110
1111
1112@_make_elementwise_binary_reference(
1113    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1114)
1115def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1116    return prims.bitwise_and(a, b)
1117
1118
1119@_make_elementwise_binary_reference(
1120    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1121)
1122def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1123    return prims.shift_left(a, b)
1124
1125
1126@_make_elementwise_binary_reference(
1127    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1128)
1129def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1130    return prims.bitwise_or(a, b)
1131
1132
1133@_make_elementwise_binary_reference(
1134    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1135)
1136def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1137    return prims.shift_right_arithmetic(a, b)
1138
1139
1140@_make_elementwise_binary_reference(
1141    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1142)
1143def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1144    return prims.bitwise_xor(a, b)
1145
1146
1147@_make_elementwise_binary_reference(
1148    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1149    supports_lhs_python_scalar=False,
1150)
1151def copysign(
1152    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1153):
1154    if isinstance(b, Number) and isinstance(a, Tensor):
1155        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1156    elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
1157        msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
1158        raise RuntimeError(msg)
1159    return where(signbit(b), neg(abs(a)), abs(a))
1160
1161
1162# complex =  _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
1163
1164
1165@register_decomposition(aten.div)
1166@out_wrapper()
1167def div(
1168    a: Union[TensorLikeType, NumberType],
1169    b: Union[TensorLikeType, NumberType],
1170    *,
1171    rounding_mode: Optional[str] = None,
1172):
1173    """
1174    Reference implementation of torch.div
1175    """
1176    if rounding_mode is None:
1177        return true_divide(a, b)
1178    elif rounding_mode == "trunc":
1179        return trunc_divide(a, b)
1180    elif rounding_mode == "floor":
1181        return floor_divide(a, b)
1182    else:
1183        msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
1184        raise ValueError(msg)
1185
1186
1187@_make_elementwise_binary_reference(
1188    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1189    supports_lhs_python_scalar=False,
1190)
1191def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1192    return prims.eq(a, b)
1193
1194
1195@_make_elementwise_binary_reference(
1196    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
1197)
1198def pow(
1199    a: Union[TensorLikeType, NumberType],
1200    b: Union[TensorLikeType, NumberType],
1201) -> TensorLikeType:
1202    assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
1203
1204    if isinstance(b, Number):
1205        if b == 1.0:
1206            return a.clone()  # type: ignore[return-value,union-attr]
1207        elif b == 2.0:
1208            return a * a  # type: ignore[return-value]
1209        elif b == 0.5:
1210            return torch.sqrt(a)  # type: ignore[arg-type]
1211    elif isinstance(a, Number):
1212        if a == 1.0:
1213            return torch.fill(b, True)
1214        if a == 2.0 and (
1215            utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
1216        ):
1217            return torch.exp2(b)
1218
1219    return prims.pow(a, b)
1220
1221
1222# Float power has its own implementation because it has unique type promotion.
1223# CompositeImplicitAutograd - don't register decomp
1224@out_wrapper()
1225def float_power(
1226    a: Union[TensorLikeType, NumberType],
1227    b: Union[TensorLikeType, NumberType],
1228) -> Tensor:
1229    if isinstance(a, Number) and isinstance(b, Number):
1230        raise ValueError(
1231            "Receive two Number inputs to an elementwise binary operation!"
1232        )
1233
1234    # Handles type promotion
1235    dtype = utils.get_higher_dtype(a, b)
1236    assert dtype is not None
1237    if utils.is_complex_dtype(dtype):
1238        dtype = torch.complex128
1239    else:
1240        dtype = torch.float64
1241
1242    # Float power has the following contiguous cast behavior to be
1243    # consistent with its C++ impl
1244    a = _maybe_convert_to_dtype(a, dtype)
1245    b = _maybe_convert_to_dtype(b, dtype)
1246
1247    a, b = _maybe_broadcast(a, b)
1248    return pow(a, b)
1249
1250
1251# >>> a = torch.tensor(-0.2500, dtype=torch.float64)
1252# tensor(-0.250000000000000, dtype=torch.float64)
1253#
1254# >>> b = torch.tensor(-0.0010, dtype=torch.float64)
1255# tensor(-0.001000000000000, dtype=torch.float64)
1256#
1257# Note: In this case, casting float to double will expand the float mantissa with zeros,
1258# while creating a double generates a distinct mantissa.
1259# >>> torch.tensor(-0.001).to(dtype=torch.float64)
1260# tensor(-0.001000000047497, dtype=torch.float64)
1261#
1262# Floor Division
1263# The difference is caused because torch.remainder(a, b) = -0.001.
1264#
1265# >>> torch.floor(torch.true_divide(a, b))
1266# tensor(250., dtype=torch.float64)
1267#
1268# >>> torch.div(a, b, rounding_mode='floor')
1269# tensor(249., dtype=torch.float64)
1270#
1271# Definition: a // b = (a - remainder(a, b)) / b
1272# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
1273# tensor(249., dtype=torch.float64)
1274#
1275# For reference, see CPython's implementation:
1276# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
1277
1278
1279@_make_elementwise_binary_reference(
1280    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1281    supports_two_python_scalars=True,
1282    should_register_decomposition=False,
1283)
1284def floor_divide(
1285    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1286):
1287    # Wrap scalars because some references only accept tensor arguments.
1288    if isinstance(a, Number) and isinstance(b, Number):
1289        a = scalar_tensor(a)
1290        b = scalar_tensor(b)
1291    elif isinstance(b, Number) and isinstance(a, Tensor):
1292        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1293    elif isinstance(a, Number) and isinstance(b, Tensor):
1294        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1295    elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
1296        if a.device == torch.device("cpu"):
1297            msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
1298            raise RuntimeError(msg)
1299        else:
1300            b = prims.device_put(b, device=a.device)
1301
1302    assert isinstance(a, Tensor) and isinstance(b, Tensor)
1303    dtype = a.dtype
1304    if utils.is_float_dtype(dtype):
1305        return _floor_divide_float(a, b)
1306    elif utils.is_integer_dtype(dtype):
1307        return _floor_divide_integer(a, b)
1308    else:
1309        torch._check(False, lambda: f"{dtype} not supported for floor_divide")
1310
1311
1312def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
1313    a, b = _maybe_broadcast(a, b)
1314
1315    if not a.dtype.is_signed:
1316        return prims.div(a, b)
1317
1318    # Convert truncation to flooring:
1319    offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
1320    return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
1321
1322
1323def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
1324    mod = fmod(a, b)
1325    div = true_divide(sub(a, mod), b)
1326
1327    # Ensure that the remainder has the same sign as denominator
1328    different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
1329    non_zero_remainder = ne(mod, 0)
1330    mask = bitwise_and(non_zero_remainder, different_signed_inputs)
1331    div = where(mask, sub(div, 1), div)
1332
1333    # Map quotient to nearest integer value
1334    floor_div = floor(div)
1335    mask = gt(sub(div, floor_div), 0.5)
1336    floor_div = where(mask, add(floor_div, 1), floor_div)
1337
1338    basic_div = true_divide(a, b)
1339    zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
1340
1341    # If quotient is zero, copy signbit from true_divide quotient
1342    floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
1343
1344    # If denominator is zero, then follow true_divide behavior
1345    return where(ne(b, 0), floor_div, basic_div)
1346
1347
1348@_make_elementwise_binary_reference(
1349    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1350    supports_lhs_python_scalar=False,
1351    supports_rhs_python_scalar=False,
1352)
1353def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1354    return prims.fmax(a, b)
1355
1356
1357@_make_elementwise_binary_reference(
1358    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1359    supports_lhs_python_scalar=False,
1360    supports_rhs_python_scalar=False,
1361)
1362def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1363    return prims.fmin(a, b)
1364
1365
1366@_make_elementwise_binary_reference(
1367    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1368    supports_lhs_python_scalar=False,
1369    supports_rhs_python_scalar=True,
1370)
1371def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1372    return prims.fmod(a, b)
1373
1374
1375@register_decomposition(aten.frexp)
1376@out_wrapper("mantissa", "exponent")
1377def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
1378    return torch.return_types.frexp(prims.frexp(self))
1379
1380
1381@_make_elementwise_binary_reference(
1382    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1383    supports_lhs_python_scalar=False,
1384    supports_rhs_python_scalar=False,
1385)
1386def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1387    return prims.gcd(a, b)
1388
1389
1390@_make_elementwise_binary_reference(
1391    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1392    supports_lhs_python_scalar=False,
1393)
1394def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1395    return prims.ge(a, b)
1396
1397
1398@_make_elementwise_binary_reference(
1399    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1400    supports_lhs_python_scalar=False,
1401)
1402def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1403    return prims.gt(a, b)
1404
1405
1406@_make_elementwise_binary_reference(
1407    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1408    supports_lhs_python_scalar=False,
1409    supports_rhs_python_scalar=False,
1410)
1411def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
1412    input_eq_zero = torch.eq(input, 0)
1413    input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
1414    zeros_and_ones = torch.where(input_lt_zero, 0, 1)
1415    output = torch.where(input_eq_zero, values, zeros_and_ones)
1416    return output
1417
1418
1419@_make_elementwise_binary_reference(
1420    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1421    supports_lhs_python_scalar=False,
1422    supports_rhs_python_scalar=False,
1423)
1424def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1425    return prims.hypot(a, b)
1426
1427
1428@_make_elementwise_binary_reference(
1429    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1430    supports_lhs_python_scalar=False,
1431    supports_rhs_python_scalar=False,
1432)
1433def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1434    return prims.igamma(a, b)
1435
1436
1437@_make_elementwise_binary_reference(
1438    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1439    supports_lhs_python_scalar=False,
1440    supports_rhs_python_scalar=False,
1441)
1442def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1443    return prims.igammac(a, b)
1444
1445
1446def _check_close_args(
1447    name: str,
1448    a: TensorLikeType,
1449    b: TensorLikeType,
1450    rtol: float,
1451    atol: float,
1452) -> None:
1453    torch._check_value(
1454        a.dtype == b.dtype,
1455        lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!",
1456    )
1457    torch._check(
1458        rtol >= 0,
1459        lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!",
1460    )
1461    torch._check(
1462        atol >= 0,
1463        lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!",
1464    )
1465
1466
1467# CompositeImplicitAutograd - don't register decomp
1468def isclose(
1469    a: TensorLikeType,
1470    b: TensorLikeType,
1471    rtol: float = 1e-05,
1472    atol: float = 1e-08,
1473    equal_nan: bool = False,
1474) -> TensorLikeType:
1475    _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
1476
1477    close = eq(a, b)
1478    if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
1479        close = logical_or(close, logical_and(isnan(a), isnan(b)))
1480
1481    # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
1482    # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
1483    if atol == 0 and rtol == 0:
1484        return close
1485
1486    # Note [closeness error computation]
1487    # atol and rtol are provided as doubles, so the computation
1488    # rtol * other will produce a float or complex tensor.
1489    # When the difference (self - other) is compared to it then the
1490    # tensor representing the difference will also be cast to float or complex.
1491    # However, since (self - other) in uint8 is very likely to produce a
1492    # negative value, this moves the cast forward so the difference is
1493    # always computed in a float or complex type.
1494    # If the values of the integer tensors cannot be exactly represented
1495    # by the default scalar type then this may cause an incorrect result.
1496    if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
1497        a = prims.convert_element_type(a, torch.get_default_dtype())
1498        b = prims.convert_element_type(b, torch.get_default_dtype())
1499
1500    allowed_error = add(atol, abs(mul(b, rtol)))
1501    actual_error = abs(sub(a, b))
1502
1503    # Computes finite closeness
1504    result = logical_or(
1505        close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
1506    )
1507
1508    return result
1509
1510
1511@_make_elementwise_binary_reference(
1512    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1513    supports_lhs_python_scalar=False,
1514    supports_rhs_python_scalar=False,
1515)
1516def lcm(a: TensorLikeType, b: TensorLikeType):
1517    dtype = a.dtype
1518    # promoting to int32 to maintain 100% consistency with C++ and to
1519    # prevent overflow in case of int8 and int16
1520    promote_to_int = dtype in (torch.int8, torch.int16)
1521    if promote_to_int:
1522        a = prims.convert_element_type(a, torch.int32)
1523        b = prims.convert_element_type(b, torch.int32)
1524
1525    g = torch.gcd(a, b)
1526    # Avoid division by zero in case gcd(0, 0) == 0
1527    g = torch.where(g == 0, 1, g)
1528    res = torch.abs(prims.div(a, g) * b)
1529    return res if not promote_to_int else prims.convert_element_type(res, dtype)
1530
1531
1532@_make_elementwise_binary_reference(
1533    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1534    supports_lhs_python_scalar=False,
1535)
1536def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1537    return prims.le(a, b)
1538
1539
1540@_make_elementwise_binary_reference(
1541    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1542    supports_lhs_python_scalar=False,
1543    supports_rhs_python_scalar=False,
1544)
1545def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1546    # Nb. this implementation does not distribute the gradients evenly when a == b
1547    mask = torch.real(a) >= torch.real(b)
1548    max_ = torch.where(mask, a, b)
1549    min_ = torch.where(mask, b, a)
1550    inf_mask = torch.logical_and(
1551        torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
1552    )
1553    if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype):
1554        # are you wondering what this bunch of codes are for? edge cases!
1555        neg_min_mask = torch.real(min_) < 0
1556        inf_vals = torch.where(
1557            neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_))
1558        )
1559        non_nan_vals = torch.where(
1560            inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_))
1561        )
1562        # the type for full_like does not include tensor yet
1563        nan_mask = torch.isnan(min_)
1564        return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals)  # type: ignore[call-overload]
1565    else:
1566        return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
1567
1568
1569@_make_elementwise_binary_reference(
1570    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1571    supports_lhs_python_scalar=False,
1572    supports_rhs_python_scalar=False,
1573)
1574def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1575    torch._check(
1576        not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)),
1577        lambda: "logaddexp2 doesn't support complex dtypes",
1578    )
1579    # Nb. this implementation does not distribute the gradients evenly when a == b
1580    mask = a >= b
1581    max_ = torch.where(mask, a, b)
1582    min_ = torch.where(mask, b, a)
1583    inf_mask = torch.logical_and(torch.isinf(a), a == b)
1584    inv_log_2 = 1.0 / math.log(2)
1585    result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2
1586    return torch.where(inf_mask, a, result)
1587
1588
1589@_make_elementwise_binary_reference(
1590    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1591)
1592def logical_and(a: TensorLikeType, b: TensorLikeType):
1593    if not utils.is_boolean_dtype(a.dtype):
1594        a = a != 0
1595    if not utils.is_boolean_dtype(b.dtype):
1596        b = b != 0
1597    return a & b
1598
1599
1600@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
1601def logical_not(a: TensorLikeType):
1602    if not utils.is_boolean_dtype(a.dtype):
1603        return a == 0
1604    return ~a
1605
1606
1607@_make_elementwise_binary_reference(
1608    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1609)
1610def logical_or(a: TensorLikeType, b: TensorLikeType):
1611    if not utils.is_boolean_dtype(a.dtype):
1612        a = a != 0
1613    if not utils.is_boolean_dtype(b.dtype):
1614        b = b != 0
1615    return bitwise_or(a, b)
1616
1617
1618# TODO: skip unnecessary conversion of long to float
1619@_make_elementwise_binary_reference(
1620    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1621)
1622def logical_xor(a: TensorLikeType, b: TensorLikeType):
1623    if not utils.is_boolean_dtype(a.dtype):
1624        a = a != 0
1625    if not utils.is_boolean_dtype(b.dtype):
1626        b = b != 0
1627    return a ^ b
1628
1629
1630@_make_elementwise_binary_reference(
1631    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1632    supports_lhs_python_scalar=False,
1633)
1634def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1635    return prims.lt(a, b)
1636
1637
1638@_make_elementwise_binary_reference(
1639    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1640)
1641def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1642    return prims.maximum(a, b)
1643
1644
1645@_make_elementwise_binary_reference(
1646    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1647)
1648def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1649    return prims.minimum(a, b)
1650
1651
1652@_make_elementwise_binary_reference(
1653    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1654    supports_two_python_scalars=True,
1655)
1656def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1657    return prims.mul(a, b)
1658
1659
1660@_make_elementwise_binary_reference(
1661    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1662    supports_lhs_python_scalar=False,
1663)
1664def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1665    return prims.ne(a, b)
1666
1667
1668@_make_elementwise_binary_reference(
1669    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
1670    supports_lhs_python_scalar=False,
1671    supports_rhs_python_scalar=False,
1672)
1673def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1674    return prims.nextafter(a, b)
1675
1676
1677@_make_elementwise_binary_reference(
1678    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1679)
1680def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1681    return prims.remainder(a, b)
1682
1683
1684# reverse sub
1685@register_decomposition(aten.rsub)
1686@out_wrapper()
1687def rsub(
1688    a: Union[TensorLikeType, NumberType],
1689    b: Union[TensorLikeType, NumberType],
1690    alpha: NumberType = 1,
1691):
1692    if isinstance(a, Number):
1693        msg = "Received a Number for the first argument, but expected a Tensor"
1694        raise ValueError(msg)
1695
1696    return torch.sub(b, a, alpha=alpha)
1697
1698
1699# TODO: consider refactoring this with add impl
1700# sub has its own implementation because it has an alpha argument
1701@register_decomposition(aten.sub)
1702@out_wrapper()
1703@elementwise_type_promotion_wrapper(
1704    type_promoting_args=("a", "b"),
1705    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1706)
1707def sub(
1708    a: Union[TensorLikeType, NumberType],
1709    b: Union[TensorLikeType, NumberType],
1710    *,
1711    alpha: NumberType = 1,
1712):
1713    """
1714    Reference implementation of torch.sub
1715    """
1716
1717    a, b = _maybe_broadcast(a, b)
1718
1719    if isinstance(a, TensorLike) and isinstance(b, TensorLike):
1720        torch._check(
1721            not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype),
1722            lambda: (
1723                "Subtraction, the `-` operator, with two bool tensors is not supported. "
1724                "Use the `^` or `logical_xor()` operator instead."
1725            ),
1726        )
1727
1728    if alpha != 1:
1729        dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
1730        python_type = utils.dtype_to_type(dtype)
1731        if not utils.is_weakly_lesser_type(type(alpha), python_type):
1732            msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
1733            raise ValueError(msg)
1734        if isinstance(b, torch.Tensor):
1735            b = prims.mul(b, alpha)
1736        else:
1737            # Carefully not to use prims.mul if b is a scalar / symint.
1738            # prims.mul always returns a tensor,
1739            # which will mess with type promotion.
1740            b = b * alpha
1741
1742    output = prims.sub(a, b)
1743    return handle_noncontiguous_outputs([a, b], output)
1744
1745
1746@_make_elementwise_binary_reference(
1747    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1748    name="true_divide",
1749    aten_op=None,  # CompositeImplicitAutograd
1750    supports_two_python_scalars=True,
1751)
1752def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
1753    return prims.div(a, b)
1754
1755
1756@register_decomposition(aten.xlogy)
1757@out_wrapper()
1758@elementwise_type_promotion_wrapper(
1759    type_promoting_args=("a", "b"),
1760    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1761)
1762def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
1763    torch._check(
1764        isinstance(a, TensorLike) or isinstance(b, TensorLike),
1765        lambda: 'Expected either argument a or b to be a Tensor"',
1766    )
1767
1768    # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
1769    if isinstance(b, TensorLike) and isinstance(a, Number):
1770        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1771    elif isinstance(a, TensorLike) and isinstance(b, Number):
1772        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1773
1774    # mypy: expected "Tensor"
1775    assert isinstance(a, TensorLike)
1776    assert isinstance(b, TensorLike)
1777    rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
1778    return torch.where(torch.isnan(b), float("nan"), rhs)
1779
1780
1781@_make_elementwise_binary_reference(
1782    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1783    aten_op=None,  # CompositeImplicitAutograd
1784    supports_two_python_scalars=True,
1785)
1786def trunc_divide(
1787    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1788):
1789    dtype = utils.get_dtype(a)
1790    if utils.is_integer_dtype(dtype):
1791        return prims.div(a, b)
1792
1793    return trunc(prims.div(a, b))
1794
1795
1796#
1797# Elementwise Ternary References
1798#
1799
1800
1801@register_decomposition(aten.addcdiv)
1802@out_wrapper()
1803@elementwise_type_promotion_wrapper(
1804    type_promoting_args=("self", "tensor1", "tensor2"),
1805    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
1806)
1807def addcdiv(
1808    self: TensorLikeType,
1809    tensor1: TensorLikeType,
1810    tensor2: TensorLikeType,
1811    *,
1812    value: NumberType = 1,
1813) -> TensorLikeType:
1814    """
1815    Reference implementation of torch.addcdiv
1816    """
1817    if value is not None:
1818        dtype = self.dtype  # no scalars allowed, see add
1819        python_type = utils.dtype_to_type(dtype)
1820        torch._check_value(
1821            utils.is_weakly_lesser_type(type(value), python_type),
1822            lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
1823        )
1824
1825    return self + value * tensor1 / tensor2
1826
1827
1828@register_decomposition(aten.addcmul)
1829@out_wrapper()
1830@elementwise_type_promotion_wrapper(
1831    type_promoting_args=("self", "tensor1", "tensor2"),
1832    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1833)
1834def addcmul(
1835    self: TensorLikeType,
1836    tensor1: TensorLikeType,
1837    tensor2: TensorLikeType,
1838    *,
1839    value: NumberType = 1,
1840) -> TensorLikeType:
1841    """
1842    Reference implementation of torch.addcmul
1843    """
1844    if value is not None:
1845        dtype = self.dtype  # no scalars allowed, see add
1846        python_type = utils.dtype_to_type(dtype)
1847        torch._check_value(
1848            utils.is_weakly_lesser_type(type(value), python_type),
1849            lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
1850        )
1851
1852    return self + value * tensor1 * tensor2
1853
1854
1855@register_decomposition(aten.clamp)
1856@out_wrapper()
1857@elementwise_type_promotion_wrapper(
1858    type_promoting_args=("a", "min", "max"),
1859    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1860)
1861def clamp(
1862    a: TensorLikeType,
1863    min: Optional[TensorOrNumberLikeType] = None,
1864    max: Optional[TensorOrNumberLikeType] = None,
1865) -> TensorLikeType:
1866    # NOTE: grad behavior with implementation `where` is not consistent on `nan`
1867    if min is None and max is None:
1868        msg = "clamp called but both min and max are none!"
1869        raise ValueError(msg)
1870    if min is not None:
1871        a_isnan = torch.isnan(a)
1872        condition = torch.bitwise_or(torch.ge(a, min), a_isnan)  # type: ignore[arg-type]
1873        # we should also propagate `nan` coming from boundaries. However, that's
1874        # not necessary since `ge` would already `False` when either operands has
1875        # a `nan`. So this line below is redundant
1876        #   `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
1877        a = torch.where(condition, a, min)  # type: ignore[arg-type]
1878    if max is not None:
1879        a_isnan = torch.isnan(a)
1880        # same as above, no need to adjust `nan` from `max`
1881        condition = torch.bitwise_or(torch.le(a, max), a_isnan)  # type: ignore[arg-type]
1882        a = torch.where(condition, a, max)  # type: ignore[arg-type]
1883
1884    return a
1885
1886
1887@register_decomposition(aten.clamp_min)
1888@out_wrapper()
1889def clamp_min(
1890    self: TensorLikeType,
1891    min: Optional[TensorOrNumberLikeType] = None,
1892) -> TensorLikeType:
1893    return torch.clamp(self, min=min)  # type: ignore[arg-type]
1894
1895
1896@register_decomposition(aten.clamp_max)
1897@out_wrapper()
1898def clamp_max(
1899    self: TensorLikeType,
1900    max: Optional[TensorOrNumberLikeType] = None,
1901) -> TensorLikeType:
1902    return torch.clamp(self, max=max)  # type: ignore[arg-type]
1903
1904
1905#
1906# Conditional references
1907#
1908
1909
1910# https://pytorch.org/docs/stable/generated/torch.where.html
1911# TODO: implement alternate where
1912@register_decomposition(aten.where)
1913@out_wrapper()
1914@elementwise_type_promotion_wrapper(
1915    type_promoting_args=("a", "b"),
1916    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
1917)
1918def where(
1919    pred: Tensor,
1920    a: Optional[TensorOrNumberLikeType] = None,
1921    b: Optional[TensorOrNumberLikeType] = None,
1922):
1923    """ """
1924
1925    if a is None or b is None:
1926        raise NotImplementedError
1927
1928    utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
1929    torch._check(
1930        pred.dtype is torch.bool,
1931        lambda: f"expected predicate to be bool, got {pred.dtype}",
1932    )
1933
1934    pred, a, b = _maybe_broadcast(pred, a, b)
1935    return prims.where(pred, a, b)
1936
1937
1938#
1939# Data Movement References
1940#
1941@register_decomposition(aten.clone)
1942@out_wrapper()
1943def clone(
1944    a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
1945) -> TensorLikeType:
1946    result = prims.clone(a, memory_format=memory_format)
1947    return result
1948
1949
1950def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
1951    if not allow_cross_device and a.device != b.device:
1952        msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!"
1953        raise RuntimeError(msg)
1954
1955    return prims.copy_to(a, b)
1956
1957
1958@register_decomposition(aten.item)
1959def item(a: TensorLikeType) -> NumberType:
1960    if a.numel() != 1:
1961        msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
1962        raise ValueError(msg)
1963
1964    # NOTE: explicit conversion is necessary for bool!
1965    # See https://github.com/pytorch/pytorch/issues/78071
1966    number_type = utils.dtype_to_type(a.dtype)
1967    return number_type(prims.item(a))
1968
1969
1970# fast path when `to` returns an alias to input. This mimics the same function in aten
1971def _to_will_alias(
1972    a: TensorLikeType,
1973    device: Optional[DeviceLikeType] = None,
1974    dtype: Optional[torch.dtype] = None,
1975    copy: Optional[bool] = None,
1976    layout: Optional[torch.layout] = None,
1977    memory_format: Optional[torch.memory_format] = None,
1978    pin_memory: Optional[bool] = False,
1979    non_blocking: bool = False,  # not using non_blocking
1980) -> bool:
1981    return (
1982        not copy
1983        and (device is None or a.device == device)
1984        and (dtype is None or a.dtype == dtype)
1985        and (layout is None or a.layout == layout)
1986        # is_pinned issue #84925
1987        # and (pin_memory is None or pin_memory == a.is_pinned())
1988        and (
1989            memory_format is None
1990            or memory_format == torch.preserve_format
1991            or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
1992        )
1993    )
1994
1995
1996@singledispatch
1997def _to_dispatch(*args, **kwargs):
1998    raise NotImplementedError
1999
2000
2001@_to_dispatch.register
2002def _to_device(
2003    device: torch.device,
2004    dtype: torch.dtype,
2005    non_blocking: bool = False,
2006    copy: bool = False,
2007    memory_format: Optional[torch.memory_format] = None,
2008) -> Dict[str, Any]:
2009    kwargs = {
2010        "device": device,
2011        "dtype": dtype,
2012        "non_blocking": non_blocking,
2013        "copy": copy,
2014        "memory_format": memory_format,
2015    }
2016    return kwargs
2017
2018
2019@_to_dispatch.register
2020def _to_device_str(
2021    device: str,
2022    dtype: torch.dtype,
2023    non_blocking: bool = False,
2024    copy: bool = False,
2025    memory_format: Optional[torch.memory_format] = None,
2026) -> Dict[str, Any]:
2027    kwargs = {
2028        "device": torch.device(device),
2029        "dtype": dtype,
2030        "non_blocking": non_blocking,
2031        "copy": copy,
2032        "memory_format": memory_format,
2033    }
2034    return kwargs
2035
2036
2037@_to_dispatch.register
2038def _to_dtype(
2039    dtype: torch.dtype,
2040    non_blocking: bool = False,
2041    copy: bool = False,
2042    memory_format: Optional[torch.memory_format] = None,
2043) -> Dict[str, Any]:
2044    kwargs = {
2045        "dtype": dtype,
2046        "non_blocking": non_blocking,
2047        "copy": copy,
2048        "memory_format": memory_format,
2049    }
2050    return kwargs
2051
2052
2053@_to_dispatch.register
2054def _to_other(
2055    other: Tensor,
2056    non_blocking: bool = False,
2057    copy: bool = False,
2058    memory_format: Optional[torch.memory_format] = None,
2059) -> Dict[str, Any]:
2060    device = other.device
2061    dtype = other.dtype
2062    layout = other.layout
2063    # is_pinned issue #84925
2064    # pin_memory = other.is_pinned()
2065    kwargs = {
2066        "device": device,
2067        "dtype": dtype,
2068        "layout": layout,
2069        "non_blocking": non_blocking,
2070        "copy": copy,
2071        "memory_format": memory_format,
2072    }
2073    return kwargs
2074
2075
2076# remove to_kwargs that is already present in `a`
2077def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
2078    options_to_check = ["dtype", "device", "layout", "memory_format"]
2079    # "device" option could be passed a str instead torch.device
2080    if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
2081        to_kwargs["device"] = torch.device(to_kwargs["device"])
2082
2083    for kw in options_to_check:
2084        if kw in to_kwargs:
2085            if (
2086                (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
2087                or (
2088                    kw == "device"
2089                    and to_kwargs[kw].type == a.device.type
2090                    and (
2091                        not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
2092                    )
2093                )
2094                or (
2095                    getattr(a, kw, None) == to_kwargs[kw]
2096                )  # this also handles {"memory_format": None}
2097            ):
2098                to_kwargs.pop(kw)
2099
2100
2101def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
2102    # handled dispatch via positional arguments
2103    if len(args) != 0:
2104        kwargs = _to_dispatch(*args, **kwargs)
2105
2106    # TODO: is_pinned is not currently supported in refs or fake_tensor
2107    # https://github.com/pytorch/pytorch/issues/84925
2108    assert "pin_memory" not in kwargs
2109    _canonicalize_to_arguments(a, kwargs)
2110
2111    if _to_will_alias(a, **kwargs):
2112        return a
2113
2114    copy = kwargs.pop("copy") if "copy" in kwargs else False
2115    non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
2116
2117    # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
2118    if (
2119        (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
2120        and (not non_blocking)
2121        and ("memory_format" not in kwargs)
2122        and ("device" not in kwargs)
2123        and ("layout" not in kwargs)
2124        # is_pinned issue #84925
2125        # and ("pin_memory" not in kwargs)
2126    ):
2127        return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
2128
2129    result = torch.empty_like(a, **kwargs)
2130    # TODO: non_blocking should be handled by `copy_to`
2131    copy_to(result, a)
2132    return result
2133
2134
2135#
2136# Reduction references
2137#
2138
2139
2140def _reduction(
2141    a: TensorLikeType,
2142    prim: Callable,
2143    *,
2144    has_identity: bool = True,
2145    accepts_dim_tuple: bool = True,  # to handle min/argmin that accept single dim only
2146    dims: Optional[DimsType] = None,
2147    keepdims: bool = False,
2148    dtype: Optional[torch.dtype] = None,  # should be specified for ops that support it
2149    out: Optional[Tensor] = None,
2150    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
2151) -> TensorLikeType:  # it is usually SAME, but I want
2152    # ref writers to actually think about what to put here
2153    assert isinstance(a, TensorLike)
2154    if a.ndim > 64:
2155        raise RuntimeError(
2156            f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!"
2157        )
2158
2159    if out is not None:
2160        assert isinstance(out, TensorLike)
2161        if dtype is not None:
2162            # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
2163            if dtype != out.dtype:
2164                raise RuntimeError(
2165                    "dtype argument and out dtype must match in reduction"
2166                )
2167    if not accepts_dim_tuple:
2168        assert dims is None or isinstance(dims, Dim)
2169    if isinstance(dims, Dim):
2170        dims = (dims,)  # type: ignore[assignment]
2171    dims = utils.reduction_dims(a.shape, dims)
2172    if not has_identity:
2173        valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
2174        if not valid_shape:
2175            raise RuntimeError(
2176                "reducing over zero-size dimension for reduction operation without identity"
2177            )
2178    computation_dtype, result_dtype = utils.reduction_dtypes(
2179        a, output_dtype_kind, dtype
2180    )
2181    a = _maybe_convert_to_dtype(a, computation_dtype)  # type: ignore[method-assign]
2182    result = prim(a, dims)
2183    if keepdims:
2184        output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
2185        broadcast_dims = [i for i in range(a.ndim) if i not in dims]
2186        result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
2187
2188    if out is not None:
2189        assert result_dtype is not None
2190        if dtype is not None and result_dtype != out.dtype:
2191            raise RuntimeError(
2192                "Expected the dtype of reduction result and out to match"
2193            )
2194        out = _maybe_resize_out(out, result.shape)
2195        return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
2196
2197    if result.dtype != result_dtype and result_dtype is not None:
2198        result = prims.convert_element_type(result, result_dtype)
2199
2200    return result
2201
2202
2203def _make_copy_from_view(fn):
2204    """
2205    Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
2206    """
2207    aten_fn = getattr(aten, fn.__name__)
2208    annotations = getattr(fn, "__annotations__", {})
2209    fn = out_wrapper()(aten_fn)
2210
2211    @wraps(fn)
2212    def _fn(*args, out=None, **kwargs):
2213        result = fn(*args, out=out, **kwargs)
2214        if out is not None:
2215            return result
2216
2217        return pytree.tree_map(
2218            lambda x: x.clone(memory_format=torch.contiguous_format),
2219            result,
2220        )
2221
2222    copy_name = f"{fn.__name__}_copy"
2223    _fn.__name__ = copy_name
2224    _fn.__annotations__.update(annotations)
2225    register_decomposition(getattr(aten, copy_name))(_fn)
2226    return _fn
2227
2228
2229@register_decomposition(aten.all)
2230@out_wrapper()
2231def all(
2232    a: TensorLikeType,
2233    dim: Optional[DimsType] = None,
2234    keepdim: bool = False,
2235) -> TensorLikeType:
2236    result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
2237
2238    if a.dtype == torch.uint8:
2239        result = result.to(dtype=torch.uint8)
2240
2241    return result
2242
2243
2244@register_decomposition(aten.any)
2245@out_wrapper()
2246def any(
2247    a: TensorLikeType,
2248    dim: Optional[DimsType] = None,
2249    keepdim: bool = False,
2250) -> TensorLikeType:
2251    a_ = _maybe_convert_to_dtype(a, torch.bool)
2252    if isinstance(dim, (list, tuple)) and len(dim) == 0:
2253        result = a_.clone()
2254    else:
2255        result = a_.sum(dim=dim, keepdim=keepdim).ne(False)
2256
2257    # Preserves uint8 -- probably a legacy mask thing
2258    if a.dtype is torch.uint8:
2259        return prims.convert_element_type(result, torch.uint8)
2260
2261    return result
2262
2263
2264@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
2265def sum(
2266    a: TensorLikeType,
2267    dim: Union[Optional[int], Optional[List[int]]] = None,
2268    keepdim: bool = False,
2269    *,
2270    dtype: Optional[torch.dtype] = None,
2271    out: Optional[Tensor] = None,
2272) -> TensorLikeType:
2273    if dtype is None:
2274        if out is not None:
2275            dtype = out.dtype
2276        elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
2277            dtype = torch.int64
2278        else:
2279            dtype = a.dtype
2280    # reduces over all dimensions if dim=() is passed
2281    if dim == () or dim == []:
2282        dim = None
2283    return _reduction(
2284        a,
2285        prims.sum,
2286        dims=dim,
2287        keepdims=keepdim,
2288        dtype=dtype,
2289        out=out,
2290        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
2291    )
2292
2293
2294def sum_to_size(
2295    a: Tensor,
2296    *shape,
2297) -> Tensor:
2298    shape = utils.extract_shape_from_varargs(shape, validate=False)
2299    torch._check(
2300        utils.is_expandable_to(shape, a.shape),
2301        lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
2302    )
2303    # In ATen scalar tensors are sent through sum and the result is returned as
2304    # type promoted
2305    if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
2306        return prims.view_of(a)
2307    leading_dims = a.ndim - len(shape)
2308    reduce_dims = tuple(range(leading_dims)) + tuple(
2309        i
2310        for i in range(leading_dims, len(shape))
2311        if shape[i - leading_dims] == 1 and a.shape[i] != 1
2312    )
2313    return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
2314
2315
2316@register_decomposition(aten.prod)
2317def prod(
2318    a: TensorLikeType,
2319    dim: Union[Optional[int], Optional[List[int]]] = None,
2320    keepdim: bool = False,
2321    *,
2322    dtype=None,
2323    out: Optional[Tensor] = None,
2324) -> TensorLikeType:
2325    if dtype is None:
2326        if out is not None:
2327            dtype = out.dtype
2328        elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
2329            dtype = torch.int64
2330        else:
2331            dtype = a.dtype
2332    # reduces over all dimensions if dim=() is passed
2333    if dim == () or dim == []:
2334        dim = None
2335    return _reduction(
2336        a,
2337        prims.prod,
2338        dims=dim,
2339        keepdims=keepdim,
2340        dtype=dtype,
2341        out=out,
2342        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
2343    )
2344
2345
2346@register_decomposition(aten.amin)
2347def amin(
2348    a: TensorLikeType,
2349    dim: Optional[DimsType] = None,
2350    keepdim: bool = False,
2351    *,
2352    out: Optional[Tensor] = None,
2353) -> TensorLikeType:
2354    # reduces over all dimensions if dim=() is passed
2355    if dim == () or dim == []:
2356        dim = None
2357
2358    return _reduction(
2359        a,
2360        prims.amin,
2361        dims=dim,
2362        keepdims=keepdim,
2363        dtype=None,
2364        out=out,
2365        has_identity=False,
2366        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
2367    )
2368
2369
2370@register_decomposition(aten.amax)
2371def amax(
2372    a: TensorLikeType,
2373    dim: Optional[DimsType] = None,
2374    keepdim: bool = False,
2375    *,
2376    out: Optional[Tensor] = None,
2377) -> TensorLikeType:
2378    # reduces over all dimensions if dim=() is passed
2379    if dim == () or dim == []:
2380        dim = None
2381
2382    return _reduction(
2383        a,
2384        prims.amax,
2385        dims=dim,
2386        keepdims=keepdim,
2387        dtype=None,
2388        out=out,
2389        has_identity=False,
2390        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
2391    )
2392
2393
2394def _dim_var_dispatch(dim=None, unbiased=None):
2395    # There's the following overload of torch.var:
2396    # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
2397    # We need to explicitly convert bool dims to unbiased arg
2398    if unbiased is None and isinstance(dim, bool):
2399        unbiased = dim
2400        dim = None
2401    return dim, unbiased
2402
2403
2404@register_decomposition(aten.var)
2405@out_wrapper()
2406def var(
2407    a: TensorLikeType,
2408    dim: Optional[DimsType] = None,
2409    unbiased: Optional[bool] = None,
2410    keepdim: bool = False,
2411    *,
2412    correction: Optional[NumberType] = None,
2413) -> TensorLikeType:
2414    dim, unbiased = _dim_var_dispatch(dim, unbiased)
2415    correction = utils.set_correction(unbiased, correction)
2416    # reduces over all dimensions if dim=() is passed
2417    if dim == () or dim == []:
2418        dim = None
2419
2420    result = _reduction(
2421        a,
2422        partial(prims.var, correction=correction),
2423        dims=dim,
2424        keepdims=keepdim,
2425        dtype=None,
2426        out=None,
2427        has_identity=True,
2428        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
2429    )
2430    return result
2431
2432
2433@register_decomposition(aten.std)
2434@out_wrapper()
2435def std(
2436    a: TensorLikeType,
2437    dim: Union[Optional[int], Optional[List[int]]] = None,
2438    unbiased: Optional[bool] = None,
2439    keepdim: bool = False,
2440    *,
2441    correction: Optional[NumberType] = None,
2442) -> TensorLikeType:
2443    dim, unbiased = _dim_var_dispatch(dim, unbiased)
2444    correction = utils.set_correction(unbiased, correction)
2445
2446    opmath_dtype, dtype = utils.reduction_dtypes(
2447        a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
2448    )
2449    a = _maybe_convert_to_dtype(a, opmath_dtype)
2450    a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
2451    a_std = torch.sqrt(a_var)
2452    assert dtype is not None
2453    return _maybe_convert_to_dtype(a_std, dtype)
2454
2455
2456@register_decomposition(aten.mean)
2457def mean(
2458    a: TensorLikeType,
2459    dim: Optional[DimsType] = None,
2460    keepdim: bool = False,
2461    *,
2462    dtype=None,
2463    out=None,
2464) -> TensorLikeType:
2465    # reduces over all dimensions if dim=() is passed
2466    if dim == () or dim == []:
2467        dim = None
2468    orig_dtype = dtype
2469    if dtype is None:
2470        dtype = a.dtype
2471    # can't use out wrapper because of this argument
2472    torch._check(
2473        out is None or out.dtype == dtype,
2474        lambda: f"Expected out tensor to have dtype {dtype}, but got {out.dtype} instead",
2475    )
2476    result = _reduction(
2477        a,
2478        prims.sum,
2479        dims=dim,
2480        keepdims=keepdim,
2481        dtype=dtype,
2482        out=None,
2483        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
2484    )
2485    torch._check(
2486        utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
2487        lambda: (
2488            f"mean(): could not infer output dtype. "
2489            f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
2490            f"a floating point or complex dtype. Got: {dtype}"
2491        ),
2492    )
2493    if isinstance(dim, Dim):
2494        dim = (dim,)  # type: ignore[assignment]
2495    dims = utils.reduction_dims(a.shape, dim)  # type: ignore[arg-type]
2496    nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
2497    result = true_divide(result, nelem)
2498    result_dtype = a.dtype if dtype is None else dtype
2499    result = _maybe_convert_to_dtype(result, result_dtype)  # type: ignore[method-assign]
2500    if out is not None:
2501        assert isinstance(out, TensorLike)
2502        out = _maybe_resize_out(out, result.shape)
2503        return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
2504    return result
2505
2506
2507@register_decomposition(aten.std_mean)
2508@out_wrapper("out0", "out1")
2509def std_mean(
2510    a: TensorLikeType,
2511    dim: Optional[DimsType] = None,
2512    *,
2513    unbiased: Optional[bool] = None,
2514    keepdim: bool = False,
2515    correction: Optional[NumberType] = None,
2516):
2517    dim, unbiased = _dim_var_dispatch(dim, unbiased)
2518    correction = utils.set_correction(unbiased, correction)
2519    opmath_dtype, dtype = utils.reduction_dtypes(
2520        a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
2521    )
2522    original_dtype = a.dtype
2523    a = _maybe_convert_to_dtype(a, opmath_dtype)
2524    a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
2525    a_std = torch.sqrt(a_var)
2526    assert dtype is not None
2527    return (
2528        _maybe_convert_to_dtype(a_std, dtype),
2529        _maybe_convert_to_dtype(a_mean, original_dtype),
2530    )
2531
2532
2533@register_decomposition(aten.var_mean)
2534@out_wrapper("out0", "out1")
2535def var_mean(
2536    a: TensorLikeType,
2537    dim: Optional[DimsType] = None,
2538    unbiased: Optional[bool] = None,
2539    keepdim: bool = False,
2540    *,
2541    correction: Optional[NumberType] = None,
2542):
2543    dim, unbiased = _dim_var_dispatch(dim, unbiased)
2544    v = var(a, dim, unbiased, keepdim, correction=correction)
2545    m = mean(a, dim, keepdim)
2546    return v, m
2547
2548
2549@register_decomposition(aten.addr)
2550@out_wrapper()
2551@elementwise_type_promotion_wrapper(
2552    type_promoting_args=("self", "vec1", "vec2"),
2553    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2554)
2555def addr(
2556    self: TensorLikeType,
2557    vec1: TensorLikeType,
2558    vec2: TensorLikeType,
2559    *,
2560    beta: NumberType = 1,
2561    alpha: NumberType = 1,
2562) -> TensorLikeType:
2563    torch._check(
2564        vec1.ndim == 1,
2565        lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
2566    )
2567    torch._check(
2568        vec2.ndim == 1,
2569        lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
2570    )
2571    for arg, arg_name in ((alpha, "alpha"), (beta, "beta")):
2572        if isinstance(arg, bool):
2573            torch._check(
2574                utils.is_boolean_dtype(self.dtype)
2575                and utils.is_boolean_dtype(vec1.dtype)
2576                and utils.is_boolean_dtype(vec2.dtype),
2577                lambda: f"Boolean {arg_name} only supported for Boolean results.",
2578            )
2579    self = self.expand(vec1.shape[0], vec2.shape[0])
2580    if utils.is_boolean_dtype(self.dtype):
2581        # Integers are accepted for booleans
2582        torch._check(
2583            is_weakly_lesser_type(type(beta), int),
2584            lambda: f"expected bool/int beta but got {type(beta)}",
2585        )
2586        torch._check(
2587            is_weakly_lesser_type(type(alpha), int),
2588            lambda: f"expected bool/int alpha but got {type(beta)}",
2589        )
2590        if not beta:
2591            return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
2592        else:
2593            return torch.logical_or(
2594                self,
2595                torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
2596            )
2597    else:
2598        torch._check(
2599            is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
2600            lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
2601        )
2602        torch._check(
2603            is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
2604            lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
2605        )
2606        if beta == 0:
2607            # This means NaNs from self are dropped if beta is zero
2608            return alpha * torch.outer(vec1, vec2)
2609        else:
2610            return beta * self + alpha * torch.outer(vec1, vec2)
2611
2612
2613# CompositeImplicitAutograd - don't register decomp
2614def atleast_1d(
2615    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
2616) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
2617    """Reference implementation of :func:`torch.atleast_1d`."""
2618    if not args and isinstance(arg, collections.abc.Sequence):
2619        args_ = arg
2620    else:
2621        assert not isinstance(arg, collections.abc.Sequence)
2622        args_ = (arg,) + args
2623    res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
2624    return res if len(res) > 1 else res[0]
2625
2626
2627# Helper function with assert to avoid MyPy error
2628# of incompatible type passed to unsqueeze
2629def _unsqueeze_atleast(
2630    at_least_fn: Callable, dim: int, arg: TensorLikeType
2631) -> TensorLikeType:
2632    arg_ = at_least_fn(arg)
2633    assert isinstance(arg_, TensorLike)
2634    return unsqueeze(arg_, dim)
2635
2636
2637# CompositeImplicitAutograd - don't register decomp
2638def atleast_2d(
2639    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
2640) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
2641    """Reference implementation of :func:`torch.atleast_2d`."""
2642    if not args and isinstance(arg, collections.abc.Sequence):
2643        args_ = arg
2644    else:
2645        assert not isinstance(arg, collections.abc.Sequence)
2646        args_ = (arg,) + args
2647    unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
2648    res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
2649    return res if len(res) > 1 else res[0]
2650
2651
2652# CompositeImplicitAutograd - don't register decomp
2653def atleast_3d(
2654    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
2655) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
2656    """Reference implementation of :func:`torch.atleast_3d`."""
2657    if not args and isinstance(arg, collections.abc.Sequence):
2658        args_ = arg
2659    else:
2660        assert not isinstance(arg, collections.abc.Sequence)
2661        args_ = (arg,) + args
2662    unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
2663    res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
2664    return res if len(res) > 1 else res[0]
2665
2666
2667def as_strided(
2668    a: TensorLikeType,
2669    size: ShapeType,
2670    stride: StrideType,
2671    storage_offset: Optional[int] = None,
2672) -> TensorLikeType:
2673    storage_offset_int = (
2674        storage_offset if storage_offset is not None else a.storage_offset()
2675    )
2676    return prims.as_strided(a, size, stride, storage_offset_int)
2677
2678
2679@register_decomposition(aten.as_strided_scatter)
2680@out_wrapper()
2681def as_strided_scatter(
2682    input: TensorLikeType,
2683    src: TensorLikeType,
2684    size: ShapeType,
2685    stride: StrideType,
2686    storage_offset: Optional[int] = None,
2687) -> TensorLikeType:
2688    storage_offset_int = 0 if storage_offset is None else storage_offset
2689    return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
2690
2691
2692def broadcast_shapes(*shapes) -> ShapeType:
2693    return torch.Size(_broadcast_shapes(*shapes))
2694
2695
2696@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
2697@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
2698def broadcast_tensors(*tensors) -> List[TensorLikeType]:
2699    if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
2700        tensors = tensors[0]
2701    return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
2702
2703
2704# CompositeImplicitAutograd - don't register decomp
2705def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
2706    start = len(size) - len(a.shape)
2707    dims = tuple(range(start, len(a.shape) + start))
2708    return prims.broadcast_in_dim(a, size, dims)
2709
2710
2711@register_decomposition(aten.cat)
2712@out_wrapper()
2713@elementwise_type_promotion_wrapper(
2714    type_promoting_args=("tensors",),
2715    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
2716)
2717def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
2718    def cat_compute_output_memory_format(inputs):
2719        format = None
2720        for t in inputs:
2721            f = utils.suggest_memory_format(t)
2722            if f == torch.contiguous_format:
2723                return f
2724            if format is not None and format != f:
2725                return torch.contiguous_format
2726            format = f
2727        assert format is not None
2728        return format
2729
2730    if len(tensors) == 0:
2731        msg = "cat expects at least one tensor, but received zero!"
2732        raise ValueError(msg)
2733
2734    for tensor in tensors:
2735        assert isinstance(tensor, TensorLike)
2736
2737    utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
2738
2739    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
2740
2741    # This is a bit tricky.  Naively, you would expect to just pick one
2742    # arbitrary tensor and check that all tensors match this tensor.  However,
2743    # there is legacy behavior which says that if you have a 1-D empty tensor
2744    # (0,), this is permissible.  So you can't assume that all the tensors
2745    # have same dimensionality, and you can't assume that the first tensor is
2746    # the correct stencil.
2747    #
2748    # We'll implement this in a few passes.  First, we will try to infer the
2749    # ndim of the cat output.  If this ndim != 1, then we know that all ndim =
2750    # 1 inputs must be empty, or are errors.  If this ndim == 1, then life
2751    # is easy (the legacy special case coincides with regular handling).
2752    #
2753    # NB: The regular implementation of cat just filters out empty inputs,
2754    # but we do it slightly different here for better handling for unbacked
2755    # SymInts
2756
2757    example = None
2758    for i, t in enumerate(tensors):
2759        if example is None:
2760            if t.ndim != 1:
2761                example = t
2762        else:
2763            if t.ndim != 1:
2764                torch._check(
2765                    t.ndim == example.ndim,
2766                    lambda: "Number of dimensions of tensors must match.  "
2767                    f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for "
2768                    f"tensor number {i} in the list",
2769                )
2770
2771    if example is None:
2772        # example is None if everything is 1-D.  If so, just arbitrarily pick
2773        # the first one
2774        example = tensors[0]
2775
2776    shape = example.shape
2777    filtered = []
2778    for tensor_idx, tensor in enumerate(tensors):
2779        if len(shape) != len(tensor.shape):
2780            assert tensor.ndim == 1  # we've already checked this above
2781            # Don't suggest the legacy behavior in the error message
2782            torch._check(
2783                # NB: it is not enough to simply assert that tensor.shape[0] == 0;
2784                # this MUST be true even under guard size oblivious.
2785                # Effectively, we must actually know that the shape is zero,
2786                # passing an unbacked SymInt which we will defer a runtime
2787                # assert on won't cut it.  This is a policy decision (size
2788                # oblivious semantics say that u0 tensors never are inferred
2789                # to be zero size, even if they must be that for the cat to go
2790                # through), and is load bearing for our Inductor lowerings
2791                # (which assume that size oblivious tests are OK to determine
2792                # if a shape is permissibly zero.)
2793                guard_size_oblivious(tensor.shape[0] == 0),
2794                lambda: f"Number of dimensions of tensors must match.  "
2795                f"Expected {example.ndim}-D tensors, but got 1-D for "
2796                f"tensor number {tensor_idx} in the list",
2797            )
2798        else:
2799            # Remove inputs that are 1-D, zero size
2800            if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0):
2801                continue
2802            # Don't bother checking size match, prims.cat will handle it
2803            filtered.append(tensor)
2804
2805    memory_format = cat_compute_output_memory_format(tensors)
2806
2807    if len(filtered) == 0:
2808        t = tensors[0]
2809
2810        # TODO: fix this to work with meta tensors
2811        try:
2812            # BUG? This looks like it wants to call builtins.any() but is
2813            # actually calling .any() (in this file). Changing to builtins.any()
2814            # causes tests to fail:
2815            # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \
2816            #   TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32
2817            requires_grad = bool(any(x.requires_grad for x in tensors))  # type: ignore[arg-type]
2818        except Exception:
2819            requires_grad = False  # type: ignore[assignment]
2820
2821        return empty(
2822            (0,),
2823            dtype=t.dtype,
2824            device=t.device,
2825            requires_grad=requires_grad,
2826            memory_format=memory_format,
2827        )
2828
2829    dim = utils.canonicalize_dim(filtered[0].ndim, dim)
2830    utils.validate_idx(filtered[0].ndim, dim)
2831
2832    return prims.cat(filtered, dim).clone(memory_format=memory_format)
2833
2834
2835# CompositeImplicitAutograd - don't register decomp
2836@out_wrapper()
2837def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
2838    aligned_tensors = tuple(
2839        x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
2840    )
2841    return cat(aligned_tensors, 1)
2842
2843
2844def conj(input: TensorLikeType) -> TensorLikeType:
2845    if not utils.is_complex_dtype(input.dtype):
2846        return input
2847    if input.is_sparse:
2848        return torch.conj_physical(input)
2849    return prims.conj(input)
2850
2851
2852# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
2853@register_decomposition(aten.constant_pad_nd)
2854@out_wrapper()
2855def constant_pad_nd(
2856    input: TensorLikeType, pad: List[int], value: NumberType = 0
2857) -> TensorLikeType:
2858    torch._check(
2859        len(pad) % 2 == 0,
2860        lambda: f"Length of pad must be even but instead it equals {len(pad)}",
2861    )
2862
2863    input_sizes = input.shape
2864    l_inp = len(input_sizes)
2865
2866    l_pad = len(pad) // 2
2867    l_diff = l_inp - l_pad
2868
2869    torch._check(
2870        l_inp >= l_pad,
2871        lambda: "Length of pad should be no more than twice the number of "
2872        f"dimensions of the input. Pad length is {len(pad)} while the input has "
2873        f"{l_inp} dimensions.",
2874    )
2875
2876    c_input = input
2877    for i in range(l_diff, l_inp):
2878        pad_idx = 2 * (l_inp - i - 1)
2879        if pad[pad_idx] < 0:
2880            c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
2881
2882        if pad[pad_idx + 1] < 0:
2883            c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
2884
2885    # If all the pads are negative we can return the result.
2886    # Avoid early exiting if all pads = 0 to prevent specialization on export.
2887    # During export, raw if statements are specialized on the input, meaning
2888    # that we lose a branch depending on the example input used to export.
2889    # Here, this is either the case where all pads = 0, or the case where at
2890    # least one pad > 0 and the rest are >= 0.
2891    # Avoiding the early exit when all pads = 0 ensures we can export
2892    # constant_pad_nd for cases when all pads >= 0.
2893    # Note: if any pads are negative, this code specializes due to the if statements above.
2894    if builtins.all(p < 0 for p in pad):
2895        return c_input.clone()
2896
2897    new_shape = list(input_sizes[:l_diff])
2898
2899    for i in range(l_pad):
2900        pad_idx = len(pad) - ((i + 1) * 2)
2901        new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
2902        torch._check(
2903            new_dim > 0,
2904            lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
2905            f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
2906            f"which is invalid. Check dimension {l_diff + i} of your input.",
2907        )
2908        new_shape.append(new_dim)
2909
2910    memory_format = utils.suggest_memory_format(input)
2911    output = torch.empty(
2912        new_shape,
2913        dtype=input.dtype,
2914        device=input.device,
2915        requires_grad=input.requires_grad,
2916        memory_format=memory_format,
2917    )
2918
2919    if value == 0 and input.dtype == torch.bool:
2920        value = False
2921    # torch.fill isn't typed to allow complex values
2922    output = torch.fill(output, value)  # type: ignore[arg-type]
2923
2924    c_output = output
2925    for i in range(l_diff, l_inp):
2926        pad_idx = 2 * (l_inp - i - 1)
2927        if pad[pad_idx] >= 0:
2928            c_output = c_output.narrow(
2929                i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
2930            )
2931        if pad[pad_idx + 1] >= 0:
2932            c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
2933
2934    prims.copy_to(c_output, c_input)
2935    return output
2936
2937
2938def contiguous(
2939    a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
2940) -> Tensor:
2941    torch._check(
2942        memory_format != torch.preserve_format,
2943        lambda: "preserve memory format is unsupported by the contiguous operator",
2944    )
2945
2946    if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
2947        return a
2948
2949    return torch.clone(a, memory_format=memory_format)
2950
2951
2952@out_wrapper()
2953def dstack(tensors: TensorSequenceType) -> TensorLikeType:
2954    torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
2955    aligned_tensors = atleast_3d(*tensors)
2956    return cat(aligned_tensors, 2)
2957
2958
2959@register_decomposition(aten.expand)
2960def expand(a: Tensor, *shape) -> Tensor:
2961    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
2962
2963    # NOTE: cannot use utils.extract_shape_from_varargs here
2964    # because that also validates the shape, but the shape
2965    # given to expand may be "invalid"
2966    if len(shape) == 1 and isinstance(shape[0], Sequence):
2967        shape = tuple(shape[0])
2968
2969    torch._check(
2970        len(shape) >= len(a.shape),
2971        lambda: "expand: the requested shape has too few dimensions!",
2972    )
2973
2974    offset = len(shape) - len(a.shape)
2975    shape_ = list(shape)
2976    for idx, x in enumerate(a.shape):
2977        offset_idx = idx + offset
2978        requested_length = shape[offset_idx]
2979        torch._check(
2980            guard_size_oblivious(requested_length == x)
2981            or guard_size_oblivious(x == 1)
2982            or requested_length == -1,
2983            lambda: f"expand: attempting to expand a dimension of length {x}!",
2984        )
2985
2986        shape_[offset_idx] = requested_length if requested_length != -1 else x
2987
2988    # At this point shape must be valid
2989    utils.validate_shape(shape_)
2990
2991    return prims.broadcast_in_dim(
2992        a, shape_, tuple(range(offset, len(a.shape) + offset))
2993    )
2994
2995
2996# CompositeImplicitAutograd - don't register decomp
2997def expand_as(a: Tensor, b: Tensor) -> Tensor:
2998    return a.expand(b.shape)
2999
3000
3001def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
3002    if chunks <= 0:
3003        msg = f"Expected at least one chunk, but got {chunks}!"
3004        raise ValueError(msg)
3005
3006    dim = utils.canonicalize_dim(a.ndim, dim)
3007    length = a.shape[dim]
3008    chunk_size = math.ceil(length / chunks)
3009    full_chunks = math.floor(length / chunk_size)
3010    tail_chunk_size = length % chunk_size
3011
3012    result = []
3013    for i in range(full_chunks):
3014        result.append(narrow(a, dim, i * chunk_size, chunk_size))
3015
3016    if tail_chunk_size != 0:
3017        result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
3018
3019    return tuple(result)
3020
3021
3022# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
3023# a 0D tensor is flattened, in which case it's returned in 1D)
3024# CompositeImplicitAutograd - don't register decomp
3025def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
3026    start_dim = utils.canonicalize_dim(a.ndim, start_dim)
3027    end_dim = utils.canonicalize_dim(a.ndim, end_dim)
3028
3029    # Short-circuits on no-op
3030    if start_dim == end_dim and a.ndim != 0:
3031        return a
3032
3033    # Tries to take a view
3034    # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
3035    new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
3036    if new_shape is not None:
3037        return prims.collapse_view(a, start_dim, end_dim)
3038
3039    # Makes a copy if it can't make a view
3040    return prims.collapse(a, start_dim, end_dim)
3041
3042
3043@register_decomposition(aten.flip)
3044@out_wrapper()
3045def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
3046    if not isinstance(dims, tuple) and not isinstance(dims, list):
3047        raise ValueError("dims has to be a sequence of ints")
3048    dims = utils.canonicalize_dims(a.ndim, dims)  # type: ignore[assignment]
3049    utils.validate_no_repeating_dims(dims)
3050    return prims.rev(a, dims)
3051
3052
3053# CompositeImplicitAutograd - don't register decomp
3054def fliplr(a: TensorLikeType) -> TensorLikeType:
3055    if a.ndim < 2:
3056        raise RuntimeError("Input must be >= 2-d.")
3057
3058    return flip(a, (1,))
3059
3060
3061# CompositeImplicitAutograd - don't register decomp
3062def flipud(a: TensorLikeType) -> TensorLikeType:
3063    if a.ndim < 1:
3064        raise RuntimeError("Input must be >= 1-d.")
3065
3066    return flip(a, (0,))
3067
3068
3069# CompositeImplicitAutograd - don't register decomp
3070def narrow(
3071    a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
3072) -> TensorLikeType:
3073    # Supports Tensor overload that was added for XLA:
3074    # https://github.com/pytorch/pytorch/issues/31558
3075    if isinstance(start, TensorLike):
3076        torch._check(
3077            start.dim() == 0 and utils.is_integer_dtype(start.dtype),
3078            lambda: "start must be an 0-dim integral Tensor.",
3079        )
3080        start = start.item()  # type: ignore[assignment]
3081    torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
3082    torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
3083    dim = utils.canonicalize_dim(a.ndim, dim)
3084    dim_length = a.size(dim)
3085    torch._check_with(
3086        IndexError,
3087        -dim_length <= start and start <= dim_length,  # type: ignore[arg-type]
3088        lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
3089    )
3090    if start < 0:
3091        start = start + dim_length
3092    torch._check(
3093        start <= dim_length - length,  # type: ignore[arg-type]
3094        lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
3095    )
3096    return prims.slice_in_dim(a, start, start + length, axis=dim)
3097
3098
3099def _normalize(
3100    a: Tensor, norm_dims: DimsType, eps: float
3101) -> Tuple[Tensor, Tensor, Tensor]:
3102    """Computes mean and 1/std of a tensor along norm_dims.
3103
3104    Used as a helper function for normalization layers.
3105
3106    Args:
3107        a (Tensor): input tensor
3108        norm_dims (DimsType): dimensions to normalize over
3109        eps (float): epsilon for numerical stability
3110
3111    Returns:
3112        out (Tensor): normalized tensor.
3113        mean (Tensor): mean of the tensor along norm_dims.
3114        rstd (Tensor): 1/std of the tensor along norm_dims.
3115    """
3116    norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
3117    computation_dtype = utils.get_computation_dtype(a.dtype)
3118    a_acc = _maybe_convert_to_dtype(a, computation_dtype)
3119    assert isinstance(a_acc, TensorLike)  # to avoid mypy error for var_mean
3120    biased_var, mean = torch.var_mean(
3121        a_acc, dim=norm_dims, unbiased=False, keepdim=True
3122    )
3123    rstd = torch.rsqrt(biased_var + eps)
3124    out = (a - mean) * rstd
3125    return out, mean, rstd
3126
3127
3128# add all specified dimensions
3129def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
3130    for dim in sorted(dimensions):
3131        x = torch.unsqueeze(x, dim)
3132    return x
3133
3134
3135@register_decomposition(aten.native_group_norm.default)
3136def native_group_norm(
3137    input: Tensor,
3138    weight: Optional[Tensor],
3139    bias: Optional[Tensor],
3140    batch_size: int,
3141    num_channels: int,
3142    flattened_inner_size: int,
3143    num_groups: int,
3144    eps: float,
3145) -> Tuple[Tensor, Tensor, Tensor]:
3146    torch._check(
3147        input.ndim >= 2,
3148        lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
3149    )
3150    torch._check(
3151        num_channels % num_groups == 0,
3152        lambda: "Expected number of channels in input to be divisible by num_groups, "
3153        + f"but got input of shape {input.shape} and num_groups = {num_groups}",
3154    )
3155
3156    # num_channels / num_groups and flattened inner dimension are the reduction axes
3157    reduction_dims = [2, 3]
3158    input_reshaped = torch.reshape(
3159        input,
3160        [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
3161    )
3162    out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
3163    out = out.view(input.shape)
3164
3165    broadcast_dims = [0] + list(range(2, input.ndim))
3166    unsqueeze_bias = None
3167    if bias is not None:
3168        unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
3169    unsqueeze_weight = None
3170    if weight is not None:
3171        unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
3172
3173    if unsqueeze_weight is not None:
3174        out = out * unsqueeze_weight
3175    if unsqueeze_bias is not None:
3176        out = out + unsqueeze_bias
3177
3178    out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
3179    mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
3180    rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]
3181
3182    # remove broadcast dimensions from mean and rstd
3183    mean = torch.squeeze(mean, reduction_dims)
3184    rstd = torch.squeeze(rstd, reduction_dims)
3185    return (out, mean, rstd)
3186
3187
3188@register_decomposition(aten.native_layer_norm)
3189@out_wrapper("out0", "out1", "out2")
3190def native_layer_norm(
3191    input: Tensor,
3192    normalized_shape: ShapeType,
3193    weight: Optional[Tensor],
3194    bias: Optional[Tensor],
3195    eps: float,
3196) -> Tuple[Tensor, Tensor, Tensor]:
3197    normalized_ndim = len(normalized_shape)
3198    torch._check(
3199        normalized_ndim >= 1,
3200        lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
3201        + "containing at least one element, but got normalized_shape = "
3202        + str(normalized_shape),
3203    )
3204    # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
3205    # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
3206    # therefore we use tuple(normalized_shape)
3207    torch._check(
3208        weight is None or weight.shape == tuple(normalized_shape),
3209        lambda: "Expected weight to be of same shape as normalized_shape, but got "
3210        + "weight of shape "
3211        + str(weight.shape)  # type: ignore[union-attr]
3212        + " and normalized_shape = "
3213        + str(normalized_shape),
3214    )
3215    torch._check(
3216        bias is None or bias.shape == tuple(normalized_shape),
3217        lambda: "Expected bias to be of same shape as normalized_shape, but got "
3218        + "bias of shape "
3219        + str(bias.shape)  # type: ignore[union-attr]
3220        + " and normalized_shape = "
3221        + str(normalized_shape),
3222    )
3223    torch._check(
3224        input.ndim >= normalized_ndim
3225        and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
3226        lambda: "Given normalized_shape="
3227        + str(normalized_shape)
3228        + ", expected input with shape "
3229        + str(normalized_shape)
3230        + ", but got input of size "
3231        + str(input.shape),
3232    )
3233
3234    input = input.contiguous()
3235    if weight is not None:
3236        weight = weight.contiguous()
3237    if bias is not None:
3238        bias = bias.contiguous()
3239
3240    axis = input.ndim - normalized_ndim
3241    reduction_dims = list(range(axis, input.ndim))
3242    out, mean, rstd = _normalize(input, reduction_dims, eps)
3243
3244    if weight is None and bias is not None:
3245        out = out + bias
3246    elif weight is not None and bias is None:
3247        out = out * weight
3248    elif weight is not None and bias is not None:
3249        out = out * weight + bias
3250
3251    out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
3252    if input.device.type == "cpu":
3253        mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
3254        rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]
3255    return (out, mean, rstd)
3256
3257
3258# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
3259# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
3260@register_decomposition(aten.permute)
3261def permute(a: TensorLikeType, *dims) -> TensorLikeType:
3262    _permutation = utils.canonicalize_dims(
3263        a.ndim, utils.extract_dims_from_varargs(dims)
3264    )
3265    return prims.transpose(a, _permutation)
3266
3267
3268@register_decomposition(aten.renorm)
3269@out_wrapper()
3270def renorm(
3271    input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType
3272) -> TensorLikeType:
3273    torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued")
3274    torch._check(p > 0, lambda: "renorm: non-positive norm not supported")
3275    torch._check(
3276        not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued"
3277    )
3278    torch._check(
3279        maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}"
3280    )
3281    ndim = input.ndim
3282    torch._check(
3283        ndim > 1,
3284        lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions",
3285    )
3286
3287    dim = utils.canonicalize_dim(ndim, dim)
3288    reduce_dims = list(range(ndim))
3289    del reduce_dims[dim]
3290
3291    # For half and bfloat16, calculate norm in float precision then cast
3292    # normalization factor to half
3293    acc_type = utils.get_computation_dtype(input.dtype)
3294    if acc_type != input.dtype:
3295        norm = torch.linalg.vector_norm(
3296            input, p, reduce_dims, keepdim=True, dtype=acc_type
3297        )
3298    else:
3299        norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True)
3300
3301    eps = 1e-7
3302    norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
3303    if acc_type != input.dtype:
3304        norm_factor = prims.convert_element_type(norm_factor, input.dtype)
3305    return (input * norm_factor).contiguous()
3306
3307
3308# CompositeImplicitAutograd - don't register decomp
3309@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
3310def stft(
3311    input: Tensor,
3312    n_fft: int,
3313    hop_length: Optional[int] = None,
3314    win_length: Optional[int] = None,
3315    window: Optional[Tensor] = None,
3316    center: bool = True,
3317    pad_mode: str = "reflect",
3318    normalized: bool = False,
3319    onesided: Optional[bool] = None,
3320    return_complex: Optional[bool] = None,
3321) -> Tensor:
3322    torch._check(
3323        window is None or window.device == input.device,
3324        lambda: (
3325            f"stft input and window must be on the same device but got self on {input.device}"
3326            + f" and window on {window.device}"  # type: ignore[union-attr]
3327        ),
3328    )
3329
3330    hop_length_ = hop_length if hop_length is not None else n_fft // 4
3331    win_length_ = win_length if win_length is not None else n_fft
3332
3333    if return_complex is None:
3334        return_complex_ = input.is_complex() or (
3335            window is not None and utils.is_complex_dtype(window.dtype)
3336        )
3337        torch._check(
3338            return_complex_,
3339            (
3340                "stft requires the return_complex parameter be given for real inputs, "
3341                + "and will further require that return_complex=True in a future PyTorch release."
3342            ),
3343        )
3344    else:
3345        return_complex_ = return_complex
3346
3347    torch._check(
3348        utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
3349        lambda: "stft expected a tensor of floating point or complex values",
3350    )
3351    torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")
3352
3353    original_ndim = input.ndim
3354    if original_ndim == 1:
3355        input = input.unsqueeze(0)
3356
3357    if center:
3358        extra_dims = 3 - input.ndim
3359        pad_amount = n_fft // 2
3360        extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
3361        input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
3362        input = input.view(input.size()[extra_dims:])
3363
3364    batch = input.size(0)
3365    length = input.size(1)
3366    torch._check(
3367        0 < n_fft <= length,
3368        lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
3369    )
3370    torch._check(
3371        hop_length_ > 0,
3372        lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
3373    )
3374    torch._check(
3375        0 < win_length_ <= n_fft,
3376        lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
3377    )
3378    torch._check(
3379        window is None or window.shape == (win_length_,),
3380        lambda: (
3381            f"expected a 1D window tensor of size equal to win_length={win_length_}, "
3382            + f"but got window with size {window.shape}"  # type: ignore[union-attr]
3383        ),
3384    )
3385
3386    if win_length_ < n_fft:
3387        if window is None:
3388            window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
3389        left = (n_fft - win_length_) // 2
3390        window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
3391
3392    input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
3393    if window is not None:
3394        input = input * window
3395
3396    complex_fft = utils.is_complex_dtype(input.dtype)
3397    onesided = onesided if onesided is not None else not complex_fft
3398    norm = "ortho" if normalized else None
3399    if onesided:
3400        torch._check(
3401            not complex_fft,
3402            lambda: "Cannot have onesided output if window or input is complex",
3403        )
3404        out = torch.fft.rfft(input, dim=-1, norm=norm)
3405    else:
3406        out = torch.fft.fft(input, dim=-1, norm=norm)
3407
3408    out.transpose_(1, 2)
3409
3410    if original_ndim == 1:
3411        out = out.squeeze_(0)
3412
3413    return out if return_complex_ else torch.view_as_real(out)
3414
3415
3416# CompositeImplicitAutograd - don't register decomp
3417@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
3418def istft(
3419    input: Tensor,
3420    n_fft: int,
3421    hop_length: Optional[int] = None,
3422    win_length: Optional[int] = None,
3423    window: Optional[Tensor] = None,
3424    center: bool = True,
3425    normalized: bool = False,
3426    onesided: Optional[bool] = None,
3427    length: Optional[int] = None,
3428    return_complex=False,
3429) -> Tensor:
3430    torch._check(
3431        window is None or window.device == input.device,
3432        lambda: (
3433            f"istft input and window must be on the same device but got self on {input.device}"
3434            + f" and window on {window.device}"  # type: ignore[union-attr]
3435        ),
3436    )
3437
3438    hop_length_ = hop_length if hop_length is not None else n_fft // 4
3439    win_length_ = win_length if win_length is not None else n_fft
3440
3441    torch._check(
3442        utils.is_complex_dtype(input.dtype),
3443        lambda: (
3444            "istft input and window must be on the same device but got self on "
3445            + f"{input.device} and window on {window.device}"  # type: ignore[union-attr]
3446        ),
3447    )
3448    n_frames = input.size(-1)
3449    fft_size = input.size(-2)
3450
3451    expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
3452    torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
3453    torch._check(
3454        2 <= input.ndim <= 3,
3455        lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
3456    )
3457    onesided_ = onesided if onesided is not None else fft_size != n_fft
3458
3459    if onesided_:
3460        torch._check(
3461            n_fft // 2 + 1 == fft_size,
3462            lambda: (
3463                "istft expected the frequency dimension (3rd to the last) of the input tensor "
3464                + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
3465            ),
3466        )
3467    else:
3468        torch._check(
3469            n_fft == fft_size,
3470            lambda: (
3471                "istft expected the frequency dimension (3rd to the last) of the input tensor "
3472                + "to match n_fft when onesided=False, but got {fft_size}",
3473            ),
3474        )
3475
3476    torch._check(
3477        0 < hop_length_ <= win_length_,
3478        lambda: "istft expected 0 < hop_length <= win_length",
3479    )
3480    torch._check(
3481        0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
3482    )
3483    torch._check(
3484        window is None or window.shape == (win_length_,),
3485        lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
3486    )
3487
3488    if window is None:
3489        real_dtype = utils.corresponding_real_dtype(input.dtype)
3490        window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
3491    else:
3492        window_ = window
3493
3494    if win_length_ != n_fft:
3495        left = (n_fft - win_length_) // 2
3496        window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)
3497
3498    original_ndim = input.ndim
3499    if input.ndim == 2:
3500        input = input.unsqueeze(0)
3501
3502    input = input.transpose(1, 2)
3503    norm = "ortho" if normalized else None
3504    if return_complex:
3505        torch._check(
3506            not onesided_,
3507            lambda: "cannot have onesided output if window or input is complex",
3508        )
3509        input = torch.fft.ifft(input, dim=-1, norm=norm)
3510    else:
3511        torch._check(
3512            window is None or not utils.is_complex_dtype(window.dtype),
3513            lambda: "Complex windows are incompatible with return_complex=False",
3514        )
3515        if not onesided_:
3516            input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
3517        input = torch.fft.irfft(input, dim=-1, norm=norm)
3518
3519    assert input.size(2) == n_fft
3520
3521    y_tmp = input * window_.view([1, 1, n_fft])
3522    y = aten.unfold_backward(
3523        y_tmp,
3524        input_sizes=(y_tmp.size(0), expected_output_signal_len),
3525        dim=1,
3526        size=n_fft,
3527        step=hop_length_,
3528    )
3529    window_envelop = aten.unfold_backward(
3530        window_.pow(2).expand((1, n_frames, n_fft)),
3531        input_sizes=(y_tmp.size(0), expected_output_signal_len),
3532        dim=1,
3533        size=n_fft,
3534        step=hop_length_,
3535    )
3536
3537    assert expected_output_signal_len == y.size(1)
3538    assert expected_output_signal_len == window_envelop.size(1)
3539
3540    start = n_fft // 2 if center else 0
3541    if length is not None:
3542        end = start + length
3543    elif center:
3544        end = expected_output_signal_len - n_fft // 2
3545    else:
3546        end = expected_output_signal_len
3547
3548    length = max(0, end - start)
3549    y = y.narrow(dim=1, start=start, length=length)
3550    window_envelop = window_envelop.narrow(dim=1, start=start, length=length)
3551
3552    y = y / window_envelop
3553    if original_ndim == 2:
3554        y = y.squeeze(0)
3555
3556    if end > expected_output_signal_len:
3557        warnings.warn(
3558            "The length of signal is shorter than the length parameter. Result is being "
3559            + "padded with zeros in the tail. Please check your center and hop_length settings"
3560        )
3561        y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
3562    return y
3563
3564
3565# Get the new shape and stride after applying unfold to an input tensor
3566def _get_unfold_shape_stride(
3567    a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
3568):
3569    a_ndim = len(a_shape)
3570    dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
3571    max_size = 1 if a_ndim == 0 else a_shape[dim]
3572    last_stride = 1 if a_ndim == 0 else a_stride[dim]
3573
3574    torch._check(
3575        size <= max_size,
3576        lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
3577    )
3578
3579    torch._check(
3580        step > 0,
3581        lambda: f"Step is {step} but must be > 0",
3582    )
3583
3584    shape = list(a_shape)
3585    strides = list(a_stride)
3586    shape.append(size)
3587    strides.append(last_stride)
3588    if dim < a_ndim:
3589        shape[dim] = (shape[dim] - size) // step + 1
3590        strides[dim] *= step
3591    return shape, strides
3592
3593
3594@register_decomposition(aten.repeat)
3595@out_wrapper()
3596def repeat(a: Tensor, *repeat_shape) -> Tensor:
3597    repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
3598    torch._check(
3599        len(repeat_shape) >= len(a.shape),
3600        lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3601    )
3602
3603    if len(repeat_shape) == 0:
3604        return torch.clone(a)
3605
3606    num_new_dimensions = len(repeat_shape) - a.ndim
3607    padded_shape = [1] * num_new_dimensions
3608    for dim_size in a.shape:
3609        padded_shape.append(dim_size)
3610
3611    target_shape = tuple(
3612        padded_size * repeat_size
3613        for padded_size, repeat_size in zip(padded_shape, repeat_shape)
3614    )
3615
3616    # return an empty tensor if one of the repeat_shape dimensions is zero
3617    if 0 in repeat_shape:
3618        return torch.empty(
3619            target_shape,
3620            dtype=a.dtype,
3621            device=a.device,
3622            requires_grad=a.requires_grad,
3623            memory_format=utils.suggest_memory_format(a),
3624        )
3625
3626    urtensor_shape = target_shape
3627    urtensor_stride = utils.make_contiguous_strides_for(target_shape)
3628    for dim, dim_size in enumerate(padded_shape):
3629        # repeat each dimension by using unfold_copy operation
3630        urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
3631            urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
3632        )
3633
3634    # derive permute order by sorting urtensor strides
3635    enumerated_stride = list(enumerate(urtensor_stride))
3636    enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
3637    permute_order, sorted_stride = zip(*enumerated_stride)
3638
3639    # add new and expand dimensions according to urtensor
3640    repeat_xtensor = a.expand(urtensor_shape)
3641
3642    # clone tensor to concretize expanded dimensions
3643    cloned_result = torch.clone(repeat_xtensor)
3644
3645    # transpose axis so strides are in sorted order
3646    permuted_result = cloned_result.permute(permute_order)
3647
3648    # reshape to get contiguous tensor with correct target shape
3649    return permuted_result.reshape(target_shape)
3650
3651
3652def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
3653    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
3654
3655    # Creates a valid shape
3656    shape = utils.extract_shape_from_varargs(shape, validate=False)
3657    # Reshape may be given a shape with a -1 length
3658    # This indicates that the dimension's length should be inferred
3659    shape = utils.infer_size(shape, a.numel())
3660
3661    # Special-cases tensors with no elements
3662    if guard_size_oblivious(a.numel() == 0):
3663        return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
3664
3665    # Special-cases reshaping zero dim tensors
3666    if a.ndim == 0:
3667        _a = a
3668        for length in shape:
3669            assert length == 1
3670            _a = unsqueeze(_a, -1)
3671        if _a is a:
3672            return prims.view_of(a)
3673        else:
3674            return _a
3675
3676    # Special-cases reshaping to zero dim tensors
3677    if len(shape) == 0:
3678        _a = a
3679        for length in a.shape:
3680            assert length == 1
3681            _a = squeeze(_a, -1)
3682        if _a is a:
3683            return prims.view_of(a)
3684        else:
3685            return _a
3686
3687    if a.is_contiguous():
3688        # Special-cases for nd_to_1d
3689        if len(shape) == 1 and a.ndim > 1:
3690            return torch.as_strided(a, [a.numel()], [1])
3691        # Special-cases for 1d_to_2d
3692        if len(shape) == 2 and a.ndim == 1:
3693            dim0 = shape[0]
3694            dim1 = shape[1]
3695            return torch.as_strided(a, [dim0, dim1], [dim1, 1])
3696
3697    # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
3698
3699    # NOTE [Reshape Algorithm]
3700    # This algorithm works by attempting to greedily construct the desired dimensions in
3701    # the output shape, left to right. It does this by, conceptually, accumulating
3702    # dimensions of the original tensor, also left to right, until the dimension
3703    # can be constructed using prims.split_dim.
3704    # The algorithm also has special handling for tail squeezes/unsqueezes, like
3705    # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
3706    #
3707    # This algorithm does not flatten the original tensor and then split dims as appropriate
3708    # because that would create copies more often than this algorithm. flatten is the only
3709    # operation below which can create a view or a copy, and while it prefers creating
3710    # views it may sometimes create a copy if the tensor's strides do not permit a view.
3711    # As a result, this algorithm tries to minimize flattening.
3712    #
3713    # Note that a better version of this algorithm may exist. Regions which could be
3714    # flattened without creating a copy can be identified in advance, and that might
3715    # allow fewer flatten calls or faster short-circuiting to make a copy.
3716    idx = 0
3717    a_ = a
3718    for length in shape:
3719        # Handles tail unsqueezes
3720        if idx >= a_.ndim:
3721            assert length == 1
3722            last_dim = a_.ndim - 1
3723            # NOTE: using split_dim instead of unsqueeze may seem silly here,
3724            # but it's necessary to get the strides correct
3725            a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
3726            idx = idx + 1
3727            continue
3728
3729        # Skips dimensions that are already the correct length
3730        if guard_size_oblivious(length == a_.shape[idx]):
3731            idx = idx + 1
3732            continue
3733
3734        # Gathers enough original dimensions such that this new dimension can be created
3735        # Note that this accumulation will terminate because we've verified a and the shape
3736        # specify the same number of elements above
3737        accum = a_.shape[idx]
3738        end = idx
3739        while guard_size_oblivious(accum % length != 0):
3740            end = end + 1
3741            accum = accum * a_.shape[end]
3742        if end != idx:
3743            # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
3744            # This flattening is why reshape sometimes creates a copy -- because flattening
3745            # may return a view of a copy
3746
3747            # Checks if collapse can be a view and short-circuits to copying reshape if it can't
3748            new_shape, new_strides = prims._collapse_view_helper(a_, idx, end)
3749            if new_shape is None:
3750                if allow_copy:
3751                    return prims.reshape(a, shape)
3752
3753                msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
3754                raise ValueError(msg)
3755
3756            a_ = flatten(a_, idx, end)
3757
3758        # Splits the (possibly flattened) dimension to create the desired dim length
3759        if guard_size_oblivious(accum != length):
3760            a_ = prims.split_dim(a_, idx, length)
3761
3762        idx = idx + 1
3763
3764    # Squeezes tail
3765    while idx < a_.ndim:
3766        torch._check(
3767            a_.shape[idx] == 1,
3768            lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}",
3769        )
3770        a_ = squeeze(a_, idx)
3771
3772    if a_ is a:
3773        return prims.view_of(a)
3774    else:
3775        return a_
3776
3777
3778# CompositeImplicitAutograd - don't register decomp
3779# NOTE: shape is a vararg because Tensor.reshape can be called with as
3780# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
3781# torch.reshape doesn't support unpacked shapes
3782def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
3783    return _reshape_view_helper(a, *shape, allow_copy=True)
3784
3785
3786# CompositeImplicitAutograd - don't register decomp
3787def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
3788    return self.reshape(other.size())
3789
3790
3791@register_decomposition(aten.roll)
3792@out_wrapper()
3793def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
3794    """Reference implementation of :func:`torch.roll`."""
3795    dims = utils.canonicalize_dims(a.ndim, dims)
3796    # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
3797    if not isinstance(shifts, Iterable):
3798        shifts = (shifts,)
3799    if not isinstance(dims, Iterable):
3800        dims = (dims,)
3801
3802    # Avoid modulo by zero
3803    if a.numel() == 0:
3804        # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
3805        return a.clone()
3806
3807    if a.dim() == 0 and len(dims) > 0:
3808        raise IndexError(
3809            f"Dimension specified as {dims[0]} but tensor has no dimensions"
3810        )
3811
3812    len_shifts = len(shifts)
3813    len_dims = len(dims)
3814    if len_shifts != 1 or len_dims != 1:
3815        if len_shifts == 0:
3816            raise RuntimeError("`shifts` required")
3817        # Takes care of the case when dims is not specified (default)
3818        # By default, the tensor is flattened before shifting, after which the original shape is restored
3819        if len_dims == 0 and len_shifts == 1:
3820            return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
3821        if len_shifts != len_dims:
3822            raise RuntimeError(
3823                f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
3824            )
3825        assert len_dims > 1
3826        tail_shifts = shifts[1:]
3827        tail_dims = dims[1:]
3828        first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
3829        return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
3830
3831    # This path is taken when only one dimension is rolled
3832    # For example to get `first_dim_rolled` above
3833    dim = dims[0]
3834    size = a.shape[dim]
3835    start = (size - shifts[0]) % size
3836    idx = torch.arange(size, device=a.device)
3837    return a.index_select(dim, torch.fmod(start + idx, size))
3838
3839
3840@register_decomposition(aten.rot90)
3841@out_wrapper()
3842def rot90(
3843    a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
3844) -> TensorLikeType:
3845    """Reference implementation of :func:`torch.rot90`."""
3846    if len(dims) != 2:
3847        raise RuntimeError(
3848            f"expected total rotation dims == 2, but got dims = {len(dims)}"
3849        )
3850    if a.ndim < 2:
3851        raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
3852
3853    # Do this after the initial checks to be compatible with the behavior in
3854    # core.
3855    dims = utils.canonicalize_dims(a.ndim, dims)
3856
3857    if dims[0] == dims[1]:
3858        raise RuntimeError(
3859            f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
3860        )
3861    k = k % 4  # Rotation direction is from the second towards the first axis for k < 0
3862    if k == 1:
3863        return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
3864    elif k == 2:
3865        return torch.flip(a, dims)
3866    elif k == 3:
3867        return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
3868    else:
3869        return a.clone(memory_format=torch.contiguous_format)
3870
3871
3872def _check_stack_inputs(tensors: TensorSequenceType) -> None:
3873    entry_shape = tensors[0].shape
3874    for i in range(1, len(tensors)):
3875        assert tensors[i].shape == entry_shape, (
3876            f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
3877            f"and {tensors[i].shape} at entry {i}"
3878        )
3879
3880
3881@register_decomposition(aten.stack)
3882@out_wrapper()
3883def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
3884    assert len(tensors) > 0, "stack expects a non-empty TensorList"
3885    wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
3886    # Refs need sparse support to check other condition
3887    if wrapped_dim < tensors[0].ndim:  # and not tensors[0].is_sparse:
3888        _check_stack_inputs(tensors)
3889        result_sizes = list(tensors[0].shape)
3890        result_sizes.insert(wrapped_dim, len(tensors))
3891        out = torch.cat(tensors, wrapped_dim)
3892        return out.view(result_sizes)
3893
3894    # If dim == tensors[0].ndim, view cannot efficiently handle it
3895    return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
3896
3897
3898# CompositeImplicitAutograd - don't register decomp
3899@out_wrapper()
3900def softmax(
3901    a: TensorLikeType,
3902    dim: int,
3903    dtype: Optional[torch.dtype] = None,
3904) -> TensorLikeType:
3905    result_dtype = dtype or a.dtype
3906    computation_dtype = utils.get_computation_dtype(result_dtype)
3907    a_ = _maybe_convert_to_dtype(a, computation_dtype)
3908    if a.numel() == 0:
3909        a_exp = exp(a_)
3910    else:
3911        a_max = amax(a_, dim, keepdim=True)
3912        a_exp = exp(a_ - a_max)
3913    return _maybe_convert_to_dtype(
3914        true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
3915    )  # type: ignore[return-value]
3916
3917
3918# CompositeImplicitAutograd - don't register decomp
3919@out_wrapper()
3920def hstack(tensors: TensorSequenceType) -> TensorLikeType:
3921    torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
3922    aligned_tensors = atleast_1d(*tensors)
3923    if aligned_tensors[0].ndim == 1:
3924        return cat(aligned_tensors, 0)
3925    return cat(aligned_tensors, 1)
3926
3927
3928# CompositeImplicitAutograd - don't register decomp
3929@out_wrapper()
3930def vstack(tensors: TensorSequenceType) -> TensorLikeType:
3931    torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
3932    aligned_tensors = atleast_2d(*tensors)
3933    return cat(aligned_tensors, 0)
3934
3935
3936# CompositeImplicitAutograd - don't register decomp
3937def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
3938    dim = utils.canonicalize_dim(a.ndim, dim)
3939    torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
3940    return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
3941
3942
3943@register_decomposition(aten.unbind)
3944def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
3945    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
3946
3947    dim = utils.canonicalize_dim(t.ndim, dim)
3948    torch._check_index(
3949        len(t.shape) > 0,
3950        lambda: "Dimension specified as 0 but tensor has no dimensions",
3951    )
3952    if guard_size_oblivious(t.shape[dim] == 0):
3953        return ()
3954    else:
3955        return tuple(
3956            torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
3957        )
3958
3959
3960@out_wrapper()
3961def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
3962    return x.clone(memory_format=torch.contiguous_format).index_copy_(
3963        dim, index, tensor
3964    )
3965
3966
3967def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
3968    dim = utils.canonicalize_dims(x.ndim, dim)
3969    torch._check(
3970        index.ndim <= 1,
3971        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
3972    )
3973    # Treat scalars as elements of \R^1
3974    y = x.unsqueeze(0) if x.ndim == 0 else x
3975    idx = (slice(None),) * dim + (index,)
3976    y[idx] = tensor
3977    return x
3978
3979
3980@register_decomposition(aten.index_fill)
3981@out_wrapper()
3982def index_fill(
3983    x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
3984):
3985    return _index_fill(x, dim, index, value, inplace=False)
3986
3987
3988@register_decomposition(aten.index_fill_)
3989def index_fill_(
3990    x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
3991):
3992    return _index_fill(x, dim, index, value, inplace=True)
3993
3994
3995def _index_fill(
3996    x: TensorLike,
3997    dim: int,
3998    index: TensorLike,
3999    value: Union[NumberType, TensorLike],
4000    *,
4001    inplace: bool,
4002):
4003    torch._check(
4004        index.ndim <= 1,
4005        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
4006    )
4007    if isinstance(value, TensorLike):
4008        torch._check(
4009            value.ndim == 0,
4010            lambda: "Only supports 0-dimensional value tensor. "  # type: ignore[union-attr]
4011            f"Got a tensor with {value.ndim} dimensions.",
4012        )  # type: ignore[arg-type]
4013    else:
4014        value = torch.scalar_tensor(
4015            value, dtype=x.dtype, layout=x.layout, device=x.device  # type: ignore[arg-type]
4016        )
4017
4018    # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
4019    zero_dim = x.ndim == 0
4020    y = x.unsqueeze(0) if zero_dim else x
4021    # index_copy does not broadcast on value so we have to do it manually
4022    shape = list(y.shape)
4023    shape[dim] = index.numel()
4024    value = value.expand(shape)
4025    index_copy = Tensor.index_copy_ if inplace else torch.index_copy
4026    out = index_copy(y, dim, index, value)  # type: ignore[operator]
4027    if inplace:
4028        return x
4029    else:
4030        if zero_dim:
4031            # The clone is necessary so that it returns a fresh tensor rather than a view
4032            out = out.squeeze(0).clone()
4033        # index_fill preserves the strides. index_copy always returns contiguous tensors
4034        if out.stride() != x.stride():
4035            new_out = torch.empty_like(x)
4036            new_out.copy_(out)
4037            out = new_out
4038        return out
4039
4040
4041@out_wrapper()
4042def index_add(
4043    x: TensorLike,
4044    dim: int,
4045    index: TensorLike,
4046    tensor: TensorLike,
4047    *,
4048    alpha: NumberType = 1,
4049):
4050    # index_add always returns a new contiguous tensor
4051    return x.clone(memory_format=torch.contiguous_format).index_add_(
4052        dim, index, tensor, alpha=alpha  # type: ignore[arg-type]
4053    )
4054
4055
4056@register_decomposition(aten.index_select)
4057@out_wrapper()
4058def index_select(x: TensorLike, dim: int, index: TensorLike):
4059    dim = utils.canonicalize_dims(x.ndim, dim)
4060    torch._check(
4061        index.ndim <= 1,
4062        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
4063    )
4064    if index.ndim == 0:
4065        index = index.unsqueeze(0)
4066    if x.ndim == 0:
4067        # Treat scalars as elements of \R^1
4068        # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
4069        return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
4070
4071    idx = (slice(None),) * dim + (index,)
4072    return x[idx]
4073
4074
4075@register_decomposition(aten.squeeze.dims)
4076def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
4077    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4078
4079    if dim is None:
4080        dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
4081        return prims.squeeze(a, dims) if dims else prims.view_of(a)
4082
4083    ndim = a.ndim
4084    dim = utils.canonicalize_dims(ndim, dim)
4085    dims = (dim,) if isinstance(dim, Dim) else dim
4086    # Short-circuits if the tensor has no dimensions
4087    if ndim == 0:
4088        assert len(dims) == 0 or dims == (0,)
4089        return prims.view_of(a)
4090
4091    # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
4092    dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
4093    if len(dims) == 0:
4094        return prims.view_of(a)
4095    if len(dims) == 1:
4096        return prims.squeeze(a, dims)
4097    dims_list = list(dims)
4098    dims_list = sorted(dims_list, reverse=True)
4099    for i in dims_list:
4100        a = squeeze(a, i)
4101    return a
4102
4103
4104# Note: does not work with TensorMetas because of data-dependent control-flow
4105# CompositeImplicitAutograd - don't register decomp
4106def tensor_split(
4107    a: TensorLikeType,
4108    indices_or_sections: Union[Tensor, DimsType],
4109    dim: int = 0,
4110) -> Tuple[TensorLikeType, ...]:
4111    _dim = utils.canonicalize_dim(a.ndim, dim)
4112    if a.ndim == 0:
4113        msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
4114        raise ValueError(msg)
4115
4116    # If indices_or_sections is a tensor, it must be a CPU Long tensor
4117    if isinstance(indices_or_sections, TensorLike):
4118        if not indices_or_sections.device.type == "cpu":
4119            msg = (
4120                f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
4121                f"but received one on {indices_or_sections.device}"
4122            )
4123            raise ValueError(msg)
4124        if indices_or_sections.dtype != torch.long:
4125            msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
4126            f" but received one with dtype {indices_or_sections.dtype}"
4127            raise ValueError(msg)
4128
4129    # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
4130    if isinstance(indices_or_sections, IntLike) or (
4131        isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
4132    ):
4133        sections: int = (
4134            indices_or_sections  # type: ignore[assignment]
4135            if isinstance(indices_or_sections, Number)
4136            else indices_or_sections.item()
4137        )
4138
4139        if sections <= 0:
4140            msg = f"tensor_split: number of sections must be greater than 0, but was {sections}"
4141            raise ValueError(msg)
4142
4143        splits = []
4144        dim_size = a.shape[_dim]
4145        min_split_size = math.floor(dim_size / sections)
4146        num_splits_one_extra = dim_size % sections
4147        start_idx = 0
4148        for split_idx in range(sections):
4149            split_size = (
4150                min_split_size + 1
4151                if (split_idx < num_splits_one_extra)
4152                else min_split_size
4153            )
4154            s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim)
4155            splits.append(s)
4156            start_idx = start_idx + split_size
4157
4158        return tuple(splits)
4159    # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
4160    else:
4161        indices = indices_or_sections
4162        if isinstance(indices_or_sections, TensorLike):
4163            if indices_or_sections.ndim != 1:
4164                msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
4165                f"but received a tensor with {indices_or_sections.ndim} dimensions"
4166                raise ValueError(msg)
4167
4168            indices = indices_or_sections.tolist()
4169
4170        splits = []
4171        start_idx = 0
4172        for x in indices:
4173            splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim))
4174            start_idx = x
4175        splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim))
4176        return tuple(splits)
4177
4178
4179# CompositeImplicitAutograd - don't register decomp
4180def hsplit(
4181    a: TensorLikeType, indices_or_sections: DimsType
4182) -> Tuple[TensorLikeType, ...]:
4183    torch._check(
4184        a.ndim >= 1,
4185        lambda: (
4186            "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
4187            + str(a.ndim)
4188            + " dimensions!"
4189        ),
4190    )
4191    dim = 0 if a.ndim == 1 else 1
4192    if isinstance(indices_or_sections, IntLike):
4193        split_size = indices_or_sections
4194        torch._check(
4195            (split_size != 0 and a.shape[dim] % split_size == 0),
4196            lambda: (
4197                "torch.hsplit attempted to split along dimension "
4198                + str(dim)
4199                + ", but the size of the dimension "
4200                + str(a.shape[dim])
4201                + " is not divisible by the split_size "
4202                + str(split_size)
4203                + "!"
4204            ),
4205        )
4206        return tensor_split(a, split_size, dim)
4207
4208    torch._check_type(
4209        isinstance(indices_or_sections, (list, tuple)),
4210        lambda: (
4211            "hsplit(): received an invalid combination of arguments. "
4212            "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
4213            f"but got type {type(indices_or_sections)}"
4214        ),
4215    )
4216
4217    split_sizes = indices_or_sections
4218    return tensor_split(a, split_sizes, dim)
4219
4220
4221# CompositeImplicitAutograd - don't register decomp
4222def vsplit(
4223    a: TensorLikeType, indices_or_sections: DimsType
4224) -> Tuple[TensorLikeType, ...]:
4225    torch._check(
4226        a.ndim >= 2,
4227        lambda: (
4228            "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
4229            + str(a.ndim)
4230            + " dimensions!"
4231        ),
4232    )
4233    if isinstance(indices_or_sections, IntLike):
4234        split_size = indices_or_sections
4235        torch._check(
4236            (split_size != 0 and a.shape[0] % split_size == 0),
4237            lambda: (
4238                f"torch.vsplit attempted to split along dimension 0"
4239                f", but the size of the dimension "
4240                f"{a.shape[0]}"
4241                f" is not divisible by the split_size "
4242                f"{split_size}"
4243                f"!"
4244            ),
4245        )
4246        return tensor_split(a, split_size, 0)
4247
4248    torch._check_type(
4249        isinstance(indices_or_sections, (list, tuple)),
4250        lambda: (
4251            "vsplit(): received an invalid combination of arguments. "
4252            "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
4253            f"but got type {type(indices_or_sections)}"
4254        ),
4255    )
4256
4257    split_sizes = indices_or_sections
4258    return tensor_split(a, split_sizes, 0)
4259
4260
4261@register_decomposition(aten.diag.out)
4262@out_wrapper()
4263def diag(
4264    self: TensorLikeType,
4265    offset: int = 0,
4266) -> TensorLikeType:
4267    ndim = self.dim()
4268    torch._check(
4269        ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
4270    )
4271    if ndim == 1:
4272        return torch.diag_embed(self, offset)
4273    else:
4274        return torch.diagonal_copy(self, offset)
4275
4276
4277@register_decomposition(aten.diagonal_scatter)
4278@out_wrapper()
4279def diagonal_scatter(
4280    input: TensorLikeType,
4281    src: TensorLikeType,
4282    offset: int = 0,
4283    dim1: int = 0,
4284    dim2: int = 1,
4285) -> TensorLikeType:
4286    out = utils.clone_preserve_strides(input)
4287    diag = out.diagonal(offset, dim1, dim2)
4288    torch._check(
4289        diag.shape == src.shape,
4290        lambda: "expected src to have a size equal to the diagonal of the input."
4291        f"Got {src.shape} for a diagonal of shape {diag.shape}",
4292    )
4293    copy_to(diag, src)
4294    return out
4295
4296
4297@register_decomposition(aten.diagonal)
4298def diagonal(
4299    self: TensorLikeType,
4300    offset: int = 0,
4301    dim1: int = 0,
4302    dim2: int = 1,
4303) -> TensorLikeType:
4304    """
4305    Reference implementation of torch.diagonal
4306    """
4307    num_dims = self.dim()
4308    dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
4309    dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
4310
4311    torch._check(
4312        dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
4313    )
4314
4315    storage_offset = self.storage_offset()
4316
4317    if offset >= 0:
4318        diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
4319    else:
4320        diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
4321
4322    if diag_size > 0:
4323        if offset >= 0:
4324            storage_offset += offset * self.stride()[dim2]
4325        else:
4326            storage_offset -= offset * self.stride()[dim1]
4327
4328    sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
4329    sizes.append(diag_size)
4330
4331    strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
4332    strides.append(self.stride()[dim1] + self.stride()[dim2])
4333
4334    result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
4335
4336    return result
4337
4338
4339@register_decomposition(aten.diag_embed)
4340@out_wrapper()
4341def diag_embed(
4342    t: TensorLikeType,
4343    offset: int = 0,
4344    dim1: int = -2,
4345    dim2: int = -1,
4346) -> TensorLikeType:
4347    """
4348    Reference implementation of torch.diag_embed
4349    """
4350    # convert from negative dims
4351    rank = t.ndim + 1
4352    dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
4353    dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
4354
4355    # as per the docs, exchanging dims is equivalent to changing the sign of
4356    # offset
4357    if dim1 > dim2:
4358        dim1, dim2 = dim2, dim1
4359        offset = -offset
4360
4361    torch._check(
4362        dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
4363    )
4364
4365    # as per the docs, the size of last dim is placed at dim1 and dim2
4366    last_dim = t.size(-1)
4367
4368    if offset != 0:
4369        # add padding to match the new size
4370        t_shape = list(t.shape)
4371        t_shape[-1] = builtins.abs(offset)
4372        z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
4373        pair = (z, t) if offset > 0 else (t, z)
4374        t = torch.cat(pair, dim=-1)
4375        # make sure the diagonal always has the same size
4376        last_dim += builtins.abs(offset)
4377
4378    # preserve original data, but place 1 at dim1 and move last dim to dim2
4379    t = t.unsqueeze(dim1).movedim(-1, dim2)
4380
4381    # generate ranges shifting indices based on offset
4382    a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
4383    b_range = torch.arange(
4384        offset, last_dim + offset, device=t.device, dtype=torch.int64
4385    )
4386
4387    # broadcast
4388    cond = a_range == b_range.unsqueeze(-1)
4389    cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
4390    cond = cond.reshape(cond_shape)
4391
4392    # aten.diag_embed always returns a new contiguous tensor
4393    # contiguous() is needed to correctly model the output stride
4394    return utils.mask_tensor(cond, t).contiguous()
4395
4396
4397@register_decomposition(aten.block_diag)
4398@out_wrapper()
4399def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
4400    """
4401    Reference implementation of torch.block_diag
4402    """
4403    tensors_2d = [
4404        tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors
4405    ]
4406
4407    ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d)
4408    device = tensors_2d[0].device
4409
4410    result = []
4411
4412    col_start = 0
4413    for i, tensor in enumerate(tensors_2d):
4414        torch._check(
4415            tensor.dim() == 2,
4416            lambda: "Input tensors must have 2 or fewer dimensions. "
4417            f"Input {i} has {tensor.dim()} dimensions",
4418        )
4419        torch._check(
4420            tensor.device == device,
4421            lambda: "Input tensors must all be on the same device. "
4422            f"Input 0 is on device {device} and input {i} is on device {tensor.device}.",
4423        )
4424        row, col = tensor.shape
4425        left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype)
4426        right = torch.zeros(
4427            (row, ncols - col_start - col), device=device, dtype=tensor.dtype
4428        )
4429        result += [torch.cat((left, tensor, right), dim=1)]
4430        col_start += col
4431
4432    return torch.cat(result, dim=0)
4433
4434
4435def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType:
4436    """
4437    This is used as an input to PythonRefInfo. `torch.block_diag`
4438    expects arguments splatted, but `aten.block_diag` expects only
4439    one argument that is a list of Tensors.
4440    """
4441    return _block_diag_iterable(tensors)  # type: ignore[arg-type]
4442
4443
4444# CompositeImplicitAutograd - don't register decomp
4445def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
4446    if a.ndim < 3:
4447        raise RuntimeError(
4448            f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
4449        )
4450    if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
4451        raise RuntimeError(
4452            "torch.dsplit attempted to split along dimension 2, "
4453            + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
4454        )
4455    return tensor_split(a, sections, 2)
4456
4457
4458@register_decomposition(aten.t.default)
4459def t(a: TensorLikeType):
4460    # TODO: Add sparse support
4461    # if a.is_sparse:
4462    #     sparse_dim = a.sparse_dim()
4463    #     dense_dim = a.dense_dim()
4464    #     if not (sparse_dim <= 2 and dense_dim == 0):
4465    #         raise RuntimeError(
4466    #             f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
4467    #             f"{dense_dim} dense dimensions"
4468    #         )
4469    if a.ndim > 2:
4470        raise RuntimeError(
4471            f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
4472        )
4473    return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
4474
4475
4476# CompositeImplicitAutograd - don't register decomp
4477def T(a: TensorLikeType) -> TensorLikeType:
4478    # n != 2 && n != 0 is deprecated in regular PyTorch.
4479    torch._check(
4480        a.ndim in (0, 2),
4481        lambda: (
4482            "The use of `x.T` on tensors of dimension other than 0 or 2 "
4483            "to reverse their shape is not supported."
4484        ),
4485    )
4486    return a.t()
4487
4488
4489@register_decomposition(aten.alias)
4490def alias(a: TensorLikeType) -> TensorLikeType:
4491    return prims.view_of(a)
4492
4493
4494@register_decomposition(aten.transpose)
4495def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
4496    _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))  # type: ignore[misc]
4497
4498    if a.ndim <= 1 or dim0 == dim1:
4499        return aten.alias.default(a)
4500
4501    _permutation = list(range(0, a.ndim))
4502    _permutation[_dim0] = _dim1
4503    _permutation[_dim1] = _dim0
4504    return torch.permute(a, _permutation)
4505
4506
4507# Aliases for transpose
4508swap_axes = transpose
4509
4510
4511@register_decomposition(aten.unfold)
4512def unfold(
4513    self: TensorLikeType, dimension: int, size: int, step: int
4514) -> TensorLikeType:
4515    shape, strides = _get_unfold_shape_stride(
4516        self.shape, self.stride(), dimension, size, step
4517    )
4518    return self.as_strided(shape, strides)
4519
4520
4521@register_decomposition(aten.unfold_copy)
4522@out_wrapper()
4523def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
4524    return self.unfold(dimension, size, step).clone(
4525        memory_format=torch.contiguous_format
4526    )
4527
4528
4529def _cumsumprod_common(
4530    func,
4531    init,
4532    a: TensorLikeType,
4533    dim: int,
4534    *,
4535    dtype: Optional[torch.dtype] = None,
4536    out: Optional[Tensor] = None,
4537) -> TensorLikeType:
4538    # We implement all the kwargs of a reduction. ATen just handles dtype
4539    # nb. This decomposition may not be as efficient as a backend-specific implementation
4540    ndim = a.ndim
4541    dim = utils.canonicalize_dim(ndim, dim)
4542    if ndim == 0:
4543        return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out)
4544    a = a.unsqueeze(dim + 1)
4545    rg = torch.arange(a.shape[dim], device=a.device)
4546    mask = rg.unsqueeze(1) <= rg
4547    for _ in range(ndim - dim - 1):
4548        mask = mask.unsqueeze(-1)
4549    masked_a = torch.where(mask, a, init)
4550    return func(masked_a, dim=dim, dtype=dtype, out=out)
4551
4552
4553@register_decomposition(aten.cumsum)
4554def cumsum(
4555    a: TensorLikeType,
4556    dim: int,
4557    *,
4558    dtype: Optional[torch.dtype] = None,
4559    out: Optional[Tensor] = None,
4560) -> TensorLikeType:
4561    return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out)
4562
4563
4564@register_decomposition(aten.cumprod)
4565def cumprod(
4566    a: TensorLikeType,
4567    dim: int,
4568    *,
4569    dtype: Optional[torch.dtype] = None,
4570    out: Optional[Tensor] = None,
4571) -> TensorLikeType:
4572    return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out)
4573
4574
4575# Note: although squeeze is documented as having the out= kwarg it doesn't
4576@register_decomposition(aten.unsqueeze)
4577def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
4578    # Note that unsqueeze canonicalizes with rank + 1 because it allows
4579    # a new innermost dimension to be specified
4580    ndim = a.ndim + 1
4581    dim = utils.canonicalize_dim(ndim, dim)
4582    return prims.expand_dims(a, (dim,), ndim=ndim)
4583
4584
4585# NOTE: shape is a vararg because Tensor.reshape can be called with as
4586# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
4587# doesn't support unpacked shapes
4588# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
4589@register_decomposition(aten.view.default)
4590def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
4591    return _reshape_view_helper(a, *shape, allow_copy=False)
4592
4593
4594# CompositeImplicitAutograd - don't register decomp
4595def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
4596    return self.view(other.size())
4597
4598
4599# CompositeImplicitAutograd - don't register decomp
4600def ravel(a: TensorLikeType) -> TensorLikeType:
4601    return reshape(a, (-1,))
4602
4603
4604# CompositeImplicitAutograd - don't register decomp
4605# missing ref impl. for aten.gather
4606@out_wrapper()
4607def take_along_dim(
4608    a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
4609) -> torch.Tensor:
4610    torch._check(
4611        a.ndim == indices.ndim,
4612        lambda: (
4613            "torch.take_along_dim(): input and indices should have the same "
4614            f"number of dimensions, but got {a.ndim} dimensions for input, and "
4615            f"{indices.ndim} dimensions for indices"
4616        ),
4617    )
4618
4619    torch._check(
4620        utils.is_integer_dtype(indices.dtype),
4621        lambda: (
4622            "torch.take_along_dim(): dtype of indices should be int but got "
4623            f"{indices.dtype} instead"
4624        ),
4625    )
4626
4627    if dim is None:
4628        return torch.gather(a.view(-1), 0, indices.view(-1))
4629    else:
4630        self_sizes = list(a.shape)
4631        self_sizes[dim] = indices.size(dim)
4632        broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
4633        indices_broadcast = broadcast_to(indices, broadcast_shape)
4634
4635        indices_sizes = list(indices.shape)
4636        indices_sizes[dim] = a.size(dim)
4637        broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
4638        self_broadcast = broadcast_to(a, broadcast_shape)
4639
4640        return torch.gather(self_broadcast, dim, indices_broadcast)
4641
4642
4643@out_wrapper()
4644def empty(
4645    *shape,
4646    dtype: Optional[torch.dtype] = None,
4647    layout: torch.layout = torch.strided,
4648    device: Optional[DeviceLikeType] = None,
4649    requires_grad: bool = False,
4650    pin_memory: bool = False,
4651    memory_format: torch.memory_format = torch.contiguous_format,
4652) -> TensorLikeType:
4653    torch._check(
4654        memory_format != torch.preserve_format,
4655        lambda: "torch.empty: the Preserve memory format is not supported",
4656    )
4657
4658    shape = utils.extract_shape_from_varargs(shape)
4659
4660    if memory_format == torch.contiguous_format:
4661        strides = utils.make_contiguous_strides_for(shape)
4662    elif memory_format == torch.channels_last_3d:
4663        strides = utils.make_channels_last_3d_strides_for(shape)
4664    else:  # memory_format == torch.channels_last
4665        torch._check(
4666            memory_format == torch.channels_last,
4667            lambda: f"torch.empty: received an unknown memory format {memory_format}!",
4668        )
4669        strides = utils.make_channels_last_2d_strides_for(shape)
4670
4671    return torch.empty_strided(
4672        shape,
4673        strides,
4674        dtype=dtype,
4675        layout=layout,
4676        device=device,
4677        pin_memory=pin_memory,
4678        requires_grad=requires_grad,
4679    )
4680
4681
4682@out_wrapper()
4683def empty_permuted(
4684    shape,
4685    physical_layout,
4686    dtype: Optional[torch.dtype] = None,
4687    layout: torch.layout = torch.strided,
4688    device: Optional[DeviceLikeType] = None,
4689    requires_grad: bool = False,
4690    pin_memory: bool = False,
4691) -> TensorLikeType:
4692    return prims.empty_permuted(
4693        shape,
4694        physical_layout,
4695        dtype=dtype,
4696        device=device,
4697        requires_grad=requires_grad,
4698    )
4699
4700
4701@register_decomposition(aten.new_empty)
4702@out_wrapper()
4703def new_empty(
4704    a: TensorLikeType,
4705    size: ShapeType,
4706    *,
4707    dtype: Optional[torch.dtype] = None,
4708    layout: Optional[torch.layout] = None,
4709    device: Optional[DeviceLikeType] = None,
4710    pin_memory: bool = False,
4711) -> TensorLikeType:
4712    dtype = a.dtype if dtype is None else dtype
4713    layout = a.layout if layout is None else layout
4714    device = a.device if device is None else device
4715
4716    return torch.empty(
4717        size,
4718        dtype=dtype,
4719        device=device,
4720        pin_memory=pin_memory,
4721        layout=layout,
4722    )
4723
4724
4725@register_decomposition(aten.new_empty_strided)
4726@out_wrapper()
4727def new_empty_strided(
4728    a: TensorLikeType,
4729    size: ShapeType,
4730    stride: StrideType,
4731    *,
4732    dtype: Optional[torch.dtype] = None,
4733    layout: Optional[torch.layout] = None,
4734    device: Optional[DeviceLikeType] = None,
4735    pin_memory: bool = False,
4736) -> TensorLikeType:
4737    """
4738    Reference implementation of torch.Tensor.new_empty_strided
4739    """
4740
4741    dtype = a.dtype if dtype is None else dtype
4742    layout = a.layout if layout is None else layout
4743    device = a.device if device is None else device
4744
4745    return torch.empty_strided(
4746        size,
4747        stride,
4748        dtype=dtype,
4749        device=device,
4750        pin_memory=pin_memory,
4751        layout=layout,
4752    )
4753
4754
4755@register_decomposition(aten.zeros.default)
4756@out_wrapper()
4757def zeros(
4758    *size,
4759    dtype: Optional[torch.dtype] = None,
4760    layout: torch.layout = torch.strided,
4761    device: Optional[DeviceLikeType] = None,
4762    pin_memory: bool = False,
4763    requires_grad: bool = False,
4764) -> TensorLikeType:
4765    size = utils.extract_shape_from_varargs(size)
4766
4767    if dtype is None:
4768        dtype = torch.get_default_dtype()
4769
4770    return torch.full(
4771        size,
4772        False if dtype == torch.bool else 0,
4773        dtype=dtype,
4774        layout=layout,
4775        device=device,
4776        pin_memory=pin_memory,
4777        requires_grad=requires_grad,
4778    )
4779
4780
4781@register_decomposition(aten.new_zeros)
4782@out_wrapper()
4783def new_zeros(
4784    a: TensorLikeType,
4785    size: ShapeType,
4786    *,
4787    dtype: Optional[torch.dtype] = None,
4788    layout: Optional[torch.layout] = None,
4789    device: Optional[DeviceLikeType] = None,
4790    pin_memory: bool = False,
4791    requires_grad: bool = False,
4792) -> TensorLikeType:
4793    dtype = a.dtype if dtype is None else dtype
4794    layout = a.layout if layout is None else layout
4795    device = a.device if device is None else device
4796
4797    return torch.full(
4798        size,
4799        False if (dtype or a.dtype) == torch.bool else 0,
4800        dtype=dtype,
4801        layout=layout,
4802        device=device,
4803        pin_memory=pin_memory,
4804        requires_grad=requires_grad,
4805    )
4806
4807
4808@register_decomposition(aten.ones.default)
4809@out_wrapper()
4810def ones(
4811    *size,
4812    dtype: Optional[torch.dtype] = None,
4813    layout: torch.layout = torch.strided,
4814    device: Optional[DeviceLikeType] = None,
4815    pin_memory: bool = False,
4816    requires_grad: bool = False,
4817) -> TensorLikeType:
4818    size = utils.extract_shape_from_varargs(size)
4819
4820    if dtype is None:
4821        dtype = torch.get_default_dtype()
4822
4823    return torch.full(
4824        size,
4825        True if dtype == torch.bool else 1,
4826        dtype=dtype,
4827        layout=layout,
4828        device=device,
4829        pin_memory=pin_memory,
4830        requires_grad=requires_grad,
4831    )
4832
4833
4834@register_decomposition(aten.new_ones)
4835@out_wrapper()
4836def new_ones(
4837    a: TensorLikeType,
4838    size: ShapeType,
4839    *,
4840    dtype: Optional[torch.dtype] = None,
4841    layout: Optional[torch.layout] = None,
4842    device: Optional[DeviceLikeType] = None,
4843    pin_memory: bool = False,
4844    requires_grad: bool = False,
4845) -> TensorLikeType:
4846    dtype = a.dtype if dtype is None else dtype
4847    layout = a.layout if layout is None else layout
4848    device = a.device if device is None else device
4849
4850    return torch.full(
4851        size,
4852        True if (dtype or a.dtype) == torch.bool else 1,
4853        dtype=dtype,
4854        layout=layout,
4855        device=device,
4856        pin_memory=pin_memory,
4857        requires_grad=requires_grad,
4858    )
4859
4860
4861@register_decomposition(aten.new_full)
4862@out_wrapper()
4863def new_full(
4864    a: TensorLikeType,
4865    size: ShapeType,
4866    fill_value: NumberType,
4867    *,
4868    dtype: Optional[torch.dtype] = None,
4869    layout: Optional[torch.layout] = None,
4870    device: Optional[DeviceLikeType] = None,
4871    pin_memory: bool = False,
4872) -> TensorLikeType:
4873    dtype = a.dtype if dtype is None else dtype
4874    layout = a.layout if layout is None else layout
4875    device = a.device if device is None else device
4876
4877    return torch.full(
4878        size,
4879        fill_value,
4880        dtype=dtype,
4881        layout=layout,
4882        device=device,
4883        pin_memory=pin_memory,
4884    )
4885
4886
4887@register_decomposition(aten.empty_like)
4888@out_wrapper()
4889def empty_like(
4890    a: TensorLikeType,
4891    *,
4892    dtype: Optional[torch.dtype] = None,
4893    device: Optional[DeviceLikeType] = None,
4894    layout: Optional[torch.layout] = None,
4895    pin_memory: bool = False,
4896    requires_grad: bool = False,
4897    memory_format: torch.memory_format = torch.preserve_format,
4898) -> TensorLikeType:
4899    dtype = a.dtype if dtype is None else dtype
4900    layout = a.layout if layout is None else layout
4901    device = a.device if device is None else device
4902
4903    if memory_format != torch.preserve_format:
4904        return torch.empty(
4905            a.shape,
4906            dtype=dtype,
4907            layout=layout,
4908            device=device,
4909            requires_grad=requires_grad,
4910            pin_memory=pin_memory,
4911            memory_format=memory_format,
4912        )
4913
4914    # memory_format == torch.preserve_format
4915    logical_to_physical_perm = (
4916        utils.compute_elementwise_output_logical_to_physical_perm(a)
4917    )
4918    # identity perm is [2, 1, 0]
4919    return torch.empty_permuted(
4920        a.shape,
4921        logical_to_physical_perm,
4922        dtype=dtype,
4923        layout=layout,
4924        device=device,
4925        pin_memory=pin_memory,
4926        requires_grad=requires_grad,
4927    )
4928
4929
4930@register_decomposition([aten.arange.start_step, aten.arange.start_out])
4931@out_wrapper()
4932def arange(
4933    start: NumberType = 0,
4934    end: Optional[NumberType] = None,
4935    step: NumberType = 1,
4936    *,
4937    dtype: Optional[torch.dtype] = None,
4938    layout: torch.layout = torch.strided,
4939    device: Optional[DeviceLikeType] = None,
4940    pin_memory: bool = False,
4941    requires_grad: bool = False,
4942) -> TensorLikeType:
4943    utils.check_layout(layout)
4944    utils.check_pin_memory(pin_memory)
4945    device = torch.device(utils.device_or_default(device))
4946
4947    assert not isinstance(start, complex)
4948    assert not isinstance(end, complex)
4949    assert not isinstance(step, complex)
4950
4951    # Case: torch.arange(5)
4952    if end is None:
4953        end = start
4954        start = 0
4955    torch._check(step != 0, lambda: "step must be nonzero")
4956    if step > 0:
4957        torch._check(
4958            end >= start,
4959            lambda: "upper bound and lower bound inconsistent with step sign",
4960        )
4961    elif step < 0:
4962        torch._check(
4963            end <= start,
4964            lambda: "upper bound and lower bound inconsistent with step sign",
4965        )
4966
4967    def is_finite(x):
4968        return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
4969
4970    torch._check(
4971        is_finite(start) and is_finite(end),
4972        lambda: f"unsupported range: {start} -> {end}",
4973    )
4974    torch._check(
4975        is_finite(step),
4976        lambda: f"step must be finite but got {step}",
4977    )
4978
4979    args = (start, end, step)
4980    integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
4981
4982    if dtype is None:
4983        dtype = torch.int64 if integer_args else torch.get_default_dtype()
4984
4985    is_integer = utils.is_integer_dtype(dtype)
4986    if is_integer or integer_args:
4987        xstart = sym_int(start)
4988        xend = sym_int(end)
4989        xstep = sym_int(step)
4990
4991    # For int64 we truncate arguments to int before calculating length, but
4992    # other integral dtypes we don't. Weird... but needed to match ATen shapes.
4993    if dtype == torch.int64 or integer_args:
4994        # Uses floordiv to avoid ceil in inductor.
4995        sgn = bool(xstep > 0) - bool(xstep < 0)  # type: ignore[possibly-undefined]
4996        length = (xend - xstart + xstep - sgn) // xstep  # type: ignore[possibly-undefined]
4997    else:
4998        length = math.ceil((end - start) / step)
4999
5000    if is_integer:
5001        return prims.iota(
5002            length,
5003            start=xstart,  # type: ignore[possibly-undefined]
5004            step=xstep,  # type: ignore[possibly-undefined]
5005            dtype=dtype,
5006            device=device,
5007            requires_grad=requires_grad,
5008        )
5009
5010    index = prims.iota(
5011        length,
5012        start=0,
5013        step=1,
5014        dtype=torch.int64,
5015        device=device,
5016        requires_grad=False,
5017    )
5018
5019    computation_dtype = (
5020        torch.long if integer_args else utils.get_acc_type(dtype, device)
5021    )
5022    index = _maybe_convert_to_dtype(index, computation_dtype)
5023    result = start + step * index
5024    result = _maybe_convert_to_dtype(result, dtype)
5025
5026    if requires_grad:
5027        result.requires_grad_(True)
5028    return result
5029
5030
5031@register_decomposition(aten.lerp)
5032@out_wrapper()
5033@elementwise_type_promotion_wrapper(
5034    type_promoting_args=("start", "end", "weight"),
5035    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
5036)
5037def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
5038    inputs = [start, end]
5039    if isinstance(weight, Number):
5040        weight = start.new_full((), weight)  # type: ignore[arg-type]
5041    else:
5042        inputs.append(weight)
5043    assert isinstance(weight, Tensor)  # mypy
5044    # We implement it this way for numerical stability. We assume (in the stability optimisation)
5045    # that 0 <= weight <= 1. We take the abs to deal with complex numbers
5046    # We want to perform operations near zero, which is where floating points are most precise
5047    # thus, we perform the following optimisation:
5048    # If weight.abs() >= 0.5:
5049    #    return (1 - weight) * (start - end) + end
5050    mask = weight.abs() >= 0.5
5051    coeff = torch.where(mask, weight - 1, weight)
5052    base = torch.where(mask, end, start)
5053    output = coeff * (end - start) + base
5054    # make sure the decomposition output's stride is same as non-decomposition path.
5055    stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
5056    if output.stride() != stride:
5057        output = prims.copy_strided(output, stride)
5058
5059    return handle_noncontiguous_outputs(inputs, output)
5060
5061
5062@register_decomposition(aten.linspace)
5063@out_wrapper()
5064def linspace(
5065    start: Union[NumberType, TensorLikeType],
5066    end: Union[NumberType, TensorLikeType],
5067    steps: NumberType,
5068    *,
5069    dtype: Optional[torch.dtype] = None,
5070    device: Optional[DeviceLikeType] = None,
5071    layout: torch.layout = torch.strided,
5072    pin_memory: bool = False,
5073    requires_grad: bool = False,
5074) -> TensorLikeType:
5075    if isinstance(start, TensorLikeType):
5076        torch._check(
5077            start.dim() == 0,
5078            lambda: "linspace only supports 0-dimensional start and end tensors",
5079        )
5080        start = _maybe_convert_to_dtype(start, torch.float64)
5081    if isinstance(end, TensorLikeType):
5082        torch._check(
5083            end.dim() == 0,
5084            lambda: "linspace only supports 0-dimensional start and end tensors",
5085        )
5086        end = _maybe_convert_to_dtype(end, torch.float64)
5087
5088    if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
5089        default_complex_dtype = utils.corresponding_complex_dtype(
5090            torch.get_default_dtype()
5091        )
5092        if dtype is None:
5093            dtype = default_complex_dtype
5094        else:
5095            torch._check(
5096                utils.is_complex_dtype(dtype),
5097                lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
5098            )
5099    else:
5100        dtype = dtype or torch.get_default_dtype()
5101    assert isinstance(dtype, torch.dtype)
5102
5103    # steps does not participate in the computation of the dtype
5104    torch._check_type(
5105        isinstance(steps, IntLike),
5106        lambda: f"received an invalid combination of arguments - got \
5107({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
5108    )
5109    assert isinstance(steps, IntLike)  # for mypy
5110    torch._check(steps >= 0, lambda: "number of steps must be non-negative")
5111
5112    factory_kwargs = {
5113        "layout": layout,
5114        "device": device,
5115        "pin_memory": pin_memory,
5116        "requires_grad": requires_grad,
5117    }
5118    if steps == 0:
5119        return torch.full((0,), 0, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
5120    if steps == 1:
5121        if isinstance(start, TensorLikeType):
5122            return torch.empty((steps,), dtype=dtype, **factory_kwargs).copy_(start)  # type: ignore[arg-type]
5123        else:
5124            return torch.full((steps,), start, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
5125
5126    # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
5127    rg = torch.arange(0, steps, **factory_kwargs)  # type: ignore[arg-type]
5128
5129    # Small types need to be computed in higher precision as this is, at heart, an associative scan
5130    dtype_red = (
5131        torch.int64
5132        if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
5133        else dtype
5134    )
5135    computation_dtype, _ = utils.reduction_dtypes(
5136        rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
5137    )
5138    cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
5139
5140    # We implement torch.lerp without performing rg / (steps - 1) explicitly
5141    # With this we get out[0] == start, out[-1] == end
5142    step = (end - start) / (steps - 1)
5143    out = torch.where(
5144        rg < steps / 2,
5145        start + step * cast_rg(rg),  # type: ignore[arg-type,operator]
5146        end - step * cast_rg((steps - 1) - rg),  # type: ignore[arg-type,operator]
5147    )
5148    return _maybe_convert_to_dtype(out, dtype)  # type: ignore[return-value]
5149
5150
5151@register_decomposition(aten.logspace)
5152@out_wrapper()
5153def logspace(
5154    start: Union[NumberType, TensorLikeType],
5155    end: Union[NumberType, TensorLikeType],
5156    steps: NumberType,
5157    base: NumberType = 10,
5158    *,
5159    dtype: Optional[torch.dtype] = None,
5160    device: Optional[DeviceLikeType] = None,
5161    layout: torch.layout = torch.strided,
5162    pin_memory: bool = False,
5163    requires_grad: bool = False,
5164) -> TensorLikeType:
5165    if dtype is None:
5166        dtype = torch.get_default_dtype()
5167
5168    # NB: NumPy doesn't have this cast
5169    if prims.utils.is_integer_dtype(dtype):
5170        if isinstance(start, FloatLike):
5171            start = sym_int(start)
5172        elif isinstance(start, TensorLikeType):
5173            torch._check(
5174                start.dim() == 0,
5175                lambda: "logspace only supports 0-dimensional start and end tensors",
5176            )
5177            start = _maybe_convert_to_dtype(start, dtype)
5178        if isinstance(end, FloatLike):
5179            end = sym_int(end)
5180        elif isinstance(end, TensorLikeType):
5181            torch._check(
5182                end.dim() == 0,
5183                lambda: "logspace only supports 0-dimensional start and end tensors",
5184            )
5185            end = _maybe_convert_to_dtype(end, dtype)
5186
5187    if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
5188        default_complex_dtype = utils.corresponding_complex_dtype(
5189            torch.get_default_dtype()
5190        )
5191        dtype = default_complex_dtype
5192        _dtype = None  # torch.linspace will update the correct dtype
5193    else:
5194        _dtype = torch.float64
5195
5196    assert not isinstance(base, complex)  # for mypy
5197    if base < 0:
5198        raise NotImplementedError
5199    ret = torch.linspace(  # type: ignore[misc]
5200        start,  # type: ignore[arg-type]
5201        end,  # type: ignore[arg-type]
5202        steps,  # type: ignore[arg-type]
5203        dtype=_dtype,
5204        layout=layout,
5205        device=device,
5206        pin_memory=pin_memory,
5207        requires_grad=requires_grad,
5208    )
5209    return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)  # type: ignore[arg-type,return-value]
5210
5211
5212@overload
5213def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
5214    pass
5215
5216
5217@overload
5218def meshgrid(*tensors: TensorLikeType, indexing: str):
5219    pass
5220
5221
5222@register_decomposition(aten.meshgrid)  # type: ignore[misc]
5223def meshgrid(
5224    *tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
5225    indexing: str,
5226) -> List[TensorLikeType]:
5227    # This ref simultaneously handles two overloads (see stubs above)
5228    # The `indexing` argument is currently optional for torch.meshgrid, but we
5229    # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
5230    if isinstance(tensors[0], (list, tuple)):
5231        assert len(tensors) == 1
5232        tensors = tuple(tensors[0])
5233
5234    torch._check(
5235        builtins.all(isinstance(a, TensorLike) for a in tensors),
5236        lambda: "meshgrid expects its inputs to be tensors",
5237    )
5238
5239    torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
5240
5241    for i in range(len(tensors) - 1):
5242        torch._check(
5243            tensors[i].dtype == tensors[i + 1].dtype,  # type: ignore[union-attr]
5244            lambda: "meshgrid expects all tensors to have the same dtype",
5245        )
5246        torch._check(
5247            tensors[i].device == tensors[i + 1].device,  # type: ignore[union-attr]
5248            lambda: "meshgrid expects all tensors to have the same device",
5249        )
5250
5251    swap_first_and_second_tensors = False
5252    if indexing == "xy":
5253        swap_first_and_second_tensors = len(tensors) >= 2
5254        if swap_first_and_second_tensors:
5255            tensors = (tensors[1], tensors[0], *tensors[2:])
5256    else:
5257        torch._check(
5258            indexing == "ij",
5259            lambda: (
5260                'torch.meshgrid: indexing must be one of "xy" or "ij", '
5261                f"but received: {indexing}"
5262            ),
5263        )
5264
5265    result_shape: List[int] = []
5266    for t in tensors:
5267        assert isinstance(t, TensorLike)  # mypy
5268        torch._check(
5269            t.ndim == 0 or t.ndim == 1,
5270            lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
5271        )
5272        result_shape.append(t.numel())
5273
5274    grids: List[TensorLikeType] = []
5275    for i, t in enumerate(tensors):
5276        assert isinstance(t, TensorLike)  # mypy
5277        if t.ndim == 0:
5278            t = t.view((1,))
5279        grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
5280
5281    if swap_first_and_second_tensors:
5282        # Swap outputs if we originally swapped at the beginning
5283        grids[0], grids[1] = grids[1], grids[0]
5284
5285    return grids
5286
5287
5288# CompositeImplicitAutograd - don't register decomp
5289def movedim(
5290    input: TensorLikeType,
5291    source: Union[int, DimsSequenceType],
5292    destination: Union[int, DimsSequenceType],
5293) -> TensorLikeType:
5294    """
5295    Reference implementation of torch.movedim
5296    """
5297    if type(source) is int:
5298        source = (source,)
5299    if type(destination) is int:
5300        destination = (destination,)
5301
5302    # Converts to list to produce a compatible error message with core PyTorch,
5303    # which prints sequences in square brackets.
5304    torch._check(
5305        len(source) == len(destination),  # type: ignore[arg-type]
5306        lambda: (
5307            "movedim: Invalid source or destination dims: source "  # type: ignore[arg-type]
5308            f"({list(source)} dims) should contain the same number "  # type: ignore[arg-type]
5309            f"of dims as destination ({list(destination)} dims)"  # type: ignore[arg-type]
5310        ),
5311    )
5312
5313    rank = input.ndim
5314    ss = tuple(utils.canonicalize_dims(rank=rank, indices=source))  # type: ignore[arg-type]
5315    ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination))  # type: ignore[arg-type]
5316
5317    sss = set(ss)
5318    dss = set(ds)
5319
5320    # See above on why this converts to list in error messages.
5321    torch._check(
5322        len(ss) == len(sss),
5323        lambda: f"movedim: repeated dim in `source` ({list(source)})",  # type: ignore[arg-type]
5324    )
5325    torch._check(
5326        len(ds) == len(dss),
5327        lambda: f"movedim: repeated dim in `destination` ({list(destination)})",  # type: ignore[arg-type]
5328    )
5329
5330    m = dict(zip(ds, ss))
5331    dims = []
5332    si = 0  # source index
5333    for di in range(rank):
5334        # check if the destination index is in the mapping
5335        s = m.get(di)
5336        if s is not None:
5337            # insert source index if found
5338            dims.append(s)
5339        else:
5340            # insert source index sequentially, skipping indices from the mapping
5341            while si in sss:
5342                si += 1
5343            dims.append(si)
5344            si += 1
5345
5346    result = torch.permute(input, tuple(dims))
5347
5348    return result
5349
5350
5351# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
5352@register_decomposition(aten.empty_strided)
5353@out_wrapper()
5354def empty_strided(
5355    shape: Union[ShapeType, Tuple[ShapeType]],
5356    strides: StrideType,
5357    *,
5358    dtype: Optional[torch.dtype] = None,
5359    device: Optional[DeviceLikeType] = None,
5360    layout: torch.layout = torch.strided,
5361    requires_grad: bool = False,
5362    pin_memory: bool = False,
5363) -> TensorLikeType:
5364    # Layout == strided, pin_memory is False
5365    utils.check_layout(layout)
5366    utils.check_pin_memory(pin_memory)
5367
5368    shape = utils.extract_shape_from_varargs(shape)
5369    dtype = torch.get_default_dtype() if dtype is None else dtype
5370    device = torch.device("cpu") if device is None else device
5371
5372    return prims.empty_strided(
5373        shape,
5374        strides,
5375        dtype=dtype,
5376        device=device,
5377        requires_grad=requires_grad,
5378    )
5379
5380
5381@register_decomposition(aten.eye)
5382@out_wrapper()
5383def eye(
5384    n: int,
5385    m: Optional[int] = None,
5386    *,
5387    dtype: Optional[torch.dtype] = None,
5388    layout: torch.layout = torch.strided,
5389    device: Optional[DeviceLikeType] = None,
5390    pin_memory: bool = False,
5391    requires_grad: bool = False,  # TODO: unused
5392) -> TensorLikeType:
5393    """
5394    Reference implementation of torch.eye
5395    """
5396    if m is None:
5397        m = n
5398
5399    torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
5400    torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
5401
5402    range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
5403    range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
5404
5405    cond = range_n.unsqueeze(-1) == range_m
5406    if dtype is torch.bool:
5407        return cond
5408    else:
5409        one = torch.ones(
5410            (1,),
5411            dtype=dtype,
5412            layout=layout,
5413            device=device,
5414            pin_memory=pin_memory,
5415            requires_grad=False,
5416        )
5417        return torch.where(cond, one, 0)
5418    # TODO: Use requires_grad.  All refs taking the requires_grad kwarg must
5419    # return a leaf tensor.
5420    # result.requires_grad_(requires_grad)
5421
5422
5423@register_decomposition([aten.full.default, aten.full.out])
5424@out_wrapper()
5425def full(
5426    shape: ShapeType,
5427    fill_value: NumberType,
5428    *,
5429    dtype: Optional[torch.dtype] = None,
5430    layout: torch.layout = torch.strided,
5431    device: Optional[DeviceLikeType] = None,
5432    pin_memory: bool = False,
5433    requires_grad: bool = False,
5434) -> TensorLikeType:
5435    utils.check_layout(layout)
5436    utils.check_pin_memory(pin_memory)
5437
5438    dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
5439    device = device if device is not None else torch.device("cpu")
5440
5441    e = empty(
5442        shape,
5443        dtype=dtype,
5444        layout=layout,
5445        device=device,
5446        pin_memory=pin_memory,
5447        requires_grad=requires_grad,
5448    )
5449    return torch.fill(e, fill_value)  # type: ignore[arg-type]
5450
5451
5452def full_like(
5453    a: TensorLikeType,
5454    fill_value: NumberType,
5455    *,
5456    dtype: Optional[torch.dtype] = None,
5457    layout: Optional[torch.layout] = None,
5458    device: Optional[DeviceLikeType] = None,
5459    pin_memory: bool = False,
5460    requires_grad: bool = False,
5461    memory_format: torch.memory_format = torch.preserve_format,
5462) -> TensorLikeType:
5463    e = torch.empty_like(
5464        a,
5465        dtype=dtype,
5466        layout=layout,
5467        device=device,
5468        pin_memory=pin_memory,
5469        requires_grad=requires_grad,
5470        memory_format=memory_format,
5471    )
5472    return fill(e, fill_value)
5473
5474
5475@register_decomposition(aten.zeros_like)
5476@out_wrapper()
5477def zeros_like(
5478    a: TensorLikeType,
5479    *,
5480    dtype: Optional[torch.dtype] = None,
5481    layout: Optional[torch.layout] = None,
5482    device: Optional[DeviceLikeType] = None,
5483    pin_memory: bool = False,
5484    requires_grad: bool = False,
5485    memory_format: torch.memory_format = torch.preserve_format,
5486) -> TensorLikeType:
5487    return torch.full_like(
5488        a,
5489        False if (dtype or a.dtype) == torch.bool else 0,
5490        dtype=dtype,
5491        layout=layout,
5492        device=device,
5493        pin_memory=pin_memory,
5494        requires_grad=requires_grad,
5495        memory_format=memory_format,
5496    )
5497
5498
5499@register_decomposition(aten.ones_like)
5500@out_wrapper()
5501def ones_like(
5502    a: TensorLikeType,
5503    *,
5504    dtype: Optional[torch.dtype] = None,
5505    layout: Optional[torch.layout] = None,
5506    device: Optional[DeviceLikeType] = None,
5507    pin_memory: bool = False,
5508    requires_grad: bool = False,
5509    memory_format: torch.memory_format = torch.preserve_format,
5510) -> TensorLikeType:
5511    return torch.full_like(
5512        a,
5513        True if (dtype or a.dtype) == torch.bool else 1,
5514        dtype=dtype,
5515        layout=layout,
5516        device=device,
5517        pin_memory=pin_memory,
5518        requires_grad=requires_grad,
5519        memory_format=memory_format,
5520    )
5521
5522
5523@register_decomposition(aten.randn.default)
5524@out_wrapper()
5525def randn(
5526    *shape,
5527    dtype: Optional[torch.dtype] = None,
5528    device: Optional[DeviceLikeType] = None,
5529    layout: Optional[torch.layout] = None,
5530    requires_grad: bool = False,
5531    pin_memory: bool = False,
5532) -> TensorLikeType:
5533    utils.check_pin_memory(pin_memory)
5534
5535    shape_ = utils.extract_shape_from_varargs(shape)
5536
5537    dtype = utils.dtype_or_default(dtype)
5538    device = utils.device_or_default(device)
5539
5540    return prims.normal(
5541        shape_,
5542        mean=0.0,
5543        std=1.0,
5544        dtype=dtype,
5545        device=device,
5546        requires_grad=requires_grad,
5547    )
5548
5549
5550def scalar_tensor(
5551    a: NumberType,
5552    *,
5553    dtype: Optional[torch.dtype] = None,
5554    layout: torch.layout = torch.strided,
5555    device: Optional[DeviceLikeType] = None,
5556    pin_memory: bool = False,
5557) -> TensorLikeType:
5558    utils.check_layout(layout)
5559    utils.check_pin_memory(pin_memory)
5560    dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
5561    device = device if device is not None else torch.device("cpu")
5562    return prims.scalar_tensor(a, dtype=dtype, device=device)
5563
5564
5565#
5566# Randomness References
5567#
5568
5569
5570def _uniform_helper(
5571    shape: ShapeType,
5572    low: Union[bool, int, float] = 0.0,
5573    high: Union[bool, int, float] = 1.0,
5574    *,
5575    dtype: torch.dtype,
5576    device: DeviceLikeType,
5577) -> TensorLikeType:
5578    utils.validate_shape(shape)
5579
5580    assert isinstance(low, Number)
5581    assert isinstance(high, Number)
5582    low = sym_float(low)
5583    high = sym_float(high)
5584
5585    assert isinstance(dtype, torch.dtype)
5586    device = utils.canonicalize_device(device)
5587
5588    return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
5589
5590
5591@register_decomposition(aten.masked_fill)
5592@out_wrapper()
5593def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
5594    python_type = utils.dtype_to_type(a.dtype)
5595    if isinstance(value, Number):
5596        value_type = type(value)
5597    else:
5598        # NOTE: Could not use value = item(value) as it resulted in
5599        # RuntimeError: Cannot cast FakeTensor(cpu) to number
5600        value_ndim = value.ndim
5601        torch._check(
5602            value_ndim == 0,
5603            lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
5604        )
5605        # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise.
5606        is_cpu_scalar = (
5607            a.device.type
5608            in ["cuda", "xpu", torch._C._get_privateuse1_backend_name(), "hpu"]
5609            and value.device.type == "cpu"
5610        )
5611        torch._check(
5612            is_cpu_scalar or value.device == a.device,
5613            lambda: "Expected `value` to be on same device as `a`",
5614        )
5615        value_type = utils.dtype_to_type(value.dtype)
5616
5617    if value_type is complex:
5618        # only downcasting from complex to lower type is not allowed.
5619        # We allow casting `value` to lower type for other case
5620        # Eg. float -> int.
5621        # Ref: https://github.com/pytorch/pytorch/issues/79195
5622        torch._check(
5623            utils.is_weakly_lesser_type(value_type, python_type),
5624            lambda: f"could not convert to type {python_type} without overflow",
5625        )
5626
5627    # Since `where` allows type-promotion,
5628    # cast value to correct type before passing to `where`
5629    value = _maybe_convert_to_dtype(value, a.dtype)
5630    r = torch.where(mask, value, a)  # type: ignore[arg-type]
5631
5632    # aten.mask_fill always return a new contiguous tensor
5633    # contiguous() is needed to correctly model the output stride
5634    return r.contiguous()
5635
5636
5637@register_decomposition(aten.masked_fill_)
5638def masked_fill_(
5639    a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
5640) -> TensorLikeType:
5641    b = torch.masked_fill(a, mask, value)  # type: ignore[arg-type]
5642    a.copy_(b)
5643    return a
5644
5645
5646# CompositeImplicitAutograd - don't register decomp
5647def allclose(
5648    a: TensorLikeType,
5649    b: TensorLikeType,
5650    rtol: float = 1e-05,
5651    atol: float = 1e-08,
5652    equal_nan: bool = False,
5653) -> bool:
5654    """
5655    Reference implementation of torch.allclose
5656    """
5657    _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
5658
5659    return bool(
5660        torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
5661    )
5662
5663
5664def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
5665    utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
5666    utils.check_same_dtype(a, b)
5667
5668    # Shape check
5669    if a.ndim != b.ndim:
5670        return False
5671
5672    for x, y in zip(a.shape, b.shape):
5673        if x != y:
5674            return False
5675
5676    # Short-circuits if there are no elements to validate
5677    if a.numel() == 0:
5678        return True
5679
5680    return item(all(eq(a, b)))  # type: ignore[return-value]
5681
5682
5683@register_decomposition(aten.norm)
5684@out_wrapper(exact_dtype=True)
5685def norm(
5686    input: TensorLikeType,
5687    p: Optional[Union[float, str]] = "fro",
5688    dim: Optional[DimsType] = None,
5689    keepdim: bool = False,
5690    *,
5691    dtype: Optional[torch.dtype] = None,
5692) -> TensorLikeType:
5693    # In these cases we compute the "Frobenius norm"
5694    if (
5695        p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
5696    ) or p is None:
5697        p = 2
5698    if isinstance(dim, Dim):
5699        dim = [dim]
5700    if isinstance(p, str):
5701        # Here we either call the nuclear norm, or we call matrix_norm with some arguments
5702        # that will throw an error
5703        if dim is None:
5704            dim = tuple(range(input.ndim))
5705        return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
5706    else:
5707        return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
5708
5709
5710@register_decomposition(aten.trace)
5711@out_wrapper()
5712def trace(self: TensorLikeType) -> TensorLikeType:
5713    torch._check(
5714        self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
5715    )
5716    return torch.sum(torch.diag(self, 0))
5717
5718
5719def _make_r_binary_op(base_op):
5720    def rop(
5721        a: Union[TensorLikeType, NumberType],
5722        b: Union[TensorLikeType, NumberType],
5723    ) -> TensorLikeType:
5724        return base_op(b, a)
5725
5726    return rop
5727
5728
5729rtruediv = _make_r_binary_op(true_divide)
5730rfloordiv = _make_r_binary_op(floor_divide)
5731rpow = _make_r_binary_op(pow)
5732
5733
5734@register_decomposition(aten.triu)
5735@out_wrapper()
5736def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
5737    torch._check(
5738        a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
5739    )
5740    h, w = a.shape[-2:]
5741    mask = (
5742        torch.arange(w, device=a.device).unsqueeze(-2)
5743        - torch.arange(h, device=a.device).unsqueeze(-1)
5744    ) >= diagonal
5745
5746    # aten.triu always returns a new contiguous tensor
5747    # contiguous() is needed to correctly model the output stride
5748    return utils.mask_tensor(mask, a).contiguous()
5749
5750
5751@register_decomposition(aten.tril)
5752@out_wrapper()
5753def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
5754    torch._check(
5755        a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
5756    )
5757    h, w = a.shape[-2:]
5758    mask = (
5759        torch.arange(w, device=a.device).unsqueeze(-2)
5760        - torch.arange(h, device=a.device).unsqueeze(-1)
5761    ) <= diagonal
5762
5763    # aten.tril always returns a new contiguous tensor
5764    # contiguous() is needed to correctly model the output stride
5765    return utils.mask_tensor(mask, a).contiguous()
5766
5767
5768# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
5769# The components of the matrix that belong to the lower triangle with offset
5770# form a pentagon that can be broken down into a top trapezoid and a bottom
5771# rectangle. For the implementation of tril_indices, we need the sizes of
5772# both of these, as well as the length of the top side of the trapezoid.
5773def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
5774    if row == 0 or col == 0:
5775        return 0, 0, 0
5776
5777    m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
5778    m_last_row = max(0, min(col, row + offset))
5779    n_row_all = max(0, min(row, row + offset))
5780    n_row_trapezoid = m_last_row - m_first_row + 1
5781
5782    # Number of elements in top trapezoid
5783    trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
5784    # Number of elements in bottom rectangle
5785    diff_row = n_row_all - n_row_trapezoid
5786    rectangle_size = max(0, diff_row * col)
5787
5788    return trapezoid_size, rectangle_size, m_first_row
5789
5790
5791def _trilu_checks(
5792    name: str,
5793    row: int,
5794    col: int,
5795    dtype: torch.dtype,
5796    layout: torch.layout,
5797    pin_memory: bool,
5798):
5799    torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
5800    torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
5801    torch._check(
5802        dtype in (torch.int32, torch.int64),
5803        lambda: f"\"{name}\" not implemented for '{dtype}'",
5804    )
5805
5806
5807# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
5808@register_decomposition(aten.tril_indices)
5809@out_wrapper()
5810def tril_indices(
5811    row: int,
5812    col: int,
5813    offset: int = 0,
5814    *,
5815    dtype: torch.dtype = torch.long,
5816    layout: torch.layout = torch.strided,
5817    device: DeviceLikeType = "cpu",
5818    pin_memory: bool = False,
5819) -> TensorLikeType:
5820    _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
5821
5822    trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
5823    row_offset = max(0, -offset)
5824
5825    arange_kw = partial(
5826        torch.arange, layout=layout, device=device, pin_memory=pin_memory
5827    )
5828
5829    # first we do the indices for top trapezoid
5830    xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
5831    b = m_first_row - 0.5
5832    row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
5833    col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
5834    row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
5835    col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
5836
5837    # then bottom rectangle
5838    xs2 = arange_kw(0, rectangle_size, dtype=dtype)
5839    row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
5840    col_inds2 = xs2 % col
5841
5842    return torch.stack(
5843        (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
5844    )
5845
5846
5847# Similar to _get_tril_sizes above, but here there is a top trapezoid and
5848# a bottom rectangle instead. Note that you can't reduce this to
5849# _get_tril_sizes(col, row, -offset) because that would correspond to
5850# decomposing into a left trapezoid and right rectangle.
5851def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
5852    if row == 0 or col == 0:
5853        return 0, 0, 0
5854
5855    m_first_row = max(0, col - offset) if offset > 0 else col
5856
5857    # Number of elements in top rectangle
5858    rectangle_size = max(0, min(row, -offset) * col)
5859
5860    # Number of elements in bottom trapezoid
5861    trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
5862    triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
5863    trapezoid_size = triu_size - rectangle_size
5864
5865    return trapezoid_size, rectangle_size, m_first_row
5866
5867
5868@register_decomposition(aten.triu_indices)
5869@out_wrapper()
5870def triu_indices(
5871    row: int,
5872    col: int,
5873    offset: int = 0,
5874    *,
5875    dtype: torch.dtype = torch.long,
5876    layout: torch.layout = torch.strided,
5877    device: DeviceLikeType = "cpu",
5878    pin_memory: bool = False,
5879) -> TensorLikeType:
5880    _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
5881
5882    trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
5883    col_offset = max(0, offset)
5884
5885    arange_kw = partial(
5886        torch.arange, layout=layout, device=device, pin_memory=pin_memory
5887    )
5888
5889    # indices for top rectangle
5890    xs2 = arange_kw(0, rectangle_size, dtype=dtype)
5891    row_inds2 = xs2 // col
5892    col_inds2 = xs2 % col
5893
5894    # bottom trapezoid
5895    xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
5896    b = -0.5 - m_first_row
5897    row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
5898    col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
5899    row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
5900    col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
5901
5902    if col:
5903        row_inds1 = row_inds1 + (rectangle_size // col)
5904    col_inds1 = col_inds1 + col_offset
5905
5906    return torch.stack(
5907        (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
5908    )
5909
5910
5911@register_decomposition(aten.bucketize)
5912@out_wrapper(exact_dtype=True)
5913def bucketize(
5914    a: TensorOrNumberLikeType,
5915    boundaries: TensorLikeType,
5916    *,
5917    out_int32: bool = False,
5918    right: bool = False,
5919):
5920    torch._check(
5921        boundaries.dim() == 1,
5922        lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
5923    )
5924
5925    a = a if isinstance(a, torch.Tensor) else torch.tensor(a)
5926    out_dtype = torch.int32 if out_int32 else torch.int64
5927    n_boundaries = boundaries.shape[-1]
5928    if n_boundaries == 0:
5929        return torch.zeros_like(a)
5930    # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
5931    # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
5932    # but each step of the search is done "in parallel" over all elements of `a`
5933    # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
5934    start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
5935    end = start + n_boundaries
5936    # Max depth of the binary search
5937    # Since we can't break out of the loop at different points for different elements of a,
5938    # we just do the max amount of iterations that binary search requires and add condition
5939    # tensor (cond_update below) to stop updating once the search terminates
5940
5941    # For first iteration through loop we can skip some checks, we have separate implementation
5942    mid = start + (end - start) // 2
5943    mid_val = boundaries[mid]
5944    if right:
5945        cond_mid = mid_val > a
5946    else:
5947        cond_mid = mid_val >= a
5948    start = torch.where(cond_mid, start, mid + 1)
5949
5950    if n_boundaries > 1:
5951        cond_update = torch.ones_like(a, dtype=torch.bool)
5952        niters = int(math.log2(n_boundaries))
5953        for _ in range(niters):
5954            end = torch.where(cond_mid & cond_update, mid, end)
5955            cond_update = start < end
5956            # start might end up pointing to 1 past the end, we guard against that
5957            mid = torch.where(cond_update, start + (end - start) // 2, 0)
5958            mid_val = boundaries[mid]
5959            # If right is true, the buckets are closed on the *left*
5960            # (i.e., we are doing the equivalent of std::upper_bound in C++)
5961            # Otherwise they are closed on the right (std::lower_bound)
5962            if right:
5963                cond_mid = mid_val > a
5964            else:
5965                cond_mid = mid_val >= a
5966            start = torch.where((~cond_mid) & cond_update, mid + 1, start)
5967
5968    return start.to(dtype=out_dtype)
5969
5970
5971@register_decomposition(aten.cauchy)
5972@out_wrapper()
5973@elementwise_type_promotion_wrapper(
5974    type_promoting_args=("self",),
5975    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
5976)
5977def cauchy(self, median=0, sigma=1, generator=None):
5978    assert generator is None
5979    torch._check(
5980        not utils.is_complex_dtype(self.dtype)
5981        and not utils.is_integer_dtype(self.dtype)
5982        and not utils.is_boolean_dtype(self.dtype),
5983        lambda: f"Cauchy distribution is a continuous probability distribution. \
5984        dtype must be a floating point but you specified {self.dtype}",
5985    )
5986    torch._check(
5987        sigma > 0.0,
5988        lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
5989    )
5990    return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
5991
5992
5993@register_decomposition(aten.exponential)
5994@out_wrapper()
5995@elementwise_type_promotion_wrapper(
5996    type_promoting_args=("self",),
5997    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
5998)
5999def exponential(self, rate=1, generator=None):
6000    assert generator is None
6001    torch._check(
6002        not utils.is_complex_dtype(self.dtype)
6003        and not utils.is_integer_dtype(self.dtype)
6004        and not utils.is_boolean_dtype(self.dtype),
6005        lambda: f"Exponential distribution is a continuous probability distribution. \
6006        dtype must be a floating point but you specified {self.dtype}",
6007    )
6008    torch._check(
6009        rate > 0.0,
6010        lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
6011    )
6012
6013    uniform_val = torch.rand_like(self)
6014
6015    # copying numerics of transformation::exponential see comment:
6016    # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
6017    # we need log to be not 0, and not underflow when converted to half
6018    # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
6019    epsilon = torch.finfo(uniform_val.dtype).eps / 2
6020    condition = uniform_val >= 1.0 - epsilon
6021    log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val))
6022
6023    return -1 / rate * log_uniform
6024
6025
6026@register_decomposition(aten.geometric)
6027@out_wrapper()
6028@elementwise_type_promotion_wrapper(
6029    type_promoting_args=("self",),
6030    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
6031)
6032def geometric(self, p, generator=None):
6033    assert generator is None
6034    # TODO: fix inductor rand_like for integer, bool dtypes
6035    torch._check(
6036        not utils.is_complex_dtype(self.dtype)
6037        and not utils.is_boolean_dtype(self.dtype),
6038        lambda: f"geometric not implemented for {self.dtype}",
6039    )
6040    torch._check(
6041        0 < p and p < 1,
6042        lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
6043    )
6044    return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
6045
6046
6047@register_decomposition(aten.log_normal)
6048@out_wrapper()
6049@elementwise_type_promotion_wrapper(
6050    type_promoting_args=("self",),
6051    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
6052)
6053def log_normal(self, mean=1, std=2, generator=None):
6054    assert generator is None
6055    torch._check(
6056        not utils.is_complex_dtype(self.dtype)
6057        and not utils.is_integer_dtype(self.dtype)
6058        and not utils.is_boolean_dtype(self.dtype),
6059        lambda: f"log_normal not implemented for {self.dtype}",
6060    )
6061    torch._check(
6062        0 < std,
6063        lambda: f"log_normal_ expects std > 0.0, but found std={std}",
6064    )
6065    return torch.exp(std * torch.randn_like(self) + mean)
6066
6067
6068# TODO: add support for functionalization aten.normal_functional
6069# NOTE: the device and dtype will be ignored when shape is None
6070@register_decomposition(aten.normal)
6071@out_wrapper()
6072@elementwise_type_promotion_wrapper(
6073    type_promoting_args=(
6074        "mean",
6075        "std",
6076    ),
6077    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
6078)
6079def normal(
6080    mean=0,
6081    std=1,
6082    size=None,
6083    *,
6084    generator=None,
6085    dtype=None,
6086    layout=None,
6087    device=None,
6088    pin_memory=None,
6089):
6090    assert layout is None or layout == torch.strided
6091
6092    if not isinstance(std, TensorLike):
6093        torch._check(
6094            std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
6095        )
6096
6097    if size is None:
6098        tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
6099        torch._check(
6100            len(tensors) > 0,
6101            lambda: "normal expects that either mean or std is a tensor, or size is defined",
6102        )
6103        torch._check(
6104            layout is None and pin_memory is None,
6105            lambda: "Cannot pass layout, or pin_memory without size",
6106        )
6107
6108        size = _broadcast_shapes(*(t.shape for t in tensors))
6109        dtype = tensors[0].dtype
6110        device = tensors[0].device
6111    else:
6112        torch._check(
6113            not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
6114            lambda: "normal expects mean and std to be scalars when size is defined",
6115        )
6116        dtype = torch.get_default_dtype() if dtype is None else dtype
6117        device = torch.device("cpu") if device is None else device
6118
6119    normal_samples = prims.normal(
6120        size,
6121        mean=0.0,
6122        std=1.0,
6123        dtype=dtype,
6124        device=device,
6125        requires_grad=False,
6126        generator=generator,
6127    )
6128    return std * normal_samples + mean
6129
6130
6131@register_decomposition(aten.normal_)
6132def normal_(self, mean=0, std=1, *, generator=None):
6133    return normal(mean, std, self.shape, out=self, generator=generator)
6134
6135
6136@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
6137def rad2deg(self: TensorLikeType):
6138    torch._check(
6139        not utils.is_complex_dtype(self.dtype),
6140        lambda: "rad2deg is not supported for complex tensors.",
6141    )
6142    M_180_PI = 57.295779513082320876798154814105170332405472466564
6143    return self * M_180_PI
6144
6145
6146@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
6147def deg2rad(self: TensorLikeType):
6148    torch._check(
6149        not utils.is_complex_dtype(self.dtype),
6150        lambda: "deg2rad is not supported for complex tensors.",
6151    )
6152    M_PI_180 = 0.017453292519943295769236907684886127134428718885417
6153    return self * M_PI_180
6154
6155
6156@register_decomposition(aten.count_nonzero)
6157@out_wrapper()
6158def count_nonzero(self, dim: Optional[DimsType] = None):
6159    return (self != 0).sum(dim)
6160
6161
6162def _dot_check(self, other):
6163    torch._check(
6164        self.dim() == 1 and other.dim() == 1,
6165        lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
6166    )
6167
6168    def numel_error():
6169        return (
6170            f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
6171            f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
6172        )
6173
6174    torch._check(self.numel() == other.numel(), numel_error)
6175
6176
6177@register_decomposition(aten.dot)
6178@out_wrapper()
6179@elementwise_type_promotion_wrapper(
6180    type_promoting_args=("self", "other"),
6181    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
6182)
6183def dot(self, other):
6184    if self.is_complex():
6185        if self.is_conj():
6186            if other.is_conj():
6187                return torch.dot(self.conj(), other.conj()).conj()
6188            else:
6189                return torch.vdot(self.conj(), other)
6190        elif other.is_conj():
6191            return torch.vdot(other.conj(), self)
6192
6193    _dot_check(self, other)
6194    return (self * other).sum()
6195
6196
6197@register_decomposition(aten.vdot)
6198@out_wrapper()
6199@elementwise_type_promotion_wrapper(
6200    type_promoting_args=("self", "other"),
6201    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
6202)
6203def vdot(self, other):
6204    if not self.is_complex():
6205        return torch.dot(self, other)
6206
6207    if self.is_conj():
6208        if other.is_conj():
6209            return torch.vdot(other.conj(), self.conj())
6210        else:
6211            return torch.dot(self.conj(), other)
6212    elif other.is_conj():
6213        return torch.dot(self, other.conj()).conj()
6214
6215    _dot_check(self, other)
6216    # The decomposition fails if you do self.conj()... not sure why
6217    return (self.conj_physical() * other).sum()
6218
6219
6220@register_decomposition(aten.select_scatter)
6221@out_wrapper()
6222def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int):
6223    dim = utils.canonicalize_dim(x.ndim, dim)
6224    mask_shape = [1] * x.ndim
6225    mask_shape[dim] = -1
6226    if index < 0:
6227        index = index + x.shape[dim]
6228    mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index
6229    src = torch.unsqueeze(src, dim).expand(x.shape)
6230    return torch.where(mask, src, x)
6231
6232
6233# inplace
6234abs_ = _make_inplace(abs)
6235acos_ = _make_inplace(acos)
6236acosh_ = _make_inplace(acosh)
6237add_ = _make_inplace(add)
6238addcmul_ = _make_inplace(addcmul)
6239addcdiv_ = _make_inplace(addcdiv)
6240asin_ = _make_inplace(asin)
6241asinh_ = _make_inplace(asinh)
6242atan_ = _make_inplace(atan)
6243atanh_ = _make_inplace(atanh)
6244atan2_ = _make_inplace(atan2)
6245bitwise_and_ = _make_inplace(bitwise_and)
6246bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
6247bitwise_not_ = _make_inplace(bitwise_not)
6248bitwise_or_ = _make_inplace(bitwise_or)
6249bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
6250bitwise_xor_ = _make_inplace(bitwise_xor)
6251ceil_ = _make_inplace(ceil)
6252clamp_ = _make_inplace(clamp)
6253clamp_min_ = _make_inplace(clamp_min)
6254clamp_max_ = _make_inplace(clamp_max)
6255conj_physical_ = _make_inplace(conj_physical)
6256copysign_ = _make_inplace(copysign)
6257cos_ = _make_inplace(cos)
6258cosh_ = _make_inplace(cosh)
6259cumsum_ = _make_inplace(cumsum)
6260cumprod_ = _make_inplace(cumprod)
6261deg2rad_ = _make_inplace(deg2rad)
6262digamma_ = _make_inplace(digamma)
6263div_ = _make_inplace(div)
6264eq_ = _make_inplace(eq)
6265erf_ = _make_inplace(erf)
6266erfc_ = _make_inplace(erfc)
6267erfinv_ = _make_inplace(erfinv)
6268exp_ = _make_inplace(exp)
6269exp2_ = _make_inplace(exp2)
6270expm1_ = _make_inplace(expm1)
6271float_power_ = _make_inplace(float_power)
6272floor_ = _make_inplace(floor)
6273floor_divide_ = _make_inplace(floor_divide)
6274fmod_ = _make_inplace(fmod)
6275frac_ = _make_inplace(frac)
6276gcd_ = _make_inplace(gcd)
6277ge_ = _make_inplace(ge)
6278gt_ = _make_inplace(gt)
6279heaviside_ = _make_inplace(heaviside)
6280hypot_ = _make_inplace(hypot)
6281igamma_ = _make_inplace(igamma)
6282igammac_ = _make_inplace(igammac)
6283i0_ = _make_inplace(i0)
6284lcm_ = _make_inplace(lcm)
6285le_ = _make_inplace(le)
6286lerp_ = _make_inplace(lerp)
6287lgamma_ = _make_inplace(lgamma)
6288log10_ = _make_inplace(log10)
6289log1p_ = _make_inplace(log1p)
6290log2_ = _make_inplace(log2)
6291log_ = _make_inplace(log)
6292logical_and_ = _make_inplace(logical_and)
6293logical_not_ = _make_inplace(logical_not)
6294logical_or_ = _make_inplace(logical_or)
6295logical_xor_ = _make_inplace(logical_xor)
6296lt_ = _make_inplace(lt)
6297mul_ = _make_inplace(mul)
6298mvlgamma_ = _make_inplace(mvlgamma)
6299nan_to_num_ = _make_inplace(nan_to_num)
6300ne_ = _make_inplace(ne)
6301neg_ = _make_inplace(neg)
6302nextafter_ = _make_inplace(nextafter)
6303pow_ = _make_inplace(pow)
6304rad2deg_ = _make_inplace(rad2deg)
6305reciprocal_ = _make_inplace(reciprocal)
6306remainder_ = _make_inplace(remainder)
6307rsqrt_ = _make_inplace(rsqrt)
6308sgn_ = _make_inplace(sgn)
6309sigmoid_ = _make_inplace(sigmoid)
6310sign_ = _make_inplace(sign)
6311sin_ = _make_inplace(sin)
6312sinc_ = _make_inplace(sinc)
6313sinh_ = _make_inplace(sinh)
6314sqrt_ = _make_inplace(sqrt)
6315square_ = _make_inplace(square)
6316sub_ = _make_inplace(sub)
6317tan_ = _make_inplace(tan)
6318tanh_ = _make_inplace(tanh)
6319tril_ = _make_inplace(tril)
6320triu_ = _make_inplace(triu)
6321true_divide_ = _make_inplace(true_divide)
6322trunc_ = _make_inplace(trunc)
6323xlogy_ = _make_inplace(xlogy)
6324cauchy_ = _make_inplace(cauchy)
6325exponential_ = _make_inplace(exponential)
6326geometric_ = _make_inplace(geometric)
6327log_normal_ = _make_inplace(log_normal)
6328zero_ = _make_inplace(zero)
6329
6330alias_copy = _make_copy_from_view(aten.alias)
6331as_strided_copy = _make_copy_from_view(aten.as_strided)
6332diagonal_copy = _make_copy_from_view(aten.diagonal)
6333expand_copy = _make_copy_from_view(aten.expand)
6334# TODO: This must return a sparse tensor if the input is sparse, but refs have
6335# no sparse support. See narrow_copy_sparse in core.
6336narrow_copy = _make_copy_from_view(aten.narrow)
6337t_copy = _make_copy_from_view(aten.t)
6338unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
6339view_copy = _make_copy_from_view(aten.view)
6340
6341
6342# xref: isStorage in torch/csrc/DynamicTypes.cpp
6343def _isStorage(obj):
6344    return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage))
6345
6346
6347# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp
6348def _compute_sizes(seq, scalar_type):
6349    MAX_DIMS = 128
6350    is_storage = _isStorage(seq)
6351    sizes = []
6352    # TODO: this is inaccurate, we actually test PySequence_Check
6353    while isinstance(seq, (list, tuple)):
6354        length = len(seq)
6355        if is_storage:
6356            length //= scalar_type.itemsize
6357        sizes.append(length)
6358        if len(sizes) > MAX_DIMS:
6359            raise ValueError(f"too many dimensions '{type(seq).__name__}'")
6360        if length == 0:
6361            break
6362        try:
6363            handle = seq[0]
6364        except Exception:
6365            raise ValueError(  # noqa: B904
6366                f"could not determine the shape of object type '{type(seq).__name__}'"
6367            )
6368        seq = handle
6369
6370    return sizes
6371
6372
6373# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp
6374def _infer_scalar_type(obj):
6375    if isinstance(obj, FloatLike):
6376        return torch.get_default_dtype()
6377    if isinstance(obj, IntLike) and not isinstance(obj, bool):  # careful!
6378        return torch.int64
6379    if isinstance(obj, BoolLike):
6380        return torch.bool
6381    if isinstance(obj, complex):
6382        default_dtype = torch.get_default_dtype()
6383        if default_dtype is torch.float:
6384            return torch.cfloat
6385        elif default_dtype is torch.double:
6386            return torch.cdouble
6387        elif default_dtype is torch.half:
6388            return torch.chalf
6389        else:
6390            raise RuntimeError("invalid default scalar type for complex")
6391    if isinstance(obj, torch.Tensor):
6392        return obj.dtype
6393    if isinstance(obj, str):
6394        raise TypeError(f"new(): invalid data type '{type(obj).__name__}'")
6395    # TODO: this is inaccurate, we actually test PySequence_Check
6396    if isinstance(obj, (list, tuple)):
6397        scalarType = None
6398        length = len(obj)
6399        # match NumPy semantics, except use default tensor type instead of
6400        # double.
6401        if length == 0:
6402            return torch.get_default_dtype()
6403        for i in range(length):
6404            cur_item = obj[i]
6405            # TODO: test this
6406            """
6407            if cur_item is obj:
6408                raise TypeError("new(): self-referential lists are incompatible")
6409            """
6410            item_scalarType = _infer_scalar_type(cur_item)  # recurse!
6411            if scalarType is not None:
6412                scalarType = torch.promote_types(scalarType, item_scalarType)
6413            else:
6414                scalarType = item_scalarType
6415            if scalarType is torch.cdouble:
6416                # this won't change (unless we hit undefined, but that will
6417                # fail later)
6418                return scalarType
6419        return scalarType
6420    raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}")
6421
6422
6423# Analogous to recursive_store
6424# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
6425def _recursive_build(
6426    scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
6427):
6428    if isinstance(obj, Tensor) and obj.numel() == 1:
6429        return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
6430    elif isinstance(obj, Tensor):
6431        # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
6432        # >>> torch.tensor([torch.randn(2)])
6433        # ValueError: only one element tensors can be converted to Python scalars
6434        #
6435        # But it is possible with a NumPy array
6436        # >>> torch.tensor([np.random.uniform(size=(2,))]).shape
6437        # torch.Size([1, 2])
6438        return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
6439    elif isinstance(obj, Number):
6440        return torch.scalar_tensor(obj, dtype=scalarType)
6441
6442    # seq can be a list of tensors
6443    seq = obj
6444    return torch.stack([_recursive_build(scalarType, item) for item in seq])
6445
6446
6447# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
6448def _internal_new_from_data(
6449    options,
6450    scalar_type,
6451    device_opt,
6452    data,
6453    copy_variables,
6454    copy_numpy,
6455    type_inference,
6456    pin_memory=False,
6457):
6458    if isinstance(data, torch.Tensor):
6459        torch._check(
6460            not pin_memory, lambda: "Can't pin tensor constructed from a variable"
6461        )
6462        var = data
6463        if copy_variables:
6464            var = var.detach()
6465        inferred_scalar_type = var.dtype if type_inference else scalar_type
6466        device = device_opt if device_opt is not None else var.device
6467        return var.to(
6468            device=device,
6469            dtype=inferred_scalar_type,
6470            non_blocking=False,
6471            copy=copy_variables,
6472        )
6473
6474    # TODO
6475    if hasattr(data, "__cuda_array_interface__"):
6476        return NotImplemented
6477
6478    # TODO: test for numpy input with PyArray_Check
6479
6480    device = device_opt if device_opt is not None else options["device"]
6481    inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type
6482
6483    # NB: Don't need to avoid tracing, as we aren't going to do any manual
6484    # pointer filling tricks
6485    if _isStorage(data):
6486        return NotImplemented
6487    else:
6488        if torch.device(device).type == "meta":
6489            return NotImplemented
6490
6491        # In the C implementation, we would directly start poking the memory
6492        # of a freshly allocated CPU tensor.  Here, we're going to do an
6493        # alternate, heinously slow implementation: turn each individual
6494        # scalar into a tensor, and then repeatedly cat them together
6495        tensor = _recursive_build(inferred_scalar_type, data)
6496
6497        tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)
6498
6499    # NB: lift_fresh is not needed, because we built the tensor from scalars
6500    # guaranteeing a fresh tensor in this case
6501    return tensor
6502
6503
6504# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp
6505def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False):
6506    # TODO (or not): support names kwarg
6507    if isinstance(data, torch.Tensor):
6508        warnings.warn(
6509            "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
6510            "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)"
6511        )
6512    type_inference = dtype is None
6513    new_tensor = _internal_new_from_data(
6514        # device="cpu" because that's what you get with torch.tensor(2) no
6515        # device by default
6516        {"device": "cpu"},  # TODO: use torch.get_default_tensor_type
6517        dtype if dtype is not None else torch.get_default_dtype(),
6518        device,
6519        data,
6520        copy_variables=True,
6521        copy_numpy=True,
6522        type_inference=type_inference,
6523        pin_memory=pin_memory,
6524    )
6525    new_tensor.detach_()
6526    if requires_grad:
6527        new_tensor.requires_grad_(requires_grad)
6528    return new_tensor
6529
6530
6531# Views
6532# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
6533# given that it does not reshape the input (it just copies the result into it)
6534
6535# squeeze_ = _make_inplace(squeeze)
6536# t_ = _make_inplace(t)
6537# transpose_ = _make_inplace(transpose)
6538# unsqueeze_ = _make_inplace(unsqueeze)
6539
6540
6541import torch._refs._conversions
6542import torch._refs.fft
6543import torch._refs.linalg
6544import torch._refs.nn.functional
6545import torch._refs.special
6546