1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests saving with registered Trackable classes and checkpoint functions.""" 16 17import os 18import tempfile 19 20from absl.testing import parameterized 21 22from google.protobuf import wrappers_pb2 23from tensorflow.python.checkpoint import checkpoint as util 24from tensorflow.python.client import session 25from tensorflow.python.eager import context 26from tensorflow.python.eager import def_function 27from tensorflow.python.eager import test 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import errors_impl 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import io_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import gfile 37from tensorflow.python.saved_model import load 38from tensorflow.python.saved_model import loader 39from tensorflow.python.saved_model import registration 40from tensorflow.python.saved_model import save 41from tensorflow.python.trackable import autotrackable 42 43 44@registration.register_serializable() 45class Part(resource_variable_ops.ResourceVariable): 46 47 def __init__(self, value): 48 self._init_from_args(value) 49 50 @classmethod 51 def _deserialize_from_proto(cls, **kwargs): 52 return cls([0, 0]) 53 54 def _export_to_saved_model_graph(self, object_map, tensor_map, **kwargs): 55 p = Part(array_ops.zeros(self.shape, self.dtype)) 56 object_map[self] = p 57 tensor_map[self.handle] = p.handle 58 return [self.handle] 59 60 61@registration.register_serializable() 62class Stack(autotrackable.AutoTrackable): 63 64 def __init__(self, parts=None): 65 self.parts = parts 66 67 @def_function.function(input_signature=[]) 68 def value(self): 69 return array_ops.stack(self.parts) 70 71 72def get_tensor_slices(trackables): 73 tensor_names = [] 74 shapes_and_slices = [] 75 tensors = [] 76 restored_trackables = [] 77 for obj_prefix, obj in trackables.items(): 78 if isinstance(obj, Part): 79 continue # only save stacks 80 tensor_names.append(obj_prefix + "/value") 81 shapes_and_slices.append("") 82 x = obj.value() 83 with ops.device("/device:CPU:0"): 84 tensors.append(array_ops.identity(x)) 85 restored_trackables.append(obj) 86 87 return tensor_names, shapes_and_slices, tensors, restored_trackables 88 89 90def save_stacks_and_parts(trackables, file_prefix): 91 """Save stack and part objects to a checkpoint shard.""" 92 tensor_names, shapes_and_slices, tensors, _ = get_tensor_slices(trackables) 93 io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors) 94 return file_prefix 95 96 97def restore_stacks_and_parts(trackables, merged_prefix): 98 tensor_names, shapes_and_slices, tensors, restored_trackables = ( 99 get_tensor_slices(trackables)) 100 dtypes = [t.dtype for t in tensors] 101 restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, 102 shapes_and_slices, dtypes) 103 for trackable, restored_tensor in zip(restored_trackables, restored_tensors): 104 expected_shape = trackable.value().get_shape() 105 restored_tensor = array_ops.reshape(restored_tensor, expected_shape) 106 parts = array_ops.unstack(restored_tensor) 107 for part, restored_part in zip(trackable.parts, parts): 108 part.assign(restored_part) 109 110 111registration.register_checkpoint_saver( 112 name="stacks", 113 predicate=lambda x: isinstance(x, (Stack, Part)), 114 save_fn=save_stacks_and_parts, 115 restore_fn=restore_stacks_and_parts) 116 117 118def cycle(obj, cycles, signatures=None, options=None): 119 to_save = obj 120 for _ in range(cycles): 121 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 122 # If available, we'll run the save and restore preferring the GPU. This 123 # just makes sure we aren't throwing errors and have enough 124 # device("CPU") blocks to satisfy the placer. 125 with test_util.use_gpu(): 126 save.save(to_save, path, signatures, options=options) 127 loaded = load.load(path) 128 signatures = loaded.signatures 129 to_save = loaded 130 return loaded 131 132 133@parameterized.named_parameters( 134 dict(testcase_name="ReloadOnce", cycles=1), 135 dict(testcase_name="ReloadTwice", cycles=2), 136 dict(testcase_name="ReloadThrice", cycles=3)) 137class SavedModelTest(test.TestCase, parameterized.TestCase): 138 139 def test_registered_serializable(self, cycles): 140 141 @registration.register_serializable(name=f"SaveAndLoad{cycles}") 142 class Module(autotrackable.AutoTrackable): 143 144 def __init__(self, name="module"): 145 self.v = variables.Variable(1.) 146 self.name = name 147 148 def _serialize_to_proto(self, **unused_kwargs): 149 return wrappers_pb2.StringValue(value=self.name) 150 151 @classmethod 152 def _deserialize_from_proto(cls, proto, **unused_kwargs): 153 if proto.Is(wrappers_pb2.StringValue.DESCRIPTOR): 154 unpacked = wrappers_pb2.StringValue() 155 proto.Unpack(unpacked) 156 return cls(name=unpacked.value) 157 raise AssertionError( 158 "Did not receive proto of correct type during deserialization. " 159 f"Expected type {wrappers_pb2.StringValue.DESCRIPTOR.full_name}, " 160 f"got {proto.TypeName()}") 161 162 m = Module("a") 163 m.v.assign(5) 164 loaded = cycle(m, cycles) 165 self.assertIsInstance(loaded, Module) 166 self.assertEqual(5, loaded.v.numpy()) 167 self.assertEqual("a", loaded.name) 168 169 def test_none_proto(self, cycles): 170 171 @registration.register_serializable(name=f"NoneProto{cycles}") 172 class Module(autotrackable.AutoTrackable): 173 174 def __init__(self, name="module"): 175 self.v = variables.Variable(1.) 176 self.name = name 177 178 # Leave _serialize_to_proto as the default (returns `None`). 179 180 @classmethod 181 def _deserialize_from_proto(cls, proto, **unused_kwargs): 182 self.assertEqual(proto.ByteSize(), 0) 183 return cls("deserialized") 184 185 m = Module("a") 186 m.v.assign(5) 187 loaded = cycle(m, cycles) 188 self.assertIsInstance(loaded, Module) 189 self.assertEqual(5, loaded.v.numpy()) 190 self.assertEqual("deserialized", loaded.name) 191 192 def test_deserialization_dependencies(self, cycles): 193 @registration.register_serializable(name=f"Dependency{cycles}") 194 class Module(autotrackable.AutoTrackable): 195 196 def __init__(self, v=None): 197 self.v = v if v is not None else variables.Variable(1.) 198 199 def _deserialization_dependencies(self, children): 200 del children # Unused. 201 return {"v": self.v} 202 203 @classmethod 204 def _deserialize_from_proto(cls, dependencies, **unused_kwargs): 205 self.assertIn("v", dependencies) 206 return cls(v=dependencies["v"]) 207 208 m = Module() 209 m.v.assign(5) 210 loaded = cycle(m, cycles) 211 self.assertIsInstance(loaded, Module) 212 self.assertEqual(5, loaded.v.numpy()) 213 214 def test_registered_saver(self, cycles): 215 p1 = Part([1, 4]) 216 p2 = Part([2, 5]) 217 p3 = Part([3, 6]) 218 s = Stack([p1, p2, p3]) 219 loaded = cycle(s, cycles) 220 self.assertAllEqual(s.value(), loaded.value()) 221 222 223class SingleCycleTest(test.TestCase): 224 225 @test_util.deprecated_graph_mode_only() 226 def test_registered_saver_fails_in_saved_model_graph_mode(self): 227 with context.eager_mode(): 228 p1 = Part([1, 4]) 229 p2 = Part([2, 5]) 230 p3 = Part([3, 6]) 231 s = Stack([p1, p2, p3]) 232 save_dir = os.path.join(self.get_temp_dir(), "save_dir") 233 save.save(s, save_dir) 234 235 with self.assertRaisesRegex( 236 NotImplementedError, 237 "registered checkpoint saver is not supported in graph mode"): 238 load.load(save_dir) 239 240 def test_registered_saver_checkpoint(self): 241 p1 = Part([1, 4]) 242 p2 = Part([2, 5]) 243 p3 = Part([3, 6]) 244 s = Stack([p1, p2, p3]) 245 s2 = Stack([p3, p1, p2]) 246 247 expected_value_s = s.value() 248 expected_value_s2 = s2.value() 249 250 ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") 251 util.Checkpoint(s=s, s2=s2).write(ckpt_path) 252 253 del s, s2, p1, p2, p3 254 255 restore_s = Stack([Part([0, 0]) for _ in range(3)]) 256 util.Checkpoint(s=restore_s).read(ckpt_path).expect_partial() 257 self.assertAllEqual(expected_value_s, restore_s.value()) 258 util.Checkpoint(s2=restore_s).read(ckpt_path).expect_partial() 259 self.assertAllEqual(expected_value_s2, restore_s.value()) 260 261 def test_compatible_with_v1_savedmodel(self): 262 p1 = Part([1, 4]) 263 p2 = Part([2, 5]) 264 p3 = Part([3, 6]) 265 s = Stack([p1, p2, p3]) 266 save_path = os.path.join(self.get_temp_dir(), "savedmodel") 267 268 @def_function.function(input_signature=[]) 269 def serve(): 270 return {"value": s.value()} 271 272 exported_value = serve()["value"] 273 274 save.save(s, save_path, signatures=serve) 275 with ops.Graph().as_default(), session.Session() as sess: 276 metagraph = loader.load(sess, ["serve"], save_path) 277 value_output = metagraph.signature_def["serving_default"].outputs["value"] 278 self.assertAllEqual(exported_value, sess.run(value_output.name)) 279 280 def test_non_strict_predicate(self): 281 class NonStrictPredicateClass(autotrackable.AutoTrackable): 282 pass 283 registration.register_checkpoint_saver( 284 name="NonStrictPredicate", 285 predicate=lambda x: isinstance(x, NonStrictPredicateClass), 286 save_fn=lambda **kwargs: [], 287 restore_fn=lambda **kwargs: None, 288 strict_predicate_restore=False) 289 290 root = NonStrictPredicateClass() 291 ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") 292 util.Checkpoint(root).write(ckpt_path) 293 294 root2 = autotrackable.AutoTrackable() 295 # This should run without throwing an error. 296 util.Checkpoint(root2).read(ckpt_path) 297 298 def test_strict_predicate(self): 299 class StrictPredicateClass(autotrackable.AutoTrackable): 300 pass 301 registration.register_checkpoint_saver( 302 name="StrictPredicate", 303 predicate=lambda x: isinstance(x, StrictPredicateClass), 304 save_fn=lambda **kwargs: [], 305 restore_fn=lambda **kwargs: None, 306 strict_predicate_restore=True) 307 308 root = StrictPredicateClass() 309 ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") 310 util.Checkpoint(root).write(ckpt_path) 311 312 root2 = autotrackable.AutoTrackable() 313 with self.assertRaisesRegex(ValueError, "saver cannot be used"): 314 util.Checkpoint(root2).read(ckpt_path) 315 316 def test_registered_saver_is_called_before_save_after_load(self): 317 if not context.executing_eagerly(): 318 self.skipTest("This test must run under eager mode.") 319 320 class RestoreClass(autotrackable.AutoTrackable): 321 pass 322 def save_fn(trackables, file_prefix): 323 del trackables # Unused. 324 # Check that directory is empty 325 files = gfile.ListDirectory(os.path.dirname(file_prefix.numpy())) 326 self.assertEmpty(files) 327 328 def restore_fn(trackables, merged_prefix): 329 del merged_prefix # Unused. 330 root = next(trackables.values()) 331 self.assertEqual(root.v.numpy(), 123) 332 333 registration.register_checkpoint_saver( 334 name="OptionalRestore", 335 predicate=lambda x: isinstance(x, RestoreClass), 336 save_fn=save_fn, 337 restore_fn=restore_fn) 338 339 root = RestoreClass() 340 root.v = variables.Variable(123.0) 341 342 ckpt_path = os.path.join(self.get_temp_dir(), "ckpt") 343 util.Checkpoint(root).write(ckpt_path) 344 345 def test_migration_backwards_compatibility(self): 346 # Tests that objects migrated to using the advanced saver registration can 347 # use pre-migration checkpoints. 348 349 class NoRegisteredSaver(autotrackable.AutoTrackable): 350 351 def __init__(self, name): 352 self.name = name 353 354 def _serialize_to_tensors(self): 355 return {"name": constant_op.constant(self.name)} 356 357 class RegisteredSaver(autotrackable.AutoTrackable): 358 359 def __init__(self, name): 360 self.name = name 361 362 def _get_tensors(trackables, append_name=True): 363 tensor_names = [] 364 shapes_and_slices = [] 365 tensors = [] 366 restored_trackables = [] 367 for obj_prefix, obj in trackables.items(): 368 tensor_names.append(obj_prefix + "name" if append_name else obj_prefix) 369 shapes_and_slices.append("") 370 tensors.append(constant_op.constant(obj.name)) 371 restored_trackables.append(obj) 372 return tensor_names, shapes_and_slices, tensors, restored_trackables 373 374 def save_fn(trackables, file_prefix): 375 tensor_names, shapes_and_slices, tensors, _ = _get_tensors(trackables) 376 io_ops.save_v2(file_prefix, tensor_names, shapes_and_slices, tensors) 377 return file_prefix 378 379 def restore_fn(trackables, merged_prefix): 380 tensor_names, shapes_and_slices, tensors, restored_trackables = ( 381 _get_tensors(trackables)) 382 dtypes = [t.dtype for t in tensors] 383 try: 384 restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, 385 shapes_and_slices, dtypes) 386 except errors_impl.NotFoundError: 387 # If a NotFoundError is caught, then it means that the checkpoint 388 # was written prior to the saver registration migration. 389 tensor_names, shapes_and_slices, tensors, restored_trackables = ( 390 _get_tensors(trackables, append_name=False)) 391 restored_tensors = io_ops.restore_v2(merged_prefix, tensor_names, 392 shapes_and_slices, dtypes) 393 for trackable, name_tensor in zip(restored_trackables, restored_tensors): 394 trackable.name = name_tensor 395 396 registration.register_checkpoint_saver( 397 name="MigratedSaver", 398 predicate=lambda x: isinstance(x, RegisteredSaver), 399 save_fn=save_fn, 400 restore_fn=restore_fn, 401 ) 402 403 before = NoRegisteredSaver("before") 404 after = RegisteredSaver("after") 405 before_ckpt_path = os.path.join(self.get_temp_dir(), "before_ckpt") 406 util.Checkpoint(before).write(before_ckpt_path) 407 408 after_ckpt = util.Checkpoint(after) 409 after_ckpt_path = os.path.join(self.get_temp_dir(), "after_ckpt") 410 after_ckpt.write(after_ckpt_path) 411 412 # Try loading the pre-migrated checkpoint to the migrated object. 413 after_ckpt.read(before_ckpt_path) 414 self.assertEqual(b"before", self.evaluate(after.name)) 415 416 417if __name__ == "__main__": 418 test.main() 419