xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/control_flow/py_func_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for py_func op."""
17
18import gc
19import queue
20import re
21
22import numpy as np
23
24from tensorflow.python.client import session as session_lib
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import errors
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import batch_ops
37from tensorflow.python.ops import gradients_impl
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import script_ops
41from tensorflow.python.ops.ragged import ragged_factory_ops
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.platform import test
44
45
46def np_func(x, y):
47  return np.sinh(x) + np.cosh(y)
48
49
50def matmul(x, y):
51  return math_ops.matmul(x, y)
52
53
54class PyFuncTestBase(test.TestCase):
55
56  def verifyExceptionHandling(self, py_exp, tf_exp, eager=False):
57
58    def inner_exception():
59      raise py_exp("blah")  # pylint: disable=not-callable
60
61    def raise_exception():
62      inner_exception()
63
64    expected_regexp = r": blah.*"               # Error at the top
65    expected_regexp += r"in raise_exception.*"  # Stacktrace outer
66    expected_regexp += r"in inner_exception.*"  # Stacktrace inner
67    expected_regexp += r": blah"                # Stacktrace of raise
68    def expected_error_check(exception):
69      return re.search(expected_regexp, str(exception), re.DOTALL)
70
71    if eager:
72      if context.executing_eagerly():
73        with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
74          f = script_ops.eager_py_func(raise_exception, [], [])
75        return
76      else:
77        f = script_ops.eager_py_func(raise_exception, [], [])
78    else:
79      f = script_ops.py_func(raise_exception, [], [])
80
81    with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
82      self.evaluate(f)
83
84
85class PyFuncTest(PyFuncTestBase):
86  """Encapsulates tests for py_func only."""
87
88  def testRealDataTypes(self):
89    def sum_func(x, y):
90      return x + y
91    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
92                  dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
93                  dtypes.int32, dtypes.int64]:
94      with self.cached_session():
95        x = constant_op.constant(1, dtype=dtype)
96        y = constant_op.constant(2, dtype=dtype)
97        z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
98        self.assertEqual(z, 3)
99
100  def testComplexDataTypes(self):
101    def sub_func(x, y):
102      return x - y
103    for dtype in [dtypes.complex64, dtypes.complex128]:
104      with self.cached_session():
105        x = constant_op.constant(1 + 1j, dtype=dtype)
106        y = constant_op.constant(2 - 2j, dtype=dtype)
107        z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
108        self.assertEqual(z, -1 + 3j)
109
110  def testBoolDataTypes(self):
111    def and_func(x, y):
112      return x and y
113    dtype = dtypes.bool
114    with self.cached_session():
115      x = constant_op.constant(True, dtype=dtype)
116      y = constant_op.constant(False, dtype=dtype)
117      z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
118      self.assertEqual(z, False)
119
120  def testSingleType(self):
121    with self.cached_session():
122      x = constant_op.constant(1.0, dtypes.float32)
123      y = constant_op.constant(2.0, dtypes.float32)
124      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
125      self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
126
127  def testScalar(self):
128    with self.cached_session():
129      x = constant_op.constant(1.0, dtypes.float32)
130      y = constant_op.constant(2.0, dtypes.float32)
131      z = self.evaluate(
132          script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
133      self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
134
135  @test_util.run_v1_only("b/120545219")
136  def testArray(self):
137    with self.cached_session():
138      x = constant_op.constant([1.0, 2.0], dtypes.float64)
139      y = constant_op.constant([2.0, 3.0], dtypes.float64)
140      z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
141      self.assertAllEqual(z[0],
142                          np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
143
144  def testComplexType(self):
145    with self.cached_session():
146      x = constant_op.constant(1 + 2j, dtypes.complex64)
147      y = constant_op.constant(3 + 4j, dtypes.complex64)
148      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
149      self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
150
151  def testRFFT(self):
152    with self.cached_session():
153      x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
154
155      def rfft(x):
156        return np.fft.rfft(x).astype(np.complex64)
157
158      y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
159      self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
160
161  def testPythonLiteral(self):
162    with self.cached_session():
163
164      def literal(x):
165        return 1.0 if float(x) == 0.0 else 0.0
166
167      x = constant_op.constant(0.0, dtypes.float64)
168      y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
169      self.assertAllClose(y, 1.0)
170
171  def testList(self):
172    with self.cached_session():
173
174      def list_func(x):
175        return [x, x + 1]
176
177      x = constant_op.constant(0.0, dtypes.float64)
178      y = self.evaluate(
179          script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
180      self.assertAllClose(y, [0.0, 1.0])
181
182  def testTuple(self):
183    # returns a tuple
184    with self.cached_session():
185
186      def tuple_func(x):
187        return x, x + 1
188
189      x = constant_op.constant(0.0, dtypes.float64)
190      y = self.evaluate(
191          script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
192      self.assertAllClose(y, [0.0, 1.0])
193
194    # returns a tuple, Tout and inp a tuple
195    with self.cached_session():
196      x = constant_op.constant(0.0, dtypes.float64)
197      y = self.evaluate(
198          script_ops.py_func(tuple_func, (x,),
199                             (dtypes.float64, dtypes.float64)))
200      self.assertAllClose(y, [0.0, 1.0])
201
202  @test_util.run_v1_only("b/120545219")
203  def testStrings(self):
204
205    def read_fixed_length_numpy_strings():
206      return np.array([b" there"])
207
208    def read_and_return_strings(x, y):
209      return x + y
210
211    with self.cached_session():
212      x = constant_op.constant([b"hello", b"hi"], dtypes.string)
213      y = self.evaluate(
214          script_ops.py_func(read_fixed_length_numpy_strings, [],
215                             dtypes.string))
216      z = self.evaluate(
217          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
218      self.assertAllEqual(z, [b"hello there", b"hi there"])
219
220  @test_util.run_v1_only("b/120545219")
221  def testStringsAreConvertedToBytes(self):
222
223    def read_fixed_length_numpy_strings():
224      return np.array([" there"])
225
226    def read_and_return_strings(x, y):
227      return x + y
228
229    with self.cached_session():
230      x = constant_op.constant(["hello", "hi"], dtypes.string)
231      y = self.evaluate(
232          script_ops.py_func(read_fixed_length_numpy_strings, [],
233                             dtypes.string))
234      z = self.evaluate(
235          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
236      self.assertAllEqual(z, [b"hello there", b"hi there"])
237
238  @test_util.run_v1_only("b/120545219")
239  def testObjectArraysAreConvertedToBytes(self):
240
241    def read_object_array():
242      return np.array([b" there", u" ya"], dtype=np.object_)
243
244    def read_and_return_strings(x, y):
245      return x + y
246
247    with self.cached_session():
248      x = constant_op.constant(["hello", "hi"], dtypes.string)
249      y, = script_ops.py_func(read_object_array, [],
250                              [dtypes.string])
251      z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
252      self.assertListEqual(list(self.evaluate(z)), [b"hello there", b"hi ya"])
253
254  @test_util.run_v1_only("b/120545219")
255  def testStringPadding(self):
256    correct = [b"this", b"is", b"a", b"test"]
257    with self.cached_session():
258      s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
259      self.assertAllEqual(s, correct)
260
261  @test_util.run_v1_only("b/120545219")
262  def testStringPaddingAreConvertedToBytes(self):
263    inp = ["this", "is", "a", "test"]
264    correct = [b"this", b"is", b"a", b"test"]
265    with self.cached_session():
266      s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
267      self.assertAllEqual(s, correct)
268
269  @test_util.run_v1_only("b/120545219")
270  def testNulTerminatedStrings(self):
271    inp = np.array(["this\0", "is\0\0", "a\0", "test\0\0"], dtype=np.str_)
272    correct = [b"this", b"is", b"a", b"test"]
273    with self.cached_session():
274      s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
275      self.assertAllEqual(s, correct)
276
277  @test_util.run_v1_only("b/120545219")
278  def testLarge(self):
279    with self.cached_session() as sess:
280      x = array_ops.zeros([1000000], dtype=np.float32)
281      y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
282      z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
283      for _ in range(100):
284        sess.run([y[0].op, z[0].op])
285
286  def testNoInput(self):
287    with self.cached_session():
288      x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
289      self.assertAllClose(x, 42.0)
290
291  @test_util.run_v1_only("b/120545219")
292  def testAlias(self):
293    with self.cached_session():
294      np_array = np.array([1.0, 2.0], dtype=np.float32)
295      tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
296      value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
297      value.op.run()
298      self.assertAllEqual(np_array, [1.0, 2.0])
299
300  @test_util.run_v1_only("b/120545219")
301  def testReturnUnicodeString(self):
302    with self.cached_session():
303      correct = u"你好 世界"
304
305      def unicode_string():
306        return correct
307
308      z, = script_ops.py_func(unicode_string, [], [dtypes.string])
309      self.assertEqual(self.evaluate(z), correct.encode("utf8"))
310
311  @test_util.run_v1_only("b/120545219")
312  def testBadNumpyReturnType(self):
313    with self.cached_session():
314
315      def bad():
316        # Structured numpy arrays aren't supported.
317        return np.array([], dtype=[("foo", np.float32)])
318
319      y, = script_ops.py_func(bad, [], [dtypes.float32])
320
321      with self.assertRaisesRegex(errors.InternalError,
322                                  "Unsupported numpy data type"):
323        self.evaluate(y)
324
325  @test_util.run_v1_only("b/120545219")
326  def testBadReturnType(self):
327    with self.cached_session():
328
329      def bad():
330        # Non-string python objects aren't supported.
331        return {"foo": dtypes.float32}
332
333      z, = script_ops.py_func(bad, [], [dtypes.int64])
334
335      with self.assertRaisesRegex(errors.InternalError,
336                                  "Unsupported object type"):
337        self.evaluate(z)
338
339  @test_util.run_v1_only("b/120545219")
340  def testReturnInput(self):
341    with self.cached_session():
342
343      def ident(x):
344        return x[0]
345
346      p = array_ops.placeholder(dtypes.float32)
347
348      # Create a numpy array aliasing a tensor and a tensor aliasing this array
349      z, = script_ops.py_func(ident, [p], [dtypes.float32])
350      z += 0.0  # Makes sure we release the tensor aliasing the numpy array x[0]
351      # above instead of using its memory as the return value of
352      # session.run
353      self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
354
355  def testStateful(self):
356    # Not using self.cached_session(), which disables optimization.
357    with session_lib.Session():
358      producer = iter(range(3))
359      x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
360      self.assertEqual(self.evaluate(x), 0)
361      self.assertEqual(self.evaluate(x), 1)
362      self.assertEqual(self.evaluate(x), 2)
363
364  @test_util.enable_tf_xla_constant_folding("b/134376434")
365  def testStateless(self):
366    # Not using self.cached_session(), which disables optimization.
367    with session_lib.Session():
368      producer = iter(range(3))
369      x, = script_ops.py_func(
370          lambda: next(producer), [], [dtypes.int64], stateful=False)
371      self.assertEqual(self.evaluate(x), 0)
372      self.assertEqual(self.evaluate(x), 0)
373      self.assertEqual(self.evaluate(x), 0)
374
375  @test_util.run_v1_only("b/120545219")
376  def testGradientFunction(self):
377    # Input to tf.compat.v1.py_func is necessary,
378    # otherwise get_gradient_function() returns None per default.
379    a = constant_op.constant(0)
380    x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64])
381    y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False)
382    self.assertEqual(None, ops.get_gradient_function(x.op))
383    self.assertEqual(None, ops.get_gradient_function(y.op))
384
385  @test_util.run_v1_only("b/120545219")
386  def testCOrder(self):
387    with self.cached_session():
388      val = [[1, 2], [3, 4]]
389      x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
390                              [dtypes.int64])
391      self.assertAllEqual(val, self.evaluate(x))
392
393  @test_util.run_v1_only("b/120545219")
394  def testParallel(self):
395    # Tests that tf.compat.v1.py_func's can run in parallel if they release
396    # the GIL.
397    with self.cached_session() as session:
398      q = queue.Queue(1)
399
400      def blocking_put():
401        q.put(42)
402        q.join()  # Wait for task_done().
403        return 42
404
405      def blocking_get():
406        v = q.get(block=True)  # Wait for put().
407        q.task_done()
408        return v
409
410      x, = script_ops.py_func(blocking_put, [], [dtypes.int64])
411      y, = script_ops.py_func(blocking_get, [], [dtypes.int64])
412
413      # This will result in a deadlock if the py_func's don't run in parallel.
414      session.run([x, y])
415
416  def testNoReturnValueStateful(self):
417
418    class State:
419
420      def __init__(self):
421        self._value = np.array([1], np.int64)
422
423      def _increment(self, diff):
424        self._value += diff
425
426      def increment(self, diff):
427        return script_ops.py_func(self._increment, [diff], [], stateful=True)
428
429      @property
430      def value(self):
431        return self._value
432
433    with self.cached_session():
434      s = State()
435      op = s.increment(constant_op.constant(2, dtypes.int64))
436      ret = self.evaluate(op)
437      self.assertIsNone(ret)
438      self.assertAllEqual([3], s.value)
439
440  @test_util.run_v1_only("b/120545219")
441  def testNoReturnValueStateless(self):
442
443    def do_nothing(unused_x):
444      pass
445
446    f = script_ops.py_func(
447        do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
448    with self.cached_session():
449      self.assertEqual(self.evaluate(f), [])
450
451  @test_util.run_v1_only("b/120545219")
452  def testExceptionHandling(self):
453    with self.cached_session():
454      self.verifyExceptionHandling(ValueError, errors.InvalidArgumentError)
455      self.verifyExceptionHandling(TypeError, errors.InvalidArgumentError)
456      self.verifyExceptionHandling(StopIteration, errors.OutOfRangeError)
457      self.verifyExceptionHandling(MemoryError, errors.ResourceExhaustedError)
458      self.verifyExceptionHandling(NotImplementedError,
459                                   errors.UnimplementedError)
460
461      class WeirdError(Exception):
462        pass
463
464      self.verifyExceptionHandling(WeirdError, errors.UnknownError)
465
466  def testFunctionReferencesAreKept(self):
467    g = ops.Graph()
468    with g.as_default():
469      c = constant_op.constant([1.], dtypes.float32)
470      @batch_ops.batch_function(1, 10, 100000)
471      def fn(x):
472        # Upon exiting this function, the py_func holds the sole reference
473        # to this lambda, without which it would be garbage collected.
474        return script_ops.py_func(lambda x: x, [x], [dtypes.float32])
475      result = fn(c)
476      gc.collect()
477      self.evaluate(result)
478
479
480class PyFuncAndEagerPyFuncTest(PyFuncTestBase):
481  """Encapsulates tests shared between py_func and eager_py_func."""
482
483  def verifyPyFuncsNoIncrease(self, make_graph):
484    ops.reset_default_graph()
485    gc.collect()
486    initial_size = script_ops._py_funcs.size()
487
488    for _ in range(1000):
489      make_graph()
490
491    ops.reset_default_graph()
492    gc.collect()
493    self.assertEqual(initial_size, script_ops._py_funcs.size())
494
495  def testCleanup(self):
496
497    def make_graph():
498      g = ops.Graph()
499      with g.as_default():
500        c = constant_op.constant([1.], dtypes.float32)
501        _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
502        _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
503        # These ops have a reference to 'c' which has a reference to the
504        # graph.
505        # Checks if the functions are being deleted though the graph is
506        # referenced from them (see #18292).
507        script_ops.py_func(
508            lambda x: x + c.shape[0], [c], [dtypes.float32])
509        script_ops.eager_py_func(
510            lambda x: x + c.shape[0], [c], [dtypes.float32])
511
512    self.verifyPyFuncsNoIncrease(make_graph)
513
514  def testCleanupInTfFunction(self):
515
516    self.skipTest("b/144098211")
517
518    def make_graph():
519      g = ops.Graph()
520      with g.as_default():
521        @def_function.function
522        def fn():
523          c = constant_op.constant([1.], dtypes.float32)
524          _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
525          _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
526          # These ops have a reference to 'c' which has a reference to the
527          # graph.
528          # Checks if the functions are being deleted though the graph is
529          # referenced from them (see #18292).
530          script_ops.py_func(
531              lambda x: x + c.shape[0], [c], [dtypes.float32])
532          script_ops.eager_py_func(
533              lambda x: x + c.shape[0], [c], [dtypes.float32])
534        fn()
535
536    self.verifyPyFuncsNoIncrease(make_graph)
537
538
539class EagerPyFuncTest(PyFuncTestBase):
540  """Encapsulates tests for eager_py_func only."""
541
542  @test_util.run_in_graph_and_eager_modes
543  def testEagerSingleOutputInt32(self):
544    a = array_ops.ones((3, 3), dtype=dtypes.int32)
545    x = array_ops.ones((3, 1), dtype=dtypes.int32)
546    output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
547    ret = self.evaluate(output)
548    self.assertAllEqual(ret, [[3], [3], [3]])
549
550  @test_util.run_in_graph_and_eager_modes
551  def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self):
552    if context.executing_eagerly():
553      self.skipTest("b/126565353: We don't test eager's remote execution.")
554
555    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
556    worker = workers[0]
557    session = session_lib.Session(worker.target)
558    with ops.device("/job:worker/task:0/cpu:0"):
559      a = array_ops.ones((3, 3), dtype=dtypes.float32)
560      x = array_ops.ones((3, 1), dtype=dtypes.float32)
561      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
562    ret = session.run(output)
563    self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
564
565  @test_util.run_in_graph_and_eager_modes
566  def testEagerSingleOutputFloat32(self):
567    with test_util.device(use_gpu=True):
568      a = array_ops.ones((3, 3), dtype=dtypes.float32)
569      x = array_ops.ones((3, 1), dtype=dtypes.float32)
570      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
571      ret = self.evaluate(output)
572      self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
573
574  @test_util.run_in_graph_and_eager_modes
575  def testEagerArrayOutput(self):
576    with test_util.device(use_gpu=True):
577      a = array_ops.ones((3, 3), dtype=dtypes.float32)
578      x = array_ops.ones((3, 1), dtype=dtypes.float32)
579      output = script_ops.eager_py_func(
580          lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32])
581      ret = self.evaluate(output)
582      self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
583
584  @test_util.run_in_graph_and_eager_modes
585  def testEagerReturnNone(self):
586    with test_util.device(use_gpu=True):
587      def no_return_value():
588        return
589
590      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
591      ret = self.evaluate(output)
592      if context.executing_eagerly():
593        self.assertEqual(len(ret), 0)
594      else:
595        self.assertIsNone(ret)
596
597  @test_util.run_in_graph_and_eager_modes
598  @test_util.disable_tfrt("b/180469928")
599  def testEagerPyFuncInDefun(self):
600    with test_util.device(use_gpu=True):
601      def wrapper():
602        a = array_ops.ones((3, 3), dtype=dtypes.float32)
603        x = array_ops.ones((3, 1), dtype=dtypes.float32)
604        return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
605
606      wrapped = function.defun(wrapper)
607      ret = self.evaluate(wrapped())
608      self.assertAllEqual(ret, [[3.0], [3.0], [3.0]])
609
610  @test_util.run_in_graph_and_eager_modes
611  @test_util.run_v1_only("b/120545219")
612  def testEagerExceptionHandling(self):
613    with test_util.device(use_gpu=True):
614      self.verifyExceptionHandling(
615          ValueError, errors.InvalidArgumentError, eager=True)
616      self.verifyExceptionHandling(
617          TypeError, errors.InvalidArgumentError, eager=True)
618      self.verifyExceptionHandling(
619          StopIteration, errors.OutOfRangeError, eager=True)
620      self.verifyExceptionHandling(
621          MemoryError, errors.ResourceExhaustedError, eager=True)
622      self.verifyExceptionHandling(
623          NotImplementedError, errors.UnimplementedError, eager=True)
624
625      class WeirdError(Exception):
626        pass
627
628      self.verifyExceptionHandling(WeirdError, errors.UnknownError, eager=True)
629
630  @test_util.run_in_graph_and_eager_modes
631  @test_util.run_v1_only("b/120545219")
632  def testEagerReturningVariableRaisesError(self):
633    def return_variable():
634      return resource_variable_ops.ResourceVariable(0.0)
635
636    with self.assertRaisesRegex(errors.UnknownError,
637                                "Attempting to return a variable"):
638      output = script_ops.eager_py_func(
639          return_variable, inp=[], Tout=dtypes.float32)
640      self.evaluate(output)
641
642  @test_util.run_in_graph_and_eager_modes
643  def testTapeCache(self):
644    # Testing for b/198962664 (gh:#51839)
645    old_cache_size = len(script_ops.tape_cache)
646
647    def f(x):
648      return x**2
649
650    x = constant_op.constant(3.0)
651
652    y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
653
654    # No cache if there is no active tape
655    self.assertEqual(len(script_ops.tape_cache), old_cache_size)
656
657    with backprop.GradientTape() as tape:
658      tape.watch(x)
659      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
660      # A new cache entry is created when running eagerly.
661      if context.executing_eagerly():
662        self.assertEqual(len(script_ops.tape_cache), old_cache_size + 1)
663      else:
664        self.assertEqual(len(script_ops.tape_cache), old_cache_size)
665    dy_dx = tape.gradient(y, x)
666    # Force a evaluation.
667    self.evaluate(dy_dx)
668    # Cache entry consumed after gradient calculation.
669    self.assertEqual(len(script_ops.tape_cache), old_cache_size)
670
671  @test_util.run_in_graph_and_eager_modes
672  def testEagerGradientTape(self):
673
674    def f(x):
675      return x**2
676
677    x = constant_op.constant(3.0)
678    with backprop.GradientTape() as tape:
679      tape.watch(x)
680      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
681    dy_dx = tape.gradient(y, x)
682    self.assertAllClose(self.evaluate(dy_dx), 6.0)
683
684    # Test complex values
685    x = constant_op.constant(3.0 + 3.0j)
686    with backprop.GradientTape() as tape:
687      tape.watch(x)
688      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.complex128)
689    dy_dx = tape.gradient(y, x)
690    # Gradient of complex will be the conj
691    self.assertAllClose(self.evaluate(dy_dx), 6.0 - 6.0j)
692
693  @test_util.run_v1_only("b/120545219")
694  def testEagerGradientGraph(self):
695
696    def f(x):
697      return x**2
698
699    x = constant_op.constant(3.0)
700    y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
701    dy_dx = gradients_impl.gradients(y, x)[0]
702    self.assertEqual(self.evaluate(dy_dx), 6.0)
703
704  @test_util.run_v1_only("b/120545219")
705  def testEagerGradientGraphTwoOutputs(self):
706
707    def f(x, y):
708      return x * y, x / y
709
710    x = constant_op.constant(3.0)
711    y = constant_op.constant(2.0)
712    fa, fb = script_ops.eager_py_func(f, inp=[x, y],
713                                      Tout=[dtypes.float32, dtypes.float32])
714    dy_dx = gradients_impl.gradients(fa + fb, x)[0]
715    self.assertEqual(self.evaluate(dy_dx), 2.5)
716
717  @test_util.run_in_graph_and_eager_modes
718  def testEagerGradientTapeMultipleArgs(self):
719
720    def f(x, y):
721      return x**2 + y**2
722
723    x = constant_op.constant(3.0)
724    y = constant_op.constant(4.0)
725    with backprop.GradientTape() as tape:
726      tape.watch(x)
727      tape.watch(y)
728      z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
729
730    dz_dx, dz_dy = tape.gradient(z, [x, y])
731    self.assertEqual(self.evaluate(dz_dx), 6.0)
732    self.assertEqual(self.evaluate(dz_dy), 8.0)
733
734  @test_util.run_v1_only("b/120545219")
735  def testEagerGradientGraphMultipleArgs(self):
736
737    def f(x, y):
738      return x**2 + y**2
739
740    x = constant_op.constant(3.0)
741    y = constant_op.constant(4.0)
742    z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
743
744    dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
745    self.assertEqual(self.evaluate(dz_dx), 6.0)
746    self.assertEqual(self.evaluate(dz_dy), 8.0)
747
748  @test_util.run_v1_only("b/120545219")
749  def testEagerGradientGraphLogHuber(self):
750
751    def log_huber(x, m):
752      if math_ops.abs(x) <= m:
753        return x**2
754      else:
755        return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))
756
757    x = array_ops.placeholder(dtypes.float32)
758    m = array_ops.placeholder(dtypes.float32)
759
760    y = script_ops.eager_py_func(
761        func=log_huber, inp=[x, m], Tout=dtypes.float32)
762    dy_dx = gradients_impl.gradients(y, x)[0]
763
764    with self.cached_session() as sess:
765      # Takes the first branch of log_huber.
766      y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
767      self.assertEqual(y, 1.0)
768      self.assertEqual(dy_dx, 2.0)
769
770  @test_util.run_v1_only("b/120545219")
771  def testEagerRespectsDevicePlacementOfOp(self):
772
773    def f(x):
774      return math_ops.square(x)
775
776    def g(x):
777      return math_ops.add(x, x)
778
779    with ops.device("/CPU:0"):
780      # Explicitly ask for the py_funcs to execute on CPU, even if
781      # a GPU is available.
782      x = array_ops.placeholder(dtypes.float32)
783      y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
784      z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)
785
786    with self.session() as sess:
787      output = sess.run(z, feed_dict={x: 3.0})
788      self.assertEqual(output, 18.0)
789
790  @test_util.run_in_graph_and_eager_modes
791  def testEagerPyFuncOnGPUWithStrings(self):
792
793    def fn(a):
794      return str(a.dtype)
795
796    x = constant_op.constant("x", dtype=dtypes.string)
797    output = script_ops.eager_py_func(fn, inp=[x], Tout=dtypes.string)
798    self.assertEqual(self.evaluate(output), "<dtype: 'string'>".encode("utf8"))
799
800  @test_util.run_in_graph_and_eager_modes
801  def testEagerPyFuncNotACallable(self):
802    x = constant_op.constant("x", dtype=dtypes.string)
803
804    with self.assertRaisesRegex(ValueError, "callable"):
805      _ = script_ops.eager_py_func(x, inp=[x], Tout=dtypes.string)
806
807  def testUnsupportedToutType(self):
808    with self.assertRaisesRegex(
809        TypeError, "Cannot convert .* to a TensorFlow DType."):
810      script_ops.eager_py_func(lambda x: x, [1], [{}])
811
812  def testRaggedTensorArg(self):
813    x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
814    y, = script_ops.eager_py_func(math_ops.reduce_sum, [x], [dtypes.int32])
815    self.assertAllEqual(y, 21)
816
817  def testRaggedTensorReturn(self):
818
819    def fn(v, l):
820      return ragged_tensor.RaggedTensor.from_row_lengths(v, l)
821
822    values = [1, 2, 3, 4, 5, 6]
823    lengths = constant_op.constant([3, 1, 2], dtypes.int64)
824    out_signature = [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]
825    y, = script_ops.eager_py_func(fn, [values, lengths], out_signature)
826    self.assertIsInstance(y, ragged_tensor.RaggedTensor)
827    self.assertAllEqual(y, [[1, 2, 3], [4], [5, 6]])
828
829  def testRaggedTensorBroadcast(self):
830    # Check that eager_py_func preserves output shape information, which is
831    # required by broadcasting.
832    def fn(x):
833      return 2 * x
834
835    def foo(x):
836      spec = ragged_tensor.RaggedTensorSpec.from_value(x)
837      res = script_ops.eager_py_func(fn, [x], spec)
838      return x + res
839
840    x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]])
841    expected_result = [[3.0, 6.0], [9.0]]
842    result1 = foo(x)
843    self.assertAllEqual(result1, expected_result)
844    result2 = def_function.function(foo)(x)
845    self.assertAllEqual(result2, expected_result)
846
847  def testRaggedExpectedListGotList(self):
848    x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
849    x_spec = type_spec.type_spec_from_value(x)
850    y, = script_ops.eager_py_func(lambda v: [v], [x], [x_spec])
851    self.assertAllEqual(y, x)
852
853  def testRaggedExpectedListGotTuple(self):
854    x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
855    x_spec = type_spec.type_spec_from_value(x)
856    y, = script_ops.eager_py_func(lambda v: (v,), [x], [x_spec])
857    self.assertAllEqual(y, x)
858
859  def testRaggedExpectedListGotSingleValue(self):
860    x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
861    x_spec = type_spec.type_spec_from_value(x)
862    y, = script_ops.eager_py_func(lambda v: v, [x], [x_spec])
863    self.assertAllEqual(y, x)
864
865  def testRaggedNoReturnValue(self):
866    x = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6]])
867    result = self.evaluate(script_ops.eager_py_func(lambda v: None, [x], []))
868    if context.executing_eagerly():
869      self.assertEqual(result, [])
870    else:
871      self.assertIsNone(result)
872
873  def testRaggedBadReturnTypeExpectedTensorReturnedRagged(self):
874    rt = ragged_factory_ops.constant([[1, 2], [3, 4, 5]])
875    with self.assertRaisesRegex(
876        (ValueError, errors.InvalidArgumentError),
877        "py_function: func=.* returned .* which did not match Tout=.*"):
878      result = script_ops.eager_py_func(lambda x: x + 3, [rt], [dtypes.int32])
879      self.evaluate(result)
880
881  def testRaggedBadReturnTypeExpectedRaggedReturnedTensor(self):
882    with self.assertRaisesRegex(
883        (ValueError, errors.InvalidArgumentError),
884        "py_function: func=.* returned .* which did not match Tout=.*"):
885      result = script_ops.eager_py_func(
886          func=lambda x: x,
887          inp=[constant_op.constant([[1, 2, 3]])],
888          Tout=[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)])
889      self.evaluate(result)
890
891
892if __name__ == "__main__":
893  test.main()
894