1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport operator 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 8*da0073e9SAndroid Build Coastguard Workerfrom torch import _VF, Tensor 9*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _add_docstr 10*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import _overload as overload, boolean_dispatch 11*da0073e9SAndroid Build Coastguard Workerfrom torch._lowrank import pca_lowrank, svd_lowrank 12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import ( 13*da0073e9SAndroid Build Coastguard Worker handle_torch_function, 14*da0073e9SAndroid Build Coastguard Worker has_torch_function, 15*da0073e9SAndroid Build Coastguard Worker has_torch_function_unary, 16*da0073e9SAndroid Build Coastguard Worker has_torch_function_variadic, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker__all__ = [ 21*da0073e9SAndroid Build Coastguard Worker "atleast_1d", 22*da0073e9SAndroid Build Coastguard Worker "atleast_2d", 23*da0073e9SAndroid Build Coastguard Worker "atleast_3d", 24*da0073e9SAndroid Build Coastguard Worker "align_tensors", 25*da0073e9SAndroid Build Coastguard Worker "broadcast_shapes", 26*da0073e9SAndroid Build Coastguard Worker "broadcast_tensors", 27*da0073e9SAndroid Build Coastguard Worker "cartesian_prod", 28*da0073e9SAndroid Build Coastguard Worker "block_diag", 29*da0073e9SAndroid Build Coastguard Worker "cdist", 30*da0073e9SAndroid Build Coastguard Worker "chain_matmul", 31*da0073e9SAndroid Build Coastguard Worker "einsum", 32*da0073e9SAndroid Build Coastguard Worker "istft", 33*da0073e9SAndroid Build Coastguard Worker "lu", 34*da0073e9SAndroid Build Coastguard Worker "norm", 35*da0073e9SAndroid Build Coastguard Worker "meshgrid", 36*da0073e9SAndroid Build Coastguard Worker "pca_lowrank", 37*da0073e9SAndroid Build Coastguard Worker "split", 38*da0073e9SAndroid Build Coastguard Worker "stft", 39*da0073e9SAndroid Build Coastguard Worker "svd_lowrank", 40*da0073e9SAndroid Build Coastguard Worker "tensordot", 41*da0073e9SAndroid Build Coastguard Worker "unique", 42*da0073e9SAndroid Build Coastguard Worker "unique_consecutive", 43*da0073e9SAndroid Build Coastguard Worker "unravel_index", 44*da0073e9SAndroid Build Coastguard Worker] 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef broadcast_tensors(*tensors): 48*da0073e9SAndroid Build Coastguard Worker r"""broadcast_tensors(*tensors) -> List of Tensors 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker Broadcasts the given tensors according to :ref:`broadcasting-semantics`. 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker Args: 53*da0073e9SAndroid Build Coastguard Worker *tensors: any number of tensors of the same type 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker .. warning:: 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker More than one element of a broadcasted tensor may refer to a single 58*da0073e9SAndroid Build Coastguard Worker memory location. As a result, in-place operations (especially ones that 59*da0073e9SAndroid Build Coastguard Worker are vectorized) may result in incorrect behavior. If you need to write 60*da0073e9SAndroid Build Coastguard Worker to the tensors, please clone them first. 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker Example:: 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker >>> x = torch.arange(3).view(1, 3) 65*da0073e9SAndroid Build Coastguard Worker >>> y = torch.arange(2).view(2, 1) 66*da0073e9SAndroid Build Coastguard Worker >>> a, b = torch.broadcast_tensors(x, y) 67*da0073e9SAndroid Build Coastguard Worker >>> a.size() 68*da0073e9SAndroid Build Coastguard Worker torch.Size([2, 3]) 69*da0073e9SAndroid Build Coastguard Worker >>> a 70*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1, 2], 71*da0073e9SAndroid Build Coastguard Worker [0, 1, 2]]) 72*da0073e9SAndroid Build Coastguard Worker """ 73*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 74*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 75*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(broadcast_tensors, tensors, *tensors) 76*da0073e9SAndroid Build Coastguard Worker return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined] 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Workerdef broadcast_shapes(*shapes): 80*da0073e9SAndroid Build Coastguard Worker r"""broadcast_shapes(*shapes) -> Size 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker Similar to :func:`broadcast_tensors` but for shapes. 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker This is equivalent to 85*da0073e9SAndroid Build Coastguard Worker ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape`` 86*da0073e9SAndroid Build Coastguard Worker but avoids the need create to intermediate tensors. This is useful for 87*da0073e9SAndroid Build Coastguard Worker broadcasting tensors of common batch shape but different rightmost shape, 88*da0073e9SAndroid Build Coastguard Worker e.g. to broadcast mean vectors with covariance matrices. 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker Example:: 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1)) 93*da0073e9SAndroid Build Coastguard Worker torch.Size([1, 3, 2]) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker Args: 96*da0073e9SAndroid Build Coastguard Worker \*shapes (torch.Size): Shapes of tensors. 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker Returns: 99*da0073e9SAndroid Build Coastguard Worker shape (torch.Size): A shape compatible with all input shapes. 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker Raises: 102*da0073e9SAndroid Build Coastguard Worker RuntimeError: If shapes are incompatible. 103*da0073e9SAndroid Build Coastguard Worker """ 104*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 105*da0073e9SAndroid Build Coastguard Worker # TODO Move this to C++ once the jit has better support for torch.Size. 106*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_tracing(): 107*da0073e9SAndroid Build Coastguard Worker max_len = 0 108*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 109*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, (int, torch.SymInt)): 110*da0073e9SAndroid Build Coastguard Worker if max_len < 1: 111*da0073e9SAndroid Build Coastguard Worker max_len = 1 112*da0073e9SAndroid Build Coastguard Worker elif isinstance(shape, (tuple, list)): 113*da0073e9SAndroid Build Coastguard Worker s = len(shape) 114*da0073e9SAndroid Build Coastguard Worker if max_len < s: 115*da0073e9SAndroid Build Coastguard Worker max_len = s 116*da0073e9SAndroid Build Coastguard Worker result = [1] * max_len 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 121*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, (int, torch.SymInt)): 122*da0073e9SAndroid Build Coastguard Worker shape = (shape,) 123*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, (tuple, list)): 124*da0073e9SAndroid Build Coastguard Worker for i in range(-1, -1 - len(shape), -1): 125*da0073e9SAndroid Build Coastguard Worker if shape[i] < 0: 126*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 127*da0073e9SAndroid Build Coastguard Worker f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})" 128*da0073e9SAndroid Build Coastguard Worker ) 129*da0073e9SAndroid Build Coastguard Worker # NB: result is initialized to 1 so this is effectively an 130*da0073e9SAndroid Build Coastguard Worker # equals one test 131*da0073e9SAndroid Build Coastguard Worker if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious( 132*da0073e9SAndroid Build Coastguard Worker shape[i] == result[i] 133*da0073e9SAndroid Build Coastguard Worker ): 134*da0073e9SAndroid Build Coastguard Worker continue 135*da0073e9SAndroid Build Coastguard Worker if result[i] != 1: 136*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 137*da0073e9SAndroid Build Coastguard Worker "Shape mismatch: objects cannot be broadcast to a single shape" 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker result[i] = shape[i] 140*da0073e9SAndroid Build Coastguard Worker else: 141*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 142*da0073e9SAndroid Build Coastguard Worker "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", 143*da0073e9SAndroid Build Coastguard Worker shape, 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker return torch.Size(result) 146*da0073e9SAndroid Build Coastguard Worker else: 147*da0073e9SAndroid Build Coastguard Worker # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail 148*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 149*da0073e9SAndroid Build Coastguard Worker scalar = torch.zeros((), device="cpu") 150*da0073e9SAndroid Build Coastguard Worker tensors = [scalar.expand(shape) for shape in shapes] 151*da0073e9SAndroid Build Coastguard Worker tensors = broadcast_tensors(*tensors) 152*da0073e9SAndroid Build Coastguard Worker return tensors[0].shape 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef split( 156*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 157*da0073e9SAndroid Build Coastguard Worker split_size_or_sections: Union[int, List[int]], 158*da0073e9SAndroid Build Coastguard Worker dim: int = 0, 159*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, ...]: 160*da0073e9SAndroid Build Coastguard Worker r"""Splits the tensor into chunks. Each chunk is a view of the original tensor. 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will 163*da0073e9SAndroid Build Coastguard Worker be split into equally sized chunks (if possible). Last chunk will be smaller if 164*da0073e9SAndroid Build Coastguard Worker the tensor size along the given dimension :attr:`dim` is not divisible by 165*da0073e9SAndroid Build Coastguard Worker :attr:`split_size`. 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split 168*da0073e9SAndroid Build Coastguard Worker into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according 169*da0073e9SAndroid Build Coastguard Worker to :attr:`split_size_or_sections`. 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker Args: 172*da0073e9SAndroid Build Coastguard Worker tensor (Tensor): tensor to split. 173*da0073e9SAndroid Build Coastguard Worker split_size_or_sections (int) or (list(int)): size of a single chunk or 174*da0073e9SAndroid Build Coastguard Worker list of sizes for each chunk 175*da0073e9SAndroid Build Coastguard Worker dim (int): dimension along which to split the tensor. 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker Example:: 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker >>> a = torch.arange(10).reshape(5, 2) 180*da0073e9SAndroid Build Coastguard Worker >>> a 181*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1], 182*da0073e9SAndroid Build Coastguard Worker [2, 3], 183*da0073e9SAndroid Build Coastguard Worker [4, 5], 184*da0073e9SAndroid Build Coastguard Worker [6, 7], 185*da0073e9SAndroid Build Coastguard Worker [8, 9]]) 186*da0073e9SAndroid Build Coastguard Worker >>> torch.split(a, 2) 187*da0073e9SAndroid Build Coastguard Worker (tensor([[0, 1], 188*da0073e9SAndroid Build Coastguard Worker [2, 3]]), 189*da0073e9SAndroid Build Coastguard Worker tensor([[4, 5], 190*da0073e9SAndroid Build Coastguard Worker [6, 7]]), 191*da0073e9SAndroid Build Coastguard Worker tensor([[8, 9]])) 192*da0073e9SAndroid Build Coastguard Worker >>> torch.split(a, [1, 4]) 193*da0073e9SAndroid Build Coastguard Worker (tensor([[0, 1]]), 194*da0073e9SAndroid Build Coastguard Worker tensor([[2, 3], 195*da0073e9SAndroid Build Coastguard Worker [4, 5], 196*da0073e9SAndroid Build Coastguard Worker [6, 7], 197*da0073e9SAndroid Build Coastguard Worker [8, 9]])) 198*da0073e9SAndroid Build Coastguard Worker """ 199*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(tensor): 200*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 201*da0073e9SAndroid Build Coastguard Worker split, (tensor,), tensor, split_size_or_sections, dim=dim 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker # Overwriting reason: 204*da0073e9SAndroid Build Coastguard Worker # This dispatches to two ATen functions depending on the type of 205*da0073e9SAndroid Build Coastguard Worker # split_size_or_sections. The branching code is in _tensor.py, which we 206*da0073e9SAndroid Build Coastguard Worker # call here. 207*da0073e9SAndroid Build Coastguard Worker return tensor.split(split_size_or_sections, dim) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Workerdef einsum(*args: Any) -> Tensor: 211*da0073e9SAndroid Build Coastguard Worker r"""einsum(equation, *operands) -> Tensor 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation 214*da0073e9SAndroid Build Coastguard Worker based on the Einstein summation convention. 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them 217*da0073e9SAndroid Build Coastguard Worker in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of 218*da0073e9SAndroid Build Coastguard Worker this format are described below, but the general idea is to label every dimension of the input :attr:`operands` 219*da0073e9SAndroid Build Coastguard Worker with some subscript and define which subscripts are part of the output. The output is then computed by summing 220*da0073e9SAndroid Build Coastguard Worker the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the 221*da0073e9SAndroid Build Coastguard Worker output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`. 222*da0073e9SAndroid Build Coastguard Worker Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why). 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker Equation: 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of 227*da0073e9SAndroid Build Coastguard Worker the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a 228*da0073e9SAndroid Build Coastguard Worker comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript 229*da0073e9SAndroid Build Coastguard Worker must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is 230*da0073e9SAndroid Build Coastguard Worker repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand 231*da0073e9SAndroid Build Coastguard Worker must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that 232*da0073e9SAndroid Build Coastguard Worker appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order. 233*da0073e9SAndroid Build Coastguard Worker The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based 234*da0073e9SAndroid Build Coastguard Worker on the subscripts, and then summing out the dimensions whose subscripts are not part of the output. 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation 237*da0073e9SAndroid Build Coastguard Worker followed by the subscripts for the output. For instance, the following equation computes the transpose of a 238*da0073e9SAndroid Build Coastguard Worker matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and 239*da0073e9SAndroid Build Coastguard Worker at most once for the output. 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. 242*da0073e9SAndroid Build Coastguard Worker Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, 243*da0073e9SAndroid Build Coastguard Worker e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth 244*da0073e9SAndroid Build Coastguard Worker dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the 245*da0073e9SAndroid Build Coastguard Worker 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not 246*da0073e9SAndroid Build Coastguard Worker explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions), 247*da0073e9SAndroid Build Coastguard Worker before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements 248*da0073e9SAndroid Build Coastguard Worker batch matrix multiplication `'...ij,...jk'`. 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, 251*da0073e9SAndroid Build Coastguard Worker arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands. 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker .. note:: 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions 256*da0073e9SAndroid Build Coastguard Worker covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output. 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker .. note:: 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to 261*da0073e9SAndroid Build Coastguard Worker consume less memory by optimizing contraction order. This optimization occurs when there are at least three 262*da0073e9SAndroid Build Coastguard Worker inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem, 263*da0073e9SAndroid Build Coastguard Worker thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available, 264*da0073e9SAndroid Build Coastguard Worker the default order is to contract from left to right. 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path 267*da0073e9SAndroid Build Coastguard Worker calculation: `torch.backends.opt_einsum.enabled = False` 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line: 270*da0073e9SAndroid Build Coastguard Worker `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and 271*da0073e9SAndroid Build Coastguard Worker 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in 272*da0073e9SAndroid Build Coastguard Worker the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker .. note:: 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format, 277*da0073e9SAndroid Build Coastguard Worker subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists 278*da0073e9SAndroid Build Coastguard Worker follow their operands, and an extra sublist can appear at the end of the input to specify the output's 279*da0073e9SAndroid Build Coastguard Worker subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object 280*da0073e9SAndroid Build Coastguard Worker may be provided in a sublist to enable broadcasting as described in the Equation section above. 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker Args: 283*da0073e9SAndroid Build Coastguard Worker equation (str): The subscripts for the Einstein summation. 284*da0073e9SAndroid Build Coastguard Worker operands (List[Tensor]): The tensors to compute the Einstein summation of. 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker Examples:: 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 289*da0073e9SAndroid Build Coastguard Worker >>> # trace 290*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('ii', torch.randn(4, 4)) 291*da0073e9SAndroid Build Coastguard Worker tensor(-1.2104) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 294*da0073e9SAndroid Build Coastguard Worker >>> # diagonal 295*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('ii->i', torch.randn(4, 4)) 296*da0073e9SAndroid Build Coastguard Worker tensor([-0.1034, 0.7952, -0.2433, 0.4545]) 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 299*da0073e9SAndroid Build Coastguard Worker >>> # outer product 300*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(5) 301*da0073e9SAndroid Build Coastguard Worker >>> y = torch.randn(4) 302*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('i,j->ij', x, y) 303*da0073e9SAndroid Build Coastguard Worker tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], 304*da0073e9SAndroid Build Coastguard Worker [-0.3744, 0.9381, 1.2685, -1.6070], 305*da0073e9SAndroid Build Coastguard Worker [ 0.7208, -1.8058, -2.4419, 3.0936], 306*da0073e9SAndroid Build Coastguard Worker [ 0.1713, -0.4291, -0.5802, 0.7350], 307*da0073e9SAndroid Build Coastguard Worker [ 0.5704, -1.4290, -1.9323, 2.4480]]) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 310*da0073e9SAndroid Build Coastguard Worker >>> # batch matrix multiplication 311*da0073e9SAndroid Build Coastguard Worker >>> As = torch.randn(3, 2, 5) 312*da0073e9SAndroid Build Coastguard Worker >>> Bs = torch.randn(3, 5, 4) 313*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('bij,bjk->bik', As, Bs) 314*da0073e9SAndroid Build Coastguard Worker tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], 315*da0073e9SAndroid Build Coastguard Worker [-1.6706, -0.8097, -0.8025, -2.1183]], 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker [[ 4.2239, 0.3107, -0.5756, -0.2354], 318*da0073e9SAndroid Build Coastguard Worker [-1.4558, -0.3460, 1.5087, -0.8530]], 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker [[ 2.8153, 1.8787, -4.3839, -1.2112], 321*da0073e9SAndroid Build Coastguard Worker [ 0.3728, -2.1131, 0.0921, 0.8305]]]) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 324*da0073e9SAndroid Build Coastguard Worker >>> # with sublist format and ellipsis 325*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) 326*da0073e9SAndroid Build Coastguard Worker tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], 327*da0073e9SAndroid Build Coastguard Worker [-1.6706, -0.8097, -0.8025, -2.1183]], 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker [[ 4.2239, 0.3107, -0.5756, -0.2354], 330*da0073e9SAndroid Build Coastguard Worker [-1.4558, -0.3460, 1.5087, -0.8530]], 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker [[ 2.8153, 1.8787, -4.3839, -1.2112], 333*da0073e9SAndroid Build Coastguard Worker [ 0.3728, -2.1131, 0.0921, 0.8305]]]) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker >>> # batch permute 336*da0073e9SAndroid Build Coastguard Worker >>> A = torch.randn(2, 3, 4, 5) 337*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('...ij->...ji', A).shape 338*da0073e9SAndroid Build Coastguard Worker torch.Size([2, 3, 5, 4]) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker >>> # equivalent to torch.nn.functional.bilinear 341*da0073e9SAndroid Build Coastguard Worker >>> A = torch.randn(3, 5, 4) 342*da0073e9SAndroid Build Coastguard Worker >>> l = torch.randn(2, 5) 343*da0073e9SAndroid Build Coastguard Worker >>> r = torch.randn(2, 4) 344*da0073e9SAndroid Build Coastguard Worker >>> torch.einsum('bn,anm,bm->ba', l, A, r) 345*da0073e9SAndroid Build Coastguard Worker tensor([[-0.3430, -5.2405, 0.4494], 346*da0073e9SAndroid Build Coastguard Worker [ 0.3311, 5.5201, -3.0356]]) 347*da0073e9SAndroid Build Coastguard Worker """ 348*da0073e9SAndroid Build Coastguard Worker import torch.backends.opt_einsum as opt_einsum 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 351*da0073e9SAndroid Build Coastguard Worker if len(args) < 2: 352*da0073e9SAndroid Build Coastguard Worker raise ValueError( 353*da0073e9SAndroid Build Coastguard Worker "einsum(): must specify the equation string and at least one operand, " 354*da0073e9SAndroid Build Coastguard Worker "or at least one operand and its subscripts list" 355*da0073e9SAndroid Build Coastguard Worker ) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker equation = None 358*da0073e9SAndroid Build Coastguard Worker operands = None 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker if isinstance(args[0], torch.Tensor): 361*da0073e9SAndroid Build Coastguard Worker # Convert the subscript list format which is an interleaving of operand and its subscripts 362*da0073e9SAndroid Build Coastguard Worker # list with an optional output subscripts list at the end (see documentation for more details on this) 363*da0073e9SAndroid Build Coastguard Worker # to the equation string format by creating the equation string from the subscripts list and grouping the 364*da0073e9SAndroid Build Coastguard Worker # input operands into a tensorlist (List[Tensor]). 365*da0073e9SAndroid Build Coastguard Worker def parse_subscript(n: int) -> str: 366*da0073e9SAndroid Build Coastguard Worker if n == Ellipsis: 367*da0073e9SAndroid Build Coastguard Worker return "..." 368*da0073e9SAndroid Build Coastguard Worker if n >= 0 and n < 26: 369*da0073e9SAndroid Build Coastguard Worker return chr(ord("A") + n) 370*da0073e9SAndroid Build Coastguard Worker if n >= 26 and n < 52: 371*da0073e9SAndroid Build Coastguard Worker return chr(ord("a") + n - 26) 372*da0073e9SAndroid Build Coastguard Worker raise ValueError( 373*da0073e9SAndroid Build Coastguard Worker "einsum(): subscript in subscript list is not within the valid range [0, 52)" 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker # Parse subscripts for input operands 377*da0073e9SAndroid Build Coastguard Worker equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2]) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker # Parse optional output subscripts (provided when the number of arguments is odd) 380*da0073e9SAndroid Build Coastguard Worker if len(args) % 2 == 1: 381*da0073e9SAndroid Build Coastguard Worker equation += "->" + "".join(parse_subscript(s) for s in args[-1]) 382*da0073e9SAndroid Build Coastguard Worker operands = args[:-1:2] 383*da0073e9SAndroid Build Coastguard Worker else: 384*da0073e9SAndroid Build Coastguard Worker operands = args[::2] 385*da0073e9SAndroid Build Coastguard Worker else: 386*da0073e9SAndroid Build Coastguard Worker equation = args[0] 387*da0073e9SAndroid Build Coastguard Worker operands = args[1:] 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker if has_torch_function(operands): 390*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(einsum, operands, equation, *operands) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker if len(operands) == 1 and isinstance(operands[0], (list, tuple)): 393*da0073e9SAndroid Build Coastguard Worker # the old interface of passing the operands as one list argument 394*da0073e9SAndroid Build Coastguard Worker _operands = operands[0] 395*da0073e9SAndroid Build Coastguard Worker # recurse incase operands contains value that has torch function 396*da0073e9SAndroid Build Coastguard Worker # in the original implementation this line is omitted 397*da0073e9SAndroid Build Coastguard Worker return einsum(equation, *_operands) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker if len(operands) <= 2 or not opt_einsum.enabled: 400*da0073e9SAndroid Build Coastguard Worker # the path for contracting 0 or 1 time(s) is already optimized 401*da0073e9SAndroid Build Coastguard Worker # or the user has disabled using opt_einsum 402*da0073e9SAndroid Build Coastguard Worker return _VF.einsum(equation, operands) # type: ignore[attr-defined] 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker path = None 405*da0073e9SAndroid Build Coastguard Worker if opt_einsum.is_available(): 406*da0073e9SAndroid Build Coastguard Worker _opt_einsum = opt_einsum.get_opt_einsum() 407*da0073e9SAndroid Build Coastguard Worker tupled_path = _opt_einsum.contract_path( 408*da0073e9SAndroid Build Coastguard Worker equation, *operands, optimize=opt_einsum.strategy 409*da0073e9SAndroid Build Coastguard Worker )[0] 410*da0073e9SAndroid Build Coastguard Worker # flatten path for dispatching to C++ 411*da0073e9SAndroid Build Coastguard Worker path = [item for pair in tupled_path for item in pair] 412*da0073e9SAndroid Build Coastguard Worker return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined] 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker# This wrapper exists to support variadic args. 416*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 417*da0073e9SAndroid Build Coastguard Worker # The JIT doesn't understand Union, so only add type annotation for mypy 418*da0073e9SAndroid Build Coastguard Worker def meshgrid( 419*da0073e9SAndroid Build Coastguard Worker *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None 420*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[Tensor, ...]: 421*da0073e9SAndroid Build Coastguard Worker return _meshgrid(*tensors, indexing=indexing) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Workerelse: 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]: 426*da0073e9SAndroid Build Coastguard Worker r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors. 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker This is helpful when you want to visualize data over some 429*da0073e9SAndroid Build Coastguard Worker range of inputs. See below for a plotting example. 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as 432*da0073e9SAndroid Build Coastguard Worker inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`, 433*da0073e9SAndroid Build Coastguard Worker this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots 434*da0073e9SAndroid Build Coastguard Worker G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where 435*da0073e9SAndroid Build Coastguard Worker the output :math:`G_i` is constructed by expanding :math:`T_i` 436*da0073e9SAndroid Build Coastguard Worker to the result shape. 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker .. note:: 439*da0073e9SAndroid Build Coastguard Worker 0D inputs are treated equivalently to 1D inputs of a 440*da0073e9SAndroid Build Coastguard Worker single element. 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker .. warning:: 443*da0073e9SAndroid Build Coastguard Worker `torch.meshgrid(*tensors)` currently has the same behavior 444*da0073e9SAndroid Build Coastguard Worker as calling `numpy.meshgrid(*arrays, indexing='ij')`. 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker In the future `torch.meshgrid` will transition to 447*da0073e9SAndroid Build Coastguard Worker `indexing='xy'` as the default. 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/50276 tracks 450*da0073e9SAndroid Build Coastguard Worker this issue with the goal of migrating to NumPy's behavior. 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker .. seealso:: 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker :func:`torch.cartesian_prod` has the same effect but it 455*da0073e9SAndroid Build Coastguard Worker collects the data in a tensor of vectors. 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker Args: 458*da0073e9SAndroid Build Coastguard Worker tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be 459*da0073e9SAndroid Build Coastguard Worker treated as tensors of size :math:`(1,)` automatically 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker indexing: (str, optional): the indexing mode, either "xy" 462*da0073e9SAndroid Build Coastguard Worker or "ij", defaults to "ij". See warning for future changes. 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker If "xy" is selected, the first dimension corresponds 465*da0073e9SAndroid Build Coastguard Worker to the cardinality of the second input and the second 466*da0073e9SAndroid Build Coastguard Worker dimension corresponds to the cardinality of the first 467*da0073e9SAndroid Build Coastguard Worker input. 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker If "ij" is selected, the dimensions are in the same 470*da0073e9SAndroid Build Coastguard Worker order as the cardinality of the inputs. 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker Returns: 473*da0073e9SAndroid Build Coastguard Worker seq (sequence of Tensors): If the input has :math:`N` 474*da0073e9SAndroid Build Coastguard Worker tensors of size :math:`S_0 \ldots S_{N-1}``, then the 475*da0073e9SAndroid Build Coastguard Worker output will also have :math:`N` tensors, where each tensor 476*da0073e9SAndroid Build Coastguard Worker is of shape :math:`(S_0, ..., S_{N-1})`. 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker Example:: 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor([1, 2, 3]) 481*da0073e9SAndroid Build Coastguard Worker >>> y = torch.tensor([4, 5, 6]) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker Observe the element-wise pairings across the grid, (1, 4), 484*da0073e9SAndroid Build Coastguard Worker (1, 5), ..., (3, 6). This is the same thing as the 485*da0073e9SAndroid Build Coastguard Worker cartesian product. 486*da0073e9SAndroid Build Coastguard Worker >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij') 487*da0073e9SAndroid Build Coastguard Worker >>> grid_x 488*da0073e9SAndroid Build Coastguard Worker tensor([[1, 1, 1], 489*da0073e9SAndroid Build Coastguard Worker [2, 2, 2], 490*da0073e9SAndroid Build Coastguard Worker [3, 3, 3]]) 491*da0073e9SAndroid Build Coastguard Worker >>> grid_y 492*da0073e9SAndroid Build Coastguard Worker tensor([[4, 5, 6], 493*da0073e9SAndroid Build Coastguard Worker [4, 5, 6], 494*da0073e9SAndroid Build Coastguard Worker [4, 5, 6]]) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker This correspondence can be seen when these grids are 497*da0073e9SAndroid Build Coastguard Worker stacked properly. 498*da0073e9SAndroid Build Coastguard Worker >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))), 499*da0073e9SAndroid Build Coastguard Worker ... torch.cartesian_prod(x, y)) 500*da0073e9SAndroid Build Coastguard Worker True 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker `torch.meshgrid` is commonly used to produce a grid for 503*da0073e9SAndroid Build Coastguard Worker plotting. 504*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(module:matplotlib) 505*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW) 506*da0073e9SAndroid Build Coastguard Worker >>> import matplotlib.pyplot as plt 507*da0073e9SAndroid Build Coastguard Worker >>> xs = torch.linspace(-5, 5, steps=100) 508*da0073e9SAndroid Build Coastguard Worker >>> ys = torch.linspace(-5, 5, steps=100) 509*da0073e9SAndroid Build Coastguard Worker >>> x, y = torch.meshgrid(xs, ys, indexing='xy') 510*da0073e9SAndroid Build Coastguard Worker >>> z = torch.sin(torch.sqrt(x * x + y * y)) 511*da0073e9SAndroid Build Coastguard Worker >>> ax = plt.axes(projection='3d') 512*da0073e9SAndroid Build Coastguard Worker >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy()) 513*da0073e9SAndroid Build Coastguard Worker >>> plt.show() 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker .. image:: ../_static/img/meshgrid.png 516*da0073e9SAndroid Build Coastguard Worker :width: 512 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker """ 519*da0073e9SAndroid Build Coastguard Worker return _meshgrid(*tensors, indexing=indexing) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Workerdef _meshgrid(*tensors, indexing: Optional[str]): 523*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 524*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing) 525*da0073e9SAndroid Build Coastguard Worker if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): 526*da0073e9SAndroid Build Coastguard Worker # the old interface of passing the operands as one list argument 527*da0073e9SAndroid Build Coastguard Worker tensors = tensors[0] # type: ignore[assignment] 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker # Continue allowing call of old method that takes no indexing 530*da0073e9SAndroid Build Coastguard Worker # kwarg for forward compatibility reasons. 531*da0073e9SAndroid Build Coastguard Worker # 532*da0073e9SAndroid Build Coastguard Worker # Remove this two weeks after landing. 533*da0073e9SAndroid Build Coastguard Worker kwargs = {} if indexing is None else {"indexing": indexing} 534*da0073e9SAndroid Build Coastguard Worker return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Workerdef stft( 538*da0073e9SAndroid Build Coastguard Worker input: Tensor, 539*da0073e9SAndroid Build Coastguard Worker n_fft: int, 540*da0073e9SAndroid Build Coastguard Worker hop_length: Optional[int] = None, 541*da0073e9SAndroid Build Coastguard Worker win_length: Optional[int] = None, 542*da0073e9SAndroid Build Coastguard Worker window: Optional[Tensor] = None, 543*da0073e9SAndroid Build Coastguard Worker center: bool = True, 544*da0073e9SAndroid Build Coastguard Worker pad_mode: str = "reflect", 545*da0073e9SAndroid Build Coastguard Worker normalized: bool = False, 546*da0073e9SAndroid Build Coastguard Worker onesided: Optional[bool] = None, 547*da0073e9SAndroid Build Coastguard Worker return_complex: Optional[bool] = None, 548*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 549*da0073e9SAndroid Build Coastguard Worker r"""Short-time Fourier transform (STFT). 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker .. warning:: 552*da0073e9SAndroid Build Coastguard Worker From version 1.8.0, :attr:`return_complex` must always be given 553*da0073e9SAndroid Build Coastguard Worker explicitly for real inputs and `return_complex=False` has been 554*da0073e9SAndroid Build Coastguard Worker deprecated. Strongly prefer `return_complex=True` as in a future 555*da0073e9SAndroid Build Coastguard Worker pytorch release, this function will only return complex tensors. 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker Note that :func:`torch.view_as_real` can be used to recover a real 558*da0073e9SAndroid Build Coastguard Worker tensor with an extra last dimension for real and imaginary components. 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker .. warning:: 561*da0073e9SAndroid Build Coastguard Worker From version 2.1, a warning will be provided if a :attr:`window` is 562*da0073e9SAndroid Build Coastguard Worker not specified. In a future release, this attribute will be required. 563*da0073e9SAndroid Build Coastguard Worker Not providing a window currently defaults to using a rectangular window, 564*da0073e9SAndroid Build Coastguard Worker which may result in undesirable artifacts. Consider using tapered windows, 565*da0073e9SAndroid Build Coastguard Worker such as :func:`torch.hann_window`. 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker The STFT computes the Fourier transform of short overlapping windows of the 568*da0073e9SAndroid Build Coastguard Worker input. This giving frequency components of the signal as they change over 569*da0073e9SAndroid Build Coastguard Worker time. The interface of this function is modeled after (but *not* a drop-in 570*da0073e9SAndroid Build Coastguard Worker replacement for) librosa_ stft function. 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker Ignoring the optional batch dimension, this method computes the following 575*da0073e9SAndroid Build Coastguard Worker expression: 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker .. math:: 578*da0073e9SAndroid Build Coastguard Worker X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}% 579*da0073e9SAndroid Build Coastguard Worker \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % 580*da0073e9SAndroid Build Coastguard Worker \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right), 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker where :math:`m` is the index of the sliding window, and :math:`\omega` is 583*da0073e9SAndroid Build Coastguard Worker the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``, 584*da0073e9SAndroid Build Coastguard Worker or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``. 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time 587*da0073e9SAndroid Build Coastguard Worker sequences. 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker * If :attr:`hop_length` is ``None`` (default), it is treated as equal to 590*da0073e9SAndroid Build Coastguard Worker ``floor(n_fft / 4)``. 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker * If :attr:`win_length` is ``None`` (default), it is treated as equal to 593*da0073e9SAndroid Build Coastguard Worker :attr:`n_fft`. 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from 596*da0073e9SAndroid Build Coastguard Worker :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is 597*da0073e9SAndroid Build Coastguard Worker treated as if having :math:`1` everywhere in the window. If 598*da0073e9SAndroid Build Coastguard Worker :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on 599*da0073e9SAndroid Build Coastguard Worker both sides to length :attr:`n_fft` before being applied. 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on 602*da0073e9SAndroid Build Coastguard Worker both sides so that the :math:`t`-th frame is centered at time 603*da0073e9SAndroid Build Coastguard Worker :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame 604*da0073e9SAndroid Build Coastguard Worker begins at time :math:`t \times \text{hop\_length}`. 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker * :attr:`pad_mode` determines the padding method used on :attr:`input` when 607*da0073e9SAndroid Build Coastguard Worker :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for 608*da0073e9SAndroid Build Coastguard Worker all available options. Default is ``"reflect"``. 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker * If :attr:`onesided` is ``True`` (default for real input), only values for 611*da0073e9SAndroid Build Coastguard Worker :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor 612*da0073e9SAndroid Build Coastguard Worker \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because 613*da0073e9SAndroid Build Coastguard Worker the real-to-complex Fourier transform satisfies the conjugate symmetry, 614*da0073e9SAndroid Build Coastguard Worker i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. 615*da0073e9SAndroid Build Coastguard Worker Note if the input or window tensors are complex, then :attr:`onesided` 616*da0073e9SAndroid Build Coastguard Worker output is not possible. 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker * If :attr:`normalized` is ``True`` (default is ``False``), the function 619*da0073e9SAndroid Build Coastguard Worker returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker * If :attr:`return_complex` is ``True`` (default if input is complex), the 622*da0073e9SAndroid Build Coastguard Worker return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``, 623*da0073e9SAndroid Build Coastguard Worker the output is a ``input.dim() + 2`` dimensional real tensor where the last 624*da0073e9SAndroid Build Coastguard Worker dimension represents the real and imaginary components. 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker Returns either a complex tensor of size :math:`(* \times N \times T)` if 627*da0073e9SAndroid Build Coastguard Worker :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N 628*da0073e9SAndroid Build Coastguard Worker \times T \times 2)`. Where :math:`*` is the optional batch size of 629*da0073e9SAndroid Build Coastguard Worker :attr:`input`, :math:`N` is the number of frequencies where STFT is applied 630*da0073e9SAndroid Build Coastguard Worker and :math:`T` is the total number of frames used. 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker .. warning:: 633*da0073e9SAndroid Build Coastguard Worker This function changed signature at version 0.4.1. Calling with the 634*da0073e9SAndroid Build Coastguard Worker previous signature may cause error or return incorrect result. 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker Args: 637*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional 638*da0073e9SAndroid Build Coastguard Worker batch dimension 639*da0073e9SAndroid Build Coastguard Worker n_fft (int): size of Fourier transform 640*da0073e9SAndroid Build Coastguard Worker hop_length (int, optional): the distance between neighboring sliding window 641*da0073e9SAndroid Build Coastguard Worker frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) 642*da0073e9SAndroid Build Coastguard Worker win_length (int, optional): the size of window frame and STFT filter. 643*da0073e9SAndroid Build Coastguard Worker Default: ``None`` (treated as equal to :attr:`n_fft`) 644*da0073e9SAndroid Build Coastguard Worker window (Tensor, optional): the optional window function. 645*da0073e9SAndroid Build Coastguard Worker Shape must be 1d and `<= n_fft` 646*da0073e9SAndroid Build Coastguard Worker Default: ``None`` (treated as window of all :math:`1` s) 647*da0073e9SAndroid Build Coastguard Worker center (bool, optional): whether to pad :attr:`input` on both sides so 648*da0073e9SAndroid Build Coastguard Worker that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. 649*da0073e9SAndroid Build Coastguard Worker Default: ``True`` 650*da0073e9SAndroid Build Coastguard Worker pad_mode (str, optional): controls the padding method used when 651*da0073e9SAndroid Build Coastguard Worker :attr:`center` is ``True``. Default: ``"reflect"`` 652*da0073e9SAndroid Build Coastguard Worker normalized (bool, optional): controls whether to return the normalized STFT results 653*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 654*da0073e9SAndroid Build Coastguard Worker onesided (bool, optional): controls whether to return half of results to 655*da0073e9SAndroid Build Coastguard Worker avoid redundancy for real inputs. 656*da0073e9SAndroid Build Coastguard Worker Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise. 657*da0073e9SAndroid Build Coastguard Worker return_complex (bool, optional): whether to return a complex tensor, or 658*da0073e9SAndroid Build Coastguard Worker a real tensor with an extra last dimension for the real and 659*da0073e9SAndroid Build Coastguard Worker imaginary components. 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker .. versionchanged:: 2.0 662*da0073e9SAndroid Build Coastguard Worker ``return_complex`` is now a required argument for real inputs, 663*da0073e9SAndroid Build Coastguard Worker as the default is being transitioned to ``True``. 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker .. deprecated:: 2.0 666*da0073e9SAndroid Build Coastguard Worker ``return_complex=False`` is deprecated, instead use ``return_complex=True`` 667*da0073e9SAndroid Build Coastguard Worker Note that calling :func:`torch.view_as_real` on the output will 668*da0073e9SAndroid Build Coastguard Worker recover the deprecated output format. 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker Returns: 671*da0073e9SAndroid Build Coastguard Worker Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where 672*da0073e9SAndroid Build Coastguard Worker - `B?` is an optional batch dimension from the input. 673*da0073e9SAndroid Build Coastguard Worker - `N` is the number of frequency samples, `(n_fft // 2) + 1` for 674*da0073e9SAndroid Build Coastguard Worker `onesided=True`, or otherwise `n_fft`. 675*da0073e9SAndroid Build Coastguard Worker - `T` is the number of frames, `1 + L // hop_length` 676*da0073e9SAndroid Build Coastguard Worker for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise. 677*da0073e9SAndroid Build Coastguard Worker - `C?` is an optional length-2 dimension of real and imaginary 678*da0073e9SAndroid Build Coastguard Worker components, present when `return_complex=False`. 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker """ 681*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 682*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 683*da0073e9SAndroid Build Coastguard Worker stft, 684*da0073e9SAndroid Build Coastguard Worker (input,), 685*da0073e9SAndroid Build Coastguard Worker input, 686*da0073e9SAndroid Build Coastguard Worker n_fft, 687*da0073e9SAndroid Build Coastguard Worker hop_length=hop_length, 688*da0073e9SAndroid Build Coastguard Worker win_length=win_length, 689*da0073e9SAndroid Build Coastguard Worker window=window, 690*da0073e9SAndroid Build Coastguard Worker center=center, 691*da0073e9SAndroid Build Coastguard Worker pad_mode=pad_mode, 692*da0073e9SAndroid Build Coastguard Worker normalized=normalized, 693*da0073e9SAndroid Build Coastguard Worker onesided=onesided, 694*da0073e9SAndroid Build Coastguard Worker return_complex=return_complex, 695*da0073e9SAndroid Build Coastguard Worker ) 696*da0073e9SAndroid Build Coastguard Worker # NOTE: Do not edit. This code will be removed once the forward-compatibility 697*da0073e9SAndroid Build Coastguard Worker # period is over for PR #73432 698*da0073e9SAndroid Build Coastguard Worker if center: 699*da0073e9SAndroid Build Coastguard Worker signal_dim = input.dim() 700*da0073e9SAndroid Build Coastguard Worker extended_shape = [1] * (3 - signal_dim) + list(input.size()) 701*da0073e9SAndroid Build Coastguard Worker pad = int(n_fft // 2) 702*da0073e9SAndroid Build Coastguard Worker input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) 703*da0073e9SAndroid Build Coastguard Worker input = input.view(input.shape[-signal_dim:]) 704*da0073e9SAndroid Build Coastguard Worker return _VF.stft( # type: ignore[attr-defined] 705*da0073e9SAndroid Build Coastguard Worker input, 706*da0073e9SAndroid Build Coastguard Worker n_fft, 707*da0073e9SAndroid Build Coastguard Worker hop_length, 708*da0073e9SAndroid Build Coastguard Worker win_length, 709*da0073e9SAndroid Build Coastguard Worker window, 710*da0073e9SAndroid Build Coastguard Worker normalized, 711*da0073e9SAndroid Build Coastguard Worker onesided, 712*da0073e9SAndroid Build Coastguard Worker return_complex, 713*da0073e9SAndroid Build Coastguard Worker ) 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Workeristft = _add_docstr( 717*da0073e9SAndroid Build Coastguard Worker torch.istft, 718*da0073e9SAndroid Build Coastguard Worker "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, " 719*da0073e9SAndroid Build Coastguard Worker "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n" 720*da0073e9SAndroid Build Coastguard Worker r""" 721*da0073e9SAndroid Build Coastguard WorkerInverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`. 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker.. warning:: 724*da0073e9SAndroid Build Coastguard Worker From version 2.1, a warning will be provided if a :attr:`window` is 725*da0073e9SAndroid Build Coastguard Worker not specified. In a future release, this attribute will be required. 726*da0073e9SAndroid Build Coastguard Worker Please provide the same window used in the stft call. 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard WorkerIt has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the 729*da0073e9SAndroid Build Coastguard Workerleast squares estimation of the original signal. The algorithm will check using the NOLA condition ( 730*da0073e9SAndroid Build Coastguard Workernonzero overlap). 731*da0073e9SAndroid Build Coastguard Worker 732*da0073e9SAndroid Build Coastguard WorkerImportant consideration in the parameters :attr:`window` and :attr:`center` so that the envelope 733*da0073e9SAndroid Build Coastguard Workercreated by the summation of all the windows is never zero at certain point in time. Specifically, 734*da0073e9SAndroid Build Coastguard Worker:math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`. 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard WorkerSince :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame, 737*da0073e9SAndroid Build Coastguard Worker``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False 738*da0073e9SAndroid Build Coastguard Workersince the signal isn't padded). If `length` is given in the arguments and is longer than expected, 739*da0073e9SAndroid Build Coastguard Worker``istft`` will pad zeros to the end of the returned signal. 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard WorkerIf :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc. 742*da0073e9SAndroid Build Coastguard WorkerLeft padding can be trimmed off exactly because they can be calculated but right padding cannot be 743*da0073e9SAndroid Build Coastguard Workercalculated without additional information. 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard WorkerExample: Suppose the last window is: 746*da0073e9SAndroid Build Coastguard Worker``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]`` 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard WorkerThe :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation 749*da0073e9SAndroid Build Coastguard Workerof right padding. These additional values could be zeros or a reflection of the signal so providing 750*da0073e9SAndroid Build Coastguard Worker:attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed 751*da0073e9SAndroid Build Coastguard Worker(some loss of signal). 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker[1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform," 754*da0073e9SAndroid Build Coastguard WorkerIEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard WorkerArgs: 757*da0073e9SAndroid Build Coastguard Worker input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`, 758*da0073e9SAndroid Build Coastguard Worker output. That is a complex tensor of shape `(B?, N, T)` where 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker - `B?` is an optional batch dimension 761*da0073e9SAndroid Build Coastguard Worker - `N` is the number of frequency samples, `(n_fft // 2) + 1` 762*da0073e9SAndroid Build Coastguard Worker for onesided input, or otherwise `n_fft`. 763*da0073e9SAndroid Build Coastguard Worker - `T` is the number of frames, `1 + length // hop_length` for centered stft, 764*da0073e9SAndroid Build Coastguard Worker or `1 + (length - n_fft) // hop_length` otherwise. 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker .. versionchanged:: 2.0 767*da0073e9SAndroid Build Coastguard Worker Real datatype inputs are no longer supported. Input must now have a 768*da0073e9SAndroid Build Coastguard Worker complex datatype, as returned by ``stft(..., return_complex=True)``. 769*da0073e9SAndroid Build Coastguard Worker n_fft (int): Size of Fourier transform 770*da0073e9SAndroid Build Coastguard Worker hop_length (Optional[int]): The distance between neighboring sliding window frames. 771*da0073e9SAndroid Build Coastguard Worker (Default: ``n_fft // 4``) 772*da0073e9SAndroid Build Coastguard Worker win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``) 773*da0073e9SAndroid Build Coastguard Worker window (Optional[torch.Tensor]): The optional window function. 774*da0073e9SAndroid Build Coastguard Worker Shape must be 1d and `<= n_fft` 775*da0073e9SAndroid Build Coastguard Worker (Default: ``torch.ones(win_length)``) 776*da0073e9SAndroid Build Coastguard Worker center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is 777*da0073e9SAndroid Build Coastguard Worker centered at time :math:`t \times \text{hop\_length}`. 778*da0073e9SAndroid Build Coastguard Worker (Default: ``True``) 779*da0073e9SAndroid Build Coastguard Worker normalized (bool): Whether the STFT was normalized. (Default: ``False``) 780*da0073e9SAndroid Build Coastguard Worker onesided (Optional[bool]): Whether the STFT was onesided. 781*da0073e9SAndroid Build Coastguard Worker (Default: ``True`` if `n_fft != fft_size` in the input size) 782*da0073e9SAndroid Build Coastguard Worker length (Optional[int]): The amount to trim the signal by (i.e. the 783*da0073e9SAndroid Build Coastguard Worker original signal length). Defaults to `(T - 1) * hop_length` for 784*da0073e9SAndroid Build Coastguard Worker centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T` 785*da0073e9SAndroid Build Coastguard Worker is the number of input frames. 786*da0073e9SAndroid Build Coastguard Worker return_complex (Optional[bool]): 787*da0073e9SAndroid Build Coastguard Worker Whether the output should be complex, or if the input should be 788*da0073e9SAndroid Build Coastguard Worker assumed to derive from a real signal and window. 789*da0073e9SAndroid Build Coastguard Worker Note that this is incompatible with ``onesided=True``. 790*da0073e9SAndroid Build Coastguard Worker (Default: ``False``) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard WorkerReturns: 793*da0073e9SAndroid Build Coastguard Worker Tensor: Least squares estimation of the original signal of shape `(B?, length)` where 794*da0073e9SAndroid Build Coastguard Worker `B?` is an optional batch dimension from the input tensor. 795*da0073e9SAndroid Build Coastguard Worker""", 796*da0073e9SAndroid Build Coastguard Worker) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 800*da0073e9SAndroid Build Coastguard Worker # These _impl functions return a variable number of tensors as output with 801*da0073e9SAndroid Build Coastguard Worker # __torch_function__; tuple unpacking is done already rather than being 802*da0073e9SAndroid Build Coastguard Worker # done by the caller of the _impl function 803*da0073e9SAndroid Build Coastguard Worker _unique_impl_out = Any 804*da0073e9SAndroid Build Coastguard Workerelse: 805*da0073e9SAndroid Build Coastguard Worker _unique_impl_out = Tuple[Tensor, Tensor, Tensor] 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Workerdef _unique_impl( 809*da0073e9SAndroid Build Coastguard Worker input: Tensor, 810*da0073e9SAndroid Build Coastguard Worker sorted: bool = True, 811*da0073e9SAndroid Build Coastguard Worker return_inverse: bool = False, 812*da0073e9SAndroid Build Coastguard Worker return_counts: bool = False, 813*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = None, 814*da0073e9SAndroid Build Coastguard Worker) -> _unique_impl_out: 815*da0073e9SAndroid Build Coastguard Worker r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor] 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker Returns the unique elements of the input tensor. 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that 820*da0073e9SAndroid Build Coastguard Worker this function also eliminates non-consecutive duplicate values. 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker .. note:: Currently in the CUDA implementation and the CPU implementation, 823*da0073e9SAndroid Build Coastguard Worker `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument. 824*da0073e9SAndroid Build Coastguard Worker Sorting could be slow, so if your input tensor is already sorted, it is recommended to use 825*da0073e9SAndroid Build Coastguard Worker :func:`torch.unique_consecutive` which avoids the sorting. 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker Args: 828*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 829*da0073e9SAndroid Build Coastguard Worker sorted (bool): Whether to sort the unique elements in ascending order 830*da0073e9SAndroid Build Coastguard Worker before returning as output. 831*da0073e9SAndroid Build Coastguard Worker return_inverse (bool): Whether to also return the indices for where 832*da0073e9SAndroid Build Coastguard Worker elements in the original input ended up in the returned unique list. 833*da0073e9SAndroid Build Coastguard Worker return_counts (bool): Whether to also return the counts for each unique 834*da0073e9SAndroid Build Coastguard Worker element. 835*da0073e9SAndroid Build Coastguard Worker dim (int, optional): the dimension to operate upon. If ``None``, the 836*da0073e9SAndroid Build Coastguard Worker unique of the flattened input is returned. Otherwise, each of the 837*da0073e9SAndroid Build Coastguard Worker tensors indexed by the given dimension is treated as one of the 838*da0073e9SAndroid Build Coastguard Worker elements to apply the unique operation upon. See examples for more 839*da0073e9SAndroid Build Coastguard Worker details. Default: ``None`` 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker Returns: 842*da0073e9SAndroid Build Coastguard Worker (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker - **output** (*Tensor*): the output list of unique scalar elements. 845*da0073e9SAndroid Build Coastguard Worker - **inverse_indices** (*Tensor*): (optional) if 846*da0073e9SAndroid Build Coastguard Worker :attr:`return_inverse` is True, there will be an additional 847*da0073e9SAndroid Build Coastguard Worker returned tensor (same shape as input) representing the indices 848*da0073e9SAndroid Build Coastguard Worker for where elements in the original input map to in the output; 849*da0073e9SAndroid Build Coastguard Worker otherwise, this function will only return a single tensor. 850*da0073e9SAndroid Build Coastguard Worker - **counts** (*Tensor*): (optional) if 851*da0073e9SAndroid Build Coastguard Worker :attr:`return_counts` is True, there will be an additional 852*da0073e9SAndroid Build Coastguard Worker returned tensor (same shape as output or output.size(dim), 853*da0073e9SAndroid Build Coastguard Worker if dim was specified) representing the number of occurrences 854*da0073e9SAndroid Build Coastguard Worker for each unique value or tensor. 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker Example:: 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) 859*da0073e9SAndroid Build Coastguard Worker >>> output 860*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3]) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker >>> output, inverse_indices = torch.unique( 863*da0073e9SAndroid Build Coastguard Worker ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) 864*da0073e9SAndroid Build Coastguard Worker >>> output 865*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3]) 866*da0073e9SAndroid Build Coastguard Worker >>> inverse_indices 867*da0073e9SAndroid Build Coastguard Worker tensor([0, 2, 1, 2]) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker >>> output, inverse_indices = torch.unique( 870*da0073e9SAndroid Build Coastguard Worker ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) 871*da0073e9SAndroid Build Coastguard Worker >>> output 872*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3]) 873*da0073e9SAndroid Build Coastguard Worker >>> inverse_indices 874*da0073e9SAndroid Build Coastguard Worker tensor([[0, 2], 875*da0073e9SAndroid Build Coastguard Worker [1, 2]]) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker >>> a = torch.tensor([ 878*da0073e9SAndroid Build Coastguard Worker ... [ 879*da0073e9SAndroid Build Coastguard Worker ... [1, 1, 0, 0], 880*da0073e9SAndroid Build Coastguard Worker ... [1, 1, 0, 0], 881*da0073e9SAndroid Build Coastguard Worker ... [0, 0, 1, 1], 882*da0073e9SAndroid Build Coastguard Worker ... ], 883*da0073e9SAndroid Build Coastguard Worker ... [ 884*da0073e9SAndroid Build Coastguard Worker ... [0, 0, 1, 1], 885*da0073e9SAndroid Build Coastguard Worker ... [0, 0, 1, 1], 886*da0073e9SAndroid Build Coastguard Worker ... [1, 1, 1, 1], 887*da0073e9SAndroid Build Coastguard Worker ... ], 888*da0073e9SAndroid Build Coastguard Worker ... [ 889*da0073e9SAndroid Build Coastguard Worker ... [1, 1, 0, 0], 890*da0073e9SAndroid Build Coastguard Worker ... [1, 1, 0, 0], 891*da0073e9SAndroid Build Coastguard Worker ... [0, 0, 1, 1], 892*da0073e9SAndroid Build Coastguard Worker ... ], 893*da0073e9SAndroid Build Coastguard Worker ... ]) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]` 896*da0073e9SAndroid Build Coastguard Worker >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match 897*da0073e9SAndroid Build Coastguard Worker >>> # each other, so one of them will be removed. 898*da0073e9SAndroid Build Coastguard Worker >>> (a[0, :, :] == a[2, :, :]).all() 899*da0073e9SAndroid Build Coastguard Worker tensor(True) 900*da0073e9SAndroid Build Coastguard Worker >>> a_unique_dim0 = torch.unique(a, dim=0) 901*da0073e9SAndroid Build Coastguard Worker >>> a_unique_dim0 902*da0073e9SAndroid Build Coastguard Worker tensor([[[0, 0, 1, 1], 903*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 1], 904*da0073e9SAndroid Build Coastguard Worker [1, 1, 1, 1]], 905*da0073e9SAndroid Build Coastguard Worker [[1, 1, 0, 0], 906*da0073e9SAndroid Build Coastguard Worker [1, 1, 0, 0], 907*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 1]]]) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker >>> # Notice which sub-tensors from `a` match with the sub-tensors from 910*da0073e9SAndroid Build Coastguard Worker >>> # `a_unique_dim0`: 911*da0073e9SAndroid Build Coastguard Worker >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all() 912*da0073e9SAndroid Build Coastguard Worker tensor(True) 913*da0073e9SAndroid Build Coastguard Worker >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all() 914*da0073e9SAndroid Build Coastguard Worker tensor(True) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are 917*da0073e9SAndroid Build Coastguard Worker >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of 918*da0073e9SAndroid Build Coastguard Worker >>> # them will be removed. 919*da0073e9SAndroid Build Coastguard Worker >>> (a[:, 0, :] == a[:, 1, :]).all() 920*da0073e9SAndroid Build Coastguard Worker tensor(True) 921*da0073e9SAndroid Build Coastguard Worker >>> torch.unique(a, dim=1) 922*da0073e9SAndroid Build Coastguard Worker tensor([[[0, 0, 1, 1], 923*da0073e9SAndroid Build Coastguard Worker [1, 1, 0, 0]], 924*da0073e9SAndroid Build Coastguard Worker [[1, 1, 1, 1], 925*da0073e9SAndroid Build Coastguard Worker [0, 0, 1, 1]], 926*da0073e9SAndroid Build Coastguard Worker [[0, 0, 1, 1], 927*da0073e9SAndroid Build Coastguard Worker [1, 1, 0, 0]]]) 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared. 930*da0073e9SAndroid Build Coastguard Worker >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and 931*da0073e9SAndroid Build Coastguard Worker >>> # `a[:, :, 3]` match each other as well. So in this case, two of the 932*da0073e9SAndroid Build Coastguard Worker >>> # sub-tensors will be removed. 933*da0073e9SAndroid Build Coastguard Worker >>> (a[:, :, 0] == a[:, :, 1]).all() 934*da0073e9SAndroid Build Coastguard Worker tensor(True) 935*da0073e9SAndroid Build Coastguard Worker >>> (a[:, :, 2] == a[:, :, 3]).all() 936*da0073e9SAndroid Build Coastguard Worker tensor(True) 937*da0073e9SAndroid Build Coastguard Worker >>> torch.unique(a, dim=2) 938*da0073e9SAndroid Build Coastguard Worker tensor([[[0, 1], 939*da0073e9SAndroid Build Coastguard Worker [0, 1], 940*da0073e9SAndroid Build Coastguard Worker [1, 0]], 941*da0073e9SAndroid Build Coastguard Worker [[1, 0], 942*da0073e9SAndroid Build Coastguard Worker [1, 0], 943*da0073e9SAndroid Build Coastguard Worker [1, 1]], 944*da0073e9SAndroid Build Coastguard Worker [[0, 1], 945*da0073e9SAndroid Build Coastguard Worker [0, 1], 946*da0073e9SAndroid Build Coastguard Worker [1, 0]]]) 947*da0073e9SAndroid Build Coastguard Worker """ 948*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 949*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 950*da0073e9SAndroid Build Coastguard Worker unique, 951*da0073e9SAndroid Build Coastguard Worker (input,), 952*da0073e9SAndroid Build Coastguard Worker input, 953*da0073e9SAndroid Build Coastguard Worker sorted=sorted, 954*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 955*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 956*da0073e9SAndroid Build Coastguard Worker dim=dim, 957*da0073e9SAndroid Build Coastguard Worker ) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker if dim is not None: 960*da0073e9SAndroid Build Coastguard Worker output, inverse_indices, counts = _VF.unique_dim( 961*da0073e9SAndroid Build Coastguard Worker input, 962*da0073e9SAndroid Build Coastguard Worker dim, 963*da0073e9SAndroid Build Coastguard Worker sorted=sorted, 964*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 965*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 966*da0073e9SAndroid Build Coastguard Worker ) 967*da0073e9SAndroid Build Coastguard Worker else: 968*da0073e9SAndroid Build Coastguard Worker output, inverse_indices, counts = torch._unique2( 969*da0073e9SAndroid Build Coastguard Worker input, 970*da0073e9SAndroid Build Coastguard Worker sorted=sorted, 971*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 972*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 973*da0073e9SAndroid Build Coastguard Worker ) 974*da0073e9SAndroid Build Coastguard Worker return output, inverse_indices, counts 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Workerdef _unique_consecutive_impl( 978*da0073e9SAndroid Build Coastguard Worker input: Tensor, 979*da0073e9SAndroid Build Coastguard Worker return_inverse: bool = False, 980*da0073e9SAndroid Build Coastguard Worker return_counts: bool = False, 981*da0073e9SAndroid Build Coastguard Worker dim: Optional[int] = None, 982*da0073e9SAndroid Build Coastguard Worker) -> _unique_impl_out: 983*da0073e9SAndroid Build Coastguard Worker r"""Eliminates all but the first element from every consecutive group of equivalent elements. 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker .. note:: This function is different from :func:`torch.unique` in the sense that this function 986*da0073e9SAndroid Build Coastguard Worker only eliminates consecutive duplicate values. This semantics is similar to `std::unique` 987*da0073e9SAndroid Build Coastguard Worker in C++. 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker Args: 990*da0073e9SAndroid Build Coastguard Worker input (Tensor): the input tensor 991*da0073e9SAndroid Build Coastguard Worker return_inverse (bool): Whether to also return the indices for where 992*da0073e9SAndroid Build Coastguard Worker elements in the original input ended up in the returned unique list. 993*da0073e9SAndroid Build Coastguard Worker return_counts (bool): Whether to also return the counts for each unique 994*da0073e9SAndroid Build Coastguard Worker element. 995*da0073e9SAndroid Build Coastguard Worker dim (int): the dimension to apply unique. If ``None``, the unique of the 996*da0073e9SAndroid Build Coastguard Worker flattened input is returned. default: ``None`` 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker Returns: 999*da0073e9SAndroid Build Coastguard Worker (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker - **output** (*Tensor*): the output list of unique scalar elements. 1002*da0073e9SAndroid Build Coastguard Worker - **inverse_indices** (*Tensor*): (optional) if 1003*da0073e9SAndroid Build Coastguard Worker :attr:`return_inverse` is True, there will be an additional 1004*da0073e9SAndroid Build Coastguard Worker returned tensor (same shape as input) representing the indices 1005*da0073e9SAndroid Build Coastguard Worker for where elements in the original input map to in the output; 1006*da0073e9SAndroid Build Coastguard Worker otherwise, this function will only return a single tensor. 1007*da0073e9SAndroid Build Coastguard Worker - **counts** (*Tensor*): (optional) if 1008*da0073e9SAndroid Build Coastguard Worker :attr:`return_counts` is True, there will be an additional 1009*da0073e9SAndroid Build Coastguard Worker returned tensor (same shape as output or output.size(dim), 1010*da0073e9SAndroid Build Coastguard Worker if dim was specified) representing the number of occurrences 1011*da0073e9SAndroid Build Coastguard Worker for each unique value or tensor. 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker Example:: 1014*da0073e9SAndroid Build Coastguard Worker 1015*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2]) 1016*da0073e9SAndroid Build Coastguard Worker >>> output = torch.unique_consecutive(x) 1017*da0073e9SAndroid Build Coastguard Worker >>> output 1018*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3, 1, 2]) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True) 1021*da0073e9SAndroid Build Coastguard Worker >>> output 1022*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3, 1, 2]) 1023*da0073e9SAndroid Build Coastguard Worker >>> inverse_indices 1024*da0073e9SAndroid Build Coastguard Worker tensor([0, 0, 1, 1, 2, 3, 3, 4]) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker >>> output, counts = torch.unique_consecutive(x, return_counts=True) 1027*da0073e9SAndroid Build Coastguard Worker >>> output 1028*da0073e9SAndroid Build Coastguard Worker tensor([1, 2, 3, 1, 2]) 1029*da0073e9SAndroid Build Coastguard Worker >>> counts 1030*da0073e9SAndroid Build Coastguard Worker tensor([2, 2, 1, 2, 1]) 1031*da0073e9SAndroid Build Coastguard Worker """ 1032*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1033*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1034*da0073e9SAndroid Build Coastguard Worker unique_consecutive, 1035*da0073e9SAndroid Build Coastguard Worker (input,), 1036*da0073e9SAndroid Build Coastguard Worker input, 1037*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 1038*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 1039*da0073e9SAndroid Build Coastguard Worker dim=dim, 1040*da0073e9SAndroid Build Coastguard Worker ) 1041*da0073e9SAndroid Build Coastguard Worker output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined] 1042*da0073e9SAndroid Build Coastguard Worker input, return_inverse=return_inverse, return_counts=return_counts, dim=dim 1043*da0073e9SAndroid Build Coastguard Worker ) 1044*da0073e9SAndroid Build Coastguard Worker return output, inverse_indices, counts 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Workerdef _return_counts( 1048*da0073e9SAndroid Build Coastguard Worker input, 1049*da0073e9SAndroid Build Coastguard Worker sorted=True, 1050*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1051*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1052*da0073e9SAndroid Build Coastguard Worker dim=None, 1053*da0073e9SAndroid Build Coastguard Worker): 1054*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1057*da0073e9SAndroid Build Coastguard Worker return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim) 1060*da0073e9SAndroid Build Coastguard Worker return output, counts 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Workerdef _return_output( 1064*da0073e9SAndroid Build Coastguard Worker input, 1065*da0073e9SAndroid Build Coastguard Worker sorted=True, 1066*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1067*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1068*da0073e9SAndroid Build Coastguard Worker dim=None, 1069*da0073e9SAndroid Build Coastguard Worker): 1070*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1073*da0073e9SAndroid Build Coastguard Worker return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim) 1076*da0073e9SAndroid Build Coastguard Worker return output 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Workerdef _return_inverse( 1080*da0073e9SAndroid Build Coastguard Worker input, 1081*da0073e9SAndroid Build Coastguard Worker sorted=True, 1082*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1083*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1084*da0073e9SAndroid Build Coastguard Worker dim=None, 1085*da0073e9SAndroid Build Coastguard Worker): 1086*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1089*da0073e9SAndroid Build Coastguard Worker return _unique_impl(input, sorted, return_inverse, return_counts, dim) 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker output, inverse_indices, _ = _unique_impl( 1092*da0073e9SAndroid Build Coastguard Worker input, sorted, return_inverse, return_counts, dim 1093*da0073e9SAndroid Build Coastguard Worker ) 1094*da0073e9SAndroid Build Coastguard Worker return output, inverse_indices 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker_return_inverse_false = boolean_dispatch( 1098*da0073e9SAndroid Build Coastguard Worker arg_name="return_counts", 1099*da0073e9SAndroid Build Coastguard Worker arg_index=3, 1100*da0073e9SAndroid Build Coastguard Worker default=False, 1101*da0073e9SAndroid Build Coastguard Worker if_true=_return_counts, 1102*da0073e9SAndroid Build Coastguard Worker if_false=_return_output, 1103*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1104*da0073e9SAndroid Build Coastguard Worker func_name="unique", 1105*da0073e9SAndroid Build Coastguard Worker) 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker_return_inverse_true = boolean_dispatch( 1108*da0073e9SAndroid Build Coastguard Worker arg_name="return_counts", 1109*da0073e9SAndroid Build Coastguard Worker arg_index=3, 1110*da0073e9SAndroid Build Coastguard Worker default=False, 1111*da0073e9SAndroid Build Coastguard Worker if_true=_unique_impl, 1112*da0073e9SAndroid Build Coastguard Worker if_false=_return_inverse, 1113*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1114*da0073e9SAndroid Build Coastguard Worker func_name="unique", 1115*da0073e9SAndroid Build Coastguard Worker) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker# The return type of unique depends on `return_inverse`, and `return_counts` so in order to 1118*da0073e9SAndroid Build Coastguard Worker# resolve the output type in TorchScript we need to statically know the value of both parameters 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Workerunique = boolean_dispatch( 1121*da0073e9SAndroid Build Coastguard Worker arg_name="return_inverse", 1122*da0073e9SAndroid Build Coastguard Worker arg_index=2, 1123*da0073e9SAndroid Build Coastguard Worker default=False, 1124*da0073e9SAndroid Build Coastguard Worker if_true=_return_inverse_true, 1125*da0073e9SAndroid Build Coastguard Worker if_false=_return_inverse_false, 1126*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1127*da0073e9SAndroid Build Coastguard Worker func_name="unique", 1128*da0073e9SAndroid Build Coastguard Worker) 1129*da0073e9SAndroid Build Coastguard Workerunique.__doc__ = _unique_impl.__doc__ 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_counts( 1133*da0073e9SAndroid Build Coastguard Worker input, 1134*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1135*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1136*da0073e9SAndroid Build Coastguard Worker dim=None, 1137*da0073e9SAndroid Build Coastguard Worker): 1138*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1141*da0073e9SAndroid Build Coastguard Worker return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker output, _, counts = _unique_consecutive_impl( 1144*da0073e9SAndroid Build Coastguard Worker input, return_inverse, return_counts, dim 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker return output, counts 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_output( 1150*da0073e9SAndroid Build Coastguard Worker input, 1151*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1152*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1153*da0073e9SAndroid Build Coastguard Worker dim=None, 1154*da0073e9SAndroid Build Coastguard Worker): 1155*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Optional[int]) -> Tensor 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1158*da0073e9SAndroid Build Coastguard Worker return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1161*da0073e9SAndroid Build Coastguard Worker return output 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Workerdef _consecutive_return_inverse( 1165*da0073e9SAndroid Build Coastguard Worker input, 1166*da0073e9SAndroid Build Coastguard Worker return_inverse=False, 1167*da0073e9SAndroid Build Coastguard Worker return_counts=False, 1168*da0073e9SAndroid Build Coastguard Worker dim=None, 1169*da0073e9SAndroid Build Coastguard Worker): 1170*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor] 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1173*da0073e9SAndroid Build Coastguard Worker return _unique_consecutive_impl(input, return_inverse, return_counts, dim) 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker output, inverse_indices, _ = _unique_consecutive_impl( 1176*da0073e9SAndroid Build Coastguard Worker input, return_inverse, return_counts, dim 1177*da0073e9SAndroid Build Coastguard Worker ) 1178*da0073e9SAndroid Build Coastguard Worker return output, inverse_indices 1179*da0073e9SAndroid Build Coastguard Worker 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker_consecutive_return_inverse_false = boolean_dispatch( 1182*da0073e9SAndroid Build Coastguard Worker arg_name="return_counts", 1183*da0073e9SAndroid Build Coastguard Worker arg_index=1, 1184*da0073e9SAndroid Build Coastguard Worker default=False, 1185*da0073e9SAndroid Build Coastguard Worker if_true=_consecutive_return_counts, 1186*da0073e9SAndroid Build Coastguard Worker if_false=_consecutive_return_output, 1187*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1188*da0073e9SAndroid Build Coastguard Worker func_name="unique_consecutive", 1189*da0073e9SAndroid Build Coastguard Worker) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker_consecutive_return_inverse_true = boolean_dispatch( 1192*da0073e9SAndroid Build Coastguard Worker arg_name="return_counts", 1193*da0073e9SAndroid Build Coastguard Worker arg_index=1, 1194*da0073e9SAndroid Build Coastguard Worker default=False, 1195*da0073e9SAndroid Build Coastguard Worker if_true=_unique_consecutive_impl, 1196*da0073e9SAndroid Build Coastguard Worker if_false=_consecutive_return_inverse, 1197*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1198*da0073e9SAndroid Build Coastguard Worker func_name="unique_consecutive", 1199*da0073e9SAndroid Build Coastguard Worker) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker# The return type of unique depends on `return_inverse`, and `return_counts` so in order to 1202*da0073e9SAndroid Build Coastguard Worker# resolve the output type in TorchScript we need to statically know the value of both parameters 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Workerunique_consecutive = boolean_dispatch( 1205*da0073e9SAndroid Build Coastguard Worker arg_name="return_inverse", 1206*da0073e9SAndroid Build Coastguard Worker arg_index=2, 1207*da0073e9SAndroid Build Coastguard Worker default=False, 1208*da0073e9SAndroid Build Coastguard Worker if_true=_consecutive_return_inverse_true, 1209*da0073e9SAndroid Build Coastguard Worker if_false=_consecutive_return_inverse_false, 1210*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 1211*da0073e9SAndroid Build Coastguard Worker func_name="unique_consecutive", 1212*da0073e9SAndroid Build Coastguard Worker) 1213*da0073e9SAndroid Build Coastguard Workerunique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 1216*da0073e9SAndroid Build Coastguard Worker pass 1217*da0073e9SAndroid Build Coastguard Worker # There's no good way to use this type annotation without breaking JIT 1218*da0073e9SAndroid Build Coastguard Worker # overloads. So leave untyped for mypy for now. 1219*da0073e9SAndroid Build Coastguard Workerelse: 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker @overload 1222*da0073e9SAndroid Build Coastguard Worker def tensordot( 1223*da0073e9SAndroid Build Coastguard Worker a, 1224*da0073e9SAndroid Build Coastguard Worker b, 1225*da0073e9SAndroid Build Coastguard Worker dims: int = 2, 1226*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 1227*da0073e9SAndroid Build Coastguard Worker ): 1228*da0073e9SAndroid Build Coastguard Worker pass 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker @overload 1231*da0073e9SAndroid Build Coastguard Worker def tensordot( # noqa: F811 1232*da0073e9SAndroid Build Coastguard Worker a, 1233*da0073e9SAndroid Build Coastguard Worker b, 1234*da0073e9SAndroid Build Coastguard Worker dims: Tuple[List[int], List[int]], 1235*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 1236*da0073e9SAndroid Build Coastguard Worker ): 1237*da0073e9SAndroid Build Coastguard Worker pass 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker @overload 1240*da0073e9SAndroid Build Coastguard Worker def tensordot( # noqa: F811 1241*da0073e9SAndroid Build Coastguard Worker a, 1242*da0073e9SAndroid Build Coastguard Worker b, 1243*da0073e9SAndroid Build Coastguard Worker dims: List[List[int]], 1244*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 1245*da0073e9SAndroid Build Coastguard Worker ): 1246*da0073e9SAndroid Build Coastguard Worker pass 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker @overload 1249*da0073e9SAndroid Build Coastguard Worker def tensordot( # noqa: F811 1250*da0073e9SAndroid Build Coastguard Worker a, 1251*da0073e9SAndroid Build Coastguard Worker b, 1252*da0073e9SAndroid Build Coastguard Worker dims: torch.Tensor, 1253*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 1254*da0073e9SAndroid Build Coastguard Worker ): 1255*da0073e9SAndroid Build Coastguard Worker pass 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Workerdef tensordot( # noqa: F811 1259*da0073e9SAndroid Build Coastguard Worker a, 1260*da0073e9SAndroid Build Coastguard Worker b, 1261*da0073e9SAndroid Build Coastguard Worker dims=2, 1262*da0073e9SAndroid Build Coastguard Worker out: Optional[torch.Tensor] = None, 1263*da0073e9SAndroid Build Coastguard Worker): 1264*da0073e9SAndroid Build Coastguard Worker r"""Returns a contraction of a and b over multiple dimensions. 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker :attr:`tensordot` implements a generalized matrix product. 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker Args: 1269*da0073e9SAndroid Build Coastguard Worker a (Tensor): Left tensor to contract 1270*da0073e9SAndroid Build Coastguard Worker b (Tensor): Right tensor to contract 1271*da0073e9SAndroid Build Coastguard Worker dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to 1272*da0073e9SAndroid Build Coastguard Worker contract or explicit lists of dimensions for :attr:`a` and 1273*da0073e9SAndroid Build Coastguard Worker :attr:`b` respectively 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker When called with a non-negative integer argument :attr:`dims` = :math:`d`, and 1276*da0073e9SAndroid Build Coastguard Worker the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, 1277*da0073e9SAndroid Build Coastguard Worker respectively, :func:`~torch.tensordot` computes 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker .. math:: 1280*da0073e9SAndroid Build Coastguard Worker r_{i_0,...,i_{m-d}, i_d,...,i_n} 1281*da0073e9SAndroid Build Coastguard Worker = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker When called with :attr:`dims` of the list form, the given dimensions will be contracted 1284*da0073e9SAndroid Build Coastguard Worker in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes 1285*da0073e9SAndroid Build Coastguard Worker in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted 1286*da0073e9SAndroid Build Coastguard Worker dimensions. 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker Examples:: 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker >>> a = torch.arange(60.).reshape(3, 4, 5) 1291*da0073e9SAndroid Build Coastguard Worker >>> b = torch.arange(24.).reshape(4, 3, 2) 1292*da0073e9SAndroid Build Coastguard Worker >>> torch.tensordot(a, b, dims=([1, 0], [0, 1])) 1293*da0073e9SAndroid Build Coastguard Worker tensor([[4400., 4730.], 1294*da0073e9SAndroid Build Coastguard Worker [4532., 4874.], 1295*da0073e9SAndroid Build Coastguard Worker [4664., 5018.], 1296*da0073e9SAndroid Build Coastguard Worker [4796., 5162.], 1297*da0073e9SAndroid Build Coastguard Worker [4928., 5306.]]) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 1300*da0073e9SAndroid Build Coastguard Worker >>> a = torch.randn(3, 4, 5, device='cuda') 1301*da0073e9SAndroid Build Coastguard Worker >>> b = torch.randn(4, 5, 6, device='cuda') 1302*da0073e9SAndroid Build Coastguard Worker >>> c = torch.tensordot(a, b, dims=2).cpu() 1303*da0073e9SAndroid Build Coastguard Worker tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], 1304*da0073e9SAndroid Build Coastguard Worker [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], 1305*da0073e9SAndroid Build Coastguard Worker [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker >>> a = torch.randn(3, 5, 4, 6) 1308*da0073e9SAndroid Build Coastguard Worker >>> b = torch.randn(6, 4, 5, 3) 1309*da0073e9SAndroid Build Coastguard Worker >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0])) 1310*da0073e9SAndroid Build Coastguard Worker tensor([[ 7.7193, -2.4867, -10.3204], 1311*da0073e9SAndroid Build Coastguard Worker [ 1.5513, -14.4737, -6.5113], 1312*da0073e9SAndroid Build Coastguard Worker [ -0.2850, 4.2573, -3.5997]]) 1313*da0073e9SAndroid Build Coastguard Worker """ 1314*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(a, b): 1315*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)): 1318*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1319*da0073e9SAndroid Build Coastguard Worker "tensordot expects dims to be int or " 1320*da0073e9SAndroid Build Coastguard Worker + "Tuple[List[int], List[int]] or " 1321*da0073e9SAndroid Build Coastguard Worker + "List[List[int]] containing two lists, but got " 1322*da0073e9SAndroid Build Coastguard Worker + f"dims={dims}" 1323*da0073e9SAndroid Build Coastguard Worker ) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker dims_a: List[int] = [] 1326*da0073e9SAndroid Build Coastguard Worker dims_b: List[int] = [] 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker if isinstance(dims, (tuple, list)): 1329*da0073e9SAndroid Build Coastguard Worker dims_a, dims_b = dims 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker if isinstance(dims, torch.Tensor): 1332*da0073e9SAndroid Build Coastguard Worker num_elements = dims.numel() 1333*da0073e9SAndroid Build Coastguard Worker if num_elements > 1: 1334*da0073e9SAndroid Build Coastguard Worker assert dims.size()[0] == 2 1335*da0073e9SAndroid Build Coastguard Worker dims_a = torch.jit.annotate(List[int], dims[0].tolist()) 1336*da0073e9SAndroid Build Coastguard Worker dims_b = torch.jit.annotate(List[int], dims[1].tolist()) 1337*da0073e9SAndroid Build Coastguard Worker else: 1338*da0073e9SAndroid Build Coastguard Worker dims_val = int(dims.item()) 1339*da0073e9SAndroid Build Coastguard Worker if dims_val < 0: 1340*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") 1341*da0073e9SAndroid Build Coastguard Worker dims_a = list(range(-dims_val, 0)) 1342*da0073e9SAndroid Build Coastguard Worker dims_b = list(range(dims_val)) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker if isinstance(dims, (int, torch.SymInt)): 1345*da0073e9SAndroid Build Coastguard Worker if dims < 0: 1346*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}") 1347*da0073e9SAndroid Build Coastguard Worker if dims > min(a.dim(), b.dim()): 1348*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1349*da0073e9SAndroid Build Coastguard Worker f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}" 1350*da0073e9SAndroid Build Coastguard Worker ) 1351*da0073e9SAndroid Build Coastguard Worker dims_a = list(range(-dims, 0)) 1352*da0073e9SAndroid Build Coastguard Worker dims_b = list(range(dims)) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker if out is None: 1355*da0073e9SAndroid Build Coastguard Worker return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined] 1356*da0073e9SAndroid Build Coastguard Worker else: 1357*da0073e9SAndroid Build Coastguard Worker return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined] 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Workerdef cartesian_prod(*tensors: Tensor) -> Tensor: 1361*da0073e9SAndroid Build Coastguard Worker """Do cartesian product of the given sequence of tensors. The behavior is similar to 1362*da0073e9SAndroid Build Coastguard Worker python's `itertools.product`. 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker Args: 1365*da0073e9SAndroid Build Coastguard Worker *tensors: any number of 1 dimensional tensors. 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker Returns: 1368*da0073e9SAndroid Build Coastguard Worker Tensor: A tensor equivalent to converting all the input tensors into lists, 1369*da0073e9SAndroid Build Coastguard Worker do `itertools.product` on these lists, and finally convert the resulting list 1370*da0073e9SAndroid Build Coastguard Worker into tensor. 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker Example:: 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker >>> import itertools 1375*da0073e9SAndroid Build Coastguard Worker >>> a = [1, 2, 3] 1376*da0073e9SAndroid Build Coastguard Worker >>> b = [4, 5] 1377*da0073e9SAndroid Build Coastguard Worker >>> list(itertools.product(a, b)) 1378*da0073e9SAndroid Build Coastguard Worker [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] 1379*da0073e9SAndroid Build Coastguard Worker >>> tensor_a = torch.tensor(a) 1380*da0073e9SAndroid Build Coastguard Worker >>> tensor_b = torch.tensor(b) 1381*da0073e9SAndroid Build Coastguard Worker >>> torch.cartesian_prod(tensor_a, tensor_b) 1382*da0073e9SAndroid Build Coastguard Worker tensor([[1, 4], 1383*da0073e9SAndroid Build Coastguard Worker [1, 5], 1384*da0073e9SAndroid Build Coastguard Worker [2, 4], 1385*da0073e9SAndroid Build Coastguard Worker [2, 5], 1386*da0073e9SAndroid Build Coastguard Worker [3, 4], 1387*da0073e9SAndroid Build Coastguard Worker [3, 5]]) 1388*da0073e9SAndroid Build Coastguard Worker """ 1389*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 1390*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 1391*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(cartesian_prod, tensors, *tensors) 1392*da0073e9SAndroid Build Coastguard Worker return _VF.cartesian_prod(tensors) # type: ignore[attr-defined] 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Workerdef block_diag(*tensors): 1396*da0073e9SAndroid Build Coastguard Worker """Create a block diagonal matrix from provided tensors. 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker Args: 1399*da0073e9SAndroid Build Coastguard Worker *tensors: One or more tensors with 0, 1, or 2 dimensions. 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker Returns: 1402*da0073e9SAndroid Build Coastguard Worker Tensor: A 2 dimensional tensor with all the input tensors arranged in 1403*da0073e9SAndroid Build Coastguard Worker order such that their upper left and lower right corners are 1404*da0073e9SAndroid Build Coastguard Worker diagonally adjacent. All other elements are set to 0. 1405*da0073e9SAndroid Build Coastguard Worker 1406*da0073e9SAndroid Build Coastguard Worker Example:: 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker >>> import torch 1409*da0073e9SAndroid Build Coastguard Worker >>> A = torch.tensor([[0, 1], [1, 0]]) 1410*da0073e9SAndroid Build Coastguard Worker >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]]) 1411*da0073e9SAndroid Build Coastguard Worker >>> C = torch.tensor(7) 1412*da0073e9SAndroid Build Coastguard Worker >>> D = torch.tensor([1, 2, 3]) 1413*da0073e9SAndroid Build Coastguard Worker >>> E = torch.tensor([[4], [5], [6]]) 1414*da0073e9SAndroid Build Coastguard Worker >>> torch.block_diag(A, B, C, D, E) 1415*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 1416*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 1417*da0073e9SAndroid Build Coastguard Worker [0, 0, 3, 4, 5, 0, 0, 0, 0, 0], 1418*da0073e9SAndroid Build Coastguard Worker [0, 0, 6, 7, 8, 0, 0, 0, 0, 0], 1419*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0, 7, 0, 0, 0, 0], 1420*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0, 0, 1, 2, 3, 0], 1421*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0, 0, 0, 0, 0, 4], 1422*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0, 0, 0, 0, 0, 5], 1423*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]]) 1424*da0073e9SAndroid Build Coastguard Worker """ 1425*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 1426*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 1427*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(block_diag, tensors, *tensors) 1428*da0073e9SAndroid Build Coastguard Worker return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined] 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Workerdef cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"): 1432*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor, float, str) -> (Tensor) 1433*da0073e9SAndroid Build Coastguard Worker r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker Args: 1436*da0073e9SAndroid Build Coastguard Worker x1 (Tensor): input tensor of shape :math:`B \times P \times M`. 1437*da0073e9SAndroid Build Coastguard Worker x2 (Tensor): input tensor of shape :math:`B \times R \times M`. 1438*da0073e9SAndroid Build Coastguard Worker p: p value for the p-norm distance to calculate between each vector pair 1439*da0073e9SAndroid Build Coastguard Worker :math:`\in [0, \infty]`. 1440*da0073e9SAndroid Build Coastguard Worker compute_mode: 1441*da0073e9SAndroid Build Coastguard Worker 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate 1442*da0073e9SAndroid Build Coastguard Worker euclidean distance (p = 2) if P > 25 or R > 25 1443*da0073e9SAndroid Build Coastguard Worker 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate 1444*da0073e9SAndroid Build Coastguard Worker euclidean distance (p = 2) 1445*da0073e9SAndroid Build Coastguard Worker 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate 1446*da0073e9SAndroid Build Coastguard Worker euclidean distance (p = 2) 1447*da0073e9SAndroid Build Coastguard Worker Default: use_mm_for_euclid_dist_if_necessary. 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the 1450*da0073e9SAndroid Build Coastguard Worker output will have shape :math:`B \times P \times R`. 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` 1453*da0073e9SAndroid Build Coastguard Worker if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to 1454*da0073e9SAndroid Build Coastguard Worker `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest 1455*da0073e9SAndroid Build Coastguard Worker scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker Example: 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) 1460*da0073e9SAndroid Build Coastguard Worker >>> a 1461*da0073e9SAndroid Build Coastguard Worker tensor([[ 0.9041, 0.0196], 1462*da0073e9SAndroid Build Coastguard Worker [-0.3108, -2.4423], 1463*da0073e9SAndroid Build Coastguard Worker [-0.4821, 1.0590]]) 1464*da0073e9SAndroid Build Coastguard Worker >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) 1465*da0073e9SAndroid Build Coastguard Worker >>> b 1466*da0073e9SAndroid Build Coastguard Worker tensor([[-2.1763, -0.4713], 1467*da0073e9SAndroid Build Coastguard Worker [-0.6986, 1.3702]]) 1468*da0073e9SAndroid Build Coastguard Worker >>> torch.cdist(a, b, p=2) 1469*da0073e9SAndroid Build Coastguard Worker tensor([[3.1193, 2.0959], 1470*da0073e9SAndroid Build Coastguard Worker [2.7138, 3.8322], 1471*da0073e9SAndroid Build Coastguard Worker [2.2830, 0.3791]]) 1472*da0073e9SAndroid Build Coastguard Worker """ 1473*da0073e9SAndroid Build Coastguard Worker if has_torch_function_variadic(x1, x2): 1474*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1475*da0073e9SAndroid Build Coastguard Worker cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode 1476*da0073e9SAndroid Build Coastguard Worker ) 1477*da0073e9SAndroid Build Coastguard Worker if compute_mode == "use_mm_for_euclid_dist_if_necessary": 1478*da0073e9SAndroid Build Coastguard Worker return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined] 1479*da0073e9SAndroid Build Coastguard Worker elif compute_mode == "use_mm_for_euclid_dist": 1480*da0073e9SAndroid Build Coastguard Worker return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined] 1481*da0073e9SAndroid Build Coastguard Worker elif compute_mode == "donot_use_mm_for_euclid_dist": 1482*da0073e9SAndroid Build Coastguard Worker return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined] 1483*da0073e9SAndroid Build Coastguard Worker else: 1484*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"{compute_mode} is not a valid value for compute_mode") 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Workerdef atleast_1d(*tensors): 1488*da0073e9SAndroid Build Coastguard Worker r""" 1489*da0073e9SAndroid Build Coastguard Worker Returns a 1-dimensional view of each input tensor with zero dimensions. 1490*da0073e9SAndroid Build Coastguard Worker Input tensors with one or more dimensions are returned as-is. 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker Args: 1493*da0073e9SAndroid Build Coastguard Worker input (Tensor or list of Tensors) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker Returns: 1496*da0073e9SAndroid Build Coastguard Worker output (Tensor or tuple of Tensors) 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker Example:: 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker >>> x = torch.arange(2) 1501*da0073e9SAndroid Build Coastguard Worker >>> x 1502*da0073e9SAndroid Build Coastguard Worker tensor([0, 1]) 1503*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_1d(x) 1504*da0073e9SAndroid Build Coastguard Worker tensor([0, 1]) 1505*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(1.) 1506*da0073e9SAndroid Build Coastguard Worker >>> x 1507*da0073e9SAndroid Build Coastguard Worker tensor(1.) 1508*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_1d(x) 1509*da0073e9SAndroid Build Coastguard Worker tensor([1.]) 1510*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(0.5) 1511*da0073e9SAndroid Build Coastguard Worker >>> y = torch.tensor(1.) 1512*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_1d((x, y)) 1513*da0073e9SAndroid Build Coastguard Worker (tensor([0.5000]), tensor([1.])) 1514*da0073e9SAndroid Build Coastguard Worker """ 1515*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 1516*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 1517*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(atleast_1d, tensors, *tensors) 1518*da0073e9SAndroid Build Coastguard Worker if len(tensors) == 1: 1519*da0073e9SAndroid Build Coastguard Worker tensors = tensors[0] 1520*da0073e9SAndroid Build Coastguard Worker return _VF.atleast_1d(tensors) # type: ignore[attr-defined] 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Workerdef atleast_2d(*tensors): 1524*da0073e9SAndroid Build Coastguard Worker r""" 1525*da0073e9SAndroid Build Coastguard Worker Returns a 2-dimensional view of each input tensor with zero dimensions. 1526*da0073e9SAndroid Build Coastguard Worker Input tensors with two or more dimensions are returned as-is. 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker Args: 1529*da0073e9SAndroid Build Coastguard Worker input (Tensor or list of Tensors) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker Returns: 1532*da0073e9SAndroid Build Coastguard Worker output (Tensor or tuple of Tensors) 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker Example:: 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(1.) 1537*da0073e9SAndroid Build Coastguard Worker >>> x 1538*da0073e9SAndroid Build Coastguard Worker tensor(1.) 1539*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_2d(x) 1540*da0073e9SAndroid Build Coastguard Worker tensor([[1.]]) 1541*da0073e9SAndroid Build Coastguard Worker >>> x = torch.arange(4).view(2, 2) 1542*da0073e9SAndroid Build Coastguard Worker >>> x 1543*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1], 1544*da0073e9SAndroid Build Coastguard Worker [2, 3]]) 1545*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_2d(x) 1546*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1], 1547*da0073e9SAndroid Build Coastguard Worker [2, 3]]) 1548*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(0.5) 1549*da0073e9SAndroid Build Coastguard Worker >>> y = torch.tensor(1.) 1550*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_2d((x, y)) 1551*da0073e9SAndroid Build Coastguard Worker (tensor([[0.5000]]), tensor([[1.]])) 1552*da0073e9SAndroid Build Coastguard Worker """ 1553*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 1554*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 1555*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(atleast_2d, tensors, *tensors) 1556*da0073e9SAndroid Build Coastguard Worker if len(tensors) == 1: 1557*da0073e9SAndroid Build Coastguard Worker tensors = tensors[0] 1558*da0073e9SAndroid Build Coastguard Worker return _VF.atleast_2d(tensors) # type: ignore[attr-defined] 1559*da0073e9SAndroid Build Coastguard Worker 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Workerdef atleast_3d(*tensors): 1562*da0073e9SAndroid Build Coastguard Worker r""" 1563*da0073e9SAndroid Build Coastguard Worker Returns a 3-dimensional view of each input tensor with zero dimensions. 1564*da0073e9SAndroid Build Coastguard Worker Input tensors with three or more dimensions are returned as-is. 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker Args: 1567*da0073e9SAndroid Build Coastguard Worker input (Tensor or list of Tensors) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker Returns: 1570*da0073e9SAndroid Build Coastguard Worker output (Tensor or tuple of Tensors) 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker Example: 1573*da0073e9SAndroid Build Coastguard Worker 1574*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(0.5) 1575*da0073e9SAndroid Build Coastguard Worker >>> x 1576*da0073e9SAndroid Build Coastguard Worker tensor(0.5000) 1577*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_3d(x) 1578*da0073e9SAndroid Build Coastguard Worker tensor([[[0.5000]]]) 1579*da0073e9SAndroid Build Coastguard Worker >>> y = torch.arange(4).view(2, 2) 1580*da0073e9SAndroid Build Coastguard Worker >>> y 1581*da0073e9SAndroid Build Coastguard Worker tensor([[0, 1], 1582*da0073e9SAndroid Build Coastguard Worker [2, 3]]) 1583*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_3d(y) 1584*da0073e9SAndroid Build Coastguard Worker tensor([[[0], 1585*da0073e9SAndroid Build Coastguard Worker [1]], 1586*da0073e9SAndroid Build Coastguard Worker <BLANKLINE> 1587*da0073e9SAndroid Build Coastguard Worker [[2], 1588*da0073e9SAndroid Build Coastguard Worker [3]]]) 1589*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(1).view(1, 1, 1) 1590*da0073e9SAndroid Build Coastguard Worker >>> x 1591*da0073e9SAndroid Build Coastguard Worker tensor([[[1]]]) 1592*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_3d(x) 1593*da0073e9SAndroid Build Coastguard Worker tensor([[[1]]]) 1594*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor(0.5) 1595*da0073e9SAndroid Build Coastguard Worker >>> y = torch.tensor(1.0) 1596*da0073e9SAndroid Build Coastguard Worker >>> torch.atleast_3d((x, y)) 1597*da0073e9SAndroid Build Coastguard Worker (tensor([[[0.5000]]]), tensor([[[1.]]])) 1598*da0073e9SAndroid Build Coastguard Worker """ 1599*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 1600*da0073e9SAndroid Build Coastguard Worker if has_torch_function(tensors): 1601*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(atleast_3d, tensors, *tensors) 1602*da0073e9SAndroid Build Coastguard Worker if len(tensors) == 1: 1603*da0073e9SAndroid Build Coastguard Worker tensors = tensors[0] 1604*da0073e9SAndroid Build Coastguard Worker return _VF.atleast_3d(tensors) # type: ignore[attr-defined] 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 1608*da0073e9SAndroid Build Coastguard Worker pass 1609*da0073e9SAndroid Build Coastguard Worker # There's no good way to use this type annotation; cannot rename norm() to 1610*da0073e9SAndroid Build Coastguard Worker # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped 1611*da0073e9SAndroid Build Coastguard Worker # for mypy for now. 1612*da0073e9SAndroid Build Coastguard Worker # def norm(input: Tensor, 1613*da0073e9SAndroid Build Coastguard Worker # p: Optional[Union[str, Number]] = "fro", 1614*da0073e9SAndroid Build Coastguard Worker # dim: Optional[Union[int, List[int]]] = None, 1615*da0073e9SAndroid Build Coastguard Worker # keepdim: bool = False, 1616*da0073e9SAndroid Build Coastguard Worker # out: Optional[Tensor] = None, 1617*da0073e9SAndroid Build Coastguard Worker # dtype: _dtype = None) -> Tensor: 1618*da0073e9SAndroid Build Coastguard Worker # return _norm_impl(input, p, dim, keepdim, out, dtype) 1619*da0073e9SAndroid Build Coastguard Workerelse: 1620*da0073e9SAndroid Build Coastguard Worker # TODO: type dim as BroadcastingList when 1621*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/33782 is fixed 1622*da0073e9SAndroid Build Coastguard Worker @overload 1623*da0073e9SAndroid Build Coastguard Worker def norm( 1624*da0073e9SAndroid Build Coastguard Worker input, 1625*da0073e9SAndroid Build Coastguard Worker p="fro", 1626*da0073e9SAndroid Build Coastguard Worker dim=None, 1627*da0073e9SAndroid Build Coastguard Worker keepdim=False, 1628*da0073e9SAndroid Build Coastguard Worker out=None, 1629*da0073e9SAndroid Build Coastguard Worker dtype=None, 1630*da0073e9SAndroid Build Coastguard Worker ): 1631*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor 1632*da0073e9SAndroid Build Coastguard Worker pass 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker @overload 1635*da0073e9SAndroid Build Coastguard Worker def norm( # noqa: F811 1636*da0073e9SAndroid Build Coastguard Worker input, 1637*da0073e9SAndroid Build Coastguard Worker p="fro", 1638*da0073e9SAndroid Build Coastguard Worker dim=None, 1639*da0073e9SAndroid Build Coastguard Worker keepdim=False, 1640*da0073e9SAndroid Build Coastguard Worker out=None, 1641*da0073e9SAndroid Build Coastguard Worker dtype=None, 1642*da0073e9SAndroid Build Coastguard Worker ): 1643*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor 1644*da0073e9SAndroid Build Coastguard Worker pass 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker @overload 1647*da0073e9SAndroid Build Coastguard Worker def norm( # noqa: F811 1648*da0073e9SAndroid Build Coastguard Worker input, 1649*da0073e9SAndroid Build Coastguard Worker p="fro", 1650*da0073e9SAndroid Build Coastguard Worker dim=None, 1651*da0073e9SAndroid Build Coastguard Worker keepdim=False, 1652*da0073e9SAndroid Build Coastguard Worker out=None, 1653*da0073e9SAndroid Build Coastguard Worker dtype=None, 1654*da0073e9SAndroid Build Coastguard Worker ): 1655*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor 1656*da0073e9SAndroid Build Coastguard Worker pass 1657*da0073e9SAndroid Build Coastguard Worker 1658*da0073e9SAndroid Build Coastguard Worker @overload 1659*da0073e9SAndroid Build Coastguard Worker def norm( # noqa: F811 1660*da0073e9SAndroid Build Coastguard Worker input, 1661*da0073e9SAndroid Build Coastguard Worker p="fro", 1662*da0073e9SAndroid Build Coastguard Worker dim=None, 1663*da0073e9SAndroid Build Coastguard Worker keepdim=False, 1664*da0073e9SAndroid Build Coastguard Worker out=None, 1665*da0073e9SAndroid Build Coastguard Worker dtype=None, 1666*da0073e9SAndroid Build Coastguard Worker ): 1667*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor 1668*da0073e9SAndroid Build Coastguard Worker pass 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Workerdef norm( # noqa: F811 1672*da0073e9SAndroid Build Coastguard Worker input, 1673*da0073e9SAndroid Build Coastguard Worker p: Optional[Union[float, str]] = "fro", 1674*da0073e9SAndroid Build Coastguard Worker dim=None, 1675*da0073e9SAndroid Build Coastguard Worker keepdim=False, 1676*da0073e9SAndroid Build Coastguard Worker out=None, 1677*da0073e9SAndroid Build Coastguard Worker dtype=None, 1678*da0073e9SAndroid Build Coastguard Worker): 1679*da0073e9SAndroid Build Coastguard Worker r"""Returns the matrix norm or vector norm of a given tensor. 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker .. warning:: 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker torch.norm is deprecated and may be removed in a future PyTorch release. 1684*da0073e9SAndroid Build Coastguard Worker Its documentation and behavior may be incorrect, and it is no longer 1685*da0073e9SAndroid Build Coastguard Worker actively maintained. 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker Use :func:`torch.linalg.vector_norm` when computing vector norms and 1688*da0073e9SAndroid Build Coastguard Worker :func:`torch.linalg.matrix_norm` when computing matrix norms. 1689*da0073e9SAndroid Build Coastguard Worker For a function with a similar behavior as this one see :func:`torch.linalg.norm`. 1690*da0073e9SAndroid Build Coastguard Worker Note, however, the signature for these functions is slightly different than the 1691*da0073e9SAndroid Build Coastguard Worker signature for ``torch.norm``. 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker Args: 1694*da0073e9SAndroid Build Coastguard Worker input (Tensor): The input tensor. Its data type must be either a floating 1695*da0073e9SAndroid Build Coastguard Worker point or complex type. For complex inputs, the norm is calculated using the 1696*da0073e9SAndroid Build Coastguard Worker absolute value of each element. If the input is complex and neither 1697*da0073e9SAndroid Build Coastguard Worker :attr:`dtype` nor :attr:`out` is specified, the result's data type will 1698*da0073e9SAndroid Build Coastguard Worker be the corresponding floating point type (e.g. float if :attr:`input` is 1699*da0073e9SAndroid Build Coastguard Worker complexfloat). 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` 1702*da0073e9SAndroid Build Coastguard Worker The following norms can be calculated: 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker ====== ============== ========================== 1705*da0073e9SAndroid Build Coastguard Worker ord matrix norm vector norm 1706*da0073e9SAndroid Build Coastguard Worker ====== ============== ========================== 1707*da0073e9SAndroid Build Coastguard Worker 'fro' Frobenius norm -- 1708*da0073e9SAndroid Build Coastguard Worker 'nuc' nuclear norm -- 1709*da0073e9SAndroid Build Coastguard Worker Number -- sum(abs(x)**ord)**(1./ord) 1710*da0073e9SAndroid Build Coastguard Worker ====== ============== ========================== 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker The vector norm can be calculated across any number of dimensions. 1713*da0073e9SAndroid Build Coastguard Worker The corresponding dimensions of :attr:`input` are flattened into 1714*da0073e9SAndroid Build Coastguard Worker one dimension, and the norm is calculated on the flattened 1715*da0073e9SAndroid Build Coastguard Worker dimension. 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker Frobenius norm produces the same result as ``p=2`` in all cases 1718*da0073e9SAndroid Build Coastguard Worker except when :attr:`dim` is a list of three or more dims, in which 1719*da0073e9SAndroid Build Coastguard Worker case Frobenius norm throws an error. 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker Nuclear norm can only be calculated across exactly two dimensions. 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker dim (int, tuple of ints, list of ints, optional): 1724*da0073e9SAndroid Build Coastguard Worker Specifies which dimension or dimensions of :attr:`input` to 1725*da0073e9SAndroid Build Coastguard Worker calculate the norm across. If :attr:`dim` is ``None``, the norm will 1726*da0073e9SAndroid Build Coastguard Worker be calculated across all dimensions of :attr:`input`. If the norm 1727*da0073e9SAndroid Build Coastguard Worker type indicated by :attr:`p` does not support the specified number of 1728*da0073e9SAndroid Build Coastguard Worker dimensions, an error will occur. 1729*da0073e9SAndroid Build Coastguard Worker keepdim (bool, optional): whether the output tensors have :attr:`dim` 1730*da0073e9SAndroid Build Coastguard Worker retained or not. Ignored if :attr:`dim` = ``None`` and 1731*da0073e9SAndroid Build Coastguard Worker :attr:`out` = ``None``. Default: ``False`` 1732*da0073e9SAndroid Build Coastguard Worker out (Tensor, optional): the output tensor. Ignored if 1733*da0073e9SAndroid Build Coastguard Worker :attr:`dim` = ``None`` and :attr:`out` = ``None``. 1734*da0073e9SAndroid Build Coastguard Worker dtype (:class:`torch.dtype`, optional): the desired data type of 1735*da0073e9SAndroid Build Coastguard Worker returned tensor. If specified, the input tensor is casted to 1736*da0073e9SAndroid Build Coastguard Worker :attr:`dtype` while performing the operation. Default: None. 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker .. note:: 1739*da0073e9SAndroid Build Coastguard Worker Even though ``p='fro'`` supports any number of dimensions, the true 1740*da0073e9SAndroid Build Coastguard Worker mathematical definition of Frobenius norm only applies to tensors with 1741*da0073e9SAndroid Build Coastguard Worker exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'`` 1742*da0073e9SAndroid Build Coastguard Worker aligns with the mathematical definition, since it can only be applied across 1743*da0073e9SAndroid Build Coastguard Worker exactly two dimensions. 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker Example:: 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker >>> import torch 1748*da0073e9SAndroid Build Coastguard Worker >>> a = torch.arange(9, dtype= torch.float) - 4 1749*da0073e9SAndroid Build Coastguard Worker >>> b = a.reshape((3, 3)) 1750*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(a) 1751*da0073e9SAndroid Build Coastguard Worker tensor(7.7460) 1752*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(b) 1753*da0073e9SAndroid Build Coastguard Worker tensor(7.7460) 1754*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(a, float('inf')) 1755*da0073e9SAndroid Build Coastguard Worker tensor(4.) 1756*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(b, float('inf')) 1757*da0073e9SAndroid Build Coastguard Worker tensor(4.) 1758*da0073e9SAndroid Build Coastguard Worker >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float) 1759*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(c, dim=0) 1760*da0073e9SAndroid Build Coastguard Worker tensor([1.4142, 2.2361, 5.0000]) 1761*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(c, dim=1) 1762*da0073e9SAndroid Build Coastguard Worker tensor([3.7417, 4.2426]) 1763*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(c, p=1, dim=1) 1764*da0073e9SAndroid Build Coastguard Worker tensor([6., 6.]) 1765*da0073e9SAndroid Build Coastguard Worker >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) 1766*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(d, dim=(1, 2)) 1767*da0073e9SAndroid Build Coastguard Worker tensor([ 3.7417, 11.2250]) 1768*da0073e9SAndroid Build Coastguard Worker >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) 1769*da0073e9SAndroid Build Coastguard Worker (tensor(3.7417), tensor(11.2250)) 1770*da0073e9SAndroid Build Coastguard Worker """ 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(input): 1773*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 1774*da0073e9SAndroid Build Coastguard Worker norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype 1775*da0073e9SAndroid Build Coastguard Worker ) 1776*da0073e9SAndroid Build Coastguard Worker 1777*da0073e9SAndroid Build Coastguard Worker # NB. All the repeated code and weird python is to please TorchScript. 1778*da0073e9SAndroid Build Coastguard Worker # For a more compact implementation see the relevant function in `_refs/__init__.py` 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker # We don't do this for MPS or sparse tensors 1781*da0073e9SAndroid Build Coastguard Worker if input.layout == torch.strided and input.device.type in ( 1782*da0073e9SAndroid Build Coastguard Worker "cpu", 1783*da0073e9SAndroid Build Coastguard Worker "cuda", 1784*da0073e9SAndroid Build Coastguard Worker "meta", 1785*da0073e9SAndroid Build Coastguard Worker torch.utils.backend_registration._privateuse1_backend_name, 1786*da0073e9SAndroid Build Coastguard Worker ): 1787*da0073e9SAndroid Build Coastguard Worker if dim is not None: 1788*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, (int, torch.SymInt)): 1789*da0073e9SAndroid Build Coastguard Worker _dim = [dim] 1790*da0073e9SAndroid Build Coastguard Worker else: 1791*da0073e9SAndroid Build Coastguard Worker _dim = dim 1792*da0073e9SAndroid Build Coastguard Worker else: 1793*da0073e9SAndroid Build Coastguard Worker _dim = None # type: ignore[assignment] 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker if isinstance(p, str): 1796*da0073e9SAndroid Build Coastguard Worker if p == "fro" and ( 1797*da0073e9SAndroid Build Coastguard Worker dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2 1798*da0073e9SAndroid Build Coastguard Worker ): 1799*da0073e9SAndroid Build Coastguard Worker if out is None: 1800*da0073e9SAndroid Build Coastguard Worker return torch.linalg.vector_norm( 1801*da0073e9SAndroid Build Coastguard Worker input, 2, _dim, keepdim, dtype=dtype 1802*da0073e9SAndroid Build Coastguard Worker ) 1803*da0073e9SAndroid Build Coastguard Worker else: 1804*da0073e9SAndroid Build Coastguard Worker return torch.linalg.vector_norm( 1805*da0073e9SAndroid Build Coastguard Worker input, 2, _dim, keepdim, dtype=dtype, out=out 1806*da0073e9SAndroid Build Coastguard Worker ) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker # Here we either call the nuclear norm, or we call matrix_norm with some arguments 1809*da0073e9SAndroid Build Coastguard Worker # that will throw an error 1810*da0073e9SAndroid Build Coastguard Worker if _dim is None: 1811*da0073e9SAndroid Build Coastguard Worker _dim = list(range(input.ndim)) 1812*da0073e9SAndroid Build Coastguard Worker if out is None: 1813*da0073e9SAndroid Build Coastguard Worker return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype) 1814*da0073e9SAndroid Build Coastguard Worker else: 1815*da0073e9SAndroid Build Coastguard Worker return torch.linalg.matrix_norm( 1816*da0073e9SAndroid Build Coastguard Worker input, p, _dim, keepdim, dtype=dtype, out=out 1817*da0073e9SAndroid Build Coastguard Worker ) 1818*da0073e9SAndroid Build Coastguard Worker else: 1819*da0073e9SAndroid Build Coastguard Worker # NB. p should be Union[str, number], not Optional! 1820*da0073e9SAndroid Build Coastguard Worker _p = 2.0 if p is None else p 1821*da0073e9SAndroid Build Coastguard Worker if out is None: 1822*da0073e9SAndroid Build Coastguard Worker return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) 1823*da0073e9SAndroid Build Coastguard Worker else: 1824*da0073e9SAndroid Build Coastguard Worker return torch.linalg.vector_norm( 1825*da0073e9SAndroid Build Coastguard Worker input, _p, _dim, keepdim, dtype=dtype, out=out 1826*da0073e9SAndroid Build Coastguard Worker ) 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker ndim = input.dim() 1829*da0073e9SAndroid Build Coastguard Worker 1830*da0073e9SAndroid Build Coastguard Worker # catch default case 1831*da0073e9SAndroid Build Coastguard Worker if dim is None and out is None and dtype is None and p is not None: 1832*da0073e9SAndroid Build Coastguard Worker if isinstance(p, str): 1833*da0073e9SAndroid Build Coastguard Worker if p == "fro": 1834*da0073e9SAndroid Build Coastguard Worker return _VF.frobenius_norm(input, dim=(), keepdim=keepdim) 1835*da0073e9SAndroid Build Coastguard Worker if not isinstance(p, str): 1836*da0073e9SAndroid Build Coastguard Worker _dim = list(range(ndim)) 1837*da0073e9SAndroid Build Coastguard Worker return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined] 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed 1840*da0073e9SAndroid Build Coastguard Worker # remove the overloads where dim is an int and replace with BraodcastingList1 1841*da0073e9SAndroid Build Coastguard Worker # and remove next four lines, replace _dim with dim 1842*da0073e9SAndroid Build Coastguard Worker if dim is not None: 1843*da0073e9SAndroid Build Coastguard Worker if isinstance(dim, (int, torch.SymInt)): 1844*da0073e9SAndroid Build Coastguard Worker _dim = [dim] 1845*da0073e9SAndroid Build Coastguard Worker else: 1846*da0073e9SAndroid Build Coastguard Worker _dim = dim 1847*da0073e9SAndroid Build Coastguard Worker else: 1848*da0073e9SAndroid Build Coastguard Worker _dim = None # type: ignore[assignment] 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker if isinstance(p, str): 1851*da0073e9SAndroid Build Coastguard Worker if p == "fro": 1852*da0073e9SAndroid Build Coastguard Worker if dtype is not None: 1853*da0073e9SAndroid Build Coastguard Worker raise ValueError("dtype argument is not supported in frobenius norm") 1854*da0073e9SAndroid Build Coastguard Worker 1855*da0073e9SAndroid Build Coastguard Worker if _dim is None: 1856*da0073e9SAndroid Build Coastguard Worker _dim = list(range(ndim)) 1857*da0073e9SAndroid Build Coastguard Worker if out is None: 1858*da0073e9SAndroid Build Coastguard Worker return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] 1859*da0073e9SAndroid Build Coastguard Worker else: 1860*da0073e9SAndroid Build Coastguard Worker return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] 1861*da0073e9SAndroid Build Coastguard Worker elif p == "nuc": 1862*da0073e9SAndroid Build Coastguard Worker if dtype is not None: 1863*da0073e9SAndroid Build Coastguard Worker raise ValueError("dtype argument is not supported in nuclear norm") 1864*da0073e9SAndroid Build Coastguard Worker if _dim is None: 1865*da0073e9SAndroid Build Coastguard Worker if out is None: 1866*da0073e9SAndroid Build Coastguard Worker return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type] 1867*da0073e9SAndroid Build Coastguard Worker else: 1868*da0073e9SAndroid Build Coastguard Worker return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type] 1869*da0073e9SAndroid Build Coastguard Worker else: 1870*da0073e9SAndroid Build Coastguard Worker if out is None: 1871*da0073e9SAndroid Build Coastguard Worker return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type] 1872*da0073e9SAndroid Build Coastguard Worker else: 1873*da0073e9SAndroid Build Coastguard Worker return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type] 1874*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}") 1875*da0073e9SAndroid Build Coastguard Worker else: 1876*da0073e9SAndroid Build Coastguard Worker if _dim is None: 1877*da0073e9SAndroid Build Coastguard Worker _dim = list(range(ndim)) 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker if out is None: 1880*da0073e9SAndroid Build Coastguard Worker if dtype is None: 1881*da0073e9SAndroid Build Coastguard Worker return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined] 1882*da0073e9SAndroid Build Coastguard Worker else: 1883*da0073e9SAndroid Build Coastguard Worker return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined] 1884*da0073e9SAndroid Build Coastguard Worker else: 1885*da0073e9SAndroid Build Coastguard Worker if dtype is None: 1886*da0073e9SAndroid Build Coastguard Worker return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined] 1887*da0073e9SAndroid Build Coastguard Worker else: 1888*da0073e9SAndroid Build Coastguard Worker return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined] 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Workerdef unravel_index( 1892*da0073e9SAndroid Build Coastguard Worker indices: Tensor, 1893*da0073e9SAndroid Build Coastguard Worker shape: Union[int, Sequence[int], torch.Size], 1894*da0073e9SAndroid Build Coastguard Worker) -> Tuple[Tensor, ...]: 1895*da0073e9SAndroid Build Coastguard Worker r"""Converts a tensor of flat indices into a tuple of coordinate tensors that 1896*da0073e9SAndroid Build Coastguard Worker index into an arbitrary tensor of the specified shape. 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker Args: 1899*da0073e9SAndroid Build Coastguard Worker indices (Tensor): An integer tensor containing indices into the 1900*da0073e9SAndroid Build Coastguard Worker flattened version of an arbitrary tensor of shape :attr:`shape`. 1901*da0073e9SAndroid Build Coastguard Worker All elements must be in the range ``[0, prod(shape) - 1]``. 1902*da0073e9SAndroid Build Coastguard Worker 1903*da0073e9SAndroid Build Coastguard Worker shape (int, sequence of ints, or torch.Size): The shape of the arbitrary 1904*da0073e9SAndroid Build Coastguard Worker tensor. All elements must be non-negative. 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker Returns: 1907*da0073e9SAndroid Build Coastguard Worker tuple of Tensors: Each ``i``-th tensor in the output corresponds with 1908*da0073e9SAndroid Build Coastguard Worker dimension ``i`` of :attr:`shape`. Each tensor has the same shape as 1909*da0073e9SAndroid Build Coastguard Worker ``indices`` and contains one index into dimension ``i`` for each of the 1910*da0073e9SAndroid Build Coastguard Worker flat indices given by ``indices``. 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker Example:: 1913*da0073e9SAndroid Build Coastguard Worker 1914*da0073e9SAndroid Build Coastguard Worker >>> import torch 1915*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor(4), (3, 2)) 1916*da0073e9SAndroid Build Coastguard Worker (tensor(2), 1917*da0073e9SAndroid Build Coastguard Worker tensor(0)) 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) 1920*da0073e9SAndroid Build Coastguard Worker (tensor([2, 0]), 1921*da0073e9SAndroid Build Coastguard Worker tensor([0, 1])) 1922*da0073e9SAndroid Build Coastguard Worker 1923*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) 1924*da0073e9SAndroid Build Coastguard Worker (tensor([0, 0, 1, 1, 2, 2]), 1925*da0073e9SAndroid Build Coastguard Worker tensor([0, 1, 0, 1, 0, 1])) 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) 1928*da0073e9SAndroid Build Coastguard Worker (tensor([1, 5]), 1929*da0073e9SAndroid Build Coastguard Worker tensor([2, 6]), 1930*da0073e9SAndroid Build Coastguard Worker tensor([3, 7]), 1931*da0073e9SAndroid Build Coastguard Worker tensor([4, 8])) 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) 1934*da0073e9SAndroid Build Coastguard Worker (tensor([[1], [5]]), 1935*da0073e9SAndroid Build Coastguard Worker tensor([[2], [6]]), 1936*da0073e9SAndroid Build Coastguard Worker tensor([[3], [7]]), 1937*da0073e9SAndroid Build Coastguard Worker tensor([[4], [8]])) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) 1940*da0073e9SAndroid Build Coastguard Worker (tensor([[12], [56]]), 1941*da0073e9SAndroid Build Coastguard Worker tensor([[34], [78]])) 1942*da0073e9SAndroid Build Coastguard Worker """ 1943*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(indices): 1944*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(unravel_index, (indices,), indices, shape=shape) 1945*da0073e9SAndroid Build Coastguard Worker res_tensor = _unravel_index(indices, shape) 1946*da0073e9SAndroid Build Coastguard Worker return res_tensor.unbind(-1) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Workerdef _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: 1950*da0073e9SAndroid Build Coastguard Worker torch._check_type( 1951*da0073e9SAndroid Build Coastguard Worker not indices.is_complex() 1952*da0073e9SAndroid Build Coastguard Worker and not indices.is_floating_point() 1953*da0073e9SAndroid Build Coastguard Worker and not indices.dtype == torch.bool, 1954*da0073e9SAndroid Build Coastguard Worker lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}", 1955*da0073e9SAndroid Build Coastguard Worker ) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker torch._check_type( 1958*da0073e9SAndroid Build Coastguard Worker isinstance(shape, (int, torch.SymInt, Sequence)), 1959*da0073e9SAndroid Build Coastguard Worker lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}", 1960*da0073e9SAndroid Build Coastguard Worker ) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker if isinstance(shape, (int, torch.SymInt)): 1963*da0073e9SAndroid Build Coastguard Worker shape = torch.Size([shape]) 1964*da0073e9SAndroid Build Coastguard Worker else: 1965*da0073e9SAndroid Build Coastguard Worker for dim in shape: 1966*da0073e9SAndroid Build Coastguard Worker torch._check_type( 1967*da0073e9SAndroid Build Coastguard Worker isinstance(dim, (int, torch.SymInt)), 1968*da0073e9SAndroid Build Coastguard Worker lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}", 1969*da0073e9SAndroid Build Coastguard Worker ) 1970*da0073e9SAndroid Build Coastguard Worker shape = torch.Size(shape) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker torch._check_value( 1973*da0073e9SAndroid Build Coastguard Worker all(dim >= 0 for dim in shape), 1974*da0073e9SAndroid Build Coastguard Worker lambda: f"'shape' cannot have negative values, but got {tuple(shape)}", 1975*da0073e9SAndroid Build Coastguard Worker ) 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker coefs = list( 1978*da0073e9SAndroid Build Coastguard Worker reversed( 1979*da0073e9SAndroid Build Coastguard Worker list( 1980*da0073e9SAndroid Build Coastguard Worker itertools.accumulate( 1981*da0073e9SAndroid Build Coastguard Worker reversed(shape[1:] + torch.Size([1])), func=operator.mul 1982*da0073e9SAndroid Build Coastguard Worker ) 1983*da0073e9SAndroid Build Coastguard Worker ) 1984*da0073e9SAndroid Build Coastguard Worker ) 1985*da0073e9SAndroid Build Coastguard Worker ) 1986*da0073e9SAndroid Build Coastguard Worker return indices.unsqueeze(-1).floor_divide( 1987*da0073e9SAndroid Build Coastguard Worker torch.tensor(coefs, device=indices.device, dtype=torch.int64) 1988*da0073e9SAndroid Build Coastguard Worker ) % torch.tensor(shape, device=indices.device, dtype=torch.int64) 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Workerdef chain_matmul(*matrices, out=None): 1992*da0073e9SAndroid Build Coastguard Worker r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed 1993*da0073e9SAndroid Build Coastguard Worker using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms 1994*da0073e9SAndroid Build Coastguard Worker of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N` 1995*da0073e9SAndroid Build Coastguard Worker needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. 1996*da0073e9SAndroid Build Coastguard Worker If :math:`N` is 1, then this is a no-op - the original matrix is returned as is. 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Worker .. warning:: 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release. 2001*da0073e9SAndroid Build Coastguard Worker Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors 2002*da0073e9SAndroid Build Coastguard Worker rather than multiple arguments. 2003*da0073e9SAndroid Build Coastguard Worker 2004*da0073e9SAndroid Build Coastguard Worker Args: 2005*da0073e9SAndroid Build Coastguard Worker matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined. 2006*da0073e9SAndroid Build Coastguard Worker out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``. 2007*da0073e9SAndroid Build Coastguard Worker 2008*da0073e9SAndroid Build Coastguard Worker Returns: 2009*da0073e9SAndroid Build Coastguard Worker Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product 2010*da0073e9SAndroid Build Coastguard Worker would be of dimensions :math:`p_{1} \times p_{N + 1}`. 2011*da0073e9SAndroid Build Coastguard Worker 2012*da0073e9SAndroid Build Coastguard Worker Example:: 2013*da0073e9SAndroid Build Coastguard Worker 2014*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP 2015*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2016*da0073e9SAndroid Build Coastguard Worker >>> a = torch.randn(3, 4) 2017*da0073e9SAndroid Build Coastguard Worker >>> b = torch.randn(4, 5) 2018*da0073e9SAndroid Build Coastguard Worker >>> c = torch.randn(5, 6) 2019*da0073e9SAndroid Build Coastguard Worker >>> d = torch.randn(6, 7) 2020*da0073e9SAndroid Build Coastguard Worker >>> # will raise a deprecation warning 2021*da0073e9SAndroid Build Coastguard Worker >>> torch.chain_matmul(a, b, c, d) 2022*da0073e9SAndroid Build Coastguard Worker tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], 2023*da0073e9SAndroid Build Coastguard Worker [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], 2024*da0073e9SAndroid Build Coastguard Worker [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]]) 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition 2027*da0073e9SAndroid Build Coastguard Worker """ 2028*da0073e9SAndroid Build Coastguard Worker # This wrapper exists to support variadic args. 2029*da0073e9SAndroid Build Coastguard Worker if has_torch_function(matrices): 2030*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(chain_matmul, matrices, *matrices) 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker if out is None: 2033*da0073e9SAndroid Build Coastguard Worker return _VF.chain_matmul(matrices) # type: ignore[attr-defined] 2034*da0073e9SAndroid Build Coastguard Worker else: 2035*da0073e9SAndroid Build Coastguard Worker return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined] 2036*da0073e9SAndroid Build Coastguard Worker 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Workerdef _lu_impl(A, pivot=True, get_infos=False, out=None): 2039*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor] 2040*da0073e9SAndroid Build Coastguard Worker r"""Computes the LU factorization of a matrix or batches of matrices 2041*da0073e9SAndroid Build Coastguard Worker :attr:`A`. Returns a tuple containing the LU factorization and 2042*da0073e9SAndroid Build Coastguard Worker pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to 2043*da0073e9SAndroid Build Coastguard Worker ``True``. 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker .. warning:: 2046*da0073e9SAndroid Build Coastguard Worker 2047*da0073e9SAndroid Build Coastguard Worker :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor` 2048*da0073e9SAndroid Build Coastguard Worker and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a 2049*da0073e9SAndroid Build Coastguard Worker future PyTorch release. 2050*da0073e9SAndroid Build Coastguard Worker ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with 2051*da0073e9SAndroid Build Coastguard Worker 2052*da0073e9SAndroid Build Coastguard Worker .. code:: python 2053*da0073e9SAndroid Build Coastguard Worker 2054*da0073e9SAndroid Build Coastguard Worker LU, pivots = torch.linalg.lu_factor(A, compute_pivots) 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker .. code:: python 2059*da0073e9SAndroid Build Coastguard Worker 2060*da0073e9SAndroid Build Coastguard Worker LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots) 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker .. note:: 2063*da0073e9SAndroid Build Coastguard Worker * The returned permutation matrix for every matrix in the batch is 2064*da0073e9SAndroid Build Coastguard Worker represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``. 2065*da0073e9SAndroid Build Coastguard Worker ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm, 2066*da0073e9SAndroid Build Coastguard Worker the ``i``-th row was permuted with the ``j-1``-th row. 2067*da0073e9SAndroid Build Coastguard Worker * LU factorization with :attr:`pivot` = ``False`` is not available 2068*da0073e9SAndroid Build Coastguard Worker for CPU, and attempting to do so will throw an error. However, 2069*da0073e9SAndroid Build Coastguard Worker LU factorization with :attr:`pivot` = ``False`` is available for 2070*da0073e9SAndroid Build Coastguard Worker CUDA. 2071*da0073e9SAndroid Build Coastguard Worker * This function does not check if the factorization was successful 2072*da0073e9SAndroid Build Coastguard Worker or not if :attr:`get_infos` is ``True`` since the status of the 2073*da0073e9SAndroid Build Coastguard Worker factorization is present in the third element of the return tuple. 2074*da0073e9SAndroid Build Coastguard Worker * In the case of batches of square matrices with size less or equal 2075*da0073e9SAndroid Build Coastguard Worker to 32 on a CUDA device, the LU factorization is repeated for 2076*da0073e9SAndroid Build Coastguard Worker singular matrices due to the bug in the MAGMA library 2077*da0073e9SAndroid Build Coastguard Worker (see magma issue 13). 2078*da0073e9SAndroid Build Coastguard Worker * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`. 2079*da0073e9SAndroid Build Coastguard Worker 2080*da0073e9SAndroid Build Coastguard Worker .. warning:: 2081*da0073e9SAndroid Build Coastguard Worker The gradients of this function will only be finite when :attr:`A` is full rank. 2082*da0073e9SAndroid Build Coastguard Worker This is because the LU decomposition is just differentiable at full rank matrices. 2083*da0073e9SAndroid Build Coastguard Worker Furthermore, if :attr:`A` is close to not being full rank, 2084*da0073e9SAndroid Build Coastguard Worker the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`. 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker Args: 2087*da0073e9SAndroid Build Coastguard Worker A (Tensor): the tensor to factor of size :math:`(*, m, n)` 2088*da0073e9SAndroid Build Coastguard Worker pivot (bool, optional): controls whether pivoting is done. Default: ``True`` 2089*da0073e9SAndroid Build Coastguard Worker get_infos (bool, optional): if set to ``True``, returns an info IntTensor. 2090*da0073e9SAndroid Build Coastguard Worker Default: ``False`` 2091*da0073e9SAndroid Build Coastguard Worker out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``, 2092*da0073e9SAndroid Build Coastguard Worker then the elements in the tuple are Tensor, IntTensor, 2093*da0073e9SAndroid Build Coastguard Worker and IntTensor. If :attr:`get_infos` is ``False``, then the 2094*da0073e9SAndroid Build Coastguard Worker elements in the tuple are Tensor, IntTensor. Default: ``None`` 2095*da0073e9SAndroid Build Coastguard Worker 2096*da0073e9SAndroid Build Coastguard Worker Returns: 2097*da0073e9SAndroid Build Coastguard Worker (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing 2098*da0073e9SAndroid Build Coastguard Worker 2099*da0073e9SAndroid Build Coastguard Worker - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)` 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`. 2102*da0073e9SAndroid Build Coastguard Worker ``pivots`` stores all the intermediate transpositions of rows. 2103*da0073e9SAndroid Build Coastguard Worker The final permutation ``perm`` could be reconstructed by 2104*da0073e9SAndroid Build Coastguard Worker applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``, 2105*da0073e9SAndroid Build Coastguard Worker where ``perm`` is initially the identity permutation of :math:`m` elements 2106*da0073e9SAndroid Build Coastguard Worker (essentially this is what :func:`torch.lu_unpack` is doing). 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of 2109*da0073e9SAndroid Build Coastguard Worker size :math:`(*)` where non-zero values indicate whether factorization for the matrix or 2110*da0073e9SAndroid Build Coastguard Worker each minibatch has succeeded or failed 2111*da0073e9SAndroid Build Coastguard Worker 2112*da0073e9SAndroid Build Coastguard Worker Example:: 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 2115*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +IGNORE_WANT("non-deterministic") 2116*da0073e9SAndroid Build Coastguard Worker >>> A = torch.randn(2, 3, 3) 2117*da0073e9SAndroid Build Coastguard Worker >>> A_LU, pivots = torch.lu(A) 2118*da0073e9SAndroid Build Coastguard Worker >>> A_LU 2119*da0073e9SAndroid Build Coastguard Worker tensor([[[ 1.3506, 2.5558, -0.0816], 2120*da0073e9SAndroid Build Coastguard Worker [ 0.1684, 1.1551, 0.1940], 2121*da0073e9SAndroid Build Coastguard Worker [ 0.1193, 0.6189, -0.5497]], 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker [[ 0.4526, 1.2526, -0.3285], 2124*da0073e9SAndroid Build Coastguard Worker [-0.7988, 0.7175, -0.9701], 2125*da0073e9SAndroid Build Coastguard Worker [ 0.2634, -0.9255, -0.3459]]]) 2126*da0073e9SAndroid Build Coastguard Worker >>> pivots 2127*da0073e9SAndroid Build Coastguard Worker tensor([[ 3, 3, 3], 2128*da0073e9SAndroid Build Coastguard Worker [ 3, 3, 3]], dtype=torch.int32) 2129*da0073e9SAndroid Build Coastguard Worker >>> A_LU, pivots, info = torch.lu(A, get_infos=True) 2130*da0073e9SAndroid Build Coastguard Worker >>> if info.nonzero().size(0) == 0: 2131*da0073e9SAndroid Build Coastguard Worker ... print('LU factorization succeeded for all samples!') 2132*da0073e9SAndroid Build Coastguard Worker LU factorization succeeded for all samples! 2133*da0073e9SAndroid Build Coastguard Worker """ 2134*da0073e9SAndroid Build Coastguard Worker # If get_infos is True, then we don't need to check for errors and vice versa 2135*da0073e9SAndroid Build Coastguard Worker return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 2139*da0073e9SAndroid Build Coastguard Worker _ListOrSeq = Sequence[Tensor] 2140*da0073e9SAndroid Build Coastguard Workerelse: 2141*da0073e9SAndroid Build Coastguard Worker _ListOrSeq = List[Tensor] 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Workerdef _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None: 2145*da0073e9SAndroid Build Coastguard Worker get_infos_int = 1 if get_infos else 0 2146*da0073e9SAndroid Build Coastguard Worker if out_len - get_infos_int != 2: 2147*da0073e9SAndroid Build Coastguard Worker raise TypeError( 2148*da0073e9SAndroid Build Coastguard Worker f"expected tuple of {2 + int(get_infos)} elements but got {out_len}" 2149*da0073e9SAndroid Build Coastguard Worker ) 2150*da0073e9SAndroid Build Coastguard Worker if not isinstance(out, (tuple, list)): 2151*da0073e9SAndroid Build Coastguard Worker raise TypeError( 2152*da0073e9SAndroid Build Coastguard Worker f"argument 'out' must be tuple of Tensors, not {type(out).__name__}" 2153*da0073e9SAndroid Build Coastguard Worker ) 2154*da0073e9SAndroid Build Coastguard Worker 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Workerdef _lu_with_infos(A, pivot=True, get_infos=False, out=None): 2157*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor] 2158*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(A): 2159*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2160*da0073e9SAndroid Build Coastguard Worker lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out 2161*da0073e9SAndroid Build Coastguard Worker ) 2162*da0073e9SAndroid Build Coastguard Worker result = _lu_impl(A, pivot, get_infos, out) 2163*da0073e9SAndroid Build Coastguard Worker if out is not None: 2164*da0073e9SAndroid Build Coastguard Worker _check_list_size(len(out), get_infos, out) 2165*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 2166*da0073e9SAndroid Build Coastguard Worker out[i].resize_as_(result[i]).copy_(result[i]) 2167*da0073e9SAndroid Build Coastguard Worker return out 2168*da0073e9SAndroid Build Coastguard Worker else: 2169*da0073e9SAndroid Build Coastguard Worker return result # A_LU, pivots, infos 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker 2172*da0073e9SAndroid Build Coastguard Workerdef _lu_no_infos(A, pivot=True, get_infos=False, out=None): 2173*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] 2174*da0073e9SAndroid Build Coastguard Worker # need to check for torch_function here so that we exit if 2175*da0073e9SAndroid Build Coastguard Worker if has_torch_function_unary(A): 2176*da0073e9SAndroid Build Coastguard Worker return handle_torch_function( 2177*da0073e9SAndroid Build Coastguard Worker lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out 2178*da0073e9SAndroid Build Coastguard Worker ) 2179*da0073e9SAndroid Build Coastguard Worker result = _lu_impl(A, pivot, get_infos, out) 2180*da0073e9SAndroid Build Coastguard Worker if out is not None: 2181*da0073e9SAndroid Build Coastguard Worker _check_list_size(len(out), get_infos, out) 2182*da0073e9SAndroid Build Coastguard Worker for i in range(len(out)): 2183*da0073e9SAndroid Build Coastguard Worker out[i].resize_as_(result[i]).copy_(result[i]) 2184*da0073e9SAndroid Build Coastguard Worker return out 2185*da0073e9SAndroid Build Coastguard Worker else: 2186*da0073e9SAndroid Build Coastguard Worker return result[0], result[1] # A_LU, pivots 2187*da0073e9SAndroid Build Coastguard Worker 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker# The return type of lu depends on `get_infos`, so in order to resolve the output type 2190*da0073e9SAndroid Build Coastguard Worker# of lu in TorchScript we need to statically know the value of `get_infos` 2191*da0073e9SAndroid Build Coastguard Workerlu = boolean_dispatch( 2192*da0073e9SAndroid Build Coastguard Worker arg_name="get_infos", 2193*da0073e9SAndroid Build Coastguard Worker arg_index=2, 2194*da0073e9SAndroid Build Coastguard Worker default=False, 2195*da0073e9SAndroid Build Coastguard Worker if_true=_lu_with_infos, 2196*da0073e9SAndroid Build Coastguard Worker if_false=_lu_no_infos, 2197*da0073e9SAndroid Build Coastguard Worker module_name=__name__, 2198*da0073e9SAndroid Build Coastguard Worker func_name="lu", 2199*da0073e9SAndroid Build Coastguard Worker) 2200*da0073e9SAndroid Build Coastguard Workerlu.__doc__ = _lu_impl.__doc__ 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Workerdef align_tensors(*tensors): 2204*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`align_tensors` not yet implemented.") 2205