xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/numpy_ops/np_array_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common array methods."""
16# pylint: disable=g-direct-tensorflow-import
17
18import enum
19import functools
20import math
21import numbers
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import clip_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import linalg_ops
32from tensorflow.python.ops import manip_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import sort_ops
35from tensorflow.python.ops.numpy_ops import np_arrays
36from tensorflow.python.ops.numpy_ops import np_dtypes
37from tensorflow.python.ops.numpy_ops import np_export
38from tensorflow.python.ops.numpy_ops import np_utils
39from tensorflow.python.util import nest
40
41
42newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis)
43
44
45@np_utils.np_doc('empty')
46def empty(shape, dtype=float):  # pylint: disable=redefined-outer-name
47  return zeros(shape, dtype)
48
49
50@np_utils.np_doc('empty_like')
51def empty_like(a, dtype=None):
52  return zeros_like(a, dtype)
53
54
55@np_utils.np_doc('zeros')
56def zeros(shape, dtype=float):  # pylint: disable=redefined-outer-name
57  dtype = (
58      np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
59  return array_ops.zeros(shape, dtype=dtype)
60
61
62@np_utils.np_doc('zeros_like')
63def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
64  dtype = np_utils.result_type_unary(a, dtype)
65
66  dtype = dtypes.as_dtype(dtype)  # Work around b/149877262
67  return array_ops.zeros_like(a, dtype)
68
69
70@np_utils.np_doc('ones')
71def ones(shape, dtype=float):  # pylint: disable=redefined-outer-name
72  if dtype:
73    dtype = np_utils.result_type(dtype)
74  return array_ops.ones(shape, dtype=dtype)
75
76
77@np_utils.np_doc('ones_like')
78def ones_like(a, dtype=None):
79  dtype = np_utils.result_type_unary(a, dtype)
80  return array_ops.ones_like(a, dtype)
81
82
83@np_utils.np_doc('eye')
84def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-docstring
85  if dtype:
86    dtype = np_utils.result_type(dtype)
87  if not M:
88    M = N
89  # Making sure N, M and k are `int`
90  N = int(N)
91  M = int(M)
92  k = int(k)
93  if k >= M or -k >= N:
94    # tf.linalg.diag will raise an error in this case
95    return zeros([N, M], dtype=dtype)
96  if k == 0:
97    return linalg_ops.eye(N, M, dtype=dtype)
98  # We need the precise length, otherwise tf.linalg.diag will raise an error
99  diag_len = min(N, M)
100  if k > 0:
101    if N >= M:
102      diag_len -= k
103    elif N + k > M:
104      diag_len = M - k
105  elif k <= 0:
106    if M >= N:
107      diag_len += k
108    elif M - k > N:
109      diag_len = N + k
110  diagonal_ = array_ops.ones([diag_len], dtype=dtype)
111  return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
112
113
114@np_utils.np_doc('identity')
115def identity(n, dtype=float):
116  return eye(N=n, M=n, dtype=dtype)
117
118
119@np_utils.np_doc('full')
120def full(shape, fill_value, dtype=None):  # pylint: disable=redefined-outer-name
121  if not isinstance(shape, np_arrays.ndarray):
122    shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
123  shape = atleast_1d(shape)
124  fill_value = asarray(fill_value, dtype=dtype)
125  return array_ops.broadcast_to(fill_value, shape)
126
127
128# Using doc only here since np full_like signature doesn't seem to have the
129# shape argument (even though it exists in the documentation online).
130@np_utils.np_doc_only('full_like')
131def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):  # pylint: disable=missing-docstring,redefined-outer-name
132  """order, subok and shape arguments mustn't be changed."""
133  if order != 'K':
134    raise ValueError('Non-standard orders are not supported.')
135  if not subok:
136    raise ValueError('subok being False is not supported.')
137  if shape:
138    raise ValueError('Overriding the shape is not supported.')
139
140  a = asarray(a)
141  dtype = dtype or np_utils.result_type(a)
142  fill_value = asarray(fill_value, dtype=dtype)
143  return array_ops.broadcast_to(fill_value, array_ops.shape(a))
144
145
146def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
147  """Main implementation of np.array()."""
148  result_t = val
149
150  if not isinstance(result_t, ops.Tensor):
151    dtype = np_utils.result_type_unary(result_t, dtype)
152    # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
153    # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
154    # while np.array allows them. We need to convert-then-cast.
155
156    # EagerTensor conversion complains about "mixed types" when converting
157    # tensors with no dtype information. This is because it infers types based
158    # on one selected item in the list. So e.g. when converting [2., 2j]
159    # to a tensor, it will select float32 as the inferred type and not be able
160    # to convert the list to a float 32 tensor.
161    # Since we have some information about the final dtype we care about, we
162    # supply that information so that convert_to_tensor will do best-effort
163    # conversion to that dtype first.
164    result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype)
165    result_t = math_ops.cast(result_t, dtype=dtype)
166  elif dtype:
167    result_t = math_ops.cast(result_t, dtype)
168
169  if copy:
170    result_t = array_ops.identity(result_t)
171
172  max_ndmin = 32
173  if ndmin > max_ndmin:
174    raise ValueError('ndmin bigger than allowable number of dimensions: '
175                     f'{max_ndmin}.')
176
177  if ndmin == 0:
178    return result_t
179
180  ndims = array_ops.rank(result_t)
181
182  def true_fn():
183    old_shape = array_ops.shape(result_t)
184    new_shape = array_ops.concat(
185        [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0)
186    return array_ops.reshape(result_t, new_shape)
187
188  result_t = np_utils.cond(
189      np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
190  return result_t
191
192
193# TODO(wangpeng): investigate whether we can make `copy` default to False.
194# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
195@np_utils.np_doc_only('array')
196def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
197  """Since Tensors are immutable, a copy is made only if val is placed on a
198
199  different device than the current one. Even if `copy` is False, a new Tensor
200  may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
201  is an ndarray or a Tensor.
202  """  # pylint:disable=g-docstring-missing-newline
203  if dtype:
204    dtype = np_utils.result_type(dtype)
205  return _array_internal(val, dtype, copy, ndmin)
206
207
208# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
209
210
211@np_utils.np_doc('asarray')
212def asarray(a, dtype=None):
213  if dtype:
214    dtype = np_utils.result_type(dtype)
215  if isinstance(a, np_arrays.ndarray) and (
216      not dtype or dtype == a.dtype.as_numpy_dtype):
217    return a
218  return array(a, dtype, copy=False)
219
220
221@np_utils.np_doc('asanyarray')
222def asanyarray(a, dtype=None):
223  return asarray(a, dtype)
224
225
226@np_utils.np_doc('ascontiguousarray')
227def ascontiguousarray(a, dtype=None):
228  return array(a, dtype, ndmin=1)
229
230
231# Numerical ranges.
232@np_utils.np_doc('arange')
233def arange(start, stop=None, step=1, dtype=None):
234  """Returns `step`-separated values in the range [start, stop).
235
236  Args:
237    start: Start of the interval. Included in the range.
238    stop: End of the interval. If not specified, `start` is treated as 0 and
239      `start` value is used as `stop`. If specified, it is not included in the
240      range if `step` is integer. When `step` is floating point, it may or may
241      not be included.
242    step: The difference between 2 consecutive values in the output range. It is
243      recommended to use `linspace` instead of using non-integer values for
244      `step`.
245    dtype: Optional. Type of the resulting ndarray. Could be a python type, a
246      NumPy type or a TensorFlow `DType`. If not provided, the largest type of
247      `start`, `stop`, `step` is used.
248
249  Raises:
250    ValueError: If step is zero.
251  """
252  if not step:
253    raise ValueError('step must be non-zero.')
254  if dtype:
255    dtype = np_utils.result_type(dtype)
256  else:
257    if stop is None:
258      dtype = np_utils.result_type(start, step)
259    else:
260      dtype = np_utils.result_type(start, step, stop)
261  if step > 0 and ((stop is not None and start > stop) or
262                   (stop is None and start < 0)):
263    return array([], dtype=dtype)
264  if step < 0 and ((stop is not None and start < stop) or
265                   (stop is None and start > 0)):
266    return array([], dtype=dtype)
267  # TODO(srbs): There are some bugs when start or stop is float type and dtype
268  # is integer type.
269  return math_ops.cast(
270      math_ops.range(start, limit=stop, delta=step), dtype=dtype)
271
272
273# Building matrices.
274@np_utils.np_doc('diag')
275def diag(v, k=0):  # pylint: disable=missing-docstring
276  """Raises an error if input is not 1- or 2-d."""
277  v = asarray(v)
278  v_rank = array_ops.rank(v)
279
280  v.shape.with_rank_at_most(2)
281
282  # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
283  # tracing time if the shape is known.
284  control_flow_ops.Assert(
285      np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)),
286      [v_rank])
287
288  def _diag(v, k):
289    return np_utils.cond(
290        math_ops.equal(array_ops.size(v), 0),
291        lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
292        lambda: array_ops.matrix_diag(v, k=k))
293
294  def _diag_part(v, k):
295    v_shape = array_ops.shape(v)
296    v, k = np_utils.cond(
297        np_utils.logical_or(
298            np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
299            np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
300        ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
301    result = array_ops.matrix_diag_part(v, k=k)
302    return result
303
304  result = np_utils.cond(
305      math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
306  return result
307
308
309@np_utils.np_doc('diagonal')
310def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
311  a = asarray(a)
312
313  maybe_rank = a.shape.rank
314  if maybe_rank is not None and offset == 0 and (
315      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
316                                                   axis2 == -1):
317    return array_ops.matrix_diag_part(a)
318
319  a = moveaxis(a, (axis1, axis2), (-2, -1))
320
321  a_shape = array_ops.shape(a)
322
323  def _zeros():  # pylint: disable=missing-docstring
324    return (array_ops.zeros(
325        array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)
326
327  # All zeros since diag_part doesn't handle all possible k (aka offset).
328  # Written this way since cond will run shape inference on both branches,
329  # and diag_part shape inference will fail when offset is out of bounds.
330  a, offset = np_utils.cond(
331      np_utils.logical_or(
332          np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
333          np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
334      ), _zeros, lambda: (a, offset))
335
336  a = array_ops.matrix_diag_part(a, k=offset)
337  return a
338
339
340@np_utils.np_doc('diagflat')
341def diagflat(v, k=0):
342  v = asarray(v)
343  return diag(array_ops.reshape(v, [-1]), k)
344
345
346def _promote_dtype(*arrays):
347  dtype = np_utils.result_type(*arrays)
348  def _fast_asarray(a):
349    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
350      return a
351    return _array_internal(a, dtype=dtype, copy=False)
352  return [_fast_asarray(a) for a in arrays]
353
354
355def _promote_dtype_binary(t1, t2):
356  dtype = np_utils._result_type_binary(t1, t2)  # pylint: disable=protected-access
357  if not(
358      isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
359    t1 = _array_internal(t1, dtype=dtype, copy=False)
360  if not(
361      isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
362    t2 = _array_internal(t2, dtype=dtype, copy=False)
363  return t1, t2
364
365
366@np_utils.np_doc('all')
367def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
368  a = asarray(a, dtype=bool)
369  return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
370
371
372@np_utils.np_doc('any')
373def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
374  a = asarray(a, dtype=bool)
375  return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
376
377
378@np_utils.np_doc('compress')
379def compress(condition, a, axis=None):  # pylint: disable=redefined-outer-name,missing-function-docstring
380  condition = asarray(condition, dtype=bool)
381  a = asarray(a)
382
383  if condition.ndim != 1:
384    raise ValueError('condition must be a 1-d array.')
385  # `np.compress` treats scalars as 1-d arrays.
386  if a.ndim == 0:
387    a = ravel(a)
388
389  if axis is None:
390    a = ravel(a)
391    axis = 0
392
393  if axis < 0:
394    axis += a.ndim
395
396  assert axis >= 0 and axis < a.ndim
397
398  # `tf.boolean_mask` requires the first dimensions of array and condition to
399  # match. `np.compress` pads condition with False when it is shorter.
400  condition_t = condition
401  a_t = a
402  if condition.shape[0] < a.shape[axis]:
403    padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
404    condition_t = array_ops.concat([condition_t, padding], axis=0)
405  return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
406
407
408@np_utils.np_doc('copy')
409def copy(a):
410  return array(a, copy=True)
411
412
413def _maybe_promote_to_int(a):
414  if dtypes.as_dtype(a.dtype).is_integer:
415    # If a is an integer type and its precision is less than that of `int`,
416    # the output type will be `int`.
417    a_numpy_dtype = a.dtype.as_numpy_dtype
418    output_type = np.promote_types(a_numpy_dtype, int)
419    if output_type != a_numpy_dtype:
420      a = asarray(a, dtype=output_type)
421
422  return a
423
424
425@np_utils.np_doc('cumprod')
426def cumprod(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
427  a = asarray(a, dtype=dtype)
428
429  if dtype is None:
430    a = _maybe_promote_to_int(a)
431
432  # If axis is None, the input is flattened.
433  if axis is None:
434    a = ravel(a)
435    axis = 0
436  elif axis < 0:
437    axis += array_ops.rank(a)
438  return math_ops.cumprod(a, axis)
439
440
441@np_utils.np_doc('cumsum')
442def cumsum(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
443  a = asarray(a, dtype=dtype)
444
445  if dtype is None:
446    a = _maybe_promote_to_int(a)
447
448  # If axis is None, the input is flattened.
449  if axis is None:
450    a = ravel(a)
451    axis = 0
452  elif axis < 0:
453    axis += array_ops.rank(a)
454  return math_ops.cumsum(a, axis)
455
456
457@np_utils.np_doc('imag')
458def imag(val):
459  val = asarray(val)
460  # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
461  # return an ndarray.
462  return math_ops.imag(val)
463
464
465_TO_INT_ = 0
466_TO_FLOAT = 1
467
468
469def _reduce(tf_fn,
470            a,
471            axis=None,
472            dtype=None,
473            keepdims=None,
474            promote_int=_TO_INT_,
475            tf_bool_fn=None,
476            preserve_bool=False):
477  """A general reduction function.
478
479  Args:
480    tf_fn: the TF reduction function.
481    a: the array to be reduced.
482    axis: (optional) the axis along which to do the reduction. If None, all
483      dimensions are reduced.
484    dtype: (optional) the dtype of the result.
485    keepdims: (optional) whether to keep the reduced dimension(s).
486    promote_int: how to promote integer and bool inputs. There are three
487      choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2)
488      `_TO_FLOAT` always promotes them to a float type (determined by
489      dtypes.default_float_type); (3) None: don't promote.
490    tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
491      only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
492      is `np.bool_` and `preserve_bool` is True.
493    preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
494      is `np.bool_` (some reductions such as np.sum convert bools to integers,
495      while others such as np.max preserve bools.
496
497  Returns:
498    An ndarray.
499  """
500  if dtype:
501    dtype = np_utils.result_type(dtype)
502  if keepdims is None:
503    keepdims = False
504  a = asarray(a, dtype=dtype)
505  if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
506      tf_bool_fn is not None):
507    return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
508  if dtype is None:
509    dtype = a.dtype.as_numpy_dtype
510    if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
511      if promote_int == _TO_INT_:
512        # If a is an integer/bool type and whose bit width is less than np.int_,
513        # numpy up-casts it to np.int_ based on the documentation at
514        # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html
515        if dtype == np.bool_:
516          is_signed = True
517          width = 8  # We can use any number here that is less than 64
518        else:
519          is_signed = np.issubdtype(dtype, np.signedinteger)
520          width = np.iinfo(dtype).bits
521        # Numpy int_ and uint are defined as 'long' and 'unsigned long', so
522        # should have the same bit width.
523        if width < np.iinfo(np.int_).bits:
524          if is_signed:
525            dtype = np.int_
526          else:
527            dtype = np.uint
528          a = math_ops.cast(a, dtype)
529      elif promote_int == _TO_FLOAT:
530        a = math_ops.cast(a, np_dtypes.default_float_type())
531
532  if isinstance(axis, ops.Tensor) and axis.dtype not in (
533      dtypes.int32, dtypes.int64):
534    axis = math_ops.cast(axis, dtypes.int64)
535
536  return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
537
538
539# TODO (DarrenZhang01): Add `axis` support to the `size` API.
540@np_utils.np_doc('size')
541def size(x, axis=None):  # pylint: disable=missing-docstring
542  if axis is not None:
543    raise NotImplementedError('axis argument is not supported in the current '
544                              '`np.size` implementation')
545  if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
546    return 1
547  x = asarray(x)
548  if x.shape.is_fully_defined():
549    return np.prod(x.shape.as_list(), dtype=int)
550  else:
551    return array_ops.size_v2(x)
552
553
554@np_utils.np_doc('sum')
555def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
556  return _reduce(
557      math_ops.reduce_sum,
558      a,
559      axis=axis,
560      dtype=dtype,
561      keepdims=keepdims,
562      tf_bool_fn=math_ops.reduce_any)
563
564
565@np_utils.np_doc('prod')
566def prod(a, axis=None, dtype=None, keepdims=None):
567  return _reduce(
568      math_ops.reduce_prod,
569      a,
570      axis=axis,
571      dtype=dtype,
572      keepdims=keepdims,
573      tf_bool_fn=math_ops.reduce_all)
574
575
576@np_utils.np_doc('mean', unsupported_params=['out'])
577def mean(a, axis=None, dtype=None, out=None, keepdims=None):
578  if out is not None:
579    raise ValueError('Setting out is not supported.')
580  return _reduce(
581      math_ops.reduce_mean,
582      a,
583      axis=axis,
584      dtype=dtype,
585      keepdims=keepdims,
586      promote_int=_TO_FLOAT)
587
588
589@np_utils.np_doc('amax', unsupported_params=['out'])
590def amax(a, axis=None, out=None, keepdims=None):
591  if out is not None:
592    raise ValueError('Setting out is not supported.')
593  return _reduce(
594      math_ops.reduce_max,
595      a,
596      axis=axis,
597      dtype=None,
598      keepdims=keepdims,
599      promote_int=None,
600      tf_bool_fn=math_ops.reduce_any,
601      preserve_bool=True)
602
603
604@np_utils.np_doc('amin', unsupported_params=['out'])
605def amin(a, axis=None, out=None, keepdims=None):
606  if out is not None:
607    raise ValueError('Setting out is not supported.')
608  return _reduce(
609      math_ops.reduce_min,
610      a,
611      axis=axis,
612      dtype=None,
613      keepdims=keepdims,
614      promote_int=None,
615      tf_bool_fn=math_ops.reduce_all,
616      preserve_bool=True)
617
618
619@np_utils.np_doc('var')
620def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):  # pylint: disable=missing-docstring
621  if dtype:
622    working_dtype = np_utils.result_type(a, dtype)
623  else:
624    working_dtype = None
625  if out is not None:
626    raise ValueError('Setting out is not supported.')
627  if ddof != 0:
628    # TF reduce_variance doesn't support ddof, so calculate it using raw ops.
629    def reduce_fn(input_tensor, axis, keepdims):
630      means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True)
631      centered = input_tensor - means
632      if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
633        centered = math_ops.cast(
634            math_ops.real(centered * math_ops.conj(centered)),
635            input_tensor.dtype)
636      else:
637        centered = math_ops.square(centered)
638      squared_deviations = math_ops.reduce_sum(
639          centered, axis=axis, keepdims=keepdims)
640
641      if axis is None:
642        n = array_ops.size(input_tensor)
643      else:
644        if axis < 0:
645          axis += array_ops.rank(input_tensor)
646        n = math_ops.reduce_prod(
647            array_ops.gather(array_ops.shape(input_tensor), axis))
648      n = math_ops.cast(n - ddof, input_tensor.dtype)
649
650      return math_ops.cast(math_ops.divide(squared_deviations, n), dtype)
651  else:
652    reduce_fn = math_ops.reduce_variance
653
654  result = _reduce(
655      reduce_fn,
656      a,
657      axis=axis,
658      dtype=working_dtype,
659      keepdims=keepdims,
660      promote_int=_TO_FLOAT)
661  if dtype:
662    result = math_ops.cast(result, dtype)
663  return result
664
665
666@np_utils.np_doc('std')
667def std(a, axis=None, keepdims=None):  # pylint: disable=missing-function-docstring
668  return _reduce(
669      math_ops.reduce_std,
670      a,
671      axis=axis,
672      dtype=None,
673      keepdims=keepdims,
674      promote_int=_TO_FLOAT)
675
676
677@np_utils.np_doc('ravel')
678def ravel(a):  # pylint: disable=missing-docstring
679  a = asarray(a)
680  return array_ops.reshape(a, [-1])
681
682
683@np_utils.np_doc('real')
684def real(val):
685  val = asarray(val)
686  # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
687  # return an ndarray.
688  return math_ops.real(val)
689
690
691@np_utils.np_doc('repeat')
692def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
693  a = asarray(a)
694  original_shape = a._shape_as_list()  # pylint: disable=protected-access
695  # Best effort recovery of the shape.
696  known_shape = original_shape is not None and None not in original_shape
697  if known_shape:
698    if not original_shape:
699      original_shape = (repeats,)
700    else:
701      repeats_np = np.ravel(np.array(repeats))
702      if repeats_np.size == 1:
703        repeats_np = repeats_np.item()
704        if axis is None:
705          original_shape = (repeats_np * np.prod(original_shape),)
706        else:
707          original_shape[axis] = repeats_np * original_shape[axis]
708      else:
709        if axis is None:
710          original_shape = (repeats_np.sum(),)
711        else:
712          original_shape[axis] = repeats_np.sum()
713
714  repeats = asarray(repeats)
715  result = array_ops.repeat(a, repeats, axis)
716  if known_shape:
717    result.set_shape(original_shape)
718
719  return result
720
721
722@np_utils.np_doc('around')
723def around(a, decimals=0):  # pylint: disable=missing-docstring
724  a = asarray(a)
725  dtype = a.dtype.as_numpy_dtype
726  factor = math.pow(10, decimals)
727  if np.issubdtype(dtype, np.inexact):
728    factor = math_ops.cast(factor, dtype)
729  else:
730    # Use float as the working dtype when a.dtype is exact (e.g. integer),
731    # because `decimals` can be negative.
732    float_dtype = np_dtypes.default_float_type()
733    a = a.astype(float_dtype)
734    factor = math_ops.cast(factor, float_dtype)
735  a = math_ops.multiply(a, factor)
736  a = math_ops.round(a)
737  a = math_ops.divide(a, factor)
738  return a.astype(dtype)
739
740
741setattr(np_arrays.ndarray, '__round__', around)
742
743
744@np_utils.np_doc('reshape')
745def reshape(a, newshape, order='C'):
746  """order argument can only b 'C' or 'F'."""
747  if order not in {'C', 'F'}:
748    raise ValueError('Unsupported order argument {}'.format(order))
749
750  a = asarray(a)
751  if isinstance(newshape, int):
752    newshape = [newshape]
753
754  if order == 'F':
755    r = array_ops.transpose(
756        array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
757  else:
758    r = array_ops.reshape(a, newshape)
759
760  return r
761
762
763def _reshape_method_wrapper(a, *newshape, **kwargs):
764  order = kwargs.pop('order', 'C')
765  if kwargs:
766    raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
767
768  if len(newshape) == 1 and not isinstance(newshape[0], int):
769    newshape = newshape[0]
770
771  return reshape(a, newshape, order=order)
772
773
774@np_utils.np_doc('expand_dims')
775def expand_dims(a, axis):
776  a = asarray(a)
777  return array_ops.expand_dims(a, axis=axis)
778
779
780@np_utils.np_doc('squeeze')
781def squeeze(a, axis=None):
782  a = asarray(a)
783  return array_ops.squeeze(a, axis)
784
785
786@np_utils.np_doc('transpose')
787def transpose(a, axes=None):
788  a = asarray(a)
789  if axes is not None:
790    axes = asarray(axes)
791  return array_ops.transpose(a=a, perm=axes)
792
793
794@np_utils.np_doc('swapaxes')
795def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
796  a = asarray(a)
797  def adjust_axes(axes, rank):
798    def f(x):
799      if isinstance(x, int):
800        if x < 0:
801          x = x + rank
802      else:
803        x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
804      return x
805    return nest.map_structure(f, axes)
806
807  if (a.shape.rank is not None and
808      isinstance(axis1, int) and isinstance(axis2, int)):
809    # This branch makes sure `perm` is statically known, to avoid a
810    # not-compile-time-constant XLA error.
811    a_rank = a.shape.rank
812    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
813    perm = list(range(a_rank))
814    perm[axis1] = axis2
815    perm[axis2] = axis1
816  else:
817    a_rank = array_ops.rank(a)
818    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
819    perm = math_ops.range(a_rank)
820    perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
821                                           [axis2, axis1])
822  a = array_ops.transpose(a, perm)
823  return a
824
825
826@np_utils.np_doc('moveaxis')
827def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
828  """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
829  if not source and not destination:
830    return a
831
832  a = asarray(a)
833
834  if isinstance(source, int):
835    source = (source,)
836  if isinstance(destination, int):
837    destination = (destination,)
838  if len(source) != len(destination):
839    raise ValueError('The lengths of source and destination must equal')
840
841  a_rank = np_utils._maybe_static(array_ops.rank(a))  # pylint: disable=protected-access
842
843  def _correct_axis(axis, rank):
844    if axis < 0:
845      return axis + rank
846    return axis
847
848  source = tuple(_correct_axis(axis, a_rank) for axis in source)
849  destination = tuple(_correct_axis(axis, a_rank) for axis in destination)
850
851  if a.shape.rank is not None:
852    perm = [i for i in range(a_rank) if i not in source]
853    for dest, src in sorted(zip(destination, source)):
854      assert dest <= len(perm)
855      perm.insert(dest, src)
856  else:
857    r = math_ops.range(a_rank)
858
859    def _remove_indices(a, b):
860      """Remove indices (`b`) from `a`."""
861      items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b))
862
863      i = 0
864      result = []
865
866      for item in items:
867        result.append(a[i:item])
868        i = item + 1
869
870      result.append(a[i:])
871
872      return array_ops.concat(result, 0)
873
874    minus_sources = _remove_indices(r, source)
875    minus_dest = _remove_indices(r, destination)
876
877    perm = array_ops.scatter_nd(
878        array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank])
879    perm = array_ops.tensor_scatter_update(
880        perm, array_ops.expand_dims(destination, 1), source)
881  a = array_ops.transpose(a, perm)
882
883  return a
884
885
886@np_utils.np_doc('pad')
887def pad(array, pad_width, mode, **kwargs):  # pylint: disable=redefined-outer-name
888  """Only supports modes 'constant', 'reflect' and 'symmetric' currently."""
889  constant_values = kwargs.get('constant_values', 0)
890  if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
891    raise ValueError('Unsupported padding mode: ' + mode)
892  mode = mode.upper()
893  array = asarray(array)
894  pad_width = asarray(pad_width, dtype=dtypes.int32)
895  return array_ops.pad(
896      tensor=array,
897      paddings=pad_width,
898      mode=mode,
899      constant_values=constant_values)
900
901
902@np_utils.np_doc('take')
903def take(a, indices, axis=None, out=None, mode='clip'):
904  """out argument is not supported, and default mode is clip."""
905  if out is not None:
906    raise ValueError('out argument is not supported in take.')
907
908  if mode not in {'raise', 'clip', 'wrap'}:
909    raise ValueError("Invalid mode '{}' for take".format(mode))
910
911  a = asarray(a)
912  indices = asarray(indices)
913
914  if axis is None:
915    a = array_ops.reshape(a, [-1])
916    axis = 0
917
918  axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
919  if mode == 'clip':
920    indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
921  elif mode == 'wrap':
922    indices = math_ops.floormod(indices, axis_size)
923  else:
924    raise ValueError("The 'raise' mode to take is not supported.")
925
926  return array_ops.gather(a, indices, axis=axis)
927
928
929@np_utils.np_doc_only('where')
930def where(condition, x=None, y=None):
931  """Raises ValueError if exactly one of x or y is not None."""
932  condition = asarray(condition, dtype=np.bool_)
933  if x is None and y is None:
934    return nonzero(condition)
935  elif x is not None and y is not None:
936    x, y = _promote_dtype(x, y)
937    return array_ops.where_v2(condition, x, y)
938  raise ValueError('Both x and y must be ndarrays, or both must be None.')
939
940
941@np_utils.np_doc('select')
942def select(condlist, choicelist, default=0):  # pylint: disable=missing-docstring
943  if len(condlist) != len(choicelist):
944    msg = 'condlist must have length equal to choicelist ({} vs {})'
945    raise ValueError(msg.format(len(condlist), len(choicelist)))
946  if not condlist:
947    raise ValueError('condlist must be non-empty')
948  choices = _promote_dtype(default, *choicelist)
949  choicelist = choices[1:]
950  output = choices[0]
951  # The traversal is in reverse order so we can return the first value in
952  # choicelist where condlist is True.
953  for cond, choice in zip(condlist[::-1], choicelist[::-1]):
954    output = where(cond, choice, output)
955  return output
956
957
958@np_utils.np_doc('shape', link=np_utils.Link(
959    'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html'))
960def shape(a):
961  a = asarray(a)
962  return a.shape
963
964
965@np_utils.np_doc('ndim', link=np_utils.NoLink())
966def ndim(a):
967  a = asarray(a)
968  return a.ndim
969
970
971@np_utils.np_doc('isscalar')
972def isscalar(num):
973  return ndim(num) == 0
974
975
976def _boundaries_to_sizes(a, boundaries, axis):
977  """Converting boundaries of splits to sizes of splits.
978
979  Args:
980    a: the array to be split.
981    boundaries: the boundaries, as in np.split.
982    axis: the axis along which to split.
983
984  Returns:
985    A list of sizes of the splits, as in tf.split.
986  """
987  if axis >= len(a.shape):
988    raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
989  total_size = a.shape[axis]
990  sizes = []
991  sizes_sum = 0
992  prev = 0
993  for i, b in enumerate(boundaries):
994    size = b - prev
995    if size < 0:
996      raise ValueError('The %s-th boundary %s is smaller than the previous '
997                       'boundary %s' % (i, b, prev))
998    size = min(size, max(0, total_size - sizes_sum))
999    sizes.append(size)
1000    sizes_sum += size
1001    prev = b
1002  sizes.append(max(0, total_size - sizes_sum))
1003  return sizes
1004
1005
1006@np_utils.np_doc('split')
1007def split(ary, indices_or_sections, axis=0):
1008  ary = asarray(ary)
1009  if not isinstance(indices_or_sections, int):
1010    indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
1011  return array_ops.split(ary, indices_or_sections, axis=axis)
1012
1013
1014def _split_on_axis(np_fun_name, axis):
1015
1016  @np_utils.np_doc(np_fun_name)
1017  def f(ary, indices_or_sections):
1018    if isinstance(indices_or_sections, int):
1019      ary_shape = ary.shape[axis]
1020      if ary_shape is not None and ary_shape % indices_or_sections:
1021        raise ValueError(
1022            'array split does not result in an equal division')
1023    return split(ary, indices_or_sections, axis=axis)
1024
1025  return f
1026
1027
1028vsplit = _split_on_axis('vsplit', axis=0)
1029hsplit = _split_on_axis('hsplit', axis=1)
1030dsplit = _split_on_axis('dsplit', axis=2)
1031
1032
1033@np_utils.np_doc('broadcast_to')
1034def broadcast_to(array, shape):  # pylint: disable=redefined-outer-name
1035  return full(shape, array)
1036
1037
1038@np_utils.np_doc('stack')
1039def stack(arrays, axis=0):  # pylint: disable=missing-function-docstring
1040  if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
1041    arrays = asarray(arrays)
1042    if axis == 0:
1043      return arrays
1044    else:
1045      return swapaxes(arrays, 0, axis)
1046  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1047  unwrapped_arrays = [
1048      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1049  ]
1050  return asarray(array_ops.stack(unwrapped_arrays, axis))
1051
1052
1053@np_utils.np_doc('hstack')
1054def hstack(tup):
1055  arrays = [atleast_1d(a) for a in tup]
1056  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1057  unwrapped_arrays = [
1058      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1059  ]
1060  rank = array_ops.rank(unwrapped_arrays[0])
1061  return np_utils.cond(
1062      math_ops.equal(rank,
1063                     1), lambda: array_ops.concat(unwrapped_arrays, axis=0),
1064      lambda: array_ops.concat(unwrapped_arrays, axis=1))
1065
1066
1067@np_utils.np_doc('vstack')
1068def vstack(tup):
1069  arrays = [atleast_2d(a) for a in tup]
1070  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1071  unwrapped_arrays = [
1072      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1073  ]
1074  return array_ops.concat(unwrapped_arrays, axis=0)
1075
1076
1077@np_utils.np_doc('dstack')
1078def dstack(tup):
1079  arrays = [atleast_3d(a) for a in tup]
1080  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1081  unwrapped_arrays = [
1082      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1083  ]
1084  return array_ops.concat(unwrapped_arrays, axis=2)
1085
1086
1087def _pad_left_to(n, old_shape):
1088  old_shape = asarray(old_shape, dtype=np.int32)
1089  new_shape = array_ops.pad(
1090      old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
1091      constant_values=1)
1092  return asarray(new_shape)
1093
1094
1095def _atleast_nd(n, new_shape, *arys):
1096  """Reshape arrays to be at least `n`-dimensional.
1097
1098  Args:
1099    n: The minimal rank.
1100    new_shape: a function that takes `n` and the old shape and returns the
1101      desired new shape.
1102    *arys: ndarray(s) to be reshaped.
1103
1104  Returns:
1105    The reshaped array(s).
1106  """
1107
1108  def f(x):
1109    # pylint: disable=g-long-lambda
1110    x = asarray(x)
1111    return asarray(
1112        np_utils.cond(
1113            np_utils.greater(n, array_ops.rank(x)),
1114            lambda: reshape(x, new_shape(n, array_ops.shape(x))),
1115            lambda: x))
1116
1117  arys = list(map(f, arys))
1118  if len(arys) == 1:
1119    return arys[0]
1120  else:
1121    return arys
1122
1123
1124@np_utils.np_doc('atleast_1d')
1125def atleast_1d(*arys):
1126  return _atleast_nd(1, _pad_left_to, *arys)
1127
1128
1129@np_utils.np_doc('atleast_2d')
1130def atleast_2d(*arys):
1131  return _atleast_nd(2, _pad_left_to, *arys)
1132
1133
1134@np_utils.np_doc('atleast_3d')
1135def atleast_3d(*arys):  # pylint: disable=missing-docstring
1136
1137  def new_shape(_, old_shape):
1138    # pylint: disable=g-long-lambda
1139    ndim_ = array_ops.size(old_shape)
1140    return np_utils.cond(
1141        math_ops.equal(ndim_, 0),
1142        lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32),
1143        lambda: np_utils.cond(
1144            math_ops.equal(ndim_, 1), lambda: array_ops.pad(
1145                old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad(
1146                    old_shape, [[0, 1]], constant_values=1)))
1147
1148  return _atleast_nd(3, new_shape, *arys)
1149
1150
1151@np_utils.np_doc('nonzero')
1152def nonzero(a):
1153  a = atleast_1d(a)
1154  if a.shape.rank is None:
1155    raise ValueError("The rank of `a` is unknown, so we can't decide how many "
1156                     'arrays to return.')
1157  return array_ops.unstack(
1158            array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
1159            a.shape.rank,
1160            axis=1)
1161
1162
1163@np_utils.np_doc('diag_indices')
1164def diag_indices(n, ndim=2):  # pylint: disable=missing-docstring,redefined-outer-name
1165  if n < 0:
1166    raise ValueError(
1167        'n argument to diag_indices must be nonnegative, got {}'.format(n))
1168  if ndim < 0:
1169    raise ValueError(
1170        'ndim argument to diag_indices must be nonnegative, got {}'.format(
1171            ndim))
1172
1173  return (math_ops.range(n),) * ndim
1174
1175
1176@np_utils.np_doc('tri')
1177def tri(N, M=None, k=0, dtype=None):  # pylint: disable=invalid-name,missing-docstring
1178  M = M if M is not None else N
1179  if dtype is not None:
1180    dtype = np_utils.result_type(dtype)
1181  else:
1182    dtype = np_dtypes.default_float_type()
1183
1184  if k < 0:
1185    lower = -k - 1
1186    if lower > N:
1187      r = array_ops.zeros([N, M], dtype)
1188    else:
1189      # Keep as tf bool, since we create an upper triangular matrix and invert
1190      # it.
1191      o = array_ops.ones([N, M], dtype=dtypes.bool)
1192      r = math_ops.cast(
1193          math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype)
1194  else:
1195    o = array_ops.ones([N, M], dtype)
1196    if k > M:
1197      r = o
1198    else:
1199      r = array_ops.matrix_band_part(o, -1, k)
1200  return r
1201
1202
1203@np_utils.np_doc('tril')
1204def tril(m, k=0):  # pylint: disable=missing-docstring
1205  m = asarray(m)
1206  if m.shape.ndims is None:
1207    raise ValueError('Argument to tril should have known rank')
1208  m_shape = m.shape.as_list()
1209
1210  if len(m_shape) < 2:
1211    raise ValueError('Argument to tril must have rank at least 2')
1212
1213  if m_shape[-1] is None or m_shape[-2] is None:
1214    raise ValueError('Currently, the last two dimensions of the input array '
1215                     'need to be known.')
1216
1217  z = constant_op.constant(0, m.dtype)
1218
1219  mask = tri(*m_shape[-2:], k=k, dtype=bool)
1220  return array_ops.where_v2(
1221      array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
1222
1223
1224@np_utils.np_doc('triu')
1225def triu(m, k=0):  # pylint: disable=missing-docstring
1226  m = asarray(m)
1227  if m.shape.ndims is None:
1228    raise ValueError('Argument to triu should have known rank')
1229  m_shape = m.shape.as_list()
1230
1231  if len(m_shape) < 2:
1232    raise ValueError('Argument to triu must have rank at least 2')
1233
1234  if m_shape[-1] is None or m_shape[-2] is None:
1235    raise ValueError('Currently, the last two dimensions of the input array '
1236                     'need to be known.')
1237
1238  z = constant_op.constant(0, m.dtype)
1239
1240  mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
1241  return array_ops.where_v2(
1242      array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
1243
1244
1245@np_utils.np_doc('flip')
1246def flip(m, axis=None):  # pylint: disable=missing-docstring
1247  m = asarray(m)
1248
1249  if axis is None:
1250    return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
1251
1252  axis = np_utils._canonicalize_axis(axis, array_ops.rank(m))  # pylint: disable=protected-access
1253
1254  return array_ops.reverse(m, [axis])
1255
1256
1257@np_utils.np_doc('flipud')
1258def flipud(m):  # pylint: disable=missing-docstring
1259  return flip(m, 0)
1260
1261
1262@np_utils.np_doc('fliplr')
1263def fliplr(m):  # pylint: disable=missing-docstring
1264  return flip(m, 1)
1265
1266
1267@np_utils.np_doc('roll')
1268def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
1269  a = asarray(a)
1270
1271  if axis is not None:
1272    return manip_ops.roll(a, shift, axis)
1273
1274  # If axis is None, the roll happens as a 1-d tensor.
1275  original_shape = array_ops.shape(a)
1276  a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
1277  return array_ops.reshape(a, original_shape)
1278
1279
1280@np_utils.np_doc('rot90')
1281def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
1282  m_rank = array_ops.rank(m)
1283  ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank)  # pylint: disable=protected-access
1284
1285  k = k % 4
1286  if k == 0:
1287    return m
1288  elif k == 2:
1289    return flip(flip(m, ax1), ax2)
1290  else:
1291    perm = math_ops.range(m_rank)
1292    perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
1293
1294    if k == 1:
1295      return transpose(flip(m, ax2), perm)
1296    else:
1297      return flip(transpose(m, perm), ax2)
1298
1299
1300@np_utils.np_doc('vander')
1301def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,invalid-name
1302  x = asarray(x)
1303
1304  x_shape = array_ops.shape(x)
1305  N = N or x_shape[0]
1306
1307  N_temp = np_utils.get_static_value(N)  # pylint: disable=invalid-name
1308  if N_temp is not None:
1309    N = N_temp
1310    if N < 0:
1311      raise ValueError('N must be nonnegative')
1312  else:
1313    control_flow_ops.Assert(N >= 0, [N])
1314
1315  rank = array_ops.rank(x)
1316  rank_temp = np_utils.get_static_value(rank)
1317  if rank_temp is not None:
1318    rank = rank_temp
1319    if rank != 1:
1320      raise ValueError('x must be a one-dimensional array')
1321  else:
1322    control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
1323
1324  if increasing:
1325    start = 0
1326    limit = N
1327    delta = 1
1328  else:
1329    start = N - 1
1330    limit = -1
1331    delta = -1
1332
1333  x = array_ops.expand_dims(x, -1)
1334  return math_ops.pow(
1335      x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
1336
1337
1338@np_utils.np_doc('ix_')
1339def ix_(*args):  # pylint: disable=missing-docstring
1340  n = len(args)
1341  output = []
1342  for i, a in enumerate(args):
1343    a = asarray(a)
1344    a_rank = array_ops.rank(a)
1345    a_rank_temp = np_utils.get_static_value(a_rank)
1346    if a_rank_temp is not None:
1347      a_rank = a_rank_temp
1348      if a_rank != 1:
1349        raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
1350            i, a_rank))
1351    else:
1352      control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
1353
1354    new_shape = [1] * n
1355    new_shape[i] = -1
1356    dtype = a.dtype
1357    if dtype == dtypes.bool:
1358      output.append(array_ops.reshape(nonzero(a)[0], new_shape))
1359    elif dtype.is_integer:
1360      output.append(array_ops.reshape(a, new_shape))
1361    else:
1362      raise ValueError(
1363          'Only integer and bool dtypes are supported, got {}'.format(dtype))
1364
1365  return output
1366
1367
1368@np_utils.np_doc('broadcast_arrays')
1369def broadcast_arrays(*args, **kwargs):  # pylint: disable=missing-docstring
1370  subok = kwargs.pop('subok', False)
1371  if subok:
1372    raise ValueError('subok=True is not supported.')
1373  if kwargs:
1374    raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
1375
1376  args = [asarray(arg) for arg in args]
1377  return np_utils.tf_broadcast(*args)
1378
1379
1380@np_utils.np_doc_only('sign')
1381def sign(x, out=None, where=None, **kwargs):  # pylint: disable=missing-docstring,redefined-outer-name
1382  if out:
1383    raise ValueError('tf.numpy doesnt support setting out.')
1384  if where:
1385    raise ValueError('tf.numpy doesnt support setting where.')
1386  if kwargs:
1387    raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
1388
1389  x = asarray(x)
1390  dtype = x.dtype.as_numpy_dtype
1391  if np.issubdtype(dtype, np.complexfloating):
1392    result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
1393  else:
1394    result = math_ops.sign(x)
1395
1396  return result
1397
1398
1399# Note that np.take_along_axis may not be present in some supported versions of
1400# numpy.
1401@np_utils.np_doc('take_along_axis')
1402def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
1403  arr = asarray(arr)
1404  indices = asarray(indices)
1405
1406  if axis is None:
1407    return take_along_axis(arr.ravel(), indices, 0)
1408
1409  rank = array_ops.rank(arr)
1410  axis = axis + rank if axis < 0 else axis
1411
1412  # Broadcast shapes to match, ensure that the axis of interest is not
1413  # broadcast.
1414  arr_shape_original = array_ops.shape(arr)
1415  indices_shape_original = array_ops.shape(indices)
1416  arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
1417  indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
1418                                                  [[axis]], [1])
1419  broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
1420                                                        indices_shape)
1421  arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
1422                                              [arr_shape_original[axis]])
1423  indices_shape = array_ops.tensor_scatter_update(
1424      broadcasted_shape, [[axis]], [indices_shape_original[axis]])
1425  arr = array_ops.broadcast_to(arr, arr_shape)
1426  indices = array_ops.broadcast_to(indices, indices_shape)
1427
1428  # Save indices shape so we can restore it later.
1429  possible_result_shape = indices.shape
1430
1431  # Correct indices since gather doesn't correctly handle negative indices.
1432  indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
1433
1434  swapaxes_ = lambda t: swapaxes(t, axis, -1)
1435
1436  dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
1437  arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
1438                      lambda: swapaxes_(arr))
1439  indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
1440                          lambda: swapaxes_(indices))
1441
1442  arr_shape = array_ops.shape(arr)
1443  arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
1444
1445  indices_shape = array_ops.shape(indices)
1446  indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
1447
1448  result = array_ops.gather(arr, indices, batch_dims=1)
1449  result = array_ops.reshape(result, indices_shape)
1450  result = np_utils.cond(dont_move_axis_to_end, lambda: result,
1451                         lambda: swapaxes_(result))
1452  result.set_shape(possible_result_shape)
1453
1454  return result
1455
1456
1457_SLICE_ERORR = (
1458    'only integers, slices (`:`), ellipsis (`...`), '
1459    'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
1460
1461
1462def _as_index(idx, need_scalar=True):
1463  """Helper function to parse idx as an index.
1464
1465  Args:
1466    idx: index
1467    need_scalar: If idx needs to be a scalar value.
1468
1469  Returns:
1470    A pair, (indx, bool). First one is the parsed index and can be a tensor,
1471    or scalar integer / Dimension. Second one is True if rank is known to be 0.
1472
1473  Raises:
1474    IndexError: For incorrect indices.
1475  """
1476  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
1477    return idx, True
1478  data = asarray(idx)
1479  if data.dtype == dtypes.bool:
1480    if data.shape.ndims != 1:
1481      # TODO(agarwal): handle higher rank boolean masks.
1482      raise NotImplementedError('Need rank 1 for bool index %s' % idx)
1483    data = array_ops.where_v2(data)
1484    data = array_ops.reshape(data, [-1])
1485  if need_scalar and data.shape.rank not in (None, 0):
1486    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1487  np_dtype = data.dtype.as_numpy_dtype
1488  if not np.issubdtype(np_dtype, np.integer):
1489    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1490  if data.dtype not in (dtypes.int64, dtypes.int32):
1491    # TF slicing can only handle int32/int64. So we need to cast.
1492    promoted_dtype = np.promote_types(np.int32, np_dtype)
1493    if promoted_dtype == np.int32:
1494      data = math_ops.cast(data, dtypes.int32)
1495    elif promoted_dtype == np.int64:
1496      data = math_ops.cast(data, dtypes.int64)
1497    else:
1498      raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1499  return data, data.shape.rank == 0
1500
1501
1502class _UpdateMethod(enum.Enum):
1503  UPDATE = 0
1504  ADD = 1
1505  MIN = 2
1506  MAX = 3
1507
1508
1509def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
1510  """Helper function for __getitem__ and _with_index_update_helper.
1511
1512  This function collects the indices in `slice_spec` into two buckets, which we
1513  can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2
1514  `gather`.  They also correspond to "basic indices" and "advanced indices" in
1515  numpy.  This function supports both reading and writing at the indices. The
1516  reading path can be summarized as `gather(stride_slice(tensor, idx1),
1517  idx2)`. The writing path can be summarized as `strided_slice_update(tensor,
1518  idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`.  (`gather` here
1519  means `tf.gather` or `tf.gather_nd`; `scatter` here means
1520  `tf.tensor_scatter_update`.)  The writing path is inefficient because it needs
1521  to first read out a portion (probably much larger than `updates`) of `tensor`
1522  using `strided_slice`, update it, and then write the portion back. An
1523  alternative approach is to only use `scatter`, which amounts to using the
1524  indexing mechanism of gather/scatter to implement
1525  strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter
1526  because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but
1527  not TF gather/scatter because they don't support spans (except those that
1528  cover entire dimensions, i.e. `:`).  If we materialize spans into individual
1529  indices, the size of the index tensor would explode.  (Note that XLA
1530  Gather/Scatter have a similar problem for stride > 1 because they don't
1531  support strides.  Indices such as `1:2:8` will need to be materialized into
1532  individual indices such as [1, 3, 5, 7].)
1533
1534  Args:
1535    tensor: the tensor to be read from or write into.
1536    slice_spec: the indices.
1537    update_method: (optional) a member of `_UpdateMethod`, indicating how to
1538      update the values (replacement, add, etc.). `None` indicates just reading.
1539    updates: (optional) the new values to write into `tensor`. It must have the
1540      same dtype as `tensor`.
1541
1542  Returns:
1543    The result of reading (if `update_method` is `None`) or the updated `tensor`
1544    after writing.
1545  """
1546  begin, end, strides = [], [], []
1547  new_axis_mask, shrink_axis_mask = 0, 0
1548  begin_mask, end_mask = 0, 0
1549  ellipsis_mask = 0
1550  advanced_indices = []
1551  shrink_indices = []
1552  for index, s in enumerate(slice_spec):
1553    if isinstance(s, slice):
1554      if s.start is not None:
1555        begin.append(_as_index(s.start)[0])
1556      else:
1557        begin.append(0)
1558        begin_mask |= (1 << index)
1559      if s.stop is not None:
1560        end.append(_as_index(s.stop)[0])
1561      else:
1562        end.append(0)
1563        end_mask |= (1 << index)
1564      if s.step is not None:
1565        strides.append(_as_index(s.step)[0])
1566      else:
1567        strides.append(1)
1568    elif s is Ellipsis:
1569      begin.append(0)
1570      end.append(0)
1571      strides.append(1)
1572      ellipsis_mask |= (1 << index)
1573    elif s is array_ops.newaxis:
1574      begin.append(0)
1575      end.append(0)
1576      strides.append(1)
1577      new_axis_mask |= (1 << index)
1578    else:
1579      s, is_scalar = _as_index(s, False)
1580      if is_scalar:
1581        begin.append(s)
1582        end.append(s + 1)
1583        strides.append(1)
1584        shrink_axis_mask |= (1 << index)
1585        shrink_indices.append(index)
1586      else:
1587        begin.append(0)
1588        end.append(0)
1589        strides.append(1)
1590        begin_mask |= (1 << index)
1591        end_mask |= (1 << index)
1592        advanced_indices.append((index, s, ellipsis_mask != 0))
1593
1594  # stack possibly involves no tensors, so we must use op_scope correct graph.
1595  with ops.name_scope(
1596      None,
1597      'strided_slice', [tensor] + begin + end + strides,
1598      skip_on_eager=False) as name:
1599    if begin:
1600      packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
1601                                                  array_ops.stack(end),
1602                                                  array_ops.stack(strides))
1603      if (packed_begin.dtype == dtypes.int64 or
1604          packed_end.dtype == dtypes.int64 or
1605          packed_strides.dtype == dtypes.int64):
1606        if packed_begin.dtype != dtypes.int64:
1607          packed_begin = math_ops.cast(packed_begin, dtypes.int64)
1608        if packed_end.dtype != dtypes.int64:
1609          packed_end = math_ops.cast(packed_end, dtypes.int64)
1610        if packed_strides.dtype != dtypes.int64:
1611          packed_strides = math_ops.cast(packed_strides, dtypes.int64)
1612    else:
1613      var_empty = constant_op.constant([], dtype=dtypes.int32)
1614      packed_begin = packed_end = packed_strides = var_empty
1615    if update_method == _UpdateMethod.UPDATE and not advanced_indices:
1616      return array_ops.tensor_strided_slice_update(
1617          tensor,
1618          packed_begin,
1619          packed_end,
1620          packed_strides,
1621          updates,
1622          begin_mask=begin_mask,
1623          end_mask=end_mask,
1624          shrink_axis_mask=shrink_axis_mask,
1625          new_axis_mask=new_axis_mask,
1626          ellipsis_mask=ellipsis_mask,
1627          name=name)
1628    else:
1629      # TODO(b/164251540): Find a better way to support update that does not
1630      #   involve one read + two writes.
1631      if updates is not None:
1632        original_tensor = tensor
1633      # TODO(agarwal): set_shape on tensor to set rank.
1634      tensor = array_ops.strided_slice(
1635          tensor,
1636          packed_begin,
1637          packed_end,
1638          packed_strides,
1639          begin_mask=begin_mask,
1640          end_mask=end_mask,
1641          shrink_axis_mask=shrink_axis_mask,
1642          new_axis_mask=new_axis_mask,
1643          ellipsis_mask=ellipsis_mask,
1644          name=name)
1645    if not advanced_indices:
1646      if update_method is None:
1647        return tensor
1648      assert update_method != _UpdateMethod.UPDATE
1649      # TF lacks TensorStridedSliceAdd and alike, so we need to do
1650      # read+add+update.
1651      if update_method == _UpdateMethod.ADD:
1652        update_op = math_ops.add
1653      elif update_method == _UpdateMethod.MIN:
1654        update_op = math_ops.minimum
1655      elif update_method == _UpdateMethod.MAX:
1656        update_op = math_ops.maximum
1657      return array_ops.tensor_strided_slice_update(
1658          original_tensor,
1659          packed_begin,
1660          packed_end,
1661          packed_strides,
1662          update_op(tensor, updates),
1663          begin_mask=begin_mask,
1664          end_mask=end_mask,
1665          shrink_axis_mask=shrink_axis_mask,
1666          new_axis_mask=new_axis_mask,
1667          ellipsis_mask=ellipsis_mask,
1668          name=name + '_2')
1669    advanced_indices_map = {}
1670    for index, data, had_ellipsis in advanced_indices:
1671      if had_ellipsis:
1672        num_shrink = len([x for x in shrink_indices if x > index])
1673        dim = index - len(slice_spec) + num_shrink
1674      else:
1675        num_shrink = len([x for x in shrink_indices if x < index])
1676        dim = index - num_shrink
1677      advanced_indices_map[dim] = data
1678    dims = sorted(advanced_indices_map.keys())
1679    dims_contiguous = True
1680    if len(dims) > 1:
1681      if dims[0] < 0 and dims[-1] >= 0:  # not all same sign
1682        dims_contiguous = False
1683      else:
1684        for i in range(len(dims) - 1):
1685          if dims[i] + 1 != dims[i + 1]:
1686            dims_contiguous = False
1687            break
1688    indices = [advanced_indices_map[x] for x in dims]
1689    indices = _promote_dtype(*indices)
1690    indices = np_utils.tf_broadcast(*indices)
1691    stacked_indices = array_ops.stack(indices, axis=-1)
1692    # Skip the contiguous-dims optimization for update because there is no
1693    # tf.*scatter* op that supports the `axis` argument.
1694    if not dims_contiguous or updates is not None:
1695      if range(len(dims)) != dims:
1696        tensor = moveaxis(tensor, dims, range(len(dims)))
1697      tensor_shape_prefix = array_ops.shape(
1698          tensor, out_type=stacked_indices.dtype)[:len(dims)]
1699      stacked_indices = array_ops.where_v2(
1700          stacked_indices < 0, stacked_indices + tensor_shape_prefix,
1701          stacked_indices)
1702      if updates is None:
1703        return array_ops.gather_nd(tensor, stacked_indices)
1704      else:
1705        # We only need to move-axis `updates` in the contiguous case becausce
1706        # only in this case the result dimensions of advanced indexing are in
1707        # the middle of `updates`. In the non-contiguous case, those dimensions
1708        # are always at the front.
1709        if dims_contiguous:
1710          # TODO(wangpeng): Support unknown rank (e.g. by partially flattening
1711          #   `updates`)
1712          if stacked_indices.shape.rank is None:
1713            raise NotImplementedError(
1714                'Rank of the advanced indices must currently be known')
1715          batch_size = stacked_indices.shape.rank - 1
1716          batch_start = dims[0]
1717          if batch_start < 0:
1718            batch_start += len(dims) - batch_size
1719          def range_(start, length):
1720            return range(start, start + length)
1721          updates = moveaxis(updates, range_(batch_start, batch_size),
1722                             range(batch_size))
1723        if update_method == _UpdateMethod.UPDATE:
1724          update_op = array_ops.tensor_scatter_update
1725        elif update_method == _UpdateMethod.ADD:
1726          update_op = array_ops.tensor_scatter_add
1727        elif update_method == _UpdateMethod.MIN:
1728          update_op = array_ops.tensor_scatter_min
1729        elif update_method == _UpdateMethod.MAX:
1730          update_op = array_ops.tensor_scatter_max
1731        tensor = update_op(
1732            tensor, stacked_indices, updates)
1733        if range(len(dims)) != dims:
1734          tensor = moveaxis(tensor, range(len(dims)), dims)
1735        return array_ops.tensor_strided_slice_update(
1736            original_tensor,
1737            packed_begin,
1738            packed_end,
1739            packed_strides,
1740            tensor,
1741            begin_mask=begin_mask,
1742            end_mask=end_mask,
1743            shrink_axis_mask=shrink_axis_mask,
1744            new_axis_mask=new_axis_mask,
1745            ellipsis_mask=ellipsis_mask,
1746            name=name + '_2')
1747    # Note that gather_nd does not support gathering from inside the array.
1748    # To avoid shuffling data back and forth, we transform the indices and
1749    # do a gather instead.
1750    rank = np_utils._maybe_static(array_ops.rank(tensor))  # pylint: disable=protected-access
1751    dims = [(x + rank if x < 0 else x) for x in dims]
1752    shape_tensor = array_ops.shape(tensor)
1753    dim_sizes = array_ops.gather(shape_tensor, dims)
1754    if len(dims) == 1:
1755      stacked_indices = indices[0]
1756    stacked_indices = math_ops.cast(stacked_indices, dtypes.int32)
1757    stacked_indices = array_ops.where_v2(stacked_indices < 0,
1758                                         stacked_indices + dim_sizes,
1759                                         stacked_indices)
1760    axis = dims[0]
1761    if len(dims) > 1:
1762      index_scaling = math_ops.cumprod(
1763          dim_sizes, reverse=True, exclusive=True)
1764      def _tensordot(a, b):
1765        # TODO(b/168657656): This function should be replaced by
1766        # tensordot(axis=1) once MatMul has int32 XLA kernel.
1767        b = array_ops.broadcast_to(b, array_ops.shape(a))
1768        return math_ops.reduce_sum(a * b, axis=-1)
1769      stacked_indices = _tensordot(stacked_indices, index_scaling)
1770      flat_shape = array_ops.concat(
1771          [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
1772          axis=0)
1773      tensor = array_ops.reshape(tensor, flat_shape)
1774
1775    return array_ops.gather(tensor, stacked_indices, axis=axis)
1776
1777
1778def _as_spec_tuple(slice_spec):
1779  """Convert slice_spec to tuple."""
1780  if isinstance(slice_spec,
1781                (list, tuple)) and not isinstance(slice_spec, np.ndarray):
1782    is_index = True
1783    for s in slice_spec:
1784      if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)):
1785        is_index = False
1786        break
1787      elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
1788        is_index = False
1789        break
1790    if not is_index:
1791      return tuple(slice_spec)
1792  return (slice_spec,)
1793
1794
1795def _getitem(self, slice_spec):
1796  """Implementation of ndarray.__getitem__."""
1797  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1798                                       slice_spec.dtype == dtypes.bool) or
1799      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1800       slice_spec.dtype == np.bool_)):
1801    return array_ops.boolean_mask(tensor=self, mask=slice_spec)
1802
1803  if not isinstance(slice_spec, tuple):
1804    slice_spec = _as_spec_tuple(slice_spec)
1805
1806  result_t = _slice_helper(self, slice_spec)
1807  return result_t
1808
1809
1810def _with_index_update_helper(update_method, a, slice_spec, updates):
1811  """Implementation of ndarray._with_index_*."""
1812  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1813                                       slice_spec.dtype == dtypes.bool) or
1814      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1815       slice_spec.dtype == np.bool_)):
1816    slice_spec = nonzero(slice_spec)
1817
1818  if not isinstance(slice_spec, tuple):
1819    slice_spec = _as_spec_tuple(slice_spec)
1820
1821  a_dtype = a.dtype
1822  a, updates = _promote_dtype_binary(a, updates)
1823  result_t = _slice_helper(a, slice_spec, update_method, updates)
1824  return result_t.astype(a_dtype)
1825
1826
1827setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
1828setattr(np_arrays.ndarray, '_with_index_update',
1829        functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
1830setattr(np_arrays.ndarray, '_with_index_add',
1831        functools.partial(_with_index_update_helper, _UpdateMethod.ADD))
1832setattr(np_arrays.ndarray, '_with_index_min',
1833        functools.partial(_with_index_update_helper, _UpdateMethod.MIN))
1834setattr(np_arrays.ndarray, '_with_index_max',
1835        functools.partial(_with_index_update_helper, _UpdateMethod.MAX))
1836