1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Handles types registrations for tf.saved_model.load.""" 16 17from tensorflow.core.framework import versions_pb2 18from tensorflow.core.protobuf import saved_object_graph_pb2 19from tensorflow.python.util.tf_export import tf_export 20 21 22@tf_export("__internal__.saved_model.load.VersionedTypeRegistration", v1=[]) 23class VersionedTypeRegistration(object): 24 """Holds information about one version of a revived type.""" 25 26 def __init__(self, object_factory, version, min_producer_version, 27 min_consumer_version, bad_consumers=None, setter=setattr): 28 """Identify a revived type version. 29 30 Args: 31 object_factory: A callable which takes a SavedUserObject proto and returns 32 a trackable object. Dependencies are added later via `setter`. 33 version: An integer, the producer version of this wrapper type. When 34 making incompatible changes to a wrapper, add a new 35 `VersionedTypeRegistration` with an incremented `version`. The most 36 recent version will be saved, and all registrations with a matching 37 identifier will be searched for the highest compatible version to use 38 when loading. 39 min_producer_version: The minimum producer version number required to use 40 this `VersionedTypeRegistration` when loading a proto. 41 min_consumer_version: `VersionedTypeRegistration`s with a version number 42 less than `min_consumer_version` will not be used to load a proto saved 43 with this object. `min_consumer_version` should be set to the lowest 44 version number which can successfully load protos saved by this 45 object. If no matching registration is available on load, the object 46 will be revived with a generic trackable type. 47 48 `min_consumer_version` and `bad_consumers` are a blunt tool, and using 49 them will generally break forward compatibility: previous versions of 50 TensorFlow will revive newly saved objects as opaque trackable 51 objects rather than wrapped objects. When updating wrappers, prefer 52 saving new information but preserving compatibility with previous 53 wrapper versions. They are, however, useful for ensuring that 54 previously-released buggy wrapper versions degrade gracefully rather 55 than throwing exceptions when presented with newly-saved SavedModels. 56 bad_consumers: A list of consumer versions which are incompatible (in 57 addition to any version less than `min_consumer_version`). 58 setter: A callable with the same signature as `setattr` to use when adding 59 dependencies to generated objects. 60 """ 61 self.setter = setter 62 self.identifier = None # Set after registration 63 self._object_factory = object_factory 64 self.version = version 65 self._min_consumer_version = min_consumer_version 66 self._min_producer_version = min_producer_version 67 if bad_consumers is None: 68 bad_consumers = [] 69 self._bad_consumers = bad_consumers 70 71 def to_proto(self): 72 """Create a SavedUserObject proto.""" 73 # For now wrappers just use dependencies to save their state, so the 74 # SavedUserObject doesn't depend on the object being saved. 75 # TODO(allenl): Add a wrapper which uses its own proto. 76 return saved_object_graph_pb2.SavedUserObject( 77 identifier=self.identifier, 78 version=versions_pb2.VersionDef( 79 producer=self.version, 80 min_consumer=self._min_consumer_version, 81 bad_consumers=self._bad_consumers)) 82 83 def from_proto(self, proto): 84 """Recreate a trackable object from a SavedUserObject proto.""" 85 return self._object_factory(proto) 86 87 def should_load(self, proto): 88 """Checks if this object should load the SavedUserObject `proto`.""" 89 if proto.identifier != self.identifier: 90 return False 91 if self.version < proto.version.min_consumer: 92 return False 93 if proto.version.producer < self._min_producer_version: 94 return False 95 for bad_version in proto.version.bad_consumers: 96 if self.version == bad_version: 97 return False 98 return True 99 100 101# string identifier -> (predicate, [VersionedTypeRegistration]) 102_REVIVED_TYPE_REGISTRY = {} 103_TYPE_IDENTIFIERS = [] 104 105 106@tf_export("__internal__.saved_model.load.register_revived_type", v1=[]) 107def register_revived_type(identifier, predicate, versions): 108 """Register a type for revived objects. 109 110 Args: 111 identifier: A unique string identifying this class of objects. 112 predicate: A Boolean predicate for this registration. Takes a 113 trackable object as an argument. If True, `type_registration` may be 114 used to save and restore the object. 115 versions: A list of `VersionedTypeRegistration` objects. 116 """ 117 # Keep registrations in order of version. We always use the highest matching 118 # version (respecting the min consumer version and bad consumers). 119 versions.sort(key=lambda reg: reg.version, reverse=True) 120 if not versions: 121 raise AssertionError("Need at least one version of a registered type.") 122 version_numbers = set() 123 for registration in versions: 124 # Copy over the identifier for use in generating protos 125 registration.identifier = identifier 126 if registration.version in version_numbers: 127 raise AssertionError( 128 f"Got multiple registrations with version {registration.version} for " 129 f"type {identifier}.") 130 version_numbers.add(registration.version) 131 132 if identifier in _REVIVED_TYPE_REGISTRY: 133 raise AssertionError(f"Duplicate registrations for type '{identifier}'") 134 135 _REVIVED_TYPE_REGISTRY[identifier] = (predicate, versions) 136 _TYPE_IDENTIFIERS.append(identifier) 137 138 139def serialize(obj): 140 """Create a SavedUserObject from a trackable object.""" 141 for identifier in _TYPE_IDENTIFIERS: 142 predicate, versions = _REVIVED_TYPE_REGISTRY[identifier] 143 if predicate(obj): 144 # Always uses the most recent version to serialize. 145 return versions[0].to_proto() 146 return None 147 148 149def deserialize(proto): 150 """Create a trackable object from a SavedUserObject proto. 151 152 Args: 153 proto: A SavedUserObject to deserialize. 154 155 Returns: 156 A tuple of (trackable, assignment_fn) where assignment_fn has the same 157 signature as setattr and should be used to add dependencies to 158 `trackable` when they are available. 159 """ 160 _, type_registrations = _REVIVED_TYPE_REGISTRY.get( 161 proto.identifier, (None, None)) 162 if type_registrations is not None: 163 for type_registration in type_registrations: 164 if type_registration.should_load(proto): 165 return (type_registration.from_proto(proto), type_registration.setter) 166 return None 167 168 169@tf_export("__internal__.saved_model.load.registered_identifiers", v1=[]) 170def registered_identifiers(): 171 """Return all the current registered revived object identifiers. 172 173 Returns: 174 A set of strings. 175 """ 176 return _REVIVED_TYPE_REGISTRY.keys() 177 178 179@tf_export("__internal__.saved_model.load.get_setter", v1=[]) 180def get_setter(proto): 181 """Gets the registered setter function for the SavedUserObject proto. 182 183 See VersionedTypeRegistration for info about the setter function. 184 185 Args: 186 proto: SavedUserObject proto 187 188 Returns: 189 setter function 190 """ 191 _, type_registrations = _REVIVED_TYPE_REGISTRY.get( 192 proto.identifier, (None, None)) 193 if type_registrations is not None: 194 for type_registration in type_registrations: 195 if type_registration.should_load(proto): 196 return type_registration.setter 197 return None 198