xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/revived_types.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles types registrations for tf.saved_model.load."""
16
17from tensorflow.core.framework import versions_pb2
18from tensorflow.core.protobuf import saved_object_graph_pb2
19from tensorflow.python.util.tf_export import tf_export
20
21
22@tf_export("__internal__.saved_model.load.VersionedTypeRegistration", v1=[])
23class VersionedTypeRegistration(object):
24  """Holds information about one version of a revived type."""
25
26  def __init__(self, object_factory, version, min_producer_version,
27               min_consumer_version, bad_consumers=None, setter=setattr):
28    """Identify a revived type version.
29
30    Args:
31      object_factory: A callable which takes a SavedUserObject proto and returns
32        a trackable object. Dependencies are added later via `setter`.
33      version: An integer, the producer version of this wrapper type. When
34        making incompatible changes to a wrapper, add a new
35        `VersionedTypeRegistration` with an incremented `version`. The most
36        recent version will be saved, and all registrations with a matching
37        identifier will be searched for the highest compatible version to use
38        when loading.
39      min_producer_version: The minimum producer version number required to use
40        this `VersionedTypeRegistration` when loading a proto.
41      min_consumer_version: `VersionedTypeRegistration`s with a version number
42        less than `min_consumer_version` will not be used to load a proto saved
43        with this object. `min_consumer_version` should be set to the lowest
44        version number which can successfully load protos saved by this
45        object. If no matching registration is available on load, the object
46        will be revived with a generic trackable type.
47
48        `min_consumer_version` and `bad_consumers` are a blunt tool, and using
49        them will generally break forward compatibility: previous versions of
50        TensorFlow will revive newly saved objects as opaque trackable
51        objects rather than wrapped objects. When updating wrappers, prefer
52        saving new information but preserving compatibility with previous
53        wrapper versions. They are, however, useful for ensuring that
54        previously-released buggy wrapper versions degrade gracefully rather
55        than throwing exceptions when presented with newly-saved SavedModels.
56      bad_consumers: A list of consumer versions which are incompatible (in
57        addition to any version less than `min_consumer_version`).
58      setter: A callable with the same signature as `setattr` to use when adding
59        dependencies to generated objects.
60    """
61    self.setter = setter
62    self.identifier = None  # Set after registration
63    self._object_factory = object_factory
64    self.version = version
65    self._min_consumer_version = min_consumer_version
66    self._min_producer_version = min_producer_version
67    if bad_consumers is None:
68      bad_consumers = []
69    self._bad_consumers = bad_consumers
70
71  def to_proto(self):
72    """Create a SavedUserObject proto."""
73    # For now wrappers just use dependencies to save their state, so the
74    # SavedUserObject doesn't depend on the object being saved.
75    # TODO(allenl): Add a wrapper which uses its own proto.
76    return saved_object_graph_pb2.SavedUserObject(
77        identifier=self.identifier,
78        version=versions_pb2.VersionDef(
79            producer=self.version,
80            min_consumer=self._min_consumer_version,
81            bad_consumers=self._bad_consumers))
82
83  def from_proto(self, proto):
84    """Recreate a trackable object from a SavedUserObject proto."""
85    return self._object_factory(proto)
86
87  def should_load(self, proto):
88    """Checks if this object should load the SavedUserObject `proto`."""
89    if proto.identifier != self.identifier:
90      return False
91    if self.version < proto.version.min_consumer:
92      return False
93    if proto.version.producer < self._min_producer_version:
94      return False
95    for bad_version in proto.version.bad_consumers:
96      if self.version == bad_version:
97        return False
98    return True
99
100
101# string identifier -> (predicate, [VersionedTypeRegistration])
102_REVIVED_TYPE_REGISTRY = {}
103_TYPE_IDENTIFIERS = []
104
105
106@tf_export("__internal__.saved_model.load.register_revived_type", v1=[])
107def register_revived_type(identifier, predicate, versions):
108  """Register a type for revived objects.
109
110  Args:
111    identifier: A unique string identifying this class of objects.
112    predicate: A Boolean predicate for this registration. Takes a
113      trackable object as an argument. If True, `type_registration` may be
114      used to save and restore the object.
115    versions: A list of `VersionedTypeRegistration` objects.
116  """
117  # Keep registrations in order of version. We always use the highest matching
118  # version (respecting the min consumer version and bad consumers).
119  versions.sort(key=lambda reg: reg.version, reverse=True)
120  if not versions:
121    raise AssertionError("Need at least one version of a registered type.")
122  version_numbers = set()
123  for registration in versions:
124    # Copy over the identifier for use in generating protos
125    registration.identifier = identifier
126    if registration.version in version_numbers:
127      raise AssertionError(
128          f"Got multiple registrations with version {registration.version} for "
129          f"type {identifier}.")
130    version_numbers.add(registration.version)
131
132  if identifier in _REVIVED_TYPE_REGISTRY:
133    raise AssertionError(f"Duplicate registrations for type '{identifier}'")
134
135  _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions)
136  _TYPE_IDENTIFIERS.append(identifier)
137
138
139def serialize(obj):
140  """Create a SavedUserObject from a trackable object."""
141  for identifier in _TYPE_IDENTIFIERS:
142    predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
143    if predicate(obj):
144      # Always uses the most recent version to serialize.
145      return versions[0].to_proto()
146  return None
147
148
149def deserialize(proto):
150  """Create a trackable object from a SavedUserObject proto.
151
152  Args:
153    proto: A SavedUserObject to deserialize.
154
155  Returns:
156    A tuple of (trackable, assignment_fn) where assignment_fn has the same
157    signature as setattr and should be used to add dependencies to
158    `trackable` when they are available.
159  """
160  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
161      proto.identifier, (None, None))
162  if type_registrations is not None:
163    for type_registration in type_registrations:
164      if type_registration.should_load(proto):
165        return (type_registration.from_proto(proto), type_registration.setter)
166  return None
167
168
169@tf_export("__internal__.saved_model.load.registered_identifiers", v1=[])
170def registered_identifiers():
171  """Return all the current registered revived object identifiers.
172
173  Returns:
174    A set of strings.
175  """
176  return _REVIVED_TYPE_REGISTRY.keys()
177
178
179@tf_export("__internal__.saved_model.load.get_setter", v1=[])
180def get_setter(proto):
181  """Gets the registered setter function for the SavedUserObject proto.
182
183  See VersionedTypeRegistration for info about the setter function.
184
185  Args:
186    proto: SavedUserObject proto
187
188  Returns:
189    setter function
190  """
191  _, type_registrations = _REVIVED_TYPE_REGISTRY.get(
192      proto.identifier, (None, None))
193  if type_registrations is not None:
194    for type_registration in type_registrations:
195      if type_registration.should_load(proto):
196        return type_registration.setter
197  return None
198