xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saving_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utils related to keras model saving."""
16
17import collections
18import copy
19import os
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.keras import backend as K
23from tensorflow.python.keras import losses
24from tensorflow.python.keras import optimizer_v1
25from tensorflow.python.keras import optimizers
26from tensorflow.python.keras.engine import base_layer_utils
27from tensorflow.python.keras.utils import generic_utils
28from tensorflow.python.keras.utils import version_utils
29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import nest
32
33
34def extract_model_metrics(model):
35  """Convert metrics from a Keras model `compile` API to dictionary.
36
37  This is used for converting Keras models to Estimators and SavedModels.
38
39  Args:
40    model: A `tf.keras.Model` object.
41
42  Returns:
43    Dictionary mapping metric names to metric instances. May return `None` if
44    the model does not contain any metrics.
45  """
46  if getattr(model, '_compile_metrics', None):
47    # TODO(psv/kathywu): use this implementation in model to estimator flow.
48    # We are not using model.metrics here because we want to exclude the metrics
49    # added using `add_metric` API.
50    return {m.name: m for m in model._compile_metric_functions}  # pylint: disable=protected-access
51  return None
52
53
54def model_input_signature(model, keep_original_batch_size=False):
55  """Inspect model to get its input signature.
56
57  The model's input signature is a list with a single (possibly-nested) object.
58  This is due to the Keras-enforced restriction that tensor inputs must be
59  passed in as the first argument.
60
61  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
62  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
63
64  Args:
65    model: Keras Model object.
66    keep_original_batch_size: A boolean indicating whether we want to keep using
67      the original batch size or set it to None. Default is `False`, which means
68      that the batch dim of the returned input signature will always be set to
69      `None`.
70
71  Returns:
72    A list containing either a single TensorSpec or an object with nested
73    TensorSpecs. This list does not contain the `training` argument.
74  """
75  input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size)  # pylint: disable=protected-access
76  if input_specs is None:
77    return None
78  input_specs = _enforce_names_consistency(input_specs)
79  # Return a list with a single element as the model's input signature.
80  if isinstance(input_specs,
81                collections.abc.Sequence) and len(input_specs) == 1:
82    # Note that the isinstance check filters out single-element dictionaries,
83    # which should also be wrapped as a single-element list.
84    return input_specs
85  else:
86    return [input_specs]
87
88
89def raise_model_input_error(model):
90  raise ValueError(
91      'Model {} cannot be saved because the input shapes have not been '
92      'set. Usually, input shapes are automatically determined from calling'
93      ' `.fit()` or `.predict()`. To manually set the shapes, call '
94      '`model.build(input_shape)`.'.format(model))
95
96
97def trace_model_call(model, input_signature=None):
98  """Trace the model call to create a tf.function for exporting a Keras model.
99
100  Args:
101    model: A Keras model.
102    input_signature: optional, a list of tf.TensorSpec objects specifying the
103      inputs to the model.
104
105  Returns:
106    A tf.function wrapping the model's call function with input signatures set.
107
108  Raises:
109    ValueError: if input signature cannot be inferred from the model.
110  """
111  if input_signature is None:
112    if isinstance(model.call, def_function.Function):
113      input_signature = model.call.input_signature
114
115  if input_signature is None:
116    input_signature = model_input_signature(model)
117
118  if input_signature is None:
119    raise_model_input_error(model)
120
121  @def_function.function(input_signature=input_signature)
122  def _wrapped_model(*args):
123    """A concrete tf.function that wraps the model's call function."""
124    # When given a single input, Keras models will call the model on the tensor
125    # rather than a list consisting of the single tensor.
126    inputs = args[0] if len(input_signature) == 1 else list(args)
127
128    with base_layer_utils.call_context().enter(
129        model, inputs=inputs, build_graph=False, training=False, saving=True):
130      outputs = model(inputs, training=False)
131
132    # Outputs always has to be a flat dict.
133    output_names = model.output_names  # Functional Model.
134    if output_names is None:  # Subclassed Model.
135      from tensorflow.python.keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
136      output_names = compile_utils.create_pseudo_output_names(outputs)
137    outputs = nest.flatten(outputs)
138    return {name: output for name, output in zip(output_names, outputs)}
139
140  return _wrapped_model
141
142
143def model_metadata(model, include_optimizer=True, require_config=True):
144  """Returns a dictionary containing the model metadata."""
145  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
146  from tensorflow.python.keras.optimizer_v2 import optimizer_v2  # pylint: disable=g-import-not-at-top
147
148  model_config = {'class_name': model.__class__.__name__}
149  try:
150    model_config['config'] = model.get_config()
151  except NotImplementedError as e:
152    if require_config:
153      raise e
154
155  metadata = dict(
156      keras_version=str(keras_version),
157      backend=K.backend(),
158      model_config=model_config)
159  if model.optimizer and include_optimizer:
160    if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
161      logging.warning(
162          'TensorFlow optimizers do not '
163          'make it possible to access '
164          'optimizer attributes or optimizer state '
165          'after instantiation. '
166          'As a result, we cannot save the optimizer '
167          'as part of the model save file. '
168          'You will have to compile your model again after loading it. '
169          'Prefer using a Keras optimizer instead '
170          '(see keras.io/optimizers).')
171    elif model._compile_was_called:  # pylint: disable=protected-access
172      training_config = model._get_compile_args(user_metrics=False)  # pylint: disable=protected-access
173      training_config.pop('optimizer', None)  # Handled separately.
174      metadata['training_config'] = _serialize_nested_config(training_config)
175      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
176        raise NotImplementedError(
177            'As of now, Optimizers loaded from SavedModel cannot be saved. '
178            'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
179            ' please set the `include_optimizer` option to `False`. For '
180            '`tf.saved_model.save`, delete the optimizer from the model.')
181      else:
182        optimizer_config = {
183            'class_name':
184                generic_utils.get_registered_name(model.optimizer.__class__),
185            'config':
186                model.optimizer.get_config()
187        }
188      metadata['training_config']['optimizer_config'] = optimizer_config
189  return metadata
190
191
192def should_overwrite(filepath, overwrite):
193  """Returns whether the filepath should be overwritten."""
194  # If file exists and should not be overwritten.
195  if not overwrite and os.path.isfile(filepath):
196    return ask_to_proceed_with_overwrite(filepath)
197  return True
198
199
200def compile_args_from_training_config(training_config, custom_objects=None):
201  """Return model.compile arguments from training config."""
202  if custom_objects is None:
203    custom_objects = {}
204
205  with generic_utils.CustomObjectScope(custom_objects):
206    optimizer_config = training_config['optimizer_config']
207    optimizer = optimizers.deserialize(optimizer_config)
208
209    # Recover losses.
210    loss = None
211    loss_config = training_config.get('loss', None)
212    if loss_config is not None:
213      loss = _deserialize_nested_config(losses.deserialize, loss_config)
214
215    # Recover metrics.
216    metrics = None
217    metrics_config = training_config.get('metrics', None)
218    if metrics_config is not None:
219      metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)
220
221    # Recover weighted metrics.
222    weighted_metrics = None
223    weighted_metrics_config = training_config.get('weighted_metrics', None)
224    if weighted_metrics_config is not None:
225      weighted_metrics = _deserialize_nested_config(_deserialize_metric,
226                                                    weighted_metrics_config)
227
228    sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
229        training_config, 'sample_weight_mode') else None
230    loss_weights = training_config['loss_weights']
231
232  return dict(
233      optimizer=optimizer,
234      loss=loss,
235      metrics=metrics,
236      weighted_metrics=weighted_metrics,
237      loss_weights=loss_weights,
238      sample_weight_mode=sample_weight_mode)
239
240
241def _deserialize_nested_config(deserialize_fn, config):
242  """Deserializes arbitrary Keras `config` using `deserialize_fn`."""
243
244  def _is_single_object(obj):
245    if isinstance(obj, dict) and 'class_name' in obj:
246      return True  # Serialized Keras object.
247    if isinstance(obj, str):
248      return True  # Serialized function or string.
249    return False
250
251  if config is None:
252    return None
253  if _is_single_object(config):
254    return deserialize_fn(config)
255  elif isinstance(config, dict):
256    return {
257        k: _deserialize_nested_config(deserialize_fn, v)
258        for k, v in config.items()
259    }
260  elif isinstance(config, (tuple, list)):
261    return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]
262
263  raise ValueError('Saved configuration not understood.')
264
265
266def _serialize_nested_config(config):
267  """Serialized a nested structure of Keras objects."""
268
269  def _serialize_fn(obj):
270    if callable(obj):
271      return generic_utils.serialize_keras_object(obj)
272    return obj
273
274  return nest.map_structure(_serialize_fn, config)
275
276
277def _deserialize_metric(metric_config):
278  """Deserialize metrics, leaving special strings untouched."""
279  from tensorflow.python.keras import metrics as metrics_module  # pylint:disable=g-import-not-at-top
280  if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
281    # Do not deserialize accuracy and cross-entropy strings as we have special
282    # case handling for these in compile, based on model output shape.
283    return metric_config
284  return metrics_module.deserialize(metric_config)
285
286
287def _enforce_names_consistency(specs):
288  """Enforces that either all specs have names or none do."""
289
290  def _has_name(spec):
291    return hasattr(spec, 'name') and spec.name is not None
292
293  def _clear_name(spec):
294    spec = copy.deepcopy(spec)
295    if hasattr(spec, 'name'):
296      spec._name = None  # pylint:disable=protected-access
297    return spec
298
299  flat_specs = nest.flatten(specs)
300  name_inconsistency = (
301      any(_has_name(s) for s in flat_specs) and
302      not all(_has_name(s) for s in flat_specs))
303
304  if name_inconsistency:
305    specs = nest.map_structure(_clear_name, specs)
306  return specs
307
308
309def try_build_compiled_arguments(model):
310  if (not version_utils.is_v1_layer_or_model(model) and
311      model.outputs is not None):
312    try:
313      if not model.compiled_loss.built:
314        model.compiled_loss.build(model.outputs)
315      if not model.compiled_metrics.built:
316        model.compiled_metrics.build(model.outputs, model.outputs)
317    except:  # pylint: disable=bare-except
318      logging.warning(
319          'Compiled the loaded model, but the compiled metrics have yet to '
320          'be built. `model.compile_metrics` will be empty until you train '
321          'or evaluate the model.')
322
323
324def is_hdf5_filepath(filepath):
325  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
326          filepath.endswith('.hdf5'))
327