xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/unary_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for XLA JIT compiler."""
16
17import unittest
18
19import numpy as np
20import six
21
22from tensorflow.compiler.tests import xla_test
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import bitwise_ops
26from tensorflow.python.ops import gen_functional_ops
27from tensorflow.python.ops import gen_nn_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.platform import googletest
31
32
33def nhwc_to_format(x, data_format):
34  """Converts a numpy array from NHWC format to `data_format`."""
35  rank = len(x.shape)
36  if data_format == "NCHW":
37    return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1)))
38  elif data_format == "NHWC":
39    return x
40  else:
41    raise ValueError("Unknown format {}".format(data_format))
42
43
44class UnaryOpsTest(xla_test.XLATestCase):
45  """Test cases for unary operators."""
46
47  def _assertOpOutputMatchesExpected(self,
48                                     op,
49                                     inp,
50                                     expected,
51                                     equality_test=None,
52                                     rtol=1e-3,
53                                     atol=1e-5):
54    """Verifies that 'op' produces 'expected' when fed input 'inp' .
55
56    Args:
57      op: operator to test
58      inp: numpy input array to use as input to 'op'.
59      expected: numpy array representing the expected output of 'op'.
60      equality_test: either None, or a function that tests two numpy arrays for
61        equality. If None, self.assertAllClose is used.
62      rtol: relative tolerance for equality test.
63      atol: absolute tolerance for equality test.
64    """
65    with self.session() as session:
66      with self.test_scope():
67        pinp = array_ops.placeholder(
68            dtypes.as_dtype(inp.dtype), inp.shape, name="a")
69        output = op(pinp)
70      result = session.run(output, {pinp: inp})
71      if equality_test is None:
72        self.assertEqual(output.dtype, expected.dtype)
73        self.assertAllCloseAccordingToType(
74            expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
75      else:
76        equality_test(result, expected, rtol=rtol, atol=atol)
77
78  def ListsAreClose(self, result, expected, rtol, atol):
79    """Tests closeness of two lists of floats."""
80    self.assertEqual(len(result), len(expected))
81    for i in range(len(result)):
82      self.assertAllClose(result[i], expected[i], rtol, atol)
83
84  def AssertCloseAndSorted(self, result, expected, rtol, atol):
85    """Tests that result and expeted are both close and sorted."""
86    self.assertAllClose(result, expected, rtol, atol)
87    self.assertAllEqual(np.sort(result), result)
88
89  def AssertAllEqual(self, result, expected, rtol, atol):
90    """Tests that result and expeted are exactly equal."""
91    self.assertAllEqual(result, expected)
92
93  def testAllTypeOps(self):
94    for dtype in self.numeric_types - {np.int8, np.uint8}:
95      self._assertOpOutputMatchesExpected(
96          array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
97          np.array(
98              [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
99              dtype=dtype))
100      self._assertOpOutputMatchesExpected(
101          array_ops.diag_part,
102          np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
103          np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype))
104      self._assertOpOutputMatchesExpected(
105          array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype),
106          np.array(
107              [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]],
108                                                      [[0, 0], [0, 4]]]],
109              dtype=dtype))
110
111      self._assertOpOutputMatchesExpected(
112          array_ops.identity,
113          np.array([[-1, 1]], dtype=dtype),
114          expected=np.array([[-1, 1]], dtype=dtype))
115
116      self._assertOpOutputMatchesExpected(
117          array_ops.prevent_gradient,
118          np.array([[-1, 1]], dtype=dtype),
119          expected=np.array([[-1, 1]], dtype=dtype))
120
121      self._assertOpOutputMatchesExpected(
122          array_ops.squeeze,
123          np.array([[[[[]]]]], dtype=dtype),
124          expected=np.array([], dtype=dtype))
125      self._assertOpOutputMatchesExpected(
126          array_ops.squeeze,
127          np.array([[[1], [2]]], dtype=dtype),
128          expected=np.array([1, 2], dtype=dtype))
129      self._assertOpOutputMatchesExpected(
130          array_ops.squeeze,
131          np.array([[[1]], [[2]]], dtype=dtype),
132          expected=np.array([1, 2], dtype=dtype))
133      self._assertOpOutputMatchesExpected(
134          array_ops.squeeze,
135          np.array([[[1, 2], [3, 4]]], dtype=dtype),
136          expected=np.array([[1, 2], [3, 4]], dtype=dtype))
137
138      self._assertOpOutputMatchesExpected(
139          array_ops.stop_gradient,
140          np.array([[-1, 1]], dtype=dtype),
141          expected=np.array([[-1, 1]], dtype=dtype))
142
143  def testLog(self):
144    for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}:
145      tol = 1e-4 if dtype == np.float32 else 1e-9
146      # pylint: disable=invalid-unary-operand-type
147      x = np.linspace(-np.e, np.e, num=1000, dtype=dtype)
148      self._assertOpOutputMatchesExpected(
149          math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol)
150
151      x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype)
152      self._assertOpOutputMatchesExpected(
153          math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol)
154
155      x = np.linspace(0., np.pi * 1e30, num=1000, dtype=dtype)
156      self._assertOpOutputMatchesExpected(
157          math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol)
158
159  def testSin(self):
160    for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}:
161      tol = 1e-6 if dtype == np.float32 else 1e-12
162
163      x = np.linspace(-4 * np.e, 4 * np.e, num=1000, dtype=dtype)
164      self._assertOpOutputMatchesExpected(
165          math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol)
166
167      x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype)
168      self._assertOpOutputMatchesExpected(
169          math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol)
170
171      if dtype == np.float64:
172        x = np.linspace(0., np.e * 1e8, num=1000, dtype=dtype)
173        self._assertOpOutputMatchesExpected(
174            math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=1e-5)
175
176  def testCos(self):
177    for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}:
178      tol = 1e-6 if dtype == np.float32 else 1e-12
179
180      x = np.linspace(-4 * np.e, 4 * np.e, num=1000, dtype=dtype)
181      self._assertOpOutputMatchesExpected(
182          math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=tol)
183
184      x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype)
185      self._assertOpOutputMatchesExpected(
186          math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=tol)
187
188      if dtype == np.float64:
189        x = np.linspace(0., np.e * 1e8, num=1000, dtype=dtype)
190        self._assertOpOutputMatchesExpected(
191            math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5)
192
193  def testFloatOps(self):
194    for dtype in self.float_types:
195      x = np.arange(-0.90, 0.90, 0.25)
196      self._assertOpOutputMatchesExpected(
197          math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
198      self._assertOpOutputMatchesExpected(
199          math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype))
200      x = np.arange(-3, 3).reshape(1, 3, 2)
201      self._assertOpOutputMatchesExpected(
202          math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype))
203
204      self._assertOpOutputMatchesExpected(
205          math_ops.acosh,
206          np.array([1, 2, 3, 4], dtype=dtype),
207          expected=np.array(
208              [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype))
209
210      self._assertOpOutputMatchesExpected(
211          math_ops.asinh,
212          np.array([1, 2, 3, 4], dtype=dtype),
213          expected=np.array(
214              [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype))
215
216      self._assertOpOutputMatchesExpected(
217          math_ops.atanh,
218          np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype),
219          expected=np.array(
220              [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype))
221
222      self._assertOpOutputMatchesExpected(
223          math_ops.ceil,
224          np.array([[-1.7, 1.2]], dtype=dtype),
225          expected=np.array([[-1, 2]], dtype=dtype))
226
227      self._assertOpOutputMatchesExpected(
228          math_ops.cosh,
229          np.array([1, 2, 3, 4], dtype=dtype),
230          expected=np.array(
231              [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype))
232
233      # Disable float16 testing for now
234      if dtype != np.float16:
235        x = np.arange(-10, 10, 1).astype(dtype)
236        with self.session() as session:
237          erf_x = session.run(math_ops.erf(x))
238          erfc_x = session.run(math_ops.erfc(x))
239
240        self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x)
241        self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x)
242
243      self._assertOpOutputMatchesExpected(
244          math_ops.exp,
245          np.array([[-1, 1]], dtype=dtype),
246          expected=np.array([[0.36787945, 2.7182817]], dtype=dtype))
247
248      self._assertOpOutputMatchesExpected(
249          math_ops.expm1,
250          np.array([[-1, 1]], dtype=dtype),
251          expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype),
252          rtol=1e-5)
253
254      self._assertOpOutputMatchesExpected(
255          math_ops.floor,
256          np.array([[-1.7, 1.2]], dtype=dtype),
257          expected=np.array([[-2, 1]], dtype=dtype))
258
259      self._assertOpOutputMatchesExpected(
260          math_ops.is_finite,
261          np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
262                   dtype=dtype),
263          expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_))
264
265      # Tests for tf.nn ops.
266      self._assertOpOutputMatchesExpected(
267          nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
268
269      self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8))
270
271      self._assertOpOutputMatchesExpected(
272          nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
273
274      self._assertOpOutputMatchesExpected(
275          math_ops.reciprocal,
276          np.array([[1, 2]], dtype=dtype),
277          expected=np.array([[1, 0.5]], dtype=dtype))
278
279      self._assertOpOutputMatchesExpected(
280          math_ops.log,
281          np.array([[1, 2]], dtype=dtype),
282          expected=np.array([[0, 0.69314718]], dtype=dtype))
283
284      self._assertOpOutputMatchesExpected(
285          math_ops.sin,
286          np.array([[1, 2]], dtype=dtype),
287          expected=np.array([[0.841478, 0.909302]], dtype=dtype))
288
289      self._assertOpOutputMatchesExpected(
290          math_ops.cos,
291          np.array([[1, 2]], dtype=dtype),
292          expected=np.array([[0.540297, -0.41614]], dtype=dtype))
293
294      # Confirm that log1p will remain precise across a range of small values.
295      self._assertOpOutputMatchesExpected(
296          math_ops.log1p,
297          np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
298                   dtype=dtype),
299          expected=np.log1p(
300              np.array(
301                  [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
302                  dtype=dtype)).astype(dtype),
303          rtol=1e-15 if dtype == np.float64 else 1e-4,
304          atol=1e-15 if dtype == np.float64 else 1e-4)
305
306      self._assertOpOutputMatchesExpected(
307          math_ops.rint,
308          np.array(
309              [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
310               [0.5, 1.5, 2.5, 3.5]],
311              dtype=dtype),
312          expected=np.array(
313              [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
314      self._assertOpOutputMatchesExpected(
315          math_ops.round,
316          np.array(
317              [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
318               [0.5, 1.5, 2.5, 3.5]],
319              dtype=dtype),
320          expected=np.array(
321              [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
322
323      self._assertOpOutputMatchesExpected(
324          math_ops.rsqrt,
325          np.array([[4, 16]], dtype=dtype),
326          expected=np.array([[0.5, 0.25]], dtype=dtype))
327
328      self._assertOpOutputMatchesExpected(
329          math_ops.sigmoid,
330          np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
331          expected=np.array(
332              [[0.7310586, 0.7310586, 0.7310586, 0.7310586],
333               [0.7310586, 0.880797, 0.95257413, 0.98201376]],
334              dtype=dtype))
335
336      self._assertOpOutputMatchesExpected(
337          math_ops.sigmoid,
338          np.array([-300, -150, 0, 150, 300], dtype=dtype),
339          expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype))
340
341      self._assertOpOutputMatchesExpected(
342          math_ops.sinh,
343          np.array([1, 2, 3, 4], dtype=dtype),
344          expected=np.array(
345              [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype))
346
347      self._assertOpOutputMatchesExpected(
348          math_ops.sqrt,
349          np.array([[4, 9]], dtype=dtype),
350          expected=np.array([[2, 3]], dtype=dtype))
351
352      self._assertOpOutputMatchesExpected(
353          math_ops.tan,
354          np.array([1, 2, 3, 4], dtype=dtype),
355          expected=np.array(
356              [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
357
358      self._assertOpOutputMatchesExpected(
359          math_ops.tanh,
360          np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20],
361                    [19, -19, 22, -22]],
362                   dtype=dtype),
363          expected=np.array(
364              [[0.76159418, 0.96402758, 0.99505478, 0.99932933],
365               [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]],
366              dtype=dtype))
367
368      self._assertOpOutputMatchesExpected(
369          nn_ops.log_softmax,
370          np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
371          expected=np.array(
372              [[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
373               [-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
374              dtype=dtype))
375
376      self._assertOpOutputMatchesExpected(
377          nn_ops.elu,
378          np.array([[-1, 0, 1, -1e-6]], dtype=dtype),
379          expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype),
380          rtol=1e-5,
381          atol=1e-6)
382
383      self._assertOpOutputMatchesExpected(
384          nn_ops.selu,
385          np.array([[-1, 0, 1, -1e-5]], dtype=dtype),
386          expected=np.array(
387              [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]],
388              dtype=dtype),
389          rtol=1e-5,
390          atol=1e-6)
391
392      self._assertOpOutputMatchesExpected(
393          nn_ops.relu,
394          np.array([[-1, 1]], dtype=dtype),
395          expected=np.array([[0, 1]], dtype=dtype))
396
397      self._assertOpOutputMatchesExpected(
398          nn_ops.relu6,
399          np.array([[-0.05, 6.05, 5]], dtype=dtype),
400          expected=np.array([[0, 6, 5]], dtype=dtype))
401
402      self._assertOpOutputMatchesExpected(
403          nn_ops.leaky_relu,
404          np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
405          expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype))
406
407      self._assertOpOutputMatchesExpected(
408          nn_ops.softmax,
409          np.array([1, 2, 3, 4], dtype=dtype),
410          expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428],
411                            dtype=dtype))
412
413      self._assertOpOutputMatchesExpected(
414          nn_ops.softmax,
415          np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
416          expected=np.array(
417              [[0.25, 0.25, 0.25, 0.25],
418               [0.032058604, 0.087144323, 0.23688284, 0.64391428]],
419              dtype=dtype))
420
421      self._assertOpOutputMatchesExpected(
422          nn_ops.softmax,
423          np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype),
424          expected=np.array(
425              [[[0.5, 0.5], [0.5, 0.5]],
426               [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]],
427              dtype=dtype))
428
429      self._assertOpOutputMatchesExpected(
430          nn_ops.softsign,
431          np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
432          expected=np.array(
433              [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
434
435      self._assertOpOutputMatchesExpected(
436          math_ops.sign,
437          np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0,
438                     float("nan")]],
439                   dtype=dtype),
440          expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0,
441                              float("nan")]],
442                            dtype=dtype))
443
444      self._assertOpOutputMatchesExpected(
445          math_ops.is_finite,
446          np.array([[42, float("inf"), -123], [float("nan"), 0, -0.0]],
447                   dtype=dtype),
448          expected=np.array([[True, False, True], [False, True, True]],
449                            dtype=np.bool_))
450
451      self._assertOpOutputMatchesExpected(
452          math_ops.lgamma,
453          np.array(0.5, dtype=dtype),
454          expected=np.array(np.log(np.pi) / 2, dtype=dtype))
455
456      self._assertOpOutputMatchesExpected(
457          math_ops.lgamma,
458          np.array(
459              [[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2],
460               [-3 / 2, -7 / 2, -11 / 2]],
461              dtype=dtype),
462          expected=np.array(
463              [
464                  [0, 0, np.log(2.0)],
465                  [np.log(6.0), np.log(24.0),
466                   np.log(120)],
467                  [
468                      np.log(np.pi) / 2,
469                      np.log(np.pi) / 2 - np.log(2),
470                      np.log(np.pi) / 2 - np.log(4) + np.log(3)
471                  ],
472                  [
473                      np.log(np.pi) / 2 - np.log(3) + np.log(4),
474                      np.log(np.pi) / 2 - np.log(105) + np.log(16),
475                      np.log(np.pi) / 2 - np.log(10395) + np.log(64),
476                  ],
477              ],
478              dtype=dtype))
479
480      # The actual result is complex. Take the real part.
481      self._assertOpOutputMatchesExpected(
482          math_ops.lgamma,
483          np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype),
484          expected=np.array(
485              [
486                  np.log(np.pi) / 2 + np.log(2),
487                  np.log(np.pi) / 2 - np.log(15) + np.log(8),
488                  np.log(np.pi) / 2 - np.log(945) + np.log(32),
489              ],
490              dtype=dtype),
491          atol=1e-4)
492
493      self._assertOpOutputMatchesExpected(
494          math_ops.digamma,
495          np.array(
496              [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0],
497               [6.0, 8.0, 9.0]],
498              dtype=dtype),
499          expected=np.array(
500              [
501                  [
502                      -np.euler_gamma, -2 * np.log(2) - np.euler_gamma,
503                      -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 -
504                      np.euler_gamma
505                  ],
506                  [
507                      -np.pi / 2 - 3 * np.log(2) - np.euler_gamma,
508                      -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) -
509                      3 * np.log(3) / 2 - np.euler_gamma,
510                      -np.pi / 2 - 4 * np.log(2) -
511                      (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2)))
512                      / np.sqrt(2) - np.euler_gamma
513                  ],
514                  [
515                      1 - np.euler_gamma, 1.5 - np.euler_gamma,
516                      11 / 6.0 - np.euler_gamma
517                  ],
518                  [
519                      137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma,
520                      761 / 280.0 - np.euler_gamma
521                  ],
522              ],
523              dtype=dtype))
524
525  def testSigmoidNumericalStability(self):
526    for dtype in self.float_types:
527      if dtype != np.float16:
528        self._assertOpOutputMatchesExpected(
529            lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)),
530            np.array([-40, 40], dtype=dtype),
531            expected=np.array([1.0, 0.025], dtype=dtype))
532
533  def testQuantizeAndDequantize(self):
534    for dtype in self.float_types:
535
536      def quantize_and_dequantize_v2(x):
537        return array_ops.quantize_and_dequantize(
538            x, -127, 127, signed_input=True, num_bits=8)
539
540      def quantize_and_dequantize_v3(x):
541        return array_ops.quantize_and_dequantize_v3(
542            x, -127, 127, num_bits=8, signed_input=True, range_given=False)
543
544      def quantize_and_dequantize_v4(x):
545        return array_ops.quantize_and_dequantize_v2(
546            x, -127, 127, signed_input=True, num_bits=8)
547
548      test_fns = (quantize_and_dequantize_v2, quantize_and_dequantize_v3,
549                  quantize_and_dequantize_v4)
550      for test_fn in test_fns:
551        self._assertOpOutputMatchesExpected(
552            test_fn,
553            np.array([-1, -0.5, 0, 0.3], dtype=dtype),
554            expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
555
556      def quantize_and_dequantize_v2_round_half_up(x):
557        return array_ops.quantize_and_dequantize(
558            x,
559            -1,
560            1.0,
561            signed_input=True,
562            num_bits=8,
563            range_given=True,
564            round_mode="HALF_UP")
565
566      self._assertOpOutputMatchesExpected(
567          quantize_and_dequantize_v2_round_half_up,
568          np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype),
569          expected=np.array([
570              -102.0 / 127,
571              -63.0 / 127,
572              0,
573              38.0 / 127,
574              102.0 / 127,
575              -128.0 / 127,
576              1,
577          ],
578                            dtype=dtype))
579
580      def quantize_and_dequantize_v2_round_half_to_even(x):
581        return array_ops.quantize_and_dequantize(
582            x,
583            -1.0,
584            1.0,
585            signed_input=True,
586            num_bits=8,
587            range_given=True,
588            round_mode="HALF_TO_EVEN")
589
590      self._assertOpOutputMatchesExpected(
591          quantize_and_dequantize_v2_round_half_to_even,
592          np.array(
593              [
594                  -0.8,
595                  # The -0.5 should become -63.5 after scaling and with
596                  # rounding this should become -64. But with the test
597                  # unary_ops_test_cpu_ondemand, this fails as the result
598                  # before scaling becomes -63.499996 and gets rounded to -63.
599                  # TODO(sreenik): Some one more familiar with this test needs
600                  # to take a look and resolve this. This works on all other
601                  # variations of the platform like cpu, and gpu.
602                  # -0.5,
603                  0,
604                  0.3,
605                  0.8,
606                  -2,
607                  33
608              ],
609              dtype=dtype),
610          expected=np.array(
611              [
612                  -102.0 / 127,
613                  # -64.0 / 127,
614                  0,
615                  38.0 / 127,
616                  102.0 / 127,
617                  -128.0 / 127,
618                  1,
619              ],
620              dtype=dtype))
621
622  def testComplexOps(self):
623    for dtype in self.complex_types:
624
625      self._assertOpOutputMatchesExpected(
626          math_ops.acosh,
627          np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
628          expected=np.arccosh(
629              np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
630
631      self._assertOpOutputMatchesExpected(
632          math_ops.asinh,
633          np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
634          expected=np.arcsinh(
635              np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
636
637      self._assertOpOutputMatchesExpected(
638          math_ops.atanh,
639          np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
640          expected=np.arctanh(
641              np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))
642
643      self._assertOpOutputMatchesExpected(
644          math_ops.cosh,
645          np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype),
646          expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype)))
647
648      self._assertOpOutputMatchesExpected(
649          math_ops.sinh,
650          np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
651          expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
652
653      self._assertOpOutputMatchesExpected(
654          math_ops.exp,
655          np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
656          expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)))
657
658      self._assertOpOutputMatchesExpected(
659          math_ops.expm1,
660          np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype),
661          expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)),
662          rtol=1e-6,
663          atol=1e-6)
664
665      # For real part close to zero, or imaginary part close to a multiple of
666      # pi.
667
668      self._assertOpOutputMatchesExpected(
669          math_ops.expm1,
670          np.array([[1e-11 + 1j, -1e-11 - 1j, 1. + 1e-11j,
671                     -1. - 1e-11j, 1e-13j + 1e-13j]], dtype=dtype),
672          # TODO(srvasude): Use numpy as the source of truth after we depend on
673          # latest numpy with this pull request:
674          # https://github.com/numpy/numpy/pull/15110.
675          # The numbers below were generated by scipy.special.expm1.
676          expected=np.array([[
677              -4.59697694e-01+8.41470985e-01j,
678              -4.59697694e-01-8.41470985e-01j,
679              1.71828183e+00+2.71828183e-11j,
680              -6.32120559e-01-3.67879441e-12j,
681              -2.00000000e-26+2.00000000e-13j]], dtype=dtype),
682          rtol=1e-09,
683          atol=1e-20)
684
685      self._assertOpOutputMatchesExpected(
686          math_ops.reciprocal,
687          np.array([[1, 2j, 2 + 3j]], dtype=dtype),
688          expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype))
689
690      self._assertOpOutputMatchesExpected(
691          math_ops.log,
692          np.array([[5j, 3 - 2j]], dtype=dtype),
693          expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype)))
694
695      self._assertOpOutputMatchesExpected(
696          math_ops.sin,
697          np.array([[5j, 3 - 2j]], dtype=dtype),
698          expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype)))
699
700      self._assertOpOutputMatchesExpected(
701          math_ops.cos,
702          np.array([[5j, 3 - 2j]], dtype=dtype),
703          expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype)))
704
705      self._assertOpOutputMatchesExpected(
706          math_ops.log1p,
707          np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype),
708          expected=np.log1p(
709              np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)),
710          rtol=1e-4,
711          atol=1e-6)
712
713      val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
714      self._assertOpOutputMatchesExpected(
715          math_ops.rsqrt, val, expected=1 / np.sqrt(val))
716
717      self._assertOpOutputMatchesExpected(
718          math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))
719
720      self._assertOpOutputMatchesExpected(
721          math_ops.sqrt, val, expected=np.sqrt(val))
722
723      self._assertOpOutputMatchesExpected(
724          math_ops.tanh,
725          np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
726          expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
727
728      self._assertOpOutputMatchesExpected(
729          math_ops.tan,
730          np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
731          expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))
732
733      ctypes = {np.complex64: np.float32, np.complex128: np.float64}
734      self._assertOpOutputMatchesExpected(
735          math_ops.abs,
736          np.array([[3 - 4j, -1j, np.inf]], dtype=dtype),
737          expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype]))
738
739      self._assertOpOutputMatchesExpected(
740          math_ops.negative,
741          np.array([[-1 + 2j, -3j]], dtype=dtype),
742          expected=np.array([[1 - 2j, 3j]], dtype=dtype))
743
744      self._assertOpOutputMatchesExpected(
745          math_ops.square,
746          np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype),
747          expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2)
748
749      self._assertOpOutputMatchesExpected(
750          array_ops.zeros_like,
751          np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype),
752          expected=np.array([[0, 0], [0, 0]], dtype=dtype))
753
754      self._assertOpOutputMatchesExpected(
755          array_ops.ones_like,
756          np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype),
757          expected=np.array([[1, 1], [1, 1]], dtype=dtype))
758
759      self._assertOpOutputMatchesExpected(
760          math_ops.angle,
761          np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
762          expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype)))
763
764      self._assertOpOutputMatchesExpected(
765          math_ops.conj,
766          np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
767          expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype))
768
769      self._assertOpOutputMatchesExpected(
770          math_ops.imag,
771          np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
772          expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype]))
773
774      self._assertOpOutputMatchesExpected(
775          math_ops.real,
776          np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
777          expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
778
779  def testIntOps(self):
780    for dtype in self.int_types:
781      self._assertOpOutputMatchesExpected(
782          bitwise_ops.invert,
783          np.array([0, -1, 1, 16, 42], dtype=dtype),
784          expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
785
786      # Test population_count for array inputs.
787      raw_inputs = [
788          0, 1, -1, 3, -3, 5, -5, 14, -14, 127, 128, 255, 256, 65535, 65536,
789          2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32, -2**63 + 1,
790          2**63 - 1
791      ]
792      # Only choose inputs which fit in the int dtype.
793      raw_inputs = list(
794          filter(lambda x: np.iinfo(dtype).min <= x <= np.iinfo(dtype).max,
795                 raw_inputs))
796      inputs = np.array(raw_inputs, dtype=dtype)
797
798      def count_bits(x):
799        return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes()))
800
801      truth = [count_bits(x) for x in inputs]
802      self._assertOpOutputMatchesExpected(
803          bitwise_ops.population_count,
804          inputs,
805          expected=np.array(truth, dtype=np.uint8),
806          equality_test=self.AssertAllEqual)
807
808      # Test population_count for scalar inputs.
809      for raw_inp in raw_inputs:
810        inp = dtype(raw_inp)
811        truth = count_bits(inp)
812        self._assertOpOutputMatchesExpected(
813            bitwise_ops.population_count,
814            inp,
815            expected=np.uint8(truth),
816            equality_test=self.AssertAllEqual)
817
818  def testNumericOps(self):
819    for dtype in self.numeric_types - {np.int8, np.uint8}:
820      self._assertOpOutputMatchesExpected(
821          math_ops.abs,
822          np.array([[2, -1]], dtype=dtype),
823          expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype))
824
825      self._assertOpOutputMatchesExpected(
826          math_ops.negative,
827          np.array([[-1, 1]], dtype=dtype),
828          expected=np.array([[1, -1]], dtype=dtype))
829
830      self._assertOpOutputMatchesExpected(
831          math_ops.square,
832          np.array([[-2, 3]], dtype=dtype),
833          expected=np.array([[4, 9]], dtype=dtype))
834
835      self._assertOpOutputMatchesExpected(
836          array_ops.zeros_like,
837          np.array([[4, 3], [2, 1]], dtype=dtype),
838          expected=np.array([[0, 0], [0, 0]], dtype=dtype))
839
840      self._assertOpOutputMatchesExpected(
841          array_ops.ones_like,
842          np.array([[4, 3], [2, 1]], dtype=dtype),
843          expected=np.array([[1, 1], [1, 1]], dtype=dtype))
844
845  # TODO(phawkins): these tests fail unless fastmath optimizations
846  # are disabled. Use more robust IsInf/IsNaN detection and enable these
847  # tests.
848  @unittest.skip("test case fails in fast-math mode")
849  def testIsInfAndIsNan(self):
850    for dtype in self.float_types:
851      self._assertOpOutputMatchesExpected(
852          math_ops.is_inf,
853          np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
854                   dtype=dtype),
855          expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool_))
856      self._assertOpOutputMatchesExpected(
857          math_ops.is_nan,
858          np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
859                   dtype=dtype),
860          expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool_))
861      self._assertOpOutputMatchesExpected(
862          math_ops.sign,
863          np.array([[np.nan]], dtype=dtype),
864          expected=np.array([[0.0]], dtype=dtype))
865
866  def testLogicalOps(self):
867    self._assertOpOutputMatchesExpected(
868        math_ops.logical_not,
869        np.array([[True, False], [False, True]], dtype=np.bool_),
870        expected=np.array([[False, True], [True, False]], dtype=np.bool_))
871
872  def testBiasAddGrad(self):
873    self._assertOpOutputMatchesExpected(
874        gen_nn_ops.bias_add_grad,
875        np.array([[1., 2.], [3., 4.]], dtype=np.float32),
876        expected=np.array([4., 6.], dtype=np.float32))
877
878    self._assertOpOutputMatchesExpected(
879        lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
880        np.array(
881            [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
882        expected=np.array([14., 22.], dtype=np.float32))
883
884  def testCast(self):
885    shapes = [[], [4], [2, 3], [2, 0, 4]]
886    types = {
887        dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
888        dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
889    }
890    for src_type in types:
891      for dst_type in types:
892        src_np_dtype = src_type.as_numpy_dtype
893        dst_np_dtype = dst_type.as_numpy_dtype
894
895        for shape in shapes:
896          src = np.arange(np.prod(shape)).astype(src_np_dtype)
897
898          if src_type in self.complex_tf_types:
899            src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype)
900          src = src.reshape(shape)
901          dst = src.astype(dst_np_dtype)
902          self._assertOpOutputMatchesExpected(
903              lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
904              src,
905              expected=dst)
906
907        # Check special values.
908        if src_type.is_integer:
909          imin = np.iinfo(src_np_dtype).min
910          imax = np.iinfo(src_np_dtype).max
911          src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype)
912        elif src_type in self.float_tf_types:
913          if dst_type.is_integer:
914            imin = np.iinfo(dst_np_dtype).min
915            imax = np.iinfo(dst_np_dtype).max // 2
916            src = np.array([imin, imax, 0, 1], dtype=src_np_dtype)
917          elif dst_type in self.float_tf_types:
918            fmin = np.finfo(dst_np_dtype).min
919            fmax = np.finfo(dst_np_dtype).max
920            tiny = np.finfo(dst_np_dtype).tiny
921            eps = np.finfo(dst_np_dtype).eps
922            src = np.array(
923                [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf],
924                dtype=src_np_dtype)
925        dst = src.astype(dst_np_dtype)
926        self._assertOpOutputMatchesExpected(
927            lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
928            src,
929            expected=dst)
930
931  def testBitcast(self):
932    self._assertOpOutputMatchesExpected(
933        lambda x: array_ops.bitcast(x, dtypes.int32),
934        np.array([1, 0x3f800000], np.int32),
935        expected=np.array([1, 0x3f800000], np.int32))
936    self._assertOpOutputMatchesExpected(
937        lambda x: array_ops.bitcast(x, dtypes.float32),
938        np.array([1, 0x3f800000], np.int32),
939        expected=np.array([1e-45, 1.0], np.float32))
940    self._assertOpOutputMatchesExpected(
941        lambda x: array_ops.bitcast(x, dtypes.int32),
942        np.array([1e-45, 1.0], np.float32),
943        expected=np.array([1, 0x3f800000], np.int32))
944    if np.int64 in self.numeric_types:
945      self._assertOpOutputMatchesExpected(
946          lambda x: array_ops.bitcast(x, dtypes.int64),
947          np.array([1, 0x100000003f800000], np.uint64),
948          expected=np.array([1, 0x100000003f800000], np.int64))
949      self._assertOpOutputMatchesExpected(
950          lambda x: array_ops.bitcast(x, dtypes.uint64),
951          np.array([1, 0x100000003f800000], np.int64),
952          expected=np.array([1, 0x100000003f800000], np.uint64))
953
954  def testBitcastInt8ToFloat(self):
955    self._assertOpOutputMatchesExpected(
956        lambda x: array_ops.bitcast(x, dtypes.float32),
957        np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8),
958        expected=np.array([1e-45, 3.14159], np.float32))
959    self._assertOpOutputMatchesExpected(
960        lambda x: array_ops.bitcast(x, dtypes.np.int8),
961        np.array([1e-45, 3.14159], np.float32),
962        expected=np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8))
963
964  def testInvertPermutation(self):
965    for np_dtype in [np.int32, np.int64]:
966      self._assertOpOutputMatchesExpected(
967          array_ops.invert_permutation,
968          np.array([1, 2, 0], np_dtype),
969          expected=np.array([2, 0, 1], dtype=np_dtype))
970
971  def testInvertPermutationTwiceIsNoop(self):
972
973    def invert_twice(x):
974      return array_ops.invert_permutation(array_ops.invert_permutation(x))
975
976    for np_dtype in [np.int32, np.int64]:
977      self._assertOpOutputMatchesExpected(
978          invert_twice,
979          np.array([1, 2, 0], np_dtype),
980          expected=np.array([1, 2, 0], dtype=np_dtype))
981
982  def testRank(self):
983    rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
984    for dtype in self.numeric_types:
985      self._assertOpOutputMatchesExpected(
986          rank_op, dtype(7), expected=np.int32(0))
987      self._assertOpOutputMatchesExpected(
988          rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2))
989      self._assertOpOutputMatchesExpected(
990          rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1))
991      self._assertOpOutputMatchesExpected(
992          rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
993      self._assertOpOutputMatchesExpected(
994          rank_op,
995          np.array([[-1], [1], [4]], dtype=dtype),
996          expected=np.int32(2))
997
998  def testShape(self):
999    shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
1000    for dtype in self.numeric_types:
1001      self._assertOpOutputMatchesExpected(
1002          shape_op, dtype(7), expected=np.array([], dtype=np.int32))
1003      self._assertOpOutputMatchesExpected(
1004          shape_op,
1005          np.array([[], []], dtype=dtype),
1006          expected=np.array([2, 0], dtype=np.int32))
1007      self._assertOpOutputMatchesExpected(
1008          shape_op,
1009          np.array([-1, 1], dtype=dtype),
1010          expected=np.array([2], dtype=np.int32))
1011      self._assertOpOutputMatchesExpected(
1012          shape_op,
1013          np.array([[-1, 1]], dtype=dtype),
1014          expected=np.array([1, 2], dtype=np.int32))
1015      self._assertOpOutputMatchesExpected(
1016          shape_op,
1017          np.array([[-1], [1], [4]], dtype=dtype),
1018          expected=np.array([3, 1], dtype=np.int32))
1019
1020  def testSize(self):
1021    size_op = lambda x: array_ops.size_internal(x, optimize=False)
1022    for dtype in self.numeric_types:
1023      self._assertOpOutputMatchesExpected(
1024          size_op, dtype(7), expected=np.int32(1))
1025      self._assertOpOutputMatchesExpected(
1026          size_op, np.array([[], []], dtype=dtype), expected=np.int32(0))
1027      self._assertOpOutputMatchesExpected(
1028          size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2))
1029      self._assertOpOutputMatchesExpected(
1030          size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
1031      self._assertOpOutputMatchesExpected(
1032          size_op,
1033          np.array([[-1], [1], [4]], dtype=dtype),
1034          expected=np.int32(3))
1035
1036  def testSizeWithInt64OutType(self):
1037
1038    def size_op(x):
1039      return array_ops.size_internal(x, optimize=False, out_type=np.int64)
1040
1041    for dtype in self.numeric_types:
1042      self._assertOpOutputMatchesExpected(
1043          size_op,
1044          np.array([[-1], [1], [4]], dtype=dtype),
1045          expected=np.int64(3))
1046
1047  def testUnpack(self):
1048    self._assertOpOutputMatchesExpected(
1049        array_ops.unstack,
1050        np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
1051        expected=[
1052            np.array([1., 2.], dtype=np.float32),
1053            np.array([3., 4.], dtype=np.float32),
1054            np.array([5., 6.], dtype=np.float32),
1055        ],
1056        equality_test=self.ListsAreClose)
1057
1058    self._assertOpOutputMatchesExpected(
1059        lambda x: array_ops.unstack(x, axis=1),
1060        np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
1061        expected=[
1062            np.array([1., 3., 5.], dtype=np.float32),
1063            np.array([2., 4., 6.], dtype=np.float32),
1064        ],
1065        equality_test=self.ListsAreClose)
1066
1067  def testDepthToSpace(self):
1068
1069    def make_op(data_format):
1070
1071      def op(x):
1072        return array_ops.depth_to_space(
1073            x, block_size=2, data_format=data_format)
1074
1075      return op
1076
1077    for dtype in self.numeric_types:
1078      for data_format in ["NCHW", "NHWC"]:
1079        self._assertOpOutputMatchesExpected(
1080            make_op(data_format),
1081            nhwc_to_format(
1082                np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format),
1083            expected=nhwc_to_format(
1084                np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format))
1085
1086        self._assertOpOutputMatchesExpected(
1087            make_op(data_format),
1088            nhwc_to_format(
1089                np.array(
1090                    [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
1091                data_format),
1092            expected=nhwc_to_format(
1093                np.array(
1094                    [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
1095                    dtype=dtype), data_format))
1096
1097        self._assertOpOutputMatchesExpected(
1098            make_op(data_format),
1099            nhwc_to_format(
1100                np.array(
1101                    [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
1102                                                     [13, 14, 15, 16]]]],
1103                    dtype=dtype), data_format),
1104            expected=nhwc_to_format(
1105                np.array(
1106                    [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
1107                      [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
1108                    dtype=dtype), data_format))
1109
1110      self._assertOpOutputMatchesExpected(
1111          make_op("NCHW_VECT_C"),
1112          np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
1113          expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]],
1114                               [[16, 17, 18, 19], [24, 25, 26, 27]]],
1115                              [[[4, 5, 6, 7], [12, 13, 14, 15]],
1116                               [[20, 21, 22, 23], [28, 29, 30, 31]]]]],
1117                            dtype=dtype))
1118
1119  def testSpaceToDepth(self):
1120
1121    def make_op(data_format):
1122
1123      def op(x):
1124        return array_ops.space_to_depth(
1125            x, block_size=2, data_format=data_format)
1126
1127      return op
1128
1129    for dtype in self.numeric_types:
1130      for data_format in ["NCHW", "NHWC"]:
1131        self._assertOpOutputMatchesExpected(
1132            make_op(data_format),
1133            nhwc_to_format(
1134                np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format),
1135            expected=nhwc_to_format(
1136                np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format))
1137
1138        self._assertOpOutputMatchesExpected(
1139            make_op(data_format),
1140            nhwc_to_format(
1141                np.array(
1142                    [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
1143                    dtype=dtype), data_format),
1144            expected=nhwc_to_format(
1145                np.array(
1146                    [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
1147                data_format))
1148
1149        self._assertOpOutputMatchesExpected(
1150            make_op(data_format),
1151            nhwc_to_format(
1152                np.array(
1153                    [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
1154                      [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
1155                    dtype=dtype), data_format),
1156            expected=nhwc_to_format(
1157                np.array(
1158                    [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
1159                                                     [13, 14, 15, 16]]]],
1160                    dtype=dtype), data_format))
1161
1162      self._assertOpOutputMatchesExpected(
1163          make_op("NCHW_VECT_C"),
1164          np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
1165          expected=np.array(
1166              [[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]],
1167                [[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]],
1168                [[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]],
1169              dtype=dtype))
1170
1171  def _assertSoftplusMatchesExpected(self,
1172                                     features,
1173                                     dtype,
1174                                     equality_test=None,
1175                                     rtol=1e-6,
1176                                     atol=9.1e-6):
1177    features = np.array(features, dtype=dtype)
1178    zero = np.asarray(0).astype(dtype)
1179    expected = np.logaddexp(zero, features).astype(dtype)
1180    self._assertOpOutputMatchesExpected(
1181        nn_ops.softplus,
1182        features,
1183        expected=expected,
1184        equality_test=equality_test,
1185        rtol=rtol,
1186        atol=atol)
1187
1188  def testSoftplus(self):
1189    for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
1190      self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
1191      self._assertSoftplusMatchesExpected(
1192          [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype)
1193      if dtype == dtypes.bfloat16.as_numpy_dtype:
1194        log_eps = np.log(np.finfo(np.float32).eps)
1195      else:
1196        log_eps = np.log(np.finfo(dtype).eps)
1197      one = dtype(1)
1198      ten = dtype(10)
1199      self._assertSoftplusMatchesExpected([
1200          log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten,
1201          -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten,
1202          -log_eps + ten
1203      ], dtype)
1204
1205      self._assertSoftplusMatchesExpected(
1206          [0.69302183, 0.69324386],
1207          dtype,
1208          equality_test=self.AssertCloseAndSorted,
1209          rtol=9e-5,
1210          atol=9e-5)
1211
1212  def testToBool(self):
1213    for dtype in self.numeric_types - self.complex_types:
1214      self._assertOpOutputMatchesExpected(
1215          gen_functional_ops.to_bool,
1216          np.array(5, dtype=dtype),
1217          expected=np.array(True))
1218
1219      self._assertOpOutputMatchesExpected(
1220          gen_functional_ops.to_bool,
1221          np.array(0, dtype=dtype),
1222          expected=np.array(False))
1223
1224      self._assertOpOutputMatchesExpected(
1225          gen_functional_ops.to_bool,
1226          np.array([], dtype=dtype),
1227          expected=np.array(False))
1228
1229      self._assertOpOutputMatchesExpected(
1230          gen_functional_ops.to_bool,
1231          np.array([1, 2, 3], dtype=dtype),
1232          expected=np.array(True))
1233
1234
1235if __name__ == "__main__":
1236  googletest.main()
1237