1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for SaveableObject compatibility.""" 17 18import os 19 20from tensorflow.python.checkpoint import checkpoint 21from tensorflow.python.checkpoint import saveable_compat 22from tensorflow.python.checkpoint.testdata import generate_checkpoint 23from tensorflow.python.eager import test 24from tensorflow.python.ops import variables 25from tensorflow.python.trackable import base 26from tensorflow.python.training import checkpoint_utils 27from tensorflow.python.training.saving import saveable_object 28 29 30_LEGACY_TABLE_CHECKPOINT_PATH = test.test_src_dir_path( 31 "python/checkpoint/testdata/table_legacy_saveable_object") 32 33 34class SaveableCompatTest(test.TestCase): 35 36 def test_lookup_table_compatibility(self): 37 saveable_compat.force_checkpoint_conversion(False) 38 39 table_module = generate_checkpoint.TableModule() 40 ckpt = checkpoint.Checkpoint(table_module) 41 checkpoint_directory = self.get_temp_dir() 42 checkpoint_path = os.path.join(checkpoint_directory, "ckpt") 43 ckpt.write(checkpoint_path) 44 45 # Ensure that the checkpoint metadata and keys are the same. 46 legacy_metadata = checkpoint.object_metadata(_LEGACY_TABLE_CHECKPOINT_PATH) 47 metadata = checkpoint.object_metadata(checkpoint_path) 48 49 def _get_table_node(object_metadata): 50 for child in object_metadata.nodes[0].children: 51 if child.local_name == "lookup_table": 52 return object_metadata.nodes[child.node_id] 53 54 table_proto = _get_table_node(metadata) 55 legacy_table_proto = _get_table_node(legacy_metadata) 56 self.assertAllEqual( 57 [table_proto.attributes[0].name, 58 table_proto.attributes[0].checkpoint_key], 59 [legacy_table_proto.attributes[0].name, 60 legacy_table_proto.attributes[0].checkpoint_key]) 61 legacy_reader = checkpoint_utils.load_checkpoint( 62 _LEGACY_TABLE_CHECKPOINT_PATH) 63 reader = checkpoint_utils.load_checkpoint(checkpoint_path) 64 self.assertEqual( 65 legacy_reader.get_variable_to_shape_map().keys(), 66 reader.get_variable_to_shape_map().keys()) 67 68 # Ensure that previous checkpoint can be loaded into current table. 69 ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed() 70 71 72class TestForceCheckpointConversionFlag(test.TestCase): 73 74 def test_checkpoint(self): 75 saveable_compat.force_checkpoint_conversion() 76 77 table_module = generate_checkpoint.TableModule() 78 table_module.lookup_table.insert(3, 9) 79 ckpt = checkpoint.Checkpoint(table_module) 80 checkpoint_directory = self.get_temp_dir() 81 checkpoint_path = os.path.join(checkpoint_directory, "ckpt") 82 ckpt.write(checkpoint_path) 83 84 new_table_module = generate_checkpoint.TableModule() 85 self.assertEqual(-1, self.evaluate(new_table_module.lookup_table.lookup(3))) 86 87 new_ckpt = checkpoint.Checkpoint(new_table_module) 88 new_ckpt.read(checkpoint_path).assert_consumed() 89 self.assertEqual(9, self.evaluate(new_table_module.lookup_table.lookup(3))) 90 91 def test_backwards_compatibility(self): 92 saveable_compat.force_checkpoint_conversion() 93 94 table_module = generate_checkpoint.TableModule() 95 table_module.lookup_table.insert(3, 9) 96 self.assertEqual(9, self.evaluate(table_module.lookup_table.lookup(3))) 97 98 ckpt = checkpoint.Checkpoint(table_module) 99 ckpt.read(_LEGACY_TABLE_CHECKPOINT_PATH).assert_consumed() 100 self.assertEqual(-1, self.evaluate(table_module.lookup_table.lookup(3))) 101 self.assertEqual(4, self.evaluate(table_module.lookup_table.lookup(2))) 102 103 def test_forward_compatibility(self): 104 105 class _MultiSpecSaveable(saveable_object.SaveableObject): 106 107 def __init__(self, obj, name): 108 self.obj = obj 109 specs = [ 110 saveable_object.SaveSpec(obj.a, "", name + "-a"), 111 saveable_object.SaveSpec(obj.b, "", name + "-b")] 112 super(_MultiSpecSaveable, self).__init__(None, specs, name) 113 114 def restore(self, restored_tensors, restored_shapes): 115 del restored_shapes # Unused. 116 self.obj.a.assign(restored_tensors[0]) 117 self.obj.b.assign(restored_tensors[1]) 118 119 class DeprecatedTrackable(base.Trackable): 120 121 def __init__(self): 122 self.a = variables.Variable(1.0) 123 self.b = variables.Variable(2.0) 124 125 def _gather_saveables_for_checkpoint(self): 126 return {"foo": lambda name: _MultiSpecSaveable(self, name)} 127 128 @saveable_compat.legacy_saveable_name("foo") 129 class NewTrackable(base.Trackable): 130 131 def __init__(self): 132 self.a = variables.Variable(3.0) 133 self.b = variables.Variable(4.0) 134 135 def _serialize_to_tensors(self): 136 return {"-a": self.a, "-b": self.b} 137 138 def _restore_from_tensors(self, restored_tensors): 139 self.a.assign(restored_tensors["-a"]) 140 self.b.assign(restored_tensors["-b"]) 141 142 new = NewTrackable() 143 144 # Test with the checkpoint conversion flag disabled (normal compatibility). 145 saveable_compat.force_checkpoint_conversion(False) 146 checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt") 147 checkpoint.Checkpoint(new).write(checkpoint_path) 148 149 dep = DeprecatedTrackable() 150 checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() 151 self.assertEqual(3, self.evaluate(dep.a)) 152 self.assertEqual(4, self.evaluate(dep.b)) 153 154 # Now test with the checkpoint conversion flag enabled (forward compat). 155 # The deprecated object will try to load from the new checkpoint. 156 saveable_compat.force_checkpoint_conversion() 157 checkpoint_path = os.path.join(self.get_temp_dir(), "ckpt2") 158 checkpoint.Checkpoint(new).write(checkpoint_path) 159 160 dep = DeprecatedTrackable() 161 checkpoint.Checkpoint(dep).read(checkpoint_path).assert_consumed() 162 self.assertEqual(3, self.evaluate(dep.a)) 163 self.assertEqual(4, self.evaluate(dep.b)) 164 165if __name__ == "__main__": 166 test.main() 167