xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/save.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Exports a SavedModel from a Trackable Python object."""
16
17import collections
18import os
19import re
20import sys
21import traceback
22
23from absl import logging
24
25from tensorflow.core.config import flags
26from tensorflow.core.framework import function_pb2
27from tensorflow.core.framework import versions_pb2
28from tensorflow.core.protobuf import meta_graph_pb2
29from tensorflow.core.protobuf import saved_model_pb2
30from tensorflow.core.protobuf import saved_object_graph_pb2
31from tensorflow.python.checkpoint import checkpoint
32from tensorflow.python.checkpoint import checkpoint_options
33from tensorflow.python.checkpoint import functional_saver
34from tensorflow.python.checkpoint import graph_view
35from tensorflow.python.checkpoint import save_util_v1
36from tensorflow.python.checkpoint import util as checkpoint_util
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.eager import function as defun
40from tensorflow.python.eager import function_saved_model_utils
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import error_interpolation
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import function as framework_fn
45from tensorflow.python.framework import meta_graph
46from tensorflow.python.framework import ops
47from tensorflow.python.framework import tensor_util
48from tensorflow.python.framework import versions
49from tensorflow.python.lib.io import file_io
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import resource_variable_ops
53from tensorflow.python.saved_model import builder_impl
54from tensorflow.python.saved_model import function_serialization
55from tensorflow.python.saved_model import pywrap_saved_model
56from tensorflow.python.saved_model import registration
57from tensorflow.python.saved_model import revived_types
58from tensorflow.python.saved_model import save_context
59from tensorflow.python.saved_model import save_options
60from tensorflow.python.saved_model import signature_constants
61from tensorflow.python.saved_model import signature_def_utils
62from tensorflow.python.saved_model import signature_serialization
63from tensorflow.python.saved_model import tag_constants
64from tensorflow.python.saved_model import tracing_utils
65from tensorflow.python.saved_model import utils_impl
66from tensorflow.python.saved_model.pywrap_saved_model import constants
67from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
68from tensorflow.python.saved_model.pywrap_saved_model import metrics
69from tensorflow.python.trackable import asset
70from tensorflow.python.trackable import base
71from tensorflow.python.trackable import resource
72from tensorflow.python.trackable import trackable_utils
73from tensorflow.python.training.saving import saveable_object_util
74from tensorflow.python.util import compat
75from tensorflow.python.util import object_identity
76from tensorflow.python.util.tf_export import tf_export
77
78_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
79
80# Container for tensors captured from external functions.
81_CapturedTensor = collections.namedtuple("_CapturedTensor",
82                                         ["name", "concrete_function"])
83
84# Number of untraced functions to display to user in warning message.
85_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
86
87# API label for SavedModel metrics.
88_SAVE_V2_LABEL = "save_v2"
89
90
91class _AugmentedGraphView(graph_view.ObjectGraphView):
92  """An extendable graph which also tracks functions attached to objects.
93
94  Extensions through `add_object` appear in the object graph and any checkpoints
95  generated from it, even if they are not dependencies of the node they were
96  attached to in the saving program. For example a `.signatures` attribute is
97  added to exported SavedModel root objects without modifying the root object
98  itself.
99
100  Also tracks functions attached to objects in the graph, through the caching
101  `_list_functions` method. Enumerating functions only through this method
102  ensures that we get a consistent view of functions, even if object attributes
103  create new functions every time they are accessed.
104  """
105
106  def __init__(self, root):
107    super(_AugmentedGraphView, self).__init__(root)
108
109    # Cache the results of `GraphView.list_children()` to ensure that the
110    # `Trackable` children are gathered exactly once.
111    self._children_cache = object_identity.ObjectIdentityDictionary()
112
113    # Cache shared between objects in the same object graph. This is passed to
114    # `Trackable._trackable_children()`.
115    self._serialization_cache = object_identity.ObjectIdentityDictionary()
116
117    # Maps functions -> wrapped functions that capture non-cached variables.
118    self._wrapped_functions = {}
119
120    self.untraced_functions = []
121
122  def set_signature(self, signature_map, wrapped_functions):
123    """Attach signature to the root object.
124
125    Args:
126      signature_map: An object that contains signature functions.
127      wrapped_functions: A dictionary mapping functions to functions that are
128        guaranteed to not capture cached variables (functions that capture
129        cached variables can't be saved).
130    """
131    self.list_children(self.root)
132    # Overrides existing dependency.
133    name = signature_serialization.SIGNATURE_ATTRIBUTE_NAME
134    self._children_cache[self.root][name] = signature_map
135    self._wrapped_functions.update(wrapped_functions)
136
137  def _breadth_first_traversal(self):
138    """Returns all trackable objects in the SavedObjectGraph."""
139    # This method is overriden to merge all equivalent constant tensors and
140    # Assets in the object graph.
141
142    trackable_objects, _ = (
143        super(_AugmentedGraphView, self)._breadth_first_traversal())
144
145    asset_paths = object_identity.ObjectIdentityDictionary()
146    constant_captures = object_identity.ObjectIdentityDictionary()
147    for obj in trackable_objects:
148      if isinstance(obj, asset.Asset):
149        asset_paths[obj.asset_path] = obj
150      if isinstance(obj, function_saved_model_utils.TrackableConstant):
151        constant_captures[obj.capture] = obj
152
153    def _get_merged_trackable(x):
154      if isinstance(x, asset.Asset):
155        return asset_paths[x.asset_path]
156      if isinstance(x, function_saved_model_utils.TrackableConstant):
157        if x.capture in asset_paths:
158          return asset_paths[x.capture]
159        else:
160          return constant_captures[x.capture]
161      return x
162
163    for obj in list(self._children_cache.keys()):
164      if _get_merged_trackable(obj) is not obj:
165        del self._children_cache[obj]
166        continue
167      for name, child in self._children_cache[obj].items():
168        self._children_cache[obj][name] = _get_merged_trackable(child)
169
170    return super(_AugmentedGraphView, self)._breadth_first_traversal()
171
172  def list_children(self, obj):
173    """Lists children of `obj` for SavedModel."""
174    if obj not in self._children_cache:
175      children = self._children_cache[obj] = {}
176
177      for name, child in super(_AugmentedGraphView, self).list_children(
178          obj,
179          save_type=base.SaveType.SAVEDMODEL,
180          cache=self._serialization_cache):
181        if isinstance(child, defun.ConcreteFunction):
182          child = self._maybe_uncache_variable_captures(child)
183        children[name] = child
184
185      # Keep track of untraced functions for later reporting to the user.
186      if isinstance(obj, def_function.Function) and not children:
187        self.untraced_functions.append(obj.name)
188
189    for name, child in self._children_cache[obj].items():
190      yield base.TrackableReference(name, child)
191
192  def get_child(self, obj, name):
193    return self._children_cache[obj][name]
194
195  def _maybe_uncache_variable_captures(self, concrete_function):
196    if concrete_function in self._wrapped_functions:
197      return self._wrapped_functions[concrete_function]
198    for capture in concrete_function.captured_inputs:
199      if hasattr(capture, "_cached_variable"):
200        if concrete_function not in self._wrapped_functions:
201          wrapped = self._wrapped_functions[concrete_function] = (
202              function_serialization.wrap_cached_variables(concrete_function))
203          return wrapped
204    return concrete_function
205
206  def list_dependencies(self, obj):
207    """Yields `Trackables` that must be loaded before `obj`.
208
209    Dependencies and children are both dictionaries of `Trackables`. Children
210    define the object graph structure (used in both checkpoints and SavedModel),
211    while dependency defines the order used to load the SavedModel
212
213    Args:
214      obj: A `Trackable` object
215
216    Yields:
217      Tuple of dependency names and trackable objects.
218
219    Raises:
220      TypeError: if any of the returned dependencies are not instances of
221        `Trackable`.
222    """
223    if obj not in self._children_cache:
224      # Slot variables do not appear in the children_cache.
225      children = {}
226    else:
227      children = self._children_cache[obj]
228    for name, dep in obj._deserialization_dependencies(children).items():  # pylint: disable=protected-access
229      if not isinstance(dep, base.Trackable):
230        raise TypeError(
231            f"The dependency of type {type(dep)} is not an instance `Trackable`"
232            ", and can't be saved to SavedModel. Please check the "
233            "implementation of `_deserialization_dependencies` in the parent "
234            f"object {obj}.")
235      yield name, dep
236
237
238class _SaveableView(object):
239  """Provides a frozen view over a trackable root.
240
241  This class helps to create a single stable view over an object to save. The
242  saving code should access properties and functions via this class and not via
243  the original object as there are cases where an object construct their
244  trackable attributes and functions dynamically per call and will yield
245  different objects if invoked more than once.
246
247  Changes to the graph, for example adding objects, must happen in
248  `augmented_graph_view` (an `_AugmentedGraphView`) before the `_SaveableView`
249  is constructed. Changes after the `_SaveableView` has been constructed will be
250  ignored.
251  """
252
253  def __init__(self, augmented_graph_view, options):
254    """Initializes a SaveableView.
255
256    Args:
257      augmented_graph_view: A GraphView object.
258      options: A SaveOptions instance.
259    """
260
261    self.augmented_graph_view = augmented_graph_view
262    self._options = options
263
264    (self._trackable_objects, self.node_paths, self.node_ids,
265     self._slot_variables, self.object_names) = (
266         checkpoint_util.objects_ids_and_slot_variables_and_paths(
267             self.augmented_graph_view))
268
269    untraced_functions = self.augmented_graph_view.untraced_functions
270    if untraced_functions:
271      logging.warning(
272          "Found untraced functions such as %s while saving (showing %d of %d)."
273          " These functions will not be directly callable after loading.",
274          ", ".join(untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
275          min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(untraced_functions)),
276          len(untraced_functions))
277
278    self._initialize_save_and_restore_functions()
279    self._initialize_nodes_and_concrete_functions()
280
281    self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
282
283  def _initialize_save_and_restore_functions(self):
284    """Generates all checkpoint save/restore functions.
285
286    The save and restore functions are generated in the eager context (or in the
287    user's Graph/Session) before being copied to the exported GraphDef. These
288    functions record the ops for saving/restoring the entire object or
289    individual objects (e.g. variables and hash tables).
290
291    The global save and restore functions are generated for compatibility with
292    TF1 and loading from C++, and is saved in the `MetaGraphDef.saver_def`.
293
294    The individual functions are generated for the Python TF2 use case, where
295    users use the loaded SavedModel as-is, or compose new models using parts
296    of the object loaded from the SavedModel. These functions are recorded in
297    the `saveable_objects` map in the `SavedObject` proto.
298    """
299    checkpoint_factory_map, registered_savers = (
300        save_util_v1.get_checkpoint_factories_and_keys(self.object_names))
301    self._obj_to_registered_saver = object_identity.ObjectIdentityDictionary()
302    for saver_name, trackables in registered_savers.items():
303      for trackable in trackables.values():
304        self._obj_to_registered_saver[trackable] = saver_name
305    self._saveable_objects_map = (
306        _gen_save_and_restore_functions(checkpoint_factory_map))
307
308  def _initialize_nodes_and_concrete_functions(self):
309    """Creates graph with nodes for trackable objects and functions.
310
311    Adds functions for each trackable object to `self.nodes` and associated
312    concrete functions to `self.concrete_functions` for serialization.
313    """
314    self.nodes = list(self._trackable_objects)
315    self.gradient_functions = []
316    self.gradient_defs = []
317
318    for obj in self.nodes:
319      if obj in self._saveable_objects_map:
320        for save_fn, restore_fn in self._saveable_objects_map[obj].values():
321          self.node_ids[save_fn] = len(self.nodes)
322          self.nodes.append(save_fn)
323
324          self.node_ids[restore_fn] = len(self.nodes)
325          self.nodes.append(restore_fn)
326
327    self.concrete_functions = [
328        obj for obj in self.nodes if isinstance(obj, defun.ConcreteFunction)
329    ]
330
331  @property
332  def concrete_and_gradient_functions(self):
333    return self.concrete_functions + self.gradient_functions
334
335  @property
336  def root(self):
337    return self.nodes[0]
338
339  def fill_object_graph_proto(self, proto):
340    """Populate the nodes, children and slot_variables of a SavedObjectGraph."""
341    for node_id, node in enumerate(self.nodes):
342      assert self.node_ids[node] == node_id
343      object_proto = proto.nodes.add()
344      object_proto.slot_variables.extend(self._slot_variables.get(node, ()))
345      if isinstance(node, _CapturedTensor):
346        continue
347      for child in self.augmented_graph_view.list_children(node):
348        child_proto = object_proto.children.add()
349        child_proto.node_id = self.node_ids[child.ref]
350        child_proto.local_name = child.name
351      for name, ref in self.augmented_graph_view.list_dependencies(node):
352        child_proto = object_proto.dependencies.add()
353        child_proto.node_id = self.node_ids[ref]
354        child_proto.local_name = name
355
356      if node in self._saveable_objects_map:
357        assert node not in self._obj_to_registered_saver, (
358            "Objects can't have both SaveableObjects and a registered saver")
359
360        for local_name, (save_fn, restore_fn) in (
361            self._saveable_objects_map[node].items()):
362          saveable_object_proto = object_proto.saveable_objects[local_name]
363          saveable_object_proto.save_function = self.node_ids[save_fn]
364          saveable_object_proto.restore_function = self.node_ids[restore_fn]
365
366      elif node in self._obj_to_registered_saver:
367        object_proto.registered_saver = self._obj_to_registered_saver[node]
368
369  def map_resources(self):
370    """Makes new resource handle ops corresponding to existing resource tensors.
371
372    Creates resource handle ops in the current default graph, whereas
373    `accessible_objects` will be from an eager context. Resource mapping adds
374    resource handle ops to the main GraphDef of a SavedModel, which allows the
375    C++ loader API to interact with resources.
376
377    Returns:
378      A tuple of (object_map, tensor_map, asset_info):
379        object_map: A dictionary mapping from object in `accessible_objects` to
380          replacement objects created to hold the new resource tensors.
381        tensor_map: A dictionary mapping from resource tensors extracted from
382          `accessible_objects` to newly created resource tensors.
383        asset_info: An _AssetInfo tuple describing external assets referenced
384          from accessible_objects.
385    """
386    # Only makes sense when adding to the export Graph
387    assert not context.executing_eagerly()
388    # TODO(b/205007558): Handle MirroredVariables and other types of variables
389    # which may need special casing.
390    object_map = object_identity.ObjectIdentityDictionary()
391    tensor_map = {}
392    asset_info = _AssetInfo(
393        asset_defs=[],
394        asset_initializers_by_resource={},
395        asset_filename_map={},
396        asset_index={})
397
398    for node_id in _dependency_sorted_node_ids(self):
399      obj = self.nodes[node_id]
400      tensors = obj._export_to_saved_model_graph(  # pylint: disable=protected-access
401          object_map=object_map,
402          tensor_map=tensor_map,
403          options=self._options)
404      if isinstance(obj, asset.Asset):
405        _add_asset_info(obj, asset_info, tensor_map[obj.asset_path])
406      if tensors:
407        for tensor in tensors:
408          self.captured_tensor_node_ids[tensor] = node_id
409
410    return object_map, tensor_map, asset_info
411
412  def add_capture_and_node(self, capture, node):
413    node_id = len(self.nodes)
414    self.nodes.append(node)
415    self.node_ids[capture] = node_id
416    self.node_ids[node] = node_id
417    self.captured_tensor_node_ids[capture] = node_id
418    return node_id
419
420  def get_concrete_resource_initializers(self):
421    concrete_initializers = []
422    for obj in self.nodes:
423      if isinstance(obj, resource.CapturableResource):
424        concrete_initializers.append(
425            self.augmented_graph_view.get_child(
426                obj, "_initialize").get_concrete_function())
427    return concrete_initializers
428
429
430def _gen_save_and_restore_functions(checkpoint_factory_map):
431  """Generates global and individual save/restore concrete functions.
432
433  The global functions records the ops to save and restore the entire object to
434  a file prefix, while the individual functions save and restore value tensors
435  for resources.
436
437  This function is intended to run on the output of
438  `save_util_v1.get_checkpoint_factories_and_keys(object_names)`,
439  which returns the generated a map of `_CheckpointFactoryData`.
440
441  Args:
442    checkpoint_factory_map: A dictionary mapping trackable objects to
443      a list of `_CheckpointFactoryData`.
444
445  Returns:
446    Tuple of (
447      saveable_fn_map: Maps obj -> factory name -> (concrete save, restore)
448      )
449  """
450  # Maps obj -> factory attribute_name -> (concrete save, concrete restore)
451  # This
452  saveable_fn_map = object_identity.ObjectIdentityDictionary()
453
454  for obj, factory_data_list in checkpoint_factory_map.items():
455    if resource_variable_ops.is_resource_variable(obj) or not factory_data_list:
456      # There is no need to trace the save and restore functions for variables.
457      continue
458
459    if factory_data_list[0].name == trackable_utils.SERIALIZE_TO_TENSORS_NAME:
460      # Trace Trackable save and restore functions.
461      assert len(factory_data_list) == 1
462      saveable_fn_map[obj] = {trackable_utils.SERIALIZE_TO_TENSORS_NAME: (
463          tracing_utils.trace_save_and_restore(obj))}
464    else:
465      # Trace deprecated SaveableObject save and restore functions.
466      saveable_fn_map[obj] = (
467          saveable_object_util.trace_save_restore_function_map(
468              obj, factory_data_list))
469  return saveable_fn_map
470
471
472def _tensor_dict_to_tensorinfo(tensor_dict):
473  return {
474      key: utils_impl.build_tensor_info_internal(value)
475      for key, value in tensor_dict.items()
476  }
477
478
479def _to_safe_name_scope(signature_key, user_input_name):
480  """Creates a sanitized name scope from user signature and input names.
481
482  Concatenates signature and input names, sanitizing as needed to be a valid
483  scope name.
484
485  Args:
486    signature_key: The user-provided key for the signature.
487    user_input_name: The user-provided name for the input placeholder.
488
489  Returns:
490    A name scope that is safe to be used in tf.name_scope().
491  """
492  name_scope = "{}_{}".format(signature_key, user_input_name)
493  if re.match(r"^[A-Za-z0-9.][A-Za-z0-9_.\\-]*$", name_scope):
494    return name_scope
495  invalid_prefix_stripped = re.sub(r"^[^A-Za-z0-9.]*", "", name_scope)
496  return re.sub(r"[^A-Za-z0-9_.\\-]", "_", invalid_prefix_stripped)
497
498
499def _map_function_arguments_to_created_inputs(function_arguments, signature_key,
500                                              function_name):
501  """Creates exterior placeholders in the exported graph for function arguments.
502
503  Functions have two types of inputs: tensors captured from the outside (eager)
504  context, and arguments to the function which we expect to receive from the
505  user at each call. `_map_captures_to_created_tensors` replaces
506  captured tensors with stand-ins (typically these are resource dtype tensors
507  associated with variables). `_map_function_inputs_to_created_inputs` runs over
508  every argument, creating a new placeholder for each which will belong to the
509  exported graph rather than the function body.
510
511  Args:
512    function_arguments: A list of argument placeholders in the function body.
513    signature_key: The name of the signature being exported, for error messages.
514    function_name: The name of the function, for error messages.
515
516  Returns:
517    A tuple of (mapped_inputs, exterior_placeholders)
518      mapped_inputs: A list with entries corresponding to `function_arguments`
519        containing all of the inputs of the function gathered from the exported
520        graph (both captured resources and arguments).
521      exterior_argument_placeholders: A dictionary mapping from argument names
522        to placeholders in the exported graph, containing the explicit arguments
523        to the function which a user is expected to provide.
524
525  Raises:
526    ValueError: If argument names are not unique.
527  """
528  # `exterior_argument_placeholders` holds placeholders which are outside the
529  # function body, directly contained in a MetaGraph of the SavedModel. The
530  # function body itself contains nearly identical placeholders used when
531  # running the function, but these exterior placeholders allow Session-based
532  # APIs to call the function using feeds and fetches which name Tensors in the
533  # MetaGraph.
534  exterior_argument_placeholders = {}
535  mapped_inputs = []
536  for placeholder in function_arguments:
537    # `export_captures` contains an exhaustive set of captures, so if we don't
538    # find the input there then we now know we have an argument.
539    user_input_name = compat.as_str_any(
540        placeholder.op.get_attr("_user_specified_name"))
541    # If the internal placeholders for a function have names which were
542    # uniquified by TensorFlow, then a single user-specified argument name
543    # must refer to multiple Tensors. The resulting signatures would be
544    # confusing to call. Instead, we throw an exception telling the user to
545    # specify explicit names.
546    if user_input_name != placeholder.op.name:
547      # This should be unreachable, since concrete functions may not be
548      # generated with non-unique argument names.
549      raise ValueError(
550          "Got non-flat/non-unique argument names for SavedModel signature "
551          f"'{signature_key}': more than one argument to "
552          f"'{compat.as_str_any(function_name)}' was named "
553          f"'{user_input_name}'. "
554          "Signatures have one Tensor per named input, so to have "
555          "predictable names Python functions used to generate these "
556          "signatures should avoid *args and Tensors in nested "
557          "structures unless unique names are specified for each. Use "
558          "tf.TensorSpec(..., name=...) to provide a name for a Tensor "
559          "input.")
560    arg_placeholder = array_ops.placeholder(
561        shape=placeholder.shape,
562        dtype=placeholder.dtype,
563        name=_to_safe_name_scope(signature_key, user_input_name))
564    exterior_argument_placeholders[user_input_name] = arg_placeholder
565    mapped_inputs.append(arg_placeholder)
566  return mapped_inputs, exterior_argument_placeholders
567
568
569def _generate_signatures(signature_functions, object_map):
570  """Validates and calls `signature_functions` in the exported graph.
571
572  Args:
573    signature_functions: A dictionary mapping string keys to concrete TensorFlow
574      functions (e.g. from `signature_serialization.canonicalize_signatures`)
575      which will be used to generate SignatureDefs.
576    object_map: A dictionary that contains mappings from signature functions to
577      concrete functions in the exported graph.
578
579  Returns:
580    Each function in the `signature_functions` dictionary is called with
581    placeholder Tensors, generating a function call operation and output
582    Tensors. The placeholder Tensors, the function call operation, and the
583    output Tensors from the function call are part of the default Graph.
584
585    This function then returns a dictionary with the same structure as
586    `signature_functions`, with the concrete functions replaced by SignatureDefs
587    implicitly containing information about how to call each function from a
588    TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
589    the generated placeholders and Tensor outputs by name.
590
591    The caller is expected to include the default Graph set while calling this
592    function as a MetaGraph in a SavedModel, including the returned
593    SignatureDefs as part of that MetaGraph.
594  """
595  signatures = {}
596  for signature_key, function in sorted(signature_functions.items()):
597    if function.graph.captures:
598      argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
599    else:
600      argument_inputs = function.graph.inputs
601    mapped_inputs, exterior_argument_placeholders = (
602        _map_function_arguments_to_created_inputs(argument_inputs,
603                                                  signature_key, function.name))
604    outputs = object_map[function](*mapped_inputs)
605    signatures[signature_key] = signature_def_utils.build_signature_def(
606        _tensor_dict_to_tensorinfo(exterior_argument_placeholders),
607        _tensor_dict_to_tensorinfo(outputs),
608        method_name=signature_constants.PREDICT_METHOD_NAME)
609  return signatures
610
611
612_AssetInfo = collections.namedtuple(
613    "_AssetInfo",
614    [
615        # List of AssetFileDef protocol buffers
616        "asset_defs",
617        # Map from asset variable resource Tensors to their init ops
618        "asset_initializers_by_resource",
619        # Map from base asset filenames to full paths
620        "asset_filename_map",
621        # Map from Asset to index of corresponding AssetFileDef
622        "asset_index"
623    ])
624
625
626def _add_asset_info(trackable_asset, asset_info, mapped_path_variable):
627  """Add `trackable_asset` to `asset_info`."""
628  original_path_tensor = trackable_asset.asset_path
629  original_path = tensor_util.constant_value(original_path_tensor)
630  try:
631    original_path = str(original_path.astype(str))
632  except AttributeError:
633    # Already a string rather than a numpy array
634    pass
635
636  path = builder_impl.get_asset_filename_to_add(
637      asset_filepath=original_path,
638      asset_filename_map=asset_info.asset_filename_map)
639  asset_info.asset_filename_map[path] = original_path
640  asset_def = meta_graph_pb2.AssetFileDef()
641  asset_def.filename = path
642  asset_def.tensor_info.name = mapped_path_variable.initial_value.name
643  asset_info.asset_defs.append(asset_def)
644  asset_info.asset_initializers_by_resource[original_path_tensor] = (
645      mapped_path_variable.initializer)
646  asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
647
648
649def _iterate_op_types(fn):
650  """Iterates through each op in the function and returns the op type and op."""
651  if isinstance(fn, framework_fn._DefinedFunction):  # pylint: disable=protected-access
652    for node in fn.definition.node_def:
653      op_type = node.attr["_gradient_op_type"].s
654      if op_type:
655        raise ValueError(
656            "Unable to save gradient functions when exporting a "
657            "_DefinedFunction (generally created through graph freezing utils "
658            "or through V1 graph importers). Please save with "
659            "`options=tf.SaveOptions(experimental_custom_gradients=False)`")
660  else:
661    for op in fn.graph.get_operations():
662      try:
663        op_type = op.get_attr("_gradient_op_type")
664      except ValueError:
665        continue
666      yield op_type, op
667
668
669def _get_outer_most_capture(fn, capture, func_graph_map):
670  """Tries to find the original captured tensor if capture more than once."""
671  outer_fn = fn
672  while outer_fn is not None and not isinstance(capture, ops.EagerTensor):
673    if capture.graph is not outer_fn.graph:
674      outer_fn = func_graph_map.get(outer_fn.graph.outer_graph)
675    else:
676      try:
677        capture_index = outer_fn.graph.internal_captures.index(capture)
678      except ValueError:
679        break  # Capture is a tensor inside function, and not captured from
680        # another external function
681      capture = outer_fn.graph.external_captures[capture_index]
682      outer_fn = func_graph_map.get(outer_fn.graph.outer_graph)
683  return outer_fn, capture
684
685
686def _trace_gradient_functions(graph, saveable_view):
687  """Traces gradient functions and records them in the SaveableView."""
688  functions = list(graph._functions.values())  # pylint: disable=protected-access
689  func_graph_map = {f.graph: f for f in functions if hasattr(f, "graph")}
690  seen_op_types = set()
691
692  for fn in functions:
693    for op_type, op in _iterate_op_types(fn):
694      if op_type in seen_op_types:
695        continue
696      seen_op_types.add(op_type)
697
698      try:
699        custom_gradient = ops.gradient_registry.lookup(op_type)
700      except LookupError:
701        continue
702
703      try:
704        grad_fn = (
705            def_function.function(custom_gradient).get_concrete_function(
706                None, *op.inputs))
707      except Exception as exc:
708        traceback.print_exc()
709        raise ValueError(
710            "Error when tracing gradients for SavedModel.\n\n"
711            "Check the error log to see the error that was raised when "
712            "converting a gradient function to a concrete function. You may "
713            "need to update the custom gradient, or disable saving gradients "
714            "with the option "
715            "tf.saved_model.SaveOptions(experimental_custom_gradients=False)"
716            f".\n\tProblematic op name: {op.name}\n\tGradient inputs: "
717            f"{op.inputs}") from exc
718
719      # The gradient function will capture all intermediate values. These
720      # captures be serialized so that they can be re-bound to the function when
721      # loading.
722      bad_captures = []
723      for capture in grad_fn.captured_inputs:
724        if capture.dtype in _UNCOPIABLE_DTYPES:
725          continue
726        # Tries to find the outermost capture in case the tensor is a constant
727        # or not actually captured in the current function (this could happen if
728        # the function is a while loop body, in which case the captured input
729        # is not the internal captured tensor).
730        outer_fn, outer_capture = _get_outer_most_capture(
731            fn, capture, func_graph_map)
732        if outer_fn is None or isinstance(outer_capture, ops.EagerTensor):
733          if outer_capture not in saveable_view.captured_tensor_node_ids:
734            raise ValueError(f"Found invalid capture {outer_capture} when "
735                             "saving custom gradients.")
736          saveable_view.captured_tensor_node_ids[capture] = (
737              saveable_view.captured_tensor_node_ids[outer_capture])
738        elif outer_capture.graph is outer_fn.graph:
739          capture_name = outer_capture.name
740          # It's possible for EagerDefinedFunctions to save different names for
741          # input tensors when serialized to FunctionDef (all non-alphanumeric
742          # characters are converted to '_').
743          if isinstance(outer_fn, defun._EagerDefinedFunction):  # pylint:disable=protected-access
744            try:
745              arg_index = outer_fn.graph.inputs.index(outer_capture)
746              capture_name = outer_fn.signature.input_arg[arg_index].name + ":0"
747            except ValueError:
748              pass
749
750          node = _CapturedTensor(capture_name, outer_fn.name)
751          saveable_view.add_capture_and_node(capture, node)
752        else:
753          bad_captures.append(capture.name)
754      if not bad_captures:
755        grad_fn.add_to_graph(graph)
756      else:
757        raise ValueError(
758            f"Cannot save custom gradient {op_type} called in function {fn} "
759            "because SavedModel is unable to serialize the captured "
760            f"inputs: {bad_captures}")
761
762      saveable_view.gradient_functions.append(grad_fn)
763      func_graph_map[grad_fn.graph] = grad_fn
764
765      grad_def = function_pb2.RegisteredGradient()
766      grad_def.gradient_func = grad_fn.name
767      grad_def.registered_op_type = op_type
768      saveable_view.gradient_defs.append(grad_def)
769
770
771def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
772                         namespace_whitelist, save_custom_gradients):
773  """Generates a MetaGraph which calls `signature_functions`.
774
775  Args:
776    meta_graph_def: The MetaGraphDef proto to fill.
777    saveable_view: The _SaveableView being exported.
778    signature_functions: A dictionary mapping signature keys to concrete
779      functions containing signatures to add to the MetaGraph.
780    namespace_whitelist: List of strings containing whitelisted op namespaces.
781    save_custom_gradients: Whether to save custom gradients.
782
783  Returns:
784    A tuple of (_AssetInfo, Graph) containing the captured assets and
785    exported Graph generated from tracing the saveable_view.
786  """
787  # List objects from the eager context to make sure Optimizers give us the
788  # right Graph-dependent variables.
789  resource_initializers = saveable_view.get_concrete_resource_initializers()
790  exported_graph = ops.Graph()
791  resource_initializer_ops = []
792  with exported_graph.as_default():
793    object_map, tensor_map, asset_info = saveable_view.map_resources()
794    signatures = _generate_signatures(signature_functions, object_map)
795    if save_custom_gradients:
796      _trace_gradient_functions(exported_graph, saveable_view)
797
798    # Create initializers for assets and resources.
799    for resource_initializer_function in resource_initializers:
800      asset_dependencies = []
801      for capture in resource_initializer_function.graph.external_captures:
802        asset_initializer = asset_info.asset_initializers_by_resource.get(
803            capture, None)
804        if asset_initializer is not None:
805          asset_dependencies.append(asset_initializer)
806      with ops.control_dependencies(asset_dependencies):
807        mapped_initializer = object_map[resource_initializer_function]
808        resource_initializer_ops.append(mapped_initializer())
809    resource_initializer_ops.extend(
810        asset_info.asset_initializers_by_resource.values())
811    with ops.control_dependencies(resource_initializer_ops):
812      init_op = control_flow_ops.no_op()
813    # Add the same op to the main_op collection and to the init_op
814    # signature. The collection is for compatibility with older loader APIs;
815    # only one will be executed.
816    meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
817        init_op.name)
818    meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
819        signature_def_utils.op_signature_def(init_op,
820                                             constants.INIT_OP_SIGNATURE_KEY))
821
822  # Saving an object-based checkpoint again gathers variables. We need to do the
823  # gathering from the eager context so Optimizers save the right set of
824  # variables, but want any operations associated with the save/restore to be in
825  # the exported graph (thus the `to_graph` argument).
826  def call_with_mapped_captures(function, args):
827    if function in object_map:
828      return object_map[function](*args)
829    # Registered saver/restore functions do not appear in `object_map`, because
830    # they are not in the object graph.
831    return function_saved_model_utils.ExportedConcreteFunction(
832        function, tensor_map)(*args)
833
834  for obj in object_map.values():
835    obj._maybe_initialize_trackable()  # pylint: disable=protected-access
836  named_saveable_objects, registered_savers = (
837      save_util_v1.frozen_saveables_and_savers(
838          graph_view=saveable_view.augmented_graph_view,
839          object_map=object_map,
840          to_graph=exported_graph,
841          call_with_mapped_captures=call_with_mapped_captures))
842  saver = functional_saver.MultiDeviceSaver(named_saveable_objects,
843                                            registered_savers,
844                                            call_with_mapped_captures)
845
846  with exported_graph.as_default():
847    saver_def = saver.to_proto()
848    meta_graph_def.saver_def.CopyFrom(saver_def)
849
850  # At this point all nodes that can be added to the SavedObjectGraph have been
851  # added, so run the following to validate deserialization dependencies.
852  _dependency_sorted_node_ids(saveable_view)
853
854  graph_def = exported_graph.as_graph_def(add_shapes=True)
855  graph_def.library.registered_gradients.extend(saveable_view.gradient_defs)
856  _verify_ops(graph_def, namespace_whitelist)
857
858  meta_graph_def.graph_def.CopyFrom(graph_def)
859  meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
860  meta_graph_def.meta_info_def.tensorflow_version = versions.__version__
861  meta_graph_def.meta_info_def.tensorflow_git_version = (
862      versions.__git_version__)
863  # We currently always strip default attributes.
864  meta_graph_def.meta_info_def.stripped_default_attrs = True
865  meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
866      meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def))
867  meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
868  for signature_key, signature in signatures.items():
869    meta_graph_def.signature_def[signature_key].CopyFrom(signature)
870  meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
871  # store tensor_content in litle endian format
872  if sys.byteorder == "big":
873    utils_impl.swap_function_tensor_content(meta_graph_def, "big", "little")
874  return asset_info, exported_graph
875
876
877def _verify_ops(graph_def, namespace_whitelist):
878  """Verifies that all namespaced ops in the graph are whitelisted.
879
880  Args:
881   graph_def: the GraphDef to validate.
882   namespace_whitelist: a list of namespaces to allow. If `None`, all will be
883     allowed. If an op does not have a namespace, it will be allowed.
884
885  Raises:
886   ValueError: If the graph contains ops that violate the whitelist.
887  """
888  # By default, if the user has not specified a whitelist, we want to allow
889  # everything.  We check for None directly rather than falseness, since the
890  # user may instead want to pass an empty list to disallow all custom
891  # namespaced ops.
892  if namespace_whitelist is None:
893    return
894
895  invalid_ops = []
896  invalid_namespaces = set()
897
898  all_operations = []
899  all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def))
900
901  for op in all_operations:
902    if ">" in op:
903      namespace = op.split(">")[0]
904      if namespace not in namespace_whitelist:
905        invalid_ops.append(op)
906        invalid_namespaces.add(namespace)
907  if invalid_ops:
908    raise ValueError(
909        "Attempted to save ops from non-whitelisted namespaces to SavedModel: "
910        f"{invalid_ops}.\nPlease verify that these ops should be saved, since "
911        "they must be available when loading the SavedModel. If loading from "
912        "Python, you must import the library defining these ops. From C++, "
913        "link the custom ops to the serving binary. Once you've confirmed this,"
914        " add the following namespaces to the `namespace_whitelist` "
915        f"argument in tf.saved_model.SaveOptions: {invalid_namespaces}.")
916
917
918def _dependency_sorted_node_ids(saveable_view):
919  """Returns topologically sorted nodes, sorted by dependencies."""
920  dependency_map = {}
921  for node in saveable_view.nodes:
922    node_id = saveable_view.node_ids[node]
923    deps = dependency_map[node_id] = []
924    # TODO(kathywu): Remove once all of these have been converted to trackable.
925    if isinstance(node, _CapturedTensor):
926      continue  # These are not `Trackable` and therefore have no dependencies.
927    for _, dep in saveable_view.augmented_graph_view.list_dependencies(node):
928      if dep not in saveable_view.node_ids:
929        node_path = trackable_utils.pretty_print_node_path(
930            saveable_view.node_paths[node])
931        raise ValueError(
932            f"Found an untracked dependency. Object {node_path} depends "
933            f"on {dep}, but this dependency isn't listed as a child. "
934            "Please track this child by overriding `_trackable_children` "
935            "or use `._track_trackable`.")
936      deps.append(saveable_view.node_ids[dep])
937  try:
938    return trackable_utils.order_by_dependency(dependency_map)
939  except trackable_utils.CyclicDependencyError as err:
940    pretty_printed_nodes = []
941    pretty_printed_dependencies = []
942
943    for x, deps in err.leftover_dependency_map.items():
944      node_path = trackable_utils.pretty_print_node_path(
945          saveable_view.node_paths[saveable_view.nodes[x]])
946      pretty_printed_nodes.append(
947          f"\tNode {x} = {node_path} (type {type(saveable_view.nodes[x])})")
948      pretty_printed_dependencies.append(f"\tNode {x} depends on nodes {deps}")
949    pretty_printed_nodes = "\n".join(pretty_printed_nodes)
950    pretty_printed_dependencies = "\n".join(pretty_printed_dependencies)
951    raise ValueError(
952        "There is one or more dependency cycle in the saved Trackable object. "
953        "Saving cannot continue until this cycle is resolved."
954        f"\n>> Unresolved nodes:\n{pretty_printed_nodes}"
955        f"\n>> Unresolved cyclic dependencies:\n{pretty_printed_dependencies}")
956
957
958def _serialize_object_graph(saveable_view, asset_file_def_index):
959  """Save a SavedObjectGraph proto for `root`."""
960  # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
961  # checkpoint. It will eventually go into the SavedModel.
962  proto = saved_object_graph_pb2.SavedObjectGraph()
963  saveable_view.fill_object_graph_proto(proto)
964
965  for concrete_function in saveable_view.concrete_and_gradient_functions:
966    name = compat.as_text(concrete_function.name)
967    serialized = function_serialization.serialize_concrete_function(
968        concrete_function, saveable_view.captured_tensor_node_ids)
969    if serialized is not None:
970      proto.concrete_functions[name].CopyFrom(serialized)
971
972  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
973    _write_object_proto(obj, obj_proto, asset_file_def_index,
974                        saveable_view.augmented_graph_view.list_children)
975  return proto
976
977
978def _write_object_proto(obj, proto, asset_file_def_index, list_children_fn):
979  """Saves an object into SavedObject proto."""
980  if isinstance(obj, asset.Asset):
981    proto.asset.SetInParent()
982    proto.asset.asset_file_def_index = asset_file_def_index[obj]
983  elif resource_variable_ops.is_resource_variable(obj):
984    options = save_context.get_save_options()
985    obj._write_object_proto(proto, options)  # pylint: disable=protected-access
986  elif isinstance(obj, def_function.Function):
987    proto.function.CopyFrom(
988        function_serialization.serialize_function(
989            obj, [x.ref for x in list_children_fn(obj)]))
990  elif isinstance(obj, defun.ConcreteFunction):
991    proto.bare_concrete_function.CopyFrom(
992        function_serialization.serialize_bare_concrete_function(obj))
993  elif isinstance(obj, _CapturedTensor):
994    proto.captured_tensor.name = obj.name
995    proto.captured_tensor.concrete_function = obj.concrete_function
996  elif isinstance(obj, resource.CapturableResource):
997    proto.resource.device = obj._resource_device  # pylint: disable=protected-access
998  else:
999    registered_type_proto = revived_types.serialize(obj)
1000    if registered_type_proto is None:
1001      # Fallback for types with no matching registration
1002      # pylint:disable=protected-access
1003      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
1004          identifier=obj._object_identifier,
1005          version=versions_pb2.VersionDef(
1006              producer=1, min_consumer=1, bad_consumers=[]))
1007      # pylint:enable=protected-access
1008    proto.user_object.CopyFrom(registered_type_proto)
1009
1010  registered_name = registration.get_registered_class_name(obj)
1011  if registered_name:
1012    proto.registered_name = registered_name
1013    serialized_user_proto = obj._serialize_to_proto(object_proto=proto)  # pylint: disable=protected-access
1014    if serialized_user_proto is not None:
1015      proto.serialized_user_proto.Pack(serialized_user_proto)
1016
1017
1018def _export_debug_info(exported_graph, export_dir):
1019  """Exports debug information from graph to file.
1020
1021  Creates and writes GraphDebugInfo with traces for ops in all functions of the
1022  exported_graph.
1023
1024  Args:
1025    exported_graph: A Graph that has been created by tracing a saveable view.
1026    export_dir: SavedModel directory in which to write the debug info.
1027  """
1028  exported_operations = []
1029  for fn_name in exported_graph._functions:  # pylint: disable=protected-access
1030    fn = exported_graph._get_function(fn_name)  # pylint: disable=protected-access
1031    if not isinstance(fn, defun._EagerDefinedFunction):  # pylint: disable=protected-access
1032      continue
1033
1034    fn_graph = fn.graph
1035    for fn_op in fn_graph.get_operations():
1036      exported_operations.append((fn_name, fn_op))
1037
1038  graph_debug_info = error_interpolation.create_graph_debug_info_def(
1039      exported_operations)
1040  file_io.atomic_write_string_to_file(
1041      file_io.join(
1042          utils_impl.get_or_create_debug_dir(export_dir),
1043          constants.DEBUG_INFO_FILENAME_PB),
1044      graph_debug_info.SerializeToString(deterministic=True))
1045
1046
1047@tf_export(
1048    "saved_model.save",
1049    v1=["saved_model.save", "saved_model.experimental.save"])
1050def save(obj, export_dir, signatures=None, options=None):
1051  # pylint: disable=line-too-long
1052  """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk).
1053
1054  The `obj` must inherit from the [`Trackable`
1055  class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591).
1056
1057  Example usage:
1058
1059  >>> class Adder(tf.Module):
1060  ...   @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
1061  ...   def add(self, x):
1062  ...     return x + x
1063
1064  >>> model = Adder()
1065  >>> tf.saved_model.save(model, '/tmp/adder')
1066
1067  The resulting SavedModel is then servable with an input named "x", a scalar
1068  with dtype float32.
1069
1070  _Signatures_
1071
1072  Signatures define the input and output types for a computation. The optional
1073  save `signatures` argument controls which methods in `obj` will be
1074  available to programs which consume `SavedModel`s, for example, serving
1075  APIs. Python functions may be decorated with
1076  `@tf.function(input_signature=...)` and passed as signatures directly, or
1077  lazily with a call to `get_concrete_function` on the method decorated with
1078  `@tf.function`.
1079
1080  Example:
1081
1082  >>> class Adder(tf.Module):
1083  ...   @tf.function
1084  ...   def add(self, x):
1085  ...     return x + x
1086
1087  >>> model = Adder()
1088  >>> tf.saved_model.save(
1089  ...   model, '/tmp/adder',signatures=model.add.get_concrete_function(
1090  ...     tf.TensorSpec([], tf.float32)))
1091
1092  If a `@tf.function` does not have an input signature and
1093  `get_concrete_function` is not called on that method, the function will not
1094  be directly callable in the restored SavedModel.
1095
1096  Example:
1097
1098  >>> class Adder(tf.Module):
1099  ...   @tf.function
1100  ...   def add(self, x):
1101  ...     return x + x
1102
1103  >>> model = Adder()
1104  >>> tf.saved_model.save(model, '/tmp/adder')
1105  >>> restored = tf.saved_model.load('/tmp/adder')
1106  >>> restored.add(1.)
1107  Traceback (most recent call last):
1108  ...
1109  ValueError: Found zero restored functions for caller function.
1110
1111  If the `signatures` argument is omitted, `obj` will be searched for
1112  `@tf.function`-decorated methods. If exactly one traced `@tf.function` is
1113  found, that method will be used as the default signature for the SavedModel.
1114  Else, any `@tf.function` attached to `obj` or its dependencies will be
1115  exported for use with `tf.saved_model.load`.
1116
1117  When invoking a signature in an exported SavedModel, `Tensor` arguments are
1118  identified by name. These names will come from the Python function's argument
1119  names by default. They may be overridden by specifying a `name=...` argument
1120  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
1121  multiple `Tensor`s are passed through a single argument to the Python
1122  function.
1123
1124  The outputs of functions used as `signatures` must either be flat lists, in
1125  which case outputs will be numbered, or a dictionary mapping string keys to
1126  `Tensor`, in which case the keys will be used to name outputs.
1127
1128  Signatures are available in objects returned by `tf.saved_model.load` as a
1129  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
1130  on an object with a custom `.signatures` attribute will raise an exception.
1131
1132  _Using `tf.saved_model.save` with Keras models_
1133
1134  While Keras has its own [saving and loading
1135  API](https://www.tensorflow.org/guide/keras/save_and_serialize),
1136  this function can be used to export Keras models. For example, exporting with
1137  a signature specified:
1138
1139  >>> class Adder(tf.keras.Model):
1140  ...   @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
1141  ...   def concat(self, x):
1142  ...      return x + x
1143
1144  >>> model = Adder()
1145  >>> tf.saved_model.save(model, '/tmp/adder')
1146
1147  Exporting from a function without a fixed signature:
1148
1149  >>> class Adder(tf.keras.Model):
1150  ...   @tf.function
1151  ...   def concat(self, x):
1152  ...      return x + x
1153
1154  >>> model = Adder()
1155  >>> tf.saved_model.save(
1156  ...   model, '/tmp/adder',
1157  ...   signatures=model.concat.get_concrete_function(
1158  ...     tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))
1159
1160  `tf.keras.Model` instances constructed from inputs and outputs already have a
1161  signature and so do not require a `@tf.function` decorator or a `signatures`
1162  argument. If neither are specified, the model's forward pass is exported.
1163
1164  >>> x = tf.keras.layers.Input((4,), name="x")
1165  >>> y = tf.keras.layers.Dense(5, name="out")(x)
1166  >>> model = tf.keras.Model(x, y)
1167  >>> tf.saved_model.save(model, '/tmp/saved_model/')
1168
1169  The exported SavedModel takes "x" with shape [None, 4] and returns "out"
1170  with shape [None, 5]
1171
1172  _Variables and Checkpoints_
1173
1174  Variables must be tracked by assigning them to an attribute of a tracked
1175  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
1176  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
1177  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
1178  uses, and an exported `Checkpoint` object may be restored as a training
1179  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
1180  "variables/" subdirectory.
1181
1182  `tf.function` does not hard-code device annotations from outside the function
1183  body, instead of using the calling context's device. This means for example
1184  that exporting a model that runs on a GPU and serving it on a CPU will
1185  generally work, with some exceptions:
1186
1187    * `tf.device` annotations inside the body of the function will be hard-coded
1188      in the exported model; this type of annotation is discouraged.
1189    * Device-specific operations, e.g. with "cuDNN" in the name or with
1190      device-specific layouts, may cause issues.
1191    * For `ConcreteFunctions`, active distribution strategies will cause device
1192      placements to be hard-coded in the function.
1193
1194  SavedModels exported with `tf.saved_model.save` [strip default-valued
1195  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
1196  automatically, which removes one source of incompatibilities when the consumer
1197  of a SavedModel is running an older TensorFlow version than the
1198  producer. There are however other sources of incompatibilities which are not
1199  handled automatically, such as when the exported model contains operations
1200  which the consumer does not have definitions for.
1201
1202  Args:
1203    obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export.
1204    export_dir: A directory in which to write the SavedModel.
1205    signatures: Optional, one of three types: * a `tf.function` with an input
1206      signature specified, which will use the default serving signature key, *
1207      the result of `f.get_concrete_function` on a `@tf.function`-decorated
1208      function `f`, in which case `f` will be used to generate a signature for
1209      the SavedModel under the default serving signature key, * a dictionary,
1210      which maps signature keys to either `tf.function` instances with input
1211      signatures or concrete functions. Keys of such a dictionary may be
1212      arbitrary strings, but will typically be from the
1213      `tf.saved_model.signature_constants` module.
1214    options: `tf.saved_model.SaveOptions` object for configuring save options.
1215
1216  Raises:
1217    ValueError: If `obj` is not trackable.
1218
1219  @compatibility(eager)
1220  Not well supported when graph building. From TensorFlow 1.x,
1221  `tf.compat.v1.enable_eager_execution()` should run first. Calling
1222  tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
1223  add new save operations to the default graph each iteration.
1224
1225  May not be called from within a function body.
1226  @end_compatibility
1227  """
1228  if isinstance(export_dir, os.PathLike):
1229    export_dir = os.fspath(export_dir)
1230  # pylint: enable=line-too-long
1231  metrics.IncrementWriteApi(_SAVE_V2_LABEL)
1232  save_and_return_nodes(obj, export_dir, signatures, options)
1233
1234  metrics.IncrementWrite(write_version="2")
1235
1236
1237def save_and_return_nodes(obj,
1238                          export_dir,
1239                          signatures=None,
1240                          options=None,
1241                          experimental_skip_checkpoint=False):
1242  """Saves a SavedModel while returning all saved nodes and their paths.
1243
1244  Please see `tf.saved_model.save` for details.
1245
1246  Args:
1247    obj: A trackable object to export.
1248    export_dir: A directory in which to write the SavedModel.
1249    signatures: A function or dictionary of functions to save in the SavedModel
1250      as signatures.
1251    options: `tf.saved_model.SaveOptions` object for configuring save options.
1252    experimental_skip_checkpoint: If set to `True`, the checkpoint will not be
1253      written.
1254
1255  Returns:
1256    A tuple of (a list of saved nodes in the order they are serialized to the
1257      `SavedObjectGraph`, dictionary mapping nodes to one possible path from
1258      the root node to the key node)
1259  """
1260  options = options or save_options.SaveOptions()
1261  # TODO(b/205008509): Factor out some subset of SavedModelBuilder which is 2.x
1262  # compatible (no sessions) and share it with this export API rather than
1263  # making a SavedModel proto and writing it directly.
1264  saved_model = saved_model_pb2.SavedModel()
1265  meta_graph_def = saved_model.meta_graphs.add()
1266
1267  _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
1268      _build_meta_graph(obj, signatures, options, meta_graph_def))
1269  saved_model.saved_model_schema_version = (
1270      constants.SAVED_MODEL_SCHEMA_VERSION)
1271
1272  # Write the checkpoint, copy assets into the assets directory, and write out
1273  # the SavedModel proto itself.
1274  if not experimental_skip_checkpoint:
1275    utils_impl.get_or_create_variables_dir(export_dir)
1276    ckpt_options = checkpoint_options.CheckpointOptions(
1277        experimental_io_device=options.experimental_io_device)
1278    object_saver.save(
1279        utils_impl.get_variables_path(export_dir), options=ckpt_options)
1280  builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
1281                                              export_dir)
1282  # Note that this needs to be the last file operation when saving the
1283  # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
1284  # indication that the SavedModel is completely written.
1285  if context.executing_eagerly():
1286    try:
1287      context.async_wait()  # Ensure save operations have completed.
1288    except errors.NotFoundError as err:
1289      raise FileNotFoundError(
1290          f"{err}\n You may be trying to save on a different device from the "
1291          "computational device. Consider setting the "
1292          "`experimental_io_device` option in `tf.saved_model.SaveOptions` "
1293          "to the io_device such as '/job:localhost'.")
1294
1295  # We will slowly migrate code in this function to pywrap_saved_model.Save
1296  # as we build up the C++ API.
1297  pywrap_saved_model.Save(export_dir)
1298
1299  saved_model_serialized = saved_model.SerializeToString(deterministic=True)
1300
1301  # Write fingerprint protobuf, if requested.
1302  if flags.config().saved_model_fingerprinting.value():
1303    fingerprint_path = file_io.join(
1304        compat.as_str(export_dir),
1305        compat.as_str(constants.FINGERPRINT_FILENAME))
1306    fingerprint_proto = fingerprinting.CreateFingerprintDef(
1307        saved_model_serialized, export_dir)
1308    file_io.atomic_write_string_to_file(fingerprint_path, fingerprint_proto)
1309
1310  path = file_io.join(
1311      compat.as_str(export_dir),
1312      compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
1313  file_io.atomic_write_string_to_file(
1314      path, saved_model.SerializeToString(deterministic=True))
1315
1316  # Save debug info, if requested.
1317  if options.save_debug_info:
1318    _export_debug_info(exported_graph, export_dir)
1319  # Clean reference cycles so repeated export()s don't make work for the garbage
1320  # collector. Before this point, we need to keep references to captured
1321  # constants in the saved graph.
1322  ops.dismantle_graph(exported_graph)
1323
1324  return saved_nodes, node_paths
1325
1326
1327def export_meta_graph(obj, filename, signatures=None, options=None):
1328  """Exports the MetaGraph proto of the `obj` to a file.
1329
1330  This function goes through the same procedures saved_model.save goes to
1331  produce the given object's MetaGraph, then saves it to the given file. It
1332  skips saving checkpoint information, and is useful when all one wants is the
1333  graph defining the model.
1334
1335  Args:
1336    obj: A trackable object to build the MetaGraph from.
1337    filename: The file into which to write the MetaGraph.
1338    signatures: Optional, either a `tf.function` with an input signature
1339      specified or the result of `f.get_concrete_function` on a
1340      `@tf.function`-decorated function `f`, in which case `f` will be used to
1341      generate a signature for the SavedModel under the default serving
1342      signature key. `signatures` may also be a dictionary, in which case it
1343      maps from signature keys to either `tf.function` instances with input
1344      signatures or concrete functions. The keys of such a dictionary may be
1345      arbitrary strings, but will typically be from the
1346      `tf.saved_model.signature_constants` module.
1347    options: Optional, `tf.saved_model.SaveOptions` object that specifies
1348      options for saving.
1349  """
1350  options = options or save_options.SaveOptions()
1351  export_dir = os.path.dirname(filename)
1352  meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
1353      obj, signatures, options)
1354
1355  file_io.atomic_write_string_to_file(
1356      filename, meta_graph_def.SerializeToString(deterministic=True))
1357
1358  # Save debug info, if requested.
1359  if options.save_debug_info:
1360    _export_debug_info(exported_graph, export_dir)
1361
1362  # Clean reference cycles so repeated export()s don't make work for the garbage
1363  # collector. Before this point, we need to keep references to captured
1364  # constants in the saved graph.
1365  ops.dismantle_graph(exported_graph)
1366
1367
1368def _build_meta_graph_impl(obj, signatures, options, meta_graph_def=None):
1369  """Creates a MetaGraph containing the resources and functions of an object."""
1370  if ops.inside_function():
1371    raise AssertionError(
1372        "`tf.saved_model.save` is not supported inside a traced @tf.function. "
1373        "Move the call to the outer eagerly-executed context.")
1374  # pylint: enable=line-too-long
1375  if not isinstance(obj, base.Trackable):
1376    raise ValueError(
1377        "Expected an object of type `Trackable`, such as `tf.Module` or a "
1378        f"subclass of the `Trackable` class, for export. Got {obj} "
1379        f"with type {type(obj)}.")
1380  meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
1381
1382  augmented_graph_view = _AugmentedGraphView(obj)
1383  if signatures is None:
1384    signatures = signature_serialization.find_function_to_export(
1385        augmented_graph_view)
1386
1387  signatures, wrapped_functions = (
1388      signature_serialization.canonicalize_signatures(signatures))
1389  signature_serialization.validate_augmented_graph_view(augmented_graph_view)
1390  signature_map = signature_serialization.create_signature_map(signatures)
1391  augmented_graph_view.set_signature(signature_map, wrapped_functions)
1392
1393  # Use _SaveableView to provide a frozen listing of properties and functions.
1394  saveable_view = _SaveableView(augmented_graph_view, options)
1395  object_saver = checkpoint.TrackableSaver(augmented_graph_view)
1396  asset_info, exported_graph = _fill_meta_graph_def(
1397      meta_graph_def, saveable_view, signatures, options.namespace_whitelist,
1398      options.experimental_custom_gradients)
1399  if options.function_aliases:
1400    function_aliases = meta_graph_def.meta_info_def.function_aliases
1401    for alias, func in options.function_aliases.items():
1402      for fdef in func._list_all_concrete_functions():  # pylint: disable=protected-access
1403        function_aliases[fdef.name] = alias
1404
1405  object_graph_proto = _serialize_object_graph(saveable_view,
1406                                               asset_info.asset_index)
1407  meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
1408
1409  return (meta_graph_def, exported_graph, object_saver, asset_info,
1410          saveable_view.nodes, saveable_view.node_paths)
1411
1412
1413def _build_meta_graph(obj, signatures, options, meta_graph_def=None):
1414  """Creates a MetaGraph under a save context.
1415
1416  Args:
1417    obj: A trackable object to build the MetaGraph from.
1418    signatures: Can be a `tf.function` with an input signature specified or the
1419      result of `f.get_concrete_function` on a `@tf.function`-decorated function
1420      `f`. `signatures` may also be a dictionary, in which case it maps from
1421      signature keys to `tf.function` instances. If None, finds signature to
1422      export from the `@tf.function`-decorated methods in `obj`.
1423    options: `tf.saved_model.SaveOptions` object that specifies options for
1424      saving.
1425    meta_graph_def: Optional, the MetaGraphDef proto fill.
1426
1427  Raises:
1428    AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
1429    ValueError: If `obj` is not trackable.
1430
1431  Returns:
1432    meta_graph_def: Filled MetaGraphDef proto
1433    exported_graph: `tf.Graph` object generated from `obj`.
1434    object_saver: `checkpoint.TrackableSaver` of the `obj` and its dependencies.
1435    asset_info: `_AssetInfo` tuple containing external assets in the `obj`.
1436    saveable_view.nodes: _SaveableView nodes.
1437    saveable_view.node_paths: _SaveableView paths.
1438  """
1439
1440  with save_context.save_context(options):
1441    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
1442