1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SignatureDef utility functions implementation.""" 16 17from tensorflow.python.keras.saving.utils_v1 import unexported_constants 18from tensorflow.python.saved_model import signature_def_utils 19from tensorflow.python.saved_model import utils_impl as utils 20 21 22# LINT.IfChange 23def supervised_train_signature_def( 24 inputs, loss, predictions=None, metrics=None): 25 return _supervised_signature_def( 26 unexported_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, 27 predictions=predictions, metrics=metrics) 28 29 30def supervised_eval_signature_def( 31 inputs, loss, predictions=None, metrics=None): 32 return _supervised_signature_def( 33 unexported_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, 34 predictions=predictions, metrics=metrics) 35 36 37def _supervised_signature_def( 38 method_name, inputs, loss=None, predictions=None, 39 metrics=None): 40 """Creates a signature for training and eval data. 41 42 This function produces signatures that describe the inputs and outputs 43 of a supervised process, such as training or evaluation, that 44 results in loss, metrics, and the like. Note that this function only requires 45 inputs to be not None. 46 47 Args: 48 method_name: Method name of the SignatureDef as a string. 49 inputs: dict of string to `Tensor`. 50 loss: dict of string to `Tensor` representing computed loss. 51 predictions: dict of string to `Tensor` representing the output predictions. 52 metrics: dict of string to `Tensor` representing metric ops. 53 54 Returns: 55 A train- or eval-flavored signature_def. 56 57 Raises: 58 ValueError: If inputs or outputs is `None`. 59 """ 60 if inputs is None or not inputs: 61 raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) 62 63 signature_inputs = {key: utils.build_tensor_info(tensor) 64 for key, tensor in inputs.items()} 65 66 signature_outputs = {} 67 for output_set in (loss, predictions, metrics): 68 if output_set is not None: 69 sig_out = {key: utils.build_tensor_info(tensor) 70 for key, tensor in output_set.items()} 71 signature_outputs.update(sig_out) 72 73 signature_def = signature_def_utils.build_signature_def( 74 signature_inputs, signature_outputs, method_name) 75 76 return signature_def 77# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py) 78