xref: /aosp_15_r20/external/tensorflow/tensorflow/python/saved_model/simple_save_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SavedModel simple save functionality."""
16
17import os
18
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import variables
21from tensorflow.python.platform import test
22from tensorflow.python.saved_model import loader
23from tensorflow.python.saved_model import signature_constants
24from tensorflow.python.saved_model import simple_save
25from tensorflow.python.saved_model import tag_constants
26
27
28class SimpleSaveTest(test.TestCase):
29
30  def _init_and_validate_variable(self, variable_name, variable_value):
31    v = variables.Variable(variable_value, name=variable_name)
32    self.evaluate(variables.global_variables_initializer())
33    self.assertEqual(variable_value, self.evaluate(v))
34    return v
35
36  def _check_variable_info(self, actual_variable, expected_variable):
37    self.assertEqual(actual_variable.name, expected_variable.name)
38    self.assertEqual(actual_variable.dtype, expected_variable.dtype)
39    self.assertEqual(len(actual_variable.shape), len(expected_variable.shape))
40    for i in range(len(actual_variable.shape)):
41      self.assertEqual(actual_variable.shape[i], expected_variable.shape[i])
42
43  def _check_tensor_info(self, actual_tensor_info, expected_tensor):
44    self.assertEqual(actual_tensor_info.name, expected_tensor.name)
45    self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype)
46    self.assertEqual(
47        len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape))
48    for i in range(len(actual_tensor_info.tensor_shape.dim)):
49      self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size,
50                       expected_tensor.shape[i])
51
52  def testSimpleSave(self):
53    """Test simple_save that uses the default parameters."""
54    export_dir = os.path.join(test.get_temp_dir(),
55                              "test_simple_save")
56
57    # Force the test to run in graph mode.
58    # This tests a deprecated v1 API that both requires a session and uses
59    # functionality that does not work with eager tensors (such as
60    # build_tensor_info as called by predict_signature_def).
61    with ops.Graph().as_default():
62      # Initialize input and output variables and save a prediction graph using
63      # the default parameters.
64      with self.session(graph=ops.Graph()) as sess:
65        var_x = self._init_and_validate_variable("var_x", 1)
66        var_y = self._init_and_validate_variable("var_y", 2)
67        inputs = {"x": var_x}
68        outputs = {"y": var_y}
69        simple_save.simple_save(sess, export_dir, inputs, outputs)
70
71      # Restore the graph with a valid tag and check the global variables and
72      # signature def map.
73      with self.session(graph=ops.Graph()) as sess:
74        graph = loader.load(sess, [tag_constants.SERVING], export_dir)
75        collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
76
77        # Check value and metadata of the saved variables.
78        self.assertEqual(len(collection_vars), 2)
79        self.assertEqual(1, collection_vars[0].eval())
80        self.assertEqual(2, collection_vars[1].eval())
81        self._check_variable_info(collection_vars[0], var_x)
82        self._check_variable_info(collection_vars[1], var_y)
83
84        # Check that the appropriate signature_def_map is created with the
85        # default key and method name, and the specified inputs and outputs.
86        signature_def_map = graph.signature_def
87        self.assertEqual(1, len(signature_def_map))
88        self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
89                         list(signature_def_map.keys())[0])
90
91        signature_def = signature_def_map[
92            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
93        self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
94                         signature_def.method_name)
95
96        self.assertEqual(1, len(signature_def.inputs))
97        self._check_tensor_info(signature_def.inputs["x"], var_x)
98        self.assertEqual(1, len(signature_def.outputs))
99        self._check_tensor_info(signature_def.outputs["y"], var_y)
100
101
102if __name__ == "__main__":
103  test.main()
104