1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers for working with signatures in tf.saved_model.save.""" 16 17from absl import logging 18 19from tensorflow.python.eager import def_function 20from tensorflow.python.eager import function as defun 21from tensorflow.python.framework import composite_tensor 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_spec 24from tensorflow.python.ops import resource_variable_ops 25from tensorflow.python.saved_model import function_serialization 26from tensorflow.python.saved_model import revived_types 27from tensorflow.python.saved_model import signature_constants 28from tensorflow.python.trackable import base 29from tensorflow.python.util import compat 30from tensorflow.python.util import nest 31from tensorflow.python.util.compat import collections_abc 32 33 34DEFAULT_SIGNATURE_ATTR = "_default_save_signature" 35SIGNATURE_ATTRIBUTE_NAME = "signatures" 36# Max number of warnings to show if signature contains normalized input names. 37_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5 38 39 40def _get_signature(function): 41 if (isinstance(function, (defun.Function, def_function.Function)) and 42 function.input_signature is not None): 43 function = function._get_concrete_function_garbage_collected() # pylint: disable=protected-access 44 if not isinstance(function, defun.ConcreteFunction): 45 return None 46 return function 47 48 49def _valid_signature(concrete_function): 50 """Returns whether concrete function can be converted to a signature.""" 51 if not concrete_function.outputs: 52 # Functions without outputs don't make sense as signatures. We just don't 53 # have any way to run an Operation with no outputs as a SignatureDef in the 54 # 1.x style. 55 return False 56 try: 57 _validate_inputs(concrete_function) 58 _normalize_outputs(concrete_function.structured_outputs, "unused", "unused") 59 except ValueError: 60 return False 61 return True 62 63 64def _validate_inputs(concrete_function): 65 """Raises error if input type is tf.Variable.""" 66 if any(isinstance(inp, resource_variable_ops.VariableSpec) 67 for inp in nest.flatten( 68 concrete_function.structured_input_signature)): 69 raise ValueError( 70 f"Unable to serialize concrete_function '{concrete_function.name}'" 71 f"with tf.Variable input. Functions that expect tf.Variable " 72 "inputs cannot be exported as signatures.") 73 74 75def _get_signature_name_changes(concrete_function): 76 """Checks for user-specified signature input names that are normalized.""" 77 # Map of {user-given name: normalized name} if the names are un-identical. 78 name_changes = {} 79 for signature_input_name, graph_input in zip( 80 concrete_function.function_def.signature.input_arg, 81 concrete_function.graph.inputs): 82 try: 83 user_specified_name = compat.as_str( 84 graph_input.op.get_attr("_user_specified_name")) 85 if signature_input_name.name != user_specified_name: 86 name_changes[user_specified_name] = signature_input_name.name 87 except ValueError: 88 # Signature input does not have a user-specified name. 89 pass 90 return name_changes 91 92 93def find_function_to_export(saveable_view): 94 """Function to export, None if no suitable function was found.""" 95 # If the user did not specify signatures, check the root object for a function 96 # that can be made into a signature. 97 children = saveable_view.list_children(saveable_view.root) 98 99 # TODO(b/205014194): Discuss removing this behaviour. It can lead to WTFs when 100 # a user decides to annotate more functions with tf.function and suddenly 101 # serving that model way later in the process stops working. 102 possible_signatures = [] 103 for name, child in children: 104 if not isinstance(child, (def_function.Function, defun.ConcreteFunction)): 105 continue 106 if name == DEFAULT_SIGNATURE_ATTR: 107 return child 108 concrete = _get_signature(child) 109 if concrete is not None and _valid_signature(concrete): 110 possible_signatures.append(concrete) 111 112 if len(possible_signatures) == 1: 113 single_function = possible_signatures[0] 114 signature = _get_signature(single_function) 115 if signature and _valid_signature(signature): 116 return signature 117 return None 118 119 120def canonicalize_signatures(signatures): 121 """Converts `signatures` into a dictionary of concrete functions.""" 122 if signatures is None: 123 return {}, {} 124 if not isinstance(signatures, collections_abc.Mapping): 125 signatures = { 126 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} 127 num_normalized_signatures_counter = 0 128 concrete_signatures = {} 129 wrapped_functions = {} 130 for signature_key, function in signatures.items(): 131 original_function = signature_function = _get_signature(function) 132 if signature_function is None: 133 raise ValueError( 134 "Expected a TensorFlow function for which to generate a signature, " 135 f"but got {function}. Only `tf.functions` with an input signature or " 136 "concrete functions can be used as a signature.") 137 138 wrapped_functions[original_function] = signature_function = ( 139 wrapped_functions.get(original_function) or 140 function_serialization.wrap_cached_variables(original_function)) 141 _validate_inputs(signature_function) 142 if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES: 143 signature_name_changes = _get_signature_name_changes(signature_function) 144 if signature_name_changes: 145 num_normalized_signatures_counter += 1 146 logging.warning( 147 "Function `%s` contains input name(s) %s with unsupported " 148 "characters which will be renamed to %s in the SavedModel.", 149 compat.as_str(signature_function.graph.name), 150 ", ".join(signature_name_changes.keys()), 151 ", ".join(signature_name_changes.values())) 152 # Re-wrap the function so that it returns a dictionary of Tensors. This 153 # matches the format of 1.x-style signatures. 154 # pylint: disable=cell-var-from-loop 155 @def_function.function 156 def signature_wrapper(**kwargs): 157 structured_outputs = signature_function(**kwargs) 158 return _normalize_outputs( 159 structured_outputs, signature_function.name, signature_key) 160 tensor_spec_signature = {} 161 if signature_function.structured_input_signature is not None: 162 # The structured input signature may contain other non-tensor arguments. 163 inputs = filter( 164 lambda x: isinstance(x, tensor_spec.TensorSpec), 165 nest.flatten(signature_function.structured_input_signature, 166 expand_composites=True)) 167 else: 168 # Structured input signature isn't always defined for some functions. 169 inputs = signature_function.inputs 170 171 for keyword, inp in zip( 172 signature_function._arg_keywords, # pylint: disable=protected-access 173 inputs): 174 keyword = compat.as_str(keyword) 175 if isinstance(inp, tensor_spec.TensorSpec): 176 spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword) 177 else: 178 spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword) 179 tensor_spec_signature[keyword] = spec 180 final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access 181 **tensor_spec_signature) 182 # pylint: disable=protected-access 183 if len(final_concrete._arg_keywords) == 1: 184 # If there is only one input to the signature, a very common case, then 185 # ordering is unambiguous and we can let people pass a positional 186 # argument. Since SignatureDefs are unordered (protobuf "map") multiple 187 # arguments means we need to be keyword-only. 188 final_concrete._num_positional_args = 1 189 else: 190 final_concrete._num_positional_args = 0 191 # pylint: enable=protected-access 192 concrete_signatures[signature_key] = final_concrete 193 # pylint: enable=cell-var-from-loop 194 return concrete_signatures, wrapped_functions 195 196 197def _normalize_outputs(outputs, function_name, signature_key): 198 """Normalize outputs if necessary and check that they are tensors.""" 199 # Convert `outputs` to a dictionary (if it's not one already). 200 if not isinstance(outputs, collections_abc.Mapping): 201 # Check if `outputs` is a namedtuple. 202 if hasattr(outputs, "_asdict"): 203 outputs = outputs._asdict() 204 else: 205 if not isinstance(outputs, collections_abc.Sequence): 206 outputs = [outputs] 207 outputs = {("output_{}".format(output_index)): output 208 for output_index, output in enumerate(outputs)} 209 210 # Check that the keys of `outputs` are strings and the values are Tensors. 211 for key, value in outputs.items(): 212 if not isinstance(key, compat.bytes_or_text_types): 213 raise ValueError( 214 f"Got a dictionary with a non-string key {key!r} in the output of " 215 f"the function {compat.as_str_any(function_name)} used to generate " 216 f"the SavedModel signature {signature_key!r}.") 217 if not isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): 218 raise ValueError( 219 f"Got a non-Tensor value {value!r} for key {key!r} in the output of " 220 f"the function {compat.as_str_any(function_name)} used to generate " 221 f"the SavedModel signature {signature_key!r}. " 222 "Outputs for functions used as signatures must be a single Tensor, " 223 "a sequence of Tensors, or a dictionary from string to Tensor.") 224 return outputs 225 226 227# _SignatureMap is immutable to ensure that users do not expect changes to be 228# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the 229# only way to create a _SignatureMap and there is no way to modify it. So we can 230# safely ignore/overwrite ".signatures" attributes attached to objects being 231# saved if they contain a _SignatureMap. A ".signatures" attribute containing 232# any other type (e.g. a regular dict) will raise an exception asking the user 233# to first "del obj.signatures" if they want it overwritten. 234class _SignatureMap(collections_abc.Mapping, base.Trackable): 235 """A collection of SavedModel signatures.""" 236 237 def __init__(self): 238 self._signatures = {} 239 240 def _add_signature(self, name, concrete_function): 241 """Adds a signature to the _SignatureMap.""" 242 # Ideally this object would be immutable, but restore is streaming so we do 243 # need a private API for adding new signatures to an existing object. 244 self._signatures[name] = concrete_function 245 246 def __getitem__(self, key): 247 return self._signatures[key] 248 249 def __iter__(self): 250 return iter(self._signatures) 251 252 def __len__(self): 253 return len(self._signatures) 254 255 def __repr__(self): 256 return "_SignatureMap({})".format(self._signatures) 257 258 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 259 if save_type != base.SaveType.SAVEDMODEL: 260 return {} 261 262 return { 263 key: value for key, value in self.items() 264 if isinstance(value, (def_function.Function, defun.ConcreteFunction)) 265 } 266 267 268revived_types.register_revived_type( 269 "signature_map", 270 lambda obj: isinstance(obj, _SignatureMap), 271 versions=[revived_types.VersionedTypeRegistration( 272 # Standard dependencies are enough to reconstruct the trackable 273 # items in dictionaries, so we don't need to save any extra information. 274 object_factory=lambda proto: _SignatureMap(), 275 version=1, 276 min_producer_version=1, 277 min_consumer_version=1, 278 setter=_SignatureMap._add_signature # pylint: disable=protected-access 279 )]) 280 281 282def create_signature_map(signatures): 283 """Creates an object containing `signatures`.""" 284 signature_map = _SignatureMap() 285 for name, func in signatures.items(): 286 # This true of any signature that came from canonicalize_signatures. Here as 287 # a sanity check on saving; crashing on load (e.g. in _add_signature) would 288 # be more problematic in case future export changes violated these 289 # assertions. 290 assert isinstance(func, defun.ConcreteFunction) 291 assert isinstance(func.structured_outputs, collections_abc.Mapping) 292 # pylint: disable=protected-access 293 if len(func._arg_keywords) == 1: 294 assert 1 == func._num_positional_args 295 else: 296 assert 0 == func._num_positional_args 297 signature_map._add_signature(name, func) 298 # pylint: enable=protected-access 299 return signature_map 300 301 302def validate_augmented_graph_view(augmented_graph_view): 303 """Performs signature-related sanity checks on `augmented_graph_view`.""" 304 for name, dep in augmented_graph_view.list_children( 305 augmented_graph_view.root): 306 if name == SIGNATURE_ATTRIBUTE_NAME: 307 if not isinstance(dep, _SignatureMap): 308 raise ValueError( 309 f"Exporting an object {augmented_graph_view.root} which has an attribute " 310 f"named '{SIGNATURE_ATTRIBUTE_NAME}'. This is a reserved attribute " 311 "used to store SavedModel signatures in objects which come from " 312 "`tf.saved_model.load`. Delete this attribute " 313 f"(e.g. `del obj.{SIGNATURE_ATTRIBUTE_NAME}`) before saving if " 314 "this shadowing is acceptable.") 315 break 316