1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Exports a SavedModel from a Trackable Python object.""" 16 17import collections 18import os 19import re 20import sys 21import traceback 22 23from absl import logging 24 25from tensorflow.core.config import flags 26from tensorflow.core.framework import function_pb2 27from tensorflow.core.framework import versions_pb2 28from tensorflow.core.protobuf import meta_graph_pb2 29from tensorflow.core.protobuf import saved_model_pb2 30from tensorflow.core.protobuf import saved_object_graph_pb2 31from tensorflow.python.checkpoint import checkpoint 32from tensorflow.python.checkpoint import checkpoint_options 33from tensorflow.python.checkpoint import functional_saver 34from tensorflow.python.checkpoint import graph_view 35from tensorflow.python.checkpoint import save_util_v1 36from tensorflow.python.checkpoint import util as checkpoint_util 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.eager import function as defun 40from tensorflow.python.eager import function_saved_model_utils 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import error_interpolation 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import function as framework_fn 45from tensorflow.python.framework import meta_graph 46from tensorflow.python.framework import ops 47from tensorflow.python.framework import tensor_util 48from tensorflow.python.framework import versions 49from tensorflow.python.lib.io import file_io 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import resource_variable_ops 53from tensorflow.python.saved_model import builder_impl 54from tensorflow.python.saved_model import function_serialization 55from tensorflow.python.saved_model import pywrap_saved_model 56from tensorflow.python.saved_model import registration 57from tensorflow.python.saved_model import revived_types 58from tensorflow.python.saved_model import save_context 59from tensorflow.python.saved_model import save_options 60from tensorflow.python.saved_model import signature_constants 61from tensorflow.python.saved_model import signature_def_utils 62from tensorflow.python.saved_model import signature_serialization 63from tensorflow.python.saved_model import tag_constants 64from tensorflow.python.saved_model import tracing_utils 65from tensorflow.python.saved_model import utils_impl 66from tensorflow.python.saved_model.pywrap_saved_model import constants 67from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting 68from tensorflow.python.saved_model.pywrap_saved_model import metrics 69from tensorflow.python.trackable import asset 70from tensorflow.python.trackable import base 71from tensorflow.python.trackable import resource 72from tensorflow.python.trackable import trackable_utils 73from tensorflow.python.training.saving import saveable_object_util 74from tensorflow.python.util import compat 75from tensorflow.python.util import object_identity 76from tensorflow.python.util.tf_export import tf_export 77 78_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant)) 79 80# Container for tensors captured from external functions. 81_CapturedTensor = collections.namedtuple("_CapturedTensor", 82 ["name", "concrete_function"]) 83 84# Number of untraced functions to display to user in warning message. 85_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5 86 87# API label for SavedModel metrics. 88_SAVE_V2_LABEL = "save_v2" 89 90 91class _AugmentedGraphView(graph_view.ObjectGraphView): 92 """An extendable graph which also tracks functions attached to objects. 93 94 Extensions through `add_object` appear in the object graph and any checkpoints 95 generated from it, even if they are not dependencies of the node they were 96 attached to in the saving program. For example a `.signatures` attribute is 97 added to exported SavedModel root objects without modifying the root object 98 itself. 99 100 Also tracks functions attached to objects in the graph, through the caching 101 `_list_functions` method. Enumerating functions only through this method 102 ensures that we get a consistent view of functions, even if object attributes 103 create new functions every time they are accessed. 104 """ 105 106 def __init__(self, root): 107 super(_AugmentedGraphView, self).__init__(root) 108 109 # Cache the results of `GraphView.list_children()` to ensure that the 110 # `Trackable` children are gathered exactly once. 111 self._children_cache = object_identity.ObjectIdentityDictionary() 112 113 # Cache shared between objects in the same object graph. This is passed to 114 # `Trackable._trackable_children()`. 115 self._serialization_cache = object_identity.ObjectIdentityDictionary() 116 117 # Maps functions -> wrapped functions that capture non-cached variables. 118 self._wrapped_functions = {} 119 120 self.untraced_functions = [] 121 122 def set_signature(self, signature_map, wrapped_functions): 123 """Attach signature to the root object. 124 125 Args: 126 signature_map: An object that contains signature functions. 127 wrapped_functions: A dictionary mapping functions to functions that are 128 guaranteed to not capture cached variables (functions that capture 129 cached variables can't be saved). 130 """ 131 self.list_children(self.root) 132 # Overrides existing dependency. 133 name = signature_serialization.SIGNATURE_ATTRIBUTE_NAME 134 self._children_cache[self.root][name] = signature_map 135 self._wrapped_functions.update(wrapped_functions) 136 137 def _breadth_first_traversal(self): 138 """Returns all trackable objects in the SavedObjectGraph.""" 139 # This method is overriden to merge all equivalent constant tensors and 140 # Assets in the object graph. 141 142 trackable_objects, _ = ( 143 super(_AugmentedGraphView, self)._breadth_first_traversal()) 144 145 asset_paths = object_identity.ObjectIdentityDictionary() 146 constant_captures = object_identity.ObjectIdentityDictionary() 147 for obj in trackable_objects: 148 if isinstance(obj, asset.Asset): 149 asset_paths[obj.asset_path] = obj 150 if isinstance(obj, function_saved_model_utils.TrackableConstant): 151 constant_captures[obj.capture] = obj 152 153 def _get_merged_trackable(x): 154 if isinstance(x, asset.Asset): 155 return asset_paths[x.asset_path] 156 if isinstance(x, function_saved_model_utils.TrackableConstant): 157 if x.capture in asset_paths: 158 return asset_paths[x.capture] 159 else: 160 return constant_captures[x.capture] 161 return x 162 163 for obj in list(self._children_cache.keys()): 164 if _get_merged_trackable(obj) is not obj: 165 del self._children_cache[obj] 166 continue 167 for name, child in self._children_cache[obj].items(): 168 self._children_cache[obj][name] = _get_merged_trackable(child) 169 170 return super(_AugmentedGraphView, self)._breadth_first_traversal() 171 172 def list_children(self, obj): 173 """Lists children of `obj` for SavedModel.""" 174 if obj not in self._children_cache: 175 children = self._children_cache[obj] = {} 176 177 for name, child in super(_AugmentedGraphView, self).list_children( 178 obj, 179 save_type=base.SaveType.SAVEDMODEL, 180 cache=self._serialization_cache): 181 if isinstance(child, defun.ConcreteFunction): 182 child = self._maybe_uncache_variable_captures(child) 183 children[name] = child 184 185 # Keep track of untraced functions for later reporting to the user. 186 if isinstance(obj, def_function.Function) and not children: 187 self.untraced_functions.append(obj.name) 188 189 for name, child in self._children_cache[obj].items(): 190 yield base.TrackableReference(name, child) 191 192 def get_child(self, obj, name): 193 return self._children_cache[obj][name] 194 195 def _maybe_uncache_variable_captures(self, concrete_function): 196 if concrete_function in self._wrapped_functions: 197 return self._wrapped_functions[concrete_function] 198 for capture in concrete_function.captured_inputs: 199 if hasattr(capture, "_cached_variable"): 200 if concrete_function not in self._wrapped_functions: 201 wrapped = self._wrapped_functions[concrete_function] = ( 202 function_serialization.wrap_cached_variables(concrete_function)) 203 return wrapped 204 return concrete_function 205 206 def list_dependencies(self, obj): 207 """Yields `Trackables` that must be loaded before `obj`. 208 209 Dependencies and children are both dictionaries of `Trackables`. Children 210 define the object graph structure (used in both checkpoints and SavedModel), 211 while dependency defines the order used to load the SavedModel 212 213 Args: 214 obj: A `Trackable` object 215 216 Yields: 217 Tuple of dependency names and trackable objects. 218 219 Raises: 220 TypeError: if any of the returned dependencies are not instances of 221 `Trackable`. 222 """ 223 if obj not in self._children_cache: 224 # Slot variables do not appear in the children_cache. 225 children = {} 226 else: 227 children = self._children_cache[obj] 228 for name, dep in obj._deserialization_dependencies(children).items(): # pylint: disable=protected-access 229 if not isinstance(dep, base.Trackable): 230 raise TypeError( 231 f"The dependency of type {type(dep)} is not an instance `Trackable`" 232 ", and can't be saved to SavedModel. Please check the " 233 "implementation of `_deserialization_dependencies` in the parent " 234 f"object {obj}.") 235 yield name, dep 236 237 238class _SaveableView(object): 239 """Provides a frozen view over a trackable root. 240 241 This class helps to create a single stable view over an object to save. The 242 saving code should access properties and functions via this class and not via 243 the original object as there are cases where an object construct their 244 trackable attributes and functions dynamically per call and will yield 245 different objects if invoked more than once. 246 247 Changes to the graph, for example adding objects, must happen in 248 `augmented_graph_view` (an `_AugmentedGraphView`) before the `_SaveableView` 249 is constructed. Changes after the `_SaveableView` has been constructed will be 250 ignored. 251 """ 252 253 def __init__(self, augmented_graph_view, options): 254 """Initializes a SaveableView. 255 256 Args: 257 augmented_graph_view: A GraphView object. 258 options: A SaveOptions instance. 259 """ 260 261 self.augmented_graph_view = augmented_graph_view 262 self._options = options 263 264 (self._trackable_objects, self.node_paths, self.node_ids, 265 self._slot_variables, self.object_names) = ( 266 checkpoint_util.objects_ids_and_slot_variables_and_paths( 267 self.augmented_graph_view)) 268 269 untraced_functions = self.augmented_graph_view.untraced_functions 270 if untraced_functions: 271 logging.warning( 272 "Found untraced functions such as %s while saving (showing %d of %d)." 273 " These functions will not be directly callable after loading.", 274 ", ".join(untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]), 275 min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(untraced_functions)), 276 len(untraced_functions)) 277 278 self._initialize_save_and_restore_functions() 279 self._initialize_nodes_and_concrete_functions() 280 281 self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() 282 283 def _initialize_save_and_restore_functions(self): 284 """Generates all checkpoint save/restore functions. 285 286 The save and restore functions are generated in the eager context (or in the 287 user's Graph/Session) before being copied to the exported GraphDef. These 288 functions record the ops for saving/restoring the entire object or 289 individual objects (e.g. variables and hash tables). 290 291 The global save and restore functions are generated for compatibility with 292 TF1 and loading from C++, and is saved in the `MetaGraphDef.saver_def`. 293 294 The individual functions are generated for the Python TF2 use case, where 295 users use the loaded SavedModel as-is, or compose new models using parts 296 of the object loaded from the SavedModel. These functions are recorded in 297 the `saveable_objects` map in the `SavedObject` proto. 298 """ 299 checkpoint_factory_map, registered_savers = ( 300 save_util_v1.get_checkpoint_factories_and_keys(self.object_names)) 301 self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary() 302 for saver_name, trackables in registered_savers.items(): 303 for trackable in trackables.values(): 304 self._obj_to_registered_saver[trackable] = saver_name 305 self._saveable_objects_map = ( 306 _gen_save_and_restore_functions(checkpoint_factory_map)) 307 308 def _initialize_nodes_and_concrete_functions(self): 309 """Creates graph with nodes for trackable objects and functions. 310 311 Adds functions for each trackable object to `self.nodes` and associated 312 concrete functions to `self.concrete_functions` for serialization. 313 """ 314 self.nodes = list(self._trackable_objects) 315 self.gradient_functions = [] 316 self.gradient_defs = [] 317 318 for obj in self.nodes: 319 if obj in self._saveable_objects_map: 320 for save_fn, restore_fn in self._saveable_objects_map[obj].values(): 321 self.node_ids[save_fn] = len(self.nodes) 322 self.nodes.append(save_fn) 323 324 self.node_ids[restore_fn] = len(self.nodes) 325 self.nodes.append(restore_fn) 326 327 self.concrete_functions = [ 328 obj for obj in self.nodes if isinstance(obj, defun.ConcreteFunction) 329 ] 330 331 @property 332 def concrete_and_gradient_functions(self): 333 return self.concrete_functions + self.gradient_functions 334 335 @property 336 def root(self): 337 return self.nodes[0] 338 339 def fill_object_graph_proto(self, proto): 340 """Populate the nodes, children and slot_variables of a SavedObjectGraph.""" 341 for node_id, node in enumerate(self.nodes): 342 assert self.node_ids[node] == node_id 343 object_proto = proto.nodes.add() 344 object_proto.slot_variables.extend(self._slot_variables.get(node, ())) 345 if isinstance(node, _CapturedTensor): 346 continue 347 for child in self.augmented_graph_view.list_children(node): 348 child_proto = object_proto.children.add() 349 child_proto.node_id = self.node_ids[child.ref] 350 child_proto.local_name = child.name 351 for name, ref in self.augmented_graph_view.list_dependencies(node): 352 child_proto = object_proto.dependencies.add() 353 child_proto.node_id = self.node_ids[ref] 354 child_proto.local_name = name 355 356 if node in self._saveable_objects_map: 357 assert node not in self._obj_to_registered_saver, ( 358 "Objects can't have both SaveableObjects and a registered saver") 359 360 for local_name, (save_fn, restore_fn) in ( 361 self._saveable_objects_map[node].items()): 362 saveable_object_proto = object_proto.saveable_objects[local_name] 363 saveable_object_proto.save_function = self.node_ids[save_fn] 364 saveable_object_proto.restore_function = self.node_ids[restore_fn] 365 366 elif node in self._obj_to_registered_saver: 367 object_proto.registered_saver = self._obj_to_registered_saver[node] 368 369 def map_resources(self): 370 """Makes new resource handle ops corresponding to existing resource tensors. 371 372 Creates resource handle ops in the current default graph, whereas 373 `accessible_objects` will be from an eager context. Resource mapping adds 374 resource handle ops to the main GraphDef of a SavedModel, which allows the 375 C++ loader API to interact with resources. 376 377 Returns: 378 A tuple of (object_map, tensor_map, asset_info): 379 object_map: A dictionary mapping from object in `accessible_objects` to 380 replacement objects created to hold the new resource tensors. 381 tensor_map: A dictionary mapping from resource tensors extracted from 382 `accessible_objects` to newly created resource tensors. 383 asset_info: An _AssetInfo tuple describing external assets referenced 384 from accessible_objects. 385 """ 386 # Only makes sense when adding to the export Graph 387 assert not context.executing_eagerly() 388 # TODO(b/205007558): Handle MirroredVariables and other types of variables 389 # which may need special casing. 390 object_map = object_identity.ObjectIdentityDictionary() 391 tensor_map = {} 392 asset_info = _AssetInfo( 393 asset_defs=[], 394 asset_initializers_by_resource={}, 395 asset_filename_map={}, 396 asset_index={}) 397 398 for node_id in _dependency_sorted_node_ids(self): 399 obj = self.nodes[node_id] 400 tensors = obj._export_to_saved_model_graph( # pylint: disable=protected-access 401 object_map=object_map, 402 tensor_map=tensor_map, 403 options=self._options) 404 if isinstance(obj, asset.Asset): 405 _add_asset_info(obj, asset_info, tensor_map[obj.asset_path]) 406 if tensors: 407 for tensor in tensors: 408 self.captured_tensor_node_ids[tensor] = node_id 409 410 return object_map, tensor_map, asset_info 411 412 def add_capture_and_node(self, capture, node): 413 node_id = len(self.nodes) 414 self.nodes.append(node) 415 self.node_ids[capture] = node_id 416 self.node_ids[node] = node_id 417 self.captured_tensor_node_ids[capture] = node_id 418 return node_id 419 420 def get_concrete_resource_initializers(self): 421 concrete_initializers = [] 422 for obj in self.nodes: 423 if isinstance(obj, resource.CapturableResource): 424 concrete_initializers.append( 425 self.augmented_graph_view.get_child( 426 obj, "_initialize").get_concrete_function()) 427 return concrete_initializers 428 429 430def _gen_save_and_restore_functions(checkpoint_factory_map): 431 """Generates global and individual save/restore concrete functions. 432 433 The global functions records the ops to save and restore the entire object to 434 a file prefix, while the individual functions save and restore value tensors 435 for resources. 436 437 This function is intended to run on the output of 438 `save_util_v1.get_checkpoint_factories_and_keys(object_names)`, 439 which returns the generated a map of `_CheckpointFactoryData`. 440 441 Args: 442 checkpoint_factory_map: A dictionary mapping trackable objects to 443 a list of `_CheckpointFactoryData`. 444 445 Returns: 446 Tuple of ( 447 saveable_fn_map: Maps obj -> factory name -> (concrete save, restore) 448 ) 449 """ 450 # Maps obj -> factory attribute_name -> (concrete save, concrete restore) 451 # This 452 saveable_fn_map = object_identity.ObjectIdentityDictionary() 453 454 for obj, factory_data_list in checkpoint_factory_map.items(): 455 if resource_variable_ops.is_resource_variable(obj) or not factory_data_list: 456 # There is no need to trace the save and restore functions for variables. 457 continue 458 459 if factory_data_list[0].name == trackable_utils.SERIALIZE_TO_TENSORS_NAME: 460 # Trace Trackable save and restore functions. 461 assert len(factory_data_list) == 1 462 saveable_fn_map[obj] = {trackable_utils.SERIALIZE_TO_TENSORS_NAME: ( 463 tracing_utils.trace_save_and_restore(obj))} 464 else: 465 # Trace deprecated SaveableObject save and restore functions. 466 saveable_fn_map[obj] = ( 467 saveable_object_util.trace_save_restore_function_map( 468 obj, factory_data_list)) 469 return saveable_fn_map 470 471 472def _tensor_dict_to_tensorinfo(tensor_dict): 473 return { 474 key: utils_impl.build_tensor_info_internal(value) 475 for key, value in tensor_dict.items() 476 } 477 478 479def _to_safe_name_scope(signature_key, user_input_name): 480 """Creates a sanitized name scope from user signature and input names. 481 482 Concatenates signature and input names, sanitizing as needed to be a valid 483 scope name. 484 485 Args: 486 signature_key: The user-provided key for the signature. 487 user_input_name: The user-provided name for the input placeholder. 488 489 Returns: 490 A name scope that is safe to be used in tf.name_scope(). 491 """ 492 name_scope = "{}_{}".format(signature_key, user_input_name) 493 if re.match(r"^[A-Za-z0-9.][A-Za-z0-9_.\\-]*$", name_scope): 494 return name_scope 495 invalid_prefix_stripped = re.sub(r"^[^A-Za-z0-9.]*", "", name_scope) 496 return re.sub(r"[^A-Za-z0-9_.\\-]", "_", invalid_prefix_stripped) 497 498 499def _map_function_arguments_to_created_inputs(function_arguments, signature_key, 500 function_name): 501 """Creates exterior placeholders in the exported graph for function arguments. 502 503 Functions have two types of inputs: tensors captured from the outside (eager) 504 context, and arguments to the function which we expect to receive from the 505 user at each call. `_map_captures_to_created_tensors` replaces 506 captured tensors with stand-ins (typically these are resource dtype tensors 507 associated with variables). `_map_function_inputs_to_created_inputs` runs over 508 every argument, creating a new placeholder for each which will belong to the 509 exported graph rather than the function body. 510 511 Args: 512 function_arguments: A list of argument placeholders in the function body. 513 signature_key: The name of the signature being exported, for error messages. 514 function_name: The name of the function, for error messages. 515 516 Returns: 517 A tuple of (mapped_inputs, exterior_placeholders) 518 mapped_inputs: A list with entries corresponding to `function_arguments` 519 containing all of the inputs of the function gathered from the exported 520 graph (both captured resources and arguments). 521 exterior_argument_placeholders: A dictionary mapping from argument names 522 to placeholders in the exported graph, containing the explicit arguments 523 to the function which a user is expected to provide. 524 525 Raises: 526 ValueError: If argument names are not unique. 527 """ 528 # `exterior_argument_placeholders` holds placeholders which are outside the 529 # function body, directly contained in a MetaGraph of the SavedModel. The 530 # function body itself contains nearly identical placeholders used when 531 # running the function, but these exterior placeholders allow Session-based 532 # APIs to call the function using feeds and fetches which name Tensors in the 533 # MetaGraph. 534 exterior_argument_placeholders = {} 535 mapped_inputs = [] 536 for placeholder in function_arguments: 537 # `export_captures` contains an exhaustive set of captures, so if we don't 538 # find the input there then we now know we have an argument. 539 user_input_name = compat.as_str_any( 540 placeholder.op.get_attr("_user_specified_name")) 541 # If the internal placeholders for a function have names which were 542 # uniquified by TensorFlow, then a single user-specified argument name 543 # must refer to multiple Tensors. The resulting signatures would be 544 # confusing to call. Instead, we throw an exception telling the user to 545 # specify explicit names. 546 if user_input_name != placeholder.op.name: 547 # This should be unreachable, since concrete functions may not be 548 # generated with non-unique argument names. 549 raise ValueError( 550 "Got non-flat/non-unique argument names for SavedModel signature " 551 f"'{signature_key}': more than one argument to " 552 f"'{compat.as_str_any(function_name)}' was named " 553 f"'{user_input_name}'. " 554 "Signatures have one Tensor per named input, so to have " 555 "predictable names Python functions used to generate these " 556 "signatures should avoid *args and Tensors in nested " 557 "structures unless unique names are specified for each. Use " 558 "tf.TensorSpec(..., name=...) to provide a name for a Tensor " 559 "input.") 560 arg_placeholder = array_ops.placeholder( 561 shape=placeholder.shape, 562 dtype=placeholder.dtype, 563 name=_to_safe_name_scope(signature_key, user_input_name)) 564 exterior_argument_placeholders[user_input_name] = arg_placeholder 565 mapped_inputs.append(arg_placeholder) 566 return mapped_inputs, exterior_argument_placeholders 567 568 569def _generate_signatures(signature_functions, object_map): 570 """Validates and calls `signature_functions` in the exported graph. 571 572 Args: 573 signature_functions: A dictionary mapping string keys to concrete TensorFlow 574 functions (e.g. from `signature_serialization.canonicalize_signatures`) 575 which will be used to generate SignatureDefs. 576 object_map: A dictionary that contains mappings from signature functions to 577 concrete functions in the exported graph. 578 579 Returns: 580 Each function in the `signature_functions` dictionary is called with 581 placeholder Tensors, generating a function call operation and output 582 Tensors. The placeholder Tensors, the function call operation, and the 583 output Tensors from the function call are part of the default Graph. 584 585 This function then returns a dictionary with the same structure as 586 `signature_functions`, with the concrete functions replaced by SignatureDefs 587 implicitly containing information about how to call each function from a 588 TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference 589 the generated placeholders and Tensor outputs by name. 590 591 The caller is expected to include the default Graph set while calling this 592 function as a MetaGraph in a SavedModel, including the returned 593 SignatureDefs as part of that MetaGraph. 594 """ 595 signatures = {} 596 for signature_key, function in sorted(signature_functions.items()): 597 if function.graph.captures: 598 argument_inputs = function.graph.inputs[:-len(function.graph.captures)] 599 else: 600 argument_inputs = function.graph.inputs 601 mapped_inputs, exterior_argument_placeholders = ( 602 _map_function_arguments_to_created_inputs(argument_inputs, 603 signature_key, function.name)) 604 outputs = object_map[function](*mapped_inputs) 605 signatures[signature_key] = signature_def_utils.build_signature_def( 606 _tensor_dict_to_tensorinfo(exterior_argument_placeholders), 607 _tensor_dict_to_tensorinfo(outputs), 608 method_name=signature_constants.PREDICT_METHOD_NAME) 609 return signatures 610 611 612_AssetInfo = collections.namedtuple( 613 "_AssetInfo", 614 [ 615 # List of AssetFileDef protocol buffers 616 "asset_defs", 617 # Map from asset variable resource Tensors to their init ops 618 "asset_initializers_by_resource", 619 # Map from base asset filenames to full paths 620 "asset_filename_map", 621 # Map from Asset to index of corresponding AssetFileDef 622 "asset_index" 623 ]) 624 625 626def _add_asset_info(trackable_asset, asset_info, mapped_path_variable): 627 """Add `trackable_asset` to `asset_info`.""" 628 original_path_tensor = trackable_asset.asset_path 629 original_path = tensor_util.constant_value(original_path_tensor) 630 try: 631 original_path = str(original_path.astype(str)) 632 except AttributeError: 633 # Already a string rather than a numpy array 634 pass 635 636 path = builder_impl.get_asset_filename_to_add( 637 asset_filepath=original_path, 638 asset_filename_map=asset_info.asset_filename_map) 639 asset_info.asset_filename_map[path] = original_path 640 asset_def = meta_graph_pb2.AssetFileDef() 641 asset_def.filename = path 642 asset_def.tensor_info.name = mapped_path_variable.initial_value.name 643 asset_info.asset_defs.append(asset_def) 644 asset_info.asset_initializers_by_resource[original_path_tensor] = ( 645 mapped_path_variable.initializer) 646 asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1 647 648 649def _iterate_op_types(fn): 650 """Iterates through each op in the function and returns the op type and op.""" 651 if isinstance(fn, framework_fn._DefinedFunction): # pylint: disable=protected-access 652 for node in fn.definition.node_def: 653 op_type = node.attr["_gradient_op_type"].s 654 if op_type: 655 raise ValueError( 656 "Unable to save gradient functions when exporting a " 657 "_DefinedFunction (generally created through graph freezing utils " 658 "or through V1 graph importers). Please save with " 659 "`options=tf.SaveOptions(experimental_custom_gradients=False)`") 660 else: 661 for op in fn.graph.get_operations(): 662 try: 663 op_type = op.get_attr("_gradient_op_type") 664 except ValueError: 665 continue 666 yield op_type, op 667 668 669def _get_outer_most_capture(fn, capture, func_graph_map): 670 """Tries to find the original captured tensor if capture more than once.""" 671 outer_fn = fn 672 while outer_fn is not None and not isinstance(capture, ops.EagerTensor): 673 if capture.graph is not outer_fn.graph: 674 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) 675 else: 676 try: 677 capture_index = outer_fn.graph.internal_captures.index(capture) 678 except ValueError: 679 break # Capture is a tensor inside function, and not captured from 680 # another external function 681 capture = outer_fn.graph.external_captures[capture_index] 682 outer_fn = func_graph_map.get(outer_fn.graph.outer_graph) 683 return outer_fn, capture 684 685 686def _trace_gradient_functions(graph, saveable_view): 687 """Traces gradient functions and records them in the SaveableView.""" 688 functions = list(graph._functions.values()) # pylint: disable=protected-access 689 func_graph_map = {f.graph: f for f in functions if hasattr(f, "graph")} 690 seen_op_types = set() 691 692 for fn in functions: 693 for op_type, op in _iterate_op_types(fn): 694 if op_type in seen_op_types: 695 continue 696 seen_op_types.add(op_type) 697 698 try: 699 custom_gradient = ops.gradient_registry.lookup(op_type) 700 except LookupError: 701 continue 702 703 try: 704 grad_fn = ( 705 def_function.function(custom_gradient).get_concrete_function( 706 None, *op.inputs)) 707 except Exception as exc: 708 traceback.print_exc() 709 raise ValueError( 710 "Error when tracing gradients for SavedModel.\n\n" 711 "Check the error log to see the error that was raised when " 712 "converting a gradient function to a concrete function. You may " 713 "need to update the custom gradient, or disable saving gradients " 714 "with the option " 715 "tf.saved_model.SaveOptions(experimental_custom_gradients=False)" 716 f".\n\tProblematic op name: {op.name}\n\tGradient inputs: " 717 f"{op.inputs}") from exc 718 719 # The gradient function will capture all intermediate values. These 720 # captures be serialized so that they can be re-bound to the function when 721 # loading. 722 bad_captures = [] 723 for capture in grad_fn.captured_inputs: 724 if capture.dtype in _UNCOPIABLE_DTYPES: 725 continue 726 # Tries to find the outermost capture in case the tensor is a constant 727 # or not actually captured in the current function (this could happen if 728 # the function is a while loop body, in which case the captured input 729 # is not the internal captured tensor). 730 outer_fn, outer_capture = _get_outer_most_capture( 731 fn, capture, func_graph_map) 732 if outer_fn is None or isinstance(outer_capture, ops.EagerTensor): 733 if outer_capture not in saveable_view.captured_tensor_node_ids: 734 raise ValueError(f"Found invalid capture {outer_capture} when " 735 "saving custom gradients.") 736 saveable_view.captured_tensor_node_ids[capture] = ( 737 saveable_view.captured_tensor_node_ids[outer_capture]) 738 elif outer_capture.graph is outer_fn.graph: 739 capture_name = outer_capture.name 740 # It's possible for EagerDefinedFunctions to save different names for 741 # input tensors when serialized to FunctionDef (all non-alphanumeric 742 # characters are converted to '_'). 743 if isinstance(outer_fn, defun._EagerDefinedFunction): # pylint:disable=protected-access 744 try: 745 arg_index = outer_fn.graph.inputs.index(outer_capture) 746 capture_name = outer_fn.signature.input_arg[arg_index].name + ":0" 747 except ValueError: 748 pass 749 750 node = _CapturedTensor(capture_name, outer_fn.name) 751 saveable_view.add_capture_and_node(capture, node) 752 else: 753 bad_captures.append(capture.name) 754 if not bad_captures: 755 grad_fn.add_to_graph(graph) 756 else: 757 raise ValueError( 758 f"Cannot save custom gradient {op_type} called in function {fn} " 759 "because SavedModel is unable to serialize the captured " 760 f"inputs: {bad_captures}") 761 762 saveable_view.gradient_functions.append(grad_fn) 763 func_graph_map[grad_fn.graph] = grad_fn 764 765 grad_def = function_pb2.RegisteredGradient() 766 grad_def.gradient_func = grad_fn.name 767 grad_def.registered_op_type = op_type 768 saveable_view.gradient_defs.append(grad_def) 769 770 771def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, 772 namespace_whitelist, save_custom_gradients): 773 """Generates a MetaGraph which calls `signature_functions`. 774 775 Args: 776 meta_graph_def: The MetaGraphDef proto to fill. 777 saveable_view: The _SaveableView being exported. 778 signature_functions: A dictionary mapping signature keys to concrete 779 functions containing signatures to add to the MetaGraph. 780 namespace_whitelist: List of strings containing whitelisted op namespaces. 781 save_custom_gradients: Whether to save custom gradients. 782 783 Returns: 784 A tuple of (_AssetInfo, Graph) containing the captured assets and 785 exported Graph generated from tracing the saveable_view. 786 """ 787 # List objects from the eager context to make sure Optimizers give us the 788 # right Graph-dependent variables. 789 resource_initializers = saveable_view.get_concrete_resource_initializers() 790 exported_graph = ops.Graph() 791 resource_initializer_ops = [] 792 with exported_graph.as_default(): 793 object_map, tensor_map, asset_info = saveable_view.map_resources() 794 signatures = _generate_signatures(signature_functions, object_map) 795 if save_custom_gradients: 796 _trace_gradient_functions(exported_graph, saveable_view) 797 798 # Create initializers for assets and resources. 799 for resource_initializer_function in resource_initializers: 800 asset_dependencies = [] 801 for capture in resource_initializer_function.graph.external_captures: 802 asset_initializer = asset_info.asset_initializers_by_resource.get( 803 capture, None) 804 if asset_initializer is not None: 805 asset_dependencies.append(asset_initializer) 806 with ops.control_dependencies(asset_dependencies): 807 mapped_initializer = object_map[resource_initializer_function] 808 resource_initializer_ops.append(mapped_initializer()) 809 resource_initializer_ops.extend( 810 asset_info.asset_initializers_by_resource.values()) 811 with ops.control_dependencies(resource_initializer_ops): 812 init_op = control_flow_ops.no_op() 813 # Add the same op to the main_op collection and to the init_op 814 # signature. The collection is for compatibility with older loader APIs; 815 # only one will be executed. 816 meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append( 817 init_op.name) 818 meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom( 819 signature_def_utils.op_signature_def(init_op, 820 constants.INIT_OP_SIGNATURE_KEY)) 821 822 # Saving an object-based checkpoint again gathers variables. We need to do the 823 # gathering from the eager context so Optimizers save the right set of 824 # variables, but want any operations associated with the save/restore to be in 825 # the exported graph (thus the `to_graph` argument). 826 def call_with_mapped_captures(function, args): 827 if function in object_map: 828 return object_map[function](*args) 829 # Registered saver/restore functions do not appear in `object_map`, because 830 # they are not in the object graph. 831 return function_saved_model_utils.ExportedConcreteFunction( 832 function, tensor_map)(*args) 833 834 for obj in object_map.values(): 835 obj._maybe_initialize_trackable() # pylint: disable=protected-access 836 named_saveable_objects, registered_savers = ( 837 save_util_v1.frozen_saveables_and_savers( 838 graph_view=saveable_view.augmented_graph_view, 839 object_map=object_map, 840 to_graph=exported_graph, 841 call_with_mapped_captures=call_with_mapped_captures)) 842 saver = functional_saver.MultiDeviceSaver(named_saveable_objects, 843 registered_savers, 844 call_with_mapped_captures) 845 846 with exported_graph.as_default(): 847 saver_def = saver.to_proto() 848 meta_graph_def.saver_def.CopyFrom(saver_def) 849 850 # At this point all nodes that can be added to the SavedObjectGraph have been 851 # added, so run the following to validate deserialization dependencies. 852 _dependency_sorted_node_ids(saveable_view) 853 854 graph_def = exported_graph.as_graph_def(add_shapes=True) 855 graph_def.library.registered_gradients.extend(saveable_view.gradient_defs) 856 _verify_ops(graph_def, namespace_whitelist) 857 858 meta_graph_def.graph_def.CopyFrom(graph_def) 859 meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING) 860 meta_graph_def.meta_info_def.tensorflow_version = versions.__version__ 861 meta_graph_def.meta_info_def.tensorflow_git_version = ( 862 versions.__git_version__) 863 # We currently always strip default attributes. 864 meta_graph_def.meta_info_def.stripped_default_attrs = True 865 meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( 866 meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)) 867 meta_graph_def.asset_file_def.extend(asset_info.asset_defs) 868 for signature_key, signature in signatures.items(): 869 meta_graph_def.signature_def[signature_key].CopyFrom(signature) 870 meta_graph.strip_graph_default_valued_attrs(meta_graph_def) 871 # store tensor_content in litle endian format 872 if sys.byteorder == "big": 873 utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little") 874 return asset_info, exported_graph 875 876 877def _verify_ops(graph_def, namespace_whitelist): 878 """Verifies that all namespaced ops in the graph are whitelisted. 879 880 Args: 881 graph_def: the GraphDef to validate. 882 namespace_whitelist: a list of namespaces to allow. If `None`, all will be 883 allowed. If an op does not have a namespace, it will be allowed. 884 885 Raises: 886 ValueError: If the graph contains ops that violate the whitelist. 887 """ 888 # By default, if the user has not specified a whitelist, we want to allow 889 # everything. We check for None directly rather than falseness, since the 890 # user may instead want to pass an empty list to disallow all custom 891 # namespaced ops. 892 if namespace_whitelist is None: 893 return 894 895 invalid_ops = [] 896 invalid_namespaces = set() 897 898 all_operations = [] 899 all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def)) 900 901 for op in all_operations: 902 if ">" in op: 903 namespace = op.split(">")[0] 904 if namespace not in namespace_whitelist: 905 invalid_ops.append(op) 906 invalid_namespaces.add(namespace) 907 if invalid_ops: 908 raise ValueError( 909 "Attempted to save ops from non-whitelisted namespaces to SavedModel: " 910 f"{invalid_ops}.\nPlease verify that these ops should be saved, since " 911 "they must be available when loading the SavedModel. If loading from " 912 "Python, you must import the library defining these ops. From C++, " 913 "link the custom ops to the serving binary. Once you've confirmed this," 914 " add the following namespaces to the `namespace_whitelist` " 915 f"argument in tf.saved_model.SaveOptions: {invalid_namespaces}.") 916 917 918def _dependency_sorted_node_ids(saveable_view): 919 """Returns topologically sorted nodes, sorted by dependencies.""" 920 dependency_map = {} 921 for node in saveable_view.nodes: 922 node_id = saveable_view.node_ids[node] 923 deps = dependency_map[node_id] = [] 924 # TODO(kathywu): Remove once all of these have been converted to trackable. 925 if isinstance(node, _CapturedTensor): 926 continue # These are not `Trackable` and therefore have no dependencies. 927 for _, dep in saveable_view.augmented_graph_view.list_dependencies(node): 928 if dep not in saveable_view.node_ids: 929 node_path = trackable_utils.pretty_print_node_path( 930 saveable_view.node_paths[node]) 931 raise ValueError( 932 f"Found an untracked dependency. Object {node_path} depends " 933 f"on {dep}, but this dependency isn't listed as a child. " 934 "Please track this child by overriding `_trackable_children` " 935 "or use `._track_trackable`.") 936 deps.append(saveable_view.node_ids[dep]) 937 try: 938 return trackable_utils.order_by_dependency(dependency_map) 939 except trackable_utils.CyclicDependencyError as err: 940 pretty_printed_nodes = [] 941 pretty_printed_dependencies = [] 942 943 for x, deps in err.leftover_dependency_map.items(): 944 node_path = trackable_utils.pretty_print_node_path( 945 saveable_view.node_paths[saveable_view.nodes[x]]) 946 pretty_printed_nodes.append( 947 f"\tNode {x} = {node_path} (type {type(saveable_view.nodes[x])})") 948 pretty_printed_dependencies.append(f"\tNode {x} depends on nodes {deps}") 949 pretty_printed_nodes = "\n".join(pretty_printed_nodes) 950 pretty_printed_dependencies = "\n".join(pretty_printed_dependencies) 951 raise ValueError( 952 "There is one or more dependency cycle in the saved Trackable object. " 953 "Saving cannot continue until this cycle is resolved." 954 f"\n>> Unresolved nodes:\n{pretty_printed_nodes}" 955 f"\n>> Unresolved cyclic dependencies:\n{pretty_printed_dependencies}") 956 957 958def _serialize_object_graph(saveable_view, asset_file_def_index): 959 """Save a SavedObjectGraph proto for `root`.""" 960 # SavedObjectGraph is similar to the TrackableObjectGraph proto in the 961 # checkpoint. It will eventually go into the SavedModel. 962 proto = saved_object_graph_pb2.SavedObjectGraph() 963 saveable_view.fill_object_graph_proto(proto) 964 965 for concrete_function in saveable_view.concrete_and_gradient_functions: 966 name = compat.as_text(concrete_function.name) 967 serialized = function_serialization.serialize_concrete_function( 968 concrete_function, saveable_view.captured_tensor_node_ids) 969 if serialized is not None: 970 proto.concrete_functions[name].CopyFrom(serialized) 971 972 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): 973 _write_object_proto(obj, obj_proto, asset_file_def_index, 974 saveable_view.augmented_graph_view.list_children) 975 return proto 976 977 978def _write_object_proto(obj, proto, asset_file_def_index, list_children_fn): 979 """Saves an object into SavedObject proto.""" 980 if isinstance(obj, asset.Asset): 981 proto.asset.SetInParent() 982 proto.asset.asset_file_def_index = asset_file_def_index[obj] 983 elif resource_variable_ops.is_resource_variable(obj): 984 options = save_context.get_save_options() 985 obj._write_object_proto(proto, options) # pylint: disable=protected-access 986 elif isinstance(obj, def_function.Function): 987 proto.function.CopyFrom( 988 function_serialization.serialize_function( 989 obj, [x.ref for x in list_children_fn(obj)])) 990 elif isinstance(obj, defun.ConcreteFunction): 991 proto.bare_concrete_function.CopyFrom( 992 function_serialization.serialize_bare_concrete_function(obj)) 993 elif isinstance(obj, _CapturedTensor): 994 proto.captured_tensor.name = obj.name 995 proto.captured_tensor.concrete_function = obj.concrete_function 996 elif isinstance(obj, resource.CapturableResource): 997 proto.resource.device = obj._resource_device # pylint: disable=protected-access 998 else: 999 registered_type_proto = revived_types.serialize(obj) 1000 if registered_type_proto is None: 1001 # Fallback for types with no matching registration 1002 # pylint:disable=protected-access 1003 registered_type_proto = saved_object_graph_pb2.SavedUserObject( 1004 identifier=obj._object_identifier, 1005 version=versions_pb2.VersionDef( 1006 producer=1, min_consumer=1, bad_consumers=[])) 1007 # pylint:enable=protected-access 1008 proto.user_object.CopyFrom(registered_type_proto) 1009 1010 registered_name = registration.get_registered_class_name(obj) 1011 if registered_name: 1012 proto.registered_name = registered_name 1013 serialized_user_proto = obj._serialize_to_proto(object_proto=proto) # pylint: disable=protected-access 1014 if serialized_user_proto is not None: 1015 proto.serialized_user_proto.Pack(serialized_user_proto) 1016 1017 1018def _export_debug_info(exported_graph, export_dir): 1019 """Exports debug information from graph to file. 1020 1021 Creates and writes GraphDebugInfo with traces for ops in all functions of the 1022 exported_graph. 1023 1024 Args: 1025 exported_graph: A Graph that has been created by tracing a saveable view. 1026 export_dir: SavedModel directory in which to write the debug info. 1027 """ 1028 exported_operations = [] 1029 for fn_name in exported_graph._functions: # pylint: disable=protected-access 1030 fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access 1031 if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access 1032 continue 1033 1034 fn_graph = fn.graph 1035 for fn_op in fn_graph.get_operations(): 1036 exported_operations.append((fn_name, fn_op)) 1037 1038 graph_debug_info = error_interpolation.create_graph_debug_info_def( 1039 exported_operations) 1040 file_io.atomic_write_string_to_file( 1041 file_io.join( 1042 utils_impl.get_or_create_debug_dir(export_dir), 1043 constants.DEBUG_INFO_FILENAME_PB), 1044 graph_debug_info.SerializeToString(deterministic=True)) 1045 1046 1047@tf_export( 1048 "saved_model.save", 1049 v1=["saved_model.save", "saved_model.experimental.save"]) 1050def save(obj, export_dir, signatures=None, options=None): 1051 # pylint: disable=line-too-long 1052 """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk). 1053 1054 The `obj` must inherit from the [`Trackable` 1055 class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591). 1056 1057 Example usage: 1058 1059 >>> class Adder(tf.Module): 1060 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)]) 1061 ... def add(self, x): 1062 ... return x + x 1063 1064 >>> model = Adder() 1065 >>> tf.saved_model.save(model, '/tmp/adder') 1066 1067 The resulting SavedModel is then servable with an input named "x", a scalar 1068 with dtype float32. 1069 1070 _Signatures_ 1071 1072 Signatures define the input and output types for a computation. The optional 1073 save `signatures` argument controls which methods in `obj` will be 1074 available to programs which consume `SavedModel`s, for example, serving 1075 APIs. Python functions may be decorated with 1076 `@tf.function(input_signature=...)` and passed as signatures directly, or 1077 lazily with a call to `get_concrete_function` on the method decorated with 1078 `@tf.function`. 1079 1080 Example: 1081 1082 >>> class Adder(tf.Module): 1083 ... @tf.function 1084 ... def add(self, x): 1085 ... return x + x 1086 1087 >>> model = Adder() 1088 >>> tf.saved_model.save( 1089 ... model, '/tmp/adder',signatures=model.add.get_concrete_function( 1090 ... tf.TensorSpec([], tf.float32))) 1091 1092 If a `@tf.function` does not have an input signature and 1093 `get_concrete_function` is not called on that method, the function will not 1094 be directly callable in the restored SavedModel. 1095 1096 Example: 1097 1098 >>> class Adder(tf.Module): 1099 ... @tf.function 1100 ... def add(self, x): 1101 ... return x + x 1102 1103 >>> model = Adder() 1104 >>> tf.saved_model.save(model, '/tmp/adder') 1105 >>> restored = tf.saved_model.load('/tmp/adder') 1106 >>> restored.add(1.) 1107 Traceback (most recent call last): 1108 ... 1109 ValueError: Found zero restored functions for caller function. 1110 1111 If the `signatures` argument is omitted, `obj` will be searched for 1112 `@tf.function`-decorated methods. If exactly one traced `@tf.function` is 1113 found, that method will be used as the default signature for the SavedModel. 1114 Else, any `@tf.function` attached to `obj` or its dependencies will be 1115 exported for use with `tf.saved_model.load`. 1116 1117 When invoking a signature in an exported SavedModel, `Tensor` arguments are 1118 identified by name. These names will come from the Python function's argument 1119 names by default. They may be overridden by specifying a `name=...` argument 1120 in the corresponding `tf.TensorSpec` object. Explicit naming is required if 1121 multiple `Tensor`s are passed through a single argument to the Python 1122 function. 1123 1124 The outputs of functions used as `signatures` must either be flat lists, in 1125 which case outputs will be numbered, or a dictionary mapping string keys to 1126 `Tensor`, in which case the keys will be used to name outputs. 1127 1128 Signatures are available in objects returned by `tf.saved_model.load` as a 1129 `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save` 1130 on an object with a custom `.signatures` attribute will raise an exception. 1131 1132 _Using `tf.saved_model.save` with Keras models_ 1133 1134 While Keras has its own [saving and loading 1135 API](https://www.tensorflow.org/guide/keras/save_and_serialize), 1136 this function can be used to export Keras models. For example, exporting with 1137 a signature specified: 1138 1139 >>> class Adder(tf.keras.Model): 1140 ... @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) 1141 ... def concat(self, x): 1142 ... return x + x 1143 1144 >>> model = Adder() 1145 >>> tf.saved_model.save(model, '/tmp/adder') 1146 1147 Exporting from a function without a fixed signature: 1148 1149 >>> class Adder(tf.keras.Model): 1150 ... @tf.function 1151 ... def concat(self, x): 1152 ... return x + x 1153 1154 >>> model = Adder() 1155 >>> tf.saved_model.save( 1156 ... model, '/tmp/adder', 1157 ... signatures=model.concat.get_concrete_function( 1158 ... tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))) 1159 1160 `tf.keras.Model` instances constructed from inputs and outputs already have a 1161 signature and so do not require a `@tf.function` decorator or a `signatures` 1162 argument. If neither are specified, the model's forward pass is exported. 1163 1164 >>> x = tf.keras.layers.Input((4,), name="x") 1165 >>> y = tf.keras.layers.Dense(5, name="out")(x) 1166 >>> model = tf.keras.Model(x, y) 1167 >>> tf.saved_model.save(model, '/tmp/saved_model/') 1168 1169 The exported SavedModel takes "x" with shape [None, 4] and returns "out" 1170 with shape [None, 5] 1171 1172 _Variables and Checkpoints_ 1173 1174 Variables must be tracked by assigning them to an attribute of a tracked 1175 object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers 1176 from `tf.keras.layers`, optimizers from `tf.train`) track their variables 1177 automatically. This is the same tracking scheme that `tf.train.Checkpoint` 1178 uses, and an exported `Checkpoint` object may be restored as a training 1179 checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's 1180 "variables/" subdirectory. 1181 1182 `tf.function` does not hard-code device annotations from outside the function 1183 body, instead of using the calling context's device. This means for example 1184 that exporting a model that runs on a GPU and serving it on a CPU will 1185 generally work, with some exceptions: 1186 1187 * `tf.device` annotations inside the body of the function will be hard-coded 1188 in the exported model; this type of annotation is discouraged. 1189 * Device-specific operations, e.g. with "cuDNN" in the name or with 1190 device-specific layouts, may cause issues. 1191 * For `ConcreteFunctions`, active distribution strategies will cause device 1192 placements to be hard-coded in the function. 1193 1194 SavedModels exported with `tf.saved_model.save` [strip default-valued 1195 attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes) 1196 automatically, which removes one source of incompatibilities when the consumer 1197 of a SavedModel is running an older TensorFlow version than the 1198 producer. There are however other sources of incompatibilities which are not 1199 handled automatically, such as when the exported model contains operations 1200 which the consumer does not have definitions for. 1201 1202 Args: 1203 obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export. 1204 export_dir: A directory in which to write the SavedModel. 1205 signatures: Optional, one of three types: * a `tf.function` with an input 1206 signature specified, which will use the default serving signature key, * 1207 the result of `f.get_concrete_function` on a `@tf.function`-decorated 1208 function `f`, in which case `f` will be used to generate a signature for 1209 the SavedModel under the default serving signature key, * a dictionary, 1210 which maps signature keys to either `tf.function` instances with input 1211 signatures or concrete functions. Keys of such a dictionary may be 1212 arbitrary strings, but will typically be from the 1213 `tf.saved_model.signature_constants` module. 1214 options: `tf.saved_model.SaveOptions` object for configuring save options. 1215 1216 Raises: 1217 ValueError: If `obj` is not trackable. 1218 1219 @compatibility(eager) 1220 Not well supported when graph building. From TensorFlow 1.x, 1221 `tf.compat.v1.enable_eager_execution()` should run first. Calling 1222 tf.saved_model.save in a loop when graph building from TensorFlow 1.x will 1223 add new save operations to the default graph each iteration. 1224 1225 May not be called from within a function body. 1226 @end_compatibility 1227 """ 1228 if isinstance(export_dir, os.PathLike): 1229 export_dir = os.fspath(export_dir) 1230 # pylint: enable=line-too-long 1231 metrics.IncrementWriteApi(_SAVE_V2_LABEL) 1232 save_and_return_nodes(obj, export_dir, signatures, options) 1233 1234 metrics.IncrementWrite(write_version="2") 1235 1236 1237def save_and_return_nodes(obj, 1238 export_dir, 1239 signatures=None, 1240 options=None, 1241 experimental_skip_checkpoint=False): 1242 """Saves a SavedModel while returning all saved nodes and their paths. 1243 1244 Please see `tf.saved_model.save` for details. 1245 1246 Args: 1247 obj: A trackable object to export. 1248 export_dir: A directory in which to write the SavedModel. 1249 signatures: A function or dictionary of functions to save in the SavedModel 1250 as signatures. 1251 options: `tf.saved_model.SaveOptions` object for configuring save options. 1252 experimental_skip_checkpoint: If set to `True`, the checkpoint will not be 1253 written. 1254 1255 Returns: 1256 A tuple of (a list of saved nodes in the order they are serialized to the 1257 `SavedObjectGraph`, dictionary mapping nodes to one possible path from 1258 the root node to the key node) 1259 """ 1260 options = options or save_options.SaveOptions() 1261 # TODO(b/205008509): Factor out some subset of SavedModelBuilder which is 2.x 1262 # compatible (no sessions) and share it with this export API rather than 1263 # making a SavedModel proto and writing it directly. 1264 saved_model = saved_model_pb2.SavedModel() 1265 meta_graph_def = saved_model.meta_graphs.add() 1266 1267 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( 1268 _build_meta_graph(obj, signatures, options, meta_graph_def)) 1269 saved_model.saved_model_schema_version = ( 1270 constants.SAVED_MODEL_SCHEMA_VERSION) 1271 1272 # Write the checkpoint, copy assets into the assets directory, and write out 1273 # the SavedModel proto itself. 1274 if not experimental_skip_checkpoint: 1275 utils_impl.get_or_create_variables_dir(export_dir) 1276 ckpt_options = checkpoint_options.CheckpointOptions( 1277 experimental_io_device=options.experimental_io_device) 1278 object_saver.save( 1279 utils_impl.get_variables_path(export_dir), options=ckpt_options) 1280 builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, 1281 export_dir) 1282 # Note that this needs to be the last file operation when saving the 1283 # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an 1284 # indication that the SavedModel is completely written. 1285 if context.executing_eagerly(): 1286 try: 1287 context.async_wait() # Ensure save operations have completed. 1288 except errors.NotFoundError as err: 1289 raise FileNotFoundError( 1290 f"{err}\n You may be trying to save on a different device from the " 1291 "computational device. Consider setting the " 1292 "`experimental_io_device` option in `tf.saved_model.SaveOptions` " 1293 "to the io_device such as '/job:localhost'.") 1294 1295 # We will slowly migrate code in this function to pywrap_saved_model.Save 1296 # as we build up the C++ API. 1297 pywrap_saved_model.Save(export_dir) 1298 1299 saved_model_serialized = saved_model.SerializeToString(deterministic=True) 1300 1301 # Write fingerprint protobuf, if requested. 1302 if flags.config().saved_model_fingerprinting.value(): 1303 fingerprint_path = file_io.join( 1304 compat.as_str(export_dir), 1305 compat.as_str(constants.FINGERPRINT_FILENAME)) 1306 fingerprint_proto = fingerprinting.CreateFingerprintDef( 1307 saved_model_serialized, export_dir) 1308 file_io.atomic_write_string_to_file(fingerprint_path, fingerprint_proto) 1309 1310 path = file_io.join( 1311 compat.as_str(export_dir), 1312 compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) 1313 file_io.atomic_write_string_to_file( 1314 path, saved_model.SerializeToString(deterministic=True)) 1315 1316 # Save debug info, if requested. 1317 if options.save_debug_info: 1318 _export_debug_info(exported_graph, export_dir) 1319 # Clean reference cycles so repeated export()s don't make work for the garbage 1320 # collector. Before this point, we need to keep references to captured 1321 # constants in the saved graph. 1322 ops.dismantle_graph(exported_graph) 1323 1324 return saved_nodes, node_paths 1325 1326 1327def export_meta_graph(obj, filename, signatures=None, options=None): 1328 """Exports the MetaGraph proto of the `obj` to a file. 1329 1330 This function goes through the same procedures saved_model.save goes to 1331 produce the given object's MetaGraph, then saves it to the given file. It 1332 skips saving checkpoint information, and is useful when all one wants is the 1333 graph defining the model. 1334 1335 Args: 1336 obj: A trackable object to build the MetaGraph from. 1337 filename: The file into which to write the MetaGraph. 1338 signatures: Optional, either a `tf.function` with an input signature 1339 specified or the result of `f.get_concrete_function` on a 1340 `@tf.function`-decorated function `f`, in which case `f` will be used to 1341 generate a signature for the SavedModel under the default serving 1342 signature key. `signatures` may also be a dictionary, in which case it 1343 maps from signature keys to either `tf.function` instances with input 1344 signatures or concrete functions. The keys of such a dictionary may be 1345 arbitrary strings, but will typically be from the 1346 `tf.saved_model.signature_constants` module. 1347 options: Optional, `tf.saved_model.SaveOptions` object that specifies 1348 options for saving. 1349 """ 1350 options = options or save_options.SaveOptions() 1351 export_dir = os.path.dirname(filename) 1352 meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( 1353 obj, signatures, options) 1354 1355 file_io.atomic_write_string_to_file( 1356 filename, meta_graph_def.SerializeToString(deterministic=True)) 1357 1358 # Save debug info, if requested. 1359 if options.save_debug_info: 1360 _export_debug_info(exported_graph, export_dir) 1361 1362 # Clean reference cycles so repeated export()s don't make work for the garbage 1363 # collector. Before this point, we need to keep references to captured 1364 # constants in the saved graph. 1365 ops.dismantle_graph(exported_graph) 1366 1367 1368def _build_meta_graph_impl(obj, signatures, options, meta_graph_def=None): 1369 """Creates a MetaGraph containing the resources and functions of an object.""" 1370 if ops.inside_function(): 1371 raise AssertionError( 1372 "`tf.saved_model.save` is not supported inside a traced @tf.function. " 1373 "Move the call to the outer eagerly-executed context.") 1374 # pylint: enable=line-too-long 1375 if not isinstance(obj, base.Trackable): 1376 raise ValueError( 1377 "Expected an object of type `Trackable`, such as `tf.Module` or a " 1378 f"subclass of the `Trackable` class, for export. Got {obj} " 1379 f"with type {type(obj)}.") 1380 meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() 1381 1382 augmented_graph_view = _AugmentedGraphView(obj) 1383 if signatures is None: 1384 signatures = signature_serialization.find_function_to_export( 1385 augmented_graph_view) 1386 1387 signatures, wrapped_functions = ( 1388 signature_serialization.canonicalize_signatures(signatures)) 1389 signature_serialization.validate_augmented_graph_view(augmented_graph_view) 1390 signature_map = signature_serialization.create_signature_map(signatures) 1391 augmented_graph_view.set_signature(signature_map, wrapped_functions) 1392 1393 # Use _SaveableView to provide a frozen listing of properties and functions. 1394 saveable_view = _SaveableView(augmented_graph_view, options) 1395 object_saver = checkpoint.TrackableSaver(augmented_graph_view) 1396 asset_info, exported_graph = _fill_meta_graph_def( 1397 meta_graph_def, saveable_view, signatures, options.namespace_whitelist, 1398 options.experimental_custom_gradients) 1399 if options.function_aliases: 1400 function_aliases = meta_graph_def.meta_info_def.function_aliases 1401 for alias, func in options.function_aliases.items(): 1402 for fdef in func._list_all_concrete_functions(): # pylint: disable=protected-access 1403 function_aliases[fdef.name] = alias 1404 1405 object_graph_proto = _serialize_object_graph(saveable_view, 1406 asset_info.asset_index) 1407 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) 1408 1409 return (meta_graph_def, exported_graph, object_saver, asset_info, 1410 saveable_view.nodes, saveable_view.node_paths) 1411 1412 1413def _build_meta_graph(obj, signatures, options, meta_graph_def=None): 1414 """Creates a MetaGraph under a save context. 1415 1416 Args: 1417 obj: A trackable object to build the MetaGraph from. 1418 signatures: Can be a `tf.function` with an input signature specified or the 1419 result of `f.get_concrete_function` on a `@tf.function`-decorated function 1420 `f`. `signatures` may also be a dictionary, in which case it maps from 1421 signature keys to `tf.function` instances. If None, finds signature to 1422 export from the `@tf.function`-decorated methods in `obj`. 1423 options: `tf.saved_model.SaveOptions` object that specifies options for 1424 saving. 1425 meta_graph_def: Optional, the MetaGraphDef proto fill. 1426 1427 Raises: 1428 AssertionError: If `export_meta_graph` is executing inside a `tf.function`. 1429 ValueError: If `obj` is not trackable. 1430 1431 Returns: 1432 meta_graph_def: Filled MetaGraphDef proto 1433 exported_graph: `tf.Graph` object generated from `obj`. 1434 object_saver: `checkpoint.TrackableSaver` of the `obj` and its dependencies. 1435 asset_info: `_AssetInfo` tuple containing external assets in the `obj`. 1436 saveable_view.nodes: _SaveableView nodes. 1437 saveable_view.node_paths: _SaveableView paths. 1438 """ 1439 1440 with save_context.save_context(options): 1441 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) 1442