1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests saving with registered Trackable classes and checkpoint functions."""
16
17import os
18import tempfile
19
20from absl.testing import parameterized
21
22from google.protobuf import wrappers_pb2
23from tensorflow.python.checkpoint import checkpoint as util
24from tensorflow.python.client import session
25from tensorflow.python.eager import context
26from tensorflow.python.eager import def_function
27from tensorflow.python.eager import test
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import errors_impl
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import io_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import gfile
37from tensorflow.python.saved_model import load
38from tensorflow.python.saved_model import loader
39from tensorflow.python.saved_model import registration
40from tensorflow.python.saved_model import save
41from tensorflow.python.trackable import autotrackable
42
43
44@registration.register_serializable()
45class Part(resource_variable_ops.ResourceVariable):
46
47  def __init__(self, value):
48    self._init_from_args(value)
49
50  @classmethod
51  def _deserialize_from_proto(cls, **kwargs):
52    return cls([0, 0])
53
54  def _export_to_saved_model_graph(self, object_map, tensor_map, **kwargs):
55    p = Part(array_ops.zeros(self.shape, self.dtype))
56    object_map[self] = p
57    tensor_map[self.handle] = p.handle
58    return [self.handle]
59
60
61@registration.register_serializable()
62class Stack(autotrackable.AutoTrackable):
63
64  def __init__(self, parts=None):
65    self.parts = parts
66
67  @def_function.function(input_signature=[])
68  def value(self):
69    return array_ops.stack(self.parts)
70
71
72def get_tensor_slices(trackables):
73  tensor_names = []
74  shapes_and_slices = []
75  tensors = []
76  restored_trackables = []
77  for obj_prefix, obj in trackables.items():
78    if isinstance(obj, Part):
79      continue  # only save stacks
80    tensor_names.append(obj_prefix + "/value")
81    shapes_and_slices.append("")
82    x = obj.value()
83    with ops.device("/device:CPU:0"):
84      tensors.append(array_ops.identity(x))
85    restored_trackables.append(obj)
86
87  return tensor_names, shapes_and_slices, tensors, restored_trackables
88
89
90def save_stacks_and_parts(trackables, file_prefix):
91  """Save stack and part objects to a checkpoint shard."""
92  tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables)
93  io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
94  return file_prefix
95
96
97def restore_stacks_and_parts(trackables, merged_prefix):
98  tensor_names, shapes_and_slices, tensors, restored_trackables = (
99      get_tensor_slices(trackables))
100  dtypes = [t.dtype for t in tensors]
101  restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
102                                       shapes_and_slices, dtypes)
103  for trackable, restored_tensor in zip(restored_trackables, restored_tensors):
104    expected_shape = trackable.value().get_shape()
105    restored_tensor = array_ops.reshape(restored_tensor, expected_shape)
106    parts = array_ops.unstack(restored_tensor)
107    for part, restored_part in zip(trackable.parts, parts):
108      part.assign(restored_part)
109
110
111registration.register_checkpoint_saver(
112    name="stacks",
113    predicate=lambda x: isinstance(x, (Stack, Part)),
114    save_fn=save_stacks_and_parts,
115    restore_fn=restore_stacks_and_parts)
116
117
118def cycle(obj, cycles, signatures=None, options=None):
119  to_save = obj
120  for _ in range(cycles):
121    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
122    # If available, we'll run the save and restore preferring the GPU. This
123    # just makes sure we aren't throwing errors and have enough
124    # device("CPU") blocks to satisfy the placer.
125    with test_util.use_gpu():
126      save.save(to_save, path, signatures, options=options)
127      loaded = load.load(path)
128      signatures = loaded.signatures
129    to_save = loaded
130  return loaded
131
132
133@parameterized.named_parameters(
134    dict(testcase_name="ReloadOnce", cycles=1),
135    dict(testcase_name="ReloadTwice", cycles=2),
136    dict(testcase_name="ReloadThrice", cycles=3))
137class SavedModelTest(test.TestCase, parameterized.TestCase):
138
139  def test_registered_serializable(self, cycles):
140
141    @registration.register_serializable(name=f"SaveAndLoad{cycles}")
142    class Module(autotrackable.AutoTrackable):
143
144      def __init__(self, name="module"):
145        self.v = variables.Variable(1.)
146        self.name = name
147
148      def _serialize_to_proto(self, **unused_kwargs):
149        return wrappers_pb2.StringValue(value=self.name)
150
151      @classmethod
152      def _deserialize_from_proto(cls, proto, **unused_kwargs):
153        if proto.Is(wrappers_pb2.StringValue.DESCRIPTOR):
154          unpacked = wrappers_pb2.StringValue()
155          proto.Unpack(unpacked)
156          return cls(name=unpacked.value)
157        raise AssertionError(
158            "Did not receive proto of correct type during deserialization. "
159            f"Expected type {wrappers_pb2.StringValue.DESCRIPTOR.full_name}, "
160            f"got {proto.TypeName()}")
161
162    m = Module("a")
163    m.v.assign(5)
164    loaded = cycle(m, cycles)
165    self.assertIsInstance(loaded, Module)
166    self.assertEqual(5, loaded.v.numpy())
167    self.assertEqual("a", loaded.name)
168
169  def test_none_proto(self, cycles):
170
171    @registration.register_serializable(name=f"NoneProto{cycles}")
172    class Module(autotrackable.AutoTrackable):
173
174      def __init__(self, name="module"):
175        self.v = variables.Variable(1.)
176        self.name = name
177
178      # Leave _serialize_to_proto as the default (returns `None`).
179
180      @classmethod
181      def _deserialize_from_proto(cls, proto, **unused_kwargs):
182        self.assertEqual(proto.ByteSize(), 0)
183        return cls("deserialized")
184
185    m = Module("a")
186    m.v.assign(5)
187    loaded = cycle(m, cycles)
188    self.assertIsInstance(loaded, Module)
189    self.assertEqual(5, loaded.v.numpy())
190    self.assertEqual("deserialized", loaded.name)
191
192  def test_deserialization_dependencies(self, cycles):
193    @registration.register_serializable(name=f"Dependency{cycles}")
194    class Module(autotrackable.AutoTrackable):
195
196      def __init__(self, v=None):
197        self.v = v if v is not None else variables.Variable(1.)
198
199      def _deserialization_dependencies(self, children):
200        del children  # Unused.
201        return {"v": self.v}
202
203      @classmethod
204      def _deserialize_from_proto(cls, dependencies, **unused_kwargs):
205        self.assertIn("v", dependencies)
206        return cls(v=dependencies["v"])
207
208    m = Module()
209    m.v.assign(5)
210    loaded = cycle(m, cycles)
211    self.assertIsInstance(loaded, Module)
212    self.assertEqual(5, loaded.v.numpy())
213
214  def test_registered_saver(self, cycles):
215    p1 = Part([1, 4])
216    p2 = Part([2, 5])
217    p3 = Part([3, 6])
218    s = Stack([p1, p2, p3])
219    loaded = cycle(s, cycles)
220    self.assertAllEqual(s.value(), loaded.value())
221
222
223class SingleCycleTest(test.TestCase):
224
225  @test_util.deprecated_graph_mode_only()
226  def test_registered_saver_fails_in_saved_model_graph_mode(self):
227    with context.eager_mode():
228      p1 = Part([1, 4])
229      p2 = Part([2, 5])
230      p3 = Part([3, 6])
231      s = Stack([p1, p2, p3])
232      save_dir = os.path.join(self.get_temp_dir(), "save_dir")
233      save.save(s, save_dir)
234
235    with self.assertRaisesRegex(
236        NotImplementedError,
237        "registered checkpoint saver is not supported in graph mode"):
238      load.load(save_dir)
239
240  def test_registered_saver_checkpoint(self):
241    p1 = Part([1, 4])
242    p2 = Part([2, 5])
243    p3 = Part([3, 6])
244    s = Stack([p1, p2, p3])
245    s2 = Stack([p3, p1, p2])
246
247    expected_value_s = s.value()
248    expected_value_s2 = s2.value()
249
250    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
251    util.Checkpoint(s=s, s2=s2).write(ckpt_path)
252
253    del s, s2, p1, p2, p3
254
255    restore_s = Stack([Part([0, 0]) for _ in range(3)])
256    util.Checkpoint(s=restore_s).read(ckpt_path).expect_partial()
257    self.assertAllEqual(expected_value_s, restore_s.value())
258    util.Checkpoint(s2=restore_s).read(ckpt_path).expect_partial()
259    self.assertAllEqual(expected_value_s2, restore_s.value())
260
261  def test_compatible_with_v1_savedmodel(self):
262    p1 = Part([1, 4])
263    p2 = Part([2, 5])
264    p3 = Part([3, 6])
265    s = Stack([p1, p2, p3])
266    save_path = os.path.join(self.get_temp_dir(), "savedmodel")
267
268    @def_function.function(input_signature=[])
269    def serve():
270      return {"value": s.value()}
271
272    exported_value = serve()["value"]
273
274    save.save(s, save_path, signatures=serve)
275    with ops.Graph().as_default(), session.Session() as sess:
276      metagraph = loader.load(sess, ["serve"], save_path)
277      value_output = metagraph.signature_def["serving_default"].outputs["value"]
278      self.assertAllEqual(exported_value, sess.run(value_output.name))
279
280  def test_non_strict_predicate(self):
281    class NonStrictPredicateClass(autotrackable.AutoTrackable):
282      pass
283    registration.register_checkpoint_saver(
284        name="NonStrictPredicate",
285        predicate=lambda x: isinstance(x, NonStrictPredicateClass),
286        save_fn=lambda **kwargs: [],
287        restore_fn=lambda **kwargs: None,
288        strict_predicate_restore=False)
289
290    root = NonStrictPredicateClass()
291    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
292    util.Checkpoint(root).write(ckpt_path)
293
294    root2 = autotrackable.AutoTrackable()
295    # This should run without throwing an error.
296    util.Checkpoint(root2).read(ckpt_path)
297
298  def test_strict_predicate(self):
299    class StrictPredicateClass(autotrackable.AutoTrackable):
300      pass
301    registration.register_checkpoint_saver(
302        name="StrictPredicate",
303        predicate=lambda x: isinstance(x, StrictPredicateClass),
304        save_fn=lambda **kwargs: [],
305        restore_fn=lambda **kwargs: None,
306        strict_predicate_restore=True)
307
308    root = StrictPredicateClass()
309    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
310    util.Checkpoint(root).write(ckpt_path)
311
312    root2 = autotrackable.AutoTrackable()
313    with self.assertRaisesRegex(ValueError, "saver cannot be used"):
314      util.Checkpoint(root2).read(ckpt_path)
315
316  def test_registered_saver_is_called_before_save_after_load(self):
317    if not context.executing_eagerly():
318      self.skipTest("This test must run under eager mode.")
319
320    class RestoreClass(autotrackable.AutoTrackable):
321      pass
322    def save_fn(trackables, file_prefix):
323      del trackables  # Unused.
324      # Check that directory is empty
325      files = gfile.ListDirectory(os.path.dirname(file_prefix.numpy()))
326      self.assertEmpty(files)
327
328    def restore_fn(trackables, merged_prefix):
329      del merged_prefix  # Unused.
330      root = next(trackables.values())
331      self.assertEqual(root.v.numpy(), 123)
332
333    registration.register_checkpoint_saver(
334        name="OptionalRestore",
335        predicate=lambda x: isinstance(x, RestoreClass),
336        save_fn=save_fn,
337        restore_fn=restore_fn)
338
339    root = RestoreClass()
340    root.v = variables.Variable(123.0)
341
342    ckpt_path = os.path.join(self.get_temp_dir(), "ckpt")
343    util.Checkpoint(root).write(ckpt_path)
344
345  def test_migration_backwards_compatibility(self):
346    # Tests that objects migrated to using the advanced saver registration can
347    # use pre-migration checkpoints.
348
349    class NoRegisteredSaver(autotrackable.AutoTrackable):
350
351      def __init__(self, name):
352        self.name = name
353
354      def _serialize_to_tensors(self):
355        return {"name": constant_op.constant(self.name)}
356
357    class RegisteredSaver(autotrackable.AutoTrackable):
358
359      def __init__(self, name):
360        self.name = name
361
362    def _get_tensors(trackables, append_name=True):
363      tensor_names = []
364      shapes_and_slices = []
365      tensors = []
366      restored_trackables = []
367      for obj_prefix, obj in trackables.items():
368        tensor_names.append(obj_prefix + "name" if append_name else obj_prefix)
369        shapes_and_slices.append("")
370        tensors.append(constant_op.constant(obj.name))
371        restored_trackables.append(obj)
372      return tensor_names, shapes_and_slices, tensors, restored_trackables
373
374    def save_fn(trackables, file_prefix):
375      tensor_names, shapes_and_slices, tensors, _ = _get_tensors(trackables)
376      io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors)
377      return file_prefix
378
379    def restore_fn(trackables, merged_prefix):
380      tensor_names, shapes_and_slices, tensors, restored_trackables = (
381          _get_tensors(trackables))
382      dtypes = [t.dtype for t in tensors]
383      try:
384        restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
385                                             shapes_and_slices, dtypes)
386      except errors_impl.NotFoundError:
387        # If a NotFoundError is caught, then it means that the checkpoint
388        # was written prior to the saver registration migration.
389        tensor_names, shapes_and_slices, tensors, restored_trackables = (
390            _get_tensors(trackables, append_name=False))
391        restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names,
392                                             shapes_and_slices, dtypes)
393      for trackable, name_tensor in zip(restored_trackables, restored_tensors):
394        trackable.name = name_tensor
395
396    registration.register_checkpoint_saver(
397        name="MigratedSaver",
398        predicate=lambda x: isinstance(x, RegisteredSaver),
399        save_fn=save_fn,
400        restore_fn=restore_fn,
401    )
402
403    before = NoRegisteredSaver("before")
404    after = RegisteredSaver("after")
405    before_ckpt_path = os.path.join(self.get_temp_dir(), "before_ckpt")
406    util.Checkpoint(before).write(before_ckpt_path)
407
408    after_ckpt = util.Checkpoint(after)
409    after_ckpt_path = os.path.join(self.get_temp_dir(), "after_ckpt")
410    after_ckpt.write(after_ckpt_path)
411
412    # Try loading the pre-migrated checkpoint to the migrated object.
413    after_ckpt.read(before_ckpt_path)
414    self.assertEqual(b"before", self.evaluate(after.name))
415
416
417if __name__ == "__main__":
418  test.main()
419