1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common array methods.""" 16# pylint: disable=g-direct-tensorflow-import 17 18import enum 19import functools 20import math 21import numbers 22import numpy as np 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import clip_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import linalg_ops 32from tensorflow.python.ops import manip_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import sort_ops 35from tensorflow.python.ops.numpy_ops import np_arrays 36from tensorflow.python.ops.numpy_ops import np_dtypes 37from tensorflow.python.ops.numpy_ops import np_export 38from tensorflow.python.ops.numpy_ops import np_utils 39from tensorflow.python.util import nest 40 41 42newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis) 43 44 45@np_utils.np_doc('empty') 46def empty(shape, dtype=float): # pylint: disable=redefined-outer-name 47 return zeros(shape, dtype) 48 49 50@np_utils.np_doc('empty_like') 51def empty_like(a, dtype=None): 52 return zeros_like(a, dtype) 53 54 55@np_utils.np_doc('zeros') 56def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name 57 dtype = ( 58 np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()) 59 return array_ops.zeros(shape, dtype=dtype) 60 61 62@np_utils.np_doc('zeros_like') 63def zeros_like(a, dtype=None): # pylint: disable=missing-docstring 64 dtype = np_utils.result_type_unary(a, dtype) 65 66 dtype = dtypes.as_dtype(dtype) # Work around b/149877262 67 return array_ops.zeros_like(a, dtype) 68 69 70@np_utils.np_doc('ones') 71def ones(shape, dtype=float): # pylint: disable=redefined-outer-name 72 if dtype: 73 dtype = np_utils.result_type(dtype) 74 return array_ops.ones(shape, dtype=dtype) 75 76 77@np_utils.np_doc('ones_like') 78def ones_like(a, dtype=None): 79 dtype = np_utils.result_type_unary(a, dtype) 80 return array_ops.ones_like(a, dtype) 81 82 83@np_utils.np_doc('eye') 84def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-docstring 85 if dtype: 86 dtype = np_utils.result_type(dtype) 87 if not M: 88 M = N 89 # Making sure N, M and k are `int` 90 N = int(N) 91 M = int(M) 92 k = int(k) 93 if k >= M or -k >= N: 94 # tf.linalg.diag will raise an error in this case 95 return zeros([N, M], dtype=dtype) 96 if k == 0: 97 return linalg_ops.eye(N, M, dtype=dtype) 98 # We need the precise length, otherwise tf.linalg.diag will raise an error 99 diag_len = min(N, M) 100 if k > 0: 101 if N >= M: 102 diag_len -= k 103 elif N + k > M: 104 diag_len = M - k 105 elif k <= 0: 106 if M >= N: 107 diag_len += k 108 elif M - k > N: 109 diag_len = N + k 110 diagonal_ = array_ops.ones([diag_len], dtype=dtype) 111 return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) 112 113 114@np_utils.np_doc('identity') 115def identity(n, dtype=float): 116 return eye(N=n, M=n, dtype=dtype) 117 118 119@np_utils.np_doc('full') 120def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name 121 if not isinstance(shape, np_arrays.ndarray): 122 shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) 123 shape = atleast_1d(shape) 124 fill_value = asarray(fill_value, dtype=dtype) 125 return array_ops.broadcast_to(fill_value, shape) 126 127 128# Using doc only here since np full_like signature doesn't seem to have the 129# shape argument (even though it exists in the documentation online). 130@np_utils.np_doc_only('full_like') 131def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # pylint: disable=missing-docstring,redefined-outer-name 132 """order, subok and shape arguments mustn't be changed.""" 133 if order != 'K': 134 raise ValueError('Non-standard orders are not supported.') 135 if not subok: 136 raise ValueError('subok being False is not supported.') 137 if shape: 138 raise ValueError('Overriding the shape is not supported.') 139 140 a = asarray(a) 141 dtype = dtype or np_utils.result_type(a) 142 fill_value = asarray(fill_value, dtype=dtype) 143 return array_ops.broadcast_to(fill_value, array_ops.shape(a)) 144 145 146def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 147 """Main implementation of np.array().""" 148 result_t = val 149 150 if not isinstance(result_t, ops.Tensor): 151 dtype = np_utils.result_type_unary(result_t, dtype) 152 # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because 153 # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) 154 # while np.array allows them. We need to convert-then-cast. 155 156 # EagerTensor conversion complains about "mixed types" when converting 157 # tensors with no dtype information. This is because it infers types based 158 # on one selected item in the list. So e.g. when converting [2., 2j] 159 # to a tensor, it will select float32 as the inferred type and not be able 160 # to convert the list to a float 32 tensor. 161 # Since we have some information about the final dtype we care about, we 162 # supply that information so that convert_to_tensor will do best-effort 163 # conversion to that dtype first. 164 result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype) 165 result_t = math_ops.cast(result_t, dtype=dtype) 166 elif dtype: 167 result_t = math_ops.cast(result_t, dtype) 168 169 if copy: 170 result_t = array_ops.identity(result_t) 171 172 max_ndmin = 32 173 if ndmin > max_ndmin: 174 raise ValueError('ndmin bigger than allowable number of dimensions: ' 175 f'{max_ndmin}.') 176 177 if ndmin == 0: 178 return result_t 179 180 ndims = array_ops.rank(result_t) 181 182 def true_fn(): 183 old_shape = array_ops.shape(result_t) 184 new_shape = array_ops.concat( 185 [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0) 186 return array_ops.reshape(result_t, new_shape) 187 188 result_t = np_utils.cond( 189 np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) 190 return result_t 191 192 193# TODO(wangpeng): investigate whether we can make `copy` default to False. 194# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 195@np_utils.np_doc_only('array') 196def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name 197 """Since Tensors are immutable, a copy is made only if val is placed on a 198 199 different device than the current one. Even if `copy` is False, a new Tensor 200 may need to be built to satisfy `dtype` and `ndim`. This is used only if `val` 201 is an ndarray or a Tensor. 202 """ # pylint:disable=g-docstring-missing-newline 203 if dtype: 204 dtype = np_utils.result_type(dtype) 205 return _array_internal(val, dtype, copy, ndmin) 206 207 208# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args 209 210 211@np_utils.np_doc('asarray') 212def asarray(a, dtype=None): 213 if dtype: 214 dtype = np_utils.result_type(dtype) 215 if isinstance(a, np_arrays.ndarray) and ( 216 not dtype or dtype == a.dtype.as_numpy_dtype): 217 return a 218 return array(a, dtype, copy=False) 219 220 221@np_utils.np_doc('asanyarray') 222def asanyarray(a, dtype=None): 223 return asarray(a, dtype) 224 225 226@np_utils.np_doc('ascontiguousarray') 227def ascontiguousarray(a, dtype=None): 228 return array(a, dtype, ndmin=1) 229 230 231# Numerical ranges. 232@np_utils.np_doc('arange') 233def arange(start, stop=None, step=1, dtype=None): 234 """Returns `step`-separated values in the range [start, stop). 235 236 Args: 237 start: Start of the interval. Included in the range. 238 stop: End of the interval. If not specified, `start` is treated as 0 and 239 `start` value is used as `stop`. If specified, it is not included in the 240 range if `step` is integer. When `step` is floating point, it may or may 241 not be included. 242 step: The difference between 2 consecutive values in the output range. It is 243 recommended to use `linspace` instead of using non-integer values for 244 `step`. 245 dtype: Optional. Type of the resulting ndarray. Could be a python type, a 246 NumPy type or a TensorFlow `DType`. If not provided, the largest type of 247 `start`, `stop`, `step` is used. 248 249 Raises: 250 ValueError: If step is zero. 251 """ 252 if not step: 253 raise ValueError('step must be non-zero.') 254 if dtype: 255 dtype = np_utils.result_type(dtype) 256 else: 257 if stop is None: 258 dtype = np_utils.result_type(start, step) 259 else: 260 dtype = np_utils.result_type(start, step, stop) 261 if step > 0 and ((stop is not None and start > stop) or 262 (stop is None and start < 0)): 263 return array([], dtype=dtype) 264 if step < 0 and ((stop is not None and start < stop) or 265 (stop is None and start > 0)): 266 return array([], dtype=dtype) 267 # TODO(srbs): There are some bugs when start or stop is float type and dtype 268 # is integer type. 269 return math_ops.cast( 270 math_ops.range(start, limit=stop, delta=step), dtype=dtype) 271 272 273# Building matrices. 274@np_utils.np_doc('diag') 275def diag(v, k=0): # pylint: disable=missing-docstring 276 """Raises an error if input is not 1- or 2-d.""" 277 v = asarray(v) 278 v_rank = array_ops.rank(v) 279 280 v.shape.with_rank_at_most(2) 281 282 # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during 283 # tracing time if the shape is known. 284 control_flow_ops.Assert( 285 np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)), 286 [v_rank]) 287 288 def _diag(v, k): 289 return np_utils.cond( 290 math_ops.equal(array_ops.size(v), 0), 291 lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype), 292 lambda: array_ops.matrix_diag(v, k=k)) 293 294 def _diag_part(v, k): 295 v_shape = array_ops.shape(v) 296 v, k = np_utils.cond( 297 np_utils.logical_or( 298 np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)), 299 np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)), 300 ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) 301 result = array_ops.matrix_diag_part(v, k=k) 302 return result 303 304 result = np_utils.cond( 305 math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) 306 return result 307 308 309@np_utils.np_doc('diagonal') 310def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring 311 a = asarray(a) 312 313 maybe_rank = a.shape.rank 314 if maybe_rank is not None and offset == 0 and ( 315 axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or 316 axis2 == -1): 317 return array_ops.matrix_diag_part(a) 318 319 a = moveaxis(a, (axis1, axis2), (-2, -1)) 320 321 a_shape = array_ops.shape(a) 322 323 def _zeros(): # pylint: disable=missing-docstring 324 return (array_ops.zeros( 325 array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) 326 327 # All zeros since diag_part doesn't handle all possible k (aka offset). 328 # Written this way since cond will run shape inference on both branches, 329 # and diag_part shape inference will fail when offset is out of bounds. 330 a, offset = np_utils.cond( 331 np_utils.logical_or( 332 np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)), 333 np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)), 334 ), _zeros, lambda: (a, offset)) 335 336 a = array_ops.matrix_diag_part(a, k=offset) 337 return a 338 339 340@np_utils.np_doc('diagflat') 341def diagflat(v, k=0): 342 v = asarray(v) 343 return diag(array_ops.reshape(v, [-1]), k) 344 345 346def _promote_dtype(*arrays): 347 dtype = np_utils.result_type(*arrays) 348 def _fast_asarray(a): 349 if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype: 350 return a 351 return _array_internal(a, dtype=dtype, copy=False) 352 return [_fast_asarray(a) for a in arrays] 353 354 355def _promote_dtype_binary(t1, t2): 356 dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access 357 if not( 358 isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype): 359 t1 = _array_internal(t1, dtype=dtype, copy=False) 360 if not( 361 isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype): 362 t2 = _array_internal(t2, dtype=dtype, copy=False) 363 return t1, t2 364 365 366@np_utils.np_doc('all') 367def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 368 a = asarray(a, dtype=bool) 369 return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims) 370 371 372@np_utils.np_doc('any') 373def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin 374 a = asarray(a, dtype=bool) 375 return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims) 376 377 378@np_utils.np_doc('compress') 379def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,missing-function-docstring 380 condition = asarray(condition, dtype=bool) 381 a = asarray(a) 382 383 if condition.ndim != 1: 384 raise ValueError('condition must be a 1-d array.') 385 # `np.compress` treats scalars as 1-d arrays. 386 if a.ndim == 0: 387 a = ravel(a) 388 389 if axis is None: 390 a = ravel(a) 391 axis = 0 392 393 if axis < 0: 394 axis += a.ndim 395 396 assert axis >= 0 and axis < a.ndim 397 398 # `tf.boolean_mask` requires the first dimensions of array and condition to 399 # match. `np.compress` pads condition with False when it is shorter. 400 condition_t = condition 401 a_t = a 402 if condition.shape[0] < a.shape[axis]: 403 padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False) 404 condition_t = array_ops.concat([condition_t, padding], axis=0) 405 return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) 406 407 408@np_utils.np_doc('copy') 409def copy(a): 410 return array(a, copy=True) 411 412 413def _maybe_promote_to_int(a): 414 if dtypes.as_dtype(a.dtype).is_integer: 415 # If a is an integer type and its precision is less than that of `int`, 416 # the output type will be `int`. 417 a_numpy_dtype = a.dtype.as_numpy_dtype 418 output_type = np.promote_types(a_numpy_dtype, int) 419 if output_type != a_numpy_dtype: 420 a = asarray(a, dtype=output_type) 421 422 return a 423 424 425@np_utils.np_doc('cumprod') 426def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring 427 a = asarray(a, dtype=dtype) 428 429 if dtype is None: 430 a = _maybe_promote_to_int(a) 431 432 # If axis is None, the input is flattened. 433 if axis is None: 434 a = ravel(a) 435 axis = 0 436 elif axis < 0: 437 axis += array_ops.rank(a) 438 return math_ops.cumprod(a, axis) 439 440 441@np_utils.np_doc('cumsum') 442def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring 443 a = asarray(a, dtype=dtype) 444 445 if dtype is None: 446 a = _maybe_promote_to_int(a) 447 448 # If axis is None, the input is flattened. 449 if axis is None: 450 a = ravel(a) 451 axis = 0 452 elif axis < 0: 453 axis += array_ops.rank(a) 454 return math_ops.cumsum(a, axis) 455 456 457@np_utils.np_doc('imag') 458def imag(val): 459 val = asarray(val) 460 # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always 461 # return an ndarray. 462 return math_ops.imag(val) 463 464 465_TO_INT_ = 0 466_TO_FLOAT = 1 467 468 469def _reduce(tf_fn, 470 a, 471 axis=None, 472 dtype=None, 473 keepdims=None, 474 promote_int=_TO_INT_, 475 tf_bool_fn=None, 476 preserve_bool=False): 477 """A general reduction function. 478 479 Args: 480 tf_fn: the TF reduction function. 481 a: the array to be reduced. 482 axis: (optional) the axis along which to do the reduction. If None, all 483 dimensions are reduced. 484 dtype: (optional) the dtype of the result. 485 keepdims: (optional) whether to keep the reduced dimension(s). 486 promote_int: how to promote integer and bool inputs. There are three 487 choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2) 488 `_TO_FLOAT` always promotes them to a float type (determined by 489 dtypes.default_float_type); (3) None: don't promote. 490 tf_bool_fn: (optional) the TF reduction function for bool inputs. It will 491 only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype 492 is `np.bool_` and `preserve_bool` is True. 493 preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype 494 is `np.bool_` (some reductions such as np.sum convert bools to integers, 495 while others such as np.max preserve bools. 496 497 Returns: 498 An ndarray. 499 """ 500 if dtype: 501 dtype = np_utils.result_type(dtype) 502 if keepdims is None: 503 keepdims = False 504 a = asarray(a, dtype=dtype) 505 if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and 506 tf_bool_fn is not None): 507 return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims) 508 if dtype is None: 509 dtype = a.dtype.as_numpy_dtype 510 if np.issubdtype(dtype, np.integer) or dtype == np.bool_: 511 if promote_int == _TO_INT_: 512 # If a is an integer/bool type and whose bit width is less than np.int_, 513 # numpy up-casts it to np.int_ based on the documentation at 514 # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html 515 if dtype == np.bool_: 516 is_signed = True 517 width = 8 # We can use any number here that is less than 64 518 else: 519 is_signed = np.issubdtype(dtype, np.signedinteger) 520 width = np.iinfo(dtype).bits 521 # Numpy int_ and uint are defined as 'long' and 'unsigned long', so 522 # should have the same bit width. 523 if width < np.iinfo(np.int_).bits: 524 if is_signed: 525 dtype = np.int_ 526 else: 527 dtype = np.uint 528 a = math_ops.cast(a, dtype) 529 elif promote_int == _TO_FLOAT: 530 a = math_ops.cast(a, np_dtypes.default_float_type()) 531 532 if isinstance(axis, ops.Tensor) and axis.dtype not in ( 533 dtypes.int32, dtypes.int64): 534 axis = math_ops.cast(axis, dtypes.int64) 535 536 return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims) 537 538 539# TODO (DarrenZhang01): Add `axis` support to the `size` API. 540@np_utils.np_doc('size') 541def size(x, axis=None): # pylint: disable=missing-docstring 542 if axis is not None: 543 raise NotImplementedError('axis argument is not supported in the current ' 544 '`np.size` implementation') 545 if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)): 546 return 1 547 x = asarray(x) 548 if x.shape.is_fully_defined(): 549 return np.prod(x.shape.as_list(), dtype=int) 550 else: 551 return array_ops.size_v2(x) 552 553 554@np_utils.np_doc('sum') 555def sum(a, axis=None, dtype=None, keepdims=None): # pylint: disable=redefined-builtin 556 return _reduce( 557 math_ops.reduce_sum, 558 a, 559 axis=axis, 560 dtype=dtype, 561 keepdims=keepdims, 562 tf_bool_fn=math_ops.reduce_any) 563 564 565@np_utils.np_doc('prod') 566def prod(a, axis=None, dtype=None, keepdims=None): 567 return _reduce( 568 math_ops.reduce_prod, 569 a, 570 axis=axis, 571 dtype=dtype, 572 keepdims=keepdims, 573 tf_bool_fn=math_ops.reduce_all) 574 575 576@np_utils.np_doc('mean', unsupported_params=['out']) 577def mean(a, axis=None, dtype=None, out=None, keepdims=None): 578 if out is not None: 579 raise ValueError('Setting out is not supported.') 580 return _reduce( 581 math_ops.reduce_mean, 582 a, 583 axis=axis, 584 dtype=dtype, 585 keepdims=keepdims, 586 promote_int=_TO_FLOAT) 587 588 589@np_utils.np_doc('amax', unsupported_params=['out']) 590def amax(a, axis=None, out=None, keepdims=None): 591 if out is not None: 592 raise ValueError('Setting out is not supported.') 593 return _reduce( 594 math_ops.reduce_max, 595 a, 596 axis=axis, 597 dtype=None, 598 keepdims=keepdims, 599 promote_int=None, 600 tf_bool_fn=math_ops.reduce_any, 601 preserve_bool=True) 602 603 604@np_utils.np_doc('amin', unsupported_params=['out']) 605def amin(a, axis=None, out=None, keepdims=None): 606 if out is not None: 607 raise ValueError('Setting out is not supported.') 608 return _reduce( 609 math_ops.reduce_min, 610 a, 611 axis=axis, 612 dtype=None, 613 keepdims=keepdims, 614 promote_int=None, 615 tf_bool_fn=math_ops.reduce_all, 616 preserve_bool=True) 617 618 619@np_utils.np_doc('var') 620def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: disable=missing-docstring 621 if dtype: 622 working_dtype = np_utils.result_type(a, dtype) 623 else: 624 working_dtype = None 625 if out is not None: 626 raise ValueError('Setting out is not supported.') 627 if ddof != 0: 628 # TF reduce_variance doesn't support ddof, so calculate it using raw ops. 629 def reduce_fn(input_tensor, axis, keepdims): 630 means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True) 631 centered = input_tensor - means 632 if input_tensor.dtype in (dtypes.complex64, dtypes.complex128): 633 centered = math_ops.cast( 634 math_ops.real(centered * math_ops.conj(centered)), 635 input_tensor.dtype) 636 else: 637 centered = math_ops.square(centered) 638 squared_deviations = math_ops.reduce_sum( 639 centered, axis=axis, keepdims=keepdims) 640 641 if axis is None: 642 n = array_ops.size(input_tensor) 643 else: 644 if axis < 0: 645 axis += array_ops.rank(input_tensor) 646 n = math_ops.reduce_prod( 647 array_ops.gather(array_ops.shape(input_tensor), axis)) 648 n = math_ops.cast(n - ddof, input_tensor.dtype) 649 650 return math_ops.cast(math_ops.divide(squared_deviations, n), dtype) 651 else: 652 reduce_fn = math_ops.reduce_variance 653 654 result = _reduce( 655 reduce_fn, 656 a, 657 axis=axis, 658 dtype=working_dtype, 659 keepdims=keepdims, 660 promote_int=_TO_FLOAT) 661 if dtype: 662 result = math_ops.cast(result, dtype) 663 return result 664 665 666@np_utils.np_doc('std') 667def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstring 668 return _reduce( 669 math_ops.reduce_std, 670 a, 671 axis=axis, 672 dtype=None, 673 keepdims=keepdims, 674 promote_int=_TO_FLOAT) 675 676 677@np_utils.np_doc('ravel') 678def ravel(a): # pylint: disable=missing-docstring 679 a = asarray(a) 680 return array_ops.reshape(a, [-1]) 681 682 683@np_utils.np_doc('real') 684def real(val): 685 val = asarray(val) 686 # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always 687 # return an ndarray. 688 return math_ops.real(val) 689 690 691@np_utils.np_doc('repeat') 692def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring 693 a = asarray(a) 694 original_shape = a._shape_as_list() # pylint: disable=protected-access 695 # Best effort recovery of the shape. 696 known_shape = original_shape is not None and None not in original_shape 697 if known_shape: 698 if not original_shape: 699 original_shape = (repeats,) 700 else: 701 repeats_np = np.ravel(np.array(repeats)) 702 if repeats_np.size == 1: 703 repeats_np = repeats_np.item() 704 if axis is None: 705 original_shape = (repeats_np * np.prod(original_shape),) 706 else: 707 original_shape[axis] = repeats_np * original_shape[axis] 708 else: 709 if axis is None: 710 original_shape = (repeats_np.sum(),) 711 else: 712 original_shape[axis] = repeats_np.sum() 713 714 repeats = asarray(repeats) 715 result = array_ops.repeat(a, repeats, axis) 716 if known_shape: 717 result.set_shape(original_shape) 718 719 return result 720 721 722@np_utils.np_doc('around') 723def around(a, decimals=0): # pylint: disable=missing-docstring 724 a = asarray(a) 725 dtype = a.dtype.as_numpy_dtype 726 factor = math.pow(10, decimals) 727 if np.issubdtype(dtype, np.inexact): 728 factor = math_ops.cast(factor, dtype) 729 else: 730 # Use float as the working dtype when a.dtype is exact (e.g. integer), 731 # because `decimals` can be negative. 732 float_dtype = np_dtypes.default_float_type() 733 a = a.astype(float_dtype) 734 factor = math_ops.cast(factor, float_dtype) 735 a = math_ops.multiply(a, factor) 736 a = math_ops.round(a) 737 a = math_ops.divide(a, factor) 738 return a.astype(dtype) 739 740 741setattr(np_arrays.ndarray, '__round__', around) 742 743 744@np_utils.np_doc('reshape') 745def reshape(a, newshape, order='C'): 746 """order argument can only b 'C' or 'F'.""" 747 if order not in {'C', 'F'}: 748 raise ValueError('Unsupported order argument {}'.format(order)) 749 750 a = asarray(a) 751 if isinstance(newshape, int): 752 newshape = [newshape] 753 754 if order == 'F': 755 r = array_ops.transpose( 756 array_ops.reshape(array_ops.transpose(a), newshape[::-1])) 757 else: 758 r = array_ops.reshape(a, newshape) 759 760 return r 761 762 763def _reshape_method_wrapper(a, *newshape, **kwargs): 764 order = kwargs.pop('order', 'C') 765 if kwargs: 766 raise ValueError('Unsupported arguments: {}'.format(kwargs.keys())) 767 768 if len(newshape) == 1 and not isinstance(newshape[0], int): 769 newshape = newshape[0] 770 771 return reshape(a, newshape, order=order) 772 773 774@np_utils.np_doc('expand_dims') 775def expand_dims(a, axis): 776 a = asarray(a) 777 return array_ops.expand_dims(a, axis=axis) 778 779 780@np_utils.np_doc('squeeze') 781def squeeze(a, axis=None): 782 a = asarray(a) 783 return array_ops.squeeze(a, axis) 784 785 786@np_utils.np_doc('transpose') 787def transpose(a, axes=None): 788 a = asarray(a) 789 if axes is not None: 790 axes = asarray(axes) 791 return array_ops.transpose(a=a, perm=axes) 792 793 794@np_utils.np_doc('swapaxes') 795def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring 796 a = asarray(a) 797 def adjust_axes(axes, rank): 798 def f(x): 799 if isinstance(x, int): 800 if x < 0: 801 x = x + rank 802 else: 803 x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x) 804 return x 805 return nest.map_structure(f, axes) 806 807 if (a.shape.rank is not None and 808 isinstance(axis1, int) and isinstance(axis2, int)): 809 # This branch makes sure `perm` is statically known, to avoid a 810 # not-compile-time-constant XLA error. 811 a_rank = a.shape.rank 812 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 813 perm = list(range(a_rank)) 814 perm[axis1] = axis2 815 perm[axis2] = axis1 816 else: 817 a_rank = array_ops.rank(a) 818 axis1, axis2 = adjust_axes((axis1, axis2), a_rank) 819 perm = math_ops.range(a_rank) 820 perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], 821 [axis2, axis1]) 822 a = array_ops.transpose(a, perm) 823 return a 824 825 826@np_utils.np_doc('moveaxis') 827def moveaxis(a, source, destination): # pylint: disable=missing-docstring 828 """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" 829 if not source and not destination: 830 return a 831 832 a = asarray(a) 833 834 if isinstance(source, int): 835 source = (source,) 836 if isinstance(destination, int): 837 destination = (destination,) 838 if len(source) != len(destination): 839 raise ValueError('The lengths of source and destination must equal') 840 841 a_rank = np_utils._maybe_static(array_ops.rank(a)) # pylint: disable=protected-access 842 843 def _correct_axis(axis, rank): 844 if axis < 0: 845 return axis + rank 846 return axis 847 848 source = tuple(_correct_axis(axis, a_rank) for axis in source) 849 destination = tuple(_correct_axis(axis, a_rank) for axis in destination) 850 851 if a.shape.rank is not None: 852 perm = [i for i in range(a_rank) if i not in source] 853 for dest, src in sorted(zip(destination, source)): 854 assert dest <= len(perm) 855 perm.insert(dest, src) 856 else: 857 r = math_ops.range(a_rank) 858 859 def _remove_indices(a, b): 860 """Remove indices (`b`) from `a`.""" 861 items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b)) 862 863 i = 0 864 result = [] 865 866 for item in items: 867 result.append(a[i:item]) 868 i = item + 1 869 870 result.append(a[i:]) 871 872 return array_ops.concat(result, 0) 873 874 minus_sources = _remove_indices(r, source) 875 minus_dest = _remove_indices(r, destination) 876 877 perm = array_ops.scatter_nd( 878 array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank]) 879 perm = array_ops.tensor_scatter_update( 880 perm, array_ops.expand_dims(destination, 1), source) 881 a = array_ops.transpose(a, perm) 882 883 return a 884 885 886@np_utils.np_doc('pad') 887def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name 888 """Only supports modes 'constant', 'reflect' and 'symmetric' currently.""" 889 constant_values = kwargs.get('constant_values', 0) 890 if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): 891 raise ValueError('Unsupported padding mode: ' + mode) 892 mode = mode.upper() 893 array = asarray(array) 894 pad_width = asarray(pad_width, dtype=dtypes.int32) 895 return array_ops.pad( 896 tensor=array, 897 paddings=pad_width, 898 mode=mode, 899 constant_values=constant_values) 900 901 902@np_utils.np_doc('take') 903def take(a, indices, axis=None, out=None, mode='clip'): 904 """out argument is not supported, and default mode is clip.""" 905 if out is not None: 906 raise ValueError('out argument is not supported in take.') 907 908 if mode not in {'raise', 'clip', 'wrap'}: 909 raise ValueError("Invalid mode '{}' for take".format(mode)) 910 911 a = asarray(a) 912 indices = asarray(indices) 913 914 if axis is None: 915 a = array_ops.reshape(a, [-1]) 916 axis = 0 917 918 axis_size = array_ops.shape(a, out_type=indices.dtype)[axis] 919 if mode == 'clip': 920 indices = clip_ops.clip_by_value(indices, 0, axis_size - 1) 921 elif mode == 'wrap': 922 indices = math_ops.floormod(indices, axis_size) 923 else: 924 raise ValueError("The 'raise' mode to take is not supported.") 925 926 return array_ops.gather(a, indices, axis=axis) 927 928 929@np_utils.np_doc_only('where') 930def where(condition, x=None, y=None): 931 """Raises ValueError if exactly one of x or y is not None.""" 932 condition = asarray(condition, dtype=np.bool_) 933 if x is None and y is None: 934 return nonzero(condition) 935 elif x is not None and y is not None: 936 x, y = _promote_dtype(x, y) 937 return array_ops.where_v2(condition, x, y) 938 raise ValueError('Both x and y must be ndarrays, or both must be None.') 939 940 941@np_utils.np_doc('select') 942def select(condlist, choicelist, default=0): # pylint: disable=missing-docstring 943 if len(condlist) != len(choicelist): 944 msg = 'condlist must have length equal to choicelist ({} vs {})' 945 raise ValueError(msg.format(len(condlist), len(choicelist))) 946 if not condlist: 947 raise ValueError('condlist must be non-empty') 948 choices = _promote_dtype(default, *choicelist) 949 choicelist = choices[1:] 950 output = choices[0] 951 # The traversal is in reverse order so we can return the first value in 952 # choicelist where condlist is True. 953 for cond, choice in zip(condlist[::-1], choicelist[::-1]): 954 output = where(cond, choice, output) 955 return output 956 957 958@np_utils.np_doc('shape', link=np_utils.Link( 959 'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html')) 960def shape(a): 961 a = asarray(a) 962 return a.shape 963 964 965@np_utils.np_doc('ndim', link=np_utils.NoLink()) 966def ndim(a): 967 a = asarray(a) 968 return a.ndim 969 970 971@np_utils.np_doc('isscalar') 972def isscalar(num): 973 return ndim(num) == 0 974 975 976def _boundaries_to_sizes(a, boundaries, axis): 977 """Converting boundaries of splits to sizes of splits. 978 979 Args: 980 a: the array to be split. 981 boundaries: the boundaries, as in np.split. 982 axis: the axis along which to split. 983 984 Returns: 985 A list of sizes of the splits, as in tf.split. 986 """ 987 if axis >= len(a.shape): 988 raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape)) 989 total_size = a.shape[axis] 990 sizes = [] 991 sizes_sum = 0 992 prev = 0 993 for i, b in enumerate(boundaries): 994 size = b - prev 995 if size < 0: 996 raise ValueError('The %s-th boundary %s is smaller than the previous ' 997 'boundary %s' % (i, b, prev)) 998 size = min(size, max(0, total_size - sizes_sum)) 999 sizes.append(size) 1000 sizes_sum += size 1001 prev = b 1002 sizes.append(max(0, total_size - sizes_sum)) 1003 return sizes 1004 1005 1006@np_utils.np_doc('split') 1007def split(ary, indices_or_sections, axis=0): 1008 ary = asarray(ary) 1009 if not isinstance(indices_or_sections, int): 1010 indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) 1011 return array_ops.split(ary, indices_or_sections, axis=axis) 1012 1013 1014def _split_on_axis(np_fun_name, axis): 1015 1016 @np_utils.np_doc(np_fun_name) 1017 def f(ary, indices_or_sections): 1018 if isinstance(indices_or_sections, int): 1019 ary_shape = ary.shape[axis] 1020 if ary_shape is not None and ary_shape % indices_or_sections: 1021 raise ValueError( 1022 'array split does not result in an equal division') 1023 return split(ary, indices_or_sections, axis=axis) 1024 1025 return f 1026 1027 1028vsplit = _split_on_axis('vsplit', axis=0) 1029hsplit = _split_on_axis('hsplit', axis=1) 1030dsplit = _split_on_axis('dsplit', axis=2) 1031 1032 1033@np_utils.np_doc('broadcast_to') 1034def broadcast_to(array, shape): # pylint: disable=redefined-outer-name 1035 return full(shape, array) 1036 1037 1038@np_utils.np_doc('stack') 1039def stack(arrays, axis=0): # pylint: disable=missing-function-docstring 1040 if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)): 1041 arrays = asarray(arrays) 1042 if axis == 0: 1043 return arrays 1044 else: 1045 return swapaxes(arrays, 0, axis) 1046 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1047 unwrapped_arrays = [ 1048 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1049 ] 1050 return asarray(array_ops.stack(unwrapped_arrays, axis)) 1051 1052 1053@np_utils.np_doc('hstack') 1054def hstack(tup): 1055 arrays = [atleast_1d(a) for a in tup] 1056 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1057 unwrapped_arrays = [ 1058 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1059 ] 1060 rank = array_ops.rank(unwrapped_arrays[0]) 1061 return np_utils.cond( 1062 math_ops.equal(rank, 1063 1), lambda: array_ops.concat(unwrapped_arrays, axis=0), 1064 lambda: array_ops.concat(unwrapped_arrays, axis=1)) 1065 1066 1067@np_utils.np_doc('vstack') 1068def vstack(tup): 1069 arrays = [atleast_2d(a) for a in tup] 1070 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1071 unwrapped_arrays = [ 1072 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1073 ] 1074 return array_ops.concat(unwrapped_arrays, axis=0) 1075 1076 1077@np_utils.np_doc('dstack') 1078def dstack(tup): 1079 arrays = [atleast_3d(a) for a in tup] 1080 arrays = _promote_dtype(*arrays) # pylint: disable=protected-access 1081 unwrapped_arrays = [ 1082 a if isinstance(a, np_arrays.ndarray) else a for a in arrays 1083 ] 1084 return array_ops.concat(unwrapped_arrays, axis=2) 1085 1086 1087def _pad_left_to(n, old_shape): 1088 old_shape = asarray(old_shape, dtype=np.int32) 1089 new_shape = array_ops.pad( 1090 old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]], 1091 constant_values=1) 1092 return asarray(new_shape) 1093 1094 1095def _atleast_nd(n, new_shape, *arys): 1096 """Reshape arrays to be at least `n`-dimensional. 1097 1098 Args: 1099 n: The minimal rank. 1100 new_shape: a function that takes `n` and the old shape and returns the 1101 desired new shape. 1102 *arys: ndarray(s) to be reshaped. 1103 1104 Returns: 1105 The reshaped array(s). 1106 """ 1107 1108 def f(x): 1109 # pylint: disable=g-long-lambda 1110 x = asarray(x) 1111 return asarray( 1112 np_utils.cond( 1113 np_utils.greater(n, array_ops.rank(x)), 1114 lambda: reshape(x, new_shape(n, array_ops.shape(x))), 1115 lambda: x)) 1116 1117 arys = list(map(f, arys)) 1118 if len(arys) == 1: 1119 return arys[0] 1120 else: 1121 return arys 1122 1123 1124@np_utils.np_doc('atleast_1d') 1125def atleast_1d(*arys): 1126 return _atleast_nd(1, _pad_left_to, *arys) 1127 1128 1129@np_utils.np_doc('atleast_2d') 1130def atleast_2d(*arys): 1131 return _atleast_nd(2, _pad_left_to, *arys) 1132 1133 1134@np_utils.np_doc('atleast_3d') 1135def atleast_3d(*arys): # pylint: disable=missing-docstring 1136 1137 def new_shape(_, old_shape): 1138 # pylint: disable=g-long-lambda 1139 ndim_ = array_ops.size(old_shape) 1140 return np_utils.cond( 1141 math_ops.equal(ndim_, 0), 1142 lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32), 1143 lambda: np_utils.cond( 1144 math_ops.equal(ndim_, 1), lambda: array_ops.pad( 1145 old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad( 1146 old_shape, [[0, 1]], constant_values=1))) 1147 1148 return _atleast_nd(3, new_shape, *arys) 1149 1150 1151@np_utils.np_doc('nonzero') 1152def nonzero(a): 1153 a = atleast_1d(a) 1154 if a.shape.rank is None: 1155 raise ValueError("The rank of `a` is unknown, so we can't decide how many " 1156 'arrays to return.') 1157 return array_ops.unstack( 1158 array_ops.where_v2(math_ops.cast(a, dtypes.bool)), 1159 a.shape.rank, 1160 axis=1) 1161 1162 1163@np_utils.np_doc('diag_indices') 1164def diag_indices(n, ndim=2): # pylint: disable=missing-docstring,redefined-outer-name 1165 if n < 0: 1166 raise ValueError( 1167 'n argument to diag_indices must be nonnegative, got {}'.format(n)) 1168 if ndim < 0: 1169 raise ValueError( 1170 'ndim argument to diag_indices must be nonnegative, got {}'.format( 1171 ndim)) 1172 1173 return (math_ops.range(n),) * ndim 1174 1175 1176@np_utils.np_doc('tri') 1177def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-docstring 1178 M = M if M is not None else N 1179 if dtype is not None: 1180 dtype = np_utils.result_type(dtype) 1181 else: 1182 dtype = np_dtypes.default_float_type() 1183 1184 if k < 0: 1185 lower = -k - 1 1186 if lower > N: 1187 r = array_ops.zeros([N, M], dtype) 1188 else: 1189 # Keep as tf bool, since we create an upper triangular matrix and invert 1190 # it. 1191 o = array_ops.ones([N, M], dtype=dtypes.bool) 1192 r = math_ops.cast( 1193 math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype) 1194 else: 1195 o = array_ops.ones([N, M], dtype) 1196 if k > M: 1197 r = o 1198 else: 1199 r = array_ops.matrix_band_part(o, -1, k) 1200 return r 1201 1202 1203@np_utils.np_doc('tril') 1204def tril(m, k=0): # pylint: disable=missing-docstring 1205 m = asarray(m) 1206 if m.shape.ndims is None: 1207 raise ValueError('Argument to tril should have known rank') 1208 m_shape = m.shape.as_list() 1209 1210 if len(m_shape) < 2: 1211 raise ValueError('Argument to tril must have rank at least 2') 1212 1213 if m_shape[-1] is None or m_shape[-2] is None: 1214 raise ValueError('Currently, the last two dimensions of the input array ' 1215 'need to be known.') 1216 1217 z = constant_op.constant(0, m.dtype) 1218 1219 mask = tri(*m_shape[-2:], k=k, dtype=bool) 1220 return array_ops.where_v2( 1221 array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) 1222 1223 1224@np_utils.np_doc('triu') 1225def triu(m, k=0): # pylint: disable=missing-docstring 1226 m = asarray(m) 1227 if m.shape.ndims is None: 1228 raise ValueError('Argument to triu should have known rank') 1229 m_shape = m.shape.as_list() 1230 1231 if len(m_shape) < 2: 1232 raise ValueError('Argument to triu must have rank at least 2') 1233 1234 if m_shape[-1] is None or m_shape[-2] is None: 1235 raise ValueError('Currently, the last two dimensions of the input array ' 1236 'need to be known.') 1237 1238 z = constant_op.constant(0, m.dtype) 1239 1240 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) 1241 return array_ops.where_v2( 1242 array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) 1243 1244 1245@np_utils.np_doc('flip') 1246def flip(m, axis=None): # pylint: disable=missing-docstring 1247 m = asarray(m) 1248 1249 if axis is None: 1250 return array_ops.reverse(m, math_ops.range(array_ops.rank(m))) 1251 1252 axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access 1253 1254 return array_ops.reverse(m, [axis]) 1255 1256 1257@np_utils.np_doc('flipud') 1258def flipud(m): # pylint: disable=missing-docstring 1259 return flip(m, 0) 1260 1261 1262@np_utils.np_doc('fliplr') 1263def fliplr(m): # pylint: disable=missing-docstring 1264 return flip(m, 1) 1265 1266 1267@np_utils.np_doc('roll') 1268def roll(a, shift, axis=None): # pylint: disable=missing-docstring 1269 a = asarray(a) 1270 1271 if axis is not None: 1272 return manip_ops.roll(a, shift, axis) 1273 1274 # If axis is None, the roll happens as a 1-d tensor. 1275 original_shape = array_ops.shape(a) 1276 a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0) 1277 return array_ops.reshape(a, original_shape) 1278 1279 1280@np_utils.np_doc('rot90') 1281def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring 1282 m_rank = array_ops.rank(m) 1283 ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access 1284 1285 k = k % 4 1286 if k == 0: 1287 return m 1288 elif k == 2: 1289 return flip(flip(m, ax1), ax2) 1290 else: 1291 perm = math_ops.range(m_rank) 1292 perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1]) 1293 1294 if k == 1: 1295 return transpose(flip(m, ax2), perm) 1296 else: 1297 return flip(transpose(m, perm), ax2) 1298 1299 1300@np_utils.np_doc('vander') 1301def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name 1302 x = asarray(x) 1303 1304 x_shape = array_ops.shape(x) 1305 N = N or x_shape[0] 1306 1307 N_temp = np_utils.get_static_value(N) # pylint: disable=invalid-name 1308 if N_temp is not None: 1309 N = N_temp 1310 if N < 0: 1311 raise ValueError('N must be nonnegative') 1312 else: 1313 control_flow_ops.Assert(N >= 0, [N]) 1314 1315 rank = array_ops.rank(x) 1316 rank_temp = np_utils.get_static_value(rank) 1317 if rank_temp is not None: 1318 rank = rank_temp 1319 if rank != 1: 1320 raise ValueError('x must be a one-dimensional array') 1321 else: 1322 control_flow_ops.Assert(math_ops.equal(rank, 1), [rank]) 1323 1324 if increasing: 1325 start = 0 1326 limit = N 1327 delta = 1 1328 else: 1329 start = N - 1 1330 limit = -1 1331 delta = -1 1332 1333 x = array_ops.expand_dims(x, -1) 1334 return math_ops.pow( 1335 x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)) 1336 1337 1338@np_utils.np_doc('ix_') 1339def ix_(*args): # pylint: disable=missing-docstring 1340 n = len(args) 1341 output = [] 1342 for i, a in enumerate(args): 1343 a = asarray(a) 1344 a_rank = array_ops.rank(a) 1345 a_rank_temp = np_utils.get_static_value(a_rank) 1346 if a_rank_temp is not None: 1347 a_rank = a_rank_temp 1348 if a_rank != 1: 1349 raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format( 1350 i, a_rank)) 1351 else: 1352 control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank]) 1353 1354 new_shape = [1] * n 1355 new_shape[i] = -1 1356 dtype = a.dtype 1357 if dtype == dtypes.bool: 1358 output.append(array_ops.reshape(nonzero(a)[0], new_shape)) 1359 elif dtype.is_integer: 1360 output.append(array_ops.reshape(a, new_shape)) 1361 else: 1362 raise ValueError( 1363 'Only integer and bool dtypes are supported, got {}'.format(dtype)) 1364 1365 return output 1366 1367 1368@np_utils.np_doc('broadcast_arrays') 1369def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring 1370 subok = kwargs.pop('subok', False) 1371 if subok: 1372 raise ValueError('subok=True is not supported.') 1373 if kwargs: 1374 raise ValueError('Received unsupported arguments {}'.format(kwargs.keys())) 1375 1376 args = [asarray(arg) for arg in args] 1377 return np_utils.tf_broadcast(*args) 1378 1379 1380@np_utils.np_doc_only('sign') 1381def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstring,redefined-outer-name 1382 if out: 1383 raise ValueError('tf.numpy doesnt support setting out.') 1384 if where: 1385 raise ValueError('tf.numpy doesnt support setting where.') 1386 if kwargs: 1387 raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys())) 1388 1389 x = asarray(x) 1390 dtype = x.dtype.as_numpy_dtype 1391 if np.issubdtype(dtype, np.complexfloating): 1392 result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype) 1393 else: 1394 result = math_ops.sign(x) 1395 1396 return result 1397 1398 1399# Note that np.take_along_axis may not be present in some supported versions of 1400# numpy. 1401@np_utils.np_doc('take_along_axis') 1402def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring 1403 arr = asarray(arr) 1404 indices = asarray(indices) 1405 1406 if axis is None: 1407 return take_along_axis(arr.ravel(), indices, 0) 1408 1409 rank = array_ops.rank(arr) 1410 axis = axis + rank if axis < 0 else axis 1411 1412 # Broadcast shapes to match, ensure that the axis of interest is not 1413 # broadcast. 1414 arr_shape_original = array_ops.shape(arr) 1415 indices_shape_original = array_ops.shape(indices) 1416 arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1]) 1417 indices_shape = array_ops.tensor_scatter_update(indices_shape_original, 1418 [[axis]], [1]) 1419 broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape, 1420 indices_shape) 1421 arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]], 1422 [arr_shape_original[axis]]) 1423 indices_shape = array_ops.tensor_scatter_update( 1424 broadcasted_shape, [[axis]], [indices_shape_original[axis]]) 1425 arr = array_ops.broadcast_to(arr, arr_shape) 1426 indices = array_ops.broadcast_to(indices, indices_shape) 1427 1428 # Save indices shape so we can restore it later. 1429 possible_result_shape = indices.shape 1430 1431 # Correct indices since gather doesn't correctly handle negative indices. 1432 indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices) 1433 1434 swapaxes_ = lambda t: swapaxes(t, axis, -1) 1435 1436 dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) 1437 arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, 1438 lambda: swapaxes_(arr)) 1439 indices = np_utils.cond(dont_move_axis_to_end, lambda: indices, 1440 lambda: swapaxes_(indices)) 1441 1442 arr_shape = array_ops.shape(arr) 1443 arr = array_ops.reshape(arr, [-1, arr_shape[-1]]) 1444 1445 indices_shape = array_ops.shape(indices) 1446 indices = array_ops.reshape(indices, [-1, indices_shape[-1]]) 1447 1448 result = array_ops.gather(arr, indices, batch_dims=1) 1449 result = array_ops.reshape(result, indices_shape) 1450 result = np_utils.cond(dont_move_axis_to_end, lambda: result, 1451 lambda: swapaxes_(result)) 1452 result.set_shape(possible_result_shape) 1453 1454 return result 1455 1456 1457_SLICE_ERORR = ( 1458 'only integers, slices (`:`), ellipsis (`...`), ' 1459 'numpy.newaxis (`None`) and integer or boolean arrays are valid indices') 1460 1461 1462def _as_index(idx, need_scalar=True): 1463 """Helper function to parse idx as an index. 1464 1465 Args: 1466 idx: index 1467 need_scalar: If idx needs to be a scalar value. 1468 1469 Returns: 1470 A pair, (indx, bool). First one is the parsed index and can be a tensor, 1471 or scalar integer / Dimension. Second one is True if rank is known to be 0. 1472 1473 Raises: 1474 IndexError: For incorrect indices. 1475 """ 1476 if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): 1477 return idx, True 1478 data = asarray(idx) 1479 if data.dtype == dtypes.bool: 1480 if data.shape.ndims != 1: 1481 # TODO(agarwal): handle higher rank boolean masks. 1482 raise NotImplementedError('Need rank 1 for bool index %s' % idx) 1483 data = array_ops.where_v2(data) 1484 data = array_ops.reshape(data, [-1]) 1485 if need_scalar and data.shape.rank not in (None, 0): 1486 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1487 np_dtype = data.dtype.as_numpy_dtype 1488 if not np.issubdtype(np_dtype, np.integer): 1489 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1490 if data.dtype not in (dtypes.int64, dtypes.int32): 1491 # TF slicing can only handle int32/int64. So we need to cast. 1492 promoted_dtype = np.promote_types(np.int32, np_dtype) 1493 if promoted_dtype == np.int32: 1494 data = math_ops.cast(data, dtypes.int32) 1495 elif promoted_dtype == np.int64: 1496 data = math_ops.cast(data, dtypes.int64) 1497 else: 1498 raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx)) 1499 return data, data.shape.rank == 0 1500 1501 1502class _UpdateMethod(enum.Enum): 1503 UPDATE = 0 1504 ADD = 1 1505 MIN = 2 1506 MAX = 3 1507 1508 1509def _slice_helper(tensor, slice_spec, update_method=None, updates=None): 1510 """Helper function for __getitem__ and _with_index_update_helper. 1511 1512 This function collects the indices in `slice_spec` into two buckets, which we 1513 can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2 1514 `gather`. They also correspond to "basic indices" and "advanced indices" in 1515 numpy. This function supports both reading and writing at the indices. The 1516 reading path can be summarized as `gather(stride_slice(tensor, idx1), 1517 idx2)`. The writing path can be summarized as `strided_slice_update(tensor, 1518 idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`. (`gather` here 1519 means `tf.gather` or `tf.gather_nd`; `scatter` here means 1520 `tf.tensor_scatter_update`.) The writing path is inefficient because it needs 1521 to first read out a portion (probably much larger than `updates`) of `tensor` 1522 using `strided_slice`, update it, and then write the portion back. An 1523 alternative approach is to only use `scatter`, which amounts to using the 1524 indexing mechanism of gather/scatter to implement 1525 strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter 1526 because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but 1527 not TF gather/scatter because they don't support spans (except those that 1528 cover entire dimensions, i.e. `:`). If we materialize spans into individual 1529 indices, the size of the index tensor would explode. (Note that XLA 1530 Gather/Scatter have a similar problem for stride > 1 because they don't 1531 support strides. Indices such as `1:2:8` will need to be materialized into 1532 individual indices such as [1, 3, 5, 7].) 1533 1534 Args: 1535 tensor: the tensor to be read from or write into. 1536 slice_spec: the indices. 1537 update_method: (optional) a member of `_UpdateMethod`, indicating how to 1538 update the values (replacement, add, etc.). `None` indicates just reading. 1539 updates: (optional) the new values to write into `tensor`. It must have the 1540 same dtype as `tensor`. 1541 1542 Returns: 1543 The result of reading (if `update_method` is `None`) or the updated `tensor` 1544 after writing. 1545 """ 1546 begin, end, strides = [], [], [] 1547 new_axis_mask, shrink_axis_mask = 0, 0 1548 begin_mask, end_mask = 0, 0 1549 ellipsis_mask = 0 1550 advanced_indices = [] 1551 shrink_indices = [] 1552 for index, s in enumerate(slice_spec): 1553 if isinstance(s, slice): 1554 if s.start is not None: 1555 begin.append(_as_index(s.start)[0]) 1556 else: 1557 begin.append(0) 1558 begin_mask |= (1 << index) 1559 if s.stop is not None: 1560 end.append(_as_index(s.stop)[0]) 1561 else: 1562 end.append(0) 1563 end_mask |= (1 << index) 1564 if s.step is not None: 1565 strides.append(_as_index(s.step)[0]) 1566 else: 1567 strides.append(1) 1568 elif s is Ellipsis: 1569 begin.append(0) 1570 end.append(0) 1571 strides.append(1) 1572 ellipsis_mask |= (1 << index) 1573 elif s is array_ops.newaxis: 1574 begin.append(0) 1575 end.append(0) 1576 strides.append(1) 1577 new_axis_mask |= (1 << index) 1578 else: 1579 s, is_scalar = _as_index(s, False) 1580 if is_scalar: 1581 begin.append(s) 1582 end.append(s + 1) 1583 strides.append(1) 1584 shrink_axis_mask |= (1 << index) 1585 shrink_indices.append(index) 1586 else: 1587 begin.append(0) 1588 end.append(0) 1589 strides.append(1) 1590 begin_mask |= (1 << index) 1591 end_mask |= (1 << index) 1592 advanced_indices.append((index, s, ellipsis_mask != 0)) 1593 1594 # stack possibly involves no tensors, so we must use op_scope correct graph. 1595 with ops.name_scope( 1596 None, 1597 'strided_slice', [tensor] + begin + end + strides, 1598 skip_on_eager=False) as name: 1599 if begin: 1600 packed_begin, packed_end, packed_strides = (array_ops.stack(begin), 1601 array_ops.stack(end), 1602 array_ops.stack(strides)) 1603 if (packed_begin.dtype == dtypes.int64 or 1604 packed_end.dtype == dtypes.int64 or 1605 packed_strides.dtype == dtypes.int64): 1606 if packed_begin.dtype != dtypes.int64: 1607 packed_begin = math_ops.cast(packed_begin, dtypes.int64) 1608 if packed_end.dtype != dtypes.int64: 1609 packed_end = math_ops.cast(packed_end, dtypes.int64) 1610 if packed_strides.dtype != dtypes.int64: 1611 packed_strides = math_ops.cast(packed_strides, dtypes.int64) 1612 else: 1613 var_empty = constant_op.constant([], dtype=dtypes.int32) 1614 packed_begin = packed_end = packed_strides = var_empty 1615 if update_method == _UpdateMethod.UPDATE and not advanced_indices: 1616 return array_ops.tensor_strided_slice_update( 1617 tensor, 1618 packed_begin, 1619 packed_end, 1620 packed_strides, 1621 updates, 1622 begin_mask=begin_mask, 1623 end_mask=end_mask, 1624 shrink_axis_mask=shrink_axis_mask, 1625 new_axis_mask=new_axis_mask, 1626 ellipsis_mask=ellipsis_mask, 1627 name=name) 1628 else: 1629 # TODO(b/164251540): Find a better way to support update that does not 1630 # involve one read + two writes. 1631 if updates is not None: 1632 original_tensor = tensor 1633 # TODO(agarwal): set_shape on tensor to set rank. 1634 tensor = array_ops.strided_slice( 1635 tensor, 1636 packed_begin, 1637 packed_end, 1638 packed_strides, 1639 begin_mask=begin_mask, 1640 end_mask=end_mask, 1641 shrink_axis_mask=shrink_axis_mask, 1642 new_axis_mask=new_axis_mask, 1643 ellipsis_mask=ellipsis_mask, 1644 name=name) 1645 if not advanced_indices: 1646 if update_method is None: 1647 return tensor 1648 assert update_method != _UpdateMethod.UPDATE 1649 # TF lacks TensorStridedSliceAdd and alike, so we need to do 1650 # read+add+update. 1651 if update_method == _UpdateMethod.ADD: 1652 update_op = math_ops.add 1653 elif update_method == _UpdateMethod.MIN: 1654 update_op = math_ops.minimum 1655 elif update_method == _UpdateMethod.MAX: 1656 update_op = math_ops.maximum 1657 return array_ops.tensor_strided_slice_update( 1658 original_tensor, 1659 packed_begin, 1660 packed_end, 1661 packed_strides, 1662 update_op(tensor, updates), 1663 begin_mask=begin_mask, 1664 end_mask=end_mask, 1665 shrink_axis_mask=shrink_axis_mask, 1666 new_axis_mask=new_axis_mask, 1667 ellipsis_mask=ellipsis_mask, 1668 name=name + '_2') 1669 advanced_indices_map = {} 1670 for index, data, had_ellipsis in advanced_indices: 1671 if had_ellipsis: 1672 num_shrink = len([x for x in shrink_indices if x > index]) 1673 dim = index - len(slice_spec) + num_shrink 1674 else: 1675 num_shrink = len([x for x in shrink_indices if x < index]) 1676 dim = index - num_shrink 1677 advanced_indices_map[dim] = data 1678 dims = sorted(advanced_indices_map.keys()) 1679 dims_contiguous = True 1680 if len(dims) > 1: 1681 if dims[0] < 0 and dims[-1] >= 0: # not all same sign 1682 dims_contiguous = False 1683 else: 1684 for i in range(len(dims) - 1): 1685 if dims[i] + 1 != dims[i + 1]: 1686 dims_contiguous = False 1687 break 1688 indices = [advanced_indices_map[x] for x in dims] 1689 indices = _promote_dtype(*indices) 1690 indices = np_utils.tf_broadcast(*indices) 1691 stacked_indices = array_ops.stack(indices, axis=-1) 1692 # Skip the contiguous-dims optimization for update because there is no 1693 # tf.*scatter* op that supports the `axis` argument. 1694 if not dims_contiguous or updates is not None: 1695 if range(len(dims)) != dims: 1696 tensor = moveaxis(tensor, dims, range(len(dims))) 1697 tensor_shape_prefix = array_ops.shape( 1698 tensor, out_type=stacked_indices.dtype)[:len(dims)] 1699 stacked_indices = array_ops.where_v2( 1700 stacked_indices < 0, stacked_indices + tensor_shape_prefix, 1701 stacked_indices) 1702 if updates is None: 1703 return array_ops.gather_nd(tensor, stacked_indices) 1704 else: 1705 # We only need to move-axis `updates` in the contiguous case becausce 1706 # only in this case the result dimensions of advanced indexing are in 1707 # the middle of `updates`. In the non-contiguous case, those dimensions 1708 # are always at the front. 1709 if dims_contiguous: 1710 # TODO(wangpeng): Support unknown rank (e.g. by partially flattening 1711 # `updates`) 1712 if stacked_indices.shape.rank is None: 1713 raise NotImplementedError( 1714 'Rank of the advanced indices must currently be known') 1715 batch_size = stacked_indices.shape.rank - 1 1716 batch_start = dims[0] 1717 if batch_start < 0: 1718 batch_start += len(dims) - batch_size 1719 def range_(start, length): 1720 return range(start, start + length) 1721 updates = moveaxis(updates, range_(batch_start, batch_size), 1722 range(batch_size)) 1723 if update_method == _UpdateMethod.UPDATE: 1724 update_op = array_ops.tensor_scatter_update 1725 elif update_method == _UpdateMethod.ADD: 1726 update_op = array_ops.tensor_scatter_add 1727 elif update_method == _UpdateMethod.MIN: 1728 update_op = array_ops.tensor_scatter_min 1729 elif update_method == _UpdateMethod.MAX: 1730 update_op = array_ops.tensor_scatter_max 1731 tensor = update_op( 1732 tensor, stacked_indices, updates) 1733 if range(len(dims)) != dims: 1734 tensor = moveaxis(tensor, range(len(dims)), dims) 1735 return array_ops.tensor_strided_slice_update( 1736 original_tensor, 1737 packed_begin, 1738 packed_end, 1739 packed_strides, 1740 tensor, 1741 begin_mask=begin_mask, 1742 end_mask=end_mask, 1743 shrink_axis_mask=shrink_axis_mask, 1744 new_axis_mask=new_axis_mask, 1745 ellipsis_mask=ellipsis_mask, 1746 name=name + '_2') 1747 # Note that gather_nd does not support gathering from inside the array. 1748 # To avoid shuffling data back and forth, we transform the indices and 1749 # do a gather instead. 1750 rank = np_utils._maybe_static(array_ops.rank(tensor)) # pylint: disable=protected-access 1751 dims = [(x + rank if x < 0 else x) for x in dims] 1752 shape_tensor = array_ops.shape(tensor) 1753 dim_sizes = array_ops.gather(shape_tensor, dims) 1754 if len(dims) == 1: 1755 stacked_indices = indices[0] 1756 stacked_indices = math_ops.cast(stacked_indices, dtypes.int32) 1757 stacked_indices = array_ops.where_v2(stacked_indices < 0, 1758 stacked_indices + dim_sizes, 1759 stacked_indices) 1760 axis = dims[0] 1761 if len(dims) > 1: 1762 index_scaling = math_ops.cumprod( 1763 dim_sizes, reverse=True, exclusive=True) 1764 def _tensordot(a, b): 1765 # TODO(b/168657656): This function should be replaced by 1766 # tensordot(axis=1) once MatMul has int32 XLA kernel. 1767 b = array_ops.broadcast_to(b, array_ops.shape(a)) 1768 return math_ops.reduce_sum(a * b, axis=-1) 1769 stacked_indices = _tensordot(stacked_indices, index_scaling) 1770 flat_shape = array_ops.concat( 1771 [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]], 1772 axis=0) 1773 tensor = array_ops.reshape(tensor, flat_shape) 1774 1775 return array_ops.gather(tensor, stacked_indices, axis=axis) 1776 1777 1778def _as_spec_tuple(slice_spec): 1779 """Convert slice_spec to tuple.""" 1780 if isinstance(slice_spec, 1781 (list, tuple)) and not isinstance(slice_spec, np.ndarray): 1782 is_index = True 1783 for s in slice_spec: 1784 if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)): 1785 is_index = False 1786 break 1787 elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0: 1788 is_index = False 1789 break 1790 if not is_index: 1791 return tuple(slice_spec) 1792 return (slice_spec,) 1793 1794 1795def _getitem(self, slice_spec): 1796 """Implementation of ndarray.__getitem__.""" 1797 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1798 slice_spec.dtype == dtypes.bool) or 1799 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1800 slice_spec.dtype == np.bool_)): 1801 return array_ops.boolean_mask(tensor=self, mask=slice_spec) 1802 1803 if not isinstance(slice_spec, tuple): 1804 slice_spec = _as_spec_tuple(slice_spec) 1805 1806 result_t = _slice_helper(self, slice_spec) 1807 return result_t 1808 1809 1810def _with_index_update_helper(update_method, a, slice_spec, updates): 1811 """Implementation of ndarray._with_index_*.""" 1812 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 1813 slice_spec.dtype == dtypes.bool) or 1814 (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and 1815 slice_spec.dtype == np.bool_)): 1816 slice_spec = nonzero(slice_spec) 1817 1818 if not isinstance(slice_spec, tuple): 1819 slice_spec = _as_spec_tuple(slice_spec) 1820 1821 a_dtype = a.dtype 1822 a, updates = _promote_dtype_binary(a, updates) 1823 result_t = _slice_helper(a, slice_spec, update_method, updates) 1824 return result_t.astype(a_dtype) 1825 1826 1827setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem) 1828setattr(np_arrays.ndarray, '_with_index_update', 1829 functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE)) 1830setattr(np_arrays.ndarray, '_with_index_add', 1831 functools.partial(_with_index_update_helper, _UpdateMethod.ADD)) 1832setattr(np_arrays.ndarray, '_with_index_min', 1833 functools.partial(_with_index_update_helper, _UpdateMethod.MIN)) 1834setattr(np_arrays.ndarray, '_with_index_max', 1835 functools.partial(_with_index_update_helper, _UpdateMethod.MAX)) 1836