xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/saveable_compat_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for SaveableObject compatibility."""
17
18import os
19
20from tensorflow.python.checkpoint import checkpoint
21from tensorflow.python.checkpoint import saveable_compat
22from tensorflow.python.checkpoint.testdata import generate_checkpoint
23from tensorflow.python.eager import test
24from tensorflow.python.ops import variables
25from tensorflow.python.trackable import base
26from tensorflow.python.training import checkpoint_utils
27from tensorflow.python.training.saving import saveable_object
28
29
30_LEGACY_TABLE_CHECKPOINT_PATH = test.test_src_dir_path(
31    "python/checkpoint/testdata/table_legacy_saveable_object")
32
33
34class SaveableCompatTest(test.TestCase):
35
36  def test_lookup_table_compatibility(self):
37    saveable_compat.force_checkpoint_conversion(False)
38
39    table_module = generate_checkpoint.TableModule()
40    ckpt = checkpoint.Checkpoint(table_module)
41    checkpoint_directory = self.get_temp_dir()
42    checkpoint_path = os.path.join(checkpoint_directory, "ckpt")
43    ckpt.write(checkpoint_path)
44
45    # Ensure that the checkpoint metadata and keys are the same.
46    legacy_metadata = checkpoint.object_metadata(_LEGACY_TABLE_CHECKPOINT_PATH)
47    metadata = checkpoint.object_metadata(checkpoint_path)
48
49    def _get_table_node(object_metadata):
50      for child in object_metadata.nodes[0].children:
51        if child.local_name == "lookup_table":
52          return object_metadata.nodes[child.node_id]
53
54    table_proto = _get_table_node(metadata)
55    legacy_table_proto = _get_table_node(legacy_metadata)
56    self.assertAllEqual(
57        [table_proto.attributes[0].name,
58         table_proto.attributes[0].checkpoint_key],
59        [legacy_table_proto.attributes[0].name,
60         legacy_table_proto.attributes[0].checkpoint_key])
61    legacy_reader = checkpoint_utils.load_checkpoint(
62        _LEGACY_TABLE_CHECKPOINT_PATH)
63    reader = checkpoint_utils.load_checkpoint(checkpoint_path)
64    self.assertEqual(
65        legacy_reader.get_variable_to_shape_map().keys(),
66        reader.get_variable_to_shape_map().keys())
67
68    # Ensure that previous checkpoint can be loaded into current table.
69    ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed()
70
71
72class TestForceCheckpointConversionFlag(test.TestCase):
73
74  def test_checkpoint(self):
75    saveable_compat.force_checkpoint_conversion()
76
77    table_module = generate_checkpoint.TableModule()
78    table_module.lookup_table.insert(3, 9)
79    ckpt = checkpoint.Checkpoint(table_module)
80    checkpoint_directory = self.get_temp_dir()
81    checkpoint_path = os.path.join(checkpoint_directory, "ckpt")
82    ckpt.write(checkpoint_path)
83
84    new_table_module = generate_checkpoint.TableModule()
85    self.assertEqual(-1, self.evaluate(new_table_module.lookup_table.lookup(3)))
86
87    new_ckpt = checkpoint.Checkpoint(new_table_module)
88    new_ckpt.read(checkpoint_path).assert_consumed()
89    self.assertEqual(9, self.evaluate(new_table_module.lookup_table.lookup(3)))
90
91  def test_backwards_compatibility(self):
92    saveable_compat.force_checkpoint_conversion()
93
94    table_module = generate_checkpoint.TableModule()
95    table_module.lookup_table.insert(3, 9)
96    self.assertEqual(9, self.evaluate(table_module.lookup_table.lookup(3)))
97
98    ckpt = checkpoint.Checkpoint(table_module)
99    ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed()
100    self.assertEqual(-1, self.evaluate(table_module.lookup_table.lookup(3)))
101    self.assertEqual(4, self.evaluate(table_module.lookup_table.lookup(2)))
102
103  def test_forward_compatibility(self):
104
105    class _MultiSpecSaveable(saveable_object.SaveableObject):
106
107      def __init__(self, obj, name):
108        self.obj = obj
109        specs = [
110            saveable_object.SaveSpec(obj.a, "", name + "-a"),
111            saveable_object.SaveSpec(obj.b, "", name + "-b")]
112        super(_MultiSpecSaveable, self).__init__(None, specs, name)
113
114      def restore(self, restored_tensors, restored_shapes):
115        del restored_shapes  # Unused.
116        self.obj.a.assign(restored_tensors[0])
117        self.obj.b.assign(restored_tensors[1])
118
119    class DeprecatedTrackable(base.Trackable):
120
121      def __init__(self):
122        self.a = variables.Variable(1.0)
123        self.b = variables.Variable(2.0)
124
125      def _gather_saveables_for_checkpoint(self):
126        return {"foo": lambda name: _MultiSpecSaveable(self, name)}
127
128    @saveable_compat.legacy_saveable_name("foo")
129    class NewTrackable(base.Trackable):
130
131      def __init__(self):
132        self.a = variables.Variable(3.0)
133        self.b = variables.Variable(4.0)
134
135      def _serialize_to_tensors(self):
136        return {"-a": self.a, "-b": self.b}
137
138      def _restore_from_tensors(self, restored_tensors):
139        self.a.assign(restored_tensors["-a"])
140        self.b.assign(restored_tensors["-b"])
141
142    new = NewTrackable()
143
144    # Test with the checkpoint conversion flag disabled (normal compatibility).
145    saveable_compat.force_checkpoint_conversion(False)
146    checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt")
147    checkpoint.Checkpoint(new).write(checkpoint_path)
148
149    dep = DeprecatedTrackable()
150    checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
151    self.assertEqual(3, self.evaluate(dep.a))
152    self.assertEqual(4, self.evaluate(dep.b))
153
154    # Now test with the checkpoint conversion flag enabled (forward compat).
155    # The deprecated object will try to load from the new checkpoint.
156    saveable_compat.force_checkpoint_conversion()
157    checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2")
158    checkpoint.Checkpoint(new).write(checkpoint_path)
159
160    dep = DeprecatedTrackable()
161    checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed()
162    self.assertEqual(3, self.evaluate(dep.a))
163    self.assertEqual(4, self.evaluate(dep.b))
164
165if __name__ == "__main__":
166  test.main()
167