xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/math_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.math_ops."""
16from absl.testing import parameterized
17import numpy as np
18
19from tensorflow.core.framework import full_type_pb2
20from tensorflow.python import tf2
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import errors_impl
28from tensorflow.python.framework import indexed_slices
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import gradients
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import tensor_array_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.ops.ragged import ragged_factory_ops
38from tensorflow.python.platform import googletest
39
40
41@test_util.run_all_in_graph_and_eager_modes
42class ReduceTest(test_util.TensorFlowTestCase):
43
44  def testReduceAllDims(self):
45    x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
46    with test_util.device(use_gpu=True):
47      y_tf = self.evaluate(math_ops.reduce_sum(x))
48      self.assertEqual(y_tf, 21)
49
50  def testReduceExtendType(self):
51    in_f32 = np.random.randn(1000, 1000).astype(np.float32)
52    in_bf16 = math_ops.cast(in_f32, dtypes.bfloat16)
53
54    out_f32 = self.evaluate(math_ops.reduce_sum(in_f32))
55    out_bf16 = self.evaluate(math_ops.reduce_sum(in_bf16))
56    expected = math_ops.cast(out_f32, dtypes.bfloat16)
57
58    self.assertAllClose(out_bf16, expected, 1e-3)
59
60  def testReduceExplicitAxes(self):
61    x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
62    with test_util.device(use_gpu=True):
63      for axis in (0, -2):
64        self.assertAllEqual(
65            self.evaluate(math_ops.reduce_sum(x, axis=axis)), [5, 7, 9])
66      for axis in (1, -1):
67        self.assertAllEqual(
68            self.evaluate(math_ops.reduce_sum(x, axis=axis)), [6, 15])
69      for axis in (None, (0, 1), (1, 0), (-1, 0), (0, -1), (-2, 1), (1, -2),
70                   (-1, -2), (-2, -1)):
71        self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21)
72
73  def testReduceInvalidAxis(self):
74    if context.executing_eagerly():
75      # The shape check is in run a graph construction time. In eager mode,
76      # it misses the check, magically return result given wrong shape.
77      return
78    x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
79    axis = np.array([[0], [1]])
80    with self.assertRaisesRegex(ValueError, "must be at most rank 1"):
81      math_ops.reduce_sum(x, axis)
82
83  def testReduceVar(self):
84    x = np.array([[0, 0, 0], [0, 0, 0]], "float32")
85    self.assertAllClose(self.evaluate(math_ops.reduce_variance(x)), 0)
86    self.assertAllClose(
87        self.evaluate(math_ops.reduce_variance(x, axis=0)), [0, 0, 0])
88
89    x = [[1, 2, 1, 1], [1, 1, 0, 1]]
90    with self.assertRaisesRegex(TypeError, "must be either real or complex"):
91      math_ops.reduce_variance(x)
92
93    x = [[1., 2., 1., 1.], [1., 1., 0., 1.]]
94    self.assertEqual(self.evaluate(math_ops.reduce_variance(x)), 0.25)
95    x_np = np.array(x)
96    self.assertEqual(np.var(x_np), 0.25)
97    self.assertEqual(self.evaluate(math_ops.reduce_variance(x_np)), 0.25)
98
99    x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.],
100                                     []])
101    self.assertAllClose(math_ops.reduce_variance(x, axis=0), [0., 16., 1., 0.])
102
103  def testReduceVarComplex(self):
104    # Ensure that complex values are handled to be consistent with numpy
105    complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64),
106                  (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32),
107                  (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)]
108    for y, dtype in complex_ys:
109      y_result = math_ops.reduce_variance(y)
110      self.assertEqual(np.var(y), 1.0)
111      self.assertEqual(self.evaluate(y_result), 1.0)
112      self.assertEqual(y_result.dtype, dtype)
113
114  def testReduceStd(self):
115    x = np.array([[0, 0, 0], [0, 0, 0]], "float32")
116    self.assertAllClose(self.evaluate(math_ops.reduce_std(x)), 0)
117    self.assertAllClose(
118        self.evaluate(math_ops.reduce_std(x, axis=0)), [0, 0, 0])
119
120    x = [[1, 2, 1, 1], [1, 1, 0, 1]]
121    with self.assertRaisesRegex(TypeError, "must be either real or complex"):
122      math_ops.reduce_std(x)
123
124    x = [[1., 2., 1., 1.], [1., 1., 0., 1.]]
125    self.assertEqual(self.evaluate(math_ops.reduce_std(x)), 0.5)
126    x_np = np.array(x)
127    self.assertEqual(np.std(x_np), 0.5)
128    self.assertEqual(self.evaluate(math_ops.reduce_std(x_np)), 0.5)
129
130    x = ragged_factory_ops.constant([[5., 1., 4., 1.], [], [5., 9., 2.], [5.],
131                                     []])
132    self.assertAllClose(math_ops.reduce_std(x, axis=0), [0., 4., 1., 0.])
133
134  def testReduceStdComplex(self):
135    # Ensure that complex values are handled to be consistent with numpy
136    complex_ys = [([0 - 1j, 0 + 1j], dtypes.float64),
137                  (np.array([0 - 1j, 0 + 1j], "complex64"), dtypes.float32),
138                  (np.array([0 - 1j, 0 + 1j], "complex128"), dtypes.float64)]
139    for y, dtype in complex_ys:
140      y_result = math_ops.reduce_std(y)
141      self.assertEqual(np.std(y), 1.0)
142      self.assertEqual(self.evaluate(y_result), 1.0)
143      self.assertEqual(y_result.dtype, dtype)
144
145
146@test_util.run_all_in_graph_and_eager_modes
147class LogSumExpTest(test_util.TensorFlowTestCase):
148
149  def testReduceLogSumExp(self):
150    for dtype in [np.float16, np.float32, np.double]:
151      x_np = np.random.rand(5, 5).astype(dtype)
152      with test_util.use_gpu():
153        y_tf_np = math_ops.reduce_logsumexp(x_np)
154        y_np = np.log(np.sum(np.exp(x_np)))
155        self.assertAllClose(y_tf_np, y_np)
156
157  def testReductionIndices(self):
158    for dtype in [np.float16, np.float32, np.double]:
159      x_np = np.random.rand(5, 5).astype(dtype)
160      with test_util.use_gpu():
161        y_tf = math_ops.reduce_logsumexp(x_np, axis=[0])
162        y_np = np.log(np.sum(np.exp(x_np), axis=0))
163        self.assertShapeEqual(y_np, y_tf)
164        y_tf_np = self.evaluate(y_tf)
165        self.assertAllClose(y_tf_np, y_np)
166
167  def testReductionIndices2(self):
168    for dtype in [np.float16, np.float32, np.double]:
169      x_np = np.random.rand(5, 5).astype(dtype)
170      with test_util.use_gpu():
171        y_tf = math_ops.reduce_logsumexp(x_np, axis=0)
172        y_np = np.log(np.sum(np.exp(x_np), axis=0))
173        self.assertShapeEqual(y_np, y_tf)
174        y_tf_np = self.evaluate(y_tf)
175        self.assertAllClose(y_tf_np, y_np)
176
177  def testKeepDims(self):
178    for dtype in [np.float16, np.float32, np.double]:
179      x_np = np.random.rand(5, 5).astype(dtype)
180      with test_util.use_gpu():
181        y_tf_np = math_ops.reduce_logsumexp(x_np, keepdims=True)
182        self.assertEqual(y_tf_np.shape.rank, x_np.ndim)
183        y_np = np.log(np.sum(np.exp(x_np), keepdims=True))
184        self.assertAllClose(y_tf_np, y_np)
185
186  def testOverflow(self):
187    x = [1000, 1001, 1002, 1003]
188    for dtype in [np.float16, np.float32, np.double]:
189      x_np = np.array(x, dtype=dtype)
190      max_np = np.max(x_np)
191      with self.assertRaisesRegex(RuntimeWarning,
192                                  "overflow encountered in exp"):
193        out = np.log(np.sum(np.exp(x_np)))
194        if out == np.inf:
195          raise RuntimeWarning("overflow encountered in exp")
196
197      with test_util.use_gpu():
198        x_tf = constant_op.constant(x_np, shape=x_np.shape)
199        y_tf_np = math_ops.reduce_logsumexp(x_tf)
200        y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np
201        self.assertAllClose(y_tf_np, y_np)
202
203  def testUnderflow(self):
204    x = [-1000, -1001, -1002, -1003]
205    for dtype in [np.float16, np.float32, np.double]:
206      x_np = np.array(x, dtype=dtype)
207      max_np = np.max(x_np)
208      with self.assertRaisesRegex(RuntimeWarning,
209                                  "divide by zero encountered in log"):
210        out = np.log(np.sum(np.exp(x_np)))
211        if out == -np.inf:
212          raise RuntimeWarning("divide by zero encountered in log")
213
214      with test_util.use_gpu():
215        x_tf = constant_op.constant(x_np, shape=x_np.shape)
216        y_tf_np = math_ops.reduce_logsumexp(x_tf)
217        y_np = np.log(np.sum(np.exp(x_np - max_np))) + max_np
218        self.assertAllClose(y_tf_np, y_np)
219
220  def testInfinity(self):
221    with test_util.use_gpu():
222      res = math_ops.reduce_logsumexp(-np.inf)
223      self.assertEqual(-np.inf, self.evaluate(res))
224
225  def testRaggedTensor(self):
226    for dtype in [dtypes.float16, dtypes.float32, dtypes.double]:
227      x_rt = ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]], dtype=dtype)
228      x_np = np.array(self.evaluate(x_rt.flat_values))
229      with test_util.use_gpu():
230        y_rt = math_ops.reduce_logsumexp(x_rt)
231        y_np = np.log(np.sum(np.exp(x_np - np.max(x_np)))) + np.max(x_np)
232        self.assertAllClose(y_rt, y_np)
233
234
235@test_util.run_all_in_graph_and_eager_modes
236class RoundTest(test_util.TensorFlowTestCase):
237
238  def testRounding(self):
239    x = np.arange(-5.0, 5.0, .25)
240    for dtype in [np.float32, np.double, np.int32]:
241      x_np = np.array(x, dtype=dtype)
242      with test_util.device(use_gpu=True):
243        x_tf = constant_op.constant(x_np, shape=x_np.shape)
244        y_tf = math_ops.round(x_tf)
245        y_tf_np = self.evaluate(y_tf)
246        y_np = np.round(x_np)
247        self.assertAllClose(y_tf_np, y_np, atol=1e-2)
248
249
250@test_util.with_eager_op_as_function
251@test_util.run_all_in_graph_and_eager_modes
252class MatMulTest(test_util.TensorFlowTestCase, parameterized.TestCase):
253  """Test for matmul."""
254
255  SUPPORTED_DTYPES = [
256      dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
257      dtypes.int64, dtypes.bfloat16, dtypes.complex64, dtypes.complex128
258  ]
259
260  def testMatMul2D(self):
261    for dtype in self.SUPPORTED_DTYPES:
262      a = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=dtype)
263      b = constant_op.constant([7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=dtype)
264      c = math_ops.matmul(a, b)
265      c_np = constant_op.constant([[58, 64], [139, 154]],
266                                  shape=(2, 2),
267                                  dtype=dtype)
268      self.assertAllClose(c, c_np, atol=1e-2)
269
270  def testBatchMatMul(self):
271    for dtype in self.SUPPORTED_DTYPES:
272      a = constant_op.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=dtype)
273      b = constant_op.constant(np.arange(13, 25), shape=[2, 3, 2], dtype=dtype)
274      c = math_ops.matmul(a, b)
275      c_np = constant_op.constant(
276          [[[94, 100], [229, 244]], [[508, 532], [697, 730]]],
277          shape=[2, 2, 2],
278          dtype=dtype)
279      self.assertAllClose(c, c_np, atol=1e-2)
280
281  def testUnsupportedtypeMatmul(self):
282    a = constant_op.constant(
283        np.arange(1, 13), shape=[2, 2, 3], dtype=dtypes.int8)
284    b = constant_op.constant(
285        np.arange(13, 25), shape=[2, 3, 2], dtype=dtypes.int8)
286    with self.assertRaisesRegex((TypeError, errors.InvalidArgumentError),
287                                "list of allowed values:"):
288      math_ops.matmul(a, b)
289
290  @parameterized.parameters((dtypes.int8, dtypes.int8),
291                            (dtypes.int8, dtypes.uint8),
292                            (dtypes.uint8, dtypes.int8))
293  # TODO(shivaniagrawal): matmul (dtypes.uint8, dtypes.uint8) fails in xla_gpu.
294  def testInt8MatMul2D(self, a_dtype, b_dtype):
295    a = constant_op.constant([1, 2, 3, 4, 5, 6], shape=[2, 3], dtype=a_dtype)
296    b = constant_op.constant([7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=b_dtype)
297    c = math_ops.matmul(a, b, output_type=dtypes.int32)
298    c_np = constant_op.constant([[58, 64], [139, 154]],
299                                shape=(2, 2),
300                                dtype=dtypes.int32)
301    self.assertAllClose(c, c_np)
302
303  @parameterized.parameters((dtypes.int8), (dtypes.uint8))
304  def testMixPrecMatMul2D(self, b_dtype):
305    a = constant_op.constant([1, 2, 3, 4, 5, 6],
306                             shape=[2, 3],
307                             dtype=dtypes.bfloat16)
308    b = constant_op.constant([7, 8, 9, 10, 11, 12], shape=[3, 2], dtype=b_dtype)
309    c = math_ops.matmul(a, b, output_type=dtypes.bfloat16)
310    c_np = constant_op.constant([[58, 64], [139, 154]],
311                                shape=(2, 2),
312                                dtype=dtypes.bfloat16)
313    self.assertAllClose(c, c_np, atol=1e-2)
314
315  @parameterized.parameters((dtypes.int8, dtypes.int8),
316                            (dtypes.int8, dtypes.uint8),
317                            (dtypes.uint8, dtypes.int8))
318  # TODO(shivaniagrawal): matmul (dtypes.uint8, dtypes.uint8) fails in xla_gpu.
319  def testInt8BatchMatmul(self, a_dtype, b_dtype):
320    a = constant_op.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=a_dtype)
321    b = constant_op.constant(np.arange(13, 25), shape=[2, 3, 2], dtype=b_dtype)
322    c_np = constant_op.constant(
323        [[[94, 100], [229, 244]], [[508, 532], [697, 730]]],
324        shape=[2, 2, 2],
325        dtype=dtypes.int32)
326    c = math_ops.matmul(a, b, output_type=dtypes.int32)
327    self.assertAllEqual(c, c_np)
328
329  @parameterized.parameters((dtypes.int8), (dtypes.uint8))
330  def testMixPrecBatchMatmul(self, b_dtype):
331    a = constant_op.constant(
332        np.arange(1, 13), shape=[2, 2, 3], dtype=dtypes.bfloat16)
333    b = constant_op.constant(np.arange(13, 25), shape=[2, 3, 2], dtype=b_dtype)
334    c_np = constant_op.constant(
335        [[[94, 100], [229, 244]], [[508, 532], [697, 730]]],
336        shape=[2, 2, 2],
337        dtype=dtypes.bfloat16)
338    c = math_ops.matmul(a, b, output_type=dtypes.bfloat16)
339    self.assertAllClose(c, c_np, atol=1e-2)
340
341  def testInvalidOutputTypeMatmul(self):
342    for dtype in [dtypes.int8, dtypes.bfloat16]:
343      a = constant_op.constant(np.arange(1, 13), shape=[2, 2, 3], dtype=dtype)
344      b = constant_op.constant(
345          np.arange(13, 25), shape=[2, 3, 2], dtype=dtypes.int8)
346      if context.executing_eagerly():
347        if context.is_tfrt_enabled():
348          with self.assertRaisesRegex(errors.InvalidArgumentError,
349                                      "NodeDef expected inputs"):
350            math_ops.matmul(a, b, output_type=dtypes.float32)
351        else:
352          with self.assertRaisesRegex(errors.NotFoundError,
353                                      "Could not find device for node:"):
354            math_ops.matmul(a, b, output_type=dtypes.float32)
355      else:
356        with self.assertRaisesRegex(errors.InvalidArgumentError,
357                                    "No OpKernel was registered to support Op"):
358          self.evaluate(math_ops.matmul(a, b, output_type=dtypes.float32))
359
360
361@test_util.run_all_in_graph_and_eager_modes
362class ModTest(test_util.TensorFlowTestCase):
363
364  def testFloat(self):
365    x = [0.5, 0.7, 0.3]
366    for dtype in [np.float32, np.double]:
367      # Test scalar and vector versions.
368      for denom in [x[0], [x[0]] * 3]:
369        x_np = np.array(x, dtype=dtype)
370        with test_util.use_gpu():
371          x_tf = constant_op.constant(x_np, shape=x_np.shape)
372          y_tf = math_ops.mod(x_tf, denom)
373          y_tf_np = self.evaluate(y_tf)
374          y_np = np.fmod(x_np, denom)
375        self.assertAllClose(y_tf_np, y_np, atol=1e-2)
376
377  def testFixed(self):
378    x = [5, 10, 23]
379    for dtype in [np.int32, np.int64]:
380      # Test scalar and vector versions.
381      for denom in [x[0], x]:
382        x_np = np.array(x, dtype=dtype)
383        with test_util.use_gpu():
384          x_tf = constant_op.constant(x_np, shape=x_np.shape)
385          y_tf = math_ops.mod(x_tf, denom)
386          y_tf_np = self.evaluate(y_tf)
387          y_np = np.mod(x_np, denom)
388        self.assertAllClose(y_tf_np, y_np)
389
390
391@test_util.run_all_in_graph_and_eager_modes
392class SquaredDifferenceTest(test_util.TensorFlowTestCase):
393
394  def testSquaredDifference(self):
395    for dtype in [
396        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype,
397        np.int32, np.int64
398    ]:
399      x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
400      y = np.array([-3, -2, -1], dtype=dtype)
401      z = (x - y) * (x - y)
402      with test_util.device(use_gpu=True):
403        z_tf = self.evaluate(math_ops.squared_difference(x, y))
404        self.assertAllClose(z, z_tf)
405
406  def testComplexSquaredDifference(self):
407    for dtype in [np.complex64, np.complex128]:
408      x = np.array([[1 + 3j, 2 + 2j, 3 + 1j], [4 - 1j, 5 - 2j, 6 - 3j]],
409                   dtype=dtype)
410      y = np.array([-3 + 1j, -2 + 2j, -1 + 3j], dtype=dtype)
411      z = np.conj(x - y) * (x - y)
412      with test_util.device(use_gpu=False):
413        z_tf = self.evaluate(math_ops.squared_difference(x, y))
414        self.assertAllClose(z, z_tf)
415
416
417@test_util.with_eager_op_as_function
418@test_util.run_all_in_graph_and_eager_modes
419class ApproximateEqualTest(test_util.TensorFlowTestCase):
420
421  def testApproximateEqual(self):
422    for dtype in [np.float32, np.double]:
423      x = dtype(1)
424      y = dtype(1.00009)
425      z = False
426      with test_util.device(use_gpu=True):
427        # Default tolerance is 0.00001
428        z_tf = self.evaluate(math_ops.approximate_equal(x, y))
429        self.assertAllEqual(z, z_tf)
430
431    for dtype in [np.float32, np.double]:
432      x = dtype(1)
433      y = dtype(1.000009)
434      z = True
435      with test_util.device(use_gpu=True):
436        # Default tolerance is 0.00001
437        z_tf = self.evaluate(math_ops.approximate_equal(x, y))
438        self.assertAllEqual(z, z_tf)
439
440    for dtype in [np.float32, np.double]:
441      x = np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype)
442      y = np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype)
443      z = np.array([[[[False, True], [True, False]]]], dtype=np.bool_)
444      with test_util.device(use_gpu=True):
445        z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
446        self.assertAllEqual(z, z_tf)
447
448  def testApproximateEqualShape(self):
449    for dtype in [np.float32, np.double]:
450      x = np.array([1, 2], dtype=dtype)
451      y = np.array([[1, 2]], dtype=dtype)
452      # The inputs 'x' and 'y' must have the same shape.
453      with self.assertRaisesRegex(
454          (ValueError, errors.InvalidArgumentError),
455          "Shapes must be equal rank|must be of the same shape"):
456        math_ops.approximate_equal(x, y)
457
458  def testApproximateEqualShapeXla(self):
459
460    @def_function.function(jit_compile=True)
461    def approximate_equal(x, y):
462      return math_ops.approximate_equal(x, y)
463
464    for dtype in [np.float32, np.double]:
465      x = np.array([1, 2], dtype=dtype)
466      y = np.array([[1, 2]], dtype=dtype)
467      with self.assertRaisesRegex(
468          (ValueError, errors.InvalidArgumentError),
469          "Shapes must be equal rank|must be of the same shape"):
470        approximate_equal(x, y)
471
472
473@test_util.run_all_in_graph_and_eager_modes
474class ScalarMulTest(test_util.TensorFlowTestCase):
475
476  def testAcceptsRefs(self):
477    if context.executing_eagerly():
478      var = resource_variable_ops.ResourceVariable(10, name="var")
479    else:
480      var = variables.Variable(10)
481    result = math_ops.scalar_mul(3, var)
482    init = variables.global_variables_initializer()
483    with test_util.device(use_gpu=True):
484      self.evaluate(init)
485      self.assertEqual(30, self.evaluate(result))
486
487  def testAcceptsConstant(self):
488    const = constant_op.constant(10)
489    result = math_ops.scalar_mul(3, const)
490    with test_util.device(use_gpu=True):
491      self.assertEqual(30, self.evaluate(result))
492
493  def testAcceptsTensor(self):
494    tensor = array_ops.ones([10, 10])
495    result = math_ops.scalar_mul(3, tensor)
496    expected = array_ops.ones([10, 10]) * 3
497
498    with test_util.device(use_gpu=True):
499      self.assertAllEqual(self.evaluate(expected), self.evaluate(result))
500
501  def testAcceptsIndexedSlices(self):
502    values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2])
503    indices = constant_op.constant([0, 2, 5])
504    x = math_ops.scalar_mul(-3, indexed_slices.IndexedSlices(values, indices))
505    with test_util.device(use_gpu=True):
506      self.assertAllEqual(
507          self.evaluate(x.values), [[-6, -9], [-15, -21], [0, 3]])
508      self.assertAllEqual(self.evaluate(x.indices), [0, 2, 5])
509
510
511@test_util.run_all_in_graph_and_eager_modes
512class AddNTest(test_util.TensorFlowTestCase):
513
514  def testPartials(self):
515    """Test that previously revealed a bug in buffer forwarding for AddN."""
516    partials = []
517    for _ in range(98):
518      partials.append(math_ops.add_n([constant_op.constant(1)]))
519    partials.append(
520        math_ops.add_n([constant_op.constant(1),
521                        constant_op.constant(1)]))
522
523    res = math_ops.add_n(partials) + constant_op.constant(0)
524    with test_util.use_gpu():
525      self.assertAllEqual(res, 100)
526
527  def testFloat(self):
528    np.random.seed(12345)
529    for num_inputs in range(1, 10):
530      x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(num_inputs)]
531      tf_x = ops.convert_n_to_tensor(x)
532      with test_util.use_gpu():
533        self.assertAllClose(sum(x), math_ops.add_n(tf_x))
534        self.assertAllClose(x[0] * num_inputs,
535                            math_ops.add_n([tf_x[0]] * num_inputs))
536
537  def testInt(self):
538    np.random.seed(54321)
539    for num_inputs in range(1, 10):
540      x = [
541          np.random.randint(-128, 128, (5, 4, 3, 2, 1))
542          for _ in range(num_inputs)
543      ]
544      tf_x = ops.convert_n_to_tensor(x)
545      with test_util.use_gpu():
546        self.assertAllEqual(sum(x), math_ops.add_n(tf_x))
547        self.assertAllEqual(x[0] * num_inputs,
548                            math_ops.add_n([tf_x[0]] * num_inputs))
549
550  def testGrad(self):
551    np.random.seed(42)
552    for num_inputs in range(1, 10):
553      with test_util.use_gpu():
554        input_vars = [
555            variables.Variable(10.0 * np.random.random())
556            for _ in range(0, num_inputs)
557        ]
558        self.evaluate(variables.global_variables_initializer())
559        if context.executing_eagerly():
560          with backprop.GradientTape() as tape:
561            tape.watch(input_vars)
562            addn = math_ops.add_n(input_vars)
563            add_n_grad = tape.gradient(addn, input_vars)
564        else:
565          addn = math_ops.add_n(input_vars)
566          add_n_grad = gradients.gradients(addn, input_vars)
567
568        self.assertAllEqual(
569            np.repeat(1.0, num_inputs),  # d/dx (x + y + ...) = 1
570            [self.evaluate(g) for g in add_n_grad])
571
572  def testIndexedSlices(self):
573    slc = indexed_slices.IndexedSlices(
574        array_ops.constant([1, 2], shape=[1, 2]), array_ops.constant([1]),
575        array_ops.constant([2, 2]))
576    slc_as_dense = np.array([[0, 0], [1, 2]])
577    with test_util.use_gpu():
578      # add_n currently always converts IndexedSlices to dense
579      self.assertAllEqual(slc_as_dense, math_ops.add_n([slc]))
580      self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc]))
581
582  def test_iterable(self):
583    """Test that add_n supports iterables (e.g. generators and dict values)."""
584
585    def fn():
586      yield 1
587      yield 2
588
589    values_dict = {"a": 1, "b": 2}
590    with test_util.use_gpu():
591      self.assertAllEqual(3, math_ops.add_n(fn()))
592      self.assertAllEqual(3, math_ops.add_n(values_dict.values()))
593
594
595@test_util.run_all_in_graph_and_eager_modes
596class DivAndModTest(test_util.TensorFlowTestCase):
597  # TODO(aselle): Test more types before exposing new division operators.
598
599  def intTestData(self):
600    nums = np.arange(-10, 10, 1).reshape(20, 1)
601    divs = np.arange(-3, 4, 2).reshape(1, 4)
602    return nums, divs
603
604  def floatTestData(self):
605    nums = np.arange(-10, 10, .25).reshape(80, 1)
606    divs = np.arange(-3, 0, .25).reshape(1, 12)
607    return nums, divs
608
609  def numpySafeFloorDivInt(self, x, y):
610    z = x // y
611    # Numpy produces 0 for INT_MIN/-1, but we expect an overflow to INT_MIN
612    # so that (INT_MIN/-1) + (INT_MIN % -1) = INT_MIN + 0 = INT_MIN.
613    z[(x == np.iinfo(x.dtype).min) & (y == -1)] = np.iinfo(x.dtype).min
614    return z
615
616  def numpySafeFloorModInt(self, x, y):
617    # Numpy crashes with a FPE for INT_MIN % -1.
618    z = self.numpySafeFloorDivInt(x, y)
619    return x - z * y
620
621  def numpySafeTruncateDivInt(self, x, y):
622    z = self.numpySafeFloorDivInt(x, y)
623    # Round up if non-zero remainder and inputs have opposite signs.
624    z[(x != z * y) & ((x < 0) != (y < 0))] += 1
625    return z
626
627  def numpySafeTruncateModInt(self, x, y):
628    # Numpy crashes with a FPE for INT_MIN % -1.
629    z = self.numpySafeTruncateDivInt(x, y)
630    return x - z * y
631
632  def testFloorModInt(self):
633    nums, divs = self.intTestData()
634    for dtype in [np.int32, np.int64]:
635      x = nums.astype(dtype)
636      y = divs.astype(dtype)
637      tf_result = math_ops.floormod(x, y)
638      np_result = self.numpySafeFloorModInt(x, y)
639      self.assertAllEqual(tf_result, np_result)
640      tf2_result = (array_ops.constant(x) % array_ops.constant(y))
641      self.assertAllEqual(tf2_result, tf_result)
642
643  def testFloorModFloat(self):
644    nums, divs = self.floatTestData()
645    for dtype in [np.float16, np.float32, np.float64]:
646      x = nums.astype(dtype)
647      y = divs.astype(dtype)
648      tf_result = math_ops.floormod(x, y)
649      np_result = x % y
650      self.assertAllEqual(tf_result, np_result)
651      tf2_result = (array_ops.constant(x) % array_ops.constant(y))
652      self.assertAllEqual(tf2_result, tf_result)
653
654  def testFloorModBfloat16(self):
655    nums, divs = self.floatTestData()
656    tf_result = math_ops.floormod(
657        math_ops.cast(nums, dtypes.bfloat16),
658        math_ops.cast(divs, dtypes.bfloat16))
659    np_result = nums % divs
660    self.assertAllEqual(tf_result, np_result)
661
662  def testTruncateModInt(self):
663    nums, divs = self.intTestData()
664    tf_result = math_ops.truncatemod(nums, divs)
665    np_result = np.fmod(nums, divs)
666    self.assertAllEqual(tf_result, np_result)
667
668  def testTruncateModFloat(self):
669    nums, divs = self.floatTestData()
670    tf_result = math_ops.truncatemod(nums, divs)
671    np_result = np.fmod(nums, divs)
672    self.assertAllEqual(tf_result, np_result)
673
674  def testFloorDivideInt(self):
675    nums, divs = self.intTestData()
676    tf_result = math_ops.floor_div(nums, divs)
677    np_result = self.numpySafeFloorDivInt(nums, divs)
678    self.assertAllEqual(tf_result, np_result)
679    tf2_result = (array_ops.constant(nums) // array_ops.constant(divs))
680    self.assertAllEqual(tf2_result, tf_result)
681
682  def testTruncateDivideInt(self):
683    nums, divs = self.intTestData()
684    tf_result = math_ops.truncatediv(nums, divs)
685    np_result = self.numpySafeTruncateDivInt(nums, divs)
686    self.assertAllEqual(tf_result, np_result)
687
688  @test_util.deprecated_graph_mode_only
689  def testDivideName(self):
690    op = math_ops.divide(
691        array_ops.constant(3), array_ops.constant(4), name="my_cool_divide")
692    self.assertEqual(op.name, "my_cool_divide:0")
693
694  def testRealDiv(self):
695    nums, divs = self.floatTestData()
696    tf_result = math_ops.realdiv(nums, divs)
697    np_result = np.divide(nums, divs)
698    self.assertAllClose(tf_result, np_result)
699
700  def testDivideType(self):
701    a = array_ops.constant([2], dtype=dtypes.int32)
702    # Since __future__.division is effect, we should always upgrade to float64
703    b = math_ops.divide(a, 1)
704    self.assertEqual(b.dtype, dtypes.float64)
705    self.assertEqual(2.0, self.evaluate(b))
706    c = math_ops.divide(a, 4)
707    self.assertEqual(c.dtype, dtypes.float64)
708    self.assertEqual(0.5, self.evaluate(c))
709
710  def testComplexDiv(self):
711    foo = array_ops.constant([1. + 3.j])
712    _ = math_ops.divide(foo, 1.)
713    _ = math_ops.div(foo, 2.)
714
715  def testFloorDivGrad(self):
716    a = variables.Variable(2.)
717    b = variables.Variable(4.)
718    input_vars = [a, b]
719    self.evaluate(variables.global_variables_initializer())
720    if context.executing_eagerly():
721      # TDOO(rmlarsen): Is there a more compact way of
722      # writing this for multiple expressions?
723      with backprop.GradientTape() as tape:
724        tape.watch(input_vars)
725        c_grad0 = tape.gradient(math_ops.divide(a, b), input_vars)
726      with backprop.GradientTape() as tape:
727        tape.watch(input_vars)
728        c_grad1 = tape.gradient(math_ops.div(a, b), input_vars)
729      with backprop.GradientTape() as tape:
730        tape.watch(input_vars)
731        c_grad2 = tape.gradient(math_ops.floordiv(a, b), input_vars)
732    else:
733      c_grad0 = gradients.gradients(math_ops.divide(a, b), input_vars)
734      c_grad1 = gradients.gradients(math_ops.div(a, b), input_vars)
735      c_grad2 = gradients.gradients(math_ops.floordiv(a, b), input_vars)
736    self.assertAllEqual([self.evaluate(x) for x in c_grad0], [.25, -.125])
737    self.assertAllEqual([self.evaluate(x) for x in c_grad1], [.25, -.125])
738    self.assertAllEqual(
739        [None if x is None else self.evaluate(x) for x in c_grad2],
740        [None, None])
741
742  def testConsistent(self):
743    nums, divs = self.intTestData()
744    tf_result = (
745        math_ops.floor_div(nums, divs) * divs + math_ops.floormod(nums, divs))
746    tf_nums = array_ops.constant(nums)
747    tf_divs = array_ops.constant(divs)
748    tf2_result = (tf_nums // tf_divs * tf_divs + tf_nums % tf_divs)
749    np_result = (nums // divs) * divs + (nums % divs)
750    # Consistent with numpy
751    self.assertAllEqual(tf_result, np_result)
752    # Consistent with two forms of divide
753    self.assertAllEqual(tf_result, tf2_result)
754    # consistency for truncation form
755    tf3_result = (
756        math_ops.truncatediv(nums, divs) * divs +
757        math_ops.truncatemod(nums, divs))
758    expanded_nums = np.reshape(
759        np.tile(nums, divs.shape[1]), (nums.shape[0], divs.shape[1]))
760    # Consistent with desire to get numerator
761    self.assertAllEqual(tf3_result, expanded_nums)
762    # Consistent with desire to get numerator
763    self.assertAllEqual(tf_result, expanded_nums)
764
765  def testWithPythonValue(self):
766    # Test case for https://github.com/tensorflow/tensorflow/issues/39475
767    x = math_ops.divide(5, 2)
768    self.assertIsInstance(x, ops.Tensor)
769    x = math_ops.divide(5, array_ops.constant(2.0))
770    self.assertIsInstance(x, ops.Tensor)
771
772  def intEdgeTestData(self, dtype):
773    """Edge-case test data for integer types."""
774    # INT_MIN/-1 expected to produce signed-integer overflow,
775    # INT_MIN/INT_MAX expected to work.
776    nums = np.array([np.iinfo(dtype).min, -1, 1,
777                     np.iinfo(dtype).max],
778                    dtype=dtype).reshape([4, 1])
779    divs = nums.reshape([1, 4])
780    return nums, divs
781
782  @test_util.disable_asan("Expected signed integer overflow.")
783  @test_util.disable_ubsan("Expected signed integer overflow.")
784  def testFloorDivModIntEdges(self):
785    for dtype in [np.int32, np.int64]:
786      x, y = self.intEdgeTestData(dtype)
787      tf_floor_div = math_ops.floor_div(x, y)
788      np_floor_div = self.numpySafeFloorDivInt(x, y)
789      self.assertAllEqual(tf_floor_div, np_floor_div)
790      tf_floor_mod = math_ops.floormod(x, y)
791      np_floor_mod = self.numpySafeFloorModInt(x, y)
792      self.assertAllEqual(tf_floor_mod, np_floor_mod)
793      z = math_ops.add(math_ops.multiply(tf_floor_div, y), tf_floor_mod)
794      # x = floor_div(x, y) * y + floor_mod(x, y)
795      self.assertAllEqual(z, np.broadcast_to(x, z.shape))
796
797  @test_util.disable_asan("Expected signed integer overflow.")
798  @test_util.disable_ubsan("Expected signed integer overflow.")
799  def testTruncateDivModIntEdges(self):
800    for dtype in [np.int32, np.int64]:
801      x, y = self.intEdgeTestData(dtype)
802      tf_truncate_div = math_ops.truncatediv(x, y)
803      np_truncate_div = self.numpySafeTruncateDivInt(x, y)
804      self.assertAllEqual(tf_truncate_div, np_truncate_div)
805      tf_truncate_mod = math_ops.truncatemod(x, y)
806      np_truncate_mod = self.numpySafeTruncateModInt(x, y)
807      self.assertAllEqual(tf_truncate_mod, np_truncate_mod)
808      z = math_ops.add(math_ops.multiply(tf_truncate_div, y), tf_truncate_mod)
809      # x = truncatediv(x, y) * y + truncatemod(x, y)
810      self.assertAllEqual(z, np.broadcast_to(x, z.shape))
811
812
813@test_util.run_all_in_graph_and_eager_modes
814class DivNoNanTest(test_util.TensorFlowTestCase, parameterized.TestCase):
815
816  @parameterized.parameters((dtypes.bfloat16), (dtypes.float16),
817                            (dtypes.float32), (dtypes.float64),
818                            (dtypes.complex64), (dtypes.complex128))
819  def testBasic(self, dtype):
820    nums = np.arange(-10, 10, .25).reshape(80, 1)
821    divs = np.arange(-3, 3, .25).reshape(1, 24)
822
823    tf_nums = constant_op.constant(nums, dtype=dtype)
824    tf_divs = constant_op.constant(divs, dtype=dtype)
825
826    # Use tf versions for expected value to ensure inputs are identical
827    # (e.g. in the case of bfloat16).
828    np_nums = self.evaluate(tf_nums)
829    np_divs = self.evaluate(tf_divs)
830    np_result = np.true_divide(np_nums, np_divs)
831    np_result[:, np_divs[0] == 0] = 0
832
833    with test_util.use_gpu():
834      tf_result = math_ops.div_no_nan(tf_nums, tf_divs)
835      self.assertAllCloseAccordingToType(tf_result, np_result)
836
837  @parameterized.parameters((dtypes.bfloat16), (dtypes.float16),
838                            (dtypes.float32), (dtypes.float64),
839                            (dtypes.complex64), (dtypes.complex128))
840  def testSmall(self, dtype):
841    # Choose values whose squared magnitude underflows to zero/subnormal.
842    zero = constant_op.constant([0, 0, 0, 0], dtype=dtype)
843    divs = constant_op.constant([1e-25, -1e-20, 1e-165, -1e-160], dtype=dtype)
844    tf_result = math_ops.div_no_nan(zero, divs)
845
846    # Results should always be exactly zero.
847    self.assertAllEqual(tf_result, zero)
848
849  @parameterized.parameters((dtypes.bfloat16), (dtypes.float16),
850                            (dtypes.float32), (dtypes.float64),
851                            (dtypes.complex64), (dtypes.complex128))
852  def testNonFiniteInNumerator(self, dtype):
853    nums = constant_op.constant([np.nan, np.inf, np.NINF], dtype=dtype)
854    zeros = constant_op.constant([0, 0, 0], dtype=dtype)
855    ones = constant_op.constant([1, 1, 1], dtype=dtype)
856    with test_util.use_gpu():
857      tf_result_zeros = math_ops.div_no_nan(nums, zeros)
858      self.assertAllEqual([0, 0, 0], tf_result_zeros)
859      tf_result_ones = math_ops.div_no_nan(nums, ones)
860      self.assertAllEqual(nums / ones, tf_result_ones)
861
862
863@test_util.run_all_in_graph_and_eager_modes
864class MultiplyNoNanTest(test_util.TensorFlowTestCase):
865
866  def testBasic(self):
867    for dtype in [np.float32, np.float64]:
868      values = [0, 1, np.nan, np.inf, np.NINF]
869      x = constant_op.constant(values, dtype=dtype)
870      zeros = constant_op.constant(np.zeros((5,)), dtype=dtype)
871      ones = constant_op.constant(np.ones((5,)), dtype=dtype)
872      with test_util.use_gpu():
873        tf_result_zeros = math_ops.multiply_no_nan(x, zeros)
874        self.assertAllEqual(tf_result_zeros, zeros)
875        tf_result_ones = math_ops.multiply_no_nan(x, ones)
876        self.assertAllEqual(tf_result_ones, x)
877        # Normal floating point arithmetic if nonfinite values are in the
878        # second argument.
879        tf_result_reverseargs = math_ops.multiply_no_nan(zeros, x)
880        self.assertAllEqual(zeros * x, tf_result_reverseargs)
881
882
883@test_util.run_all_in_graph_and_eager_modes
884class XlogyTest(test_util.TensorFlowTestCase):
885
886  def testXlogyNoZero(self):
887    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
888      x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
889      y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
890      with test_util.use_gpu():
891        xlogy = self.evaluate(math_ops.xlogy(x, y))
892        xtimeslogy = self.evaluate(x * math_ops.log(y))
893        self.assertAllClose(xlogy, xtimeslogy)
894
895  def testXlogyWithZero(self):
896    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
897      x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
898      y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
899      with test_util.use_gpu():
900        xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
901        zeros_np = self.evaluate(array_ops.zeros_like(y))
902        self.assertAllClose(xlogy_tf_np, zeros_np)
903
904  def testXlogyWithZeroBroadcast(self):
905    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
906      x = constant_op.constant([[0.], [1.]], dtype=dtype)
907      y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
908      with test_util.use_gpu():
909        xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
910        zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
911        xtimes_logy = self.evaluate(math_ops.log(y[1]))
912        self.assertAllClose(zeros_np, xlogy_tf_np[0])
913        self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
914
915
916@test_util.run_all_in_graph_and_eager_modes
917class Xlog1pyTest(test_util.TensorFlowTestCase):
918
919  def testXlog1pyNoNeg1(self):
920    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
921      x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
922      y = constant_op.constant([[-0.1, -0.2, 3.5], [3.1, -0.9, 2.]],
923                               dtype=dtype)
924      with test_util.use_gpu():
925        xlog1py = self.evaluate(math_ops.xlog1py(x, y))
926        xtimeslog1py = self.evaluate(x * math_ops.log1p(y))
927        self.assertAllClose(xlog1py, xtimeslog1py)
928
929  def testXlog1pyWithNegOne(self):
930    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
931      x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
932      y = constant_op.constant([[0.1, 0.2, 3.5], [-1., 1., 2.]], dtype=dtype)
933      with test_util.use_gpu():
934        xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y))
935        zeros_np = self.evaluate(array_ops.zeros_like(y))
936        self.assertAllClose(xlog1py_tf_np, zeros_np)
937
938  def testXlog1pyWithZeroBroadcast(self):
939    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
940      x = constant_op.constant([[0.], [1.]], dtype=dtype)
941      y = constant_op.constant([[-0.1, -0.2, -1.], [0., 1., 2.]], dtype=dtype)
942      with test_util.use_gpu():
943        xlog1py_tf_np = self.evaluate(math_ops.xlog1py(x, y))
944        zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
945        xtimes_log1py = self.evaluate(math_ops.log1p(y[1]))
946        self.assertAllClose(zeros_np, xlog1py_tf_np[0])
947        self.assertAllClose(xtimes_log1py, xlog1py_tf_np[1])
948
949
950@test_util.run_all_in_graph_and_eager_modes
951class XdivyTest(test_util.TensorFlowTestCase):
952
953  def testXdivyNoZero(self):
954    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
955      x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
956      y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
957      with test_util.use_gpu():
958        xdivy = self.evaluate(math_ops.xdivy(x, y))
959        x_over_y = self.evaluate(x / y)
960        self.assertAllClose(xdivy, x_over_y)
961
962  def testXdivyWithZero(self):
963    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
964      x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
965      y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
966      with test_util.use_gpu():
967        xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
968        zeros_np = self.evaluate(array_ops.zeros_like(y))
969        self.assertAllClose(xdivy_tf_np, zeros_np)
970
971  def testXdivyWithZeroBroadcast(self):
972    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
973      x = constant_op.constant([[0.], [1.]], dtype=dtype)
974      y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
975      with test_util.use_gpu():
976        xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
977        zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
978        x_over_y = self.evaluate(1 / y[1])
979        self.assertAllClose(zeros_np, xdivy_tf_np[0])
980        self.assertAllClose(x_over_y, xdivy_tf_np[1])
981
982
983@test_util.run_all_in_graph_and_eager_modes
984class NextAfterTest(test_util.TensorFlowTestCase):
985
986  # Basic NextAfter tests that replicate numpy nextafter tests.
987  def testBasic(self):
988
989    for dtype in [dtypes.float32, dtypes.float64]:
990      one = constant_op.constant([1], dtype=dtype)
991      two = constant_op.constant([2], dtype=dtype)
992      zero = constant_op.constant([0], dtype=dtype)
993      nan = constant_op.constant([np.nan], dtype=dtype)
994
995      eps = constant_op.constant([np.finfo(dtype.as_numpy_dtype).eps],
996                                 dtype=dtype)
997
998      self.assertAllEqual(math_ops.nextafter(one, two) - one, eps)
999      self.assertAllLess(math_ops.nextafter(one, zero) - one, 0)
1000      self.assertAllEqual(math_ops.is_nan(math_ops.nextafter(nan, one)), [True])
1001      self.assertAllEqual(math_ops.is_nan(math_ops.nextafter(one, nan)), [True])
1002      self.assertAllEqual(math_ops.nextafter(one, one), one)
1003
1004  def testBroadcasting(self):
1005
1006    for dtype in [dtypes.float32, dtypes.float64]:
1007      one = constant_op.constant([1, 1], dtype=dtype)
1008      two = constant_op.constant([2], dtype=dtype)
1009
1010      eps = np.finfo(dtype.as_numpy_dtype).eps
1011
1012      eps_const = constant_op.constant([eps, eps], dtype=dtype)
1013
1014      self.assertAllEqual(math_ops.nextafter(one, two) - one, eps_const)
1015
1016
1017@test_util.run_all_in_graph_and_eager_modes
1018class BinaryOpsTest(test_util.TensorFlowTestCase):
1019
1020  def testErrorReceivedIfDtypeMismatchFromOp(self):
1021    if context.executing_eagerly():
1022      error = errors_impl.InvalidArgumentError
1023      error_message = (
1024          r"cannot compute Add(V2)? as input #1\(zero-based\) was expected to "
1025          r"be a int32 tensor but is a float tensor \[Op:Add(V2)?\]")
1026    else:
1027      error = TypeError
1028      error_message = (
1029          "Input 'y' of 'Add(V2)?' Op has type float32 that does not "
1030          "match type int32 of argument 'x'.")
1031    with self.assertRaisesRegex(error, error_message):
1032      a = array_ops.ones([1], dtype=dtypes.int32) + 1.0
1033      self.evaluate(a)
1034
1035  def testRHSDispatchingAndErrorRaising(self):
1036    if context.executing_eagerly():
1037      error = ValueError
1038      error_message = (
1039          r"Attempt to convert a value .* with an unsupported type")
1040    else:
1041      error = TypeError
1042      error_message = (r"Failed to convert elements of .* to Tensor")
1043
1044    class RHSReturnsTrue:
1045
1046      def __radd__(self, other):
1047        return True
1048
1049    a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsTrue()
1050    self.assertEqual(a, True)
1051
1052    class RHSRaisesError:
1053
1054      def __radd__(self, other):
1055        raise TypeError("RHS not implemented")
1056
1057    with self.assertRaisesRegex(error, error_message):
1058      a = array_ops.ones([1], dtype=dtypes.int32) + RHSRaisesError()
1059      self.evaluate(a)
1060
1061    class RHSReturnsNotImplemented:
1062
1063      def __radd__(self, other):
1064        return NotImplemented
1065
1066    with self.assertRaisesRegex(error, error_message):
1067      a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsNotImplemented()
1068      self.evaluate(a)
1069
1070    class RHSNotImplemented:
1071      pass
1072
1073    with self.assertRaisesRegex(error, error_message):
1074      a = array_ops.ones([1], dtype=dtypes.int32) + RHSNotImplemented()
1075      self.evaluate(a)
1076
1077
1078class SignTest(test_util.TensorFlowTestCase):
1079
1080  def test_complex_sign_gradient(self):
1081    with context.eager_mode():
1082      x = math_ops.complex(1., 1.)
1083      with backprop.GradientTape() as t:
1084        t.watch(x)
1085        y = math_ops.sign(x)
1086      self.assertAllClose(
1087          t.gradient(y, x), math_ops.complex(0.353553, -0.353553))
1088
1089
1090@test_util.run_all_in_graph_and_eager_modes
1091class ReciprocalNoNanTest(test_util.TensorFlowTestCase):
1092
1093  allowed_dtypes = [
1094      dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
1095      dtypes.complex128
1096  ]
1097
1098  def testBasic(self):
1099    for dtype in self.allowed_dtypes:
1100      x = constant_op.constant([1.0, 2.0, 0.0, 4.0], dtype=dtype)
1101
1102      y = math_ops.reciprocal_no_nan(x)
1103
1104      target = constant_op.constant([1.0, 0.5, 0.0, 0.25], dtype=dtype)
1105
1106      self.assertAllEqual(y, target)
1107      self.assertEqual(y.dtype.base_dtype, target.dtype.base_dtype)
1108
1109  def testInverse(self):
1110    for dtype in self.allowed_dtypes:
1111      x = np.random.choice([0, 1, 2, 4, 5], size=(5, 5, 5))
1112      x = constant_op.constant(x, dtype=dtype)
1113
1114      y = math_ops.reciprocal_no_nan(math_ops.reciprocal_no_nan(x))
1115
1116      self.assertAllClose(y, x)
1117      self.assertEqual(y.dtype.base_dtype, x.dtype.base_dtype)
1118
1119
1120class EqualityTest(test_util.TensorFlowTestCase, parameterized.TestCase):
1121
1122  @test_util.run_all_in_graph_and_eager_modes
1123  def testEqualityNone(self):
1124    x = constant_op.constant([1.0, 2.0, 0.0, 4.0], dtype=dtypes.float32)
1125    self.assertNotEqual(x, None)
1126    self.assertNotEqual(None, x)
1127    self.assertFalse(math_ops.tensor_equals(x, None))
1128    self.assertTrue(math_ops.tensor_not_equals(x, None))
1129
1130  @parameterized.named_parameters(
1131      (f"-is_equals={is_equals}-float_literal_type={type(float_literal)}"  # pylint: disable=g-complex-comprehension
1132       f"-float_literal={float_literal}", is_equals, float_literal)
1133      for float_literal in [4.6, np.float32(4.6), 4.4, np.float32(4.4)]
1134      for is_equals in [True, False])
1135  def testEqualityNoDowncast(self, is_equals, float_literal):
1136    if (tf2.enabled() and isinstance(float_literal, np.float32) or
1137        not tf2.enabled() and isinstance(float_literal, float)):
1138      # TODO(b/199262800): Remove this skip
1139      self.skipTest("There is a bug in type promotion.")
1140    if is_equals:
1141      op = math_ops.tensor_equals
1142    else:
1143      op = math_ops.tensor_not_equals
1144    x = constant_op.constant(4)
1145    try:
1146      result = op(x, float_literal)
1147      if isinstance(result, ops.Tensor):
1148        result = self.evaluate(result)
1149    except TypeError:
1150      # Throwing a TypeError is OK
1151      return
1152    self.assertEqual(result, not is_equals)
1153
1154
1155@test_util.run_all_in_graph_and_eager_modes
1156class RangeTest(test_util.TensorFlowTestCase):
1157
1158  def testConvertToTensorRange(self):
1159    values = range(5)
1160    tensor = ops.convert_to_tensor(values)
1161    self.assertAllEqual((5,), tensor.get_shape().as_list())
1162    self.assertAllEqual(values, self.evaluate(tensor))
1163
1164  def testInputsNearInt64Max(self):
1165    int64_t_max = 2**63 - 1
1166    x = math_ops.range(0, 201, int64_t_max - 200, dtype=dtypes.int64)
1167    self.assertAllEqual((0,), self.evaluate(x))  # just below potential overflow
1168    x = math_ops.range(0, 202, int64_t_max - 200, dtype=dtypes.int64)
1169    self.assertAllEqual(
1170        (0,), self.evaluate(x))  # smallest input with potential overflow
1171
1172
1173@test_util.run_all_in_graph_and_eager_modes
1174class ErfcinvTest(test_util.TensorFlowTestCase):
1175
1176  def testErfcinv(self):
1177    values = np.random.uniform(0.1, 1.9, size=int(1e4)).astype(np.float32)
1178    approx_id = math_ops.erfc(math_ops.erfcinv(values))
1179    self.assertAllClose(values, self.evaluate(approx_id))
1180
1181
1182@test_util.run_all_in_graph_and_eager_modes
1183class ArgMaxMinTest(test_util.TensorFlowTestCase):
1184
1185  def _generateRandomTensor(self, dtype, shape):
1186    if dtype.is_integer:
1187      array = np.random.default_rng().integers(
1188          low=dtype.min, high=dtype.max, size=shape, endpoint=True)
1189      return constant_op.constant(array, dtype=dtype)
1190    else:
1191      array = np.random.default_rng().uniform(low=-1.0, high=1.0, size=shape)
1192      return constant_op.constant(array, dtype=dtype)
1193
1194  def _getValidDtypes(self):
1195    return (dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64,
1196            dtypes.int32, dtypes.int64)
1197
1198  def testArgMax(self):
1199    shape = (24, 8)
1200    for dtype in self._getValidDtypes():
1201      tf_values = self._generateRandomTensor(dtype, shape)
1202      np_values = self.evaluate(tf_values)
1203      for axis in range(0, len(shape)):
1204        np_max = np.argmax(np_values, axis=axis)
1205        tf_max = math_ops.argmax(tf_values, axis=axis)
1206        self.assertAllEqual(tf_max, np_max)
1207
1208  def testArgMaxReturnsFirstOccurence(self):
1209    for dtype in self._getValidDtypes():
1210      values = constant_op.constant(
1211          [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype)
1212      self.assertAllEqual(
1213          math_ops.argmax(values, axis=1),
1214          np.argmax(self.evaluate(values), axis=1))
1215
1216      # Long tensor to ensure works with multithreading/GPU
1217      values = array_ops.zeros(shape=(193681,), dtype=dtype)
1218      self.assertAllEqual(math_ops.argmax(values), 0)
1219
1220  def testArgMaxUint16(self):
1221    shape = (24, 8)
1222    for dtype in self._getValidDtypes():
1223      tf_values = self._generateRandomTensor(dtype, shape)
1224      np_values = self.evaluate(tf_values)
1225      for axis in range(0, len(shape)):
1226        np_max = np.argmax(np_values, axis=axis)
1227        tf_max = math_ops.argmax(
1228            tf_values, axis=axis, output_type=dtypes.uint16)
1229        self.assertAllEqual(tf_max, np_max)
1230
1231  def testArgMin(self):
1232    shape = (24, 8)
1233    for dtype in self._getValidDtypes():
1234      tf_values = self._generateRandomTensor(dtype, shape)
1235      np_values = self.evaluate(tf_values)
1236      for axis in range(0, len(shape)):
1237        np_min = np.argmin(np_values, axis=axis)
1238        tf_min = math_ops.argmin(tf_values, axis=axis)
1239        self.assertAllEqual(tf_min, np_min)
1240
1241  def testArgMinReturnsFirstOccurence(self):
1242    for dtype in self._getValidDtypes():
1243      values = constant_op.constant(
1244          [[10, 11, 15, 15, 10], [12, 12, 10, 10, 12]], dtype=dtype)
1245      self.assertAllEqual(
1246          math_ops.argmin(values, axis=1),
1247          np.argmin(self.evaluate(values), axis=1))
1248
1249      # Long tensor to ensure works with multithreading/GPU
1250      values = array_ops.zeros(shape=(193681,), dtype=dtype)
1251      self.assertAllEqual(math_ops.argmin(values), 0)
1252
1253
1254class CastTest(test_util.TensorFlowTestCase):
1255
1256  def testCastWithFullType(self):
1257
1258    @def_function.function
1259    def test_fn():
1260      ta = tensor_array_ops.TensorArray(dtypes.int32, size=1)
1261      h = math_ops.cast(ta.flow, dtypes.variant)
1262
1263      t = full_type_pb2.FullTypeDef(
1264          type_id=full_type_pb2.TFT_PRODUCT,
1265          args=[full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY)])
1266      h.op.experimental_set_type(t)
1267
1268      ta = tensor_array_ops.TensorArray(dtypes.int32, flow=h)
1269      ta = ta.write(0, constant_op.constant(1))
1270      return ta.stack()
1271
1272    self.assertAllEqual(self.evaluate(test_fn()), [1])
1273
1274if __name__ == "__main__":
1275  googletest.main()
1276