xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/signature_def_utils_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
16
17
18from tensorflow.core.framework import types_pb2
19from tensorflow.core.protobuf import meta_graph_pb2
20from tensorflow.python.framework import errors
21from tensorflow.python.framework import ops
22from tensorflow.python.saved_model import signature_constants
23from tensorflow.python.saved_model import utils_impl as utils
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export(
29    v1=[
30        'saved_model.build_signature_def',
31        'saved_model.signature_def_utils.build_signature_def'
32    ])
33@deprecation.deprecated_endpoints(
34    'saved_model.signature_def_utils.build_signature_def')
35def build_signature_def(inputs=None, outputs=None, method_name=None):
36  """Utility function to build a SignatureDef protocol buffer.
37
38  Args:
39    inputs: Inputs of the SignatureDef defined as a proto map of string to
40        tensor info.
41    outputs: Outputs of the SignatureDef defined as a proto map of string to
42        tensor info.
43    method_name: Method name of the SignatureDef as a string.
44
45  Returns:
46    A SignatureDef protocol buffer constructed based on the supplied arguments.
47  """
48  signature_def = meta_graph_pb2.SignatureDef()
49  if inputs is not None:
50    for item in inputs:
51      signature_def.inputs[item].CopyFrom(inputs[item])
52  if outputs is not None:
53    for item in outputs:
54      signature_def.outputs[item].CopyFrom(outputs[item])
55  if method_name is not None:
56    signature_def.method_name = method_name
57  return signature_def
58
59
60@tf_export(
61    v1=[
62        'saved_model.regression_signature_def',
63        'saved_model.signature_def_utils.regression_signature_def'
64    ])
65@deprecation.deprecated_endpoints(
66    'saved_model.signature_def_utils.regression_signature_def')
67def regression_signature_def(examples, predictions):
68  """Creates regression signature from given examples and predictions.
69
70  This function produces signatures intended for use with the TensorFlow Serving
71  Regress API (tensorflow_serving/apis/prediction_service.proto), and so
72  constrains the input and output types to those allowed by TensorFlow Serving.
73
74  Args:
75    examples: A string `Tensor`, expected to accept serialized tf.Examples.
76    predictions: A float `Tensor`.
77
78  Returns:
79    A regression-flavored signature_def.
80
81  Raises:
82    ValueError: If examples is `None`.
83  """
84  if examples is None:
85    raise ValueError('Regression `examples` cannot be None.')
86  if not isinstance(examples, ops.Tensor):
87    raise ValueError('Expected regression `examples` to be of type Tensor. '
88                     f'Found `examples` of type {type(examples)}.')
89  if predictions is None:
90    raise ValueError('Regression `predictions` cannot be None.')
91
92  input_tensor_info = utils.build_tensor_info(examples)
93  if input_tensor_info.dtype != types_pb2.DT_STRING:
94    raise ValueError('Regression input tensors must be of type string. '
95                     f'Found tensors with type {input_tensor_info.dtype}.')
96  signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
97
98  output_tensor_info = utils.build_tensor_info(predictions)
99  if output_tensor_info.dtype != types_pb2.DT_FLOAT:
100    raise ValueError('Regression output tensors must be of type float. '
101                     f'Found tensors with type {output_tensor_info.dtype}.')
102  signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
103
104  signature_def = build_signature_def(
105      signature_inputs, signature_outputs,
106      signature_constants.REGRESS_METHOD_NAME)
107
108  return signature_def
109
110
111@tf_export(
112    v1=[
113        'saved_model.classification_signature_def',
114        'saved_model.signature_def_utils.classification_signature_def'
115    ])
116@deprecation.deprecated_endpoints(
117    'saved_model.signature_def_utils.classification_signature_def')
118def classification_signature_def(examples, classes, scores):
119  """Creates classification signature from given examples and predictions.
120
121  This function produces signatures intended for use with the TensorFlow Serving
122  Classify API (tensorflow_serving/apis/prediction_service.proto), and so
123  constrains the input and output types to those allowed by TensorFlow Serving.
124
125  Args:
126    examples: A string `Tensor`, expected to accept serialized tf.Examples.
127    classes: A string `Tensor`.  Note that the ClassificationResponse message
128      requires that class labels are strings, not integers or anything else.
129    scores: a float `Tensor`.
130
131  Returns:
132    A classification-flavored signature_def.
133
134  Raises:
135    ValueError: If examples is `None`.
136  """
137  if examples is None:
138    raise ValueError('Classification `examples` cannot be None.')
139  if not isinstance(examples, ops.Tensor):
140    raise ValueError('Classification `examples` must be a string Tensor. '
141                     f'Found `examples` of type {type(examples)}.')
142  if classes is None and scores is None:
143    raise ValueError('Classification `classes` and `scores` cannot both be '
144                     'None.')
145
146  input_tensor_info = utils.build_tensor_info(examples)
147  if input_tensor_info.dtype != types_pb2.DT_STRING:
148    raise ValueError('Classification input tensors must be of type string. '
149                     f'Found tensors of type {input_tensor_info.dtype}')
150  signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
151
152  signature_outputs = {}
153  if classes is not None:
154    classes_tensor_info = utils.build_tensor_info(classes)
155    if classes_tensor_info.dtype != types_pb2.DT_STRING:
156      raise ValueError('Classification classes must be of type string Tensor. '
157                       f'Found tensors of type {classes_tensor_info.dtype}.`')
158    signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
159        classes_tensor_info)
160  if scores is not None:
161    scores_tensor_info = utils.build_tensor_info(scores)
162    if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
163      raise ValueError('Classification scores must be a float Tensor.')
164    signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
165        scores_tensor_info)
166
167  signature_def = build_signature_def(
168      signature_inputs, signature_outputs,
169      signature_constants.CLASSIFY_METHOD_NAME)
170
171  return signature_def
172
173
174@tf_export(
175    v1=[
176        'saved_model.predict_signature_def',
177        'saved_model.signature_def_utils.predict_signature_def'
178    ])
179@deprecation.deprecated_endpoints(
180    'saved_model.signature_def_utils.predict_signature_def')
181def predict_signature_def(inputs, outputs):
182  """Creates prediction signature from given inputs and outputs.
183
184  This function produces signatures intended for use with the TensorFlow Serving
185  Predict API (tensorflow_serving/apis/prediction_service.proto). This API
186  imposes no constraints on the input and output types.
187
188  Args:
189    inputs: dict of string to `Tensor`.
190    outputs: dict of string to `Tensor`.
191
192  Returns:
193    A prediction-flavored signature_def.
194
195  Raises:
196    ValueError: If inputs or outputs is `None`.
197  """
198  if inputs is None or not inputs:
199    raise ValueError('Prediction `inputs` cannot be None or empty.')
200  if outputs is None or not outputs:
201    raise ValueError('Prediction `outputs` cannot be None or empty.')
202
203  signature_inputs = {key: utils.build_tensor_info(tensor)
204                      for key, tensor in inputs.items()}
205  signature_outputs = {key: utils.build_tensor_info(tensor)
206                       for key, tensor in outputs.items()}
207
208  signature_def = build_signature_def(
209      signature_inputs, signature_outputs,
210      signature_constants.PREDICT_METHOD_NAME)
211
212  return signature_def
213
214
215# LINT.IfChange
216def supervised_train_signature_def(
217    inputs, loss, predictions=None, metrics=None):
218  return _supervised_signature_def(
219      signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
220      predictions=predictions, metrics=metrics)
221
222
223def supervised_eval_signature_def(
224    inputs, loss, predictions=None, metrics=None):
225  return _supervised_signature_def(
226      signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
227      predictions=predictions, metrics=metrics)
228
229
230def _supervised_signature_def(
231    method_name, inputs, loss=None, predictions=None,
232    metrics=None):
233  """Creates a signature for training and eval data.
234
235  This function produces signatures that describe the inputs and outputs
236  of a supervised process, such as training or evaluation, that
237  results in loss, metrics, and the like. Note that this function only requires
238  inputs to be not None.
239
240  Args:
241    method_name: Method name of the SignatureDef as a string.
242    inputs: dict of string to `Tensor`.
243    loss: dict of string to `Tensor` representing computed loss.
244    predictions: dict of string to `Tensor` representing the output predictions.
245    metrics: dict of string to `Tensor` representing metric ops.
246
247  Returns:
248    A train- or eval-flavored signature_def.
249
250  Raises:
251    ValueError: If inputs or outputs is `None`.
252  """
253  if inputs is None or not inputs:
254    raise ValueError(f'{method_name} `inputs` cannot be None or empty.')
255
256  signature_inputs = {key: utils.build_tensor_info(tensor)
257                      for key, tensor in inputs.items()}
258
259  signature_outputs = {}
260  for output_set in (loss, predictions, metrics):
261    if output_set is not None:
262      sig_out = {key: utils.build_tensor_info(tensor)
263                 for key, tensor in output_set.items()}
264      signature_outputs.update(sig_out)
265
266  signature_def = build_signature_def(
267      signature_inputs, signature_outputs, method_name)
268
269  return signature_def
270# LINT.ThenChange(//keras/saving/utils_v1/signature_def_utils.py)
271
272
273@tf_export(
274    v1=[
275        'saved_model.is_valid_signature',
276        'saved_model.signature_def_utils.is_valid_signature'
277    ])
278@deprecation.deprecated_endpoints(
279    'saved_model.signature_def_utils.is_valid_signature')
280def is_valid_signature(signature_def):
281  """Determine whether a SignatureDef can be served by TensorFlow Serving."""
282  if signature_def is None:
283    return False
284  return (_is_valid_classification_signature(signature_def) or
285          _is_valid_regression_signature(signature_def) or
286          _is_valid_predict_signature(signature_def))
287
288
289def _is_valid_predict_signature(signature_def):
290  """Determine whether the argument is a servable 'predict' SignatureDef."""
291  if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
292    return False
293  if not signature_def.inputs.keys():
294    return False
295  if not signature_def.outputs.keys():
296    return False
297  return True
298
299
300def _is_valid_regression_signature(signature_def):
301  """Determine whether the argument is a servable 'regress' SignatureDef."""
302  if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
303    return False
304
305  if (set(signature_def.inputs.keys())
306      != set([signature_constants.REGRESS_INPUTS])):
307    return False
308  if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
309      types_pb2.DT_STRING):
310    return False
311
312  if (set(signature_def.outputs.keys())
313      != set([signature_constants.REGRESS_OUTPUTS])):
314    return False
315  if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
316      types_pb2.DT_FLOAT):
317    return False
318
319  return True
320
321
322def _is_valid_classification_signature(signature_def):
323  """Determine whether the argument is a servable 'classify' SignatureDef."""
324  if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
325    return False
326
327  if (set(signature_def.inputs.keys())
328      != set([signature_constants.CLASSIFY_INPUTS])):
329    return False
330  if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
331      types_pb2.DT_STRING):
332    return False
333
334  allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
335                         signature_constants.CLASSIFY_OUTPUT_SCORES])
336
337  if not signature_def.outputs.keys():
338    return False
339  if set(signature_def.outputs.keys()) - allowed_outputs:
340    return False
341  if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
342      and
343      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
344      != types_pb2.DT_STRING):
345    return False
346  if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
347      and
348      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
349      types_pb2.DT_FLOAT):
350    return False
351
352  return True
353
354
355def op_signature_def(op, key):
356  """Creates a signature def with the output pointing to an op.
357
358  Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
359  It is recommended to use the build_signature_def() function for Tensors.
360
361  Args:
362    op: An Op (or possibly Tensor).
363    key: Key to graph element in the SignatureDef outputs.
364
365  Returns:
366    A SignatureDef with a single output pointing to the op.
367  """
368  # Use build_tensor_info_from_op, which creates a TensorInfo from the element's
369  # name.
370  return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
371
372
373def load_op_from_signature_def(signature_def, key, import_scope=None):
374  """Load an Op from a SignatureDef created by op_signature_def().
375
376  Args:
377    signature_def: a SignatureDef proto
378    key: string key to op in the SignatureDef outputs.
379    import_scope: Scope used to import the op
380
381  Returns:
382    Op (or possibly Tensor) in the graph with the same name as saved in the
383      SignatureDef.
384
385  Raises:
386    NotFoundError: If the op could not be found in the graph.
387  """
388  tensor_info = signature_def.outputs[key]
389  try:
390    # The init and train ops are not strictly enforced to be operations, so
391    # retrieve any graph element (can be either op or tensor).
392    return utils.get_element_from_tensor_info(
393        tensor_info, import_scope=import_scope)
394  except KeyError:
395    raise errors.NotFoundError(
396        None, None,
397        f'The key "{key}" could not be found in the graph. Please make sure the'
398        ' SavedModel was created by the internal _SavedModelBuilder. If you '
399        'are using the public API, please make sure the SignatureDef in the '
400        f'SavedModel does not contain the key "{key}".')
401