1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for XLA JIT compiler.""" 16 17import unittest 18 19import numpy as np 20import six 21 22from tensorflow.compiler.tests import xla_test 23from tensorflow.python.framework import dtypes 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import bitwise_ops 26from tensorflow.python.ops import gen_functional_ops 27from tensorflow.python.ops import gen_nn_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_ops 30from tensorflow.python.platform import googletest 31 32 33def nhwc_to_format(x, data_format): 34 """Converts a numpy array from NHWC format to `data_format`.""" 35 rank = len(x.shape) 36 if data_format == "NCHW": 37 return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) 38 elif data_format == "NHWC": 39 return x 40 else: 41 raise ValueError("Unknown format {}".format(data_format)) 42 43 44class UnaryOpsTest(xla_test.XLATestCase): 45 """Test cases for unary operators.""" 46 47 def _assertOpOutputMatchesExpected(self, 48 op, 49 inp, 50 expected, 51 equality_test=None, 52 rtol=1e-3, 53 atol=1e-5): 54 """Verifies that 'op' produces 'expected' when fed input 'inp' . 55 56 Args: 57 op: operator to test 58 inp: numpy input array to use as input to 'op'. 59 expected: numpy array representing the expected output of 'op'. 60 equality_test: either None, or a function that tests two numpy arrays for 61 equality. If None, self.assertAllClose is used. 62 rtol: relative tolerance for equality test. 63 atol: absolute tolerance for equality test. 64 """ 65 with self.session() as session: 66 with self.test_scope(): 67 pinp = array_ops.placeholder( 68 dtypes.as_dtype(inp.dtype), inp.shape, name="a") 69 output = op(pinp) 70 result = session.run(output, {pinp: inp}) 71 if equality_test is None: 72 self.assertEqual(output.dtype, expected.dtype) 73 self.assertAllCloseAccordingToType( 74 expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) 75 else: 76 equality_test(result, expected, rtol=rtol, atol=atol) 77 78 def ListsAreClose(self, result, expected, rtol, atol): 79 """Tests closeness of two lists of floats.""" 80 self.assertEqual(len(result), len(expected)) 81 for i in range(len(result)): 82 self.assertAllClose(result[i], expected[i], rtol, atol) 83 84 def AssertCloseAndSorted(self, result, expected, rtol, atol): 85 """Tests that result and expeted are both close and sorted.""" 86 self.assertAllClose(result, expected, rtol, atol) 87 self.assertAllEqual(np.sort(result), result) 88 89 def AssertAllEqual(self, result, expected, rtol, atol): 90 """Tests that result and expeted are exactly equal.""" 91 self.assertAllEqual(result, expected) 92 93 def testAllTypeOps(self): 94 for dtype in self.numeric_types - {np.int8, np.uint8}: 95 self._assertOpOutputMatchesExpected( 96 array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), 97 np.array( 98 [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], 99 dtype=dtype)) 100 self._assertOpOutputMatchesExpected( 101 array_ops.diag_part, 102 np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), 103 np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) 104 self._assertOpOutputMatchesExpected( 105 array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), 106 np.array( 107 [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], 108 [[0, 0], [0, 4]]]], 109 dtype=dtype)) 110 111 self._assertOpOutputMatchesExpected( 112 array_ops.identity, 113 np.array([[-1, 1]], dtype=dtype), 114 expected=np.array([[-1, 1]], dtype=dtype)) 115 116 self._assertOpOutputMatchesExpected( 117 array_ops.prevent_gradient, 118 np.array([[-1, 1]], dtype=dtype), 119 expected=np.array([[-1, 1]], dtype=dtype)) 120 121 self._assertOpOutputMatchesExpected( 122 array_ops.squeeze, 123 np.array([[[[[]]]]], dtype=dtype), 124 expected=np.array([], dtype=dtype)) 125 self._assertOpOutputMatchesExpected( 126 array_ops.squeeze, 127 np.array([[[1], [2]]], dtype=dtype), 128 expected=np.array([1, 2], dtype=dtype)) 129 self._assertOpOutputMatchesExpected( 130 array_ops.squeeze, 131 np.array([[[1]], [[2]]], dtype=dtype), 132 expected=np.array([1, 2], dtype=dtype)) 133 self._assertOpOutputMatchesExpected( 134 array_ops.squeeze, 135 np.array([[[1, 2], [3, 4]]], dtype=dtype), 136 expected=np.array([[1, 2], [3, 4]], dtype=dtype)) 137 138 self._assertOpOutputMatchesExpected( 139 array_ops.stop_gradient, 140 np.array([[-1, 1]], dtype=dtype), 141 expected=np.array([[-1, 1]], dtype=dtype)) 142 143 def testLog(self): 144 for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}: 145 tol = 1e-4 if dtype == np.float32 else 1e-9 146 # pylint: disable=invalid-unary-operand-type 147 x = np.linspace(-np.e, np.e, num=1000, dtype=dtype) 148 self._assertOpOutputMatchesExpected( 149 math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) 150 151 x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype) 152 self._assertOpOutputMatchesExpected( 153 math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) 154 155 x = np.linspace(0., np.pi * 1e30, num=1000, dtype=dtype) 156 self._assertOpOutputMatchesExpected( 157 math_ops.log, x, expected=np.log(x), atol=tol, rtol=tol) 158 159 def testSin(self): 160 for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}: 161 tol = 1e-6 if dtype == np.float32 else 1e-12 162 163 x = np.linspace(-4 * np.e, 4 * np.e, num=1000, dtype=dtype) 164 self._assertOpOutputMatchesExpected( 165 math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol) 166 167 x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype) 168 self._assertOpOutputMatchesExpected( 169 math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=tol) 170 171 if dtype == np.float64: 172 x = np.linspace(0., np.e * 1e8, num=1000, dtype=dtype) 173 self._assertOpOutputMatchesExpected( 174 math_ops.sin, x, expected=np.sin(x), rtol=tol, atol=1e-5) 175 176 def testCos(self): 177 for dtype in self.float_types - {dtypes.bfloat16.as_numpy_dtype}: 178 tol = 1e-6 if dtype == np.float32 else 1e-12 179 180 x = np.linspace(-4 * np.e, 4 * np.e, num=1000, dtype=dtype) 181 self._assertOpOutputMatchesExpected( 182 math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=tol) 183 184 x = np.linspace(0., np.e * 1e-30, num=1000, dtype=dtype) 185 self._assertOpOutputMatchesExpected( 186 math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=tol) 187 188 if dtype == np.float64: 189 x = np.linspace(0., np.e * 1e8, num=1000, dtype=dtype) 190 self._assertOpOutputMatchesExpected( 191 math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5) 192 193 def testFloatOps(self): 194 for dtype in self.float_types: 195 x = np.arange(-0.90, 0.90, 0.25) 196 self._assertOpOutputMatchesExpected( 197 math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) 198 self._assertOpOutputMatchesExpected( 199 math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype)) 200 x = np.arange(-3, 3).reshape(1, 3, 2) 201 self._assertOpOutputMatchesExpected( 202 math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype)) 203 204 self._assertOpOutputMatchesExpected( 205 math_ops.acosh, 206 np.array([1, 2, 3, 4], dtype=dtype), 207 expected=np.array( 208 [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype)) 209 210 self._assertOpOutputMatchesExpected( 211 math_ops.asinh, 212 np.array([1, 2, 3, 4], dtype=dtype), 213 expected=np.array( 214 [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype)) 215 216 self._assertOpOutputMatchesExpected( 217 math_ops.atanh, 218 np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), 219 expected=np.array( 220 [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype)) 221 222 self._assertOpOutputMatchesExpected( 223 math_ops.ceil, 224 np.array([[-1.7, 1.2]], dtype=dtype), 225 expected=np.array([[-1, 2]], dtype=dtype)) 226 227 self._assertOpOutputMatchesExpected( 228 math_ops.cosh, 229 np.array([1, 2, 3, 4], dtype=dtype), 230 expected=np.array( 231 [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) 232 233 # Disable float16 testing for now 234 if dtype != np.float16: 235 x = np.arange(-10, 10, 1).astype(dtype) 236 with self.session() as session: 237 erf_x = session.run(math_ops.erf(x)) 238 erfc_x = session.run(math_ops.erfc(x)) 239 240 self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x) 241 self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x) 242 243 self._assertOpOutputMatchesExpected( 244 math_ops.exp, 245 np.array([[-1, 1]], dtype=dtype), 246 expected=np.array([[0.36787945, 2.7182817]], dtype=dtype)) 247 248 self._assertOpOutputMatchesExpected( 249 math_ops.expm1, 250 np.array([[-1, 1]], dtype=dtype), 251 expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), 252 rtol=1e-5) 253 254 self._assertOpOutputMatchesExpected( 255 math_ops.floor, 256 np.array([[-1.7, 1.2]], dtype=dtype), 257 expected=np.array([[-2, 1]], dtype=dtype)) 258 259 self._assertOpOutputMatchesExpected( 260 math_ops.is_finite, 261 np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], 262 dtype=dtype), 263 expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_)) 264 265 # Tests for tf.nn ops. 266 self._assertOpOutputMatchesExpected( 267 nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0)) 268 269 self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8)) 270 271 self._assertOpOutputMatchesExpected( 272 nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10)) 273 274 self._assertOpOutputMatchesExpected( 275 math_ops.reciprocal, 276 np.array([[1, 2]], dtype=dtype), 277 expected=np.array([[1, 0.5]], dtype=dtype)) 278 279 self._assertOpOutputMatchesExpected( 280 math_ops.log, 281 np.array([[1, 2]], dtype=dtype), 282 expected=np.array([[0, 0.69314718]], dtype=dtype)) 283 284 self._assertOpOutputMatchesExpected( 285 math_ops.sin, 286 np.array([[1, 2]], dtype=dtype), 287 expected=np.array([[0.841478, 0.909302]], dtype=dtype)) 288 289 self._assertOpOutputMatchesExpected( 290 math_ops.cos, 291 np.array([[1, 2]], dtype=dtype), 292 expected=np.array([[0.540297, -0.41614]], dtype=dtype)) 293 294 # Confirm that log1p will remain precise across a range of small values. 295 self._assertOpOutputMatchesExpected( 296 math_ops.log1p, 297 np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], 298 dtype=dtype), 299 expected=np.log1p( 300 np.array( 301 [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], 302 dtype=dtype)).astype(dtype), 303 rtol=1e-15 if dtype == np.float64 else 1e-4, 304 atol=1e-15 if dtype == np.float64 else 1e-4) 305 306 self._assertOpOutputMatchesExpected( 307 math_ops.rint, 308 np.array( 309 [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], 310 [0.5, 1.5, 2.5, 3.5]], 311 dtype=dtype), 312 expected=np.array( 313 [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) 314 self._assertOpOutputMatchesExpected( 315 math_ops.round, 316 np.array( 317 [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], 318 [0.5, 1.5, 2.5, 3.5]], 319 dtype=dtype), 320 expected=np.array( 321 [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype)) 322 323 self._assertOpOutputMatchesExpected( 324 math_ops.rsqrt, 325 np.array([[4, 16]], dtype=dtype), 326 expected=np.array([[0.5, 0.25]], dtype=dtype)) 327 328 self._assertOpOutputMatchesExpected( 329 math_ops.sigmoid, 330 np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), 331 expected=np.array( 332 [[0.7310586, 0.7310586, 0.7310586, 0.7310586], 333 [0.7310586, 0.880797, 0.95257413, 0.98201376]], 334 dtype=dtype)) 335 336 self._assertOpOutputMatchesExpected( 337 math_ops.sigmoid, 338 np.array([-300, -150, 0, 150, 300], dtype=dtype), 339 expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype)) 340 341 self._assertOpOutputMatchesExpected( 342 math_ops.sinh, 343 np.array([1, 2, 3, 4], dtype=dtype), 344 expected=np.array( 345 [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype)) 346 347 self._assertOpOutputMatchesExpected( 348 math_ops.sqrt, 349 np.array([[4, 9]], dtype=dtype), 350 expected=np.array([[2, 3]], dtype=dtype)) 351 352 self._assertOpOutputMatchesExpected( 353 math_ops.tan, 354 np.array([1, 2, 3, 4], dtype=dtype), 355 expected=np.array( 356 [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype)) 357 358 self._assertOpOutputMatchesExpected( 359 math_ops.tanh, 360 np.array([[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], 361 [19, -19, 22, -22]], 362 dtype=dtype), 363 expected=np.array( 364 [[0.76159418, 0.96402758, 0.99505478, 0.99932933], 365 [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], 366 dtype=dtype)) 367 368 self._assertOpOutputMatchesExpected( 369 nn_ops.log_softmax, 370 np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), 371 expected=np.array( 372 [[-1.3862944, -1.3862944, -1.3862944, -1.3862944], 373 [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], 374 dtype=dtype)) 375 376 self._assertOpOutputMatchesExpected( 377 nn_ops.elu, 378 np.array([[-1, 0, 1, -1e-6]], dtype=dtype), 379 expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), 380 rtol=1e-5, 381 atol=1e-6) 382 383 self._assertOpOutputMatchesExpected( 384 nn_ops.selu, 385 np.array([[-1, 0, 1, -1e-5]], dtype=dtype), 386 expected=np.array( 387 [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], 388 dtype=dtype), 389 rtol=1e-5, 390 atol=1e-6) 391 392 self._assertOpOutputMatchesExpected( 393 nn_ops.relu, 394 np.array([[-1, 1]], dtype=dtype), 395 expected=np.array([[0, 1]], dtype=dtype)) 396 397 self._assertOpOutputMatchesExpected( 398 nn_ops.relu6, 399 np.array([[-0.05, 6.05, 5]], dtype=dtype), 400 expected=np.array([[0, 6, 5]], dtype=dtype)) 401 402 self._assertOpOutputMatchesExpected( 403 nn_ops.leaky_relu, 404 np.array([[-2, -1, 0, 1, 2]], dtype=dtype), 405 expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype)) 406 407 self._assertOpOutputMatchesExpected( 408 nn_ops.softmax, 409 np.array([1, 2, 3, 4], dtype=dtype), 410 expected=np.array([0.032058604, 0.087144323, 0.23688284, 0.64391428], 411 dtype=dtype)) 412 413 self._assertOpOutputMatchesExpected( 414 nn_ops.softmax, 415 np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), 416 expected=np.array( 417 [[0.25, 0.25, 0.25, 0.25], 418 [0.032058604, 0.087144323, 0.23688284, 0.64391428]], 419 dtype=dtype)) 420 421 self._assertOpOutputMatchesExpected( 422 nn_ops.softmax, 423 np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), 424 expected=np.array( 425 [[[0.5, 0.5], [0.5, 0.5]], 426 [[0.26894142, 0.73105858], [0.26894142, 0.73105858]]], 427 dtype=dtype)) 428 429 self._assertOpOutputMatchesExpected( 430 nn_ops.softsign, 431 np.array([[-2, -1, 0, 1, 2]], dtype=dtype), 432 expected=np.array( 433 [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype)) 434 435 self._assertOpOutputMatchesExpected( 436 math_ops.sign, 437 np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, 438 float("nan")]], 439 dtype=dtype), 440 expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, 441 float("nan")]], 442 dtype=dtype)) 443 444 self._assertOpOutputMatchesExpected( 445 math_ops.is_finite, 446 np.array([[42, float("inf"), -123], [float("nan"), 0, -0.0]], 447 dtype=dtype), 448 expected=np.array([[True, False, True], [False, True, True]], 449 dtype=np.bool_)) 450 451 self._assertOpOutputMatchesExpected( 452 math_ops.lgamma, 453 np.array(0.5, dtype=dtype), 454 expected=np.array(np.log(np.pi) / 2, dtype=dtype)) 455 456 self._assertOpOutputMatchesExpected( 457 math_ops.lgamma, 458 np.array( 459 [[1, 2, 3], [4, 5, 6], [1 / 2, 3 / 2, 5 / 2], 460 [-3 / 2, -7 / 2, -11 / 2]], 461 dtype=dtype), 462 expected=np.array( 463 [ 464 [0, 0, np.log(2.0)], 465 [np.log(6.0), np.log(24.0), 466 np.log(120)], 467 [ 468 np.log(np.pi) / 2, 469 np.log(np.pi) / 2 - np.log(2), 470 np.log(np.pi) / 2 - np.log(4) + np.log(3) 471 ], 472 [ 473 np.log(np.pi) / 2 - np.log(3) + np.log(4), 474 np.log(np.pi) / 2 - np.log(105) + np.log(16), 475 np.log(np.pi) / 2 - np.log(10395) + np.log(64), 476 ], 477 ], 478 dtype=dtype)) 479 480 # The actual result is complex. Take the real part. 481 self._assertOpOutputMatchesExpected( 482 math_ops.lgamma, 483 np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), 484 expected=np.array( 485 [ 486 np.log(np.pi) / 2 + np.log(2), 487 np.log(np.pi) / 2 - np.log(15) + np.log(8), 488 np.log(np.pi) / 2 - np.log(945) + np.log(32), 489 ], 490 dtype=dtype), 491 atol=1e-4) 492 493 self._assertOpOutputMatchesExpected( 494 math_ops.digamma, 495 np.array( 496 [[1.0, 0.5, 1 / 3.0], [0.25, 1 / 6.0, 0.125], [2.0, 3.0, 4.0], 497 [6.0, 8.0, 9.0]], 498 dtype=dtype), 499 expected=np.array( 500 [ 501 [ 502 -np.euler_gamma, -2 * np.log(2) - np.euler_gamma, 503 -np.pi / 2 / np.sqrt(3) - 3 * np.log(3) / 2 - 504 np.euler_gamma 505 ], 506 [ 507 -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, 508 -np.pi * np.sqrt(3) / 2 - 2 * np.log(2) - 509 3 * np.log(3) / 2 - np.euler_gamma, 510 -np.pi / 2 - 4 * np.log(2) - 511 (np.pi + np.log(2 + np.sqrt(2)) - np.log(2 - np.sqrt(2))) 512 / np.sqrt(2) - np.euler_gamma 513 ], 514 [ 515 1 - np.euler_gamma, 1.5 - np.euler_gamma, 516 11 / 6.0 - np.euler_gamma 517 ], 518 [ 519 137 / 60.0 - np.euler_gamma, 363 / 140.0 - np.euler_gamma, 520 761 / 280.0 - np.euler_gamma 521 ], 522 ], 523 dtype=dtype)) 524 525 def testSigmoidNumericalStability(self): 526 for dtype in self.float_types: 527 if dtype != np.float16: 528 self._assertOpOutputMatchesExpected( 529 lambda x: math_ops.sigmoid(x) / math_ops.log1p(math_ops.exp(x)), 530 np.array([-40, 40], dtype=dtype), 531 expected=np.array([1.0, 0.025], dtype=dtype)) 532 533 def testQuantizeAndDequantize(self): 534 for dtype in self.float_types: 535 536 def quantize_and_dequantize_v2(x): 537 return array_ops.quantize_and_dequantize( 538 x, -127, 127, signed_input=True, num_bits=8) 539 540 def quantize_and_dequantize_v3(x): 541 return array_ops.quantize_and_dequantize_v3( 542 x, -127, 127, num_bits=8, signed_input=True, range_given=False) 543 544 def quantize_and_dequantize_v4(x): 545 return array_ops.quantize_and_dequantize_v2( 546 x, -127, 127, signed_input=True, num_bits=8) 547 548 test_fns = (quantize_and_dequantize_v2, quantize_and_dequantize_v3, 549 quantize_and_dequantize_v4) 550 for test_fn in test_fns: 551 self._assertOpOutputMatchesExpected( 552 test_fn, 553 np.array([-1, -0.5, 0, 0.3], dtype=dtype), 554 expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) 555 556 def quantize_and_dequantize_v2_round_half_up(x): 557 return array_ops.quantize_and_dequantize( 558 x, 559 -1, 560 1.0, 561 signed_input=True, 562 num_bits=8, 563 range_given=True, 564 round_mode="HALF_UP") 565 566 self._assertOpOutputMatchesExpected( 567 quantize_and_dequantize_v2_round_half_up, 568 np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype), 569 expected=np.array([ 570 -102.0 / 127, 571 -63.0 / 127, 572 0, 573 38.0 / 127, 574 102.0 / 127, 575 -128.0 / 127, 576 1, 577 ], 578 dtype=dtype)) 579 580 def quantize_and_dequantize_v2_round_half_to_even(x): 581 return array_ops.quantize_and_dequantize( 582 x, 583 -1.0, 584 1.0, 585 signed_input=True, 586 num_bits=8, 587 range_given=True, 588 round_mode="HALF_TO_EVEN") 589 590 self._assertOpOutputMatchesExpected( 591 quantize_and_dequantize_v2_round_half_to_even, 592 np.array( 593 [ 594 -0.8, 595 # The -0.5 should become -63.5 after scaling and with 596 # rounding this should become -64. But with the test 597 # unary_ops_test_cpu_ondemand, this fails as the result 598 # before scaling becomes -63.499996 and gets rounded to -63. 599 # TODO(sreenik): Some one more familiar with this test needs 600 # to take a look and resolve this. This works on all other 601 # variations of the platform like cpu, and gpu. 602 # -0.5, 603 0, 604 0.3, 605 0.8, 606 -2, 607 33 608 ], 609 dtype=dtype), 610 expected=np.array( 611 [ 612 -102.0 / 127, 613 # -64.0 / 127, 614 0, 615 38.0 / 127, 616 102.0 / 127, 617 -128.0 / 127, 618 1, 619 ], 620 dtype=dtype)) 621 622 def testComplexOps(self): 623 for dtype in self.complex_types: 624 625 self._assertOpOutputMatchesExpected( 626 math_ops.acosh, 627 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), 628 expected=np.arccosh( 629 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) 630 631 self._assertOpOutputMatchesExpected( 632 math_ops.asinh, 633 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), 634 expected=np.arcsinh( 635 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) 636 637 self._assertOpOutputMatchesExpected( 638 math_ops.atanh, 639 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), 640 expected=np.arctanh( 641 np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) 642 643 self._assertOpOutputMatchesExpected( 644 math_ops.cosh, 645 np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype), 646 expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype))) 647 648 self._assertOpOutputMatchesExpected( 649 math_ops.sinh, 650 np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), 651 expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) 652 653 self._assertOpOutputMatchesExpected( 654 math_ops.exp, 655 np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), 656 expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) 657 658 self._assertOpOutputMatchesExpected( 659 math_ops.expm1, 660 np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), 661 expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), 662 rtol=1e-6, 663 atol=1e-6) 664 665 # For real part close to zero, or imaginary part close to a multiple of 666 # pi. 667 668 self._assertOpOutputMatchesExpected( 669 math_ops.expm1, 670 np.array([[1e-11 + 1j, -1e-11 - 1j, 1. + 1e-11j, 671 -1. - 1e-11j, 1e-13j + 1e-13j]], dtype=dtype), 672 # TODO(srvasude): Use numpy as the source of truth after we depend on 673 # latest numpy with this pull request: 674 # https://github.com/numpy/numpy/pull/15110. 675 # The numbers below were generated by scipy.special.expm1. 676 expected=np.array([[ 677 -4.59697694e-01+8.41470985e-01j, 678 -4.59697694e-01-8.41470985e-01j, 679 1.71828183e+00+2.71828183e-11j, 680 -6.32120559e-01-3.67879441e-12j, 681 -2.00000000e-26+2.00000000e-13j]], dtype=dtype), 682 rtol=1e-09, 683 atol=1e-20) 684 685 self._assertOpOutputMatchesExpected( 686 math_ops.reciprocal, 687 np.array([[1, 2j, 2 + 3j]], dtype=dtype), 688 expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) 689 690 self._assertOpOutputMatchesExpected( 691 math_ops.log, 692 np.array([[5j, 3 - 2j]], dtype=dtype), 693 expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) 694 695 self._assertOpOutputMatchesExpected( 696 math_ops.sin, 697 np.array([[5j, 3 - 2j]], dtype=dtype), 698 expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype))) 699 700 self._assertOpOutputMatchesExpected( 701 math_ops.cos, 702 np.array([[5j, 3 - 2j]], dtype=dtype), 703 expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) 704 705 self._assertOpOutputMatchesExpected( 706 math_ops.log1p, 707 np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), 708 expected=np.log1p( 709 np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), 710 rtol=1e-4, 711 atol=1e-6) 712 713 val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) 714 self._assertOpOutputMatchesExpected( 715 math_ops.rsqrt, val, expected=1 / np.sqrt(val)) 716 717 self._assertOpOutputMatchesExpected( 718 math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) 719 720 self._assertOpOutputMatchesExpected( 721 math_ops.sqrt, val, expected=np.sqrt(val)) 722 723 self._assertOpOutputMatchesExpected( 724 math_ops.tanh, 725 np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), 726 expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) 727 728 self._assertOpOutputMatchesExpected( 729 math_ops.tan, 730 np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), 731 expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) 732 733 ctypes = {np.complex64: np.float32, np.complex128: np.float64} 734 self._assertOpOutputMatchesExpected( 735 math_ops.abs, 736 np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), 737 expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype])) 738 739 self._assertOpOutputMatchesExpected( 740 math_ops.negative, 741 np.array([[-1 + 2j, -3j]], dtype=dtype), 742 expected=np.array([[1 - 2j, 3j]], dtype=dtype)) 743 744 self._assertOpOutputMatchesExpected( 745 math_ops.square, 746 np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype), 747 expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2) 748 749 self._assertOpOutputMatchesExpected( 750 array_ops.zeros_like, 751 np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype), 752 expected=np.array([[0, 0], [0, 0]], dtype=dtype)) 753 754 self._assertOpOutputMatchesExpected( 755 array_ops.ones_like, 756 np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), 757 expected=np.array([[1, 1], [1, 1]], dtype=dtype)) 758 759 self._assertOpOutputMatchesExpected( 760 math_ops.angle, 761 np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), 762 expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) 763 764 self._assertOpOutputMatchesExpected( 765 math_ops.conj, 766 np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), 767 expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype)) 768 769 self._assertOpOutputMatchesExpected( 770 math_ops.imag, 771 np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), 772 expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype])) 773 774 self._assertOpOutputMatchesExpected( 775 math_ops.real, 776 np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), 777 expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) 778 779 def testIntOps(self): 780 for dtype in self.int_types: 781 self._assertOpOutputMatchesExpected( 782 bitwise_ops.invert, 783 np.array([0, -1, 1, 16, 42], dtype=dtype), 784 expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) 785 786 # Test population_count for array inputs. 787 raw_inputs = [ 788 0, 1, -1, 3, -3, 5, -5, 14, -14, 127, 128, 255, 256, 65535, 65536, 789 2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32, -2**63 + 1, 790 2**63 - 1 791 ] 792 # Only choose inputs which fit in the int dtype. 793 raw_inputs = list( 794 filter(lambda x: np.iinfo(dtype).min <= x <= np.iinfo(dtype).max, 795 raw_inputs)) 796 inputs = np.array(raw_inputs, dtype=dtype) 797 798 def count_bits(x): 799 return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes())) 800 801 truth = [count_bits(x) for x in inputs] 802 self._assertOpOutputMatchesExpected( 803 bitwise_ops.population_count, 804 inputs, 805 expected=np.array(truth, dtype=np.uint8), 806 equality_test=self.AssertAllEqual) 807 808 # Test population_count for scalar inputs. 809 for raw_inp in raw_inputs: 810 inp = dtype(raw_inp) 811 truth = count_bits(inp) 812 self._assertOpOutputMatchesExpected( 813 bitwise_ops.population_count, 814 inp, 815 expected=np.uint8(truth), 816 equality_test=self.AssertAllEqual) 817 818 def testNumericOps(self): 819 for dtype in self.numeric_types - {np.int8, np.uint8}: 820 self._assertOpOutputMatchesExpected( 821 math_ops.abs, 822 np.array([[2, -1]], dtype=dtype), 823 expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) 824 825 self._assertOpOutputMatchesExpected( 826 math_ops.negative, 827 np.array([[-1, 1]], dtype=dtype), 828 expected=np.array([[1, -1]], dtype=dtype)) 829 830 self._assertOpOutputMatchesExpected( 831 math_ops.square, 832 np.array([[-2, 3]], dtype=dtype), 833 expected=np.array([[4, 9]], dtype=dtype)) 834 835 self._assertOpOutputMatchesExpected( 836 array_ops.zeros_like, 837 np.array([[4, 3], [2, 1]], dtype=dtype), 838 expected=np.array([[0, 0], [0, 0]], dtype=dtype)) 839 840 self._assertOpOutputMatchesExpected( 841 array_ops.ones_like, 842 np.array([[4, 3], [2, 1]], dtype=dtype), 843 expected=np.array([[1, 1], [1, 1]], dtype=dtype)) 844 845 # TODO(phawkins): these tests fail unless fastmath optimizations 846 # are disabled. Use more robust IsInf/IsNaN detection and enable these 847 # tests. 848 @unittest.skip("test case fails in fast-math mode") 849 def testIsInfAndIsNan(self): 850 for dtype in self.float_types: 851 self._assertOpOutputMatchesExpected( 852 math_ops.is_inf, 853 np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], 854 dtype=dtype), 855 expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool_)) 856 self._assertOpOutputMatchesExpected( 857 math_ops.is_nan, 858 np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], 859 dtype=dtype), 860 expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool_)) 861 self._assertOpOutputMatchesExpected( 862 math_ops.sign, 863 np.array([[np.nan]], dtype=dtype), 864 expected=np.array([[0.0]], dtype=dtype)) 865 866 def testLogicalOps(self): 867 self._assertOpOutputMatchesExpected( 868 math_ops.logical_not, 869 np.array([[True, False], [False, True]], dtype=np.bool_), 870 expected=np.array([[False, True], [True, False]], dtype=np.bool_)) 871 872 def testBiasAddGrad(self): 873 self._assertOpOutputMatchesExpected( 874 gen_nn_ops.bias_add_grad, 875 np.array([[1., 2.], [3., 4.]], dtype=np.float32), 876 expected=np.array([4., 6.], dtype=np.float32)) 877 878 self._assertOpOutputMatchesExpected( 879 lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), 880 np.array( 881 [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), 882 expected=np.array([14., 22.], dtype=np.float32)) 883 884 def testCast(self): 885 shapes = [[], [4], [2, 3], [2, 0, 4]] 886 types = { 887 dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64, 888 dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64 889 } 890 for src_type in types: 891 for dst_type in types: 892 src_np_dtype = src_type.as_numpy_dtype 893 dst_np_dtype = dst_type.as_numpy_dtype 894 895 for shape in shapes: 896 src = np.arange(np.prod(shape)).astype(src_np_dtype) 897 898 if src_type in self.complex_tf_types: 899 src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype) 900 src = src.reshape(shape) 901 dst = src.astype(dst_np_dtype) 902 self._assertOpOutputMatchesExpected( 903 lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), 904 src, 905 expected=dst) 906 907 # Check special values. 908 if src_type.is_integer: 909 imin = np.iinfo(src_np_dtype).min 910 imax = np.iinfo(src_np_dtype).max 911 src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) 912 elif src_type in self.float_tf_types: 913 if dst_type.is_integer: 914 imin = np.iinfo(dst_np_dtype).min 915 imax = np.iinfo(dst_np_dtype).max // 2 916 src = np.array([imin, imax, 0, 1], dtype=src_np_dtype) 917 elif dst_type in self.float_tf_types: 918 fmin = np.finfo(dst_np_dtype).min 919 fmax = np.finfo(dst_np_dtype).max 920 tiny = np.finfo(dst_np_dtype).tiny 921 eps = np.finfo(dst_np_dtype).eps 922 src = np.array( 923 [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf], 924 dtype=src_np_dtype) 925 dst = src.astype(dst_np_dtype) 926 self._assertOpOutputMatchesExpected( 927 lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), 928 src, 929 expected=dst) 930 931 def testBitcast(self): 932 self._assertOpOutputMatchesExpected( 933 lambda x: array_ops.bitcast(x, dtypes.int32), 934 np.array([1, 0x3f800000], np.int32), 935 expected=np.array([1, 0x3f800000], np.int32)) 936 self._assertOpOutputMatchesExpected( 937 lambda x: array_ops.bitcast(x, dtypes.float32), 938 np.array([1, 0x3f800000], np.int32), 939 expected=np.array([1e-45, 1.0], np.float32)) 940 self._assertOpOutputMatchesExpected( 941 lambda x: array_ops.bitcast(x, dtypes.int32), 942 np.array([1e-45, 1.0], np.float32), 943 expected=np.array([1, 0x3f800000], np.int32)) 944 if np.int64 in self.numeric_types: 945 self._assertOpOutputMatchesExpected( 946 lambda x: array_ops.bitcast(x, dtypes.int64), 947 np.array([1, 0x100000003f800000], np.uint64), 948 expected=np.array([1, 0x100000003f800000], np.int64)) 949 self._assertOpOutputMatchesExpected( 950 lambda x: array_ops.bitcast(x, dtypes.uint64), 951 np.array([1, 0x100000003f800000], np.int64), 952 expected=np.array([1, 0x100000003f800000], np.uint64)) 953 954 def testBitcastInt8ToFloat(self): 955 self._assertOpOutputMatchesExpected( 956 lambda x: array_ops.bitcast(x, dtypes.float32), 957 np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8), 958 expected=np.array([1e-45, 3.14159], np.float32)) 959 self._assertOpOutputMatchesExpected( 960 lambda x: array_ops.bitcast(x, dtypes.np.int8), 961 np.array([1e-45, 3.14159], np.float32), 962 expected=np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8)) 963 964 def testInvertPermutation(self): 965 for np_dtype in [np.int32, np.int64]: 966 self._assertOpOutputMatchesExpected( 967 array_ops.invert_permutation, 968 np.array([1, 2, 0], np_dtype), 969 expected=np.array([2, 0, 1], dtype=np_dtype)) 970 971 def testInvertPermutationTwiceIsNoop(self): 972 973 def invert_twice(x): 974 return array_ops.invert_permutation(array_ops.invert_permutation(x)) 975 976 for np_dtype in [np.int32, np.int64]: 977 self._assertOpOutputMatchesExpected( 978 invert_twice, 979 np.array([1, 2, 0], np_dtype), 980 expected=np.array([1, 2, 0], dtype=np_dtype)) 981 982 def testRank(self): 983 rank_op = lambda x: array_ops.rank_internal(x, optimize=False) 984 for dtype in self.numeric_types: 985 self._assertOpOutputMatchesExpected( 986 rank_op, dtype(7), expected=np.int32(0)) 987 self._assertOpOutputMatchesExpected( 988 rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) 989 self._assertOpOutputMatchesExpected( 990 rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) 991 self._assertOpOutputMatchesExpected( 992 rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) 993 self._assertOpOutputMatchesExpected( 994 rank_op, 995 np.array([[-1], [1], [4]], dtype=dtype), 996 expected=np.int32(2)) 997 998 def testShape(self): 999 shape_op = lambda x: array_ops.shape_internal(x, optimize=False) 1000 for dtype in self.numeric_types: 1001 self._assertOpOutputMatchesExpected( 1002 shape_op, dtype(7), expected=np.array([], dtype=np.int32)) 1003 self._assertOpOutputMatchesExpected( 1004 shape_op, 1005 np.array([[], []], dtype=dtype), 1006 expected=np.array([2, 0], dtype=np.int32)) 1007 self._assertOpOutputMatchesExpected( 1008 shape_op, 1009 np.array([-1, 1], dtype=dtype), 1010 expected=np.array([2], dtype=np.int32)) 1011 self._assertOpOutputMatchesExpected( 1012 shape_op, 1013 np.array([[-1, 1]], dtype=dtype), 1014 expected=np.array([1, 2], dtype=np.int32)) 1015 self._assertOpOutputMatchesExpected( 1016 shape_op, 1017 np.array([[-1], [1], [4]], dtype=dtype), 1018 expected=np.array([3, 1], dtype=np.int32)) 1019 1020 def testSize(self): 1021 size_op = lambda x: array_ops.size_internal(x, optimize=False) 1022 for dtype in self.numeric_types: 1023 self._assertOpOutputMatchesExpected( 1024 size_op, dtype(7), expected=np.int32(1)) 1025 self._assertOpOutputMatchesExpected( 1026 size_op, np.array([[], []], dtype=dtype), expected=np.int32(0)) 1027 self._assertOpOutputMatchesExpected( 1028 size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2)) 1029 self._assertOpOutputMatchesExpected( 1030 size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) 1031 self._assertOpOutputMatchesExpected( 1032 size_op, 1033 np.array([[-1], [1], [4]], dtype=dtype), 1034 expected=np.int32(3)) 1035 1036 def testSizeWithInt64OutType(self): 1037 1038 def size_op(x): 1039 return array_ops.size_internal(x, optimize=False, out_type=np.int64) 1040 1041 for dtype in self.numeric_types: 1042 self._assertOpOutputMatchesExpected( 1043 size_op, 1044 np.array([[-1], [1], [4]], dtype=dtype), 1045 expected=np.int64(3)) 1046 1047 def testUnpack(self): 1048 self._assertOpOutputMatchesExpected( 1049 array_ops.unstack, 1050 np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), 1051 expected=[ 1052 np.array([1., 2.], dtype=np.float32), 1053 np.array([3., 4.], dtype=np.float32), 1054 np.array([5., 6.], dtype=np.float32), 1055 ], 1056 equality_test=self.ListsAreClose) 1057 1058 self._assertOpOutputMatchesExpected( 1059 lambda x: array_ops.unstack(x, axis=1), 1060 np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), 1061 expected=[ 1062 np.array([1., 3., 5.], dtype=np.float32), 1063 np.array([2., 4., 6.], dtype=np.float32), 1064 ], 1065 equality_test=self.ListsAreClose) 1066 1067 def testDepthToSpace(self): 1068 1069 def make_op(data_format): 1070 1071 def op(x): 1072 return array_ops.depth_to_space( 1073 x, block_size=2, data_format=data_format) 1074 1075 return op 1076 1077 for dtype in self.numeric_types: 1078 for data_format in ["NCHW", "NHWC"]: 1079 self._assertOpOutputMatchesExpected( 1080 make_op(data_format), 1081 nhwc_to_format( 1082 np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format), 1083 expected=nhwc_to_format( 1084 np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format)) 1085 1086 self._assertOpOutputMatchesExpected( 1087 make_op(data_format), 1088 nhwc_to_format( 1089 np.array( 1090 [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), 1091 data_format), 1092 expected=nhwc_to_format( 1093 np.array( 1094 [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], 1095 dtype=dtype), data_format)) 1096 1097 self._assertOpOutputMatchesExpected( 1098 make_op(data_format), 1099 nhwc_to_format( 1100 np.array( 1101 [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], 1102 [13, 14, 15, 16]]]], 1103 dtype=dtype), data_format), 1104 expected=nhwc_to_format( 1105 np.array( 1106 [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], 1107 [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], 1108 dtype=dtype), data_format)) 1109 1110 self._assertOpOutputMatchesExpected( 1111 make_op("NCHW_VECT_C"), 1112 np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), 1113 expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]], 1114 [[16, 17, 18, 19], [24, 25, 26, 27]]], 1115 [[[4, 5, 6, 7], [12, 13, 14, 15]], 1116 [[20, 21, 22, 23], [28, 29, 30, 31]]]]], 1117 dtype=dtype)) 1118 1119 def testSpaceToDepth(self): 1120 1121 def make_op(data_format): 1122 1123 def op(x): 1124 return array_ops.space_to_depth( 1125 x, block_size=2, data_format=data_format) 1126 1127 return op 1128 1129 for dtype in self.numeric_types: 1130 for data_format in ["NCHW", "NHWC"]: 1131 self._assertOpOutputMatchesExpected( 1132 make_op(data_format), 1133 nhwc_to_format( 1134 np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format), 1135 expected=nhwc_to_format( 1136 np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format)) 1137 1138 self._assertOpOutputMatchesExpected( 1139 make_op(data_format), 1140 nhwc_to_format( 1141 np.array( 1142 [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], 1143 dtype=dtype), data_format), 1144 expected=nhwc_to_format( 1145 np.array( 1146 [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), 1147 data_format)) 1148 1149 self._assertOpOutputMatchesExpected( 1150 make_op(data_format), 1151 nhwc_to_format( 1152 np.array( 1153 [[[[1], [2], [5], [6]], [[3], [4], [7], [8]], 1154 [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], 1155 dtype=dtype), data_format), 1156 expected=nhwc_to_format( 1157 np.array( 1158 [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], 1159 [13, 14, 15, 16]]]], 1160 dtype=dtype), data_format)) 1161 1162 self._assertOpOutputMatchesExpected( 1163 make_op("NCHW_VECT_C"), 1164 np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), 1165 expected=np.array( 1166 [[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]], 1167 [[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]], 1168 [[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]], 1169 dtype=dtype)) 1170 1171 def _assertSoftplusMatchesExpected(self, 1172 features, 1173 dtype, 1174 equality_test=None, 1175 rtol=1e-6, 1176 atol=9.1e-6): 1177 features = np.array(features, dtype=dtype) 1178 zero = np.asarray(0).astype(dtype) 1179 expected = np.logaddexp(zero, features).astype(dtype) 1180 self._assertOpOutputMatchesExpected( 1181 nn_ops.softplus, 1182 features, 1183 expected=expected, 1184 equality_test=equality_test, 1185 rtol=rtol, 1186 atol=atol) 1187 1188 def testSoftplus(self): 1189 for dtype in self.float_types & {dtypes.float32, dtypes.float64}: 1190 self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) 1191 self._assertSoftplusMatchesExpected( 1192 [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) 1193 if dtype == dtypes.bfloat16.as_numpy_dtype: 1194 log_eps = np.log(np.finfo(np.float32).eps) 1195 else: 1196 log_eps = np.log(np.finfo(dtype).eps) 1197 one = dtype(1) 1198 ten = dtype(10) 1199 self._assertSoftplusMatchesExpected([ 1200 log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten, 1201 -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten, 1202 -log_eps + ten 1203 ], dtype) 1204 1205 self._assertSoftplusMatchesExpected( 1206 [0.69302183, 0.69324386], 1207 dtype, 1208 equality_test=self.AssertCloseAndSorted, 1209 rtol=9e-5, 1210 atol=9e-5) 1211 1212 def testToBool(self): 1213 for dtype in self.numeric_types - self.complex_types: 1214 self._assertOpOutputMatchesExpected( 1215 gen_functional_ops.to_bool, 1216 np.array(5, dtype=dtype), 1217 expected=np.array(True)) 1218 1219 self._assertOpOutputMatchesExpected( 1220 gen_functional_ops.to_bool, 1221 np.array(0, dtype=dtype), 1222 expected=np.array(False)) 1223 1224 self._assertOpOutputMatchesExpected( 1225 gen_functional_ops.to_bool, 1226 np.array([], dtype=dtype), 1227 expected=np.array(False)) 1228 1229 self._assertOpOutputMatchesExpected( 1230 gen_functional_ops.to_bool, 1231 np.array([1, 2, 3], dtype=dtype), 1232 expected=np.array(True)) 1233 1234 1235if __name__ == "__main__": 1236 googletest.main() 1237