1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes that list&validate all attributes to serialize to SavedModel. 16""" 17 18from tensorflow.python.eager import def_function 19from tensorflow.python.keras.saving.saved_model import constants 20from tensorflow.python.keras.saving.saved_model import save_impl 21from tensorflow.python.keras.utils.generic_utils import LazyLoader 22from tensorflow.python.trackable import base as trackable 23from tensorflow.python.trackable.autotrackable import AutoTrackable 24 25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 26# once the issue with copybara is fixed. 27# pylint:disable=g-inconsistent-quotes 28base_layer = LazyLoader( 29 "base_layer", globals(), 30 "tensorflow.python.keras.engine.base_layer") 31training_lib = LazyLoader( 32 "training_lib", globals(), 33 "tensorflow.python.keras.engine.training") 34metrics = LazyLoader("metrics", globals(), 35 "tensorflow.python.keras.metrics") 36recurrent = LazyLoader( 37 "recurrent", globals(), 38 "tensorflow.python.keras.layers.recurrent") 39# pylint:enable=g-inconsistent-quotes 40 41 42class SerializedAttributes(object): 43 """Class that tracks and validates all serialization attributes. 44 45 Keras models contain many Python-defined components. For example, the 46 trainable_variable property lists the model's trainable variables by 47 recursively retrieving the trainable variables from each of the child layers. 48 Another example is model.call, a python function that calls child layers and 49 adds ops to the backend graph. 50 51 Only Tensorflow checkpointable objects and functions can be serialized to 52 SavedModel. Serializing a Keras model as-is results in a checkpointable object 53 that does not resemble a Keras model at all. Thus, extra checkpointable 54 objects and functions must be created during serialization. 55 56 **Defining new serialized attributes** 57 Child classes should be defined using: 58 SerializedAttributes.with_attributes( 59 'name', checkpointable_objects=[...], functions=[...], copy_from=[...]) 60 This class is used to cache generated checkpointable objects and functions, 61 ensuring that new objects and functions are generated a single time. 62 63 **Usage during serialization** 64 Each Layer/Model object should have a corresponding instance of 65 SerializedAttributes. Create a new instance by calling 66 `SerializedAttributes.new(obj)`. Objects and functions may be saved using 67 `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`. 68 The properties `.checkpointable_objects` and `.functions` returns the cached 69 values. 70 71 **Adding/changing attributes to save to SavedModel** 72 1. Change the call to `SerializedAttributes.with_attributes` in the correct 73 class: 74 - CommonEndpoints: Base attributes to be added during serialization. If 75 these attributes are present in a Trackable object, it can be 76 deserialized to a Keras Model. 77 - LayerAttributes: Attributes to serialize for Layer objects. 78 - ModelAttributes: Attributes to serialize for Model objects. 79 2. Update class docstring 80 3. Update arguments to any calls to `set_and_validate_*`. For example, if 81 `call_raw_tensors` is added to the ModelAttributes function list, then 82 a `call_raw_tensors` function should be passed to 83 `set_and_validate_functions`. 84 85 **Common endpoints vs other attributes** 86 Only common endpoints are attached directly to the root object. Keras-specific 87 attributes are saved to a separate trackable object with the name "keras_api". 88 The number of objects attached to the root is limited because any naming 89 conflicts will cause user code to break. 90 91 Another reason is that this will only affect users who call 92 `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are 93 advanced users who are likely to have defined their own tf.functions and 94 trackable objects. The added Keras-specific attributes are kept out of the way 95 in the "keras_api" namespace. 96 97 Properties defined in this class may be used to filter out keras-specific 98 attributes: 99 - `functions_to_serialize`: Returns dict of functions to attach to the root 100 object. 101 - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to 102 the root object (including separate trackable object containing 103 keras-specific attributes) 104 105 All changes to the serialized attributes must be backwards-compatible, so 106 attributes should not be removed or modified without sufficient justification. 107 """ 108 109 @staticmethod 110 def with_attributes( 111 name, checkpointable_objects=None, functions=None, copy_from=None): 112 """Creates a subclass with all attributes as specified in the arguments. 113 114 Args: 115 name: Name of subclass 116 checkpointable_objects: List of checkpointable objects to be serialized 117 in the SavedModel. 118 functions: List of functions to be serialized in the SavedModel. 119 copy_from: List of other SerializedAttributes subclasses. The returned 120 class will copy checkpoint objects/functions from each subclass. 121 122 Returns: 123 Child class with attributes as defined in the `checkpointable_objects` 124 and `functions` lists. 125 """ 126 checkpointable_objects = checkpointable_objects or [] 127 functions = functions or [] 128 129 if copy_from is not None: 130 for cls in copy_from: 131 checkpointable_objects.extend(cls.all_checkpointable_objects) 132 functions.extend(cls.all_functions) 133 134 classdict = { 135 'all_checkpointable_objects': set(checkpointable_objects), 136 'all_functions': set(functions)} 137 return type(name, (SerializedAttributes,), classdict) 138 139 @staticmethod 140 def new(obj): 141 """Returns a new SerializedAttribute object.""" 142 if isinstance(obj, training_lib.Model): 143 return ModelAttributes() 144 elif isinstance(obj, metrics.Metric): 145 return MetricAttributes() 146 elif isinstance(obj, recurrent.RNN): 147 return RNNAttributes() 148 elif isinstance(obj, base_layer.Layer): 149 return LayerAttributes() 150 else: 151 raise TypeError('Internal error during serialization: Expected Keras ' 152 'Layer object, got {} of type {}'.format(obj, type(obj))) 153 154 def __init__(self): 155 self._object_dict = {} 156 self._function_dict = {} 157 self._keras_trackable = AutoTrackable() 158 159 @property 160 def functions(self): 161 """Returns dictionary of all functions.""" 162 return {key: value for key, value in self._function_dict.items() 163 if value is not None} 164 165 @property 166 def checkpointable_objects(self): 167 """Returns dictionary of all checkpointable objects.""" 168 return {key: value for key, value in self._object_dict.items() 169 if value is not None} 170 171 @property 172 def functions_to_serialize(self): 173 """Returns functions to attach to the root object during serialization.""" 174 functions = {} 175 for key, v in self.functions.items(): 176 if key in CommonEndpoints.all_functions: 177 functions[key] = (v.wrapped_call if isinstance(v, save_impl.LayerCall) 178 else v) 179 return functions 180 181 @property 182 def objects_to_serialize(self): 183 """Returns objects to attach to the root object during serialization.""" 184 objects = {key: value for key, value in self.checkpointable_objects.items() 185 if key in CommonEndpoints.all_checkpointable_objects} 186 objects[constants.KERAS_ATTR] = self._keras_trackable 187 return objects 188 189 def set_and_validate_functions(self, function_dict): 190 """Saves function dictionary, and validates dictionary values.""" 191 for key in self.all_functions: 192 if key in function_dict: 193 if (function_dict[key] is not None and # Not all functions are required 194 not isinstance(function_dict[key], 195 (def_function.Function, save_impl.LayerCall))): 196 raise ValueError( 197 'Function dictionary contained a non-function object: {} (for key' 198 ' {})'.format(function_dict[key], key)) 199 fn = function_dict[key] 200 self._function_dict[key] = fn 201 202 # Extract TensorFlow `Function` from LayerCall. 203 tf_fn = fn.wrapped_call if isinstance(fn, save_impl.LayerCall) else fn 204 setattr(self._keras_trackable, key, tf_fn) 205 else: 206 raise ValueError('Function {} missing from serialized function dict.' 207 .format(key)) 208 return self.functions 209 210 def set_and_validate_objects(self, object_dict): 211 """Saves objects to a dictionary, and validates the values.""" 212 for key in self.all_checkpointable_objects: 213 if key in object_dict: 214 if not isinstance(object_dict[key], trackable.Trackable): 215 raise ValueError( 216 'Object dictionary contained a non-trackable object: {} (for key' 217 ' {})'.format(object_dict[key], key)) 218 self._object_dict[key] = object_dict[key] 219 setattr(self._keras_trackable, key, object_dict[key]) 220 else: 221 raise ValueError( 222 'Object {} missing from serialized object dict.'.format(key)) 223 return self.checkpointable_objects 224 225 226class CommonEndpoints(SerializedAttributes.with_attributes( 227 'CommonEndpoints', 228 checkpointable_objects=['variables', 'trainable_variables', 229 'regularization_losses'], 230 functions=['__call__', 'call_and_return_all_conditional_losses', 231 '_default_save_signature'])): 232 """Common endpoints shared by all models loadable by Keras. 233 234 List of all attributes: 235 variables: List of all variables in the model and its sublayers. 236 trainable_variables: List of all trainable variables in the model and its 237 sublayers. 238 regularization_losses: List of all unconditional losses (losses not 239 dependent on the inputs) in the model and its sublayers. 240 __call__: Function that takes inputs and returns the outputs of the model 241 call function. 242 call_and_return_all_conditional_losses: Function that returns a tuple of 243 (call function outputs, list of all losses that depend on the inputs). 244 _default_save_signature: Traced model call function. This is only included 245 if the top level exported object is a Keras model. 246 """ 247 248 249class LayerAttributes(SerializedAttributes.with_attributes( 250 'LayerAttributes', 251 checkpointable_objects=['non_trainable_variables', 'layers', 'metrics', 252 'layer_regularization_losses', 'layer_metrics'], 253 functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'], 254 copy_from=[CommonEndpoints] 255 )): 256 """Layer checkpointable objects + functions that are saved to the SavedModel. 257 258 List of all attributes: 259 All attributes from CommonEndpoints 260 non_trainable_variables: List of non-trainable variables in the layer and 261 its sublayers. 262 layers: List of all sublayers. 263 metrics: List of all metrics in the layer and its sublayers. 264 call_and_return_conditional_losses: Function that takes inputs and returns a 265 tuple of (outputs of the call function, list of input-dependent losses). 266 The list of losses excludes the activity regularizer function, which is 267 separate to allow the deserialized Layer object to define a different 268 activity regularizer. 269 activity_regularizer_fn: Callable that returns the activity regularizer loss 270 layer_regularization_losses: List of losses owned only by this layer. 271 layer_metrics: List of metrics owned by this layer. 272 """ 273 274 275class ModelAttributes(SerializedAttributes.with_attributes( 276 'ModelAttributes', 277 copy_from=[LayerAttributes])): 278 """Model checkpointable objects + functions that are saved to the SavedModel. 279 280 List of all attributes: 281 All attributes from LayerAttributes (including CommonEndpoints) 282 """ 283 # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which 284 # list all losses and metrics defined by `model.compile`. 285 286 287class MetricAttributes( 288 SerializedAttributes.with_attributes( 289 'MetricAttributes', 290 checkpointable_objects=['variables'], 291 functions=[], 292 )): 293 """Attributes that are added to Metric objects when saved to SavedModel. 294 295 List of all attributes: 296 variables: list of all variables 297 """ 298 pass 299 300 301class RNNAttributes(SerializedAttributes.with_attributes( 302 'RNNAttributes', 303 checkpointable_objects=['states'], 304 copy_from=[LayerAttributes])): 305 """RNN checkpointable objects + functions that are saved to the SavedModel. 306 307 List of all attributes: 308 All attributes from LayerAttributes (including CommonEndpoints) 309 states: List of state variables 310 """ 311 312