1# mypy: ignore-errors 2 3"""A thin pytorch / numpy compat layer. 4 5Things imported from here have numpy-compatible signatures but operate on 6pytorch tensors. 7""" 8# Contents of this module ends up in the main namespace via _funcs.py 9# where type annotations are used in conjunction with the @normalizer decorator. 10from __future__ import annotations 11 12import builtins 13import itertools 14import operator 15from typing import Optional, Sequence, TYPE_CHECKING 16 17import torch 18 19from . import _dtypes_impl, _util 20 21 22if TYPE_CHECKING: 23 from ._normalizations import ( 24 ArrayLike, 25 ArrayLikeOrScalar, 26 CastingModes, 27 DTypeLike, 28 NDArray, 29 NotImplementedType, 30 OutArray, 31 ) 32 33 34def copy( 35 a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False 36): 37 return a.clone() 38 39 40def copyto( 41 dst: NDArray, 42 src: ArrayLike, 43 casting: Optional[CastingModes] = "same_kind", 44 where: NotImplementedType = None, 45): 46 (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) 47 dst.copy_(src) 48 49 50def atleast_1d(*arys: ArrayLike): 51 res = torch.atleast_1d(*arys) 52 if isinstance(res, tuple): 53 return list(res) 54 else: 55 return res 56 57 58def atleast_2d(*arys: ArrayLike): 59 res = torch.atleast_2d(*arys) 60 if isinstance(res, tuple): 61 return list(res) 62 else: 63 return res 64 65 66def atleast_3d(*arys: ArrayLike): 67 res = torch.atleast_3d(*arys) 68 if isinstance(res, tuple): 69 return list(res) 70 else: 71 return res 72 73 74def _concat_check(tup, dtype, out): 75 if tup == (): 76 raise ValueError("need at least one array to concatenate") 77 78 """Check inputs in concatenate et al.""" 79 if out is not None and dtype is not None: 80 # mimic numpy 81 raise TypeError( 82 "concatenate() only takes `out` or `dtype` as an " 83 "argument, but both were provided." 84 ) 85 86 87def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): 88 """Figure out dtypes, cast if necessary.""" 89 90 if out is not None or dtype is not None: 91 # figure out the type of the inputs and outputs 92 out_dtype = out.dtype.torch_dtype if dtype is None else dtype 93 else: 94 out_dtype = _dtypes_impl.result_type_impl(*tensors) 95 96 # cast input arrays if necessary; do not broadcast them agains `out` 97 tensors = _util.typecast_tensors(tensors, out_dtype, casting) 98 99 return tensors 100 101 102def _concatenate( 103 tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind" 104): 105 # pure torch implementation, used below and in cov/corrcoef below 106 tensors, axis = _util.axis_none_flatten(*tensors, axis=axis) 107 tensors = _concat_cast_helper(tensors, out, dtype, casting) 108 return torch.cat(tensors, axis) 109 110 111def concatenate( 112 ar_tuple: Sequence[ArrayLike], 113 axis=0, 114 out: Optional[OutArray] = None, 115 dtype: Optional[DTypeLike] = None, 116 casting: Optional[CastingModes] = "same_kind", 117): 118 _concat_check(ar_tuple, dtype, out=out) 119 result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting) 120 return result 121 122 123def vstack( 124 tup: Sequence[ArrayLike], 125 *, 126 dtype: Optional[DTypeLike] = None, 127 casting: Optional[CastingModes] = "same_kind", 128): 129 _concat_check(tup, dtype, out=None) 130 tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) 131 return torch.vstack(tensors) 132 133 134row_stack = vstack 135 136 137def hstack( 138 tup: Sequence[ArrayLike], 139 *, 140 dtype: Optional[DTypeLike] = None, 141 casting: Optional[CastingModes] = "same_kind", 142): 143 _concat_check(tup, dtype, out=None) 144 tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) 145 return torch.hstack(tensors) 146 147 148def dstack( 149 tup: Sequence[ArrayLike], 150 *, 151 dtype: Optional[DTypeLike] = None, 152 casting: Optional[CastingModes] = "same_kind", 153): 154 # XXX: in numpy 1.24 dstack does not have dtype and casting keywords 155 # but {h,v}stack do. Hence add them here for consistency. 156 _concat_check(tup, dtype, out=None) 157 tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) 158 return torch.dstack(tensors) 159 160 161def column_stack( 162 tup: Sequence[ArrayLike], 163 *, 164 dtype: Optional[DTypeLike] = None, 165 casting: Optional[CastingModes] = "same_kind", 166): 167 # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords 168 # but row_stack does. (because row_stack is an alias for vstack, really). 169 # Hence add these keywords here for consistency. 170 _concat_check(tup, dtype, out=None) 171 tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) 172 return torch.column_stack(tensors) 173 174 175def stack( 176 arrays: Sequence[ArrayLike], 177 axis=0, 178 out: Optional[OutArray] = None, 179 *, 180 dtype: Optional[DTypeLike] = None, 181 casting: Optional[CastingModes] = "same_kind", 182): 183 _concat_check(arrays, dtype, out=out) 184 185 tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting) 186 result_ndim = tensors[0].ndim + 1 187 axis = _util.normalize_axis_index(axis, result_ndim) 188 return torch.stack(tensors, axis=axis) 189 190 191def append(arr: ArrayLike, values: ArrayLike, axis=None): 192 if axis is None: 193 if arr.ndim != 1: 194 arr = arr.flatten() 195 values = values.flatten() 196 axis = arr.ndim - 1 197 return _concatenate((arr, values), axis=axis) 198 199 200# ### split ### 201 202 203def _split_helper(tensor, indices_or_sections, axis, strict=False): 204 if isinstance(indices_or_sections, int): 205 return _split_helper_int(tensor, indices_or_sections, axis, strict) 206 elif isinstance(indices_or_sections, (list, tuple)): 207 # NB: drop split=..., it only applies to split_helper_int 208 return _split_helper_list(tensor, list(indices_or_sections), axis) 209 else: 210 raise TypeError("split_helper: ", type(indices_or_sections)) 211 212 213def _split_helper_int(tensor, indices_or_sections, axis, strict=False): 214 if not isinstance(indices_or_sections, int): 215 raise NotImplementedError("split: indices_or_sections") 216 217 axis = _util.normalize_axis_index(axis, tensor.ndim) 218 219 # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n 220 l, n = tensor.shape[axis], indices_or_sections 221 222 if n <= 0: 223 raise ValueError 224 225 if l % n == 0: 226 num, sz = n, l // n 227 lst = [sz] * num 228 else: 229 if strict: 230 raise ValueError("array split does not result in an equal division") 231 232 num, sz = l % n, l // n + 1 233 lst = [sz] * num 234 235 lst += [sz - 1] * (n - num) 236 237 return torch.split(tensor, lst, axis) 238 239 240def _split_helper_list(tensor, indices_or_sections, axis): 241 if not isinstance(indices_or_sections, list): 242 raise NotImplementedError("split: indices_or_sections: list") 243 # numpy expects indices, while torch expects lengths of sections 244 # also, numpy appends zero-size arrays for indices above the shape[axis] 245 lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] 246 num_extra = len(indices_or_sections) - len(lst) 247 248 lst.append(tensor.shape[axis]) 249 lst = [ 250 lst[0], 251 ] + [a - b for a, b in zip(lst[1:], lst[:-1])] 252 lst += [0] * num_extra 253 254 return torch.split(tensor, lst, axis) 255 256 257def array_split(ary: ArrayLike, indices_or_sections, axis=0): 258 return _split_helper(ary, indices_or_sections, axis) 259 260 261def split(ary: ArrayLike, indices_or_sections, axis=0): 262 return _split_helper(ary, indices_or_sections, axis, strict=True) 263 264 265def hsplit(ary: ArrayLike, indices_or_sections): 266 if ary.ndim == 0: 267 raise ValueError("hsplit only works on arrays of 1 or more dimensions") 268 axis = 1 if ary.ndim > 1 else 0 269 return _split_helper(ary, indices_or_sections, axis, strict=True) 270 271 272def vsplit(ary: ArrayLike, indices_or_sections): 273 if ary.ndim < 2: 274 raise ValueError("vsplit only works on arrays of 2 or more dimensions") 275 return _split_helper(ary, indices_or_sections, 0, strict=True) 276 277 278def dsplit(ary: ArrayLike, indices_or_sections): 279 if ary.ndim < 3: 280 raise ValueError("dsplit only works on arrays of 3 or more dimensions") 281 return _split_helper(ary, indices_or_sections, 2, strict=True) 282 283 284def kron(a: ArrayLike, b: ArrayLike): 285 return torch.kron(a, b) 286 287 288def vander(x: ArrayLike, N=None, increasing=False): 289 return torch.vander(x, N, increasing) 290 291 292# ### linspace, geomspace, logspace and arange ### 293 294 295def linspace( 296 start: ArrayLike, 297 stop: ArrayLike, 298 num=50, 299 endpoint=True, 300 retstep=False, 301 dtype: Optional[DTypeLike] = None, 302 axis=0, 303): 304 if axis != 0 or retstep or not endpoint: 305 raise NotImplementedError 306 if dtype is None: 307 dtype = _dtypes_impl.default_dtypes().float_dtype 308 # XXX: raises TypeError if start or stop are not scalars 309 return torch.linspace(start, stop, num, dtype=dtype) 310 311 312def geomspace( 313 start: ArrayLike, 314 stop: ArrayLike, 315 num=50, 316 endpoint=True, 317 dtype: Optional[DTypeLike] = None, 318 axis=0, 319): 320 if axis != 0 or not endpoint: 321 raise NotImplementedError 322 base = torch.pow(stop / start, 1.0 / (num - 1)) 323 logbase = torch.log(base) 324 return torch.logspace( 325 torch.log(start) / logbase, 326 torch.log(stop) / logbase, 327 num, 328 base=base, 329 ) 330 331 332def logspace( 333 start, 334 stop, 335 num=50, 336 endpoint=True, 337 base=10.0, 338 dtype: Optional[DTypeLike] = None, 339 axis=0, 340): 341 if axis != 0 or not endpoint: 342 raise NotImplementedError 343 return torch.logspace(start, stop, num, base=base, dtype=dtype) 344 345 346def arange( 347 start: Optional[ArrayLikeOrScalar] = None, 348 stop: Optional[ArrayLikeOrScalar] = None, 349 step: Optional[ArrayLikeOrScalar] = 1, 350 dtype: Optional[DTypeLike] = None, 351 *, 352 like: NotImplementedType = None, 353): 354 if step == 0: 355 raise ZeroDivisionError 356 if stop is None and start is None: 357 raise TypeError 358 if stop is None: 359 # XXX: this breaks if start is passed as a kwarg: 360 # arange(start=4) should raise (no stop) but doesn't 361 start, stop = 0, start 362 if start is None: 363 start = 0 364 365 # the dtype of the result 366 if dtype is None: 367 dtype = ( 368 _dtypes_impl.default_dtypes().float_dtype 369 if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step)) 370 else _dtypes_impl.default_dtypes().int_dtype 371 ) 372 work_dtype = torch.float64 if dtype.is_complex else dtype 373 374 # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager. 375 if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)): 376 raise NotImplementedError 377 378 if (step > 0 and start > stop) or (step < 0 and start < stop): 379 # empty range 380 return torch.empty(0, dtype=dtype) 381 382 result = torch.arange(start, stop, step, dtype=work_dtype) 383 result = _util.cast_if_needed(result, dtype) 384 return result 385 386 387# ### zeros/ones/empty/full ### 388 389 390def empty( 391 shape, 392 dtype: Optional[DTypeLike] = None, 393 order: NotImplementedType = "C", 394 *, 395 like: NotImplementedType = None, 396): 397 if dtype is None: 398 dtype = _dtypes_impl.default_dtypes().float_dtype 399 return torch.empty(shape, dtype=dtype) 400 401 402# NB: *_like functions deliberately deviate from numpy: it has subok=True 403# as the default; we set subok=False and raise on anything else. 404 405 406def empty_like( 407 prototype: ArrayLike, 408 dtype: Optional[DTypeLike] = None, 409 order: NotImplementedType = "K", 410 subok: NotImplementedType = False, 411 shape=None, 412): 413 result = torch.empty_like(prototype, dtype=dtype) 414 if shape is not None: 415 result = result.reshape(shape) 416 return result 417 418 419def full( 420 shape, 421 fill_value: ArrayLike, 422 dtype: Optional[DTypeLike] = None, 423 order: NotImplementedType = "C", 424 *, 425 like: NotImplementedType = None, 426): 427 if isinstance(shape, int): 428 shape = (shape,) 429 if dtype is None: 430 dtype = fill_value.dtype 431 if not isinstance(shape, (tuple, list)): 432 shape = (shape,) 433 return torch.full(shape, fill_value, dtype=dtype) 434 435 436def full_like( 437 a: ArrayLike, 438 fill_value, 439 dtype: Optional[DTypeLike] = None, 440 order: NotImplementedType = "K", 441 subok: NotImplementedType = False, 442 shape=None, 443): 444 # XXX: fill_value broadcasts 445 result = torch.full_like(a, fill_value, dtype=dtype) 446 if shape is not None: 447 result = result.reshape(shape) 448 return result 449 450 451def ones( 452 shape, 453 dtype: Optional[DTypeLike] = None, 454 order: NotImplementedType = "C", 455 *, 456 like: NotImplementedType = None, 457): 458 if dtype is None: 459 dtype = _dtypes_impl.default_dtypes().float_dtype 460 return torch.ones(shape, dtype=dtype) 461 462 463def ones_like( 464 a: ArrayLike, 465 dtype: Optional[DTypeLike] = None, 466 order: NotImplementedType = "K", 467 subok: NotImplementedType = False, 468 shape=None, 469): 470 result = torch.ones_like(a, dtype=dtype) 471 if shape is not None: 472 result = result.reshape(shape) 473 return result 474 475 476def zeros( 477 shape, 478 dtype: Optional[DTypeLike] = None, 479 order: NotImplementedType = "C", 480 *, 481 like: NotImplementedType = None, 482): 483 if dtype is None: 484 dtype = _dtypes_impl.default_dtypes().float_dtype 485 return torch.zeros(shape, dtype=dtype) 486 487 488def zeros_like( 489 a: ArrayLike, 490 dtype: Optional[DTypeLike] = None, 491 order: NotImplementedType = "K", 492 subok: NotImplementedType = False, 493 shape=None, 494): 495 result = torch.zeros_like(a, dtype=dtype) 496 if shape is not None: 497 result = result.reshape(shape) 498 return result 499 500 501# ### cov & corrcoef ### 502 503 504def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): 505 """Prepare inputs for cov and corrcoef.""" 506 507 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 508 if y_tensor is not None: 509 # make sure x and y are at least 2D 510 ndim_extra = 2 - x_tensor.ndim 511 if ndim_extra > 0: 512 x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape) 513 if not rowvar and x_tensor.shape[0] != 1: 514 x_tensor = x_tensor.mT 515 x_tensor = x_tensor.clone() 516 517 ndim_extra = 2 - y_tensor.ndim 518 if ndim_extra > 0: 519 y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape) 520 if not rowvar and y_tensor.shape[0] != 1: 521 y_tensor = y_tensor.mT 522 y_tensor = y_tensor.clone() 523 524 x_tensor = _concatenate((x_tensor, y_tensor), axis=0) 525 526 return x_tensor 527 528 529def corrcoef( 530 x: ArrayLike, 531 y: Optional[ArrayLike] = None, 532 rowvar=True, 533 bias=None, 534 ddof=None, 535 *, 536 dtype: Optional[DTypeLike] = None, 537): 538 if bias is not None or ddof is not None: 539 # deprecated in NumPy 540 raise NotImplementedError 541 xy_tensor = _xy_helper_corrcoef(x, y, rowvar) 542 543 is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu 544 if is_half: 545 # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" 546 dtype = torch.float32 547 548 xy_tensor = _util.cast_if_needed(xy_tensor, dtype) 549 result = torch.corrcoef(xy_tensor) 550 551 if is_half: 552 result = result.to(torch.float16) 553 554 return result 555 556 557def cov( 558 m: ArrayLike, 559 y: Optional[ArrayLike] = None, 560 rowvar=True, 561 bias=False, 562 ddof=None, 563 fweights: Optional[ArrayLike] = None, 564 aweights: Optional[ArrayLike] = None, 565 *, 566 dtype: Optional[DTypeLike] = None, 567): 568 m = _xy_helper_corrcoef(m, y, rowvar) 569 570 if ddof is None: 571 ddof = 1 if bias == 0 else 0 572 573 is_half = (m.dtype == torch.float16) and m.is_cpu 574 if is_half: 575 # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" 576 dtype = torch.float32 577 578 m = _util.cast_if_needed(m, dtype) 579 result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights) 580 581 if is_half: 582 result = result.to(torch.float16) 583 584 return result 585 586 587def _conv_corr_impl(a, v, mode): 588 dt = _dtypes_impl.result_type_impl(a, v) 589 a = _util.cast_if_needed(a, dt) 590 v = _util.cast_if_needed(v, dt) 591 592 padding = v.shape[0] - 1 if mode == "full" else mode 593 594 if padding == "same" and v.shape[0] % 2 == 0: 595 # UserWarning: Using padding='same' with even kernel lengths and odd 596 # dilation may require a zero-padded copy of the input be created 597 # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.) 598 raise NotImplementedError("mode='same' and even-length weights") 599 600 # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights 601 aa = a[None, :] 602 vv = v[None, None, :] 603 604 result = torch.nn.functional.conv1d(aa, vv, padding=padding) 605 606 # torch returns a 2D result, numpy returns a 1D array 607 return result[0, :] 608 609 610def convolve(a: ArrayLike, v: ArrayLike, mode="full"): 611 # NumPy: if v is longer than a, the arrays are swapped before computation 612 if a.shape[0] < v.shape[0]: 613 a, v = v, a 614 615 # flip the weights since numpy does and torch does not 616 v = torch.flip(v, (0,)) 617 618 return _conv_corr_impl(a, v, mode) 619 620 621def correlate(a: ArrayLike, v: ArrayLike, mode="valid"): 622 v = torch.conj_physical(v) 623 return _conv_corr_impl(a, v, mode) 624 625 626# ### logic & element selection ### 627 628 629def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0): 630 if x.numel() == 0: 631 # edge case allowed by numpy 632 x = x.new_empty(0, dtype=int) 633 634 int_dtype = _dtypes_impl.default_dtypes().int_dtype 635 (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe") 636 637 return torch.bincount(x, weights, minlength) 638 639 640def where( 641 condition: ArrayLike, 642 x: Optional[ArrayLikeOrScalar] = None, 643 y: Optional[ArrayLikeOrScalar] = None, 644 /, 645): 646 if (x is None) != (y is None): 647 raise ValueError("either both or neither of x and y should be given") 648 649 if condition.dtype != torch.bool: 650 condition = condition.to(torch.bool) 651 652 if x is None and y is None: 653 result = torch.where(condition) 654 else: 655 result = torch.where(condition, x, y) 656 return result 657 658 659# ###### module-level queries of object properties 660 661 662def ndim(a: ArrayLike): 663 return a.ndim 664 665 666def shape(a: ArrayLike): 667 return tuple(a.shape) 668 669 670def size(a: ArrayLike, axis=None): 671 if axis is None: 672 return a.numel() 673 else: 674 return a.shape[axis] 675 676 677# ###### shape manipulations and indexing 678 679 680def expand_dims(a: ArrayLike, axis): 681 shape = _util.expand_shape(a.shape, axis) 682 return a.view(shape) # never copies 683 684 685def flip(m: ArrayLike, axis=None): 686 # XXX: semantic difference: np.flip returns a view, torch.flip copies 687 if axis is None: 688 axis = tuple(range(m.ndim)) 689 else: 690 axis = _util.normalize_axis_tuple(axis, m.ndim) 691 return torch.flip(m, axis) 692 693 694def flipud(m: ArrayLike): 695 return torch.flipud(m) 696 697 698def fliplr(m: ArrayLike): 699 return torch.fliplr(m) 700 701 702def rot90(m: ArrayLike, k=1, axes=(0, 1)): 703 axes = _util.normalize_axis_tuple(axes, m.ndim) 704 return torch.rot90(m, k, axes) 705 706 707# ### broadcasting and indices ### 708 709 710def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): 711 return torch.broadcast_to(array, size=shape) 712 713 714# This is a function from tuples to tuples, so we just reuse it 715from torch import broadcast_shapes 716 717 718def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): 719 return torch.broadcast_tensors(*args) 720 721 722def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"): 723 ndim = len(xi) 724 725 if indexing not in ["xy", "ij"]: 726 raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.") 727 728 s0 = (1,) * ndim 729 output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)] 730 731 if indexing == "xy" and ndim > 1: 732 # switch first and second axis 733 output[0] = output[0].reshape((1, -1) + s0[2:]) 734 output[1] = output[1].reshape((-1, 1) + s0[2:]) 735 736 if not sparse: 737 # Return the full N-D matrix (not only the 1-D vector) 738 output = torch.broadcast_tensors(*output) 739 740 if copy: 741 output = [x.clone() for x in output] 742 743 return list(output) # match numpy, return a list 744 745 746def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False): 747 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791 748 dimensions = tuple(dimensions) 749 N = len(dimensions) 750 shape = (1,) * N 751 if sparse: 752 res = () 753 else: 754 res = torch.empty((N,) + dimensions, dtype=dtype) 755 for i, dim in enumerate(dimensions): 756 idx = torch.arange(dim, dtype=dtype).reshape( 757 shape[:i] + (dim,) + shape[i + 1 :] 758 ) 759 if sparse: 760 res = res + (idx,) 761 else: 762 res[i] = idx 763 return res 764 765 766# ### tri*-something ### 767 768 769def tril(m: ArrayLike, k=0): 770 return torch.tril(m, k) 771 772 773def triu(m: ArrayLike, k=0): 774 return torch.triu(m, k) 775 776 777def tril_indices(n, k=0, m=None): 778 if m is None: 779 m = n 780 return torch.tril_indices(n, m, offset=k) 781 782 783def triu_indices(n, k=0, m=None): 784 if m is None: 785 m = n 786 return torch.triu_indices(n, m, offset=k) 787 788 789def tril_indices_from(arr: ArrayLike, k=0): 790 if arr.ndim != 2: 791 raise ValueError("input array must be 2-d") 792 # Return a tensor rather than a tuple to avoid a graphbreak 793 return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) 794 795 796def triu_indices_from(arr: ArrayLike, k=0): 797 if arr.ndim != 2: 798 raise ValueError("input array must be 2-d") 799 # Return a tensor rather than a tuple to avoid a graphbreak 800 return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) 801 802 803def tri( 804 N, 805 M=None, 806 k=0, 807 dtype: Optional[DTypeLike] = None, 808 *, 809 like: NotImplementedType = None, 810): 811 if M is None: 812 M = N 813 tensor = torch.ones((N, M), dtype=dtype) 814 return torch.tril(tensor, diagonal=k) 815 816 817# ### equality, equivalence, allclose ### 818 819 820def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): 821 dtype = _dtypes_impl.result_type_impl(a, b) 822 a = _util.cast_if_needed(a, dtype) 823 b = _util.cast_if_needed(b, dtype) 824 return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) 825 826 827def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): 828 dtype = _dtypes_impl.result_type_impl(a, b) 829 a = _util.cast_if_needed(a, dtype) 830 b = _util.cast_if_needed(b, dtype) 831 return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) 832 833 834def _tensor_equal(a1, a2, equal_nan=False): 835 # Implementation of array_equal/array_equiv. 836 if a1.shape != a2.shape: 837 return False 838 cond = a1 == a2 839 if equal_nan: 840 cond = cond | (torch.isnan(a1) & torch.isnan(a2)) 841 return cond.all().item() 842 843 844def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False): 845 return _tensor_equal(a1, a2, equal_nan=equal_nan) 846 847 848def array_equiv(a1: ArrayLike, a2: ArrayLike): 849 # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not 850 try: 851 a1_t, a2_t = torch.broadcast_tensors(a1, a2) 852 except RuntimeError: 853 # failed to broadcast => not equivalent 854 return False 855 return _tensor_equal(a1_t, a2_t) 856 857 858def nan_to_num( 859 x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None 860): 861 # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble' 862 if x.is_complex(): 863 re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf) 864 im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf) 865 return re + 1j * im 866 else: 867 return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) 868 869 870# ### put/take_along_axis ### 871 872 873def take( 874 a: ArrayLike, 875 indices: ArrayLike, 876 axis=None, 877 out: Optional[OutArray] = None, 878 mode: NotImplementedType = "raise", 879): 880 (a,), axis = _util.axis_none_flatten(a, axis=axis) 881 axis = _util.normalize_axis_index(axis, a.ndim) 882 idx = (slice(None),) * axis + (indices, ...) 883 result = a[idx] 884 return result 885 886 887def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): 888 (arr,), axis = _util.axis_none_flatten(arr, axis=axis) 889 axis = _util.normalize_axis_index(axis, arr.ndim) 890 return torch.take_along_dim(arr, indices, axis) 891 892 893def put( 894 a: NDArray, 895 indices: ArrayLike, 896 values: ArrayLike, 897 mode: NotImplementedType = "raise", 898): 899 v = values.type(a.dtype) 900 # If indices is larger than v, expand v to at least the size of indices. Any 901 # unnecessary trailing elements are then trimmed. 902 if indices.numel() > v.numel(): 903 ratio = (indices.numel() + v.numel() - 1) // v.numel() 904 v = v.unsqueeze(0).expand((ratio,) + v.shape) 905 # Trim unnecessary elements, regardless if v was expanded or not. Note 906 # np.put() trims v to match indices by default too. 907 if indices.numel() < v.numel(): 908 v = v.flatten() 909 v = v[: indices.numel()] 910 a.put_(indices, v) 911 return None 912 913 914def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): 915 (arr,), axis = _util.axis_none_flatten(arr, axis=axis) 916 axis = _util.normalize_axis_index(axis, arr.ndim) 917 918 indices, values = torch.broadcast_tensors(indices, values) 919 values = _util.cast_if_needed(values, arr.dtype) 920 result = torch.scatter(arr, axis, indices, values) 921 arr.copy_(result.reshape(arr.shape)) 922 return None 923 924 925def choose( 926 a: ArrayLike, 927 choices: Sequence[ArrayLike], 928 out: Optional[OutArray] = None, 929 mode: NotImplementedType = "raise", 930): 931 # First, broadcast elements of `choices` 932 choices = torch.stack(torch.broadcast_tensors(*choices)) 933 934 # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`: 935 # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939) 936 idx_list = [ 937 torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1)) 938 for i, dim in enumerate(choices.shape) 939 ] 940 941 idx_list[0] = a 942 return choices[idx_list].squeeze(0) 943 944 945# ### unique et al. ### 946 947 948def unique( 949 ar: ArrayLike, 950 return_index: NotImplementedType = False, 951 return_inverse=False, 952 return_counts=False, 953 axis=None, 954 *, 955 equal_nan: NotImplementedType = True, 956): 957 (ar,), axis = _util.axis_none_flatten(ar, axis=axis) 958 axis = _util.normalize_axis_index(axis, ar.ndim) 959 960 result = torch.unique( 961 ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis 962 ) 963 964 return result 965 966 967def nonzero(a: ArrayLike): 968 return torch.nonzero(a, as_tuple=True) 969 970 971def argwhere(a: ArrayLike): 972 return torch.argwhere(a) 973 974 975def flatnonzero(a: ArrayLike): 976 return torch.flatten(a).nonzero(as_tuple=True)[0] 977 978 979def clip( 980 a: ArrayLike, 981 min: Optional[ArrayLike] = None, 982 max: Optional[ArrayLike] = None, 983 out: Optional[OutArray] = None, 984): 985 return torch.clamp(a, min, max) 986 987 988def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None): 989 return torch.repeat_interleave(a, repeats, axis) 990 991 992def tile(A: ArrayLike, reps): 993 if isinstance(reps, int): 994 reps = (reps,) 995 return torch.tile(A, reps) 996 997 998def resize(a: ArrayLike, new_shape=None): 999 # implementation vendored from 1000 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497 1001 if new_shape is None: 1002 return a 1003 1004 if isinstance(new_shape, int): 1005 new_shape = (new_shape,) 1006 1007 a = a.flatten() 1008 1009 new_size = 1 1010 for dim_length in new_shape: 1011 new_size *= dim_length 1012 if dim_length < 0: 1013 raise ValueError("all elements of `new_shape` must be non-negative") 1014 1015 if a.numel() == 0 or new_size == 0: 1016 # First case must zero fill. The second would have repeats == 0. 1017 return torch.zeros(new_shape, dtype=a.dtype) 1018 1019 repeats = -(-new_size // a.numel()) # ceil division 1020 a = concatenate((a,) * repeats)[:new_size] 1021 1022 return reshape(a, new_shape) 1023 1024 1025# ### diag et al. ### 1026 1027 1028def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): 1029 axis1 = _util.normalize_axis_index(axis1, a.ndim) 1030 axis2 = _util.normalize_axis_index(axis2, a.ndim) 1031 return torch.diagonal(a, offset, axis1, axis2) 1032 1033 1034def trace( 1035 a: ArrayLike, 1036 offset=0, 1037 axis1=0, 1038 axis2=1, 1039 dtype: Optional[DTypeLike] = None, 1040 out: Optional[OutArray] = None, 1041): 1042 result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype) 1043 return result 1044 1045 1046def eye( 1047 N, 1048 M=None, 1049 k=0, 1050 dtype: Optional[DTypeLike] = None, 1051 order: NotImplementedType = "C", 1052 *, 1053 like: NotImplementedType = None, 1054): 1055 if dtype is None: 1056 dtype = _dtypes_impl.default_dtypes().float_dtype 1057 if M is None: 1058 M = N 1059 z = torch.zeros(N, M, dtype=dtype) 1060 z.diagonal(k).fill_(1) 1061 return z 1062 1063 1064def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None): 1065 return torch.eye(n, dtype=dtype) 1066 1067 1068def diag(v: ArrayLike, k=0): 1069 return torch.diag(v, k) 1070 1071 1072def diagflat(v: ArrayLike, k=0): 1073 return torch.diagflat(v, k) 1074 1075 1076def diag_indices(n, ndim=2): 1077 idx = torch.arange(n) 1078 return (idx,) * ndim 1079 1080 1081def diag_indices_from(arr: ArrayLike): 1082 if not arr.ndim >= 2: 1083 raise ValueError("input array must be at least 2-d") 1084 # For more than d=2, the strided formula is only valid for arrays with 1085 # all dimensions equal, so we check first. 1086 s = arr.shape 1087 if s[1:] != s[:-1]: 1088 raise ValueError("All dimensions of input must be of equal length") 1089 return diag_indices(s[0], arr.ndim) 1090 1091 1092def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): 1093 if a.ndim < 2: 1094 raise ValueError("array must be at least 2-d") 1095 if val.numel() == 0 and not wrap: 1096 a.fill_diagonal_(val) 1097 return a 1098 1099 if val.ndim == 0: 1100 val = val.unsqueeze(0) 1101 1102 # torch.Tensor.fill_diagonal_ only accepts scalars 1103 # If the size of val is too large, then val is trimmed 1104 if a.ndim == 2: 1105 tall = a.shape[0] > a.shape[1] 1106 # wrap does nothing for wide matrices... 1107 if not wrap or not tall: 1108 # Never wraps 1109 diag = a.diagonal() 1110 diag.copy_(val[: diag.numel()]) 1111 else: 1112 # wraps and tall... leaving one empty line between diagonals?! 1113 max_, min_ = a.shape 1114 idx = torch.arange(max_ - max_ // (min_ + 1)) 1115 mod = idx % min_ 1116 div = idx // min_ 1117 a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()] 1118 else: 1119 idx = diag_indices_from(a) 1120 # a.shape = (n, n, ..., n) 1121 a[idx] = val[: a.shape[0]] 1122 1123 return a 1124 1125 1126def vdot(a: ArrayLike, b: ArrayLike, /): 1127 # 1. torch only accepts 1D arrays, numpy flattens 1128 # 2. torch requires matching dtype, while numpy casts (?) 1129 t_a, t_b = torch.atleast_1d(a, b) 1130 if t_a.ndim > 1: 1131 t_a = t_a.flatten() 1132 if t_b.ndim > 1: 1133 t_b = t_b.flatten() 1134 1135 dtype = _dtypes_impl.result_type_impl(t_a, t_b) 1136 is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu) 1137 is_bool = dtype == torch.bool 1138 1139 # work around torch's "dot" not implemented for 'Half', 'Bool' 1140 if is_half: 1141 dtype = torch.float32 1142 elif is_bool: 1143 dtype = torch.uint8 1144 1145 t_a = _util.cast_if_needed(t_a, dtype) 1146 t_b = _util.cast_if_needed(t_b, dtype) 1147 1148 result = torch.vdot(t_a, t_b) 1149 1150 if is_half: 1151 result = result.to(torch.float16) 1152 elif is_bool: 1153 result = result.to(torch.bool) 1154 1155 return result 1156 1157 1158def tensordot(a: ArrayLike, b: ArrayLike, axes=2): 1159 if isinstance(axes, (list, tuple)): 1160 axes = [[ax] if isinstance(ax, int) else ax for ax in axes] 1161 1162 target_dtype = _dtypes_impl.result_type_impl(a, b) 1163 a = _util.cast_if_needed(a, target_dtype) 1164 b = _util.cast_if_needed(b, target_dtype) 1165 1166 return torch.tensordot(a, b, dims=axes) 1167 1168 1169def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): 1170 dtype = _dtypes_impl.result_type_impl(a, b) 1171 is_bool = dtype == torch.bool 1172 if is_bool: 1173 dtype = torch.uint8 1174 1175 a = _util.cast_if_needed(a, dtype) 1176 b = _util.cast_if_needed(b, dtype) 1177 1178 if a.ndim == 0 or b.ndim == 0: 1179 result = a * b 1180 else: 1181 result = torch.matmul(a, b) 1182 1183 if is_bool: 1184 result = result.to(torch.bool) 1185 1186 return result 1187 1188 1189def inner(a: ArrayLike, b: ArrayLike, /): 1190 dtype = _dtypes_impl.result_type_impl(a, b) 1191 is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu) 1192 is_bool = dtype == torch.bool 1193 1194 if is_half: 1195 # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" 1196 dtype = torch.float32 1197 elif is_bool: 1198 dtype = torch.uint8 1199 1200 a = _util.cast_if_needed(a, dtype) 1201 b = _util.cast_if_needed(b, dtype) 1202 1203 result = torch.inner(a, b) 1204 1205 if is_half: 1206 result = result.to(torch.float16) 1207 elif is_bool: 1208 result = result.to(torch.bool) 1209 return result 1210 1211 1212def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): 1213 return torch.outer(a, b) 1214 1215 1216def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None): 1217 # implementation vendored from 1218 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685 1219 if axis is not None: 1220 axisa, axisb, axisc = (axis,) * 3 1221 1222 # Check axisa and axisb are within bounds 1223 axisa = _util.normalize_axis_index(axisa, a.ndim) 1224 axisb = _util.normalize_axis_index(axisb, b.ndim) 1225 1226 # Move working axis to the end of the shape 1227 a = torch.moveaxis(a, axisa, -1) 1228 b = torch.moveaxis(b, axisb, -1) 1229 msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)" 1230 if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): 1231 raise ValueError(msg) 1232 1233 # Create the output array 1234 shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape) 1235 if a.shape[-1] == 3 or b.shape[-1] == 3: 1236 shape += (3,) 1237 # Check axisc is within bounds 1238 axisc = _util.normalize_axis_index(axisc, len(shape)) 1239 dtype = _dtypes_impl.result_type_impl(a, b) 1240 cp = torch.empty(shape, dtype=dtype) 1241 1242 # recast arrays as dtype 1243 a = _util.cast_if_needed(a, dtype) 1244 b = _util.cast_if_needed(b, dtype) 1245 1246 # create local aliases for readability 1247 a0 = a[..., 0] 1248 a1 = a[..., 1] 1249 if a.shape[-1] == 3: 1250 a2 = a[..., 2] 1251 b0 = b[..., 0] 1252 b1 = b[..., 1] 1253 if b.shape[-1] == 3: 1254 b2 = b[..., 2] 1255 if cp.ndim != 0 and cp.shape[-1] == 3: 1256 cp0 = cp[..., 0] 1257 cp1 = cp[..., 1] 1258 cp2 = cp[..., 2] 1259 1260 if a.shape[-1] == 2: 1261 if b.shape[-1] == 2: 1262 # a0 * b1 - a1 * b0 1263 cp[...] = a0 * b1 - a1 * b0 1264 return cp 1265 else: 1266 assert b.shape[-1] == 3 1267 # cp0 = a1 * b2 - 0 (a2 = 0) 1268 # cp1 = 0 - a0 * b2 (a2 = 0) 1269 # cp2 = a0 * b1 - a1 * b0 1270 cp0[...] = a1 * b2 1271 cp1[...] = -a0 * b2 1272 cp2[...] = a0 * b1 - a1 * b0 1273 else: 1274 assert a.shape[-1] == 3 1275 if b.shape[-1] == 3: 1276 cp0[...] = a1 * b2 - a2 * b1 1277 cp1[...] = a2 * b0 - a0 * b2 1278 cp2[...] = a0 * b1 - a1 * b0 1279 else: 1280 assert b.shape[-1] == 2 1281 cp0[...] = -a2 * b1 1282 cp1[...] = a2 * b0 1283 cp2[...] = a0 * b1 - a1 * b0 1284 1285 return torch.moveaxis(cp, -1, axisc) 1286 1287 1288def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): 1289 # Have to manually normalize *operands and **kwargs, following the NumPy signature 1290 # We have a local import to avoid poluting the global space, as it will be then 1291 # exported in funcs.py 1292 from ._ndarray import ndarray 1293 from ._normalizations import ( 1294 maybe_copy_to, 1295 normalize_array_like, 1296 normalize_casting, 1297 normalize_dtype, 1298 wrap_tensors, 1299 ) 1300 1301 dtype = normalize_dtype(dtype) 1302 casting = normalize_casting(casting) 1303 if out is not None and not isinstance(out, ndarray): 1304 raise TypeError("'out' must be an array") 1305 if order != "K": 1306 raise NotImplementedError("'order' parameter is not supported.") 1307 1308 # parse arrays and normalize them 1309 sublist_format = not isinstance(operands[0], str) 1310 if sublist_format: 1311 # op, str, op, str ... [sublistout] format: normalize every other argument 1312 1313 # - if sublistout is not given, the length of operands is even, and we pick 1314 # odd-numbered elements, which are arrays. 1315 # - if sublistout is given, the length of operands is odd, we peel off 1316 # the last one, and pick odd-numbered elements, which are arrays. 1317 # Without [:-1], we would have picked sublistout, too. 1318 array_operands = operands[:-1][::2] 1319 else: 1320 # ("ij->", arrays) format 1321 subscripts, array_operands = operands[0], operands[1:] 1322 1323 tensors = [normalize_array_like(op) for op in array_operands] 1324 target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype 1325 1326 # work around 'bmm' not implemented for 'Half' etc 1327 is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors) 1328 if is_half: 1329 target_dtype = torch.float32 1330 1331 is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] 1332 if is_short_int: 1333 target_dtype = torch.int64 1334 1335 tensors = _util.typecast_tensors(tensors, target_dtype, casting) 1336 1337 from torch.backends import opt_einsum 1338 1339 try: 1340 # set the global state to handle the optimize=... argument, restore on exit 1341 if opt_einsum.is_available(): 1342 old_strategy = torch.backends.opt_einsum.strategy 1343 old_enabled = torch.backends.opt_einsum.enabled 1344 1345 # torch.einsum calls opt_einsum.contract_path, which runs into 1346 # https://github.com/dgasmith/opt_einsum/issues/219 1347 # for strategy={True, False} 1348 if optimize is True: 1349 optimize = "auto" 1350 elif optimize is False: 1351 torch.backends.opt_einsum.enabled = False 1352 1353 torch.backends.opt_einsum.strategy = optimize 1354 1355 if sublist_format: 1356 # recombine operands 1357 sublists = operands[1::2] 1358 has_sublistout = len(operands) % 2 == 1 1359 if has_sublistout: 1360 sublistout = operands[-1] 1361 operands = list(itertools.chain.from_iterable(zip(tensors, sublists))) 1362 if has_sublistout: 1363 operands.append(sublistout) 1364 1365 result = torch.einsum(*operands) 1366 else: 1367 result = torch.einsum(subscripts, *tensors) 1368 1369 finally: 1370 if opt_einsum.is_available(): 1371 torch.backends.opt_einsum.strategy = old_strategy 1372 torch.backends.opt_einsum.enabled = old_enabled 1373 1374 result = maybe_copy_to(out, result) 1375 return wrap_tensors(result) 1376 1377 1378# ### sort and partition ### 1379 1380 1381def _sort_helper(tensor, axis, kind, order): 1382 if tensor.dtype.is_complex: 1383 raise NotImplementedError(f"sorting {tensor.dtype} is not supported") 1384 (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) 1385 axis = _util.normalize_axis_index(axis, tensor.ndim) 1386 1387 stable = kind == "stable" 1388 1389 return tensor, axis, stable 1390 1391 1392def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): 1393 # `order` keyword arg is only relevant for structured dtypes; so not supported here. 1394 a, axis, stable = _sort_helper(a, axis, kind, order) 1395 result = torch.sort(a, dim=axis, stable=stable) 1396 return result.values 1397 1398 1399def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): 1400 a, axis, stable = _sort_helper(a, axis, kind, order) 1401 return torch.argsort(a, dim=axis, stable=stable) 1402 1403 1404def searchsorted( 1405 a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None 1406): 1407 if a.dtype.is_complex: 1408 raise NotImplementedError(f"searchsorted with dtype={a.dtype}") 1409 1410 return torch.searchsorted(a, v, side=side, sorter=sorter) 1411 1412 1413# ### swap/move/roll axis ### 1414 1415 1416def moveaxis(a: ArrayLike, source, destination): 1417 source = _util.normalize_axis_tuple(source, a.ndim, "source") 1418 destination = _util.normalize_axis_tuple(destination, a.ndim, "destination") 1419 return torch.moveaxis(a, source, destination) 1420 1421 1422def swapaxes(a: ArrayLike, axis1, axis2): 1423 axis1 = _util.normalize_axis_index(axis1, a.ndim) 1424 axis2 = _util.normalize_axis_index(axis2, a.ndim) 1425 return torch.swapaxes(a, axis1, axis2) 1426 1427 1428def rollaxis(a: ArrayLike, axis, start=0): 1429 # Straight vendor from: 1430 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259 1431 # 1432 # Also note this function in NumPy is mostly retained for backwards compat 1433 # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing) 1434 # so let's not touch it unless hard pressed. 1435 n = a.ndim 1436 axis = _util.normalize_axis_index(axis, n) 1437 if start < 0: 1438 start += n 1439 msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" 1440 if not (0 <= start < n + 1): 1441 raise _util.AxisError(msg % ("start", -n, "start", n + 1, start)) 1442 if axis < start: 1443 # it's been removed 1444 start -= 1 1445 if axis == start: 1446 # numpy returns a view, here we try returning the tensor itself 1447 # return tensor[...] 1448 return a 1449 axes = list(range(0, n)) 1450 axes.remove(axis) 1451 axes.insert(start, axis) 1452 return a.view(axes) 1453 1454 1455def roll(a: ArrayLike, shift, axis=None): 1456 if axis is not None: 1457 axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) 1458 if not isinstance(shift, tuple): 1459 shift = (shift,) * len(axis) 1460 return torch.roll(a, shift, axis) 1461 1462 1463# ### shape manipulations ### 1464 1465 1466def squeeze(a: ArrayLike, axis=None): 1467 if axis == (): 1468 result = a 1469 elif axis is None: 1470 result = a.squeeze() 1471 else: 1472 if isinstance(axis, tuple): 1473 result = a 1474 for ax in axis: 1475 result = a.squeeze(ax) 1476 else: 1477 result = a.squeeze(axis) 1478 return result 1479 1480 1481def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"): 1482 # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) 1483 newshape = newshape[0] if len(newshape) == 1 else newshape 1484 return a.reshape(newshape) 1485 1486 1487# NB: cannot use torch.reshape(a, newshape) above, because of 1488# (Pdb) torch.reshape(torch.as_tensor([1]), 1) 1489# *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int 1490 1491 1492def transpose(a: ArrayLike, axes=None): 1493 # numpy allows both .transpose(sh) and .transpose(*sh) 1494 # also older code uses axes being a list 1495 if axes in [(), None, (None,)]: 1496 axes = tuple(reversed(range(a.ndim))) 1497 elif len(axes) == 1: 1498 axes = axes[0] 1499 return a.permute(axes) 1500 1501 1502def ravel(a: ArrayLike, order: NotImplementedType = "C"): 1503 return torch.flatten(a) 1504 1505 1506def diff( 1507 a: ArrayLike, 1508 n=1, 1509 axis=-1, 1510 prepend: Optional[ArrayLike] = None, 1511 append: Optional[ArrayLike] = None, 1512): 1513 axis = _util.normalize_axis_index(axis, a.ndim) 1514 1515 if n < 0: 1516 raise ValueError(f"order must be non-negative but got {n}") 1517 1518 if n == 0: 1519 # match numpy and return the input immediately 1520 return a 1521 1522 if prepend is not None: 1523 shape = list(a.shape) 1524 shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1 1525 prepend = torch.broadcast_to(prepend, shape) 1526 1527 if append is not None: 1528 shape = list(a.shape) 1529 shape[axis] = append.shape[axis] if append.ndim > 0 else 1 1530 append = torch.broadcast_to(append, shape) 1531 1532 return torch.diff(a, n, axis=axis, prepend=prepend, append=append) 1533 1534 1535# ### math functions ### 1536 1537 1538def angle(z: ArrayLike, deg=False): 1539 result = torch.angle(z) 1540 if deg: 1541 result = result * (180 / torch.pi) 1542 return result 1543 1544 1545def sinc(x: ArrayLike): 1546 return torch.sinc(x) 1547 1548 1549# NB: have to normalize *varargs manually 1550def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1): 1551 N = f.ndim # number of dimensions 1552 1553 varargs = _util.ndarrays_to_tensors(varargs) 1554 1555 if axis is None: 1556 axes = tuple(range(N)) 1557 else: 1558 axes = _util.normalize_axis_tuple(axis, N) 1559 1560 len_axes = len(axes) 1561 n = len(varargs) 1562 if n == 0: 1563 # no spacing argument - use 1 in all axes 1564 dx = [1.0] * len_axes 1565 elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0): 1566 # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0) 1567 dx = varargs * len_axes 1568 elif n == len_axes: 1569 # scalar or 1d array for each axis 1570 dx = list(varargs) 1571 for i, distances in enumerate(dx): 1572 distances = torch.as_tensor(distances) 1573 if distances.ndim == 0: 1574 continue 1575 elif distances.ndim != 1: 1576 raise ValueError("distances must be either scalars or 1d") 1577 if len(distances) != f.shape[axes[i]]: 1578 raise ValueError( 1579 "when 1d, distances must match " 1580 "the length of the corresponding dimension" 1581 ) 1582 if not (distances.dtype.is_floating_point or distances.dtype.is_complex): 1583 distances = distances.double() 1584 1585 diffx = torch.diff(distances) 1586 # if distances are constant reduce to the scalar case 1587 # since it brings a consistent speedup 1588 if (diffx == diffx[0]).all(): 1589 diffx = diffx[0] 1590 dx[i] = diffx 1591 else: 1592 raise TypeError("invalid number of arguments") 1593 1594 if edge_order > 2: 1595 raise ValueError("'edge_order' greater than 2 not supported") 1596 1597 # use central differences on interior and one-sided differences on the 1598 # endpoints. This preserves second order-accuracy over the full domain. 1599 1600 outvals = [] 1601 1602 # create slice objects --- initially all are [:, :, ..., :] 1603 slice1 = [slice(None)] * N 1604 slice2 = [slice(None)] * N 1605 slice3 = [slice(None)] * N 1606 slice4 = [slice(None)] * N 1607 1608 otype = f.dtype 1609 if _dtypes_impl.python_type_for_torch(otype) in (int, bool): 1610 # Convert to floating point. 1611 # First check if f is a numpy integer type; if so, convert f to float64 1612 # to avoid modular arithmetic when computing the changes in f. 1613 f = f.double() 1614 otype = torch.float64 1615 1616 for axis, ax_dx in zip(axes, dx): 1617 if f.shape[axis] < edge_order + 1: 1618 raise ValueError( 1619 "Shape of array too small to calculate a numerical gradient, " 1620 "at least (edge_order + 1) elements are required." 1621 ) 1622 # result allocation 1623 out = torch.empty_like(f, dtype=otype) 1624 1625 # spacing for the current axis (NB: np.ndim(ax_dx) == 0) 1626 uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0 1627 1628 # Numerical differentiation: 2nd order interior 1629 slice1[axis] = slice(1, -1) 1630 slice2[axis] = slice(None, -2) 1631 slice3[axis] = slice(1, -1) 1632 slice4[axis] = slice(2, None) 1633 1634 if uniform_spacing: 1635 out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx) 1636 else: 1637 dx1 = ax_dx[0:-1] 1638 dx2 = ax_dx[1:] 1639 a = -(dx2) / (dx1 * (dx1 + dx2)) 1640 b = (dx2 - dx1) / (dx1 * dx2) 1641 c = dx1 / (dx2 * (dx1 + dx2)) 1642 # fix the shape for broadcasting 1643 shape = [1] * N 1644 shape[axis] = -1 1645 a = a.reshape(shape) 1646 b = b.reshape(shape) 1647 c = c.reshape(shape) 1648 # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] 1649 out[tuple(slice1)] = ( 1650 a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] 1651 ) 1652 1653 # Numerical differentiation: 1st order edges 1654 if edge_order == 1: 1655 slice1[axis] = 0 1656 slice2[axis] = 1 1657 slice3[axis] = 0 1658 dx_0 = ax_dx if uniform_spacing else ax_dx[0] 1659 # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) 1660 out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 1661 1662 slice1[axis] = -1 1663 slice2[axis] = -1 1664 slice3[axis] = -2 1665 dx_n = ax_dx if uniform_spacing else ax_dx[-1] 1666 # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) 1667 out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n 1668 1669 # Numerical differentiation: 2nd order edges 1670 else: 1671 slice1[axis] = 0 1672 slice2[axis] = 0 1673 slice3[axis] = 1 1674 slice4[axis] = 2 1675 if uniform_spacing: 1676 a = -1.5 / ax_dx 1677 b = 2.0 / ax_dx 1678 c = -0.5 / ax_dx 1679 else: 1680 dx1 = ax_dx[0] 1681 dx2 = ax_dx[1] 1682 a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) 1683 b = (dx1 + dx2) / (dx1 * dx2) 1684 c = -dx1 / (dx2 * (dx1 + dx2)) 1685 # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] 1686 out[tuple(slice1)] = ( 1687 a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] 1688 ) 1689 1690 slice1[axis] = -1 1691 slice2[axis] = -3 1692 slice3[axis] = -2 1693 slice4[axis] = -1 1694 if uniform_spacing: 1695 a = 0.5 / ax_dx 1696 b = -2.0 / ax_dx 1697 c = 1.5 / ax_dx 1698 else: 1699 dx1 = ax_dx[-2] 1700 dx2 = ax_dx[-1] 1701 a = (dx2) / (dx1 * (dx1 + dx2)) 1702 b = -(dx2 + dx1) / (dx1 * dx2) 1703 c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) 1704 # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] 1705 out[tuple(slice1)] = ( 1706 a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] 1707 ) 1708 1709 outvals.append(out) 1710 1711 # reset the slice object in this dimension to ":" 1712 slice1[axis] = slice(None) 1713 slice2[axis] = slice(None) 1714 slice3[axis] = slice(None) 1715 slice4[axis] = slice(None) 1716 1717 if len_axes == 1: 1718 return outvals[0] 1719 else: 1720 return outvals 1721 1722 1723# ### Type/shape etc queries ### 1724 1725 1726def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): 1727 if a.is_floating_point(): 1728 result = torch.round(a, decimals=decimals) 1729 elif a.is_complex(): 1730 # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' 1731 result = torch.complex( 1732 torch.round(a.real, decimals=decimals), 1733 torch.round(a.imag, decimals=decimals), 1734 ) 1735 else: 1736 # RuntimeError: "round_cpu" not implemented for 'int' 1737 result = a 1738 return result 1739 1740 1741around = round 1742round_ = round 1743 1744 1745def real_if_close(a: ArrayLike, tol=100): 1746 if not torch.is_complex(a): 1747 return a 1748 if tol > 1: 1749 # Undocumented in numpy: if tol < 1, it's an absolute tolerance! 1750 # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon 1751 # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577 1752 tol = tol * torch.finfo(a.dtype).eps 1753 1754 mask = torch.abs(a.imag) < tol 1755 return a.real if mask.all() else a 1756 1757 1758def real(a: ArrayLike): 1759 return torch.real(a) 1760 1761 1762def imag(a: ArrayLike): 1763 if a.is_complex(): 1764 return a.imag 1765 return torch.zeros_like(a) 1766 1767 1768def iscomplex(x: ArrayLike): 1769 if torch.is_complex(x): 1770 return x.imag != 0 1771 return torch.zeros_like(x, dtype=torch.bool) 1772 1773 1774def isreal(x: ArrayLike): 1775 if torch.is_complex(x): 1776 return x.imag == 0 1777 return torch.ones_like(x, dtype=torch.bool) 1778 1779 1780def iscomplexobj(x: ArrayLike): 1781 return torch.is_complex(x) 1782 1783 1784def isrealobj(x: ArrayLike): 1785 return not torch.is_complex(x) 1786 1787 1788def isneginf(x: ArrayLike, out: Optional[OutArray] = None): 1789 return torch.isneginf(x) 1790 1791 1792def isposinf(x: ArrayLike, out: Optional[OutArray] = None): 1793 return torch.isposinf(x) 1794 1795 1796def i0(x: ArrayLike): 1797 return torch.special.i0(x) 1798 1799 1800def isscalar(a): 1801 # We need to use normalize_array_like, but we don't want to export it in funcs.py 1802 from ._normalizations import normalize_array_like 1803 1804 try: 1805 t = normalize_array_like(a) 1806 return t.numel() == 1 1807 except Exception: 1808 return False 1809 1810 1811# ### Filter windows ### 1812 1813 1814def hamming(M): 1815 dtype = _dtypes_impl.default_dtypes().float_dtype 1816 return torch.hamming_window(M, periodic=False, dtype=dtype) 1817 1818 1819def hanning(M): 1820 dtype = _dtypes_impl.default_dtypes().float_dtype 1821 return torch.hann_window(M, periodic=False, dtype=dtype) 1822 1823 1824def kaiser(M, beta): 1825 dtype = _dtypes_impl.default_dtypes().float_dtype 1826 return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype) 1827 1828 1829def blackman(M): 1830 dtype = _dtypes_impl.default_dtypes().float_dtype 1831 return torch.blackman_window(M, periodic=False, dtype=dtype) 1832 1833 1834def bartlett(M): 1835 dtype = _dtypes_impl.default_dtypes().float_dtype 1836 return torch.bartlett_window(M, periodic=False, dtype=dtype) 1837 1838 1839# ### Dtype routines ### 1840 1841# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666 1842 1843 1844array_type = [ 1845 [torch.float16, torch.float32, torch.float64], 1846 [None, torch.complex64, torch.complex128], 1847] 1848array_precision = { 1849 torch.float16: 0, 1850 torch.float32: 1, 1851 torch.float64: 2, 1852 torch.complex64: 1, 1853 torch.complex128: 2, 1854} 1855 1856 1857def common_type(*tensors: ArrayLike): 1858 is_complex = False 1859 precision = 0 1860 for a in tensors: 1861 t = a.dtype 1862 if iscomplexobj(a): 1863 is_complex = True 1864 if not (t.is_floating_point or t.is_complex): 1865 p = 2 # array_precision[_nx.double] 1866 else: 1867 p = array_precision.get(t, None) 1868 if p is None: 1869 raise TypeError("can't get common type for non-numeric array") 1870 precision = builtins.max(precision, p) 1871 if is_complex: 1872 return array_type[1][precision] 1873 else: 1874 return array_type[0][precision] 1875 1876 1877# ### histograms ### 1878 1879 1880def histogram( 1881 a: ArrayLike, 1882 bins: ArrayLike = 10, 1883 range=None, 1884 normed=None, 1885 weights: Optional[ArrayLike] = None, 1886 density=None, 1887): 1888 if normed is not None: 1889 raise ValueError("normed argument is deprecated, use density= instead") 1890 1891 if weights is not None and weights.dtype.is_complex: 1892 raise NotImplementedError("complex weights histogram.") 1893 1894 is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) 1895 is_w_int = weights is None or not weights.dtype.is_floating_point 1896 if is_a_int: 1897 a = a.double() 1898 1899 if weights is not None: 1900 weights = _util.cast_if_needed(weights, a.dtype) 1901 1902 if isinstance(bins, torch.Tensor): 1903 if bins.ndim == 0: 1904 # bins was a single int 1905 bins = operator.index(bins) 1906 else: 1907 bins = _util.cast_if_needed(bins, a.dtype) 1908 1909 if range is None: 1910 h, b = torch.histogram(a, bins, weight=weights, density=bool(density)) 1911 else: 1912 h, b = torch.histogram( 1913 a, bins, range=range, weight=weights, density=bool(density) 1914 ) 1915 1916 if not density and is_w_int: 1917 h = h.long() 1918 if is_a_int: 1919 b = b.long() 1920 1921 return h, b 1922 1923 1924def histogram2d( 1925 x, 1926 y, 1927 bins=10, 1928 range: Optional[ArrayLike] = None, 1929 normed=None, 1930 weights: Optional[ArrayLike] = None, 1931 density=None, 1932): 1933 # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821 1934 if len(x) != len(y): 1935 raise ValueError("x and y must have the same length.") 1936 1937 try: 1938 N = len(bins) 1939 except TypeError: 1940 N = 1 1941 1942 if N != 1 and N != 2: 1943 bins = [bins, bins] 1944 1945 h, e = histogramdd((x, y), bins, range, normed, weights, density) 1946 1947 return h, e[0], e[1] 1948 1949 1950def histogramdd( 1951 sample, 1952 bins=10, 1953 range: Optional[ArrayLike] = None, 1954 normed=None, 1955 weights: Optional[ArrayLike] = None, 1956 density=None, 1957): 1958 # have to normalize manually because `sample` interpretation differs 1959 # for a list of lists and a 2D array 1960 if normed is not None: 1961 raise ValueError("normed argument is deprecated, use density= instead") 1962 1963 from ._normalizations import normalize_array_like, normalize_seq_array_like 1964 1965 if isinstance(sample, (list, tuple)): 1966 sample = normalize_array_like(sample).T 1967 else: 1968 sample = normalize_array_like(sample) 1969 1970 sample = torch.atleast_2d(sample) 1971 1972 if not (sample.dtype.is_floating_point or sample.dtype.is_complex): 1973 sample = sample.double() 1974 1975 # bins is either an int, or a sequence of ints or a sequence of arrays 1976 bins_is_array = not ( 1977 isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins) 1978 ) 1979 if bins_is_array: 1980 bins = normalize_seq_array_like(bins) 1981 bins_dtypes = [b.dtype for b in bins] 1982 bins = [_util.cast_if_needed(b, sample.dtype) for b in bins] 1983 1984 if range is not None: 1985 range = range.flatten().tolist() 1986 1987 if weights is not None: 1988 # range=... is required : interleave min and max values per dimension 1989 mm = sample.aminmax(dim=0) 1990 range = torch.cat(mm).reshape(2, -1).T.flatten() 1991 range = tuple(range.tolist()) 1992 weights = _util.cast_if_needed(weights, sample.dtype) 1993 w_kwd = {"weight": weights} 1994 else: 1995 w_kwd = {} 1996 1997 h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd) 1998 1999 if bins_is_array: 2000 b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)] 2001 2002 return h, b 2003 2004 2005# ### odds and ends 2006 2007 2008def min_scalar_type(a: ArrayLike, /): 2009 # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288 2010 2011 from ._dtypes import DType 2012 2013 if a.numel() > 1: 2014 # numpy docs: "For non-scalar array a, returns the vector's dtype unmodified." 2015 return DType(a.dtype) 2016 2017 if a.dtype == torch.bool: 2018 dtype = torch.bool 2019 2020 elif a.dtype.is_complex: 2021 fi = torch.finfo(torch.float32) 2022 fits_in_single = a.dtype == torch.complex64 or ( 2023 fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max 2024 ) 2025 dtype = torch.complex64 if fits_in_single else torch.complex128 2026 2027 elif a.dtype.is_floating_point: 2028 for dt in [torch.float16, torch.float32, torch.float64]: 2029 fi = torch.finfo(dt) 2030 if fi.min <= a <= fi.max: 2031 dtype = dt 2032 break 2033 else: 2034 # must be integer 2035 for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: 2036 # Prefer unsigned int where possible, as numpy does. 2037 ii = torch.iinfo(dt) 2038 if ii.min <= a <= ii.max: 2039 dtype = dt 2040 break 2041 2042 return DType(dtype) 2043 2044 2045def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs): 2046 if mode != "constant": 2047 raise NotImplementedError 2048 value = kwargs.get("constant_values", 0) 2049 # `value` must be a python scalar for torch.nn.functional.pad 2050 typ = _dtypes_impl.python_type_for_torch(array.dtype) 2051 value = typ(value) 2052 2053 pad_width = torch.broadcast_to(pad_width, (array.ndim, 2)) 2054 pad_width = torch.flip(pad_width, (0,)).flatten() 2055 2056 return torch.nn.functional.pad(array, tuple(pad_width), value=value) 2057