xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/utils_v1/signature_def_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
16
17from tensorflow.python.keras.saving.utils_v1 import unexported_constants
18from tensorflow.python.saved_model import signature_def_utils
19from tensorflow.python.saved_model import utils_impl as utils
20
21
22# LINT.IfChange
23def supervised_train_signature_def(
24    inputs, loss, predictions=None, metrics=None):
25  return _supervised_signature_def(
26      unexported_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
27      predictions=predictions, metrics=metrics)
28
29
30def supervised_eval_signature_def(
31    inputs, loss, predictions=None, metrics=None):
32  return _supervised_signature_def(
33      unexported_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
34      predictions=predictions, metrics=metrics)
35
36
37def _supervised_signature_def(
38    method_name, inputs, loss=None, predictions=None,
39    metrics=None):
40  """Creates a signature for training and eval data.
41
42  This function produces signatures that describe the inputs and outputs
43  of a supervised process, such as training or evaluation, that
44  results in loss, metrics, and the like. Note that this function only requires
45  inputs to be not None.
46
47  Args:
48    method_name: Method name of the SignatureDef as a string.
49    inputs: dict of string to `Tensor`.
50    loss: dict of string to `Tensor` representing computed loss.
51    predictions: dict of string to `Tensor` representing the output predictions.
52    metrics: dict of string to `Tensor` representing metric ops.
53
54  Returns:
55    A train- or eval-flavored signature_def.
56
57  Raises:
58    ValueError: If inputs or outputs is `None`.
59  """
60  if inputs is None or not inputs:
61    raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
62
63  signature_inputs = {key: utils.build_tensor_info(tensor)
64                      for key, tensor in inputs.items()}
65
66  signature_outputs = {}
67  for output_set in (loss, predictions, metrics):
68    if output_set is not None:
69      sig_out = {key: utils.build_tensor_info(tensor)
70                 for key, tensor in output_set.items()}
71      signature_outputs.update(sig_out)
72
73  signature_def = signature_def_utils.build_signature_def(
74      signature_inputs, signature_outputs, method_name)
75
76  return signature_def
77# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py)
78