1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for py_func op.""" 17 18import gc 19import queue 20import re 21 22import numpy as np 23 24from tensorflow.python.client import session as session_lib 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import errors 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.framework import type_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import batch_ops 37from tensorflow.python.ops import gradients_impl 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import script_ops 41from tensorflow.python.ops.ragged import ragged_factory_ops 42from tensorflow.python.ops.ragged import ragged_tensor 43from tensorflow.python.platform import test 44 45 46def np_func(x, y): 47 return np.sinh(x) + np.cosh(y) 48 49 50def matmul(x, y): 51 return math_ops.matmul(x, y) 52 53 54class PyFuncTestBase(test.TestCase): 55 56 def verifyExceptionHandling(self, py_exp, tf_exp, eager=False): 57 58 def inner_exception(): 59 raise py_exp("blah") # pylint: disable=not-callable 60 61 def raise_exception(): 62 inner_exception() 63 64 expected_regexp = r": blah.*" # Error at the top 65 expected_regexp += r"in raise_exception.*" # Stacktrace outer 66 expected_regexp += r"in inner_exception.*" # Stacktrace inner 67 expected_regexp += r": blah" # Stacktrace of raise 68 def expected_error_check(exception): 69 return re.search(expected_regexp, str(exception), re.DOTALL) 70 71 if eager: 72 if context.executing_eagerly(): 73 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 74 f = script_ops.eager_py_func(raise_exception, [], []) 75 return 76 else: 77 f = script_ops.eager_py_func(raise_exception, [], []) 78 else: 79 f = script_ops.py_func(raise_exception, [], []) 80 81 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 82 self.evaluate(f) 83 84 85class PyFuncTest(PyFuncTestBase): 86 """Encapsulates tests for py_func only.""" 87 88 def testRealDataTypes(self): 89 def sum_func(x, y): 90 return x + y 91 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64, 92 dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16, 93 dtypes.int32, dtypes.int64]: 94 with self.cached_session(): 95 x = constant_op.constant(1, dtype=dtype) 96 y = constant_op.constant(2, dtype=dtype) 97 z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype)) 98 self.assertEqual(z, 3) 99 100 def testComplexDataTypes(self): 101 def sub_func(x, y): 102 return x - y 103 for dtype in [dtypes.complex64, dtypes.complex128]: 104 with self.cached_session(): 105 x = constant_op.constant(1 + 1j, dtype=dtype) 106 y = constant_op.constant(2 - 2j, dtype=dtype) 107 z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype)) 108 self.assertEqual(z, -1 + 3j) 109 110 def testBoolDataTypes(self): 111 def and_func(x, y): 112 return x and y 113 dtype = dtypes.bool 114 with self.cached_session(): 115 x = constant_op.constant(True, dtype=dtype) 116 y = constant_op.constant(False, dtype=dtype) 117 z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype)) 118 self.assertEqual(z, False) 119 120 def testSingleType(self): 121 with self.cached_session(): 122 x = constant_op.constant(1.0, dtypes.float32) 123 y = constant_op.constant(2.0, dtypes.float32) 124 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32)) 125 self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32)) 126 127 def testScalar(self): 128 with self.cached_session(): 129 x = constant_op.constant(1.0, dtypes.float32) 130 y = constant_op.constant(2.0, dtypes.float32) 131 z = self.evaluate( 132 script_ops.eager_py_func(np_func, [x, y], [dtypes.float32])) 133 self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32)) 134 135 @test_util.run_v1_only("b/120545219") 136 def testArray(self): 137 with self.cached_session(): 138 x = constant_op.constant([1.0, 2.0], dtypes.float64) 139 y = constant_op.constant([2.0, 3.0], dtypes.float64) 140 z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64])) 141 self.assertAllEqual(z[0], 142 np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) 143 144 def testComplexType(self): 145 with self.cached_session(): 146 x = constant_op.constant(1 + 2j, dtypes.complex64) 147 y = constant_op.constant(3 + 4j, dtypes.complex64) 148 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64)) 149 self.assertAllClose(z, np_func(1 + 2j, 3 + 4j)) 150 151 def testRFFT(self): 152 with self.cached_session(): 153 x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) 154 155 def rfft(x): 156 return np.fft.rfft(x).astype(np.complex64) 157 158 y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64)) 159 self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.])) 160 161 def testPythonLiteral(self): 162 with self.cached_session(): 163 164 def literal(x): 165 return 1.0 if float(x) == 0.0 else 0.0 166 167 x = constant_op.constant(0.0, dtypes.float64) 168 y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64)) 169 self.assertAllClose(y, 1.0) 170 171 def testList(self): 172 with self.cached_session(): 173 174 def list_func(x): 175 return [x, x + 1] 176 177 x = constant_op.constant(0.0, dtypes.float64) 178 y = self.evaluate( 179 script_ops.py_func(list_func, [x], [dtypes.float64] * 2)) 180 self.assertAllClose(y, [0.0, 1.0]) 181 182 def testTuple(self): 183 # returns a tuple 184 with self.cached_session(): 185 186 def tuple_func(x): 187 return x, x + 1 188 189 x = constant_op.constant(0.0, dtypes.float64) 190 y = self.evaluate( 191 script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)) 192 self.assertAllClose(y, [0.0, 1.0]) 193 194 # returns a tuple, Tout and inp a tuple 195 with self.cached_session(): 196 x = constant_op.constant(0.0, dtypes.float64) 197 y = self.evaluate( 198 script_ops.py_func(tuple_func, (x,), 199 (dtypes.float64, dtypes.float64))) 200 self.assertAllClose(y, [0.0, 1.0]) 201 202 @test_util.run_v1_only("b/120545219") 203 def testStrings(self): 204 205 def read_fixed_length_numpy_strings(): 206 return np.array([b" there"]) 207 208 def read_and_return_strings(x, y): 209 return x + y 210 211 with self.cached_session(): 212 x = constant_op.constant([b"hello", b"hi"], dtypes.string) 213 y = self.evaluate( 214 script_ops.py_func(read_fixed_length_numpy_strings, [], 215 dtypes.string)) 216 z = self.evaluate( 217 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 218 self.assertAllEqual(z, [b"hello there", b"hi there"]) 219 220 @test_util.run_v1_only("b/120545219") 221 def testStringsAreConvertedToBytes(self): 222 223 def read_fixed_length_numpy_strings(): 224 return np.array([" there"]) 225 226 def read_and_return_strings(x, y): 227 return x + y 228 229 with self.cached_session(): 230 x = constant_op.constant(["hello", "hi"], dtypes.string) 231 y = self.evaluate( 232 script_ops.py_func(read_fixed_length_numpy_strings, [], 233 dtypes.string)) 234 z = self.evaluate( 235 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 236 self.assertAllEqual(z, [b"hello there", b"hi there"]) 237 238 @test_util.run_v1_only("b/120545219") 239 def testObjectArraysAreConvertedToBytes(self): 240 241 def read_object_array(): 242 return np.array([b" there", u" ya"], dtype=np.object_) 243 244 def read_and_return_strings(x, y): 245 return x + y 246 247 with self.cached_session(): 248 x = constant_op.constant(["hello", "hi"], dtypes.string) 249 y, = script_ops.py_func(read_object_array, [], 250 [dtypes.string]) 251 z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string]) 252 self.assertListEqual(list(self.evaluate(z)), [b"hello there", b"hi ya"]) 253 254 @test_util.run_v1_only("b/120545219") 255 def testStringPadding(self): 256 correct = [b"this", b"is", b"a", b"test"] 257 with self.cached_session(): 258 s, = script_ops.py_func(lambda: [correct], [], [dtypes.string]) 259 self.assertAllEqual(s, correct) 260 261 @test_util.run_v1_only("b/120545219") 262 def testStringPaddingAreConvertedToBytes(self): 263 inp = ["this", "is", "a", "test"] 264 correct = [b"this", b"is", b"a", b"test"] 265 with self.cached_session(): 266 s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) 267 self.assertAllEqual(s, correct) 268 269 @test_util.run_v1_only("b/120545219") 270 def testNulTerminatedStrings(self): 271 inp = np.array(["this\0", "is\0\0", "a\0", "test\0\0"], dtype=np.str_) 272 correct = [b"this", b"is", b"a", b"test"] 273 with self.cached_session(): 274 s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) 275 self.assertAllEqual(s, correct) 276 277 @test_util.run_v1_only("b/120545219") 278 def testLarge(self): 279 with self.cached_session() as sess: 280 x = array_ops.zeros([1000000], dtype=np.float32) 281 y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32]) 282 z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32]) 283 for _ in range(100): 284 sess.run([y[0].op, z[0].op]) 285 286 def testNoInput(self): 287 with self.cached_session(): 288 x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64)) 289 self.assertAllClose(x, 42.0) 290 291 @test_util.run_v1_only("b/120545219") 292 def testAlias(self): 293 with self.cached_session(): 294 np_array = np.array([1.0, 2.0], dtype=np.float32) 295 tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) 296 value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) 297 value.op.run() 298 self.assertAllEqual(np_array, [1.0, 2.0]) 299 300 @test_util.run_v1_only("b/120545219") 301 def testReturnUnicodeString(self): 302 with self.cached_session(): 303 correct = u"你好 世界" 304 305 def unicode_string(): 306 return correct 307 308 z, = script_ops.py_func(unicode_string, [], [dtypes.string]) 309 self.assertEqual(self.evaluate(z), correct.encode("utf8")) 310 311 @test_util.run_v1_only("b/120545219") 312 def testBadNumpyReturnType(self): 313 with self.cached_session(): 314 315 def bad(): 316 # Structured numpy arrays aren't supported. 317 return np.array([], dtype=[("foo", np.float32)]) 318 319 y, = script_ops.py_func(bad, [], [dtypes.float32]) 320 321 with self.assertRaisesRegex(errors.InternalError, 322 "Unsupported numpy data type"): 323 self.evaluate(y) 324 325 @test_util.run_v1_only("b/120545219") 326 def testBadReturnType(self): 327 with self.cached_session(): 328 329 def bad(): 330 # Non-string python objects aren't supported. 331 return {"foo": dtypes.float32} 332 333 z, = script_ops.py_func(bad, [], [dtypes.int64]) 334 335 with self.assertRaisesRegex(errors.InternalError, 336 "Unsupported object type"): 337 self.evaluate(z) 338 339 @test_util.run_v1_only("b/120545219") 340 def testReturnInput(self): 341 with self.cached_session(): 342 343 def ident(x): 344 return x[0] 345 346 p = array_ops.placeholder(dtypes.float32) 347 348 # Create a numpy array aliasing a tensor and a tensor aliasing this array 349 z, = script_ops.py_func(ident, [p], [dtypes.float32]) 350 z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0] 351 # above instead of using its memory as the return value of 352 # session.run 353 self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) 354 355 def testStateful(self): 356 # Not using self.cached_session(), which disables optimization. 357 with session_lib.Session(): 358 producer = iter(range(3)) 359 x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64]) 360 self.assertEqual(self.evaluate(x), 0) 361 self.assertEqual(self.evaluate(x), 1) 362 self.assertEqual(self.evaluate(x), 2) 363 364 @test_util.enable_tf_xla_constant_folding("b/134376434") 365 def testStateless(self): 366 # Not using self.cached_session(), which disables optimization. 367 with session_lib.Session(): 368 producer = iter(range(3)) 369 x, = script_ops.py_func( 370 lambda: next(producer), [], [dtypes.int64], stateful=False) 371 self.assertEqual(self.evaluate(x), 0) 372 self.assertEqual(self.evaluate(x), 0) 373 self.assertEqual(self.evaluate(x), 0) 374 375 @test_util.run_v1_only("b/120545219") 376 def testGradientFunction(self): 377 # Input to tf.compat.v1.py_func is necessary, 378 # otherwise get_gradient_function() returns None per default. 379 a = constant_op.constant(0) 380 x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64]) 381 y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False) 382 self.assertEqual(None, ops.get_gradient_function(x.op)) 383 self.assertEqual(None, ops.get_gradient_function(y.op)) 384 385 @test_util.run_v1_only("b/120545219") 386 def testCOrder(self): 387 with self.cached_session(): 388 val = [[1, 2], [3, 4]] 389 x, = script_ops.py_func(lambda: np.array(val, order="F"), [], 390 [dtypes.int64]) 391 self.assertAllEqual(val, self.evaluate(x)) 392 393 @test_util.run_v1_only("b/120545219") 394 def testParallel(self): 395 # Tests that tf.compat.v1.py_func's can run in parallel if they release 396 # the GIL. 397 with self.cached_session() as session: 398 q = queue.Queue(1) 399 400 def blocking_put(): 401 q.put(42) 402 q.join() # Wait for task_done(). 403 return 42 404 405 def blocking_get(): 406 v = q.get(block=True) # Wait for put(). 407 q.task_done() 408 return v 409 410 x, = script_ops.py_func(blocking_put, [], [dtypes.int64]) 411 y, = script_ops.py_func(blocking_get, [], [dtypes.int64]) 412 413 # This will result in a deadlock if the py_func's don't run in parallel. 414 session.run([x, y]) 415 416 def testNoReturnValueStateful(self): 417 418 class State: 419 420 def __init__(self): 421 self._value = np.array([1], np.int64) 422 423 def _increment(self, diff): 424 self._value += diff 425 426 def increment(self, diff): 427 return script_ops.py_func(self._increment, [diff], [], stateful=True) 428 429 @property 430 def value(self): 431 return self._value 432 433 with self.cached_session(): 434 s = State() 435 op = s.increment(constant_op.constant(2, dtypes.int64)) 436 ret = self.evaluate(op) 437 self.assertIsNone(ret) 438 self.assertAllEqual([3], s.value) 439 440 @test_util.run_v1_only("b/120545219") 441 def testNoReturnValueStateless(self): 442 443 def do_nothing(unused_x): 444 pass 445 446 f = script_ops.py_func( 447 do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False) 448 with self.cached_session(): 449 self.assertEqual(self.evaluate(f), []) 450 451 @test_util.run_v1_only("b/120545219") 452 def testExceptionHandling(self): 453 with self.cached_session(): 454 self.verifyExceptionHandling(ValueError, errors.InvalidArgumentError) 455 self.verifyExceptionHandling(TypeError, errors.InvalidArgumentError) 456 self.verifyExceptionHandling(StopIteration, errors.OutOfRangeError) 457 self.verifyExceptionHandling(MemoryError, errors.ResourceExhaustedError) 458 self.verifyExceptionHandling(NotImplementedError, 459 errors.UnimplementedError) 460 461 class WeirdError(Exception): 462 pass 463 464 self.verifyExceptionHandling(WeirdError, errors.UnknownError) 465 466 def testFunctionReferencesAreKept(self): 467 g = ops.Graph() 468 with g.as_default(): 469 c = constant_op.constant([1.], dtypes.float32) 470 @batch_ops.batch_function(1, 10, 100000) 471 def fn(x): 472 # Upon exiting this function, the py_func holds the sole reference 473 # to this lambda, without which it would be garbage collected. 474 return script_ops.py_func(lambda x: x, [x], [dtypes.float32]) 475 result = fn(c) 476 gc.collect() 477 self.evaluate(result) 478 479 480class PyFuncAndEagerPyFuncTest(PyFuncTestBase): 481 """Encapsulates tests shared between py_func and eager_py_func.""" 482 483 def verifyPyFuncsNoIncrease(self, make_graph): 484 ops.reset_default_graph() 485 gc.collect() 486 initial_size = script_ops._py_funcs.size() 487 488 for _ in range(1000): 489 make_graph() 490 491 ops.reset_default_graph() 492 gc.collect() 493 self.assertEqual(initial_size, script_ops._py_funcs.size()) 494 495 def testCleanup(self): 496 497 def make_graph(): 498 g = ops.Graph() 499 with g.as_default(): 500 c = constant_op.constant([1.], dtypes.float32) 501 _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) 502 _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) 503 # These ops have a reference to 'c' which has a reference to the 504 # graph. 505 # Checks if the functions are being deleted though the graph is 506 # referenced from them (see #18292). 507 script_ops.py_func( 508 lambda x: x + c.shape[0], [c], [dtypes.float32]) 509 script_ops.eager_py_func( 510 lambda x: x + c.shape[0], [c], [dtypes.float32]) 511 512 self.verifyPyFuncsNoIncrease(make_graph) 513 514 def testCleanupInTfFunction(self): 515 516 self.skipTest("b/144098211") 517 518 def make_graph(): 519 g = ops.Graph() 520 with g.as_default(): 521 @def_function.function 522 def fn(): 523 c = constant_op.constant([1.], dtypes.float32) 524 _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) 525 _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) 526 # These ops have a reference to 'c' which has a reference to the 527 # graph. 528 # Checks if the functions are being deleted though the graph is 529 # referenced from them (see #18292). 530 script_ops.py_func( 531 lambda x: x + c.shape[0], [c], [dtypes.float32]) 532 script_ops.eager_py_func( 533 lambda x: x + c.shape[0], [c], [dtypes.float32]) 534 fn() 535 536 self.verifyPyFuncsNoIncrease(make_graph) 537 538 539class EagerPyFuncTest(PyFuncTestBase): 540 """Encapsulates tests for eager_py_func only.""" 541 542 @test_util.run_in_graph_and_eager_modes 543 def testEagerSingleOutputInt32(self): 544 a = array_ops.ones((3, 3), dtype=dtypes.int32) 545 x = array_ops.ones((3, 1), dtype=dtypes.int32) 546 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) 547 ret = self.evaluate(output) 548 self.assertAllEqual(ret, [[3], [3], [3]]) 549 550 @test_util.run_in_graph_and_eager_modes 551 def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self): 552 if context.executing_eagerly(): 553 self.skipTest("b/126565353: We don't test eager's remote execution.") 554 555 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 556 worker = workers[0] 557 session = session_lib.Session(worker.target) 558 with ops.device("/job:worker/task:0/cpu:0"): 559 a = array_ops.ones((3, 3), dtype=dtypes.float32) 560 x = array_ops.ones((3, 1), dtype=dtypes.float32) 561 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 562 ret = session.run(output) 563 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 564 565 @test_util.run_in_graph_and_eager_modes 566 def testEagerSingleOutputFloat32(self): 567 with test_util.device(use_gpu=True): 568 a = array_ops.ones((3, 3), dtype=dtypes.float32) 569 x = array_ops.ones((3, 1), dtype=dtypes.float32) 570 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 571 ret = self.evaluate(output) 572 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 573 574 @test_util.run_in_graph_and_eager_modes 575 def testEagerArrayOutput(self): 576 with test_util.device(use_gpu=True): 577 a = array_ops.ones((3, 3), dtype=dtypes.float32) 578 x = array_ops.ones((3, 1), dtype=dtypes.float32) 579 output = script_ops.eager_py_func( 580 lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) 581 ret = self.evaluate(output) 582 self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) 583 584 @test_util.run_in_graph_and_eager_modes 585 def testEagerReturnNone(self): 586 with test_util.device(use_gpu=True): 587 def no_return_value(): 588 return 589 590 output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) 591 ret = self.evaluate(output) 592 if context.executing_eagerly(): 593 self.assertEqual(len(ret), 0) 594 else: 595 self.assertIsNone(ret) 596 597 @test_util.run_in_graph_and_eager_modes 598 @test_util.disable_tfrt("b/180469928") 599 def testEagerPyFuncInDefun(self): 600 with test_util.device(use_gpu=True): 601 def wrapper(): 602 a = array_ops.ones((3, 3), dtype=dtypes.float32) 603 x = array_ops.ones((3, 1), dtype=dtypes.float32) 604 return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 605 606 wrapped = function.defun(wrapper) 607 ret = self.evaluate(wrapped()) 608 self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) 609 610 @test_util.run_in_graph_and_eager_modes 611 @test_util.run_v1_only("b/120545219") 612 def testEagerExceptionHandling(self): 613 with test_util.device(use_gpu=True): 614 self.verifyExceptionHandling( 615 ValueError, errors.InvalidArgumentError, eager=True) 616 self.verifyExceptionHandling( 617 TypeError, errors.InvalidArgumentError, eager=True) 618 self.verifyExceptionHandling( 619 StopIteration, errors.OutOfRangeError, eager=True) 620 self.verifyExceptionHandling( 621 MemoryError, errors.ResourceExhaustedError, eager=True) 622 self.verifyExceptionHandling( 623 NotImplementedError, errors.UnimplementedError, eager=True) 624 625 class WeirdError(Exception): 626 pass 627 628 self.verifyExceptionHandling(WeirdError, errors.UnknownError, eager=True) 629 630 @test_util.run_in_graph_and_eager_modes 631 @test_util.run_v1_only("b/120545219") 632 def testEagerReturningVariableRaisesError(self): 633 def return_variable(): 634 return resource_variable_ops.ResourceVariable(0.0) 635 636 with self.assertRaisesRegex(errors.UnknownError, 637 "Attempting to return a variable"): 638 output = script_ops.eager_py_func( 639 return_variable, inp=[], Tout=dtypes.float32) 640 self.evaluate(output) 641 642 @test_util.run_in_graph_and_eager_modes 643 def testTapeCache(self): 644 # Testing for b/198962664 (gh:#51839) 645 old_cache_size = len(script_ops.tape_cache) 646 647 def f(x): 648 return x**2 649 650 x = constant_op.constant(3.0) 651 652 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 653 654 # No cache if there is no active tape 655 self.assertEqual(len(script_ops.tape_cache), old_cache_size) 656 657 with backprop.GradientTape() as tape: 658 tape.watch(x) 659 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 660 # A new cache entry is created when running eagerly. 661 if context.executing_eagerly(): 662 self.assertEqual(len(script_ops.tape_cache), old_cache_size + 1) 663 else: 664 self.assertEqual(len(script_ops.tape_cache), old_cache_size) 665 dy_dx = tape.gradient(y, x) 666 # Force a evaluation. 667 self.evaluate(dy_dx) 668 # Cache entry consumed after gradient calculation. 669 self.assertEqual(len(script_ops.tape_cache), old_cache_size) 670 671 @test_util.run_in_graph_and_eager_modes 672 def testEagerGradientTape(self): 673 674 def f(x): 675 return x**2 676 677 x = constant_op.constant(3.0) 678 with backprop.GradientTape() as tape: 679 tape.watch(x) 680 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 681 dy_dx = tape.gradient(y, x) 682 self.assertAllClose(self.evaluate(dy_dx), 6.0) 683 684 # Test complex values 685 x = constant_op.constant(3.0 + 3.0j) 686 with backprop.GradientTape() as tape: 687 tape.watch(x) 688 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.complex128) 689 dy_dx = tape.gradient(y, x) 690 # Gradient of complex will be the conj 691 self.assertAllClose(self.evaluate(dy_dx), 6.0 - 6.0j) 692 693 @test_util.run_v1_only("b/120545219") 694 def testEagerGradientGraph(self): 695 696 def f(x): 697 return x**2 698 699 x = constant_op.constant(3.0) 700 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 701 dy_dx = gradients_impl.gradients(y, x)[0] 702 self.assertEqual(self.evaluate(dy_dx), 6.0) 703 704 @test_util.run_v1_only("b/120545219") 705 def testEagerGradientGraphTwoOutputs(self): 706 707 def f(x, y): 708 return x * y, x / y 709 710 x = constant_op.constant(3.0) 711 y = constant_op.constant(2.0) 712 fa, fb = script_ops.eager_py_func(f, inp=[x, y], 713 Tout=[dtypes.float32, dtypes.float32]) 714 dy_dx = gradients_impl.gradients(fa + fb, x)[0] 715 self.assertEqual(self.evaluate(dy_dx), 2.5) 716 717 @test_util.run_in_graph_and_eager_modes 718 def testEagerGradientTapeMultipleArgs(self): 719 720 def f(x, y): 721 return x**2 + y**2 722 723 x = constant_op.constant(3.0) 724 y = constant_op.constant(4.0) 725 with backprop.GradientTape() as tape: 726 tape.watch(x) 727 tape.watch(y) 728 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 729 730 dz_dx, dz_dy = tape.gradient(z, [x, y]) 731 self.assertEqual(self.evaluate(dz_dx), 6.0) 732 self.assertEqual(self.evaluate(dz_dy), 8.0) 733 734 @test_util.run_v1_only("b/120545219") 735 def testEagerGradientGraphMultipleArgs(self): 736 737 def f(x, y): 738 return x**2 + y**2 739 740 x = constant_op.constant(3.0) 741 y = constant_op.constant(4.0) 742 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 743 744 dz_dx, dz_dy = gradients_impl.gradients(z, [x, y]) 745 self.assertEqual(self.evaluate(dz_dx), 6.0) 746 self.assertEqual(self.evaluate(dz_dy), 8.0) 747 748 @test_util.run_v1_only("b/120545219") 749 def testEagerGradientGraphLogHuber(self): 750 751 def log_huber(x, m): 752 if math_ops.abs(x) <= m: 753 return x**2 754 else: 755 return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2)) 756 757 x = array_ops.placeholder(dtypes.float32) 758 m = array_ops.placeholder(dtypes.float32) 759 760 y = script_ops.eager_py_func( 761 func=log_huber, inp=[x, m], Tout=dtypes.float32) 762 dy_dx = gradients_impl.gradients(y, x)[0] 763 764 with self.cached_session() as sess: 765 # Takes the first branch of log_huber. 766 y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) 767 self.assertEqual(y, 1.0) 768 self.assertEqual(dy_dx, 2.0) 769 770 @test_util.run_v1_only("b/120545219") 771 def testEagerRespectsDevicePlacementOfOp(self): 772 773 def f(x): 774 return math_ops.square(x) 775 776 def g(x): 777 return math_ops.add(x, x) 778 779 with ops.device("/CPU:0"): 780 # Explicitly ask for the py_funcs to execute on CPU, even if 781 # a GPU is available. 782 x = array_ops.placeholder(dtypes.float32) 783 y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) 784 z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) 785 786 with self.session() as sess: 787 output = sess.run(z, feed_dict={x: 3.0}) 788 self.assertEqual(output, 18.0) 789 790 @test_util.run_in_graph_and_eager_modes 791 def testEagerPyFuncOnGPUWithStrings(self): 792 793 def fn(a): 794 return str(a.dtype) 795 796 x = constant_op.constant("x", dtype=dtypes.string) 797 output = script_ops.eager_py_func(fn, inp=[x], Tout=dtypes.string) 798 self.assertEqual(self.evaluate(output), "<dtype: 'string'>".encode("utf8")) 799 800 @test_util.run_in_graph_and_eager_modes 801 def testEagerPyFuncNotACallable(self): 802 x = constant_op.constant("x", dtype=dtypes.string) 803 804 with self.assertRaisesRegex(ValueError, "callable"): 805 _ = script_ops.eager_py_func(x, inp=[x], Tout=dtypes.string) 806 807 def testUnsupportedToutType(self): 808 with self.assertRaisesRegex( 809 TypeError, "Cannot convert .* to a TensorFlow DType."): 810 script_ops.eager_py_func(lambda x: x, [1], [{}]) 811 812 def testRaggedTensorArg(self): 813 x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) 814 y, = script_ops.eager_py_func(math_ops.reduce_sum, [x], [dtypes.int32]) 815 self.assertAllEqual(y, 21) 816 817 def testRaggedTensorReturn(self): 818 819 def fn(v, l): 820 return ragged_tensor.RaggedTensor.from_row_lengths(v, l) 821 822 values = [1, 2, 3, 4, 5, 6] 823 lengths = constant_op.constant([3, 1, 2], dtypes.int64) 824 out_signature = [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)] 825 y, = script_ops.eager_py_func(fn, [values, lengths], out_signature) 826 self.assertIsInstance(y, ragged_tensor.RaggedTensor) 827 self.assertAllEqual(y, [[1, 2, 3], [4], [5, 6]]) 828 829 def testRaggedTensorBroadcast(self): 830 # Check that eager_py_func preserves output shape information, which is 831 # required by broadcasting. 832 def fn(x): 833 return 2 * x 834 835 def foo(x): 836 spec = ragged_tensor.RaggedTensorSpec.from_value(x) 837 res = script_ops.eager_py_func(fn, [x], spec) 838 return x + res 839 840 x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) 841 expected_result = [[3.0, 6.0], [9.0]] 842 result1 = foo(x) 843 self.assertAllEqual(result1, expected_result) 844 result2 = def_function.function(foo)(x) 845 self.assertAllEqual(result2, expected_result) 846 847 def testRaggedExpectedListGotList(self): 848 x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) 849 x_spec = type_spec.type_spec_from_value(x) 850 y, = script_ops.eager_py_func(lambda v: [v], [x], [x_spec]) 851 self.assertAllEqual(y, x) 852 853 def testRaggedExpectedListGotTuple(self): 854 x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) 855 x_spec = type_spec.type_spec_from_value(x) 856 y, = script_ops.eager_py_func(lambda v: (v,), [x], [x_spec]) 857 self.assertAllEqual(y, x) 858 859 def testRaggedExpectedListGotSingleValue(self): 860 x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) 861 x_spec = type_spec.type_spec_from_value(x) 862 y, = script_ops.eager_py_func(lambda v: v, [x], [x_spec]) 863 self.assertAllEqual(y, x) 864 865 def testRaggedNoReturnValue(self): 866 x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]]) 867 result = self.evaluate(script_ops.eager_py_func(lambda v: None, [x], [])) 868 if context.executing_eagerly(): 869 self.assertEqual(result, []) 870 else: 871 self.assertIsNone(result) 872 873 def testRaggedBadReturnTypeExpectedTensorReturnedRagged(self): 874 rt = ragged_factory_ops.constant([[1, 2], [3, 4, 5]]) 875 with self.assertRaisesRegex( 876 (ValueError, errors.InvalidArgumentError), 877 "py_function: func=.* returned .* which did not match Tout=.*"): 878 result = script_ops.eager_py_func(lambda x: x + 3, [rt], [dtypes.int32]) 879 self.evaluate(result) 880 881 def testRaggedBadReturnTypeExpectedRaggedReturnedTensor(self): 882 with self.assertRaisesRegex( 883 (ValueError, errors.InvalidArgumentError), 884 "py_function: func=.* returned .* which did not match Tout=.*"): 885 result = script_ops.eager_py_func( 886 func=lambda x: x, 887 inp=[constant_op.constant([[1, 2, 3]])], 888 Tout=[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]) 889 self.evaluate(result) 890 891 892if __name__ == "__main__": 893 test.main() 894