xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/signature_serialization.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers for working with signatures in tf.saved_model.save."""
16
17from absl import logging
18
19from tensorflow.python.eager import def_function
20from tensorflow.python.eager import function as defun
21from tensorflow.python.framework import composite_tensor
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_spec
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.saved_model import function_serialization
26from tensorflow.python.saved_model import revived_types
27from tensorflow.python.saved_model import signature_constants
28from tensorflow.python.trackable import base
29from tensorflow.python.util import compat
30from tensorflow.python.util import nest
31from tensorflow.python.util.compat import collections_abc
32
33
34DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
35SIGNATURE_ATTRIBUTE_NAME = "signatures"
36# Max number of warnings to show if signature contains normalized input names.
37_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5
38
39
40def _get_signature(function):
41  if (isinstance(function, (defun.Function, def_function.Function)) and
42      function.input_signature is not None):
43    function = function._get_concrete_function_garbage_collected()  # pylint: disable=protected-access
44  if not isinstance(function, defun.ConcreteFunction):
45    return None
46  return function
47
48
49def _valid_signature(concrete_function):
50  """Returns whether concrete function can be converted to a signature."""
51  if not concrete_function.outputs:
52    # Functions without outputs don't make sense as signatures. We just don't
53    # have any way to run an Operation with no outputs as a SignatureDef in the
54    # 1.x style.
55    return False
56  try:
57    _validate_inputs(concrete_function)
58    _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
59  except ValueError:
60    return False
61  return True
62
63
64def _validate_inputs(concrete_function):
65  """Raises error if input type is tf.Variable."""
66  if any(isinstance(inp, resource_variable_ops.VariableSpec)
67         for inp in nest.flatten(
68             concrete_function.structured_input_signature)):
69    raise ValueError(
70        f"Unable to serialize concrete_function '{concrete_function.name}'"
71        f"with tf.Variable input. Functions that expect tf.Variable "
72        "inputs cannot be exported as signatures.")
73
74
75def _get_signature_name_changes(concrete_function):
76  """Checks for user-specified signature input names that are normalized."""
77  # Map of {user-given name: normalized name} if the names are un-identical.
78  name_changes = {}
79  for signature_input_name, graph_input in zip(
80      concrete_function.function_def.signature.input_arg,
81      concrete_function.graph.inputs):
82    try:
83      user_specified_name = compat.as_str(
84          graph_input.op.get_attr("_user_specified_name"))
85      if signature_input_name.name != user_specified_name:
86        name_changes[user_specified_name] = signature_input_name.name
87    except ValueError:
88      # Signature input does not have a user-specified name.
89      pass
90  return name_changes
91
92
93def find_function_to_export(saveable_view):
94  """Function to export, None if no suitable function was found."""
95  # If the user did not specify signatures, check the root object for a function
96  # that can be made into a signature.
97  children = saveable_view.list_children(saveable_view.root)
98
99  # TODO(b/205014194): Discuss removing this behaviour. It can lead to WTFs when
100  # a user decides to annotate more functions with tf.function and suddenly
101  # serving that model way later in the process stops working.
102  possible_signatures = []
103  for name, child in children:
104    if not isinstance(child, (def_function.Function, defun.ConcreteFunction)):
105      continue
106    if name == DEFAULT_SIGNATURE_ATTR:
107      return child
108    concrete = _get_signature(child)
109    if concrete is not None and _valid_signature(concrete):
110      possible_signatures.append(concrete)
111
112  if len(possible_signatures) == 1:
113    single_function = possible_signatures[0]
114    signature = _get_signature(single_function)
115    if signature and  _valid_signature(signature):
116      return signature
117  return None
118
119
120def canonicalize_signatures(signatures):
121  """Converts `signatures` into a dictionary of concrete functions."""
122  if signatures is None:
123    return {}, {}
124  if not isinstance(signatures, collections_abc.Mapping):
125    signatures = {
126        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
127  num_normalized_signatures_counter = 0
128  concrete_signatures = {}
129  wrapped_functions = {}
130  for signature_key, function in signatures.items():
131    original_function = signature_function = _get_signature(function)
132    if signature_function is None:
133      raise ValueError(
134          "Expected a TensorFlow function for which to generate a signature, "
135          f"but got {function}. Only `tf.functions` with an input signature or "
136          "concrete functions can be used as a signature.")
137
138    wrapped_functions[original_function] = signature_function = (
139        wrapped_functions.get(original_function) or
140        function_serialization.wrap_cached_variables(original_function))
141    _validate_inputs(signature_function)
142    if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
143      signature_name_changes = _get_signature_name_changes(signature_function)
144      if signature_name_changes:
145        num_normalized_signatures_counter += 1
146        logging.warning(
147            "Function `%s` contains input name(s) %s with unsupported "
148            "characters which will be renamed to %s in the SavedModel.",
149            compat.as_str(signature_function.graph.name),
150            ", ".join(signature_name_changes.keys()),
151            ", ".join(signature_name_changes.values()))
152    # Re-wrap the function so that it returns a dictionary of Tensors. This
153    # matches the format of 1.x-style signatures.
154    # pylint: disable=cell-var-from-loop
155    @def_function.function
156    def signature_wrapper(**kwargs):
157      structured_outputs = signature_function(**kwargs)
158      return _normalize_outputs(
159          structured_outputs, signature_function.name, signature_key)
160    tensor_spec_signature = {}
161    if signature_function.structured_input_signature is not None:
162      # The structured input signature may contain other non-tensor arguments.
163      inputs = filter(
164          lambda x: isinstance(x, tensor_spec.TensorSpec),
165          nest.flatten(signature_function.structured_input_signature,
166                       expand_composites=True))
167    else:
168      # Structured input signature isn't always defined for some functions.
169      inputs = signature_function.inputs
170
171    for keyword, inp in zip(
172        signature_function._arg_keywords,  # pylint: disable=protected-access
173        inputs):
174      keyword = compat.as_str(keyword)
175      if isinstance(inp, tensor_spec.TensorSpec):
176        spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
177      else:
178        spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
179      tensor_spec_signature[keyword] = spec
180    final_concrete = signature_wrapper._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
181        **tensor_spec_signature)
182    # pylint: disable=protected-access
183    if len(final_concrete._arg_keywords) == 1:
184      # If there is only one input to the signature, a very common case, then
185      # ordering is unambiguous and we can let people pass a positional
186      # argument. Since SignatureDefs are unordered (protobuf "map") multiple
187      # arguments means we need to be keyword-only.
188      final_concrete._num_positional_args = 1
189    else:
190      final_concrete._num_positional_args = 0
191    # pylint: enable=protected-access
192    concrete_signatures[signature_key] = final_concrete
193    # pylint: enable=cell-var-from-loop
194  return concrete_signatures, wrapped_functions
195
196
197def _normalize_outputs(outputs, function_name, signature_key):
198  """Normalize outputs if necessary and check that they are tensors."""
199  # Convert `outputs` to a dictionary (if it's not one already).
200  if not isinstance(outputs, collections_abc.Mapping):
201    # Check if `outputs` is a namedtuple.
202    if hasattr(outputs, "_asdict"):
203      outputs = outputs._asdict()
204    else:
205      if not isinstance(outputs, collections_abc.Sequence):
206        outputs = [outputs]
207      outputs = {("output_{}".format(output_index)): output
208                 for output_index, output in enumerate(outputs)}
209
210  # Check that the keys of `outputs` are strings and the values are Tensors.
211  for key, value in outputs.items():
212    if not isinstance(key, compat.bytes_or_text_types):
213      raise ValueError(
214          f"Got a dictionary with a non-string key {key!r} in the output of "
215          f"the function {compat.as_str_any(function_name)} used to generate "
216          f"the SavedModel signature {signature_key!r}.")
217    if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
218      raise ValueError(
219          f"Got a non-Tensor value {value!r} for key {key!r} in the output of "
220          f"the function {compat.as_str_any(function_name)} used to generate "
221          f"the SavedModel signature {signature_key!r}. "
222          "Outputs for functions used as signatures must be a single Tensor, "
223          "a sequence of Tensors, or a dictionary from string to Tensor.")
224  return outputs
225
226
227# _SignatureMap is immutable to ensure that users do not expect changes to be
228# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
229# only way to create a _SignatureMap and there is no way to modify it. So we can
230# safely ignore/overwrite ".signatures" attributes attached to objects being
231# saved if they contain a _SignatureMap. A ".signatures" attribute containing
232# any other type (e.g. a regular dict) will raise an exception asking the user
233# to first "del obj.signatures" if they want it overwritten.
234class _SignatureMap(collections_abc.Mapping, base.Trackable):
235  """A collection of SavedModel signatures."""
236
237  def __init__(self):
238    self._signatures = {}
239
240  def _add_signature(self, name, concrete_function):
241    """Adds a signature to the _SignatureMap."""
242    # Ideally this object would be immutable, but restore is streaming so we do
243    # need a private API for adding new signatures to an existing object.
244    self._signatures[name] = concrete_function
245
246  def __getitem__(self, key):
247    return self._signatures[key]
248
249  def __iter__(self):
250    return iter(self._signatures)
251
252  def __len__(self):
253    return len(self._signatures)
254
255  def __repr__(self):
256    return "_SignatureMap({})".format(self._signatures)
257
258  def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs):
259    if save_type != base.SaveType.SAVEDMODEL:
260      return {}
261
262    return {
263        key: value for key, value in self.items()
264        if isinstance(value, (def_function.Function, defun.ConcreteFunction))
265    }
266
267
268revived_types.register_revived_type(
269    "signature_map",
270    lambda obj: isinstance(obj, _SignatureMap),
271    versions=[revived_types.VersionedTypeRegistration(
272        # Standard dependencies are enough to reconstruct the trackable
273        # items in dictionaries, so we don't need to save any extra information.
274        object_factory=lambda proto: _SignatureMap(),
275        version=1,
276        min_producer_version=1,
277        min_consumer_version=1,
278        setter=_SignatureMap._add_signature  # pylint: disable=protected-access
279    )])
280
281
282def create_signature_map(signatures):
283  """Creates an object containing `signatures`."""
284  signature_map = _SignatureMap()
285  for name, func in signatures.items():
286    # This true of any signature that came from canonicalize_signatures. Here as
287    # a sanity check on saving; crashing on load (e.g. in _add_signature) would
288    # be more problematic in case future export changes violated these
289    # assertions.
290    assert isinstance(func, defun.ConcreteFunction)
291    assert isinstance(func.structured_outputs, collections_abc.Mapping)
292    # pylint: disable=protected-access
293    if len(func._arg_keywords) == 1:
294      assert 1 == func._num_positional_args
295    else:
296      assert 0 == func._num_positional_args
297    signature_map._add_signature(name, func)
298    # pylint: enable=protected-access
299  return signature_map
300
301
302def validate_augmented_graph_view(augmented_graph_view):
303  """Performs signature-related sanity checks on `augmented_graph_view`."""
304  for name, dep in augmented_graph_view.list_children(
305      augmented_graph_view.root):
306    if name == SIGNATURE_ATTRIBUTE_NAME:
307      if not isinstance(dep, _SignatureMap):
308        raise ValueError(
309            f"Exporting an object {augmented_graph_view.root} which has an attribute "
310            f"named '{SIGNATURE_ATTRIBUTE_NAME}'. This is a reserved attribute "
311            "used to store SavedModel signatures in objects which come from "
312            "`tf.saved_model.load`. Delete this attribute "
313            f"(e.g. `del obj.{SIGNATURE_ATTRIBUTE_NAME}`) before saving if "
314            "this shadowing is acceptable.")
315      break
316