xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saved_model/serialized_attributes.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes that list&validate all attributes to serialize to SavedModel.
16"""
17
18from tensorflow.python.eager import def_function
19from tensorflow.python.keras.saving.saved_model import constants
20from tensorflow.python.keras.saving.saved_model import save_impl
21from tensorflow.python.keras.utils.generic_utils import LazyLoader
22from tensorflow.python.trackable import base as trackable
23from tensorflow.python.trackable.autotrackable import AutoTrackable
24
25# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
26# once the issue with copybara is fixed.
27# pylint:disable=g-inconsistent-quotes
28base_layer = LazyLoader(
29    "base_layer", globals(),
30    "tensorflow.python.keras.engine.base_layer")
31training_lib = LazyLoader(
32    "training_lib", globals(),
33    "tensorflow.python.keras.engine.training")
34metrics = LazyLoader("metrics", globals(),
35                     "tensorflow.python.keras.metrics")
36recurrent = LazyLoader(
37    "recurrent", globals(),
38    "tensorflow.python.keras.layers.recurrent")
39# pylint:enable=g-inconsistent-quotes
40
41
42class SerializedAttributes(object):
43  """Class that tracks and validates all serialization attributes.
44
45  Keras models contain many Python-defined components. For example, the
46  trainable_variable property lists the model's trainable variables by
47  recursively retrieving the trainable variables from each of the child layers.
48  Another example is model.call, a python function that calls child layers and
49  adds ops to the backend graph.
50
51  Only Tensorflow checkpointable objects and functions can be serialized to
52  SavedModel. Serializing a Keras model as-is results in a checkpointable object
53  that does not resemble a Keras model at all. Thus, extra checkpointable
54  objects and functions must be created during serialization.
55
56  **Defining new serialized attributes**
57  Child classes should be defined using:
58    SerializedAttributes.with_attributes(
59        'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
60  This class is used to cache generated checkpointable objects and functions,
61  ensuring that new objects and functions are generated a single time.
62
63  **Usage during serialization**
64  Each Layer/Model object should have a corresponding instance of
65  SerializedAttributes. Create a new instance by calling
66  `SerializedAttributes.new(obj)`. Objects and functions may be saved using
67  `.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
68  The properties `.checkpointable_objects` and `.functions` returns the cached
69  values.
70
71  **Adding/changing attributes to save to SavedModel**
72  1. Change the call to `SerializedAttributes.with_attributes` in the correct
73     class:
74     - CommonEndpoints: Base attributes to be added during serialization. If
75       these attributes are present in a Trackable object, it can be
76       deserialized to a Keras Model.
77     - LayerAttributes: Attributes to serialize for Layer objects.
78     - ModelAttributes: Attributes to serialize for Model objects.
79  2. Update class docstring
80  3. Update arguments to any calls to `set_and_validate_*`. For example, if
81     `call_raw_tensors` is added to the ModelAttributes function list, then
82     a `call_raw_tensors` function should be passed to
83     `set_and_validate_functions`.
84
85  **Common endpoints vs other attributes**
86  Only common endpoints are attached directly to the root object. Keras-specific
87  attributes are saved to a separate trackable object with the name "keras_api".
88  The number of objects attached to the root is limited because any naming
89  conflicts will cause user code to break.
90
91  Another reason is that this will only affect users who call
92  `tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
93  advanced users who are likely to have defined their own tf.functions and
94  trackable objects. The added Keras-specific attributes are kept out of the way
95  in the "keras_api" namespace.
96
97  Properties defined in this class may be used to filter out keras-specific
98  attributes:
99  - `functions_to_serialize`: Returns dict of functions to attach to the root
100      object.
101  - `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
102      the root object (including separate trackable object containing
103      keras-specific attributes)
104
105  All changes to the serialized attributes must be backwards-compatible, so
106  attributes should not be removed or modified without sufficient justification.
107  """
108
109  @staticmethod
110  def with_attributes(
111      name, checkpointable_objects=None, functions=None, copy_from=None):
112    """Creates a subclass with all attributes as specified in the arguments.
113
114    Args:
115      name: Name of subclass
116      checkpointable_objects: List of checkpointable objects to be serialized
117        in the SavedModel.
118      functions: List of functions to be serialized in the SavedModel.
119      copy_from: List of other SerializedAttributes subclasses. The returned
120        class will copy checkpoint objects/functions from each subclass.
121
122    Returns:
123      Child class with attributes as defined in the `checkpointable_objects`
124      and `functions` lists.
125    """
126    checkpointable_objects = checkpointable_objects or []
127    functions = functions or []
128
129    if copy_from is not None:
130      for cls in copy_from:
131        checkpointable_objects.extend(cls.all_checkpointable_objects)
132        functions.extend(cls.all_functions)
133
134    classdict = {
135        'all_checkpointable_objects': set(checkpointable_objects),
136        'all_functions': set(functions)}
137    return type(name, (SerializedAttributes,), classdict)
138
139  @staticmethod
140  def new(obj):
141    """Returns a new SerializedAttribute object."""
142    if isinstance(obj, training_lib.Model):
143      return ModelAttributes()
144    elif isinstance(obj, metrics.Metric):
145      return MetricAttributes()
146    elif isinstance(obj, recurrent.RNN):
147      return RNNAttributes()
148    elif isinstance(obj, base_layer.Layer):
149      return LayerAttributes()
150    else:
151      raise TypeError('Internal error during serialization: Expected Keras '
152                      'Layer object, got {} of type {}'.format(obj, type(obj)))
153
154  def __init__(self):
155    self._object_dict = {}
156    self._function_dict = {}
157    self._keras_trackable = AutoTrackable()
158
159  @property
160  def functions(self):
161    """Returns dictionary of all functions."""
162    return {key: value for key, value in self._function_dict.items()
163            if value is not None}
164
165  @property
166  def checkpointable_objects(self):
167    """Returns dictionary of all checkpointable objects."""
168    return {key: value for key, value in self._object_dict.items()
169            if value is not None}
170
171  @property
172  def functions_to_serialize(self):
173    """Returns functions to attach to the root object during serialization."""
174    functions = {}
175    for key, v in self.functions.items():
176      if key in CommonEndpoints.all_functions:
177        functions[key] = (v.wrapped_call if isinstance(v, save_impl.LayerCall)
178                          else v)
179    return functions
180
181  @property
182  def objects_to_serialize(self):
183    """Returns objects to attach to the root object during serialization."""
184    objects = {key: value for key, value in self.checkpointable_objects.items()
185               if key in CommonEndpoints.all_checkpointable_objects}
186    objects[constants.KERAS_ATTR] = self._keras_trackable
187    return objects
188
189  def set_and_validate_functions(self, function_dict):
190    """Saves function dictionary, and validates dictionary values."""
191    for key in self.all_functions:
192      if key in function_dict:
193        if (function_dict[key] is not None and  # Not all functions are required
194            not isinstance(function_dict[key],
195                           (def_function.Function, save_impl.LayerCall))):
196          raise ValueError(
197              'Function dictionary contained a non-function object: {} (for key'
198              ' {})'.format(function_dict[key], key))
199        fn = function_dict[key]
200        self._function_dict[key] = fn
201
202        # Extract TensorFlow `Function` from LayerCall.
203        tf_fn = fn.wrapped_call if isinstance(fn, save_impl.LayerCall) else fn
204        setattr(self._keras_trackable, key, tf_fn)
205      else:
206        raise ValueError('Function {} missing from serialized function dict.'
207                         .format(key))
208    return self.functions
209
210  def set_and_validate_objects(self, object_dict):
211    """Saves objects to a dictionary, and validates the values."""
212    for key in self.all_checkpointable_objects:
213      if key in object_dict:
214        if not isinstance(object_dict[key], trackable.Trackable):
215          raise ValueError(
216              'Object dictionary contained a non-trackable object: {} (for key'
217              ' {})'.format(object_dict[key], key))
218        self._object_dict[key] = object_dict[key]
219        setattr(self._keras_trackable, key, object_dict[key])
220      else:
221        raise ValueError(
222            'Object {} missing from serialized object dict.'.format(key))
223    return self.checkpointable_objects
224
225
226class CommonEndpoints(SerializedAttributes.with_attributes(
227    'CommonEndpoints',
228    checkpointable_objects=['variables', 'trainable_variables',
229                            'regularization_losses'],
230    functions=['__call__', 'call_and_return_all_conditional_losses',
231               '_default_save_signature'])):
232  """Common endpoints shared by all models loadable by Keras.
233
234  List of all attributes:
235    variables: List of all variables in the model and its sublayers.
236    trainable_variables: List of all trainable variables in the model and its
237      sublayers.
238    regularization_losses: List of all unconditional losses (losses not
239      dependent on the inputs) in the model and its sublayers.
240    __call__: Function that takes inputs and returns the outputs of the model
241      call function.
242    call_and_return_all_conditional_losses: Function that returns a tuple of
243      (call function outputs, list of all losses that depend on the inputs).
244    _default_save_signature: Traced model call function. This is only included
245      if the top level exported object is a Keras model.
246  """
247
248
249class LayerAttributes(SerializedAttributes.with_attributes(
250    'LayerAttributes',
251    checkpointable_objects=['non_trainable_variables', 'layers', 'metrics',
252                            'layer_regularization_losses', 'layer_metrics'],
253    functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
254    copy_from=[CommonEndpoints]
255    )):
256  """Layer checkpointable objects + functions that are saved to the SavedModel.
257
258  List of all attributes:
259    All attributes from CommonEndpoints
260    non_trainable_variables: List of non-trainable variables in the layer and
261      its sublayers.
262    layers: List of all sublayers.
263    metrics: List of all metrics in the layer and its sublayers.
264    call_and_return_conditional_losses: Function that takes inputs and returns a
265      tuple of (outputs of the call function, list of input-dependent losses).
266      The list of losses excludes the activity regularizer function, which is
267      separate to allow the deserialized Layer object to define a different
268      activity regularizer.
269    activity_regularizer_fn: Callable that returns the activity regularizer loss
270    layer_regularization_losses: List of losses owned only by this layer.
271    layer_metrics: List of metrics owned by this layer.
272  """
273
274
275class ModelAttributes(SerializedAttributes.with_attributes(
276    'ModelAttributes',
277    copy_from=[LayerAttributes])):
278  """Model checkpointable objects + functions that are saved to the SavedModel.
279
280  List of all attributes:
281    All attributes from LayerAttributes (including CommonEndpoints)
282  """
283  # TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
284  #  list all losses and metrics defined by `model.compile`.
285
286
287class MetricAttributes(
288    SerializedAttributes.with_attributes(
289        'MetricAttributes',
290        checkpointable_objects=['variables'],
291        functions=[],
292    )):
293  """Attributes that are added to Metric objects when saved to SavedModel.
294
295  List of all attributes:
296    variables: list of all variables
297  """
298  pass
299
300
301class RNNAttributes(SerializedAttributes.with_attributes(
302    'RNNAttributes',
303    checkpointable_objects=['states'],
304    copy_from=[LayerAttributes])):
305  """RNN checkpointable objects + functions that are saved to the SavedModel.
306
307  List of all attributes:
308    All attributes from LayerAttributes (including CommonEndpoints)
309    states: List of state variables
310  """
311
312