xref: /aosp_15_r20/external/pytorch/torch/_prims/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from enum import Enum
4from functools import partial, reduce
5from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
6
7import torch
8import torch._prims_common as utils
9import torch.library
10from torch import sym_float, Tensor
11from torch._C import _get_default_device
12from torch._higher_order_ops.effects import new_token_tensor
13from torch._library.utils import is_functional_schema
14from torch._prims.debug_prims import register_debug_prims
15from torch._prims.rng_prims import register_rng_prims
16from torch._prims_common import (
17    Dim,
18    DimsSequenceType,
19    DimsType,
20    IntLike,
21    Number,
22    NumberType,
23    RETURN_TYPE,
24    ShapeType,
25    StrideType,
26    TensorLike,
27    TensorLikeType,
28    type_to_dtype,
29)
30from torch._prims_common.wrappers import backwards_not_supported
31from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
32from torch.overrides import handle_torch_function, has_torch_function
33from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
34
35
36prim = torch.library.Library("prims", "DEF")
37prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
38prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
39prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
40prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
41
42# Experimental module containing prototype "primitive" operations.
43
44__all__ = [
45    #
46    # Common datastructures and helpers
47    #
48    "RETURN_TYPE",
49    #
50    # Elementwise unary prims
51    #
52    "abs",
53    "acos",
54    "acosh",
55    "asin",
56    "asinh",
57    "atan",
58    "atanh",
59    "cos",
60    "cosh",
61    "bessel_i0",
62    "bessel_i0e",
63    "bessel_i1",
64    "bessel_i1e",
65    "bessel_j0",
66    "bessel_j1",
67    "bitwise_not",
68    "cbrt",
69    "ceil",
70    "conj_physical",
71    "digamma",
72    "erf",
73    "erf_inv",
74    "erfc",
75    "erfcx",
76    "exp",
77    "expm1",
78    "exp2",
79    "fill",
80    "floor",
81    "imag",
82    "isfinite",
83    "lgamma",
84    "log",
85    "log1p",
86    "log2",
87    "log10",
88    "ndtri",
89    "neg",
90    "real",
91    "reciprocal",
92    "round",
93    "sign",
94    "signbit",
95    "sin",
96    "sinh",
97    "spherical_bessel_j0",
98    "sqrt",
99    "tan",
100    "tanh",
101    "trunc",
102    #
103    # Elementwise binary prims
104    #
105    "add",
106    "atan2",
107    "bitwise_and",
108    "bitwise_or",
109    "bitwise_xor",
110    # 'complex',  # needs custom meta
111    "div",
112    "eq",
113    "fmax",
114    "fmin",
115    "fmod",
116    "frexp",
117    "gcd",
118    "ge",
119    "gt",
120    "hypot",
121    "igamma",
122    "igammac",
123    "le",
124    "lt",
125    "maximum",
126    "minimum",
127    "mul",
128    "ne",
129    "nextafter",
130    "pow",
131    "remainder",
132    "rsqrt",
133    "shift_left",
134    "shift_right_arithmetic",
135    "shift_right_logical",  # not implemented
136    "sub",
137    "zeta",
138    #
139    # View prims
140    #
141    "as_strided",
142    "broadcast_in_dim",
143    "collapse_view",
144    "conj",
145    "expand_dims",
146    "slice",
147    "slice_in_dim",  # implemented using slice -- make this a ref?
148    "split_dim",
149    "squeeze",
150    "transpose",
151    "view_of",
152    "view_element_type",
153    #
154    # Functionalized view mutations
155    #
156    "as_strided_scatter",
157    #
158    # Shape prims
159    #
160    "collapse",
161    "cat",
162    "reshape",
163    "rev",
164    #
165    # Conditional prims
166    #
167    "where",
168    #
169    # Data conversion and movement prims
170    #
171    "clone",
172    "convert_element_type",
173    "device_put",
174    "item",
175    "maximum_value",
176    "minimum_value",
177    "copy_strided",
178    #
179    # Inplace prims
180    #
181    "copy_to",
182    "resize",
183    # "_set",  # Commented out, see note below
184    #
185    # Reduction prims
186    #
187    "amax",
188    "amin",
189    "prod",
190    "sum",
191    "xor_sum",
192    "var",
193    #
194    # Tensor Creation Prims
195    #
196    "empty_strided",
197    "empty_permuted",
198    "scalar_tensor",
199    "iota",
200    #
201    # Linear algebra (linalg) Prims
202    #
203    "svd",
204    #
205    # Randomness Prims
206    #
207    "normal",
208    "_uniform_helper",
209    #
210    # FFT prims
211    #
212    "fft_r2c",
213    "fft_c2c",
214    "fft_c2r",
215    #
216    # prims for making/sinking tokens
217    #
218    "_make_token",
219    "_sink_tokens",
220]
221
222
223def TensorMeta(
224    tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
225    *,
226    shape: Optional[ShapeType] = None,
227    strides: Optional[StrideType] = None,
228    dtype: Optional[torch.dtype] = None,
229    device: Optional[Union[torch.device, str]] = None,
230):
231    if isinstance(tensorlike, Number):
232        assert not shape and (shape is None or isinstance(shape, Sequence))
233        assert not strides and (strides is None or isinstance(strides, Sequence))
234        inferred_shape: Tuple[int, ...] = ()
235        inferred_strides: Tuple[int, ...] = ()
236        inferred_dtype = type_to_dtype(type(tensorlike))
237        inferred_device = torch.device("cpu")
238        # TODO: This looks wrong, a number that is wrapped into a tensor
239        # needs to behave differently than a scalar tensor for type
240        # promotion purposes
241    elif tensorlike is not None:
242        assert isinstance(tensorlike, torch.Tensor)
243        inferred_shape = tuple(tensorlike.shape)
244        inferred_strides = tuple(tensorlike.stride())
245        inferred_dtype = tensorlike.dtype
246        inferred_device = tensorlike.device
247    else:
248        # If no tensorlike "example" is given then all metadata
249        # must be provided explicitly
250        assert shape is not None
251        assert strides is not None
252        assert dtype is not None
253        assert device is not None
254
255    shape = inferred_shape if shape is None else tuple(shape)  # type: ignore[possibly-undefined]
256    strides = inferred_strides if strides is None else tuple(strides)  # type: ignore[possibly-undefined]
257    dtype = inferred_dtype if dtype is None else dtype  # type: ignore[possibly-undefined]
258    device = inferred_device if device is None else device  # type: ignore[possibly-undefined]
259
260    if isinstance(device, str):
261        device = torch.device(device)
262
263    return torch.empty_strided(shape, strides, dtype=dtype, device=device)
264
265
266def _make_prim(
267    *,
268    schema: str,
269    return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
270    meta: Callable,
271    impl_aten: Callable,
272    doc: str,
273    tags: Optional[Sequence[torch.Tag]] = None,
274    use_old_custom_ops_api: bool = False,
275    register_conj_neg_fallthrough: bool = False,
276):
277    """
278    Creates a primitive operation.
279
280    """
281
282    def _prim_impl(*args, **kwargs):
283        # always run the meta function because aten implementation will
284        # typically accept more inputs (e.g., it will do promotion and
285        # broadcasting) which we want to reject
286        meta(*args, **kwargs)
287        return impl_aten(*args, **kwargs)
288
289    # Right now prims don't support autograd (we can and should add an
290    # argument that provides an implementation for backward here.)  Because we
291    # don't have derivative formulas, we must setup a custom autograd function
292    # that raises an error if backwards is invoked
293    def _autograd_impl(*args, **kwargs):
294        return backwards_not_supported(_prim)(*args, **kwargs)
295
296    def _backend_select_impl(*args, **kwargs):
297        if kwargs.get("device") and kwargs["device"].type == "meta":
298            return meta(*args, **kwargs)
299        if any(isinstance(x, torch.device) and x.type == "meta" for x in args):
300            return meta(*args, **kwargs)
301        else:
302            return _prim_impl(*args, **kwargs)
303
304    name = schema.split("(")[0]
305    schema = schema[len(name) :]
306
307    # register non-functional ops with old custom ops API
308    cpp_schema = torch._C.parse_schema(name + schema)
309    if use_old_custom_ops_api or not is_functional_schema(cpp_schema):
310        prim.define(name + schema, tags=torch.Tag.pt2_compliant_tag)
311        prim_impl.impl(name, _prim_impl)
312        prim_autograd_impl.impl(name, _autograd_impl)
313        prim_meta_impl.impl(name, meta)
314    else:
315        mutates_args = []
316        for arg in cpp_schema.arguments:
317            if arg.alias_info is not None and arg.alias_info.is_write:
318                mutates_args.append(arg.name)
319        prim_def = torch.library.custom_op(
320            "prims::" + name,
321            _prim_impl,
322            mutates_args=tuple(mutates_args),
323            schema=schema,
324        )
325        prim_def.register_fake(meta)
326
327        # all view ops get conj/neg fallthroughs
328        if return_type == RETURN_TYPE.VIEW or register_conj_neg_fallthrough:
329            prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Conjugate")
330            prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Negative")
331
332    _prim_packet = getattr(torch._ops.ops.prims, name)
333    _prim = _prim_packet.default
334    if tags:
335        _prim._tags = tags
336    elif aten_packet := getattr(torch.ops.aten, name, None):
337        overload_tags = [
338            getattr(aten_packet, overload).tags for overload in aten_packet.overloads()
339        ]
340        tags_intersection = set(overload_tags[0])
341        tags_intersection.intersection_update(*overload_tags[1:])
342
343        # dont inadvertently add to prim ops
344        tags_intersection.discard(torch.Tag.core)
345        # causes errors with python ref executor tests, none of the
346        # data dependent pytorch ops actually decompose to prims
347        tags_intersection.discard(torch.Tag.data_dependent_output)
348
349        # iter over first tags for determinism
350        _prim._tags = tuple(t for t in overload_tags[0] if t in tags_intersection)
351
352    from torch._subclasses.fake_tensor import contains_tensor_types
353
354    if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
355        _prim
356    ) in [
357        # See https://github.com/pytorch/pytorch/issues/103532
358        "prims.device_put.default"
359    ]:
360        prim_backend_select_impl.impl(name, _backend_select_impl)
361
362    for p in (_prim_packet, _prim):
363        p.__doc__ = doc
364        p.return_type = return_type  # type: ignore[attr-defined]
365
366        p.schema = schema
367        p.prim_impl = _prim_impl
368        p.prim_meta_impl = meta
369        p.impl_aten = impl_aten
370
371    return _prim
372
373
374class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
375    DEFAULT = (0,)
376    INT_TO_FLOAT = (2,)
377    ALWAYS_BOOL = (3,)
378    COMPLEX_TO_FLOAT = (4,)
379
380
381# TODO: implement dtype validation here, too, or on the corresponding refs
382def _prim_elementwise_meta(
383    *args,
384    type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
385    args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
386) -> FakeTensor:
387    """
388    Meta function for elementwise operations that produce outputs in the same dtype
389    as their inputs.
390
391    Stride logic is currently incorrect.
392    """
393
394    assert len(args) > 0
395
396    utils.check_same_dtype(*args)
397
398    args_ = list(args)
399    if args_with_fixed_dtypes is not None:
400        args_ = list(args_with_fixed_dtypes) + args_
401
402    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
403    utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
404
405    l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
406    shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
407
408    # Acquires the dtype
409    dtype = None
410    scalar_type = None
411    for arg in args:
412        if isinstance(arg, TensorLike):
413            if not utils.is_cpu_scalar_tensor(arg):
414                dtype = arg.dtype
415                break
416            else:
417                dtype = arg.dtype
418        elif isinstance(arg, Number):
419            scalar_type = type(arg)
420
421    if dtype is None and scalar_type is not None:
422        dtype = utils.type_to_dtype(scalar_type)
423
424    # Acquires the device (if it exists) or number
425    device = None
426    number = None
427    for arg in args_:
428        if isinstance(arg, TensorLike):
429            if utils.is_cpu_scalar_tensor(arg):
430                if device is None:
431                    device = arg.device
432                # keep going, in case there is a cuda tensor later
433            else:
434                device = arg.device
435                break
436
437        elif isinstance(arg, Number):
438            if number is None:
439                number = arg
440
441    # NOTE: type promotion behavior here is mostly hidden from tests because
442    # references will typically handle the type promotion properly even if this doesn't
443    # (but getting it wrong will cause too many casts to be inserted in traces!)
444    if device is not None:
445        assert dtype is not None
446        if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
447            dtype = dtype
448        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
449            dtype = torch.bool
450        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
451            if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype):
452                dtype = torch.get_default_dtype()
453        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
454            if utils.is_complex_dtype(dtype):
455                dtype = utils.corresponding_real_dtype(dtype)
456            else:
457                dtype = dtype
458
459        assert shape is not None
460        return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype)  # type: ignore[return-value]
461
462    # Number case
463    # TODO: fix number type promotion (bool, complex->float)
464
465    # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat)
466    seen_float = False
467    if isinstance(number, (torch.SymInt, torch.SymFloat)):
468        for a in args:
469            assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
470            seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
471        if seen_float:
472            number = sym_float(number)
473
474    return TensorMeta(number)  # type: ignore[arg-type]
475
476
477def _complex_only_elementwise_meta(*args, **kwargs):
478    torch._check(
479        utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
480    )
481    return _prim_elementwise_meta(*args, **kwargs)
482
483
484def _make_elementwise_unary_prim(
485    name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
486):
487    """
488    Creates an elementwise unary prim.
489    """
490
491    return _make_prim(
492        schema=f"{name}(Tensor self) -> Tensor",
493        meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
494        return_type=RETURN_TYPE.NEW,
495        **kwargs,
496    )
497
498
499def _make_elementwise_binary_prim(
500    name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
501):
502    """
503    Creates an elementwise binary prim.
504    """
505
506    return _make_prim(
507        schema=f"{name}(Tensor self, Tensor other) -> Tensor",
508        meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
509        return_type=RETURN_TYPE.NEW,
510        **kwargs,
511    )
512
513
514def _not_impl(*args, **kwargs):
515    raise NotImplementedError
516
517
518#
519# Elementwise unary operations
520#
521
522
523abs = _make_elementwise_unary_prim(
524    "abs",
525    impl_aten=torch.abs,
526    doc="",
527    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
528)
529
530acos = _make_elementwise_unary_prim(
531    "acos",
532    impl_aten=torch.acos,
533    doc="",
534    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
535)
536
537acosh = _make_elementwise_unary_prim(
538    "acosh",
539    impl_aten=torch.acosh,
540    doc="",
541    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
542)
543
544asin = _make_elementwise_unary_prim(
545    "asin",
546    impl_aten=torch.asin,
547    doc="",
548    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
549)
550
551asinh = _make_elementwise_unary_prim(
552    "asinh",
553    impl_aten=torch.asinh,
554    doc="",
555    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
556)
557
558atan = _make_elementwise_unary_prim(
559    "atan",
560    impl_aten=torch.atan,
561    doc="",
562    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
563)
564
565atanh = _make_elementwise_unary_prim(
566    "atanh",
567    impl_aten=torch.atanh,
568    doc="",
569    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
570)
571
572cos = _make_elementwise_unary_prim(
573    "cos",
574    impl_aten=torch.cos,
575    doc="",
576    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
577)
578
579cosh = _make_elementwise_unary_prim(
580    "cosh",
581    impl_aten=torch.cosh,
582    doc="",
583    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
584)
585
586bessel_j0 = _make_elementwise_unary_prim(
587    "bessel_j0",
588    impl_aten=torch.special.bessel_j0,
589    doc="",
590    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
591)
592
593bessel_j1 = _make_elementwise_unary_prim(
594    "bessel_j1",
595    impl_aten=torch.special.bessel_j1,
596    doc="",
597    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
598)
599
600bessel_i0 = _make_elementwise_unary_prim(
601    "bessel_i0",
602    impl_aten=torch.i0,
603    doc="",
604    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
605)
606
607bessel_i0e = _make_elementwise_unary_prim(
608    "bessel_i0e",
609    impl_aten=torch.special.i0e,
610    doc="",
611    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
612)
613
614bessel_i1 = _make_elementwise_unary_prim(
615    "bessel_i1",
616    impl_aten=torch.special.i1,
617    doc="",
618    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
619)
620
621bessel_i1e = _make_elementwise_unary_prim(
622    "bessel_i1e",
623    impl_aten=torch.special.i1e,
624    doc="",
625    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
626)
627
628bitwise_not = _make_elementwise_unary_prim(
629    "bitwise_not",
630    impl_aten=torch.bitwise_not,
631    doc="",
632    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
633)
634
635
636def _cbrt_aten(a: torch.Tensor) -> Tensor:
637    torch._check(
638        not a.is_complex(),
639        lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
640    )
641    # Returns the real cubic root of the number.
642    # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
643    # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
644    # which is a complex number.
645    # For more info see the section Note in
646    # https://en.cppreference.com/w/cpp/numeric/math/cbrt
647    return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
648
649
650cbrt = _make_elementwise_unary_prim(
651    "cbrt",
652    impl_aten=_cbrt_aten,
653    doc="",
654    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
655)
656
657ceil = _make_elementwise_unary_prim(
658    "ceil",
659    impl_aten=torch.ceil,
660    doc="",
661    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
662)
663
664
665def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
666    if not input.dtype.is_complex:
667        raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
668
669    strides = utils.compute_elementwise_output_strides(input)
670    return TensorMeta(input, strides=strides)
671
672
673conj_physical = _make_prim(
674    schema="conj_physical(Tensor self) -> Tensor",
675    meta=_conj_physical_meta,
676    impl_aten=torch._conj_physical,
677    doc="Returns the physical conjugation of a complex tensor",
678    return_type=RETURN_TYPE.NEW,
679)
680
681
682def _clone_meta(
683    input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
684) -> TensorLikeType:
685    if memory_format != torch.preserve_format:
686        return torch.empty(
687            input.shape,
688            dtype=input.dtype,
689            layout=input.layout,
690            device=input.device,
691            memory_format=memory_format,
692        )
693
694    # memory_format == torch.preserve_format
695    strides = utils.compute_elementwise_output_strides(input)
696    return torch.empty_strided(
697        input.shape,
698        strides,
699        dtype=input.dtype,
700        layout=input.layout,
701        device=input.device,
702    )
703
704
705clone = _make_prim(
706    schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
707    meta=_clone_meta,
708    impl_aten=torch.clone,
709    doc="Returns the copy of a tensor",
710    return_type=RETURN_TYPE.NEW,
711    register_conj_neg_fallthrough=True,
712)
713
714digamma = _make_elementwise_unary_prim(
715    "digamma",
716    impl_aten=torch.digamma,
717    doc="",
718    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
719)
720
721erf = _make_elementwise_unary_prim(
722    "erf",
723    impl_aten=torch.erf,
724    doc="",
725    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
726)
727
728erf_inv = _make_elementwise_unary_prim(
729    "erf_inv",
730    impl_aten=torch.special.erfinv,
731    doc="",
732    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
733)
734
735erfc = _make_elementwise_unary_prim(
736    "erfc",
737    impl_aten=torch.special.erfc,
738    doc="",
739    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
740)
741
742erfcx = _make_elementwise_unary_prim(
743    "erfcx",
744    impl_aten=torch.special.erfcx,
745    doc="",
746    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
747)
748
749exp = _make_elementwise_unary_prim(
750    "exp",
751    impl_aten=torch.exp,
752    doc="",
753    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
754)
755
756expm1 = _make_elementwise_unary_prim(
757    "expm1",
758    impl_aten=torch.special.expm1,
759    doc="",
760    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
761)
762
763exp2 = _make_elementwise_unary_prim(
764    "exp2",
765    impl_aten=torch.special.exp2,
766    doc="",
767    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
768)
769
770
771def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
772    return _prim_elementwise_meta(
773        a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
774    )
775
776
777# NOTE: fill uses _make_prim directly because it has a value parameter
778fill = _make_prim(
779    schema="fill(Tensor self, Scalar value) -> Tensor",
780    return_type=RETURN_TYPE.NEW,
781    meta=_fill_meta,
782    impl_aten=torch.fill,
783    doc="",
784)
785
786floor = _make_elementwise_unary_prim(
787    "floor",
788    impl_aten=torch.floor,
789    doc="",
790    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
791)
792
793imag = _make_prim(
794    schema="imag(Tensor(a) self) -> Tensor(a)",
795    meta=partial(
796        _complex_only_elementwise_meta,
797        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
798    ),
799    return_type=RETURN_TYPE.VIEW,
800    impl_aten=torch.imag,
801    doc="",
802)
803
804isfinite = _make_elementwise_unary_prim(
805    "isfinite",
806    impl_aten=torch.isfinite,
807    doc="",
808    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
809)
810
811lgamma = _make_elementwise_unary_prim(
812    "lgamma",
813    impl_aten=torch.lgamma,
814    doc="",
815    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
816)
817
818log = _make_elementwise_unary_prim(
819    "log",
820    impl_aten=torch.log,
821    doc="",
822    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
823)
824
825log1p = _make_elementwise_unary_prim(
826    "log1p",
827    impl_aten=torch.log1p,
828    doc="",
829    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
830)
831
832log2 = _make_elementwise_unary_prim(
833    "log2",
834    impl_aten=torch.log2,
835    doc="",
836    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
837)
838
839log10 = _make_elementwise_unary_prim(
840    "log10",
841    impl_aten=torch.log10,
842    doc="",
843    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
844)
845
846real = _make_prim(
847    schema="real(Tensor(a) self) -> Tensor(a)",
848    meta=partial(
849        _complex_only_elementwise_meta,
850        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
851    ),
852    return_type=RETURN_TYPE.VIEW,
853    impl_aten=torch.real,
854    doc="",
855)
856
857reciprocal = _make_elementwise_unary_prim(
858    "reciprocal",
859    impl_aten=torch.reciprocal,
860    doc="",
861    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
862)
863
864ndtri = _make_elementwise_unary_prim(
865    "ndtri",
866    impl_aten=torch.special.ndtri,
867    doc="",
868    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
869)
870
871neg = _make_elementwise_unary_prim(
872    "neg",
873    impl_aten=torch.neg,
874    doc="",
875    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
876)
877
878round = _make_elementwise_unary_prim(
879    "round",
880    impl_aten=torch.round,
881    doc="",
882    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
883)
884
885rsqrt = _make_elementwise_unary_prim(
886    "rsqrt",
887    impl_aten=torch.rsqrt,
888    doc="",
889    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
890)
891
892sign = _make_elementwise_unary_prim(
893    "sign",
894    impl_aten=torch.sign,
895    doc="",
896    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
897)
898
899signbit = _make_elementwise_unary_prim(
900    "signbit",
901    impl_aten=torch.signbit,
902    doc="",
903    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
904)
905
906sin = _make_elementwise_unary_prim(
907    "sin",
908    impl_aten=torch.sin,
909    doc="",
910    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
911)
912
913sinh = _make_elementwise_unary_prim(
914    "sinh",
915    impl_aten=torch.sinh,
916    doc="",
917    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
918)
919
920spherical_bessel_j0 = _make_elementwise_unary_prim(
921    "spherical_bessel_j0",
922    impl_aten=torch.special.spherical_bessel_j0,
923    doc="",
924    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
925)
926
927sqrt = _make_elementwise_unary_prim(
928    "sqrt",
929    impl_aten=torch.sqrt,
930    doc="",
931    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
932)
933
934tan = _make_elementwise_unary_prim(
935    "tan",
936    impl_aten=torch.tan,
937    doc="",
938    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
939)
940
941tanh = _make_elementwise_unary_prim(
942    "tanh",
943    impl_aten=torch.tanh,
944    doc="",
945    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
946)
947
948trunc = _make_elementwise_unary_prim(
949    "trunc",
950    impl_aten=torch.trunc,
951    doc="",
952    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
953)
954
955#
956# Elementwise binary operations
957#
958
959add = _make_elementwise_binary_prim(
960    name="add",
961    impl_aten=torch.add,
962    doc="",
963    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
964)
965
966atan2 = _make_elementwise_binary_prim(
967    name="atan2",
968    impl_aten=torch.atan2,
969    doc="",
970    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
971)
972
973bitwise_and = _make_elementwise_binary_prim(
974    "bitwise_and",
975    impl_aten=torch.bitwise_and,
976    doc="",
977    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
978)
979
980bitwise_or = _make_elementwise_binary_prim(
981    "bitwise_or",
982    impl_aten=torch.bitwise_or,
983    doc="",
984    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
985)
986
987bitwise_xor = _make_elementwise_binary_prim(
988    "bitwise_xor",
989    impl_aten=torch.bitwise_xor,
990    doc="",
991    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
992)
993
994# TODO: complex needs a special meta to account for its float -> complex behavior
995# complex = _make_elementwise_binary_prim(
996#   impl_aten=torch.complex,
997#   doc="",
998# )
999
1000
1001# div prim performs truncation division on integer inputs
1002#   and true division for floating and complex inputs
1003def _div_aten(a, b):
1004    is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
1005        isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
1006    )
1007
1008    if is_integral:
1009        return torch.div(a, b, rounding_mode="trunc")
1010    else:
1011        return torch.true_divide(a, b)
1012
1013
1014div = _make_elementwise_binary_prim(
1015    "div",
1016    impl_aten=_div_aten,
1017    doc="",
1018    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1019)
1020
1021eq = _make_elementwise_binary_prim(
1022    "eq",
1023    impl_aten=torch.eq,
1024    doc="",
1025    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1026)
1027
1028fmax = _make_elementwise_binary_prim(
1029    "fmax",
1030    impl_aten=torch.fmax,
1031    doc="",
1032    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1033)
1034
1035fmin = _make_elementwise_binary_prim(
1036    "fmin",
1037    impl_aten=torch.fmin,
1038    doc="",
1039    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1040)
1041
1042fmod = _make_elementwise_binary_prim(
1043    "fmod",
1044    impl_aten=torch.fmod,
1045    doc="",
1046    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1047)
1048
1049
1050gcd = _make_elementwise_binary_prim(
1051    "gcd",
1052    impl_aten=torch.gcd,
1053    doc="",
1054    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1055)
1056
1057
1058ge = _make_elementwise_binary_prim(
1059    "ge",
1060    impl_aten=torch.ge,
1061    doc="",
1062    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1063)
1064
1065gt = _make_elementwise_binary_prim(
1066    "gt",
1067    impl_aten=torch.gt,
1068    doc="",
1069    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1070)
1071
1072hypot = _make_elementwise_binary_prim(
1073    "hypot",
1074    impl_aten=torch.hypot,
1075    doc="",
1076    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1077)
1078
1079igamma = _make_elementwise_binary_prim(
1080    "igamma",
1081    impl_aten=torch.special.gammainc,
1082    doc="",
1083    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1084)
1085
1086igammac = _make_elementwise_binary_prim(
1087    "igammac",
1088    impl_aten=torch.special.gammaincc,
1089    doc="",
1090    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1091)
1092
1093le = _make_elementwise_binary_prim(
1094    "le",
1095    impl_aten=torch.le,
1096    doc="",
1097    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1098)
1099
1100lt = _make_elementwise_binary_prim(
1101    "lt",
1102    impl_aten=torch.lt,
1103    doc="",
1104    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1105)
1106
1107
1108# Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs
1109def _maximum_aten(
1110    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1111) -> TensorLikeType:
1112    if isinstance(a, TensorLike) and isinstance(b, Number):
1113        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1114    elif isinstance(b, TensorLike) and isinstance(a, Number):
1115        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1116
1117    return torch.maximum(a, b)  # type: ignore[arg-type]
1118
1119
1120maximum = _make_elementwise_binary_prim(
1121    "maximum",
1122    impl_aten=_maximum_aten,
1123    doc="",
1124    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1125)
1126
1127
1128def _minimum_aten(
1129    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1130) -> TensorLikeType:
1131    if isinstance(a, TensorLike) and isinstance(b, Number):
1132        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1133    elif isinstance(b, TensorLike) and isinstance(a, Number):
1134        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1135
1136    return torch.minimum(a, b)  # type: ignore[arg-type]
1137
1138
1139minimum = _make_elementwise_binary_prim(
1140    "minimum",
1141    impl_aten=_minimum_aten,
1142    doc="",
1143    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1144)
1145
1146mul = _make_elementwise_binary_prim(
1147    "mul",
1148    impl_aten=torch.mul,
1149    doc="",
1150    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1151)
1152
1153ne = _make_elementwise_binary_prim(
1154    "ne",
1155    impl_aten=torch.ne,
1156    doc="",
1157    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1158)
1159
1160nextafter = _make_elementwise_binary_prim(
1161    "nextafter",
1162    impl_aten=torch.nextafter,
1163    doc="",
1164    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1165)
1166
1167pow = _make_elementwise_binary_prim(
1168    "pow",
1169    impl_aten=torch.pow,
1170    doc="",
1171    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1172)
1173
1174remainder = _make_elementwise_binary_prim(
1175    "remainder",
1176    impl_aten=torch.remainder,
1177    doc="",
1178    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1179)
1180
1181
1182shift_left = _make_elementwise_binary_prim(
1183    "shift_left",
1184    impl_aten=torch.bitwise_left_shift,
1185    doc="",
1186    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1187)
1188
1189shift_right_arithmetic = _make_elementwise_binary_prim(
1190    "shift_right_arithmetic",
1191    impl_aten=torch.bitwise_right_shift,
1192    doc="",
1193    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1194)
1195
1196shift_right_logical = _not_impl
1197
1198sub = _make_elementwise_binary_prim(
1199    "sub",
1200    impl_aten=torch.sub,
1201    doc="",
1202    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1203)
1204
1205zeta = _make_elementwise_binary_prim(
1206    "zeta",
1207    impl_aten=torch.special.zeta,
1208    doc="",
1209    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1210)
1211
1212
1213#
1214# View operations
1215def _as_strided_meta(
1216    a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
1217) -> TensorLikeType:
1218    assert len(size) == len(stride)
1219    assert storage_offset >= 0
1220    utils.validate_strides(stride)
1221    utils.validate_shape(size)
1222
1223    if reduce(operator.mul, size) == 0:
1224        # NOTE: This special case is to avoid having to acquire the storage below
1225        # as_strided to shapes with no elements are trivially valid, so it's OK
1226        pass
1227    elif isinstance(a, torch.Tensor):
1228        utils.check_in_bounds_for_storage(
1229            a._typed_storage(), size, stride, storage_offset
1230        )
1231
1232    return torch.as_strided(a, size, stride, storage_offset)
1233
1234
1235def _as_strided_aten(
1236    a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
1237) -> Tensor:
1238    return torch.as_strided(a, size, stride, storage_offset)
1239
1240
1241_as_strided_doc = """
1242    Creates a view of the tensor with the given shape (size), strides (stride) and
1243    storage offset (storage_offset).
1244"""
1245
1246as_strided = _make_prim(
1247    schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
1248    meta=_as_strided_meta,
1249    impl_aten=_as_strided_aten,
1250    return_type=RETURN_TYPE.VIEW,
1251    doc=_as_strided_doc,
1252)
1253
1254
1255def _broadcast_in_dim_meta(
1256    a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
1257):
1258    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1259
1260    # Type checks
1261    assert isinstance(a, TensorLike)
1262    assert isinstance(shape, Sequence)
1263    assert isinstance(broadcast_dimensions, Sequence)
1264
1265    # every dimension must be accounted for
1266    assert a.ndim == len(broadcast_dimensions)
1267
1268    # broadcast shape must have weakly more dimensions
1269    assert len(shape) >= a.ndim
1270
1271    # broadcast_dimensions must be an ascending sequence
1272    # (no relative reordering of dims) of integers and
1273    # each dimension must be within the new shape
1274    def _greater_than_reduce(acc, x):
1275        assert isinstance(x, Dim)
1276        assert x > acc
1277        assert x < len(shape)
1278
1279        return x
1280
1281    reduce(_greater_than_reduce, broadcast_dimensions, -1)
1282
1283    # shape must be broadcastable to
1284    for idx, new_idx in enumerate(broadcast_dimensions):
1285        if not guard_size_oblivious(a.shape[idx] == 1):
1286            torch._check(
1287                a.shape[idx] == shape[new_idx],
1288                lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
1289            )
1290
1291    new_strides = []
1292    original_idx = 0
1293    for idx in range(len(shape)):
1294        if idx in broadcast_dimensions:
1295            # Assigns a stride of zero to dimensions
1296            # which were actually broadcast
1297            if guard_size_oblivious(a.shape[original_idx] != shape[idx]):
1298                new_strides.append(0)
1299            else:
1300                new_strides.append(a.stride()[original_idx])
1301            original_idx = original_idx + 1
1302        else:
1303            if guard_size_oblivious(shape[idx] != 1):
1304                new_strides.append(0)
1305            elif original_idx == a.ndim:
1306                new_strides.append(1)
1307            else:
1308                new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1309
1310    return a.as_strided(shape, new_strides, a.storage_offset())
1311
1312
1313def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
1314    s = list(shape)
1315    for broadcast_dimension in broadcast_dimensions:
1316        s[broadcast_dimension] = -1
1317
1318    v = a
1319    for idx, x in enumerate(s):
1320        if x != -1:
1321            v = v.unsqueeze(idx)
1322
1323    return v.expand(shape)
1324
1325
1326_broadcast_in_dim_doc = """
1327  Creates a view of a with the specified shape.
1328
1329  Allows adding dimensions of any length and broadcasting
1330  dimensions of length one in a to any length.
1331
1332  The location of the broadcast dimensions must be specified
1333  using the broadcast_dimensions argument. Changing the
1334  relative order of dimensions is not supported.
1335  """
1336
1337broadcast_in_dim = _make_prim(
1338    schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
1339    meta=_broadcast_in_dim_meta,
1340    impl_aten=_broadcast_in_dim_aten,
1341    return_type=RETURN_TYPE.VIEW,
1342    doc=_broadcast_in_dim_doc,
1343)
1344
1345
1346def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
1347    # Special-case for zero dimensional tensors
1348    ndim = max(1, a.dim())
1349    utils.validate_idx(ndim, start)
1350    utils.validate_idx(ndim, end)
1351
1352    # Verifies end is strictly greater than start
1353    # (Collapse requires a non-empty interval)
1354    torch._check_value(
1355        end >= start,
1356        lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
1357    )
1358
1359
1360def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
1361    """
1362    Returns the shape of a with dims in [start, end) merged into a single dimension.
1363    """
1364    # Special-case for zero dimensional tensors
1365    shape = (1,) if len(shape) == 0 else tuple(shape)
1366
1367    dim_length = 1
1368    for s in shape[start : end + 1]:
1369        dim_length = dim_length * s
1370
1371    return shape[0:start] + (dim_length,) + shape[end + 1 :]
1372
1373
1374def _collapse_view_helper(
1375    a: TensorLikeType, start: int, end: int
1376) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
1377    assert isinstance(a, TensorLike)
1378
1379    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1380
1381    _validate_collapse_args(a, start, end)
1382
1383    # Special-case for zero dimensional tensors
1384    if a.ndim == 0:
1385        shape = (1,)
1386        strides = (1,)
1387    else:
1388        shape = a.shape  # type: ignore[assignment]
1389        strides = a.stride()  # type: ignore[assignment]
1390
1391    if a.ndim == 0 or (end == start):
1392        return shape, strides
1393
1394    length = shape[end]
1395    stride = strides[end]
1396    for idx in range(end - 1, start - 1, -1):
1397        if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
1398            shape[idx + 1] == 0
1399        ):
1400            length = 0
1401            stride = 0
1402            break
1403
1404        if guard_size_oblivious(shape[idx] == 1):
1405            continue
1406
1407        length = length * shape[idx]
1408        if guard_size_oblivious(stride < strides[idx]):
1409            stride = stride
1410        else:
1411            stride = strides[idx]
1412
1413        if (
1414            guard_size_oblivious(a.numel() > 0)
1415            and guard_size_oblivious(shape[idx + 1] != 1)
1416            and not guard_size_oblivious(
1417                strides[idx] == strides[idx + 1] * shape[idx + 1]
1418            )
1419        ):
1420            return None, None
1421
1422    new_shape = shape[:start] + (length,) + shape[end + 1 :]
1423    new_strides = strides[:start] + (stride,) + strides[end + 1 :]
1424
1425    # NOTE: when the input has no elements it's restrided as if it were contiguous
1426    if guard_size_oblivious(a.numel() == 0):
1427        new_strides = utils.make_contiguous_strides_for(new_shape)
1428
1429    return new_shape, new_strides
1430
1431
1432def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
1433    new_shape, new_strides = _collapse_view_helper(a, start, end)
1434
1435    if new_shape is None:
1436        msg = "Attempting to view a collapsed tensor, but no such view exists!"
1437        raise ValueError(msg)
1438
1439    assert new_strides is not None
1440    return a.as_strided(new_shape, new_strides, a.storage_offset())
1441
1442
1443def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
1444    new_shape = _collapsed_shape(a.shape, start, end)
1445    return a.view(new_shape)
1446
1447
1448_collapse_view_doc = """
1449  Creates a view of a with the dimensions between
1450  start (inclusive) and end (exclusive) merged into a
1451  single dimension.
1452
1453  If it's not possible to take such a view then an error
1454  is thrown. See collapse instead.
1455
1456  The dimensions can be merged if and only if
1457  they are all "nested" with each other. That is, they all
1458  have the property that
1459
1460  stride[i] = stride[i+1] * shape[i+1]
1461
1462  for all i in [start, end - 1).
1463  """
1464
1465collapse_view = _make_prim(
1466    schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
1467    meta=_collapse_view_meta,
1468    impl_aten=_collapse_view_aten,
1469    return_type=RETURN_TYPE.VIEW,
1470    doc=_collapse_view_doc,
1471)
1472
1473
1474def _conj_meta(a: TensorLikeType) -> TensorLikeType:
1475    if not a.dtype.is_complex:
1476        raise RuntimeError("Expected complex dtype in prims.conj")
1477    out = a.as_strided(a.shape, a.stride(), a.storage_offset())
1478    torch._C._set_conj(out, not a.is_conj())
1479    return out
1480
1481
1482_conj_doc = """
1483Returns a conjugated view of the original tensor
1484"""
1485
1486conj = _make_prim(
1487    schema="conj(Tensor(a) a) -> Tensor(a)",
1488    meta=_conj_meta,
1489    impl_aten=torch.conj,
1490    return_type=RETURN_TYPE.VIEW,
1491    doc=_conj_doc,
1492)
1493
1494
1495def expand_dims(
1496    a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
1497) -> TensorLikeType:
1498    """
1499    Creates a view of a with a.ndim + len(dimensions) dimensions, with new
1500    dimensions of length one at the dimensions specified by dimensions.
1501    """
1502    if ndim is not None:
1503        # TODO: this is only here to support the unsqueeze ref
1504        dims = sorted(utils.canonicalize_dims(ndim, dimensions))  # type: ignore[arg-type]
1505    else:
1506        dims = sorted(utils.canonicalize_dims(a.ndim, dimensions))  # type: ignore[arg-type]
1507    if len(set(dims)) != len(dims):
1508        msg = f"Received duplicate dimensions to expand in {str(dimensions)}"
1509        raise ValueError(msg)
1510
1511    new_shape = list(a.shape)
1512    for idx in dims:
1513        new_shape.insert(idx, 1)
1514
1515    broadcast_dimensions = [
1516        idx for idx in range(len(new_shape)) if idx not in dimensions
1517    ]
1518    return broadcast_in_dim(a, new_shape, broadcast_dimensions)
1519
1520
1521# Note: saves the Python slice object because we're about to clobber its name with the slice prim
1522pyslice: Type[slice] = slice  # type: ignore[has-type]
1523
1524
1525def _slice_meta(
1526    a: TensorLikeType,
1527    start_indices: DimsSequenceType,
1528    limit_indices: DimsSequenceType,
1529    strides: Optional[StrideType] = None,
1530) -> TensorLikeType:
1531    _strides = strides if strides is not None else [1] * len(start_indices)
1532
1533    if a.ndim != len(start_indices):
1534        msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!"
1535        raise ValueError(msg)
1536
1537    if a.ndim != len(limit_indices):
1538        msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!"
1539        raise ValueError(msg)
1540
1541    if a.ndim != len(_strides):
1542        msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!"
1543        raise ValueError(msg)
1544
1545    for x, y in zip(start_indices, a.shape):
1546        if x < 0:
1547            msg = f"Attempting to slice a tensor with a negative start index of {x}!"
1548            raise ValueError(msg)
1549        if x > y:
1550            msg = (
1551                f"Attempting to slice a tensor but a start index in {start_indices} is greater than"
1552                f" the length of its corresponding dimension in shape {a.shape}"
1553            )
1554            raise ValueError(msg)
1555
1556    for x, y, z in zip(limit_indices, a.shape, start_indices):
1557        if x < 0:
1558            msg = f"Attempting to slice a tensor with a negative stop index of {x}!"
1559            raise ValueError(msg)
1560        if x > y:
1561            msg = (
1562                f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of "
1563                f" its corresponding dimension in shape {a.shape}"
1564            )
1565            raise ValueError(msg)
1566        if x < z:
1567            msg = (
1568                f"Attempting to slice a tensor but a start index in {x} is greater than "
1569                f" its corresponding stop index {z}"
1570            )
1571
1572    for x in _strides:
1573        if x <= 0:
1574            msg = f"Attempting to slice a tensor with a non-positive step of {x}!"
1575            raise ValueError(msg)
1576
1577    new_shape = []
1578    for x, y, z in zip(start_indices, limit_indices, _strides):
1579        new_shape.append(1 + (y - x - 1) // z)
1580
1581    new_strides = []
1582    for x, y in zip(a.stride(), _strides):
1583        new_strides.append(x * y)
1584
1585    return a.as_strided(new_shape, new_strides, a.storage_offset())
1586
1587
1588def _slice_aten(
1589    a: Tensor,
1590    start_indices: DimsSequenceType,
1591    limit_indices: DimsSequenceType,
1592    strides: Optional[StrideType] = None,
1593) -> Tensor:
1594    _strides = strides if strides is not None else [1] * len(start_indices)
1595
1596    slices = []
1597    for start, stop, step in zip(start_indices, limit_indices, _strides):
1598        slices.append(pyslice(start, stop, step))
1599
1600    return operator.getitem(a, slices)  # type: ignore[call-overload]
1601
1602
1603_slice_doc = """
1604    Creates a view of a "bounding box" within the tensor.
1605
1606    The bounding box is specified independently in each of the tensor's dimensions.
1607    start_indices and limit_indices describe the box's boundaries for their corresponding
1608    dimensions. If strides is specified then they specify the step size between elements
1609    in their corresponding dimension.
1610
1611    This operation is analogous to slicing in NumPy, but does not permit slices where
1612    the stop indices are less than the start indices.
1613    """
1614
1615slice = _make_prim(
1616    schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
1617    meta=_slice_meta,
1618    impl_aten=_slice_aten,
1619    return_type=RETURN_TYPE.VIEW,
1620    doc=_slice_doc,
1621)
1622
1623
1624def _slice_in_dim_meta(
1625    a: TensorLikeType,
1626    start_index: int,
1627    limit_index: int,
1628    stride: int = 1,
1629    axis: int = 0,
1630) -> TensorLikeType:
1631    if axis < 0:
1632        msg = f"slice_in_dim: received a negative axis {axis}"
1633        raise ValueError(msg)
1634    if axis >= a.ndim:
1635        msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor"
1636        raise ValueError(msg)
1637
1638    if start_index < 0:
1639        msg = f"slice_in_dim: received a negative start_index {start_index}"
1640        raise ValueError(msg)
1641
1642    if start_index > a.shape[axis]:
1643        msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}"
1644        raise ValueError(msg)
1645
1646    if limit_index > a.shape[axis]:
1647        msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}"
1648        raise ValueError(msg)
1649
1650    if limit_index < start_index:
1651        msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}"
1652        raise ValueError(msg)
1653
1654    if stride < 0:
1655        msg = f"slice_in_dim: received a non-positive stride of {stride}!"
1656        raise ValueError(msg)
1657
1658    start_indices = [0] * a.ndim
1659    limit_indices = list(a.shape)
1660    strides = [1] * a.ndim
1661
1662    start_indices[axis] = start_index
1663    limit_indices[axis] = limit_index
1664    strides[axis] = stride
1665
1666    return _slice_meta(a, start_indices, limit_indices, strides)
1667
1668
1669def _slice_in_dim_aten(
1670    a: Tensor,
1671    start_index: int,
1672    limit_index: int,
1673    stride: int = 1,
1674    axis: int = 0,
1675) -> Tensor:
1676    start_indices = [0] * a.ndim
1677    limit_indices = list(a.shape)
1678    strides = [1] * a.ndim
1679
1680    start_indices[axis] = start_index
1681    limit_indices[axis] = limit_index
1682    strides[axis] = stride
1683
1684    return slice(a, start_indices, limit_indices, strides)
1685
1686
1687_slice_in_dim_doc = """
1688    Convenience wrapper for slicing just one dimension using slice.
1689    """
1690
1691# TODO: make stride SymInt
1692slice_in_dim = _make_prim(
1693    schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
1694    meta=_slice_in_dim_meta,
1695    impl_aten=_slice_in_dim_aten,
1696    return_type=RETURN_TYPE.VIEW,
1697    doc=_slice_in_dim_doc,
1698)
1699
1700
1701def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
1702    assert isinstance(a, TensorLike)
1703    utils.validate_idx(a.ndim, dim)
1704    utils.validate_dim_length(outer_length)
1705
1706    # Verifies the dim can be split with the specified lhs_length
1707    inner_length = a.shape[dim] // outer_length
1708
1709    if (a.shape[dim] % outer_length) != 0:
1710        msg = (
1711            f"Attempting to split dimension of length {a.shape[dim]}, "
1712            f"but outer length of {outer_length} divides it with a remainder!"
1713        )
1714        raise ValueError(msg)
1715
1716    new_shape: List[int] = []
1717    new_strides: List[int] = []
1718    for idx in range(a.ndim):
1719        if idx == dim:
1720            new_shape.extend((outer_length, inner_length))
1721            new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
1722        else:
1723            new_shape.append(a.shape[idx])
1724            new_strides.append(a.stride()[idx])
1725
1726    return a.as_strided(new_shape, new_strides, a.storage_offset())
1727
1728
1729def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
1730    inner_length = a.shape[dim] // outer_length
1731    new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
1732
1733    return a.view(new_shape)
1734
1735
1736_split_dim_doc = """
1737  Creates a view of a with the given dimension (of length l) split
1738  into two dimensions, with the outer of the two having
1739  length outer_length and the inner of the two having computed
1740  length inner_length such outer_length * inner_length = l.
1741  """
1742
1743# TODO: consider renaming split_dim_view
1744split_dim = _make_prim(
1745    schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
1746    meta=_split_dim_meta,
1747    impl_aten=_split_dim_aten,
1748    return_type=RETURN_TYPE.VIEW,
1749    doc=_split_dim_doc,
1750)
1751
1752
1753# Note: allows dimensions to be specified redundantly
1754def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
1755    assert isinstance(a, TensorLike)
1756
1757    for idx in dimensions:
1758        utils.validate_idx(a.ndim, idx)
1759        assert a.shape[idx] == 1
1760
1761    new_shape = []
1762    new_strides = []
1763    for idx in range(len(a.shape)):
1764        if idx in dimensions:
1765            continue
1766
1767        new_shape.append(a.shape[idx])
1768        new_strides.append(a.stride()[idx])
1769
1770    return a.as_strided(new_shape, new_strides, a.storage_offset())
1771
1772
1773_squeeze_doc = """
1774  Creates a view of the tensor with the specified dimensions removed.
1775
1776  The removed dimensions must each have length one.
1777  """
1778
1779squeeze = _make_prim(
1780    schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
1781    meta=_squeeze_meta,
1782    impl_aten=torch.squeeze,
1783    return_type=RETURN_TYPE.VIEW,
1784    doc=_squeeze_doc,
1785)
1786
1787
1788def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
1789    if a.ndim != len(permutation):
1790        msg = f"Attempting to permute a tensor of rank {a.ndim}, but received a permutation of length {len(permutation)}!"
1791        raise ValueError(msg)
1792
1793    if not utils.is_valid_permutation(a.ndim, permutation):
1794        msg = f"Received an invalid permutation, {permutation}!"
1795        raise ValueError(msg)
1796
1797    new_shape = [0] * a.ndim
1798    new_strides = [0] * a.ndim
1799    for idx, dim in enumerate(permutation):
1800        new_shape[idx] = a.shape[dim]
1801        new_strides[idx] = a.stride()[dim]
1802
1803    return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
1804
1805
1806def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
1807    return torch.permute(a, permutation)
1808
1809
1810_transpose_doc = """
1811    Creates a view of the tensor with its dimensions permuted.
1812
1813    The length of the permutation must be the rank of the tensor,
1814    and each element of the permutation specifies the new order
1815    for the corresponding dimension.
1816    """
1817
1818transpose = _make_prim(
1819    schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
1820    meta=_transpose_meta,
1821    impl_aten=_transpose_aten,
1822    return_type=RETURN_TYPE.VIEW,
1823    doc=_transpose_doc,
1824)
1825
1826
1827def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
1828    return a.as_strided(a.shape, a.stride(), a.storage_offset())
1829
1830
1831def _view_of_aten(a: Tensor) -> Tensor:
1832    return a.view(a.shape)
1833
1834
1835_view_of_doc = """
1836    Creates a view of the tensor.
1837    """
1838
1839view_of = _make_prim(
1840    schema="view_of(Tensor(a) a) -> Tensor(a)",
1841    meta=_view_of_meta,
1842    impl_aten=_view_of_aten,
1843    return_type=RETURN_TYPE.VIEW,
1844    doc=_view_of_doc,
1845)
1846
1847
1848def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
1849    return a.view(dtype)
1850
1851
1852def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
1853    return a.view(dtype)
1854
1855
1856_view_element_type_doc = """
1857    Creates a view of the tensor with a different dtype.
1858    """
1859
1860view_element_type = _make_prim(
1861    schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor(a)",
1862    meta=_view_element_type_meta,
1863    impl_aten=_view_element_type_aten,
1864    return_type=RETURN_TYPE.VIEW,
1865    doc=_view_element_type_doc,
1866)
1867
1868#
1869# Functionalized view mutations
1870#
1871
1872
1873def _as_strided_scatter_meta(
1874    input: TensorLikeType,
1875    src: TensorLikeType,
1876    size: ShapeType,
1877    stride: StrideType,
1878    storage_offset: int,
1879) -> TensorLikeType:
1880    utils.validate_shape(size)
1881    utils.validate_strides(stride)
1882
1883    required_size = utils.compute_required_storage_length(size, stride, storage_offset)
1884    torch._check(
1885        input.numel() >= required_size,
1886        lambda: (
1887            f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
1888            f" and itemsize {input.element_size()} requiring a storage size of "
1889            f"{required_size * input.element_size()} are out of bounds "
1890            f"for storage of size {input.numel() * input.element_size()}"
1891        ),
1892    )
1893    torch._check(
1894        utils.is_same_shape(src.shape, size),
1895        lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
1896    )
1897
1898    return utils.clone_preserve_strides(input)
1899
1900
1901_as_strided_scatter_doc = """
1902    Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
1903    ``out.as_strided(size, stride, storage_offset).copy_(src)``.
1904"""
1905
1906as_strided_scatter = _make_prim(
1907    schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
1908    meta=_as_strided_scatter_meta,
1909    impl_aten=torch.as_strided_scatter,
1910    return_type=RETURN_TYPE.NEW,
1911    doc=_as_strided_scatter_doc,
1912)
1913
1914
1915#
1916# Shape operations
1917#
1918
1919
1920def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
1921    # Special-case for zero dimensional tensors
1922    _validate_collapse_args(a, start, end)
1923    new_shape = _collapsed_shape(a.shape, start, end)
1924    return a.new_empty(new_shape)
1925
1926
1927def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
1928    new_shape = _collapsed_shape(a.shape, start, end)
1929    out = a.new_empty(new_shape)
1930    with torch.no_grad():
1931        out.view_as(a).copy_(a)
1932    return out
1933
1934
1935_collapse_doc = """
1936Collapse a span of neighboring dimensions into one.
1937
1938See collapse_view for the corresponding view operation.
1939"""
1940collapse = _make_prim(
1941    schema="collapse(Tensor a, int start, int end) -> Tensor",
1942    meta=_collapse_meta,
1943    impl_aten=_collapse_aten,
1944    return_type=RETURN_TYPE.NEW,
1945    doc=_collapse_doc,
1946)
1947
1948
1949# TODO: review stride logic
1950# NB: unlike torch.cat, this is more strict about empty tensors and dim is
1951# never negative
1952def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
1953    # Verifies same shape (except in the concat dimension)
1954    assert dim >= 0
1955    shape = tensors[0].shape
1956    concat_length = 0
1957    for tensor_idx, tensor in enumerate(tensors):
1958        assert len(shape) == len(tensor.shape)
1959        for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
1960            if idx == dim:
1961                concat_length = concat_length + length
1962            else:
1963                torch._check(
1964                    length == common_length,
1965                    lambda: f"Sizes of tensors must match except in dimension {dim}. "
1966                    f"Expected {common_length} but got {length} for tensor number "
1967                    f"{tensor_idx} in the list",
1968                )
1969
1970    new_shape = list(tensors[0].shape).copy()
1971    new_shape[dim] = concat_length
1972    return TensorMeta(
1973        tensors[0],
1974        shape=new_shape,
1975        strides=utils.make_contiguous_strides_for(new_shape),
1976    )
1977
1978
1979def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
1980    return torch.cat(tensors, dim)
1981
1982
1983_cat_doc = """
1984  Concatenates tensors along the specified dimension.
1985
1986  The tensors' shapes must have the same rank and same length for other dimensions.
1987  """
1988
1989cat = _make_prim(
1990    schema="cat(Tensor[] tensors, int dim) -> Tensor",
1991    meta=_cat_meta,
1992    impl_aten=_cat_aten,
1993    return_type=RETURN_TYPE.NEW,
1994    doc=_cat_doc,
1995)
1996
1997
1998def _reshape_meta(a: TensorLikeType, shape: ShapeType):
1999    assert isinstance(a, TensorLike)
2000    utils.validate_shape(shape)
2001
2002    # Validates the tensor and the requested shape have the
2003    # same number of elements
2004    numel = reduce(operator.mul, shape)
2005    if numel != a.numel():
2006        msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!"
2007        raise ValueError(msg)
2008
2009    return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
2010
2011
2012def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
2013    return a.reshape(shape).contiguous().clone()
2014
2015
2016_reshape_doc = """
2017  Creates a contiguous tensor with the specified shape
2018  containing a copy of the data in a.
2019  """
2020reshape = _make_prim(
2021    schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
2022    meta=_reshape_meta,
2023    impl_aten=_reshape_aten,
2024    return_type=RETURN_TYPE.NEW,
2025    doc=_reshape_doc,
2026)
2027
2028
2029def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
2030    utils.validate_dimension_indices(a.ndim, dims)
2031    return torch.empty_like(a, memory_format=torch.preserve_format)
2032
2033
2034_rev_doc = """
2035    Reverses the order of elements along the given dimensions.
2036    """
2037
2038rev = _make_prim(
2039    schema="rev(Tensor a, int[] dims) -> Tensor",
2040    meta=_rev_meta,
2041    impl_aten=torch.flip,
2042    return_type=RETURN_TYPE.NEW,
2043    doc=_rev_doc,
2044)
2045
2046#
2047# Conditional prims
2048#
2049
2050
2051def _where_meta(
2052    pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
2053) -> TensorLikeType:
2054    return _prim_elementwise_meta(
2055        a,
2056        b,
2057        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
2058        args_with_fixed_dtypes=(pred,),
2059    )
2060
2061
2062_where_doc = """
2063  Selects elements from a and b according to pred.
2064
2065  Where pred is true the result contains the element from a, and
2066  where pred is false the result contains the element from b.
2067  """
2068
2069where = _make_prim(
2070    schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
2071    meta=_where_meta,
2072    impl_aten=torch.where,
2073    return_type=RETURN_TYPE.NEW,
2074    doc=_where_doc,
2075)
2076
2077
2078#
2079# Type conversions
2080#
2081def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
2082    # Type checks
2083    assert isinstance(a, TensorLike)
2084    assert isinstance(dtype, torch.dtype)
2085
2086    # dtype conversion preserves dense strides
2087    if torch._prims_common.is_non_overlapping_and_dense(a):
2088        strides = a.stride()
2089    else:
2090        strides = utils.compute_elementwise_output_strides(a)
2091
2092    return TensorMeta(a, strides=strides, dtype=dtype)
2093
2094
2095def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
2096    # Propagates requires grad when possible
2097    if not utils.is_grad_dtype(dtype):
2098        requires_grad = False
2099    else:
2100        # TODO: update meta objects so this can be acquired directly
2101        try:
2102            requires_grad = a.requires_grad
2103        except Exception as e:
2104            requires_grad = False
2105
2106    result = torch.empty_like(
2107        a, device=a.device, dtype=dtype, requires_grad=requires_grad
2108    )
2109    with torch.no_grad():
2110        return copy_to(result, a)
2111
2112
2113_convert_element_type_doc = """
2114  Creates a copy of a tensor with the given dtype.
2115  """
2116
2117convert_element_type = _make_prim(
2118    schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
2119    meta=_convert_element_type_meta,
2120    impl_aten=_convert_element_type_aten,
2121    return_type=RETURN_TYPE.NEW,
2122    doc=_convert_element_type_doc,
2123    tags=(torch.Tag.pointwise,),
2124)
2125
2126
2127def _device_put_meta(
2128    a: TensorLikeType, device: Union[str, torch.device]
2129) -> TensorLikeType:
2130    assert isinstance(a, TensorLike)
2131    assert isinstance(device, (str, torch.device))
2132
2133    return TensorMeta(a, device=utils.canonicalize_device(device))
2134
2135
2136def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
2137    return a.to(device)
2138
2139
2140_device_put_doc = """
2141  Creates a copy of a tensor on the given device.
2142  """
2143
2144device_put = _make_prim(
2145    schema="device_put(Tensor a, Device device) -> Tensor",
2146    meta=_device_put_meta,
2147    impl_aten=_device_put_aten,
2148    return_type=RETURN_TYPE.NEW,
2149    doc=_device_put_doc,
2150)
2151
2152
2153# NOTE: need to model meta scalars
2154# See https://github.com/pytorch/pytorch/issues/78070
2155def _item_meta(a: TensorLikeType) -> FakeTensor:
2156    number_type = utils.dtype_to_type(a.dtype)
2157    return TensorMeta(number_type(-1))
2158
2159
2160_item_doc = """
2161    Converts a tensor with one element to a Python number.
2162"""
2163
2164# TODO: create a new return type for scalars?
2165# FIXME: currently returns integers for boolean tensors
2166# https://github.com/pytorch/pytorch/issues/78071
2167item = _make_prim(
2168    schema="item(Tensor a) -> Scalar",
2169    meta=_item_meta,
2170    impl_aten=torch.Tensor.item,
2171    return_type=RETURN_TYPE.NEW,
2172    doc=_item_doc,
2173)
2174
2175
2176# NOTE: need to model meta scalars
2177# See https://github.com/pytorch/pytorch/issues/78070
2178def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
2179    number_type = utils.dtype_to_type(dtype)
2180    return TensorMeta(number_type(-1))
2181
2182
2183def _maximum_value_aten(dtype: torch.dtype):
2184    if dtype == torch.bool:
2185        return True
2186    elif dtype.is_complex or dtype.is_floating_point:
2187        return torch.finfo(dtype).max
2188    else:
2189        return torch.iinfo(dtype).max
2190
2191
2192_maximum_value_doc = """
2193    Return the maximum finite value for a dtype.
2194"""
2195
2196# TODO: create a new return type for scalars?
2197# FIXME: currently returns integers for boolean tensors
2198# https://github.com/pytorch/pytorch/issues/78071
2199maximum_value = _make_prim(
2200    schema="maximum_value(ScalarType dtype) -> Scalar",
2201    meta=_maximum_value_meta,
2202    impl_aten=_maximum_value_aten,
2203    return_type=RETURN_TYPE.NEW,
2204    doc=_maximum_value_doc,
2205)
2206
2207
2208# NOTE: need to model meta scalars
2209# See https://github.com/pytorch/pytorch/issues/78070
2210def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
2211    number_type = utils.dtype_to_type(dtype)
2212    return TensorMeta(number_type(-1))
2213
2214
2215def _minimum_value_aten(dtype: torch.dtype):
2216    if dtype == torch.bool:
2217        return False
2218    elif dtype.is_complex or dtype.is_floating_point:
2219        return torch.finfo(dtype).min
2220    else:
2221        return torch.iinfo(dtype).min
2222
2223
2224_minimum_value_doc = """
2225    Return the minimum finite value for a dtype.
2226"""
2227
2228# TODO: create a new return type for scalars?
2229# FIXME: currently returns integers for boolean tensors
2230# https://github.com/pytorch/pytorch/issues/78071
2231minimum_value = _make_prim(
2232    schema="minimum_value(ScalarType dtype) -> Scalar",
2233    meta=_minimum_value_meta,
2234    impl_aten=_minimum_value_aten,
2235    return_type=RETURN_TYPE.NEW,
2236    doc=_minimum_value_doc,
2237)
2238
2239#
2240# Inplace operators
2241#
2242
2243
2244def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
2245    assert isinstance(a, TensorLike)
2246    assert isinstance(b, TensorLike)
2247
2248    # Validates the cast is safe
2249    # TODO: move this as an option on the reference
2250    # a_typ = utils.dtype_to_type(a.dtype)
2251    # b_typ = utils.dtype_to_type(b.dtype)
2252    # if a_typ is not utils.get_higher_type(a_typ, b_typ):
2253    #     raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
2254
2255    # Validates the tensors have the same number of elements
2256    if a.numel() != b.numel():
2257        msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!"
2258        raise RuntimeError(msg)
2259
2260    return a
2261
2262
2263def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
2264    return a.copy_(b)
2265
2266
2267_copy_to_doc = """
2268  Copies the data in b to a and returns the modified a.
2269  """
2270
2271# TODO: Remove safe casting and implement on reference instead
2272copy_to = _make_prim(
2273    schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
2274    meta=_copy_to_meta,
2275    impl_aten=_copy_to_aten,
2276    return_type=RETURN_TYPE.INPLACE,
2277    doc=_copy_to_doc,
2278    register_conj_neg_fallthrough=True,
2279)
2280
2281
2282def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
2283    assert isinstance(a, TensorLike)
2284    return torch.empty_strided(
2285        a.shape,
2286        stride,
2287        dtype=a.dtype,
2288        layout=a.layout,
2289        device=a.device,
2290        requires_grad=a.requires_grad,
2291    )
2292
2293
2294def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
2295    out = torch.empty_strided(
2296        a.size(),
2297        stride=stride,
2298        dtype=a.dtype,
2299        layout=a.layout,
2300        device=a.device,
2301        requires_grad=a.requires_grad,
2302    )
2303    out.copy_(a)
2304    return out
2305
2306
2307_copy_strided_doc = """
2308  Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
2309  """
2310
2311
2312copy_strided = _make_prim(
2313    schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
2314    meta=_copy_strided_meta,
2315    impl_aten=_copy_strided_aten,
2316    return_type=RETURN_TYPE.NEW,
2317    doc=_copy_strided_doc,
2318)
2319
2320
2321def _resize_meta(a: TensorLikeType, shape: ShapeType):
2322    return a.resize_(shape)
2323
2324
2325def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
2326    return a.resize_(shape)
2327
2328
2329_resize_doc = """
2330  Gives a tensor with no elements a new shape, returning the modified tensor.
2331
2332  The tensor's strides are contiguous and its values are unitialized.
2333  """
2334
2335# TODO: review support arbitrary resizes
2336resize = _make_prim(
2337    schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
2338    meta=_resize_meta,
2339    impl_aten=_resize_aten,
2340    return_type=RETURN_TYPE.INPLACE,
2341    doc=_resize_doc,
2342)
2343
2344
2345def _reduction_meta(inp, dims, *, output_dtype=None):
2346    """
2347    Meta function for single output reduction operations
2348    Stride logic is incorrect
2349    """
2350    assert isinstance(inp, TensorLike)
2351    if output_dtype is None:
2352        output_dtype = inp.dtype
2353    output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
2354    return TensorMeta(
2355        shape=output_shape,
2356        strides=utils.make_contiguous_strides_for(output_shape),
2357        dtype=output_dtype,
2358        device=inp.device,
2359    )
2360
2361
2362def _var_reduction_meta(inp, dims, correction):
2363    if utils.is_complex_dtype(inp.dtype):
2364        output_dtype = utils.corresponding_real_dtype(inp.dtype)
2365    else:
2366        output_dtype = inp.dtype
2367    return _reduction_meta(inp, dims, output_dtype=output_dtype)
2368
2369
2370_sum_doc = """
2371    Computes the sum of elements in the input tensor over the list of dimensions
2372    specified in the dim argument
2373    """
2374_xor_sum_doc = """
2375    Computes the xor sum of elements in the input tensor over the list of dimensions
2376    specified in the dim argument
2377    """
2378_prod_doc = """
2379    Computes the product of elements in the input tensor over the list of dimensions
2380    specified in the dim argument
2381    """
2382_amax_doc = """
2383    Computes the maximum value of elements in the input tensor over the list of dimensions
2384    specified in the dim argument
2385    """
2386_amin_doc = """
2387    Computes the minimum value of elements in the input tensor over the list of dimensions
2388    specified in the dim argument
2389    """
2390_var_doc = """
2391    Computes the biased variance of x over the list of dimensions specified in the dim argument
2392    """
2393
2394
2395def _make_reduction_prim(name: str, impl_aten, doc):
2396    """Creates a reduction prim."""
2397    return _make_prim(
2398        schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
2399        meta=_reduction_meta,
2400        impl_aten=impl_aten,
2401        return_type=RETURN_TYPE.NEW,
2402        doc=doc,
2403    )
2404
2405
2406def _make_var_reduction_prim(name: str, impl_aten, doc):
2407    """Creates a reduction prim."""
2408    return _make_prim(
2409        schema=f"{name}(Tensor inp, int[]? dims, float? correction=1, *, ScalarType? output_dtype=None) -> Tensor",
2410        meta=_var_reduction_meta,
2411        impl_aten=impl_aten,
2412        return_type=RETURN_TYPE.NEW,
2413        doc=doc,
2414    )
2415
2416
2417sum = _make_reduction_prim(
2418    name="sum",
2419    impl_aten=torch.sum,
2420    doc=_sum_doc,
2421)
2422
2423
2424def _xor_sum_aten(
2425    inp: TensorLikeType,
2426    dims: Optional[DimsSequenceType],
2427    *,
2428    dtype: Optional[torch.dtype] = None,
2429) -> Tensor:
2430    raise NotImplementedError("xor_sum only implemented with inductor")
2431
2432
2433xor_sum = _make_reduction_prim(
2434    name="xor_sum",
2435    impl_aten=_xor_sum_aten,
2436    doc=_xor_sum_doc,
2437)
2438
2439
2440def _prod_aten(
2441    inp: TensorLikeType,
2442    dims: Optional[DimsSequenceType],
2443    *,
2444    dtype: Optional[torch.dtype] = None,
2445) -> Tensor:
2446    if dims is not None:
2447        if len(dims) == 0:
2448            return inp.clone()
2449        for d in sorted(dims, reverse=True):
2450            assert d >= 0
2451            inp = torch.prod(inp, d, dtype=dtype)
2452        return inp
2453    else:
2454        return torch.prod(inp, dims, dtype=dtype)
2455
2456
2457prod = _make_reduction_prim(
2458    name="prod",
2459    impl_aten=_prod_aten,
2460    doc=_prod_doc,
2461)
2462
2463
2464# torch.var, but correction is not kwarg-only
2465def torch_var(input, dim=None, correction=1, **kwargs):
2466    return torch.var(input, dim=dim, correction=correction, **kwargs)
2467
2468
2469var = _make_var_reduction_prim(
2470    name="var",
2471    impl_aten=torch_var,
2472    doc=_var_doc,
2473)
2474
2475amax = _make_reduction_prim(
2476    name="amax",
2477    impl_aten=torch.amax,
2478    doc=_amax_doc,
2479)
2480
2481amin = _make_reduction_prim(
2482    name="amin",
2483    impl_aten=torch.amin,
2484    doc=_amin_doc,
2485)
2486
2487
2488_iota_doc = """
2489    Constructs a 1-D tensor t where ``t[i] == start + i * step``.
2490"""
2491
2492
2493# TODO: layout, pin_memory, memory_format
2494# TODO: model requires_grad on TensorMeta
2495def _iota_meta(
2496    length: int,
2497    *,
2498    start: int,
2499    step: int,
2500    dtype: torch.dtype,
2501    device: torch.device,
2502    requires_grad: bool,
2503) -> TensorLikeType:
2504    torch._check(
2505        utils.is_integer_dtype(dtype),
2506        lambda: "prims.iota only supports integer dtypes",
2507    )
2508    torch._check(step != 0, lambda: "step must be nonzero")
2509    return torch.empty(
2510        length,
2511        dtype=dtype,
2512        device=device,
2513        requires_grad=requires_grad,
2514    )
2515
2516
2517def _iota_aten(
2518    length: int,
2519    *,
2520    start: int,
2521    step: int,
2522    dtype: torch.dtype,
2523    device: torch.device,
2524    requires_grad: bool,
2525) -> TensorLikeType:
2526    end = start + length * step
2527    return torch.arange(
2528        start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
2529    )
2530
2531
2532iota = _make_prim(
2533    schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor",  # noqa: B950
2534    return_type=RETURN_TYPE.NEW,
2535    meta=_iota_meta,
2536    impl_aten=_iota_aten,
2537    doc=_iota_doc,
2538)
2539
2540
2541# TODO: layout, pin_memory, memory_format
2542# TODO: model requires_grad on TensorMeta
2543def _empty_meta(
2544    shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2545) -> TensorLikeType:
2546    strides = utils.make_contiguous_strides_for(shape)
2547    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2548
2549
2550def _empty_aten(
2551    shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2552) -> Tensor:
2553    return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2554
2555
2556_empty_doc = """
2557    Creates a tensor with uninitialized values and the specified shape, dtype, and device.
2558"""
2559
2560empty = _make_prim(
2561    schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2562    meta=_empty_meta,
2563    impl_aten=_empty_aten,
2564    return_type=RETURN_TYPE.NEW,
2565    doc=_empty_doc,
2566)
2567
2568
2569def _empty_strided_meta(
2570    shape: ShapeType,
2571    strides: StrideType,
2572    *,
2573    dtype: torch.dtype,
2574    device: torch.device,
2575    requires_grad: bool,
2576) -> TensorLikeType:
2577    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2578
2579
2580_empty_strided_doc = """
2581    Creates a tensor with uninitialized values.
2582"""
2583
2584# TODO: add layout, pin_memory
2585empty_strided = _make_prim(
2586    schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2587    return_type=RETURN_TYPE.NEW,
2588    meta=_empty_strided_meta,
2589    impl_aten=torch.empty_strided,
2590    doc=_empty_strided_doc,
2591)
2592
2593
2594def _empty_permuted_meta(
2595    shape: ShapeType,
2596    physical_layout: DimsSequenceType,
2597    *,
2598    dtype: torch.dtype,
2599    device: torch.device,
2600    requires_grad: bool,
2601) -> TensorLikeType:
2602    p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
2603    dim = len(shape)
2604    torch._check(
2605        len(physical_layout) == dim,
2606        lambda: (
2607            "Number of dimensions in the tensor input does not match the "
2608            f"length of the physical layout; i.e. len(size) = {dim} "
2609            f"is not equal to len(physical_layout) = {len(physical_layout)}"
2610        ),
2611    )
2612    strides = [0] * len(shape)
2613    seen_dims = set()
2614    for p, l in enumerate(physical_layout):
2615        torch._check(
2616            0 <= l < dim,
2617            lambda: (
2618                f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
2619                f"{l} at index {p}).  NB: negative dims "
2620                "not currently supported; file an issue if you want it."
2621            ),
2622        )
2623        torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
2624        strides[l] = p_strides[p]
2625        seen_dims.add(l)
2626    return TensorMeta(
2627        shape=shape,
2628        strides=strides,
2629        dtype=dtype,
2630        device=device,
2631    )
2632
2633
2634_empty_permuted_doc = """
2635    Creates a tensor with uninitialized values according to some physical layout,
2636    that is guaranteed to be non-overlapping and dense.
2637"""
2638
2639# TODO: add layout, pin_memory
2640empty_permuted = _make_prim(
2641    schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",  # noqa: B950
2642    return_type=RETURN_TYPE.NEW,
2643    meta=_empty_permuted_meta,
2644    impl_aten=torch.empty_permuted,
2645    doc=_empty_permuted_doc,
2646)
2647
2648
2649def _full_meta(
2650    shape: ShapeType,
2651    fill_value: NumberType,
2652    *,
2653    dtype: torch.dtype,
2654    device: torch.device,
2655    requires_grad: bool,
2656) -> TensorLikeType:
2657    strides = utils.make_contiguous_strides_for(shape)
2658    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2659
2660
2661def _full_aten(
2662    shape: ShapeType,
2663    fill_value: NumberType,
2664    *,
2665    dtype: torch.dtype,
2666    device: torch.device,
2667    requires_grad: bool,
2668) -> Tensor:
2669    # Note that Mypy thinks torch.full can't accept a complex fill_value
2670    return torch.full(
2671        shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad  # type: ignore[arg-type]
2672    )
2673
2674
2675_full_doc = """
2676    Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
2677"""
2678
2679# TODO: add layout
2680full = _make_prim(
2681    schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2682    meta=_full_meta,
2683    impl_aten=_full_aten,
2684    return_type=RETURN_TYPE.NEW,
2685    doc=_full_doc,
2686)
2687
2688
2689def _full_like_meta(
2690    a: TensorLikeType,
2691    fill_value: NumberType,
2692    *,
2693    dtype: torch.dtype,
2694    device: torch.device,
2695    requires_grad: bool,
2696) -> TensorLikeType:
2697    strides = utils.compute_elementwise_output_strides(a)
2698    if a.numel() == 0:
2699        strides = a.stride()
2700
2701    return TensorMeta(a, strides=strides, dtype=dtype, device=device)
2702
2703
2704def _full_like_aten(
2705    a: Tensor,
2706    fill_value: NumberType,
2707    *,
2708    dtype: torch.dtype,
2709    device: torch.device,
2710    requires_grad: bool,
2711) -> Tensor:
2712    # Note that Mypy thinks torch.full can't accept a complex fill_value
2713    return torch.full_like(
2714        a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad  # type: ignore[arg-type]
2715    )
2716
2717
2718_full_like_doc = """
2719    Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
2720    given tensor by default. The dtype and device settings can be overridden
2721    by specifying them explicitly.
2722"""
2723
2724full_like = _make_prim(
2725    schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2726    meta=_full_like_meta,
2727    impl_aten=_full_like_aten,
2728    return_type=RETURN_TYPE.NEW,
2729    doc=_full_like_doc,
2730)
2731
2732
2733def _scalar_tensor_meta(
2734    scalar: NumberType,
2735    *,
2736    dtype: torch.dtype,
2737    device: torch.device,
2738) -> TensorLikeType:
2739    shape: ShapeType = []
2740    strides = utils.make_contiguous_strides_for(shape)
2741    return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
2742
2743
2744def _scalar_tensor_aten(
2745    scalar: NumberType,
2746    *,
2747    dtype: torch.dtype,
2748    device: torch.device,
2749) -> Tensor:
2750    if isinstance(scalar, complex) and (
2751        dtype is None or not utils.is_complex_dtype(dtype)
2752    ):
2753        raise TypeError("Complex scalar requires complex tensor dtype.")
2754    # Note that Mypy thinks torch.scalar can't accept a complex scalar
2755    return torch.scalar_tensor(scalar, dtype=dtype, device=device)  # type: ignore[arg-type]
2756
2757
2758_scalar_tensor_doc = """
2759    Wraps a Number into a Tensor with the specified dtype and device.
2760"""
2761
2762# TODO: add layout and pin_memory support
2763scalar_tensor = _make_prim(
2764    schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
2765    meta=_scalar_tensor_meta,
2766    impl_aten=_scalar_tensor_aten,
2767    return_type=RETURN_TYPE.NEW,
2768    doc=_scalar_tensor_doc,
2769)
2770
2771
2772#
2773# Linear algebra (linalg) prims
2774#
2775
2776
2777def _svd_meta(
2778    A: TensorLikeType, *, full_matrices: bool
2779) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
2780    utils.check_is_matrix(A, "linalg.svd")
2781    utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
2782
2783    A_shape = A.shape
2784    batch = A_shape[:-2]
2785    m, n = A_shape[-2:]
2786    k = min(m, n)
2787
2788    shape_U = batch + (m, m if full_matrices else k)
2789    strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
2790    U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
2791
2792    shape_S = batch + (k,)
2793    strides_S = utils.make_contiguous_strides_for(shape_S)
2794    S = TensorMeta(
2795        shape=shape_S,
2796        strides=strides_S,
2797        dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
2798        device=A.device,
2799    )
2800
2801    shape_Vh = batch + (n if full_matrices else k, n)
2802    # The CPU backend returns V, but the cuSolver backend returns V^H
2803    # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
2804    is_cuda = A.device.type == "cuda"
2805    strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
2806    Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
2807    # Also makes sure this is CUDA or HIP:
2808    # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
2809    if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available():
2810        Vh = Vh.conj()
2811    return U, S, Vh
2812
2813
2814def _svd_aten(
2815    A: TensorLikeType, *, full_matrices: bool
2816) -> Tuple[Tensor, Tensor, Tensor]:
2817    return torch.linalg.svd(A, full_matrices=full_matrices)
2818
2819
2820_svd_doc = """
2821    Returns the SVD of a matrix or batch of matrices.
2822
2823    The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
2824"""
2825
2826svd = _make_prim(
2827    schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
2828    meta=_svd_meta,
2829    impl_aten=_svd_aten,
2830    return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
2831    doc=_svd_doc,
2832)
2833
2834
2835#
2836# Randomness Prims
2837#
2838
2839
2840def _normal_meta(
2841    shape: ShapeType,
2842    *,
2843    mean: Union[float, complex],
2844    std: float,
2845    dtype: torch.dtype,
2846    device: torch.device,
2847    requires_grad: bool,
2848    generator: Optional[torch.Generator] = None,
2849) -> TensorLikeType:
2850    torch._check(
2851        std >= 0.0,
2852        lambda: f"expected non-negative standard deviation, but got std={std}",
2853    )
2854
2855    torch._check(
2856        utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
2857        lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
2858    )
2859
2860    strides = utils.make_contiguous_strides_for(shape)
2861    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2862
2863
2864def _normal_aten(
2865    shape: ShapeType,
2866    *,
2867    mean: Union[float, complex],
2868    std: float,
2869    dtype: torch.dtype,
2870    device: torch.device,
2871    requires_grad: bool,
2872    generator: Optional[torch.Generator] = None,
2873) -> Tensor:
2874    a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2875    with torch.no_grad():
2876        # NOTE: normal_ is incorrectly annotated to expect mean to be a float
2877        a.normal_(mean, std, generator=generator)  # type: ignore[arg-type]
2878    return a
2879
2880
2881_normal_doc = """
2882    Constructs a tensor filled with values drawn from a normal distribution with the specified mean
2883    and standard deviation.
2884
2885    Only supports floating-point types.
2886"""
2887
2888normal = _make_prim(
2889    schema=(
2890        "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor"  # noqa: B950
2891    ),
2892    return_type=RETURN_TYPE.NEW,
2893    meta=_normal_meta,
2894    impl_aten=_normal_aten,
2895    doc=_normal_doc,
2896)
2897
2898
2899def _uniform_meta(
2900    shape: ShapeType,
2901    *,
2902    low: float,
2903    high: float,
2904    dtype: torch.dtype,
2905    device: torch.device,
2906    generator: Optional[torch.Generator] = None,
2907) -> TensorLikeType:
2908    strides = utils.make_contiguous_strides_for(shape)
2909    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2910
2911
2912def _uniform_aten(
2913    shape: ShapeType,
2914    *,
2915    low: float,
2916    high: float,
2917    dtype: torch.dtype,
2918    device: torch.device,
2919    generator: Optional[torch.Generator] = None,
2920) -> Tensor:
2921    a = torch.empty(shape, dtype=dtype, device=device)
2922    a.uniform_(low, high, generator=generator)
2923    return a
2924
2925
2926_uniform_doc = """
2927    Constructs a tensor filled with values drawn uniformly from low to high.
2928"""
2929
2930# TODO: we should more seriously review randomness modeling and prims
2931_uniform_helper = _make_prim(
2932    schema=(
2933        "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
2934    ),
2935    return_type=RETURN_TYPE.NEW,
2936    meta=_uniform_meta,
2937    impl_aten=_uniform_aten,
2938    doc=_uniform_doc,
2939)
2940
2941#
2942# FFT prims
2943#
2944
2945
2946def _fft_r2c_meta(
2947    input: TensorLike,
2948    *,
2949    dim: DimsSequenceType,
2950    onesided: bool,
2951) -> TensorLikeType:
2952    dim = utils.canonicalize_dims(input.ndim, dim)
2953    utils.validate_no_repeating_dims(dim)
2954
2955    shape = list(input.shape)
2956    if onesided:
2957        last_dim = dim[-1]
2958        shape[last_dim] = shape[last_dim] // 2 + 1
2959
2960    dtype = utils.corresponding_complex_dtype(input.dtype)
2961    strides = utils.make_contiguous_strides_for(shape)
2962    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
2963
2964
2965def _fft_r2c_aten(
2966    input: TensorLike,
2967    *,
2968    dim: DimsSequenceType,
2969    onesided: bool,
2970) -> TensorLikeType:
2971    normalization = 0  # No normalization
2972    return torch._fft_r2c(input, dim, normalization, onesided)
2973
2974
2975_fft_r2c_doc = """
2976    Performs a real to complex Fast Fourier Transform
2977"""
2978
2979
2980fft_r2c = _make_prim(
2981    schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
2982    meta=_fft_r2c_meta,
2983    impl_aten=_fft_r2c_aten,
2984    return_type=RETURN_TYPE.NEW,
2985    doc=_fft_r2c_doc,
2986)
2987
2988
2989def _fft_c2c_meta(
2990    input: TensorLike,
2991    *,
2992    dim: DimsSequenceType,
2993    forward: bool,
2994) -> TensorLikeType:
2995    dim = utils.canonicalize_dims(input.ndim, dim)
2996    utils.validate_no_repeating_dims(dim)
2997
2998    shape = input.shape
2999    strides = utils.make_contiguous_strides_for(shape)
3000    return TensorMeta(
3001        shape=shape, strides=strides, dtype=input.dtype, device=input.device
3002    )
3003
3004
3005def _fft_c2c_aten(
3006    input: TensorLike,
3007    *,
3008    dim: DimsSequenceType,
3009    forward: bool,
3010) -> TensorLikeType:
3011    normalization = 0  # No normalization
3012    return torch._fft_c2c(input, dim, normalization, forward)
3013
3014
3015_fft_c2c_doc = """
3016    Performs either a Fast Fourier Transform, or its inverse
3017"""
3018
3019
3020fft_c2c = _make_prim(
3021    schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
3022    meta=_fft_c2c_meta,
3023    impl_aten=_fft_c2c_aten,
3024    return_type=RETURN_TYPE.NEW,
3025    doc=_fft_c2c_doc,
3026)
3027
3028
3029def _fft_c2r_meta(
3030    input: TensorLike,
3031    *,
3032    dim: DimsSequenceType,
3033    last_dim_size: int,
3034) -> TensorLikeType:
3035    dim = utils.canonicalize_dims(input.ndim, dim)
3036    utils.validate_no_repeating_dims(dim)
3037
3038    shape = list(input.shape)
3039    shape[dim[-1]] = last_dim_size
3040    dtype = utils.corresponding_real_dtype(input.dtype)
3041    strides = utils.make_contiguous_strides_for(shape)
3042    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
3043
3044
3045def _fft_c2r_aten(
3046    input: TensorLike,
3047    *,
3048    dim: DimsSequenceType,
3049    last_dim_size: int,
3050) -> TensorLikeType:
3051    normalization = 0  # No normalization
3052    return torch._fft_c2r(input, dim, normalization, last_dim_size)
3053
3054
3055_fft_c2r_doc = """
3056    Performs a complex to real Inverse Fast Fourier Transform
3057"""
3058
3059
3060fft_c2r = _make_prim(
3061    schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
3062    meta=_fft_c2r_meta,
3063    impl_aten=_fft_c2r_aten,
3064    return_type=RETURN_TYPE.NEW,
3065    doc=_fft_c2r_doc,
3066)
3067
3068
3069def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
3070    torch._check(
3071        self.dtype.is_floating_point,
3072        lambda: "torch.frexp() only supports floating-point dtypes",
3073    )
3074    return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32)
3075
3076
3077frexp = _make_prim(
3078    schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)",
3079    meta=_frexp_meta,
3080    return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW),
3081    impl_aten=torch.frexp,
3082    doc="",
3083)
3084
3085
3086def _make_token_aten() -> TensorLikeType:
3087    return new_token_tensor()
3088
3089
3090_make_token = _make_prim(
3091    schema="_make_token() -> Tensor",
3092    meta=_make_token_aten,
3093    return_type=RETURN_TYPE.NEW,
3094    impl_aten=_make_token_aten,
3095    doc="Creates a token used for keeping track of side effects.",
3096)
3097
3098
3099def _sink_tokens_aten(tokens) -> None:
3100    pass
3101
3102
3103_sink_tokens = _make_prim(
3104    schema="_sink_tokens(Tensor[] tokens) -> ()",
3105    meta=_sink_tokens_aten,
3106    return_type=RETURN_TYPE.NONE,
3107    impl_aten=_sink_tokens_aten,
3108    doc="Sink all of the tokens which were previously used for keeping track of side effects.",
3109)
3110
3111
3112register_rng_prims()
3113register_debug_prims()
3114