xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/saveable_compat.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Checkpoint compatibility functions with SaveableObject.
16
17Compatibility methods to ensure that checkpoints are saved with the same
18metadata attributes before/after the SaveableObject deprecation.
19"""
20
21_LEGACY_SAVEABLE_NAME = "_LEGACY_SAVEABLE_NAME"
22
23
24def legacy_saveable_name(name):
25  """Decorator to set the local name to use in the Checkpoint.
26
27  Needed for migrating certain Trackables (see next paragraph) from the legacy
28  `_gather_saveables_for_checkpoint` to the new `_serialize_to_tensors`
29  function.
30
31  This decorator should be used if the SaveableObject generates tensors with
32  different names from the name that is passed to the factory.
33
34  Example migration:
35
36  *Before*
37
38  ```
39  class MyTrackable(Trackable):
40    def _gather_saveables_for_checkpoint(self):
41      return {"key": _MySaveable}
42
43  class _MySaveable(SaveableObject):
44    def __init__(self, name):
45      specs = [
46          SaveSpec(tensor1, "", name + "-1")
47          SaveSpec(tensor2, "", name + "-2")
48      ]
49      super().__init__(None, specs, name)
50  ```
51
52  *After*
53
54  ```
55  @legacy_saveable_name("key")
56  class MyTrackable(Trackable):
57
58    def _serialize_to_tensors(self):
59      return {"key-1": tensor1, "key-2": tensor2}
60  ```
61
62  Args:
63    name: String name of the SaveableObject factory (the key returned in the
64       `_gather_saveables_for_checkpoint` function)
65
66  Returns:
67    A decorator.
68  """
69  def decorator(cls_or_obj):
70    setattr(cls_or_obj, _LEGACY_SAVEABLE_NAME, name)
71    return cls_or_obj
72  return decorator
73
74
75def get_saveable_name(cls_or_obj):
76  return getattr(cls_or_obj, _LEGACY_SAVEABLE_NAME, None)
77
78
79_FORCE_CHECKPOINT_CONVERSION = False
80
81
82def force_checkpoint_conversion(value=True):
83  """Forces checkpoint to use the new implementation.
84
85  The new checkpoint implementation is changing the saved metadata slightly,
86  and therefore may break forward compatibility in newly saved checkpoints. This
87  means:
88
89    - Previous versions of TensorFlow may not be able to load new checkpoints.
90    - Backwards compatibility is unchanged: Old checkpoints can still be loaded.
91
92  TensorFlow guarantees 3 weeks of forward compatibility, so this flag will be
93  removed in the future weeks, after which checkpoint conversion will happen by
94  default.
95
96  **What happens when this flag is enabled?**
97
98  The checkpoint will be saved with different metadata, meaning that previous
99  versions of TensorFlow (<=2.10) will not be able to load this checkpoint.
100
101  Args:
102    value: Boolean value, whether or not to force checkpoint conversion to the
103      new implementation.
104  """
105  # TODO(kathywu): Add definite date for flag removal.
106  global _FORCE_CHECKPOINT_CONVERSION
107  _FORCE_CHECKPOINT_CONVERSION = value
108
109
110def force_checkpoint_conversion_enabled():
111  return _FORCE_CHECKPOINT_CONVERSION
112
113
114class CheckpointConversionError(Exception):
115  pass
116