xref: /aosp_15_r20/external/pytorch/torch/_prims_common/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import operator
5import warnings
6from contextlib import nullcontext
7from enum import Enum
8from functools import reduce
9from typing import (
10    Any,
11    Callable,
12    cast,
13    List,
14    NamedTuple,
15    Optional,
16    overload,
17    Sequence,
18    Tuple,
19    Type,
20    TYPE_CHECKING,
21    Union,
22)
23from typing_extensions import deprecated, TypeAlias
24
25
26if TYPE_CHECKING:
27    # Import the following modules during type checking to enable code intelligence features,
28    # such as auto-completion in tools like pylance, even when these modules are not explicitly
29    # imported in user code.
30
31    import sympy
32
33import torch
34from torch import sym_float, sym_int, sym_max
35
36
37ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
38StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
39DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
40DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
41# TODO: Type[torch.SymInt], Type[torch.SymFloat]
42NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
43# TODO: This needs a lot more type annotations
44# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
45NumberType: TypeAlias = Union[bool, int, float, complex]
46RealNumberType: TypeAlias = Union[bool, int, float]
47
48Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool)
49# I don't call it Integral because numbers.Integral includes bool, but IntLike
50# does not
51Dim = int
52IntLike = (int, torch.SymInt)
53FloatLike = (float, torch.SymFloat)
54BoolLike = (bool, torch.SymBool)
55IntWithoutSymInt = int
56FloatWithoutSymFloat = float
57DeviceLikeType: TypeAlias = Union[str, torch.device, int]
58Tensor = torch.Tensor
59
60
61torch_function_passthrough = {
62    torch.device,
63    torch.sym_not,
64    torch.sym_float,
65    torch.sym_int,
66    torch.sym_max,
67    torch.sym_min,
68    torch._sym_sqrt,  # type: ignore[attr-defined]
69    torch.sym_ite,
70    torch.Tensor.dim,
71    torch.Tensor.ndim.__get__,  # type: ignore[attr-defined]
72    torch.Tensor.numel,
73    torch.Tensor.size,
74    torch.Tensor.storage_offset,
75    torch.Tensor.stride,
76    torch.Tensor.dtype.__get__,  # type: ignore[attr-defined]
77    torch.Tensor.is_sparse.__get__,  # type: ignore[attr-defined]
78    torch.Tensor.shape.__get__,  # type: ignore[attr-defined]
79    torch.Tensor.device.__get__,  # type: ignore[attr-defined]
80    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
81    torch.Tensor.layout.__get__,  # type: ignore[attr-defined]
82    torch.Tensor.is_contiguous,
83    # For TorchRefsMode only
84    torch.Tensor.__format__,
85    torch.Tensor.__repr__,
86    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
87    torch.Tensor.__getitem__,
88}
89
90
91TensorLikeType = torch.Tensor
92TensorLike = torch.Tensor
93TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
94TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
95
96CustomOutParamAnnotation = "__custom_out_param__"
97
98
99def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
100    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
101
102    if len(a) != len(b):
103        return False
104
105    for x, y in zip(a, b):
106        if allow_rhs_unbacked:
107            # TODO: We should check that the symbols are consistent
108            # with each other
109            if isinstance(y, torch.SymInt):
110                continue
111        # NB: Naively, you would not expect to have to do an oblivious guard
112        # here because there is seemingly no broadcasting here, but in fact we
113        # use this in some situations to determine if we need to do an expand
114        # on the tensor because they don't line up, so you can definitely end
115        # up trying to prove u0 != 1 in this situation.  See
116        # python test/test_proxy_tensor.py -k test_cumsum_unbacked
117        if guard_size_oblivious(x != y):
118            return False
119
120    return True
121
122
123def _maybe_get_pytype(t):
124    if t is torch.SymFloat:
125        return float
126    elif t is torch.SymInt:
127        return int
128    elif t is torch.SymBool:
129        return bool
130    else:
131        return t
132
133
134# TODO: look at using torch.testing.assert_close instead with an option
135#   to just compare metadata
136def compare_tensor_meta(
137    a: TensorLikeType,
138    b: TensorLikeType,
139    check_strides=False,
140    *,
141    allow_rhs_unbacked=False,
142    check_conj=True,
143):
144    """
145    Checks that two tensor likes have the same shape,
146    dtype and device.
147
148    In the future this will validate additional metadata, like
149    strides.
150    """
151    assert isinstance(a, TensorLike)
152    assert isinstance(b, TensorLike)
153
154    if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
155        msg = f"Shapes {a.shape} and {b.shape} are not equal!"
156        raise AssertionError(msg)
157
158    if a.dtype != b.dtype:
159        msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"
160        raise AssertionError(msg)
161
162    if a.device != b.device:
163        # Handles special cuda:0 vs cuda case
164        # TODO: we should review why this happens and see about fixing it
165        if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
166            str(b.device) == "cuda:0" or str(b.device) == "cuda"
167        ):
168            pass
169        else:
170            msg = f"Devices {a.device} and {b.device} are not equal!"
171            raise AssertionError(msg)
172
173    # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
174    if check_strides:
175        same_strides, idx = check_significant_strides(a, b)
176        if not same_strides:
177            msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"
178            raise RuntimeError(msg)
179
180        if a.storage_offset() != b.storage_offset():
181            msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"
182            raise RuntimeError(msg)
183
184    if check_conj:
185        if a.is_conj() != b.is_conj():
186            raise RuntimeError(
187                f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"
188            )
189
190    if a.is_neg() != b.is_neg():
191        raise RuntimeError(
192            f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"
193        )
194
195
196def _check_strides_helper(
197    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True
198) -> Tuple[bool, Optional[int]]:
199    # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
200    # See https://github.com/pytorch/pytorch/issues/77553
201    # Only compares strides that are "meaningful" -- strides for dimensions with length > 1
202    # and for tensors with more than one element
203    if (
204        not only_cuda or a.device.type == "cuda" or b.device.type == "cuda"
205    ) and a.numel() > 0:
206        for idx in range(a.ndim):
207            check = not significant_only or a.shape[idx] > 1
208            if a.stride()[idx] != b.stride()[idx] and check:
209                return False, idx
210
211    return True, None
212
213
214def check_significant_strides(
215    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
216) -> Tuple[bool, Optional[int]]:
217    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)
218
219
220def check_all_strides(
221    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
222) -> Tuple[bool, Optional[int]]:
223    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
224
225
226# This function is equivalent to compute_contiguous() from TensorImpl.cpp
227def is_contiguous(a: TensorLikeType) -> bool:
228    """
229    Tests whether a tensor is contiguous or not.
230
231    Tensors are contiguous when they have no elements,
232    one element, or when they have "nested" strides.
233    """
234    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
235
236    if guard_size_oblivious(a.numel() < 2):
237        return True
238
239    expected_stride = 1
240    for x, y in reversed(tuple(zip(a.shape, a.stride()))):
241        # Skips checking strides when a dimension has length 1
242        if guard_size_oblivious(x == 1):
243            continue
244
245        if guard_size_oblivious(y != expected_stride):
246            return False
247        expected_stride = expected_stride * x
248
249    return True
250
251
252# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
253def is_channels_last_contiguous_2d(a: Tensor) -> bool:
254    # NHWC or not channels last 2D contiguous
255    if a.ndim != 4:
256        return False
257
258    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
259
260    expected_stride = 1
261    for idx in (1, 3, 2, 0):
262        length = a.shape[idx]
263        if guard_size_oblivious(length == 1):
264            continue
265
266        stride = a.stride()[idx]
267        if guard_size_oblivious(stride != expected_stride):
268            return False
269
270        expected_stride *= length
271
272    return True
273
274
275def is_channels_last_contiguous_3d(a: Tensor) -> bool:
276    # NDHWC or not channels last 3D contiguous
277    if a.ndim != 5:
278        return False
279
280    expected_stride = 1
281    for idx in (1, 4, 3, 2, 0):
282        length = a.shape[idx]
283        if length == 1:
284            continue
285
286        stride = a.stride()[idx]
287        if stride != expected_stride:
288            return False
289
290        expected_stride *= length
291
292    return True
293
294
295_memory_formats = {
296    torch.contiguous_format,
297    torch.preserve_format,
298    torch.channels_last,
299    torch.channels_last_3d,
300}
301
302
303def validate_memory_format(memory_format: torch.memory_format):
304    torch._check(
305        memory_format in _memory_formats,
306        lambda: f"Received unknown memory format {memory_format}!",
307    )
308
309
310def is_contiguous_for_memory_format(  # type: ignore[return]
311    a: Tensor, *, memory_format: torch.memory_format
312) -> bool:
313    validate_memory_format(memory_format)
314
315    if memory_format == torch.contiguous_format:
316        return is_contiguous(a)
317    if memory_format == torch.channels_last:
318        return is_channels_last_contiguous_2d(a)
319    if memory_format == torch.channels_last_3d:
320        return is_channels_last_contiguous_3d(a)
321
322    torch._check(
323        False,
324        lambda: f"is_contiguous received unsupported memory format {memory_format}",
325    )
326
327
328# NOTE: that tensors with no elements and channels last is ???
329def is_channels_last_contiguous(a: Tensor) -> bool:
330    """
331    True when a tensor is channels-last contiguous.
332
333    This requires that:
334
335      - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
336      - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
337        stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
338        each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
339        "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
340        for example.
341    """
342    return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
343
344
345def is_non_overlapping_and_dense(a: Tensor) -> bool:
346    """
347    True when a tensor is non-overlapping and dense.
348
349    A tensor is non-overlapping and dense when there exists a permutation of
350    its dimensions that is contiguous.
351    """
352
353    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
354
355    if a.is_sparse:
356        return False
357
358    # Short-circuits if the tensor is already contiguous or channels-last contiguous
359    if is_contiguous(a) or is_channels_last_contiguous(a):
360        return True
361
362    # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
363
364    # Short-circuits for tensors of rank one, which are
365    # non-overlapping and "dense" if their stride is one
366    if a.ndim == 1:
367        return a.stride()[0] == 1
368
369    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
370    # Sorts (length, stride) pairs by stride
371    #
372    # This sort is done in a size-oblivious way, which helps if we do a
373    # comparison like 2048*u0 > u0; we just want this to return True
374    # (and not worry about what if u0 is zero).
375    class K(NamedTuple):
376        size: int
377        stride: int
378
379        def __lt__(self, other):
380            return guard_size_oblivious(self.stride < other.stride)
381
382        def __gt__(self, other):
383            return guard_size_oblivious(self.stride > other.stride)
384
385        def __le__(self, other):
386            return guard_size_oblivious(self.stride <= other.stride)
387
388        def __ge__(self, other):
389            return guard_size_oblivious(self.stride >= other.stride)
390
391        def __eq__(self, other):
392            return guard_size_oblivious(self.stride == other.stride)
393
394    lengths_and_strides = sorted(map(K, a.shape, a.stride()))
395
396    expected_stride = 1
397    for length, stride in lengths_and_strides:
398        if guard_size_oblivious(length == 1):
399            continue
400
401        if stride != expected_stride:
402            return False
403
404        expected_stride *= length
405
406    return True
407
408
409# NOTE: Based on the implementation in TensorIterator.cpp, but note that
410# the note [Computing output strides] is incorrect, because it
411# says that strides will be preserved even if they are not
412# "non overlapping and dense", but this is incorrect. The
413# output of elementwise operations are always given
414# non overlapping and dense strides.
415# This is also INCORRECT because it does not model TensorIterator's
416# short-circuit, which can cause different strides.
417def compute_elementwise_output_logical_to_physical_perm(
418    *tensors, _skip_checks=False
419) -> List[int]:
420    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
421
422    if not _skip_checks and len(tensors) == 0:
423        msg = "Can't compute elementwise output strides for zero tensors!"
424        raise ValueError(msg)
425
426    if not _skip_checks:
427        check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
428
429    # Filters the tensors to actual tensors
430    if not _skip_checks:
431        tensors = tuple(
432            a
433            for a in tensors
434            if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
435        )
436
437    # Short-circuits for CPU scalar case
438    if len(tensors) == 0:
439        return []
440
441    # Short-circuits for shapes with zero or one dimensions
442    # TODO: are these necessary?
443    ndim = tensors[0].ndim
444    if ndim == 0:
445        return []
446    if ndim == 1:
447        return [0]
448
449    # Short-circuits if contiguous or channels last, following the fake fast path.
450    # This reduces the number of guards we end up making
451    is_contiguous = True
452    is_channels_last = True
453    for t in tensors:
454        is_contiguous = is_contiguous and t.is_contiguous(
455            memory_format=torch.contiguous_format
456        )
457        is_channels_last = is_channels_last and t.is_contiguous(
458            memory_format=torch.channels_last
459        )
460
461    if is_contiguous and not is_channels_last:
462        return list(range(ndim))
463
464    if is_channels_last and not is_contiguous:
465        return [0, *list(range(2, ndim)), 1]
466
467    shape = tensors[0].shape
468
469    def should_swap(idx_a, idx_b):
470        for tensor in tensors:
471            stride_a = tensor.stride()[idx_a]
472            stride_b = tensor.stride()[idx_b]
473
474            if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
475                stride_b == 0
476            ):
477                continue
478
479            if guard_size_oblivious(stride_a < stride_b):
480                return -1
481
482            if guard_size_oblivious(stride_a > stride_b):
483                return 1
484
485            # stride_a == stride_b
486            if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
487                return 1
488
489        # Note: this case is hit if all strides are zero,
490        # or all strides are equal and all dimensions have the same length
491        return 0
492
493    # The "sort" order for the permutation is back-to-front, but
494    # the natural order for permutations is front-to-back.  Do the
495    # sorting back-to-front and then reverse it on output.
496    #
497    # also, note this returns the logical to physical shape permutation
498    perm = list(reversed(range(ndim)))
499
500    # insertion sort with support for ambiguous comparisons
501    for i in range(1, ndim):
502        dim1 = i
503        for dim0 in reversed(range(i)):
504            comparison = should_swap(perm[dim0], perm[dim1])
505            if comparison > 0:
506                perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
507                dim1 = dim0
508            elif comparison < 0:
509                break
510
511    return list(reversed(perm))
512
513
514def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
515    """
516    Computes the output strides for elementwise operations.
517    """
518    if len(tensors) == 0:
519        msg = "Can't compute elementwise output strides for zero tensors!"
520        raise ValueError(msg)
521
522    check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
523
524    # Filters the tensors to actual tensors
525    tensors = tuple(
526        a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
527    )
528
529    # Short-circuits for CPU scalar case
530    if len(tensors) == 0:
531        return ()
532
533    ndim = tensors[0].ndim
534    shape = tensors[0].shape
535
536    if ndim == 0:
537        return ()
538    if ndim == 1:
539        return (1,)
540
541    logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
542        *tensors, _skip_checks=True
543    )
544    permuted_shape = apply_perm(shape, logical_to_physical_perm)  # to physical
545
546    new_strides = make_contiguous_strides_for(permuted_shape)
547    permuted_strides = apply_perm(
548        new_strides, invert_perm(logical_to_physical_perm)
549    )  # to logical
550
551    return tuple(permuted_strides)
552
553
554# Identity permutation is [0, 1, 2]
555def apply_perm(inp, perm):
556    ndim = len(inp)
557    permuted_inp = [-1] * ndim
558    for idx, x in enumerate(perm):
559        permuted_inp[idx] = inp[x]
560    return permuted_inp
561
562
563def invert_perm(perm):
564    ndim = len(perm)
565    new_perm = [-1] * ndim
566    for idx, x in enumerate(perm):
567        new_perm[x] = idx
568    return new_perm
569
570
571#
572# Common helper functions
573#
574
575
576def validate_dim_length(length: int):
577    """
578    Validates that an object represents a valid
579    dimension length.
580    """
581
582    if isinstance(length, (int, torch.SymInt)):
583        torch._check_is_size(length)
584    else:
585        # sometimes called with sympy expression by inductor
586        assert length >= 0
587
588
589def validate_shape(shape: ShapeType):
590    """
591    Validates that a sequence represents a valid shape.
592    """
593
594    assert isinstance(shape, Sequence), type(shape)
595    for l in shape:
596        validate_dim_length(l)
597
598
599def validate_strides(strides: StrideType):
600    """
601    Verifies the object specifies valid strides.
602    """
603
604    assert isinstance(strides, Sequence)
605    for stride in strides:
606        assert stride >= 0
607
608
609def validate_idx(rank: int, idx: int):
610    """
611    Validates that idx is a valid index for the given shape.
612    Assumes the index is already canonicalized.
613    """
614
615    assert isinstance(idx, Dim)
616    assert isinstance(rank, Dim)
617
618    assert idx >= 0 and idx < rank or idx == 0
619
620
621def validate_dimension_indices(rank: int, indices: DimsSequenceType):
622    for idx in indices:
623        validate_idx(rank, idx)
624
625
626def validate_exclusive_idx(rank: int, ex_idx: int):
627    """
628    Validates that ex_idx is a valid exclusive index
629    for the given shape.
630    """
631
632    assert isinstance(ex_idx, Dim)
633    assert isinstance(rank, Dim)
634    assert ex_idx > 0 and ex_idx <= rank
635
636
637# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
638# specified using negative indices. If `wrap_scalar` is true then scalar
639# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
640# idx should be in the range [-rank, rank-1].
641def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
642    if rank < 0:
643        msg = f"Rank cannot be negative but got {rank}"
644        raise IndexError(msg)
645
646    if rank == 0:
647        if not wrap_scalar:
648            msg = f"Dimension specified as {idx} but tensor has no dimensions"
649            raise IndexError(msg)
650        rank = 1
651
652    if idx >= 0 and idx < rank:
653        return idx
654
655    if idx < 0:
656        _idx = idx + rank
657    else:
658        _idx = idx
659
660    if _idx < 0 or _idx >= rank:
661        # Same error message as in aten/src/ATen/WrapDimUtils.h:49
662        msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})"
663        raise IndexError(msg)
664
665    return _idx
666
667
668# Takes a dimension or sequence of dimensions and "wraps" them,
669# mapping negative offsets to positive ones
670@overload
671def canonicalize_dims(
672    rank: int, indices: Sequence[int], wrap_scalar: bool = True
673) -> Tuple[int, ...]:
674    pass
675
676
677@overload
678def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
679    pass
680
681
682def canonicalize_dims(rank, indices, wrap_scalar=True):
683    if isinstance(indices, Dim):
684        return canonicalize_dim(rank, indices, wrap_scalar)
685
686    return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)
687
688
689def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
690    """
691    Validates that perm is a permutation of length rank.
692    """
693
694    return isinstance(perm, Sequence) and sorted(perm) == list(range(rank))
695
696
697def is_same_shape(a: Sequence, b: Sequence) -> bool:
698    """
699    Compares two shapes a and b, returning True if they are the same
700    (their ranks and corresponding lengths match) and False otherwise.
701    """
702
703    return tuple(a) == tuple(b)
704
705
706def is_cpu_scalar_tensor(a: Any) -> bool:
707    return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
708
709
710def check_same_device(*args, allow_cpu_scalar_tensors):
711    """
712    Checks that all Tensors in args have the same device.
713
714    Raises a RuntimeError when:
715      - args contains an object whose type is not Tensor or Number
716      - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
717    """
718    # Short-circuits if all (one or fewer) arguments are trivially on the same device
719    if len(args) <= 1:
720        return
721
722    # Note: cannot initialize device to the first arg's device (it may not have one)
723    device = None
724    for arg in args:
725        if isinstance(arg, Number):
726            continue
727        elif isinstance(arg, TensorLike):
728            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
729                continue
730
731            if device is None:
732                device = arg.device
733
734            if device != arg.device:
735                msg = (
736                    "Tensor on device "
737                    + str(arg.device)
738                    + " is not on the expected device "
739                    + str(device)
740                    + "!"
741                )
742                raise RuntimeError(msg)
743        else:
744            msg = (
745                "Unexpected type when checking for same device, " + str(type(arg)) + "!"
746            )
747            raise RuntimeError(msg)
748
749
750def canonicalize_device(device: DeviceLikeType) -> torch.device:
751    if isinstance(device, torch.device):
752        return device
753
754    assert isinstance(device, str)
755    return torch.device(device)
756
757
758# Asserts if any of the following are true:
759#   - a non-scalar or non-Tensor is given
760#   - the shape of any tensors is distinct
761def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
762    """
763    Checks that all Tensors in args have the same shape.
764
765    Raises a RuntimeError when:
766      - args contains an object whose type is not Tensor or Number
767      - two Tensor objects in args have different devices
768    """
769    shape = None
770
771    for arg in args:
772        if isinstance(arg, Number):
773            continue
774        elif isinstance(arg, TensorLike):
775            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
776                continue
777
778            if shape is None:
779                shape = arg.shape
780
781            if not is_same_shape(shape, arg.shape):
782                msg = f"Shape {arg.shape} is not the expected shape {shape}!"
783                raise RuntimeError(msg)
784        else:
785            msg = (
786                "Unexpected type when checking for same shape, " + str(type(arg)) + "!"
787            )
788            raise RuntimeError(msg)
789
790
791# Acquires a common shape, if it exists, from one or more tensor arguments,
792# filtering number arguments
793def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
794    shape = None
795    scalar_shape = None
796
797    for arg in args:
798        if isinstance(arg, Number):
799            continue
800        elif isinstance(arg, TensorLike):
801            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
802                scalar_shape = arg.shape
803                continue
804
805            if shape is None:
806                shape = arg.shape
807
808            if not is_same_shape(shape, arg.shape):
809                return None
810        else:
811            return None
812
813    return shape if shape is not None else scalar_shape
814
815
816# Extracts dimensions that might be passed either as a list/tuple or as varargs.
817# A typical case is Tensor.permute .
818def extract_dims_from_varargs(
819    dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
820) -> DimsSequenceType:
821    if dims and isinstance(dims[0], Sequence):
822        assert len(dims) == 1
823        dims = cast(Tuple[DimsSequenceType], dims)
824        return dims[0]
825    else:
826        return cast(DimsSequenceType, dims)
827
828
829def extract_shape_from_varargs(
830    shape: Union[ShapeType, Tuple[ShapeType]],
831    validate=True,
832) -> Tuple[int, ...]:
833    """
834    Returns a shape from varargs.
835
836    In PyTorch, operations that accept shapes often accept them as varargs, like
837    foo(*shape). However a user can pass the shape as a sequence of integers,
838    like this:
839
840      foo(1, 2, 3)
841
842    or as a sequence of integers
843
844      foo((1, 2, 3))
845
846    In the first case shape will be a tuple of integers, and in the second case it's a tuple
847    containing a tuple of integers. This validates those inputs and canonicalizes them
848    to a tuple of integers.
849    """
850
851    # Handles tuple unwrapping
852    if len(shape) == 1 and isinstance(shape[0], Sequence):
853        shape = shape[0]
854
855    if validate:
856        validate_shape(shape)  # type: ignore[arg-type]
857    return shape  # type: ignore[return-value]
858
859
860def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
861    ndim = max(len(a), len(b))
862    expandedSizes = [0] * ndim
863
864    for i in range(ndim - 1, -1, -1):
865        offset = ndim - 1 - i
866        dimA = len(a) - 1 - offset
867        dimB = len(b) - 1 - offset
868        sizeA = a[dimA] if dimA >= 0 else 1
869        sizeB = b[dimB] if dimB >= 0 else 1
870
871        torch._check(
872            (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
873            lambda: (
874                f"The size of tensor a ({sizeA}) must match the size of "
875                f"tensor b ({sizeB}) at non-jagged dimension {i}"
876            ),
877        )
878
879        # 1s map to the other size (even 0)
880        expandedSizes[i] = sizeB if sizeA == 1 else sizeA
881
882    return tuple(expandedSizes)
883
884
885def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
886    """
887    Infers the size of a dim with size -1, if it exists.
888    Also checks that new shape is compatible with the number of elements.
889    """
890    dim = None
891    newsize = 1
892    for i, d in enumerate(shape):
893        if d == -1:
894            torch._check(dim is None, lambda: "only one dimension can be inferred")
895            dim = i
896        elif d >= 0:
897            newsize *= d
898        else:
899            torch._check(False, lambda: f"invalid shape dimension {d}")
900    if dim is None:
901        torch._check(
902            numel == newsize,
903            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
904        )
905    else:
906        from torch.fx.experimental.symbolic_shapes import definitely_true
907
908        torch._check(
909            newsize != 0,
910            lambda: (
911                f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
912                f"unspecified dimension size -1 can be any value and is ambiguous"
913                if definitely_true(numel == 0)
914                else f"shape '{list(shape)}' is invalid for input of size {numel}"
915            ),
916        )
917        torch._check(
918            numel % newsize == 0,
919            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
920        )
921        # Convert to list to produce a compatible error message with core
922        # PyTorch, which prints sequences in square brackets.
923        shape = list(shape)
924        shape[dim] = numel // newsize
925        # NB: This is pretty important when you have unbacked SymInts.
926        # Suppose you have (i0, 12) resizing into (2, -1, 12).  The old
927        # range for i0 is typically [2, inf], which means if you divide
928        # by two the new range should be [1, inf].  But this is bad news
929        # if you have an unbacked SymInt: we need to reapply the unsound
930        # assumption that the size is >= 2.
931        torch._check_is_size(shape[dim])
932    return tuple(shape)
933
934
935_integer_dtypes = (
936    torch.uint8,
937    torch.uint16,
938    torch.uint32,
939    torch.uint64,
940    torch.int8,
941    torch.int16,
942    torch.int32,
943    torch.int64,
944)
945_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
946_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
947
948
949def is_boolean_dtype(dtype: torch.dtype) -> bool:
950    assert isinstance(dtype, torch.dtype)
951    return dtype is torch.bool
952
953
954def is_integer_dtype(dtype: torch.dtype) -> bool:
955    assert isinstance(dtype, torch.dtype)
956    return dtype in _integer_dtypes
957
958
959def is_low_precision_dtype(dtype: torch.dtype) -> bool:
960    assert isinstance(dtype, torch.dtype)
961    return dtype in _low_precision_dtypes
962
963
964def is_float_dtype(dtype: torch.dtype) -> bool:
965    assert isinstance(dtype, torch.dtype)
966    return dtype.is_floating_point
967
968
969def is_complex_dtype(dtype: torch.dtype) -> bool:
970    assert isinstance(dtype, torch.dtype)
971    return dtype in _complex_dtypes
972
973
974def is_grad_dtype(dtype: torch.dtype) -> bool:
975    """
976    Checks if the dtype can require a gradient.
977    """
978    return dtype.is_floating_point or is_complex_dtype(dtype)
979
980
981_complex_to_real_dtype_map = {
982    torch.complex128: torch.float64,
983    torch.complex64: torch.float32,
984    torch.complex32: torch.float16,
985}
986
987_real_to_complex_dtype_map = {
988    torch.float16: torch.complex32,
989    torch.bfloat16: torch.complex64,
990    torch.float32: torch.complex64,
991    torch.float64: torch.complex128,
992}
993
994
995def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
996    return _complex_to_real_dtype_map[dtype]
997
998
999def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
1000    return _real_to_complex_dtype_map[dtype]
1001
1002
1003def dtype_to_type(dtype: torch.dtype) -> type:
1004    """
1005    Computes the corresponding Python type (AKA "type kind") for the
1006    given dtype.
1007    """
1008    assert isinstance(dtype, torch.dtype)
1009
1010    if dtype is torch.bool:
1011        return bool
1012    if dtype in _integer_dtypes:
1013        return int
1014    if dtype.is_floating_point:
1015        return float
1016    if dtype in _complex_dtypes:
1017        return complex
1018
1019    raise ValueError("Invalid dtype!")
1020
1021
1022def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:
1023    """
1024    Computes the corresponding Python type constructor for the
1025    given dtype.
1026    """
1027    assert isinstance(dtype, torch.dtype)
1028
1029    if dtype is torch.bool:
1030        return lambda x: bool(x)
1031    if dtype in _integer_dtypes:
1032        return sym_int
1033    if dtype.is_floating_point:
1034        return sym_float
1035    if dtype in _complex_dtypes:
1036        # TODO: type error here is real, replace with sym_complex
1037        return lambda x: complex(x)  # type: ignore[arg-type]
1038
1039    raise ValueError("Invalid dtype!")
1040
1041
1042def type_to_dtype(typ: type) -> torch.dtype:
1043    """
1044    Computes the corresponding dtype for a Number type.
1045    """
1046
1047    assert isinstance(typ, type)
1048
1049    if typ in (bool, torch.SymBool):
1050        return torch.bool
1051    if typ in (int, torch.SymInt):
1052        return torch.long
1053    if typ in (float, torch.SymFloat):
1054        return torch.get_default_dtype()
1055    # TODO: sym_complex_float?
1056    if typ is complex:
1057        return corresponding_complex_dtype(torch.get_default_dtype())
1058
1059    raise ValueError(f"Invalid type {typ}!")
1060
1061
1062def get_dtype(x: Union[torch.Tensor, NumberType]):
1063    if isinstance(x, torch.Tensor):
1064        return x.dtype
1065    else:
1066        return type_to_dtype(type(x))
1067
1068
1069_ordered_types = (bool, int, float, complex)
1070
1071
1072def check_fp_or_complex(
1073    dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True
1074):
1075    """
1076    Checks whether the input is floating point or complex.
1077    If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
1078    """
1079    torch._check(
1080        is_float_dtype(dtype) or is_complex_dtype(dtype),
1081        lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
1082    )
1083    torch._check(
1084        allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
1085        lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
1086    )
1087
1088
1089def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
1090    torch._check(
1091        len(A.shape) >= 2,
1092        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
1093    )
1094
1095
1096def get_higher_type(a: type, b: type) -> type:
1097    """
1098    Returns the higher of the two given Number types.
1099
1100    The types are ordered bool -> int -> float -> complex.
1101    """
1102    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1103    # Type checking
1104    if a not in _ordered_types or b not in _ordered_types:
1105        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1106
1107    if a is b:
1108        return a
1109
1110    for typ in _ordered_types:
1111        if a is typ:
1112            return b
1113        if b is typ:
1114            return a
1115
1116    raise ValueError("Unknown Python scalar type!")
1117
1118
1119# Returns the higher of two torch datatypes a and b or, if the two
1120#   are not ordered relative to each other, the next
1121#   higher datatype
1122def get_higher_dtype(
1123    a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1124    b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1125) -> Optional[torch.dtype]:
1126    """
1127    Computes the "lowest" datatype that is weakly
1128    "higher" than both a and b.
1129    """
1130
1131    # Type checking
1132    assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
1133    assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
1134
1135    def _extract_dtype(
1136        x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
1137    ) -> Optional[torch.dtype]:
1138        if x is None:
1139            return None
1140        if isinstance(x, torch.dtype):
1141            return x
1142        if isinstance(x, TensorLike):
1143            return x.dtype
1144        if isinstance(x, Number):
1145            return type_to_dtype(type(x))
1146
1147        raise RuntimeError("Unexpected type given to _extract_dtype!")
1148
1149    a, b = _extract_dtype(a), _extract_dtype(b)
1150
1151    if a is b:
1152        return a
1153
1154    if a is None:
1155        return b
1156
1157    if b is None:
1158        return a
1159
1160    ordered_datatypes = (
1161        (torch.bool,),
1162        (torch.uint8, torch.int8),
1163        (torch.int16,),
1164        (torch.int32,),
1165        (torch.int64,),
1166        (torch.float16, torch.bfloat16),
1167        (torch.float32,),
1168        (torch.float64,),
1169        (torch.complex32,),
1170        (torch.complex64,),
1171        (torch.complex128,),
1172    )
1173
1174    for idx, dtypes in enumerate(ordered_datatypes):
1175        if a in dtypes and b in dtypes:
1176            return ordered_datatypes[idx + 1][0]
1177        if a in dtypes:
1178            return b
1179        if b in dtypes:
1180            return a
1181
1182    raise RuntimeError("Unexpected termination!")
1183
1184
1185def check_pin_memory(pin_memory: bool):
1186    torch._check_not_implemented(
1187        not pin_memory, lambda: "PrimTorch does not support pinned memory"
1188    )
1189
1190
1191def check_layout(layout: torch.layout):
1192    torch._check_not_implemented(
1193        layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}"
1194    )
1195
1196
1197# TODO: maybe unify with can_cast_to?
1198def is_weakly_lesser_type(a: type, b: type) -> bool:
1199    """
1200    Compares two types, a and b, returning True if a is weakly "less" than b.
1201
1202    The comparison is determined by the following type ordering: bool, int, float, complex.
1203    """
1204
1205    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1206
1207    if a not in _ordered_types or b not in _ordered_types:
1208        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1209
1210    for typ in _ordered_types:
1211        if a == typ:
1212            return True
1213        if b == typ:
1214            return False
1215
1216    raise RuntimeError("Unexpected termination!")
1217
1218
1219def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
1220    for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
1221        if fn(cast_to):
1222            return True
1223        if fn(cast_from):
1224            return False
1225
1226    raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!")
1227
1228
1229def check_same_dtype(*args):
1230    """
1231    Checks that all Tensors in args have the same device and that all Numbers have the
1232    same corresponding Python type.
1233
1234    Raises a RuntimeError when:
1235      - args contains an object whose type is not Tensor or Number
1236      - two Tensors objects in args have different dtypes
1237      - two Number objects in args have different types
1238      - there are Tensors and Numbers in args, and one of those Tensors corresponding
1239          Python types is different from the type of one of those Numbers
1240    """
1241    full_dtype = None
1242    scalar_type = None
1243
1244    for arg in args:
1245        if isinstance(arg, Number):
1246            # Scalar type checking is disabled (and may be removed in the future)
1247            continue
1248            # if scalar_type is None:
1249            #     scalar_type = type(arg)
1250
1251            # if scalar_type is not type(arg):
1252            #     msg = (
1253            #         "Scalar of type "
1254            #         + str(type(arg))
1255            #         + " is not the expected type of "
1256            #         + str(scalar_type)
1257            #         + "!"
1258            #     )
1259            #     raise RuntimeError(msg)
1260        elif isinstance(arg, TensorLike):
1261            if full_dtype is None:
1262                full_dtype = arg.dtype
1263            if scalar_type is None:
1264                scalar_type = dtype_to_type(arg.dtype)
1265
1266            if full_dtype is not arg.dtype:
1267                msg = (
1268                    "Tensor with dtype "
1269                    + str(arg.dtype)
1270                    + " is not the expected dtype of "
1271                    + str(full_dtype)
1272                    + "!"
1273                )
1274                raise RuntimeError(msg)
1275
1276            arg_type = dtype_to_type(arg.dtype)
1277            if arg_type is not scalar_type:
1278                msg = (
1279                    "Tensor with corresponding Python type "
1280                    + str(arg_type)
1281                    + " is not the expected type of "
1282                    + str(scalar_type)
1283                    + "!"
1284                )
1285                raise RuntimeError(msg)
1286        else:
1287            msg = (
1288                "Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
1289            )
1290            raise RuntimeError(msg)
1291
1292
1293# Maps datatypes to their computation types for elementwise operations
1294_computation_dtype_map = {
1295    torch.bfloat16: torch.float32,
1296    torch.float16: torch.float32,
1297    torch.complex32: torch.complex64,
1298}
1299
1300
1301def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
1302    return _computation_dtype_map.get(dtype, dtype)
1303
1304
1305_cpu_acc_type_map = {
1306    torch.bfloat16: torch.float64,
1307    torch.float16: torch.float64,
1308    torch.float32: torch.float64,
1309    torch.complex32: torch.complex128,
1310    torch.complex64: torch.complex128,
1311}
1312
1313
1314def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:
1315    # Equivalent to at::toAccumulateType, prefer computation_dtype where possible
1316    if device.type == "cpu":
1317        return _cpu_acc_type_map.get(dtype, dtype)
1318    else:
1319        return get_computation_dtype(dtype)
1320
1321
1322class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
1323    DEFAULT = (0,)
1324    NO_OPMATH = (1,)
1325    INT_TO_FLOAT = (2,)
1326    ALWAYS_BOOL = (3,)
1327    COMPLEX_TO_FLOAT = (4,)
1328    BOOL_TO_LONG = (5,)
1329
1330
1331class REDUCTION_OUTPUT_TYPE_KIND(Enum):
1332    SAME = (0,)
1333    COMPLEX_TO_FLOAT = (1,)  # for complex types outputs corresponding real type
1334    KEEP_PROMOTED_TYPE = (2,)  # keep output in opmath type, needed for mean
1335    ALWAYS_BOOL = (3,)
1336
1337
1338# Describes the return type of the primitive:
1339#
1340#   - NEW, a new tensor is created
1341#   - VIEW, a view of an input tensor is returned
1342#   - INPLACE, one or more input tensors is modified
1343#
1344# these descriptors are mututally exclusive and exhaustive.
1345class RETURN_TYPE(Enum):
1346    NEW = (0,)
1347    VIEW = (1,)
1348    INPLACE = (2,)
1349    NONE = (3,)
1350
1351
1352# TODO: when NumberType contains the sym types, can simplify this
1353def number_type(
1354    x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
1355) -> Type:
1356    if isinstance(x, torch.SymInt):
1357        return int
1358    elif isinstance(x, torch.SymFloat):
1359        return float
1360    elif isinstance(x, torch.SymBool):
1361        return bool
1362    else:
1363        return type(x)
1364
1365
1366def expr_type(x: sympy.Basic) -> Type:
1367    import sympy
1368
1369    if x.kind is sympy.core.kind.BooleanKind:
1370        return bool
1371    elif x.is_integer:  # type: ignore[attr-defined]
1372        return int
1373    else:
1374        # NB: Not strictly correct, but we don't support SymPy complex or bool.
1375        return float
1376
1377
1378# TODO: document type promotion kinds
1379def elementwise_dtypes(
1380    *_args,
1381    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
1382) -> Tuple[torch.dtype, torch.dtype]:
1383    """
1384    Computes the computation and result dtypes for elementwise type promotion
1385    on the given arguments and with the given elementwise type promotion kind.
1386
1387    Note that not all inputs to an elementwise operation necessarily participate in type promotion.
1388    For example, the "alpha" parameter of torch.add does not participate in type promotion,
1389    although it may be cast to the Python type corresponding to the computation dtype that
1390    the type promotion algorithm determines.
1391
1392    Default elementwise type promotion, which all other type promotion kinds tweak (see below),
1393    first decides which of four ordered types to use:
1394
1395    bool -> integer -> floating point -> complex
1396
1397    The selected type is the "lowest" type in the above list such that all number arguments
1398    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
1399    type for their dtype.
1400
1401    Once the type is determined, the particular result dtype is found. The dtypes are
1402    partially ordered as follows:
1403
1404    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
1405      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
1406
1407    The result dtype is selected by:
1408      - if no tensor's dtype has the same corresponding type as the one selected,
1409          then the result dtype is the (default) dtype corresponding to the selected type
1410          (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
1411      - if the result type is complex then the dtype is:
1412        -  the default complex dtype if there are no floating point or complex tensors
1413        -  if there are floating point or complex tensors with one or more dimensions, then
1414            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1415            (for example, double + cfloat -> cdouble)
1416        -  if there are only floating point or complex tensors with zero dimensions, then
1417            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1418      - if the first two cases do not apply, the result dtype is the highest dtype among
1419          all tensors with one or more dimensions of the output type, and if there are no such
1420          tensors then it's the highest dtype among all tensors with zero dimensions of the output type
1421          (for example, long + half -> half, even if the half tensor has zero dimensions)
1422
1423    The "corresponding complex dtypes" are:
1424      float16    -> complex32
1425      bfloat16   -> complex64
1426      float32    -> complex64
1427      float64    -> complex128
1428      complex32  -> complex32
1429      complex64  -> complex64
1430      complex128 -> complex128
1431
1432    The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
1433    dtype by mapping low precision floating point and complex dtypes as follows:
1434
1435      float16   -> float32
1436      bfloat16  -> float32
1437      complex32 -> complex64
1438
1439    This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
1440    computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
1441    which perform no mathematical operations on their tensors (see below for examples).
1442
1443    The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype,
1444    and computation dtypes to the appropriate op math dtype.
1445
1446    The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
1447    mapping:
1448
1449        complex32  -> float16
1450        complex64  -> float32
1451        complex128 -> float64
1452
1453    Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
1454
1455    The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
1456
1457    The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
1458
1459    Example operators for each type promotion option:
1460      DEFAULT                 : add
1461      NO_OPMATH               : where, nextafter, cat
1462      INT_TO_FLOAT            : sin
1463      COMPLEX_TO_FLOAT        : abs
1464      BOOL_TO_LONG            : pow
1465      ALWAYS_BOOL             : eq
1466
1467    """
1468
1469    args = tuple(x for x in _args if x is not None)
1470
1471    highest_type: type = bool
1472
1473    # Import sympy locally, as importing it eagerly at a module level is too slow
1474    # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
1475    import sympy
1476
1477    for x in args:
1478        if not isinstance(x, (Number, TensorLike, sympy.Basic)):
1479            msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
1480            raise ValueError(msg)
1481
1482        if isinstance(x, Number):
1483            highest_type = get_higher_type(highest_type, number_type(x))
1484        elif isinstance(x, sympy.Basic):
1485            highest_type = get_higher_type(highest_type, expr_type(x))
1486        else:
1487            # x is a TensorLike
1488            highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
1489
1490    result_dtype = None
1491
1492    def _find_highest_dtype_filtered(
1493        args, filter, *, float_as_complex=False
1494    ) -> Optional[torch.dtype]:
1495        zero_dim_tensor_dtype = None
1496        one_plus_dim_tensor_dtype = None
1497        for x in args:
1498            if isinstance(x, TensorLike) and filter(x.dtype):
1499                _dtype = x.dtype
1500                if float_as_complex and is_float_dtype(_dtype):
1501                    _dtype = corresponding_complex_dtype(_dtype)
1502                if x.ndim == 0:
1503                    zero_dim_tensor_dtype = get_higher_dtype(
1504                        zero_dim_tensor_dtype, _dtype
1505                    )
1506                else:
1507                    # x.ndim > 0
1508                    one_plus_dim_tensor_dtype = get_higher_dtype(
1509                        one_plus_dim_tensor_dtype, _dtype
1510                    )
1511
1512        # Prefers dtype of tensors with one or more dimensions
1513        if one_plus_dim_tensor_dtype is not None:
1514            return one_plus_dim_tensor_dtype
1515
1516        return zero_dim_tensor_dtype
1517
1518    if highest_type is float:
1519        result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
1520        result_dtype = (
1521            torch.get_default_dtype() if result_dtype is None else result_dtype
1522        )
1523    elif highest_type is complex:
1524        result_dtype = _find_highest_dtype_filtered(
1525            args,
1526            lambda x: is_float_dtype(x) or is_complex_dtype(x),
1527            float_as_complex=True,
1528        )
1529        if result_dtype is None:
1530            result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
1531    elif highest_type is int:
1532        result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
1533        result_dtype = torch.long if result_dtype is None else result_dtype
1534    else:
1535        # highest_type is bool
1536        result_dtype = torch.bool
1537
1538    if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
1539        return get_computation_dtype(result_dtype), result_dtype
1540    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
1541        return result_dtype, result_dtype
1542    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
1543        if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
1544            result_dtype = torch.get_default_dtype()
1545        return get_computation_dtype(result_dtype), result_dtype
1546    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
1547        # NOTE: computation can still occur in a complex dtype
1548        computation_dtype = get_computation_dtype(result_dtype)
1549        if is_complex_dtype(result_dtype):
1550            result_dtype = corresponding_real_dtype(result_dtype)
1551        return computation_dtype, result_dtype
1552    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
1553        if is_boolean_dtype(result_dtype):
1554            return torch.long, torch.long
1555        return get_computation_dtype(result_dtype), result_dtype
1556    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
1557        return get_computation_dtype(result_dtype), torch.bool
1558    else:
1559        raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}")
1560
1561
1562def reduction_dtypes(
1563    arg,
1564    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
1565    dtype: Optional[torch.dtype] = None,
1566) -> Tuple[torch.dtype, Optional[torch.dtype]]:
1567    # even though some reductions, like amin or amax, don't strictly require type promotion,
1568    # all the math ops (including comparisons) are still defined only for a computation type,
1569    # so promotion will still happen. We are doing it explicitly here
1570    inp_dtype = dtype if dtype is not None else arg.dtype
1571    computation_dtype = get_computation_dtype(inp_dtype)
1572    if (
1573        output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME
1574        or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1575    ):
1576        result_dtype = dtype if dtype else arg.dtype
1577        if (
1578            output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1579            and is_complex_dtype(result_dtype)
1580        ):
1581            result_dtype = corresponding_real_dtype(result_dtype)
1582    elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:
1583        result_dtype = None
1584    else:  # ALWAYS_BOOL
1585        result_dtype = torch.bool
1586    return computation_dtype, result_dtype
1587
1588
1589# This function's logic is borrowed from the following functions defined in C++:
1590# batched_matrix_contiguous_strides and contiguous_strides
1591def make_contiguous_strides_for(
1592    shape: ShapeType, row_major: bool = True
1593) -> Tuple[int, ...]:
1594    """
1595    Returns the strides of a contiguous tensor if row_major
1596    If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
1597    This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
1598    """
1599    # contiguous_strides from c10/util/strides.h
1600    validate_shape(shape)
1601    if not shape:
1602        return ()
1603
1604    from torch.fx.experimental.symbolic_shapes import is_nested_int
1605
1606    multiplier = 1
1607    strides = []
1608    for l in reversed(shape):
1609        strides.append(multiplier)
1610        multiplier *= l if is_nested_int(l) else sym_max(l, 1)
1611
1612    result = tuple(reversed(strides))
1613
1614    # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h
1615    if row_major:
1616        return result
1617    else:
1618        if len(shape) < 2:
1619            return result
1620        return result[:-2] + (1, max(shape[-2], 1))
1621
1622
1623def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1624    torch._check(
1625        len(shape) == 3,
1626        lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
1627    )
1628
1629    multiplier = 1
1630    strides = [0] * 3
1631    for idx in (1, -1, 0):
1632        # NOTE: intentionally divergence from make_contiguous_strides_for
1633        # This is consistent with eager
1634        strides[idx] = multiplier
1635        multiplier *= shape[idx]
1636
1637    return tuple(strides)
1638
1639
1640def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1641    # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
1642    torch._check(
1643        len(shape) == 4,
1644        lambda: "Only tensors of rank 4 can use the channels_last memory format",
1645    )
1646
1647    multiplier = 1
1648    strides = [0] * 4
1649    for idx in (1, -1, -2, 0):
1650        # NOTE: intentionally divergence from make_contiguous_strides_for
1651        # This is consistent with eager
1652        strides[idx] = multiplier
1653        multiplier *= shape[idx]
1654
1655    return tuple(strides)
1656
1657
1658def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1659    torch._check(
1660        len(shape) == 5,
1661        lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
1662    )
1663
1664    multiplier = 1
1665    strides = [0] * 5
1666    for idx in (1, -1, -2, -3, 0):
1667        # NOTE: intentionally divergence from make_contiguous_strides_for
1668        # This is consistent with eager
1669        strides[idx] = multiplier
1670        multiplier *= shape[idx]
1671
1672    return tuple(strides)
1673
1674
1675def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1676    ndim = len(shape) if isinstance(shape, Sequence) else 1
1677    if ndim == 3:
1678        return make_channels_last_1d_strides_for(shape)
1679    elif ndim == 4:
1680        return make_channels_last_2d_strides_for(shape)
1681    elif ndim == 5:
1682        return make_channels_last_3d_strides_for(shape)
1683    else:
1684        raise RuntimeError(
1685            f"no channels last format strides exist in {ndim} dimensions"
1686        )
1687
1688
1689def compute_reduction_output_shape(
1690    shape: ShapeType, dimensions: Sequence
1691) -> Tuple[int, ...]:
1692    for idx in dimensions:
1693        validate_idx(len(shape), idx)
1694
1695    new_shape = []
1696    for idx in range(len(shape)):
1697        if idx in dimensions:
1698            continue
1699
1700        new_shape.append(shape[idx])
1701
1702    return tuple(new_shape)
1703
1704
1705def validate_no_repeating_dims(dims: Sequence):
1706    if len(dims) != len(set(dims)):
1707        raise RuntimeError("duplicate value in the list of dims")
1708
1709
1710def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
1711    if dims is None:
1712        return tuple(range(len(shape)))
1713    dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
1714    validate_no_repeating_dims(dims)
1715    return dims
1716
1717
1718def set_correction(
1719    unbiased: Optional[bool] = None,
1720    correction: Optional[NumberType] = None,
1721) -> float:
1722    if correction is not None and unbiased is not None:
1723        raise RuntimeError("cannot specify both correction and unbiased arguments")
1724    elif correction is None and unbiased is None:
1725        correction = 1.0
1726    elif correction is None and unbiased is not None:
1727        correction = 0.0 if unbiased is False else 1.0
1728    # NB: we don't actually support symint here, but it's harmless to accept
1729    if not isinstance(correction, (IntLike, FloatLike)):
1730        raise ValueError("correction argument should be integer or float")
1731    if correction < 0:
1732        raise ValueError("correction argument should be non-negative")
1733    return sym_float(correction)
1734
1735
1736def compute_required_storage_length(
1737    shape: ShapeType, strides: StrideType, storage_offset: int
1738) -> int:
1739    """Computes the minimum storage size to hold the given tensor geometry.
1740
1741    Example
1742    =======
1743
1744    This is the size of a newly allocated tensor's storage, in units of elements
1745
1746    >>> t = torch.empty((10, 20))
1747    >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
1748    200
1749
1750    >>> # xdoctest: +SKIP(failing)
1751    >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
1752    >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
1753    >>> size == t.storage().size()
1754    True
1755
1756    A valid tensor may have a larger storage size, but never smaller
1757
1758    >>> slice = torch.empty(100)[20:40]
1759    >>> slice.storage().size()
1760    100
1761
1762    >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
1763    40
1764
1765    """
1766    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1767
1768    # Short-circuits if the shape has no elements
1769    if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):
1770        return 0
1771
1772    max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
1773    # +1 to account for the first element which offsets are taken from
1774    return 1 + storage_offset + max_offset
1775
1776
1777def check_in_bounds_for_storage(
1778    a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
1779):
1780    """
1781    Determines if the given shape, strides, and offset are valid for the given storage.
1782    """
1783
1784    required_length = compute_required_storage_length(shape, strides, storage_offset)
1785    if a.size() < required_length:
1786        msg = (
1787            f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, "
1788            f"shape of {str(shape)}, and strides of {str(strides)}, "
1789            f"which requires a storage of size {required_length}"
1790        )
1791        raise ValueError(msg)
1792
1793
1794# NOTE: This function should ideally be removed, but some Meta internal models
1795# packaged with `torch.package` are using it, so it will have to be removed
1796# at some point in the future when those models no longer use this function.
1797@deprecated(
1798    "`torch._prims_common.check` is deprecated and will be removed in the future. "
1799    "Please use `torch._check*` functions instead.",
1800    category=FutureWarning,
1801)
1802def check(
1803    b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
1804) -> None:
1805    """
1806    Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
1807    Error message is a callable producing a string (to avoid wasting time
1808    string formatting in non-error case, and also to make it easier for torchdynamo
1809    to trace.)
1810
1811    .. note:: This function is planned for removal in the future. Please use
1812        `torch._check*` functions instead.
1813    """
1814    torch._check_with(exc_type, b, s)
1815
1816
1817# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
1818# c10/core/MemoryFormat.h into one function
1819def are_strides_like_channels_last(
1820    shape: Sequence[int], strides: Sequence[int]
1821) -> bool:
1822    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1823
1824    ndim = len(shape)
1825
1826    if ndim == 4:
1827        # Check for channels_last_2d
1828        dim_order = [1, 3, 2, 0]
1829    elif ndim == 5:
1830        # Check for channels_last_3d
1831        dim_order = [1, 4, 3, 2, 0]
1832    else:
1833        return False
1834
1835    if guard_size_oblivious(strides[1] == 0):
1836        return False
1837
1838    min = 0
1839    for d in dim_order:
1840        if guard_size_oblivious(shape[d] == 0):
1841            return False
1842        if guard_size_oblivious(strides[d] < min):
1843            return False
1844        if d == 0 and min == strides[1]:
1845            return False
1846        min = strides[d]
1847        if guard_size_oblivious(strides[d] > 1):
1848            min *= shape[d]
1849    return True
1850
1851
1852def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:
1853    if x.layout != torch.strided:
1854        return torch.contiguous_format
1855
1856    if are_strides_like_channels_last(x.shape, x.stride()):
1857        return torch.channels_last if x.ndim == 4 else torch.channels_last_3d
1858
1859    return torch.contiguous_format
1860
1861
1862def prod(xs: Sequence[NumberType]) -> NumberType:
1863    """Product of elements in input sequence. Returns 1 for empty sequence"""
1864    return reduce(operator.mul, xs, 1)
1865
1866
1867def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:
1868    """Checks if a shape can be expanded to another shape.
1869    This is equivalent to checking if the two shapes are broadcastable.
1870    """
1871    # This is a Python implementation of
1872    # aten/src/ATen/ExpandUtils.h:is_expandable_to
1873    if len(shape) > len(desired):
1874        return False
1875    for i in range(len(shape)):
1876        if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:
1877            return False
1878    return True
1879
1880
1881def mask_tensor(mask: TensorLikeType, t: TensorLikeType):
1882    """
1883    Similar to torch.where(mask, t, 0) but if t is boolean,
1884    result is also boolean and not promoted to int.
1885    """
1886    # torch.where(mask, t, False) is equivalent
1887    # but feels hacky and might break in the future
1888    if t.dtype is torch.bool:
1889        return mask.logical_and(t)
1890    else:
1891        return torch.where(mask, t, 0)
1892
1893
1894def get_aten_op(fn: Callable, name: str):
1895    """
1896    Given the __module__ of reference and its name, it returns
1897    (our best guess of) the ATen name of the associated operation
1898
1899    Note: In ATen, the __name__ of a function within a module often
1900    starts by the module name. E.g. linalg_eigh, or special_zeta
1901    """
1902    module = fn.__module__
1903    prefix = "torch._refs"
1904    assert module.startswith(prefix)
1905    module = module[len(prefix) :]
1906    # We want to go from .special / .nn.functional
1907    # to special and special_ / nn_functional_
1908    if module:
1909        module = module[1:]
1910        module = module.replace(".", "_")
1911        module = module + "_"
1912    return getattr(torch._ops.ops.aten, f"{module}{name}")
1913
1914
1915def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
1916    return dtype if dtype is not None else torch.get_default_dtype()
1917
1918
1919def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType:
1920    return device if device is not None else torch.device("cpu")
1921
1922
1923def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
1924    return layout if layout is not None else torch.strided
1925
1926
1927def clone_preserve_strides(x):
1928    needed_size = compute_required_storage_length(
1929        x.size(), x.stride(), x.storage_offset()
1930    )
1931    # Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
1932    # so these as_strided() calls are not seen by autograd.
1933    # We need to mimic this behavior in our ref/prim implementations.
1934    # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
1935    # We should revisit this when we add a compositional as_strided op,
1936    # and also as part of https://github.com/pytorch/pytorch/issues/90507
1937    try:
1938        old = torch._C._dispatch_tls_is_dispatch_key_excluded(
1939            torch._C.DispatchKey.ADInplaceOrView
1940        )
1941        torch._C._dispatch_tls_set_dispatch_key_excluded(
1942            torch._C.DispatchKey.ADInplaceOrView, True
1943        )
1944        buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
1945        return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
1946    finally:
1947        torch._C._dispatch_tls_set_dispatch_key_excluded(
1948            torch._C.DispatchKey.ADInplaceOrView, old
1949        )
1950
1951
1952def alert_not_deterministic(caller: str):
1953    if torch.are_deterministic_algorithms_enabled():
1954        if torch.is_deterministic_algorithms_warn_only_enabled():
1955            warnings.warn(
1956                f"{caller} does not have a deterministic implementation, but you set "
1957                f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
1958                f"You can file an issue at https://github.com/pytorch/pytorch/issues "
1959                f"to help us prioritize adding deterministic support for this operation."
1960            )
1961        else:
1962            torch._check(
1963                False,
1964                lambda: (
1965                    f"{caller} does not have a deterministic implementation, but you set "
1966                    f"'torch.use_deterministic_algorithms(True)'. You can turn off "
1967                    f"determinism just for this operation, or you can use the "
1968                    f"'warn_only=True' option, if that's acceptable for your application. "
1969                    f"You can also file an issue at https://github.com/pytorch/pytorch/issues "
1970                    f"to help us prioritize adding deterministic support for this operation."
1971                ),
1972            )
1973
1974
1975class CUDARngStateHelper:
1976    @staticmethod
1977    def get_torch_state_as_tuple(fake_mode=nullcontext()):
1978        if not torch.cuda.is_available():
1979            raise RuntimeError("CUDA not available")
1980
1981        with fake_mode:
1982            seed = torch.tensor(torch.cuda.initial_seed())
1983            offset = torch.tensor(torch.cuda._get_rng_state_offset())
1984            return seed, offset
1985
1986    @staticmethod
1987    def set_torch_state_tensor(seed, offset):
1988        # Rng state is [64-bit seed, 64-bit offset]
1989        seed_portion = seed.reshape([1]).view(torch.uint8)
1990        offset_portion = offset.reshape([1]).view(torch.uint8)
1991        new_state = torch.cat([seed_portion, offset_portion])
1992        torch.cuda.set_rng_state(new_state)
1993
1994    @staticmethod
1995    def set_new_offset(relative_offset):
1996        torch.cuda._set_rng_state_offset(relative_offset.item())
1997