xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/saved_model/load.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel deserialization."""
16
17import os
18import re
19import types
20
21from google.protobuf import message
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import regularizers
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.optimizer_v2 import optimizer_v2
32from tensorflow.python.keras.protobuf import saved_metadata_pb2
33from tensorflow.python.keras.protobuf import versions_pb2
34from tensorflow.python.keras.saving import saving_utils
35from tensorflow.python.keras.saving.saved_model import constants
36from tensorflow.python.keras.saving.saved_model import json_utils
37from tensorflow.python.keras.saving.saved_model import utils
38from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
39from tensorflow.python.keras.utils import generic_utils
40from tensorflow.python.keras.utils import metrics_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.platform import gfile
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.saved_model import load as tf_load
46from tensorflow.python.saved_model import loader_impl
47from tensorflow.python.saved_model import nested_structure_coder
48from tensorflow.python.saved_model import revived_types
49from tensorflow.python.trackable import base as trackable
50from tensorflow.python.trackable import data_structures
51from tensorflow.python.util import compat
52from tensorflow.python.util import nest
53
54# To avoid circular dependencies between keras/engine and keras/saving,
55# code in keras/saving must delay imports.
56
57# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
58# once the issue with copybara is fixed.
59# pylint:disable=g-inconsistent-quotes
60models_lib = LazyLoader("models_lib", globals(),
61                        "tensorflow.python.keras.models")
62base_layer = LazyLoader(
63    "base_layer", globals(),
64    "tensorflow.python.keras.engine.base_layer")
65layers_module = LazyLoader(
66    "layers_module", globals(),
67    "tensorflow.python.keras.layers")
68input_layer = LazyLoader(
69    "input_layer", globals(),
70    "tensorflow.python.keras.engine.input_layer")
71functional_lib = LazyLoader(
72    "functional_lib", globals(),
73    "tensorflow.python.keras.engine.functional")
74training_lib = LazyLoader(
75    "training_lib", globals(),
76    "tensorflow.python.keras.engine.training")
77training_lib_v1 = LazyLoader(
78    "training_lib_v1", globals(),
79    "tensorflow.python.keras.engine.training_v1")
80metrics = LazyLoader("metrics", globals(),
81                     "tensorflow.python.keras.metrics")
82recurrent = LazyLoader(
83    "recurrent", globals(),
84    "tensorflow.python.keras.layers.recurrent")
85# pylint:enable=g-inconsistent-quotes
86
87
88PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
89    CommonEndpoints.all_checkpointable_objects)
90PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
91
92
93def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
94  """Loads Keras objects from a SavedModel.
95
96  Any Keras layer or model saved to the SavedModel will be loaded back
97  as Keras objects. Other objects are loaded as regular trackable objects (same
98  as `tf.saved_model.load`).
99
100  Currently, Keras saving/loading only retains the Keras object's weights,
101  losses, and call function.
102
103  The loaded model can be re-compiled, but the original optimizer, compiled loss
104  functions, and metrics are not retained. This is temporary, and `model.save`
105  will soon be able to serialize compiled models.
106
107  Args:
108    path: Path to SavedModel.
109    compile: If true, compile the model after loading it.
110    options: Optional `tf.saved_model.LoadOptions` object that specifies
111      options for loading from SavedModel.
112
113
114  Returns:
115    Object loaded from SavedModel.
116  """
117  # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
118  # TODO(kathywu): Add code to load from objects that contain all endpoints
119
120  # Look for metadata file or parse the SavedModel
121  metadata = saved_metadata_pb2.SavedMetadata()
122  meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
123  object_graph_def = meta_graph_def.object_graph_def
124  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
125  if gfile.Exists(path_to_metadata_pb):
126    try:
127      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
128        file_content = f.read()
129      metadata.ParseFromString(file_content)
130    except message.DecodeError as e:
131      raise IOError('Cannot parse keras metadata {}: {}.'
132                    .format(path_to_metadata_pb, str(e)))
133  else:
134    logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
135                    'Keras model. Please ensure that you are saving the model '
136                    'with model.save() or tf.keras.models.save_model(), *NOT* '
137                    'tf.saved_model.save(). To confirm, there should be a file '
138                    'named "keras_metadata.pb" in the SavedModel directory.')
139    _read_legacy_metadata(object_graph_def, metadata)
140
141  if not metadata.nodes:
142    # When there are no Keras objects, return the results from the core loader
143    return tf_load.load(path, options=options)
144
145  # Recreate layers and metrics using the info stored in the metadata.
146  keras_loader = KerasObjectLoader(metadata, object_graph_def)
147  keras_loader.load_layers(compile=compile)
148
149  # Generate a dictionary of all loaded nodes.
150  nodes_to_load = {'root': None}
151  for node_id, loaded_node in keras_loader.loaded_nodes.items():
152    nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
153  loaded = tf_load.load_partial(path, nodes_to_load, options=options)
154
155  # Finalize the loaded layers and remove the extra tracked dependencies.
156  keras_loader.finalize_objects()
157  keras_loader.del_tracking()
158
159  model = loaded['root']
160
161  # pylint: disable=protected-access
162  if isinstance(model, training_lib.Model) and compile:
163    # TODO(kathywu): Use compiled objects from SavedModel, instead of
164    # creating new objects from the training config.
165    training_config = model._serialized_attributes['metadata'].get(
166        'training_config', None)
167    if training_config is not None:
168      model.compile(**saving_utils.compile_args_from_training_config(
169          training_config), from_serialized=True)
170      saving_utils.try_build_compiled_arguments(model)
171      if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
172        if (model.optimizer.get_slot_names()):
173          logging.warning('Your optimizer uses slots. '
174                          'Slots cannot be restored from saved_model, '
175                          'as a result, your model is starting with  '
176                          'a new initialized optimizer.')
177    else:
178      logging.warning('No training configuration found in save file, so the '
179                      'model was *not* compiled. Compile it manually.')
180  # pylint: enable=protected-access
181
182  # Force variables and resources to initialize.
183  if not context.executing_eagerly():
184    sess = backend.get_session()  # Variables are initialized by this call.
185    sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS))
186
187  return model
188
189
190def _read_legacy_metadata(object_graph_def, metadata):
191  """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
192  # Older SavedModels store the metadata directly in the proto instead of the
193  # separate pb file.
194  node_paths = _generate_object_paths(object_graph_def)
195  for node_id, proto in enumerate(object_graph_def.nodes):
196    if (proto.WhichOneof('kind') == 'user_object' and
197        proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
198      if not proto.user_object.metadata:
199        raise ValueError('Unable to create a Keras model from this SavedModel. '
200                         'This SavedModel was created with '
201                         '`tf.saved_model.save`, and lacks the Keras metadata.'
202                         'Please save your Keras model by calling `model.save`'
203                         'or `tf.keras.models.save_model`.')
204      metadata.nodes.add(
205          node_id=node_id,
206          node_path=node_paths[node_id],
207          version=versions_pb2.VersionDef(
208              producer=1, min_consumer=1, bad_consumers=[]),
209          identifier=proto.user_object.identifier,
210          metadata=proto.user_object.metadata)
211
212
213def _generate_object_paths(object_graph_def):
214  """Traverses through an ObjectGraphDef and builds a map of all node paths."""
215  paths = {0: 'root'}
216  nodes_to_visit = [0]
217
218  while nodes_to_visit:
219    current_node = nodes_to_visit.pop()
220    current_path = paths[current_node]
221    for reference in object_graph_def.nodes[current_node].children:
222      if reference.node_id in paths:
223        continue
224      paths[reference.node_id] = '{}.{}'.format(current_path,
225                                                reference.local_name)
226      nodes_to_visit.append(reference.node_id)
227
228  return paths
229
230
231def _is_graph_network(layer):
232  """Determines whether the layer is a graph network."""
233  # pylint: disable=protected-access
234  if isinstance(layer, RevivedNetwork):
235    return False
236  elif isinstance(layer, functional_lib.Functional):
237    return (layer._is_graph_network or
238            isinstance(layer, models_lib.Sequential))
239  return False
240
241
242class KerasObjectLoader(object):
243  """Loader that recreates Keras objects (e.g. layers, models).
244
245  Layers and models are revived from either the config or SavedModel following
246  these rules:
247  1. If object is a graph network (i.e. Sequential or Functional) then it will
248     be initialized using the structure from the config only after the children
249     layers have been created. Graph networks must be initialized with inputs
250     and outputs, so all child layers must be created beforehand.
251  2. If object's config exists and the class can be found, then revive from
252     config.
253  3. Object may have already been created if its parent was revived from config.
254     In this case, do nothing.
255  4. If nothing of the above applies, compose the various artifacts from the
256     SavedModel to create a subclassed layer or model. At this time, custom
257     metrics are not supported.
258
259  """
260
261  def __init__(self, metadata, object_graph_def):
262    self._metadata = {x.node_id: x for x in metadata.nodes}
263    self._proto = object_graph_def
264
265    self._node_paths = {node_data.node_id: node_data.node_path
266                        for node_data in metadata.nodes}
267    self.loaded_nodes = {}  # Maps node path -> loaded node
268
269    # Store all node ids that have already been traversed when tracking nodes
270    # that were recreated from the config.
271    self._traversed_nodes_from_config = set()
272
273    # Maps model id -> (blank model obj, list of child layer or their node ids)
274    # This tracks all layers in functional and sequential models. These models
275    # are only reconstructed after all of their child layers have been created.
276    self.model_layer_dependencies = {}
277    self._models_to_reconstruct = []
278
279  def del_tracking(self):
280    """Removes tracked references that are only used when loading the model."""
281    # Now that the node object has been fully loaded, and the checkpoint has
282    # been restored, the object no longer needs to track objects added from
283    # SerializedAttributes. (Note that saving a training checkpoint still
284    # functions correctly, because layers and variables are tracked separately
285    # by the Layer object.)
286    # TODO(kathywu): Instead of outright deleting these nodes (which would
287    # make restoring from a different checkpoint tricky), mark them as extra
288    # dependencies that are OK to overwrite.
289    for node in self.loaded_nodes.values():
290      node = node[0]
291      if not isinstance(node, base_layer.Layer):
292        # Loaded nodes can contain other trackable objects created when
293        # loading layers from the config, such as variables.
294        continue
295      for name in PUBLIC_ATTRIBUTES:
296        node._delete_tracking(name)  # pylint: disable=protected-access
297
298      if isinstance(node, functional_lib.Functional):
299        # Delete the temporary layer dependencies, which were used to restore
300        # the checkpointed values. When the model is live, the user can delete
301        # or add layers to the model at any time, so these layer dependencies
302        # may be obsolete.
303        dependencies = list(node._self_unconditional_dependency_names)  # pylint: disable=protected-access
304        for name in dependencies:
305          if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
306            node._delete_tracking(name)  # pylint: disable=protected-access
307
308  def _add_children_recreated_from_config(self, obj, proto, node_id):
309    """Recursively records objects recreated from config."""
310    # pylint: disable=protected-access
311    if node_id in self._traversed_nodes_from_config:
312      return
313
314    parent_path = self._node_paths[node_id]
315    self._traversed_nodes_from_config.add(node_id)
316    obj._maybe_initialize_trackable()
317    if isinstance(obj, base_layer.Layer) and not obj.built:
318      metadata = json_utils.decode(self._metadata[node_id].metadata)
319      self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
320
321    # Create list of all possible children
322    children = []
323    # Look for direct children
324    for reference in proto.children:
325      obj_child = obj._lookup_dependency(reference.local_name)
326      children.append((obj_child, reference.node_id, reference.local_name))
327
328    # Add metrics that may have been added to the layer._metrics list.
329    # This is stored in the SavedModel as layer.keras_api.layer_metrics in
330    # SavedModels created after Tf 2.2.
331    metric_list_node_id = self._search_for_child_node(
332        node_id, [constants.KERAS_ATTR, 'layer_metrics'])
333    if metric_list_node_id is not None and hasattr(obj, '_metrics'):
334      obj_metrics = {m.name: m for m in obj._metrics}
335      for reference in self._proto.nodes[metric_list_node_id].children:
336        metric = obj_metrics.get(reference.local_name)
337        if metric is not None:
338          metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
339                                                     reference.local_name)
340          children.append((metric, reference.node_id, metric_path))
341
342    for (obj_child, child_id, child_name) in children:
343      child_proto = self._proto.nodes[child_id]
344
345      if not isinstance(obj_child, trackable.Trackable):
346        continue
347      if (child_proto.user_object.identifier in
348          revived_types.registered_identifiers()):
349        setter = revived_types.get_setter(child_proto.user_object)
350      elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
351        setter = _revive_setter
352      else:
353        setter = setattr
354        # pylint: enable=protected-access
355
356      if child_id in self.loaded_nodes:
357        if self.loaded_nodes[child_id][0] is not obj_child:
358          # This means that the same trackable object is referenced by two
359          # different objects that were recreated from the config.
360          logging.warning(
361              'Looks like there is an object (perhaps variable or '
362              'layer) that is shared between different layers/models. '
363              'This may cause issues when restoring the variable '
364              'values. Object: {}'.format(obj_child))
365        continue
366
367      # Overwrite variable names with the ones saved in the SavedModel.
368      if (child_proto.WhichOneof('kind') == 'variable' and
369          child_proto.variable.name):
370        obj_child._handle_name = child_proto.variable.name + ':0'  # pylint: disable=protected-access
371
372      if isinstance(obj_child, data_structures.TrackableDataStructure):
373        setter = lambda *args: None
374
375      child_path = '{}.{}'.format(parent_path, child_name)
376      self._node_paths[child_id] = child_path
377      self._add_children_recreated_from_config(
378          obj_child, child_proto, child_id)
379      self.loaded_nodes[child_id] = obj_child, setter
380
381  def load_layers(self, compile=True):  # pylint: disable=redefined-builtin
382    """Load all layer nodes from the metadata."""
383    # Load metrics after models and layers, since it's likely that models
384    # and layers will create the metric when initialized (this avoids wasting
385    # time by creating objects multiple times).
386    metric_list = []
387    for node_metadata in self._metadata.values():
388      if node_metadata.identifier == constants.METRIC_IDENTIFIER:
389        metric_list.append(node_metadata)
390        continue
391
392      self.loaded_nodes[node_metadata.node_id] = self._load_layer(
393          node_metadata.node_id, node_metadata.identifier,
394          node_metadata.metadata)
395
396    for node_metadata in metric_list:
397      try:
398        self.loaded_nodes[node_metadata.node_id] = self._load_layer(
399            node_metadata.node_id, node_metadata.identifier,
400            node_metadata.metadata)
401      except ValueError:
402        # Metrics are only needed when the model is compiled later. We ignore
403        # errors when trying to load custom metrics when `compile=False` until
404        # custom metrics are serialized properly (b/135550038).
405        if compile:
406          raise
407        logging.warning('Unable to restore custom metric. Please ensure that '
408                        'the layer implements `get_config` and `from_config` '
409                        'when saving. In addition, please use the '
410                        '`custom_objects` arg when calling `load_model()`.')
411
412  def _load_layer(self, node_id, identifier, metadata):
413    """Load a single layer from a SavedUserObject proto."""
414    metadata = json_utils.decode(metadata)
415
416    # If node was already created
417    if node_id in self.loaded_nodes:
418      node, setter = self.loaded_nodes[node_id]
419
420      # Revive setter requires the object to have a `_serialized_attributes`
421      # property. Add it here.
422      _maybe_add_serialized_attributes(node, metadata)
423
424      config = metadata.get('config')
425      if _is_graph_network(node) and generic_utils.validate_config(config):
426        child_nodes = self._get_child_layer_node_ids(node_id)
427        self.model_layer_dependencies[node_id] = (node, child_nodes)
428        if not child_nodes:
429          self._models_to_reconstruct.append(node_id)
430      return node, setter
431
432    # Detect whether this object can be revived from the config. If not, then
433    # revive from the SavedModel instead.
434    obj, setter = self._revive_from_config(identifier, metadata, node_id)
435    if obj is None:
436      obj, setter = revive_custom_object(identifier, metadata)
437
438    # Add an attribute that stores the extra functions/objects saved in the
439    # SavedModel. Most of these functions/objects are ignored, but some are
440    # used later in the loading process (e.g. the list of regularization
441    # losses, or the training config of compiled models).
442    _maybe_add_serialized_attributes(obj, metadata)
443    return obj, setter
444
445  def _revive_from_config(self, identifier, metadata, node_id):
446    """Revives a layer/model from config, or returns None."""
447    if identifier == constants.METRIC_IDENTIFIER:
448      obj = self._revive_metric_from_config(metadata)
449    else:
450      obj = (
451          self._revive_graph_network(identifier, metadata, node_id) or
452          self._revive_layer_or_model_from_config(metadata, node_id))
453
454    if obj is None:
455      return None, None
456
457    setter = self._config_node_setter(_revive_setter)
458    self._add_children_recreated_from_config(
459        obj, self._proto.nodes[node_id], node_id)
460    return obj, setter
461
462  def _revive_graph_network(self, identifier, metadata, node_id):
463    """Revives a graph network from config."""
464    # Determine whether the metadata contains information for reviving a
465    # functional or Sequential model.
466    config = metadata.get('config')
467    if not generic_utils.validate_config(config):
468      return None
469
470    class_name = compat.as_str(metadata['class_name'])
471    if generic_utils.get_registered_object(class_name) is not None:
472      return None
473    model_is_functional_or_sequential = (
474        metadata.get('is_graph_network', False) or
475        class_name == 'Sequential' or
476        class_name == 'Functional')
477    if not model_is_functional_or_sequential:
478      return None
479
480    # Revive functional and sequential models as blank model objects for now (
481    # must be initialized to enable setattr tracking and attribute caching).
482    # Reconstruction of the network is deferred until all of the model's layers
483    # have been revived.
484    if class_name == 'Sequential':
485      model = models_lib.Sequential(name=config['name'])
486    # The model is a custom Sequential model.
487    elif identifier == constants.SEQUENTIAL_IDENTIFIER:
488      # Uses the custom class name, since the config does not have one.
489      model = models_lib.Sequential(name=class_name)
490    else:
491      model = models_lib.Functional(
492          inputs=[], outputs=[], name=config['name'])
493
494    # Record this model and its layers. This will later be used to reconstruct
495    # the model.
496    layers = self._get_child_layer_node_ids(node_id)
497    self.model_layer_dependencies[node_id] = (model, layers)
498    if not layers:
499      self._models_to_reconstruct.append(node_id)
500    return model
501
502  def _revive_layer_or_model_from_config(self, metadata, node_id):
503    """Revives a layer/custom model from config; returns None if infeasible."""
504    # Check that the following requirements are met for reviving from config:
505    #    1. Object can be deserialized from config.
506    #    2. If the object needs to be built, then the build input shape can be
507    #       found.
508    class_name = metadata.get('class_name')
509    config = metadata.get('config')
510    shared_object_id = metadata.get('shared_object_id')
511    must_restore_from_config = metadata.get('must_restore_from_config')
512    if not generic_utils.validate_config(config):
513      return None
514
515    try:
516      obj = layers_module.deserialize(
517          generic_utils.serialize_keras_class_and_config(
518              class_name, config, shared_object_id=shared_object_id))
519    except ValueError:
520      if must_restore_from_config:
521        raise RuntimeError(
522            'Unable to restore a layer of class {cls}. Layers of '
523            'class {cls} require that the class be provided to '
524            'the model loading code, either by registering the '
525            'class using @keras.utils.register_keras_serializable '
526            'on the class def and including that file in your '
527            'program, or by passing the class in a '
528            'keras.utils.CustomObjectScope that wraps this load '
529            'call.'.format(cls=class_name))
530      else:
531        return None
532
533    # Use the dtype, name, and trainable status. Often times these are not
534    # specified in custom configs, so retrieve their values from the metadata.
535    # pylint: disable=protected-access
536    obj._name = metadata['name']
537    if metadata.get('trainable') is not None:
538      obj.trainable = metadata['trainable']
539    if metadata.get('dtype') is not None:
540      obj._set_dtype_policy(metadata['dtype'])
541    if metadata.get('stateful') is not None:
542      obj.stateful = metadata['stateful']
543    # Restore model save spec for subclassed models. (layers do not store a
544    # SaveSpec)
545    if isinstance(obj, training_lib.Model):
546      save_spec = metadata.get('save_spec')
547      if save_spec is not None:
548        obj._set_save_spec(save_spec)
549    # pylint: enable=protected-access
550
551    build_input_shape = metadata.get('build_input_shape')
552    built = self._try_build_layer(obj, node_id, build_input_shape)
553
554    if not built:
555      # If the layer cannot be built, revive a custom layer instead.
556      return None
557    return obj
558
559  def _revive_metric_from_config(self, metadata):
560    """Revives a metric object using the config saved in the metadata."""
561    class_name = compat.as_str(metadata['class_name'])
562    config = metadata.get('config')
563
564    if not generic_utils.validate_config(config):
565      return None
566
567    try:
568      obj = metrics.deserialize(
569          generic_utils.serialize_keras_class_and_config(class_name, config))
570    except ValueError:
571      return None
572
573    build_input_shape = metadata.get('build_input_shape')
574    if build_input_shape is not None and hasattr(obj, '_build'):
575      obj._build(build_input_shape)  # pylint: disable=protected-access
576
577    return obj
578
579  def _try_build_layer(self, obj, node_id, build_input_shape):
580    """Attempts to build the layer."""
581    if obj.built or hasattr(obj.build, '_is_default'):
582      obj.built = True
583      return True
584
585    if build_input_shape is None:
586      build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
587
588    if build_input_shape is not None:
589      obj.build(build_input_shape)
590      base_layer.Layer.build(obj, build_input_shape)
591      return True
592
593    return False
594
595  def _load_edges(self):
596    """Add edges for all nodes that are not waiting on initialization."""
597    for node_id, proto in enumerate(self._proto.nodes):
598      if node_id not in self.model_layer_dependencies:
599        self._add_object_graph_edges(proto, node_id)
600
601  def get_path(self, node_id):
602    return self._node_paths[node_id]
603
604  def finalize_objects(self):
605    """Finish setting up Keras objects.
606
607    This function is executed after all objects and functions have been created.
608    Call functions and losses are attached to each layer, and once all layers
609    have been fully set up, graph networks are initialized.
610
611    Subclassed models that are revived from the SavedModel are treated like
612    layers, and have their call/loss functions attached here.
613    """
614    # Finish setting up layers and subclassed models. This step attaches call
615    # functions and losses to each object, and sets model inputs/outputs.
616    layers_revived_from_config = []
617    layers_revived_from_saved_model = []
618    for node_id, (node, _) in self.loaded_nodes.items():
619      if (not isinstance(node, base_layer.Layer) or
620          # Don't finalize models until all layers have finished loading.
621          node_id in self.model_layer_dependencies):
622        continue
623
624      self._unblock_model_reconstruction(node_id, node)
625
626      if isinstance(node, input_layer.InputLayer):
627        continue
628      elif isinstance(node, metrics.Metric):
629        continue
630
631      if isinstance(node, (RevivedLayer, RevivedInputLayer)):
632        layers_revived_from_saved_model.append(node)
633      else:
634        layers_revived_from_config.append(node)
635
636    _finalize_saved_model_layers(layers_revived_from_saved_model)
637    _finalize_config_layers(layers_revived_from_config)
638
639    # Initialize graph networks, now that layer dependencies have been resolved.
640    self._reconstruct_all_models()
641
642  def _unblock_model_reconstruction(self, layer_id, layer):
643    """Removes layer from blocking model reconstruction."""
644    for model_id, v in self.model_layer_dependencies.items():
645      _, layers = v
646      if layer_id not in layers:
647        continue
648      layers[layers.index(layer_id)] = layer
649      if all(isinstance(x, base_layer.Layer) for x in layers):
650        self._models_to_reconstruct.append(model_id)
651
652  def _reconstruct_all_models(self):
653    """Reconstructs the network structure of all models."""
654    all_initialized_models = set()
655    while self._models_to_reconstruct:
656      model_id = self._models_to_reconstruct.pop(0)
657      all_initialized_models.add(model_id)
658      model, layers = self.model_layer_dependencies[model_id]
659      self._reconstruct_model(model_id, model, layers)
660      _finalize_config_layers([model])
661
662    if all_initialized_models != set(self.model_layer_dependencies.keys()):
663      # This should not happen.
664      uninitialized_model_ids = (
665          set(self.model_layer_dependencies.keys()) - all_initialized_models)
666      uninitialized_model_names = [
667          self.model_layer_dependencies[model_id][0].name
668          for model_id in uninitialized_model_ids]
669      raise ValueError('Error when loading from SavedModel -- the following '
670                       'models could not be initialized: {}'
671                       .format(uninitialized_model_names))
672
673  def _reconstruct_model(self, model_id, model, layers):
674    """Reconstructs the network structure."""
675    config = json_utils.decode(self._metadata[model_id].metadata)['config']
676
677    # Set up model inputs
678    if model.inputs:
679      # Inputs may already be created if the model is instantiated in another
680      # object's __init__.
681      pass
682    elif isinstance(model, models_lib.Sequential):
683      if not layers or not isinstance(layers[0], input_layer.InputLayer):
684        if config['layers'][0]['class_name'] == 'InputLayer':
685          layers.insert(0, input_layer.InputLayer.from_config(
686              config['layers'][0]['config']))
687        elif 'batch_input_shape' in config['layers'][0]['config']:
688          batch_input_shape = config['layers'][0]['config']['batch_input_shape']
689          layers.insert(0, input_layer.InputLayer(
690              input_shape=batch_input_shape[1:],
691              batch_size=batch_input_shape[0],
692              dtype=layers[0].dtype,
693              name=layers[0].name + '_input'))
694      model.__init__(layers, name=config['name'])
695      if not model.inputs:
696        first_layer = self._get_child_layer_node_ids(model_id)[0]
697        input_specs = self._infer_inputs(first_layer)
698        input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
699        model._set_inputs(input_specs)  # pylint: disable=protected-access
700        if not model.built and not isinstance(input_specs, dict):
701          model.build(input_shapes)
702    else:  # Reconstruct functional model
703      (inputs, outputs,
704       created_layers) = functional_lib.reconstruct_from_config(
705           config, created_layers={layer.name: layer for layer in layers})
706      model.__init__(inputs, outputs, name=config['name'])
707      functional_lib.connect_ancillary_layers(model, created_layers)
708
709    # Set model dtype.
710    _set_network_attributes_from_metadata(model)
711
712    # Unblock models that are dependent on this model.
713    self._unblock_model_reconstruction(model_id, model)
714
715  def _get_child_layer_node_ids(self, node_id):
716    """Returns the node ids of each layer in a Sequential/Functional model."""
717    # Sequential and Functional track layers with names following the format
718    # "layer-N". Use this to generate the list of layers.
719    num_layers = 0
720    child_layers = {}
721    pattern = re.compile('layer-(\\d+)')
722
723    for child in self._proto.nodes[node_id].children:
724      m = pattern.match(child.local_name)
725      if m is None:
726        continue
727      layer_n = int(m.group(1))
728      num_layers = max(layer_n + 1, num_layers)
729      child_layers[layer_n] = child.node_id
730
731    ordered = []
732    for n in range(num_layers):
733      child = child_layers.get(n)
734      if child is None:
735        break
736      ordered.append(child)
737    return ordered
738
739  def _search_for_child_node(self, parent_id, path_to_child):
740    """Returns node id of child node.
741
742    A helper method for traversing the object graph proto.
743
744    As an example, say that the object graph proto in the SavedModel contains an
745    object with the following child and grandchild attributes:
746
747    `parent.child_a.child_b`
748
749    This method can be used to retrieve the node id of `child_b` using the
750    parent's node id by calling:
751
752    `_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
753
754    Args:
755      parent_id: node id of parent node
756      path_to_child: list of children names.
757
758    Returns:
759      node_id of child, or None if child isn't found.
760    """
761    if not path_to_child:
762      return parent_id
763
764    for child in self._proto.nodes[parent_id].children:
765      if child.local_name == path_to_child[0]:
766        return self._search_for_child_node(child.node_id, path_to_child[1:])
767    return None
768
769  def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
770    """Infers input shape of layer from SavedModel functions."""
771    call_fn_id = self._search_for_child_node(
772        layer_node_id, ['call_and_return_all_conditional_losses'])
773    if call_fn_id is None:
774      return None
775
776    concrete_functions = (
777        self._proto.nodes[call_fn_id].function.concrete_functions)
778    if not concrete_functions:
779      return None
780    call_fn_name = concrete_functions[0]
781    call_fn_proto = self._proto.concrete_functions[call_fn_name]
782    structured_input_signature = nested_structure_coder.decode_proto(
783        call_fn_proto.canonicalized_input_signature)
784    inputs = structured_input_signature[0][0]
785    if convert_to_shapes:
786      return nest.map_structure(lambda spec: spec.shape, inputs)
787    else:
788      return inputs
789
790  def _config_node_setter(self, setter):
791    """Creates edges for nodes that are recreated from config."""
792    def setattr_wrapper(obj, name, value):
793      # Avoid overwriting attributes of objects recreated from the config.
794      if obj._lookup_dependency(name) is None:  # pylint: disable=protected-access
795        setter(obj, name, value)
796    return setattr_wrapper
797
798
799def _finalize_saved_model_layers(layers):
800  """Runs the final steps of loading Keras Layers from SavedModel."""
801  # pylint: disable=protected-access
802  # 1. Set up call functions for all layers initialized from the SavedModel (
803  # and not the config)
804  for layer in layers:
805    layer.built = True
806    layer_call = getattr(_get_keras_attr(layer),
807                         'call_and_return_conditional_losses', None)
808    if layer_call and layer_call.concrete_functions:
809      layer.call = utils.use_wrapped_call(
810          layer, layer_call, return_method=True)
811      expects_training_arg = layer._serialized_attributes['metadata'][
812          'expects_training_arg']
813      if 'training' in layer_call.function_spec.arg_names:
814        # This could change the value of `expects_training_arg` if this layer
815        # doesn't expect a training arg, but has a child layer that does.
816        expects_training_arg = True
817      layer._init_call_fn_args(expects_training_arg)
818    else:
819      layer.call = types.MethodType(
820          _unable_to_call_layer_due_to_serialization_issue, layer)
821
822  for layer in layers:
823    # 2. Set model inputs and outputs.
824    if isinstance(layer, RevivedNetwork):
825      _set_network_attributes_from_metadata(layer)
826
827      if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
828        call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
829        if not call_fn.concrete_functions:
830          continue
831        if call_fn.input_signature is None:
832          inputs = infer_inputs_from_restored_call_function(call_fn)
833        else:
834          inputs = call_fn.input_signature[0]
835        layer._set_inputs(inputs)  # pylint: disable=protected-access
836
837    # 3. Add losses that aren't generated by the layer.call function.
838    _restore_layer_unconditional_losses(layer)
839    _restore_layer_activation_loss(layer)
840
841    # 4. Restore metrics list
842    _restore_layer_metrics(layer)
843
844  # pylint: enable=protected-access
845
846
847def _unable_to_call_layer_due_to_serialization_issue(
848    layer, *unused_args, **unused_kwargs):
849  """Replaces the `layer.call` if the layer was not fully serialized.
850
851  Keras Model/Layer serialization is relatively relaxed because SavedModels
852  are not always loaded back as keras models. Thus, when there is an issue
853  tracing a non-signature function, a warning is logged instead of raising an
854  error. This results in a SavedModel where the model's call function is saved,
855  but the internal layer call functions are not.
856
857  When deserialized with `tf.keras.models.load_model`, the internal layers
858  which do not have serialized call functions should raise an error when called.
859
860  Args:
861    layer: Layer without the serialized call function.
862
863  Raises:
864    ValueError
865  """
866
867  raise ValueError(
868      'Cannot call custom layer {} of type {}, because the call function was '
869      'not serialized to the SavedModel.'
870      'Please try one of the following methods to fix this issue:'
871      '\n\n(1) Implement `get_config` and `from_config` in the layer/model '
872      'class, and pass the object to the `custom_objects` argument when '
873      'loading the model. For more details, see: '
874      'https://www.tensorflow.org/guide/keras/save_and_serialize'
875      '\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
876      'and not `__call__`. The input shape and dtype will be automatically '
877      'recorded when the object is called, and used when saving. To manually '
878      'specify the input shape/dtype, decorate the call function with '
879      '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
880
881
882def _finalize_config_layers(layers):
883  """Runs the final steps of loading Keras Layers from config."""
884  for layer in layers:
885    # It is assumed that layers define their unconditional losses after being
886    # recreated from the config and built. The exceptions to this
887    # are Functional and Sequential models, which only store conditional losses
888    # (losses dependent on the inputs) in the config. Unconditional losses like
889    # weight regularization must be revived from the SavedModel.
890    if _is_graph_network(layer):
891      _restore_layer_unconditional_losses(layer)
892
893    # Some layers, like Dense, record their activation loss function in the
894    # config. However, not all layers do this, so the activation loss may be
895    # missing when restored from the config/hdf5.
896    # TODO(kathywu): Investigate ways to improve the config to ensure consistent
897    # loading behavior between HDF5 and SavedModel.
898    _restore_layer_activation_loss(layer)
899
900    # Restore metrics list.
901    _restore_layer_metrics(layer)
902
903    # Restore RNN layer states.
904    if (isinstance(layer, recurrent.RNN) and
905        layer.stateful and
906        hasattr(_get_keras_attr(layer), 'states')):
907      layer.states = getattr(_get_keras_attr(layer), 'states', None)
908      for variable in nest.flatten(layer.states):
909        backend.track_variable(variable)
910
911    # Perform any layer defined finalization of the layer state.
912    layer.finalize_state()
913
914
915def _finalize_metric(metric):
916  metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
917      metric.keras_api.update_state), metric)
918  metric.result = metric.keras_api.result
919
920
921def _restore_layer_unconditional_losses(layer):
922  """Restore unconditional losses from SavedModel."""
923  if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'):
924    losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', [])
925  else:
926    # Some earlier SavedModels may not have layer_regularization_losses
927    # serialized separately. Fall back to using the regularization_losses
928    # list if it does not exist.
929    losses = layer._serialized_attributes.get('regularization_losses', [])  # pylint: disable=protected-access
930  for loss in losses:
931    layer.add_loss(loss)
932
933
934def _restore_layer_activation_loss(layer):
935  """Restore actiation loss from SavedModel."""
936  # Use wrapped activity regularizer function if the layer's activity
937  # regularizer wasn't created during initialization.
938  activity_regularizer = getattr(_get_keras_attr(layer),
939                                 'activity_regularizer_fn', None)
940  if activity_regularizer and not layer.activity_regularizer:
941    try:
942      layer.activity_regularizer = activity_regularizer
943    except AttributeError:
944      # This may happen if a layer wrapper is saved with an activity
945      # regularizer. The wrapper object's activity regularizer is unsettable.
946      pass
947
948
949def revive_custom_object(identifier, metadata):
950  """Revives object from SavedModel."""
951  if ops.executing_eagerly_outside_functions():
952    model_class = training_lib.Model
953  else:
954    model_class = training_lib_v1.Model
955
956  revived_classes = {
957      constants.INPUT_LAYER_IDENTIFIER: (
958          RevivedInputLayer, input_layer.InputLayer),
959      constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
960      constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
961      constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
962      constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
963  }
964  parent_classes = revived_classes.get(identifier, None)
965
966  if parent_classes is not None:
967    parent_classes = revived_classes[identifier]
968    revived_cls = type(
969        compat.as_str(metadata['class_name']), parent_classes, {})
970    return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
971  else:
972    raise ValueError('Unable to restore custom object of type {} currently. '
973                     'Please make sure that the layer implements `get_config`'
974                     'and `from_config` when saving. In addition, please use '
975                     'the `custom_objects` arg when calling `load_model()`.'
976                     .format(identifier))
977
978
979def _restore_layer_metrics(layer):
980  metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
981  layer_metrics = {m.name: m for m in layer._metrics}  # pylint: disable=protected-access
982  for name, metric in metrics_list.items():
983    if name not in layer_metrics:
984      # Metrics may be added during initialization/building of custom layers.
985      layer._metrics.append(metric)  # pylint: disable=protected-access
986
987
988# TODO(kathywu): Centrally define keys and functions for both  serialization and
989# deserialization.
990class RevivedLayer(object):
991  """Keras layer loaded from a SavedModel."""
992
993  @classmethod
994  def _init_from_metadata(cls, metadata):
995    """Create revived layer from metadata stored in the SavedModel proto."""
996    init_args = dict(
997        name=metadata['name'],
998        trainable=metadata['trainable'])
999    if metadata.get('dtype') is not None:
1000      init_args['dtype'] = metadata['dtype']
1001    if metadata.get('batch_input_shape') is not None:
1002      init_args['batch_input_shape'] = metadata['batch_input_shape']
1003
1004    revived_obj = cls(**init_args)
1005
1006    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1007      # pylint:disable=protected-access
1008      revived_obj._expects_training_arg = metadata['expects_training_arg']
1009      config = metadata.get('config')
1010      if generic_utils.validate_config(config):
1011        revived_obj._config = config
1012      if metadata.get('input_spec') is not None:
1013        revived_obj.input_spec = recursively_deserialize_keras_object(
1014            metadata['input_spec'],
1015            module_objects={'InputSpec': input_spec.InputSpec})
1016      if metadata.get('activity_regularizer') is not None:
1017        revived_obj.activity_regularizer = regularizers.deserialize(
1018            metadata['activity_regularizer'])
1019      if metadata.get('_is_feature_layer') is not None:
1020        revived_obj._is_feature_layer = metadata['_is_feature_layer']
1021      if metadata.get('stateful') is not None:
1022        revived_obj.stateful = metadata['stateful']
1023      # pylint:enable=protected-access
1024
1025    return revived_obj, _revive_setter
1026
1027  @property
1028  def keras_api(self):
1029    return self._serialized_attributes.get(constants.KERAS_ATTR, None)
1030
1031  def get_config(self):
1032    if hasattr(self, '_config'):
1033      return self._config
1034    else:
1035      raise NotImplementedError
1036
1037
1038def _revive_setter(layer, name, value):
1039  """Setter function that saves some attributes to separate dictionary."""
1040  # Many attributes in the SavedModel conflict with properties defined in
1041  # Layer and Model. Save these attributes to a separate dictionary.
1042  if name in PUBLIC_ATTRIBUTES:
1043    # pylint: disable=protected-access
1044    if isinstance(value, trackable.Trackable):
1045      layer._track_trackable(value, name=name)
1046    layer._serialized_attributes[name] = value
1047    # pylint: enable=protected-access
1048  elif (isinstance(layer, functional_lib.Functional) and
1049        re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
1050    # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
1051    # network._track_layers, should not be added as an attribute. They should
1052    # be temporarily added as a dependency so that checkpointed values can be
1053    # restored. These dependencies are manually deleted in
1054    # KerasObjectLoader.del_tracking.
1055
1056    # Set `overwrite=True` in the case that `layer` already tracks a different
1057    # layer-n. This may cause variable values to not be loaded properly in the
1058    # original layer-n, but we already warn the users about this
1059    # (ctrl-f "shared between different layers/models").
1060    layer._track_trackable(value, name, overwrite=True)  # pylint: disable=protected-access
1061  elif getattr(layer, name, None) is not None:
1062    # Don't overwrite already defined attributes.
1063    pass
1064  else:
1065    setattr(layer, name, value)
1066
1067
1068class RevivedInputLayer(object):
1069  """InputLayer loaded from a SavedModel."""
1070
1071  @classmethod
1072  def _init_from_metadata(cls, metadata):
1073    """Revives the saved InputLayer from the Metadata."""
1074    init_args = dict(
1075        name=metadata['name'],
1076        dtype=metadata['dtype'],
1077        sparse=metadata['sparse'],
1078        ragged=metadata['ragged'],
1079        batch_input_shape=metadata['batch_input_shape'])
1080    revived_obj = cls(**init_args)
1081    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1082      revived_obj._config = metadata['config']  # pylint:disable=protected-access
1083
1084    return revived_obj, setattr
1085
1086  def get_config(self):
1087    return self._config
1088
1089
1090def recursively_deserialize_keras_object(config, module_objects=None):
1091  """Deserialize Keras object from a nested structure."""
1092  if isinstance(config, dict):
1093    if 'class_name' in config:
1094      return generic_utils.deserialize_keras_object(
1095          config, module_objects=module_objects)
1096    else:
1097      return {key: recursively_deserialize_keras_object(config[key],
1098                                                        module_objects)
1099              for key in config}
1100  if isinstance(config, (tuple, list)):
1101    return [recursively_deserialize_keras_object(x, module_objects)
1102            for x in config]
1103  else:
1104    raise ValueError('Unable to decode config: {}'.format(config))
1105
1106
1107def get_common_shape(x, y):
1108  """Find a `TensorShape` that is compatible with both `x` and `y`."""
1109  if x is None != y is None:
1110    raise RuntimeError(
1111        'Cannot find a common shape when LHS shape is None but RHS shape '
1112        'is not (or vice versa): %s vs. %s' % (x, y))
1113  if x is None:
1114    return None  # The associated input was not a Tensor, no shape generated.
1115  if not isinstance(x, tensor_shape.TensorShape):
1116    raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
1117  if not isinstance(y, tensor_shape.TensorShape):
1118    raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
1119  if x.rank != y.rank or x.rank is None:
1120    return tensor_shape.TensorShape(None)
1121  dims = []
1122  for dim_x, dim_y in zip(x.dims, y.dims):
1123    if (dim_x != dim_y
1124        or tensor_shape.dimension_value(dim_x) is None
1125        or tensor_shape.dimension_value(dim_y) is None):
1126      dims.append(None)
1127    else:
1128      dims.append(tensor_shape.dimension_value(dim_x))
1129  return tensor_shape.TensorShape(dims)
1130
1131
1132def infer_inputs_from_restored_call_function(fn):
1133  """Returns TensorSpec of inputs from a restored call function.
1134
1135  Args:
1136    fn: Restored layer call function. It is assumed that `fn` has at least
1137        one concrete function and that the inputs are in the first argument.
1138
1139  Returns:
1140    TensorSpec of call function inputs.
1141  """
1142  def common_spec(x, y):
1143    common_shape = get_common_shape(x.shape, y.shape)
1144    if isinstance(x, sparse_tensor.SparseTensorSpec):
1145      return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
1146    elif isinstance(x, ragged_tensor.RaggedTensorSpec):
1147      return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype)
1148    return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
1149
1150  spec = fn.concrete_functions[0].structured_input_signature[0][0]
1151  for concrete in fn.concrete_functions[1:]:
1152    spec2 = concrete.structured_input_signature[0][0]
1153    spec = nest.map_structure(common_spec, spec, spec2)
1154  return spec
1155
1156
1157class RevivedNetwork(RevivedLayer):
1158  """Keras network of layers loaded from a SavedModel."""
1159
1160  @classmethod
1161  def _init_from_metadata(cls, metadata):
1162    """Create revived network from metadata stored in the SavedModel proto."""
1163    revived_obj = cls(name=metadata['name'])
1164
1165    # Store attributes revived from SerializedAttributes in a un-tracked
1166    # dictionary. The attributes are the ones listed in CommonEndpoints or
1167    # "keras_api" for keras-specific attributes.
1168    with utils.no_automatic_dependency_tracking_scope(revived_obj):
1169      # pylint:disable=protected-access
1170      revived_obj._expects_training_arg = metadata['expects_training_arg']
1171      config = metadata.get('config')
1172      if generic_utils.validate_config(config):
1173        revived_obj._config = config
1174
1175      if metadata.get('activity_regularizer') is not None:
1176        revived_obj.activity_regularizer = regularizers.deserialize(
1177            metadata['activity_regularizer'])
1178      # pylint:enable=protected-access
1179
1180    return revived_obj, _revive_setter  # pylint:disable=protected-access
1181
1182
1183def _set_network_attributes_from_metadata(revived_obj):
1184  """Sets attributes recorded in the metadata."""
1185  with utils.no_automatic_dependency_tracking_scope(revived_obj):
1186    # pylint:disable=protected-access
1187    metadata = revived_obj._serialized_attributes['metadata']
1188    if metadata.get('dtype') is not None:
1189      revived_obj._set_dtype_policy(metadata['dtype'])
1190    revived_obj._trainable = metadata['trainable']
1191    # pylint:enable=protected-access
1192
1193
1194def _maybe_add_serialized_attributes(layer, metadata):
1195  # Store attributes revived from SerializedAttributes in a un-tracked
1196  # dictionary. The attributes are the ones listed in CommonEndpoints or
1197  # "keras_api" for keras-specific attributes.
1198  if not hasattr(layer, '_serialized_attributes'):
1199    with utils.no_automatic_dependency_tracking_scope(layer):
1200      layer._serialized_attributes = {'metadata': metadata}  # pylint: disable=protected-access
1201
1202
1203def _get_keras_attr(layer):
1204  return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR,
1205                                                          None)
1206