xref: /aosp_15_r20/external/pytorch/torch/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport operator
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
8*da0073e9SAndroid Build Coastguard Workerfrom torch import _VF, Tensor
9*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _add_docstr
10*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import _overload as overload, boolean_dispatch
11*da0073e9SAndroid Build Coastguard Workerfrom torch._lowrank import pca_lowrank, svd_lowrank
12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import (
13*da0073e9SAndroid Build Coastguard Worker    handle_torch_function,
14*da0073e9SAndroid Build Coastguard Worker    has_torch_function,
15*da0073e9SAndroid Build Coastguard Worker    has_torch_function_unary,
16*da0073e9SAndroid Build Coastguard Worker    has_torch_function_variadic,
17*da0073e9SAndroid Build Coastguard Worker)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker__all__ = [
21*da0073e9SAndroid Build Coastguard Worker    "atleast_1d",
22*da0073e9SAndroid Build Coastguard Worker    "atleast_2d",
23*da0073e9SAndroid Build Coastguard Worker    "atleast_3d",
24*da0073e9SAndroid Build Coastguard Worker    "align_tensors",
25*da0073e9SAndroid Build Coastguard Worker    "broadcast_shapes",
26*da0073e9SAndroid Build Coastguard Worker    "broadcast_tensors",
27*da0073e9SAndroid Build Coastguard Worker    "cartesian_prod",
28*da0073e9SAndroid Build Coastguard Worker    "block_diag",
29*da0073e9SAndroid Build Coastguard Worker    "cdist",
30*da0073e9SAndroid Build Coastguard Worker    "chain_matmul",
31*da0073e9SAndroid Build Coastguard Worker    "einsum",
32*da0073e9SAndroid Build Coastguard Worker    "istft",
33*da0073e9SAndroid Build Coastguard Worker    "lu",
34*da0073e9SAndroid Build Coastguard Worker    "norm",
35*da0073e9SAndroid Build Coastguard Worker    "meshgrid",
36*da0073e9SAndroid Build Coastguard Worker    "pca_lowrank",
37*da0073e9SAndroid Build Coastguard Worker    "split",
38*da0073e9SAndroid Build Coastguard Worker    "stft",
39*da0073e9SAndroid Build Coastguard Worker    "svd_lowrank",
40*da0073e9SAndroid Build Coastguard Worker    "tensordot",
41*da0073e9SAndroid Build Coastguard Worker    "unique",
42*da0073e9SAndroid Build Coastguard Worker    "unique_consecutive",
43*da0073e9SAndroid Build Coastguard Worker    "unravel_index",
44*da0073e9SAndroid Build Coastguard Worker]
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef broadcast_tensors(*tensors):
48*da0073e9SAndroid Build Coastguard Worker    r"""broadcast_tensors(*tensors) -> List of Tensors
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    Args:
53*da0073e9SAndroid Build Coastguard Worker        *tensors: any number of tensors of the same type
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    .. warning::
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        More than one element of a broadcasted tensor may refer to a single
58*da0073e9SAndroid Build Coastguard Worker        memory location. As a result, in-place operations (especially ones that
59*da0073e9SAndroid Build Coastguard Worker        are vectorized) may result in incorrect behavior. If you need to write
60*da0073e9SAndroid Build Coastguard Worker        to the tensors, please clone them first.
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    Example::
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.arange(3).view(1, 3)
65*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.arange(2).view(2, 1)
66*da0073e9SAndroid Build Coastguard Worker        >>> a, b = torch.broadcast_tensors(x, y)
67*da0073e9SAndroid Build Coastguard Worker        >>> a.size()
68*da0073e9SAndroid Build Coastguard Worker        torch.Size([2, 3])
69*da0073e9SAndroid Build Coastguard Worker        >>> a
70*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1, 2],
71*da0073e9SAndroid Build Coastguard Worker                [0, 1, 2]])
72*da0073e9SAndroid Build Coastguard Worker    """
73*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
74*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
75*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(broadcast_tensors, tensors, *tensors)
76*da0073e9SAndroid Build Coastguard Worker    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Workerdef broadcast_shapes(*shapes):
80*da0073e9SAndroid Build Coastguard Worker    r"""broadcast_shapes(*shapes) -> Size
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    Similar to :func:`broadcast_tensors` but for shapes.
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    This is equivalent to
85*da0073e9SAndroid Build Coastguard Worker    ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
86*da0073e9SAndroid Build Coastguard Worker    but avoids the need create to intermediate tensors. This is useful for
87*da0073e9SAndroid Build Coastguard Worker    broadcasting tensors of common batch shape but different rightmost shape,
88*da0073e9SAndroid Build Coastguard Worker    e.g. to broadcast mean vectors with covariance matrices.
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    Example::
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker        >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
93*da0073e9SAndroid Build Coastguard Worker        torch.Size([1, 3, 2])
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    Args:
96*da0073e9SAndroid Build Coastguard Worker        \*shapes (torch.Size): Shapes of tensors.
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    Returns:
99*da0073e9SAndroid Build Coastguard Worker        shape (torch.Size): A shape compatible with all input shapes.
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    Raises:
102*da0073e9SAndroid Build Coastguard Worker        RuntimeError: If shapes are incompatible.
103*da0073e9SAndroid Build Coastguard Worker    """
104*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
105*da0073e9SAndroid Build Coastguard Worker    # TODO Move this to C++ once the jit has better support for torch.Size.
106*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_tracing():
107*da0073e9SAndroid Build Coastguard Worker        max_len = 0
108*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
109*da0073e9SAndroid Build Coastguard Worker            if isinstance(shape, (int, torch.SymInt)):
110*da0073e9SAndroid Build Coastguard Worker                if max_len < 1:
111*da0073e9SAndroid Build Coastguard Worker                    max_len = 1
112*da0073e9SAndroid Build Coastguard Worker            elif isinstance(shape, (tuple, list)):
113*da0073e9SAndroid Build Coastguard Worker                s = len(shape)
114*da0073e9SAndroid Build Coastguard Worker                if max_len < s:
115*da0073e9SAndroid Build Coastguard Worker                    max_len = s
116*da0073e9SAndroid Build Coastguard Worker        result = [1] * max_len
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
121*da0073e9SAndroid Build Coastguard Worker            if isinstance(shape, (int, torch.SymInt)):
122*da0073e9SAndroid Build Coastguard Worker                shape = (shape,)
123*da0073e9SAndroid Build Coastguard Worker            if isinstance(shape, (tuple, list)):
124*da0073e9SAndroid Build Coastguard Worker                for i in range(-1, -1 - len(shape), -1):
125*da0073e9SAndroid Build Coastguard Worker                    if shape[i] < 0:
126*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(
127*da0073e9SAndroid Build Coastguard Worker                            f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
128*da0073e9SAndroid Build Coastguard Worker                        )
129*da0073e9SAndroid Build Coastguard Worker                    # NB: result is initialized to 1 so this is effectively an
130*da0073e9SAndroid Build Coastguard Worker                    # equals one test
131*da0073e9SAndroid Build Coastguard Worker                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
132*da0073e9SAndroid Build Coastguard Worker                        shape[i] == result[i]
133*da0073e9SAndroid Build Coastguard Worker                    ):
134*da0073e9SAndroid Build Coastguard Worker                        continue
135*da0073e9SAndroid Build Coastguard Worker                    if result[i] != 1:
136*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError(
137*da0073e9SAndroid Build Coastguard Worker                            "Shape mismatch: objects cannot be broadcast to a single shape"
138*da0073e9SAndroid Build Coastguard Worker                        )
139*da0073e9SAndroid Build Coastguard Worker                    result[i] = shape[i]
140*da0073e9SAndroid Build Coastguard Worker            else:
141*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
142*da0073e9SAndroid Build Coastguard Worker                    "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
143*da0073e9SAndroid Build Coastguard Worker                    shape,
144*da0073e9SAndroid Build Coastguard Worker                )
145*da0073e9SAndroid Build Coastguard Worker        return torch.Size(result)
146*da0073e9SAndroid Build Coastguard Worker    else:
147*da0073e9SAndroid Build Coastguard Worker        # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
148*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
149*da0073e9SAndroid Build Coastguard Worker            scalar = torch.zeros((), device="cpu")
150*da0073e9SAndroid Build Coastguard Worker            tensors = [scalar.expand(shape) for shape in shapes]
151*da0073e9SAndroid Build Coastguard Worker            tensors = broadcast_tensors(*tensors)
152*da0073e9SAndroid Build Coastguard Worker            return tensors[0].shape
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef split(
156*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
157*da0073e9SAndroid Build Coastguard Worker    split_size_or_sections: Union[int, List[int]],
158*da0073e9SAndroid Build Coastguard Worker    dim: int = 0,
159*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, ...]:
160*da0073e9SAndroid Build Coastguard Worker    r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker    If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
163*da0073e9SAndroid Build Coastguard Worker    be split into equally sized chunks (if possible). Last chunk will be smaller if
164*da0073e9SAndroid Build Coastguard Worker    the tensor size along the given dimension :attr:`dim` is not divisible by
165*da0073e9SAndroid Build Coastguard Worker    :attr:`split_size`.
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
168*da0073e9SAndroid Build Coastguard Worker    into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
169*da0073e9SAndroid Build Coastguard Worker    to :attr:`split_size_or_sections`.
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    Args:
172*da0073e9SAndroid Build Coastguard Worker        tensor (Tensor): tensor to split.
173*da0073e9SAndroid Build Coastguard Worker        split_size_or_sections (int) or (list(int)): size of a single chunk or
174*da0073e9SAndroid Build Coastguard Worker            list of sizes for each chunk
175*da0073e9SAndroid Build Coastguard Worker        dim (int): dimension along which to split the tensor.
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    Example::
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.arange(10).reshape(5, 2)
180*da0073e9SAndroid Build Coastguard Worker        >>> a
181*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1],
182*da0073e9SAndroid Build Coastguard Worker                [2, 3],
183*da0073e9SAndroid Build Coastguard Worker                [4, 5],
184*da0073e9SAndroid Build Coastguard Worker                [6, 7],
185*da0073e9SAndroid Build Coastguard Worker                [8, 9]])
186*da0073e9SAndroid Build Coastguard Worker        >>> torch.split(a, 2)
187*da0073e9SAndroid Build Coastguard Worker        (tensor([[0, 1],
188*da0073e9SAndroid Build Coastguard Worker                 [2, 3]]),
189*da0073e9SAndroid Build Coastguard Worker         tensor([[4, 5],
190*da0073e9SAndroid Build Coastguard Worker                 [6, 7]]),
191*da0073e9SAndroid Build Coastguard Worker         tensor([[8, 9]]))
192*da0073e9SAndroid Build Coastguard Worker        >>> torch.split(a, [1, 4])
193*da0073e9SAndroid Build Coastguard Worker        (tensor([[0, 1]]),
194*da0073e9SAndroid Build Coastguard Worker         tensor([[2, 3],
195*da0073e9SAndroid Build Coastguard Worker                 [4, 5],
196*da0073e9SAndroid Build Coastguard Worker                 [6, 7],
197*da0073e9SAndroid Build Coastguard Worker                 [8, 9]]))
198*da0073e9SAndroid Build Coastguard Worker    """
199*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(tensor):
200*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
201*da0073e9SAndroid Build Coastguard Worker            split, (tensor,), tensor, split_size_or_sections, dim=dim
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker    # Overwriting reason:
204*da0073e9SAndroid Build Coastguard Worker    # This dispatches to two ATen functions depending on the type of
205*da0073e9SAndroid Build Coastguard Worker    # split_size_or_sections. The branching code is in _tensor.py, which we
206*da0073e9SAndroid Build Coastguard Worker    # call here.
207*da0073e9SAndroid Build Coastguard Worker    return tensor.split(split_size_or_sections, dim)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Workerdef einsum(*args: Any) -> Tensor:
211*da0073e9SAndroid Build Coastguard Worker    r"""einsum(equation, *operands) -> Tensor
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
214*da0073e9SAndroid Build Coastguard Worker    based on the Einstein summation convention.
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
217*da0073e9SAndroid Build Coastguard Worker    in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
218*da0073e9SAndroid Build Coastguard Worker    this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
219*da0073e9SAndroid Build Coastguard Worker    with some subscript and define which subscripts are part of the output. The output is then computed by summing
220*da0073e9SAndroid Build Coastguard Worker    the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
221*da0073e9SAndroid Build Coastguard Worker    output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
222*da0073e9SAndroid Build Coastguard Worker    Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    Equation:
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
227*da0073e9SAndroid Build Coastguard Worker        the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
228*da0073e9SAndroid Build Coastguard Worker        comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
229*da0073e9SAndroid Build Coastguard Worker        must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
230*da0073e9SAndroid Build Coastguard Worker        repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
231*da0073e9SAndroid Build Coastguard Worker        must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
232*da0073e9SAndroid Build Coastguard Worker        appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
233*da0073e9SAndroid Build Coastguard Worker        The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
234*da0073e9SAndroid Build Coastguard Worker        on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
237*da0073e9SAndroid Build Coastguard Worker        followed by the subscripts for the output. For instance, the following equation computes the transpose of a
238*da0073e9SAndroid Build Coastguard Worker        matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
239*da0073e9SAndroid Build Coastguard Worker        at most once for the output.
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
242*da0073e9SAndroid Build Coastguard Worker        Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
243*da0073e9SAndroid Build Coastguard Worker        e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
244*da0073e9SAndroid Build Coastguard Worker        dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
245*da0073e9SAndroid Build Coastguard Worker        'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
246*da0073e9SAndroid Build Coastguard Worker        explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
247*da0073e9SAndroid Build Coastguard Worker        before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
248*da0073e9SAndroid Build Coastguard Worker        batch matrix multiplication `'...ij,...jk'`.
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
251*da0073e9SAndroid Build Coastguard Worker        arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    .. note::
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
256*da0073e9SAndroid Build Coastguard Worker        covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    .. note::
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to
261*da0073e9SAndroid Build Coastguard Worker        consume less memory by optimizing contraction order. This optimization occurs when there are at least three
262*da0073e9SAndroid Build Coastguard Worker        inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem,
263*da0073e9SAndroid Build Coastguard Worker        thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available,
264*da0073e9SAndroid Build Coastguard Worker        the default order is to contract from left to right.
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path
267*da0073e9SAndroid Build Coastguard Worker        calculation: `torch.backends.opt_einsum.enabled = False`
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker        To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
270*da0073e9SAndroid Build Coastguard Worker        `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and
271*da0073e9SAndroid Build Coastguard Worker        'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
272*da0073e9SAndroid Build Coastguard Worker        the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    .. note::
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
277*da0073e9SAndroid Build Coastguard Worker        subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
278*da0073e9SAndroid Build Coastguard Worker        follow their operands, and an extra sublist can appear at the end of the input to specify the output's
279*da0073e9SAndroid Build Coastguard Worker        subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
280*da0073e9SAndroid Build Coastguard Worker        may be provided in a sublist to enable broadcasting as described in the Equation section above.
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    Args:
283*da0073e9SAndroid Build Coastguard Worker        equation (str): The subscripts for the Einstein summation.
284*da0073e9SAndroid Build Coastguard Worker        operands (List[Tensor]): The tensors to compute the Einstein summation of.
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    Examples::
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
289*da0073e9SAndroid Build Coastguard Worker        >>> # trace
290*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('ii', torch.randn(4, 4))
291*da0073e9SAndroid Build Coastguard Worker        tensor(-1.2104)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
294*da0073e9SAndroid Build Coastguard Worker        >>> # diagonal
295*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('ii->i', torch.randn(4, 4))
296*da0073e9SAndroid Build Coastguard Worker        tensor([-0.1034,  0.7952, -0.2433,  0.4545])
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
299*da0073e9SAndroid Build Coastguard Worker        >>> # outer product
300*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(5)
301*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.randn(4)
302*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('i,j->ij', x, y)
303*da0073e9SAndroid Build Coastguard Worker        tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
304*da0073e9SAndroid Build Coastguard Worker                [-0.3744,  0.9381,  1.2685, -1.6070],
305*da0073e9SAndroid Build Coastguard Worker                [ 0.7208, -1.8058, -2.4419,  3.0936],
306*da0073e9SAndroid Build Coastguard Worker                [ 0.1713, -0.4291, -0.5802,  0.7350],
307*da0073e9SAndroid Build Coastguard Worker                [ 0.5704, -1.4290, -1.9323,  2.4480]])
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
310*da0073e9SAndroid Build Coastguard Worker        >>> # batch matrix multiplication
311*da0073e9SAndroid Build Coastguard Worker        >>> As = torch.randn(3, 2, 5)
312*da0073e9SAndroid Build Coastguard Worker        >>> Bs = torch.randn(3, 5, 4)
313*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('bij,bjk->bik', As, Bs)
314*da0073e9SAndroid Build Coastguard Worker        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
315*da0073e9SAndroid Build Coastguard Worker                [-1.6706, -0.8097, -0.8025, -2.1183]],
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker                [[ 4.2239,  0.3107, -0.5756, -0.2354],
318*da0073e9SAndroid Build Coastguard Worker                [-1.4558, -0.3460,  1.5087, -0.8530]],
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker                [[ 2.8153,  1.8787, -4.3839, -1.2112],
321*da0073e9SAndroid Build Coastguard Worker                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
324*da0073e9SAndroid Build Coastguard Worker        >>> # with sublist format and ellipsis
325*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
326*da0073e9SAndroid Build Coastguard Worker        tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
327*da0073e9SAndroid Build Coastguard Worker                [-1.6706, -0.8097, -0.8025, -2.1183]],
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker                [[ 4.2239,  0.3107, -0.5756, -0.2354],
330*da0073e9SAndroid Build Coastguard Worker                [-1.4558, -0.3460,  1.5087, -0.8530]],
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker                [[ 2.8153,  1.8787, -4.3839, -1.2112],
333*da0073e9SAndroid Build Coastguard Worker                [ 0.3728, -2.1131,  0.0921,  0.8305]]])
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        >>> # batch permute
336*da0073e9SAndroid Build Coastguard Worker        >>> A = torch.randn(2, 3, 4, 5)
337*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('...ij->...ji', A).shape
338*da0073e9SAndroid Build Coastguard Worker        torch.Size([2, 3, 5, 4])
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        >>> # equivalent to torch.nn.functional.bilinear
341*da0073e9SAndroid Build Coastguard Worker        >>> A = torch.randn(3, 5, 4)
342*da0073e9SAndroid Build Coastguard Worker        >>> l = torch.randn(2, 5)
343*da0073e9SAndroid Build Coastguard Worker        >>> r = torch.randn(2, 4)
344*da0073e9SAndroid Build Coastguard Worker        >>> torch.einsum('bn,anm,bm->ba', l, A, r)
345*da0073e9SAndroid Build Coastguard Worker        tensor([[-0.3430, -5.2405,  0.4494],
346*da0073e9SAndroid Build Coastguard Worker                [ 0.3311,  5.5201, -3.0356]])
347*da0073e9SAndroid Build Coastguard Worker    """
348*da0073e9SAndroid Build Coastguard Worker    import torch.backends.opt_einsum as opt_einsum
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
351*da0073e9SAndroid Build Coastguard Worker    if len(args) < 2:
352*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
353*da0073e9SAndroid Build Coastguard Worker            "einsum(): must specify the equation string and at least one operand, "
354*da0073e9SAndroid Build Coastguard Worker            "or at least one operand and its subscripts list"
355*da0073e9SAndroid Build Coastguard Worker        )
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    equation = None
358*da0073e9SAndroid Build Coastguard Worker    operands = None
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    if isinstance(args[0], torch.Tensor):
361*da0073e9SAndroid Build Coastguard Worker        # Convert the subscript list format which is an interleaving of operand and its subscripts
362*da0073e9SAndroid Build Coastguard Worker        # list with an optional output subscripts list at the end (see documentation for more details on this)
363*da0073e9SAndroid Build Coastguard Worker        # to the equation string format by creating the equation string from the subscripts list and grouping the
364*da0073e9SAndroid Build Coastguard Worker        # input operands into a tensorlist (List[Tensor]).
365*da0073e9SAndroid Build Coastguard Worker        def parse_subscript(n: int) -> str:
366*da0073e9SAndroid Build Coastguard Worker            if n == Ellipsis:
367*da0073e9SAndroid Build Coastguard Worker                return "..."
368*da0073e9SAndroid Build Coastguard Worker            if n >= 0 and n < 26:
369*da0073e9SAndroid Build Coastguard Worker                return chr(ord("A") + n)
370*da0073e9SAndroid Build Coastguard Worker            if n >= 26 and n < 52:
371*da0073e9SAndroid Build Coastguard Worker                return chr(ord("a") + n - 26)
372*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
373*da0073e9SAndroid Build Coastguard Worker                "einsum(): subscript in subscript list is not within the valid range [0, 52)"
374*da0073e9SAndroid Build Coastguard Worker            )
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        # Parse subscripts for input operands
377*da0073e9SAndroid Build Coastguard Worker        equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        # Parse optional output subscripts (provided when the number of arguments is odd)
380*da0073e9SAndroid Build Coastguard Worker        if len(args) % 2 == 1:
381*da0073e9SAndroid Build Coastguard Worker            equation += "->" + "".join(parse_subscript(s) for s in args[-1])
382*da0073e9SAndroid Build Coastguard Worker            operands = args[:-1:2]
383*da0073e9SAndroid Build Coastguard Worker        else:
384*da0073e9SAndroid Build Coastguard Worker            operands = args[::2]
385*da0073e9SAndroid Build Coastguard Worker    else:
386*da0073e9SAndroid Build Coastguard Worker        equation = args[0]
387*da0073e9SAndroid Build Coastguard Worker        operands = args[1:]
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(operands):
390*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(einsum, operands, equation, *operands)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
393*da0073e9SAndroid Build Coastguard Worker        # the old interface of passing the operands as one list argument
394*da0073e9SAndroid Build Coastguard Worker        _operands = operands[0]
395*da0073e9SAndroid Build Coastguard Worker        # recurse incase operands contains value that has torch function
396*da0073e9SAndroid Build Coastguard Worker        # in the original implementation this line is omitted
397*da0073e9SAndroid Build Coastguard Worker        return einsum(equation, *_operands)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker    if len(operands) <= 2 or not opt_einsum.enabled:
400*da0073e9SAndroid Build Coastguard Worker        # the path for contracting 0 or 1 time(s) is already optimized
401*da0073e9SAndroid Build Coastguard Worker        # or the user has disabled using opt_einsum
402*da0073e9SAndroid Build Coastguard Worker        return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    path = None
405*da0073e9SAndroid Build Coastguard Worker    if opt_einsum.is_available():
406*da0073e9SAndroid Build Coastguard Worker        _opt_einsum = opt_einsum.get_opt_einsum()
407*da0073e9SAndroid Build Coastguard Worker        tupled_path = _opt_einsum.contract_path(
408*da0073e9SAndroid Build Coastguard Worker            equation, *operands, optimize=opt_einsum.strategy
409*da0073e9SAndroid Build Coastguard Worker        )[0]
410*da0073e9SAndroid Build Coastguard Worker        # flatten path for dispatching to C++
411*da0073e9SAndroid Build Coastguard Worker        path = [item for pair in tupled_path for item in pair]
412*da0073e9SAndroid Build Coastguard Worker    return _VF.einsum(equation, operands, path=path)  # type: ignore[attr-defined]
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker# This wrapper exists to support variadic args.
416*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
417*da0073e9SAndroid Build Coastguard Worker    # The JIT doesn't understand Union, so only add type annotation for mypy
418*da0073e9SAndroid Build Coastguard Worker    def meshgrid(
419*da0073e9SAndroid Build Coastguard Worker        *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
420*da0073e9SAndroid Build Coastguard Worker    ) -> Tuple[Tensor, ...]:
421*da0073e9SAndroid Build Coastguard Worker        return _meshgrid(*tensors, indexing=indexing)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Workerelse:
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
426*da0073e9SAndroid Build Coastguard Worker        r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker        This is helpful when you want to visualize data over some
429*da0073e9SAndroid Build Coastguard Worker        range of inputs. See below for a plotting example.
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
432*da0073e9SAndroid Build Coastguard Worker        inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
433*da0073e9SAndroid Build Coastguard Worker        this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
434*da0073e9SAndroid Build Coastguard Worker        G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
435*da0073e9SAndroid Build Coastguard Worker        the output :math:`G_i` is constructed by expanding :math:`T_i`
436*da0073e9SAndroid Build Coastguard Worker        to the result shape.
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        .. note::
439*da0073e9SAndroid Build Coastguard Worker            0D inputs are treated equivalently to 1D inputs of a
440*da0073e9SAndroid Build Coastguard Worker            single element.
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        .. warning::
443*da0073e9SAndroid Build Coastguard Worker            `torch.meshgrid(*tensors)` currently has the same behavior
444*da0073e9SAndroid Build Coastguard Worker            as calling `numpy.meshgrid(*arrays, indexing='ij')`.
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker            In the future `torch.meshgrid` will transition to
447*da0073e9SAndroid Build Coastguard Worker            `indexing='xy'` as the default.
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker            https://github.com/pytorch/pytorch/issues/50276 tracks
450*da0073e9SAndroid Build Coastguard Worker            this issue with the goal of migrating to NumPy's behavior.
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        .. seealso::
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker            :func:`torch.cartesian_prod` has the same effect but it
455*da0073e9SAndroid Build Coastguard Worker            collects the data in a tensor of vectors.
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        Args:
458*da0073e9SAndroid Build Coastguard Worker            tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
459*da0073e9SAndroid Build Coastguard Worker                treated as tensors of size :math:`(1,)` automatically
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker            indexing: (str, optional): the indexing mode, either "xy"
462*da0073e9SAndroid Build Coastguard Worker                or "ij", defaults to "ij". See warning for future changes.
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker                If "xy" is selected, the first dimension corresponds
465*da0073e9SAndroid Build Coastguard Worker                to the cardinality of the second input and the second
466*da0073e9SAndroid Build Coastguard Worker                dimension corresponds to the cardinality of the first
467*da0073e9SAndroid Build Coastguard Worker                input.
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker                If "ij" is selected, the dimensions are in the same
470*da0073e9SAndroid Build Coastguard Worker                order as the cardinality of the inputs.
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        Returns:
473*da0073e9SAndroid Build Coastguard Worker            seq (sequence of Tensors): If the input has :math:`N`
474*da0073e9SAndroid Build Coastguard Worker            tensors of size :math:`S_0 \ldots S_{N-1}``, then the
475*da0073e9SAndroid Build Coastguard Worker            output will also have :math:`N` tensors, where each tensor
476*da0073e9SAndroid Build Coastguard Worker            is of shape :math:`(S_0, ..., S_{N-1})`.
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        Example::
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker            >>> x = torch.tensor([1, 2, 3])
481*da0073e9SAndroid Build Coastguard Worker            >>> y = torch.tensor([4, 5, 6])
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            Observe the element-wise pairings across the grid, (1, 4),
484*da0073e9SAndroid Build Coastguard Worker            (1, 5), ..., (3, 6). This is the same thing as the
485*da0073e9SAndroid Build Coastguard Worker            cartesian product.
486*da0073e9SAndroid Build Coastguard Worker            >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
487*da0073e9SAndroid Build Coastguard Worker            >>> grid_x
488*da0073e9SAndroid Build Coastguard Worker            tensor([[1, 1, 1],
489*da0073e9SAndroid Build Coastguard Worker                    [2, 2, 2],
490*da0073e9SAndroid Build Coastguard Worker                    [3, 3, 3]])
491*da0073e9SAndroid Build Coastguard Worker            >>> grid_y
492*da0073e9SAndroid Build Coastguard Worker            tensor([[4, 5, 6],
493*da0073e9SAndroid Build Coastguard Worker                    [4, 5, 6],
494*da0073e9SAndroid Build Coastguard Worker                    [4, 5, 6]])
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker            This correspondence can be seen when these grids are
497*da0073e9SAndroid Build Coastguard Worker            stacked properly.
498*da0073e9SAndroid Build Coastguard Worker            >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
499*da0073e9SAndroid Build Coastguard Worker            ...             torch.cartesian_prod(x, y))
500*da0073e9SAndroid Build Coastguard Worker            True
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            `torch.meshgrid` is commonly used to produce a grid for
503*da0073e9SAndroid Build Coastguard Worker            plotting.
504*da0073e9SAndroid Build Coastguard Worker            >>> # xdoctest: +REQUIRES(module:matplotlib)
505*da0073e9SAndroid Build Coastguard Worker            >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
506*da0073e9SAndroid Build Coastguard Worker            >>> import matplotlib.pyplot as plt
507*da0073e9SAndroid Build Coastguard Worker            >>> xs = torch.linspace(-5, 5, steps=100)
508*da0073e9SAndroid Build Coastguard Worker            >>> ys = torch.linspace(-5, 5, steps=100)
509*da0073e9SAndroid Build Coastguard Worker            >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
510*da0073e9SAndroid Build Coastguard Worker            >>> z = torch.sin(torch.sqrt(x * x + y * y))
511*da0073e9SAndroid Build Coastguard Worker            >>> ax = plt.axes(projection='3d')
512*da0073e9SAndroid Build Coastguard Worker            >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
513*da0073e9SAndroid Build Coastguard Worker            >>> plt.show()
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        .. image:: ../_static/img/meshgrid.png
516*da0073e9SAndroid Build Coastguard Worker            :width: 512
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        """
519*da0073e9SAndroid Build Coastguard Worker        return _meshgrid(*tensors, indexing=indexing)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Workerdef _meshgrid(*tensors, indexing: Optional[str]):
523*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
524*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
525*da0073e9SAndroid Build Coastguard Worker    if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
526*da0073e9SAndroid Build Coastguard Worker        # the old interface of passing the operands as one list argument
527*da0073e9SAndroid Build Coastguard Worker        tensors = tensors[0]  # type: ignore[assignment]
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker    # Continue allowing call of old method that takes no indexing
530*da0073e9SAndroid Build Coastguard Worker    # kwarg for forward compatibility reasons.
531*da0073e9SAndroid Build Coastguard Worker    #
532*da0073e9SAndroid Build Coastguard Worker    # Remove this two weeks after landing.
533*da0073e9SAndroid Build Coastguard Worker    kwargs = {} if indexing is None else {"indexing": indexing}
534*da0073e9SAndroid Build Coastguard Worker    return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Workerdef stft(
538*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
539*da0073e9SAndroid Build Coastguard Worker    n_fft: int,
540*da0073e9SAndroid Build Coastguard Worker    hop_length: Optional[int] = None,
541*da0073e9SAndroid Build Coastguard Worker    win_length: Optional[int] = None,
542*da0073e9SAndroid Build Coastguard Worker    window: Optional[Tensor] = None,
543*da0073e9SAndroid Build Coastguard Worker    center: bool = True,
544*da0073e9SAndroid Build Coastguard Worker    pad_mode: str = "reflect",
545*da0073e9SAndroid Build Coastguard Worker    normalized: bool = False,
546*da0073e9SAndroid Build Coastguard Worker    onesided: Optional[bool] = None,
547*da0073e9SAndroid Build Coastguard Worker    return_complex: Optional[bool] = None,
548*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
549*da0073e9SAndroid Build Coastguard Worker    r"""Short-time Fourier transform (STFT).
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker    .. warning::
552*da0073e9SAndroid Build Coastguard Worker        From version 1.8.0, :attr:`return_complex` must always be given
553*da0073e9SAndroid Build Coastguard Worker        explicitly for real inputs and `return_complex=False` has been
554*da0073e9SAndroid Build Coastguard Worker        deprecated. Strongly prefer `return_complex=True` as in a future
555*da0073e9SAndroid Build Coastguard Worker        pytorch release, this function will only return complex tensors.
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        Note that :func:`torch.view_as_real` can be used to recover a real
558*da0073e9SAndroid Build Coastguard Worker        tensor with an extra last dimension for real and imaginary components.
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker    .. warning::
561*da0073e9SAndroid Build Coastguard Worker        From version 2.1, a warning will be provided if a :attr:`window` is
562*da0073e9SAndroid Build Coastguard Worker        not specified. In a future release, this attribute will be required.
563*da0073e9SAndroid Build Coastguard Worker        Not providing a window currently defaults to using a rectangular window,
564*da0073e9SAndroid Build Coastguard Worker        which may result in undesirable artifacts. Consider using tapered windows,
565*da0073e9SAndroid Build Coastguard Worker        such as :func:`torch.hann_window`.
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    The STFT computes the Fourier transform of short overlapping windows of the
568*da0073e9SAndroid Build Coastguard Worker    input. This giving frequency components of the signal as they change over
569*da0073e9SAndroid Build Coastguard Worker    time. The interface of this function is modeled after (but *not* a drop-in
570*da0073e9SAndroid Build Coastguard Worker    replacement for) librosa_ stft function.
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker    .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker    Ignoring the optional batch dimension, this method computes the following
575*da0073e9SAndroid Build Coastguard Worker    expression:
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker    .. math::
578*da0073e9SAndroid Build Coastguard Worker        X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
579*da0073e9SAndroid Build Coastguard Worker                            \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
580*da0073e9SAndroid Build Coastguard Worker                            \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    where :math:`m` is the index of the sliding window, and :math:`\omega` is
583*da0073e9SAndroid Build Coastguard Worker    the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
584*da0073e9SAndroid Build Coastguard Worker    or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
587*da0073e9SAndroid Build Coastguard Worker      sequences.
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker    * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
590*da0073e9SAndroid Build Coastguard Worker      ``floor(n_fft / 4)``.
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker    * If :attr:`win_length` is ``None`` (default), it is treated as equal to
593*da0073e9SAndroid Build Coastguard Worker      :attr:`n_fft`.
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker    * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
596*da0073e9SAndroid Build Coastguard Worker      :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
597*da0073e9SAndroid Build Coastguard Worker      treated as if having :math:`1` everywhere in the window. If
598*da0073e9SAndroid Build Coastguard Worker      :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
599*da0073e9SAndroid Build Coastguard Worker      both sides to length :attr:`n_fft` before being applied.
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker    * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
602*da0073e9SAndroid Build Coastguard Worker      both sides so that the :math:`t`-th frame is centered at time
603*da0073e9SAndroid Build Coastguard Worker      :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
604*da0073e9SAndroid Build Coastguard Worker      begins at time  :math:`t \times \text{hop\_length}`.
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    * :attr:`pad_mode` determines the padding method used on :attr:`input` when
607*da0073e9SAndroid Build Coastguard Worker      :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
608*da0073e9SAndroid Build Coastguard Worker      all available options. Default is ``"reflect"``.
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    * If :attr:`onesided` is ``True`` (default for real input), only values for
611*da0073e9SAndroid Build Coastguard Worker      :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
612*da0073e9SAndroid Build Coastguard Worker      \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
613*da0073e9SAndroid Build Coastguard Worker      the real-to-complex Fourier transform satisfies the conjugate symmetry,
614*da0073e9SAndroid Build Coastguard Worker      i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
615*da0073e9SAndroid Build Coastguard Worker      Note if the input or window tensors are complex, then :attr:`onesided`
616*da0073e9SAndroid Build Coastguard Worker      output is not possible.
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    * If :attr:`normalized` is ``True`` (default is ``False``), the function
619*da0073e9SAndroid Build Coastguard Worker      returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
620*da0073e9SAndroid Build Coastguard Worker
621*da0073e9SAndroid Build Coastguard Worker    * If :attr:`return_complex` is ``True`` (default if input is complex), the
622*da0073e9SAndroid Build Coastguard Worker      return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
623*da0073e9SAndroid Build Coastguard Worker      the output is a ``input.dim() + 2`` dimensional real tensor where the last
624*da0073e9SAndroid Build Coastguard Worker      dimension represents the real and imaginary components.
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker    Returns either a complex tensor of size :math:`(* \times N \times T)` if
627*da0073e9SAndroid Build Coastguard Worker    :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
628*da0073e9SAndroid Build Coastguard Worker    \times T \times 2)`. Where :math:`*` is the optional batch size of
629*da0073e9SAndroid Build Coastguard Worker    :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
630*da0073e9SAndroid Build Coastguard Worker    and :math:`T` is the total number of frames used.
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker    .. warning::
633*da0073e9SAndroid Build Coastguard Worker      This function changed signature at version 0.4.1. Calling with the
634*da0073e9SAndroid Build Coastguard Worker      previous signature may cause error or return incorrect result.
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    Args:
637*da0073e9SAndroid Build Coastguard Worker        input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
638*da0073e9SAndroid Build Coastguard Worker            batch dimension
639*da0073e9SAndroid Build Coastguard Worker        n_fft (int): size of Fourier transform
640*da0073e9SAndroid Build Coastguard Worker        hop_length (int, optional): the distance between neighboring sliding window
641*da0073e9SAndroid Build Coastguard Worker            frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
642*da0073e9SAndroid Build Coastguard Worker        win_length (int, optional): the size of window frame and STFT filter.
643*da0073e9SAndroid Build Coastguard Worker            Default: ``None``  (treated as equal to :attr:`n_fft`)
644*da0073e9SAndroid Build Coastguard Worker        window (Tensor, optional): the optional window function.
645*da0073e9SAndroid Build Coastguard Worker            Shape must be 1d and `<= n_fft`
646*da0073e9SAndroid Build Coastguard Worker            Default: ``None`` (treated as window of all :math:`1` s)
647*da0073e9SAndroid Build Coastguard Worker        center (bool, optional): whether to pad :attr:`input` on both sides so
648*da0073e9SAndroid Build Coastguard Worker            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
649*da0073e9SAndroid Build Coastguard Worker            Default: ``True``
650*da0073e9SAndroid Build Coastguard Worker        pad_mode (str, optional): controls the padding method used when
651*da0073e9SAndroid Build Coastguard Worker            :attr:`center` is ``True``. Default: ``"reflect"``
652*da0073e9SAndroid Build Coastguard Worker        normalized (bool, optional): controls whether to return the normalized STFT results
653*da0073e9SAndroid Build Coastguard Worker             Default: ``False``
654*da0073e9SAndroid Build Coastguard Worker        onesided (bool, optional): controls whether to return half of results to
655*da0073e9SAndroid Build Coastguard Worker            avoid redundancy for real inputs.
656*da0073e9SAndroid Build Coastguard Worker            Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
657*da0073e9SAndroid Build Coastguard Worker        return_complex (bool, optional): whether to return a complex tensor, or
658*da0073e9SAndroid Build Coastguard Worker            a real tensor with an extra last dimension for the real and
659*da0073e9SAndroid Build Coastguard Worker            imaginary components.
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker            .. versionchanged:: 2.0
662*da0073e9SAndroid Build Coastguard Worker               ``return_complex`` is now a required argument for real inputs,
663*da0073e9SAndroid Build Coastguard Worker               as the default is being transitioned to ``True``.
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker            .. deprecated:: 2.0
666*da0073e9SAndroid Build Coastguard Worker               ``return_complex=False`` is deprecated, instead use ``return_complex=True``
667*da0073e9SAndroid Build Coastguard Worker               Note that calling :func:`torch.view_as_real` on the output will
668*da0073e9SAndroid Build Coastguard Worker               recover the deprecated output format.
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    Returns:
671*da0073e9SAndroid Build Coastguard Worker        Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
672*da0073e9SAndroid Build Coastguard Worker           - `B?` is an optional batch dimension from the input.
673*da0073e9SAndroid Build Coastguard Worker           - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
674*da0073e9SAndroid Build Coastguard Worker             `onesided=True`, or otherwise `n_fft`.
675*da0073e9SAndroid Build Coastguard Worker           - `T` is the number of frames, `1 + L // hop_length`
676*da0073e9SAndroid Build Coastguard Worker             for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
677*da0073e9SAndroid Build Coastguard Worker           - `C?` is an optional length-2 dimension of real and imaginary
678*da0073e9SAndroid Build Coastguard Worker             components, present when `return_complex=False`.
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    """
681*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
682*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
683*da0073e9SAndroid Build Coastguard Worker            stft,
684*da0073e9SAndroid Build Coastguard Worker            (input,),
685*da0073e9SAndroid Build Coastguard Worker            input,
686*da0073e9SAndroid Build Coastguard Worker            n_fft,
687*da0073e9SAndroid Build Coastguard Worker            hop_length=hop_length,
688*da0073e9SAndroid Build Coastguard Worker            win_length=win_length,
689*da0073e9SAndroid Build Coastguard Worker            window=window,
690*da0073e9SAndroid Build Coastguard Worker            center=center,
691*da0073e9SAndroid Build Coastguard Worker            pad_mode=pad_mode,
692*da0073e9SAndroid Build Coastguard Worker            normalized=normalized,
693*da0073e9SAndroid Build Coastguard Worker            onesided=onesided,
694*da0073e9SAndroid Build Coastguard Worker            return_complex=return_complex,
695*da0073e9SAndroid Build Coastguard Worker        )
696*da0073e9SAndroid Build Coastguard Worker    # NOTE: Do not edit. This code will be removed once the forward-compatibility
697*da0073e9SAndroid Build Coastguard Worker    #       period is over for PR #73432
698*da0073e9SAndroid Build Coastguard Worker    if center:
699*da0073e9SAndroid Build Coastguard Worker        signal_dim = input.dim()
700*da0073e9SAndroid Build Coastguard Worker        extended_shape = [1] * (3 - signal_dim) + list(input.size())
701*da0073e9SAndroid Build Coastguard Worker        pad = int(n_fft // 2)
702*da0073e9SAndroid Build Coastguard Worker        input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
703*da0073e9SAndroid Build Coastguard Worker        input = input.view(input.shape[-signal_dim:])
704*da0073e9SAndroid Build Coastguard Worker    return _VF.stft(  # type: ignore[attr-defined]
705*da0073e9SAndroid Build Coastguard Worker        input,
706*da0073e9SAndroid Build Coastguard Worker        n_fft,
707*da0073e9SAndroid Build Coastguard Worker        hop_length,
708*da0073e9SAndroid Build Coastguard Worker        win_length,
709*da0073e9SAndroid Build Coastguard Worker        window,
710*da0073e9SAndroid Build Coastguard Worker        normalized,
711*da0073e9SAndroid Build Coastguard Worker        onesided,
712*da0073e9SAndroid Build Coastguard Worker        return_complex,
713*da0073e9SAndroid Build Coastguard Worker    )
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Workeristft = _add_docstr(
717*da0073e9SAndroid Build Coastguard Worker    torch.istft,
718*da0073e9SAndroid Build Coastguard Worker    "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
719*da0073e9SAndroid Build Coastguard Worker    "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
720*da0073e9SAndroid Build Coastguard Worker    r"""
721*da0073e9SAndroid Build Coastguard WorkerInverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker.. warning::
724*da0073e9SAndroid Build Coastguard Worker    From version 2.1, a warning will be provided if a :attr:`window` is
725*da0073e9SAndroid Build Coastguard Worker    not specified. In a future release, this attribute will be required.
726*da0073e9SAndroid Build Coastguard Worker    Please provide the same window used in the stft call.
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard WorkerIt has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
729*da0073e9SAndroid Build Coastguard Workerleast squares estimation of the original signal. The algorithm will check using the NOLA condition (
730*da0073e9SAndroid Build Coastguard Workernonzero overlap).
731*da0073e9SAndroid Build Coastguard Worker
732*da0073e9SAndroid Build Coastguard WorkerImportant consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
733*da0073e9SAndroid Build Coastguard Workercreated by the summation of all the windows is never zero at certain point in time. Specifically,
734*da0073e9SAndroid Build Coastguard Worker:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard WorkerSince :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
737*da0073e9SAndroid Build Coastguard Worker``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
738*da0073e9SAndroid Build Coastguard Workersince the signal isn't padded). If `length` is given in the arguments and is longer than expected,
739*da0073e9SAndroid Build Coastguard Worker``istft`` will pad zeros to the end of the returned signal.
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard WorkerIf :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
742*da0073e9SAndroid Build Coastguard WorkerLeft padding can be trimmed off exactly because they can be calculated but right padding cannot be
743*da0073e9SAndroid Build Coastguard Workercalculated without additional information.
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard WorkerExample: Suppose the last window is:
746*da0073e9SAndroid Build Coastguard Worker``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard WorkerThe :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
749*da0073e9SAndroid Build Coastguard Workerof right padding. These additional values could be zeros or a reflection of the signal so providing
750*da0073e9SAndroid Build Coastguard Worker:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
751*da0073e9SAndroid Build Coastguard Worker(some loss of signal).
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
754*da0073e9SAndroid Build Coastguard WorkerIEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard WorkerArgs:
757*da0073e9SAndroid Build Coastguard Worker    input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
758*da0073e9SAndroid Build Coastguard Worker        output. That is a complex tensor of shape `(B?, N, T)` where
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        - `B?` is an optional batch dimension
761*da0073e9SAndroid Build Coastguard Worker        - `N` is the number of frequency samples, `(n_fft // 2) + 1`
762*da0073e9SAndroid Build Coastguard Worker          for onesided input, or otherwise `n_fft`.
763*da0073e9SAndroid Build Coastguard Worker        - `T` is the number of frames, `1 + length // hop_length` for centered stft,
764*da0073e9SAndroid Build Coastguard Worker          or `1 + (length - n_fft) // hop_length` otherwise.
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker        .. versionchanged:: 2.0
767*da0073e9SAndroid Build Coastguard Worker            Real datatype inputs are no longer supported. Input must now have a
768*da0073e9SAndroid Build Coastguard Worker            complex datatype, as returned by ``stft(..., return_complex=True)``.
769*da0073e9SAndroid Build Coastguard Worker    n_fft (int): Size of Fourier transform
770*da0073e9SAndroid Build Coastguard Worker    hop_length (Optional[int]): The distance between neighboring sliding window frames.
771*da0073e9SAndroid Build Coastguard Worker        (Default: ``n_fft // 4``)
772*da0073e9SAndroid Build Coastguard Worker    win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
773*da0073e9SAndroid Build Coastguard Worker    window (Optional[torch.Tensor]): The optional window function.
774*da0073e9SAndroid Build Coastguard Worker        Shape must be 1d and `<= n_fft`
775*da0073e9SAndroid Build Coastguard Worker        (Default: ``torch.ones(win_length)``)
776*da0073e9SAndroid Build Coastguard Worker    center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
777*da0073e9SAndroid Build Coastguard Worker        centered at time :math:`t \times \text{hop\_length}`.
778*da0073e9SAndroid Build Coastguard Worker        (Default: ``True``)
779*da0073e9SAndroid Build Coastguard Worker    normalized (bool): Whether the STFT was normalized. (Default: ``False``)
780*da0073e9SAndroid Build Coastguard Worker    onesided (Optional[bool]): Whether the STFT was onesided.
781*da0073e9SAndroid Build Coastguard Worker        (Default: ``True`` if `n_fft != fft_size` in the input size)
782*da0073e9SAndroid Build Coastguard Worker    length (Optional[int]): The amount to trim the signal by (i.e. the
783*da0073e9SAndroid Build Coastguard Worker        original signal length). Defaults to `(T - 1) * hop_length` for
784*da0073e9SAndroid Build Coastguard Worker        centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
785*da0073e9SAndroid Build Coastguard Worker        is the number of input frames.
786*da0073e9SAndroid Build Coastguard Worker    return_complex (Optional[bool]):
787*da0073e9SAndroid Build Coastguard Worker        Whether the output should be complex, or if the input should be
788*da0073e9SAndroid Build Coastguard Worker        assumed to derive from a real signal and window.
789*da0073e9SAndroid Build Coastguard Worker        Note that this is incompatible with ``onesided=True``.
790*da0073e9SAndroid Build Coastguard Worker        (Default: ``False``)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard WorkerReturns:
793*da0073e9SAndroid Build Coastguard Worker    Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
794*da0073e9SAndroid Build Coastguard Worker        `B?` is an optional batch dimension from the input tensor.
795*da0073e9SAndroid Build Coastguard Worker""",
796*da0073e9SAndroid Build Coastguard Worker)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
800*da0073e9SAndroid Build Coastguard Worker    # These _impl functions return a variable number of tensors as output with
801*da0073e9SAndroid Build Coastguard Worker    # __torch_function__; tuple unpacking is done already rather than being
802*da0073e9SAndroid Build Coastguard Worker    # done by the caller of the _impl function
803*da0073e9SAndroid Build Coastguard Worker    _unique_impl_out = Any
804*da0073e9SAndroid Build Coastguard Workerelse:
805*da0073e9SAndroid Build Coastguard Worker    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Workerdef _unique_impl(
809*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
810*da0073e9SAndroid Build Coastguard Worker    sorted: bool = True,
811*da0073e9SAndroid Build Coastguard Worker    return_inverse: bool = False,
812*da0073e9SAndroid Build Coastguard Worker    return_counts: bool = False,
813*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = None,
814*da0073e9SAndroid Build Coastguard Worker) -> _unique_impl_out:
815*da0073e9SAndroid Build Coastguard Worker    r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker    Returns the unique elements of the input tensor.
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
820*da0073e9SAndroid Build Coastguard Worker        this function also eliminates non-consecutive duplicate values.
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    .. note:: Currently in the CUDA implementation and the CPU implementation,
823*da0073e9SAndroid Build Coastguard Worker        `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
824*da0073e9SAndroid Build Coastguard Worker        Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
825*da0073e9SAndroid Build Coastguard Worker        :func:`torch.unique_consecutive` which avoids the sorting.
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker    Args:
828*da0073e9SAndroid Build Coastguard Worker        input (Tensor): the input tensor
829*da0073e9SAndroid Build Coastguard Worker        sorted (bool): Whether to sort the unique elements in ascending order
830*da0073e9SAndroid Build Coastguard Worker            before returning as output.
831*da0073e9SAndroid Build Coastguard Worker        return_inverse (bool): Whether to also return the indices for where
832*da0073e9SAndroid Build Coastguard Worker            elements in the original input ended up in the returned unique list.
833*da0073e9SAndroid Build Coastguard Worker        return_counts (bool): Whether to also return the counts for each unique
834*da0073e9SAndroid Build Coastguard Worker            element.
835*da0073e9SAndroid Build Coastguard Worker        dim (int, optional): the dimension to operate upon. If ``None``, the
836*da0073e9SAndroid Build Coastguard Worker            unique of the flattened input is returned. Otherwise, each of the
837*da0073e9SAndroid Build Coastguard Worker            tensors indexed by the given dimension is treated as one of the
838*da0073e9SAndroid Build Coastguard Worker            elements to apply the unique operation upon. See examples for more
839*da0073e9SAndroid Build Coastguard Worker            details. Default: ``None``
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    Returns:
842*da0073e9SAndroid Build Coastguard Worker        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker            - **output** (*Tensor*): the output list of unique scalar elements.
845*da0073e9SAndroid Build Coastguard Worker            - **inverse_indices** (*Tensor*): (optional) if
846*da0073e9SAndroid Build Coastguard Worker              :attr:`return_inverse` is True, there will be an additional
847*da0073e9SAndroid Build Coastguard Worker              returned tensor (same shape as input) representing the indices
848*da0073e9SAndroid Build Coastguard Worker              for where elements in the original input map to in the output;
849*da0073e9SAndroid Build Coastguard Worker              otherwise, this function will only return a single tensor.
850*da0073e9SAndroid Build Coastguard Worker            - **counts** (*Tensor*): (optional) if
851*da0073e9SAndroid Build Coastguard Worker              :attr:`return_counts` is True, there will be an additional
852*da0073e9SAndroid Build Coastguard Worker              returned tensor (same shape as output or output.size(dim),
853*da0073e9SAndroid Build Coastguard Worker              if dim was specified) representing the number of occurrences
854*da0073e9SAndroid Build Coastguard Worker              for each unique value or tensor.
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    Example::
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker        >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
859*da0073e9SAndroid Build Coastguard Worker        >>> output
860*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3])
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker        >>> output, inverse_indices = torch.unique(
863*da0073e9SAndroid Build Coastguard Worker        ...     torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
864*da0073e9SAndroid Build Coastguard Worker        >>> output
865*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3])
866*da0073e9SAndroid Build Coastguard Worker        >>> inverse_indices
867*da0073e9SAndroid Build Coastguard Worker        tensor([0, 2, 1, 2])
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker        >>> output, inverse_indices = torch.unique(
870*da0073e9SAndroid Build Coastguard Worker        ...     torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
871*da0073e9SAndroid Build Coastguard Worker        >>> output
872*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3])
873*da0073e9SAndroid Build Coastguard Worker        >>> inverse_indices
874*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 2],
875*da0073e9SAndroid Build Coastguard Worker                [1, 2]])
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.tensor([
878*da0073e9SAndroid Build Coastguard Worker        ...     [
879*da0073e9SAndroid Build Coastguard Worker        ...         [1, 1, 0, 0],
880*da0073e9SAndroid Build Coastguard Worker        ...         [1, 1, 0, 0],
881*da0073e9SAndroid Build Coastguard Worker        ...         [0, 0, 1, 1],
882*da0073e9SAndroid Build Coastguard Worker        ...     ],
883*da0073e9SAndroid Build Coastguard Worker        ...     [
884*da0073e9SAndroid Build Coastguard Worker        ...         [0, 0, 1, 1],
885*da0073e9SAndroid Build Coastguard Worker        ...         [0, 0, 1, 1],
886*da0073e9SAndroid Build Coastguard Worker        ...         [1, 1, 1, 1],
887*da0073e9SAndroid Build Coastguard Worker        ...     ],
888*da0073e9SAndroid Build Coastguard Worker        ...     [
889*da0073e9SAndroid Build Coastguard Worker        ...         [1, 1, 0, 0],
890*da0073e9SAndroid Build Coastguard Worker        ...         [1, 1, 0, 0],
891*da0073e9SAndroid Build Coastguard Worker        ...         [0, 0, 1, 1],
892*da0073e9SAndroid Build Coastguard Worker        ...     ],
893*da0073e9SAndroid Build Coastguard Worker        ... ])
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker        >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
896*da0073e9SAndroid Build Coastguard Worker        >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
897*da0073e9SAndroid Build Coastguard Worker        >>> # each other, so one of them will be removed.
898*da0073e9SAndroid Build Coastguard Worker        >>> (a[0, :, :] == a[2, :, :]).all()
899*da0073e9SAndroid Build Coastguard Worker        tensor(True)
900*da0073e9SAndroid Build Coastguard Worker        >>> a_unique_dim0 = torch.unique(a, dim=0)
901*da0073e9SAndroid Build Coastguard Worker        >>> a_unique_dim0
902*da0073e9SAndroid Build Coastguard Worker        tensor([[[0, 0, 1, 1],
903*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 1, 1],
904*da0073e9SAndroid Build Coastguard Worker                 [1, 1, 1, 1]],
905*da0073e9SAndroid Build Coastguard Worker                [[1, 1, 0, 0],
906*da0073e9SAndroid Build Coastguard Worker                 [1, 1, 0, 0],
907*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 1, 1]]])
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        >>> # Notice which sub-tensors from `a` match with the sub-tensors from
910*da0073e9SAndroid Build Coastguard Worker        >>> # `a_unique_dim0`:
911*da0073e9SAndroid Build Coastguard Worker        >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
912*da0073e9SAndroid Build Coastguard Worker        tensor(True)
913*da0073e9SAndroid Build Coastguard Worker        >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
914*da0073e9SAndroid Build Coastguard Worker        tensor(True)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker        >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
917*da0073e9SAndroid Build Coastguard Worker        >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
918*da0073e9SAndroid Build Coastguard Worker        >>> # them will be removed.
919*da0073e9SAndroid Build Coastguard Worker        >>> (a[:, 0, :] == a[:, 1, :]).all()
920*da0073e9SAndroid Build Coastguard Worker        tensor(True)
921*da0073e9SAndroid Build Coastguard Worker        >>> torch.unique(a, dim=1)
922*da0073e9SAndroid Build Coastguard Worker        tensor([[[0, 0, 1, 1],
923*da0073e9SAndroid Build Coastguard Worker                 [1, 1, 0, 0]],
924*da0073e9SAndroid Build Coastguard Worker                [[1, 1, 1, 1],
925*da0073e9SAndroid Build Coastguard Worker                 [0, 0, 1, 1]],
926*da0073e9SAndroid Build Coastguard Worker                [[0, 0, 1, 1],
927*da0073e9SAndroid Build Coastguard Worker                 [1, 1, 0, 0]]])
928*da0073e9SAndroid Build Coastguard Worker
929*da0073e9SAndroid Build Coastguard Worker        >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
930*da0073e9SAndroid Build Coastguard Worker        >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
931*da0073e9SAndroid Build Coastguard Worker        >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
932*da0073e9SAndroid Build Coastguard Worker        >>> # sub-tensors will be removed.
933*da0073e9SAndroid Build Coastguard Worker        >>> (a[:, :, 0] == a[:, :, 1]).all()
934*da0073e9SAndroid Build Coastguard Worker        tensor(True)
935*da0073e9SAndroid Build Coastguard Worker        >>> (a[:, :, 2] == a[:, :, 3]).all()
936*da0073e9SAndroid Build Coastguard Worker        tensor(True)
937*da0073e9SAndroid Build Coastguard Worker        >>> torch.unique(a, dim=2)
938*da0073e9SAndroid Build Coastguard Worker        tensor([[[0, 1],
939*da0073e9SAndroid Build Coastguard Worker                 [0, 1],
940*da0073e9SAndroid Build Coastguard Worker                 [1, 0]],
941*da0073e9SAndroid Build Coastguard Worker                [[1, 0],
942*da0073e9SAndroid Build Coastguard Worker                 [1, 0],
943*da0073e9SAndroid Build Coastguard Worker                 [1, 1]],
944*da0073e9SAndroid Build Coastguard Worker                [[0, 1],
945*da0073e9SAndroid Build Coastguard Worker                 [0, 1],
946*da0073e9SAndroid Build Coastguard Worker                 [1, 0]]])
947*da0073e9SAndroid Build Coastguard Worker    """
948*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
949*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
950*da0073e9SAndroid Build Coastguard Worker            unique,
951*da0073e9SAndroid Build Coastguard Worker            (input,),
952*da0073e9SAndroid Build Coastguard Worker            input,
953*da0073e9SAndroid Build Coastguard Worker            sorted=sorted,
954*da0073e9SAndroid Build Coastguard Worker            return_inverse=return_inverse,
955*da0073e9SAndroid Build Coastguard Worker            return_counts=return_counts,
956*da0073e9SAndroid Build Coastguard Worker            dim=dim,
957*da0073e9SAndroid Build Coastguard Worker        )
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker    if dim is not None:
960*da0073e9SAndroid Build Coastguard Worker        output, inverse_indices, counts = _VF.unique_dim(
961*da0073e9SAndroid Build Coastguard Worker            input,
962*da0073e9SAndroid Build Coastguard Worker            dim,
963*da0073e9SAndroid Build Coastguard Worker            sorted=sorted,
964*da0073e9SAndroid Build Coastguard Worker            return_inverse=return_inverse,
965*da0073e9SAndroid Build Coastguard Worker            return_counts=return_counts,
966*da0073e9SAndroid Build Coastguard Worker        )
967*da0073e9SAndroid Build Coastguard Worker    else:
968*da0073e9SAndroid Build Coastguard Worker        output, inverse_indices, counts = torch._unique2(
969*da0073e9SAndroid Build Coastguard Worker            input,
970*da0073e9SAndroid Build Coastguard Worker            sorted=sorted,
971*da0073e9SAndroid Build Coastguard Worker            return_inverse=return_inverse,
972*da0073e9SAndroid Build Coastguard Worker            return_counts=return_counts,
973*da0073e9SAndroid Build Coastguard Worker        )
974*da0073e9SAndroid Build Coastguard Worker    return output, inverse_indices, counts
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Workerdef _unique_consecutive_impl(
978*da0073e9SAndroid Build Coastguard Worker    input: Tensor,
979*da0073e9SAndroid Build Coastguard Worker    return_inverse: bool = False,
980*da0073e9SAndroid Build Coastguard Worker    return_counts: bool = False,
981*da0073e9SAndroid Build Coastguard Worker    dim: Optional[int] = None,
982*da0073e9SAndroid Build Coastguard Worker) -> _unique_impl_out:
983*da0073e9SAndroid Build Coastguard Worker    r"""Eliminates all but the first element from every consecutive group of equivalent elements.
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker    .. note:: This function is different from :func:`torch.unique` in the sense that this function
986*da0073e9SAndroid Build Coastguard Worker        only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
987*da0073e9SAndroid Build Coastguard Worker        in C++.
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker    Args:
990*da0073e9SAndroid Build Coastguard Worker        input (Tensor): the input tensor
991*da0073e9SAndroid Build Coastguard Worker        return_inverse (bool): Whether to also return the indices for where
992*da0073e9SAndroid Build Coastguard Worker            elements in the original input ended up in the returned unique list.
993*da0073e9SAndroid Build Coastguard Worker        return_counts (bool): Whether to also return the counts for each unique
994*da0073e9SAndroid Build Coastguard Worker            element.
995*da0073e9SAndroid Build Coastguard Worker        dim (int): the dimension to apply unique. If ``None``, the unique of the
996*da0073e9SAndroid Build Coastguard Worker            flattened input is returned. default: ``None``
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker    Returns:
999*da0073e9SAndroid Build Coastguard Worker        (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker            - **output** (*Tensor*): the output list of unique scalar elements.
1002*da0073e9SAndroid Build Coastguard Worker            - **inverse_indices** (*Tensor*): (optional) if
1003*da0073e9SAndroid Build Coastguard Worker              :attr:`return_inverse` is True, there will be an additional
1004*da0073e9SAndroid Build Coastguard Worker              returned tensor (same shape as input) representing the indices
1005*da0073e9SAndroid Build Coastguard Worker              for where elements in the original input map to in the output;
1006*da0073e9SAndroid Build Coastguard Worker              otherwise, this function will only return a single tensor.
1007*da0073e9SAndroid Build Coastguard Worker            - **counts** (*Tensor*): (optional) if
1008*da0073e9SAndroid Build Coastguard Worker              :attr:`return_counts` is True, there will be an additional
1009*da0073e9SAndroid Build Coastguard Worker              returned tensor (same shape as output or output.size(dim),
1010*da0073e9SAndroid Build Coastguard Worker              if dim was specified) representing the number of occurrences
1011*da0073e9SAndroid Build Coastguard Worker              for each unique value or tensor.
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker    Example::
1014*da0073e9SAndroid Build Coastguard Worker
1015*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
1016*da0073e9SAndroid Build Coastguard Worker        >>> output = torch.unique_consecutive(x)
1017*da0073e9SAndroid Build Coastguard Worker        >>> output
1018*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3, 1, 2])
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker        >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
1021*da0073e9SAndroid Build Coastguard Worker        >>> output
1022*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3, 1, 2])
1023*da0073e9SAndroid Build Coastguard Worker        >>> inverse_indices
1024*da0073e9SAndroid Build Coastguard Worker        tensor([0, 0, 1, 1, 2, 3, 3, 4])
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        >>> output, counts = torch.unique_consecutive(x, return_counts=True)
1027*da0073e9SAndroid Build Coastguard Worker        >>> output
1028*da0073e9SAndroid Build Coastguard Worker        tensor([1, 2, 3, 1, 2])
1029*da0073e9SAndroid Build Coastguard Worker        >>> counts
1030*da0073e9SAndroid Build Coastguard Worker        tensor([2, 2, 1, 2, 1])
1031*da0073e9SAndroid Build Coastguard Worker    """
1032*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1033*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1034*da0073e9SAndroid Build Coastguard Worker            unique_consecutive,
1035*da0073e9SAndroid Build Coastguard Worker            (input,),
1036*da0073e9SAndroid Build Coastguard Worker            input,
1037*da0073e9SAndroid Build Coastguard Worker            return_inverse=return_inverse,
1038*da0073e9SAndroid Build Coastguard Worker            return_counts=return_counts,
1039*da0073e9SAndroid Build Coastguard Worker            dim=dim,
1040*da0073e9SAndroid Build Coastguard Worker        )
1041*da0073e9SAndroid Build Coastguard Worker    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
1042*da0073e9SAndroid Build Coastguard Worker        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
1043*da0073e9SAndroid Build Coastguard Worker    )
1044*da0073e9SAndroid Build Coastguard Worker    return output, inverse_indices, counts
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Workerdef _return_counts(
1048*da0073e9SAndroid Build Coastguard Worker    input,
1049*da0073e9SAndroid Build Coastguard Worker    sorted=True,
1050*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1051*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1052*da0073e9SAndroid Build Coastguard Worker    dim=None,
1053*da0073e9SAndroid Build Coastguard Worker):
1054*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1057*da0073e9SAndroid Build Coastguard Worker        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker    output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1060*da0073e9SAndroid Build Coastguard Worker    return output, counts
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Workerdef _return_output(
1064*da0073e9SAndroid Build Coastguard Worker    input,
1065*da0073e9SAndroid Build Coastguard Worker    sorted=True,
1066*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1067*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1068*da0073e9SAndroid Build Coastguard Worker    dim=None,
1069*da0073e9SAndroid Build Coastguard Worker):
1070*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1073*da0073e9SAndroid Build Coastguard Worker        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
1076*da0073e9SAndroid Build Coastguard Worker    return output
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Workerdef _return_inverse(
1080*da0073e9SAndroid Build Coastguard Worker    input,
1081*da0073e9SAndroid Build Coastguard Worker    sorted=True,
1082*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1083*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1084*da0073e9SAndroid Build Coastguard Worker    dim=None,
1085*da0073e9SAndroid Build Coastguard Worker):
1086*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1089*da0073e9SAndroid Build Coastguard Worker        return _unique_impl(input, sorted, return_inverse, return_counts, dim)
1090*da0073e9SAndroid Build Coastguard Worker
1091*da0073e9SAndroid Build Coastguard Worker    output, inverse_indices, _ = _unique_impl(
1092*da0073e9SAndroid Build Coastguard Worker        input, sorted, return_inverse, return_counts, dim
1093*da0073e9SAndroid Build Coastguard Worker    )
1094*da0073e9SAndroid Build Coastguard Worker    return output, inverse_indices
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker_return_inverse_false = boolean_dispatch(
1098*da0073e9SAndroid Build Coastguard Worker    arg_name="return_counts",
1099*da0073e9SAndroid Build Coastguard Worker    arg_index=3,
1100*da0073e9SAndroid Build Coastguard Worker    default=False,
1101*da0073e9SAndroid Build Coastguard Worker    if_true=_return_counts,
1102*da0073e9SAndroid Build Coastguard Worker    if_false=_return_output,
1103*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1104*da0073e9SAndroid Build Coastguard Worker    func_name="unique",
1105*da0073e9SAndroid Build Coastguard Worker)
1106*da0073e9SAndroid Build Coastguard Worker
1107*da0073e9SAndroid Build Coastguard Worker_return_inverse_true = boolean_dispatch(
1108*da0073e9SAndroid Build Coastguard Worker    arg_name="return_counts",
1109*da0073e9SAndroid Build Coastguard Worker    arg_index=3,
1110*da0073e9SAndroid Build Coastguard Worker    default=False,
1111*da0073e9SAndroid Build Coastguard Worker    if_true=_unique_impl,
1112*da0073e9SAndroid Build Coastguard Worker    if_false=_return_inverse,
1113*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1114*da0073e9SAndroid Build Coastguard Worker    func_name="unique",
1115*da0073e9SAndroid Build Coastguard Worker)
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1118*da0073e9SAndroid Build Coastguard Worker# resolve the output type in TorchScript we need to statically know the value of both parameters
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Workerunique = boolean_dispatch(
1121*da0073e9SAndroid Build Coastguard Worker    arg_name="return_inverse",
1122*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
1123*da0073e9SAndroid Build Coastguard Worker    default=False,
1124*da0073e9SAndroid Build Coastguard Worker    if_true=_return_inverse_true,
1125*da0073e9SAndroid Build Coastguard Worker    if_false=_return_inverse_false,
1126*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1127*da0073e9SAndroid Build Coastguard Worker    func_name="unique",
1128*da0073e9SAndroid Build Coastguard Worker)
1129*da0073e9SAndroid Build Coastguard Workerunique.__doc__ = _unique_impl.__doc__
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_counts(
1133*da0073e9SAndroid Build Coastguard Worker    input,
1134*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1135*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1136*da0073e9SAndroid Build Coastguard Worker    dim=None,
1137*da0073e9SAndroid Build Coastguard Worker):
1138*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1139*da0073e9SAndroid Build Coastguard Worker
1140*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1141*da0073e9SAndroid Build Coastguard Worker        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker    output, _, counts = _unique_consecutive_impl(
1144*da0073e9SAndroid Build Coastguard Worker        input, return_inverse, return_counts, dim
1145*da0073e9SAndroid Build Coastguard Worker    )
1146*da0073e9SAndroid Build Coastguard Worker    return output, counts
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_output(
1150*da0073e9SAndroid Build Coastguard Worker    input,
1151*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1152*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1153*da0073e9SAndroid Build Coastguard Worker    dim=None,
1154*da0073e9SAndroid Build Coastguard Worker):
1155*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Optional[int]) -> Tensor
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1158*da0073e9SAndroid Build Coastguard Worker        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker    output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1161*da0073e9SAndroid Build Coastguard Worker    return output
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_inverse(
1165*da0073e9SAndroid Build Coastguard Worker    input,
1166*da0073e9SAndroid Build Coastguard Worker    return_inverse=False,
1167*da0073e9SAndroid Build Coastguard Worker    return_counts=False,
1168*da0073e9SAndroid Build Coastguard Worker    dim=None,
1169*da0073e9SAndroid Build Coastguard Worker):
1170*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1173*da0073e9SAndroid Build Coastguard Worker        return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker    output, inverse_indices, _ = _unique_consecutive_impl(
1176*da0073e9SAndroid Build Coastguard Worker        input, return_inverse, return_counts, dim
1177*da0073e9SAndroid Build Coastguard Worker    )
1178*da0073e9SAndroid Build Coastguard Worker    return output, inverse_indices
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker_consecutive_return_inverse_false = boolean_dispatch(
1182*da0073e9SAndroid Build Coastguard Worker    arg_name="return_counts",
1183*da0073e9SAndroid Build Coastguard Worker    arg_index=1,
1184*da0073e9SAndroid Build Coastguard Worker    default=False,
1185*da0073e9SAndroid Build Coastguard Worker    if_true=_consecutive_return_counts,
1186*da0073e9SAndroid Build Coastguard Worker    if_false=_consecutive_return_output,
1187*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1188*da0073e9SAndroid Build Coastguard Worker    func_name="unique_consecutive",
1189*da0073e9SAndroid Build Coastguard Worker)
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker_consecutive_return_inverse_true = boolean_dispatch(
1192*da0073e9SAndroid Build Coastguard Worker    arg_name="return_counts",
1193*da0073e9SAndroid Build Coastguard Worker    arg_index=1,
1194*da0073e9SAndroid Build Coastguard Worker    default=False,
1195*da0073e9SAndroid Build Coastguard Worker    if_true=_unique_consecutive_impl,
1196*da0073e9SAndroid Build Coastguard Worker    if_false=_consecutive_return_inverse,
1197*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1198*da0073e9SAndroid Build Coastguard Worker    func_name="unique_consecutive",
1199*da0073e9SAndroid Build Coastguard Worker)
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
1202*da0073e9SAndroid Build Coastguard Worker# resolve the output type in TorchScript we need to statically know the value of both parameters
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Workerunique_consecutive = boolean_dispatch(
1205*da0073e9SAndroid Build Coastguard Worker    arg_name="return_inverse",
1206*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
1207*da0073e9SAndroid Build Coastguard Worker    default=False,
1208*da0073e9SAndroid Build Coastguard Worker    if_true=_consecutive_return_inverse_true,
1209*da0073e9SAndroid Build Coastguard Worker    if_false=_consecutive_return_inverse_false,
1210*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
1211*da0073e9SAndroid Build Coastguard Worker    func_name="unique_consecutive",
1212*da0073e9SAndroid Build Coastguard Worker)
1213*da0073e9SAndroid Build Coastguard Workerunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
1216*da0073e9SAndroid Build Coastguard Worker    pass
1217*da0073e9SAndroid Build Coastguard Worker    # There's no good way to use this type annotation without breaking JIT
1218*da0073e9SAndroid Build Coastguard Worker    # overloads. So leave untyped for mypy for now.
1219*da0073e9SAndroid Build Coastguard Workerelse:
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker    @overload
1222*da0073e9SAndroid Build Coastguard Worker    def tensordot(
1223*da0073e9SAndroid Build Coastguard Worker        a,
1224*da0073e9SAndroid Build Coastguard Worker        b,
1225*da0073e9SAndroid Build Coastguard Worker        dims: int = 2,
1226*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
1227*da0073e9SAndroid Build Coastguard Worker    ):
1228*da0073e9SAndroid Build Coastguard Worker        pass
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    @overload
1231*da0073e9SAndroid Build Coastguard Worker    def tensordot(  # noqa: F811
1232*da0073e9SAndroid Build Coastguard Worker        a,
1233*da0073e9SAndroid Build Coastguard Worker        b,
1234*da0073e9SAndroid Build Coastguard Worker        dims: Tuple[List[int], List[int]],
1235*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
1236*da0073e9SAndroid Build Coastguard Worker    ):
1237*da0073e9SAndroid Build Coastguard Worker        pass
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker    @overload
1240*da0073e9SAndroid Build Coastguard Worker    def tensordot(  # noqa: F811
1241*da0073e9SAndroid Build Coastguard Worker        a,
1242*da0073e9SAndroid Build Coastguard Worker        b,
1243*da0073e9SAndroid Build Coastguard Worker        dims: List[List[int]],
1244*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
1245*da0073e9SAndroid Build Coastguard Worker    ):
1246*da0073e9SAndroid Build Coastguard Worker        pass
1247*da0073e9SAndroid Build Coastguard Worker
1248*da0073e9SAndroid Build Coastguard Worker    @overload
1249*da0073e9SAndroid Build Coastguard Worker    def tensordot(  # noqa: F811
1250*da0073e9SAndroid Build Coastguard Worker        a,
1251*da0073e9SAndroid Build Coastguard Worker        b,
1252*da0073e9SAndroid Build Coastguard Worker        dims: torch.Tensor,
1253*da0073e9SAndroid Build Coastguard Worker        out: Optional[torch.Tensor] = None,
1254*da0073e9SAndroid Build Coastguard Worker    ):
1255*da0073e9SAndroid Build Coastguard Worker        pass
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Workerdef tensordot(  # noqa: F811
1259*da0073e9SAndroid Build Coastguard Worker    a,
1260*da0073e9SAndroid Build Coastguard Worker    b,
1261*da0073e9SAndroid Build Coastguard Worker    dims=2,
1262*da0073e9SAndroid Build Coastguard Worker    out: Optional[torch.Tensor] = None,
1263*da0073e9SAndroid Build Coastguard Worker):
1264*da0073e9SAndroid Build Coastguard Worker    r"""Returns a contraction of a and b over multiple dimensions.
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker    :attr:`tensordot` implements a generalized matrix product.
1267*da0073e9SAndroid Build Coastguard Worker
1268*da0073e9SAndroid Build Coastguard Worker    Args:
1269*da0073e9SAndroid Build Coastguard Worker      a (Tensor): Left tensor to contract
1270*da0073e9SAndroid Build Coastguard Worker      b (Tensor): Right tensor to contract
1271*da0073e9SAndroid Build Coastguard Worker      dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
1272*da0073e9SAndroid Build Coastguard Worker         contract or explicit lists of dimensions for :attr:`a` and
1273*da0073e9SAndroid Build Coastguard Worker         :attr:`b` respectively
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker    When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
1276*da0073e9SAndroid Build Coastguard Worker    the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
1277*da0073e9SAndroid Build Coastguard Worker    respectively, :func:`~torch.tensordot` computes
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    .. math::
1280*da0073e9SAndroid Build Coastguard Worker        r_{i_0,...,i_{m-d}, i_d,...,i_n}
1281*da0073e9SAndroid Build Coastguard Worker          = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker    When called with :attr:`dims` of the list form, the given dimensions will be contracted
1284*da0073e9SAndroid Build Coastguard Worker    in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
1285*da0073e9SAndroid Build Coastguard Worker    in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
1286*da0073e9SAndroid Build Coastguard Worker    dimensions.
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker    Examples::
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.arange(60.).reshape(3, 4, 5)
1291*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.arange(24.).reshape(4, 3, 2)
1292*da0073e9SAndroid Build Coastguard Worker        >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
1293*da0073e9SAndroid Build Coastguard Worker        tensor([[4400., 4730.],
1294*da0073e9SAndroid Build Coastguard Worker                [4532., 4874.],
1295*da0073e9SAndroid Build Coastguard Worker                [4664., 5018.],
1296*da0073e9SAndroid Build Coastguard Worker                [4796., 5162.],
1297*da0073e9SAndroid Build Coastguard Worker                [4928., 5306.]])
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1300*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.randn(3, 4, 5, device='cuda')
1301*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.randn(4, 5, 6, device='cuda')
1302*da0073e9SAndroid Build Coastguard Worker        >>> c = torch.tensordot(a, b, dims=2).cpu()
1303*da0073e9SAndroid Build Coastguard Worker        tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
1304*da0073e9SAndroid Build Coastguard Worker                [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
1305*da0073e9SAndroid Build Coastguard Worker                [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.randn(3, 5, 4, 6)
1308*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.randn(6, 4, 5, 3)
1309*da0073e9SAndroid Build Coastguard Worker        >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
1310*da0073e9SAndroid Build Coastguard Worker        tensor([[  7.7193,  -2.4867, -10.3204],
1311*da0073e9SAndroid Build Coastguard Worker                [  1.5513, -14.4737,  -6.5113],
1312*da0073e9SAndroid Build Coastguard Worker                [ -0.2850,   4.2573,  -3.5997]])
1313*da0073e9SAndroid Build Coastguard Worker    """
1314*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(a, b):
1315*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker    if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
1318*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1319*da0073e9SAndroid Build Coastguard Worker            "tensordot expects dims to be int or "
1320*da0073e9SAndroid Build Coastguard Worker            + "Tuple[List[int], List[int]] or "
1321*da0073e9SAndroid Build Coastguard Worker            + "List[List[int]] containing two lists, but got "
1322*da0073e9SAndroid Build Coastguard Worker            + f"dims={dims}"
1323*da0073e9SAndroid Build Coastguard Worker        )
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker    dims_a: List[int] = []
1326*da0073e9SAndroid Build Coastguard Worker    dims_b: List[int] = []
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker    if isinstance(dims, (tuple, list)):
1329*da0073e9SAndroid Build Coastguard Worker        dims_a, dims_b = dims
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker    if isinstance(dims, torch.Tensor):
1332*da0073e9SAndroid Build Coastguard Worker        num_elements = dims.numel()
1333*da0073e9SAndroid Build Coastguard Worker        if num_elements > 1:
1334*da0073e9SAndroid Build Coastguard Worker            assert dims.size()[0] == 2
1335*da0073e9SAndroid Build Coastguard Worker            dims_a = torch.jit.annotate(List[int], dims[0].tolist())
1336*da0073e9SAndroid Build Coastguard Worker            dims_b = torch.jit.annotate(List[int], dims[1].tolist())
1337*da0073e9SAndroid Build Coastguard Worker        else:
1338*da0073e9SAndroid Build Coastguard Worker            dims_val = int(dims.item())
1339*da0073e9SAndroid Build Coastguard Worker            if dims_val < 0:
1340*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1341*da0073e9SAndroid Build Coastguard Worker            dims_a = list(range(-dims_val, 0))
1342*da0073e9SAndroid Build Coastguard Worker            dims_b = list(range(dims_val))
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    if isinstance(dims, (int, torch.SymInt)):
1345*da0073e9SAndroid Build Coastguard Worker        if dims < 0:
1346*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
1347*da0073e9SAndroid Build Coastguard Worker        if dims > min(a.dim(), b.dim()):
1348*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1349*da0073e9SAndroid Build Coastguard Worker                f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
1350*da0073e9SAndroid Build Coastguard Worker            )
1351*da0073e9SAndroid Build Coastguard Worker        dims_a = list(range(-dims, 0))
1352*da0073e9SAndroid Build Coastguard Worker        dims_b = list(range(dims))
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker    if out is None:
1355*da0073e9SAndroid Build Coastguard Worker        return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore[attr-defined]
1356*da0073e9SAndroid Build Coastguard Worker    else:
1357*da0073e9SAndroid Build Coastguard Worker        return _VF.tensordot(a, b, dims_a, dims_b, out=out)  # type: ignore[attr-defined]
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Workerdef cartesian_prod(*tensors: Tensor) -> Tensor:
1361*da0073e9SAndroid Build Coastguard Worker    """Do cartesian product of the given sequence of tensors. The behavior is similar to
1362*da0073e9SAndroid Build Coastguard Worker    python's `itertools.product`.
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker    Args:
1365*da0073e9SAndroid Build Coastguard Worker        *tensors: any number of 1 dimensional tensors.
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker    Returns:
1368*da0073e9SAndroid Build Coastguard Worker        Tensor: A tensor equivalent to converting all the input tensors into lists,
1369*da0073e9SAndroid Build Coastguard Worker        do `itertools.product` on these lists, and finally convert the resulting list
1370*da0073e9SAndroid Build Coastguard Worker        into tensor.
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker    Example::
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        >>> import itertools
1375*da0073e9SAndroid Build Coastguard Worker        >>> a = [1, 2, 3]
1376*da0073e9SAndroid Build Coastguard Worker        >>> b = [4, 5]
1377*da0073e9SAndroid Build Coastguard Worker        >>> list(itertools.product(a, b))
1378*da0073e9SAndroid Build Coastguard Worker        [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
1379*da0073e9SAndroid Build Coastguard Worker        >>> tensor_a = torch.tensor(a)
1380*da0073e9SAndroid Build Coastguard Worker        >>> tensor_b = torch.tensor(b)
1381*da0073e9SAndroid Build Coastguard Worker        >>> torch.cartesian_prod(tensor_a, tensor_b)
1382*da0073e9SAndroid Build Coastguard Worker        tensor([[1, 4],
1383*da0073e9SAndroid Build Coastguard Worker                [1, 5],
1384*da0073e9SAndroid Build Coastguard Worker                [2, 4],
1385*da0073e9SAndroid Build Coastguard Worker                [2, 5],
1386*da0073e9SAndroid Build Coastguard Worker                [3, 4],
1387*da0073e9SAndroid Build Coastguard Worker                [3, 5]])
1388*da0073e9SAndroid Build Coastguard Worker    """
1389*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
1390*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
1391*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(cartesian_prod, tensors, *tensors)
1392*da0073e9SAndroid Build Coastguard Worker    return _VF.cartesian_prod(tensors)  # type: ignore[attr-defined]
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Workerdef block_diag(*tensors):
1396*da0073e9SAndroid Build Coastguard Worker    """Create a block diagonal matrix from provided tensors.
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker    Args:
1399*da0073e9SAndroid Build Coastguard Worker        *tensors: One or more tensors with 0, 1, or 2 dimensions.
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker    Returns:
1402*da0073e9SAndroid Build Coastguard Worker        Tensor: A 2 dimensional tensor with all the input tensors arranged in
1403*da0073e9SAndroid Build Coastguard Worker        order such that their upper left and lower right corners are
1404*da0073e9SAndroid Build Coastguard Worker        diagonally adjacent. All other elements are set to 0.
1405*da0073e9SAndroid Build Coastguard Worker
1406*da0073e9SAndroid Build Coastguard Worker    Example::
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker        >>> import torch
1409*da0073e9SAndroid Build Coastguard Worker        >>> A = torch.tensor([[0, 1], [1, 0]])
1410*da0073e9SAndroid Build Coastguard Worker        >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
1411*da0073e9SAndroid Build Coastguard Worker        >>> C = torch.tensor(7)
1412*da0073e9SAndroid Build Coastguard Worker        >>> D = torch.tensor([1, 2, 3])
1413*da0073e9SAndroid Build Coastguard Worker        >>> E = torch.tensor([[4], [5], [6]])
1414*da0073e9SAndroid Build Coastguard Worker        >>> torch.block_diag(A, B, C, D, E)
1415*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
1416*da0073e9SAndroid Build Coastguard Worker                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1417*da0073e9SAndroid Build Coastguard Worker                [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
1418*da0073e9SAndroid Build Coastguard Worker                [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
1419*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
1420*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
1421*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
1422*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
1423*da0073e9SAndroid Build Coastguard Worker                [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
1424*da0073e9SAndroid Build Coastguard Worker    """
1425*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
1426*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
1427*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(block_diag, tensors, *tensors)
1428*da0073e9SAndroid Build Coastguard Worker    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Workerdef cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
1432*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, Tensor, float, str) -> (Tensor)
1433*da0073e9SAndroid Build Coastguard Worker    r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    Args:
1436*da0073e9SAndroid Build Coastguard Worker        x1 (Tensor): input tensor of shape :math:`B \times P \times M`.
1437*da0073e9SAndroid Build Coastguard Worker        x2 (Tensor): input tensor of shape :math:`B \times R \times M`.
1438*da0073e9SAndroid Build Coastguard Worker        p: p value for the p-norm distance to calculate between each vector pair
1439*da0073e9SAndroid Build Coastguard Worker            :math:`\in [0, \infty]`.
1440*da0073e9SAndroid Build Coastguard Worker        compute_mode:
1441*da0073e9SAndroid Build Coastguard Worker            'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
1442*da0073e9SAndroid Build Coastguard Worker            euclidean distance (p = 2) if P > 25 or R > 25
1443*da0073e9SAndroid Build Coastguard Worker            'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
1444*da0073e9SAndroid Build Coastguard Worker            euclidean distance (p = 2)
1445*da0073e9SAndroid Build Coastguard Worker            'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
1446*da0073e9SAndroid Build Coastguard Worker            euclidean distance (p = 2)
1447*da0073e9SAndroid Build Coastguard Worker            Default: use_mm_for_euclid_dist_if_necessary.
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker    If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
1450*da0073e9SAndroid Build Coastguard Worker    output will have shape :math:`B \times P \times R`.
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker    This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
1453*da0073e9SAndroid Build Coastguard Worker    if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
1454*da0073e9SAndroid Build Coastguard Worker    `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
1455*da0073e9SAndroid Build Coastguard Worker    scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    Example:
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
1460*da0073e9SAndroid Build Coastguard Worker        >>> a
1461*da0073e9SAndroid Build Coastguard Worker        tensor([[ 0.9041,  0.0196],
1462*da0073e9SAndroid Build Coastguard Worker                [-0.3108, -2.4423],
1463*da0073e9SAndroid Build Coastguard Worker                [-0.4821,  1.0590]])
1464*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
1465*da0073e9SAndroid Build Coastguard Worker        >>> b
1466*da0073e9SAndroid Build Coastguard Worker        tensor([[-2.1763, -0.4713],
1467*da0073e9SAndroid Build Coastguard Worker                [-0.6986,  1.3702]])
1468*da0073e9SAndroid Build Coastguard Worker        >>> torch.cdist(a, b, p=2)
1469*da0073e9SAndroid Build Coastguard Worker        tensor([[3.1193, 2.0959],
1470*da0073e9SAndroid Build Coastguard Worker                [2.7138, 3.8322],
1471*da0073e9SAndroid Build Coastguard Worker                [2.2830, 0.3791]])
1472*da0073e9SAndroid Build Coastguard Worker    """
1473*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_variadic(x1, x2):
1474*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1475*da0073e9SAndroid Build Coastguard Worker            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
1476*da0073e9SAndroid Build Coastguard Worker        )
1477*da0073e9SAndroid Build Coastguard Worker    if compute_mode == "use_mm_for_euclid_dist_if_necessary":
1478*da0073e9SAndroid Build Coastguard Worker        return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
1479*da0073e9SAndroid Build Coastguard Worker    elif compute_mode == "use_mm_for_euclid_dist":
1480*da0073e9SAndroid Build Coastguard Worker        return _VF.cdist(x1, x2, p, 1)  # type: ignore[attr-defined]
1481*da0073e9SAndroid Build Coastguard Worker    elif compute_mode == "donot_use_mm_for_euclid_dist":
1482*da0073e9SAndroid Build Coastguard Worker        return _VF.cdist(x1, x2, p, 2)  # type: ignore[attr-defined]
1483*da0073e9SAndroid Build Coastguard Worker    else:
1484*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Workerdef atleast_1d(*tensors):
1488*da0073e9SAndroid Build Coastguard Worker    r"""
1489*da0073e9SAndroid Build Coastguard Worker    Returns a 1-dimensional view of each input tensor with zero dimensions.
1490*da0073e9SAndroid Build Coastguard Worker    Input tensors with one or more dimensions are returned as-is.
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker    Args:
1493*da0073e9SAndroid Build Coastguard Worker        input (Tensor or list of Tensors)
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker    Returns:
1496*da0073e9SAndroid Build Coastguard Worker        output (Tensor or tuple of Tensors)
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker    Example::
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.arange(2)
1501*da0073e9SAndroid Build Coastguard Worker        >>> x
1502*da0073e9SAndroid Build Coastguard Worker        tensor([0, 1])
1503*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_1d(x)
1504*da0073e9SAndroid Build Coastguard Worker        tensor([0, 1])
1505*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(1.)
1506*da0073e9SAndroid Build Coastguard Worker        >>> x
1507*da0073e9SAndroid Build Coastguard Worker        tensor(1.)
1508*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_1d(x)
1509*da0073e9SAndroid Build Coastguard Worker        tensor([1.])
1510*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(0.5)
1511*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.tensor(1.)
1512*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_1d((x, y))
1513*da0073e9SAndroid Build Coastguard Worker        (tensor([0.5000]), tensor([1.]))
1514*da0073e9SAndroid Build Coastguard Worker    """
1515*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
1516*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
1517*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(atleast_1d, tensors, *tensors)
1518*da0073e9SAndroid Build Coastguard Worker    if len(tensors) == 1:
1519*da0073e9SAndroid Build Coastguard Worker        tensors = tensors[0]
1520*da0073e9SAndroid Build Coastguard Worker    return _VF.atleast_1d(tensors)  # type: ignore[attr-defined]
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Workerdef atleast_2d(*tensors):
1524*da0073e9SAndroid Build Coastguard Worker    r"""
1525*da0073e9SAndroid Build Coastguard Worker    Returns a 2-dimensional view of each input tensor with zero dimensions.
1526*da0073e9SAndroid Build Coastguard Worker    Input tensors with two or more dimensions are returned as-is.
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker    Args:
1529*da0073e9SAndroid Build Coastguard Worker        input (Tensor or list of Tensors)
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker    Returns:
1532*da0073e9SAndroid Build Coastguard Worker        output (Tensor or tuple of Tensors)
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker    Example::
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(1.)
1537*da0073e9SAndroid Build Coastguard Worker        >>> x
1538*da0073e9SAndroid Build Coastguard Worker        tensor(1.)
1539*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_2d(x)
1540*da0073e9SAndroid Build Coastguard Worker        tensor([[1.]])
1541*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.arange(4).view(2, 2)
1542*da0073e9SAndroid Build Coastguard Worker        >>> x
1543*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1],
1544*da0073e9SAndroid Build Coastguard Worker                [2, 3]])
1545*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_2d(x)
1546*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1],
1547*da0073e9SAndroid Build Coastguard Worker                [2, 3]])
1548*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(0.5)
1549*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.tensor(1.)
1550*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_2d((x, y))
1551*da0073e9SAndroid Build Coastguard Worker        (tensor([[0.5000]]), tensor([[1.]]))
1552*da0073e9SAndroid Build Coastguard Worker    """
1553*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
1554*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
1555*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(atleast_2d, tensors, *tensors)
1556*da0073e9SAndroid Build Coastguard Worker    if len(tensors) == 1:
1557*da0073e9SAndroid Build Coastguard Worker        tensors = tensors[0]
1558*da0073e9SAndroid Build Coastguard Worker    return _VF.atleast_2d(tensors)  # type: ignore[attr-defined]
1559*da0073e9SAndroid Build Coastguard Worker
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Workerdef atleast_3d(*tensors):
1562*da0073e9SAndroid Build Coastguard Worker    r"""
1563*da0073e9SAndroid Build Coastguard Worker    Returns a 3-dimensional view of each input tensor with zero dimensions.
1564*da0073e9SAndroid Build Coastguard Worker    Input tensors with three or more dimensions are returned as-is.
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker    Args:
1567*da0073e9SAndroid Build Coastguard Worker        input (Tensor or list of Tensors)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker    Returns:
1570*da0073e9SAndroid Build Coastguard Worker        output (Tensor or tuple of Tensors)
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker    Example:
1573*da0073e9SAndroid Build Coastguard Worker
1574*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(0.5)
1575*da0073e9SAndroid Build Coastguard Worker        >>> x
1576*da0073e9SAndroid Build Coastguard Worker        tensor(0.5000)
1577*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_3d(x)
1578*da0073e9SAndroid Build Coastguard Worker        tensor([[[0.5000]]])
1579*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.arange(4).view(2, 2)
1580*da0073e9SAndroid Build Coastguard Worker        >>> y
1581*da0073e9SAndroid Build Coastguard Worker        tensor([[0, 1],
1582*da0073e9SAndroid Build Coastguard Worker                [2, 3]])
1583*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_3d(y)
1584*da0073e9SAndroid Build Coastguard Worker        tensor([[[0],
1585*da0073e9SAndroid Build Coastguard Worker                 [1]],
1586*da0073e9SAndroid Build Coastguard Worker                <BLANKLINE>
1587*da0073e9SAndroid Build Coastguard Worker                [[2],
1588*da0073e9SAndroid Build Coastguard Worker                 [3]]])
1589*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(1).view(1, 1, 1)
1590*da0073e9SAndroid Build Coastguard Worker        >>> x
1591*da0073e9SAndroid Build Coastguard Worker        tensor([[[1]]])
1592*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_3d(x)
1593*da0073e9SAndroid Build Coastguard Worker        tensor([[[1]]])
1594*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor(0.5)
1595*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.tensor(1.0)
1596*da0073e9SAndroid Build Coastguard Worker        >>> torch.atleast_3d((x, y))
1597*da0073e9SAndroid Build Coastguard Worker        (tensor([[[0.5000]]]), tensor([[[1.]]]))
1598*da0073e9SAndroid Build Coastguard Worker    """
1599*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
1600*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(tensors):
1601*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(atleast_3d, tensors, *tensors)
1602*da0073e9SAndroid Build Coastguard Worker    if len(tensors) == 1:
1603*da0073e9SAndroid Build Coastguard Worker        tensors = tensors[0]
1604*da0073e9SAndroid Build Coastguard Worker    return _VF.atleast_3d(tensors)  # type: ignore[attr-defined]
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
1608*da0073e9SAndroid Build Coastguard Worker    pass
1609*da0073e9SAndroid Build Coastguard Worker    # There's no good way to use this type annotation; cannot rename norm() to
1610*da0073e9SAndroid Build Coastguard Worker    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
1611*da0073e9SAndroid Build Coastguard Worker    # for mypy for now.
1612*da0073e9SAndroid Build Coastguard Worker    #    def norm(input: Tensor,
1613*da0073e9SAndroid Build Coastguard Worker    #             p: Optional[Union[str, Number]] = "fro",
1614*da0073e9SAndroid Build Coastguard Worker    #             dim: Optional[Union[int, List[int]]] = None,
1615*da0073e9SAndroid Build Coastguard Worker    #             keepdim: bool = False,
1616*da0073e9SAndroid Build Coastguard Worker    #             out: Optional[Tensor] = None,
1617*da0073e9SAndroid Build Coastguard Worker    #             dtype: _dtype = None) -> Tensor:
1618*da0073e9SAndroid Build Coastguard Worker    #        return _norm_impl(input, p, dim, keepdim, out, dtype)
1619*da0073e9SAndroid Build Coastguard Workerelse:
1620*da0073e9SAndroid Build Coastguard Worker    # TODO: type dim as BroadcastingList when
1621*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/33782 is fixed
1622*da0073e9SAndroid Build Coastguard Worker    @overload
1623*da0073e9SAndroid Build Coastguard Worker    def norm(
1624*da0073e9SAndroid Build Coastguard Worker        input,
1625*da0073e9SAndroid Build Coastguard Worker        p="fro",
1626*da0073e9SAndroid Build Coastguard Worker        dim=None,
1627*da0073e9SAndroid Build Coastguard Worker        keepdim=False,
1628*da0073e9SAndroid Build Coastguard Worker        out=None,
1629*da0073e9SAndroid Build Coastguard Worker        dtype=None,
1630*da0073e9SAndroid Build Coastguard Worker    ):
1631*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1632*da0073e9SAndroid Build Coastguard Worker        pass
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker    @overload
1635*da0073e9SAndroid Build Coastguard Worker    def norm(  # noqa: F811
1636*da0073e9SAndroid Build Coastguard Worker        input,
1637*da0073e9SAndroid Build Coastguard Worker        p="fro",
1638*da0073e9SAndroid Build Coastguard Worker        dim=None,
1639*da0073e9SAndroid Build Coastguard Worker        keepdim=False,
1640*da0073e9SAndroid Build Coastguard Worker        out=None,
1641*da0073e9SAndroid Build Coastguard Worker        dtype=None,
1642*da0073e9SAndroid Build Coastguard Worker    ):
1643*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
1644*da0073e9SAndroid Build Coastguard Worker        pass
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker    @overload
1647*da0073e9SAndroid Build Coastguard Worker    def norm(  # noqa: F811
1648*da0073e9SAndroid Build Coastguard Worker        input,
1649*da0073e9SAndroid Build Coastguard Worker        p="fro",
1650*da0073e9SAndroid Build Coastguard Worker        dim=None,
1651*da0073e9SAndroid Build Coastguard Worker        keepdim=False,
1652*da0073e9SAndroid Build Coastguard Worker        out=None,
1653*da0073e9SAndroid Build Coastguard Worker        dtype=None,
1654*da0073e9SAndroid Build Coastguard Worker    ):
1655*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1656*da0073e9SAndroid Build Coastguard Worker        pass
1657*da0073e9SAndroid Build Coastguard Worker
1658*da0073e9SAndroid Build Coastguard Worker    @overload
1659*da0073e9SAndroid Build Coastguard Worker    def norm(  # noqa: F811
1660*da0073e9SAndroid Build Coastguard Worker        input,
1661*da0073e9SAndroid Build Coastguard Worker        p="fro",
1662*da0073e9SAndroid Build Coastguard Worker        dim=None,
1663*da0073e9SAndroid Build Coastguard Worker        keepdim=False,
1664*da0073e9SAndroid Build Coastguard Worker        out=None,
1665*da0073e9SAndroid Build Coastguard Worker        dtype=None,
1666*da0073e9SAndroid Build Coastguard Worker    ):
1667*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
1668*da0073e9SAndroid Build Coastguard Worker        pass
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Workerdef norm(  # noqa: F811
1672*da0073e9SAndroid Build Coastguard Worker    input,
1673*da0073e9SAndroid Build Coastguard Worker    p: Optional[Union[float, str]] = "fro",
1674*da0073e9SAndroid Build Coastguard Worker    dim=None,
1675*da0073e9SAndroid Build Coastguard Worker    keepdim=False,
1676*da0073e9SAndroid Build Coastguard Worker    out=None,
1677*da0073e9SAndroid Build Coastguard Worker    dtype=None,
1678*da0073e9SAndroid Build Coastguard Worker):
1679*da0073e9SAndroid Build Coastguard Worker    r"""Returns the matrix norm or vector norm of a given tensor.
1680*da0073e9SAndroid Build Coastguard Worker
1681*da0073e9SAndroid Build Coastguard Worker    .. warning::
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker        torch.norm is deprecated and may be removed in a future PyTorch release.
1684*da0073e9SAndroid Build Coastguard Worker        Its documentation and behavior may be incorrect, and it is no longer
1685*da0073e9SAndroid Build Coastguard Worker        actively maintained.
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker        Use :func:`torch.linalg.vector_norm` when computing vector norms and
1688*da0073e9SAndroid Build Coastguard Worker        :func:`torch.linalg.matrix_norm` when computing matrix norms.
1689*da0073e9SAndroid Build Coastguard Worker        For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
1690*da0073e9SAndroid Build Coastguard Worker        Note, however, the signature for these functions is slightly different than the
1691*da0073e9SAndroid Build Coastguard Worker        signature for ``torch.norm``.
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker    Args:
1694*da0073e9SAndroid Build Coastguard Worker        input (Tensor): The input tensor. Its data type must be either a floating
1695*da0073e9SAndroid Build Coastguard Worker            point or complex type. For complex inputs, the norm is calculated using the
1696*da0073e9SAndroid Build Coastguard Worker            absolute value of each element. If the input is complex and neither
1697*da0073e9SAndroid Build Coastguard Worker            :attr:`dtype` nor :attr:`out` is specified, the result's data type will
1698*da0073e9SAndroid Build Coastguard Worker            be the corresponding floating point type (e.g. float if :attr:`input` is
1699*da0073e9SAndroid Build Coastguard Worker            complexfloat).
1700*da0073e9SAndroid Build Coastguard Worker
1701*da0073e9SAndroid Build Coastguard Worker        p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
1702*da0073e9SAndroid Build Coastguard Worker            The following norms can be calculated:
1703*da0073e9SAndroid Build Coastguard Worker
1704*da0073e9SAndroid Build Coastguard Worker            ======  ==============  ==========================
1705*da0073e9SAndroid Build Coastguard Worker            ord     matrix norm     vector norm
1706*da0073e9SAndroid Build Coastguard Worker            ======  ==============  ==========================
1707*da0073e9SAndroid Build Coastguard Worker            'fro'   Frobenius norm  --
1708*da0073e9SAndroid Build Coastguard Worker            'nuc'   nuclear norm    --
1709*da0073e9SAndroid Build Coastguard Worker            Number  --              sum(abs(x)**ord)**(1./ord)
1710*da0073e9SAndroid Build Coastguard Worker            ======  ==============  ==========================
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker            The vector norm can be calculated across any number of dimensions.
1713*da0073e9SAndroid Build Coastguard Worker            The corresponding dimensions of :attr:`input` are flattened into
1714*da0073e9SAndroid Build Coastguard Worker            one dimension, and the norm is calculated on the flattened
1715*da0073e9SAndroid Build Coastguard Worker            dimension.
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker            Frobenius norm produces the same result as ``p=2`` in all cases
1718*da0073e9SAndroid Build Coastguard Worker            except when :attr:`dim` is a list of three or more dims, in which
1719*da0073e9SAndroid Build Coastguard Worker            case Frobenius norm throws an error.
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker            Nuclear norm can only be calculated across exactly two dimensions.
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker        dim (int, tuple of ints, list of ints, optional):
1724*da0073e9SAndroid Build Coastguard Worker            Specifies which dimension or dimensions of :attr:`input` to
1725*da0073e9SAndroid Build Coastguard Worker            calculate the norm across. If :attr:`dim` is ``None``, the norm will
1726*da0073e9SAndroid Build Coastguard Worker            be calculated across all dimensions of :attr:`input`. If the norm
1727*da0073e9SAndroid Build Coastguard Worker            type indicated by :attr:`p` does not support the specified number of
1728*da0073e9SAndroid Build Coastguard Worker            dimensions, an error will occur.
1729*da0073e9SAndroid Build Coastguard Worker        keepdim (bool, optional): whether the output tensors have :attr:`dim`
1730*da0073e9SAndroid Build Coastguard Worker            retained or not. Ignored if :attr:`dim` = ``None`` and
1731*da0073e9SAndroid Build Coastguard Worker            :attr:`out` = ``None``. Default: ``False``
1732*da0073e9SAndroid Build Coastguard Worker        out (Tensor, optional): the output tensor. Ignored if
1733*da0073e9SAndroid Build Coastguard Worker            :attr:`dim` = ``None`` and :attr:`out` = ``None``.
1734*da0073e9SAndroid Build Coastguard Worker        dtype (:class:`torch.dtype`, optional): the desired data type of
1735*da0073e9SAndroid Build Coastguard Worker            returned tensor. If specified, the input tensor is casted to
1736*da0073e9SAndroid Build Coastguard Worker            :attr:`dtype` while performing the operation. Default: None.
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker    .. note::
1739*da0073e9SAndroid Build Coastguard Worker        Even though ``p='fro'`` supports any number of dimensions, the true
1740*da0073e9SAndroid Build Coastguard Worker        mathematical definition of Frobenius norm only applies to tensors with
1741*da0073e9SAndroid Build Coastguard Worker        exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
1742*da0073e9SAndroid Build Coastguard Worker        aligns with the mathematical definition, since it can only be applied across
1743*da0073e9SAndroid Build Coastguard Worker        exactly two dimensions.
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker    Example::
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        >>> import torch
1748*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.arange(9, dtype= torch.float) - 4
1749*da0073e9SAndroid Build Coastguard Worker        >>> b = a.reshape((3, 3))
1750*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(a)
1751*da0073e9SAndroid Build Coastguard Worker        tensor(7.7460)
1752*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(b)
1753*da0073e9SAndroid Build Coastguard Worker        tensor(7.7460)
1754*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(a, float('inf'))
1755*da0073e9SAndroid Build Coastguard Worker        tensor(4.)
1756*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(b, float('inf'))
1757*da0073e9SAndroid Build Coastguard Worker        tensor(4.)
1758*da0073e9SAndroid Build Coastguard Worker        >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
1759*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(c, dim=0)
1760*da0073e9SAndroid Build Coastguard Worker        tensor([1.4142, 2.2361, 5.0000])
1761*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(c, dim=1)
1762*da0073e9SAndroid Build Coastguard Worker        tensor([3.7417, 4.2426])
1763*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(c, p=1, dim=1)
1764*da0073e9SAndroid Build Coastguard Worker        tensor([6., 6.])
1765*da0073e9SAndroid Build Coastguard Worker        >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
1766*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(d, dim=(1, 2))
1767*da0073e9SAndroid Build Coastguard Worker        tensor([ 3.7417, 11.2250])
1768*da0073e9SAndroid Build Coastguard Worker        >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
1769*da0073e9SAndroid Build Coastguard Worker        (tensor(3.7417), tensor(11.2250))
1770*da0073e9SAndroid Build Coastguard Worker    """
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(input):
1773*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
1774*da0073e9SAndroid Build Coastguard Worker            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
1775*da0073e9SAndroid Build Coastguard Worker        )
1776*da0073e9SAndroid Build Coastguard Worker
1777*da0073e9SAndroid Build Coastguard Worker    # NB. All the repeated code and weird python is to please TorchScript.
1778*da0073e9SAndroid Build Coastguard Worker    #     For a more compact implementation see the relevant function in `_refs/__init__.py`
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker    # We don't do this for MPS or sparse tensors
1781*da0073e9SAndroid Build Coastguard Worker    if input.layout == torch.strided and input.device.type in (
1782*da0073e9SAndroid Build Coastguard Worker        "cpu",
1783*da0073e9SAndroid Build Coastguard Worker        "cuda",
1784*da0073e9SAndroid Build Coastguard Worker        "meta",
1785*da0073e9SAndroid Build Coastguard Worker        torch.utils.backend_registration._privateuse1_backend_name,
1786*da0073e9SAndroid Build Coastguard Worker    ):
1787*da0073e9SAndroid Build Coastguard Worker        if dim is not None:
1788*da0073e9SAndroid Build Coastguard Worker            if isinstance(dim, (int, torch.SymInt)):
1789*da0073e9SAndroid Build Coastguard Worker                _dim = [dim]
1790*da0073e9SAndroid Build Coastguard Worker            else:
1791*da0073e9SAndroid Build Coastguard Worker                _dim = dim
1792*da0073e9SAndroid Build Coastguard Worker        else:
1793*da0073e9SAndroid Build Coastguard Worker            _dim = None  # type: ignore[assignment]
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker        if isinstance(p, str):
1796*da0073e9SAndroid Build Coastguard Worker            if p == "fro" and (
1797*da0073e9SAndroid Build Coastguard Worker                dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
1798*da0073e9SAndroid Build Coastguard Worker            ):
1799*da0073e9SAndroid Build Coastguard Worker                if out is None:
1800*da0073e9SAndroid Build Coastguard Worker                    return torch.linalg.vector_norm(
1801*da0073e9SAndroid Build Coastguard Worker                        input, 2, _dim, keepdim, dtype=dtype
1802*da0073e9SAndroid Build Coastguard Worker                    )
1803*da0073e9SAndroid Build Coastguard Worker                else:
1804*da0073e9SAndroid Build Coastguard Worker                    return torch.linalg.vector_norm(
1805*da0073e9SAndroid Build Coastguard Worker                        input, 2, _dim, keepdim, dtype=dtype, out=out
1806*da0073e9SAndroid Build Coastguard Worker                    )
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker            # Here we either call the nuclear norm, or we call matrix_norm with some arguments
1809*da0073e9SAndroid Build Coastguard Worker            # that will throw an error
1810*da0073e9SAndroid Build Coastguard Worker            if _dim is None:
1811*da0073e9SAndroid Build Coastguard Worker                _dim = list(range(input.ndim))
1812*da0073e9SAndroid Build Coastguard Worker            if out is None:
1813*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
1814*da0073e9SAndroid Build Coastguard Worker            else:
1815*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.matrix_norm(
1816*da0073e9SAndroid Build Coastguard Worker                    input, p, _dim, keepdim, dtype=dtype, out=out
1817*da0073e9SAndroid Build Coastguard Worker                )
1818*da0073e9SAndroid Build Coastguard Worker        else:
1819*da0073e9SAndroid Build Coastguard Worker            # NB. p should be Union[str, number], not Optional!
1820*da0073e9SAndroid Build Coastguard Worker            _p = 2.0 if p is None else p
1821*da0073e9SAndroid Build Coastguard Worker            if out is None:
1822*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
1823*da0073e9SAndroid Build Coastguard Worker            else:
1824*da0073e9SAndroid Build Coastguard Worker                return torch.linalg.vector_norm(
1825*da0073e9SAndroid Build Coastguard Worker                    input, _p, _dim, keepdim, dtype=dtype, out=out
1826*da0073e9SAndroid Build Coastguard Worker                )
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker    ndim = input.dim()
1829*da0073e9SAndroid Build Coastguard Worker
1830*da0073e9SAndroid Build Coastguard Worker    # catch default case
1831*da0073e9SAndroid Build Coastguard Worker    if dim is None and out is None and dtype is None and p is not None:
1832*da0073e9SAndroid Build Coastguard Worker        if isinstance(p, str):
1833*da0073e9SAndroid Build Coastguard Worker            if p == "fro":
1834*da0073e9SAndroid Build Coastguard Worker                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
1835*da0073e9SAndroid Build Coastguard Worker        if not isinstance(p, str):
1836*da0073e9SAndroid Build Coastguard Worker            _dim = list(range(ndim))
1837*da0073e9SAndroid Build Coastguard Worker            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore[attr-defined]
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker    # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
1840*da0073e9SAndroid Build Coastguard Worker    # remove the overloads where dim is an int and replace with BraodcastingList1
1841*da0073e9SAndroid Build Coastguard Worker    # and remove next four lines, replace _dim with dim
1842*da0073e9SAndroid Build Coastguard Worker    if dim is not None:
1843*da0073e9SAndroid Build Coastguard Worker        if isinstance(dim, (int, torch.SymInt)):
1844*da0073e9SAndroid Build Coastguard Worker            _dim = [dim]
1845*da0073e9SAndroid Build Coastguard Worker        else:
1846*da0073e9SAndroid Build Coastguard Worker            _dim = dim
1847*da0073e9SAndroid Build Coastguard Worker    else:
1848*da0073e9SAndroid Build Coastguard Worker        _dim = None  # type: ignore[assignment]
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker    if isinstance(p, str):
1851*da0073e9SAndroid Build Coastguard Worker        if p == "fro":
1852*da0073e9SAndroid Build Coastguard Worker            if dtype is not None:
1853*da0073e9SAndroid Build Coastguard Worker                raise ValueError("dtype argument is not supported in frobenius norm")
1854*da0073e9SAndroid Build Coastguard Worker
1855*da0073e9SAndroid Build Coastguard Worker            if _dim is None:
1856*da0073e9SAndroid Build Coastguard Worker                _dim = list(range(ndim))
1857*da0073e9SAndroid Build Coastguard Worker            if out is None:
1858*da0073e9SAndroid Build Coastguard Worker                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1859*da0073e9SAndroid Build Coastguard Worker            else:
1860*da0073e9SAndroid Build Coastguard Worker                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1861*da0073e9SAndroid Build Coastguard Worker        elif p == "nuc":
1862*da0073e9SAndroid Build Coastguard Worker            if dtype is not None:
1863*da0073e9SAndroid Build Coastguard Worker                raise ValueError("dtype argument is not supported in nuclear norm")
1864*da0073e9SAndroid Build Coastguard Worker            if _dim is None:
1865*da0073e9SAndroid Build Coastguard Worker                if out is None:
1866*da0073e9SAndroid Build Coastguard Worker                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore[arg-type]
1867*da0073e9SAndroid Build Coastguard Worker                else:
1868*da0073e9SAndroid Build Coastguard Worker                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1869*da0073e9SAndroid Build Coastguard Worker            else:
1870*da0073e9SAndroid Build Coastguard Worker                if out is None:
1871*da0073e9SAndroid Build Coastguard Worker                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore[arg-type]
1872*da0073e9SAndroid Build Coastguard Worker                else:
1873*da0073e9SAndroid Build Coastguard Worker                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore[arg-type]
1874*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
1875*da0073e9SAndroid Build Coastguard Worker    else:
1876*da0073e9SAndroid Build Coastguard Worker        if _dim is None:
1877*da0073e9SAndroid Build Coastguard Worker            _dim = list(range(ndim))
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        if out is None:
1880*da0073e9SAndroid Build Coastguard Worker            if dtype is None:
1881*da0073e9SAndroid Build Coastguard Worker                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
1882*da0073e9SAndroid Build Coastguard Worker            else:
1883*da0073e9SAndroid Build Coastguard Worker                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore[attr-defined]
1884*da0073e9SAndroid Build Coastguard Worker        else:
1885*da0073e9SAndroid Build Coastguard Worker            if dtype is None:
1886*da0073e9SAndroid Build Coastguard Worker                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore[attr-defined]
1887*da0073e9SAndroid Build Coastguard Worker            else:
1888*da0073e9SAndroid Build Coastguard Worker                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore[attr-defined]
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker
1891*da0073e9SAndroid Build Coastguard Workerdef unravel_index(
1892*da0073e9SAndroid Build Coastguard Worker    indices: Tensor,
1893*da0073e9SAndroid Build Coastguard Worker    shape: Union[int, Sequence[int], torch.Size],
1894*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, ...]:
1895*da0073e9SAndroid Build Coastguard Worker    r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
1896*da0073e9SAndroid Build Coastguard Worker    index into an arbitrary tensor of the specified shape.
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker    Args:
1899*da0073e9SAndroid Build Coastguard Worker        indices (Tensor): An integer tensor containing indices into the
1900*da0073e9SAndroid Build Coastguard Worker            flattened version of an arbitrary tensor of shape :attr:`shape`.
1901*da0073e9SAndroid Build Coastguard Worker            All elements must be in the range ``[0, prod(shape) - 1]``.
1902*da0073e9SAndroid Build Coastguard Worker
1903*da0073e9SAndroid Build Coastguard Worker        shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
1904*da0073e9SAndroid Build Coastguard Worker            tensor. All elements must be non-negative.
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker    Returns:
1907*da0073e9SAndroid Build Coastguard Worker        tuple of Tensors: Each ``i``-th tensor in the output corresponds with
1908*da0073e9SAndroid Build Coastguard Worker        dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
1909*da0073e9SAndroid Build Coastguard Worker        ``indices`` and contains one index into dimension ``i`` for each of the
1910*da0073e9SAndroid Build Coastguard Worker        flat indices given by ``indices``.
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker    Example::
1913*da0073e9SAndroid Build Coastguard Worker
1914*da0073e9SAndroid Build Coastguard Worker        >>> import torch
1915*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor(4), (3, 2))
1916*da0073e9SAndroid Build Coastguard Worker        (tensor(2),
1917*da0073e9SAndroid Build Coastguard Worker         tensor(0))
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
1920*da0073e9SAndroid Build Coastguard Worker        (tensor([2, 0]),
1921*da0073e9SAndroid Build Coastguard Worker         tensor([0, 1]))
1922*da0073e9SAndroid Build Coastguard Worker
1923*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
1924*da0073e9SAndroid Build Coastguard Worker        (tensor([0, 0, 1, 1, 2, 2]),
1925*da0073e9SAndroid Build Coastguard Worker         tensor([0, 1, 0, 1, 0, 1]))
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
1928*da0073e9SAndroid Build Coastguard Worker        (tensor([1, 5]),
1929*da0073e9SAndroid Build Coastguard Worker         tensor([2, 6]),
1930*da0073e9SAndroid Build Coastguard Worker         tensor([3, 7]),
1931*da0073e9SAndroid Build Coastguard Worker         tensor([4, 8]))
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
1934*da0073e9SAndroid Build Coastguard Worker        (tensor([[1], [5]]),
1935*da0073e9SAndroid Build Coastguard Worker         tensor([[2], [6]]),
1936*da0073e9SAndroid Build Coastguard Worker         tensor([[3], [7]]),
1937*da0073e9SAndroid Build Coastguard Worker         tensor([[4], [8]]))
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker        >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
1940*da0073e9SAndroid Build Coastguard Worker        (tensor([[12], [56]]),
1941*da0073e9SAndroid Build Coastguard Worker         tensor([[34], [78]]))
1942*da0073e9SAndroid Build Coastguard Worker    """
1943*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(indices):
1944*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
1945*da0073e9SAndroid Build Coastguard Worker    res_tensor = _unravel_index(indices, shape)
1946*da0073e9SAndroid Build Coastguard Worker    return res_tensor.unbind(-1)
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Workerdef _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
1950*da0073e9SAndroid Build Coastguard Worker    torch._check_type(
1951*da0073e9SAndroid Build Coastguard Worker        not indices.is_complex()
1952*da0073e9SAndroid Build Coastguard Worker        and not indices.is_floating_point()
1953*da0073e9SAndroid Build Coastguard Worker        and not indices.dtype == torch.bool,
1954*da0073e9SAndroid Build Coastguard Worker        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
1955*da0073e9SAndroid Build Coastguard Worker    )
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker    torch._check_type(
1958*da0073e9SAndroid Build Coastguard Worker        isinstance(shape, (int, torch.SymInt, Sequence)),
1959*da0073e9SAndroid Build Coastguard Worker        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
1960*da0073e9SAndroid Build Coastguard Worker    )
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker    if isinstance(shape, (int, torch.SymInt)):
1963*da0073e9SAndroid Build Coastguard Worker        shape = torch.Size([shape])
1964*da0073e9SAndroid Build Coastguard Worker    else:
1965*da0073e9SAndroid Build Coastguard Worker        for dim in shape:
1966*da0073e9SAndroid Build Coastguard Worker            torch._check_type(
1967*da0073e9SAndroid Build Coastguard Worker                isinstance(dim, (int, torch.SymInt)),
1968*da0073e9SAndroid Build Coastguard Worker                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
1969*da0073e9SAndroid Build Coastguard Worker            )
1970*da0073e9SAndroid Build Coastguard Worker        shape = torch.Size(shape)
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker    torch._check_value(
1973*da0073e9SAndroid Build Coastguard Worker        all(dim >= 0 for dim in shape),
1974*da0073e9SAndroid Build Coastguard Worker        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
1975*da0073e9SAndroid Build Coastguard Worker    )
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker    coefs = list(
1978*da0073e9SAndroid Build Coastguard Worker        reversed(
1979*da0073e9SAndroid Build Coastguard Worker            list(
1980*da0073e9SAndroid Build Coastguard Worker                itertools.accumulate(
1981*da0073e9SAndroid Build Coastguard Worker                    reversed(shape[1:] + torch.Size([1])), func=operator.mul
1982*da0073e9SAndroid Build Coastguard Worker                )
1983*da0073e9SAndroid Build Coastguard Worker            )
1984*da0073e9SAndroid Build Coastguard Worker        )
1985*da0073e9SAndroid Build Coastguard Worker    )
1986*da0073e9SAndroid Build Coastguard Worker    return indices.unsqueeze(-1).floor_divide(
1987*da0073e9SAndroid Build Coastguard Worker        torch.tensor(coefs, device=indices.device, dtype=torch.int64)
1988*da0073e9SAndroid Build Coastguard Worker    ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Workerdef chain_matmul(*matrices, out=None):
1992*da0073e9SAndroid Build Coastguard Worker    r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
1993*da0073e9SAndroid Build Coastguard Worker    using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
1994*da0073e9SAndroid Build Coastguard Worker    of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
1995*da0073e9SAndroid Build Coastguard Worker    needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
1996*da0073e9SAndroid Build Coastguard Worker    If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Worker    .. warning::
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker        :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
2001*da0073e9SAndroid Build Coastguard Worker        Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
2002*da0073e9SAndroid Build Coastguard Worker        rather than multiple arguments.
2003*da0073e9SAndroid Build Coastguard Worker
2004*da0073e9SAndroid Build Coastguard Worker    Args:
2005*da0073e9SAndroid Build Coastguard Worker        matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
2006*da0073e9SAndroid Build Coastguard Worker        out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
2007*da0073e9SAndroid Build Coastguard Worker
2008*da0073e9SAndroid Build Coastguard Worker    Returns:
2009*da0073e9SAndroid Build Coastguard Worker        Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
2010*da0073e9SAndroid Build Coastguard Worker        would be of dimensions :math:`p_{1} \times p_{N + 1}`.
2011*da0073e9SAndroid Build Coastguard Worker
2012*da0073e9SAndroid Build Coastguard Worker    Example::
2013*da0073e9SAndroid Build Coastguard Worker
2014*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +SKIP
2015*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2016*da0073e9SAndroid Build Coastguard Worker        >>> a = torch.randn(3, 4)
2017*da0073e9SAndroid Build Coastguard Worker        >>> b = torch.randn(4, 5)
2018*da0073e9SAndroid Build Coastguard Worker        >>> c = torch.randn(5, 6)
2019*da0073e9SAndroid Build Coastguard Worker        >>> d = torch.randn(6, 7)
2020*da0073e9SAndroid Build Coastguard Worker        >>> # will raise a deprecation warning
2021*da0073e9SAndroid Build Coastguard Worker        >>> torch.chain_matmul(a, b, c, d)
2022*da0073e9SAndroid Build Coastguard Worker        tensor([[ -2.3375,  -3.9790,  -4.1119,  -6.6577,   9.5609, -11.5095,  -3.2614],
2023*da0073e9SAndroid Build Coastguard Worker                [ 21.4038,   3.3378,  -8.4982,  -5.2457, -10.2561,  -2.4684,   2.7163],
2024*da0073e9SAndroid Build Coastguard Worker                [ -0.9647,  -5.8917,  -2.3213,  -5.2284,  12.8615, -12.2816,  -2.5095]])
2025*da0073e9SAndroid Build Coastguard Worker
2026*da0073e9SAndroid Build Coastguard Worker    .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
2027*da0073e9SAndroid Build Coastguard Worker    """
2028*da0073e9SAndroid Build Coastguard Worker    # This wrapper exists to support variadic args.
2029*da0073e9SAndroid Build Coastguard Worker    if has_torch_function(matrices):
2030*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(chain_matmul, matrices, *matrices)
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker    if out is None:
2033*da0073e9SAndroid Build Coastguard Worker        return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
2034*da0073e9SAndroid Build Coastguard Worker    else:
2035*da0073e9SAndroid Build Coastguard Worker        return _VF.chain_matmul(matrices, out=out)  # type: ignore[attr-defined]
2036*da0073e9SAndroid Build Coastguard Worker
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Workerdef _lu_impl(A, pivot=True, get_infos=False, out=None):
2039*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor]
2040*da0073e9SAndroid Build Coastguard Worker    r"""Computes the LU factorization of a matrix or batches of matrices
2041*da0073e9SAndroid Build Coastguard Worker    :attr:`A`. Returns a tuple containing the LU factorization and
2042*da0073e9SAndroid Build Coastguard Worker    pivots of :attr:`A`.  Pivoting is done if :attr:`pivot` is set to
2043*da0073e9SAndroid Build Coastguard Worker    ``True``.
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker    .. warning::
2046*da0073e9SAndroid Build Coastguard Worker
2047*da0073e9SAndroid Build Coastguard Worker        :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
2048*da0073e9SAndroid Build Coastguard Worker        and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
2049*da0073e9SAndroid Build Coastguard Worker        future PyTorch release.
2050*da0073e9SAndroid Build Coastguard Worker        ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
2051*da0073e9SAndroid Build Coastguard Worker
2052*da0073e9SAndroid Build Coastguard Worker        .. code:: python
2053*da0073e9SAndroid Build Coastguard Worker
2054*da0073e9SAndroid Build Coastguard Worker            LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker        ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
2057*da0073e9SAndroid Build Coastguard Worker
2058*da0073e9SAndroid Build Coastguard Worker        .. code:: python
2059*da0073e9SAndroid Build Coastguard Worker
2060*da0073e9SAndroid Build Coastguard Worker            LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
2061*da0073e9SAndroid Build Coastguard Worker
2062*da0073e9SAndroid Build Coastguard Worker    .. note::
2063*da0073e9SAndroid Build Coastguard Worker        * The returned permutation matrix for every matrix in the batch is
2064*da0073e9SAndroid Build Coastguard Worker          represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
2065*da0073e9SAndroid Build Coastguard Worker          ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
2066*da0073e9SAndroid Build Coastguard Worker          the ``i``-th row was permuted with the ``j-1``-th row.
2067*da0073e9SAndroid Build Coastguard Worker        * LU factorization with :attr:`pivot` = ``False`` is not available
2068*da0073e9SAndroid Build Coastguard Worker          for CPU, and attempting to do so will throw an error. However,
2069*da0073e9SAndroid Build Coastguard Worker          LU factorization with :attr:`pivot` = ``False`` is available for
2070*da0073e9SAndroid Build Coastguard Worker          CUDA.
2071*da0073e9SAndroid Build Coastguard Worker        * This function does not check if the factorization was successful
2072*da0073e9SAndroid Build Coastguard Worker          or not if :attr:`get_infos` is ``True`` since the status of the
2073*da0073e9SAndroid Build Coastguard Worker          factorization is present in the third element of the return tuple.
2074*da0073e9SAndroid Build Coastguard Worker        * In the case of batches of square matrices with size less or equal
2075*da0073e9SAndroid Build Coastguard Worker          to 32 on a CUDA device, the LU factorization is repeated for
2076*da0073e9SAndroid Build Coastguard Worker          singular matrices due to the bug in the MAGMA library
2077*da0073e9SAndroid Build Coastguard Worker          (see magma issue 13).
2078*da0073e9SAndroid Build Coastguard Worker        * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
2079*da0073e9SAndroid Build Coastguard Worker
2080*da0073e9SAndroid Build Coastguard Worker    .. warning::
2081*da0073e9SAndroid Build Coastguard Worker        The gradients of this function will only be finite when :attr:`A` is full rank.
2082*da0073e9SAndroid Build Coastguard Worker        This is because the LU decomposition is just differentiable at full rank matrices.
2083*da0073e9SAndroid Build Coastguard Worker        Furthermore, if :attr:`A` is close to not being full rank,
2084*da0073e9SAndroid Build Coastguard Worker        the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker    Args:
2087*da0073e9SAndroid Build Coastguard Worker        A (Tensor): the tensor to factor of size :math:`(*, m, n)`
2088*da0073e9SAndroid Build Coastguard Worker        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
2089*da0073e9SAndroid Build Coastguard Worker        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
2090*da0073e9SAndroid Build Coastguard Worker                                    Default: ``False``
2091*da0073e9SAndroid Build Coastguard Worker        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
2092*da0073e9SAndroid Build Coastguard Worker                               then the elements in the tuple are Tensor, IntTensor,
2093*da0073e9SAndroid Build Coastguard Worker                               and IntTensor. If :attr:`get_infos` is ``False``, then the
2094*da0073e9SAndroid Build Coastguard Worker                               elements in the tuple are Tensor, IntTensor. Default: ``None``
2095*da0073e9SAndroid Build Coastguard Worker
2096*da0073e9SAndroid Build Coastguard Worker    Returns:
2097*da0073e9SAndroid Build Coastguard Worker        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
2098*da0073e9SAndroid Build Coastguard Worker
2099*da0073e9SAndroid Build Coastguard Worker            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker            - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
2102*da0073e9SAndroid Build Coastguard Worker              ``pivots`` stores all the intermediate transpositions of rows.
2103*da0073e9SAndroid Build Coastguard Worker              The final permutation ``perm`` could be reconstructed by
2104*da0073e9SAndroid Build Coastguard Worker              applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
2105*da0073e9SAndroid Build Coastguard Worker              where ``perm`` is initially the identity permutation of :math:`m` elements
2106*da0073e9SAndroid Build Coastguard Worker              (essentially this is what :func:`torch.lu_unpack` is doing).
2107*da0073e9SAndroid Build Coastguard Worker
2108*da0073e9SAndroid Build Coastguard Worker            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
2109*da0073e9SAndroid Build Coastguard Worker              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
2110*da0073e9SAndroid Build Coastguard Worker              each minibatch has succeeded or failed
2111*da0073e9SAndroid Build Coastguard Worker
2112*da0073e9SAndroid Build Coastguard Worker    Example::
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
2115*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
2116*da0073e9SAndroid Build Coastguard Worker        >>> A = torch.randn(2, 3, 3)
2117*da0073e9SAndroid Build Coastguard Worker        >>> A_LU, pivots = torch.lu(A)
2118*da0073e9SAndroid Build Coastguard Worker        >>> A_LU
2119*da0073e9SAndroid Build Coastguard Worker        tensor([[[ 1.3506,  2.5558, -0.0816],
2120*da0073e9SAndroid Build Coastguard Worker                 [ 0.1684,  1.1551,  0.1940],
2121*da0073e9SAndroid Build Coastguard Worker                 [ 0.1193,  0.6189, -0.5497]],
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker                [[ 0.4526,  1.2526, -0.3285],
2124*da0073e9SAndroid Build Coastguard Worker                 [-0.7988,  0.7175, -0.9701],
2125*da0073e9SAndroid Build Coastguard Worker                 [ 0.2634, -0.9255, -0.3459]]])
2126*da0073e9SAndroid Build Coastguard Worker        >>> pivots
2127*da0073e9SAndroid Build Coastguard Worker        tensor([[ 3,  3,  3],
2128*da0073e9SAndroid Build Coastguard Worker                [ 3,  3,  3]], dtype=torch.int32)
2129*da0073e9SAndroid Build Coastguard Worker        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
2130*da0073e9SAndroid Build Coastguard Worker        >>> if info.nonzero().size(0) == 0:
2131*da0073e9SAndroid Build Coastguard Worker        ...     print('LU factorization succeeded for all samples!')
2132*da0073e9SAndroid Build Coastguard Worker        LU factorization succeeded for all samples!
2133*da0073e9SAndroid Build Coastguard Worker    """
2134*da0073e9SAndroid Build Coastguard Worker    # If get_infos is True, then we don't need to check for errors and vice versa
2135*da0073e9SAndroid Build Coastguard Worker    return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
2139*da0073e9SAndroid Build Coastguard Worker    _ListOrSeq = Sequence[Tensor]
2140*da0073e9SAndroid Build Coastguard Workerelse:
2141*da0073e9SAndroid Build Coastguard Worker    _ListOrSeq = List[Tensor]
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker
2144*da0073e9SAndroid Build Coastguard Workerdef _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
2145*da0073e9SAndroid Build Coastguard Worker    get_infos_int = 1 if get_infos else 0
2146*da0073e9SAndroid Build Coastguard Worker    if out_len - get_infos_int != 2:
2147*da0073e9SAndroid Build Coastguard Worker        raise TypeError(
2148*da0073e9SAndroid Build Coastguard Worker            f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
2149*da0073e9SAndroid Build Coastguard Worker        )
2150*da0073e9SAndroid Build Coastguard Worker    if not isinstance(out, (tuple, list)):
2151*da0073e9SAndroid Build Coastguard Worker        raise TypeError(
2152*da0073e9SAndroid Build Coastguard Worker            f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
2153*da0073e9SAndroid Build Coastguard Worker        )
2154*da0073e9SAndroid Build Coastguard Worker
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Workerdef _lu_with_infos(A, pivot=True, get_infos=False, out=None):
2157*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
2158*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(A):
2159*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2160*da0073e9SAndroid Build Coastguard Worker            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
2161*da0073e9SAndroid Build Coastguard Worker        )
2162*da0073e9SAndroid Build Coastguard Worker    result = _lu_impl(A, pivot, get_infos, out)
2163*da0073e9SAndroid Build Coastguard Worker    if out is not None:
2164*da0073e9SAndroid Build Coastguard Worker        _check_list_size(len(out), get_infos, out)
2165*da0073e9SAndroid Build Coastguard Worker        for i in range(len(out)):
2166*da0073e9SAndroid Build Coastguard Worker            out[i].resize_as_(result[i]).copy_(result[i])
2167*da0073e9SAndroid Build Coastguard Worker        return out
2168*da0073e9SAndroid Build Coastguard Worker    else:
2169*da0073e9SAndroid Build Coastguard Worker        return result  # A_LU, pivots, infos
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker
2172*da0073e9SAndroid Build Coastguard Workerdef _lu_no_infos(A, pivot=True, get_infos=False, out=None):
2173*da0073e9SAndroid Build Coastguard Worker    # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
2174*da0073e9SAndroid Build Coastguard Worker    # need to check for torch_function here so that we exit if
2175*da0073e9SAndroid Build Coastguard Worker    if has_torch_function_unary(A):
2176*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(
2177*da0073e9SAndroid Build Coastguard Worker            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
2178*da0073e9SAndroid Build Coastguard Worker        )
2179*da0073e9SAndroid Build Coastguard Worker    result = _lu_impl(A, pivot, get_infos, out)
2180*da0073e9SAndroid Build Coastguard Worker    if out is not None:
2181*da0073e9SAndroid Build Coastguard Worker        _check_list_size(len(out), get_infos, out)
2182*da0073e9SAndroid Build Coastguard Worker        for i in range(len(out)):
2183*da0073e9SAndroid Build Coastguard Worker            out[i].resize_as_(result[i]).copy_(result[i])
2184*da0073e9SAndroid Build Coastguard Worker        return out
2185*da0073e9SAndroid Build Coastguard Worker    else:
2186*da0073e9SAndroid Build Coastguard Worker        return result[0], result[1]  # A_LU, pivots
2187*da0073e9SAndroid Build Coastguard Worker
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker# The return type of lu depends on `get_infos`, so in order to resolve the output type
2190*da0073e9SAndroid Build Coastguard Worker# of lu in TorchScript we need to statically know the value of `get_infos`
2191*da0073e9SAndroid Build Coastguard Workerlu = boolean_dispatch(
2192*da0073e9SAndroid Build Coastguard Worker    arg_name="get_infos",
2193*da0073e9SAndroid Build Coastguard Worker    arg_index=2,
2194*da0073e9SAndroid Build Coastguard Worker    default=False,
2195*da0073e9SAndroid Build Coastguard Worker    if_true=_lu_with_infos,
2196*da0073e9SAndroid Build Coastguard Worker    if_false=_lu_no_infos,
2197*da0073e9SAndroid Build Coastguard Worker    module_name=__name__,
2198*da0073e9SAndroid Build Coastguard Worker    func_name="lu",
2199*da0073e9SAndroid Build Coastguard Worker)
2200*da0073e9SAndroid Build Coastguard Workerlu.__doc__ = _lu_impl.__doc__
2201*da0073e9SAndroid Build Coastguard Worker
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Workerdef align_tensors(*tensors):
2204*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("`align_tensors` not yet implemented.")
2205