xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/load_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for trackable object SavedModel loading."""
16
17import collections
18import contextlib
19import functools
20import gc
21import io
22import os
23import pathlib
24import sys
25import tempfile
26import weakref
27
28from absl.testing import parameterized
29import numpy as np
30from tensorflow.python.checkpoint import checkpoint
31from tensorflow.python.checkpoint import saveable_compat
32from tensorflow.python.client import session as session_lib
33from tensorflow.python.data.ops import dataset_ops
34from tensorflow.python.data.ops import readers
35from tensorflow.python.eager import backprop
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.eager import test
39from tensorflow.python.eager import wrap_function
40from tensorflow.python.framework import config
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import function as framework_function
45from tensorflow.python.framework import op_callbacks
46from tensorflow.python.framework import ops
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import test_util
50from tensorflow.python.framework import versions
51from tensorflow.python.lib.io import file_io
52from tensorflow.python.lib.io import tf_record
53from tensorflow.python.module import module
54from tensorflow.python.ops import array_ops
55from tensorflow.python.ops import cond_v2
56from tensorflow.python.ops import control_flow_ops
57from tensorflow.python.ops import custom_gradient
58from tensorflow.python.ops import lookup_ops
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import resource_variable_ops
61from tensorflow.python.ops import string_ops
62from tensorflow.python.ops import variable_scope
63from tensorflow.python.ops import variables
64from tensorflow.python.ops.ragged import ragged_factory_ops
65from tensorflow.python.ops.ragged import ragged_tensor
66from tensorflow.python.saved_model import load
67from tensorflow.python.saved_model import load_options
68from tensorflow.python.saved_model import loader_impl
69from tensorflow.python.saved_model import save
70from tensorflow.python.saved_model import save_options
71from tensorflow.python.saved_model import tag_constants
72from tensorflow.python.trackable import asset
73from tensorflow.python.trackable import autotrackable
74from tensorflow.python.trackable import resource
75from tensorflow.python.training import monitored_session
76from tensorflow.python.util import tf_inspect
77
78
79def cycle(obj, cycles, signatures=None, options=None):
80  to_save = obj
81  # TODO(vbardiovsky): It would be nice if exported protos reached a fixed
82  # point w.r.t. saving/restoring, ideally after 2nd saving.
83  for _ in range(cycles):
84    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
85    # If available, we'll run the save and restore preferring the GPU. This
86    # just makes sure we aren't throwing errors and have enough
87    # device("CPU") blocks to satisfy the placer.
88    with test_util.use_gpu():
89      save.save(to_save, path, signatures, options=options)
90      loaded = load.load(path)
91      signatures = loaded.signatures
92    to_save = loaded
93  return loaded
94
95
96@parameterized.named_parameters(
97    dict(testcase_name="ReloadOnce", cycles=1),
98    dict(testcase_name="ReloadTwice", cycles=2),
99    dict(testcase_name="ReloadThrice", cycles=3)
100)
101class LoadTest(test.TestCase, parameterized.TestCase):
102
103  def test_structure_import(self, cycles):
104    root = autotrackable.AutoTrackable()
105    root.dep_one = autotrackable.AutoTrackable()
106    root.dep_two = autotrackable.AutoTrackable()
107    root.dep_two.dep = autotrackable.AutoTrackable()
108    root.dep_three = root.dep_two.dep
109    imported = cycle(root, cycles)
110    self.assertIs(imported.dep_three, imported.dep_two.dep)
111    self.assertIsNot(imported.dep_one, imported.dep_two)
112
113  @test_util.run_in_graph_and_eager_modes
114  def test_variables(self, cycles):
115    root = autotrackable.AutoTrackable()
116    root.v1 = variables.Variable(1., trainable=True)
117    root.v2 = variables.Variable(2., trainable=False)
118    self.evaluate([root.v1.initializer, root.v2.initializer])
119
120    for _ in range(cycles):
121      imported = cycle(root, 1)
122      self.evaluate([imported.v1.initializer, imported.v2.initializer])
123
124    if not context.executing_eagerly():
125      self.assertIsInstance(imported.v1.initializer, ops.Operation)
126      self.assertIsInstance(imported.v2.initializer, ops.Operation)
127
128    self.assertEqual(self.evaluate(imported.v1), 1.0)
129    self.assertTrue(imported.v1.trainable)
130    self.assertEqual(self.evaluate(imported.v2), 2.0)
131    self.assertFalse(imported.v2.trainable)
132
133  def test_variables_name(self, cycles):
134    root = autotrackable.AutoTrackable()
135    # Test 2 variables with same name: should work as the checkpoint
136    # is based on object name and not on variable name.
137    root.v1 = variables.Variable(1., trainable=True, name="v1")
138    root.v2 = variables.Variable(2., trainable=False, name="v1")
139    imported = cycle(root, cycles)
140    self.assertEqual(imported.v1.numpy(), 1.0)
141    self.assertEqual(imported.v2.numpy(), 2.0)
142    self.assertEqual(imported.v1.name, root.v1.name)
143    self.assertEqual(imported.v2.name, root.v2.name)
144    with variable_scope.variable_scope("foo"):
145      imported = cycle(root, cycles)
146      self.assertTrue(imported.v1.name.startswith("foo/"))
147      self.assertTrue(imported.v2.name.startswith("foo/"))
148
149  def test_partially_defined_variable_shape(self, cycles):
150
151    class MakeVariable(module.Module):
152
153      def __init__(self):
154        self.v = None
155
156      @def_function.function(
157          input_signature=[tensor_spec.TensorSpec([None], dtypes.int64)])
158      def make_variable(self, initial_value):
159        if self.v is None:
160          self.v = variables.Variable(initial_value)
161
162    m = MakeVariable()
163    m.make_variable([1, 2, 3])
164    m = cycle(m, cycles)
165    m.v.assign([1, 2, 3, 4])
166    self.assertEqual([None], tensor_shape.as_shape(m.v.shape).as_list())
167
168  @test_util.run_in_graph_and_eager_modes
169  def test_capture_variables(self, cycles):
170    root = autotrackable.AutoTrackable()
171    root.weights = variables.Variable(2.)
172    self.evaluate(root.weights.initializer)
173    root.f = def_function.function(
174        lambda x: root.weights * x,
175        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
176    for _ in range(cycles):
177      imported = cycle(root, 1)
178      self.evaluate(imported.weights.initializer)
179    self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.))))
180    self.evaluate(imported.weights.assign(4.0))
181    self.assertEqual(8., self.evaluate(imported.f(constant_op.constant(2.))))
182
183  @test_util.run_in_graph_and_eager_modes
184  def test_capture_constant(self, cycles):
185    root = autotrackable.AutoTrackable()
186    captured_constant = constant_op.constant(2.)
187    root.f = def_function.function(
188        lambda x: captured_constant * x,
189        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
190    imported = cycle(root, cycles)
191    self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.))))
192
193  def test_control_outputs(self, cycles):
194    exported = autotrackable.AutoTrackable()
195    exported.v = variables.Variable(1.)
196    exported.f = def_function.function(
197        lambda: exported.v.assign(2., name="should_be_control_output"))
198    exported_graph = exported.f.get_concrete_function().graph
199    self.assertIn(
200        exported_graph.get_operation_by_name("should_be_control_output"),
201        exported_graph.control_outputs)
202
203    imported = cycle(exported, cycles)
204    # Calling get_concrete_function wraps in a second call operation; we want to
205    # inspect the original function body for the control output; digging into
206    # graph.as_graph_def() and its FunctionDefLibrary is another option.
207    imported_concrete, = imported.f.concrete_functions
208    imported_graph = imported_concrete.graph
209    self.assertIn(
210        imported_graph.get_operation_by_name("should_be_control_output"),
211        imported_graph.control_outputs)
212
213  def _make_asset(self, contents):
214    fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir())
215    with os.fdopen(fd, "w") as f:
216      f.write(contents)
217    return filename
218
219  @test_util.run_in_graph_and_eager_modes
220  def test_assets(self, cycles):
221    file1 = self._make_asset("contents 1")
222    file2 = self._make_asset("contents 2")
223
224    root = autotrackable.AutoTrackable()
225    root.asset1 = asset.Asset(file1)
226    root.asset2 = asset.Asset(file2)
227
228    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
229    save.save(root, save_dir)
230
231    file_io.delete_file(file1)
232    file_io.delete_file(file2)
233    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
234    file_io.rename(save_dir, load_dir)
235
236    imported = load.load(load_dir)
237    with open(self.evaluate(imported.asset1.asset_path), "r") as f:
238      self.assertEqual("contents 1", f.read())
239    with open(self.evaluate(imported.asset2.asset_path), "r") as f:
240      self.assertEqual("contents 2", f.read())
241
242  def test_cond_prune(self, cycles):
243    x_in = []
244    x_out = []
245
246    def f(x, y):
247      x_in.append(x)
248      xx = cond_v2.cond_v2(
249          math_ops.less(1, 2),
250          lambda: x + 1,
251          lambda: x + 2,
252      )
253      x_out.append(xx)
254      return xx, 2 * y
255
256    f_wrapped = wrap_function.wrap_function(
257        f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2)
258    f_pruned = f_wrapped.prune(x_in[0], [x_out[0]])
259
260    class Adder(module.Module):
261
262      @def_function.function(input_signature=[
263          tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)])
264      def add(self, x):
265        return f_pruned(x)
266
267    root = Adder()
268    root.add(constant_op.constant(1.))
269    root = cycle(root, cycles)
270    root.add(constant_op.constant(1.))
271
272  def test_capture_assets(self, cycles):
273    root = autotrackable.AutoTrackable()
274    root.vocab = asset.Asset(self._make_asset("contents"))
275    root.f = def_function.function(
276        lambda: root.vocab.asset_path,
277        input_signature=[])
278    imported = cycle(root, cycles)
279    original_output = root.f().numpy()
280    imported_output = imported.f().numpy()
281    self.assertNotEqual(original_output, imported_output)
282    with open(imported_output, "r") as f:
283      self.assertEqual("contents", f.read())
284
285  def test_capture_assets_in_graph(self, cycles):
286    root = autotrackable.AutoTrackable()
287    root.vocab = asset.Asset(self._make_asset("contents"))
288    root.f = def_function.function(
289        lambda: root.vocab.asset_path,
290        input_signature=[])
291
292    original_output = root.f().numpy()
293
294    if cycles > 1:
295      root = cycle(root, cycles - 1)
296    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
297    save.save(root, path)
298
299    with ops.Graph().as_default():
300      imported = load.load(path)
301      imported_tensor = imported.f()
302      with monitored_session.MonitoredSession() as sess:
303        imported_output = sess.run(imported_tensor)
304        self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1)
305        self.assertNotEqual(original_output, imported_output)
306        with open(imported_output, "r") as f:
307          self.assertEqual("contents", f.read())
308
309  def test_dedup_assets(self, cycles):
310    vocab = self._make_asset("contents")
311    root = autotrackable.AutoTrackable()
312    root.asset1 = asset.Asset(vocab)
313    root.asset2 = asset.Asset(vocab)
314    imported = cycle(root, cycles)
315    self.assertEqual(imported.asset1.asset_path.numpy(),
316                     imported.asset2.asset_path.numpy())
317
318  def test_asset_fspath(self, cycles):
319    vocab = pathlib.Path(self._make_asset("contents"))
320    root = autotrackable.AutoTrackable()
321    root.asset = asset.Asset(vocab)
322    imported = cycle(root, cycles)
323    self.assertTrue(hasattr(imported, "asset"))
324
325  def test_implicit_input_signature(self, cycles):
326    @def_function.function
327    def func(x):
328      return 2 * x
329
330    root = autotrackable.AutoTrackable()
331    root.f = func
332
333    # Add two traces.
334    root.f(constant_op.constant(1.))
335    root.f(constant_op.constant(1))
336
337    imported = cycle(root, cycles)
338
339    self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy())
340    self.assertEqual(14, imported.f(constant_op.constant(7)).numpy())
341
342  def test_explicit_input_signature(self, cycles):
343    @def_function.function(
344        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
345    def func(x):
346      return 2 * x
347
348    root = autotrackable.AutoTrackable()
349    root.f = func
350
351    imported = cycle(root, cycles)
352    self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
353
354  def test_explicit_save_signature(self, cycles):
355    @def_function.function
356    def func(x):
357      return 2 * x
358
359    root = autotrackable.AutoTrackable()
360    root.f = func
361
362    imported = cycle(
363        root, cycles, {
364            "f":
365                root.f.get_concrete_function(
366                    tensor_spec.TensorSpec(None, dtypes.float32))
367        })
368    self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy())
369
370  def test_nested_functions(self, cycles):
371    f = def_function.function(
372        lambda x: x*2.0,
373        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
374    g = def_function.function(
375        lambda x: f(x) + 1.0,
376        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
377
378    root = autotrackable.AutoTrackable()
379    root.g = g
380    imported = cycle(root, cycles)
381    imported.g(constant_op.constant([1.0]))
382
383  def test_function_with_default_bool_input(self, cycles):
384
385    def func(x, training=False):
386      if training:
387        return 2 * x
388      else:
389        return 7
390
391    root = autotrackable.AutoTrackable()
392    root.f = def_function.function(func)
393
394    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
395    self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
396    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
397
398    imported = cycle(root, cycles)
399
400    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
401    self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
402
403  def test_function_with_default_none_input(self, cycles):
404
405    def func(x, dtype=None):
406      if dtype:
407        return array_ops.zeros(shape=x.shape, dtype=dtype)
408      else:
409        return array_ops.zeros(shape=x.shape, dtype=dtypes.float32)
410
411    root = autotrackable.AutoTrackable()
412    root.f = def_function.function(func)
413
414    self.assertAllEqual([0.0, 0.0, 0.0],
415                        root.f(constant_op.constant([1, 2, 3])).numpy())
416    self.assertAllEqual([0.0, 0.0, 0.0],
417                        root.f(constant_op.constant([1.0, 2.0, 3.0])).numpy())
418    self.assertAllEqual([0.0, 0.0, 0.0, 0.0],
419                        root.f(constant_op.constant([1, 2, 3, 4])).numpy())
420    self.assertAllEqual([0, 0, 0],
421                        root.f(
422                            constant_op.constant([1.0, 2.0, 3.0]),
423                            dtype=dtypes.int32).numpy())
424
425    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
426    self.assertLen(concrete_functions, 4)
427
428    imported = cycle(root, cycles)
429
430    self.assertAllEqual([0.0, 0.0, 0.0],
431                        imported.f(constant_op.constant([1, 2, 3]),
432                                   None).numpy())
433    self.assertAllEqual([0.0, 0.0, 0.0],
434                        imported.f(constant_op.constant([1.0, 2.0,
435                                                         3.0])).numpy())
436    self.assertAllEqual([0.0, 0.0, 0.0, 0.0],
437                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
438    self.assertAllEqual([0, 0, 0],
439                        imported.f(
440                            constant_op.constant([1.0, 2.0, 3.0]),
441                            dtype=dtypes.int32).numpy())
442
443  def test_function_with_str_bytes_input(self, cycles):
444
445    @def_function.function
446    def func(x, y):
447      return string_ops.string_join([x, y])
448
449    root = autotrackable.AutoTrackable()
450    root.f = func
451
452    self.assertAllEqual(b"ab", root.f("a", "b"))
453    self.assertAllEqual(b"ab", root.f("a", constant_op.constant("b")))
454    self.assertAllEqual(b"ab", root.f(constant_op.constant("a"), "b"))
455
456    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
457    self.assertLen(concrete_functions, 3)
458
459    imported = cycle(root, cycles)
460
461    self.assertAllEqual(b"ab", imported.f("a", "b"))
462    self.assertAllEqual(b"ab", imported.f("a", constant_op.constant("b")))
463    self.assertAllEqual(b"ab", imported.f(constant_op.constant("a"), "b"))
464
465  def test_function_no_return(self, cycles):
466
467    class TrackableWithOneVariable(autotrackable.AutoTrackable):
468
469      def __init__(self, initial_value=0.0):
470        super(TrackableWithOneVariable, self).__init__()
471        self.variable = variables.Variable(initial_value)
472
473      @def_function.function
474      def increase(self, by=1.0):
475        self.variable.assign_add(by)
476
477    obj = TrackableWithOneVariable(5.0)
478
479    obj.increase(constant_op.constant(10.0))
480    self.assertEqual(15.0, obj.variable.numpy())
481    obj.increase()
482    self.assertEqual(16.0, obj.variable.numpy())
483
484    imported = cycle(obj, cycles)
485
486    imported.increase(constant_op.constant(10.0))
487    self.assertEqual(26.0, imported.variable.numpy())
488    imported.increase(constant_op.constant(1.0))
489    self.assertEqual(27.0, imported.variable.numpy())
490
491  def test_structured_inputs(self, cycles):
492
493    def func(x, training=True):
494      # x is a nested structure, we care about one particular tensor.
495      _, (a, b) = x
496      if training:
497        return 2 * a["a"] + b
498      else:
499        return 7
500
501    root = autotrackable.AutoTrackable()
502    root.f = def_function.function(func)
503
504    x = constant_op.constant(10)
505    y = constant_op.constant(11)
506
507    input1 = [6, ({"a": x}, y)]
508    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
509    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.
510
511    # Note: by only calling f(input1) before serialization, only inputs with
512    # matching signature will be valid on the loaded model.
513    self.assertEqual(31, root.f(input1).numpy())
514
515    imported = cycle(root, cycles)
516
517    with self.assertRaisesRegex(
518        ValueError, "Could not find matching concrete function to call"):
519      imported.f(input2)
520
521    self.assertEqual(31, imported.f(input1).numpy())
522    self.assertEqual(32, imported.f(input3).numpy())
523
524  def test_structured_inputs_bare_concrete_function(self, cycles):
525
526    def func(x, training=True):
527      # x is a nested structure, we care about one particular tensor.
528      _, (a, b) = x
529      if training:
530        return 2 * a["a"] + b
531      else:
532        return 7
533
534    x = constant_op.constant(10)
535    y = constant_op.constant(11)
536
537    input1 = [6, ({"a": x}, y)]
538    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
539    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.
540
541    root = autotrackable.AutoTrackable()
542    root.f = def_function.function(func).get_concrete_function(input1)
543
544    imported = cycle(root, cycles)
545
546    with self.assertRaises(TypeError):
547      imported.f(input2)
548
549    self.assertEqual(31, imported.f(input1).numpy())
550    self.assertEqual(32, imported.f(input3).numpy())
551
552  def test_structured_output(self, cycles):
553
554    # Use fields with non-alphabetical order
555    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])
556
557    def func(input1, input2):
558      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
559      return [named_tuple, input2, {"x": 0.5}]
560
561    root = autotrackable.AutoTrackable()
562    root.f = def_function.function(func)
563
564    result = root.f(constant_op.constant(2), constant_op.constant(3))
565
566    self.assertEqual(5, result[0].a.numpy())
567    self.assertEqual(6, result[0].b.numpy())
568    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
569    self.assertEqual(3, result[1].numpy())
570    self.assertEqual(0.5, result[2]["x"].numpy())
571
572    imported = cycle(root, cycles)
573
574    result = imported.f(constant_op.constant(2), constant_op.constant(5))
575    self.assertEqual(7, result[0].a.numpy())
576    self.assertEqual(10, result[0].b.numpy())
577    self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
578    self.assertEqual(5, result[1].numpy())
579    self.assertEqual(0.5, result[2]["x"].numpy())
580
581  def test_pretty_print_signature(self, cycles):
582
583    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])
584
585    def func(input1, input2):
586      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
587      return [named_tuple, input2, {"x": 0.5}]
588
589    root = autotrackable.AutoTrackable()
590    root.f = def_function.function(func).get_concrete_function(
591        constant_op.constant(2), constant_op.constant(3))
592
593    imported = cycle(root, cycles)
594    self.assertEqual(
595        imported.f.pretty_printed_signature(), """func(input1, input2)
596  Args:
597    input1: int32 Tensor, shape=()
598    input2: int32 Tensor, shape=()
599  Returns:
600    [NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}]
601      <1>: int32 Tensor, shape=()
602      <2>: int32 Tensor, shape=()
603      <3>: int32 Tensor, shape=()
604      <4>: float32 Tensor, shape=()""")
605
606  def test_positional_arguments(self, cycles):
607    def func(x, training=False, abc=7.1, defg=7.7):
608      del abc
609      if training:
610        return 2 * x
611      if defg == 7:
612        return 6
613      else:
614        return 7
615
616    root = autotrackable.AutoTrackable()
617    root.f = def_function.function(func)
618
619    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
620    self.assertEqual(7, root.f(constant_op.constant(1)).numpy())
621    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
622    self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy())
623
624    imported = cycle(root, cycles)
625
626    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
627    self.assertEqual(7, imported.f(constant_op.constant(2)).numpy())
628    self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy())
629
630  def test_additional_kwargs(self, cycles):
631    def func(x, training=False, **options):
632      del options
633      if training:
634        return 2 * x
635      else:
636        return 7
637
638    root = autotrackable.AutoTrackable()
639    root.f = def_function.function(func)
640
641    x = constant_op.constant(10)
642    self.assertEqual(7, root.f(x, learning_rate=0.5, epochs=3).numpy())
643
644    imported = cycle(root, cycles)
645
646    with self.assertRaisesRegex(
647        ValueError, "Could not find matching concrete function to call.*"):
648      imported.f(x, learning_rate=0.5, epochs=4)
649
650    self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy())
651
652  def test_member_function(self, cycles):
653    class TrackableWithMember(autotrackable.AutoTrackable):
654
655      def __init__(self):
656        super(TrackableWithMember, self).__init__()
657        self._some_value = 20
658
659      @def_function.function
660      def f(self, x, training=False):
661        if training:
662          return 2 * x
663        else:
664          return 7 + self._some_value
665
666    root = TrackableWithMember()
667
668    self.assertEqual(20, root.f(constant_op.constant(10), True).numpy())
669    self.assertEqual(27, root.f(constant_op.constant(1)).numpy())
670    self.assertEqual(2, root.f(constant_op.constant(1), True).numpy())
671
672    imported = cycle(root, cycles)
673
674    self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy())
675    self.assertEqual(27, imported.f(constant_op.constant(2)).numpy())
676
677  def test_side_effect_listing(self, cycles):
678    class M(autotrackable.AutoTrackable):
679
680      def __init__(self):
681        super(M, self).__init__()
682        self.var = None
683
684      @def_function.function(
685          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
686      def f(self, x):
687        if self.var is None:
688          self.var = variables.Variable(2.)
689        return x * self.var
690
691    m = M()
692    cycle(m, cycles)
693    self.assertEqual(4.0, m.f(constant_op.constant(2.0)).numpy())
694
695  def test_basic_backprop(self, cycles):
696    weight = variables.Variable(1., trainable=True)
697    bias = variables.Variable(0., trainable=True)
698    g = def_function.function(
699        lambda x: x*weight + bias,
700        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
701
702    root = autotrackable.AutoTrackable()
703    root.weight = weight
704    root.bias = bias
705    root.g = g
706    imported = cycle(root, cycles)
707    with backprop.GradientTape() as t:
708      x = constant_op.constant([3.5])
709      loss = imported.g(x)
710      grad = t.gradient(loss, [imported.weight, imported.bias])
711      self.assertAllClose(grad, [3.5, 1.0])
712
713  def test_nested_backprop(self, cycles):
714    weight = variables.Variable(1., trainable=True)
715    bias = variables.Variable(0., trainable=True)
716
717    # Note: this function gets called from other function defs via a
718    # "PartitionedCall" op node.
719    @def_function.function(input_signature=[
720        tensor_spec.TensorSpec(None, dtypes.float32),
721        tensor_spec.TensorSpec(None, dtypes.float32)])
722    def mul(x, y):
723      return x * y
724
725    # Note: this function gets called from other function defs via a
726    # "StatefulPartitionedCall" op node.
727    @def_function.function(input_signature=[
728        tensor_spec.TensorSpec(None, dtypes.float32)])
729    def f(x):
730      return mul(weight.read_value(), x)
731
732    @def_function.function(input_signature=[
733        tensor_spec.TensorSpec(None, dtypes.float32)])
734    def g(x):
735      return f(x) + bias,
736
737    @def_function.function(input_signature=[
738        tensor_spec.TensorSpec(None, dtypes.float32)])
739    def h(x):
740      return g(x) + bias,
741
742    root = autotrackable.AutoTrackable()
743    root.weight = weight
744    root.bias = bias
745    root.g = h
746
747    imported = cycle(root, cycles)
748    with backprop.GradientTape() as t:
749      x = constant_op.constant([3.5])
750      loss = imported.g(x)
751    grad = t.gradient(loss, [imported.weight, imported.bias])
752    self.assertAllClose(grad, [3.5, 2.0])
753
754  def test_while_loop_backprop(self, cycles):
755    weight = variables.Variable(2., trainable=True)
756
757    @def_function.function(input_signature=[
758        tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
759    def g(x):
760      """Adds rows of matrix x after multiplying each entry by v."""
761      i_0 = constant_op.constant(0)
762      s_0 = constant_op.constant([0., 0.])
763      cond = lambda i, _: i < array_ops.shape(x)[1]
764      body = lambda i, s: (i + 1, s + weight * x[:, i])
765      i_end, s_end = control_flow_ops.while_loop(cond, body, (i_0, s_0))
766      del i_end
767      return s_end
768
769    root = autotrackable.AutoTrackable()
770    root.weight = weight
771    root.g = g
772    imported = cycle(root, cycles)
773
774    def get_gradient(obj):
775      with backprop.GradientTape() as t:
776        x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
777        y = obj.g(x)
778        self.assertAllClose(y, obj.weight * [6., 2.])
779        loss = math_ops.reduce_sum(y)  # weight * 8.
780        self.assertAllEqual(t.watched_variables(), [obj.weight])
781        return t.gradient(loss, obj.weight)
782
783    imported_gradient = get_gradient(imported)
784    original_gradient = get_gradient(root)
785    self.assertIsNotNone(original_gradient)
786    self.assertAllClose(original_gradient, 8.)
787    self.assertIsNotNone(imported_gradient)
788    self.assertAllClose(imported_gradient, 8.)
789
790  def _test_restored_func_with_captured_var_backprop(self, cycles, dtype):
791    weight = variables.Variable(2., trainable=True, dtype=dtype)
792
793    @def_function.function(input_signature=[
794        tensor_spec.TensorSpec(dtype=dtype, shape=())])
795    def g(x):
796      return x * weight
797
798    root = autotrackable.AutoTrackable()
799    root.weight = weight
800    root.g = g
801    imported = cycle(root, cycles)
802
803    def get_gradient(obj):
804      with backprop.GradientTape() as t:
805        x = constant_op.constant(2.)
806        y = obj.g(x)
807        self.assertAllClose(y, obj.weight * 2.)
808        self.assertAllEqual(t.watched_variables(), [obj.weight])
809        return t.gradient(y, obj.weight)
810
811    imported_gradient = get_gradient(imported)
812    original_gradient = get_gradient(root)
813    self.assertIsNotNone(original_gradient)
814    self.assertAllClose(original_gradient, 2.)
815    self.assertIsNotNone(imported_gradient)
816    self.assertAllClose(imported_gradient, 2.)
817
818  def test_nested_fn_backprop(self, cycles):
819    weight = variables.Variable(2., trainable=True)
820
821    @def_function.function(input_signature=[
822        tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
823    def g(x):
824      weight.read_value()  # Just get the tape to watch the variable
825      handle = array_ops.identity(weight.handle)
826      @def_function.function
827      def launder_var_handle():
828        return array_ops.identity(handle)
829      return x + resource_variable_ops.read_variable_op(
830          launder_var_handle(), dtypes.float32)
831
832    root = autotrackable.AutoTrackable()
833    root.weight = weight
834    root.g = g
835    imported = cycle(root, cycles)
836    def get_gradient(obj, persistent):
837      with backprop.GradientTape(persistent=persistent) as t:
838        x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
839        y = obj.g(x)
840        self.assertAllClose(y, obj.weight + x)
841        loss = math_ops.reduce_sum(y)
842        return t.gradient(loss, obj.weight)
843
844    imported_gradient = get_gradient(imported, persistent=False)
845    original_gradient = get_gradient(root, persistent=False)
846    self.assertIsNotNone(original_gradient)
847    self.assertAllClose(original_gradient, 6.)
848    self.assertIsNotNone(imported_gradient)
849    self.assertAllClose(imported_gradient, 6.)
850
851  def test_restored_func_with_captured_var_backprop_float32(self, cycles):
852    self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32)
853
854  def test_restored_func_with_captured_var_backprop_float64(self, cycles):
855    self.skipTest("b/144573917")
856    self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float64)
857
858  def test_callable(self, cycles):
859    class M1(autotrackable.AutoTrackable):
860
861      @def_function.function(
862          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
863      def __call__(self, x):
864        return x
865
866    root = autotrackable.AutoTrackable()
867    root.m1 = M1()
868    root.m2 = autotrackable.AutoTrackable()
869    root.m2.__call__ = def_function.function(
870        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
871            lambda x: x*3.0)
872    imported = cycle(root, cycles)
873    x = constant_op.constant(1.0)
874
875    self.assertTrue(callable(imported.m1))
876    self.assertAllEqual(root.m1(x), imported.m1(x))
877
878    # Note: `root.m2` was not callable since `__call__` attribute was set
879    # into the instance and not on the class. But after a serialization cycle
880    # that starts to work.
881    self.assertTrue(callable(imported.m2))
882    self.assertAllEqual(root.m2.__call__(x), imported.m2(x))
883
884    # Verify that user objects without `__call__` attribute are not callable.
885    self.assertFalse(callable(imported))
886
887  def test_chain_callable(self, cycles):
888    func = def_function.function(
889        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
890            lambda x: x*3.0)
891    root = autotrackable.AutoTrackable()
892    root.__call__ = autotrackable.AutoTrackable()
893    root.__call__.__call__ = autotrackable.AutoTrackable()
894    root.__call__.__call__.__call__ = func
895
896    imported = cycle(root, cycles)
897    self.assertTrue(callable(imported))
898    x = constant_op.constant(1.0)
899    self.assertAllEqual(imported(x).numpy(), 3.0)
900
901  def test_load_in_graph_mode(self, cycles):
902    root = autotrackable.AutoTrackable()
903    root.v1 = variables.Variable(1., name="v_one", trainable=False)
904    root.v2 = variables.Variable(2., name="v_two", trainable=True)
905    root.f = def_function.function(
906        lambda x: root.v2 * x,
907        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
908
909    if cycles > 1:
910      root = cycle(root, cycles - 1)
911    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
912    save.save(root, path)
913
914    with ops.Graph().as_default() as g:
915      imported = load.load(path)
916      var_v1 = imported.v1
917      self.assertFalse(var_v1.trainable)
918      var_v2 = imported.v2
919      self.assertTrue(var_v2.trainable)
920      output = imported.f(constant_op.constant(2.))
921      with monitored_session.MonitoredSession() as sess:
922        self.assertEqual(1.0, sess.run(var_v1))
923        self.assertEqual(4.0, sess.run(output))
924      self.assertCountEqual([var_v1, var_v2],
925                            g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
926      # load() should not add to TRAINABLE_VARIABLES. Higher levels of model
927      # building control retraining or frozen use of imported SavedModels.
928      self.assertCountEqual([],
929                            g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
930
931  def test_load_in_func_graph(self, cycles):
932    root = autotrackable.AutoTrackable()
933    root.v1 = variables.Variable(1.)
934    root.v2 = variables.Variable(2.)
935    root.f = def_function.function(
936        lambda x: root.v2 * x,
937        input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
938
939    if cycles > 1:
940      root = cycle(root, cycles - 1)
941    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
942    save.save(root, path)
943
944    closure = autotrackable.AutoTrackable()
945    @def_function.function
946    def func(x):
947      if not hasattr(closure, "model"):
948        closure.model = load.load(path)
949      return closure.model.f(x)
950
951    inputs = constant_op.constant(2.)
952    self.assertEqual(4.0, func(inputs).numpy())
953
954  def test_soft_matching(self, cycles):
955
956    @def_function.function(
957        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
958    def func(x):
959      return 2 * x
960
961    root = autotrackable.AutoTrackable()
962    root.f = func
963
964    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
965    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())
966
967    concrete_functions = root.f._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
968    self.assertLen(concrete_functions, 1)
969
970    imported = cycle(root, cycles)
971
972    with self.assertRaisesRegex(ValueError, "Python inputs incompatible"):
973      # We cannot call the function with a constant of shape ().
974      imported.f(constant_op.constant(2)).numpy()
975
976    # TODO(vbardiovsky): When classes are revived with input_signatures, we
977    # should also check that the calls below are not generating any more
978    # concrete functions.
979    self.assertAllEqual([2, 4, 6, 8],
980                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
981    self.assertAllEqual([2, 4, 6],
982                        imported.f(constant_op.constant([1, 2, 3])).numpy())
983
984  def test_jit_compile(self, cycles):
985
986    # It'd be nice to use parameterize here, but the library does not support
987    # having parameterized test methods inside already-parameterized classes.
988    for jit_compile in (None, True, False):
989
990      @def_function.function(jit_compile=jit_compile)
991      def f(x):
992        return x + 1.
993
994      root = module.Module()
995      root.f = f
996      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
997      save.save(root, save_dir)
998
999      imported = cycle(root, cycles)
1000
1001      self.assertEqual(imported.f._jit_compile, jit_compile)
1002
1003  def test_get_concrete_function(self, cycles):
1004
1005    @def_function.function
1006    def func(x, training=False):
1007      if training:
1008        return 2 * x
1009      else:
1010        return 3 * x
1011
1012    func.get_concrete_function(
1013        tensor_spec.TensorSpec([None], dtypes.int32), True)
1014    func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32))
1015
1016    root = autotrackable.AutoTrackable()
1017    root.f = func
1018
1019    imported = cycle(root, cycles)
1020
1021    concrete = imported.f.get_concrete_function(
1022        training=True, x=tensor_spec.TensorSpec([None], dtypes.int32))
1023
1024    self.assertAllEqual([2, 4, 6, 8],
1025                        concrete(x=constant_op.constant([1, 2, 3, 4])).numpy())
1026    with self.assertRaisesRegex(
1027        ValueError, "Could not find matching concrete function to call"):
1028      imported.f.get_concrete_function(
1029          tensor_spec.TensorSpec([None], dtypes.int32))
1030    imported.f.get_concrete_function(
1031        tensor_spec.TensorSpec([None], dtypes.int32), True)
1032
1033  def test_concrete_function(self, cycles):
1034
1035    @def_function.function(
1036        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
1037    def func(x):
1038      return 2 * x
1039
1040    root = autotrackable.AutoTrackable()
1041    root.f = func.get_concrete_function()
1042
1043    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
1044    self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy())
1045
1046    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
1047    imported = cycle(root, cycles, signatures={})
1048
1049    self.assertAllEqual([2, 4, 6, 8],
1050                        imported.f(constant_op.constant([1, 2, 3, 4])).numpy())
1051    self.assertAllEqual([2, 4, 6],
1052                        imported.f(constant_op.constant([1, 2, 3])).numpy())
1053
1054  def test_concrete_function_captures(self, cycles):
1055
1056    class Root(module.Module):
1057
1058      def __init__(self):
1059        self.v = variables.Variable(1.)
1060        self.v1 = variables.Variable(1.)
1061
1062      @def_function.function(
1063          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
1064      def use_v(self, x):
1065        return self.v + self.v1 + 1.
1066
1067    root = Root()
1068    self.assertIn(root.v.handle,
1069                  root.use_v.get_concrete_function().graph.external_captures)
1070    root = cycle(root, cycles, signatures=root.use_v.get_concrete_function())
1071    func_captures = root.use_v.get_concrete_function().graph.external_captures
1072    self.assertLen(func_captures, 2)
1073    self.assertTrue(any(root.v.handle is t for t in func_captures))
1074    self.assertTrue(any(root.v1.handle is t for t in func_captures))
1075    signature_captures = root.signatures[
1076        "serving_default"].graph.external_captures
1077    self.assertLen(signature_captures, 2)
1078    self.assertTrue(any(root.v.handle is t for t in signature_captures))
1079    self.assertTrue(any(root.v1.handle is t for t in signature_captures))
1080
1081  def test_concrete_function_arg_names(self, cycles):
1082
1083    @def_function.function(
1084        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
1085    def func(x):
1086      return 2 * x
1087
1088    root = autotrackable.AutoTrackable()
1089    root.f = func.get_concrete_function()
1090
1091    self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy())
1092
1093    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
1094    imported = cycle(root, cycles, signatures={})
1095
1096    self.assertAllEqual([2, 4, 6],
1097                        imported.f(x=constant_op.constant([1, 2, 3])).numpy())
1098
1099  def test_concrete_function_no_signature(self, cycles):
1100    @def_function.function
1101    def func(x):
1102      return 2 * x
1103
1104    root = autotrackable.AutoTrackable()
1105    root.f = func.get_concrete_function(constant_op.constant([1]))
1106    self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy())
1107    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
1108    imported = cycle(root, cycles, signatures={})
1109    self.assertAllEqual([6],
1110                        imported.f(constant_op.constant([3])).numpy())
1111
1112  @test_util.run_in_graph_and_eager_modes
1113  def test_concrete_function_backprop(self, cycles):
1114    @def_function.function(
1115        input_signature=[tensor_spec.TensorSpec([], dtypes.float32)])
1116    def func(x):
1117      return x ** 2.
1118    root = autotrackable.AutoTrackable()
1119    root.f = func.get_concrete_function()
1120
1121    def _compute_gradient(function):
1122      with backprop.GradientTape() as tape:
1123        inp = constant_op.constant(1.)
1124        tape.watch(inp)
1125        output = function(inp)
1126      return tape.gradient(output, inp)
1127
1128    self.assertAllEqual(2., _compute_gradient(root.f))
1129    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
1130    imported = cycle(root, cycles, signatures={})
1131    self.assertAllEqual(2., _compute_gradient(imported.f))
1132
1133  def test_revived_concrete_function_kwargs(self, cycles):
1134
1135    @def_function.function
1136    def func(x, y):
1137      return x * (y + 1.)
1138    root = autotrackable.AutoTrackable()
1139    root.f = func.get_concrete_function(
1140        tensor_spec.TensorSpec([], dtypes.float32),
1141        tensor_spec.TensorSpec([], dtypes.float32))
1142    self.assertEqual(8., root.f(y=constant_op.constant(3.),
1143                                x=constant_op.constant(2.)).numpy())
1144    # TODO(andresp): Fix exporting of loaded concrete functions as signatures.
1145    imported = cycle(root, cycles, signatures={})
1146    self.assertEqual(8., imported.f(y=constant_op.constant(3.),
1147                                    x=constant_op.constant(2.)).numpy())
1148
1149  def test_revived_concrete_function_tensorspec_kwargs(self, cycles):
1150
1151    @def_function.function
1152    def func(*args):
1153      x, y = args
1154      return x * (y + 1.)
1155    root = autotrackable.AutoTrackable()
1156    root.f = func.get_concrete_function(
1157        tensor_spec.TensorSpec([], dtypes.float32, name="x"),
1158        tensor_spec.TensorSpec([], dtypes.float32, name="y"))
1159    self.assertEqual(8., root.f(y=constant_op.constant(3.),
1160                                x=constant_op.constant(2.)).numpy())
1161    imported = cycle(root, cycles, signatures={})
1162    self.assertEqual(8., imported.f(y=constant_op.constant(3.),
1163                                    x=constant_op.constant(2.)).numpy())
1164
1165  def test_concrete_function_variable_argument(self, cycles):
1166    capture = variables.Variable(0)
1167
1168    @def_function.function
1169    def func(v):
1170      v.assign_add(1)
1171      capture.assign_sub(1)
1172
1173    @def_function.function(input_signature=[
1174        resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
1175    ])
1176    def func_with_input_signature(v):
1177      v.assign_add(5)
1178      capture.assign_sub(5)
1179      return 1
1180
1181    vsave = variables.Variable(1)
1182    root = autotrackable.AutoTrackable()
1183    root.f = func.get_concrete_function(vsave)
1184    root.f_sig = func_with_input_signature.get_concrete_function()
1185    root.capture = capture
1186
1187    self.assertEqual(1, vsave.numpy())
1188    root.f(vsave)
1189    self.assertEqual(2, vsave.numpy())
1190    self.assertEqual(-1, capture.numpy())
1191
1192    root.f_sig(vsave)
1193    self.assertEqual(7, vsave.numpy())
1194    self.assertEqual(-6, capture.numpy())
1195
1196    imported = cycle(root, cycles)
1197
1198    vload = variables.Variable(1)
1199    imported.f(vload)
1200    self.assertEqual(2, vload.numpy())
1201    imported.f(v=vload)
1202    self.assertEqual(3, vload.numpy())
1203    self.assertEqual(-8, imported.capture.numpy())
1204
1205    imported.f_sig(v=vload)
1206    self.assertEqual(8, vload.numpy())
1207    self.assertEqual(-13, imported.capture.numpy())
1208
1209    self.assertEqual(-6, capture.numpy())
1210
1211  def test_function_and_component(self, cycles):
1212
1213    @def_function.function
1214    def func(v):
1215      return v + 1
1216
1217    root = autotrackable.AutoTrackable()
1218    root.func = func
1219    root.concrete_func = func.get_concrete_function(
1220        tensor_spec.TensorSpec(None, dtypes.int32))
1221    one = constant_op.constant(1)
1222    self.assertEqual(2, root.func(one).numpy())
1223    self.assertEqual(2, root.concrete_func(one).numpy())
1224    imported = cycle(root, cycles)
1225    self.assertEqual(2, imported.func(one).numpy())
1226    self.assertEqual(2, imported.concrete_func(one).numpy())
1227
1228  def test_dict(self, cycles):
1229    root = autotrackable.AutoTrackable()
1230    root.variables = dict(a=variables.Variable(1.))
1231    root.variables["b"] = variables.Variable(2.)
1232    root.variables["c"] = 1
1233    root.funcs = dict(
1234        a=def_function.function(lambda: constant_op.constant(100.)))
1235    root.funcs["conc"] = root.funcs["a"].get_concrete_function()
1236    imported = cycle(root, cycles)
1237    self.assertEqual(1., imported.variables["a"].numpy())
1238    self.assertEqual(2., imported.variables["b"].numpy())
1239    self.assertEqual(set(["a", "b"]), set(imported.variables.keys()))
1240    self.assertEqual(100., imported.funcs["a"]().numpy())
1241    self.assertEqual(100., imported.funcs["conc"]().numpy())
1242
1243  def test_list(self, cycles):
1244    root = autotrackable.AutoTrackable()
1245    root.variables = [variables.Variable(1.)]
1246    root.variables.append(1)
1247    root.variables.append(variables.Variable(3.))
1248    imported = cycle(root, cycles)
1249    self.assertEqual(1., imported.variables[0].numpy())
1250    self.assertEqual(3., imported.variables[2].numpy())
1251    self.assertIs(None, imported.variables[1])
1252    self.assertLen(imported.variables, 3)
1253
1254  def test_tuple(self, cycles):
1255    root = autotrackable.AutoTrackable()
1256    root.variables = (variables.Variable(1.), 1, variables.Variable(3.))
1257    imported = cycle(root, cycles)
1258    self.assertEqual(1., imported.variables[0].numpy())
1259    self.assertEqual(3., imported.variables[2].numpy())
1260    self.assertIs(None, imported.variables[1])
1261    self.assertLen(imported.variables, 3)
1262
1263  def test_functions_list(self, cycles):
1264    root = autotrackable.AutoTrackable()
1265    v1 = variables.Variable(1.)
1266    root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))]
1267    root.variables = [v1]
1268
1269    @def_function.function
1270    def _v2_loss():
1271      if len(root.variables) == 1:
1272        v2 = variables.Variable(2.)
1273        root.variables.append(v2)
1274      return math_ops.reduce_sum(root.variables[1] ** 2)
1275
1276    root.losses.append(_v2_loss)
1277    self.assertAllClose([1., 4.], [loss() for loss in root.losses])
1278    imported = cycle(root, cycles)
1279    self.assertAllClose([1., 4.], [loss() for loss in imported.losses])
1280    imported.variables[0].assign(3.)
1281    imported.variables[1].assign(4.)
1282    self.assertAllClose([9., 16.], [loss() for loss in imported.losses])
1283
1284  def test_captured_constant(self, cycles):
1285    const = array_ops.zeros([100])
1286    root = autotrackable.AutoTrackable()
1287    root.f = def_function.function(lambda: const + 1.)
1288    root.g = def_function.function(lambda: const + 2.)
1289    self.assertAllClose(array_ops.ones([100]), root.f())
1290    self.assertAllClose(2. * array_ops.ones([100]), root.g())
1291    imported = cycle(root, cycles)
1292    self.assertAllClose(array_ops.ones([100]), imported.f())
1293    self.assertAllClose(2. * array_ops.ones([100]), imported.g())
1294    # TODO(b/123408994): Use the public get_concrete_function.
1295    f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0]
1296    g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0]
1297    self.assertLen(f_concrete.captured_inputs, 1)
1298    self.assertLen(g_concrete.captured_inputs, 1)
1299    # We should be using the same captured EagerTensor in both functions, not
1300    # duplicating the constant.
1301    self.assertIs(f_concrete.captured_inputs[0],
1302                  g_concrete.captured_inputs[0])
1303
1304  def test_functions_accessed_once(self, cycles):
1305
1306    class Exported(autotrackable.AutoTrackable):
1307
1308      def __init__(self):
1309        self._counter = 0
1310
1311      @property
1312      def make_func(self):
1313        @def_function.function
1314        def f():
1315          return constant_op.constant(self._counter)
1316        f.get_concrete_function()  # force a trace
1317        self._counter += 1
1318        return f
1319
1320    exported = Exported()
1321    imported = cycle(exported, cycles)
1322    self.assertEqual(0, imported.make_func().numpy())
1323    self.assertEqual(1, exported.make_func().numpy())
1324
1325  def test_overwritten_signatures_error(self, cycles):
1326    exported = autotrackable.AutoTrackable()
1327    exported.f = def_function.function(lambda: constant_op.constant(1.))
1328    imported = cycle(
1329        exported, cycles,
1330        signatures={"key": exported.f.get_concrete_function()})
1331    self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy())
1332    imported.signatures = {"key1": imported.signatures["key"]}
1333    with self.assertRaisesRegex(ValueError, "signatures"):
1334      save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir()))
1335
1336  def test_signature_loading(self, cycles):
1337
1338    class Exported(autotrackable.AutoTrackable):
1339
1340      def __init__(self):
1341        self.v = variables.Variable(3.)
1342
1343      @def_function.function
1344      def do(self, x):
1345        return self.v * x
1346
1347    exported = Exported()
1348    imported = cycle(
1349        exported,
1350        cycles,
1351        signatures=exported.do.get_concrete_function(
1352            tensor_spec.TensorSpec(None, dtypes.float32)))
1353    self.assertEqual(["serving_default"], list(imported.signatures.keys()))
1354    imported_function = imported.signatures["serving_default"]
1355    two = constant_op.constant(2.)
1356    self.assertEqual(6., imported_function(x=two)["output_0"].numpy())
1357    imported.v.assign(4.)
1358    self.assertEqual(8., imported_function(x=two)["output_0"].numpy())
1359    self.assertEqual(8., imported_function(two)["output_0"].numpy())
1360    with self.assertRaises(TypeError):
1361      # The signatures mapping is immutable
1362      imported.signatures["random_key"] = 3
1363
1364  def test_names_normalized(self, cycles):
1365    class ObjWithFunction(module.Module):
1366
1367      @def_function.function(input_signature=[
1368          tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"),
1369          tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"),
1370          tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"),
1371          tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"),
1372      ])
1373      def foo(self, a, b, c, d=10, **options):
1374        del options
1375        return a + b + c + d
1376
1377    exported = ObjWithFunction()
1378
1379    with self.assertLogs(level="WARNING") as logs:
1380      imported = cycle(exported, cycles)
1381
1382    expected_message = (
1383        "WARNING:absl:Function `foo` contains input name(s) A-b, A/D with "
1384        "unsupported characters which will be renamed to a_b, a_d in the "
1385        "SavedModel.")
1386    self.assertIn(expected_message, logs.output)
1387
1388    loaded_signature = imported.signatures["serving_default"].inputs
1389    self.assertEqual("a_b:0", loaded_signature[0].name)
1390    self.assertEqual("a_d:0", loaded_signature[1].name)
1391
1392  def test_multiple_argument_signatures_no_positional(self, cycles):
1393
1394    class Exported(autotrackable.AutoTrackable):
1395
1396      @def_function.function
1397      def do(self, x, y):
1398        return x + y
1399
1400    exported = Exported()
1401    imported = cycle(
1402        exported, cycles, signatures=exported.do.get_concrete_function(
1403            tensor_spec.TensorSpec(None, dtypes.float32),
1404            tensor_spec.TensorSpec(None, dtypes.float32)))
1405    with self.assertRaises(TypeError):
1406      imported.signatures["serving_default"](
1407          constant_op.constant(1.),
1408          y=constant_op.constant(2.))
1409    self.assertEqual(
1410        {"output_0": 3.},
1411        self.evaluate(imported.signatures["serving_default"](
1412            x=constant_op.constant(1.),
1413            y=constant_op.constant(2.))))
1414
1415  def _make_model_with_tables(self):
1416    default_val = -1
1417    keys = constant_op.constant(["brain", "salad", "surgery"])
1418    values = constant_op.constant([0, 1, 2], dtypes.int64)
1419    table1_initializer = lookup_ops.KeyValueTensorInitializer(keys, values)
1420    table1 = lookup_ops.HashTable(table1_initializer, default_val)
1421
1422    table2_file = self._make_asset("test\nfoo\nbrain\n")
1423    table2_initializer = lookup_ops.TextFileIdTableInitializer(table2_file)
1424    table2 = lookup_ops.HashTable(table2_initializer, default_val)
1425
1426    def _make_lookup_function(table):
1427      signature = [tensor_spec.TensorSpec(None, dtypes.string)]
1428      return def_function.function(input_signature=signature)(
1429          lambda x: table.lookup(x))  # pylint: disable=unnecessary-lambda
1430
1431    root = autotrackable.AutoTrackable()
1432    root.table1 = table1
1433    root.lookup1 = _make_lookup_function(table1)
1434    root.table2 = table2
1435    root.lookup2 = _make_lookup_function(table2)
1436    return root
1437
1438  def test_table(self, cycles):
1439    root = self._make_model_with_tables()
1440    imported = cycle(root, cycles, signatures={})
1441    keys = constant_op.constant(["brain", "test", "foo", "surgery"])
1442    self.assertAllEqual([0, -1, -1, 2], imported.lookup1(keys).numpy())
1443    self.assertAllEqual([2, 0, 1, -1], imported.lookup2(keys).numpy())
1444
1445  def test_table_collections_untouched_eager(self, cycles):
1446
1447    def _gather_nonempty_collections():
1448      graph = ops.get_default_graph()
1449      gathered = {}
1450      for collection in graph.collections:
1451        collection_contents = graph.get_collection(collection)
1452        if collection_contents:
1453          gathered[collection] = collection_contents
1454      return gathered
1455
1456    root = self._make_model_with_tables()
1457    # Warm up collections to ignore those that don't expand every iteration,
1458    # e.g. the __varscope collection.
1459    cycle(root, 1)
1460    original_collections = _gather_nonempty_collections()
1461    cycle(root, cycles)
1462    self.assertEqual(original_collections, _gather_nonempty_collections())
1463
1464  def test_table_in_graph(self, cycles):
1465    root = self._make_model_with_tables()
1466
1467    if cycles > 1:
1468      root = cycle(root, cycles - 1)
1469    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
1470    save.save(root, path)
1471    imported = cycle(root, 1)
1472
1473    with ops.Graph().as_default():
1474      imported = load.load(path)
1475      keys = constant_op.constant(["brain", "test", "foo", "surgery"])
1476      output1 = imported.lookup1(keys)
1477      output2 = imported.lookup2(keys)
1478      with monitored_session.MonitoredSession() as sess:
1479        self.assertAllEqual([0, -1, -1, 2], sess.run(output1))
1480        self.assertAllEqual([2, 0, 1, -1], sess.run(output2))
1481
1482  def test_preserve_argspec(self, cycles):
1483
1484    def f(a, b, c):  # pylint: disable=unused-argument
1485      return None
1486
1487    original_fullargspec = tf_inspect.getfullargspec(f)
1488
1489    root = autotrackable.AutoTrackable()
1490    root.f = def_function.function(f)
1491    imported = cycle(root, cycles)
1492
1493    restored_fullargspec = tf_inspect.getfullargspec(imported.f)
1494    self.assertEqual(original_fullargspec, restored_fullargspec)
1495
1496  def test_canonicalize_inputs(self, cycles):
1497    @def_function.function(autograph=False)
1498    def func(a=1, b=2, c=3, training=True):
1499      if training:
1500        return [a, b, c, training]
1501      else:
1502        return [c, b, a, training]
1503
1504    # TODO(b/123501567): Work-around to trigger generic traces of a function
1505    # with extra non tensor args.
1506    signature = 3*[tensor_spec.TensorSpec(None, dtypes.float32)]
1507    @def_function.function(input_signature=signature)
1508    def trigger(a, b, c):
1509      func(a, b, c, True)
1510      func(a, b, c, False)
1511
1512    trigger.get_concrete_function()
1513
1514    root = autotrackable.AutoTrackable()
1515    root.f = func
1516    root = cycle(root, cycles)
1517    self.assertAllEqual(root.f(), [1.0, 2.0, 3.0, True])
1518    self.assertAllEqual(root.f(-1.0, training=False), [3.0, 2.0, -1.0, False])
1519
1520    with self.assertRaisesRegex(ValueError,
1521                                "Could not find matching concrete function"):
1522      root.f(["hello", 1.0])
1523
1524  def test_prefer_specific_trace(self, cycles):
1525    @def_function.function(autograph=False)
1526    def func(a):
1527      if isinstance(a, int):
1528        return a
1529      else:
1530        return a + 1
1531
1532    self.assertAllEqual(2, func(2).numpy())
1533    self.assertAllEqual(3, func(constant_op.constant(2)).numpy())
1534
1535    root = autotrackable.AutoTrackable()
1536    root.f = func
1537    root = cycle(root, cycles)
1538    self.assertAllEqual(2, root.f(2).numpy())
1539    self.assertAllEqual(4, root.f(3).numpy())
1540    self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy())
1541    self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy())
1542
1543  def test_partial(self, cycles):
1544    def f(x, y):
1545      return x + y
1546
1547    func = def_function.function(
1548        functools.partial(f, x=array_ops.zeros([1]), y=array_ops.ones([1])))
1549
1550    root = autotrackable.AutoTrackable()
1551    root.f = func
1552    self.assertAllEqual(root.f(), [1.0])
1553
1554    root = cycle(root, cycles)
1555    self.assertAllEqual(root.f(), [1.0])
1556
1557  def test_partial_with_non_tensor_defaults(self, cycles):
1558
1559    def f(x, y=3):
1560      return x + y
1561
1562    func = def_function.function(functools.partial(f, y=5))
1563
1564    root = autotrackable.AutoTrackable()
1565    root.f = func
1566    self.assertAllEqual(root.f(1), 6)
1567
1568    root = cycle(root, cycles)
1569    self.assertAllEqual(root.f(1), 6)
1570
1571  def test_partial_with_positional(self, cycles):
1572    def f(x, y):
1573      return x + y
1574
1575    func = def_function.function(functools.partial(f, constant_op.constant(5)))
1576
1577    root = autotrackable.AutoTrackable()
1578    root.f = func
1579    self.assertAllEqual(root.f(1), 6)
1580
1581    root = cycle(root, cycles)
1582    self.assertAllEqual(root.f(1), 6)
1583
1584  def test_partial_with_positional_captured_tensors(self, cycles):
1585
1586    def f(x, y):
1587      return x + y
1588
1589    tensor = constant_op.constant(5) + constant_op.constant(7)
1590    func = def_function.function(functools.partial(f, tensor))
1591
1592    root = autotrackable.AutoTrackable()
1593    root.f = func
1594    self.assertAllEqual(root.f(1), 13)
1595
1596    root = cycle(root, cycles)
1597    self.assertAllEqual(root.f(1), 13)
1598
1599  def test_partial_keyword_hiding_default(self, cycles):
1600
1601    def f(x=3, training=True, y=7):
1602      if training:
1603        return x + y
1604      else:
1605        return x + y + 2
1606
1607    func = def_function.function(functools.partial(f, y=6))
1608
1609    root = autotrackable.AutoTrackable()
1610    root.f = func
1611    self.assertEqual(root.f().numpy(), 9)
1612    self.assertEqual(root.f(training=False).numpy(), 11)
1613
1614    root = cycle(root, cycles)
1615    self.assertEqual(root.f().numpy(), 9)
1616    self.assertEqual(root.f(training=False).numpy(), 11)
1617
1618  def test_partial_with_kwargs(self, cycles):
1619
1620    def f(a, b, *args, **kwargs):
1621      args_sum = sum(args)
1622      return a + b + kwargs["some_tensor"] * kwargs["learning_rate"] + args_sum
1623
1624    constant_tensor = constant_op.constant(10)
1625    func = def_function.function(
1626        functools.partial(
1627            f, 7, 1, 2, learning_rate=3, some_tensor=constant_tensor))
1628
1629    root = autotrackable.AutoTrackable()
1630    root.f = func
1631    self.assertEqual(root.f(constant_op.constant(4)).numpy(), 44)
1632
1633    root = cycle(root, cycles)
1634    self.assertEqual(root.f(constant_op.constant(5)).numpy(), 45)
1635
1636  def test_partial_bind_only_first_argument(self, cycles):
1637    if sys.version_info[0] < 3:
1638      self.skipTest("Test is only valid in python3. Only then we get some more "
1639                    "advanced inspection of partials where this is allowed.")
1640
1641    def f(x, y):
1642      return x + y
1643
1644    partial_func = functools.partial(f, x=5)
1645    tf_func = def_function.function(partial_func)
1646
1647    root = autotrackable.AutoTrackable()
1648    root.f = tf_func
1649    self.assertAllEqual(root.f(y=constant_op.constant(7)), 12)
1650
1651    root = cycle(root, cycles)
1652    self.assertAllEqual(root.f(y=constant_op.constant(9)), 14)
1653
1654  def test_partial_with_passed_fn_as_default(self, cycles):
1655
1656    def f(x, y):
1657      return x(3) + y
1658
1659    def my_func(a):
1660      return 2 * a
1661
1662    func = def_function.function(functools.partial(f, my_func))
1663
1664    root = autotrackable.AutoTrackable()
1665    root.f = func
1666    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
1667
1668    root = cycle(root, cycles)
1669    self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9)
1670
1671  def test_partial_with_input_signature(self, cycles):
1672
1673    def full_function(a, b, c=3.0):
1674      return a, b, c
1675
1676    partial = functools.partial(full_function, 1, c=4)
1677    self.assertAllEqual((1, 2.0, 4), partial(2.0))
1678
1679    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
1680    func = def_function.function(partial, input_signature=signature)
1681
1682    root = autotrackable.AutoTrackable()
1683    root.f = func
1684    a, b, c = root.f(2.0)
1685    self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 2.0, 4))
1686
1687    root = cycle(root, cycles)
1688    a, b, c = root.f(3.0)
1689    self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 3.0, 4))
1690
1691  def test_convert_to_input_signature(self, cycles):
1692
1693    @def_function.function(
1694        input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)])
1695    def func(x):
1696      return x
1697
1698    root = autotrackable.AutoTrackable()
1699    root.f = func
1700
1701    root = cycle(root, cycles)
1702
1703    self.assertEqual([2], root.f([2]).numpy())
1704
1705  def test_named_tuple(self, cycles):
1706
1707    class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])):
1708      pass
1709
1710    @def_function.function
1711    def f(x):
1712      return x.a + x.b
1713
1714    f.get_concrete_function(
1715        NamedTupleType(
1716            a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"),
1717            b=tensor_spec.TensorSpec(None, dtypes.float32, name="b")))
1718    obj = autotrackable.AutoTrackable()
1719    obj.__call__ = f
1720    if sys.version_info.major == 3 and sys.version_info.minor < 5:
1721      # TODO(allenl): figure out why this doesn't work in Python3.4
1722      self.skipTest("Not working in Python 3.4")
1723    imported = cycle(obj, cycles)
1724    self.assertAllClose(3.,
1725                        imported(NamedTupleType(a=constant_op.constant(1.),
1726                                                b=constant_op.constant(2.))))
1727
1728  def test_extra_args(self, cycles):
1729
1730    @def_function.function
1731    def f(x):
1732      return math_ops.add(x["a"], 1.)
1733    # Trigger a trace.
1734    f({"a": constant_op.constant(2.0)})
1735
1736    obj = autotrackable.AutoTrackable()
1737    obj.__call__ = f
1738    imported = cycle(obj, cycles)
1739
1740    self.assertEqual(4.0, imported({"a": 3.0}).numpy())
1741
1742    with self.assertRaisesRegex(
1743        ValueError, "Could not find matching concrete function to call"):
1744      imported({"a": 2.0, "b": 3.0})
1745
1746  def test_shapes_available(self, cycles):
1747
1748    @def_function.function(input_signature=[
1749        tensor_spec.TensorSpec([None, 3], dtypes.int32),
1750        tensor_spec.TensorSpec([None, 2], dtypes.int32)
1751    ])
1752    def func(x, y):
1753      return array_ops.concat([x, y], axis=1)
1754
1755    root = autotrackable.AutoTrackable()
1756    root.f = func
1757
1758    root = cycle(root, cycles)
1759
1760    imported_graph = root.f.get_concrete_function().graph
1761    input_x, input_y = imported_graph.inputs
1762    self.assertEqual([None, 3], input_x.shape.as_list())
1763    self.assertEqual([None, 2], input_y.shape.as_list())
1764    output, = imported_graph.outputs
1765    self.assertEqual([None, 5], output.shape.as_list())
1766    signature = root.signatures["serving_default"]
1767    self.assertEqual(
1768        [None, 3], signature.inputs[0].shape.as_list())
1769    self.assertEqual(
1770        [None, 2], signature.inputs[1].shape.as_list())
1771    self.assertEqual(
1772        [None, 5], signature.outputs[0].shape.as_list())
1773
1774  def test_variables_destroyed(self, cycles):
1775    v1 = variables.Variable(1.)
1776    weak_v1 = weakref.ref(v1)
1777    root = checkpoint.Checkpoint(v=v1)
1778    root = cycle(root, cycles)
1779    del v1
1780    self.assertIsNone(weak_v1())
1781    weak_v2 = weakref.ref(root.v)
1782    del root
1783    self.assertIsNone(weak_v2())
1784
1785  def test_variable_attributes_preserved(self, cycles):
1786    v = variables.Variable(
1787        1.,
1788        trainable=False,
1789        synchronization=variables.VariableSynchronization.NONE,
1790        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
1791    self.assertEqual(variables.VariableSynchronization.NONE,
1792                     v.synchronization)
1793    self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA,
1794                     v.aggregation)
1795    root = autotrackable.AutoTrackable()
1796    root.v = v
1797    root = cycle(root, cycles)
1798    self.assertEqual(False, root.v.trainable)
1799    self.assertEqual(variables.VariableSynchronization.NONE,
1800                     root.v.synchronization)
1801    self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA,
1802                     root.v.aggregation)
1803
1804  def test_captured_dataset(self, cycles):
1805
1806    class HasDataset(module.Module):
1807
1808      def __init__(self):
1809        super(HasDataset, self).__init__()
1810        self.dataset = (
1811            dataset_ops.Dataset.range(5)
1812            .map(lambda x: x ** 2))
1813
1814      @def_function.function
1815      def __call__(self, x):
1816        current_sum = array_ops.zeros([], dtype=dtypes.int64)
1817        for element in self.dataset:
1818          current_sum += x * element
1819        return current_sum
1820
1821    root = HasDataset()
1822    self.assertEqual(
1823        3 * (1 + 4 + 9 + 16),
1824        root(constant_op.constant(3, dtype=dtypes.int64)).numpy())
1825    root = cycle(root, cycles)
1826    self.assertEqual(
1827        3 * (1 + 4 + 9 + 16),
1828        root(constant_op.constant(3, dtype=dtypes.int64)).numpy())
1829
1830  def test_tuple_signature(self, cycles):
1831    root = checkpoint.Checkpoint()
1832    root.f = def_function.function(
1833        lambda: (array_ops.ones([]), array_ops.zeros([])),
1834        input_signature=())
1835    root = cycle(root, cycles, signatures=root.f)
1836    self.assertEqual(({"output_0": 1., "output_1": 0.}),
1837                     self.evaluate(root.signatures["serving_default"]()))
1838
1839  def test_version_info(self, cycles):
1840    root = checkpoint.Checkpoint()
1841    root = cycle(root, cycles)
1842    self.assertEqual(versions.__version__, root.tensorflow_version)
1843    self.assertEqual(versions.__git_version__, root.tensorflow_git_version)
1844
1845  def test_load_grad_save(self, cycles):
1846    root = checkpoint.Checkpoint()
1847    root.v = variables.Variable(2.)
1848    root.f = def_function.function(lambda x: root.v * x)
1849    root.g = def_function.function(root.f)
1850    for _ in range(cycles):
1851      with backprop.GradientTape() as tape:
1852        inp = constant_op.constant(2.)
1853        tape.watch(inp)
1854        output = root.g(inp)
1855        self.assertAllClose(4., output)
1856      self.assertAllClose(2., tape.gradient(output, inp))
1857      root = cycle(root, 1)
1858
1859  def test_destroy_resource(self, cycles):
1860
1861    def get_handle():
1862      return resource_variable_ops.var_handle_op(
1863          shape=tensor_shape.as_shape([]),
1864          dtype=dtypes.float32,
1865          shared_name="my_var_name",
1866          name="my_var",
1867          container="my_container")
1868
1869    class MyResource(resource.TrackableResource):
1870
1871      def _create_resource(self):
1872        return get_handle()
1873
1874      def _initialize(self):
1875        resource_variable_ops.assign_variable_op(
1876            self.resource_handle, 1.0, name="assign")
1877
1878      def _destroy_resource(self):
1879        handle = get_handle()
1880        resource_variable_ops.destroy_resource_op(
1881            handle, ignore_lookup_error=True)
1882
1883    class MyModel(autotrackable.AutoTrackable):
1884
1885      def __init__(self):
1886        super(MyModel, self).__init__()
1887        self.resource = MyResource()
1888
1889      @def_function.function(input_signature=[])
1890      def increase(self):
1891        handle = self.resource.resource_handle
1892        resource_variable_ops.assign_add_variable_op(
1893            handle, 10.0, name="assign_add")
1894        return resource_variable_ops.read_variable_op(handle, dtypes.float32)
1895
1896    root = MyModel()
1897    imported = cycle(root, cycles)
1898    self.assertEqual(11, imported.increase().numpy())  # Create the resource.
1899
1900    handle = imported.resource.resource_handle
1901
1902    # Delete the imported SaveModel. Since we explicitly set the deleter, it
1903    # should destroy the resource automatically.
1904    del imported
1905
1906    # Try to destroy the resource again, should fail.
1907    with self.assertRaisesRegex(errors.NotFoundError,
1908                                r"Resource .* does not exist."):
1909      resource_variable_ops.destroy_resource_op(
1910          handle, ignore_lookup_error=False)
1911
1912  def test_function_called_as_operation(self, cycles):
1913
1914    @framework_function.Defun(dtypes.float32)
1915    def inner(x):
1916      return x + 1.
1917
1918    @def_function.function(
1919        input_signature=[tensor_spec.TensorSpec([], dtypes.float32)])
1920    def outer(x):
1921      return inner(x)
1922
1923    root = module.Module()
1924    root.f = outer
1925    imported = cycle(root, cycles)
1926    self.assertAllClose(2., imported.f(constant_op.constant(1.)))
1927
1928  def test_ragged(self, cycles):
1929
1930    @def_function.function
1931    def f(x, c=1):
1932      """Returns Tensor x incremented by Python constant c."""
1933      return math_ops.add(x, c)
1934
1935    for c in (1, 2, 3):
1936      _ = f.get_concrete_function(
1937          ragged_tensor.RaggedTensorSpec([None, None], dtype=dtypes.int32),
1938          c)
1939
1940    obj = autotrackable.AutoTrackable()
1941    obj.f = f
1942
1943    imported1 = cycle(obj, cycles, signatures={})
1944    rt = ragged_factory_ops.constant([[1, 2], [3]])
1945    self.assertAllEqual(imported1.f(rt), [[2, 3], [4]])
1946    self.assertAllEqual(imported1.f(rt, 2), [[3, 4], [5]])
1947    self.assertAllEqual(imported1.f(rt, 3), [[4, 5], [6]])
1948
1949    imported2 = cycle(obj, cycles)
1950    rt = ragged_factory_ops.constant([[1, 2], [3]])
1951    self.assertAllEqual(imported2.f(rt, 1), [[2, 3], [4]])
1952    self.assertAllEqual(imported2.f(rt, 2), [[3, 4], [5]])
1953    self.assertAllEqual(imported2.f(rt, 3), [[4, 5], [6]])
1954
1955  def test_accepts_io_device(self, cycles):
1956    options = load_options.LoadOptions()
1957    self.assertIsNone(options.experimental_io_device)
1958    options = load_options.LoadOptions(experimental_io_device="/job:localhost")
1959    self.assertEqual("/job:localhost", options.experimental_io_device)
1960
1961  def _custom_saveable_object(self, cycles):
1962    if context.is_tfrt_enabled():
1963      self.skipTest("Disable due to b/190539415.")
1964    root = autotrackable.AutoTrackable()
1965    root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
1966    root.table.insert("foo", 15)
1967    root.table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
1968    root.table2.insert("idk", 21)
1969
1970    @def_function.function(
1971        input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
1972    def lookup(key):
1973      return root.table.lookup(key)
1974
1975    root.lookup = lookup
1976
1977    imported = cycle(root, cycles)
1978    self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
1979    self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
1980
1981    if not saveable_compat.force_checkpoint_conversion_enabled():
1982      self.assertEqual({"table"},
1983                       imported.table._self_saveable_object_factories.keys())
1984
1985  def test_load_custom_saveable_object(self, cycles):
1986    self._custom_saveable_object(cycles)
1987
1988  def test_load_custom_saveable_object_ckpt_conversion(self, cycles):
1989    # Tests custom saveable object with checkpoint conversion enabled (forces
1990    # Trackable-based checkpoint implementation).
1991    saveable_compat.force_checkpoint_conversion()
1992    self._custom_saveable_object(cycles)
1993
1994  def test_load_resource_with_dependency(self, cycles):
1995    # Test with StaticHashTable, which has a _initializer attribute that tracks
1996    # the Asset vocab table.
1997
1998    class MyLookupModel(autotrackable.AutoTrackable):
1999
2000      def __init__(self, vocab_file):
2001
2002        vocab_initializer = lookup_ops.TextFileInitializer(
2003            vocab_file,
2004            key_dtype=dtypes.string,
2005            key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
2006            value_dtype=dtypes.int64,
2007            value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
2008        self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer,
2009                                                       default_value=-1)
2010
2011      @def_function.function(input_signature=[
2012          tensor_spec.TensorSpec((None,), dtypes.string)])
2013      def __call__(self, inputs):
2014        return self._vocab_table.lookup(inputs)
2015
2016    vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
2017    root = MyLookupModel(vocab_file)
2018    imported = cycle(root, cycles)
2019    file_io.delete_file(vocab_file)
2020    self.assertAllEqual(imported(constant_op.constant(["d", "b"])),
2021                        [3, 1])
2022
2023  def test_custom_gradients(self, cycles):
2024
2025    @custom_gradient.custom_gradient
2026    def log1pexp(x):
2027      e = math_ops.exp(x)
2028
2029      def grad(dy):
2030        return dy * e  # incorrect to check the custom gradients is respected.
2031
2032      return math_ops.log(1 + e), grad
2033
2034    @def_function.function
2035    def g(x):
2036      y = log1pexp(x)
2037
2038      @def_function.function
2039      def g_nest():
2040        return log1pexp(y)
2041
2042      return g_nest()
2043
2044    @def_function.function
2045    def f(x):
2046      return log1pexp(g(x * x))
2047
2048    v = variables.Variable(1.)
2049
2050    with backprop.GradientTape() as tape2:
2051      with backprop.GradientTape() as tape:
2052        tape.watch(v)
2053        y = f(v)
2054        expected_grads = tape.gradient(y, v)
2055      expected_grad_grads = tape2.gradient(expected_grads, v)
2056
2057    root = autotrackable.AutoTrackable()
2058    root.f = f
2059    loaded = cycle(
2060        root, cycles, options=save_options.SaveOptions(
2061            experimental_custom_gradients=True))
2062    with backprop.GradientTape() as tape2:
2063      with backprop.GradientTape() as tape:
2064        tape.watch(v)
2065        y = loaded.f(v)
2066        grads = tape.gradient(y, v)
2067      grad_grads = tape2.gradient(grads, v)
2068
2069    self.assertAllClose(grads, expected_grads)
2070    self.assertAllClose(grad_grads, expected_grad_grads)
2071
2072  def test_custom_gradients_with_none_grad(self, cycles):
2073    # https://github.com/google/jax/issues/7123
2074
2075    @custom_gradient.custom_gradient
2076    def f(params, state):
2077      def grad_fn(*args):
2078        return args
2079      return (params, state), grad_fn
2080    @def_function.function(input_signature=[
2081        tensor_spec.TensorSpec([], dtypes.float32),
2082        tensor_spec.TensorSpec([], dtypes.int32)])
2083    def predict(params, state):
2084      return f(params, state)
2085
2086    params = variables.Variable(1.0)
2087    # None grads only appear when state is an int.
2088    state = constant_op.constant(3, dtype=dtypes.int32)
2089    with backprop.GradientTape() as tape:
2090      tape.watch(params)
2091      y = predict(params, state)
2092      expected_grads = tape.gradient(y, params)
2093
2094    root = autotrackable.AutoTrackable()
2095    root.fn = predict
2096    loaded = cycle(
2097        root, cycles, options=save_options.SaveOptions(
2098            experimental_custom_gradients=True))
2099
2100    with backprop.GradientTape() as tape:
2101      tape.watch(params)
2102      y = loaded.fn(params, state)
2103      grads = tape.gradient(y, params)
2104
2105    self.assertAllClose(grads, expected_grads)
2106
2107
2108class SingleCycleTests(test.TestCase, parameterized.TestCase):
2109
2110  def test_load_with_tags(self):
2111    root = autotrackable.AutoTrackable()
2112    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
2113    save.save(root, path)
2114    with self.assertRaises(ValueError):
2115      load.load(path, tags=[tag_constants.EVAL])
2116    load.load(path, tags=[tag_constants.SERVING])
2117    load.load(path, tags=tag_constants.SERVING)
2118    load.load(path, tags=set([tag_constants.SERVING]))
2119
2120  def test_save_load_contains_with_fspath(self):
2121    root = autotrackable.AutoTrackable()
2122    path = pathlib.Path(tempfile.mkdtemp(prefix=self.get_temp_dir()))
2123    save.save(root, path)
2124    self.assertTrue(loader_impl.contains_saved_model(path))
2125    load.load(path)
2126
2127  def test_single_restore_op_used(self):
2128    root = module.Module()
2129    root.v1 = variables.Variable(1.)
2130    root.v2 = variables.Variable(2.)
2131    root.v3 = variables.Variable(3.)
2132    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
2133    save.save(root, path)
2134    restore_count = 0
2135
2136    def _count_restores(op_type, *unused_args, **unused_kwargs):
2137      nonlocal restore_count
2138      if op_type == b"RestoreV2":
2139        restore_count += 1
2140
2141    op_callbacks.add_op_callback(_count_restores)
2142    load.load(path)
2143    op_callbacks.remove_op_callback(_count_restores)
2144    self.assertEqual(1, restore_count)
2145
2146  def test_docstring_examples(self):
2147    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
2148    exported = checkpoint.Checkpoint(v=variables.Variable(3.))
2149    exported.f = def_function.function(
2150        lambda x: exported.v * x,
2151        input_signature=[
2152            tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)])
2153    save.save(exported, path)
2154    imported = load.load(path)
2155    self.assertEqual(3., imported.v.numpy())
2156    self.assertEqual(6., imported.f(x=constant_op.constant(2.)).numpy())
2157
2158    save.save(exported, path, exported.f.get_concrete_function())
2159    imported = load.load(path)
2160    f = imported.signatures["serving_default"]
2161    self.assertAllEqual(
2162        [[-3.]],
2163        f(x=constant_op.constant([[-1.]]))["output_0"].numpy())
2164
2165  def test_object_with_extra_dependencies(self):
2166
2167    class Extra(autotrackable.AutoTrackable):
2168
2169      def _trackable_children(self, save_type, **kwargs):
2170        children = super(Extra, self)._trackable_children(save_type, **kwargs)
2171        children["a"] = variables.Variable(5.)
2172        return children
2173
2174    root = Extra()
2175    path = tempfile.mkdtemp(prefix=self.get_temp_dir())
2176    save.save(root, path)
2177    imported = load.load(path)
2178    self.assertEqual(5, self.evaluate(imported.a))
2179
2180  def test_save_cached_variable(self):
2181    with ops.Graph().as_default(), session_lib.Session() as session:
2182      obj = autotrackable.AutoTrackable()
2183      obj.v = variables.Variable(2., caching_device=lambda op: op.device)
2184      obj.w = variables.Variable(3.)
2185      session.run([obj.v.initializer, obj.w.initializer])
2186
2187      @def_function.function
2188      def total():
2189        return obj.v + obj.w
2190
2191      @def_function.function(input_signature=[tensor_spec.TensorSpec([])])
2192      def wrapped_total(x):
2193        return total() + x
2194
2195      @def_function.function
2196      def increment_v(x):
2197        obj.v.assign_add(x)
2198
2199      session.run(increment_v(constant_op.constant(3.)))  # generate signatures
2200      self.assertAllClose(8, total())
2201      self.assertAllClose(13, wrapped_total(constant_op.constant(5.)))
2202
2203      obj.total = total
2204      obj.wrapped_total = wrapped_total.get_concrete_function()
2205      obj.increment_v = increment_v
2206
2207      save_dir = os.path.join(self.get_temp_dir(), "saved_model")
2208      save.save(obj, save_dir, signatures=total.get_concrete_function())
2209      imported = load.load(save_dir)
2210      session.run(variables.global_variables_initializer())
2211      self.assertAllClose(8, imported.total())
2212      session.run(imported.increment_v(4))
2213      self.assertAllClose(12, imported.total())
2214      self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.)))
2215      self.assertAllClose({"output_0": 12},
2216                          imported.signatures["serving_default"]())
2217
2218    # Try loading and running the function in eager mode
2219    imported = load.load(save_dir)
2220    self.assertAllClose(8, imported.total())
2221    imported.increment_v(5)
2222    self.assertAllClose(13, imported.total())
2223    self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(.5)))
2224    self.assertAllClose({"output_0": 13},
2225                        imported.signatures["serving_default"]())
2226
2227  # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3
2228  # iterations took hundreds of seconds). It would be really nice to check
2229  # allocations at a lower level.
2230  @test_util.assert_no_new_pyobjects_executing_eagerly
2231  def test_functions_cleaned(self):
2232    if sys.version_info.major < 3:
2233      self.skipTest("Not working in Python 2")
2234    root = module.Module()
2235    root.v = variables.Variable(1.)
2236    root.f = def_function.function(
2237        lambda x: x + root.v,
2238        input_signature=[
2239            tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)])
2240    cycle(root, 1)
2241
2242  def test_load_partial_object(self):
2243    root = module.Module()
2244    root.variables_holder = module.Module()
2245    root.variables_holder.v = variables.Variable(1.)
2246
2247    class Adder(module.Module):
2248
2249      @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=[])])
2250      def __call__(self, y):
2251        root.variables_holder.v.assign_add(y)
2252        return 1
2253
2254    root.adder = Adder()
2255
2256    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
2257    save.save(root, save_dir)
2258
2259    imported = load.load_partial(save_dir,
2260                                 ["root.variables_holder.v", "root.adder"])
2261    v = imported["root.variables_holder.v"]
2262    adder = imported["root.adder"]
2263    self.assertEqual(self.evaluate(v), 1)
2264    adder(5)
2265    self.assertEqual(self.evaluate(v), 6)
2266
2267    with self.assertRaisesRegex(
2268        ValueError, "does not include all required objects for loading"):
2269      imported = load.load_partial(save_dir, ["root.adder"])
2270
2271  def test_load_partial_checkpoint(self):
2272    root = module.Module()
2273    root.variables_holder = module.Module()
2274    root.variables_holder.v = variables.Variable(1.)
2275
2276    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
2277    save.save(root, save_dir)
2278
2279    loaded = module.Module()
2280    loaded.v = variables.Variable(2.)
2281
2282    load.load_partial(
2283        save_dir, {"root": loaded},
2284        options=load_options.LoadOptions(allow_partial_checkpoint=True))
2285    self.assertEqual(loaded.variables_holder.v.numpy(), 1)
2286    with self.assertRaisesRegex(AssertionError, "were not bound"):
2287      load.load_partial(save_dir, {"root": loaded})
2288
2289  def test_call_untraced_function_raises_error(self):
2290
2291    class ObjWithFunction(module.Module):
2292
2293      @def_function.function
2294      def foo(self, a):
2295        return a
2296
2297    root = ObjWithFunction()
2298    with self.assertLogs(level="WARNING") as logs:
2299      loaded = cycle(root, 1)
2300
2301    expected_save_message = (
2302        "WARNING:absl:Found untraced functions such as foo while saving "
2303        "(showing 1 of 1). These functions will not be directly callable after "
2304        "loading.")
2305    self.assertIn(expected_save_message, logs.output)
2306
2307    with self.assertRaisesRegex(
2308        ValueError, "Found zero restored functions for caller function."):
2309      loaded.foo(1)
2310
2311  def test_restored_function_execute_eagerly(self):
2312    try:
2313      def_function.run_functions_eagerly(True)
2314
2315      class MyModel(module.Module):
2316
2317        @def_function.function
2318        def __call__(self, inputs, training=False):
2319          return math_ops.multiply(0.5, inputs)
2320
2321      model = MyModel()
2322      model.__call__.get_concrete_function(
2323          tensor_spec.TensorSpec([None], dtypes.float32))
2324      loaded = cycle(model, 1)
2325
2326      # Calling the function should not throw an exception.
2327      loaded(constant_op.constant([1.0]))
2328
2329    finally:
2330      def_function.run_functions_eagerly(False)
2331
2332  def test_restored_model_concrete_function_is_deterministic(self):
2333    previous_concrete_function = None
2334    for _ in range(100):
2335
2336      class MyModel(module.Module):
2337
2338        @def_function.function
2339        def __call__(self, x):
2340          return x * constant_op.constant(3.0)
2341
2342      model = MyModel()
2343      model(array_ops.ones((7, 3), dtype=dtypes.float32))
2344      model.__call__.get_concrete_function(
2345          tensor_spec.TensorSpec([None, 3], dtypes.float32))
2346      loaded = cycle(model, 1)
2347
2348      # Ensure the newly loaded concrete function is the same as the previous
2349      # after a cycle of serialization / deserialization.
2350      new_concrete_function = loaded.__call__.get_concrete_function(
2351          tensor_spec.TensorSpec([None, 3], dtypes.float32))
2352      if previous_concrete_function is not None:
2353        self.assertEqual(previous_concrete_function.pretty_printed_signature(),
2354                         new_concrete_function.pretty_printed_signature())
2355
2356      previous_concrete_function = new_concrete_function
2357
2358  def test_garbage_collection_capturable_resource_doesnt_raise_exception(self):
2359    model = module.Module()
2360    model.mapping = lookup_ops.StaticHashTable(
2361        lookup_ops.KeyValueTensorInitializer(
2362            keys=math_ops.range(1, dtype=dtypes.int32),
2363            values=["foo"]),
2364        "default_value")
2365    loaded = cycle(model, 1)
2366    del model
2367    del loaded
2368    # Exceptions raised during garbage collection are simply printed to stderr
2369    # and ignored, and we have no way to access them. We'll capture stdout
2370    # during the garbage collection process and inspect to see if any
2371    # exceptions were raised.
2372    stderr = io.StringIO()
2373    with contextlib.redirect_stderr(stderr):
2374      gc.collect()
2375    if "Exception ignored in" in stderr.getvalue():
2376      raise Exception(stderr.getvalue())
2377
2378  def test_captured_dataset_with_asset(self):
2379
2380    class HasDataset(module.Module):
2381
2382      def __init__(self, temp_dir, file_name):
2383        super(HasDataset, self).__init__()
2384        file = os.path.join(temp_dir, file_name)
2385        with tf_record.TFRecordWriter(file, "GZIP") as f:
2386          for v in ["a", "aa", "aaa"]:
2387            f.write(str(v))
2388        self.dataset = readers.TFRecordDataset([file], compression_type="GZIP")
2389
2390      @def_function.function
2391      def __call__(self, x):
2392        current_sum = array_ops.zeros([], dtype=dtypes.int32)
2393        for element in self.dataset:
2394          current_sum += x * string_ops.string_length(element)
2395        return current_sum
2396
2397    temp_dir = self.get_temp_dir()
2398    file_name = "tf_record_asset.tfrecord.gz"
2399    root = HasDataset(temp_dir, file_name)
2400    self.assertEqual(
2401        18,  # 3 * (1 + 2 + 3)
2402        root(constant_op.constant(3, dtype=dtypes.int32)).numpy())
2403
2404    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
2405    save.save(root, save_dir)
2406
2407    file_io.delete_file(os.path.join(temp_dir, file_name))
2408    asset_path = os.path.join(save_dir, "assets/{}".format(file_name))
2409    self.assertTrue(file_io.file_exists(asset_path))
2410    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
2411    file_io.rename(save_dir, load_dir)
2412
2413    loaded = load.load(load_dir)
2414    self.assertEqual(
2415        18,  # 3 * (1 + 2 + 3)
2416        loaded(constant_op.constant(3, dtype=dtypes.int32)).numpy())
2417
2418
2419class DeferredInitModuleVariablesTest(test.TestCase):
2420
2421  def test_deferred_init_module_variables(self):
2422    """Defer initialization of variables in a module to the load stage."""
2423
2424    class MyModule(module.Module):
2425
2426      def __init__(self, size):
2427        super().__init__()
2428        self.size = size
2429        # variable initialized by a Tensor-compatible value
2430        self.w1 = variables.Variable(
2431            constant_op.constant(1., shape=[self.size]), trainable=False)
2432        # variable initialized by a function
2433        self.w2 = variables.Variable(
2434            lambda: constant_op.constant(2., shape=[self.size]))
2435        # variable instantiated lazily in call()
2436        self.w3 = None
2437
2438      def call(self):
2439        if self.w3 is None:
2440          self.w3 = variables.Variable(
2441              constant_op.constant(3., shape=[self.size]))
2442        for w in (self.w1, self.w2, self.w3):
2443          w.assign_add(constant_op.constant(1., shape=[self.size]))
2444        return self.w1, self.w2, self.w3
2445
2446    def export_initializer(initial_value, export_dir):
2447
2448      class Initializer(module.Module):
2449
2450        @def_function.function(input_signature=[])
2451        def call(self):
2452          if callable(initial_value):
2453            return initial_value()
2454          return initial_value
2455
2456      save.save(Initializer(), export_dir)
2457
2458    def create_and_save_module(weight_size):
2459
2460      initial_values = {}  # For storing initial_value of created variables
2461
2462      def variable_creator(next_creator, **kwargs):
2463        variable = next_creator(**kwargs)
2464        variable_name = variable.name
2465        if ":" in variable_name:
2466          variable_name = variable_name[:variable_name.index(":")]
2467        initial_values[variable_name] = kwargs["initial_value"]
2468        return variable
2469
2470      export_dir = self.create_tempdir().full_path
2471
2472      with ops.Graph().as_default():
2473        with variable_scope.variable_creator_scope(variable_creator):
2474          exported = MyModule(weight_size)
2475          exported.call = def_function.function(input_signature=[])(
2476              exported.call)
2477
2478          module_dir = f"{export_dir}/module"
2479          file_io.recursive_create_dir(module_dir)
2480          save.save_and_return_nodes(
2481              exported, module_dir, experimental_skip_checkpoint=True)
2482
2483      # Save the initializer of the created variables.
2484      for variable_name, initial_value in initial_values.items():
2485        export_initializer(initial_value,
2486                           f"{export_dir}/variables/{variable_name}")
2487
2488      return export_dir
2489
2490    def load_and_run_module(export_dir, weight_size):
2491
2492      # pylint: disable=unused-argument
2493      def layer_variable_creator(next_creator, **kwargs):
2494        variable_dir = f"{export_dir}/variables/{kwargs['name']}"
2495        initializer = load.load(variable_dir)
2496        kwargs["initial_value"] = initializer.call
2497        variable = resource_variable_ops.ResourceVariable(**kwargs)
2498        return variable
2499
2500      with ops.Graph().as_default():
2501        with variable_scope.variable_creator_scope(layer_variable_creator):
2502          imported = load.load(
2503              f"{export_dir}/module",
2504              options=load_options.LoadOptions(
2505                  experimental_skip_checkpoint=True))
2506        outputs = imported.call()
2507
2508        with self.cached_session() as sess:
2509          variables.global_variables_initializer().run()
2510          # Check if variables work as expected across multiple iterations.
2511          for i in range(3):
2512            np_outputs = sess.run(outputs)
2513            for j, np_output in enumerate(np_outputs):
2514              self.assertAllClose(np_output, np.full(weight_size, i + j + 2))
2515
2516    # The size of the serialized content (both module and variables) stays
2517    # small even with a large weight_size as the initial values are not stored
2518    # in checkpoints.
2519    weight_size = 1024
2520    export_dir = create_and_save_module(weight_size)
2521    load_and_run_module(export_dir, weight_size)
2522
2523  def _make_asset(self, contents):
2524    fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir())
2525    with os.fdopen(fd, "w") as f:
2526      f.write(contents)
2527    return filename
2528
2529  def test_assets(self):
2530
2531    class MyLookupModel(autotrackable.AutoTrackable):
2532
2533      def __init__(self, vocab_file):
2534
2535        vocab_initializer = lookup_ops.TextFileInitializer(
2536            vocab_file,
2537            key_dtype=dtypes.string,
2538            key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
2539            value_dtype=dtypes.int64,
2540            value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
2541        self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer,
2542                                                       default_value=-1)
2543
2544      @def_function.function(input_signature=[
2545          tensor_spec.TensorSpec((None,), dtypes.string)])
2546      def __call__(self, inputs):
2547        return self._vocab_table.lookup(inputs)
2548
2549    vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"]))
2550    root = MyLookupModel(vocab_file)
2551
2552    save_dir = os.path.join(self.get_temp_dir(), "save_dir")
2553    save.save_and_return_nodes(
2554        root, save_dir, experimental_skip_checkpoint=True)
2555    file_io.delete_file(vocab_file)
2556    load_dir = os.path.join(self.get_temp_dir(), "load_dir")
2557    file_io.rename(save_dir, load_dir)
2558
2559    imported = load.load(
2560        load_dir,
2561        options=load_options.LoadOptions(experimental_skip_checkpoint=True))
2562    self.assertAllEqual(imported(constant_op.constant(["d", "b"])),
2563                        [3, 1])
2564
2565
2566class _TestModel(module.Module):
2567
2568  def __init__(self, rows, cols):
2569    super().__init__()
2570    self.rows = rows
2571    self.cols = cols
2572    self.table = None
2573
2574  def __call__(self, x):
2575    with ops.device("/cpu:0"):
2576      self.table = variables.Variable(
2577          constant_op.constant(1., shape=[self.rows, self.cols]))
2578      x = math_ops.matmul(self.table, x)
2579      x = math_ops.reduce_sum(x, axis=0)
2580    return x
2581
2582
2583class SavedModelLoadMemoryTests(test.TestCase):
2584
2585  @test_util.run_gpu_only
2586  def test_no_oom_loading_large_tenor(self):
2587    if not config.get_soft_device_placement():
2588      self.skipTest("This test only works for soft device placement is on")
2589    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
2590    ncols = 16
2591    nrows = 32
2592    model = _TestModel(rows=nrows, cols=ncols)
2593    x = array_ops.zeros(shape=(ncols, 2), dtype=dtypes.float32)
2594    y = model(x)
2595    save.save(
2596        model,
2597        save_dir,
2598        options=save_options.SaveOptions(
2599            experimental_variable_policy=save_options.VariablePolicy
2600            .SAVE_VARIABLE_DEVICES),
2601    )
2602    loaded_on_cpu = load.load(
2603        export_dir=save_dir,
2604        options=load_options.LoadOptions(
2605            experimental_variable_policy=save_options.VariablePolicy
2606            .SAVE_VARIABLE_DEVICES),
2607    )
2608    loaded_on_gpu = load.load(export_dir=save_dir)
2609    self.assertTrue("CPU" in loaded_on_cpu.table.device)
2610    self.assertTrue("GPU" in loaded_on_gpu.table.device)
2611
2612
2613if __name__ == "__main__":
2614  test.main()
2615