1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Checkpoint compatibility functions with SaveableObject. 16 17Compatibility methods to ensure that checkpoints are saved with the same 18metadata attributes before/after the SaveableObject deprecation. 19""" 20 21_LEGACY_SAVEABLE_NAME = "_LEGACY_SAVEABLE_NAME" 22 23 24def legacy_saveable_name(name): 25 """Decorator to set the local name to use in the Checkpoint. 26 27 Needed for migrating certain Trackables (see next paragraph) from the legacy 28 `_gather_saveables_for_checkpoint` to the new `_serialize_to_tensors` 29 function. 30 31 This decorator should be used if the SaveableObject generates tensors with 32 different names from the name that is passed to the factory. 33 34 Example migration: 35 36 *Before* 37 38 ``` 39 class MyTrackable(Trackable): 40 def _gather_saveables_for_checkpoint(self): 41 return {"key": _MySaveable} 42 43 class _MySaveable(SaveableObject): 44 def __init__(self, name): 45 specs = [ 46 SaveSpec(tensor1, "", name + "-1") 47 SaveSpec(tensor2, "", name + "-2") 48 ] 49 super().__init__(None, specs, name) 50 ``` 51 52 *After* 53 54 ``` 55 @legacy_saveable_name("key") 56 class MyTrackable(Trackable): 57 58 def _serialize_to_tensors(self): 59 return {"key-1": tensor1, "key-2": tensor2} 60 ``` 61 62 Args: 63 name: String name of the SaveableObject factory (the key returned in the 64 `_gather_saveables_for_checkpoint` function) 65 66 Returns: 67 A decorator. 68 """ 69 def decorator(cls_or_obj): 70 setattr(cls_or_obj, _LEGACY_SAVEABLE_NAME, name) 71 return cls_or_obj 72 return decorator 73 74 75def get_saveable_name(cls_or_obj): 76 return getattr(cls_or_obj, _LEGACY_SAVEABLE_NAME, None) 77 78 79_FORCE_CHECKPOINT_CONVERSION = False 80 81 82def force_checkpoint_conversion(value=True): 83 """Forces checkpoint to use the new implementation. 84 85 The new checkpoint implementation is changing the saved metadata slightly, 86 and therefore may break forward compatibility in newly saved checkpoints. This 87 means: 88 89 - Previous versions of TensorFlow may not be able to load new checkpoints. 90 - Backwards compatibility is unchanged: Old checkpoints can still be loaded. 91 92 TensorFlow guarantees 3 weeks of forward compatibility, so this flag will be 93 removed in the future weeks, after which checkpoint conversion will happen by 94 default. 95 96 **What happens when this flag is enabled?** 97 98 The checkpoint will be saved with different metadata, meaning that previous 99 versions of TensorFlow (<=2.10) will not be able to load this checkpoint. 100 101 Args: 102 value: Boolean value, whether or not to force checkpoint conversion to the 103 new implementation. 104 """ 105 # TODO(kathywu): Add definite date for flag removal. 106 global _FORCE_CHECKPOINT_CONVERSION 107 _FORCE_CHECKPOINT_CONVERSION = value 108 109 110def force_checkpoint_conversion_enabled(): 111 return _FORCE_CHECKPOINT_CONVERSION 112 113 114class CheckpointConversionError(Exception): 115 pass 116