xref: /aosp_15_r20/external/pytorch/torch/_numpy/_funcs_impl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3"""A thin pytorch / numpy compat layer.
4
5Things imported from here have numpy-compatible signatures but operate on
6pytorch tensors.
7"""
8# Contents of this module ends up in the main namespace via _funcs.py
9# where type annotations are used in conjunction with the @normalizer decorator.
10from __future__ import annotations
11
12import builtins
13import itertools
14import operator
15from typing import Optional, Sequence, TYPE_CHECKING
16
17import torch
18
19from . import _dtypes_impl, _util
20
21
22if TYPE_CHECKING:
23    from ._normalizations import (
24        ArrayLike,
25        ArrayLikeOrScalar,
26        CastingModes,
27        DTypeLike,
28        NDArray,
29        NotImplementedType,
30        OutArray,
31    )
32
33
34def copy(
35    a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False
36):
37    return a.clone()
38
39
40def copyto(
41    dst: NDArray,
42    src: ArrayLike,
43    casting: Optional[CastingModes] = "same_kind",
44    where: NotImplementedType = None,
45):
46    (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
47    dst.copy_(src)
48
49
50def atleast_1d(*arys: ArrayLike):
51    res = torch.atleast_1d(*arys)
52    if isinstance(res, tuple):
53        return list(res)
54    else:
55        return res
56
57
58def atleast_2d(*arys: ArrayLike):
59    res = torch.atleast_2d(*arys)
60    if isinstance(res, tuple):
61        return list(res)
62    else:
63        return res
64
65
66def atleast_3d(*arys: ArrayLike):
67    res = torch.atleast_3d(*arys)
68    if isinstance(res, tuple):
69        return list(res)
70    else:
71        return res
72
73
74def _concat_check(tup, dtype, out):
75    if tup == ():
76        raise ValueError("need at least one array to concatenate")
77
78    """Check inputs in concatenate et al."""
79    if out is not None and dtype is not None:
80        # mimic numpy
81        raise TypeError(
82            "concatenate() only takes `out` or `dtype` as an "
83            "argument, but both were provided."
84        )
85
86
87def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
88    """Figure out dtypes, cast if necessary."""
89
90    if out is not None or dtype is not None:
91        # figure out the type of the inputs and outputs
92        out_dtype = out.dtype.torch_dtype if dtype is None else dtype
93    else:
94        out_dtype = _dtypes_impl.result_type_impl(*tensors)
95
96    # cast input arrays if necessary; do not broadcast them agains `out`
97    tensors = _util.typecast_tensors(tensors, out_dtype, casting)
98
99    return tensors
100
101
102def _concatenate(
103    tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
104):
105    # pure torch implementation, used below and in cov/corrcoef below
106    tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
107    tensors = _concat_cast_helper(tensors, out, dtype, casting)
108    return torch.cat(tensors, axis)
109
110
111def concatenate(
112    ar_tuple: Sequence[ArrayLike],
113    axis=0,
114    out: Optional[OutArray] = None,
115    dtype: Optional[DTypeLike] = None,
116    casting: Optional[CastingModes] = "same_kind",
117):
118    _concat_check(ar_tuple, dtype, out=out)
119    result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
120    return result
121
122
123def vstack(
124    tup: Sequence[ArrayLike],
125    *,
126    dtype: Optional[DTypeLike] = None,
127    casting: Optional[CastingModes] = "same_kind",
128):
129    _concat_check(tup, dtype, out=None)
130    tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
131    return torch.vstack(tensors)
132
133
134row_stack = vstack
135
136
137def hstack(
138    tup: Sequence[ArrayLike],
139    *,
140    dtype: Optional[DTypeLike] = None,
141    casting: Optional[CastingModes] = "same_kind",
142):
143    _concat_check(tup, dtype, out=None)
144    tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
145    return torch.hstack(tensors)
146
147
148def dstack(
149    tup: Sequence[ArrayLike],
150    *,
151    dtype: Optional[DTypeLike] = None,
152    casting: Optional[CastingModes] = "same_kind",
153):
154    # XXX: in numpy 1.24 dstack does not have dtype and casting keywords
155    # but {h,v}stack do.  Hence add them here for consistency.
156    _concat_check(tup, dtype, out=None)
157    tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
158    return torch.dstack(tensors)
159
160
161def column_stack(
162    tup: Sequence[ArrayLike],
163    *,
164    dtype: Optional[DTypeLike] = None,
165    casting: Optional[CastingModes] = "same_kind",
166):
167    # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
168    # but row_stack does. (because row_stack is an alias for vstack, really).
169    # Hence add these keywords here for consistency.
170    _concat_check(tup, dtype, out=None)
171    tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
172    return torch.column_stack(tensors)
173
174
175def stack(
176    arrays: Sequence[ArrayLike],
177    axis=0,
178    out: Optional[OutArray] = None,
179    *,
180    dtype: Optional[DTypeLike] = None,
181    casting: Optional[CastingModes] = "same_kind",
182):
183    _concat_check(arrays, dtype, out=out)
184
185    tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
186    result_ndim = tensors[0].ndim + 1
187    axis = _util.normalize_axis_index(axis, result_ndim)
188    return torch.stack(tensors, axis=axis)
189
190
191def append(arr: ArrayLike, values: ArrayLike, axis=None):
192    if axis is None:
193        if arr.ndim != 1:
194            arr = arr.flatten()
195        values = values.flatten()
196        axis = arr.ndim - 1
197    return _concatenate((arr, values), axis=axis)
198
199
200# ### split ###
201
202
203def _split_helper(tensor, indices_or_sections, axis, strict=False):
204    if isinstance(indices_or_sections, int):
205        return _split_helper_int(tensor, indices_or_sections, axis, strict)
206    elif isinstance(indices_or_sections, (list, tuple)):
207        # NB: drop split=..., it only applies to split_helper_int
208        return _split_helper_list(tensor, list(indices_or_sections), axis)
209    else:
210        raise TypeError("split_helper: ", type(indices_or_sections))
211
212
213def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
214    if not isinstance(indices_or_sections, int):
215        raise NotImplementedError("split: indices_or_sections")
216
217    axis = _util.normalize_axis_index(axis, tensor.ndim)
218
219    # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
220    l, n = tensor.shape[axis], indices_or_sections
221
222    if n <= 0:
223        raise ValueError
224
225    if l % n == 0:
226        num, sz = n, l // n
227        lst = [sz] * num
228    else:
229        if strict:
230            raise ValueError("array split does not result in an equal division")
231
232        num, sz = l % n, l // n + 1
233        lst = [sz] * num
234
235    lst += [sz - 1] * (n - num)
236
237    return torch.split(tensor, lst, axis)
238
239
240def _split_helper_list(tensor, indices_or_sections, axis):
241    if not isinstance(indices_or_sections, list):
242        raise NotImplementedError("split: indices_or_sections: list")
243    # numpy expects indices, while torch expects lengths of sections
244    # also, numpy appends zero-size arrays for indices above the shape[axis]
245    lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
246    num_extra = len(indices_or_sections) - len(lst)
247
248    lst.append(tensor.shape[axis])
249    lst = [
250        lst[0],
251    ] + [a - b for a, b in zip(lst[1:], lst[:-1])]
252    lst += [0] * num_extra
253
254    return torch.split(tensor, lst, axis)
255
256
257def array_split(ary: ArrayLike, indices_or_sections, axis=0):
258    return _split_helper(ary, indices_or_sections, axis)
259
260
261def split(ary: ArrayLike, indices_or_sections, axis=0):
262    return _split_helper(ary, indices_or_sections, axis, strict=True)
263
264
265def hsplit(ary: ArrayLike, indices_or_sections):
266    if ary.ndim == 0:
267        raise ValueError("hsplit only works on arrays of 1 or more dimensions")
268    axis = 1 if ary.ndim > 1 else 0
269    return _split_helper(ary, indices_or_sections, axis, strict=True)
270
271
272def vsplit(ary: ArrayLike, indices_or_sections):
273    if ary.ndim < 2:
274        raise ValueError("vsplit only works on arrays of 2 or more dimensions")
275    return _split_helper(ary, indices_or_sections, 0, strict=True)
276
277
278def dsplit(ary: ArrayLike, indices_or_sections):
279    if ary.ndim < 3:
280        raise ValueError("dsplit only works on arrays of 3 or more dimensions")
281    return _split_helper(ary, indices_or_sections, 2, strict=True)
282
283
284def kron(a: ArrayLike, b: ArrayLike):
285    return torch.kron(a, b)
286
287
288def vander(x: ArrayLike, N=None, increasing=False):
289    return torch.vander(x, N, increasing)
290
291
292# ### linspace, geomspace, logspace and arange ###
293
294
295def linspace(
296    start: ArrayLike,
297    stop: ArrayLike,
298    num=50,
299    endpoint=True,
300    retstep=False,
301    dtype: Optional[DTypeLike] = None,
302    axis=0,
303):
304    if axis != 0 or retstep or not endpoint:
305        raise NotImplementedError
306    if dtype is None:
307        dtype = _dtypes_impl.default_dtypes().float_dtype
308    # XXX: raises TypeError if start or stop are not scalars
309    return torch.linspace(start, stop, num, dtype=dtype)
310
311
312def geomspace(
313    start: ArrayLike,
314    stop: ArrayLike,
315    num=50,
316    endpoint=True,
317    dtype: Optional[DTypeLike] = None,
318    axis=0,
319):
320    if axis != 0 or not endpoint:
321        raise NotImplementedError
322    base = torch.pow(stop / start, 1.0 / (num - 1))
323    logbase = torch.log(base)
324    return torch.logspace(
325        torch.log(start) / logbase,
326        torch.log(stop) / logbase,
327        num,
328        base=base,
329    )
330
331
332def logspace(
333    start,
334    stop,
335    num=50,
336    endpoint=True,
337    base=10.0,
338    dtype: Optional[DTypeLike] = None,
339    axis=0,
340):
341    if axis != 0 or not endpoint:
342        raise NotImplementedError
343    return torch.logspace(start, stop, num, base=base, dtype=dtype)
344
345
346def arange(
347    start: Optional[ArrayLikeOrScalar] = None,
348    stop: Optional[ArrayLikeOrScalar] = None,
349    step: Optional[ArrayLikeOrScalar] = 1,
350    dtype: Optional[DTypeLike] = None,
351    *,
352    like: NotImplementedType = None,
353):
354    if step == 0:
355        raise ZeroDivisionError
356    if stop is None and start is None:
357        raise TypeError
358    if stop is None:
359        # XXX: this breaks if start is passed as a kwarg:
360        # arange(start=4) should raise (no stop) but doesn't
361        start, stop = 0, start
362    if start is None:
363        start = 0
364
365    # the dtype of the result
366    if dtype is None:
367        dtype = (
368            _dtypes_impl.default_dtypes().float_dtype
369            if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step))
370            else _dtypes_impl.default_dtypes().int_dtype
371        )
372    work_dtype = torch.float64 if dtype.is_complex else dtype
373
374    # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager.
375    if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)):
376        raise NotImplementedError
377
378    if (step > 0 and start > stop) or (step < 0 and start < stop):
379        # empty range
380        return torch.empty(0, dtype=dtype)
381
382    result = torch.arange(start, stop, step, dtype=work_dtype)
383    result = _util.cast_if_needed(result, dtype)
384    return result
385
386
387# ### zeros/ones/empty/full ###
388
389
390def empty(
391    shape,
392    dtype: Optional[DTypeLike] = None,
393    order: NotImplementedType = "C",
394    *,
395    like: NotImplementedType = None,
396):
397    if dtype is None:
398        dtype = _dtypes_impl.default_dtypes().float_dtype
399    return torch.empty(shape, dtype=dtype)
400
401
402# NB: *_like functions deliberately deviate from numpy: it has subok=True
403# as the default; we set subok=False and raise on anything else.
404
405
406def empty_like(
407    prototype: ArrayLike,
408    dtype: Optional[DTypeLike] = None,
409    order: NotImplementedType = "K",
410    subok: NotImplementedType = False,
411    shape=None,
412):
413    result = torch.empty_like(prototype, dtype=dtype)
414    if shape is not None:
415        result = result.reshape(shape)
416    return result
417
418
419def full(
420    shape,
421    fill_value: ArrayLike,
422    dtype: Optional[DTypeLike] = None,
423    order: NotImplementedType = "C",
424    *,
425    like: NotImplementedType = None,
426):
427    if isinstance(shape, int):
428        shape = (shape,)
429    if dtype is None:
430        dtype = fill_value.dtype
431    if not isinstance(shape, (tuple, list)):
432        shape = (shape,)
433    return torch.full(shape, fill_value, dtype=dtype)
434
435
436def full_like(
437    a: ArrayLike,
438    fill_value,
439    dtype: Optional[DTypeLike] = None,
440    order: NotImplementedType = "K",
441    subok: NotImplementedType = False,
442    shape=None,
443):
444    # XXX: fill_value broadcasts
445    result = torch.full_like(a, fill_value, dtype=dtype)
446    if shape is not None:
447        result = result.reshape(shape)
448    return result
449
450
451def ones(
452    shape,
453    dtype: Optional[DTypeLike] = None,
454    order: NotImplementedType = "C",
455    *,
456    like: NotImplementedType = None,
457):
458    if dtype is None:
459        dtype = _dtypes_impl.default_dtypes().float_dtype
460    return torch.ones(shape, dtype=dtype)
461
462
463def ones_like(
464    a: ArrayLike,
465    dtype: Optional[DTypeLike] = None,
466    order: NotImplementedType = "K",
467    subok: NotImplementedType = False,
468    shape=None,
469):
470    result = torch.ones_like(a, dtype=dtype)
471    if shape is not None:
472        result = result.reshape(shape)
473    return result
474
475
476def zeros(
477    shape,
478    dtype: Optional[DTypeLike] = None,
479    order: NotImplementedType = "C",
480    *,
481    like: NotImplementedType = None,
482):
483    if dtype is None:
484        dtype = _dtypes_impl.default_dtypes().float_dtype
485    return torch.zeros(shape, dtype=dtype)
486
487
488def zeros_like(
489    a: ArrayLike,
490    dtype: Optional[DTypeLike] = None,
491    order: NotImplementedType = "K",
492    subok: NotImplementedType = False,
493    shape=None,
494):
495    result = torch.zeros_like(a, dtype=dtype)
496    if shape is not None:
497        result = result.reshape(shape)
498    return result
499
500
501# ### cov & corrcoef ###
502
503
504def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
505    """Prepare inputs for cov and corrcoef."""
506
507    # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
508    if y_tensor is not None:
509        # make sure x and y are at least 2D
510        ndim_extra = 2 - x_tensor.ndim
511        if ndim_extra > 0:
512            x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape)
513        if not rowvar and x_tensor.shape[0] != 1:
514            x_tensor = x_tensor.mT
515        x_tensor = x_tensor.clone()
516
517        ndim_extra = 2 - y_tensor.ndim
518        if ndim_extra > 0:
519            y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape)
520        if not rowvar and y_tensor.shape[0] != 1:
521            y_tensor = y_tensor.mT
522        y_tensor = y_tensor.clone()
523
524        x_tensor = _concatenate((x_tensor, y_tensor), axis=0)
525
526    return x_tensor
527
528
529def corrcoef(
530    x: ArrayLike,
531    y: Optional[ArrayLike] = None,
532    rowvar=True,
533    bias=None,
534    ddof=None,
535    *,
536    dtype: Optional[DTypeLike] = None,
537):
538    if bias is not None or ddof is not None:
539        # deprecated in NumPy
540        raise NotImplementedError
541    xy_tensor = _xy_helper_corrcoef(x, y, rowvar)
542
543    is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu
544    if is_half:
545        # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
546        dtype = torch.float32
547
548    xy_tensor = _util.cast_if_needed(xy_tensor, dtype)
549    result = torch.corrcoef(xy_tensor)
550
551    if is_half:
552        result = result.to(torch.float16)
553
554    return result
555
556
557def cov(
558    m: ArrayLike,
559    y: Optional[ArrayLike] = None,
560    rowvar=True,
561    bias=False,
562    ddof=None,
563    fweights: Optional[ArrayLike] = None,
564    aweights: Optional[ArrayLike] = None,
565    *,
566    dtype: Optional[DTypeLike] = None,
567):
568    m = _xy_helper_corrcoef(m, y, rowvar)
569
570    if ddof is None:
571        ddof = 1 if bias == 0 else 0
572
573    is_half = (m.dtype == torch.float16) and m.is_cpu
574    if is_half:
575        # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
576        dtype = torch.float32
577
578    m = _util.cast_if_needed(m, dtype)
579    result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights)
580
581    if is_half:
582        result = result.to(torch.float16)
583
584    return result
585
586
587def _conv_corr_impl(a, v, mode):
588    dt = _dtypes_impl.result_type_impl(a, v)
589    a = _util.cast_if_needed(a, dt)
590    v = _util.cast_if_needed(v, dt)
591
592    padding = v.shape[0] - 1 if mode == "full" else mode
593
594    if padding == "same" and v.shape[0] % 2 == 0:
595        # UserWarning: Using padding='same' with even kernel lengths and odd
596        # dilation may require a zero-padded copy of the input be created
597        # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.)
598        raise NotImplementedError("mode='same' and even-length weights")
599
600    # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
601    aa = a[None, :]
602    vv = v[None, None, :]
603
604    result = torch.nn.functional.conv1d(aa, vv, padding=padding)
605
606    # torch returns a 2D result, numpy returns a 1D array
607    return result[0, :]
608
609
610def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
611    # NumPy: if v is longer than a, the arrays are swapped before computation
612    if a.shape[0] < v.shape[0]:
613        a, v = v, a
614
615    # flip the weights since numpy does and torch does not
616    v = torch.flip(v, (0,))
617
618    return _conv_corr_impl(a, v, mode)
619
620
621def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
622    v = torch.conj_physical(v)
623    return _conv_corr_impl(a, v, mode)
624
625
626# ### logic & element selection ###
627
628
629def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):
630    if x.numel() == 0:
631        # edge case allowed by numpy
632        x = x.new_empty(0, dtype=int)
633
634    int_dtype = _dtypes_impl.default_dtypes().int_dtype
635    (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
636
637    return torch.bincount(x, weights, minlength)
638
639
640def where(
641    condition: ArrayLike,
642    x: Optional[ArrayLikeOrScalar] = None,
643    y: Optional[ArrayLikeOrScalar] = None,
644    /,
645):
646    if (x is None) != (y is None):
647        raise ValueError("either both or neither of x and y should be given")
648
649    if condition.dtype != torch.bool:
650        condition = condition.to(torch.bool)
651
652    if x is None and y is None:
653        result = torch.where(condition)
654    else:
655        result = torch.where(condition, x, y)
656    return result
657
658
659# ###### module-level queries of object properties
660
661
662def ndim(a: ArrayLike):
663    return a.ndim
664
665
666def shape(a: ArrayLike):
667    return tuple(a.shape)
668
669
670def size(a: ArrayLike, axis=None):
671    if axis is None:
672        return a.numel()
673    else:
674        return a.shape[axis]
675
676
677# ###### shape manipulations and indexing
678
679
680def expand_dims(a: ArrayLike, axis):
681    shape = _util.expand_shape(a.shape, axis)
682    return a.view(shape)  # never copies
683
684
685def flip(m: ArrayLike, axis=None):
686    # XXX: semantic difference: np.flip returns a view, torch.flip copies
687    if axis is None:
688        axis = tuple(range(m.ndim))
689    else:
690        axis = _util.normalize_axis_tuple(axis, m.ndim)
691    return torch.flip(m, axis)
692
693
694def flipud(m: ArrayLike):
695    return torch.flipud(m)
696
697
698def fliplr(m: ArrayLike):
699    return torch.fliplr(m)
700
701
702def rot90(m: ArrayLike, k=1, axes=(0, 1)):
703    axes = _util.normalize_axis_tuple(axes, m.ndim)
704    return torch.rot90(m, k, axes)
705
706
707# ### broadcasting and indices ###
708
709
710def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
711    return torch.broadcast_to(array, size=shape)
712
713
714# This is a function from tuples to tuples, so we just reuse it
715from torch import broadcast_shapes
716
717
718def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False):
719    return torch.broadcast_tensors(*args)
720
721
722def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
723    ndim = len(xi)
724
725    if indexing not in ["xy", "ij"]:
726        raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.")
727
728    s0 = (1,) * ndim
729    output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)]
730
731    if indexing == "xy" and ndim > 1:
732        # switch first and second axis
733        output[0] = output[0].reshape((1, -1) + s0[2:])
734        output[1] = output[1].reshape((-1, 1) + s0[2:])
735
736    if not sparse:
737        # Return the full N-D matrix (not only the 1-D vector)
738        output = torch.broadcast_tensors(*output)
739
740    if copy:
741        output = [x.clone() for x in output]
742
743    return list(output)  # match numpy, return a list
744
745
746def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False):
747    # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791
748    dimensions = tuple(dimensions)
749    N = len(dimensions)
750    shape = (1,) * N
751    if sparse:
752        res = ()
753    else:
754        res = torch.empty((N,) + dimensions, dtype=dtype)
755    for i, dim in enumerate(dimensions):
756        idx = torch.arange(dim, dtype=dtype).reshape(
757            shape[:i] + (dim,) + shape[i + 1 :]
758        )
759        if sparse:
760            res = res + (idx,)
761        else:
762            res[i] = idx
763    return res
764
765
766# ### tri*-something ###
767
768
769def tril(m: ArrayLike, k=0):
770    return torch.tril(m, k)
771
772
773def triu(m: ArrayLike, k=0):
774    return torch.triu(m, k)
775
776
777def tril_indices(n, k=0, m=None):
778    if m is None:
779        m = n
780    return torch.tril_indices(n, m, offset=k)
781
782
783def triu_indices(n, k=0, m=None):
784    if m is None:
785        m = n
786    return torch.triu_indices(n, m, offset=k)
787
788
789def tril_indices_from(arr: ArrayLike, k=0):
790    if arr.ndim != 2:
791        raise ValueError("input array must be 2-d")
792    # Return a tensor rather than a tuple to avoid a graphbreak
793    return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
794
795
796def triu_indices_from(arr: ArrayLike, k=0):
797    if arr.ndim != 2:
798        raise ValueError("input array must be 2-d")
799    # Return a tensor rather than a tuple to avoid a graphbreak
800    return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
801
802
803def tri(
804    N,
805    M=None,
806    k=0,
807    dtype: Optional[DTypeLike] = None,
808    *,
809    like: NotImplementedType = None,
810):
811    if M is None:
812        M = N
813    tensor = torch.ones((N, M), dtype=dtype)
814    return torch.tril(tensor, diagonal=k)
815
816
817# ### equality, equivalence, allclose ###
818
819
820def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
821    dtype = _dtypes_impl.result_type_impl(a, b)
822    a = _util.cast_if_needed(a, dtype)
823    b = _util.cast_if_needed(b, dtype)
824    return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
825
826
827def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
828    dtype = _dtypes_impl.result_type_impl(a, b)
829    a = _util.cast_if_needed(a, dtype)
830    b = _util.cast_if_needed(b, dtype)
831    return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
832
833
834def _tensor_equal(a1, a2, equal_nan=False):
835    # Implementation of array_equal/array_equiv.
836    if a1.shape != a2.shape:
837        return False
838    cond = a1 == a2
839    if equal_nan:
840        cond = cond | (torch.isnan(a1) & torch.isnan(a2))
841    return cond.all().item()
842
843
844def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False):
845    return _tensor_equal(a1, a2, equal_nan=equal_nan)
846
847
848def array_equiv(a1: ArrayLike, a2: ArrayLike):
849    # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not
850    try:
851        a1_t, a2_t = torch.broadcast_tensors(a1, a2)
852    except RuntimeError:
853        # failed to broadcast => not equivalent
854        return False
855    return _tensor_equal(a1_t, a2_t)
856
857
858def nan_to_num(
859    x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None
860):
861    # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble'
862    if x.is_complex():
863        re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf)
864        im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf)
865        return re + 1j * im
866    else:
867        return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
868
869
870# ### put/take_along_axis ###
871
872
873def take(
874    a: ArrayLike,
875    indices: ArrayLike,
876    axis=None,
877    out: Optional[OutArray] = None,
878    mode: NotImplementedType = "raise",
879):
880    (a,), axis = _util.axis_none_flatten(a, axis=axis)
881    axis = _util.normalize_axis_index(axis, a.ndim)
882    idx = (slice(None),) * axis + (indices, ...)
883    result = a[idx]
884    return result
885
886
887def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
888    (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
889    axis = _util.normalize_axis_index(axis, arr.ndim)
890    return torch.take_along_dim(arr, indices, axis)
891
892
893def put(
894    a: NDArray,
895    indices: ArrayLike,
896    values: ArrayLike,
897    mode: NotImplementedType = "raise",
898):
899    v = values.type(a.dtype)
900    # If indices is larger than v, expand v to at least the size of indices. Any
901    # unnecessary trailing elements are then trimmed.
902    if indices.numel() > v.numel():
903        ratio = (indices.numel() + v.numel() - 1) // v.numel()
904        v = v.unsqueeze(0).expand((ratio,) + v.shape)
905    # Trim unnecessary elements, regardless if v was expanded or not. Note
906    # np.put() trims v to match indices by default too.
907    if indices.numel() < v.numel():
908        v = v.flatten()
909        v = v[: indices.numel()]
910    a.put_(indices, v)
911    return None
912
913
914def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
915    (arr,), axis = _util.axis_none_flatten(arr, axis=axis)
916    axis = _util.normalize_axis_index(axis, arr.ndim)
917
918    indices, values = torch.broadcast_tensors(indices, values)
919    values = _util.cast_if_needed(values, arr.dtype)
920    result = torch.scatter(arr, axis, indices, values)
921    arr.copy_(result.reshape(arr.shape))
922    return None
923
924
925def choose(
926    a: ArrayLike,
927    choices: Sequence[ArrayLike],
928    out: Optional[OutArray] = None,
929    mode: NotImplementedType = "raise",
930):
931    # First, broadcast elements of `choices`
932    choices = torch.stack(torch.broadcast_tensors(*choices))
933
934    # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`:
935    # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939)
936    idx_list = [
937        torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1))
938        for i, dim in enumerate(choices.shape)
939    ]
940
941    idx_list[0] = a
942    return choices[idx_list].squeeze(0)
943
944
945# ### unique et al. ###
946
947
948def unique(
949    ar: ArrayLike,
950    return_index: NotImplementedType = False,
951    return_inverse=False,
952    return_counts=False,
953    axis=None,
954    *,
955    equal_nan: NotImplementedType = True,
956):
957    (ar,), axis = _util.axis_none_flatten(ar, axis=axis)
958    axis = _util.normalize_axis_index(axis, ar.ndim)
959
960    result = torch.unique(
961        ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis
962    )
963
964    return result
965
966
967def nonzero(a: ArrayLike):
968    return torch.nonzero(a, as_tuple=True)
969
970
971def argwhere(a: ArrayLike):
972    return torch.argwhere(a)
973
974
975def flatnonzero(a: ArrayLike):
976    return torch.flatten(a).nonzero(as_tuple=True)[0]
977
978
979def clip(
980    a: ArrayLike,
981    min: Optional[ArrayLike] = None,
982    max: Optional[ArrayLike] = None,
983    out: Optional[OutArray] = None,
984):
985    return torch.clamp(a, min, max)
986
987
988def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
989    return torch.repeat_interleave(a, repeats, axis)
990
991
992def tile(A: ArrayLike, reps):
993    if isinstance(reps, int):
994        reps = (reps,)
995    return torch.tile(A, reps)
996
997
998def resize(a: ArrayLike, new_shape=None):
999    # implementation vendored from
1000    # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
1001    if new_shape is None:
1002        return a
1003
1004    if isinstance(new_shape, int):
1005        new_shape = (new_shape,)
1006
1007    a = a.flatten()
1008
1009    new_size = 1
1010    for dim_length in new_shape:
1011        new_size *= dim_length
1012        if dim_length < 0:
1013            raise ValueError("all elements of `new_shape` must be non-negative")
1014
1015    if a.numel() == 0 or new_size == 0:
1016        # First case must zero fill. The second would have repeats == 0.
1017        return torch.zeros(new_shape, dtype=a.dtype)
1018
1019    repeats = -(-new_size // a.numel())  # ceil division
1020    a = concatenate((a,) * repeats)[:new_size]
1021
1022    return reshape(a, new_shape)
1023
1024
1025# ### diag et al. ###
1026
1027
1028def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
1029    axis1 = _util.normalize_axis_index(axis1, a.ndim)
1030    axis2 = _util.normalize_axis_index(axis2, a.ndim)
1031    return torch.diagonal(a, offset, axis1, axis2)
1032
1033
1034def trace(
1035    a: ArrayLike,
1036    offset=0,
1037    axis1=0,
1038    axis2=1,
1039    dtype: Optional[DTypeLike] = None,
1040    out: Optional[OutArray] = None,
1041):
1042    result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
1043    return result
1044
1045
1046def eye(
1047    N,
1048    M=None,
1049    k=0,
1050    dtype: Optional[DTypeLike] = None,
1051    order: NotImplementedType = "C",
1052    *,
1053    like: NotImplementedType = None,
1054):
1055    if dtype is None:
1056        dtype = _dtypes_impl.default_dtypes().float_dtype
1057    if M is None:
1058        M = N
1059    z = torch.zeros(N, M, dtype=dtype)
1060    z.diagonal(k).fill_(1)
1061    return z
1062
1063
1064def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None):
1065    return torch.eye(n, dtype=dtype)
1066
1067
1068def diag(v: ArrayLike, k=0):
1069    return torch.diag(v, k)
1070
1071
1072def diagflat(v: ArrayLike, k=0):
1073    return torch.diagflat(v, k)
1074
1075
1076def diag_indices(n, ndim=2):
1077    idx = torch.arange(n)
1078    return (idx,) * ndim
1079
1080
1081def diag_indices_from(arr: ArrayLike):
1082    if not arr.ndim >= 2:
1083        raise ValueError("input array must be at least 2-d")
1084    # For more than d=2, the strided formula is only valid for arrays with
1085    # all dimensions equal, so we check first.
1086    s = arr.shape
1087    if s[1:] != s[:-1]:
1088        raise ValueError("All dimensions of input must be of equal length")
1089    return diag_indices(s[0], arr.ndim)
1090
1091
1092def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
1093    if a.ndim < 2:
1094        raise ValueError("array must be at least 2-d")
1095    if val.numel() == 0 and not wrap:
1096        a.fill_diagonal_(val)
1097        return a
1098
1099    if val.ndim == 0:
1100        val = val.unsqueeze(0)
1101
1102    # torch.Tensor.fill_diagonal_ only accepts scalars
1103    # If the size of val is too large, then val is trimmed
1104    if a.ndim == 2:
1105        tall = a.shape[0] > a.shape[1]
1106        # wrap does nothing for wide matrices...
1107        if not wrap or not tall:
1108            # Never wraps
1109            diag = a.diagonal()
1110            diag.copy_(val[: diag.numel()])
1111        else:
1112            # wraps and tall... leaving one empty line between diagonals?!
1113            max_, min_ = a.shape
1114            idx = torch.arange(max_ - max_ // (min_ + 1))
1115            mod = idx % min_
1116            div = idx // min_
1117            a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()]
1118    else:
1119        idx = diag_indices_from(a)
1120        # a.shape = (n, n, ..., n)
1121        a[idx] = val[: a.shape[0]]
1122
1123    return a
1124
1125
1126def vdot(a: ArrayLike, b: ArrayLike, /):
1127    # 1. torch only accepts 1D arrays, numpy flattens
1128    # 2. torch requires matching dtype, while numpy casts (?)
1129    t_a, t_b = torch.atleast_1d(a, b)
1130    if t_a.ndim > 1:
1131        t_a = t_a.flatten()
1132    if t_b.ndim > 1:
1133        t_b = t_b.flatten()
1134
1135    dtype = _dtypes_impl.result_type_impl(t_a, t_b)
1136    is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu)
1137    is_bool = dtype == torch.bool
1138
1139    # work around torch's "dot" not implemented for 'Half', 'Bool'
1140    if is_half:
1141        dtype = torch.float32
1142    elif is_bool:
1143        dtype = torch.uint8
1144
1145    t_a = _util.cast_if_needed(t_a, dtype)
1146    t_b = _util.cast_if_needed(t_b, dtype)
1147
1148    result = torch.vdot(t_a, t_b)
1149
1150    if is_half:
1151        result = result.to(torch.float16)
1152    elif is_bool:
1153        result = result.to(torch.bool)
1154
1155    return result
1156
1157
1158def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
1159    if isinstance(axes, (list, tuple)):
1160        axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1161
1162    target_dtype = _dtypes_impl.result_type_impl(a, b)
1163    a = _util.cast_if_needed(a, target_dtype)
1164    b = _util.cast_if_needed(b, target_dtype)
1165
1166    return torch.tensordot(a, b, dims=axes)
1167
1168
1169def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1170    dtype = _dtypes_impl.result_type_impl(a, b)
1171    is_bool = dtype == torch.bool
1172    if is_bool:
1173        dtype = torch.uint8
1174
1175    a = _util.cast_if_needed(a, dtype)
1176    b = _util.cast_if_needed(b, dtype)
1177
1178    if a.ndim == 0 or b.ndim == 0:
1179        result = a * b
1180    else:
1181        result = torch.matmul(a, b)
1182
1183    if is_bool:
1184        result = result.to(torch.bool)
1185
1186    return result
1187
1188
1189def inner(a: ArrayLike, b: ArrayLike, /):
1190    dtype = _dtypes_impl.result_type_impl(a, b)
1191    is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu)
1192    is_bool = dtype == torch.bool
1193
1194    if is_half:
1195        # work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
1196        dtype = torch.float32
1197    elif is_bool:
1198        dtype = torch.uint8
1199
1200    a = _util.cast_if_needed(a, dtype)
1201    b = _util.cast_if_needed(b, dtype)
1202
1203    result = torch.inner(a, b)
1204
1205    if is_half:
1206        result = result.to(torch.float16)
1207    elif is_bool:
1208        result = result.to(torch.bool)
1209    return result
1210
1211
1212def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
1213    return torch.outer(a, b)
1214
1215
1216def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None):
1217    # implementation vendored from
1218    # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685
1219    if axis is not None:
1220        axisa, axisb, axisc = (axis,) * 3
1221
1222    # Check axisa and axisb are within bounds
1223    axisa = _util.normalize_axis_index(axisa, a.ndim)
1224    axisb = _util.normalize_axis_index(axisb, b.ndim)
1225
1226    # Move working axis to the end of the shape
1227    a = torch.moveaxis(a, axisa, -1)
1228    b = torch.moveaxis(b, axisb, -1)
1229    msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)"
1230    if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
1231        raise ValueError(msg)
1232
1233    # Create the output array
1234    shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape)
1235    if a.shape[-1] == 3 or b.shape[-1] == 3:
1236        shape += (3,)
1237        # Check axisc is within bounds
1238        axisc = _util.normalize_axis_index(axisc, len(shape))
1239    dtype = _dtypes_impl.result_type_impl(a, b)
1240    cp = torch.empty(shape, dtype=dtype)
1241
1242    # recast arrays as dtype
1243    a = _util.cast_if_needed(a, dtype)
1244    b = _util.cast_if_needed(b, dtype)
1245
1246    # create local aliases for readability
1247    a0 = a[..., 0]
1248    a1 = a[..., 1]
1249    if a.shape[-1] == 3:
1250        a2 = a[..., 2]
1251    b0 = b[..., 0]
1252    b1 = b[..., 1]
1253    if b.shape[-1] == 3:
1254        b2 = b[..., 2]
1255    if cp.ndim != 0 and cp.shape[-1] == 3:
1256        cp0 = cp[..., 0]
1257        cp1 = cp[..., 1]
1258        cp2 = cp[..., 2]
1259
1260    if a.shape[-1] == 2:
1261        if b.shape[-1] == 2:
1262            # a0 * b1 - a1 * b0
1263            cp[...] = a0 * b1 - a1 * b0
1264            return cp
1265        else:
1266            assert b.shape[-1] == 3
1267            # cp0 = a1 * b2 - 0  (a2 = 0)
1268            # cp1 = 0 - a0 * b2  (a2 = 0)
1269            # cp2 = a0 * b1 - a1 * b0
1270            cp0[...] = a1 * b2
1271            cp1[...] = -a0 * b2
1272            cp2[...] = a0 * b1 - a1 * b0
1273    else:
1274        assert a.shape[-1] == 3
1275        if b.shape[-1] == 3:
1276            cp0[...] = a1 * b2 - a2 * b1
1277            cp1[...] = a2 * b0 - a0 * b2
1278            cp2[...] = a0 * b1 - a1 * b0
1279        else:
1280            assert b.shape[-1] == 2
1281            cp0[...] = -a2 * b1
1282            cp1[...] = a2 * b0
1283            cp2[...] = a0 * b1 - a1 * b0
1284
1285    return torch.moveaxis(cp, -1, axisc)
1286
1287
1288def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
1289    # Have to manually normalize *operands and **kwargs, following the NumPy signature
1290    # We have a local import to avoid poluting the global space, as it will be then
1291    # exported in funcs.py
1292    from ._ndarray import ndarray
1293    from ._normalizations import (
1294        maybe_copy_to,
1295        normalize_array_like,
1296        normalize_casting,
1297        normalize_dtype,
1298        wrap_tensors,
1299    )
1300
1301    dtype = normalize_dtype(dtype)
1302    casting = normalize_casting(casting)
1303    if out is not None and not isinstance(out, ndarray):
1304        raise TypeError("'out' must be an array")
1305    if order != "K":
1306        raise NotImplementedError("'order' parameter is not supported.")
1307
1308    # parse arrays and normalize them
1309    sublist_format = not isinstance(operands[0], str)
1310    if sublist_format:
1311        # op, str, op, str ... [sublistout] format: normalize every other argument
1312
1313        # - if sublistout is not given, the length of operands is even, and we pick
1314        #   odd-numbered elements, which are arrays.
1315        # - if sublistout is given, the length of operands is odd, we peel off
1316        #   the last one, and pick odd-numbered elements, which are arrays.
1317        #   Without [:-1], we would have picked sublistout, too.
1318        array_operands = operands[:-1][::2]
1319    else:
1320        # ("ij->", arrays) format
1321        subscripts, array_operands = operands[0], operands[1:]
1322
1323    tensors = [normalize_array_like(op) for op in array_operands]
1324    target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype
1325
1326    # work around 'bmm' not implemented for 'Half' etc
1327    is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors)
1328    if is_half:
1329        target_dtype = torch.float32
1330
1331    is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
1332    if is_short_int:
1333        target_dtype = torch.int64
1334
1335    tensors = _util.typecast_tensors(tensors, target_dtype, casting)
1336
1337    from torch.backends import opt_einsum
1338
1339    try:
1340        # set the global state to handle the optimize=... argument, restore on exit
1341        if opt_einsum.is_available():
1342            old_strategy = torch.backends.opt_einsum.strategy
1343            old_enabled = torch.backends.opt_einsum.enabled
1344
1345            # torch.einsum calls opt_einsum.contract_path, which runs into
1346            # https://github.com/dgasmith/opt_einsum/issues/219
1347            # for strategy={True, False}
1348            if optimize is True:
1349                optimize = "auto"
1350            elif optimize is False:
1351                torch.backends.opt_einsum.enabled = False
1352
1353            torch.backends.opt_einsum.strategy = optimize
1354
1355        if sublist_format:
1356            # recombine operands
1357            sublists = operands[1::2]
1358            has_sublistout = len(operands) % 2 == 1
1359            if has_sublistout:
1360                sublistout = operands[-1]
1361            operands = list(itertools.chain.from_iterable(zip(tensors, sublists)))
1362            if has_sublistout:
1363                operands.append(sublistout)
1364
1365            result = torch.einsum(*operands)
1366        else:
1367            result = torch.einsum(subscripts, *tensors)
1368
1369    finally:
1370        if opt_einsum.is_available():
1371            torch.backends.opt_einsum.strategy = old_strategy
1372            torch.backends.opt_einsum.enabled = old_enabled
1373
1374    result = maybe_copy_to(out, result)
1375    return wrap_tensors(result)
1376
1377
1378# ### sort and partition ###
1379
1380
1381def _sort_helper(tensor, axis, kind, order):
1382    if tensor.dtype.is_complex:
1383        raise NotImplementedError(f"sorting {tensor.dtype} is not supported")
1384    (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
1385    axis = _util.normalize_axis_index(axis, tensor.ndim)
1386
1387    stable = kind == "stable"
1388
1389    return tensor, axis, stable
1390
1391
1392def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
1393    # `order` keyword arg is only relevant for structured dtypes; so not supported here.
1394    a, axis, stable = _sort_helper(a, axis, kind, order)
1395    result = torch.sort(a, dim=axis, stable=stable)
1396    return result.values
1397
1398
1399def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
1400    a, axis, stable = _sort_helper(a, axis, kind, order)
1401    return torch.argsort(a, dim=axis, stable=stable)
1402
1403
1404def searchsorted(
1405    a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
1406):
1407    if a.dtype.is_complex:
1408        raise NotImplementedError(f"searchsorted with dtype={a.dtype}")
1409
1410    return torch.searchsorted(a, v, side=side, sorter=sorter)
1411
1412
1413# ### swap/move/roll axis ###
1414
1415
1416def moveaxis(a: ArrayLike, source, destination):
1417    source = _util.normalize_axis_tuple(source, a.ndim, "source")
1418    destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
1419    return torch.moveaxis(a, source, destination)
1420
1421
1422def swapaxes(a: ArrayLike, axis1, axis2):
1423    axis1 = _util.normalize_axis_index(axis1, a.ndim)
1424    axis2 = _util.normalize_axis_index(axis2, a.ndim)
1425    return torch.swapaxes(a, axis1, axis2)
1426
1427
1428def rollaxis(a: ArrayLike, axis, start=0):
1429    # Straight vendor from:
1430    # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
1431    #
1432    # Also note this function in NumPy is mostly retained for backwards compat
1433    # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
1434    # so let's not touch it unless hard pressed.
1435    n = a.ndim
1436    axis = _util.normalize_axis_index(axis, n)
1437    if start < 0:
1438        start += n
1439    msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
1440    if not (0 <= start < n + 1):
1441        raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
1442    if axis < start:
1443        # it's been removed
1444        start -= 1
1445    if axis == start:
1446        # numpy returns a view, here we try returning the tensor itself
1447        # return tensor[...]
1448        return a
1449    axes = list(range(0, n))
1450    axes.remove(axis)
1451    axes.insert(start, axis)
1452    return a.view(axes)
1453
1454
1455def roll(a: ArrayLike, shift, axis=None):
1456    if axis is not None:
1457        axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
1458        if not isinstance(shift, tuple):
1459            shift = (shift,) * len(axis)
1460    return torch.roll(a, shift, axis)
1461
1462
1463# ### shape manipulations ###
1464
1465
1466def squeeze(a: ArrayLike, axis=None):
1467    if axis == ():
1468        result = a
1469    elif axis is None:
1470        result = a.squeeze()
1471    else:
1472        if isinstance(axis, tuple):
1473            result = a
1474            for ax in axis:
1475                result = a.squeeze(ax)
1476        else:
1477            result = a.squeeze(axis)
1478    return result
1479
1480
1481def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"):
1482    # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
1483    newshape = newshape[0] if len(newshape) == 1 else newshape
1484    return a.reshape(newshape)
1485
1486
1487# NB: cannot use torch.reshape(a, newshape) above, because of
1488# (Pdb) torch.reshape(torch.as_tensor([1]), 1)
1489# *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int
1490
1491
1492def transpose(a: ArrayLike, axes=None):
1493    # numpy allows both .transpose(sh) and .transpose(*sh)
1494    # also older code uses axes being a list
1495    if axes in [(), None, (None,)]:
1496        axes = tuple(reversed(range(a.ndim)))
1497    elif len(axes) == 1:
1498        axes = axes[0]
1499    return a.permute(axes)
1500
1501
1502def ravel(a: ArrayLike, order: NotImplementedType = "C"):
1503    return torch.flatten(a)
1504
1505
1506def diff(
1507    a: ArrayLike,
1508    n=1,
1509    axis=-1,
1510    prepend: Optional[ArrayLike] = None,
1511    append: Optional[ArrayLike] = None,
1512):
1513    axis = _util.normalize_axis_index(axis, a.ndim)
1514
1515    if n < 0:
1516        raise ValueError(f"order must be non-negative but got {n}")
1517
1518    if n == 0:
1519        # match numpy and return the input immediately
1520        return a
1521
1522    if prepend is not None:
1523        shape = list(a.shape)
1524        shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1
1525        prepend = torch.broadcast_to(prepend, shape)
1526
1527    if append is not None:
1528        shape = list(a.shape)
1529        shape[axis] = append.shape[axis] if append.ndim > 0 else 1
1530        append = torch.broadcast_to(append, shape)
1531
1532    return torch.diff(a, n, axis=axis, prepend=prepend, append=append)
1533
1534
1535# ### math functions ###
1536
1537
1538def angle(z: ArrayLike, deg=False):
1539    result = torch.angle(z)
1540    if deg:
1541        result = result * (180 / torch.pi)
1542    return result
1543
1544
1545def sinc(x: ArrayLike):
1546    return torch.sinc(x)
1547
1548
1549# NB: have to normalize *varargs manually
1550def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
1551    N = f.ndim  # number of dimensions
1552
1553    varargs = _util.ndarrays_to_tensors(varargs)
1554
1555    if axis is None:
1556        axes = tuple(range(N))
1557    else:
1558        axes = _util.normalize_axis_tuple(axis, N)
1559
1560    len_axes = len(axes)
1561    n = len(varargs)
1562    if n == 0:
1563        # no spacing argument - use 1 in all axes
1564        dx = [1.0] * len_axes
1565    elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
1566        # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
1567        dx = varargs * len_axes
1568    elif n == len_axes:
1569        # scalar or 1d array for each axis
1570        dx = list(varargs)
1571        for i, distances in enumerate(dx):
1572            distances = torch.as_tensor(distances)
1573            if distances.ndim == 0:
1574                continue
1575            elif distances.ndim != 1:
1576                raise ValueError("distances must be either scalars or 1d")
1577            if len(distances) != f.shape[axes[i]]:
1578                raise ValueError(
1579                    "when 1d, distances must match "
1580                    "the length of the corresponding dimension"
1581                )
1582            if not (distances.dtype.is_floating_point or distances.dtype.is_complex):
1583                distances = distances.double()
1584
1585            diffx = torch.diff(distances)
1586            # if distances are constant reduce to the scalar case
1587            # since it brings a consistent speedup
1588            if (diffx == diffx[0]).all():
1589                diffx = diffx[0]
1590            dx[i] = diffx
1591    else:
1592        raise TypeError("invalid number of arguments")
1593
1594    if edge_order > 2:
1595        raise ValueError("'edge_order' greater than 2 not supported")
1596
1597    # use central differences on interior and one-sided differences on the
1598    # endpoints. This preserves second order-accuracy over the full domain.
1599
1600    outvals = []
1601
1602    # create slice objects --- initially all are [:, :, ..., :]
1603    slice1 = [slice(None)] * N
1604    slice2 = [slice(None)] * N
1605    slice3 = [slice(None)] * N
1606    slice4 = [slice(None)] * N
1607
1608    otype = f.dtype
1609    if _dtypes_impl.python_type_for_torch(otype) in (int, bool):
1610        # Convert to floating point.
1611        # First check if f is a numpy integer type; if so, convert f to float64
1612        # to avoid modular arithmetic when computing the changes in f.
1613        f = f.double()
1614        otype = torch.float64
1615
1616    for axis, ax_dx in zip(axes, dx):
1617        if f.shape[axis] < edge_order + 1:
1618            raise ValueError(
1619                "Shape of array too small to calculate a numerical gradient, "
1620                "at least (edge_order + 1) elements are required."
1621            )
1622        # result allocation
1623        out = torch.empty_like(f, dtype=otype)
1624
1625        # spacing for the current axis (NB: np.ndim(ax_dx) == 0)
1626        uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0
1627
1628        # Numerical differentiation: 2nd order interior
1629        slice1[axis] = slice(1, -1)
1630        slice2[axis] = slice(None, -2)
1631        slice3[axis] = slice(1, -1)
1632        slice4[axis] = slice(2, None)
1633
1634        if uniform_spacing:
1635            out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
1636        else:
1637            dx1 = ax_dx[0:-1]
1638            dx2 = ax_dx[1:]
1639            a = -(dx2) / (dx1 * (dx1 + dx2))
1640            b = (dx2 - dx1) / (dx1 * dx2)
1641            c = dx1 / (dx2 * (dx1 + dx2))
1642            # fix the shape for broadcasting
1643            shape = [1] * N
1644            shape[axis] = -1
1645            a = a.reshape(shape)
1646            b = b.reshape(shape)
1647            c = c.reshape(shape)
1648            # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:]
1649            out[tuple(slice1)] = (
1650                a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
1651            )
1652
1653        # Numerical differentiation: 1st order edges
1654        if edge_order == 1:
1655            slice1[axis] = 0
1656            slice2[axis] = 1
1657            slice3[axis] = 0
1658            dx_0 = ax_dx if uniform_spacing else ax_dx[0]
1659            # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
1660            out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
1661
1662            slice1[axis] = -1
1663            slice2[axis] = -1
1664            slice3[axis] = -2
1665            dx_n = ax_dx if uniform_spacing else ax_dx[-1]
1666            # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2])
1667            out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
1668
1669        # Numerical differentiation: 2nd order edges
1670        else:
1671            slice1[axis] = 0
1672            slice2[axis] = 0
1673            slice3[axis] = 1
1674            slice4[axis] = 2
1675            if uniform_spacing:
1676                a = -1.5 / ax_dx
1677                b = 2.0 / ax_dx
1678                c = -0.5 / ax_dx
1679            else:
1680                dx1 = ax_dx[0]
1681                dx2 = ax_dx[1]
1682                a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2))
1683                b = (dx1 + dx2) / (dx1 * dx2)
1684                c = -dx1 / (dx2 * (dx1 + dx2))
1685            # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2]
1686            out[tuple(slice1)] = (
1687                a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
1688            )
1689
1690            slice1[axis] = -1
1691            slice2[axis] = -3
1692            slice3[axis] = -2
1693            slice4[axis] = -1
1694            if uniform_spacing:
1695                a = 0.5 / ax_dx
1696                b = -2.0 / ax_dx
1697                c = 1.5 / ax_dx
1698            else:
1699                dx1 = ax_dx[-2]
1700                dx2 = ax_dx[-1]
1701                a = (dx2) / (dx1 * (dx1 + dx2))
1702                b = -(dx2 + dx1) / (dx1 * dx2)
1703                c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2))
1704            # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
1705            out[tuple(slice1)] = (
1706                a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
1707            )
1708
1709        outvals.append(out)
1710
1711        # reset the slice object in this dimension to ":"
1712        slice1[axis] = slice(None)
1713        slice2[axis] = slice(None)
1714        slice3[axis] = slice(None)
1715        slice4[axis] = slice(None)
1716
1717    if len_axes == 1:
1718        return outvals[0]
1719    else:
1720        return outvals
1721
1722
1723# ### Type/shape etc queries ###
1724
1725
1726def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
1727    if a.is_floating_point():
1728        result = torch.round(a, decimals=decimals)
1729    elif a.is_complex():
1730        # RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
1731        result = torch.complex(
1732            torch.round(a.real, decimals=decimals),
1733            torch.round(a.imag, decimals=decimals),
1734        )
1735    else:
1736        # RuntimeError: "round_cpu" not implemented for 'int'
1737        result = a
1738    return result
1739
1740
1741around = round
1742round_ = round
1743
1744
1745def real_if_close(a: ArrayLike, tol=100):
1746    if not torch.is_complex(a):
1747        return a
1748    if tol > 1:
1749        # Undocumented in numpy: if tol < 1, it's an absolute tolerance!
1750        # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
1751        # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
1752        tol = tol * torch.finfo(a.dtype).eps
1753
1754    mask = torch.abs(a.imag) < tol
1755    return a.real if mask.all() else a
1756
1757
1758def real(a: ArrayLike):
1759    return torch.real(a)
1760
1761
1762def imag(a: ArrayLike):
1763    if a.is_complex():
1764        return a.imag
1765    return torch.zeros_like(a)
1766
1767
1768def iscomplex(x: ArrayLike):
1769    if torch.is_complex(x):
1770        return x.imag != 0
1771    return torch.zeros_like(x, dtype=torch.bool)
1772
1773
1774def isreal(x: ArrayLike):
1775    if torch.is_complex(x):
1776        return x.imag == 0
1777    return torch.ones_like(x, dtype=torch.bool)
1778
1779
1780def iscomplexobj(x: ArrayLike):
1781    return torch.is_complex(x)
1782
1783
1784def isrealobj(x: ArrayLike):
1785    return not torch.is_complex(x)
1786
1787
1788def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
1789    return torch.isneginf(x)
1790
1791
1792def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
1793    return torch.isposinf(x)
1794
1795
1796def i0(x: ArrayLike):
1797    return torch.special.i0(x)
1798
1799
1800def isscalar(a):
1801    # We need to use normalize_array_like, but we don't want to export it in funcs.py
1802    from ._normalizations import normalize_array_like
1803
1804    try:
1805        t = normalize_array_like(a)
1806        return t.numel() == 1
1807    except Exception:
1808        return False
1809
1810
1811# ### Filter windows ###
1812
1813
1814def hamming(M):
1815    dtype = _dtypes_impl.default_dtypes().float_dtype
1816    return torch.hamming_window(M, periodic=False, dtype=dtype)
1817
1818
1819def hanning(M):
1820    dtype = _dtypes_impl.default_dtypes().float_dtype
1821    return torch.hann_window(M, periodic=False, dtype=dtype)
1822
1823
1824def kaiser(M, beta):
1825    dtype = _dtypes_impl.default_dtypes().float_dtype
1826    return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
1827
1828
1829def blackman(M):
1830    dtype = _dtypes_impl.default_dtypes().float_dtype
1831    return torch.blackman_window(M, periodic=False, dtype=dtype)
1832
1833
1834def bartlett(M):
1835    dtype = _dtypes_impl.default_dtypes().float_dtype
1836    return torch.bartlett_window(M, periodic=False, dtype=dtype)
1837
1838
1839# ### Dtype routines ###
1840
1841# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
1842
1843
1844array_type = [
1845    [torch.float16, torch.float32, torch.float64],
1846    [None, torch.complex64, torch.complex128],
1847]
1848array_precision = {
1849    torch.float16: 0,
1850    torch.float32: 1,
1851    torch.float64: 2,
1852    torch.complex64: 1,
1853    torch.complex128: 2,
1854}
1855
1856
1857def common_type(*tensors: ArrayLike):
1858    is_complex = False
1859    precision = 0
1860    for a in tensors:
1861        t = a.dtype
1862        if iscomplexobj(a):
1863            is_complex = True
1864        if not (t.is_floating_point or t.is_complex):
1865            p = 2  # array_precision[_nx.double]
1866        else:
1867            p = array_precision.get(t, None)
1868            if p is None:
1869                raise TypeError("can't get common type for non-numeric array")
1870        precision = builtins.max(precision, p)
1871    if is_complex:
1872        return array_type[1][precision]
1873    else:
1874        return array_type[0][precision]
1875
1876
1877# ### histograms ###
1878
1879
1880def histogram(
1881    a: ArrayLike,
1882    bins: ArrayLike = 10,
1883    range=None,
1884    normed=None,
1885    weights: Optional[ArrayLike] = None,
1886    density=None,
1887):
1888    if normed is not None:
1889        raise ValueError("normed argument is deprecated, use density= instead")
1890
1891    if weights is not None and weights.dtype.is_complex:
1892        raise NotImplementedError("complex weights histogram.")
1893
1894    is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
1895    is_w_int = weights is None or not weights.dtype.is_floating_point
1896    if is_a_int:
1897        a = a.double()
1898
1899    if weights is not None:
1900        weights = _util.cast_if_needed(weights, a.dtype)
1901
1902    if isinstance(bins, torch.Tensor):
1903        if bins.ndim == 0:
1904            # bins was a single int
1905            bins = operator.index(bins)
1906        else:
1907            bins = _util.cast_if_needed(bins, a.dtype)
1908
1909    if range is None:
1910        h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
1911    else:
1912        h, b = torch.histogram(
1913            a, bins, range=range, weight=weights, density=bool(density)
1914        )
1915
1916    if not density and is_w_int:
1917        h = h.long()
1918    if is_a_int:
1919        b = b.long()
1920
1921    return h, b
1922
1923
1924def histogram2d(
1925    x,
1926    y,
1927    bins=10,
1928    range: Optional[ArrayLike] = None,
1929    normed=None,
1930    weights: Optional[ArrayLike] = None,
1931    density=None,
1932):
1933    # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821
1934    if len(x) != len(y):
1935        raise ValueError("x and y must have the same length.")
1936
1937    try:
1938        N = len(bins)
1939    except TypeError:
1940        N = 1
1941
1942    if N != 1 and N != 2:
1943        bins = [bins, bins]
1944
1945    h, e = histogramdd((x, y), bins, range, normed, weights, density)
1946
1947    return h, e[0], e[1]
1948
1949
1950def histogramdd(
1951    sample,
1952    bins=10,
1953    range: Optional[ArrayLike] = None,
1954    normed=None,
1955    weights: Optional[ArrayLike] = None,
1956    density=None,
1957):
1958    # have to normalize manually because `sample` interpretation differs
1959    # for a list of lists and a 2D array
1960    if normed is not None:
1961        raise ValueError("normed argument is deprecated, use density= instead")
1962
1963    from ._normalizations import normalize_array_like, normalize_seq_array_like
1964
1965    if isinstance(sample, (list, tuple)):
1966        sample = normalize_array_like(sample).T
1967    else:
1968        sample = normalize_array_like(sample)
1969
1970    sample = torch.atleast_2d(sample)
1971
1972    if not (sample.dtype.is_floating_point or sample.dtype.is_complex):
1973        sample = sample.double()
1974
1975    # bins is either an int, or a sequence of ints or a sequence of arrays
1976    bins_is_array = not (
1977        isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins)
1978    )
1979    if bins_is_array:
1980        bins = normalize_seq_array_like(bins)
1981        bins_dtypes = [b.dtype for b in bins]
1982        bins = [_util.cast_if_needed(b, sample.dtype) for b in bins]
1983
1984    if range is not None:
1985        range = range.flatten().tolist()
1986
1987    if weights is not None:
1988        # range=... is required : interleave min and max values per dimension
1989        mm = sample.aminmax(dim=0)
1990        range = torch.cat(mm).reshape(2, -1).T.flatten()
1991        range = tuple(range.tolist())
1992        weights = _util.cast_if_needed(weights, sample.dtype)
1993        w_kwd = {"weight": weights}
1994    else:
1995        w_kwd = {}
1996
1997    h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd)
1998
1999    if bins_is_array:
2000        b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)]
2001
2002    return h, b
2003
2004
2005# ### odds and ends
2006
2007
2008def min_scalar_type(a: ArrayLike, /):
2009    # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288
2010
2011    from ._dtypes import DType
2012
2013    if a.numel() > 1:
2014        # numpy docs: "For non-scalar array a, returns the vector's dtype unmodified."
2015        return DType(a.dtype)
2016
2017    if a.dtype == torch.bool:
2018        dtype = torch.bool
2019
2020    elif a.dtype.is_complex:
2021        fi = torch.finfo(torch.float32)
2022        fits_in_single = a.dtype == torch.complex64 or (
2023            fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max
2024        )
2025        dtype = torch.complex64 if fits_in_single else torch.complex128
2026
2027    elif a.dtype.is_floating_point:
2028        for dt in [torch.float16, torch.float32, torch.float64]:
2029            fi = torch.finfo(dt)
2030            if fi.min <= a <= fi.max:
2031                dtype = dt
2032                break
2033    else:
2034        # must be integer
2035        for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
2036            # Prefer unsigned int where possible, as numpy does.
2037            ii = torch.iinfo(dt)
2038            if ii.min <= a <= ii.max:
2039                dtype = dt
2040                break
2041
2042    return DType(dtype)
2043
2044
2045def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs):
2046    if mode != "constant":
2047        raise NotImplementedError
2048    value = kwargs.get("constant_values", 0)
2049    # `value` must be a python scalar for torch.nn.functional.pad
2050    typ = _dtypes_impl.python_type_for_torch(array.dtype)
2051    value = typ(value)
2052
2053    pad_width = torch.broadcast_to(pad_width, (array.ndim, 2))
2054    pad_width = torch.flip(pad_width, (0,)).flatten()
2055
2056    return torch.nn.functional.pad(array, tuple(pad_width), value=value)
2057