1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for trackable object SavedModel save.""" 16 17import os 18 19from absl.testing import parameterized 20 21from google.protobuf import text_format 22 23from tensorflow.core.config import flags 24from tensorflow.core.framework import graph_pb2 25from tensorflow.core.protobuf import graph_debug_info_pb2 26from tensorflow.python.checkpoint import checkpoint 27from tensorflow.python.client import session as session_lib 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.distribute import mirrored_strategy 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import test 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import meta_graph 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import test_util 40from tensorflow.python.framework import versions 41from tensorflow.python.lib.io import file_io 42from tensorflow.python.module import module 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import lookup_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import resource_variable_ops 48from tensorflow.python.ops import variables 49from tensorflow.python.ops.ragged import ragged_factory_ops 50from tensorflow.python.ops.ragged import ragged_tensor 51from tensorflow.python.saved_model import load 52from tensorflow.python.saved_model import loader 53from tensorflow.python.saved_model import loader_impl 54from tensorflow.python.saved_model import save 55from tensorflow.python.saved_model import save_options 56from tensorflow.python.saved_model import signature_constants 57from tensorflow.python.saved_model import tag_constants 58from tensorflow.python.trackable import asset 59from tensorflow.python.trackable import autotrackable 60from tensorflow.python.training import saver 61from tensorflow.python.util import compat 62 63 64def _run_signature(session, meta_graph_def, inputs, signature_key): 65 signature = meta_graph_def.signature_def[signature_key] 66 assert set(inputs.keys()) == set(signature.inputs.keys()) 67 feed_dict = {} 68 for arg_name in inputs.keys(): 69 input_tensor = session.graph.get_tensor_by_name( 70 signature.inputs[arg_name].name) 71 feed_dict[input_tensor] = inputs[arg_name] 72 output_dict = {} 73 for output_name, output_tensor_info in signature.outputs.items(): 74 output_dict[output_name] = session.graph.get_tensor_by_name( 75 output_tensor_info.name) 76 return session.run(output_dict, feed_dict=feed_dict) 77 78 79def _import_and_infer( 80 save_dir, 81 inputs, 82 signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): 83 """Import a SavedModel into a TF 1.x-style graph and run `signature_key`.""" 84 graph = ops.Graph() 85 with graph.as_default(), session_lib.Session() as session: 86 model = loader.load(session, [tag_constants.SERVING], save_dir) 87 return _run_signature(session, model, inputs, signature_key) 88 89 90class SaveTest(test.TestCase, parameterized.TestCase): 91 92 def test_method_save_signature(self): 93 root = autotrackable.AutoTrackable() 94 root.f = def_function.function( 95 lambda x: 2. * x, 96 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 97 root.f(constant_op.constant(1.)) 98 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 99 save.save(root, save_dir, root.f) 100 self.assertEqual({"output_0": 2.}, _import_and_infer(save_dir, {"x": 1.})) 101 102 def test_method_save_list_func(self): 103 root = autotrackable.AutoTrackable() 104 105 @def_function.function 106 def case_fn(x): 107 branch_index = constant_op.constant(1) 108 branches = [lambda: x, lambda: x + 1] 109 case_out = control_flow_ops.switch_case(branch_index, branches) 110 return case_out 111 112 root.f = def_function.function( 113 lambda x: 2. * case_fn(x), 114 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 115 root.f(constant_op.constant(1.)) 116 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 117 save.save(root, save_dir, root.f) 118 self.assertEqual({"output_0": 4.}, _import_and_infer(save_dir, {"x": 1.})) 119 120 def test_method_save_concrete(self): 121 root = autotrackable.AutoTrackable() 122 root.f = def_function.function(lambda z: {"out": 2. * z}) 123 root.f(constant_op.constant(1.)) 124 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 125 save.save( 126 root, save_dir, { 127 "non_default_key": 128 root.f.get_concrete_function( 129 tensor_spec.TensorSpec(None, dtypes.float32)) 130 }) 131 self.assertEqual({"out": 2.}, 132 _import_and_infer( 133 save_dir, {"z": 1.}, signature_key="non_default_key")) 134 135 def test_method_save_annotated_function(self): 136 # This test is only meaningful with Python 3 because Python 2's 137 # inspect.getargspec doesn't save annotations. 138 139 root = autotrackable.AutoTrackable() 140 141 class UnknownType(object): # pylint: disable=unused-variable 142 pass 143 144 def annotated_function(z): 145 return {"out": 2. * z} 146 147 # Same effect as annotating function like the following. 148 # def annotated_function("z": UnknownType) -> UnknownType: 149 # This is a workaround since Python 2 does not support annotations and 150 # our presubmit linter catches it. 151 annotated_function.__annotations__ = { 152 "z": UnknownType, 153 "return": UnknownType 154 } 155 156 root.f = def_function.function(annotated_function) 157 root.f(constant_op.constant(1.)) 158 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 159 save.save( 160 root, save_dir, { 161 "non_default_key": 162 root.f.get_concrete_function( 163 tensor_spec.TensorSpec(None, dtypes.float32)) 164 }) 165 self.assertEqual({"out": 2.}, 166 _import_and_infer( 167 save_dir, {"z": 1.}, signature_key="non_default_key")) 168 169 def test_unsaveable_func_graph(self): 170 root = module.Module() 171 172 @def_function.function(input_signature=[]) 173 def nested_f(): 174 ops.get_default_graph().mark_as_unsaveable("ERROR MSG") 175 return 1 176 177 @def_function.function(input_signature=[]) 178 def f(): 179 return nested_f() 180 181 root.f = f 182 with self.assertRaisesRegex(ValueError, "ERROR MSG"): 183 save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) 184 185 def test_untracked_variable_useful_message(self): 186 root = module.Module() 187 v = variables.Variable(1., name="some_unique_name") 188 189 @def_function.function(input_signature=[]) 190 def f(): 191 return v.read_value() 192 193 root.f = f 194 with self.assertRaisesRegex( 195 AssertionError, "Trackable referencing this tensor.*some_unique_name"): 196 save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) 197 198 def test_version_information_included(self): 199 root = autotrackable.AutoTrackable() 200 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 201 save.save(root, save_dir) 202 saved_model_proto = loader_impl.parse_saved_model(save_dir) 203 self.assertEqual( 204 versions.__version__, 205 saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_version) 206 self.assertEqual( 207 versions.__git_version__, 208 saved_model_proto.meta_graphs[0].meta_info_def.tensorflow_git_version) 209 210 def test_non_concrete_error(self): 211 root = autotrackable.AutoTrackable() 212 root.f = def_function.function(lambda x: 2. * x) 213 root.f(constant_op.constant(1.)) 214 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 215 with self.assertRaisesRegex(ValueError, "Expected a TensorFlow function"): 216 save.save(root, save_dir, root.f) 217 218 def test_captures_unreachable_variable(self): 219 root = autotrackable.AutoTrackable() 220 unreachable_variable = variables.Variable([5.0, 2.0]) 221 root.reachable_variable = variables.Variable([1.0, 3.0]) 222 223 @def_function.function 224 def increase_variable(x): 225 return 2 * unreachable_variable * x + root.reachable_variable 226 227 root.f = increase_variable 228 229 self.assertAllEqual([101.0, 83.0], 230 root.f(constant_op.constant([10.0, 20.0])).numpy()) 231 232 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 233 234 with self.assertRaisesRegex(KeyError, "not reachable from root"): 235 save.save(root, save_dir) 236 237 def test_nested_inputs(self): 238 root = autotrackable.AutoTrackable() 239 root.f = def_function.function( 240 lambda x: 2. * x[0], 241 input_signature=([ 242 tensor_spec.TensorSpec(None, dtypes.float32), 243 tensor_spec.TensorSpec(None, dtypes.float32) 244 ],)) 245 root.f([constant_op.constant(1.), constant_op.constant(1.)]) 246 247 def test_nested_outputs(self): 248 root = autotrackable.AutoTrackable() 249 root.f = def_function.function(lambda x: (2. * x, (3. * x, 4. * x))) 250 root.f(constant_op.constant(1.)) 251 to_save = root.f.get_concrete_function(constant_op.constant(1.)) 252 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 253 with self.assertRaisesRegex(ValueError, "non-Tensor value"): 254 save.save(root, save_dir, to_save) 255 256 def test_nested_dict_outputs(self): 257 root = checkpoint.Checkpoint( 258 f=def_function.function(lambda x: { # pylint: disable=g-long-lambda 259 "a": 2. * x, 260 "b": (3. * x, 4. * x) 261 })) 262 root.f(constant_op.constant(1.)) 263 to_save = root.f.get_concrete_function(constant_op.constant(1.)) 264 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 265 with self.assertRaisesRegex(ValueError, "non-Tensor value"): 266 save.save(root, save_dir, to_save) 267 268 def test_variable(self): 269 root = autotrackable.AutoTrackable() 270 root.v1 = variables.Variable(3.) 271 root.v2 = variables.Variable(2.) 272 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 273 root.f(constant_op.constant(1.)) 274 to_save = root.f.get_concrete_function(constant_op.constant(1.)) 275 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 276 save.save(root, save_dir, to_save) 277 self.assertAllEqual({"output_0": 12.}, 278 _import_and_infer(save_dir, {"x": 2.})) 279 280 def test_single_function_default_signature(self): 281 model = autotrackable.AutoTrackable() 282 model.f = def_function.function(lambda: 3., input_signature=()) 283 model.f() 284 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 285 save.save(model, save_dir) 286 self.assertAllClose({"output_0": 3.}, _import_and_infer(save_dir, {})) 287 288 def test_single_function_no_signature(self): 289 model = autotrackable.AutoTrackable() 290 model.f = def_function.function(lambda: 3.) 291 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 292 save.save(model, save_dir) 293 294 def test_save_function_no_trace(self): 295 296 class ObjWithFunction(module.Module): 297 298 @def_function.function 299 def foo(self, a): 300 return a 301 302 @def_function.function 303 def bar(self, a): 304 return a + 1 305 306 root = ObjWithFunction() 307 root.bar(1) 308 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 309 with self.assertLogs(level="WARNING") as logs: 310 save.save(root, save_dir) 311 312 expected_message = ( 313 "WARNING:absl:Found untraced functions such as foo while saving " 314 "(showing 1 of 1). These functions will not be directly callable after " 315 "loading.") 316 self.assertIn(expected_message, logs.output) 317 318 def test_find_default_save_function(self): 319 320 class ObjWithDefaultSignature(checkpoint.Checkpoint): 321 322 @def_function.function(input_signature=[ 323 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) 324 ]) 325 def _default_save_signature(self, x): 326 return x + x + 1 327 328 obj = ObjWithDefaultSignature() 329 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 330 save.save(obj, save_dir) 331 self.assertAllClose({"output_0": 7.}, 332 _import_and_infer(save_dir, {"x": 3.})) 333 334 def test_docstring(self): 335 336 class Adder(module.Module): 337 338 @def_function.function(input_signature=[ 339 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) 340 ]) 341 def add(self, x): 342 return x + x + 1. 343 344 to_save = Adder() 345 to_save.add(constant_op.constant(1.)) 346 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 347 save.save(to_save, save_dir) 348 self.assertAllClose({"output_0": 7.}, 349 _import_and_infer(save_dir, {"x": 3.})) 350 351 def test_datastructures(self): 352 353 class HasDatastructures(checkpoint.Checkpoint): 354 355 def __init__(self): 356 self.a = [1.] 357 self.a.append(variables.Variable(2.)) 358 self.b = {"a": variables.Variable(3.)} 359 360 @def_function.function(input_signature=[ 361 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) 362 ]) 363 def add(self, x): 364 return x + math_ops.add_n(self.a) + self.b["a"] 365 366 to_save = HasDatastructures() 367 to_save.add(constant_op.constant(1.)) 368 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 369 save.save(to_save, save_dir) 370 self.assertAllClose({"output_0": 10.}, 371 _import_and_infer(save_dir, {"x": 4.})) 372 373 def test_default_attr_stripping(self): 374 375 class Complex(checkpoint.Checkpoint): 376 377 @def_function.function(input_signature=[]) 378 def __call__(self): 379 return math_ops.complex( 380 constant_op.constant(1.), constant_op.constant(2.), name="complex") 381 382 to_save = Complex() 383 to_save() 384 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 385 save.save(to_save, save_dir) 386 graph = ops.Graph() 387 with graph.as_default(), self.session(graph) as session: 388 loader.load(session, [tag_constants.SERVING], save_dir) 389 func, = [f for name, f in graph._functions.items() if "call" in name] 390 complex_node, = [ 391 node for node in func.definition.node_def if node.op == "Complex" 392 ] 393 self.assertNotIn("T", complex_node.attr) 394 self.assertNotIn("Tout", complex_node.attr) 395 396 def test_signature_attribute_reserved(self): 397 root = checkpoint.Checkpoint(signatures=variables.Variable(1.)) 398 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 399 with self.assertRaisesRegex(ValueError, "del obj.signatures"): 400 save.save(root, save_dir) 401 del root.signatures 402 save.save(root, save_dir) 403 404 def test_function_with_captured_dataset(self): 405 if test_util.is_gpu_available(): 406 self.skipTest("Currently broken when a GPU is available.") 407 408 class HasDataset(module.Module): 409 410 def __init__(self): 411 super(HasDataset, self).__init__() 412 self.dataset = (dataset_ops.Dataset.range(5).map(lambda x: x**2)) 413 414 @def_function.function 415 def __call__(self, x): 416 current_sum = array_ops.zeros([], dtype=dtypes.int64) 417 for element in self.dataset: 418 current_sum += x * element 419 return current_sum 420 421 root = HasDataset() 422 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 423 save.save( 424 root, 425 save_dir, 426 signatures=root.__call__.get_concrete_function( 427 tensor_spec.TensorSpec(None, dtypes.int64))) 428 self.assertAllClose({"output_0": 3 * (1 + 4 + 9 + 16)}, 429 _import_and_infer(save_dir, {"x": 3})) 430 431 def test_variable_args_cannot_be_used_as_signature(self): 432 433 @def_function.function(input_signature=[ 434 resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) 435 ]) 436 def f(unused_v): 437 return 1 438 439 root = autotrackable.AutoTrackable() 440 root.f = f.get_concrete_function() 441 with self.assertRaisesRegex(ValueError, 442 "tf.Variable inputs cannot be exported"): 443 save.save( 444 root, 445 os.path.join(self.get_temp_dir(), "saved_model"), 446 signatures=root.f) 447 448 def test_export_correct_output_shapes(self): 449 """Asserts that nodes are exported with the correct number of output shapes. 450 451 After backpropagation rewrite, functions are rewritten with additional 452 outputs. When exporting to SavedModel, the shapes of the additional outputs 453 were incorrectly added to the FunctionDef proto (b/133666530). 454 """ 455 obj = autotrackable.AutoTrackable() 456 obj.v = variables.Variable(2.) 457 458 @def_function.function( 459 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 460 def f(x): 461 return (math_ops.multiply(obj.v, x), math_ops.multiply(obj.v, 462 (x + 1)), None) 463 464 obj.f = f 465 466 @def_function.function( 467 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 468 def g(x): 469 return obj.f(x)[1] 470 471 obj.g = g 472 473 # After the following lines, the concrete functions of obj.g and obj.f are 474 # rewritten with many extra outputs. 475 with backprop.GradientTape(): 476 obj.g(constant_op.constant(3.0)) 477 478 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 479 save.save(obj, save_dir, signatures={"g": obj.g}) 480 graph_def = loader_impl.parse_saved_model(save_dir).meta_graphs[0].graph_def 481 482 def assert_correct_number_of_output_shapes(node): 483 if node.op == "StatefulPartitionedCall": 484 fn_name = node.attr["f"].func.name 485 if fn_name.startswith("__inference_f"): 486 self.assertLen(node.attr["_output_shapes"].list.shape, 2) 487 if fn_name.startswith("__inference_g"): 488 self.assertLen(node.attr["_output_shapes"].list.shape, 1) 489 490 for f in graph_def.library.function: 491 if (f.signature.name.startswith("__inference_f") or 492 f.signature.name.startswith("__inference_g")): 493 for node in f.node_def: 494 assert_correct_number_of_output_shapes(node) 495 496 def test_save_cached_variable(self): 497 with ops.Graph().as_default(), session_lib.Session() as session: 498 obj = autotrackable.AutoTrackable() 499 obj.v = variables.Variable(2., caching_device=lambda op: op.device) 500 obj.w = variables.Variable(3.) 501 session.run([obj.v.initializer, obj.w.initializer]) 502 503 @def_function.function(input_signature=[]) 504 def f(): 505 return obj.v + obj.w 506 507 obj.f = f 508 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 509 save.save(obj, save_dir, signatures=obj.f) 510 self.assertAllClose({"output_0": 5}, _import_and_infer(save_dir, {})) 511 512 @parameterized.named_parameters( 513 ("_SaveDevices_ExportMetaGraph", 514 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, True), 515 ("_DiscardDevices_ExportMetaGraph", save_options.VariablePolicy.NONE, 516 True), ("_SaveDevices_Save", 517 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, False), 518 ("_DiscardDevices_Save", save_options.VariablePolicy.NONE, False)) 519 def test_save_variable_devices(self, save_devices, meta_graph_only): 520 context._reset_context() 521 cpus = context.context().list_physical_devices("CPU") 522 if len(cpus) == 1: 523 context.context().set_logical_device_configuration( 524 cpus[0], [ 525 context.LogicalDeviceConfiguration(), 526 context.LogicalDeviceConfiguration() 527 ]) 528 context.ensure_initialized() 529 530 root = autotrackable.AutoTrackable() 531 with ops.device("CPU:0"): 532 root.v0 = variables.Variable(1., name="v0") 533 with ops.device("CPU:1"): 534 root.v1 = variables.Variable(1., name="v1") 535 536 options = save_options.SaveOptions( 537 experimental_variable_policy=save_devices) 538 file_name = os.path.join(self.get_temp_dir(), "saved_model") 539 if meta_graph_only: 540 save.export_meta_graph(obj=root, filename=file_name, options=options) 541 else: 542 save.save(obj=root, export_dir=file_name, options=options) 543 544 meta = None 545 if meta_graph_only: 546 meta = meta_graph.read_meta_graph_file(file_name) 547 else: 548 meta = loader_impl.parse_saved_model(file_name).meta_graphs[0] 549 550 # Check devices in meta graph nodes. 551 graph_def = meta.graph_def 552 v0 = next((n for n in graph_def.node if n.name == "v0"), None) 553 v1 = next((n for n in graph_def.node if n.name == "v1"), None) 554 self.assertIsNotNone(v0) 555 self.assertIsNotNone(v1) 556 if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES: 557 self.assertIn("CPU:0", v0.device) 558 self.assertIn("CPU:1", v1.device) 559 else: 560 self.assertEmpty(v0.device) 561 self.assertEmpty(v1.device) 562 563 # Check devices in object graph nodes. 564 object_graph_def = meta.object_graph_def 565 v0 = next((n.variable 566 for n in object_graph_def.nodes 567 if n.HasField("variable") and n.variable.name == "v0"), None) 568 v1 = next((n.variable 569 for n in object_graph_def.nodes 570 if n.HasField("variable") and n.variable.name == "v1"), None) 571 self.assertIsNotNone(v0) 572 self.assertIsNotNone(v1) 573 if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES: 574 self.assertIn("CPU:0", v0.device) 575 self.assertIn("CPU:1", v1.device) 576 else: 577 self.assertEmpty(v0.device) 578 self.assertEmpty(v1.device) 579 580 @parameterized.named_parameters( 581 ("_ExpandDistributedVariablesWithPolicy", 582 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, True), 583 ("_ExpandDistributedVariablesWithoutPolicy", 584 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, False), 585 ("_DiscardDistributedVariablesWithPolicy", 586 save_options.VariablePolicy.NONE, True), 587 ("_DiscardDistributedVariablesWithoutPolicy", 588 save_options.VariablePolicy.NONE, False)) 589 def test_expand_distributed_variables(self, expand_strategy, policy): 590 # 1. Create a context with both CPU:0 and CPU:1. 591 context._reset_context() 592 cpus = context.context().list_physical_devices("CPU") 593 if len(cpus) == 1: 594 context.context().set_logical_device_configuration( 595 cpus[0], [ 596 context.LogicalDeviceConfiguration(), 597 context.LogicalDeviceConfiguration() 598 ]) 599 context.ensure_initialized() 600 601 # 2. Create and save a model under a mirrored strategy. 602 file_name = os.path.join(self.get_temp_dir(), "saved_model.pb") 603 strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]) 604 strategy.extended._use_var_policy = policy 605 with strategy.scope(): 606 root = autotrackable.AutoTrackable() 607 root.v = variables.Variable([1., 1.], name="v") 608 609 @def_function.function(input_signature=[]) 610 def f(): 611 root.v.assign([2., 2.]) 612 613 root.f = f 614 615 save.export_meta_graph( 616 obj=root, 617 filename=file_name, 618 options=save_options.SaveOptions( 619 experimental_variable_policy=expand_strategy)) 620 621 # 3. Read the output file and test behavior. 622 meta_graph_def = meta_graph.read_meta_graph_file(file_name) 623 object_graph = meta_graph_def.object_graph_def 624 graph_def = meta_graph_def.graph_def 625 v = next((n.variable 626 for n in object_graph.nodes 627 if n.HasField("variable") and n.variable.name == "v"), None) 628 saved_function = next((f for f in graph_def.library.function 629 if "inference_f_" in f.signature.name), None) 630 self.assertIsNotNone(saved_function) 631 if (expand_strategy == 632 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES): 633 # experimental_save_variable_devices should have been automatically set. 634 self.assertIn("CPU:0", v.device) 635 components = v.experimental_distributed_variable_components 636 self.assertLen(components, 2) 637 v0 = next((x for x in components if x.name == "v"), None) 638 v1 = next((x for x in components if x.name == "v/replica_1"), None) 639 self.assertIsNotNone(v0) 640 self.assertIsNotNone(v1) 641 self.assertIn("CPU:0", v0.device) 642 self.assertIn("CPU:1", v1.device) 643 self.assertLen(saved_function.signature.input_arg, 2) 644 else: 645 self.assertEmpty(v.device) 646 self.assertEmpty(v.experimental_distributed_variable_components) 647 self.assertLen(saved_function.signature.input_arg, 1) 648 649 def test_save_uninitialized_variable(self): 650 root = autotrackable.AutoTrackable() 651 root.uninitialized_variable = resource_variable_ops.UninitializedVariable( 652 name="uninitialized_variable", dtype=dtypes.float32) 653 root.initialized_variable = variables.Variable( 654 1.0, name="initialized_variable") 655 656 # TODO(b/149594077): Python loading does not work now partly because it 657 # shouldn't, as the public API and semantics of uninitialized variables 658 # are not properly defined, and officially supporting loading would end up 659 # defining semantics "by usage." We should only allow loading once the API 660 # is made official. 661 export_dir = os.path.join(self.get_temp_dir(), "saved_model") 662 save.save(root, export_dir) 663 with self.assertRaisesRegex(FileNotFoundError, 664 "Key uninitialized_variable"): 665 load.load(export_dir) 666 with ops.Graph().as_default(), session_lib.Session() as session: 667 # The final ValueError here (with "no variables to save") is confusing, 668 # but errors upstream give the user the correct information (a 669 # NotFoundError stating that the uninitalized_variable was not found in 670 # the checkpoint). 671 with self.assertRaises(ValueError): 672 loader.load(session, [tag_constants.SERVING], export_dir) 673 674 def test_concrete_function_with_set_shape(self,): 675 # Serialized concrete function should retain the shape from the TensorSpec, 676 # instead of using the shape of the inputs (which are changed by set_shape). 677 @def_function.function 678 def f(x): 679 x.set_shape((5, 1)) 680 return x 681 682 root = autotrackable.AutoTrackable() 683 path = os.path.join(self.get_temp_dir(), "saved_model") 684 concrete = f.get_concrete_function( 685 tensor_spec.TensorSpec((None, 1), name="name")) 686 save.save(root, path, signatures={"key": concrete}) 687 imported = load.load(path) 688 self.assertEqual(imported.signatures["key"].structured_input_signature[1], 689 {"name": tensor_spec.TensorSpec((None, 1), name="name")}) 690 691 def test_save_composite_tensor_signature(self): 692 @def_function.function( 693 input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)]) 694 def f(x): 695 return {"output_key": x} 696 root = autotrackable.AutoTrackable() 697 path = os.path.join(self.get_temp_dir(), "saved_model") 698 inp = ragged_factory_ops.constant([[[1.0, 2.0], [3.0]], [[5.]]]) 699 flat_inp = { 700 "x": constant_op.constant([1., 2., 3., 5]), 701 "x_1": constant_op.constant([0, 2, 3], dtype=dtypes.int64), 702 "x_2": constant_op.constant([0, 2, 3, 4], dtype=dtypes.int64) 703 } 704 save.save(root, path, signatures={"key": f.get_concrete_function()}) 705 706 # Test that the ragged signature can be loaded back into Python with V2 APIs 707 imported = load.load(path) 708 self.assertAllEqual(inp, 709 imported.signatures["key"](**flat_inp)["output_key"]) 710 graph = ops.Graph() 711 712 # Try running the signature with V1 APIs. 713 with graph.as_default(), session_lib.Session() as session: 714 meta_graph_def = loader.load(session, [tag_constants.SERVING], path) 715 signature = meta_graph_def.signature_def["key"] 716 717 feed_dict = {} 718 for arg_name in flat_inp: 719 input_tensor = session.graph.get_tensor_by_name( 720 signature.inputs[arg_name].name) 721 feed_dict[input_tensor] = flat_inp[arg_name].numpy() 722 723 # Get composite tensor components 724 output_components = ( 725 signature.outputs["output_key"].composite_tensor.components) 726 fetches = {} 727 components_keys = ["x", "x_1", "x_2"] 728 for k, output_tensor_info in zip(components_keys, output_components): 729 fetches[k] = session.graph.get_tensor_by_name(output_tensor_info.name) 730 731 outputs = session.run(fetches, feed_dict) 732 733 self.assertAllClose(flat_inp, outputs) 734 735 def test_save_uses_sanitized_signature_name(self): 736 737 @def_function.function( 738 input_signature=[ragged_tensor.RaggedTensorSpec(ragged_rank=2)]) 739 def f(x): 740 return {"output_key": x} 741 742 # Colons are not usable as name scopes. 743 unsanitized_name = "foo:bar" 744 root = autotrackable.AutoTrackable() 745 path = os.path.join(self.get_temp_dir(), "saved_model") 746 save.save( 747 root, path, signatures={unsanitized_name: f.get_concrete_function()}) 748 graph = ops.Graph() 749 with graph.as_default(), session_lib.Session() as session: 750 meta_graph_def = loader.load(session, [tag_constants.SERVING], path) 751 signature = meta_graph_def.signature_def[unsanitized_name] 752 tensor_names = [ 753 session.graph.get_tensor_by_name(signature.inputs[key].name).name 754 for key in signature.inputs 755 ] 756 # The placeholder names will have the sanitized version. 757 self.assertCountEqual(tensor_names, 758 ["foo_bar_x:0", "foo_bar_x_1:0", "foo_bar_x_2:0"]) 759 760 def test_save_returns_none(self): 761 # Test that `tf.saved_model.save` API returns None to user. 762 root = autotrackable.AutoTrackable() 763 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 764 result = save.save(root, save_dir) 765 self.assertIsNone(result) 766 767 768class DependencyTest(test.TestCase): 769 """Tests for deserialization dependencies (saving-related only).""" 770 771 def test_validate_dependencies(self): 772 773 class Valid(autotrackable.AutoTrackable): 774 775 def _deserialization_dependencies(self, children): 776 return children 777 778 root = Valid() 779 root.f = variables.Variable(1.0) 780 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 781 save.save(root, save_dir) 782 783 def test_validate_dependencies_error_untracked(self): 784 untracked = variables.Variable(1.0) 785 786 class Invalid(autotrackable.AutoTrackable): 787 788 def _deserialization_dependencies(self, children): 789 del children # Unused. 790 return {"untracked": untracked} 791 invalid_deps = Invalid() 792 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 793 with self.assertRaisesRegex(ValueError, "Found an untracked dependency"): 794 save.save(invalid_deps, save_dir) 795 796 def test_validate_dependencies_error_cyclic(self): 797 798 class Invalid(autotrackable.AutoTrackable): 799 800 def __init__(self): 801 self.cycle_ref = None 802 803 def _deserialization_dependencies(self, children): 804 del children # Unused. 805 return {"cycle_ref": self.cycle_ref} 806 cycle1 = Invalid() 807 cycle2 = Invalid() 808 cycle1.cycle_ref = cycle2 809 cycle2.cycle_ref = cycle1 810 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 811 with self.assertRaisesRegex(ValueError, 812 "dependency cycle in the saved Trackable"): 813 save.save(cycle1, save_dir) 814 815 816class VariablePolicyEnumTest(test.TestCase): 817 818 def testFromObj(self): 819 self.assertEqual(save_options.VariablePolicy.NONE, 820 save_options.VariablePolicy.from_obj(None)) 821 self.assertEqual( 822 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, 823 save_options.VariablePolicy.from_obj( 824 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES)) 825 self.assertEqual( 826 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, 827 save_options.VariablePolicy.from_obj( 828 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)) 829 self.assertEqual( 830 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, 831 save_options.VariablePolicy.from_obj("save_variable_devices")) 832 self.assertEqual( 833 save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, 834 save_options.VariablePolicy.from_obj("SaVe_VaRiAbLe_DeViCeS")) 835 self.assertEqual( 836 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, 837 save_options.VariablePolicy.from_obj("expand_distributed_variables")) 838 self.assertEqual( 839 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, 840 save_options.VariablePolicy.from_obj("eXpAnD_dIsTrIbUtEd_VaRiAbLeS")) 841 for invalid in ["not_a_valid_value", 2.0, []]: 842 with self.assertRaisesRegex(ValueError, "invalid VariablePolicy value"): 843 save_options.VariablePolicy.from_obj(invalid) 844 845 def testNamingConvention(self): 846 """Enforces names are uppercase versions of values.""" 847 for policy in save_options.VariablePolicy: 848 if policy == save_options.VariablePolicy.NONE: 849 self.assertIsNone(policy.value) 850 else: 851 self.assertEqual(policy.name, policy.name.upper()) 852 self.assertEqual(policy.value, policy.value.lower()) 853 self.assertEqual(policy.name, policy.value.upper()) 854 855 856class SavingOptionsTest(test.TestCase): 857 858 def testOpNameSpace(self): 859 # TODO(kathywu): Add test that saves out SavedModel with a custom op when 860 # the ">" character is allowed in op names. 861 graph_def = graph_pb2.GraphDef() 862 text_format.Parse("node { name: 'A' op: 'Test>CustomOp' }", graph_def) 863 with self.assertRaisesRegex( 864 ValueError, "Attempted to save ops from non-whitelisted namespaces"): 865 save._verify_ops(graph_def, []) 866 save._verify_ops(graph_def, ["Test"]) 867 868 # Test with multiple carrots in op name. 869 text_format.Parse("node { name: 'A' op: 'Test>>A>CustomOp' }", graph_def) 870 with self.assertRaisesRegex( 871 ValueError, "Attempted to save ops from non-whitelisted namespaces"): 872 save._verify_ops(graph_def, []) 873 save._verify_ops(graph_def, ["Test"]) 874 875 def test_save_custom_op_with_no_whitelist_specified(self): 876 # Test that we are able to save a model that contains a custom op with a 877 # custom namespace when the user has not explicitly specified a namespace 878 # whitelist (i.e. that we default to allowing all custom ops when saving 879 # and no whitelist is specified, rather than throwing an exception). 880 graph_def = graph_pb2.GraphDef() 881 text_format.Parse("node { name: 'A' op: 'Test>CustomOp' }", graph_def) 882 save._verify_ops(graph_def, namespace_whitelist=None) 883 884 # If the user passes an empty list for the namespace whitelist rather than 885 # nothing, we should then throw an exception if a custom op is used. 886 with self.assertRaisesRegex( 887 ValueError, "Attempted to save ops from non-whitelisted namespaces"): 888 save._verify_ops(graph_def, []) 889 890 def test_save_debug_info_enabled(self): 891 root = autotrackable.AutoTrackable() 892 root.f = def_function.function( 893 lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"), 894 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 895 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 896 save.save( 897 root, 898 save_dir, 899 root.f, 900 options=save_options.SaveOptions(save_debug_info=True)) 901 debug_info_file_name = os.path.join(save_dir, "debug", 902 "saved_model_debug_info.pb") 903 self.assertTrue(os.path.exists(debug_info_file_name)) 904 debug_info = graph_debug_info_pb2.GraphDebugInfo() 905 with open(debug_info_file_name, "rb") as f: 906 debug_info.ParseFromString(f.read()) 907 908 # Verify that there is a trace for DEBUG_INFO_OP just to ensure that 909 # function debug info tracing is nominally functioning. 910 found_op = False 911 for key in debug_info.traces.keys(): 912 if key.startswith("DEBUG_INFO_OP@"): 913 found_op = True 914 break 915 self.assertTrue(found_op, "Did not find DEBUG_INFO_OP in trace") 916 917 def test_save_debug_info_disabled(self): 918 root = autotrackable.AutoTrackable() 919 root.f = def_function.function( 920 lambda x: math_ops.mul(2., x, name="DEBUG_INFO_OP"), 921 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 922 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 923 save.save( 924 root, 925 save_dir, 926 root.f, 927 options=save_options.SaveOptions(save_debug_info=False)) 928 debug_info_file_name = os.path.join(save_dir, "debug", 929 "saved_model_debug_info.pb") 930 self.assertFalse(os.path.exists(debug_info_file_name)) 931 932 def test_function_aliases(self): 933 root = autotrackable.AutoTrackable() 934 root.f = def_function.function( 935 lambda x: 2. * x, 936 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 937 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 938 options = save_options.SaveOptions(function_aliases={ 939 "my_func": root.f, 940 }) 941 save.save(root, save_dir, root.f, options=options) 942 function_cache = root.f._stateful_fn._list_all_concrete_functions() 943 function_aliases = loader_impl.parse_saved_model( 944 save_dir).meta_graphs[0].meta_info_def.function_aliases 945 self.assertLen(function_cache, 1) 946 self.assertEqual(function_cache[0].name.decode("utf-8"), 947 list(function_aliases.keys())[0]) 948 949 def test_accepts_io_device(self): 950 options = save_options.SaveOptions() 951 self.assertIsNone(options.experimental_io_device) 952 options = save_options.SaveOptions(experimental_io_device="/job:localhost") 953 self.assertEqual("/job:localhost", options.experimental_io_device) 954 955 def test_accepts_variable_policy(self): 956 options = save_options.SaveOptions() 957 self.assertEqual(save_options.VariablePolicy.NONE, 958 options.experimental_variable_policy) 959 # VariablePolicy instances. 960 options = save_options.SaveOptions(experimental_variable_policy=save_options 961 .VariablePolicy.SAVE_VARIABLE_DEVICES) 962 self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, 963 options.experimental_variable_policy) 964 options = save_options.SaveOptions( 965 experimental_variable_policy=save_options.VariablePolicy 966 .EXPAND_DISTRIBUTED_VARIABLES) 967 self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, 968 options.experimental_variable_policy) 969 # String conversions. 970 options = save_options.SaveOptions( 971 experimental_variable_policy="save_variable_devices") 972 self.assertEqual(save_options.VariablePolicy.SAVE_VARIABLE_DEVICES, 973 options.experimental_variable_policy) 974 options = save_options.SaveOptions( 975 experimental_variable_policy="expand_distributed_variables") 976 self.assertEqual(save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES, 977 options.experimental_variable_policy) 978 with self.assertRaisesRegex(ValueError, "invalid VariablePolicy value"): 979 options = save_options.SaveOptions( 980 experimental_variable_policy="not_a_valid_value") 981 982 983class AssetTests(test.TestCase): 984 985 def setUp(self): 986 super(AssetTests, self).setUp() 987 self._vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt") 988 with open(self._vocab_path, "w") as f: 989 f.write("alpha\nbeta\ngamma\n") 990 991 def test_asset_path_returned(self): 992 root = autotrackable.AutoTrackable() 993 root.path = asset.Asset(self._vocab_path) 994 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 995 root.get_asset = def_function.function(lambda: root.path.asset_path) 996 save.save(root, save_dir, signatures=root.get_asset.get_concrete_function()) 997 second_dir = os.path.join(self.get_temp_dir(), "second_dir") 998 file_io.rename(save_dir, second_dir) 999 imported_path = _import_and_infer(second_dir, {})["output_0"] 1000 self.assertIn( 1001 compat.as_str_any(second_dir), compat.as_str_any(imported_path)) 1002 1003 def test_table(self): 1004 initializer = lookup_ops.TextFileInitializer( 1005 self._vocab_path, 1006 key_dtype=dtypes.string, 1007 key_index=lookup_ops.TextFileIndex.WHOLE_LINE, 1008 value_dtype=dtypes.int64, 1009 value_index=lookup_ops.TextFileIndex.LINE_NUMBER) 1010 root = checkpoint.Checkpoint( 1011 table=lookup_ops.HashTable(initializer, default_value=-1)) 1012 root.table_user = def_function.function( 1013 root.table.lookup, 1014 input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) 1015 self.assertEqual( 1016 2, self.evaluate(root.table_user(constant_op.constant("gamma")))) 1017 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 1018 save.save(root, save_dir) 1019 file_io.delete_file(self._vocab_path) 1020 self.assertAllClose({"output_0": [2, 0]}, 1021 _import_and_infer(save_dir, 1022 {"keys": ["gamma", "alpha"]})) 1023 second_dir = os.path.join(self.get_temp_dir(), "second_dir") 1024 # Asset paths should track the location the SavedModel is loaded from. 1025 file_io.rename(save_dir, second_dir) 1026 self.assertAllClose({"output_0": [2, 1]}, 1027 _import_and_infer(second_dir, 1028 {"keys": ["gamma", "beta"]})) 1029 1030 def test_untracked_table_useful_message(self): 1031 root = module.Module() 1032 initializer = lookup_ops.TextFileInitializer( 1033 self._vocab_path, 1034 key_dtype=dtypes.string, 1035 key_index=lookup_ops.TextFileIndex.WHOLE_LINE, 1036 value_dtype=dtypes.int64, 1037 value_index=lookup_ops.TextFileIndex.LINE_NUMBER) 1038 table = lookup_ops.HashTable(initializer, default_value=-1) 1039 root.table_user = def_function.function( 1040 table.lookup, 1041 input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) 1042 root.table_user(constant_op.constant("gamma")) 1043 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 1044 with self.assertRaisesRegexp(AssertionError, "HashTable"): 1045 save.save(root, save_dir) 1046 1047 def test_unused_asset(self): 1048 root = autotrackable.AutoTrackable() 1049 root.f = def_function.function( 1050 lambda x: 2. * x, 1051 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 1052 root.asset = asset.Asset(self._vocab_path) 1053 1054 export_dir = os.path.join(self.get_temp_dir(), "save_dir") 1055 save.save(root, export_dir) 1056 self.assertAllClose({"output_0": [0.2]}, 1057 _import_and_infer(export_dir, {"x": [0.1]})) 1058 1059 def test_sensible_function_building_exception(self): 1060 root = checkpoint.Checkpoint(v=variables.Variable(2.)) 1061 root.f = def_function.function( 1062 lambda x: 2. * root.v, 1063 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 1064 export_dir = os.path.join(self.get_temp_dir(), "save_dir") 1065 1066 @def_function.function 1067 def _calls_save(): 1068 save.save(root, export_dir) 1069 1070 with self.assertRaisesRegex(AssertionError, "tf.function"): 1071 _calls_save() 1072 1073 1074class ExportMetaGraphTests(test.TestCase): 1075 1076 def test_export_meta_graph(self): 1077 root = autotrackable.AutoTrackable() 1078 root.variable = resource_variable_ops.UninitializedVariable( 1079 name="some_variable", dtype=dtypes.float32) 1080 1081 @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) 1082 def multiply_var(x): 1083 return root.variable * x 1084 1085 @def_function.function(input_signature=[tensor_spec.TensorSpec([])]) 1086 def update(y): 1087 root.variable.assign_add(y) 1088 # TODO(b/150393409): All functions exported as signatures must have at 1089 # least one output. 1090 return 0 1091 1092 @def_function.function(input_signature=[]) 1093 def initialize(): 1094 root.variable.assign(1.0) 1095 # TODO(b/150393409): All functions exported as signatures must have at 1096 # least one output. 1097 return 0 1098 1099 save_path = os.path.join(self.get_temp_dir(), "meta_graph.pb") 1100 save.export_meta_graph( 1101 root, 1102 save_path, 1103 signatures={ 1104 "multiply_var": multiply_var, 1105 "initialize": initialize, 1106 "update": update 1107 }) 1108 1109 with ops.Graph().as_default(), session_lib.Session() as session: 1110 saver.import_meta_graph(save_path) 1111 meta_graph_def = meta_graph.read_meta_graph_file(save_path) 1112 1113 # Initialize variable to 1 1114 _run_signature(session, meta_graph_def, {}, "initialize") 1115 out = _run_signature(session, meta_graph_def, {"x": 3}, "multiply_var") 1116 self.assertAllEqual(out, {"output_0": 3}) 1117 1118 # Adds 2 to the variable. Variable is now 3 1119 _run_signature(session, meta_graph_def, {"y": 2}, "update") 1120 out = _run_signature(session, meta_graph_def, {"x": 4}, "multiply_var") 1121 self.assertAllEqual(out, {"output_0": 12}) 1122 1123 1124class FingerprintingTests(test.TestCase): 1125 1126 def test_toggle_flag(self): 1127 self.assertFalse(flags.config().saved_model_fingerprinting.value()) 1128 flags.config().saved_model_fingerprinting.reset(True) 1129 self.assertTrue(flags.config().saved_model_fingerprinting.value()) 1130 1131 1132if __name__ == "__main__": 1133 test.main() 1134