1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras SavedModel deserialization.""" 16 17import os 18import re 19import types 20 21from google.protobuf import message 22 23from tensorflow.python.eager import context 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.keras import backend 29from tensorflow.python.keras import regularizers 30from tensorflow.python.keras.engine import input_spec 31from tensorflow.python.keras.optimizer_v2 import optimizer_v2 32from tensorflow.python.keras.protobuf import saved_metadata_pb2 33from tensorflow.python.keras.protobuf import versions_pb2 34from tensorflow.python.keras.saving import saving_utils 35from tensorflow.python.keras.saving.saved_model import constants 36from tensorflow.python.keras.saving.saved_model import json_utils 37from tensorflow.python.keras.saving.saved_model import utils 38from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints 39from tensorflow.python.keras.utils import generic_utils 40from tensorflow.python.keras.utils import metrics_utils 41from tensorflow.python.keras.utils.generic_utils import LazyLoader 42from tensorflow.python.ops.ragged import ragged_tensor 43from tensorflow.python.platform import gfile 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.saved_model import load as tf_load 46from tensorflow.python.saved_model import loader_impl 47from tensorflow.python.saved_model import nested_structure_coder 48from tensorflow.python.saved_model import revived_types 49from tensorflow.python.trackable import base as trackable 50from tensorflow.python.trackable import data_structures 51from tensorflow.python.util import compat 52from tensorflow.python.util import nest 53 54# To avoid circular dependencies between keras/engine and keras/saving, 55# code in keras/saving must delay imports. 56 57# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 58# once the issue with copybara is fixed. 59# pylint:disable=g-inconsistent-quotes 60models_lib = LazyLoader("models_lib", globals(), 61 "tensorflow.python.keras.models") 62base_layer = LazyLoader( 63 "base_layer", globals(), 64 "tensorflow.python.keras.engine.base_layer") 65layers_module = LazyLoader( 66 "layers_module", globals(), 67 "tensorflow.python.keras.layers") 68input_layer = LazyLoader( 69 "input_layer", globals(), 70 "tensorflow.python.keras.engine.input_layer") 71functional_lib = LazyLoader( 72 "functional_lib", globals(), 73 "tensorflow.python.keras.engine.functional") 74training_lib = LazyLoader( 75 "training_lib", globals(), 76 "tensorflow.python.keras.engine.training") 77training_lib_v1 = LazyLoader( 78 "training_lib_v1", globals(), 79 "tensorflow.python.keras.engine.training_v1") 80metrics = LazyLoader("metrics", globals(), 81 "tensorflow.python.keras.metrics") 82recurrent = LazyLoader( 83 "recurrent", globals(), 84 "tensorflow.python.keras.layers.recurrent") 85# pylint:enable=g-inconsistent-quotes 86 87 88PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union( 89 CommonEndpoints.all_checkpointable_objects) 90PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR) 91 92 93def load(path, compile=True, options=None): # pylint: disable=redefined-builtin 94 """Loads Keras objects from a SavedModel. 95 96 Any Keras layer or model saved to the SavedModel will be loaded back 97 as Keras objects. Other objects are loaded as regular trackable objects (same 98 as `tf.saved_model.load`). 99 100 Currently, Keras saving/loading only retains the Keras object's weights, 101 losses, and call function. 102 103 The loaded model can be re-compiled, but the original optimizer, compiled loss 104 functions, and metrics are not retained. This is temporary, and `model.save` 105 will soon be able to serialize compiled models. 106 107 Args: 108 path: Path to SavedModel. 109 compile: If true, compile the model after loading it. 110 options: Optional `tf.saved_model.LoadOptions` object that specifies 111 options for loading from SavedModel. 112 113 114 Returns: 115 Object loaded from SavedModel. 116 """ 117 # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. 118 # TODO(kathywu): Add code to load from objects that contain all endpoints 119 120 # Look for metadata file or parse the SavedModel 121 metadata = saved_metadata_pb2.SavedMetadata() 122 meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] 123 object_graph_def = meta_graph_def.object_graph_def 124 path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) 125 if gfile.Exists(path_to_metadata_pb): 126 try: 127 with gfile.GFile(path_to_metadata_pb, 'rb') as f: 128 file_content = f.read() 129 metadata.ParseFromString(file_content) 130 except message.DecodeError as e: 131 raise IOError('Cannot parse keras metadata {}: {}.' 132 .format(path_to_metadata_pb, str(e))) 133 else: 134 logging.warning('SavedModel saved prior to TF 2.5 detected when loading ' 135 'Keras model. Please ensure that you are saving the model ' 136 'with model.save() or tf.keras.models.save_model(), *NOT* ' 137 'tf.saved_model.save(). To confirm, there should be a file ' 138 'named "keras_metadata.pb" in the SavedModel directory.') 139 _read_legacy_metadata(object_graph_def, metadata) 140 141 if not metadata.nodes: 142 # When there are no Keras objects, return the results from the core loader 143 return tf_load.load(path, options=options) 144 145 # Recreate layers and metrics using the info stored in the metadata. 146 keras_loader = KerasObjectLoader(metadata, object_graph_def) 147 keras_loader.load_layers(compile=compile) 148 149 # Generate a dictionary of all loaded nodes. 150 nodes_to_load = {'root': None} 151 for node_id, loaded_node in keras_loader.loaded_nodes.items(): 152 nodes_to_load[keras_loader.get_path(node_id)] = loaded_node 153 loaded = tf_load.load_partial(path, nodes_to_load, options=options) 154 155 # Finalize the loaded layers and remove the extra tracked dependencies. 156 keras_loader.finalize_objects() 157 keras_loader.del_tracking() 158 159 model = loaded['root'] 160 161 # pylint: disable=protected-access 162 if isinstance(model, training_lib.Model) and compile: 163 # TODO(kathywu): Use compiled objects from SavedModel, instead of 164 # creating new objects from the training config. 165 training_config = model._serialized_attributes['metadata'].get( 166 'training_config', None) 167 if training_config is not None: 168 model.compile(**saving_utils.compile_args_from_training_config( 169 training_config), from_serialized=True) 170 saving_utils.try_build_compiled_arguments(model) 171 if isinstance(model.optimizer, optimizer_v2.OptimizerV2): 172 if (model.optimizer.get_slot_names()): 173 logging.warning('Your optimizer uses slots. ' 174 'Slots cannot be restored from saved_model, ' 175 'as a result, your model is starting with ' 176 'a new initialized optimizer.') 177 else: 178 logging.warning('No training configuration found in save file, so the ' 179 'model was *not* compiled. Compile it manually.') 180 # pylint: enable=protected-access 181 182 # Force variables and resources to initialize. 183 if not context.executing_eagerly(): 184 sess = backend.get_session() # Variables are initialized by this call. 185 sess.run(ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)) 186 187 return model 188 189 190def _read_legacy_metadata(object_graph_def, metadata): 191 """Builds a KerasMetadata proto from the SavedModel ObjectGraphDef.""" 192 # Older SavedModels store the metadata directly in the proto instead of the 193 # separate pb file. 194 node_paths = _generate_object_paths(object_graph_def) 195 for node_id, proto in enumerate(object_graph_def.nodes): 196 if (proto.WhichOneof('kind') == 'user_object' and 197 proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS): 198 if not proto.user_object.metadata: 199 raise ValueError('Unable to create a Keras model from this SavedModel. ' 200 'This SavedModel was created with ' 201 '`tf.saved_model.save`, and lacks the Keras metadata.' 202 'Please save your Keras model by calling `model.save`' 203 'or `tf.keras.models.save_model`.') 204 metadata.nodes.add( 205 node_id=node_id, 206 node_path=node_paths[node_id], 207 version=versions_pb2.VersionDef( 208 producer=1, min_consumer=1, bad_consumers=[]), 209 identifier=proto.user_object.identifier, 210 metadata=proto.user_object.metadata) 211 212 213def _generate_object_paths(object_graph_def): 214 """Traverses through an ObjectGraphDef and builds a map of all node paths.""" 215 paths = {0: 'root'} 216 nodes_to_visit = [0] 217 218 while nodes_to_visit: 219 current_node = nodes_to_visit.pop() 220 current_path = paths[current_node] 221 for reference in object_graph_def.nodes[current_node].children: 222 if reference.node_id in paths: 223 continue 224 paths[reference.node_id] = '{}.{}'.format(current_path, 225 reference.local_name) 226 nodes_to_visit.append(reference.node_id) 227 228 return paths 229 230 231def _is_graph_network(layer): 232 """Determines whether the layer is a graph network.""" 233 # pylint: disable=protected-access 234 if isinstance(layer, RevivedNetwork): 235 return False 236 elif isinstance(layer, functional_lib.Functional): 237 return (layer._is_graph_network or 238 isinstance(layer, models_lib.Sequential)) 239 return False 240 241 242class KerasObjectLoader(object): 243 """Loader that recreates Keras objects (e.g. layers, models). 244 245 Layers and models are revived from either the config or SavedModel following 246 these rules: 247 1. If object is a graph network (i.e. Sequential or Functional) then it will 248 be initialized using the structure from the config only after the children 249 layers have been created. Graph networks must be initialized with inputs 250 and outputs, so all child layers must be created beforehand. 251 2. If object's config exists and the class can be found, then revive from 252 config. 253 3. Object may have already been created if its parent was revived from config. 254 In this case, do nothing. 255 4. If nothing of the above applies, compose the various artifacts from the 256 SavedModel to create a subclassed layer or model. At this time, custom 257 metrics are not supported. 258 259 """ 260 261 def __init__(self, metadata, object_graph_def): 262 self._metadata = {x.node_id: x for x in metadata.nodes} 263 self._proto = object_graph_def 264 265 self._node_paths = {node_data.node_id: node_data.node_path 266 for node_data in metadata.nodes} 267 self.loaded_nodes = {} # Maps node path -> loaded node 268 269 # Store all node ids that have already been traversed when tracking nodes 270 # that were recreated from the config. 271 self._traversed_nodes_from_config = set() 272 273 # Maps model id -> (blank model obj, list of child layer or their node ids) 274 # This tracks all layers in functional and sequential models. These models 275 # are only reconstructed after all of their child layers have been created. 276 self.model_layer_dependencies = {} 277 self._models_to_reconstruct = [] 278 279 def del_tracking(self): 280 """Removes tracked references that are only used when loading the model.""" 281 # Now that the node object has been fully loaded, and the checkpoint has 282 # been restored, the object no longer needs to track objects added from 283 # SerializedAttributes. (Note that saving a training checkpoint still 284 # functions correctly, because layers and variables are tracked separately 285 # by the Layer object.) 286 # TODO(kathywu): Instead of outright deleting these nodes (which would 287 # make restoring from a different checkpoint tricky), mark them as extra 288 # dependencies that are OK to overwrite. 289 for node in self.loaded_nodes.values(): 290 node = node[0] 291 if not isinstance(node, base_layer.Layer): 292 # Loaded nodes can contain other trackable objects created when 293 # loading layers from the config, such as variables. 294 continue 295 for name in PUBLIC_ATTRIBUTES: 296 node._delete_tracking(name) # pylint: disable=protected-access 297 298 if isinstance(node, functional_lib.Functional): 299 # Delete the temporary layer dependencies, which were used to restore 300 # the checkpointed values. When the model is live, the user can delete 301 # or add layers to the model at any time, so these layer dependencies 302 # may be obsolete. 303 dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access 304 for name in dependencies: 305 if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: 306 node._delete_tracking(name) # pylint: disable=protected-access 307 308 def _add_children_recreated_from_config(self, obj, proto, node_id): 309 """Recursively records objects recreated from config.""" 310 # pylint: disable=protected-access 311 if node_id in self._traversed_nodes_from_config: 312 return 313 314 parent_path = self._node_paths[node_id] 315 self._traversed_nodes_from_config.add(node_id) 316 obj._maybe_initialize_trackable() 317 if isinstance(obj, base_layer.Layer) and not obj.built: 318 metadata = json_utils.decode(self._metadata[node_id].metadata) 319 self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) 320 321 # Create list of all possible children 322 children = [] 323 # Look for direct children 324 for reference in proto.children: 325 obj_child = obj._lookup_dependency(reference.local_name) 326 children.append((obj_child, reference.node_id, reference.local_name)) 327 328 # Add metrics that may have been added to the layer._metrics list. 329 # This is stored in the SavedModel as layer.keras_api.layer_metrics in 330 # SavedModels created after Tf 2.2. 331 metric_list_node_id = self._search_for_child_node( 332 node_id, [constants.KERAS_ATTR, 'layer_metrics']) 333 if metric_list_node_id is not None and hasattr(obj, '_metrics'): 334 obj_metrics = {m.name: m for m in obj._metrics} 335 for reference in self._proto.nodes[metric_list_node_id].children: 336 metric = obj_metrics.get(reference.local_name) 337 if metric is not None: 338 metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, 339 reference.local_name) 340 children.append((metric, reference.node_id, metric_path)) 341 342 for (obj_child, child_id, child_name) in children: 343 child_proto = self._proto.nodes[child_id] 344 345 if not isinstance(obj_child, trackable.Trackable): 346 continue 347 if (child_proto.user_object.identifier in 348 revived_types.registered_identifiers()): 349 setter = revived_types.get_setter(child_proto.user_object) 350 elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS: 351 setter = _revive_setter 352 else: 353 setter = setattr 354 # pylint: enable=protected-access 355 356 if child_id in self.loaded_nodes: 357 if self.loaded_nodes[child_id][0] is not obj_child: 358 # This means that the same trackable object is referenced by two 359 # different objects that were recreated from the config. 360 logging.warning( 361 'Looks like there is an object (perhaps variable or ' 362 'layer) that is shared between different layers/models. ' 363 'This may cause issues when restoring the variable ' 364 'values. Object: {}'.format(obj_child)) 365 continue 366 367 # Overwrite variable names with the ones saved in the SavedModel. 368 if (child_proto.WhichOneof('kind') == 'variable' and 369 child_proto.variable.name): 370 obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access 371 372 if isinstance(obj_child, data_structures.TrackableDataStructure): 373 setter = lambda *args: None 374 375 child_path = '{}.{}'.format(parent_path, child_name) 376 self._node_paths[child_id] = child_path 377 self._add_children_recreated_from_config( 378 obj_child, child_proto, child_id) 379 self.loaded_nodes[child_id] = obj_child, setter 380 381 def load_layers(self, compile=True): # pylint: disable=redefined-builtin 382 """Load all layer nodes from the metadata.""" 383 # Load metrics after models and layers, since it's likely that models 384 # and layers will create the metric when initialized (this avoids wasting 385 # time by creating objects multiple times). 386 metric_list = [] 387 for node_metadata in self._metadata.values(): 388 if node_metadata.identifier == constants.METRIC_IDENTIFIER: 389 metric_list.append(node_metadata) 390 continue 391 392 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 393 node_metadata.node_id, node_metadata.identifier, 394 node_metadata.metadata) 395 396 for node_metadata in metric_list: 397 try: 398 self.loaded_nodes[node_metadata.node_id] = self._load_layer( 399 node_metadata.node_id, node_metadata.identifier, 400 node_metadata.metadata) 401 except ValueError: 402 # Metrics are only needed when the model is compiled later. We ignore 403 # errors when trying to load custom metrics when `compile=False` until 404 # custom metrics are serialized properly (b/135550038). 405 if compile: 406 raise 407 logging.warning('Unable to restore custom metric. Please ensure that ' 408 'the layer implements `get_config` and `from_config` ' 409 'when saving. In addition, please use the ' 410 '`custom_objects` arg when calling `load_model()`.') 411 412 def _load_layer(self, node_id, identifier, metadata): 413 """Load a single layer from a SavedUserObject proto.""" 414 metadata = json_utils.decode(metadata) 415 416 # If node was already created 417 if node_id in self.loaded_nodes: 418 node, setter = self.loaded_nodes[node_id] 419 420 # Revive setter requires the object to have a `_serialized_attributes` 421 # property. Add it here. 422 _maybe_add_serialized_attributes(node, metadata) 423 424 config = metadata.get('config') 425 if _is_graph_network(node) and generic_utils.validate_config(config): 426 child_nodes = self._get_child_layer_node_ids(node_id) 427 self.model_layer_dependencies[node_id] = (node, child_nodes) 428 if not child_nodes: 429 self._models_to_reconstruct.append(node_id) 430 return node, setter 431 432 # Detect whether this object can be revived from the config. If not, then 433 # revive from the SavedModel instead. 434 obj, setter = self._revive_from_config(identifier, metadata, node_id) 435 if obj is None: 436 obj, setter = revive_custom_object(identifier, metadata) 437 438 # Add an attribute that stores the extra functions/objects saved in the 439 # SavedModel. Most of these functions/objects are ignored, but some are 440 # used later in the loading process (e.g. the list of regularization 441 # losses, or the training config of compiled models). 442 _maybe_add_serialized_attributes(obj, metadata) 443 return obj, setter 444 445 def _revive_from_config(self, identifier, metadata, node_id): 446 """Revives a layer/model from config, or returns None.""" 447 if identifier == constants.METRIC_IDENTIFIER: 448 obj = self._revive_metric_from_config(metadata) 449 else: 450 obj = ( 451 self._revive_graph_network(identifier, metadata, node_id) or 452 self._revive_layer_or_model_from_config(metadata, node_id)) 453 454 if obj is None: 455 return None, None 456 457 setter = self._config_node_setter(_revive_setter) 458 self._add_children_recreated_from_config( 459 obj, self._proto.nodes[node_id], node_id) 460 return obj, setter 461 462 def _revive_graph_network(self, identifier, metadata, node_id): 463 """Revives a graph network from config.""" 464 # Determine whether the metadata contains information for reviving a 465 # functional or Sequential model. 466 config = metadata.get('config') 467 if not generic_utils.validate_config(config): 468 return None 469 470 class_name = compat.as_str(metadata['class_name']) 471 if generic_utils.get_registered_object(class_name) is not None: 472 return None 473 model_is_functional_or_sequential = ( 474 metadata.get('is_graph_network', False) or 475 class_name == 'Sequential' or 476 class_name == 'Functional') 477 if not model_is_functional_or_sequential: 478 return None 479 480 # Revive functional and sequential models as blank model objects for now ( 481 # must be initialized to enable setattr tracking and attribute caching). 482 # Reconstruction of the network is deferred until all of the model's layers 483 # have been revived. 484 if class_name == 'Sequential': 485 model = models_lib.Sequential(name=config['name']) 486 # The model is a custom Sequential model. 487 elif identifier == constants.SEQUENTIAL_IDENTIFIER: 488 # Uses the custom class name, since the config does not have one. 489 model = models_lib.Sequential(name=class_name) 490 else: 491 model = models_lib.Functional( 492 inputs=[], outputs=[], name=config['name']) 493 494 # Record this model and its layers. This will later be used to reconstruct 495 # the model. 496 layers = self._get_child_layer_node_ids(node_id) 497 self.model_layer_dependencies[node_id] = (model, layers) 498 if not layers: 499 self._models_to_reconstruct.append(node_id) 500 return model 501 502 def _revive_layer_or_model_from_config(self, metadata, node_id): 503 """Revives a layer/custom model from config; returns None if infeasible.""" 504 # Check that the following requirements are met for reviving from config: 505 # 1. Object can be deserialized from config. 506 # 2. If the object needs to be built, then the build input shape can be 507 # found. 508 class_name = metadata.get('class_name') 509 config = metadata.get('config') 510 shared_object_id = metadata.get('shared_object_id') 511 must_restore_from_config = metadata.get('must_restore_from_config') 512 if not generic_utils.validate_config(config): 513 return None 514 515 try: 516 obj = layers_module.deserialize( 517 generic_utils.serialize_keras_class_and_config( 518 class_name, config, shared_object_id=shared_object_id)) 519 except ValueError: 520 if must_restore_from_config: 521 raise RuntimeError( 522 'Unable to restore a layer of class {cls}. Layers of ' 523 'class {cls} require that the class be provided to ' 524 'the model loading code, either by registering the ' 525 'class using @keras.utils.register_keras_serializable ' 526 'on the class def and including that file in your ' 527 'program, or by passing the class in a ' 528 'keras.utils.CustomObjectScope that wraps this load ' 529 'call.'.format(cls=class_name)) 530 else: 531 return None 532 533 # Use the dtype, name, and trainable status. Often times these are not 534 # specified in custom configs, so retrieve their values from the metadata. 535 # pylint: disable=protected-access 536 obj._name = metadata['name'] 537 if metadata.get('trainable') is not None: 538 obj.trainable = metadata['trainable'] 539 if metadata.get('dtype') is not None: 540 obj._set_dtype_policy(metadata['dtype']) 541 if metadata.get('stateful') is not None: 542 obj.stateful = metadata['stateful'] 543 # Restore model save spec for subclassed models. (layers do not store a 544 # SaveSpec) 545 if isinstance(obj, training_lib.Model): 546 save_spec = metadata.get('save_spec') 547 if save_spec is not None: 548 obj._set_save_spec(save_spec) 549 # pylint: enable=protected-access 550 551 build_input_shape = metadata.get('build_input_shape') 552 built = self._try_build_layer(obj, node_id, build_input_shape) 553 554 if not built: 555 # If the layer cannot be built, revive a custom layer instead. 556 return None 557 return obj 558 559 def _revive_metric_from_config(self, metadata): 560 """Revives a metric object using the config saved in the metadata.""" 561 class_name = compat.as_str(metadata['class_name']) 562 config = metadata.get('config') 563 564 if not generic_utils.validate_config(config): 565 return None 566 567 try: 568 obj = metrics.deserialize( 569 generic_utils.serialize_keras_class_and_config(class_name, config)) 570 except ValueError: 571 return None 572 573 build_input_shape = metadata.get('build_input_shape') 574 if build_input_shape is not None and hasattr(obj, '_build'): 575 obj._build(build_input_shape) # pylint: disable=protected-access 576 577 return obj 578 579 def _try_build_layer(self, obj, node_id, build_input_shape): 580 """Attempts to build the layer.""" 581 if obj.built or hasattr(obj.build, '_is_default'): 582 obj.built = True 583 return True 584 585 if build_input_shape is None: 586 build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True) 587 588 if build_input_shape is not None: 589 obj.build(build_input_shape) 590 base_layer.Layer.build(obj, build_input_shape) 591 return True 592 593 return False 594 595 def _load_edges(self): 596 """Add edges for all nodes that are not waiting on initialization.""" 597 for node_id, proto in enumerate(self._proto.nodes): 598 if node_id not in self.model_layer_dependencies: 599 self._add_object_graph_edges(proto, node_id) 600 601 def get_path(self, node_id): 602 return self._node_paths[node_id] 603 604 def finalize_objects(self): 605 """Finish setting up Keras objects. 606 607 This function is executed after all objects and functions have been created. 608 Call functions and losses are attached to each layer, and once all layers 609 have been fully set up, graph networks are initialized. 610 611 Subclassed models that are revived from the SavedModel are treated like 612 layers, and have their call/loss functions attached here. 613 """ 614 # Finish setting up layers and subclassed models. This step attaches call 615 # functions and losses to each object, and sets model inputs/outputs. 616 layers_revived_from_config = [] 617 layers_revived_from_saved_model = [] 618 for node_id, (node, _) in self.loaded_nodes.items(): 619 if (not isinstance(node, base_layer.Layer) or 620 # Don't finalize models until all layers have finished loading. 621 node_id in self.model_layer_dependencies): 622 continue 623 624 self._unblock_model_reconstruction(node_id, node) 625 626 if isinstance(node, input_layer.InputLayer): 627 continue 628 elif isinstance(node, metrics.Metric): 629 continue 630 631 if isinstance(node, (RevivedLayer, RevivedInputLayer)): 632 layers_revived_from_saved_model.append(node) 633 else: 634 layers_revived_from_config.append(node) 635 636 _finalize_saved_model_layers(layers_revived_from_saved_model) 637 _finalize_config_layers(layers_revived_from_config) 638 639 # Initialize graph networks, now that layer dependencies have been resolved. 640 self._reconstruct_all_models() 641 642 def _unblock_model_reconstruction(self, layer_id, layer): 643 """Removes layer from blocking model reconstruction.""" 644 for model_id, v in self.model_layer_dependencies.items(): 645 _, layers = v 646 if layer_id not in layers: 647 continue 648 layers[layers.index(layer_id)] = layer 649 if all(isinstance(x, base_layer.Layer) for x in layers): 650 self._models_to_reconstruct.append(model_id) 651 652 def _reconstruct_all_models(self): 653 """Reconstructs the network structure of all models.""" 654 all_initialized_models = set() 655 while self._models_to_reconstruct: 656 model_id = self._models_to_reconstruct.pop(0) 657 all_initialized_models.add(model_id) 658 model, layers = self.model_layer_dependencies[model_id] 659 self._reconstruct_model(model_id, model, layers) 660 _finalize_config_layers([model]) 661 662 if all_initialized_models != set(self.model_layer_dependencies.keys()): 663 # This should not happen. 664 uninitialized_model_ids = ( 665 set(self.model_layer_dependencies.keys()) - all_initialized_models) 666 uninitialized_model_names = [ 667 self.model_layer_dependencies[model_id][0].name 668 for model_id in uninitialized_model_ids] 669 raise ValueError('Error when loading from SavedModel -- the following ' 670 'models could not be initialized: {}' 671 .format(uninitialized_model_names)) 672 673 def _reconstruct_model(self, model_id, model, layers): 674 """Reconstructs the network structure.""" 675 config = json_utils.decode(self._metadata[model_id].metadata)['config'] 676 677 # Set up model inputs 678 if model.inputs: 679 # Inputs may already be created if the model is instantiated in another 680 # object's __init__. 681 pass 682 elif isinstance(model, models_lib.Sequential): 683 if not layers or not isinstance(layers[0], input_layer.InputLayer): 684 if config['layers'][0]['class_name'] == 'InputLayer': 685 layers.insert(0, input_layer.InputLayer.from_config( 686 config['layers'][0]['config'])) 687 elif 'batch_input_shape' in config['layers'][0]['config']: 688 batch_input_shape = config['layers'][0]['config']['batch_input_shape'] 689 layers.insert(0, input_layer.InputLayer( 690 input_shape=batch_input_shape[1:], 691 batch_size=batch_input_shape[0], 692 dtype=layers[0].dtype, 693 name=layers[0].name + '_input')) 694 model.__init__(layers, name=config['name']) 695 if not model.inputs: 696 first_layer = self._get_child_layer_node_ids(model_id)[0] 697 input_specs = self._infer_inputs(first_layer) 698 input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) 699 model._set_inputs(input_specs) # pylint: disable=protected-access 700 if not model.built and not isinstance(input_specs, dict): 701 model.build(input_shapes) 702 else: # Reconstruct functional model 703 (inputs, outputs, 704 created_layers) = functional_lib.reconstruct_from_config( 705 config, created_layers={layer.name: layer for layer in layers}) 706 model.__init__(inputs, outputs, name=config['name']) 707 functional_lib.connect_ancillary_layers(model, created_layers) 708 709 # Set model dtype. 710 _set_network_attributes_from_metadata(model) 711 712 # Unblock models that are dependent on this model. 713 self._unblock_model_reconstruction(model_id, model) 714 715 def _get_child_layer_node_ids(self, node_id): 716 """Returns the node ids of each layer in a Sequential/Functional model.""" 717 # Sequential and Functional track layers with names following the format 718 # "layer-N". Use this to generate the list of layers. 719 num_layers = 0 720 child_layers = {} 721 pattern = re.compile('layer-(\\d+)') 722 723 for child in self._proto.nodes[node_id].children: 724 m = pattern.match(child.local_name) 725 if m is None: 726 continue 727 layer_n = int(m.group(1)) 728 num_layers = max(layer_n + 1, num_layers) 729 child_layers[layer_n] = child.node_id 730 731 ordered = [] 732 for n in range(num_layers): 733 child = child_layers.get(n) 734 if child is None: 735 break 736 ordered.append(child) 737 return ordered 738 739 def _search_for_child_node(self, parent_id, path_to_child): 740 """Returns node id of child node. 741 742 A helper method for traversing the object graph proto. 743 744 As an example, say that the object graph proto in the SavedModel contains an 745 object with the following child and grandchild attributes: 746 747 `parent.child_a.child_b` 748 749 This method can be used to retrieve the node id of `child_b` using the 750 parent's node id by calling: 751 752 `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. 753 754 Args: 755 parent_id: node id of parent node 756 path_to_child: list of children names. 757 758 Returns: 759 node_id of child, or None if child isn't found. 760 """ 761 if not path_to_child: 762 return parent_id 763 764 for child in self._proto.nodes[parent_id].children: 765 if child.local_name == path_to_child[0]: 766 return self._search_for_child_node(child.node_id, path_to_child[1:]) 767 return None 768 769 def _infer_inputs(self, layer_node_id, convert_to_shapes=False): 770 """Infers input shape of layer from SavedModel functions.""" 771 call_fn_id = self._search_for_child_node( 772 layer_node_id, ['call_and_return_all_conditional_losses']) 773 if call_fn_id is None: 774 return None 775 776 concrete_functions = ( 777 self._proto.nodes[call_fn_id].function.concrete_functions) 778 if not concrete_functions: 779 return None 780 call_fn_name = concrete_functions[0] 781 call_fn_proto = self._proto.concrete_functions[call_fn_name] 782 structured_input_signature = nested_structure_coder.decode_proto( 783 call_fn_proto.canonicalized_input_signature) 784 inputs = structured_input_signature[0][0] 785 if convert_to_shapes: 786 return nest.map_structure(lambda spec: spec.shape, inputs) 787 else: 788 return inputs 789 790 def _config_node_setter(self, setter): 791 """Creates edges for nodes that are recreated from config.""" 792 def setattr_wrapper(obj, name, value): 793 # Avoid overwriting attributes of objects recreated from the config. 794 if obj._lookup_dependency(name) is None: # pylint: disable=protected-access 795 setter(obj, name, value) 796 return setattr_wrapper 797 798 799def _finalize_saved_model_layers(layers): 800 """Runs the final steps of loading Keras Layers from SavedModel.""" 801 # pylint: disable=protected-access 802 # 1. Set up call functions for all layers initialized from the SavedModel ( 803 # and not the config) 804 for layer in layers: 805 layer.built = True 806 layer_call = getattr(_get_keras_attr(layer), 807 'call_and_return_conditional_losses', None) 808 if layer_call and layer_call.concrete_functions: 809 layer.call = utils.use_wrapped_call( 810 layer, layer_call, return_method=True) 811 expects_training_arg = layer._serialized_attributes['metadata'][ 812 'expects_training_arg'] 813 if 'training' in layer_call.function_spec.arg_names: 814 # This could change the value of `expects_training_arg` if this layer 815 # doesn't expect a training arg, but has a child layer that does. 816 expects_training_arg = True 817 layer._init_call_fn_args(expects_training_arg) 818 else: 819 layer.call = types.MethodType( 820 _unable_to_call_layer_due_to_serialization_issue, layer) 821 822 for layer in layers: 823 # 2. Set model inputs and outputs. 824 if isinstance(layer, RevivedNetwork): 825 _set_network_attributes_from_metadata(layer) 826 827 if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'): 828 call_fn = _get_keras_attr(layer).call_and_return_conditional_losses 829 if not call_fn.concrete_functions: 830 continue 831 if call_fn.input_signature is None: 832 inputs = infer_inputs_from_restored_call_function(call_fn) 833 else: 834 inputs = call_fn.input_signature[0] 835 layer._set_inputs(inputs) # pylint: disable=protected-access 836 837 # 3. Add losses that aren't generated by the layer.call function. 838 _restore_layer_unconditional_losses(layer) 839 _restore_layer_activation_loss(layer) 840 841 # 4. Restore metrics list 842 _restore_layer_metrics(layer) 843 844 # pylint: enable=protected-access 845 846 847def _unable_to_call_layer_due_to_serialization_issue( 848 layer, *unused_args, **unused_kwargs): 849 """Replaces the `layer.call` if the layer was not fully serialized. 850 851 Keras Model/Layer serialization is relatively relaxed because SavedModels 852 are not always loaded back as keras models. Thus, when there is an issue 853 tracing a non-signature function, a warning is logged instead of raising an 854 error. This results in a SavedModel where the model's call function is saved, 855 but the internal layer call functions are not. 856 857 When deserialized with `tf.keras.models.load_model`, the internal layers 858 which do not have serialized call functions should raise an error when called. 859 860 Args: 861 layer: Layer without the serialized call function. 862 863 Raises: 864 ValueError 865 """ 866 867 raise ValueError( 868 'Cannot call custom layer {} of type {}, because the call function was ' 869 'not serialized to the SavedModel.' 870 'Please try one of the following methods to fix this issue:' 871 '\n\n(1) Implement `get_config` and `from_config` in the layer/model ' 872 'class, and pass the object to the `custom_objects` argument when ' 873 'loading the model. For more details, see: ' 874 'https://www.tensorflow.org/guide/keras/save_and_serialize' 875 '\n\n(2) Ensure that the subclassed model or layer overwrites `call` ' 876 'and not `__call__`. The input shape and dtype will be automatically ' 877 'recorded when the object is called, and used when saving. To manually ' 878 'specify the input shape/dtype, decorate the call function with ' 879 '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer))) 880 881 882def _finalize_config_layers(layers): 883 """Runs the final steps of loading Keras Layers from config.""" 884 for layer in layers: 885 # It is assumed that layers define their unconditional losses after being 886 # recreated from the config and built. The exceptions to this 887 # are Functional and Sequential models, which only store conditional losses 888 # (losses dependent on the inputs) in the config. Unconditional losses like 889 # weight regularization must be revived from the SavedModel. 890 if _is_graph_network(layer): 891 _restore_layer_unconditional_losses(layer) 892 893 # Some layers, like Dense, record their activation loss function in the 894 # config. However, not all layers do this, so the activation loss may be 895 # missing when restored from the config/hdf5. 896 # TODO(kathywu): Investigate ways to improve the config to ensure consistent 897 # loading behavior between HDF5 and SavedModel. 898 _restore_layer_activation_loss(layer) 899 900 # Restore metrics list. 901 _restore_layer_metrics(layer) 902 903 # Restore RNN layer states. 904 if (isinstance(layer, recurrent.RNN) and 905 layer.stateful and 906 hasattr(_get_keras_attr(layer), 'states')): 907 layer.states = getattr(_get_keras_attr(layer), 'states', None) 908 for variable in nest.flatten(layer.states): 909 backend.track_variable(variable) 910 911 # Perform any layer defined finalization of the layer state. 912 layer.finalize_state() 913 914 915def _finalize_metric(metric): 916 metric.update_state = types.MethodType(metrics_utils.update_state_wrapper( 917 metric.keras_api.update_state), metric) 918 metric.result = metric.keras_api.result 919 920 921def _restore_layer_unconditional_losses(layer): 922 """Restore unconditional losses from SavedModel.""" 923 if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'): 924 losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', []) 925 else: 926 # Some earlier SavedModels may not have layer_regularization_losses 927 # serialized separately. Fall back to using the regularization_losses 928 # list if it does not exist. 929 losses = layer._serialized_attributes.get('regularization_losses', []) # pylint: disable=protected-access 930 for loss in losses: 931 layer.add_loss(loss) 932 933 934def _restore_layer_activation_loss(layer): 935 """Restore actiation loss from SavedModel.""" 936 # Use wrapped activity regularizer function if the layer's activity 937 # regularizer wasn't created during initialization. 938 activity_regularizer = getattr(_get_keras_attr(layer), 939 'activity_regularizer_fn', None) 940 if activity_regularizer and not layer.activity_regularizer: 941 try: 942 layer.activity_regularizer = activity_regularizer 943 except AttributeError: 944 # This may happen if a layer wrapper is saved with an activity 945 # regularizer. The wrapper object's activity regularizer is unsettable. 946 pass 947 948 949def revive_custom_object(identifier, metadata): 950 """Revives object from SavedModel.""" 951 if ops.executing_eagerly_outside_functions(): 952 model_class = training_lib.Model 953 else: 954 model_class = training_lib_v1.Model 955 956 revived_classes = { 957 constants.INPUT_LAYER_IDENTIFIER: ( 958 RevivedInputLayer, input_layer.InputLayer), 959 constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer), 960 constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class), 961 constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional), 962 constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential), 963 } 964 parent_classes = revived_classes.get(identifier, None) 965 966 if parent_classes is not None: 967 parent_classes = revived_classes[identifier] 968 revived_cls = type( 969 compat.as_str(metadata['class_name']), parent_classes, {}) 970 return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access 971 else: 972 raise ValueError('Unable to restore custom object of type {} currently. ' 973 'Please make sure that the layer implements `get_config`' 974 'and `from_config` when saving. In addition, please use ' 975 'the `custom_objects` arg when calling `load_model()`.' 976 .format(identifier)) 977 978 979def _restore_layer_metrics(layer): 980 metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {}) 981 layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access 982 for name, metric in metrics_list.items(): 983 if name not in layer_metrics: 984 # Metrics may be added during initialization/building of custom layers. 985 layer._metrics.append(metric) # pylint: disable=protected-access 986 987 988# TODO(kathywu): Centrally define keys and functions for both serialization and 989# deserialization. 990class RevivedLayer(object): 991 """Keras layer loaded from a SavedModel.""" 992 993 @classmethod 994 def _init_from_metadata(cls, metadata): 995 """Create revived layer from metadata stored in the SavedModel proto.""" 996 init_args = dict( 997 name=metadata['name'], 998 trainable=metadata['trainable']) 999 if metadata.get('dtype') is not None: 1000 init_args['dtype'] = metadata['dtype'] 1001 if metadata.get('batch_input_shape') is not None: 1002 init_args['batch_input_shape'] = metadata['batch_input_shape'] 1003 1004 revived_obj = cls(**init_args) 1005 1006 with utils.no_automatic_dependency_tracking_scope(revived_obj): 1007 # pylint:disable=protected-access 1008 revived_obj._expects_training_arg = metadata['expects_training_arg'] 1009 config = metadata.get('config') 1010 if generic_utils.validate_config(config): 1011 revived_obj._config = config 1012 if metadata.get('input_spec') is not None: 1013 revived_obj.input_spec = recursively_deserialize_keras_object( 1014 metadata['input_spec'], 1015 module_objects={'InputSpec': input_spec.InputSpec}) 1016 if metadata.get('activity_regularizer') is not None: 1017 revived_obj.activity_regularizer = regularizers.deserialize( 1018 metadata['activity_regularizer']) 1019 if metadata.get('_is_feature_layer') is not None: 1020 revived_obj._is_feature_layer = metadata['_is_feature_layer'] 1021 if metadata.get('stateful') is not None: 1022 revived_obj.stateful = metadata['stateful'] 1023 # pylint:enable=protected-access 1024 1025 return revived_obj, _revive_setter 1026 1027 @property 1028 def keras_api(self): 1029 return self._serialized_attributes.get(constants.KERAS_ATTR, None) 1030 1031 def get_config(self): 1032 if hasattr(self, '_config'): 1033 return self._config 1034 else: 1035 raise NotImplementedError 1036 1037 1038def _revive_setter(layer, name, value): 1039 """Setter function that saves some attributes to separate dictionary.""" 1040 # Many attributes in the SavedModel conflict with properties defined in 1041 # Layer and Model. Save these attributes to a separate dictionary. 1042 if name in PUBLIC_ATTRIBUTES: 1043 # pylint: disable=protected-access 1044 if isinstance(value, trackable.Trackable): 1045 layer._track_trackable(value, name=name) 1046 layer._serialized_attributes[name] = value 1047 # pylint: enable=protected-access 1048 elif (isinstance(layer, functional_lib.Functional) and 1049 re.match(r'^layer(_with_weights)?-[\d+]', name) is not None): 1050 # Edges named "layer-n" or "layer_with_weights-n", which are tracked in 1051 # network._track_layers, should not be added as an attribute. They should 1052 # be temporarily added as a dependency so that checkpointed values can be 1053 # restored. These dependencies are manually deleted in 1054 # KerasObjectLoader.del_tracking. 1055 1056 # Set `overwrite=True` in the case that `layer` already tracks a different 1057 # layer-n. This may cause variable values to not be loaded properly in the 1058 # original layer-n, but we already warn the users about this 1059 # (ctrl-f "shared between different layers/models"). 1060 layer._track_trackable(value, name, overwrite=True) # pylint: disable=protected-access 1061 elif getattr(layer, name, None) is not None: 1062 # Don't overwrite already defined attributes. 1063 pass 1064 else: 1065 setattr(layer, name, value) 1066 1067 1068class RevivedInputLayer(object): 1069 """InputLayer loaded from a SavedModel.""" 1070 1071 @classmethod 1072 def _init_from_metadata(cls, metadata): 1073 """Revives the saved InputLayer from the Metadata.""" 1074 init_args = dict( 1075 name=metadata['name'], 1076 dtype=metadata['dtype'], 1077 sparse=metadata['sparse'], 1078 ragged=metadata['ragged'], 1079 batch_input_shape=metadata['batch_input_shape']) 1080 revived_obj = cls(**init_args) 1081 with utils.no_automatic_dependency_tracking_scope(revived_obj): 1082 revived_obj._config = metadata['config'] # pylint:disable=protected-access 1083 1084 return revived_obj, setattr 1085 1086 def get_config(self): 1087 return self._config 1088 1089 1090def recursively_deserialize_keras_object(config, module_objects=None): 1091 """Deserialize Keras object from a nested structure.""" 1092 if isinstance(config, dict): 1093 if 'class_name' in config: 1094 return generic_utils.deserialize_keras_object( 1095 config, module_objects=module_objects) 1096 else: 1097 return {key: recursively_deserialize_keras_object(config[key], 1098 module_objects) 1099 for key in config} 1100 if isinstance(config, (tuple, list)): 1101 return [recursively_deserialize_keras_object(x, module_objects) 1102 for x in config] 1103 else: 1104 raise ValueError('Unable to decode config: {}'.format(config)) 1105 1106 1107def get_common_shape(x, y): 1108 """Find a `TensorShape` that is compatible with both `x` and `y`.""" 1109 if x is None != y is None: 1110 raise RuntimeError( 1111 'Cannot find a common shape when LHS shape is None but RHS shape ' 1112 'is not (or vice versa): %s vs. %s' % (x, y)) 1113 if x is None: 1114 return None # The associated input was not a Tensor, no shape generated. 1115 if not isinstance(x, tensor_shape.TensorShape): 1116 raise TypeError('Expected x to be a TensorShape but saw %s' % (x,)) 1117 if not isinstance(y, tensor_shape.TensorShape): 1118 raise TypeError('Expected y to be a TensorShape but saw %s' % (y,)) 1119 if x.rank != y.rank or x.rank is None: 1120 return tensor_shape.TensorShape(None) 1121 dims = [] 1122 for dim_x, dim_y in zip(x.dims, y.dims): 1123 if (dim_x != dim_y 1124 or tensor_shape.dimension_value(dim_x) is None 1125 or tensor_shape.dimension_value(dim_y) is None): 1126 dims.append(None) 1127 else: 1128 dims.append(tensor_shape.dimension_value(dim_x)) 1129 return tensor_shape.TensorShape(dims) 1130 1131 1132def infer_inputs_from_restored_call_function(fn): 1133 """Returns TensorSpec of inputs from a restored call function. 1134 1135 Args: 1136 fn: Restored layer call function. It is assumed that `fn` has at least 1137 one concrete function and that the inputs are in the first argument. 1138 1139 Returns: 1140 TensorSpec of call function inputs. 1141 """ 1142 def common_spec(x, y): 1143 common_shape = get_common_shape(x.shape, y.shape) 1144 if isinstance(x, sparse_tensor.SparseTensorSpec): 1145 return sparse_tensor.SparseTensorSpec(common_shape, x.dtype) 1146 elif isinstance(x, ragged_tensor.RaggedTensorSpec): 1147 return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype) 1148 return tensor_spec.TensorSpec(common_shape, x.dtype, x.name) 1149 1150 spec = fn.concrete_functions[0].structured_input_signature[0][0] 1151 for concrete in fn.concrete_functions[1:]: 1152 spec2 = concrete.structured_input_signature[0][0] 1153 spec = nest.map_structure(common_spec, spec, spec2) 1154 return spec 1155 1156 1157class RevivedNetwork(RevivedLayer): 1158 """Keras network of layers loaded from a SavedModel.""" 1159 1160 @classmethod 1161 def _init_from_metadata(cls, metadata): 1162 """Create revived network from metadata stored in the SavedModel proto.""" 1163 revived_obj = cls(name=metadata['name']) 1164 1165 # Store attributes revived from SerializedAttributes in a un-tracked 1166 # dictionary. The attributes are the ones listed in CommonEndpoints or 1167 # "keras_api" for keras-specific attributes. 1168 with utils.no_automatic_dependency_tracking_scope(revived_obj): 1169 # pylint:disable=protected-access 1170 revived_obj._expects_training_arg = metadata['expects_training_arg'] 1171 config = metadata.get('config') 1172 if generic_utils.validate_config(config): 1173 revived_obj._config = config 1174 1175 if metadata.get('activity_regularizer') is not None: 1176 revived_obj.activity_regularizer = regularizers.deserialize( 1177 metadata['activity_regularizer']) 1178 # pylint:enable=protected-access 1179 1180 return revived_obj, _revive_setter # pylint:disable=protected-access 1181 1182 1183def _set_network_attributes_from_metadata(revived_obj): 1184 """Sets attributes recorded in the metadata.""" 1185 with utils.no_automatic_dependency_tracking_scope(revived_obj): 1186 # pylint:disable=protected-access 1187 metadata = revived_obj._serialized_attributes['metadata'] 1188 if metadata.get('dtype') is not None: 1189 revived_obj._set_dtype_policy(metadata['dtype']) 1190 revived_obj._trainable = metadata['trainable'] 1191 # pylint:enable=protected-access 1192 1193 1194def _maybe_add_serialized_attributes(layer, metadata): 1195 # Store attributes revived from SerializedAttributes in a un-tracked 1196 # dictionary. The attributes are the ones listed in CommonEndpoints or 1197 # "keras_api" for keras-specific attributes. 1198 if not hasattr(layer, '_serialized_attributes'): 1199 with utils.no_automatic_dependency_tracking_scope(layer): 1200 layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access 1201 1202 1203def _get_keras_attr(layer): 1204 return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR, 1205 None) 1206