xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/load_v1_in_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a TF v1-style SavedModel when executing eagerly."""
16
17import functools
18
19from tensorflow.python.eager import context
20from tensorflow.python.eager import lift_to_graph
21from tensorflow.python.eager import wrap_function
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import func_graph
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.saved_model import function_deserialization
31from tensorflow.python.saved_model import loader_impl
32from tensorflow.python.saved_model import signature_serialization
33from tensorflow.python.saved_model.pywrap_saved_model import metrics
34from tensorflow.python.trackable import asset
35from tensorflow.python.trackable import autotrackable
36from tensorflow.python.trackable import resource
37from tensorflow.python.training import monitored_session
38from tensorflow.python.training import saver as tf_saver
39from tensorflow.python.util import nest
40
41# API label for SavedModel metrics.
42_LOAD_V1_V2_LABEL = "load_v1_in_v2"
43
44
45class _Initializer(resource.CapturableResource):
46  """Represents an initialization operation restored from a SavedModel.
47
48  Without this object re-export of imported 1.x SavedModels would omit the
49  original SavedModel's initialization procedure.
50
51  Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
52  initialization op. This object holds a function that runs the
53  initialization. It does not require any manual user intervention;
54  `tf.saved_model.save` will see this object and automatically add it to the
55  exported SavedModel, and `tf.saved_model.load` runs the initialization
56  function automatically.
57  """
58
59  def __init__(self, init_fn, asset_paths):
60    super(_Initializer, self).__init__()
61    self._asset_paths = asset_paths
62    self._init_fn = init_fn
63
64  def _create_resource(self):
65    return array_ops.placeholder(
66        dtype=dtypes.resource, shape=[], name="unused_resource")
67
68  def _initialize(self):
69    return self._init_fn(*[path.asset_path for path in self._asset_paths])
70
71
72class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
73  """Loads a SavedModel without using Sessions."""
74
75  def get_meta_graph_def_from_tags(self, tags):
76    """Override to support implicit one-MetaGraph loading with tags=None."""
77    if tags is None:
78      if len(self._saved_model.meta_graphs) != 1:
79        tag_sets = [mg.meta_info_def.tags
80                    for mg in self._saved_model.meta_graphs]
81        raise ValueError(
82            "Importing a SavedModel with `tf.saved_model.load` requires a "
83            "`tags=` argument if there is more than one MetaGraph. Got "
84            f"`tags=None`, but there are {len(self._saved_model.meta_graphs)} "
85            f"MetaGraphs in the SavedModel with tag sets: {tag_sets}. Pass a "
86            "`tags=` argument to load this SavedModel.")
87      return self._saved_model.meta_graphs[0]
88    return super(_EagerSavedModelLoader, self).get_meta_graph_def_from_tags(
89        tags)
90
91  def load_graph(self, returns, meta_graph_def):
92    """Called from wrap_function to import `meta_graph_def`."""
93    # pylint: disable=protected-access
94    saver, _ = tf_saver._import_meta_graph_with_return_elements(
95        meta_graph_def)
96    # pylint: enable=protected-access
97    returns[0] = saver
98
99  def _extract_saver_restore(self, wrapped, saver):
100    if saver is None:
101      return None
102    saver_def = saver.saver_def
103    filename_tensor = wrapped.graph.as_graph_element(
104        saver_def.filename_tensor_name)
105    # We both feed and fetch filename_tensor so we have an operation to use to
106    # feed into variable initializers (only relevant for v1 graph building).
107    return wrapped.prune(
108        feeds=[filename_tensor],
109        fetches=[filename_tensor,
110                 wrapped.graph.as_graph_element(saver_def.restore_op_name)])
111
112  def restore_variables(self, wrapped, restore_from_saver):
113    """Restores variables from the checkpoint."""
114    if restore_from_saver is not None:
115      initializer, _ = restore_from_saver(
116          constant_op.constant(self._variables_path))
117      if not ops.executing_eagerly_outside_functions():
118        # Add the initialization operation to the "saved_model_initializers"
119        # collection in case we don't have any lifted variables to attach it to.
120        ops.add_to_collection("saved_model_initializers", initializer)
121        one_unlifted = False
122
123        for variable in wrapped.graph.get_collection_ref(
124            ops.GraphKeys.GLOBAL_VARIABLES):
125          if variable.graph is wrapped.graph:
126            one_unlifted = True
127          # pylint: disable=protected-access
128          variable._initializer_op = initializer
129          # pylint: enable=protected-access
130        if one_unlifted:
131          logging.warning(
132              "Some variables could not be lifted out of a loaded function. "
133              "Please run "
134              "`sess.run(tf.get_collection(\"saved_model_initializers\"))`to "
135              "restore these variables.")
136
137  def _extract_signatures(self, wrapped, meta_graph_def):
138    """Creates ConcreteFunctions for signatures in `meta_graph_def`."""
139    signature_functions = {}
140    for signature_key, signature_def in meta_graph_def.signature_def.items():
141      if signature_def.inputs:
142        input_items = sorted(
143            signature_def.inputs.items(), key=lambda item: item[0])
144        original_input_names, input_specs = zip(*input_items)
145      else:
146        original_input_names = []
147        input_specs = []
148      # TODO(b/205015292): Support optional arguments
149      feeds = [
150          wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph)  # pylint: disable=protected-access
151          for input_spec in input_specs
152      ]
153      input_names = []
154      input_tensors = []
155      for original_input_name, feed in zip(original_input_names, feeds):
156        if isinstance(feed, sparse_tensor.SparseTensor):
157          # We have to give explicit name for SparseTensor arguments, because
158          # these are not present in the TensorInfo.
159          indices_name = "%s_indices" % original_input_name
160          values_name = "%s_values" % original_input_name
161          dense_shape_name = "%s_dense_shape" % original_input_name
162          input_names.extend([indices_name, values_name, dense_shape_name])
163          input_tensors.extend([feed.indices, feed.values, feed.dense_shape])
164        elif isinstance(feed, composite_tensor.CompositeTensor):
165          component_tensors = nest.flatten(feed, expand_composites=True)
166          input_names.extend("%s_component_%d" % (original_input_name, n)
167                             for n in range(len(component_tensors)))
168          input_tensors.extend(component_tensors)
169        else:
170          input_names.append(original_input_name)
171          input_tensors.append(feed)
172      fetches = {name: out for name, out in signature_def.outputs.items()}
173      try:
174        signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
175      except lift_to_graph.UnliftableError as ex:
176        # Mutate the exception to add a bit more detail.
177        args = ex.args
178        if not args:
179          message = ""
180        else:
181          message = args[0]
182        message = (
183            ("A SavedModel signature needs an input for each placeholder the "
184             "signature's outputs use. An output for signature '{}' depends on "
185             "a placeholder which is not an input (i.e. the placeholder is not "
186             "fed a value).\n\n").format(signature_key)
187            + message)
188        ex.args = (message,) + args[1:]
189        raise
190      # pylint: disable=protected-access
191      signature_fn._arg_keywords = input_names
192      signature_fn._func_graph.structured_input_signature = (
193          (),
194          func_graph.convert_structure_to_signature(
195              dict(zip(input_names, input_tensors))))
196
197      if len(input_names) == 1:
198        # Allowing positional arguments does not create any ambiguity if there's
199        # only one.
200        signature_fn._num_positional_args = 1
201      else:
202        signature_fn._num_positional_args = 0
203      # pylint: enable=protected-access
204      signature_functions[signature_key] = signature_fn
205    return signature_functions
206
207  def load(self, tags):
208    """Creates an object from the MetaGraph identified by `tags`."""
209    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
210    load_shared_name_suffix = "_load_{}".format(ops.uid())
211    functions = function_deserialization.load_function_def_library(
212        meta_graph_def.graph_def.library,
213        load_shared_name_suffix=load_shared_name_suffix)
214    # Replace existing functions in the MetaGraphDef with renamed functions so
215    # we don't have duplicates or name collisions.
216    meta_graph_def.graph_def.library.Clear()
217    for function in functions.values():
218      meta_graph_def.graph_def.library.function.add().CopyFrom(
219          function.function_def)
220    # We've renamed functions and shared names. We need the same operation on
221    # the GraphDef itself for consistency.
222    for node_def in meta_graph_def.graph_def.node:
223      function_deserialization.fix_node_def(node_def, functions,
224                                            load_shared_name_suffix)
225
226    load_graph_returns = [None]
227    wrapped = wrap_function.wrap_function(
228        functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
229        signature=[])
230    saver, = load_graph_returns
231    restore_from_saver = self._extract_saver_restore(wrapped, saver)
232    self.restore_variables(wrapped, restore_from_saver)
233    with wrapped.graph.as_default():
234      init_op = loader_impl.get_init_op(
235          meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
236      # Add a dummy Tensor we know we can fetch to add control dependencies to.
237      init_anchor = constant_op.constant(0., name="dummy_fetch")
238
239    root = autotrackable.AutoTrackable()
240    if restore_from_saver is not None:
241      root.restore = (
242          lambda path: restore_from_saver(constant_op.constant(path)))
243    asset_feed_tensors = []
244    asset_paths = []
245    for tensor_name, value in loader_impl.get_asset_tensors(
246        self._export_dir, meta_graph_def).items():
247      asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
248      asset_paths.append(asset.Asset(value))
249    init_fn = wrapped.prune(
250        feeds=asset_feed_tensors,
251        fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
252    initializer = _Initializer(init_fn, asset_paths)
253    # pylint: disable=protected-access
254    local_init_op, _ = initializer._initialize()
255    # pylint: enable=protected-access
256    with ops.init_scope():
257      if not context.executing_eagerly():
258        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
259        for variable in wrapped.graph.get_collection_ref(
260            ops.GraphKeys.LOCAL_VARIABLES):
261          # pylint: disable=protected-access
262          variable._initializer_op = local_init_op
263          # pylint: enable=protected-access
264    root.initializer = initializer
265    root.asset_paths = asset_paths
266    signature_functions = self._extract_signatures(wrapped, meta_graph_def)
267
268    root.signatures = signature_serialization.create_signature_map(
269        signature_functions)
270    root.variables = list(wrapped.graph.variables)
271    root.tensorflow_version = (
272        meta_graph_def.meta_info_def.tensorflow_version)
273    root.tensorflow_git_version = (
274        meta_graph_def.meta_info_def.tensorflow_git_version)
275    root.graph = wrapped.graph
276    root.prune = wrapped.prune
277    return root
278
279
280def load(export_dir, tags):
281  """Load a v1-style SavedModel as an object."""
282  metrics.IncrementReadApi(_LOAD_V1_V2_LABEL)
283  loader = _EagerSavedModelLoader(export_dir)
284  result = loader.load(tags=tags)
285  metrics.IncrementRead(write_version="1")
286  return result
287