1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SavedModel simple save functionality.""" 16 17import os 18 19from tensorflow.python.framework import ops 20from tensorflow.python.ops import variables 21from tensorflow.python.platform import test 22from tensorflow.python.saved_model import loader 23from tensorflow.python.saved_model import signature_constants 24from tensorflow.python.saved_model import simple_save 25from tensorflow.python.saved_model import tag_constants 26 27 28class SimpleSaveTest(test.TestCase): 29 30 def _init_and_validate_variable(self, variable_name, variable_value): 31 v = variables.Variable(variable_value, name=variable_name) 32 self.evaluate(variables.global_variables_initializer()) 33 self.assertEqual(variable_value, self.evaluate(v)) 34 return v 35 36 def _check_variable_info(self, actual_variable, expected_variable): 37 self.assertEqual(actual_variable.name, expected_variable.name) 38 self.assertEqual(actual_variable.dtype, expected_variable.dtype) 39 self.assertEqual(len(actual_variable.shape), len(expected_variable.shape)) 40 for i in range(len(actual_variable.shape)): 41 self.assertEqual(actual_variable.shape[i], expected_variable.shape[i]) 42 43 def _check_tensor_info(self, actual_tensor_info, expected_tensor): 44 self.assertEqual(actual_tensor_info.name, expected_tensor.name) 45 self.assertEqual(actual_tensor_info.dtype, expected_tensor.dtype) 46 self.assertEqual( 47 len(actual_tensor_info.tensor_shape.dim), len(expected_tensor.shape)) 48 for i in range(len(actual_tensor_info.tensor_shape.dim)): 49 self.assertEqual(actual_tensor_info.tensor_shape.dim[i].size, 50 expected_tensor.shape[i]) 51 52 def testSimpleSave(self): 53 """Test simple_save that uses the default parameters.""" 54 export_dir = os.path.join(test.get_temp_dir(), 55 "test_simple_save") 56 57 # Force the test to run in graph mode. 58 # This tests a deprecated v1 API that both requires a session and uses 59 # functionality that does not work with eager tensors (such as 60 # build_tensor_info as called by predict_signature_def). 61 with ops.Graph().as_default(): 62 # Initialize input and output variables and save a prediction graph using 63 # the default parameters. 64 with self.session(graph=ops.Graph()) as sess: 65 var_x = self._init_and_validate_variable("var_x", 1) 66 var_y = self._init_and_validate_variable("var_y", 2) 67 inputs = {"x": var_x} 68 outputs = {"y": var_y} 69 simple_save.simple_save(sess, export_dir, inputs, outputs) 70 71 # Restore the graph with a valid tag and check the global variables and 72 # signature def map. 73 with self.session(graph=ops.Graph()) as sess: 74 graph = loader.load(sess, [tag_constants.SERVING], export_dir) 75 collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 76 77 # Check value and metadata of the saved variables. 78 self.assertEqual(len(collection_vars), 2) 79 self.assertEqual(1, collection_vars[0].eval()) 80 self.assertEqual(2, collection_vars[1].eval()) 81 self._check_variable_info(collection_vars[0], var_x) 82 self._check_variable_info(collection_vars[1], var_y) 83 84 # Check that the appropriate signature_def_map is created with the 85 # default key and method name, and the specified inputs and outputs. 86 signature_def_map = graph.signature_def 87 self.assertEqual(1, len(signature_def_map)) 88 self.assertEqual(signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 89 list(signature_def_map.keys())[0]) 90 91 signature_def = signature_def_map[ 92 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 93 self.assertEqual(signature_constants.PREDICT_METHOD_NAME, 94 signature_def.method_name) 95 96 self.assertEqual(1, len(signature_def.inputs)) 97 self._check_tensor_info(signature_def.inputs["x"], var_x) 98 self.assertEqual(1, len(signature_def.outputs)) 99 self._check_tensor_info(signature_def.outputs["y"], var_y) 100 101 102if __name__ == "__main__": 103 test.main() 104