1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SignatureDef utility functions implementation.""" 16 17 18from tensorflow.core.framework import types_pb2 19from tensorflow.core.protobuf import meta_graph_pb2 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import ops 22from tensorflow.python.saved_model import signature_constants 23from tensorflow.python.saved_model import utils_impl as utils 24from tensorflow.python.util import deprecation 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export( 29 v1=[ 30 'saved_model.build_signature_def', 31 'saved_model.signature_def_utils.build_signature_def' 32 ]) 33@deprecation.deprecated_endpoints( 34 'saved_model.signature_def_utils.build_signature_def') 35def build_signature_def(inputs=None, outputs=None, method_name=None): 36 """Utility function to build a SignatureDef protocol buffer. 37 38 Args: 39 inputs: Inputs of the SignatureDef defined as a proto map of string to 40 tensor info. 41 outputs: Outputs of the SignatureDef defined as a proto map of string to 42 tensor info. 43 method_name: Method name of the SignatureDef as a string. 44 45 Returns: 46 A SignatureDef protocol buffer constructed based on the supplied arguments. 47 """ 48 signature_def = meta_graph_pb2.SignatureDef() 49 if inputs is not None: 50 for item in inputs: 51 signature_def.inputs[item].CopyFrom(inputs[item]) 52 if outputs is not None: 53 for item in outputs: 54 signature_def.outputs[item].CopyFrom(outputs[item]) 55 if method_name is not None: 56 signature_def.method_name = method_name 57 return signature_def 58 59 60@tf_export( 61 v1=[ 62 'saved_model.regression_signature_def', 63 'saved_model.signature_def_utils.regression_signature_def' 64 ]) 65@deprecation.deprecated_endpoints( 66 'saved_model.signature_def_utils.regression_signature_def') 67def regression_signature_def(examples, predictions): 68 """Creates regression signature from given examples and predictions. 69 70 This function produces signatures intended for use with the TensorFlow Serving 71 Regress API (tensorflow_serving/apis/prediction_service.proto), and so 72 constrains the input and output types to those allowed by TensorFlow Serving. 73 74 Args: 75 examples: A string `Tensor`, expected to accept serialized tf.Examples. 76 predictions: A float `Tensor`. 77 78 Returns: 79 A regression-flavored signature_def. 80 81 Raises: 82 ValueError: If examples is `None`. 83 """ 84 if examples is None: 85 raise ValueError('Regression `examples` cannot be None.') 86 if not isinstance(examples, ops.Tensor): 87 raise ValueError('Expected regression `examples` to be of type Tensor. ' 88 f'Found `examples` of type {type(examples)}.') 89 if predictions is None: 90 raise ValueError('Regression `predictions` cannot be None.') 91 92 input_tensor_info = utils.build_tensor_info(examples) 93 if input_tensor_info.dtype != types_pb2.DT_STRING: 94 raise ValueError('Regression input tensors must be of type string. ' 95 f'Found tensors with type {input_tensor_info.dtype}.') 96 signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} 97 98 output_tensor_info = utils.build_tensor_info(predictions) 99 if output_tensor_info.dtype != types_pb2.DT_FLOAT: 100 raise ValueError('Regression output tensors must be of type float. ' 101 f'Found tensors with type {output_tensor_info.dtype}.') 102 signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} 103 104 signature_def = build_signature_def( 105 signature_inputs, signature_outputs, 106 signature_constants.REGRESS_METHOD_NAME) 107 108 return signature_def 109 110 111@tf_export( 112 v1=[ 113 'saved_model.classification_signature_def', 114 'saved_model.signature_def_utils.classification_signature_def' 115 ]) 116@deprecation.deprecated_endpoints( 117 'saved_model.signature_def_utils.classification_signature_def') 118def classification_signature_def(examples, classes, scores): 119 """Creates classification signature from given examples and predictions. 120 121 This function produces signatures intended for use with the TensorFlow Serving 122 Classify API (tensorflow_serving/apis/prediction_service.proto), and so 123 constrains the input and output types to those allowed by TensorFlow Serving. 124 125 Args: 126 examples: A string `Tensor`, expected to accept serialized tf.Examples. 127 classes: A string `Tensor`. Note that the ClassificationResponse message 128 requires that class labels are strings, not integers or anything else. 129 scores: a float `Tensor`. 130 131 Returns: 132 A classification-flavored signature_def. 133 134 Raises: 135 ValueError: If examples is `None`. 136 """ 137 if examples is None: 138 raise ValueError('Classification `examples` cannot be None.') 139 if not isinstance(examples, ops.Tensor): 140 raise ValueError('Classification `examples` must be a string Tensor. ' 141 f'Found `examples` of type {type(examples)}.') 142 if classes is None and scores is None: 143 raise ValueError('Classification `classes` and `scores` cannot both be ' 144 'None.') 145 146 input_tensor_info = utils.build_tensor_info(examples) 147 if input_tensor_info.dtype != types_pb2.DT_STRING: 148 raise ValueError('Classification input tensors must be of type string. ' 149 f'Found tensors of type {input_tensor_info.dtype}') 150 signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} 151 152 signature_outputs = {} 153 if classes is not None: 154 classes_tensor_info = utils.build_tensor_info(classes) 155 if classes_tensor_info.dtype != types_pb2.DT_STRING: 156 raise ValueError('Classification classes must be of type string Tensor. ' 157 f'Found tensors of type {classes_tensor_info.dtype}.`') 158 signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( 159 classes_tensor_info) 160 if scores is not None: 161 scores_tensor_info = utils.build_tensor_info(scores) 162 if scores_tensor_info.dtype != types_pb2.DT_FLOAT: 163 raise ValueError('Classification scores must be a float Tensor.') 164 signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( 165 scores_tensor_info) 166 167 signature_def = build_signature_def( 168 signature_inputs, signature_outputs, 169 signature_constants.CLASSIFY_METHOD_NAME) 170 171 return signature_def 172 173 174@tf_export( 175 v1=[ 176 'saved_model.predict_signature_def', 177 'saved_model.signature_def_utils.predict_signature_def' 178 ]) 179@deprecation.deprecated_endpoints( 180 'saved_model.signature_def_utils.predict_signature_def') 181def predict_signature_def(inputs, outputs): 182 """Creates prediction signature from given inputs and outputs. 183 184 This function produces signatures intended for use with the TensorFlow Serving 185 Predict API (tensorflow_serving/apis/prediction_service.proto). This API 186 imposes no constraints on the input and output types. 187 188 Args: 189 inputs: dict of string to `Tensor`. 190 outputs: dict of string to `Tensor`. 191 192 Returns: 193 A prediction-flavored signature_def. 194 195 Raises: 196 ValueError: If inputs or outputs is `None`. 197 """ 198 if inputs is None or not inputs: 199 raise ValueError('Prediction `inputs` cannot be None or empty.') 200 if outputs is None or not outputs: 201 raise ValueError('Prediction `outputs` cannot be None or empty.') 202 203 signature_inputs = {key: utils.build_tensor_info(tensor) 204 for key, tensor in inputs.items()} 205 signature_outputs = {key: utils.build_tensor_info(tensor) 206 for key, tensor in outputs.items()} 207 208 signature_def = build_signature_def( 209 signature_inputs, signature_outputs, 210 signature_constants.PREDICT_METHOD_NAME) 211 212 return signature_def 213 214 215# LINT.IfChange 216def supervised_train_signature_def( 217 inputs, loss, predictions=None, metrics=None): 218 return _supervised_signature_def( 219 signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, 220 predictions=predictions, metrics=metrics) 221 222 223def supervised_eval_signature_def( 224 inputs, loss, predictions=None, metrics=None): 225 return _supervised_signature_def( 226 signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, 227 predictions=predictions, metrics=metrics) 228 229 230def _supervised_signature_def( 231 method_name, inputs, loss=None, predictions=None, 232 metrics=None): 233 """Creates a signature for training and eval data. 234 235 This function produces signatures that describe the inputs and outputs 236 of a supervised process, such as training or evaluation, that 237 results in loss, metrics, and the like. Note that this function only requires 238 inputs to be not None. 239 240 Args: 241 method_name: Method name of the SignatureDef as a string. 242 inputs: dict of string to `Tensor`. 243 loss: dict of string to `Tensor` representing computed loss. 244 predictions: dict of string to `Tensor` representing the output predictions. 245 metrics: dict of string to `Tensor` representing metric ops. 246 247 Returns: 248 A train- or eval-flavored signature_def. 249 250 Raises: 251 ValueError: If inputs or outputs is `None`. 252 """ 253 if inputs is None or not inputs: 254 raise ValueError(f'{method_name} `inputs` cannot be None or empty.') 255 256 signature_inputs = {key: utils.build_tensor_info(tensor) 257 for key, tensor in inputs.items()} 258 259 signature_outputs = {} 260 for output_set in (loss, predictions, metrics): 261 if output_set is not None: 262 sig_out = {key: utils.build_tensor_info(tensor) 263 for key, tensor in output_set.items()} 264 signature_outputs.update(sig_out) 265 266 signature_def = build_signature_def( 267 signature_inputs, signature_outputs, method_name) 268 269 return signature_def 270# LINT.ThenChange(//keras/saving/utils_v1/signature_def_utils.py) 271 272 273@tf_export( 274 v1=[ 275 'saved_model.is_valid_signature', 276 'saved_model.signature_def_utils.is_valid_signature' 277 ]) 278@deprecation.deprecated_endpoints( 279 'saved_model.signature_def_utils.is_valid_signature') 280def is_valid_signature(signature_def): 281 """Determine whether a SignatureDef can be served by TensorFlow Serving.""" 282 if signature_def is None: 283 return False 284 return (_is_valid_classification_signature(signature_def) or 285 _is_valid_regression_signature(signature_def) or 286 _is_valid_predict_signature(signature_def)) 287 288 289def _is_valid_predict_signature(signature_def): 290 """Determine whether the argument is a servable 'predict' SignatureDef.""" 291 if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME: 292 return False 293 if not signature_def.inputs.keys(): 294 return False 295 if not signature_def.outputs.keys(): 296 return False 297 return True 298 299 300def _is_valid_regression_signature(signature_def): 301 """Determine whether the argument is a servable 'regress' SignatureDef.""" 302 if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME: 303 return False 304 305 if (set(signature_def.inputs.keys()) 306 != set([signature_constants.REGRESS_INPUTS])): 307 return False 308 if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype != 309 types_pb2.DT_STRING): 310 return False 311 312 if (set(signature_def.outputs.keys()) 313 != set([signature_constants.REGRESS_OUTPUTS])): 314 return False 315 if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype != 316 types_pb2.DT_FLOAT): 317 return False 318 319 return True 320 321 322def _is_valid_classification_signature(signature_def): 323 """Determine whether the argument is a servable 'classify' SignatureDef.""" 324 if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME: 325 return False 326 327 if (set(signature_def.inputs.keys()) 328 != set([signature_constants.CLASSIFY_INPUTS])): 329 return False 330 if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype != 331 types_pb2.DT_STRING): 332 return False 333 334 allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES, 335 signature_constants.CLASSIFY_OUTPUT_SCORES]) 336 337 if not signature_def.outputs.keys(): 338 return False 339 if set(signature_def.outputs.keys()) - allowed_outputs: 340 return False 341 if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs 342 and 343 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype 344 != types_pb2.DT_STRING): 345 return False 346 if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs 347 and 348 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype != 349 types_pb2.DT_FLOAT): 350 return False 351 352 return True 353 354 355def op_signature_def(op, key): 356 """Creates a signature def with the output pointing to an op. 357 358 Note that op isn't strictly enforced to be an Op object, and may be a Tensor. 359 It is recommended to use the build_signature_def() function for Tensors. 360 361 Args: 362 op: An Op (or possibly Tensor). 363 key: Key to graph element in the SignatureDef outputs. 364 365 Returns: 366 A SignatureDef with a single output pointing to the op. 367 """ 368 # Use build_tensor_info_from_op, which creates a TensorInfo from the element's 369 # name. 370 return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)}) 371 372 373def load_op_from_signature_def(signature_def, key, import_scope=None): 374 """Load an Op from a SignatureDef created by op_signature_def(). 375 376 Args: 377 signature_def: a SignatureDef proto 378 key: string key to op in the SignatureDef outputs. 379 import_scope: Scope used to import the op 380 381 Returns: 382 Op (or possibly Tensor) in the graph with the same name as saved in the 383 SignatureDef. 384 385 Raises: 386 NotFoundError: If the op could not be found in the graph. 387 """ 388 tensor_info = signature_def.outputs[key] 389 try: 390 # The init and train ops are not strictly enforced to be operations, so 391 # retrieve any graph element (can be either op or tensor). 392 return utils.get_element_from_tensor_info( 393 tensor_info, import_scope=import_scope) 394 except KeyError: 395 raise errors.NotFoundError( 396 None, None, 397 f'The key "{key}" could not be found in the graph. Please make sure the' 398 ' SavedModel was created by the internal _SavedModelBuilder. If you ' 399 'are using the public API, please make sure the SignatureDef in the ' 400 f'SavedModel does not contain the key "{key}".') 401