xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/function_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import collections
17import copy
18import functools
19import itertools
20import multiprocessing.pool
21import os
22import re
23import sys
24import time
25import weakref
26
27from absl.testing import parameterized
28import numpy
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python.autograph.core import ag_ctx
33from tensorflow.python.autograph.lang import directives
34from tensorflow.python.data.ops import dataset_ops
35from tensorflow.python.data.ops import iterator_ops
36from tensorflow.python.eager import backprop
37from tensorflow.python.eager import cancellation
38from tensorflow.python.eager import context
39from tensorflow.python.eager import def_function
40from tensorflow.python.eager import function
41from tensorflow.python.framework import composite_tensor
42from tensorflow.python.framework import config
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import dtypes
45from tensorflow.python.framework import errors
46from tensorflow.python.framework import func_graph
47from tensorflow.python.framework import function as tf_function
48from tensorflow.python.framework import indexed_slices
49from tensorflow.python.framework import ops
50from tensorflow.python.framework import random_seed
51from tensorflow.python.framework import sparse_tensor
52from tensorflow.python.framework import tensor_shape
53from tensorflow.python.framework import tensor_spec
54from tensorflow.python.framework import test_ops
55from tensorflow.python.framework import test_util
56from tensorflow.python.framework import type_spec
57from tensorflow.python.layers import convolutional
58from tensorflow.python.module import module
59from tensorflow.python.ops import array_ops
60from tensorflow.python.ops import check_ops
61from tensorflow.python.ops import clip_ops
62from tensorflow.python.ops import control_flow_ops
63from tensorflow.python.ops import data_flow_ops
64from tensorflow.python.ops import functional_ops
65from tensorflow.python.ops import gen_functional_ops
66from tensorflow.python.ops import gen_random_ops
67from tensorflow.python.ops import gen_resource_variable_ops
68from tensorflow.python.ops import gen_sendrecv_ops
69from tensorflow.python.ops import gradients_impl
70from tensorflow.python.ops import init_ops
71from tensorflow.python.ops import list_ops
72from tensorflow.python.ops import logging_ops
73from tensorflow.python.ops import math_ops
74from tensorflow.python.ops import random_ops
75from tensorflow.python.ops import resource_variable_ops
76from tensorflow.python.ops import script_ops
77from tensorflow.python.ops import string_ops
78from tensorflow.python.ops import variable_scope
79from tensorflow.python.ops import variables
80from tensorflow.python.ops.ragged import ragged_factory_ops
81from tensorflow.python.ops.ragged import ragged_tensor
82from tensorflow.python.ops.structured import structured_tensor
83from tensorflow.python.platform import test
84from tensorflow.python.saved_model.load import load
85from tensorflow.python.saved_model.save import save
86from tensorflow.python.training import training_ops
87from tensorflow.python.util import compat
88from tensorflow.python.util import nest
89from tensorflow.python.util import tf_decorator
90from tensorflow.python.util import tf_inspect
91
92try:
93  import attr  # pylint:disable=g-import-not-at-top
94except ImportError:
95  attr = None
96
97
98def total_function_cache(defined):
99  return defined._list_all_concrete_functions()  # pylint: disable=protected-access
100
101
102def _example_indexed_slices_with_dense_shape():
103  return indexed_slices.IndexedSlices(
104      constant_op.constant([1, 2]), constant_op.constant([0, 1]),
105      constant_op.constant([2]))
106
107
108def _example_indexed_slices_without_dense_shape():
109  return indexed_slices.IndexedSlices(
110      constant_op.constant([1, 2]), constant_op.constant([0, 1]))
111
112
113def _spec_for_value(value):
114  """Returns the (nested) TypeSpec for a value."""
115  if nest.is_nested(value):
116    return nest.map_structure(_spec_for_value, value)
117  elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
118    return type_spec.type_spec_from_value(value)
119  else:
120    return value
121
122
123# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
124def dummy_tf_decorator(method):
125
126  def wrapper(*args, **kwargs):
127    return method(*args, **kwargs)
128
129  return tf_decorator.make_decorator(method, wrapper)
130
131
132# TODO(mdan): Organize these tests.
133class FunctionTest(test.TestCase, parameterized.TestCase):
134
135  def setUp(self):
136    super().setUp()
137    cpus = config.list_physical_devices('CPU')
138    # Set 4 virtual CPUs
139    config.set_logical_device_configuration(cpus[0], [
140        context.LogicalDeviceConfiguration(),
141        context.LogicalDeviceConfiguration(),
142        context.LogicalDeviceConfiguration(),
143        context.LogicalDeviceConfiguration()
144    ])
145
146  def testBasic(self):
147    matmul = def_function.function(math_ops.matmul)
148    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
149    sq = matmul(t, t, transpose_a=True)
150    sq2 = matmul(sq, t, transpose_a=True)
151    self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
152    self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
153
154  def testPythonFunctionNotCallable(self):
155    with self.assertRaisesRegex(TypeError, 'is not a callable object'):
156      def_function.function(1)
157
158  def testOnExitCallback(self):
159    values = []
160    def append_1():
161      values.append(1)
162
163    def append_2():
164      values.append(2)
165
166    def g(x):
167      old_values = list(values)
168      ops.add_exit_callback_to_default_func_graph(append_1)
169      self.assertEqual(old_values, values)
170      return x + 1
171
172    tf_g = def_function.function(g)
173
174    def f(x):
175      old_values = list(values)
176      ops.add_exit_callback_to_default_func_graph(append_2)
177      self.assertEqual(old_values, values)
178      return tf_g(x)
179
180    tf_f = def_function.function(f)
181    self.assertEmpty(values)
182    tf_f(constant_op.constant(1.0))
183    self.assertEqual(values, [1, 2])  # Once for g, once for f.
184    tf_f(constant_op.constant([1.0]))  # force a retrace
185    self.assertEqual(values, [1, 2, 1, 2])  # And again.
186
187  def testCannotAddExitCallbackWhenNotInFunctionScope(self):
188    with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
189      ops.add_exit_callback_to_default_func_graph(lambda: None)
190
191  def testVariable(self):
192    v1 = variables.Variable(1.0)
193    add = def_function.function(lambda x, v: x + v1 + v)
194    v2 = variables.Variable(1.0)
195    x = constant_op.constant(1.0)
196    r = add(x, v2)
197    self.assertEqual(3.0, self.evaluate(r))
198
199  def testVariableOnly(self):
200    v = variables.Variable(1.0)
201    add = def_function.function(lambda x: x.assign_add(1.0))
202    r1 = add(v)
203    self.assertEqual(2.0, self.evaluate(r1))
204    c = constant_op.constant(1.0)
205    with self.assertRaisesRegex(AttributeError, 'no attribute'):
206      add(c)
207
208  def testVariableMultiFunction(self):
209    @def_function.function
210    def second(dup_var, dup_var_2, some_const):
211      return dup_var + dup_var_2 + some_const
212
213    @def_function.function
214    def first(dup_var, some_const):
215      return second(dup_var, dup_var, some_const)
216
217    my_const = constant_op.constant(1)
218    my_var = variables.Variable(2, dtype=dtypes.int32)
219    self.assertEqual(second(my_var, my_var, my_const).numpy(), 5)
220    self.assertEqual(first(my_var, my_const).numpy(), 5)
221
222  @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
223  def testPackedVariable(self):
224    with ops.device('/cpu:0'):
225      v0_0 = resource_variable_ops.ResourceVariable(1.0)
226    with ops.device('/cpu:1'):
227      v0_1 = resource_variable_ops.ResourceVariable(2.0)
228      v1_0 = resource_variable_ops.ResourceVariable(3.0)
229    with ops.device('/cpu:2'):
230      v1_1 = resource_variable_ops.ResourceVariable(4.0)
231
232    packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
233    packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
234
235    # TODO(b/145922293): use ResourceVariable.assign_add and
236    # ResourceVariable.read_value directly once we support packing multiple
237    # ResourceVariable into one ResourceVariable.
238    @def_function.function
239    def read_var():
240      resource_variable_ops.assign_add_variable_op(
241          packed_var_0, constant_op.constant(5.0))
242      resource_variable_ops.assign_add_variable_op(
243          packed_var_1, constant_op.constant(6.0))
244      with ops.device('/cpu:0'):
245        read0 = resource_variable_ops.read_variable_op(
246            packed_var_0, dtype=dtypes.float32)
247      with ops.device('/cpu:1'):
248        read1 = resource_variable_ops.read_variable_op(
249            packed_var_0, dtype=dtypes.float32)
250        read2 = resource_variable_ops.read_variable_op(
251            packed_var_1, dtype=dtypes.float32)
252      with ops.device('/cpu:2'):
253        read3 = resource_variable_ops.read_variable_op(
254            packed_var_1, dtype=dtypes.float32)
255
256      return read0, read1, read2, read3
257
258    arg_attrs = read_var.get_concrete_function().function_def.arg_attr
259    self.assertLen(arg_attrs, 2)
260    self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
261                     compat.as_bytes(packed_var_0.device))
262    self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
263                     compat.as_bytes(packed_var_1.device))
264
265    self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
266
267  def testImplementsAttributeBasic(self):
268    v = def_function.function(
269        experimental_implements='func')(lambda x, y: x + y)
270    with context.graph_mode(), self.cached_session():
271      a = array_ops.placeholder(dtypes.float32, ())
272      b = array_ops.placeholder(dtypes.float32, ())
273      v(a, b)
274      gradients_impl.gradients(v(a, b), [a, b])
275      fdefs = ops.get_default_graph().as_graph_def().library.function
276      self.assertLen(fdefs, 3)
277      not_present = 0
278      present = 0
279      for f in fdefs:
280        name = f.signature.name
281        if 'forward' in name or 'backward' in name:
282          not_present += 1
283          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
284        else:
285          present += 1
286          self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s,
287                           'func'.encode('ascii'), f)
288      self.assertEqual(not_present, 2, fdefs)
289      self.assertEqual(present, 1, fdefs)
290
291  def testImplementsAttributeAssertsOnSideInput(self):
292    with context.graph_mode(), self.cached_session():
293      z = array_ops.zeros(0)
294      v = def_function.function(
295          experimental_implements='func')(lambda x, y: x + y + z)
296      a = array_ops.ones((1,))
297      b = array_ops.ones((1,))
298      with self.assertRaisesRegex(AssertionError,
299                                  'variables are always captured'):
300        v(a, b)
301      functions = ops.get_default_graph().as_graph_def().library.function
302      self.assertEmpty(functions)
303
304  def testImplementsAttributeWorksWithGradientTape(self):
305    add = lambda x, y: x + y ** 2
306    add = def_function.function(experimental_implements='MyFunc')(add)
307    x = variables.Variable(3.0)
308    y = variables.Variable(2.0)
309
310    with backprop.GradientTape() as tape:
311      g = add(x, y)
312
313    dg_dy, dg_dx = tape.gradient(g, [y, x])
314    self.assertEqual(dg_dy.numpy(), 4.0)
315    self.assertEqual(dg_dx.numpy(), 1.0)
316
317  def testImplementsAttributeWorksOnVariables(self):
318    with context.graph_mode(), self.cached_session():
319      v = def_function.function(
320          experimental_implements='func')(lambda x, y: x + y)
321      a = variables.Variable((1.0,))
322      b = variables.Variable((1.0,))
323      r1 = v(a, b)
324      _ = v(a, a)
325      functions = ops.get_default_graph().as_graph_def().library.function
326      # Verify that we created only one function
327      self.assertLen(functions, 1)
328      # Verify that eval() reads the current values.
329      a.initializer.run()
330      b.initializer.run()
331      self.assertEqual(r1.eval(), 2)
332
333      a.assign_add([1]).eval()
334      self.assertEqual(r1.eval(), 3)
335
336  def testImplementsAttributeWorksOnConstants(self):
337    with context.graph_mode(), self.cached_session():
338      v = def_function.function(
339          experimental_implements='func')(lambda x, y: x + y)
340      a = variables.Variable(1.0)
341      r1 = v(a, 2.)
342      r2 = v(2., a)
343      functions = ops.get_default_graph().as_graph_def().library.function
344      self.assertLen(functions, 1)
345      self.assertLen(functions[0].signature.input_arg, 2)
346      # Verify that eval() reads the current values.
347      a.initializer.run()
348      self.assertEqual(r1.eval(), 3)
349      self.assertEqual(r2.eval(), 3)
350
351  def testImplementsAttributeSpecializes(self):
352    with context.graph_mode(), self.cached_session():
353      v = def_function.function(
354          experimental_implements='func')(lambda x, y: x + y)
355      a = variables.Variable(1.0)
356      r1 = v(a, [2.])
357      r2 = v([2., 2], a)
358      functions = ops.get_default_graph().as_graph_def().library.function
359      self.assertLen(functions, 2)
360      # Ensure that all parameters are still there and haven't been inlined!
361
362      self.assertLen(functions[0].signature.input_arg, 2)
363      self.assertLen(functions[1].signature.input_arg, 2)
364      # Verify that eval() reads the current values.
365      a.initializer.run()
366      numpy.testing.assert_equal(r1.eval(), [3.])
367      numpy.testing.assert_equal(r2.eval(), [3., 3.])
368
369  def testImplementsWorksWithTensorSpec(self):
370    v = def_function.function(
371        experimental_implements='func')(lambda x, y: x + y)
372    v = v.get_concrete_function(
373        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32),
374        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32))
375    x = v(1., 2.)
376    self.assertEqual(x.numpy(), 3.)
377
378  def testImplementsAttributeAsNameAttrList(self):
379    implements_attr = (
380        'name: "embedding_matmul" attr {   key: "key1"   value {     i: 2   } '
381        '} attr {   key: "key2"   value {     b: false   } }')
382    v = def_function.function(
383        experimental_implements=implements_attr)(lambda x, y: x + y)
384    with context.graph_mode(), self.cached_session():
385      a = array_ops.placeholder(dtypes.float32, ())
386      b = array_ops.placeholder(dtypes.float32, ())
387      v(a, b)
388      gradients_impl.gradients(v(a, b), [a, b])
389      fdefs = ops.get_default_graph().as_graph_def().library.function
390      self.assertLen(fdefs, 3)
391      not_present = 0
392      present = 0
393      for f in fdefs:
394        name = f.signature.name
395        if 'forward' in name or 'backward' in name:
396          not_present += 1
397          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
398        else:
399          present += 1
400          attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME]
401          self.assertIsNotNone(attr_value.func, f)
402          self.assertEqual(attr_value.func.name, 'embedding_matmul')
403          name_attrs = attr_value.func.attr
404          self.assertLen(name_attrs, 2)
405      self.assertEqual(not_present, 2, fdefs)
406      self.assertEqual(present, 1, fdefs)
407
408  def testExternalControlDependency(self):
409    with ops.Graph().as_default(), self.test_session():
410      v = variables.Variable(1.0)
411      v.initializer.run()
412
413      op = v.assign_add(1.0)
414
415      @function.defun
416      def f():
417        with ops.control_dependencies([op]):
418          return 1.0
419
420      self.evaluate(f())
421      self.assertAllEqual(self.evaluate(v), 2.0)
422
423  def testInputShapeFunctionRelaxation(self):
424    unknown_dim = [False]
425
426    @function.defun(reduce_retracing=True)
427    def func(a):
428      if a._shape_tuple()[0] is None:
429        unknown_dim[0] = True
430      return a + 1
431
432    func(constant_op.constant([]))
433    self.assertFalse(unknown_dim[0])
434    self.assertLen(total_function_cache(func), 1)
435
436    func(constant_op.constant([1.0]))
437    self.assertTrue(unknown_dim[0])
438    self.assertLen(total_function_cache(func), 2)
439
440    func(constant_op.constant([1.0, 2.0]))
441    self.assertTrue(unknown_dim[0])
442    self.assertLen(total_function_cache(func), 2)
443
444  def testInputShapeRelaxationOnInstanceMethod(self):
445    # Test that reduce_retracing is passed during
446    # instance method bounding.
447    unknown_dim = [False]
448
449    class Foo:
450
451      @def_function.function(reduce_retracing=True)
452      def func(self, a):
453        if a._shape_tuple()[0] is None:
454          unknown_dim[0] = True
455        return a + 1
456
457    foo = Foo()
458    foo.func(constant_op.constant([]))
459    self.assertFalse(unknown_dim[0])
460
461    foo.func(constant_op.constant([1.0]))
462    self.assertTrue(unknown_dim[0])
463
464    foo.func(constant_op.constant([1.0, 2.0]))
465    self.assertTrue(unknown_dim[0])
466
467  def testInputShapeFunctionRelaxationWithRaggedTensors(self):
468    traced_type_spec = [None]
469
470    @def_function.function(reduce_retracing=True)
471    def func(x):
472      traced_type_spec[0] = x._type_spec
473      return x
474
475    def check_trace(x, expected_trace):
476      traced_type_spec[0] = None
477      func(x)
478      self.assertEqual(traced_type_spec[0], expected_trace)
479
480    check_trace(  # Initial call gets traced.
481        ragged_factory_ops.constant([[1], [2, 3, 4]]),
482        ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
483    check_trace(  # Input TypeSpec is the same -> no retrace.
484        ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
485    check_trace(  # Even if component tensor shapes change -> no retrace.
486        ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
487    check_trace(  # Different TypeSpec shape (nrows): relax & retrace
488        ragged_factory_ops.constant([[1], [2], [3]]),
489        ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
490    check_trace(  # Different nrows again: relax & retrace
491        ragged_factory_ops.constant([[1], [2], [3], [4]]), None)
492    check_trace(  # Different nrows yet again: not retrace
493        ragged_factory_ops.constant([[1]]), None)
494    check_trace(  # Different ragged_rank: retrace
495        ragged_factory_ops.constant([[[1]]]),
496        ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
497    check_trace(  # Different ragged_rank again: retrace & relax
498        ragged_factory_ops.constant([[[1]], [[2]]]),
499        ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
500
501  def testInputShapeFunctionRelaxationWithStructuredTensors(self):
502    traced_type_spec = [None]
503
504    @def_function.function(reduce_retracing=True)
505    def func(x):
506      traced_type_spec[0] = x._type_spec
507      return x
508
509    def check_trace(x, expected_trace):
510      traced_type_spec[0] = None
511      func(x)
512      self.assertEqual(traced_type_spec[0], expected_trace)
513
514    # If we have TypeSpecs that differ in ways other than just their shape,
515    # then retrace each time.
516    check_trace(
517        structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
518        structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
519            fields={'a': tensor_spec.TensorSpec((1,), dtypes.int32)},
520            rank=0))
521    check_trace(
522        structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
523        structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
524            fields={'b': tensor_spec.TensorSpec((1,), dtypes.int32)},
525            rank=0))
526    check_trace(
527        structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
528        structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
529            fields={'c': tensor_spec.TensorSpec((1,), dtypes.int32)},
530            rank=0))
531
532    # But if we call again with only shape different, then do relax:
533    check_trace(  # relax & retrace
534        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
535        structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
536            fields={'a': tensor_spec.TensorSpec((None,), dtypes.int32)},
537            rank=0))
538    check_trace(   # use relaxed graph
539        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}),
540        None)
541    check_trace(  # use relaxed graph
542        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
543        None)
544
545  def testInputShapeFunctionRelaxationWithDatasetIterators(self):
546    # For dataset iterators, the TypeSpec includes type information that's
547    # not derivable from the component tensors.  Make sure that the TypeSpec
548    # shapes get relaxed as appropriate.
549
550    traced_type_spec = [None]
551
552    @def_function.function(reduce_retracing=True)
553    def func(x):
554      traced_type_spec[0] = x._type_spec
555      return x
556
557    def check_trace(x, expected_trace):
558      traced_type_spec[0] = None
559      func(x)
560      self.assertEqual(traced_type_spec[0], expected_trace)
561
562    ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
563    ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
564    ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
565    ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
566    ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
567    check_trace(  # shape=[1, 2]: retrace
568        dataset_ops.make_one_shot_iterator(ds_1_2),
569        iterator_ops.IteratorSpec(
570            tensor_spec.TensorSpec([1, 2], dtypes.float32)))
571    check_trace(  # shape=[1, 2]: no retrace (use the [1, 2] graph)
572        dataset_ops.make_one_shot_iterator(ds_1_2), None)
573    check_trace(  # shape=[2, 2]: relax to [None, 2] and retrace
574        dataset_ops.make_one_shot_iterator(ds_2_2),
575        iterator_ops.IteratorSpec(
576            tensor_spec.TensorSpec([None, 2], dtypes.float32)))
577    check_trace(  # shape=[3, 2]: no retrace (use the [None, 2] graph)
578        dataset_ops.make_one_shot_iterator(ds_3_2), None)
579    check_trace(  # shape=[4, 2]: no retrace (use the [None, 2] graph)
580        dataset_ops.make_one_shot_iterator(ds_4_2), None)
581    check_trace(  # shape=[2, 1]: relax to [None, None] and retrace
582        dataset_ops.make_one_shot_iterator(ds_2_1),
583        iterator_ops.IteratorSpec(
584            tensor_spec.TensorSpec([None, None], dtypes.float32)))
585
586  def testCapturesVariables(self):
587    a = variables.Variable(1.0, trainable=False)
588    b = variables.Variable(1.0)
589    cc = [None]
590
591    @def_function.function
592    def f():
593      c = cc[0]
594      if c is None:
595        c = cc[0] = variables.Variable(1.)
596      return a + b + c + 1
597
598    cf = f.get_concrete_function()
599    c = cc[0]
600
601    captured_variables = {v.ref() for v in (a, b, c)}
602    trainable_variables = {v.ref() for v in (b, c)}
603    self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
604    self.assertEqual({v.ref() for v in cf.trainable_variables},
605                     trainable_variables)
606    self.assertEqual(cf.variables, cf.graph.variables)
607    self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
608
609  def testNestedInputShapeFunctionRelaxation(self):
610    unknown_dim = [False]
611
612    @function.defun(reduce_retracing=True)
613    def func(a_, b_=None):
614      del a_  # Only used to check which cache is used.
615      self.assertEqual(b_[0]._shape_tuple(), ())
616      if b_[1]._shape_tuple()[0] is None:
617        unknown_dim[0] = True
618      return b_[0] + 1
619
620    a = 'hi'
621    b0 = constant_op.constant(1.0)
622    func(a, b_=[b0, constant_op.constant([])])
623    self.assertFalse(unknown_dim[0])
624    self.assertLen(total_function_cache(func), 1)
625
626    func(a, b_=[b0, constant_op.constant([1.0])])
627    self.assertTrue(unknown_dim[0])
628    self.assertLen(total_function_cache(func), 2)
629
630    func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
631    self.assertTrue(unknown_dim[0])
632    self.assertLen(total_function_cache(func), 2)
633
634    unknown_dim[0] = False
635
636    # Now do the same except with a new a which is not a tensor; this should
637    # change the cache key.
638    a = 'bye'
639    func(a, b_=[b0, constant_op.constant([])])
640    self.assertFalse(unknown_dim[0])
641    self.assertLen(total_function_cache(func), 3)
642
643    # We relax the type traced previously.
644    func(a, b_=[b0, constant_op.constant([1.0])])
645    self.assertTrue(unknown_dim[0])
646    self.assertLen(total_function_cache(func), 4)
647
648  def testNestedShapeFunctionRelaxation(self):
649    traced_shape = None
650    # The inner function will go through shape relaxation because the shapes it
651    # receives will be [1], [2], [3], ...
652    @def_function.function(reduce_retracing=True)
653    def bar(x_shape):
654      nonlocal traced_shape
655      traced_shape = x_shape._shape_tuple()
656      return x_shape
657
658    # The outer function will not go through shape relaxation because the shapes
659    # it receives will be [1], [[1]], [[[1]]], ...
660    @def_function.function(reduce_retracing=True)
661    def foo(ones):
662      return bar(array_ops.shape(ones))
663
664    self.assertAllEqual(self.evaluate(foo(array_ops.ones([1]))), [1])
665    self.assertEqual(traced_shape, (1,))
666
667    for rank in range(2, 6):
668      x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
669      self.assertAllEqual(x_shape, [1] * rank)
670      self.assertEqual(traced_shape, (None,))
671
672  def testNoHash(self):
673
674    @def_function.function()
675    def f(_):
676      return 1.0
677
678    with self.assertRaisesRegex(
679        TypeError,
680        r'could not be represented through the generic tracing'):
681      f(set([]))
682
683  def testFuncName(self):
684
685    @function.defun_with_attributes(attributes={'func_name': 'multiply'})
686    def add(x, y):
687      _ = x * y
688      return x + y
689
690    @function.defun
691    def add_2(x, y):
692      _ = x * y
693      return x + y
694
695    self.assertEqual(add._name, 'multiply')
696    self.assertEqual(add_2._name, 'add_2')
697
698  def testBasicGraphMode(self):
699    matmul = def_function.function(math_ops.matmul)
700
701    @def_function.function
702    def sq(a):
703      return matmul(a, a)
704
705    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
706    out = sq(t)
707    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
708
709  def testNestedInputsGraphMode(self):
710    matmul = def_function.function(math_ops.matmul)
711
712    pair = collections.namedtuple('pair', ['a', 'b'])
713
714    @def_function.function
715    def a_times_b(inputs):
716      return matmul(inputs.a['a'], inputs.b['b'])
717
718    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
719
720    out = a_times_b(pair({'a': t}, {'b': t}))
721    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
722
723  def testNestedOutputsGraphMode(self):
724    matmul = def_function.function(math_ops.matmul)
725
726    pair = collections.namedtuple('pair', ['a', 'b'])
727
728    @def_function.function()
729    def pairs_mul(pair_a, pair_b):
730      return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
731
732    a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
733    b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
734
735    out = pairs_mul(pair(a, b), pair(b, a))
736    expected = pair(math_ops.matmul(a, b).numpy(),
737                    math_ops.matmul(b, a).numpy())
738    self.assertAllClose(out, expected)
739
740  @parameterized.named_parameters(
741      dict(testcase_name='Defun',
742           function_decorator=function.defun),
743      dict(testcase_name='DefFunction',
744           function_decorator=def_function.function))
745  def testNestedFunctionGraphNotOutOfDate(self, function_decorator):
746    @function_decorator
747    def f():
748      return constant_op.constant(1.)
749
750    class _Model(object):
751
752      @function_decorator
753      def g(self):
754        self.f = f.get_concrete_function()
755
756    model = _Model()
757    model.g()
758    concrete = model.f
759    weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
760    self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
761    weak_g = weakref.ref(model.g)
762    del model
763    self.assertIsNone(weak_g())
764    self.assertIsNone(weak_g_graph())
765    self.assertIsNotNone(concrete.graph.outer_graph)
766    self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
767
768  def testGraphEagerIsolation(self):
769
770    @function.defun
771    def f():
772      self.v = variables.Variable(1.0)
773      return self.v.read_value()
774
775    self.assertAllEqual(f(), 1.0)
776
777    with ops.Graph().as_default():
778      self.assertEqual(f().shape, ())
779
780  def testBasicGraphFunction(self):
781    matmul = def_function.function(math_ops.matmul)
782
783    @def_function.function
784    def sq(a):
785      return matmul(a, a)
786
787    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
788
789    sq_op = sq.get_concrete_function(t)
790    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
791    out = sq_op(t)
792    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
793
794  def testGetConcreteFunctionThreadSafety(self):
795
796    @def_function.function
797    def sq():
798      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
799      return math_ops.matmul(t, t)
800
801    concrete_functions = []
802
803    def thread_func(_):
804      cf = sq.get_concrete_function()
805      concrete_functions.append(cf)
806
807    num_threads = 100
808    pool = multiprocessing.pool.ThreadPool(num_threads)
809    _ = pool.map(thread_func, list(range(num_threads)))
810
811    self.assertLen(set(concrete_functions), 1)
812
813  def testGetConcreteFunctionThreadSafetyWithArgs(self):
814    @def_function.function
815    def add_100(*args):
816      return math_ops.add_n(args)
817
818    p = multiprocessing.pool.ThreadPool(2)
819    args = (constant_op.constant(1.),) * 100
820    f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
821    # I see about len(args) + max(0, len(args) - 3) arguments expected.
822    f1(*args)
823    del f2
824
825  def testInputSpecGraphFunction(self):
826    matmul = def_function.function(math_ops.matmul)
827
828    @def_function.function
829    def sq(a):
830      return matmul(a, a)
831
832    sq_op = sq.get_concrete_function(
833        tensor_spec.TensorSpec((None, None), dtypes.float32))
834    self.assertEqual([None, None], sq_op.output_shapes.as_list())
835
836    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
837    out1 = sq_op(t1)
838    self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
839
840    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
841    out2 = sq_op(t2)
842    self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
843
844  def testNestedInputSpecGraphFunction(self):
845    matmul = def_function.function(math_ops.matmul)
846
847    @def_function.function
848    def sq(mats):
849      ((a, b),) = mats
850      return matmul(a, b)
851
852    sq_op_autonamed = sq.get_concrete_function(
853        [(tensor_spec.TensorSpec((None, None), dtypes.float32),
854          tensor_spec.TensorSpec((None, None), dtypes.float32))])
855    self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
856
857    sq_op = sq.get_concrete_function(
858        [(tensor_spec.TensorSpec((None, None), dtypes.float32,
859                                 name='first_mat'),
860          tensor_spec.TensorSpec((None, None), dtypes.float32,
861                                 name='second_mat'))])
862    self.assertEqual([None, None], sq_op.output_shapes.as_list())
863
864    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
865    t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
866    out = sq_op(first_mat=t1, second_mat=t2)
867    self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
868    self.assertAllEqual(sq_op_autonamed(t1, t2),
869                        math_ops.matmul(t1, t2).numpy())
870
871  def testExecutingStatelessDefunConcurrently(self):
872
873    @def_function.function
874    def stateless(x):
875      return math_ops.multiply(2.0, x)
876
877    pool = multiprocessing.pool.ThreadPool()
878    inputs = [constant_op.constant(1.0 * x) for x in range(100)]
879    outputs = [float(out) for out in pool.map(stateless, inputs)]
880    expected = [float(2.0 * x) for x in inputs]
881    self.assertSequenceEqual(outputs, expected)
882
883  def testExecutingManyStatelessDefunsConcurrently(self):
884
885    @def_function.function
886    def stateless(x):
887      del x
888      return math_ops.multiply(2.0, 2.0)
889
890    pool = multiprocessing.pool.ThreadPool()
891    # `pool.map` below instantiates 100 functions, one for each object.
892    objects = [object() for _ in range(100)]
893    outputs = [float(out) for out in pool.map(stateless, objects)]
894    expected = [4.0] * 100
895    self.assertSequenceEqual(outputs, expected)
896
897  @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
898  def testExecutingStatefulDefunConcurrently(self):
899
900    v = resource_variable_ops.ResourceVariable(1.0)
901
902    @def_function.function
903    def stateful(x):
904      v.assign(x)
905
906    pool = multiprocessing.pool.ThreadPool()
907    inputs = [constant_op.constant(0.0)] * 100
908    pool.map(stateful, inputs)
909    self.assertEqual(float(v.read_value()), 0.0)
910
911  def testExecutingManyStatefulDefunsConcurrently(self):
912
913    v = resource_variable_ops.ResourceVariable(1.0)
914
915    @def_function.function
916    def stateful(x):
917      del x
918      return v.assign(0.0)
919
920    pool = multiprocessing.pool.ThreadPool()
921    # `pool.map` below instantiates 100 functions, one for each object.
922    pool.map(stateful, [object() for _ in range(100)])
923    self.assertEqual(float(v.read_value()), 0.0)
924
925  def testShareRendezvous(self):
926
927    # Disable grappler from inlining the functions. Note we run the send & recv
928    # in graph mode since with eager mode the function should automatically be
929    # inlined.
930    context.context().set_optimizer_experimental_options(
931        {'disable_meta_optimizer': True})
932
933    cpu = '/device:CPU:0'
934
935    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
936
937    @def_function.function
938    def send():
939      x = constant_op.constant(1)
940      gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
941      return x
942
943    send._shared_rendezvous = True  # pylint: disable=protected-access
944
945    @def_function.function(input_signature=signature)
946    def send_body(n):
947      send()
948      return n - 1
949
950    @def_function.function
951    def recv():
952      return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
953
954    recv._shared_rendezvous = True  # pylint: disable=protected-access
955
956    @def_function.function(input_signature=signature)
957    def recv_body(n):
958      recv()
959      return n - 1
960
961    @def_function.function(input_signature=signature)
962    def cond(n):
963      return n > 0
964
965    # Instead of calling the send & recv functions directly we want to call them
966    # through a functional while to ensure the rendezvous is shared across the
967    # while boundary.
968    @def_function.function
969    def fn(n):
970      functional_ops.While([n], cond.get_concrete_function(),
971                           send_body.get_concrete_function())
972      return functional_ops.While([n], cond.get_concrete_function(),
973                                  recv_body.get_concrete_function())
974
975    # Use a graph context since functions will not be automatically inlined
976    with context.graph_mode(), self.cached_session():
977      self.evaluate(fn(2))
978
979  def disabled_testRandomSeed(self):
980
981    @def_function.function
982    def f():
983      return random_ops.random_normal(())
984
985    random_seed.set_random_seed(1)
986    x = f()
987    self.assertNotEqual(x, f())
988    random_seed.set_random_seed(1)
989    self.assertAllEqual(f(), x)
990
991  def testNestedInputsGraphFunction(self):
992    matmul = def_function.function(math_ops.matmul)
993
994    pair = collections.namedtuple('pair', ['a', 'b'])
995
996    @def_function.function
997    def a_times_b(inputs):
998      return matmul(inputs.a['a'], inputs.b['b'])
999
1000    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1001    sq_op = a_times_b.get_concrete_function(
1002        pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
1003             dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
1004    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
1005    out = sq_op(a=t, b=t)
1006    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
1007
1008  def testNestedOutputGraphFunction(self):
1009    matmul = def_function.function(math_ops.matmul)
1010
1011    @def_function.function
1012    def sq(a):
1013      return (matmul(a, a), {'b': constant_op.constant(1.0)})
1014
1015    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1016
1017    sq_op = sq.get_concrete_function(t)
1018    self.assertEqual(sq_op.output_shapes,
1019                     (tensor_shape.TensorShape([2, 2]),
1020                      {'b': tensor_shape.TensorShape([])}))
1021    self.assertEqual(sq_op.output_dtypes,
1022                     (dtypes.float32, {'b': dtypes.float32}))
1023    (a, b) = sq_op(t)
1024    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
1025    self.assertAllEqual(b['b'].numpy(), 1.0)
1026
1027  def testGraphFunctionNoneOutput(self):
1028    @def_function.function
1029    def fn(unused_a, unused_b):
1030      return None
1031
1032    x = constant_op.constant(1)
1033    fn_op = fn.get_concrete_function(x, x)
1034    self.assertEqual(fn_op.output_dtypes, None)
1035    self.assertEqual(fn_op.output_shapes, None)
1036    self.assertAllEqual(fn_op(x, x), None)
1037
1038  def testDefunNumpyArraysConvertedToTensors(self):
1039
1040    def f(x):
1041      self.assertIsInstance(x, ops.Tensor)
1042      return x
1043
1044    x = random_ops.random_uniform([2, 2]).numpy()
1045    defined = function.defun(f)
1046    defined(x)
1047    self.assertLen(total_function_cache(defined), 1)
1048
1049    x = random_ops.random_uniform([2, 2]).numpy()
1050    defined(x)
1051    # A NumPy array with different values but the same shape and dtype
1052    # shouldn't trigger another function definition.
1053    self.assertLen(total_function_cache(defined), 1)
1054
1055    np_ones = numpy.ones([], numpy.float32)
1056    np_zeros = numpy.zeros([], numpy.float32)
1057    tf_ones = array_ops.ones([])
1058    tf_zeros = array_ops.zeros([])
1059
1060    # Test that the numpy array is properly an argument to the graph function.
1061    self.assertEqual(1., defined(np_ones).numpy())
1062    self.assertLen(total_function_cache(defined), 2)
1063    self.assertEqual(0., defined(np_zeros).numpy())
1064    self.assertEqual(1., defined(tf_ones).numpy())
1065    self.assertEqual(0., defined(tf_zeros).numpy())
1066    self.assertLen(total_function_cache(defined), 2)
1067
1068    # Test that mutable inputs are supported.
1069    mutable = numpy.ones([], numpy.float32)
1070    self.assertEqual(1., defined(mutable).numpy())
1071    mutable.fill(0)
1072    self.assertEqual(0., defined(mutable).numpy())
1073
1074    class MyNdarray(numpy.ndarray):
1075      pass
1076
1077    # Test that the subclasses of ndarray are converted too.
1078    self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
1079    self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
1080
1081    # We should not have triggered any re-tracing of the python function.
1082    self.assertLen(total_function_cache(defined), 2)
1083
1084  def testNumpyDtypeInputSupported(self):
1085    @function.defun
1086    def f(x, dtype):
1087      return constant_op.constant(dtype(x))
1088
1089    self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
1090    self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
1091    self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
1092    self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
1093
1094  def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
1095
1096    def f(**kwargs):
1097      x = kwargs.pop('x')
1098      self.assertIsInstance(x, ops.Tensor)
1099      return x
1100
1101    x = random_ops.random_uniform([2, 2]).numpy()
1102    defined = function.defun(f)
1103    defined(x=x)
1104    self.assertLen(total_function_cache(defined), 1)
1105
1106    x = random_ops.random_uniform([2, 2]).numpy()
1107    defined(x=x)
1108    # A NumPy array with different values but the same shape and dtype
1109    # shouldn't trigger another function definition.
1110    self.assertLen(total_function_cache(defined), 1)
1111
1112    # Test that the numpy array is properly an argument to the graph function.
1113    self.assertEqual(1., defined(x=numpy.ones([])).numpy())
1114    self.assertEqual(0., defined(x=numpy.zeros([])).numpy())
1115    self.assertEqual(1., defined(x=array_ops.ones([])).numpy())
1116    self.assertEqual(0., defined(x=array_ops.zeros([])).numpy())
1117
1118  def testDefunCapturedInt32(self):
1119    x = constant_op.constant(1, dtype=dtypes.int32)
1120
1121    @def_function.function
1122    def add_int32s():
1123      return x + x
1124
1125    self.assertEqual(2, int(add_int32s()))
1126
1127  def testDefunReadVariable(self):
1128    v = resource_variable_ops.ResourceVariable(1.0)
1129
1130    @def_function.function
1131    def f():
1132      return v.read_value()
1133
1134    self.assertEqual(1.0, float(f()))
1135
1136  def testDefunAssignAddVariable(self):
1137    v = resource_variable_ops.ResourceVariable(1.0)
1138    x = constant_op.constant(2.0)
1139
1140    @def_function.function
1141    def test_assign_add():
1142      v.assign_add(x)
1143      return v.read_value()
1144
1145    self.assertEqual(3.0, float(test_assign_add()))
1146
1147  @test_util.run_in_graph_and_eager_modes
1148  def testTensorInitializationInFunctionRaisesError(self):
1149
1150    @def_function.function
1151    def tensor_init():
1152      with self.assertRaisesRegex(ValueError, 'could not be lifted out'):
1153        resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
1154
1155    tensor_init()
1156
1157  @test_util.run_in_graph_and_eager_modes
1158  def testCallableTensorInitializationInFunction(self):
1159
1160    @def_function.function
1161    def tensor_init():
1162      self.v = resource_variable_ops.ResourceVariable(
1163          lambda: constant_op.constant(2.0))
1164      return self.v.read_value()
1165
1166    value = tensor_init()
1167    if not context.executing_eagerly():
1168      self.evaluate(variables.global_variables_initializer())
1169    self.assertEqual(self.evaluate(value), 2.0)
1170
1171  @test_util.also_run_as_tf_function
1172  def testInitScopeTensorInitializationInFunction(self):
1173
1174    @def_function.function
1175    def tensor_init():
1176      with ops.init_scope():
1177        const = constant_op.constant(2.0)
1178      # Note: this variable bypasses tf.function's variable creation
1179      # requirements by bypassing variable_creator_scope by using
1180      # ResourceVariable instead of Variable.
1181      self.v = resource_variable_ops.ResourceVariable(const)
1182      return self.v.read_value()
1183
1184    value = tensor_init()
1185    self.assertAllEqual(value, 2.0)
1186
1187  @test_util.run_in_graph_and_eager_modes
1188  def testGetConcreteFunctionCreatesVariables(self):
1189
1190    v_holder = []
1191
1192    @def_function.function
1193    def tensor_init():
1194      if not v_holder:
1195        v_holder.append(variables.Variable(5.))
1196      return v_holder[0].read_value()
1197
1198    concrete = tensor_init.get_concrete_function()
1199    self.evaluate(variables.global_variables_initializer())
1200    self.assertAllEqual(5., self.evaluate(concrete()))
1201    self.assertAllEqual(5., self.evaluate(tensor_init()))
1202
1203  def testFuncGraphCaptureByValue(self):
1204    v = variables.Variable(1.0)
1205
1206    def trivial_function():
1207      return v.read_value()
1208
1209    graph_function = function.Function(
1210        trivial_function, 'test', capture_by_value=True)
1211
1212    self.assertAllEqual(graph_function(), 1.0)
1213    v.assign(2.0)
1214    self.assertAllEqual(graph_function(), 1.0)
1215
1216  def testFuncGraphCaptureByValueNested(self):
1217    v = variables.Variable(1.0)
1218
1219    def trivial_function():
1220      return control_flow_ops.cond(
1221          array_ops.placeholder_with_default(True, ()),
1222          v.read_value, v.read_value)
1223
1224    graph_function = function.Function(
1225        trivial_function, 'test', capture_by_value=True)
1226
1227    self.assertAllEqual(graph_function(), 1.0)
1228    v.assign(2.0)
1229    self.assertAllEqual(graph_function(), 1.0)
1230
1231  def testDefunShapeInferenceWithCapturedResourceVariable(self):
1232    v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1233
1234    def f():
1235      x = constant_op.constant([[1, 2], [3, 4]])
1236      out = math_ops.matmul(v, x)
1237      self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1238      # We do not return v directly since the tensor conversion function of
1239      # ResourceVariable returns the read value and not the resource itself.
1240      return v._handle
1241
1242    compiled = def_function.function(f)
1243    var_handle = compiled()
1244    self.assertEqual(var_handle.dtype, dtypes.resource)
1245    self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1246    var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1247    self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1248
1249  def testShapeInferenceForMoreSpecificInput(self):
1250
1251    def f(a):
1252      return array_ops.reshape(a, [-1, 3])
1253
1254    signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
1255    compiled = def_function.function(f, input_signature=signature)
1256
1257    @def_function.function
1258    def use_f():
1259      inputs = array_ops.zeros([10, 10, 3])
1260      self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
1261
1262    use_f()
1263
1264  def testFuncListAttr(self):
1265
1266    @function.defun
1267    def test_function(val):
1268
1269      def fn1():
1270        return array_ops.ones([10])
1271
1272      fn2 = lambda: array_ops.ones([10]) * 2
1273
1274      def fn3(x=3):
1275        return array_ops.ones([10]) * x
1276      fn4 = functools.partial(fn3, x=4)
1277      fn5 = functools.partial(fn3, 5)
1278
1279      return gen_functional_ops.case(val, [], [dtypes.float32],
1280                                     [function.defun(f).get_concrete_function()
1281                                      for f in (fn1, fn2, fn3, fn4, fn5)])
1282
1283    ones = array_ops.ones([10])
1284    self.assertAllEqual([ones], test_function(0))
1285    self.assertAllEqual([ones * 2], test_function(1))
1286    self.assertAllEqual([ones * 3], test_function(2))
1287    self.assertAllEqual([ones * 4], test_function(3))
1288    self.assertAllEqual([ones * 5], test_function(4))
1289    self.assertAllEqual([ones * 5], test_function(22))  # default branch
1290
1291  @test_util.enable_control_flow_v2
1292  def testVariableInLoopInFunction(self):
1293
1294    @function.defun
1295    def test_function():
1296
1297      def loop_test(_):
1298        return False
1299
1300      def loop_body(_):
1301        return variable_scope.get_variable('a', shape=())
1302
1303      return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
1304
1305    self.assertEqual(test_function().shape, [])
1306
1307  def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
1308    with context.graph_mode():
1309      v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1310
1311      def f():
1312        x = constant_op.constant([[1, 2], [3, 4]])
1313        out = math_ops.matmul(v, x)
1314        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1315        # We do not return v directly since the tensor conversion function of
1316        # ResourceVariable returns the read value and not the resource itself.
1317        return v._handle
1318
1319      compiled = def_function.function(f)
1320      var_handle = compiled()
1321      self.assertEqual(var_handle.dtype, dtypes.resource)
1322      self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1323      var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1324      self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1325
1326  def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
1327    with context.graph_mode():
1328      v = variables.Variable([[1, 2], [3, 4]])
1329
1330      def f():
1331        x = constant_op.constant([[1, 2], [3, 4]])
1332        out = math_ops.matmul(v, x)
1333        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1334
1335      # Check that shape inference works while creating the defun
1336      compiled = def_function.function(f)
1337      compiled()
1338
1339  def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
1340    with context.graph_mode():
1341      tensor_list = list_ops.empty_tensor_list(
1342          element_dtype=dtypes.float32,
1343          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1344      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1345                                                   constant_op.constant(1.0))
1346      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1347                                                   constant_op.constant(2.0))
1348
1349      def f():
1350        tl, value = list_ops.tensor_list_pop_back(
1351            tensor_list, element_dtype=dtypes.float32)
1352        self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1353        return tl
1354
1355      compiled = def_function.function(f)
1356      output_tensor_list = compiled()
1357      _, value = list_ops.tensor_list_pop_back(
1358          output_tensor_list, element_dtype=dtypes.float32)
1359      self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1360
1361  @test_util.run_in_graph_and_eager_modes
1362  def testDefunForcesResourceVariables(self):
1363
1364    def variable_creator():
1365      self.v = variables.Variable(0.0)
1366      return self.v.read_value()
1367
1368    self.v = None
1369    defined = function.defun(variable_creator)
1370    defined()  # Create the variable.
1371    self.assertIsInstance(
1372        self.v, resource_variable_ops.ResourceVariable)
1373
1374  def testRunMetadata(self):
1375
1376    @def_function.function
1377    def f(x):
1378      return x * x
1379
1380    with ops.device('cpu:0'):
1381      context.enable_run_metadata()
1382      f(constant_op.constant(1.0))
1383    run_metadata = context.export_run_metadata()
1384    context.disable_run_metadata()
1385    self.assertLen(run_metadata.partition_graphs, 1)
1386
1387  def testGraphModeCaptureVariable(self):
1388    with context.graph_mode(), self.cached_session():
1389
1390      class HasAVar:
1391
1392        def __init__(self):
1393          self.v = resource_variable_ops.ResourceVariable(1.0)
1394
1395        def call(self):
1396          return self.v * 2
1397
1398      o = HasAVar()
1399      self.evaluate(variables.global_variables_initializer())
1400      call = def_function.function(o.call)
1401      op = call()
1402      self.assertAllEqual(self.evaluate(op), 2.0)
1403
1404  def testGraphModeManyFunctions(self):
1405    with ops.Graph().as_default(), self.cached_session():
1406
1407      @def_function.function
1408      def f(x):
1409        return x * x
1410
1411      @def_function.function
1412      def g(x):
1413        return f(x) + 1
1414
1415      self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
1416
1417  def testDict(self):
1418
1419    @def_function.function
1420    def f(x):
1421      return {'name': x + 1}
1422
1423    self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
1424
1425  def testWeakrefInputsRejected(self):
1426
1427    @def_function.function
1428    def f(x):
1429      return x
1430
1431    class Dummy:
1432      pass
1433    o = Dummy()
1434    wr = weakref.ref(o)
1435
1436    with self.assertRaisesRegex(ValueError, 'weakref'):
1437      f(wr)
1438
1439  def testTensorConversionWithDefun(self):
1440
1441    @def_function.function
1442    def f(x):
1443      return math_ops.add(x, constant_op.constant(3))
1444
1445    self.assertAllEqual(5, f(constant_op.constant(2)))
1446
1447  def testTensorConversionCall(self):
1448
1449    @def_function.function
1450    def f(x):
1451      return math_ops.add(x, constant_op.constant(3))
1452
1453    @def_function.function
1454    def g(x):
1455      return f(f(x))
1456
1457    self.assertAllEqual(8, g(constant_op.constant(2)))
1458
1459  def testCallShape(self):
1460
1461    @def_function.function
1462    def f(x):
1463      return x + 1
1464
1465    @def_function.function
1466    def g(x):
1467      x = f(x)
1468      self.assertEqual(x.shape.as_list(), [])
1469      return None
1470
1471    g(constant_op.constant(1.0))
1472
1473  def testNestedDefunWithNoOutputAndTapedInput(self):
1474    three = resource_variable_ops.ResourceVariable(3.0, name='v')
1475
1476    @def_function.function
1477    def f(x):
1478      # This function intentionally takes a taped variable as input,
1479      # but does not return any values
1480      math_ops.add(x, three)
1481
1482    @def_function.function
1483    def g(x):
1484      y = math_ops.add(x, three)
1485      f(y)
1486
1487    g(three)
1488
1489  def testGatherResourceWithDefun(self):
1490    with ops.device('cpu:0'):
1491      v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1492
1493    def sum_gather():
1494      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
1495
1496    defined = def_function.function(sum_gather)
1497    self.assertAllEqual(sum_gather(), defined())
1498
1499  @parameterized.named_parameters([
1500      ('IndexedSlicesWithDenseShape',
1501       _example_indexed_slices_with_dense_shape,),
1502      ('IndexedSlicesWithoutDenseShape',
1503       _example_indexed_slices_without_dense_shape,),
1504      ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
1505       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1506      ('RaggedTensorRaggedRank2',
1507       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1508       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1509      ('SparseTensor', sparse_tensor.SparseTensor,
1510       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1511  ])  # pyformat: disable
1512  def testReturnCompositeTensorWithDefun(self,
1513                                         factory_fn,
1514                                         factory_kwargs={},
1515                                         input_signature=None):
1516    input_ct = factory_fn(**factory_kwargs)
1517
1518    @def_function.function(input_signature=input_signature)
1519    def f():
1520      return input_ct
1521
1522    output_ct = f()
1523    self.assertIsInstance(output_ct, type(input_ct))
1524    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1525
1526    input_flat = nest.flatten(input_ct, expand_composites=True)
1527    output_flat = nest.flatten(output_ct, expand_composites=True)
1528    for (input_component, output_component) in zip(input_flat, output_flat):
1529      self.assertAllEqual(input_component, output_component)
1530
1531  @parameterized.named_parameters([
1532      ('IndexedSlicesWithDenseShape',
1533       _example_indexed_slices_with_dense_shape,),
1534      ('IndexedSlicesWithoutDenseShape',
1535       _example_indexed_slices_without_dense_shape,),
1536      ('RaggedTensorRaggedRank1',
1537       ragged_tensor.RaggedTensor.from_row_lengths,
1538       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1539      ('RaggedTensorRaggedRank2',
1540       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1541       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1542      ('SparseTensor',
1543       sparse_tensor.SparseTensor,
1544       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1545      ('RaggedTensorRaggedRank1WithSignature',
1546       ragged_tensor.RaggedTensor.from_row_lengths,
1547       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
1548       [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
1549      ('RaggedTensorRaggedRank2WithSignature',
1550       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1551       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
1552       [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
1553      ('SparseTensorWithSignature',
1554       sparse_tensor.SparseTensor,
1555       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
1556       [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
1557  ])  # pyformat: disable
1558  def testCompositeAsArgumentTensorWithDefun(self,
1559                                             factory_fn,
1560                                             factory_kwargs={},
1561                                             input_signature=None):
1562    input_ct = factory_fn(**factory_kwargs)
1563
1564    @def_function.function(input_signature=input_signature)
1565    def f(x):
1566      return x
1567
1568    output_ct = f(input_ct)
1569    self.assertIsInstance(output_ct, type(input_ct))
1570    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1571
1572    input_flat = nest.flatten(input_ct, expand_composites=True)
1573    output_flat = nest.flatten(output_ct, expand_composites=True)
1574    for (input_component, output_component) in zip(input_flat, output_flat):
1575      self.assertAllEqual(input_component, output_component)
1576
1577  def testTracedCompositeDiscardsShapeInfo(self):
1578    # SparseTensorSpec intentionally excludes info about the number of elements
1579    # that are in a sparse tensor (which is recorded as st.indices.shape[0] and
1580    # st.values.shape[0]).  Similarly, RaggedTensorSpec intentionally excludes
1581    # info about the total number of values in a RaggedTensor (stored as
1582    # rt.values.shape[0]).  This test checks that the placeholders created by
1583    # tf.function() properly mask this shape info.
1584    @def_function.function
1585    def f(rt, st):
1586      self.assertEqual(st.indices.shape.as_list()[:1], [None])
1587      self.assertEqual(st.values.shape.as_list(), [None])
1588      return (rt, st)
1589
1590    rt = ragged_factory_ops.constant([[1, 2], [3]])
1591    st = sparse_tensor.SparseTensor([[0]], [0], [10])
1592    f(rt, st)
1593
1594  @test_util.run_gpu_only
1595  def testFunctionOnDevice(self):
1596    x = constant_op.constant([1.]).gpu()
1597    f = def_function.function(math_ops.add)
1598    y = f(x, x).cpu()
1599    self.assertAllEqual(y, [2.])
1600
1601  @test_util.run_gpu_only
1602  @test_util.run_in_graph_and_eager_modes
1603  def testFunctionWithResourcesOnDifferentDevices(self):
1604    with ops.device('/cpu:0'):
1605      v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1606
1607    with ops.device('/gpu:0'):
1608      v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1609
1610    def sum_gather():
1611      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
1612      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
1613      return cpu_result, gpu_result
1614
1615    defined = function.defun(sum_gather)
1616    if not context.executing_eagerly():
1617      self.evaluate(variables.global_variables_initializer())
1618    expected = self.evaluate(sum_gather())
1619    self.assertAllEqual(expected, self.evaluate(defined()))
1620
1621  @test_util.run_gpu_only
1622  @test_util.run_in_graph_and_eager_modes
1623  def testOpInFunctionWithConflictingResourceInputs(self):
1624    with ops.device('/cpu:0'):
1625      v_cpu = resource_variable_ops.ResourceVariable(
1626          [0.0, 1.0, 2.0], name='cpu')
1627      v_also_cpu = resource_variable_ops.ResourceVariable(
1628          [0.0, 1.0, 2.0], name='also_cpu')
1629
1630    with ops.device('/gpu:0'):
1631      v_gpu = resource_variable_ops.ResourceVariable(
1632          [0.0, 1.0, 2.0], name='gpu')
1633
1634    @def_function.function
1635    def resource_apply_adam():
1636      training_ops.resource_apply_adam(
1637          v_cpu.handle,
1638          v_gpu.handle,
1639          v_also_cpu.handle,
1640          1.0,  # beta1_power
1641          1.0,  # beta2_power
1642          1.0,  # learning_rate
1643          1.0,  # beta1
1644          1.0,  # beta2
1645          1.0,  # epsilon,
1646          [1.0, 1.0, 1.0],  # grad
1647          False)  # use_locking
1648      return None
1649
1650    with self.assertRaisesRegex(
1651        errors.InvalidArgumentError,
1652        'Cannot place the graph because a reference or resource edge connects '
1653        'colocation groups with incompatible assigned devices'):
1654      if not context.executing_eagerly():
1655        self.evaluate(variables.global_variables_initializer())
1656      self.evaluate(resource_apply_adam())
1657
1658  @test_util.run_gpu_only
1659  def testFunctionHandlesInputsOnDifferentDevices(self):
1660    # The Reshape op requires the shape tensor to be placed in host memory.
1661    reshape = def_function.function(array_ops.reshape)
1662    value = constant_op.constant([1., 2.]).gpu()
1663    shape = constant_op.constant([2, 1])
1664    reshaped = reshape(value, shape).cpu()
1665    self.assertAllEqual(reshaped, [[1], [2]])
1666
1667  @test_util.run_gpu_only
1668  def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
1669    # The Reshape op requires the shape tensor to be placed in host memory.
1670    reshape = def_function.function(array_ops.reshape)
1671    value = constant_op.constant([1., 2.])
1672    shape = constant_op.constant([2, 1]).gpu()
1673    reshape(value, shape)  # No error is raised
1674
1675  def testNoneOutput(self):
1676
1677    @def_function.function
1678    def my_function(_):
1679      return None
1680
1681    self.assertAllEqual(my_function(1), None)
1682
1683  def testNestedFunctions(self):
1684    # TensorFlow function (which is what would be used in TensorFlow graph
1685    # construction).
1686    @tf_function.Defun(dtypes.int32, dtypes.int32)
1687    def add(a, b):
1688      return math_ops.add(a, b)
1689
1690    @def_function.function
1691    def add_one(x):
1692      return add(x, 1)
1693
1694    self.assertAllEqual(3, add_one(constant_op.constant(2)))
1695
1696  def testVariableCaptureInNestedFunctions(self):
1697    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
1698
1699    @def_function.function
1700    def inner_read():
1701      return v.read_value()
1702
1703    @def_function.function
1704    def outer():
1705      return inner_read()
1706
1707    self.assertEqual(1, int(outer()))
1708
1709  def testReturnCapturedEagerTensor(self):
1710    t = constant_op.constant(1)
1711
1712    @def_function.function
1713    def read():
1714      return t
1715
1716    self.assertEqual(1, int(read()))
1717
1718  def testReturnCapturedGraphTensor(self):
1719    with context.graph_mode(), self.cached_session():
1720      t = constant_op.constant(1)
1721
1722      @def_function.function
1723      def read():
1724        return t
1725
1726      self.assertEqual(1, int(self.evaluate(read())))
1727
1728  def testSequenceInputs(self):
1729    clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm)
1730    t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
1731    clipped_list, global_norm = clip_by_global_norm(t_list,
1732                                                    constant_op.constant(.2))
1733    for t in clipped_list:
1734      self.assertIsInstance(t, ops.Tensor)
1735    self.assertIsInstance(global_norm, ops.Tensor)
1736
1737  def testNestedSequenceInputs(self):
1738
1739    def my_op(inputs):
1740      a, b, c = inputs
1741      e, f = b
1742      g, h = e
1743      return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
1744
1745    my_eager_op = def_function.function(my_op)
1746    ret = my_eager_op([
1747        constant_op.constant(1), [(constant_op.constant(2),
1748                                   constant_op.constant(3)),
1749                                  constant_op.constant(4)],
1750        constant_op.constant(5)
1751    ])
1752    self.assertLen(ret, 2)
1753    self.assertAllEqual(ret[0][0], 2)
1754    self.assertAllEqual(ret[0][1][0][0], 8)
1755    self.assertAllEqual(ret[0][1][0][1], 4)
1756    self.assertIsInstance(ret[0][1][0], tuple)
1757    self.assertAllEqual(ret[0][1][1], 6)
1758    self.assertAllEqual(ret[0][2], 10)
1759    self.assertAllEqual(ret[1], 15)
1760
1761  def testVariableNamesRespectNameScopesWithDefun(self):
1762    @def_function.function
1763    def create_variable():
1764      with ops.name_scope('foo', skip_on_eager=False):
1765        v = resource_variable_ops.ResourceVariable(0.0, name='bar')
1766      self.assertEqual(v.name, 'foo/bar:0')
1767
1768    create_variable()
1769
1770  def testVariableNamesRespectNameScopesWithDefunInGraph(self):
1771    with context.graph_mode():
1772      @def_function.function
1773      def create_variable():
1774        with ops.name_scope('foo', skip_on_eager=False):
1775          v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
1776        self.assertEqual(v.name, 'foo/bar:0')
1777
1778      with ops.get_default_graph().as_default():
1779        create_variable()
1780
1781  @test_util.assert_no_new_pyobjects_executing_eagerly
1782  def testCallOptionsMemory(self):
1783
1784    @function.defun
1785    def model(x):
1786      return x + constant_op.constant(1.)
1787
1788    # This happens with a lot of option toggles, e.g. soft device placement
1789    context.context().function_call_options = None
1790    model(constant_op.constant(2.))
1791
1792  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
1793  def testLayerInDefun(self):
1794    conv = convolutional.Conv2D(
1795        filters=1,
1796        kernel_size=2,
1797        kernel_initializer=init_ops.ones_initializer(),
1798        bias_initializer=init_ops.zeros_initializer())
1799
1800    @function.defun
1801    def model(x):
1802      return conv(x)
1803
1804    x = array_ops.ones([1, 2, 2, 1])
1805    y = model(x)
1806
1807    if not context.executing_eagerly():
1808      self.evaluate(variables.global_variables_initializer())
1809
1810    self.assertAllClose([[[[4.0]]]], self.evaluate(y))
1811
1812  # Variable lifting is somewhat different between defun/tf.function, so testing
1813  # device placement on both makes sense.
1814  @parameterized.named_parameters(
1815      dict(testcase_name='Defun',
1816           function_decorator=function.defun),
1817      dict(testcase_name='DefFunction',
1818           function_decorator=def_function.function))
1819  @test_util.run_in_graph_and_eager_modes
1820  def testVariablesPlacedOnOutsideDevice(self, function_decorator):
1821
1822    class _Obj(object):
1823
1824      def __init__(self):
1825        self.v = None
1826
1827      @function_decorator
1828      def f(self):
1829        if self.v is None:
1830          self.v = variables.Variable(1.)
1831        return self.v + 1.
1832
1833    has_device = _Obj()
1834    with ops.device('cpu:0'):
1835      has_device.f()
1836    self.assertIn('CPU', has_device.v.device)
1837
1838  @test_util.run_in_graph_and_eager_modes
1839  def testMultipleDeviceCheck(self):
1840
1841    def f():
1842      with ops.device('cpu'):
1843        return test_ops.device_placement_op()
1844
1845    func = function.defun(f)
1846    with ops.device('cpu:0'):
1847      output = self.evaluate(func())
1848      self.assertIn(compat.as_bytes('CPU:0'), output)
1849
1850  @test_util.run_in_graph_and_eager_modes
1851  def testDeviceAnnotationsRespected(self):
1852
1853    def multi_device_fn():
1854      with ops.device('/cpu:0'):
1855        s0 = test_ops.device_placement_op()
1856      with ops.device('/cpu:1'):
1857        s1 = test_ops.device_placement_op()
1858      with ops.device('/cpu:2'):
1859        s2 = test_ops.device_placement_op()
1860      s3 = test_ops.device_placement_op()
1861      return s0, s1, s2, s3
1862
1863    defined = function.defun(multi_device_fn)
1864    outputs = self.evaluate(defined())
1865    self.assertLen(total_function_cache(defined), 1)
1866    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1867    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1868    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1869
1870    with ops.device('/cpu:3'):
1871      outputs = self.evaluate(defined())
1872    # All function definitions are agnostic to call site devices.
1873    self.assertLen(total_function_cache(defined), 1)
1874    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1875    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1876    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1877    self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
1878
1879    with ops.device('/cpu:0'):
1880      outputs = self.evaluate(defined())
1881    self.assertLen(total_function_cache(defined), 1)
1882    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1883    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1884    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1885    self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
1886
1887  @test_util.run_in_graph_and_eager_modes
1888  def testCallingGraphFunctionOnDifferentDevice(self):
1889
1890    def func():
1891      return constant_op.constant(0)
1892
1893    defined = def_function.function(func)
1894    with ops.device('cpu:0'):
1895      cpu_graph_function = defined.get_concrete_function()
1896
1897    with ops.device('cpu:0'):
1898      self.assertEqual(
1899          self.evaluate(cpu_graph_function()), self.evaluate(func()))
1900
1901    with ops.device('cpu:1'):
1902      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1903
1904    with ops.device(None):
1905      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1906
1907    default_graph_function = defined.get_concrete_function()
1908    self.assertEqual(
1909        self.evaluate(default_graph_function()), self.evaluate(func()))
1910
1911    with ops.device('cpu:1'):
1912      self.assertEqual(0., self.evaluate(default_graph_function()))
1913
1914  @test_util.run_gpu_only
1915  @test_util.run_in_graph_and_eager_modes
1916  def testColocateWithRespected(self):
1917    # TODO(b/113291792): Use multiple CPUs instead of a GPU.
1918    with ops.device('cpu:0'):
1919      x = array_ops.identity(1.0)
1920
1921    with ops.device('gpu:0'):
1922      y = array_ops.identity(1.0)
1923
1924    @def_function.function
1925    def foo():
1926      return test_ops.device_placement_op()
1927
1928    with ops.colocate_with(x):
1929      self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
1930
1931    with ops.colocate_with(y):
1932      self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
1933
1934  def testVariablesAreTracked(self):
1935    v = resource_variable_ops.ResourceVariable(1.0)
1936
1937    def foo(x):
1938      return v * x
1939
1940    defined = def_function.function(foo)
1941
1942    x = constant_op.constant([1.0])
1943    self.assertEqual(1., self.evaluate(defined(x)))
1944    v.assign(2.)
1945
1946    x = constant_op.constant([1.0, 2.0])
1947    self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
1948
1949  def testCacheObjectHashCollisions(self):
1950
1951    class Foo:
1952
1953      def __hash__(self):
1954        return 42
1955
1956    def func(foo):
1957      return constant_op.constant([id(foo)])
1958
1959    defined = function.defun(func)
1960    foo_1 = Foo()
1961    defined(foo_1)
1962    self.assertLen(total_function_cache(defined), 1)
1963
1964    foo_2 = Foo()
1965    defined(foo_2)
1966    self.assertLen(total_function_cache(defined), 2)
1967
1968  def testCacheTensorDtypeCollision(self):
1969
1970    def func(t):
1971      return t + t
1972
1973    defined = function.defun(func)
1974    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1975    defined(t)
1976    self.assertLen(total_function_cache(defined), 1)
1977
1978    t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
1979    defined(t)
1980    self.assertLen(total_function_cache(defined), 2)
1981
1982  def testCacheTensorShapeCollision(self):
1983
1984    def func(t):
1985      return t + t
1986
1987    defined = function.defun(func)
1988    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1989    defined(t)
1990    self.assertLen(total_function_cache(defined), 1)
1991
1992    t = constant_op.constant([1.0], dtype=dtypes.complex64)
1993    defined(t)
1994    self.assertLen(total_function_cache(defined), 2)
1995
1996  def testCacheTensorShapeDtypeCollision(self):
1997
1998    def func(t):
1999      return t + t
2000
2001    defined = function.defun(func)
2002    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
2003    defined(t)
2004    self.assertLen(total_function_cache(defined), 1)
2005
2006    t = constant_op.constant([1.0], dtype=dtypes.complex128)
2007    defined(t)
2008    self.assertLen(total_function_cache(defined), 2)
2009
2010  def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
2011
2012    def func(t):
2013      return t + t
2014
2015    with context.graph_mode(), self.cached_session():
2016      defined = function.defun(func, reduce_retracing=True)
2017
2018      p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
2019      defined(p)
2020      self.assertLen(total_function_cache(defined), 1)
2021
2022      p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
2023      defined(p)
2024      self.assertLen(total_function_cache(defined), 2)
2025
2026      p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
2027      defined(p)
2028      # Gradual shape relaxation is performed; and the common shape between
2029      # [1] and [2] is one containing unknown dimensions.
2030      self.assertLen(total_function_cache(defined), 2)
2031
2032      t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
2033      defined(t)
2034      # Shape (3,) matches the relaxed shape TensorShape([None])
2035      self.assertLen(total_function_cache(defined), 2)
2036
2037  def testPythonFunctionWithDefaultArgs(self):
2038
2039    def func(foo, bar=1, baz=2):
2040      del foo
2041      del bar
2042      del baz
2043      return
2044
2045    defined = function.defun(func)
2046    defined(0, baz=20)
2047    self.assertLen(total_function_cache(defined), 1)
2048
2049    defined(1)  # bar=1, baz=2
2050    self.assertLen(total_function_cache(defined), 2)
2051
2052    # This matches the previous call.
2053    defined(foo=1)
2054    self.assertLen(total_function_cache(defined), 2)
2055
2056    defined(1, 2, 3)
2057    self.assertLen(total_function_cache(defined), 3)
2058
2059    # This matches the previous call.
2060    defined(1, bar=2, baz=3)
2061    self.assertLen(total_function_cache(defined), 3)
2062
2063    # This matches the previous call.
2064    defined(1, baz=3, bar=2)
2065    self.assertLen(total_function_cache(defined), 3)
2066
2067  def testDatasetIteratorCaching(self):
2068    def func(it1, it2):
2069      next(it1)
2070      next(it2)
2071      return 0
2072
2073    defined = function.defun(func)
2074
2075    d = dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])
2076    it1 = iter(d)
2077    it2 = iter(d)
2078    _ = defined(it1, it2)  # The two iterators are different
2079    self.assertLen(total_function_cache(defined), 1)
2080
2081    it3 = iter(d)
2082    it4 = iter(d)
2083    _ = defined(it3, it4)  # The two iterators are different, should not retrace
2084    self.assertLen(total_function_cache(defined), 1)
2085
2086    it5 = iter(d)
2087    _ = defined(it5, it5)  # The two iterators are the same, should retrace
2088    self.assertLen(total_function_cache(defined), 2)
2089
2090    it6 = iter(d)
2091    _ = defined(it6, it6)  # The two iterators are the same, should not retrace
2092    self.assertLen(total_function_cache(defined), 2)
2093
2094  def testFunctoolsPartialUnwrappedCorrectly(self):
2095
2096    def full_function(a, b, c=3):
2097      return a, b, c
2098
2099    partial = functools.partial(full_function, 1, c=4)
2100    a, b, c = partial(2)
2101
2102    defined = function.defun(partial)
2103    func_a, func_b, func_c = defined(2)
2104    self.assertEqual(func_a.numpy(), a)
2105    self.assertEqual(func_b.numpy(), b)
2106    self.assertEqual(func_c.numpy(), c)
2107
2108  def testInputSignatureWithMatchingInputs(self):
2109
2110    def foo(a):
2111      self.assertEqual(a.shape, (2,))
2112      return a
2113
2114    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2115    defined = function.defun(foo, input_signature=signature)
2116    a = array_ops.ones([2])
2117    self.assertAllEqual(a, defined(a))
2118    self.assertLen(total_function_cache(defined), 1)
2119    self.assertAllEqual(a, defined.get_concrete_function()(a))
2120    self.assertAllEqual(a, defined.get_concrete_function(a)(a))
2121    self.assertAllEqual(a, defined.get_concrete_function(
2122        tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
2123    self.assertLen(total_function_cache(defined), 1)
2124
2125    def bar(a):
2126      self.assertEqual(a._shape_tuple(), (2, None))
2127      return a
2128
2129    signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
2130    defined = function.defun(bar, input_signature=signature)
2131    a = array_ops.ones([2, 1])
2132    out = defined(a)
2133    self.assertLen(total_function_cache(defined), 1)
2134    self.assertAllEqual(out, a)
2135
2136    # Changing the second dimension shouldn't create a new function.
2137    b = array_ops.ones([2, 3])
2138    out = defined(b)
2139    self.assertLen(total_function_cache(defined), 1)
2140    self.assertAllEqual(out, b)
2141
2142  def testInputSignatureWithDictInPositionalArgs(self):
2143
2144    @function.defun
2145    def f(*_args, **_kwargs):
2146      return None
2147
2148    f(1, x=2)
2149    self.assertLen(total_function_cache(f), 1)
2150    f(1, x=2)
2151    self.assertLen(total_function_cache(f), 1)
2152    f(1, {'x': 2})
2153    self.assertLen(total_function_cache(f), 2)
2154
2155  def testInputSignatureWithCompatibleInputs(self):
2156
2157    rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
2158                                        dtype=dtypes.float32)
2159
2160    @function.defun(input_signature=[rank2_spec])
2161    def func(a):
2162      self.assertEqual([None, None], a.shape.as_list())
2163      return array_ops.shape(a)
2164
2165    self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
2166    self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
2167
2168    with self.assertRaisesRegex(ValueError, 'incompatible'):
2169      func([0.0, 1.0, 2.0])  # Wrong shape.
2170
2171    with self.assertRaisesRegex(ValueError, 'incompatible'):
2172      func([['wrong dtype']])
2173
2174  def testNestedInputSignatures(self):
2175
2176    def expected_foo(a, b):
2177      return [a, b]
2178
2179    @function.defun(input_signature=[
2180        [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
2181        tensor_spec.TensorSpec((1,), dtypes.float32),
2182    ])
2183    def foo(a, b):
2184      self.assertEqual(a[0]._shape_tuple(), (2, None))
2185      self.assertEqual(a[1]._shape_tuple(), (2, None))
2186      self.assertEqual(b._shape_tuple(), (1,))
2187      return [a, b]
2188
2189    a = array_ops.ones([2, 1])
2190    b = array_ops.ones([1])
2191    expected = expected_foo([a, a], b)
2192    out = foo([a, a], b)
2193    self.assertLen(total_function_cache(foo), 1)
2194    nest.assert_same_structure(out, expected)
2195    self.assertAllEqual(out[0][0], a)
2196    self.assertAllEqual(out[0][1], a)
2197    self.assertAllEqual(out[1], b)
2198
2199    # Changing the unspecified dimensions shouldn't create a new function.
2200    a = array_ops.ones([2, 3])
2201    b = array_ops.ones([2, 5])
2202    c = array_ops.ones([1])
2203    expected = expected_foo([a, b], c)
2204    out = foo([a, b], c)
2205    self.assertLen(total_function_cache(foo), 1)
2206    nest.assert_same_structure(out, expected)
2207    self.assertAllEqual(out[0][0], a)
2208    self.assertAllEqual(out[0][1], b)
2209    self.assertAllEqual(out[1], c)
2210
2211    # Passing compatible inputs should work.
2212    a = a.numpy().tolist()
2213    b = b.numpy().tolist()
2214    c = c.numpy().tolist()
2215    out = foo([a, b], c)
2216    self.assertLen(total_function_cache(foo), 1)
2217    nest.assert_same_structure(out, expected)
2218    self.assertAllEqual(out[0][0], a)
2219    self.assertAllEqual(out[0][1], b)
2220    self.assertAllEqual(out[1], c)
2221
2222  def testNestedInputSignaturesWithDict(self):
2223    def expected_bar(a):
2224      return a
2225
2226    @function.defun(input_signature=[{
2227        'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
2228        'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
2229        'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
2230    def bar(a):
2231      self.assertEqual(a['a']._shape_tuple(), (2, None))
2232      self.assertEqual(a['b']._shape_tuple(), (2, None))
2233      self.assertEqual(a['c']._shape_tuple(), (1,))
2234      return a
2235
2236    a = array_ops.ones([2, 3])
2237    b = array_ops.ones([1])
2238    inputs = {'a': a, 'b': a, 'c': b}
2239    expected = expected_bar(inputs)
2240    out = bar(inputs)
2241    nest.assert_same_structure(out, expected)
2242    self.assertAllEqual(out['a'], expected['a'])
2243    self.assertAllEqual(out['b'], expected['b'])
2244    self.assertAllEqual(out['c'], expected['c'])
2245
2246    # Passing compatible inputs should work.
2247    a = a.numpy().tolist()
2248    b = b.numpy().tolist()
2249    inputs = {'a': a, 'b': a, 'c': b}
2250    out = bar(inputs)
2251    nest.assert_same_structure(out, expected)
2252    self.assertAllEqual(out['a'], expected['a'])
2253    self.assertAllEqual(out['b'], expected['b'])
2254    self.assertAllEqual(out['c'], expected['c'])
2255
2256  def testInputSignatureMustBeSequenceOfTensorSpecs(self):
2257
2258    def foo(a, b):
2259      del a
2260      del b
2261
2262    # Signatures must consist exclusively of `TensorSpec` objects.
2263    signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
2264    with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'):
2265      def_function.function(foo, input_signature=signature)
2266
2267    # Signatures must be either lists or tuples on their outermost levels.
2268    signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
2269    with self.assertRaisesRegex(
2270        TypeError, 'input_signature must be either a '
2271        'tuple or a list.*'):
2272      function.defun(foo, input_signature=signature)
2273
2274  @test_util.run_in_graph_and_eager_modes
2275  def testInputsIncompatibleWithSignatureRaisesError(self):
2276
2277    def foo(a):
2278      return a
2279
2280    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2281    defined = def_function.function(foo, input_signature=signature)
2282
2283    # Invalid shapes.
2284    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2285      defined(array_ops.ones([3]))
2286
2287    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2288      defined(array_ops.ones([2, 1]))
2289
2290    # Wrong number of arguments.
2291    with self.assertRaisesRegex(TypeError, 'specifies 1 .* got 2'):
2292      defined(array_ops.ones([2]), array_ops.ones([2]))
2293    with self.assertRaisesRegex(ValueError,
2294                                'Structure of Python function inputs.*'):
2295      defined()
2296
2297    with self.assertRaisesRegex(ValueError,
2298                                'inputs incompatible with input_signature'):
2299      defined.get_concrete_function(
2300          tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
2301
2302  def testMismatchedConcreteSignatureRaisesError(self):
2303
2304    @def_function.function
2305    def run_test():
2306      @def_function.function
2307      def f(x):
2308        return x
2309
2310      with self.assertRaisesRegex(
2311          TypeError, 'ConcreteFunction .* was constructed .* but was called'):
2312        f.get_concrete_function(1)(constant_op.constant(1))
2313
2314      with self.assertRaisesRegex(TypeError, r'f\(x\) expected .* but got .*'):
2315        f.get_concrete_function(constant_op.constant(1))(1)
2316
2317      with self.assertRaisesRegex(
2318          TypeError, 'ConcreteFunction .* was constructed .* but was called'):
2319        f.get_concrete_function(1)(2)
2320
2321    run_test()
2322
2323  def testInputsIncompatibleWithNestedSignatureRaisesError(self):
2324
2325    def foo(a, b):
2326      return [a, b]
2327
2328    signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
2329                 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
2330    defined = function.defun(foo, input_signature=signature)
2331    a = array_ops.ones([1])
2332
2333    with self.assertRaisesRegex(ValueError,
2334                                'Structure of Python function inputs.*'):
2335      defined([a, a, a], [a])
2336
2337    with self.assertRaisesRegex(ValueError,
2338                                'Structure of Python function inputs.*'):
2339      defined([a], [a, a, a])
2340    defined([a, a], [a, a])
2341
2342  def testUnderspecifiedInputSignature(self):
2343    @function.defun(input_signature=[
2344        tensor_spec.TensorSpec([], dtypes.float32),
2345    ])
2346    def foo(a, training=True):
2347      if training:
2348        return a
2349      else:
2350        return -1.0 * a
2351
2352    x = constant_op.constant(1.0)
2353    with self.assertRaisesRegex(
2354        TypeError, 'got keyword argument `training` '
2355        'that was not included in input_signature'):
2356      foo(x, training=True)
2357
2358    with self.assertRaisesRegex(
2359        TypeError, 'got keyword argument `training` '
2360        'that was not included in input_signature'):
2361      foo(x, training=False)
2362
2363    self.assertAllEqual(x.numpy(), foo(x).numpy())
2364
2365  def testInputSignatureWithPartialFunction(self):
2366    def full_function(a, b, c=3.0):
2367      return a, b, c
2368
2369    partial = functools.partial(full_function, 1, c=4)
2370    a, b, c = partial(2.0)
2371    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
2372    defined = function.defun(partial, input_signature=signature)
2373    x = constant_op.constant(2.0)
2374    func_a, func_b, func_c = defined(x)
2375    self.assertEqual(func_a.numpy(), a)
2376    self.assertEqual(func_b.numpy(), b)
2377    self.assertEqual(func_c.numpy(), c)
2378
2379  def testInputSignatureConversionWithDefaultArg(self):
2380
2381    def foo(a, training=True):
2382      if training:
2383        return a
2384      else:
2385        return -1.0 * a
2386
2387    signature = [
2388        tensor_spec.TensorSpec([], dtypes.float32),
2389        tensor_spec.TensorSpec([], dtypes.bool),
2390    ]
2391    defined = def_function.function(foo, input_signature=signature)
2392    a = constant_op.constant(1.0)
2393    self.assertAllEqual(a.numpy(), defined(a))
2394    self.assertAllEqual(a.numpy(), defined(a, training=True))
2395    self.assertAllEqual(-a.numpy(), defined(a, training=False))
2396
2397  def testInputSignatureWithKeywordPositionalArgs(self):
2398
2399    @function.defun(input_signature=[
2400        tensor_spec.TensorSpec([], dtypes.float32),
2401        tensor_spec.TensorSpec([], dtypes.int64)
2402    ])
2403    def foo(flt, integer):
2404      return flt, integer
2405
2406    flt = constant_op.constant(1.0)
2407    integer = constant_op.constant(2, dtypes.int64)
2408
2409    out1, out2 = foo(flt, integer)
2410    self.assertLen(total_function_cache(foo), 1)
2411    self.assertEqual(out1.numpy(), 1.0)
2412    self.assertEqual(out2.numpy(), 2)
2413
2414    out1, out2 = foo(flt=flt, integer=integer)
2415    self.assertLen(total_function_cache(foo), 1)
2416    self.assertEqual(out1.numpy(), 1.0)
2417    self.assertEqual(out2.numpy(), 2)
2418
2419    out1, out2 = foo(integer=integer, flt=flt)
2420    self.assertLen(total_function_cache(foo), 1)
2421    self.assertEqual(out1.numpy(), 1.0)
2422    self.assertEqual(out2.numpy(), 2)
2423
2424    out1, out2 = foo(flt, integer=integer)
2425    self.assertLen(total_function_cache(foo), 1)
2426    self.assertEqual(out1.numpy(), 1.0)
2427    self.assertEqual(out2.numpy(), 2)
2428
2429  def testInputSignatureWithKeywordArgs(self):
2430    def foo(a, b, **kwargs):
2431      del kwargs
2432      return a, b
2433
2434    x = function.defun(
2435        foo,
2436        input_signature=[
2437            tensor_spec.TensorSpec([], dtypes.float32),
2438            tensor_spec.TensorSpec([], dtypes.int32)
2439        ]).get_concrete_function()
2440    result = x(constant_op.constant(5.0), constant_op.constant(5))
2441    self.assertAllEqual(result, [5.0, 5])
2442
2443  def testInputSignatureWithCompositeTensors(self):
2444    def f(rt):
2445      self.assertEqual(rt.values.shape.as_list(), [None])
2446      self.assertEqual(rt.row_splits.shape.as_list(), [4])
2447      return rt
2448
2449    signature = [ragged_tensor.RaggedTensorSpec(
2450        shape=[3, None], dtype=dtypes.int32)]
2451    defined = function.defun(f, input_signature=signature)
2452    rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
2453    out1 = defined(rt1)
2454    self.assertLen(total_function_cache(defined), 1)
2455    self.assertAllEqual(out1.values, rt1.values)
2456    self.assertAllEqual(out1.row_splits, rt1.row_splits)
2457
2458    # Changing the row lengths shouldn't create a new function.
2459    rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
2460    out2 = defined(rt2)
2461    self.assertLen(total_function_cache(defined), 1)
2462    self.assertAllEqual(out2.values, rt2.values)
2463    self.assertAllEqual(out2.row_splits, rt2.row_splits)
2464
2465    # Different number of rows
2466    rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
2467    with self.assertRaisesRegex(ValueError, 'incompatible'):
2468      defined(rt3)
2469
2470    # Different dtype
2471    rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
2472    with self.assertRaisesRegex(ValueError, 'Structure .* does not match'):
2473      defined(rt4)
2474
2475    # Different rank
2476    rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
2477    with self.assertRaisesRegex(ValueError, 'does not match'):
2478      defined(rt5)
2479
2480  def testInputSignatureWithVariableArgs(self):
2481
2482    def f(v):
2483      v.assign_add(1)
2484
2485    signature = [
2486        resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
2487    ]
2488    defined = function.defun(f, input_signature=signature)
2489
2490    v1 = variables.Variable(0)
2491    v2 = variables.Variable(0)
2492
2493    defined(v1)
2494    self.assertEqual(v1.numpy(), 1)
2495    self.assertEqual(v2.numpy(), 0)
2496
2497    defined(v=v2)
2498    self.assertEqual(v1.numpy(), 1)
2499    self.assertEqual(v2.numpy(), 1)
2500
2501  def testInputSignatureWithKeywordOnlyArgs(self):
2502
2503    def f(a, b, c=3, *, d=4):
2504      self.assertIsInstance(a, ops.Tensor)
2505      self.assertIsInstance(b, ops.Tensor)
2506      self.assertIsInstance(c, int)
2507      self.assertIsInstance(d, (int, ops.Tensor))
2508      return a + b + c + d
2509
2510    signature = [
2511        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
2512        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
2513    ]
2514    defined = function.defun(f, input_signature=signature)
2515    self.assertEqual(defined(1, 2).numpy(), 10)
2516
2517    defined = function.defun(
2518        functools.partial(f, c=4), input_signature=signature)
2519    self.assertEqual(defined(1, 2).numpy(), 11)
2520
2521    defined = function.defun(
2522        functools.partial(f, d=5), input_signature=signature)
2523    self.assertEqual(defined(1, 2).numpy(), 11)
2524
2525    defined = function.defun(
2526        functools.partial(f, d=array_ops.constant(5)),
2527        input_signature=signature)
2528    self.assertEqual(defined(1, 2).numpy(), 11)
2529
2530    mod = module.Module()
2531    save(mod, '/tmp/kwonlyf', defined.get_concrete_function(*signature))
2532    loaded = load('/tmp/kwonlyf')
2533    result = loaded.signatures['serving_default'](
2534        a=array_ops.constant(1), b=array_ops.constant(2))
2535    self.assertEqual(result['output_0'].numpy(), 11)
2536
2537  def testInputSignatureWithKeywordOnlyArgsNoDefaults(self):
2538    signature = [
2539        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
2540        tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32),
2541    ]
2542
2543    def test_func(a, *, b):
2544      return a + b
2545
2546    with self.assertRaisesRegex(
2547        ValueError, "keyword-only arguments must have default values.*'b'"):
2548      function.defun(test_func, input_signature=signature)
2549
2550    test_func_lambda = lambda a, *, b: a + b
2551    with self.assertRaisesRegex(
2552        ValueError, "keyword-only arguments must have default values.*'b'"):
2553      function.defun(test_func_lambda, input_signature=signature)
2554
2555  def testTensorKeywordArguments(self):
2556
2557    def foo(a, b):
2558      del a
2559      return b
2560
2561    defined = function.defun(foo)
2562    a = constant_op.constant(2.0)
2563    b = constant_op.constant([1.0, 2.0])
2564    one = defined(a, b)
2565    self.assertLen(total_function_cache(defined), 1)
2566
2567    two = defined(a=a, b=b)
2568    self.assertLen(total_function_cache(defined), 1)
2569
2570    three = defined(b=b, a=a)
2571    self.assertLen(total_function_cache(defined), 1)
2572
2573    four = defined(a, b=b)
2574    self.assertLen(total_function_cache(defined), 1)
2575
2576    # The next call corresponds to a new input signature, hence
2577    # we expect another function to be defined.
2578    five = defined(b, a)
2579    self.assertLen(total_function_cache(defined), 2)
2580
2581    six = defined(a=b, b=a)
2582    self.assertLen(total_function_cache(defined), 2)
2583
2584    seven = defined(b=a, a=b)
2585    self.assertLen(total_function_cache(defined), 2)
2586
2587    self.assertAllEqual(one, [1.0, 2.0])
2588    self.assertAllEqual(two, [1.0, 2.0])
2589    self.assertAllEqual(three, [1.0, 2.0])
2590    self.assertAllEqual(four, [1.0, 2.0])
2591    self.assertAllEqual(five, 2.0)
2592    self.assertAllEqual(six, 2.0)
2593    self.assertAllEqual(seven, 2.0)
2594
2595  def testDefuningInstanceMethod(self):
2596
2597    integer = constant_op.constant(2, dtypes.int64)
2598
2599    class Foo:
2600
2601      def one(self, tensor):
2602        return tensor
2603
2604      @def_function.function
2605      def two(self, tensor, other=integer):
2606        return self.one(tensor), other
2607
2608    foo = Foo()
2609    t = constant_op.constant(1.0)
2610    one, two = foo.two(t)
2611    self.assertEqual(one.numpy(), 1.0)
2612    self.assertEqual(two.numpy(), 2)
2613
2614  def testDefuningInstanceMethodWithDefaultArgument(self):
2615
2616    integer = constant_op.constant(2, dtypes.int64)
2617
2618    class Foo:
2619
2620      @def_function.function
2621      def func(self, other=integer):
2622        return other
2623
2624    foo = Foo()
2625    self.assertEqual(foo.func().numpy(), int(integer))
2626
2627  def testPythonCallWithSideEffects(self):
2628    state = []
2629
2630    @def_function.function
2631    def side_effecting_function():
2632      state.append(0)
2633
2634    side_effecting_function()
2635    self.assertAllEqual(state, [0])
2636
2637    # The second invocation should call the graph function, which shouldn't
2638    # trigger the list append.
2639    side_effecting_function()
2640    self.assertAllEqual(state, [0])
2641
2642    # Whereas calling the python function directly should create a side-effect.
2643    side_effecting_function.python_function()
2644    self.assertAllEqual(state, [0, 0])
2645
2646  def testFunctionWithNestedFunctionCallAndSideEffects(self):
2647    v1 = variables.Variable(1.0)
2648    v2 = variables.Variable(1.0)
2649
2650    @def_function.function
2651    def add_one(a):
2652      a.assign_add(1.0)
2653
2654    # Grappler will inline calls to `add_one` into the function body, we check
2655    # that all side-effects were executed.
2656    @def_function.function
2657    def side_effecting_function(a, b):
2658      add_one(a)
2659      add_one(b)
2660      return a + b
2661
2662    result = side_effecting_function(v1, v2)
2663    self.assertEqual(result.numpy(), 4.0)
2664
2665  def testFunctionWithExtraAttributes(self):
2666    @function.defun_with_attributes(attributes={'experimental_1': 'value1',
2667                                                'experimental_2': 2})
2668    def matmul(x, y):
2669      return math_ops.matmul(x, y)
2670
2671    def add(x, y):
2672      return math_ops.add(x, y)
2673    defun_add = function.defun_with_attributes(
2674        add, attributes={'experimental_3': True, 'experimental_4': 1.0})
2675
2676    with context.graph_mode(), self.cached_session():
2677      with ops.get_default_graph().as_default():
2678        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2679        sq = matmul(t, t)
2680        double = defun_add(t, t)
2681        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2682        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2683
2684        graph = ops.get_default_graph()
2685        # pylint: disable=protected-access
2686        self.assertLen(graph._functions, 2)
2687        functions = list(graph._functions.values())
2688        self.assertRegex(functions[0].definition.signature.name, '.*matmul.*')
2689        attrs = functions[0].definition.attr
2690        self.assertLen(attrs, 2)
2691        self.assertEqual(attrs['experimental_1'].s, b'value1')
2692        self.assertEqual(attrs['experimental_2'].i, 2)
2693
2694        self.assertRegex(functions[1].definition.signature.name, '.*add.*')
2695        attrs = functions[1].definition.attr
2696        self.assertLen(attrs, 2)
2697        self.assertEqual(attrs['experimental_3'].b, True)
2698        self.assertEqual(attrs['experimental_4'].f, 1.0)
2699        # pylint: enable=protected-access
2700
2701  def testFunctionWithInvalidAttribute(self):
2702    @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
2703    def add(x, y):
2704      return math_ops.add(x, y)
2705
2706    with self.assertRaisesRegex(ValueError,
2707                                'Attribute experimental_1 must be .* Got .*'):
2708      with context.graph_mode(), self.cached_session():
2709        with ops.get_default_graph().as_default():
2710          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2711          add(t, t)
2712
2713  def testRegisterFunction(self):
2714
2715    @function.defun
2716    def add(x, y):
2717      return math_ops.add(x, y)
2718
2719    def matmul(x, y):
2720      return math_ops.matmul(x, y)
2721    defun_matmul = function.defun(matmul)
2722
2723    with context.graph_mode(), self.cached_session():
2724      with ops.get_default_graph().as_default():
2725        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2726        function.register(defun_matmul, t, t)
2727        function.register(add, t, t)
2728
2729        graph = ops.get_default_graph()
2730        # pylint: disable=protected-access
2731        self.assertLen(graph._functions, 6)
2732        # two sets of functions, each of them are (inference, forward, backward)
2733        functions = list(graph._functions.values())
2734        captured_function_names = [
2735            f.definition.signature.name for f in functions
2736        ]
2737        expected_func_name_regex = [
2738            '.*inference.*matmul.*',
2739            '.*forward.*matmul.*',
2740            '.*inference.*backward.*matmul.*',
2741            '.*inference.*add.*',
2742            '.*forward.*add.*',
2743            '.*inference.*backward.*add.*',
2744        ]
2745        for i in range(len(functions)):
2746          self.assertRegex(captured_function_names[i],
2747                           expected_func_name_regex[i])
2748
2749        # Check the forward and backward function has the correct attributes.
2750        self.assertEqual(
2751            functions[1].definition.attr['backward_function_name'].s,
2752            functions[2].name)
2753        self.assertEqual(
2754            functions[2].definition.attr['forward_function_name'].s,
2755            functions[1].name)
2756
2757        self.assertEqual(
2758            functions[4].definition.attr['backward_function_name'].s,
2759            functions[5].name)
2760        self.assertEqual(
2761            functions[5].definition.attr['forward_function_name'].s,
2762            functions[4].name)
2763
2764        sq = defun_matmul(t, t)
2765        double = add(t, t)
2766        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2767        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2768        # Make sure the pre registered function is used, and no other function
2769        # is added.
2770        self.assertLen(graph._functions, 6)
2771        functions = list(graph._functions.values())
2772        for i in range(len(functions)):
2773          self.assertEqual(captured_function_names[i],
2774                           functions[i].definition.signature.name)
2775
2776  @parameterized.named_parameters(
2777      dict(testcase_name='Defun',
2778           function_decorator=function.defun),
2779      dict(testcase_name='DefFunction',
2780           function_decorator=def_function.function))
2781  def testRegisterConcreteFunction(self, function_decorator):
2782    @function_decorator
2783    def py_add(x, y):
2784      return math_ops.add(x, y)
2785
2786    py_add(array_ops.ones([]), array_ops.ones([]))
2787    add = py_add.get_concrete_function(
2788        tensor_spec.TensorSpec(None, dtypes.float32),
2789        tensor_spec.TensorSpec(None, dtypes.float32))
2790
2791    @function_decorator
2792    def py_composite(x, y):
2793      return x, add(x, y)
2794
2795    py_composite(array_ops.ones([]), array_ops.ones([]))
2796    composite = py_composite.get_concrete_function(
2797        tensor_spec.TensorSpec(None, dtypes.float32),
2798        tensor_spec.TensorSpec(None, dtypes.float32))
2799
2800    with context.graph_mode(), self.cached_session():
2801      with ops.get_default_graph().as_default():
2802        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2803        composite.add_to_graph()
2804        composite.add_gradient_functions_to_graph()
2805
2806        graph = ops.get_default_graph()
2807        # pylint: disable=protected-access
2808        self.assertLen(graph._functions, 6)
2809        # two sets of functions, each of them are (inference, forward, backward)
2810        functions = list(graph._functions.values())
2811        captured_function_names = [
2812            f.definition.signature.name for f in functions
2813        ]
2814        expected_func_name_regex = [
2815            '.*inference.*py_composite.*',
2816            '.*inference.*py_add.*',
2817            '.*forward.*py_composite.*',
2818            '.*forward.*py_add.*',
2819            '.*inference.*backward.*py_composite.*',
2820            '.*inference.*backward.*py_add.*',
2821        ]
2822        for expected, found in zip(
2823            expected_func_name_regex,
2824            captured_function_names):
2825          self.assertRegex(found, expected)
2826
2827        composite_t, composite_double = composite(t, t)
2828        double = add(t, t)
2829        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
2830        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
2831        self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
2832        # Make sure the pre registered function is used, and no other function
2833        # is added.
2834        self.assertLen(graph._functions, 6)
2835
2836  @parameterized.named_parameters(
2837      dict(testcase_name='Defun',
2838           function_decorator=function.defun),
2839      dict(testcase_name='DefFunction',
2840           function_decorator=def_function.function))
2841  def testEagerCaptures(self, function_decorator):
2842    with context.eager_mode():
2843      large_tensor = array_ops.ones(shape=(256,))
2844      self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
2845
2846      small_tensor = array_ops.ones(shape=(4,))
2847      self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
2848
2849      v = resource_variable_ops.ResourceVariable(0.0)
2850
2851    for captured, op_type in [(large_tensor, 'Placeholder'),
2852                              (small_tensor, 'Const'), (v, 'Placeholder')]:
2853      @function_decorator
2854      def test_fn():
2855        return captured + 1  # pylint: disable=cell-var-from-loop
2856
2857      g = test_fn.get_concrete_function().graph
2858      internal_captures = g.internal_captures
2859      self.assertLen(internal_captures, 1)
2860      self.assertEqual(internal_captures[0].op.type, op_type)
2861
2862  def testRegisterFunctionWithInputSignature(self):
2863    def matmul(x, y):
2864      return math_ops.matmul(x, y)
2865    defun_matmul = function.defun(
2866        matmul,
2867        input_signature=[
2868            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
2869            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
2870        ])
2871    with context.graph_mode(), self.cached_session():
2872      with ops.get_default_graph().as_default():
2873        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2874        function.register(defun_matmul, t, t)
2875
2876        graph = ops.get_default_graph()
2877        # pylint: disable=protected-access
2878        self.assertLen(graph._functions, 3)
2879
2880        # Test register function with cache, note inputs are ignored.
2881        function.register(defun_matmul)
2882        graph = ops.get_default_graph()
2883        self.assertLen(graph._functions, 3)
2884
2885  def testRegisterFunctionWithCache(self):
2886    def matmul(x, y):
2887      return math_ops.matmul(x, y)
2888    defun_matmul = function.defun(matmul)
2889
2890    with context.graph_mode(), self.cached_session():
2891      with ops.get_default_graph().as_default():
2892        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2893        t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
2894        function.register(defun_matmul, t, t)
2895        function.register(defun_matmul, t2, t2)
2896
2897        graph = ops.get_default_graph()
2898        # Only one function is registered since the input param are in same type
2899        # pylint: disable=protected-access
2900        self.assertLen(graph._functions, 3)
2901
2902  def testCallingFunctionWithDifferentVariables(self):
2903
2904    @function.defun
2905    def foo(v):
2906      v.assign_add(1.0)
2907      return v.read_value()
2908
2909    v = resource_variable_ops.ResourceVariable(0.0)
2910    graph_function = foo.get_concrete_function(v)
2911    self.assertLen(graph_function.inputs, 1)
2912    self.assertEmpty(graph_function.captured_inputs)
2913
2914    self.assertEqual(float(graph_function(v)), 1.0)
2915    self.assertEqual(float(graph_function(v)), 2.0)
2916
2917    w = resource_variable_ops.ResourceVariable(0.0)
2918
2919    @function.defun
2920    def bar(v):
2921      del v
2922      return constant_op.constant(1.0)
2923
2924    graph_function = bar.get_concrete_function(v)
2925    self.assertEqual(float(graph_function(v)), 1.0)
2926    self.assertEqual(float(graph_function(w)), 1.0)
2927
2928  def testCallingFunctionWithNonTensorsFails(self):
2929
2930    @function.defun
2931    def foo(x):
2932      return x
2933
2934    graph_function = foo.get_concrete_function(constant_op.constant(1.0))
2935    with self.assertRaises((TypeError, ValueError)):
2936      graph_function('Not a Tensor.')
2937
2938  def testSwapImplementationWithGrapplerPlugin(self):
2939    # Set the min_graph_nodes to -1 since the graph in this test is too small,
2940    # and will be ignored by grappler if don't set this.
2941    rewrites = rewriter_config_pb2.RewriterConfig()
2942    rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
2943    rewrites.min_graph_nodes = -1
2944    graph_options = config_pb2.GraphOptions(
2945        rewrite_options=rewrites, build_cost_model=1)
2946    config_proto = config_pb2.ConfigProto(graph_options=graph_options)
2947
2948    with context.graph_mode(), self.cached_session(
2949        config=config_proto, graph=ops.Graph(), use_gpu=True):
2950
2951      @function.defun_with_attributes(
2952          attributes={
2953              'api_implements': 'random_boost',
2954              'api_preferred_device': 'CPU'
2955          })
2956      def cpu_boost(x):
2957        return math_ops.add(x, 2.0)
2958
2959      @function.defun_with_attributes(
2960          attributes={
2961              'api_implements': 'random_boost',
2962              'api_preferred_device': 'GPU'
2963          })
2964      def gpu_boost(x):
2965        return math_ops.add(x, 4.0)
2966
2967      x = constant_op.constant(1.0)
2968
2969      function.register(cpu_boost, x)
2970      y = gpu_boost(x)
2971      y_value = self.evaluate(y)
2972
2973      if test.is_gpu_available():
2974        self.assertEqual(y_value, 5.0)
2975      else:
2976        # Grappler fallback to use the CPU impl even called with GPU function.
2977        self.assertEqual(y_value, 3.0)
2978
2979  @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
2980                          'equivalent to implementation_selector for function')
2981  def testSwapImplementationInEager(self):
2982    if not context.executing_eagerly():
2983      self.skipTest('eager only')
2984
2985    # testSharedRendezvous sets the disable_meta_optimizer flag to True
2986    # if that subtest runs before this one, then having that set to True
2987    # will cause this subtest to fail. To avoid that scenario, explicitly
2988    # set the disable_meta_optimizer flag to false here
2989    context.context().set_optimizer_experimental_options({
2990        'min_graph_nodes': -1,
2991        'implementation_selector': True,
2992        'disable_meta_optimizer': False
2993    })
2994
2995    @function.defun_with_attributes(
2996        attributes={'api_implements': 'foo',
2997                    'api_preferred_device': 'CPU'})
2998    def on_cpu(x):
2999      return x + 2
3000
3001    @function.defun_with_attributes(
3002        attributes={'api_implements': 'foo',
3003                    'api_preferred_device': 'GPU'})
3004    def on_gpu(x):
3005      return x + 4
3006
3007    @function.defun
3008    def run_on_cpu(t):
3009      function.register(on_cpu, t)
3010      with ops.device('CPU:0'):
3011        return on_gpu(t)
3012
3013    # Expect to run the on_cpu branch, regardless whether gpu is available.
3014    self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
3015
3016  def testDefunFunctionSeparateGraphs(self):
3017    with context.graph_mode():
3018
3019      @function.defun
3020      def add(x):
3021        return x + 5
3022
3023      @function.defun
3024      def maybe_add(x, should_add):
3025        if should_add:
3026          return add(x)
3027        else:
3028          return x
3029
3030      with ops.Graph().as_default():
3031        x = constant_op.constant(11)
3032        maybe_add(x, True)
3033        self.assertLen(total_function_cache(maybe_add), 1)
3034        self.assertLen(total_function_cache(add), 1)
3035
3036        maybe_add(x, False)
3037        self.assertLen(total_function_cache(maybe_add), 2)
3038        self.assertLen(total_function_cache(add), 1)
3039
3040      with ops.Graph().as_default():
3041        x = constant_op.constant(11)
3042        maybe_add(x, True)
3043        self.assertLen(total_function_cache(maybe_add), 3)
3044        self.assertLen(total_function_cache(add), 2)
3045
3046  def testCacheKeyOverlappingShapes(self):
3047    @function.defun
3048    def defined(t):
3049      return t
3050
3051    defined(array_ops.zeros([12, 1]))
3052    self.assertLen(total_function_cache(defined), 1)
3053    defined(array_ops.zeros([1, 21]))
3054    self.assertLen(total_function_cache(defined), 2)
3055
3056    @function.defun
3057    def defined_again(t):
3058      return defined(t)
3059
3060    defined_again.get_concrete_function(array_ops.zeros([12, 1]))
3061    self.assertLen(total_function_cache(defined_again), 1)
3062    defined_again.get_concrete_function(array_ops.zeros([1, 21]))
3063    self.assertLen(total_function_cache(defined_again), 2)
3064
3065  def testCacheTensorSpecIdenticalToTensor(self):
3066    @function.defun
3067    def defined(t):
3068      return t
3069
3070    z = array_ops.zeros([2, 2])
3071    z_spec = tensor_spec.TensorSpec.from_tensor(z)
3072    self.assertIs(
3073        defined.get_concrete_function(z_spec), defined.get_concrete_function(z))
3074
3075  def testCacheKeyNestedLists(self):
3076    @function.defun
3077    def defined(l):
3078      return l
3079
3080    a = constant_op.constant(1.)
3081    b = constant_op.constant(2.)
3082    c = constant_op.constant(3.)
3083    defined([[a], b, c])
3084    self.assertLen(total_function_cache(defined), 1)
3085
3086    defined([[a, b], c])
3087    self.assertLen(total_function_cache(defined), 2)
3088
3089  def testCacheKeyAttrsClass(self):
3090    if attr is None:
3091      self.skipTest('attr module is unavailable.')
3092
3093    @attr.s
3094    class TestClass:
3095      a = attr.ib()
3096      b = attr.ib()
3097
3098    @function.defun
3099    def defined(l):
3100      return l
3101
3102    defined(
3103        TestClass(
3104            constant_op.constant(1.),
3105            [constant_op.constant(2.),
3106             constant_op.constant(3.)]))
3107    self.assertLen(total_function_cache(defined), 1)
3108    defined(
3109        TestClass(
3110            constant_op.constant(1.),
3111            [constant_op.constant(2.),
3112             constant_op.constant(3.)]))
3113    self.assertLen(total_function_cache(defined), 1)
3114
3115    defined(
3116        TestClass([constant_op.constant(1.),
3117                   constant_op.constant(2.)], constant_op.constant(3.)))
3118    self.assertLen(total_function_cache(defined), 2)
3119
3120  def testDistinctVariablesNoRetracing(self):
3121    @function.defun
3122    def defined(a, b, c):
3123      return a + b + c
3124
3125    x = resource_variable_ops.ResourceVariable(0.0)
3126    y = resource_variable_ops.ResourceVariable(0.0)
3127    z = resource_variable_ops.ResourceVariable(0.0)
3128
3129    # We generate cache keys based on unique combinations of resource ids.
3130    defined(x, y, z)
3131    self.assertLen(total_function_cache(defined), 1)
3132
3133    # Re-arranging arguments should not cause cache miss
3134    # because the three inputs are still distinct
3135    defined(z, y, x)
3136    self.assertLen(total_function_cache(defined), 1)
3137
3138  def testRetracingOnDifferentVaribleCombinationPatterns(self):
3139    @function.defun
3140    def defined(a, b, c):
3141      return a + b + c
3142
3143    x = resource_variable_ops.ResourceVariable(0.0)
3144    y = resource_variable_ops.ResourceVariable(0.0)
3145    z = resource_variable_ops.ResourceVariable(0.0)
3146
3147    defined(x, y, z)
3148    self.assertLen(total_function_cache(defined), 1)
3149
3150    # Retracing because the first two arguments are the same
3151    defined(x, x, z)
3152    self.assertLen(total_function_cache(defined), 2)
3153
3154    # Replacing x with y does not cause cache miss
3155    # because the combination stays the same as (x, x, z)
3156    defined(y, y, z)
3157    self.assertLen(total_function_cache(defined), 2)
3158
3159    # A different combination pattern causes cache miss
3160    defined(z, y, y)
3161    self.assertLen(total_function_cache(defined), 3)
3162    defined(z, y, y)
3163    self.assertLen(total_function_cache(defined), 3)
3164
3165  def testDeepcopyVariableNoRetracing(self):
3166    @function.defun
3167    def defined(a, b, c):
3168      return a + b + c
3169
3170    x = resource_variable_ops.ResourceVariable(0.0)
3171    y = resource_variable_ops.ResourceVariable(0.0)
3172    z = resource_variable_ops.ResourceVariable(0.0)
3173    defined(x, y, z)
3174    self.assertLen(total_function_cache(defined), 1)
3175
3176    x_copy = copy.deepcopy(x)
3177    defined(x_copy, y, z)
3178    self.assertLen(total_function_cache(defined), 1)
3179
3180  def _total_function_cache_def_func(self, defined):
3181    return defined._list_all_concrete_functions()  # pylint: disable=protected-access
3182
3183  def testVariableRetracingOnDtypeChanges(self):
3184
3185    @def_function.function
3186    def defined(a, b):
3187      return a + b
3188
3189    x1 = resource_variable_ops.ResourceVariable(0.0)
3190    x2 = resource_variable_ops.ResourceVariable(0.0)
3191
3192    defined(x1, x2)
3193    self.assertLen(self._total_function_cache_def_func(defined), 1)
3194
3195    # Should expect retracing for new dtypes
3196    y1 = resource_variable_ops.ResourceVariable(0)
3197    y2 = resource_variable_ops.ResourceVariable(1)
3198    defined(y1, y2)
3199    self.assertLen(self._total_function_cache_def_func(defined), 2)
3200
3201  def testVariableRetracingDtypeShape(self):
3202
3203    @def_function.function
3204    def defined(a, b):
3205      return a + b
3206
3207    x1 = resource_variable_ops.ResourceVariable(0.0)
3208    x2 = resource_variable_ops.ResourceVariable(0.0)
3209
3210    defined(x1, x2)
3211    self.assertLen(self._total_function_cache_def_func(defined), 1)
3212
3213    y1 = resource_variable_ops.ResourceVariable([0.0, 1.0])
3214    y2 = resource_variable_ops.ResourceVariable([0.0, 1.0])
3215
3216    defined(y1, y2)
3217    self.assertLen(self._total_function_cache_def_func(defined), 2)
3218
3219    z1 = resource_variable_ops.ResourceVariable([[0.0, 1.0]])
3220    z2 = resource_variable_ops.ResourceVariable([[0.0, 1.0]])
3221    defined(z1, z2)
3222    self.assertLen(self._total_function_cache_def_func(defined), 3)
3223
3224  def testDecoratedMethodInspect(self):
3225
3226    class DefunnedMiniModel:
3227
3228      @function.defun
3229      def call(self, inputs, training=True):
3230        pass
3231
3232    m = DefunnedMiniModel()
3233    fullargspec = tf_inspect.getfullargspec(m.call)
3234    self.assertIn('training', fullargspec.args)
3235
3236  def testFunctionModifiesInputList(self):
3237    # Tests on `list` methods that do in place modification, except `list.sort`
3238    # since it cannot even be "defunned" in the first place
3239
3240    def get_list():
3241      return [constant_op.constant(0.), constant_op.constant(1.)]
3242
3243    expected_msg = '.*() should not modify'
3244
3245    with self.assertRaisesRegex(ValueError, expected_msg):
3246
3247      @def_function.function
3248      def append(l):
3249        l.append(constant_op.constant(0.))
3250
3251      append(get_list())
3252
3253    with self.assertRaisesRegex(ValueError, expected_msg):
3254
3255      @def_function.function
3256      def extend(l):
3257        l.extend([constant_op.constant(0.)])
3258
3259      extend(get_list())
3260
3261    with self.assertRaisesRegex(ValueError, expected_msg):
3262
3263      @def_function.function
3264      def insert(l):
3265        l.insert(0, constant_op.constant(0.))
3266
3267      insert(get_list())
3268
3269    with self.assertRaisesRegex(ValueError, expected_msg):
3270
3271      @def_function.function
3272      def pop(l):
3273        l.pop()
3274
3275      pop(get_list())
3276
3277    with self.assertRaisesRegex(ValueError, expected_msg):
3278
3279      @def_function.function
3280      def reverse(l):
3281        l.reverse()
3282
3283      reverse(get_list())
3284
3285    with self.assertRaisesRegex(ValueError, expected_msg):
3286
3287      @def_function.function
3288      def remove(l):
3289        l.remove(l[0])
3290
3291      remove(get_list())
3292
3293    # `list.clear` is a method that is in Py3 but not Py2
3294    if sys.version.startswith('3'):
3295
3296      with self.assertRaisesRegex(ValueError, expected_msg):
3297
3298        @def_function.function
3299        def clear(l):
3300          l.clear()
3301
3302        clear(get_list())
3303
3304    # One last test for keyword arguments
3305    with self.assertRaisesRegex(ValueError, expected_msg):
3306
3307      @def_function.function
3308      def kwdappend(**kwargs):
3309        l = kwargs['l']
3310        l.append(constant_op.constant(0.))
3311
3312      kwdappend(l=get_list())
3313
3314  def testFunctionModifiesInputDict(self):
3315
3316    def get_dict():
3317      return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
3318
3319    expected_msg = '.* should not modify'
3320
3321    with self.assertRaisesRegex(ValueError, expected_msg):
3322
3323      @def_function.function
3324      def clear(m):
3325        m.clear()
3326
3327      clear(get_dict())
3328
3329    with self.assertRaisesRegex(ValueError, expected_msg):
3330
3331      @def_function.function
3332      def pop(m):
3333        m.pop('t1')
3334
3335      pop(get_dict())
3336
3337    with self.assertRaisesRegex(ValueError, expected_msg):
3338
3339      @def_function.function
3340      def popitem(m):
3341        m.popitem()
3342
3343      popitem(get_dict())
3344
3345    with self.assertRaisesRegex(ValueError, expected_msg):
3346
3347      @def_function.function
3348      def update(m):
3349        m.update({'t1': constant_op.constant(3.)})
3350
3351      update(get_dict())
3352
3353    with self.assertRaisesRegex(ValueError, expected_msg):
3354
3355      @def_function.function
3356      def setdefault(m):
3357        m.setdefault('t3', constant_op.constant(3.))
3358
3359      setdefault(get_dict())
3360
3361  def testFunctionModifiesInputNest(self):
3362    with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
3363
3364      @def_function.function
3365      def modify(n):
3366        n[0]['t1'].append(constant_op.constant(1.))
3367
3368      nested_input = [{
3369          't1': [constant_op.constant(0.),
3370                 constant_op.constant(1.)],
3371      },
3372                      constant_op.constant(2.)]
3373
3374      modify(nested_input)
3375
3376    with self.assertRaisesRegex(ValueError,
3377                                'modify_same_flat.* should not modify'):
3378
3379      # The flat list doesn't change whereas the true structure changes
3380      @def_function.function
3381      def modify_same_flat(n):
3382        n[0].append(n[1].pop(0))
3383
3384      nested_input = [[constant_op.constant(0.)],
3385                      [constant_op.constant(1.),
3386                       constant_op.constant(2.)]]
3387
3388      modify_same_flat(nested_input)
3389
3390  @test_util.disable_tfrt('b/173429686')
3391  def testExecutorType(self):
3392    @function.defun
3393    def add_five(x):
3394      return x + 5
3395
3396    self.assertEqual(
3397        5,
3398        add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3399
3400    with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
3401      with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
3402        add_five(constant_op.constant(0, dtype=dtypes.int32))
3403
3404    for executor_type in ('', 'DEFAULT', None):
3405      with context.function_executor_type(executor_type):
3406        self.assertAllEqual(
3407            5,
3408            add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3409
3410  @test_util.assert_no_garbage_created
3411  def testReferenceCycles(self):
3412
3413    fn = function.defun(lambda x: 2. * x)
3414
3415    fn(constant_op.constant(4.0))
3416    weak_fn = weakref.ref(fn)
3417    del fn
3418    # Tests that the weak reference we made to the function is now dead, which
3419    # means the object has been deleted. This should be true as long as the
3420    # function itself is not involved in a reference cycle.
3421    self.assertIs(None, weak_fn())
3422
3423  def testFunctionStackInErrorMessage(self):
3424    if context.executing_eagerly():
3425      # TODO(b/122736651): Remove this skipTest once fixed.
3426      self.skipTest('Error interpolation is not working when function is '
3427                    'invoked without PartitionedCallOp.')
3428
3429    @def_function.function()
3430    def fn3(x):
3431      return x + 2
3432
3433    @def_function.function()
3434    def fn2(x):
3435      check_ops.assert_equal(fn3(x), 3)
3436      return 2
3437
3438    @def_function.function()
3439    def fn(x):
3440      return fn2(x)
3441
3442    with self.assertRaises(errors.InvalidArgumentError) as cm:
3443      fn(2)
3444    e = cm.exception
3445    self.assertIn('fn -> fn2', e.message)
3446    self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
3447    self.assertNotIn('fn3', e.message)
3448
3449  @test_util.run_gpu_only
3450  def testFunctionIsNotPinned(self):
3451    """Tests that functions aren't pinned to the CPU by the eager runtime."""
3452    seed1, seed2 = 79, 25
3453    shape = constant_op.constant([4, 7])
3454    dtype = dtypes.float32
3455
3456    @def_function.function
3457    def func():
3458      with ops.device('GPU:0'):
3459        return gen_random_ops.random_standard_normal(
3460            shape, dtype=dtype, seed=seed1, seed2=seed2)
3461
3462    with ops.device('GPU:0'):
3463      x = func()
3464      self.assertRegex(x.device, 'GPU')
3465
3466  @test_util.run_in_graph_and_eager_modes
3467  def testShapeCaching(self):
3468
3469    @function.defun
3470    def func(x):
3471      return array_ops.shape(x)
3472
3473    @function.defun(
3474        input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
3475    def calls_func(x):
3476      return func(x)
3477
3478    self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
3479    self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
3480    self.assertAllEqual(
3481        [3, 3],
3482        self.evaluate(calls_func(array_ops.zeros([3, 3]))))
3483
3484  def testLimitedRetracing(self):
3485    trace_count = [0]
3486    @function.defun
3487    def func(x):
3488      trace_count[0] += 1
3489      return x
3490
3491    for _ in range(50):
3492      func(constant_op.constant(3.))
3493      func(constant_op.constant(4.))
3494      func(constant_op.constant([[1., 2.]]))
3495      func(constant_op.constant([[]]))
3496      func(constant_op.constant([[3., 4.], [5., 6.]]))
3497      func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
3498    # Tracing more than twice per input doesn't make sense.
3499    self.assertLess(trace_count[0], 13)
3500
3501  def testLimitedRetracingWithCompositeTensors(self):
3502    trace_count = [0]
3503
3504    @def_function.function
3505    def f(x):
3506      trace_count[0] += 1
3507      return x
3508
3509    for i in range(10):
3510      f(ragged_factory_ops.constant([[1, 2], [i]]))
3511      f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
3512      f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
3513      self.assertEqual(trace_count[0], 3)
3514
3515  def test_concrete_function_shape_mismatch(self):
3516
3517    @def_function.function
3518    def f(argument_name):
3519      return argument_name + 1.
3520
3521    f_concrete = f.get_concrete_function(constant_op.constant([1.]))
3522
3523    # Calling a function from eager doesn't do any shape checking above what
3524    # kernels do while executing.
3525    self.assertAllEqual(
3526        [2., 3.],
3527        f_concrete(constant_op.constant([1., 2.])).numpy())
3528
3529    @def_function.function
3530    def g():
3531      f_concrete(constant_op.constant([1., 2.]))
3532
3533    with self.assertRaisesRegex(ValueError, 'is not compatible with the shape'):
3534      g()
3535
3536  @test_util.run_in_graph_and_eager_modes
3537  def test_shape_inference_with_symbolic_shapes(self):
3538
3539    @def_function.function
3540    def _uses_symbolic_shapes(w, x, y):
3541      x = array_ops.identity(x, name='name_collision')
3542      x = array_ops.transpose(x, [1, 0, 2])
3543      x_batch = array_ops.shape(x)[0]
3544      y_batch = array_ops.shape(y)[0]
3545      y *= w
3546      n = y_batch // x_batch
3547      return array_ops.reshape(y, [n, x_batch, -1])
3548
3549    conc = _uses_symbolic_shapes.get_concrete_function(
3550        tensor_spec.TensorSpec(None, dtypes.float32),
3551        tensor_spec.TensorSpec(None, dtypes.float32),
3552        tensor_spec.TensorSpec(None, dtypes.float32))
3553
3554    @def_function.function
3555    def _call_concrete():
3556      c = constant_op.constant(1.)
3557      array_ops.identity(c, name='name_collision')
3558      output1 = conc(array_ops.ones([2]),
3559                     array_ops.ones([5, 4, 2]),
3560                     array_ops.ones([20, 2]))
3561      self.assertEqual([5, 4, 2], output1.shape)
3562      output2 = conc(array_ops.ones([3]),
3563                     array_ops.ones([5, 4, 3]),
3564                     array_ops.ones([40, 3]))
3565      self.assertEqual([10, 4, 3], output2.shape)
3566      return output1, output2
3567
3568    output1, output2 = _call_concrete()
3569    self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
3570    self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
3571
3572  def testAutoGraphContext(self):
3573
3574    @def_function.function
3575    def test_fn():
3576      self.assertEqual(
3577          ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
3578
3579    prev_status = ag_ctx.control_status_ctx().status
3580    test_fn()
3581    self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
3582
3583  @test_util.disable_tfrt('b/170435618')
3584  def testCancelBeforeFunctionExecution(self):
3585    if not context.executing_eagerly():
3586      self.skipTest('eager only')
3587
3588    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3589
3590    @def_function.function
3591    def f():
3592      return q.dequeue()
3593
3594    c_mgr = cancellation.CancellationManager()
3595    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3596
3597    c_mgr.start_cancel()
3598    with self.assertRaises(errors.CancelledError):
3599      cancelable_func()
3600
3601  @test_util.disable_tfrt('b/170435618')
3602  def testCancelBlockedFunctionExecution(self):
3603    if not context.executing_eagerly():
3604      self.skipTest('eager only')
3605
3606    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3607
3608    @def_function.function
3609    def f():
3610      return q.dequeue()
3611
3612    c_mgr = cancellation.CancellationManager()
3613    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3614
3615    def cancel_thread():
3616      time.sleep(0.5)
3617      c_mgr.start_cancel()
3618
3619    t = self.checkedThread(cancel_thread)
3620    t.start()
3621    with self.assertRaises(errors.CancelledError):
3622      cancelable_func()
3623    t.join()
3624
3625  @test_util.disable_tfrt('b/170435618')
3626  def testCancelAfterFunctionExecution(self):
3627    if not context.executing_eagerly():
3628      self.skipTest('eager only')
3629
3630    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3631    q.enqueue(37)
3632
3633    @def_function.function
3634    def f():
3635      return q.dequeue()
3636
3637    c_mgr = cancellation.CancellationManager()
3638    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3639
3640    self.assertAllEqual(37, cancelable_func().numpy())
3641
3642    # Cancellation after the function executes is a no-op.
3643    c_mgr.start_cancel()
3644
3645  def testAddFunctionCallback(self):
3646    functions = []
3647    def function_callback(f, name, graph, inputs, outputs):
3648      del name, graph, inputs, outputs
3649      functions.append(f)
3650
3651    @def_function.function
3652    def plus_one(x):
3653      return x + 1
3654
3655    try:
3656      function.add_function_callback(function_callback)
3657      x_float32 = numpy.array(3.0, dtype=numpy.float32)
3658      self.assertAllClose(plus_one(x_float32), 4.0)
3659      self.assertLen(functions, 1)
3660      # Function is already created. Executing it again should not invoke the
3661      # function callback.
3662      self.assertAllClose(plus_one(x_float32), 4.0)
3663      self.assertLen(functions, 1)
3664      # Signature change leads to a new Function being built.
3665      x_float64 = numpy.array(3.0, dtype=numpy.float64)
3666      self.assertAllClose(plus_one(x_float64), 4.0)
3667      self.assertLen(functions, 2)
3668    finally:
3669      function.clear_function_callbacks()
3670
3671  def testFunctionCallbackAddOps(self):
3672    file_name = os.path.join(self.get_temp_dir(), 'test')
3673
3674    def function_callback(f, name, graph, inputs, outputs):
3675      del f, name, inputs
3676
3677      with graph.as_default():
3678        printer = logging_ops.print_v2(
3679            'hello',
3680            output_stream='file://' + file_name
3681        )
3682        outputs[0].op._add_control_input(printer)
3683
3684    @def_function.function
3685    def plus_one(x):
3686      return x + 1
3687
3688    self.addCleanup(function.clear_function_callbacks)
3689    function.add_function_callback(function_callback)
3690    x_float32 = numpy.array(3.0, dtype=numpy.float32)
3691
3692    self.assertAllClose(plus_one(x_float32), 4.0)
3693
3694    with open(file_name, 'r') as f:
3695      self.assertEqual(f.read().strip(), 'hello')
3696
3697  def testRemoveFunctionCallback(self):
3698    functions_1 = []
3699    def function_callback_1(f, name, graph, inputs, outputs):
3700      del name, graph, inputs, outputs
3701      functions_1.append(f)
3702
3703    functions_2 = []
3704    def function_callback_2(f, name, graph, inputs, outputs):
3705      del name, graph, inputs, outputs
3706      functions_2.append(f)
3707
3708    @def_function.function
3709    def plus_one(x):
3710      return x + 1
3711
3712    try:
3713      function.add_function_callback(function_callback_1)
3714      function.add_function_callback(function_callback_2)
3715      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0)
3716      self.assertLen(functions_1, 1)
3717      self.assertLen(functions_2, 1)
3718      function.remove_function_callback(function_callback_1)
3719      # The 1st callback should not be invokved after remove_function_callback()
3720      # is called.
3721      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0)
3722      self.assertLen(functions_1, 1)
3723      self.assertLen(functions_2, 2)
3724    finally:
3725      function.clear_function_callbacks()
3726
3727  def testClearFunctionCallbacks(self):
3728    function.add_function_callback(lambda f: None)
3729    function.add_function_callback(lambda f: None)
3730    self.assertLen(function._function_callbacks, 2)
3731    function.clear_function_callbacks()
3732    self.assertEmpty(function._function_callbacks)  # pylint:disable=protected-access
3733
3734  @test_util.run_in_graph_and_eager_modes
3735  def testConcreteFunctionWithNestedTensorInputs(self):
3736
3737    @def_function.function
3738    def f(x, y):
3739      return (x['a'] + x['b'], y[0] + y[1])
3740
3741    a = constant_op.constant(1000)
3742    b = constant_op.constant(200)
3743    c = constant_op.constant(30)
3744    d = {'a': a, 'b': b}
3745    e = (c, 4)
3746
3747    # Test different argument signatures when constructing the concrete func.
3748    for cf in [
3749        f.get_concrete_function(d, e),
3750        f.get_concrete_function(d, y=e),
3751        f.get_concrete_function(y=e, x=d),
3752        f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
3753        f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
3754        f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
3755    ]:
3756      # Test different calling conventions when calling the concrete func.
3757      for output in [
3758          cf(d, e),  # structured signature
3759          cf(d, y=e),  # structured signature w/ kwarg
3760          cf(y=e, x=d),  # structured signature w/ 2 kwargs
3761          cf(a, b, c),  # flat signature
3762      ]:
3763        self.assertIsInstance(output, tuple)
3764        self.assertLen(output, 2)
3765        self.assertAllEqual(output[0], 1200)
3766        self.assertAllEqual(output[1], 34)
3767
3768  @test_util.run_in_graph_and_eager_modes
3769  def testConcreteFunctionWithNestedNonTensorInputs(self):
3770
3771    @def_function.function
3772    def f(x, y):
3773      return (x['a'] + x['b'], y[0] + y[1])
3774
3775    a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
3776    b = (50, 3)
3777
3778    for cf in [  # argument y is bound to non-Tensor value (50, 3).
3779        f.get_concrete_function(a, b),
3780        f.get_concrete_function(a, y=b),
3781        f.get_concrete_function(x=a, y=b)
3782    ]:
3783      for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3784        self.assertAllEqual(output[0] + output[1], 1253)
3785
3786  @test_util.run_in_graph_and_eager_modes
3787  def testConcreteFunctionWithNonTensorStringInputs(self):
3788
3789    @def_function.function
3790    def f(x, y):
3791      return string_ops.string_join([x, y])
3792
3793    a = constant_op.constant('a')
3794    b = 'b'
3795
3796    cf = f.get_concrete_function(a, b)
3797    for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3798      self.assertAllEqual(output, b'ab')
3799
3800  @test_util.run_in_graph_and_eager_modes
3801  def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
3802
3803    @def_function.function
3804    def f(x, y):
3805      return (x['a'] + x['b'], y[0] + y[1])
3806
3807    a = {'a': 3000, 'b': 200, 'c': 9000}
3808    b = (constant_op.constant(30), 4)
3809
3810    for cf in [  # argument x is bound to non-tensor value `a`
3811        f.get_concrete_function(a, b),
3812        f.get_concrete_function(a, y=b),
3813        f.get_concrete_function(x=a, y=b)
3814    ]:
3815      for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
3816        self.assertAllEqual(output[0] + output[1], 3234)
3817
3818  @test_util.run_in_graph_and_eager_modes
3819  def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
3820
3821    @def_function.function
3822    def f(x, y):
3823      return (x['a'] + x['b'], y[0] + y[1])
3824
3825    a = {'a': 5000, 'b': 500}
3826    b = (50, 5)
3827
3828    cf = f.get_concrete_function(a, b)
3829    for output in [cf(), cf(a), cf(y=b)]:
3830      self.assertAllEqual(output[0] + output[1], 5555)
3831
3832  @test_util.run_in_graph_and_eager_modes
3833  def testConcreteFunctionMethodWithVarargs(self):
3834    float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
3835
3836    class MyModel(module.Module):
3837
3838      @def_function.function(input_signature=[float32_scalar, float32_scalar])
3839      def add(self, *arg):
3840        return math_ops.add(*arg)
3841
3842    m = MyModel()
3843    cf = m.add.get_concrete_function()
3844    cf(-12.0, 3.0)
3845
3846  @test_util.run_in_graph_and_eager_modes
3847  def testConcreteFunctionStructuredSignatureKeywordOrder(self):
3848    # Check that keyword-only arguments are sorted appropriately, so that they
3849    # feed the right tensor into each input.
3850    @def_function.function
3851    def g(**kwargs):
3852      return string_ops.reduce_join(
3853          string_ops.reduce_join(
3854              ops.convert_to_tensor(sorted(kwargs.items())),
3855              axis=1,
3856              separator='='),
3857          axis=0,
3858          separator=', ')
3859
3860    s = constant_op.constant('s')
3861    g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
3862    self.assertAllEqual(
3863        g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
3864        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3865    self.assertAllEqual(
3866        g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
3867        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3868    self.assertAllEqual(
3869        g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
3870        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3871
3872  # pylint: disable=g-long-lambda
3873  @parameterized.named_parameters([
3874      dict(
3875          testcase_name='MissingArg',
3876          conc_args=lambda: (1, constant_op.constant(2)),
3877          call_args=lambda: (1,),
3878          error=r'func\(x, y\) missing required arguments: y'),
3879      dict(
3880          testcase_name='MissingVararg',
3881          conc_args=lambda: (1, 2, constant_op.constant(1.0)),
3882          call_args=lambda: (1, 2),
3883          error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
3884      dict(
3885          testcase_name='ExtraPositionalArg',
3886          conc_args=lambda: (1, 2),
3887          call_args=lambda: (1, 2, 3),
3888          error=r'func\(x, y\) takes 2 .* got 3'),
3889      dict(
3890          testcase_name='MissingKeywordOnlyArg',
3891          conc_args=lambda: (1, 2),
3892          conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
3893          call_args=lambda: (1, 2),
3894          error=r'func\(x, y, \*, c\) missing required arguments: c'),
3895      dict(
3896          testcase_name='ExtraKeywordArg',
3897          conc_args=lambda: (1, 2),
3898          call_args=lambda: (1, 2),
3899          call_kwargs=lambda: {'c': constant_op.constant(1.0)},
3900          error=r'func\(x, y\) got unexpected keyword arguments: c'),
3901      dict(
3902          testcase_name='ExpectedRaggedGotNest',
3903          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3904          call_args=lambda: ({
3905              'a': constant_op.constant([1, 2, 3])
3906          },),
3907          error=r'func\(x, y\): argument x had incorrect type\n'
3908          r'  expected: RaggedTensor\n'
3909          r"       got: {'a': (Eager)?Tensor}"),
3910      dict(
3911          testcase_name='WrongRaggedRank',
3912          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3913          call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
3914          error=r'func\(x, y\): argument x had incorrect type\n'),
3915      dict(
3916          testcase_name='WrongRaggedDType',
3917          conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
3918          call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
3919          error=r'func\(x, y\): argument x had incorrect type\n'),
3920      dict(
3921          testcase_name='ExpectedDictGotTensor',
3922          conc_args=lambda: ({
3923              'a': constant_op.constant(1),
3924              'b': constant_op.constant(1)
3925          },),
3926          call_args=lambda: (constant_op.constant(1),),
3927          error=r'func\(x, y\): argument x had incorrect type\n'),
3928      dict(
3929          testcase_name='ExpectedTupleGotTensor',
3930          conc_args=lambda:
3931          ((constant_op.constant(1), constant_op.constant(2)),),
3932          call_args=lambda: (constant_op.constant(1),),
3933          error=r'func\(x, y\): argument x had incorrect type\n'),
3934      dict(
3935          testcase_name='WrongDType',
3936          conc_args=lambda: (constant_op.constant(1),),
3937          call_args=lambda: (constant_op.constant(1.0),),
3938          exception=(ValueError, errors.InvalidArgumentError,
3939                     # on xla_gpu, we get InternalError instead.
3940                     errors.InternalError)),
3941      dict(
3942          testcase_name='ExpectedTensorGotInt',
3943          conc_args=lambda: (constant_op.constant(1),),
3944          call_args=lambda: (5,),
3945          error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
3946      dict(
3947          testcase_name='ExpectedIntGotDifferentInt',
3948          conc_args=lambda: (5,),
3949          call_args=lambda: (8,),
3950          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3951          r'value 5 in x, but was called with int value 8'),
3952      dict(
3953          testcase_name='ExpectedIntGotTensor',
3954          conc_args=lambda: (5,),
3955          call_args=lambda: (constant_op.constant(6),),
3956          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3957          'value 5 in x, but was called with (Eager)?Tensor value .*'),
3958      dict(
3959          testcase_name='TwoValuesForArgument',
3960          conc_args=lambda: (1, 2),
3961          call_args=lambda: (1, 2),
3962          call_kwargs=lambda: {'x': 3},
3963          error=r"func\(x, y\) got two values for 'x'"),
3964  ])
3965  # pylint: enable=g-long-lambda
3966  @test_util.run_in_graph_and_eager_modes
3967  def testConcreteFunctionStructuredSignatureError(self,
3968                                                   conc_args=(),
3969                                                   conc_kwargs=None,
3970                                                   call_args=(),
3971                                                   call_kwargs=None,
3972                                                   error='.*',
3973                                                   exception=TypeError):
3974    """Tests for errors in the structrued signature.
3975
3976    Args:
3977      conc_args: Positional arguments used for get_concrete_function.
3978      conc_kwargs: Keyword arguments used for get_concrete_function.
3979      call_args: Positional arguments used to call the function.
3980      call_kwargs: Keyword arguments used to call the function.
3981      error: Expected exception message.
3982      exception: Expected exception type.
3983    """
3984    conc_args = conc_args() if callable(conc_args) else conc_args
3985    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
3986    call_args = call_args() if callable(call_args) else call_args
3987    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
3988    self.assertIsInstance(conc_args, tuple)
3989    self.assertIsInstance(call_args, tuple)
3990    self.assertIsInstance(conc_kwargs, dict)
3991    self.assertIsInstance(call_kwargs, dict)
3992
3993    @def_function.function
3994    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
3995      del y, varargs, kwargs
3996      return x
3997
3998    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
3999    with self.assertRaisesRegex(exception, error):
4000      self.evaluate(conc(*call_args, **call_kwargs))
4001
4002  # pylint: disable=g-long-lambda
4003  @parameterized.named_parameters([
4004      dict(
4005          testcase_name='MissingArg',
4006          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4007          call_args=lambda: (constant_op.constant(1),),
4008          error=r'func\(x, y\) missing required arguments: y'),
4009      dict(
4010          testcase_name='TwoValuesForArg',
4011          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4012          call_args=lambda: (constant_op.constant(1),),
4013          call_kwargs=lambda: {
4014              'x': constant_op.constant(1),
4015              'y': constant_op.constant(1)
4016          },
4017          error=r"func\(x, y\) got two values for 'x'"),
4018      dict(
4019          testcase_name='ExtraPositionalArg',
4020          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4021          call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
4022                             constant_op.constant(3)),
4023          error=r'func\(x, y\) takes 2 .* got 3'),
4024      dict(
4025          testcase_name='UnexpectedKeywordArg',
4026          conc_args=lambda: (constant_op.constant(1),),
4027          call_args=lambda: (constant_op.constant(1),),
4028          call_kwargs=lambda: {'c': constant_op.constant(1)},
4029          error=r'func\(x\) got unexpected keyword arguments: c'),
4030      dict(
4031          testcase_name='MissingVararg',
4032          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
4033                             constant_op.constant(3)),
4034          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4035          error=r'func\(x, y, varargs_0\) missing required '
4036          r'arguments: varargs_0'),
4037      dict(
4038          testcase_name='MissingKeywordArg',
4039          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4040          conc_kwargs=lambda: {'c': constant_op.constant(1)},
4041          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4042          error=r'func\(x, y, c\) missing required arguments: c'),
4043      dict(
4044          testcase_name='ExpectedTensorGotInt',
4045          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4046          call_args=lambda: (5, constant_op.constant(2)),
4047          error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
4048          r'a Tensor; got int \(5\)'),
4049      dict(
4050          testcase_name='WrongDType',
4051          conc_args=lambda: (constant_op.constant(1),),
4052          call_args=lambda: (constant_op.constant(1.0),),
4053          exception=(ValueError, errors.InvalidArgumentError,
4054                     # on xla_gpu, we get InternalError instead.
4055                     errors.InternalError)),
4056      dict(
4057          testcase_name='MissingKeywordArgNestPiece',
4058          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4059          conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
4060          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
4061          call_kwargs=lambda: {'c': constant_op.constant(1)},
4062          error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
4063  ])
4064  # pylint: enable=g-long-lambda
4065  @test_util.run_in_graph_and_eager_modes
4066  def testConcreteFunctionFlatSignatureError(self,
4067                                             conc_args=(),
4068                                             conc_kwargs=None,
4069                                             call_args=(),
4070                                             call_kwargs=None,
4071                                             error='.*',
4072                                             exception=TypeError):
4073    """Tests for errors in the flat signature.
4074
4075    Args:
4076      conc_args: Positional arguments used for get_concrete_function.
4077      conc_kwargs: Keyword arguments used for get_concrete_function.
4078      call_args: Positional arguments used to call the function.
4079      call_kwargs: Keyword arguments used to call the function.
4080      error: Expected exception message.
4081      exception: Expected exception type.
4082    """
4083    conc_args = conc_args() if callable(conc_args) else conc_args
4084    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
4085    call_args = call_args() if callable(call_args) else call_args
4086    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
4087    self.assertIsInstance(conc_args, tuple)
4088    self.assertIsInstance(call_args, tuple)
4089    self.assertIsInstance(conc_kwargs, dict)
4090    self.assertIsInstance(call_kwargs, dict)
4091
4092    @def_function.function
4093    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
4094      del y, varargs, kwargs
4095      return x
4096
4097    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
4098
4099    # Remove _function_spec, to disable the structured signature.
4100    conc._set_function_spec(None)  # pylint: disable=protected-access
4101
4102    with self.assertRaisesRegex(exception, error):
4103      self.evaluate(conc(*call_args, **call_kwargs))
4104
4105  @test_util.run_in_graph_and_eager_modes
4106  def testConcreteFunctionAmbiguousSignature(self):
4107    # When both the flat & structured signatures are applicable, but they
4108    # give different results, we use the structured signature.  Note: we expect
4109    # this to be extremely rare.
4110    @def_function.function
4111    def f(x, y):
4112      return x * 10 + y
4113
4114    conc = f.get_concrete_function(
4115        x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
4116        y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
4117
4118    result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
4119    self.assertAllEqual(result, 56)
4120
4121  def testPrettyPrintedSignature(self):
4122
4123    @def_function.function
4124    def func(x, kangaroo=None, octopus=7):
4125      del octopus, kangaroo
4126      return x
4127
4128    scalar = constant_op.constant(5)
4129    vector = constant_op.constant([10, 10, 20])
4130    ragged = ragged_factory_ops.constant([[10, 20], [40]])
4131
4132    c1 = func.get_concrete_function(scalar, vector)
4133    c1_summary = r'func\(x, kangaroo, octopus=7\)'
4134    c1_details = (r'  Args:\n'
4135                  r'    x: int32 Tensor, shape=\(\)\n'
4136                  r'    kangaroo: int32 Tensor, shape=\(3,\)\n'
4137                  r'  Returns:\n'
4138                  r'    int32 Tensor, shape=\(\)')
4139    self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary)
4140    self.assertRegex(
4141        c1.pretty_printed_signature(verbose=True),
4142        c1_summary + '\n' + c1_details)
4143    self.assertRegex(
4144        repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
4145    self.assertRegex(
4146        str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
4147
4148    c2 = func.get_concrete_function(scalar, ragged, 3)
4149    c2_summary = r'func\(x, kangaroo, octopus=3\)'
4150    c2_details = (r'  Args:\n'
4151                  r'    x: int32 Tensor, shape=\(\)\n'
4152                  r'    kangaroo: RaggedTensorSpec\(.*\)\n'
4153                  r'  Returns:\n'
4154                  r'    int32 Tensor, shape=\(\)')
4155    self.assertRegex(c2.pretty_printed_signature(),
4156                     c2_summary + '\n' + c2_details)
4157
4158    c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
4159    c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
4160    c3_details = (r'  Args:\n'
4161                  r"    x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
4162                  r'      <1>: int32 Tensor, shape=\(\)\n'
4163                  r'      <2>: RaggedTensorSpec\(.*\)\n'
4164                  r'      <3>: RaggedTensorSpec\(.*\)\n'
4165                  r'  Returns:\n'
4166                  r"    {'a': <1>, 'b': \[<2>, <3>\]}\n"
4167                  r'      <1>: int32 Tensor, shape=\(\)\n'
4168                  r'      <2>: RaggedTensorSpec\(.*\)\n'
4169                  r'      <3>: RaggedTensorSpec\(.*\)')
4170
4171    # python 3.5 does not gurantee deterministic iteration of dict contents
4172    # which can lead mismatch on pretty_printed_signature output for "Args"
4173    if sys.version_info >= (3, 6):
4174      self.assertRegex(c3.pretty_printed_signature(),
4175                       c3_summary + '\n' + c3_details)
4176
4177    # pylint: disable=keyword-arg-before-vararg
4178    @def_function.function
4179    def func2(x, y=3, *args, **kwargs):
4180      return (x, y, args, kwargs)
4181
4182    c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
4183    c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
4184    self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
4185
4186    c5 = func2.get_concrete_function(8, vector)
4187    c5_summary = 'func2(x=8, y)'
4188    self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
4189
4190  def testPrettyPrintedExplicitSignatureWithKeywordArg(self):  # b/159639913
4191
4192    @def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
4193    def fn(a, b=1):
4194      return a + b
4195
4196    concrete_fn = fn.get_concrete_function()
4197    self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)')
4198    self.assertEqual(
4199        concrete_fn.pretty_printed_signature(True), 'fn(a)\n'
4200        '  Args:\n'
4201        '    a: float32 Tensor, shape=<unknown>\n'
4202        '  Returns:\n'
4203        '    float32 Tensor, shape=<unknown>')
4204
4205  def testPrettyPrintedSignatureLoadedNamedTuple(self):
4206    Point = collections.namedtuple('Point', ['x', 'y'])
4207
4208    @def_function.function
4209    def fn(b, a):  # pylint: disable=unused-argument
4210      return 1.
4211
4212    b = Point(
4213        x=constant_op.constant(1., dtype=dtypes.float32),
4214        y=constant_op.constant(1., dtype=dtypes.float32))
4215    a = Point(
4216        x=constant_op.constant(1, dtype=dtypes.int32),
4217        y=constant_op.constant(1, dtype=dtypes.int32))
4218
4219    mod = module.Module()
4220    f = fn.get_concrete_function(b, a)
4221    save(mod, '/tmp/f', signatures=f)
4222    loaded = load('/tmp/f')
4223
4224    printed = loaded.signatures['serving_default'].pretty_printed_signature()
4225    self.assertIn('a: int32 Tensor, shape=()', printed)
4226    self.assertIn('a_1: int32 Tensor, shape=()', printed)
4227    self.assertIn('b: float32 Tensor, shape=()', printed)
4228    self.assertIn('b_1: float32 Tensor, shape=()', printed)
4229
4230  @test_util.run_in_graph_and_eager_modes
4231  def testIndexedSlicesAsGradientsForConcreteFunctions(self):
4232
4233    @def_function.function
4234    def summing_rnn(inputs):
4235      return math_ops.reduce_sum(inputs, axis=1)
4236
4237    @def_function.function
4238    def gradients(inputs):
4239      with backprop.GradientTape() as tape:
4240        tape.watch(inputs)
4241        hidden = summing_rnn(inputs)
4242        hidden = array_ops.gather(hidden, constant_op.constant([0]))
4243        loss = math_ops.reduce_mean(hidden)
4244      return tape.gradient(loss, inputs)
4245
4246    gradients(constant_op.constant([[[1.0], [2.0]]]))  # No error is raised
4247
4248  def testFollowTypeHintsTraceBasic(self):
4249    trace_count = [0]
4250
4251    def func(x: ops.Tensor):
4252      trace_count[0] += 1
4253      return x
4254
4255    enabled = def_function.function(func, experimental_follow_type_hints=True)
4256    disabled = def_function.function(func, experimental_follow_type_hints=False)
4257
4258    enabled(1)  # Initial call gets traced
4259    enabled(2)
4260    enabled(3)
4261    self.assertEqual(trace_count[0], 1)
4262
4263    trace_count = [0]
4264    disabled(1)
4265    disabled(2)  # Retrace
4266    disabled(3)  # Retrace
4267    self.assertEqual(trace_count[0], 3)
4268
4269  def testFollowTypeHintsTraceWithArgs(self):
4270    trace_count = [0]
4271
4272    def func(*args: ops.Tensor):
4273      trace_count[0] += 1
4274      return args
4275
4276    enabled = def_function.function(func, experimental_follow_type_hints=True)
4277    disabled = def_function.function(func, experimental_follow_type_hints=False)
4278
4279    args = (
4280        'abc',
4281        'def',
4282    ) * 20
4283    args2 = (
4284        'def',
4285        'abc',
4286    ) * 20
4287
4288    enabled(args)
4289    enabled(args2)
4290    self.assertEqual(trace_count[0], 1)
4291
4292    trace_count = [0]
4293    disabled(args)
4294    disabled(args2)  # Retrace
4295    self.assertEqual(trace_count[0], 2)
4296
4297  def testFollowTypeHintsTraceWithKwargs(self):
4298    trace_count = [0]
4299
4300    def func(t: ops.Tensor, **kwargs: ops.Tensor):
4301      del kwargs
4302      trace_count[0] += 1
4303      return t
4304
4305    enabled = def_function.function(func, experimental_follow_type_hints=True)
4306    disabled = def_function.function(func, experimental_follow_type_hints=False)
4307
4308    enabled(1, x=1, y=1.0, z='one')
4309    enabled(2, x=2, y=2.0, z='two')
4310    self.assertEqual(trace_count[0], 1)
4311
4312    trace_count = [0]
4313    disabled(1, x=1, y=1.0, z='one')
4314    disabled(2, x=2, y=2.0, z='two')  # Retrace
4315    self.assertEqual(trace_count[0], 2)
4316
4317  def testFollowTypeHintsTraceWithMultipleInputTypes(self):
4318    trace_count = [0]
4319
4320    def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
4321      del args, kwargs
4322      trace_count[0] += 1
4323      return t
4324
4325    enabled = def_function.function(func, experimental_follow_type_hints=True)
4326    disabled = def_function.function(func, experimental_follow_type_hints=False)
4327
4328    enabled(1, constant_op.constant(1), 'str', x=4.0)
4329    enabled(2, constant_op.constant(2), 'str2', x=5.0)
4330    self.assertEqual(trace_count[0], 1)
4331
4332    trace_count = [0]
4333    disabled(1, constant_op.constant(1), 'str', x=4.0)
4334    disabled(2, constant_op.constant(2), 'str2', x=5.0)  # Retrace
4335    self.assertEqual(trace_count[0], 2)
4336
4337  def testFollowTypeHintsTraceWithOnlyArgNamed(self):
4338    trace_count = [0]
4339
4340    def func(t: ops.Tensor, i: int = 1, **kwargs):  # pylint: disable=bad-whitespace
4341      del i, kwargs
4342      trace_count[0] += 1
4343      return t
4344
4345    enabled = def_function.function(func, experimental_follow_type_hints=True)
4346
4347    enabled(1, 3, x=4.0, y='str')
4348    enabled(2, 4, x=4.0, y='str')  # Retrace
4349    self.assertEqual(trace_count[0], 2)
4350
4351  def testFollowTypeHintsTraceWithNotAllNamed(self):
4352    trace_count = [0]
4353
4354    def func(x, y: ops.Tensor, z: int):
4355      del y, z
4356      trace_count[0] += 1
4357      return x
4358
4359    enabled = def_function.function(func, experimental_follow_type_hints=True)
4360
4361    enabled(1, 2, 3)
4362    enabled(1, 20, 3)  # No retrace - change in ops.Tensor typed arg
4363    enabled(2, 2, 3)  # Retrace - change in untyped arg
4364    enabled(2, 2, 4)  # Retrace - change in typed arg
4365    self.assertEqual(trace_count[0], 3)
4366
4367  def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
4368    trace_count = [0]
4369
4370    def func(x, y, *args: ops.Tensor):
4371      del y, args
4372      trace_count[0] += 1
4373      return x
4374
4375    enabled = def_function.function(func, experimental_follow_type_hints=True)
4376
4377    enabled(1, 20, 3, 4, 5, 6)
4378    enabled(1, 20, 3, 4, 5, 60)  # No retrace - change in *args
4379    enabled(1, 30, 7, 8, 9, 10)  # Retrace - change in args
4380    self.assertEqual(trace_count[0], 2)
4381
4382  def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
4383    trace_count = [0]
4384
4385    def func(x, y, *args, **kwargs: ops.Tensor):
4386      del y, args, kwargs
4387      trace_count[0] += 1
4388      return x
4389
4390    enabled = def_function.function(func, experimental_follow_type_hints=True)
4391
4392    enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
4393    enabled(
4394        1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
4395        c=3.5)  # No retrace - change in **kwargs
4396    enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)  # Retrace - change in args
4397    enabled(
4398        1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0)  # Retrace - change in *args
4399    self.assertEqual(trace_count[0], 3)
4400
4401  def testFollowTypeHintsTraceWithArgsEquals(self):
4402    trace_count = [0]
4403
4404    def func(
4405        x: ops.Tensor = 0,  # pylint:disable=bad-whitespace
4406        y: int = 1,  # pylint:disable=bad-whitespace
4407        **kwargs: ops.Tensor):
4408      del y, kwargs
4409      trace_count[0] += 1
4410      return x
4411
4412    enabled = def_function.function(func, experimental_follow_type_hints=True)
4413
4414    enabled(x=1, y=2, z=3)
4415    enabled(x=1, y=3, z=3)  # Retrace - change in args
4416    enabled(x=2, y=2, z=4)  # No retrace - change in args and **kwargs
4417    enabled(x=2, y=2, z=4, u=5)  # Retrace - change in **kwargs
4418    self.assertEqual(trace_count[0], 3)
4419
4420  def testFollowTypeHintsWithTensorSpec(self):
4421    def func(x: ops.Tensor, y):
4422      return x + y
4423    v = def_function.function(experimental_follow_type_hints=True)(func)
4424    v = v.get_concrete_function(
4425        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3)
4426    x = v(constant_op.constant(1.), 3)
4427    self.assertEqual(x.numpy(), 4.)
4428
4429  def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self):
4430    trace_count = [0]
4431
4432    def func(a: int, b: ops.Tensor,
4433             x: ops.Tensor = 0, y: int = 1):
4434      del a, b, y
4435      trace_count[0] += 1
4436      return x
4437
4438    enabled = def_function.function(func, experimental_follow_type_hints=True)
4439
4440    enabled(0, 0, x=1, y=2)
4441    enabled(0, 0, x=2, y=2,)  # No retrace, since only tensor changed
4442    self.assertEqual(trace_count[0], 1)
4443
4444    # Pass args as keyword args.
4445    enabled(a=0, b=0, x=2, y=2,)  # No retrace, args are the same
4446    self.assertEqual(trace_count[0], 1)
4447
4448    enabled(a=1, b=0, x=2, y=2,)  # Retrace, since non-tensor arg changed
4449    self.assertEqual(trace_count[0], 2)
4450
4451    enabled(a=1, b=2, x=2, y=2)  # No retrace, since only tensor changed
4452    self.assertEqual(trace_count[0], 2)
4453
4454    trace_count[0] = 0
4455    disabled = def_function.function(func, experimental_follow_type_hints=False)
4456    disabled(0, 0, x=1, y=2)
4457    disabled(0, 0, x=2, y=2,)  # Retrace
4458    self.assertEqual(trace_count[0], 2)
4459
4460  def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
4461    trace_count = [0]
4462
4463    def func(x, y, **kwargs: ops.Tensor):
4464      del y, kwargs
4465      trace_count[0] += 1
4466      return x
4467
4468    enabled = def_function.function(func, experimental_follow_type_hints=True)
4469
4470    enabled(x=1, y=2, z=3)
4471    enabled(x=1, y=3, z=3)  # Retrace
4472    enabled(x=1, y=2, z=4)  # No retrace
4473    enabled(x=2, y=2, z=4)  # Retrace
4474    enabled(x=2, y=2, z=4, u=5)  # Retrace
4475    self.assertEqual(trace_count[0], 4)
4476
4477  def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
4478    trace_count = [0]
4479
4480    def func(x: ops.Tensor, y: int, **kwargs):
4481      del y, kwargs
4482      trace_count[0] += 1
4483      return x
4484
4485    enabled = def_function.function(func, experimental_follow_type_hints=True)
4486
4487    enabled(x=1, y=2, z=3)
4488    enabled(x=1, y=3, z=3)  # Retrace
4489    enabled(x=1, y=2, z=4)  # Retrace
4490    enabled(x=2, y=2, z=3)  # No retrace
4491    enabled(x=2, y=2, z=4, u=5)  # Retrace
4492    self.assertEqual(trace_count[0], 4)
4493
4494  def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
4495    trace_count = [0]
4496
4497    def func(*, a: ops.Tensor = None, b=1):  # pylint: disable=bad-whitespace
4498      del b
4499      trace_count[0] += 1
4500      return a
4501
4502    enabled = def_function.function(func, experimental_follow_type_hints=True)
4503
4504    enabled(a=1, b=2)
4505    enabled(a=2, b=2)  # No retrace
4506    enabled(a=1, b=1)  # Retrace
4507    self.assertEqual(trace_count[0], 2)
4508
4509  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
4510    trace_count = [0]
4511
4512    def func(arg: ops.Tensor, *args, kwonly, **kwargs):
4513      del args, kwonly, kwargs
4514      trace_count[0] += 1
4515      return arg
4516
4517    enabled = def_function.function(func, experimental_follow_type_hints=True)
4518
4519    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4520    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4521    enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4522    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4523    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4524    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4525    self.assertEqual(trace_count[0], 4)
4526
4527  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
4528    trace_count = [0]
4529
4530    def func(arg, *args: ops.Tensor, kwonly, **kwargs):
4531      del args, kwonly, kwargs
4532      trace_count[0] += 1
4533      return arg
4534
4535    enabled = def_function.function(func, experimental_follow_type_hints=True)
4536
4537    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4538    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4539    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4540    enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4541    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4542    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4543    self.assertEqual(trace_count[0], 4)
4544
4545  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
4546    trace_count = [0]
4547
4548    def func(arg, *args, kwonly: ops.Tensor, **kwargs):
4549      del args, kwonly, kwargs
4550      trace_count[0] += 1
4551      return arg
4552
4553    enabled = def_function.function(func, experimental_follow_type_hints=True)
4554
4555    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4556    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4557    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4558    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # No retrace
4559    enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7)  # No retrace
4560    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4561    self.assertEqual(trace_count[0], 4)
4562
4563  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
4564    trace_count = [0]
4565
4566    def func(arg, *args, kwonly, **kwargs: ops.Tensor):
4567      del args, kwonly, kwargs
4568      trace_count[0] += 1
4569      return arg
4570
4571    enabled = def_function.function(func, experimental_follow_type_hints=True)
4572
4573    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4574    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4575    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4576    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4577    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # No retrace
4578    enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700)  # No retrace
4579    self.assertEqual(trace_count[0], 4)
4580
4581  def testWithExtraWrapper(self):
4582
4583    class Foo(module.Module):
4584
4585      def __init__(self):
4586        super().__init__()
4587        self.var = None
4588
4589      @def_function.function
4590      @dummy_tf_decorator
4591      def add(self, x, y, z=1):
4592        if self.var is None:
4593          return x + y + z
4594
4595    foo = Foo()
4596    self.assertEqual(foo.add(2, 3).numpy(), 6)
4597
4598  @parameterized.parameters([(def_function.function, dummy_tf_decorator),
4599                             (dummy_tf_decorator, def_function.function),
4600                             (def_function.function, def_function.function)])
4601  def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
4602
4603    class Foo(module.Module):
4604
4605      def __init__(self):
4606        super().__init__()
4607        self.var = None
4608
4609      @decorator1
4610      @decorator2
4611      def add1(self, x, y):
4612        if self.var is None:
4613          return x + y
4614
4615    foo = Foo()
4616    with self.assertRaisesRegex(TypeError, 'got two values'):
4617      foo.add1(2, x=3)  # pylint: disable=redundant-keyword-arg,no-value-for-parameter
4618
4619  def testWithExtraWrapperMissingArgs(self):
4620
4621    class Foo(module.Module):
4622
4623      def __init__(self):
4624        super().__init__()
4625        self.var = None
4626
4627      @def_function.function
4628      @dummy_tf_decorator
4629      def add1(self, x, y):
4630        if self.var is None:
4631          return x + y
4632
4633      @def_function.function
4634      @dummy_tf_decorator
4635      def add2(self, x, y):
4636        if self.var is None:
4637          return x + y
4638
4639      @def_function.function
4640      @def_function.function
4641      def add3(self, x, y):
4642        if self.var is None:
4643          return x + y
4644
4645    foo = Foo()
4646    with self.assertRaisesRegex(
4647        TypeError, 'missing 1 required positional argument: \'y\''):
4648      foo.add1(2)  # pylint: disable=no-value-for-parameter
4649
4650    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4651      foo.add1(y=2)  # pylint: disable=no-value-for-parameter
4652
4653    with self.assertRaisesRegex(
4654        TypeError, 'missing 1 required positional argument: \'y\''):
4655      foo.add2(2)  # pylint: disable=no-value-for-parameter
4656
4657    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4658      foo.add2(y=2)  # pylint: disable=no-value-for-parameter
4659
4660    with self.assertRaisesRegex(
4661        TypeError, 'missing 1 required positional argument: \'y\''):
4662      foo.add3(2)  # pylint: disable=no-value-for-parameter
4663
4664    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4665      foo.add3(y=2)  # pylint: disable=no-value-for-parameter
4666
4667  def testMissingArgsTfFunctionedMethod(self):
4668
4669    class A:
4670
4671      def func(self, position_arg1, position_arg2):
4672        return position_arg1, position_arg2
4673
4674      @def_function.function
4675      def decorated_method(self, position_arg1, position_arg2):
4676        return position_arg1, position_arg2
4677
4678    a_instance = A()
4679    tf_method_pos = def_function.function(a_instance.func)
4680    with self.assertRaisesRegex(
4681        TypeError, '.* missing 1 required argument: position_arg1'):
4682      tf_method_pos(position_arg2='foo')
4683
4684    # tf.function-decorated instance methods need to be tested because of
4685    # the __get__ method implementation.
4686    tf_func_decorated_method = def_function.function(
4687        a_instance.decorated_method)
4688    tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
4689    with self.assertRaisesRegex(
4690        TypeError, '.* missing 1 required argument: position_arg1'):
4691      tf_func_decorated_method(position_arg2='bar')
4692
4693  def testMissingArgsTfFunctionedObject(self):
4694
4695    class A:
4696
4697      def __call__(self, position_arg1, position_arg2):
4698        return position_arg1, position_arg2
4699
4700    a_instance = A()
4701
4702    # A tf.function-decorated callable object needs to be tested because of
4703    # the special inspect results.
4704    tf_func_obj = def_function.function(a_instance)
4705    tf_func_obj(position_arg1=1, position_arg2=2)
4706    with self.assertRaisesRegex(
4707        TypeError, '.* missing 1 required argument: position_arg1'):
4708      tf_func_obj(position_arg2='bar')
4709
4710  def testMissingArgsTfFunctionedFunctions(self):
4711
4712    def func_pos(position_arg1, position_arg2):
4713      return position_arg1, position_arg2
4714
4715    def func_with_default(position_arg, named_arg=None):
4716      return position_arg, named_arg
4717
4718    def func_pos_3args(position_arg1, position_arg2, position_arg3):
4719      return position_arg1, position_arg2, position_arg3
4720
4721    tf_func_pos = def_function.function(func_pos)
4722    with self.assertRaisesRegex(
4723        TypeError, '.* missing 1 required argument: position_arg1'):
4724      tf_func_pos(position_arg2='foo')
4725
4726    tf_func_with_default = def_function.function(func_with_default)
4727    tf_func_with_default(position_arg='bar')
4728    with self.assertRaisesRegex(TypeError,
4729                                '.* missing 1 required argument: position_arg'):
4730      tf_func_with_default(named_arg='foo')
4731
4732    tf_func_pos_3args = def_function.function(func_pos_3args)
4733    with self.assertRaisesRegex(
4734        TypeError,
4735        '.* missing required arguments: position_arg1, position_arg3'):
4736      tf_func_pos_3args(position_arg2='foo')
4737
4738  def testShapeInferencePropagateConstNestedStack(self):
4739
4740    @def_function.function(input_signature=[
4741        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4742        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4743    ])
4744    def f(x, s):
4745      old_shape = array_ops.shape(x)
4746      new_shape = array_ops.stack([old_shape[0], s], axis=0)
4747      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4748      return y
4749
4750    @def_function.function(input_signature=[
4751        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4752    ])
4753    def g(x):
4754      y = f(x, s=5)
4755      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4756      return y
4757
4758    self.assertAllEqual(
4759        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4760
4761  def testShapeInferencePropagateConstNestedUnstackStack(self):
4762
4763    @def_function.function(input_signature=[
4764        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4765        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4766    ])
4767    def f(x, s):
4768      s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
4769      new_shape = array_ops.stack([s0, s], axis=0)
4770      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4771      return y
4772
4773    @def_function.function(input_signature=[
4774        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4775    ])
4776    def g(x):
4777      y = f(x, s=5)
4778      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4779      return y
4780
4781    self.assertAllEqual(
4782        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4783
4784  def testShapeInferencePropagateConstNestedConcat(self):
4785
4786    @def_function.function(input_signature=[
4787        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4788        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4789        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4790    ])
4791    def f(d1, d2, d3):
4792      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4793      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4794      return y
4795
4796    @def_function.function()
4797    def g():
4798      y = f(1, 2, 3)
4799      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4800      return y
4801
4802    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4803
4804  def testShapeInferencePropagateConstDoubleNested(self):
4805
4806    @def_function.function(input_signature=[
4807        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4808        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4809        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4810    ])
4811    def f(d1, d2, d3):
4812      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4813      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4814      return y
4815
4816    @def_function.function()
4817    def g():
4818      y = def_function.function(f)(1, 2, 3)
4819      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4820      return y
4821
4822    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4823
4824  @test_util.run_v2_only
4825  def testControlDependencyAfterInline(self):
4826    v = variables.Variable(0.)
4827
4828    @def_function.function
4829    def assign():
4830      return v.assign(1.)
4831
4832    @def_function.function
4833    def assign_add():
4834      return v.assign_add(1.)
4835
4836    @def_function.function
4837    def f():
4838      check_ops.assert_equal_v2(assign(), 1.)
4839      check_ops.assert_equal_v2(assign_add(), 2.)
4840
4841    # We don't have a way to inspect the inlined graph in Python, so we run it
4842    # multiple times to have more confidence the dependency is correct.
4843    for _ in range(30):
4844      f()
4845
4846  @test_util.run_v2_only
4847  def testReadInFuncWriteOutside(self):
4848    # Run many times since we are testing for a potential race condition.
4849    for _ in range(30):
4850      # pylint: disable=cell-var-from-loop
4851      v = variables.Variable(1.)
4852
4853      @def_function.function
4854      def add_one():
4855        return v + 1.
4856
4857      @def_function.function
4858      def get_v_plus_one():
4859        v_plus_one = add_one()
4860        v.assign_add(2.0)
4861        return v_plus_one
4862
4863      self.assertAllEqual(get_v_plus_one(), 2.0)
4864
4865  def testOpExpandErrorMessage(self):
4866    @def_function.function
4867    def test_fn():
4868      if array_ops.constant(False):
4869        return array_ops.constant(1)
4870      else:
4871        return script_ops.eager_py_func(
4872            func=lambda: array_ops.constant([2.]), inp=(), Tout=dtypes.int32)
4873
4874    error_pattern = re.compile(r'Graph execution error.*func=lambda', re.DOTALL)
4875    with self.assertRaisesRegex(errors.InvalidArgumentError, error_pattern):
4876      test_fn()
4877
4878
4879class MultiDeviceTest(test.TestCase, parameterized.TestCase):
4880
4881  @test_util.run_gpu_only
4882  def testMultiDeviceOutput(self):
4883    """Tests that functions can produce outputs on multiple devices."""
4884    @function.defun
4885    def func(a, b, transpose_a):
4886      with ops.device('/device:CPU:0'):
4887        m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
4888      with ops.device('/device:GPU:0'):
4889        m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
4890      return m1, m2
4891
4892    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
4893    m1, m2 = func(t, t, transpose_a=True)
4894    self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
4895    self.assertRegex(m1.backing_device, 'CPU')
4896    self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
4897    self.assertRegex(m2.backing_device, 'GPU')
4898
4899  @test_util.run_gpu_only
4900  def testEmptyBody(self):
4901    @function.defun
4902    def func(a, b):
4903      return b, a
4904
4905    with ops.device('/device:CPU:0'):
4906      a = array_ops.identity(3.0)
4907    with ops.device('/device:GPU:0'):
4908      b = array_ops.identity(5.0)
4909
4910    m1, m2 = func(a, b)
4911    self.assertAllEqual(m1.numpy(), 5.0)
4912    self.assertRegex(m1.backing_device, 'GPU')
4913    self.assertAllEqual(m2.numpy(), 3.0)
4914    self.assertRegex(m2.backing_device, 'CPU')
4915
4916  @test_util.run_gpu_only
4917  def testMultiDeviceInt32(self):
4918    """Tests that multi-device functions can take and output INT32s.
4919
4920    When an INT32 device tensor is fed into a function, it is copied to CPU
4921    by the eager runtime. The function sees all INT32 inputs on CPU.
4922
4923    We set allocator attribute 'on_host' for INT32 outputs. They can be
4924    partitioned into the GPU component function, but will be allocated on
4925    CPU nevertheless.
4926
4927    There is experimental support for `ints_on_device` in
4928    FunctionLibraryRuntime now. We can try that.
4929
4930    """
4931    with ops.device('/device:CPU:0'):
4932      int_cpu = constant_op.constant(3, dtype=dtypes.int32)
4933      resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
4934    with ops.device('/device:GPU:0'):
4935      int_gpu = constant_op.constant(7, dtype=dtypes.int32)
4936
4937    @function.defun
4938    def func(int_cpu, resource, int_gpu):
4939      with ops.device('/device:CPU:0'):
4940        m1 = int_cpu * resource + int_gpu
4941      with ops.device('/device:GPU:0'):
4942        # This computation will happen on GPU but m2 will be copied to CPU.
4943        m2 = int_gpu * resource + int_cpu + 1
4944      return m1, m2
4945
4946    m1, m2 = func(int_cpu, resource, int_gpu)
4947    self.assertAllEqual(m1.numpy(), 22)
4948    self.assertRegex(m1.backing_device, 'CPU')
4949    self.assertAllEqual(m2.numpy(), 39)
4950    self.assertRegex(m2.backing_device, 'CPU')
4951
4952    # flip arguments
4953    m1, m2 = func(int_gpu, resource, int_cpu)
4954    self.assertAllEqual(m1.numpy(), 38)
4955    self.assertRegex(m1.backing_device, 'CPU')
4956    self.assertAllEqual(m2.numpy(), 23)
4957    self.assertRegex(m2.backing_device, 'CPU')
4958
4959  @test_util.run_gpu_only
4960  def testMultiDeviceColocateWith(self):
4961    """Tests that function's outputs respect colocation constraints."""
4962    @function.defun
4963    def func(a, b):
4964      with ops.colocate_with(a):
4965        ra = 2 * a
4966      with ops.colocate_with(b):
4967        rb = 3 * b
4968      return ra, rb
4969
4970    devices = ['/device:CPU:0', '/device:GPU:0']
4971    for dev1, dev2 in itertools.product(devices, devices):
4972      with ops.device(dev1):
4973        a = array_ops.identity(1.0)
4974      with ops.device(dev2):
4975        b = array_ops.identity(10.0)
4976
4977      ra, rb = func(a, b)
4978      self.assertEqual(ra.numpy(), 2.0)
4979      self.assertRegex(ra.backing_device, dev1)
4980      self.assertEqual(rb.numpy(), 30.0)
4981      self.assertRegex(rb.backing_device, dev2)
4982
4983  @test_util.run_gpu_only
4984  def testMultiDeviceResources(self):
4985    with ops.device('/device:CPU:0'):
4986      c1 = resource_variable_ops.ResourceVariable(2.0)
4987      c2 = resource_variable_ops.ResourceVariable(7.0)
4988    with ops.device('/device:GPU:0'):
4989      g1 = resource_variable_ops.ResourceVariable(3.0)
4990      g2 = resource_variable_ops.ResourceVariable(5.0)
4991
4992    @function.defun
4993    def func(resource1, resource2):
4994      with ops.device('/device:CPU:0'):
4995        result1 = resource1 * g2
4996      with ops.device('/device:GPU:0'):
4997        result2 = resource2 * c2
4998      return result1, result2
4999
5000    r1, r2 = func(c1, g1)
5001    self.assertEqual(r1.numpy(), 10.0)
5002    self.assertRegex(r1.backing_device, 'CPU')
5003    self.assertEqual(r2.numpy(), 21.0)
5004    self.assertRegex(r2.backing_device, 'GPU')
5005
5006    # Call with flipped inputs. Check that we look at resource's
5007    # device and reinstantiates the function when inputs' devices change.
5008    r1, r2 = func(g1, c1)
5009    self.assertEqual(r1.numpy(), 15.0)
5010    self.assertRegex(r1.backing_device, 'CPU')
5011    self.assertEqual(r2.numpy(), 14.0)
5012    self.assertRegex(r2.backing_device, 'GPU')
5013
5014  @test_util.run_gpu_only
5015  def testOutputResources(self):
5016    with ops.device('/device:CPU:0'):
5017      c1 = resource_variable_ops.ResourceVariable(2.0)
5018    with ops.device('/device:GPU:0'):
5019      g1 = resource_variable_ops.ResourceVariable(3.0)
5020
5021    @function.defun
5022    def func(resource1, resource2):
5023      with ops.device('/device:CPU:0'):
5024        result1 = resource1 * 5
5025      with ops.device('/device:GPU:0'):
5026        result2 = resource2 * 7
5027      return result1, resource1.handle, result2, resource2.handle
5028
5029    r1, res1, r2, res2 = func(c1, g1)
5030    self.assertEqual(r1.numpy(), 10.0)
5031    self.assertRegex(r1.backing_device, 'CPU')
5032    self.assertEqual(r2.numpy(), 21.0)
5033    self.assertRegex(r2.backing_device, 'GPU')
5034
5035    def check_handle(handle, expected_value):
5036      self.assertRegex(handle.backing_device, 'CPU')
5037      tensor = gen_resource_variable_ops.read_variable_op(
5038          handle, dtypes.float32)
5039      self.assertEqual(tensor.numpy(), expected_value)
5040
5041    # Check that handles returned from functions are on CPU and an op using
5042    # the resource handle is correctly placed on the device backing the
5043    # resource.
5044    check_handle(res1, 2.0)
5045    check_handle(res2, 3.0)
5046
5047    # Call with flipped inputs to make sure the same the function is
5048    # reinstantiated and eager runtime does not mess up the device assignment
5049    # for ops consuming handles returned from defuns.
5050    r1, res1, r2, res2 = func(g1, c1)
5051    self.assertEqual(r1.numpy(), 15.0)
5052    self.assertRegex(r1.backing_device, 'CPU')
5053    self.assertEqual(r2.numpy(), 14.0)
5054    self.assertRegex(r2.backing_device, 'GPU')
5055    check_handle(res1, 3.0)
5056    check_handle(res2, 2.0)
5057
5058  @test_util.run_gpu_only
5059  def testPassResourceThroughNestedFunctionCall(self):
5060    """Test passing GPU resource to noinline function call placed on CPU.
5061
5062    PartitionedCallOp must not enforce any particular device assignment for the
5063    resource output. Inner function marked as `_nospecialize`, so Grappler would
5064    not prune unused function output.
5065    """
5066
5067    with ops.device('/device:GPU:0'):
5068      g1 = resource_variable_ops.ResourceVariable(3.0)
5069
5070    @function.defun_with_attributes(attributes={
5071        '_noinline': True,
5072        '_nospecialize': True
5073    })
5074    def inner(resource1):
5075      return resource1 * 2, resource1.handle
5076
5077    @function.defun
5078    def outer(resource1):
5079      with ops.device('/device:CPU:0'):
5080        r1, _ = inner(resource1)
5081      return r1
5082
5083    r1 = outer(g1)
5084
5085    self.assertEqual(r1.numpy(), 6.0)
5086    self.assertRegex(r1.backing_device, 'CPU')
5087
5088  @test_util.run_gpu_only
5089  def testReturnResourceFromNestedFunctionCall(self):
5090    """Test returning GPU resource from noinline function call placed on CPU.
5091
5092    When inferring output devices for the return value, do not set a device for
5093    returns of DT_RESOURCE data type based on the device assignment of the node
5094    that produced that resource. As an example function call placed on CPU can
5095    return resources on GPU.
5096    """
5097
5098    with ops.device('/device:GPU:0'):
5099      g1 = resource_variable_ops.ResourceVariable(3.0)
5100
5101    @function.defun_with_attributes(attributes={
5102        '_noinline': True
5103    })
5104    def inner(resource1):
5105      resource1.assign_add(2.0)
5106      return resource1 * 2, resource1.handle
5107
5108    @function.defun
5109    def outer(resource1):
5110      with ops.device('/device:CPU:0'):
5111        r1, res1 = inner(resource1)
5112      return r1, res1
5113
5114    r1, res1 = outer(g1)
5115
5116    self.assertEqual(r1.numpy(), 10.0)
5117    self.assertRegex(r1.backing_device, 'CPU')
5118
5119    def check_handle(handle, expected_value):
5120      self.assertRegex(handle.backing_device, 'CPU')
5121      tensor = gen_resource_variable_ops.read_variable_op(
5122          handle, dtypes.float32)
5123      self.assertEqual(tensor.numpy(), expected_value)
5124
5125    # Check that handles returned from functions are on CPU and an op using
5126    # the resource handle is correctly placed on the device backing the
5127    # resource.
5128    check_handle(res1, 5.0)
5129
5130  @test_util.run_gpu_only
5131  def testComplexInputOutputDevicePattern(self):
5132    """Tests input/output mapping logic in partitioning."""
5133    with ops.device('/device:CPU:0'):
5134      rc0 = resource_variable_ops.ResourceVariable(2.0)
5135      rc1 = resource_variable_ops.ResourceVariable(3.0)
5136      cc0 = array_ops.identity(5.0)
5137      cc1 = array_ops.identity(7.0)
5138    with ops.device('/device:GPU:0'):
5139      rg0 = resource_variable_ops.ResourceVariable(11.0)
5140      rg1 = resource_variable_ops.ResourceVariable(13.0)
5141      cg0 = array_ops.identity(17.0)
5142      cg1 = array_ops.identity(19.0)
5143
5144    # Make sure tensors are on expected devices.
5145    for tensor in [cc0, cc1]:
5146      self.assertRegex(tensor.backing_device, 'CPU:0')
5147    for tensor in [cg0, cg1]:
5148      self.assertRegex(tensor.backing_device, 'GPU:0')
5149
5150    @function.defun
5151    def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
5152      with ops.device('/device:CPU:0'):
5153        m1 = rc0 * cg0
5154      with ops.device('/device:GPU:0'):
5155        m2 = rg0 * cc0
5156
5157      with ops.device('/device:CPU:0'):
5158        r1 = 1000.0 * m2 + rc1 * cg1
5159      with ops.device('/device:GPU:0'):
5160        r2 = 1000.0 * m1 + rg1 * cc1
5161
5162      return r1, r2, m2, m1
5163
5164    r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
5165    self.assertRegex(m1.backing_device, 'CPU')
5166    self.assertRegex(r1.backing_device, 'CPU')
5167    self.assertRegex(m2.backing_device, 'GPU')
5168    self.assertRegex(r2.backing_device, 'GPU')
5169    self.assertEqual(m1.numpy(), 34.0)
5170    self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
5171    self.assertEqual(m2.numpy(), 55.0)
5172    self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
5173
5174  @test_util.run_gpu_only
5175  def testArgumentPruning(self):
5176    """Tests functions taking unnecessary arguments."""
5177    with ops.device('/device:CPU:0'):
5178      c1 = constant_op.constant(5.0)
5179      c2 = constant_op.constant(7.0)
5180
5181    with ops.device('/device:GPU:0'):
5182      g1 = constant_op.constant(11.0)
5183      g2 = constant_op.constant(13.0)
5184      g3 = constant_op.constant(17.0)
5185
5186    @function.defun
5187    def func(g1, g2, c1, g3, c2):  # pylint: disable=unused-argument
5188      # arguments g1 and g2 are unused and can be pruned by grappler.
5189      return c1 * g3 * c2
5190
5191    result = func(g1, g2, c1, g3, c2)
5192    self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
5193
5194  def testNestedCallWatchedVariables(self):
5195
5196    v = variables.Variable(4.)
5197
5198    @def_function.function
5199    def f():
5200      return v ** 2.
5201
5202    with backprop.GradientTape() as tape:
5203      f()
5204
5205    self.assertEqual((v,), tape.watched_variables())
5206
5207    @def_function.function
5208    def g():
5209      return f()
5210
5211    with backprop.GradientTape() as tape:
5212      g()
5213
5214    self.assertEqual((v,), tape.watched_variables())
5215
5216    # f() can rely on the variable being read during its trace. g() checks that
5217    # variables from a function which knows about them are recorded on the
5218    # tape. h() tests that functions forward knowledge of variables to callers.
5219
5220    @def_function.function
5221    def h():
5222      return g()
5223
5224    with backprop.GradientTape() as tape:
5225      h()
5226
5227    self.assertEqual((v,), tape.watched_variables())
5228
5229  def testReplaceCaptureWithDeferred(self):
5230
5231    x = constant_op.constant(1.0)
5232    y = constant_op.constant(2.0)
5233    z = constant_op.constant(3.0)
5234
5235    @def_function.function
5236    def fn():
5237      a = x + y
5238      b = a + z
5239      return b
5240
5241    concrete_fn = fn.get_concrete_function()
5242    self.assertAllEqual(concrete_fn(), 6.0)
5243
5244    value = constant_op.constant(4.0)
5245
5246    def closure():
5247      return value
5248
5249    concrete_fn.replace_capture_with_deferred_capture(
5250        concrete_fn.captured_inputs[1],
5251        closure,
5252        spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
5253        placeholder=concrete_fn.inputs[1])
5254
5255    self.assertAllEqual(concrete_fn(), 8.0)
5256
5257    value = constant_op.constant(5.0)
5258    self.assertAllEqual(concrete_fn(), 9.0)
5259
5260  def testRaiseReplaceCaptureWithDeferredTypeSpecMismatch(self):
5261    bool_captured_tensor = constant_op.constant(True)
5262    float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
5263    value = constant_op.constant([2.], dtype=dtypes.float32)
5264
5265    @def_function.function
5266    def fn():
5267      deferred_tensor = ops.get_default_graph().capture_call_time_value(
5268          lambda: value,
5269          tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32))
5270      if bool_captured_tensor:
5271        return deferred_tensor
5272      else:
5273        return deferred_tensor + float_captured_tensor
5274
5275    concrete_fn = fn.get_concrete_function()
5276    self.assertAllEqual(concrete_fn(), [2.])
5277
5278    new_bool_captured_tensor = constant_op.constant(False)
5279    def bool_closure():
5280      return new_bool_captured_tensor
5281
5282    # Test raise if replacing a bool capture with a closure of output type
5283    # float32
5284    new_float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
5285    def float_closure():
5286      return new_float_captured_tensor
5287
5288    with self.assertRaisesRegex(ValueError,
5289                                'Attempting to substitute closure with spec*'):
5290      concrete_fn.replace_capture_with_deferred_capture(
5291          bool_captured_tensor,
5292          float_closure,
5293          spec=tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32))
5294
5295    # Test replace without a placeholder
5296    concrete_fn.replace_capture_with_deferred_capture(
5297        bool_captured_tensor,
5298        bool_closure,
5299        spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool))
5300
5301    self.assertAllEqual(concrete_fn(), [5.])
5302
5303  def testConcreteFunctionSetExternalCapture(self):
5304    captured_tensor = constant_op.constant([1.])
5305    value = constant_op.constant([2.])
5306
5307    @def_function.function
5308    def fn():
5309      deferred_tensor = ops.get_default_graph().capture_call_time_value(
5310          lambda: value,
5311          tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32))
5312      return deferred_tensor + captured_tensor
5313
5314    cf = fn.get_concrete_function()
5315    self.assertLen(cf._captured_inputs, 2)
5316    self.assertEqual(list(map(callable, cf._captured_inputs)), [False, True])
5317    self.assertAllEqual(cf(), [3.])
5318
5319    # Reset capture to a deferred one, reset deferred capture to a capture.
5320    cf.set_external_captures([cf._captured_inputs[1], cf._captured_inputs[0]])
5321
5322    value = constant_op.constant([3.])
5323    self.assertAllEqual(cf(), [4.])
5324
5325  def testGraphReplaceCaptureAndSetExternalCapture(self):
5326    bool_captured_tensor = constant_op.constant(True)
5327    float_captured_tensor = constant_op.constant([3.], dtype=dtypes.float32)
5328    value = constant_op.constant([2.], dtype=dtypes.float32)
5329
5330    @def_function.function
5331    def fn():
5332      deferred_tensor = ops.get_default_graph().capture_call_time_value(
5333          lambda: value,
5334          tensor_spec.TensorSpec(shape=(1,), dtype=dtypes.float32))
5335      if bool_captured_tensor:
5336        return deferred_tensor
5337      else:
5338        return deferred_tensor + float_captured_tensor
5339
5340    concrete_fn = fn.get_concrete_function()
5341    self.assertAllEqual(concrete_fn(), [2.])
5342
5343    new_bool_captured_tensor = constant_op.constant(False)
5344
5345    def closure():
5346      return new_bool_captured_tensor
5347
5348    concrete_fn.graph.replace_capture_with_deferred_capture(
5349        concrete_fn.captured_inputs[0],
5350        closure,
5351        spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool),
5352        placeholder=concrete_fn.inputs[1])
5353
5354    concrete_fn.set_external_captures([
5355        closure, concrete_fn._captured_inputs[1],
5356        concrete_fn._captured_inputs[2]
5357    ])
5358    self.assertAllEqual(concrete_fn(), [5.])
5359
5360  def testDeferredCapture(self):
5361    value = 1.0
5362
5363    @def_function.function
5364    def lazy_capture(x):
5365      y = ops.get_default_graph().capture_call_time_value(
5366          lambda: value, tensor_spec.TensorSpec(None))
5367      return x + y
5368
5369    self.assertAllEqual(lazy_capture(2.0), 3.0)
5370    # After changing the value of `value` the function call should return a
5371    # different result.
5372    value = 2.0
5373    self.assertAllEqual(lazy_capture(2.0), 4.0)
5374
5375  def testNestedDeferredCapture(self):
5376    value = 1.0
5377
5378    @def_function.function
5379    def inner(x):
5380      y = ops.get_default_graph().capture_call_time_value(
5381          lambda: value, tensor_spec.TensorSpec(None))
5382      return x + y
5383
5384    @def_function.function
5385    def outer(x):
5386      return inner(x)
5387
5388    self.assertAllEqual(outer(2.0), 3.0)
5389    # After changing the value of `value` the function call should return a
5390    # different result.
5391    value = 2.0
5392    self.assertAllEqual(outer(2.0), 4.0)
5393
5394  def testNestedDeferredCaptureInTFWhileLoop(self):
5395
5396    value = 1.
5397
5398    @def_function.function
5399    def inner(x):
5400      y = ops.get_default_graph().capture_call_time_value(
5401          lambda: value, tensor_spec.TensorSpec(None))
5402      return x + y
5403
5404    @def_function.function
5405    def outer():
5406      dummy = constant_op.constant(True)
5407      sums = constant_op.constant(0.)
5408      while dummy:
5409        directives.set_loop_options(
5410            shape_invariants=[(sums, tensor_shape.TensorShape(None))])
5411        sums += inner(2.)
5412        dummy = constant_op.constant(False)
5413      return sums
5414
5415    self.assertAllEqual(outer(), 3.)
5416
5417    value = constant_op.constant(2.)
5418    self.assertAllEqual(outer(), 4.)
5419
5420    value = constant_op.constant(3.)
5421    self.assertAllEqual(outer(), 5.)
5422
5423  def testDeferredCaptureWithKey(self):
5424    value0 = 1.0
5425    value1 = 2.0
5426
5427    @def_function.function
5428    def lazy_capture(x):
5429      w = ops.get_default_graph().capture_call_time_value(
5430          lambda: value0, tensor_spec.TensorSpec(None), key=0)
5431      y = ops.get_default_graph().capture_call_time_value(
5432          lambda: value1, tensor_spec.TensorSpec(None), key=1)
5433      def bad_closure():
5434        raise ValueError('Should not run')
5435      z = ops.get_default_graph().capture_call_time_value(
5436          bad_closure, tensor_spec.TensorSpec(None), key=1)
5437      return x + y + w + z
5438
5439    self.assertAllEqual(lazy_capture(2.0), 7.0)
5440    value0 = 2.0
5441    value1 = 3.0
5442    self.assertAllEqual(lazy_capture(2.0), 10.0)
5443
5444  def testDeferredCaptureTypeError(self):
5445    value = constant_op.constant(1.0)
5446
5447    @def_function.function
5448    def lazy_capture(x):
5449      y = ops.get_default_graph().capture_call_time_value(
5450          lambda: value, tensor_spec.TensorSpec(()))
5451      return x + y
5452
5453    self.assertAllEqual(lazy_capture(2.0), 3.0)
5454
5455    # dtype mismatch
5456    value = constant_op.constant(1)
5457    with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'):
5458      lazy_capture(2.0)
5459
5460    # shape mismatch
5461    value = constant_op.constant([1.0])
5462    with self.assertRaisesRegex(ValueError, 'Value .* shape'):
5463      lazy_capture(2.0)
5464
5465  def testDeferredCaptureReturnNestWithCompositeTensor(self):
5466    i_s = indexed_slices.IndexedSlices(
5467        constant_op.constant([1, 2]),
5468        constant_op.constant([0, 1], dtype=dtypes.int64),
5469        constant_op.constant([2]))
5470    r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
5471    s_t = sparse_tensor.SparseTensor(
5472        values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
5473
5474    @def_function.function
5475    def lazy_capture():
5476      y = ops.get_default_graph().capture_call_time_value(
5477          lambda: {'i': i_s, 't': (r_t, s_t)},
5478          {'i': indexed_slices.IndexedSlicesSpec(
5479              dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
5480           't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
5481                 sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
5482      return y['i'], y['t']
5483
5484    i, (r, s) = lazy_capture()
5485    self.assertAllEqual(i_s.values, i.values)
5486    self.assertAllEqual(i_s.indices, i.indices)
5487    self.assertAllEqual(i_s.dense_shape, i.dense_shape)
5488    self.assertAllEqual(r_t, r)
5489    self.assertAllEqual(s_t.indices, s.indices)
5490    self.assertAllEqual(s_t.values, s.values)
5491    self.assertAllEqual(s_t.dense_shape, s.dense_shape)
5492
5493  def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
5494    value = indexed_slices.IndexedSlices(
5495        constant_op.constant([1, 2]),
5496        constant_op.constant([0, 1], dtype=dtypes.int64))
5497
5498    @def_function.function
5499    def lazy_capture():
5500      return ops.get_default_graph().capture_call_time_value(
5501          lambda: value,
5502          indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32))
5503
5504    # Type matches spec.
5505    lazy_capture()
5506
5507    # Extra dense shape component.
5508    value = indexed_slices.IndexedSlices(
5509        constant_op.constant([1, 2]),
5510        constant_op.constant([0, 1], dtype=dtypes.int64),
5511        constant_op.constant([2]))
5512    with self.assertRaises(ValueError):
5513      lazy_capture()
5514
5515    # Index dtype mismatch int32 vs. int64.
5516    value = indexed_slices.IndexedSlices(
5517        constant_op.constant([1, 2]),
5518        constant_op.constant([0, 1]))
5519    with self.assertRaises(ValueError):
5520      lazy_capture()
5521
5522  def testMaybeCreateCapturePlaceholderWithValidCapture(self):
5523    @def_function.function
5524    def f():
5525      func = lambda: x
5526      return ops.get_default_graph()._maybe_create_capture_placeholder(func)
5527
5528    x = {
5529        'tensor': constant_op.constant(0),
5530        'list': [constant_op.constant(1), 2],
5531        'dict': {
5532            'float': constant_op.constant(0.5)
5533        }
5534    }
5535
5536    out = f()
5537    # tf.function output should have same structure/values with the side input
5538    self.assertEqual(x['tensor'].numpy(), out['tensor'].numpy())
5539    self.assertEqual(x['list'][0].numpy(), out['list'][0].numpy())
5540    self.assertEqual(x['list'][1], out['list'][1].numpy())
5541    self.assertEqual(x['dict']['float'].numpy(), out['dict']['float'].numpy())
5542
5543  def testMaybeCreateCapturePlaceholderWithInvalidCapture(self):
5544    @def_function.function
5545    def f():
5546      func = lambda: x
5547      return ops.get_default_graph()._maybe_create_capture_placeholder(func)
5548
5549    # Set is not supported
5550    x = set([1, 2])
5551    with self.assertRaises(NotImplementedError):
5552      f()
5553
5554  # TODO(panzf): remove this test after exposing manual API, as the integration
5555  # testcase can be turned on at that time.
5556  def test_inner_nested_tf_function_raise_error(self):
5557    @def_function.function
5558    def tf_f():
5559
5560      @def_function.function
5561      def tf_g():
5562        cx = ops.get_default_graph()._experimental_capture_side_input_by_ref(  # pylint: disable=protected-access
5563            'lambda: x', lambda: x)
5564        return cx
5565
5566      return tf_g()
5567
5568    x = constant_op.constant(0)  # pylint: disable=unused-variable
5569    with self.assertRaisesRegex(
5570        NotImplementedError, 'Manual side input usage for inner nested'):
5571      tf_f()
5572
5573  @parameterized.parameters(
5574      (1, int, 2, int, 2),
5575      (1, constant_op.constant, 2, constant_op.constant, 1))
5576  def testRetraceLogicWithSideInputs(self, val_before, type_before, val_after,
5577                                     type_after, expected_len):
5578    @def_function.function
5579    def f():
5580      func = lambda: x
5581      return ops.get_default_graph()._experimental_capture_side_input_by_ref(  # pylint: disable=protected-access
5582          'lambda: x', func)
5583
5584    x = type_before(val_before)
5585    _ = f()
5586    x = type_after(val_after)
5587    _ = f()
5588    self.assertLen(total_function_cache(f), expected_len)
5589
5590  def testFunctoolsLruCache(self):
5591    self.skipTest(
5592        "b/194845243: inspect.getfullargspec doesn't unwrap Python decorators.")
5593
5594    @def_function.function
5595    @functools.lru_cache(maxsize=2)
5596    def f(a):
5597      return 2 * a
5598
5599    self.assertAllEqual(f(1), array_ops.constant(2))
5600
5601if __name__ == '__main__':
5602  ops.enable_eager_execution()
5603  test.main()
5604