1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for trackable object SavedModel loading.""" 16 17import collections 18import contextlib 19import functools 20import gc 21import io 22import os 23import pathlib 24import sys 25import tempfile 26import weakref 27 28from absl.testing import parameterized 29import numpy as np 30from tensorflow.python.checkpoint import checkpoint 31from tensorflow.python.checkpoint import saveable_compat 32from tensorflow.python.client import session as session_lib 33from tensorflow.python.data.ops import dataset_ops 34from tensorflow.python.data.ops import readers 35from tensorflow.python.eager import backprop 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.eager import test 39from tensorflow.python.eager import wrap_function 40from tensorflow.python.framework import config 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import function as framework_function 45from tensorflow.python.framework import op_callbacks 46from tensorflow.python.framework import ops 47from tensorflow.python.framework import tensor_shape 48from tensorflow.python.framework import tensor_spec 49from tensorflow.python.framework import test_util 50from tensorflow.python.framework import versions 51from tensorflow.python.lib.io import file_io 52from tensorflow.python.lib.io import tf_record 53from tensorflow.python.module import module 54from tensorflow.python.ops import array_ops 55from tensorflow.python.ops import cond_v2 56from tensorflow.python.ops import control_flow_ops 57from tensorflow.python.ops import custom_gradient 58from tensorflow.python.ops import lookup_ops 59from tensorflow.python.ops import math_ops 60from tensorflow.python.ops import resource_variable_ops 61from tensorflow.python.ops import string_ops 62from tensorflow.python.ops import variable_scope 63from tensorflow.python.ops import variables 64from tensorflow.python.ops.ragged import ragged_factory_ops 65from tensorflow.python.ops.ragged import ragged_tensor 66from tensorflow.python.saved_model import load 67from tensorflow.python.saved_model import load_options 68from tensorflow.python.saved_model import loader_impl 69from tensorflow.python.saved_model import save 70from tensorflow.python.saved_model import save_options 71from tensorflow.python.saved_model import tag_constants 72from tensorflow.python.trackable import asset 73from tensorflow.python.trackable import autotrackable 74from tensorflow.python.trackable import resource 75from tensorflow.python.training import monitored_session 76from tensorflow.python.util import tf_inspect 77 78 79def cycle(obj, cycles, signatures=None, options=None): 80 to_save = obj 81 # TODO(vbardiovsky): It would be nice if exported protos reached a fixed 82 # point w.r.t. saving/restoring, ideally after 2nd saving. 83 for _ in range(cycles): 84 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 85 # If available, we'll run the save and restore preferring the GPU. This 86 # just makes sure we aren't throwing errors and have enough 87 # device("CPU") blocks to satisfy the placer. 88 with test_util.use_gpu(): 89 save.save(to_save, path, signatures, options=options) 90 loaded = load.load(path) 91 signatures = loaded.signatures 92 to_save = loaded 93 return loaded 94 95 96@parameterized.named_parameters( 97 dict(testcase_name="ReloadOnce", cycles=1), 98 dict(testcase_name="ReloadTwice", cycles=2), 99 dict(testcase_name="ReloadThrice", cycles=3) 100) 101class LoadTest(test.TestCase, parameterized.TestCase): 102 103 def test_structure_import(self, cycles): 104 root = autotrackable.AutoTrackable() 105 root.dep_one = autotrackable.AutoTrackable() 106 root.dep_two = autotrackable.AutoTrackable() 107 root.dep_two.dep = autotrackable.AutoTrackable() 108 root.dep_three = root.dep_two.dep 109 imported = cycle(root, cycles) 110 self.assertIs(imported.dep_three, imported.dep_two.dep) 111 self.assertIsNot(imported.dep_one, imported.dep_two) 112 113 @test_util.run_in_graph_and_eager_modes 114 def test_variables(self, cycles): 115 root = autotrackable.AutoTrackable() 116 root.v1 = variables.Variable(1., trainable=True) 117 root.v2 = variables.Variable(2., trainable=False) 118 self.evaluate([root.v1.initializer, root.v2.initializer]) 119 120 for _ in range(cycles): 121 imported = cycle(root, 1) 122 self.evaluate([imported.v1.initializer, imported.v2.initializer]) 123 124 if not context.executing_eagerly(): 125 self.assertIsInstance(imported.v1.initializer, ops.Operation) 126 self.assertIsInstance(imported.v2.initializer, ops.Operation) 127 128 self.assertEqual(self.evaluate(imported.v1), 1.0) 129 self.assertTrue(imported.v1.trainable) 130 self.assertEqual(self.evaluate(imported.v2), 2.0) 131 self.assertFalse(imported.v2.trainable) 132 133 def test_variables_name(self, cycles): 134 root = autotrackable.AutoTrackable() 135 # Test 2 variables with same name: should work as the checkpoint 136 # is based on object name and not on variable name. 137 root.v1 = variables.Variable(1., trainable=True, name="v1") 138 root.v2 = variables.Variable(2., trainable=False, name="v1") 139 imported = cycle(root, cycles) 140 self.assertEqual(imported.v1.numpy(), 1.0) 141 self.assertEqual(imported.v2.numpy(), 2.0) 142 self.assertEqual(imported.v1.name, root.v1.name) 143 self.assertEqual(imported.v2.name, root.v2.name) 144 with variable_scope.variable_scope("foo"): 145 imported = cycle(root, cycles) 146 self.assertTrue(imported.v1.name.startswith("foo/")) 147 self.assertTrue(imported.v2.name.startswith("foo/")) 148 149 def test_partially_defined_variable_shape(self, cycles): 150 151 class MakeVariable(module.Module): 152 153 def __init__(self): 154 self.v = None 155 156 @def_function.function( 157 input_signature=[tensor_spec.TensorSpec([None], dtypes.int64)]) 158 def make_variable(self, initial_value): 159 if self.v is None: 160 self.v = variables.Variable(initial_value) 161 162 m = MakeVariable() 163 m.make_variable([1, 2, 3]) 164 m = cycle(m, cycles) 165 m.v.assign([1, 2, 3, 4]) 166 self.assertEqual([None], tensor_shape.as_shape(m.v.shape).as_list()) 167 168 @test_util.run_in_graph_and_eager_modes 169 def test_capture_variables(self, cycles): 170 root = autotrackable.AutoTrackable() 171 root.weights = variables.Variable(2.) 172 self.evaluate(root.weights.initializer) 173 root.f = def_function.function( 174 lambda x: root.weights * x, 175 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 176 for _ in range(cycles): 177 imported = cycle(root, 1) 178 self.evaluate(imported.weights.initializer) 179 self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.)))) 180 self.evaluate(imported.weights.assign(4.0)) 181 self.assertEqual(8., self.evaluate(imported.f(constant_op.constant(2.)))) 182 183 @test_util.run_in_graph_and_eager_modes 184 def test_capture_constant(self, cycles): 185 root = autotrackable.AutoTrackable() 186 captured_constant = constant_op.constant(2.) 187 root.f = def_function.function( 188 lambda x: captured_constant * x, 189 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 190 imported = cycle(root, cycles) 191 self.assertEqual(4., self.evaluate(imported.f(constant_op.constant(2.)))) 192 193 def test_control_outputs(self, cycles): 194 exported = autotrackable.AutoTrackable() 195 exported.v = variables.Variable(1.) 196 exported.f = def_function.function( 197 lambda: exported.v.assign(2., name="should_be_control_output")) 198 exported_graph = exported.f.get_concrete_function().graph 199 self.assertIn( 200 exported_graph.get_operation_by_name("should_be_control_output"), 201 exported_graph.control_outputs) 202 203 imported = cycle(exported, cycles) 204 # Calling get_concrete_function wraps in a second call operation; we want to 205 # inspect the original function body for the control output; digging into 206 # graph.as_graph_def() and its FunctionDefLibrary is another option. 207 imported_concrete, = imported.f.concrete_functions 208 imported_graph = imported_concrete.graph 209 self.assertIn( 210 imported_graph.get_operation_by_name("should_be_control_output"), 211 imported_graph.control_outputs) 212 213 def _make_asset(self, contents): 214 fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir()) 215 with os.fdopen(fd, "w") as f: 216 f.write(contents) 217 return filename 218 219 @test_util.run_in_graph_and_eager_modes 220 def test_assets(self, cycles): 221 file1 = self._make_asset("contents 1") 222 file2 = self._make_asset("contents 2") 223 224 root = autotrackable.AutoTrackable() 225 root.asset1 = asset.Asset(file1) 226 root.asset2 = asset.Asset(file2) 227 228 save_dir = os.path.join(self.get_temp_dir(), "save_dir") 229 save.save(root, save_dir) 230 231 file_io.delete_file(file1) 232 file_io.delete_file(file2) 233 load_dir = os.path.join(self.get_temp_dir(), "load_dir") 234 file_io.rename(save_dir, load_dir) 235 236 imported = load.load(load_dir) 237 with open(self.evaluate(imported.asset1.asset_path), "r") as f: 238 self.assertEqual("contents 1", f.read()) 239 with open(self.evaluate(imported.asset2.asset_path), "r") as f: 240 self.assertEqual("contents 2", f.read()) 241 242 def test_cond_prune(self, cycles): 243 x_in = [] 244 x_out = [] 245 246 def f(x, y): 247 x_in.append(x) 248 xx = cond_v2.cond_v2( 249 math_ops.less(1, 2), 250 lambda: x + 1, 251 lambda: x + 2, 252 ) 253 x_out.append(xx) 254 return xx, 2 * y 255 256 f_wrapped = wrap_function.wrap_function( 257 f, [tensor_spec.TensorSpec((), dtypes.float32)] * 2) 258 f_pruned = f_wrapped.prune(x_in[0], [x_out[0]]) 259 260 class Adder(module.Module): 261 262 @def_function.function(input_signature=[ 263 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)]) 264 def add(self, x): 265 return f_pruned(x) 266 267 root = Adder() 268 root.add(constant_op.constant(1.)) 269 root = cycle(root, cycles) 270 root.add(constant_op.constant(1.)) 271 272 def test_capture_assets(self, cycles): 273 root = autotrackable.AutoTrackable() 274 root.vocab = asset.Asset(self._make_asset("contents")) 275 root.f = def_function.function( 276 lambda: root.vocab.asset_path, 277 input_signature=[]) 278 imported = cycle(root, cycles) 279 original_output = root.f().numpy() 280 imported_output = imported.f().numpy() 281 self.assertNotEqual(original_output, imported_output) 282 with open(imported_output, "r") as f: 283 self.assertEqual("contents", f.read()) 284 285 def test_capture_assets_in_graph(self, cycles): 286 root = autotrackable.AutoTrackable() 287 root.vocab = asset.Asset(self._make_asset("contents")) 288 root.f = def_function.function( 289 lambda: root.vocab.asset_path, 290 input_signature=[]) 291 292 original_output = root.f().numpy() 293 294 if cycles > 1: 295 root = cycle(root, cycles - 1) 296 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 297 save.save(root, path) 298 299 with ops.Graph().as_default(): 300 imported = load.load(path) 301 imported_tensor = imported.f() 302 with monitored_session.MonitoredSession() as sess: 303 imported_output = sess.run(imported_tensor) 304 self.assertLen(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1) 305 self.assertNotEqual(original_output, imported_output) 306 with open(imported_output, "r") as f: 307 self.assertEqual("contents", f.read()) 308 309 def test_dedup_assets(self, cycles): 310 vocab = self._make_asset("contents") 311 root = autotrackable.AutoTrackable() 312 root.asset1 = asset.Asset(vocab) 313 root.asset2 = asset.Asset(vocab) 314 imported = cycle(root, cycles) 315 self.assertEqual(imported.asset1.asset_path.numpy(), 316 imported.asset2.asset_path.numpy()) 317 318 def test_asset_fspath(self, cycles): 319 vocab = pathlib.Path(self._make_asset("contents")) 320 root = autotrackable.AutoTrackable() 321 root.asset = asset.Asset(vocab) 322 imported = cycle(root, cycles) 323 self.assertTrue(hasattr(imported, "asset")) 324 325 def test_implicit_input_signature(self, cycles): 326 @def_function.function 327 def func(x): 328 return 2 * x 329 330 root = autotrackable.AutoTrackable() 331 root.f = func 332 333 # Add two traces. 334 root.f(constant_op.constant(1.)) 335 root.f(constant_op.constant(1)) 336 337 imported = cycle(root, cycles) 338 339 self.assertEqual(4., imported.f(constant_op.constant(2.)).numpy()) 340 self.assertEqual(14, imported.f(constant_op.constant(7)).numpy()) 341 342 def test_explicit_input_signature(self, cycles): 343 @def_function.function( 344 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 345 def func(x): 346 return 2 * x 347 348 root = autotrackable.AutoTrackable() 349 root.f = func 350 351 imported = cycle(root, cycles) 352 self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy()) 353 354 def test_explicit_save_signature(self, cycles): 355 @def_function.function 356 def func(x): 357 return 2 * x 358 359 root = autotrackable.AutoTrackable() 360 root.f = func 361 362 imported = cycle( 363 root, cycles, { 364 "f": 365 root.f.get_concrete_function( 366 tensor_spec.TensorSpec(None, dtypes.float32)) 367 }) 368 self.assertEqual(4., imported.f(constant_op.constant(2.0)).numpy()) 369 370 def test_nested_functions(self, cycles): 371 f = def_function.function( 372 lambda x: x*2.0, 373 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 374 g = def_function.function( 375 lambda x: f(x) + 1.0, 376 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 377 378 root = autotrackable.AutoTrackable() 379 root.g = g 380 imported = cycle(root, cycles) 381 imported.g(constant_op.constant([1.0])) 382 383 def test_function_with_default_bool_input(self, cycles): 384 385 def func(x, training=False): 386 if training: 387 return 2 * x 388 else: 389 return 7 390 391 root = autotrackable.AutoTrackable() 392 root.f = def_function.function(func) 393 394 self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) 395 self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) 396 self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) 397 398 imported = cycle(root, cycles) 399 400 self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) 401 self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) 402 403 def test_function_with_default_none_input(self, cycles): 404 405 def func(x, dtype=None): 406 if dtype: 407 return array_ops.zeros(shape=x.shape, dtype=dtype) 408 else: 409 return array_ops.zeros(shape=x.shape, dtype=dtypes.float32) 410 411 root = autotrackable.AutoTrackable() 412 root.f = def_function.function(func) 413 414 self.assertAllEqual([0.0, 0.0, 0.0], 415 root.f(constant_op.constant([1, 2, 3])).numpy()) 416 self.assertAllEqual([0.0, 0.0, 0.0], 417 root.f(constant_op.constant([1.0, 2.0, 3.0])).numpy()) 418 self.assertAllEqual([0.0, 0.0, 0.0, 0.0], 419 root.f(constant_op.constant([1, 2, 3, 4])).numpy()) 420 self.assertAllEqual([0, 0, 0], 421 root.f( 422 constant_op.constant([1.0, 2.0, 3.0]), 423 dtype=dtypes.int32).numpy()) 424 425 concrete_functions = root.f._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 426 self.assertLen(concrete_functions, 4) 427 428 imported = cycle(root, cycles) 429 430 self.assertAllEqual([0.0, 0.0, 0.0], 431 imported.f(constant_op.constant([1, 2, 3]), 432 None).numpy()) 433 self.assertAllEqual([0.0, 0.0, 0.0], 434 imported.f(constant_op.constant([1.0, 2.0, 435 3.0])).numpy()) 436 self.assertAllEqual([0.0, 0.0, 0.0, 0.0], 437 imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) 438 self.assertAllEqual([0, 0, 0], 439 imported.f( 440 constant_op.constant([1.0, 2.0, 3.0]), 441 dtype=dtypes.int32).numpy()) 442 443 def test_function_with_str_bytes_input(self, cycles): 444 445 @def_function.function 446 def func(x, y): 447 return string_ops.string_join([x, y]) 448 449 root = autotrackable.AutoTrackable() 450 root.f = func 451 452 self.assertAllEqual(b"ab", root.f("a", "b")) 453 self.assertAllEqual(b"ab", root.f("a", constant_op.constant("b"))) 454 self.assertAllEqual(b"ab", root.f(constant_op.constant("a"), "b")) 455 456 concrete_functions = root.f._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 457 self.assertLen(concrete_functions, 3) 458 459 imported = cycle(root, cycles) 460 461 self.assertAllEqual(b"ab", imported.f("a", "b")) 462 self.assertAllEqual(b"ab", imported.f("a", constant_op.constant("b"))) 463 self.assertAllEqual(b"ab", imported.f(constant_op.constant("a"), "b")) 464 465 def test_function_no_return(self, cycles): 466 467 class TrackableWithOneVariable(autotrackable.AutoTrackable): 468 469 def __init__(self, initial_value=0.0): 470 super(TrackableWithOneVariable, self).__init__() 471 self.variable = variables.Variable(initial_value) 472 473 @def_function.function 474 def increase(self, by=1.0): 475 self.variable.assign_add(by) 476 477 obj = TrackableWithOneVariable(5.0) 478 479 obj.increase(constant_op.constant(10.0)) 480 self.assertEqual(15.0, obj.variable.numpy()) 481 obj.increase() 482 self.assertEqual(16.0, obj.variable.numpy()) 483 484 imported = cycle(obj, cycles) 485 486 imported.increase(constant_op.constant(10.0)) 487 self.assertEqual(26.0, imported.variable.numpy()) 488 imported.increase(constant_op.constant(1.0)) 489 self.assertEqual(27.0, imported.variable.numpy()) 490 491 def test_structured_inputs(self, cycles): 492 493 def func(x, training=True): 494 # x is a nested structure, we care about one particular tensor. 495 _, (a, b) = x 496 if training: 497 return 2 * a["a"] + b 498 else: 499 return 7 500 501 root = autotrackable.AutoTrackable() 502 root.f = def_function.function(func) 503 504 x = constant_op.constant(10) 505 y = constant_op.constant(11) 506 507 input1 = [6, ({"a": x}, y)] 508 input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. 509 input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. 510 511 # Note: by only calling f(input1) before serialization, only inputs with 512 # matching signature will be valid on the loaded model. 513 self.assertEqual(31, root.f(input1).numpy()) 514 515 imported = cycle(root, cycles) 516 517 with self.assertRaisesRegex( 518 ValueError, "Could not find matching concrete function to call"): 519 imported.f(input2) 520 521 self.assertEqual(31, imported.f(input1).numpy()) 522 self.assertEqual(32, imported.f(input3).numpy()) 523 524 def test_structured_inputs_bare_concrete_function(self, cycles): 525 526 def func(x, training=True): 527 # x is a nested structure, we care about one particular tensor. 528 _, (a, b) = x 529 if training: 530 return 2 * a["a"] + b 531 else: 532 return 7 533 534 x = constant_op.constant(10) 535 y = constant_op.constant(11) 536 537 input1 = [6, ({"a": x}, y)] 538 input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. 539 input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. 540 541 root = autotrackable.AutoTrackable() 542 root.f = def_function.function(func).get_concrete_function(input1) 543 544 imported = cycle(root, cycles) 545 546 with self.assertRaises(TypeError): 547 imported.f(input2) 548 549 self.assertEqual(31, imported.f(input1).numpy()) 550 self.assertEqual(32, imported.f(input3).numpy()) 551 552 def test_structured_output(self, cycles): 553 554 # Use fields with non-alphabetical order 555 named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) 556 557 def func(input1, input2): 558 named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) 559 return [named_tuple, input2, {"x": 0.5}] 560 561 root = autotrackable.AutoTrackable() 562 root.f = def_function.function(func) 563 564 result = root.f(constant_op.constant(2), constant_op.constant(3)) 565 566 self.assertEqual(5, result[0].a.numpy()) 567 self.assertEqual(6, result[0].b.numpy()) 568 self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) 569 self.assertEqual(3, result[1].numpy()) 570 self.assertEqual(0.5, result[2]["x"].numpy()) 571 572 imported = cycle(root, cycles) 573 574 result = imported.f(constant_op.constant(2), constant_op.constant(5)) 575 self.assertEqual(7, result[0].a.numpy()) 576 self.assertEqual(10, result[0].b.numpy()) 577 self.assertEqual(["b", "a"], list(result[0]._asdict().keys())) 578 self.assertEqual(5, result[1].numpy()) 579 self.assertEqual(0.5, result[2]["x"].numpy()) 580 581 def test_pretty_print_signature(self, cycles): 582 583 named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) 584 585 def func(input1, input2): 586 named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) 587 return [named_tuple, input2, {"x": 0.5}] 588 589 root = autotrackable.AutoTrackable() 590 root.f = def_function.function(func).get_concrete_function( 591 constant_op.constant(2), constant_op.constant(3)) 592 593 imported = cycle(root, cycles) 594 self.assertEqual( 595 imported.f.pretty_printed_signature(), """func(input1, input2) 596 Args: 597 input1: int32 Tensor, shape=() 598 input2: int32 Tensor, shape=() 599 Returns: 600 [NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}] 601 <1>: int32 Tensor, shape=() 602 <2>: int32 Tensor, shape=() 603 <3>: int32 Tensor, shape=() 604 <4>: float32 Tensor, shape=()""") 605 606 def test_positional_arguments(self, cycles): 607 def func(x, training=False, abc=7.1, defg=7.7): 608 del abc 609 if training: 610 return 2 * x 611 if defg == 7: 612 return 6 613 else: 614 return 7 615 616 root = autotrackable.AutoTrackable() 617 root.f = def_function.function(func) 618 619 self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) 620 self.assertEqual(7, root.f(constant_op.constant(1)).numpy()) 621 self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) 622 self.assertEqual(6, root.f(constant_op.constant(1), defg=7.0).numpy()) 623 624 imported = cycle(root, cycles) 625 626 self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) 627 self.assertEqual(7, imported.f(constant_op.constant(2)).numpy()) 628 self.assertEqual(6, imported.f(constant_op.constant(1), defg=7.0).numpy()) 629 630 def test_additional_kwargs(self, cycles): 631 def func(x, training=False, **options): 632 del options 633 if training: 634 return 2 * x 635 else: 636 return 7 637 638 root = autotrackable.AutoTrackable() 639 root.f = def_function.function(func) 640 641 x = constant_op.constant(10) 642 self.assertEqual(7, root.f(x, learning_rate=0.5, epochs=3).numpy()) 643 644 imported = cycle(root, cycles) 645 646 with self.assertRaisesRegex( 647 ValueError, "Could not find matching concrete function to call.*"): 648 imported.f(x, learning_rate=0.5, epochs=4) 649 650 self.assertEqual(7, imported.f(x, learning_rate=0.5, epochs=3).numpy()) 651 652 def test_member_function(self, cycles): 653 class TrackableWithMember(autotrackable.AutoTrackable): 654 655 def __init__(self): 656 super(TrackableWithMember, self).__init__() 657 self._some_value = 20 658 659 @def_function.function 660 def f(self, x, training=False): 661 if training: 662 return 2 * x 663 else: 664 return 7 + self._some_value 665 666 root = TrackableWithMember() 667 668 self.assertEqual(20, root.f(constant_op.constant(10), True).numpy()) 669 self.assertEqual(27, root.f(constant_op.constant(1)).numpy()) 670 self.assertEqual(2, root.f(constant_op.constant(1), True).numpy()) 671 672 imported = cycle(root, cycles) 673 674 self.assertEqual(4, imported.f(constant_op.constant(2), True).numpy()) 675 self.assertEqual(27, imported.f(constant_op.constant(2)).numpy()) 676 677 def test_side_effect_listing(self, cycles): 678 class M(autotrackable.AutoTrackable): 679 680 def __init__(self): 681 super(M, self).__init__() 682 self.var = None 683 684 @def_function.function( 685 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 686 def f(self, x): 687 if self.var is None: 688 self.var = variables.Variable(2.) 689 return x * self.var 690 691 m = M() 692 cycle(m, cycles) 693 self.assertEqual(4.0, m.f(constant_op.constant(2.0)).numpy()) 694 695 def test_basic_backprop(self, cycles): 696 weight = variables.Variable(1., trainable=True) 697 bias = variables.Variable(0., trainable=True) 698 g = def_function.function( 699 lambda x: x*weight + bias, 700 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 701 702 root = autotrackable.AutoTrackable() 703 root.weight = weight 704 root.bias = bias 705 root.g = g 706 imported = cycle(root, cycles) 707 with backprop.GradientTape() as t: 708 x = constant_op.constant([3.5]) 709 loss = imported.g(x) 710 grad = t.gradient(loss, [imported.weight, imported.bias]) 711 self.assertAllClose(grad, [3.5, 1.0]) 712 713 def test_nested_backprop(self, cycles): 714 weight = variables.Variable(1., trainable=True) 715 bias = variables.Variable(0., trainable=True) 716 717 # Note: this function gets called from other function defs via a 718 # "PartitionedCall" op node. 719 @def_function.function(input_signature=[ 720 tensor_spec.TensorSpec(None, dtypes.float32), 721 tensor_spec.TensorSpec(None, dtypes.float32)]) 722 def mul(x, y): 723 return x * y 724 725 # Note: this function gets called from other function defs via a 726 # "StatefulPartitionedCall" op node. 727 @def_function.function(input_signature=[ 728 tensor_spec.TensorSpec(None, dtypes.float32)]) 729 def f(x): 730 return mul(weight.read_value(), x) 731 732 @def_function.function(input_signature=[ 733 tensor_spec.TensorSpec(None, dtypes.float32)]) 734 def g(x): 735 return f(x) + bias, 736 737 @def_function.function(input_signature=[ 738 tensor_spec.TensorSpec(None, dtypes.float32)]) 739 def h(x): 740 return g(x) + bias, 741 742 root = autotrackable.AutoTrackable() 743 root.weight = weight 744 root.bias = bias 745 root.g = h 746 747 imported = cycle(root, cycles) 748 with backprop.GradientTape() as t: 749 x = constant_op.constant([3.5]) 750 loss = imported.g(x) 751 grad = t.gradient(loss, [imported.weight, imported.bias]) 752 self.assertAllClose(grad, [3.5, 2.0]) 753 754 def test_while_loop_backprop(self, cycles): 755 weight = variables.Variable(2., trainable=True) 756 757 @def_function.function(input_signature=[ 758 tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))]) 759 def g(x): 760 """Adds rows of matrix x after multiplying each entry by v.""" 761 i_0 = constant_op.constant(0) 762 s_0 = constant_op.constant([0., 0.]) 763 cond = lambda i, _: i < array_ops.shape(x)[1] 764 body = lambda i, s: (i + 1, s + weight * x[:, i]) 765 i_end, s_end = control_flow_ops.while_loop(cond, body, (i_0, s_0)) 766 del i_end 767 return s_end 768 769 root = autotrackable.AutoTrackable() 770 root.weight = weight 771 root.g = g 772 imported = cycle(root, cycles) 773 774 def get_gradient(obj): 775 with backprop.GradientTape() as t: 776 x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]]) 777 y = obj.g(x) 778 self.assertAllClose(y, obj.weight * [6., 2.]) 779 loss = math_ops.reduce_sum(y) # weight * 8. 780 self.assertAllEqual(t.watched_variables(), [obj.weight]) 781 return t.gradient(loss, obj.weight) 782 783 imported_gradient = get_gradient(imported) 784 original_gradient = get_gradient(root) 785 self.assertIsNotNone(original_gradient) 786 self.assertAllClose(original_gradient, 8.) 787 self.assertIsNotNone(imported_gradient) 788 self.assertAllClose(imported_gradient, 8.) 789 790 def _test_restored_func_with_captured_var_backprop(self, cycles, dtype): 791 weight = variables.Variable(2., trainable=True, dtype=dtype) 792 793 @def_function.function(input_signature=[ 794 tensor_spec.TensorSpec(dtype=dtype, shape=())]) 795 def g(x): 796 return x * weight 797 798 root = autotrackable.AutoTrackable() 799 root.weight = weight 800 root.g = g 801 imported = cycle(root, cycles) 802 803 def get_gradient(obj): 804 with backprop.GradientTape() as t: 805 x = constant_op.constant(2.) 806 y = obj.g(x) 807 self.assertAllClose(y, obj.weight * 2.) 808 self.assertAllEqual(t.watched_variables(), [obj.weight]) 809 return t.gradient(y, obj.weight) 810 811 imported_gradient = get_gradient(imported) 812 original_gradient = get_gradient(root) 813 self.assertIsNotNone(original_gradient) 814 self.assertAllClose(original_gradient, 2.) 815 self.assertIsNotNone(imported_gradient) 816 self.assertAllClose(imported_gradient, 2.) 817 818 def test_nested_fn_backprop(self, cycles): 819 weight = variables.Variable(2., trainable=True) 820 821 @def_function.function(input_signature=[ 822 tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))]) 823 def g(x): 824 weight.read_value() # Just get the tape to watch the variable 825 handle = array_ops.identity(weight.handle) 826 @def_function.function 827 def launder_var_handle(): 828 return array_ops.identity(handle) 829 return x + resource_variable_ops.read_variable_op( 830 launder_var_handle(), dtypes.float32) 831 832 root = autotrackable.AutoTrackable() 833 root.weight = weight 834 root.g = g 835 imported = cycle(root, cycles) 836 def get_gradient(obj, persistent): 837 with backprop.GradientTape(persistent=persistent) as t: 838 x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]]) 839 y = obj.g(x) 840 self.assertAllClose(y, obj.weight + x) 841 loss = math_ops.reduce_sum(y) 842 return t.gradient(loss, obj.weight) 843 844 imported_gradient = get_gradient(imported, persistent=False) 845 original_gradient = get_gradient(root, persistent=False) 846 self.assertIsNotNone(original_gradient) 847 self.assertAllClose(original_gradient, 6.) 848 self.assertIsNotNone(imported_gradient) 849 self.assertAllClose(imported_gradient, 6.) 850 851 def test_restored_func_with_captured_var_backprop_float32(self, cycles): 852 self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32) 853 854 def test_restored_func_with_captured_var_backprop_float64(self, cycles): 855 self.skipTest("b/144573917") 856 self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float64) 857 858 def test_callable(self, cycles): 859 class M1(autotrackable.AutoTrackable): 860 861 @def_function.function( 862 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 863 def __call__(self, x): 864 return x 865 866 root = autotrackable.AutoTrackable() 867 root.m1 = M1() 868 root.m2 = autotrackable.AutoTrackable() 869 root.m2.__call__ = def_function.function( 870 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( 871 lambda x: x*3.0) 872 imported = cycle(root, cycles) 873 x = constant_op.constant(1.0) 874 875 self.assertTrue(callable(imported.m1)) 876 self.assertAllEqual(root.m1(x), imported.m1(x)) 877 878 # Note: `root.m2` was not callable since `__call__` attribute was set 879 # into the instance and not on the class. But after a serialization cycle 880 # that starts to work. 881 self.assertTrue(callable(imported.m2)) 882 self.assertAllEqual(root.m2.__call__(x), imported.m2(x)) 883 884 # Verify that user objects without `__call__` attribute are not callable. 885 self.assertFalse(callable(imported)) 886 887 def test_chain_callable(self, cycles): 888 func = def_function.function( 889 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])( 890 lambda x: x*3.0) 891 root = autotrackable.AutoTrackable() 892 root.__call__ = autotrackable.AutoTrackable() 893 root.__call__.__call__ = autotrackable.AutoTrackable() 894 root.__call__.__call__.__call__ = func 895 896 imported = cycle(root, cycles) 897 self.assertTrue(callable(imported)) 898 x = constant_op.constant(1.0) 899 self.assertAllEqual(imported(x).numpy(), 3.0) 900 901 def test_load_in_graph_mode(self, cycles): 902 root = autotrackable.AutoTrackable() 903 root.v1 = variables.Variable(1., name="v_one", trainable=False) 904 root.v2 = variables.Variable(2., name="v_two", trainable=True) 905 root.f = def_function.function( 906 lambda x: root.v2 * x, 907 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 908 909 if cycles > 1: 910 root = cycle(root, cycles - 1) 911 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 912 save.save(root, path) 913 914 with ops.Graph().as_default() as g: 915 imported = load.load(path) 916 var_v1 = imported.v1 917 self.assertFalse(var_v1.trainable) 918 var_v2 = imported.v2 919 self.assertTrue(var_v2.trainable) 920 output = imported.f(constant_op.constant(2.)) 921 with monitored_session.MonitoredSession() as sess: 922 self.assertEqual(1.0, sess.run(var_v1)) 923 self.assertEqual(4.0, sess.run(output)) 924 self.assertCountEqual([var_v1, var_v2], 925 g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) 926 # load() should not add to TRAINABLE_VARIABLES. Higher levels of model 927 # building control retraining or frozen use of imported SavedModels. 928 self.assertCountEqual([], 929 g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 930 931 def test_load_in_func_graph(self, cycles): 932 root = autotrackable.AutoTrackable() 933 root.v1 = variables.Variable(1.) 934 root.v2 = variables.Variable(2.) 935 root.f = def_function.function( 936 lambda x: root.v2 * x, 937 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 938 939 if cycles > 1: 940 root = cycle(root, cycles - 1) 941 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 942 save.save(root, path) 943 944 closure = autotrackable.AutoTrackable() 945 @def_function.function 946 def func(x): 947 if not hasattr(closure, "model"): 948 closure.model = load.load(path) 949 return closure.model.f(x) 950 951 inputs = constant_op.constant(2.) 952 self.assertEqual(4.0, func(inputs).numpy()) 953 954 def test_soft_matching(self, cycles): 955 956 @def_function.function( 957 input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) 958 def func(x): 959 return 2 * x 960 961 root = autotrackable.AutoTrackable() 962 root.f = func 963 964 self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) 965 self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) 966 967 concrete_functions = root.f._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access 968 self.assertLen(concrete_functions, 1) 969 970 imported = cycle(root, cycles) 971 972 with self.assertRaisesRegex(ValueError, "Python inputs incompatible"): 973 # We cannot call the function with a constant of shape (). 974 imported.f(constant_op.constant(2)).numpy() 975 976 # TODO(vbardiovsky): When classes are revived with input_signatures, we 977 # should also check that the calls below are not generating any more 978 # concrete functions. 979 self.assertAllEqual([2, 4, 6, 8], 980 imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) 981 self.assertAllEqual([2, 4, 6], 982 imported.f(constant_op.constant([1, 2, 3])).numpy()) 983 984 def test_jit_compile(self, cycles): 985 986 # It'd be nice to use parameterize here, but the library does not support 987 # having parameterized test methods inside already-parameterized classes. 988 for jit_compile in (None, True, False): 989 990 @def_function.function(jit_compile=jit_compile) 991 def f(x): 992 return x + 1. 993 994 root = module.Module() 995 root.f = f 996 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 997 save.save(root, save_dir) 998 999 imported = cycle(root, cycles) 1000 1001 self.assertEqual(imported.f._jit_compile, jit_compile) 1002 1003 def test_get_concrete_function(self, cycles): 1004 1005 @def_function.function 1006 def func(x, training=False): 1007 if training: 1008 return 2 * x 1009 else: 1010 return 3 * x 1011 1012 func.get_concrete_function( 1013 tensor_spec.TensorSpec([None], dtypes.int32), True) 1014 func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32)) 1015 1016 root = autotrackable.AutoTrackable() 1017 root.f = func 1018 1019 imported = cycle(root, cycles) 1020 1021 concrete = imported.f.get_concrete_function( 1022 training=True, x=tensor_spec.TensorSpec([None], dtypes.int32)) 1023 1024 self.assertAllEqual([2, 4, 6, 8], 1025 concrete(x=constant_op.constant([1, 2, 3, 4])).numpy()) 1026 with self.assertRaisesRegex( 1027 ValueError, "Could not find matching concrete function to call"): 1028 imported.f.get_concrete_function( 1029 tensor_spec.TensorSpec([None], dtypes.int32)) 1030 imported.f.get_concrete_function( 1031 tensor_spec.TensorSpec([None], dtypes.int32), True) 1032 1033 def test_concrete_function(self, cycles): 1034 1035 @def_function.function( 1036 input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) 1037 def func(x): 1038 return 2 * x 1039 1040 root = autotrackable.AutoTrackable() 1041 root.f = func.get_concrete_function() 1042 1043 self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) 1044 self.assertAllEqual([2, 4], root.f(constant_op.constant([1, 2])).numpy()) 1045 1046 # TODO(andresp): Fix exporting of loaded concrete functions as signatures. 1047 imported = cycle(root, cycles, signatures={}) 1048 1049 self.assertAllEqual([2, 4, 6, 8], 1050 imported.f(constant_op.constant([1, 2, 3, 4])).numpy()) 1051 self.assertAllEqual([2, 4, 6], 1052 imported.f(constant_op.constant([1, 2, 3])).numpy()) 1053 1054 def test_concrete_function_captures(self, cycles): 1055 1056 class Root(module.Module): 1057 1058 def __init__(self): 1059 self.v = variables.Variable(1.) 1060 self.v1 = variables.Variable(1.) 1061 1062 @def_function.function( 1063 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 1064 def use_v(self, x): 1065 return self.v + self.v1 + 1. 1066 1067 root = Root() 1068 self.assertIn(root.v.handle, 1069 root.use_v.get_concrete_function().graph.external_captures) 1070 root = cycle(root, cycles, signatures=root.use_v.get_concrete_function()) 1071 func_captures = root.use_v.get_concrete_function().graph.external_captures 1072 self.assertLen(func_captures, 2) 1073 self.assertTrue(any(root.v.handle is t for t in func_captures)) 1074 self.assertTrue(any(root.v1.handle is t for t in func_captures)) 1075 signature_captures = root.signatures[ 1076 "serving_default"].graph.external_captures 1077 self.assertLen(signature_captures, 2) 1078 self.assertTrue(any(root.v.handle is t for t in signature_captures)) 1079 self.assertTrue(any(root.v1.handle is t for t in signature_captures)) 1080 1081 def test_concrete_function_arg_names(self, cycles): 1082 1083 @def_function.function( 1084 input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) 1085 def func(x): 1086 return 2 * x 1087 1088 root = autotrackable.AutoTrackable() 1089 root.f = func.get_concrete_function() 1090 1091 self.assertAllEqual([2], root.f(constant_op.constant([1])).numpy()) 1092 1093 # TODO(andresp): Fix exporting of loaded concrete functions as signatures. 1094 imported = cycle(root, cycles, signatures={}) 1095 1096 self.assertAllEqual([2, 4, 6], 1097 imported.f(x=constant_op.constant([1, 2, 3])).numpy()) 1098 1099 def test_concrete_function_no_signature(self, cycles): 1100 @def_function.function 1101 def func(x): 1102 return 2 * x 1103 1104 root = autotrackable.AutoTrackable() 1105 root.f = func.get_concrete_function(constant_op.constant([1])) 1106 self.assertAllEqual([4], root.f(constant_op.constant([2])).numpy()) 1107 # TODO(andresp): Fix exporting of loaded concrete functions as signatures. 1108 imported = cycle(root, cycles, signatures={}) 1109 self.assertAllEqual([6], 1110 imported.f(constant_op.constant([3])).numpy()) 1111 1112 @test_util.run_in_graph_and_eager_modes 1113 def test_concrete_function_backprop(self, cycles): 1114 @def_function.function( 1115 input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]) 1116 def func(x): 1117 return x ** 2. 1118 root = autotrackable.AutoTrackable() 1119 root.f = func.get_concrete_function() 1120 1121 def _compute_gradient(function): 1122 with backprop.GradientTape() as tape: 1123 inp = constant_op.constant(1.) 1124 tape.watch(inp) 1125 output = function(inp) 1126 return tape.gradient(output, inp) 1127 1128 self.assertAllEqual(2., _compute_gradient(root.f)) 1129 # TODO(andresp): Fix exporting of loaded concrete functions as signatures. 1130 imported = cycle(root, cycles, signatures={}) 1131 self.assertAllEqual(2., _compute_gradient(imported.f)) 1132 1133 def test_revived_concrete_function_kwargs(self, cycles): 1134 1135 @def_function.function 1136 def func(x, y): 1137 return x * (y + 1.) 1138 root = autotrackable.AutoTrackable() 1139 root.f = func.get_concrete_function( 1140 tensor_spec.TensorSpec([], dtypes.float32), 1141 tensor_spec.TensorSpec([], dtypes.float32)) 1142 self.assertEqual(8., root.f(y=constant_op.constant(3.), 1143 x=constant_op.constant(2.)).numpy()) 1144 # TODO(andresp): Fix exporting of loaded concrete functions as signatures. 1145 imported = cycle(root, cycles, signatures={}) 1146 self.assertEqual(8., imported.f(y=constant_op.constant(3.), 1147 x=constant_op.constant(2.)).numpy()) 1148 1149 def test_revived_concrete_function_tensorspec_kwargs(self, cycles): 1150 1151 @def_function.function 1152 def func(*args): 1153 x, y = args 1154 return x * (y + 1.) 1155 root = autotrackable.AutoTrackable() 1156 root.f = func.get_concrete_function( 1157 tensor_spec.TensorSpec([], dtypes.float32, name="x"), 1158 tensor_spec.TensorSpec([], dtypes.float32, name="y")) 1159 self.assertEqual(8., root.f(y=constant_op.constant(3.), 1160 x=constant_op.constant(2.)).numpy()) 1161 imported = cycle(root, cycles, signatures={}) 1162 self.assertEqual(8., imported.f(y=constant_op.constant(3.), 1163 x=constant_op.constant(2.)).numpy()) 1164 1165 def test_concrete_function_variable_argument(self, cycles): 1166 capture = variables.Variable(0) 1167 1168 @def_function.function 1169 def func(v): 1170 v.assign_add(1) 1171 capture.assign_sub(1) 1172 1173 @def_function.function(input_signature=[ 1174 resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) 1175 ]) 1176 def func_with_input_signature(v): 1177 v.assign_add(5) 1178 capture.assign_sub(5) 1179 return 1 1180 1181 vsave = variables.Variable(1) 1182 root = autotrackable.AutoTrackable() 1183 root.f = func.get_concrete_function(vsave) 1184 root.f_sig = func_with_input_signature.get_concrete_function() 1185 root.capture = capture 1186 1187 self.assertEqual(1, vsave.numpy()) 1188 root.f(vsave) 1189 self.assertEqual(2, vsave.numpy()) 1190 self.assertEqual(-1, capture.numpy()) 1191 1192 root.f_sig(vsave) 1193 self.assertEqual(7, vsave.numpy()) 1194 self.assertEqual(-6, capture.numpy()) 1195 1196 imported = cycle(root, cycles) 1197 1198 vload = variables.Variable(1) 1199 imported.f(vload) 1200 self.assertEqual(2, vload.numpy()) 1201 imported.f(v=vload) 1202 self.assertEqual(3, vload.numpy()) 1203 self.assertEqual(-8, imported.capture.numpy()) 1204 1205 imported.f_sig(v=vload) 1206 self.assertEqual(8, vload.numpy()) 1207 self.assertEqual(-13, imported.capture.numpy()) 1208 1209 self.assertEqual(-6, capture.numpy()) 1210 1211 def test_function_and_component(self, cycles): 1212 1213 @def_function.function 1214 def func(v): 1215 return v + 1 1216 1217 root = autotrackable.AutoTrackable() 1218 root.func = func 1219 root.concrete_func = func.get_concrete_function( 1220 tensor_spec.TensorSpec(None, dtypes.int32)) 1221 one = constant_op.constant(1) 1222 self.assertEqual(2, root.func(one).numpy()) 1223 self.assertEqual(2, root.concrete_func(one).numpy()) 1224 imported = cycle(root, cycles) 1225 self.assertEqual(2, imported.func(one).numpy()) 1226 self.assertEqual(2, imported.concrete_func(one).numpy()) 1227 1228 def test_dict(self, cycles): 1229 root = autotrackable.AutoTrackable() 1230 root.variables = dict(a=variables.Variable(1.)) 1231 root.variables["b"] = variables.Variable(2.) 1232 root.variables["c"] = 1 1233 root.funcs = dict( 1234 a=def_function.function(lambda: constant_op.constant(100.))) 1235 root.funcs["conc"] = root.funcs["a"].get_concrete_function() 1236 imported = cycle(root, cycles) 1237 self.assertEqual(1., imported.variables["a"].numpy()) 1238 self.assertEqual(2., imported.variables["b"].numpy()) 1239 self.assertEqual(set(["a", "b"]), set(imported.variables.keys())) 1240 self.assertEqual(100., imported.funcs["a"]().numpy()) 1241 self.assertEqual(100., imported.funcs["conc"]().numpy()) 1242 1243 def test_list(self, cycles): 1244 root = autotrackable.AutoTrackable() 1245 root.variables = [variables.Variable(1.)] 1246 root.variables.append(1) 1247 root.variables.append(variables.Variable(3.)) 1248 imported = cycle(root, cycles) 1249 self.assertEqual(1., imported.variables[0].numpy()) 1250 self.assertEqual(3., imported.variables[2].numpy()) 1251 self.assertIs(None, imported.variables[1]) 1252 self.assertLen(imported.variables, 3) 1253 1254 def test_tuple(self, cycles): 1255 root = autotrackable.AutoTrackable() 1256 root.variables = (variables.Variable(1.), 1, variables.Variable(3.)) 1257 imported = cycle(root, cycles) 1258 self.assertEqual(1., imported.variables[0].numpy()) 1259 self.assertEqual(3., imported.variables[2].numpy()) 1260 self.assertIs(None, imported.variables[1]) 1261 self.assertLen(imported.variables, 3) 1262 1263 def test_functions_list(self, cycles): 1264 root = autotrackable.AutoTrackable() 1265 v1 = variables.Variable(1.) 1266 root.losses = [def_function.function(lambda: math_ops.reduce_sum(v1 ** 2))] 1267 root.variables = [v1] 1268 1269 @def_function.function 1270 def _v2_loss(): 1271 if len(root.variables) == 1: 1272 v2 = variables.Variable(2.) 1273 root.variables.append(v2) 1274 return math_ops.reduce_sum(root.variables[1] ** 2) 1275 1276 root.losses.append(_v2_loss) 1277 self.assertAllClose([1., 4.], [loss() for loss in root.losses]) 1278 imported = cycle(root, cycles) 1279 self.assertAllClose([1., 4.], [loss() for loss in imported.losses]) 1280 imported.variables[0].assign(3.) 1281 imported.variables[1].assign(4.) 1282 self.assertAllClose([9., 16.], [loss() for loss in imported.losses]) 1283 1284 def test_captured_constant(self, cycles): 1285 const = array_ops.zeros([100]) 1286 root = autotrackable.AutoTrackable() 1287 root.f = def_function.function(lambda: const + 1.) 1288 root.g = def_function.function(lambda: const + 2.) 1289 self.assertAllClose(array_ops.ones([100]), root.f()) 1290 self.assertAllClose(2. * array_ops.ones([100]), root.g()) 1291 imported = cycle(root, cycles) 1292 self.assertAllClose(array_ops.ones([100]), imported.f()) 1293 self.assertAllClose(2. * array_ops.ones([100]), imported.g()) 1294 # TODO(b/123408994): Use the public get_concrete_function. 1295 f_concrete = imported.f._list_all_concrete_functions_for_serialization()[0] 1296 g_concrete = imported.g._list_all_concrete_functions_for_serialization()[0] 1297 self.assertLen(f_concrete.captured_inputs, 1) 1298 self.assertLen(g_concrete.captured_inputs, 1) 1299 # We should be using the same captured EagerTensor in both functions, not 1300 # duplicating the constant. 1301 self.assertIs(f_concrete.captured_inputs[0], 1302 g_concrete.captured_inputs[0]) 1303 1304 def test_functions_accessed_once(self, cycles): 1305 1306 class Exported(autotrackable.AutoTrackable): 1307 1308 def __init__(self): 1309 self._counter = 0 1310 1311 @property 1312 def make_func(self): 1313 @def_function.function 1314 def f(): 1315 return constant_op.constant(self._counter) 1316 f.get_concrete_function() # force a trace 1317 self._counter += 1 1318 return f 1319 1320 exported = Exported() 1321 imported = cycle(exported, cycles) 1322 self.assertEqual(0, imported.make_func().numpy()) 1323 self.assertEqual(1, exported.make_func().numpy()) 1324 1325 def test_overwritten_signatures_error(self, cycles): 1326 exported = autotrackable.AutoTrackable() 1327 exported.f = def_function.function(lambda: constant_op.constant(1.)) 1328 imported = cycle( 1329 exported, cycles, 1330 signatures={"key": exported.f.get_concrete_function()}) 1331 self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy()) 1332 imported.signatures = {"key1": imported.signatures["key"]} 1333 with self.assertRaisesRegex(ValueError, "signatures"): 1334 save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir())) 1335 1336 def test_signature_loading(self, cycles): 1337 1338 class Exported(autotrackable.AutoTrackable): 1339 1340 def __init__(self): 1341 self.v = variables.Variable(3.) 1342 1343 @def_function.function 1344 def do(self, x): 1345 return self.v * x 1346 1347 exported = Exported() 1348 imported = cycle( 1349 exported, 1350 cycles, 1351 signatures=exported.do.get_concrete_function( 1352 tensor_spec.TensorSpec(None, dtypes.float32))) 1353 self.assertEqual(["serving_default"], list(imported.signatures.keys())) 1354 imported_function = imported.signatures["serving_default"] 1355 two = constant_op.constant(2.) 1356 self.assertEqual(6., imported_function(x=two)["output_0"].numpy()) 1357 imported.v.assign(4.) 1358 self.assertEqual(8., imported_function(x=two)["output_0"].numpy()) 1359 self.assertEqual(8., imported_function(two)["output_0"].numpy()) 1360 with self.assertRaises(TypeError): 1361 # The signatures mapping is immutable 1362 imported.signatures["random_key"] = 3 1363 1364 def test_names_normalized(self, cycles): 1365 class ObjWithFunction(module.Module): 1366 1367 @def_function.function(input_signature=[ 1368 tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"), 1369 tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"), 1370 tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"), 1371 tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"), 1372 ]) 1373 def foo(self, a, b, c, d=10, **options): 1374 del options 1375 return a + b + c + d 1376 1377 exported = ObjWithFunction() 1378 1379 with self.assertLogs(level="WARNING") as logs: 1380 imported = cycle(exported, cycles) 1381 1382 expected_message = ( 1383 "WARNING:absl:Function `foo` contains input name(s) A-b, A/D with " 1384 "unsupported characters which will be renamed to a_b, a_d in the " 1385 "SavedModel.") 1386 self.assertIn(expected_message, logs.output) 1387 1388 loaded_signature = imported.signatures["serving_default"].inputs 1389 self.assertEqual("a_b:0", loaded_signature[0].name) 1390 self.assertEqual("a_d:0", loaded_signature[1].name) 1391 1392 def test_multiple_argument_signatures_no_positional(self, cycles): 1393 1394 class Exported(autotrackable.AutoTrackable): 1395 1396 @def_function.function 1397 def do(self, x, y): 1398 return x + y 1399 1400 exported = Exported() 1401 imported = cycle( 1402 exported, cycles, signatures=exported.do.get_concrete_function( 1403 tensor_spec.TensorSpec(None, dtypes.float32), 1404 tensor_spec.TensorSpec(None, dtypes.float32))) 1405 with self.assertRaises(TypeError): 1406 imported.signatures["serving_default"]( 1407 constant_op.constant(1.), 1408 y=constant_op.constant(2.)) 1409 self.assertEqual( 1410 {"output_0": 3.}, 1411 self.evaluate(imported.signatures["serving_default"]( 1412 x=constant_op.constant(1.), 1413 y=constant_op.constant(2.)))) 1414 1415 def _make_model_with_tables(self): 1416 default_val = -1 1417 keys = constant_op.constant(["brain", "salad", "surgery"]) 1418 values = constant_op.constant([0, 1, 2], dtypes.int64) 1419 table1_initializer = lookup_ops.KeyValueTensorInitializer(keys, values) 1420 table1 = lookup_ops.HashTable(table1_initializer, default_val) 1421 1422 table2_file = self._make_asset("test\nfoo\nbrain\n") 1423 table2_initializer = lookup_ops.TextFileIdTableInitializer(table2_file) 1424 table2 = lookup_ops.HashTable(table2_initializer, default_val) 1425 1426 def _make_lookup_function(table): 1427 signature = [tensor_spec.TensorSpec(None, dtypes.string)] 1428 return def_function.function(input_signature=signature)( 1429 lambda x: table.lookup(x)) # pylint: disable=unnecessary-lambda 1430 1431 root = autotrackable.AutoTrackable() 1432 root.table1 = table1 1433 root.lookup1 = _make_lookup_function(table1) 1434 root.table2 = table2 1435 root.lookup2 = _make_lookup_function(table2) 1436 return root 1437 1438 def test_table(self, cycles): 1439 root = self._make_model_with_tables() 1440 imported = cycle(root, cycles, signatures={}) 1441 keys = constant_op.constant(["brain", "test", "foo", "surgery"]) 1442 self.assertAllEqual([0, -1, -1, 2], imported.lookup1(keys).numpy()) 1443 self.assertAllEqual([2, 0, 1, -1], imported.lookup2(keys).numpy()) 1444 1445 def test_table_collections_untouched_eager(self, cycles): 1446 1447 def _gather_nonempty_collections(): 1448 graph = ops.get_default_graph() 1449 gathered = {} 1450 for collection in graph.collections: 1451 collection_contents = graph.get_collection(collection) 1452 if collection_contents: 1453 gathered[collection] = collection_contents 1454 return gathered 1455 1456 root = self._make_model_with_tables() 1457 # Warm up collections to ignore those that don't expand every iteration, 1458 # e.g. the __varscope collection. 1459 cycle(root, 1) 1460 original_collections = _gather_nonempty_collections() 1461 cycle(root, cycles) 1462 self.assertEqual(original_collections, _gather_nonempty_collections()) 1463 1464 def test_table_in_graph(self, cycles): 1465 root = self._make_model_with_tables() 1466 1467 if cycles > 1: 1468 root = cycle(root, cycles - 1) 1469 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 1470 save.save(root, path) 1471 imported = cycle(root, 1) 1472 1473 with ops.Graph().as_default(): 1474 imported = load.load(path) 1475 keys = constant_op.constant(["brain", "test", "foo", "surgery"]) 1476 output1 = imported.lookup1(keys) 1477 output2 = imported.lookup2(keys) 1478 with monitored_session.MonitoredSession() as sess: 1479 self.assertAllEqual([0, -1, -1, 2], sess.run(output1)) 1480 self.assertAllEqual([2, 0, 1, -1], sess.run(output2)) 1481 1482 def test_preserve_argspec(self, cycles): 1483 1484 def f(a, b, c): # pylint: disable=unused-argument 1485 return None 1486 1487 original_fullargspec = tf_inspect.getfullargspec(f) 1488 1489 root = autotrackable.AutoTrackable() 1490 root.f = def_function.function(f) 1491 imported = cycle(root, cycles) 1492 1493 restored_fullargspec = tf_inspect.getfullargspec(imported.f) 1494 self.assertEqual(original_fullargspec, restored_fullargspec) 1495 1496 def test_canonicalize_inputs(self, cycles): 1497 @def_function.function(autograph=False) 1498 def func(a=1, b=2, c=3, training=True): 1499 if training: 1500 return [a, b, c, training] 1501 else: 1502 return [c, b, a, training] 1503 1504 # TODO(b/123501567): Work-around to trigger generic traces of a function 1505 # with extra non tensor args. 1506 signature = 3*[tensor_spec.TensorSpec(None, dtypes.float32)] 1507 @def_function.function(input_signature=signature) 1508 def trigger(a, b, c): 1509 func(a, b, c, True) 1510 func(a, b, c, False) 1511 1512 trigger.get_concrete_function() 1513 1514 root = autotrackable.AutoTrackable() 1515 root.f = func 1516 root = cycle(root, cycles) 1517 self.assertAllEqual(root.f(), [1.0, 2.0, 3.0, True]) 1518 self.assertAllEqual(root.f(-1.0, training=False), [3.0, 2.0, -1.0, False]) 1519 1520 with self.assertRaisesRegex(ValueError, 1521 "Could not find matching concrete function"): 1522 root.f(["hello", 1.0]) 1523 1524 def test_prefer_specific_trace(self, cycles): 1525 @def_function.function(autograph=False) 1526 def func(a): 1527 if isinstance(a, int): 1528 return a 1529 else: 1530 return a + 1 1531 1532 self.assertAllEqual(2, func(2).numpy()) 1533 self.assertAllEqual(3, func(constant_op.constant(2)).numpy()) 1534 1535 root = autotrackable.AutoTrackable() 1536 root.f = func 1537 root = cycle(root, cycles) 1538 self.assertAllEqual(2, root.f(2).numpy()) 1539 self.assertAllEqual(4, root.f(3).numpy()) 1540 self.assertAllEqual(3, root.f(constant_op.constant(2)).numpy()) 1541 self.assertAllEqual(4, root.f(constant_op.constant(3)).numpy()) 1542 1543 def test_partial(self, cycles): 1544 def f(x, y): 1545 return x + y 1546 1547 func = def_function.function( 1548 functools.partial(f, x=array_ops.zeros([1]), y=array_ops.ones([1]))) 1549 1550 root = autotrackable.AutoTrackable() 1551 root.f = func 1552 self.assertAllEqual(root.f(), [1.0]) 1553 1554 root = cycle(root, cycles) 1555 self.assertAllEqual(root.f(), [1.0]) 1556 1557 def test_partial_with_non_tensor_defaults(self, cycles): 1558 1559 def f(x, y=3): 1560 return x + y 1561 1562 func = def_function.function(functools.partial(f, y=5)) 1563 1564 root = autotrackable.AutoTrackable() 1565 root.f = func 1566 self.assertAllEqual(root.f(1), 6) 1567 1568 root = cycle(root, cycles) 1569 self.assertAllEqual(root.f(1), 6) 1570 1571 def test_partial_with_positional(self, cycles): 1572 def f(x, y): 1573 return x + y 1574 1575 func = def_function.function(functools.partial(f, constant_op.constant(5))) 1576 1577 root = autotrackable.AutoTrackable() 1578 root.f = func 1579 self.assertAllEqual(root.f(1), 6) 1580 1581 root = cycle(root, cycles) 1582 self.assertAllEqual(root.f(1), 6) 1583 1584 def test_partial_with_positional_captured_tensors(self, cycles): 1585 1586 def f(x, y): 1587 return x + y 1588 1589 tensor = constant_op.constant(5) + constant_op.constant(7) 1590 func = def_function.function(functools.partial(f, tensor)) 1591 1592 root = autotrackable.AutoTrackable() 1593 root.f = func 1594 self.assertAllEqual(root.f(1), 13) 1595 1596 root = cycle(root, cycles) 1597 self.assertAllEqual(root.f(1), 13) 1598 1599 def test_partial_keyword_hiding_default(self, cycles): 1600 1601 def f(x=3, training=True, y=7): 1602 if training: 1603 return x + y 1604 else: 1605 return x + y + 2 1606 1607 func = def_function.function(functools.partial(f, y=6)) 1608 1609 root = autotrackable.AutoTrackable() 1610 root.f = func 1611 self.assertEqual(root.f().numpy(), 9) 1612 self.assertEqual(root.f(training=False).numpy(), 11) 1613 1614 root = cycle(root, cycles) 1615 self.assertEqual(root.f().numpy(), 9) 1616 self.assertEqual(root.f(training=False).numpy(), 11) 1617 1618 def test_partial_with_kwargs(self, cycles): 1619 1620 def f(a, b, *args, **kwargs): 1621 args_sum = sum(args) 1622 return a + b + kwargs["some_tensor"] * kwargs["learning_rate"] + args_sum 1623 1624 constant_tensor = constant_op.constant(10) 1625 func = def_function.function( 1626 functools.partial( 1627 f, 7, 1, 2, learning_rate=3, some_tensor=constant_tensor)) 1628 1629 root = autotrackable.AutoTrackable() 1630 root.f = func 1631 self.assertEqual(root.f(constant_op.constant(4)).numpy(), 44) 1632 1633 root = cycle(root, cycles) 1634 self.assertEqual(root.f(constant_op.constant(5)).numpy(), 45) 1635 1636 def test_partial_bind_only_first_argument(self, cycles): 1637 if sys.version_info[0] < 3: 1638 self.skipTest("Test is only valid in python3. Only then we get some more " 1639 "advanced inspection of partials where this is allowed.") 1640 1641 def f(x, y): 1642 return x + y 1643 1644 partial_func = functools.partial(f, x=5) 1645 tf_func = def_function.function(partial_func) 1646 1647 root = autotrackable.AutoTrackable() 1648 root.f = tf_func 1649 self.assertAllEqual(root.f(y=constant_op.constant(7)), 12) 1650 1651 root = cycle(root, cycles) 1652 self.assertAllEqual(root.f(y=constant_op.constant(9)), 14) 1653 1654 def test_partial_with_passed_fn_as_default(self, cycles): 1655 1656 def f(x, y): 1657 return x(3) + y 1658 1659 def my_func(a): 1660 return 2 * a 1661 1662 func = def_function.function(functools.partial(f, my_func)) 1663 1664 root = autotrackable.AutoTrackable() 1665 root.f = func 1666 self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9) 1667 1668 root = cycle(root, cycles) 1669 self.assertEqual(root.f(constant_op.constant(3)).numpy(), 9) 1670 1671 def test_partial_with_input_signature(self, cycles): 1672 1673 def full_function(a, b, c=3.0): 1674 return a, b, c 1675 1676 partial = functools.partial(full_function, 1, c=4) 1677 self.assertAllEqual((1, 2.0, 4), partial(2.0)) 1678 1679 signature = [tensor_spec.TensorSpec([], dtypes.float32)] 1680 func = def_function.function(partial, input_signature=signature) 1681 1682 root = autotrackable.AutoTrackable() 1683 root.f = func 1684 a, b, c = root.f(2.0) 1685 self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 2.0, 4)) 1686 1687 root = cycle(root, cycles) 1688 a, b, c = root.f(3.0) 1689 self.assertAllEqual([a.numpy(), b.numpy(), c.numpy()], (1, 3.0, 4)) 1690 1691 def test_convert_to_input_signature(self, cycles): 1692 1693 @def_function.function( 1694 input_signature=[tensor_spec.TensorSpec([None], dtypes.int32)]) 1695 def func(x): 1696 return x 1697 1698 root = autotrackable.AutoTrackable() 1699 root.f = func 1700 1701 root = cycle(root, cycles) 1702 1703 self.assertEqual([2], root.f([2]).numpy()) 1704 1705 def test_named_tuple(self, cycles): 1706 1707 class NamedTupleType(collections.namedtuple("NamedTupleType", ["a", "b"])): 1708 pass 1709 1710 @def_function.function 1711 def f(x): 1712 return x.a + x.b 1713 1714 f.get_concrete_function( 1715 NamedTupleType( 1716 a=tensor_spec.TensorSpec(None, dtypes.float32, name="a"), 1717 b=tensor_spec.TensorSpec(None, dtypes.float32, name="b"))) 1718 obj = autotrackable.AutoTrackable() 1719 obj.__call__ = f 1720 if sys.version_info.major == 3 and sys.version_info.minor < 5: 1721 # TODO(allenl): figure out why this doesn't work in Python3.4 1722 self.skipTest("Not working in Python 3.4") 1723 imported = cycle(obj, cycles) 1724 self.assertAllClose(3., 1725 imported(NamedTupleType(a=constant_op.constant(1.), 1726 b=constant_op.constant(2.)))) 1727 1728 def test_extra_args(self, cycles): 1729 1730 @def_function.function 1731 def f(x): 1732 return math_ops.add(x["a"], 1.) 1733 # Trigger a trace. 1734 f({"a": constant_op.constant(2.0)}) 1735 1736 obj = autotrackable.AutoTrackable() 1737 obj.__call__ = f 1738 imported = cycle(obj, cycles) 1739 1740 self.assertEqual(4.0, imported({"a": 3.0}).numpy()) 1741 1742 with self.assertRaisesRegex( 1743 ValueError, "Could not find matching concrete function to call"): 1744 imported({"a": 2.0, "b": 3.0}) 1745 1746 def test_shapes_available(self, cycles): 1747 1748 @def_function.function(input_signature=[ 1749 tensor_spec.TensorSpec([None, 3], dtypes.int32), 1750 tensor_spec.TensorSpec([None, 2], dtypes.int32) 1751 ]) 1752 def func(x, y): 1753 return array_ops.concat([x, y], axis=1) 1754 1755 root = autotrackable.AutoTrackable() 1756 root.f = func 1757 1758 root = cycle(root, cycles) 1759 1760 imported_graph = root.f.get_concrete_function().graph 1761 input_x, input_y = imported_graph.inputs 1762 self.assertEqual([None, 3], input_x.shape.as_list()) 1763 self.assertEqual([None, 2], input_y.shape.as_list()) 1764 output, = imported_graph.outputs 1765 self.assertEqual([None, 5], output.shape.as_list()) 1766 signature = root.signatures["serving_default"] 1767 self.assertEqual( 1768 [None, 3], signature.inputs[0].shape.as_list()) 1769 self.assertEqual( 1770 [None, 2], signature.inputs[1].shape.as_list()) 1771 self.assertEqual( 1772 [None, 5], signature.outputs[0].shape.as_list()) 1773 1774 def test_variables_destroyed(self, cycles): 1775 v1 = variables.Variable(1.) 1776 weak_v1 = weakref.ref(v1) 1777 root = checkpoint.Checkpoint(v=v1) 1778 root = cycle(root, cycles) 1779 del v1 1780 self.assertIsNone(weak_v1()) 1781 weak_v2 = weakref.ref(root.v) 1782 del root 1783 self.assertIsNone(weak_v2()) 1784 1785 def test_variable_attributes_preserved(self, cycles): 1786 v = variables.Variable( 1787 1., 1788 trainable=False, 1789 synchronization=variables.VariableSynchronization.NONE, 1790 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 1791 self.assertEqual(variables.VariableSynchronization.NONE, 1792 v.synchronization) 1793 self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA, 1794 v.aggregation) 1795 root = autotrackable.AutoTrackable() 1796 root.v = v 1797 root = cycle(root, cycles) 1798 self.assertEqual(False, root.v.trainable) 1799 self.assertEqual(variables.VariableSynchronization.NONE, 1800 root.v.synchronization) 1801 self.assertEqual(variables.VariableAggregation.ONLY_FIRST_REPLICA, 1802 root.v.aggregation) 1803 1804 def test_captured_dataset(self, cycles): 1805 1806 class HasDataset(module.Module): 1807 1808 def __init__(self): 1809 super(HasDataset, self).__init__() 1810 self.dataset = ( 1811 dataset_ops.Dataset.range(5) 1812 .map(lambda x: x ** 2)) 1813 1814 @def_function.function 1815 def __call__(self, x): 1816 current_sum = array_ops.zeros([], dtype=dtypes.int64) 1817 for element in self.dataset: 1818 current_sum += x * element 1819 return current_sum 1820 1821 root = HasDataset() 1822 self.assertEqual( 1823 3 * (1 + 4 + 9 + 16), 1824 root(constant_op.constant(3, dtype=dtypes.int64)).numpy()) 1825 root = cycle(root, cycles) 1826 self.assertEqual( 1827 3 * (1 + 4 + 9 + 16), 1828 root(constant_op.constant(3, dtype=dtypes.int64)).numpy()) 1829 1830 def test_tuple_signature(self, cycles): 1831 root = checkpoint.Checkpoint() 1832 root.f = def_function.function( 1833 lambda: (array_ops.ones([]), array_ops.zeros([])), 1834 input_signature=()) 1835 root = cycle(root, cycles, signatures=root.f) 1836 self.assertEqual(({"output_0": 1., "output_1": 0.}), 1837 self.evaluate(root.signatures["serving_default"]())) 1838 1839 def test_version_info(self, cycles): 1840 root = checkpoint.Checkpoint() 1841 root = cycle(root, cycles) 1842 self.assertEqual(versions.__version__, root.tensorflow_version) 1843 self.assertEqual(versions.__git_version__, root.tensorflow_git_version) 1844 1845 def test_load_grad_save(self, cycles): 1846 root = checkpoint.Checkpoint() 1847 root.v = variables.Variable(2.) 1848 root.f = def_function.function(lambda x: root.v * x) 1849 root.g = def_function.function(root.f) 1850 for _ in range(cycles): 1851 with backprop.GradientTape() as tape: 1852 inp = constant_op.constant(2.) 1853 tape.watch(inp) 1854 output = root.g(inp) 1855 self.assertAllClose(4., output) 1856 self.assertAllClose(2., tape.gradient(output, inp)) 1857 root = cycle(root, 1) 1858 1859 def test_destroy_resource(self, cycles): 1860 1861 def get_handle(): 1862 return resource_variable_ops.var_handle_op( 1863 shape=tensor_shape.as_shape([]), 1864 dtype=dtypes.float32, 1865 shared_name="my_var_name", 1866 name="my_var", 1867 container="my_container") 1868 1869 class MyResource(resource.TrackableResource): 1870 1871 def _create_resource(self): 1872 return get_handle() 1873 1874 def _initialize(self): 1875 resource_variable_ops.assign_variable_op( 1876 self.resource_handle, 1.0, name="assign") 1877 1878 def _destroy_resource(self): 1879 handle = get_handle() 1880 resource_variable_ops.destroy_resource_op( 1881 handle, ignore_lookup_error=True) 1882 1883 class MyModel(autotrackable.AutoTrackable): 1884 1885 def __init__(self): 1886 super(MyModel, self).__init__() 1887 self.resource = MyResource() 1888 1889 @def_function.function(input_signature=[]) 1890 def increase(self): 1891 handle = self.resource.resource_handle 1892 resource_variable_ops.assign_add_variable_op( 1893 handle, 10.0, name="assign_add") 1894 return resource_variable_ops.read_variable_op(handle, dtypes.float32) 1895 1896 root = MyModel() 1897 imported = cycle(root, cycles) 1898 self.assertEqual(11, imported.increase().numpy()) # Create the resource. 1899 1900 handle = imported.resource.resource_handle 1901 1902 # Delete the imported SaveModel. Since we explicitly set the deleter, it 1903 # should destroy the resource automatically. 1904 del imported 1905 1906 # Try to destroy the resource again, should fail. 1907 with self.assertRaisesRegex(errors.NotFoundError, 1908 r"Resource .* does not exist."): 1909 resource_variable_ops.destroy_resource_op( 1910 handle, ignore_lookup_error=False) 1911 1912 def test_function_called_as_operation(self, cycles): 1913 1914 @framework_function.Defun(dtypes.float32) 1915 def inner(x): 1916 return x + 1. 1917 1918 @def_function.function( 1919 input_signature=[tensor_spec.TensorSpec([], dtypes.float32)]) 1920 def outer(x): 1921 return inner(x) 1922 1923 root = module.Module() 1924 root.f = outer 1925 imported = cycle(root, cycles) 1926 self.assertAllClose(2., imported.f(constant_op.constant(1.))) 1927 1928 def test_ragged(self, cycles): 1929 1930 @def_function.function 1931 def f(x, c=1): 1932 """Returns Tensor x incremented by Python constant c.""" 1933 return math_ops.add(x, c) 1934 1935 for c in (1, 2, 3): 1936 _ = f.get_concrete_function( 1937 ragged_tensor.RaggedTensorSpec([None, None], dtype=dtypes.int32), 1938 c) 1939 1940 obj = autotrackable.AutoTrackable() 1941 obj.f = f 1942 1943 imported1 = cycle(obj, cycles, signatures={}) 1944 rt = ragged_factory_ops.constant([[1, 2], [3]]) 1945 self.assertAllEqual(imported1.f(rt), [[2, 3], [4]]) 1946 self.assertAllEqual(imported1.f(rt, 2), [[3, 4], [5]]) 1947 self.assertAllEqual(imported1.f(rt, 3), [[4, 5], [6]]) 1948 1949 imported2 = cycle(obj, cycles) 1950 rt = ragged_factory_ops.constant([[1, 2], [3]]) 1951 self.assertAllEqual(imported2.f(rt, 1), [[2, 3], [4]]) 1952 self.assertAllEqual(imported2.f(rt, 2), [[3, 4], [5]]) 1953 self.assertAllEqual(imported2.f(rt, 3), [[4, 5], [6]]) 1954 1955 def test_accepts_io_device(self, cycles): 1956 options = load_options.LoadOptions() 1957 self.assertIsNone(options.experimental_io_device) 1958 options = load_options.LoadOptions(experimental_io_device="/job:localhost") 1959 self.assertEqual("/job:localhost", options.experimental_io_device) 1960 1961 def _custom_saveable_object(self, cycles): 1962 if context.is_tfrt_enabled(): 1963 self.skipTest("Disable due to b/190539415.") 1964 root = autotrackable.AutoTrackable() 1965 root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1) 1966 root.table.insert("foo", 15) 1967 root.table2 = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1) 1968 root.table2.insert("idk", 21) 1969 1970 @def_function.function( 1971 input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) 1972 def lookup(key): 1973 return root.table.lookup(key) 1974 1975 root.lookup = lookup 1976 1977 imported = cycle(root, cycles) 1978 self.assertEqual(self.evaluate(imported.lookup("foo")), 15) 1979 self.assertEqual(self.evaluate(imported.lookup("idk")), -1) 1980 1981 if not saveable_compat.force_checkpoint_conversion_enabled(): 1982 self.assertEqual({"table"}, 1983 imported.table._self_saveable_object_factories.keys()) 1984 1985 def test_load_custom_saveable_object(self, cycles): 1986 self._custom_saveable_object(cycles) 1987 1988 def test_load_custom_saveable_object_ckpt_conversion(self, cycles): 1989 # Tests custom saveable object with checkpoint conversion enabled (forces 1990 # Trackable-based checkpoint implementation). 1991 saveable_compat.force_checkpoint_conversion() 1992 self._custom_saveable_object(cycles) 1993 1994 def test_load_resource_with_dependency(self, cycles): 1995 # Test with StaticHashTable, which has a _initializer attribute that tracks 1996 # the Asset vocab table. 1997 1998 class MyLookupModel(autotrackable.AutoTrackable): 1999 2000 def __init__(self, vocab_file): 2001 2002 vocab_initializer = lookup_ops.TextFileInitializer( 2003 vocab_file, 2004 key_dtype=dtypes.string, 2005 key_index=lookup_ops.TextFileIndex.WHOLE_LINE, 2006 value_dtype=dtypes.int64, 2007 value_index=lookup_ops.TextFileIndex.LINE_NUMBER) 2008 self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer, 2009 default_value=-1) 2010 2011 @def_function.function(input_signature=[ 2012 tensor_spec.TensorSpec((None,), dtypes.string)]) 2013 def __call__(self, inputs): 2014 return self._vocab_table.lookup(inputs) 2015 2016 vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"])) 2017 root = MyLookupModel(vocab_file) 2018 imported = cycle(root, cycles) 2019 file_io.delete_file(vocab_file) 2020 self.assertAllEqual(imported(constant_op.constant(["d", "b"])), 2021 [3, 1]) 2022 2023 def test_custom_gradients(self, cycles): 2024 2025 @custom_gradient.custom_gradient 2026 def log1pexp(x): 2027 e = math_ops.exp(x) 2028 2029 def grad(dy): 2030 return dy * e # incorrect to check the custom gradients is respected. 2031 2032 return math_ops.log(1 + e), grad 2033 2034 @def_function.function 2035 def g(x): 2036 y = log1pexp(x) 2037 2038 @def_function.function 2039 def g_nest(): 2040 return log1pexp(y) 2041 2042 return g_nest() 2043 2044 @def_function.function 2045 def f(x): 2046 return log1pexp(g(x * x)) 2047 2048 v = variables.Variable(1.) 2049 2050 with backprop.GradientTape() as tape2: 2051 with backprop.GradientTape() as tape: 2052 tape.watch(v) 2053 y = f(v) 2054 expected_grads = tape.gradient(y, v) 2055 expected_grad_grads = tape2.gradient(expected_grads, v) 2056 2057 root = autotrackable.AutoTrackable() 2058 root.f = f 2059 loaded = cycle( 2060 root, cycles, options=save_options.SaveOptions( 2061 experimental_custom_gradients=True)) 2062 with backprop.GradientTape() as tape2: 2063 with backprop.GradientTape() as tape: 2064 tape.watch(v) 2065 y = loaded.f(v) 2066 grads = tape.gradient(y, v) 2067 grad_grads = tape2.gradient(grads, v) 2068 2069 self.assertAllClose(grads, expected_grads) 2070 self.assertAllClose(grad_grads, expected_grad_grads) 2071 2072 def test_custom_gradients_with_none_grad(self, cycles): 2073 # https://github.com/google/jax/issues/7123 2074 2075 @custom_gradient.custom_gradient 2076 def f(params, state): 2077 def grad_fn(*args): 2078 return args 2079 return (params, state), grad_fn 2080 @def_function.function(input_signature=[ 2081 tensor_spec.TensorSpec([], dtypes.float32), 2082 tensor_spec.TensorSpec([], dtypes.int32)]) 2083 def predict(params, state): 2084 return f(params, state) 2085 2086 params = variables.Variable(1.0) 2087 # None grads only appear when state is an int. 2088 state = constant_op.constant(3, dtype=dtypes.int32) 2089 with backprop.GradientTape() as tape: 2090 tape.watch(params) 2091 y = predict(params, state) 2092 expected_grads = tape.gradient(y, params) 2093 2094 root = autotrackable.AutoTrackable() 2095 root.fn = predict 2096 loaded = cycle( 2097 root, cycles, options=save_options.SaveOptions( 2098 experimental_custom_gradients=True)) 2099 2100 with backprop.GradientTape() as tape: 2101 tape.watch(params) 2102 y = loaded.fn(params, state) 2103 grads = tape.gradient(y, params) 2104 2105 self.assertAllClose(grads, expected_grads) 2106 2107 2108class SingleCycleTests(test.TestCase, parameterized.TestCase): 2109 2110 def test_load_with_tags(self): 2111 root = autotrackable.AutoTrackable() 2112 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 2113 save.save(root, path) 2114 with self.assertRaises(ValueError): 2115 load.load(path, tags=[tag_constants.EVAL]) 2116 load.load(path, tags=[tag_constants.SERVING]) 2117 load.load(path, tags=tag_constants.SERVING) 2118 load.load(path, tags=set([tag_constants.SERVING])) 2119 2120 def test_save_load_contains_with_fspath(self): 2121 root = autotrackable.AutoTrackable() 2122 path = pathlib.Path(tempfile.mkdtemp(prefix=self.get_temp_dir())) 2123 save.save(root, path) 2124 self.assertTrue(loader_impl.contains_saved_model(path)) 2125 load.load(path) 2126 2127 def test_single_restore_op_used(self): 2128 root = module.Module() 2129 root.v1 = variables.Variable(1.) 2130 root.v2 = variables.Variable(2.) 2131 root.v3 = variables.Variable(3.) 2132 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 2133 save.save(root, path) 2134 restore_count = 0 2135 2136 def _count_restores(op_type, *unused_args, **unused_kwargs): 2137 nonlocal restore_count 2138 if op_type == b"RestoreV2": 2139 restore_count += 1 2140 2141 op_callbacks.add_op_callback(_count_restores) 2142 load.load(path) 2143 op_callbacks.remove_op_callback(_count_restores) 2144 self.assertEqual(1, restore_count) 2145 2146 def test_docstring_examples(self): 2147 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 2148 exported = checkpoint.Checkpoint(v=variables.Variable(3.)) 2149 exported.f = def_function.function( 2150 lambda x: exported.v * x, 2151 input_signature=[ 2152 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)]) 2153 save.save(exported, path) 2154 imported = load.load(path) 2155 self.assertEqual(3., imported.v.numpy()) 2156 self.assertEqual(6., imported.f(x=constant_op.constant(2.)).numpy()) 2157 2158 save.save(exported, path, exported.f.get_concrete_function()) 2159 imported = load.load(path) 2160 f = imported.signatures["serving_default"] 2161 self.assertAllEqual( 2162 [[-3.]], 2163 f(x=constant_op.constant([[-1.]]))["output_0"].numpy()) 2164 2165 def test_object_with_extra_dependencies(self): 2166 2167 class Extra(autotrackable.AutoTrackable): 2168 2169 def _trackable_children(self, save_type, **kwargs): 2170 children = super(Extra, self)._trackable_children(save_type, **kwargs) 2171 children["a"] = variables.Variable(5.) 2172 return children 2173 2174 root = Extra() 2175 path = tempfile.mkdtemp(prefix=self.get_temp_dir()) 2176 save.save(root, path) 2177 imported = load.load(path) 2178 self.assertEqual(5, self.evaluate(imported.a)) 2179 2180 def test_save_cached_variable(self): 2181 with ops.Graph().as_default(), session_lib.Session() as session: 2182 obj = autotrackable.AutoTrackable() 2183 obj.v = variables.Variable(2., caching_device=lambda op: op.device) 2184 obj.w = variables.Variable(3.) 2185 session.run([obj.v.initializer, obj.w.initializer]) 2186 2187 @def_function.function 2188 def total(): 2189 return obj.v + obj.w 2190 2191 @def_function.function(input_signature=[tensor_spec.TensorSpec([])]) 2192 def wrapped_total(x): 2193 return total() + x 2194 2195 @def_function.function 2196 def increment_v(x): 2197 obj.v.assign_add(x) 2198 2199 session.run(increment_v(constant_op.constant(3.))) # generate signatures 2200 self.assertAllClose(8, total()) 2201 self.assertAllClose(13, wrapped_total(constant_op.constant(5.))) 2202 2203 obj.total = total 2204 obj.wrapped_total = wrapped_total.get_concrete_function() 2205 obj.increment_v = increment_v 2206 2207 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 2208 save.save(obj, save_dir, signatures=total.get_concrete_function()) 2209 imported = load.load(save_dir) 2210 session.run(variables.global_variables_initializer()) 2211 self.assertAllClose(8, imported.total()) 2212 session.run(imported.increment_v(4)) 2213 self.assertAllClose(12, imported.total()) 2214 self.assertAllClose(15, imported.wrapped_total(constant_op.constant(3.))) 2215 self.assertAllClose({"output_0": 12}, 2216 imported.signatures["serving_default"]()) 2217 2218 # Try loading and running the function in eager mode 2219 imported = load.load(save_dir) 2220 self.assertAllClose(8, imported.total()) 2221 imported.increment_v(5) 2222 self.assertAllClose(13, imported.total()) 2223 self.assertAllClose(13.5, imported.wrapped_total(constant_op.constant(.5))) 2224 self.assertAllClose({"output_0": 13}, 2225 imported.signatures["serving_default"]()) 2226 2227 # TODO(allenl, kkb): Use the new memory checker here once it's fast enough (3 2228 # iterations took hundreds of seconds). It would be really nice to check 2229 # allocations at a lower level. 2230 @test_util.assert_no_new_pyobjects_executing_eagerly 2231 def test_functions_cleaned(self): 2232 if sys.version_info.major < 3: 2233 self.skipTest("Not working in Python 2") 2234 root = module.Module() 2235 root.v = variables.Variable(1.) 2236 root.f = def_function.function( 2237 lambda x: x + root.v, 2238 input_signature=[ 2239 tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)]) 2240 cycle(root, 1) 2241 2242 def test_load_partial_object(self): 2243 root = module.Module() 2244 root.variables_holder = module.Module() 2245 root.variables_holder.v = variables.Variable(1.) 2246 2247 class Adder(module.Module): 2248 2249 @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=[])]) 2250 def __call__(self, y): 2251 root.variables_holder.v.assign_add(y) 2252 return 1 2253 2254 root.adder = Adder() 2255 2256 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 2257 save.save(root, save_dir) 2258 2259 imported = load.load_partial(save_dir, 2260 ["root.variables_holder.v", "root.adder"]) 2261 v = imported["root.variables_holder.v"] 2262 adder = imported["root.adder"] 2263 self.assertEqual(self.evaluate(v), 1) 2264 adder(5) 2265 self.assertEqual(self.evaluate(v), 6) 2266 2267 with self.assertRaisesRegex( 2268 ValueError, "does not include all required objects for loading"): 2269 imported = load.load_partial(save_dir, ["root.adder"]) 2270 2271 def test_load_partial_checkpoint(self): 2272 root = module.Module() 2273 root.variables_holder = module.Module() 2274 root.variables_holder.v = variables.Variable(1.) 2275 2276 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 2277 save.save(root, save_dir) 2278 2279 loaded = module.Module() 2280 loaded.v = variables.Variable(2.) 2281 2282 load.load_partial( 2283 save_dir, {"root": loaded}, 2284 options=load_options.LoadOptions(allow_partial_checkpoint=True)) 2285 self.assertEqual(loaded.variables_holder.v.numpy(), 1) 2286 with self.assertRaisesRegex(AssertionError, "were not bound"): 2287 load.load_partial(save_dir, {"root": loaded}) 2288 2289 def test_call_untraced_function_raises_error(self): 2290 2291 class ObjWithFunction(module.Module): 2292 2293 @def_function.function 2294 def foo(self, a): 2295 return a 2296 2297 root = ObjWithFunction() 2298 with self.assertLogs(level="WARNING") as logs: 2299 loaded = cycle(root, 1) 2300 2301 expected_save_message = ( 2302 "WARNING:absl:Found untraced functions such as foo while saving " 2303 "(showing 1 of 1). These functions will not be directly callable after " 2304 "loading.") 2305 self.assertIn(expected_save_message, logs.output) 2306 2307 with self.assertRaisesRegex( 2308 ValueError, "Found zero restored functions for caller function."): 2309 loaded.foo(1) 2310 2311 def test_restored_function_execute_eagerly(self): 2312 try: 2313 def_function.run_functions_eagerly(True) 2314 2315 class MyModel(module.Module): 2316 2317 @def_function.function 2318 def __call__(self, inputs, training=False): 2319 return math_ops.multiply(0.5, inputs) 2320 2321 model = MyModel() 2322 model.__call__.get_concrete_function( 2323 tensor_spec.TensorSpec([None], dtypes.float32)) 2324 loaded = cycle(model, 1) 2325 2326 # Calling the function should not throw an exception. 2327 loaded(constant_op.constant([1.0])) 2328 2329 finally: 2330 def_function.run_functions_eagerly(False) 2331 2332 def test_restored_model_concrete_function_is_deterministic(self): 2333 previous_concrete_function = None 2334 for _ in range(100): 2335 2336 class MyModel(module.Module): 2337 2338 @def_function.function 2339 def __call__(self, x): 2340 return x * constant_op.constant(3.0) 2341 2342 model = MyModel() 2343 model(array_ops.ones((7, 3), dtype=dtypes.float32)) 2344 model.__call__.get_concrete_function( 2345 tensor_spec.TensorSpec([None, 3], dtypes.float32)) 2346 loaded = cycle(model, 1) 2347 2348 # Ensure the newly loaded concrete function is the same as the previous 2349 # after a cycle of serialization / deserialization. 2350 new_concrete_function = loaded.__call__.get_concrete_function( 2351 tensor_spec.TensorSpec([None, 3], dtypes.float32)) 2352 if previous_concrete_function is not None: 2353 self.assertEqual(previous_concrete_function.pretty_printed_signature(), 2354 new_concrete_function.pretty_printed_signature()) 2355 2356 previous_concrete_function = new_concrete_function 2357 2358 def test_garbage_collection_capturable_resource_doesnt_raise_exception(self): 2359 model = module.Module() 2360 model.mapping = lookup_ops.StaticHashTable( 2361 lookup_ops.KeyValueTensorInitializer( 2362 keys=math_ops.range(1, dtype=dtypes.int32), 2363 values=["foo"]), 2364 "default_value") 2365 loaded = cycle(model, 1) 2366 del model 2367 del loaded 2368 # Exceptions raised during garbage collection are simply printed to stderr 2369 # and ignored, and we have no way to access them. We'll capture stdout 2370 # during the garbage collection process and inspect to see if any 2371 # exceptions were raised. 2372 stderr = io.StringIO() 2373 with contextlib.redirect_stderr(stderr): 2374 gc.collect() 2375 if "Exception ignored in" in stderr.getvalue(): 2376 raise Exception(stderr.getvalue()) 2377 2378 def test_captured_dataset_with_asset(self): 2379 2380 class HasDataset(module.Module): 2381 2382 def __init__(self, temp_dir, file_name): 2383 super(HasDataset, self).__init__() 2384 file = os.path.join(temp_dir, file_name) 2385 with tf_record.TFRecordWriter(file, "GZIP") as f: 2386 for v in ["a", "aa", "aaa"]: 2387 f.write(str(v)) 2388 self.dataset = readers.TFRecordDataset([file], compression_type="GZIP") 2389 2390 @def_function.function 2391 def __call__(self, x): 2392 current_sum = array_ops.zeros([], dtype=dtypes.int32) 2393 for element in self.dataset: 2394 current_sum += x * string_ops.string_length(element) 2395 return current_sum 2396 2397 temp_dir = self.get_temp_dir() 2398 file_name = "tf_record_asset.tfrecord.gz" 2399 root = HasDataset(temp_dir, file_name) 2400 self.assertEqual( 2401 18, # 3 * (1 + 2 + 3) 2402 root(constant_op.constant(3, dtype=dtypes.int32)).numpy()) 2403 2404 save_dir = os.path.join(self.get_temp_dir(), "save_dir") 2405 save.save(root, save_dir) 2406 2407 file_io.delete_file(os.path.join(temp_dir, file_name)) 2408 asset_path = os.path.join(save_dir, "assets/{}".format(file_name)) 2409 self.assertTrue(file_io.file_exists(asset_path)) 2410 load_dir = os.path.join(self.get_temp_dir(), "load_dir") 2411 file_io.rename(save_dir, load_dir) 2412 2413 loaded = load.load(load_dir) 2414 self.assertEqual( 2415 18, # 3 * (1 + 2 + 3) 2416 loaded(constant_op.constant(3, dtype=dtypes.int32)).numpy()) 2417 2418 2419class DeferredInitModuleVariablesTest(test.TestCase): 2420 2421 def test_deferred_init_module_variables(self): 2422 """Defer initialization of variables in a module to the load stage.""" 2423 2424 class MyModule(module.Module): 2425 2426 def __init__(self, size): 2427 super().__init__() 2428 self.size = size 2429 # variable initialized by a Tensor-compatible value 2430 self.w1 = variables.Variable( 2431 constant_op.constant(1., shape=[self.size]), trainable=False) 2432 # variable initialized by a function 2433 self.w2 = variables.Variable( 2434 lambda: constant_op.constant(2., shape=[self.size])) 2435 # variable instantiated lazily in call() 2436 self.w3 = None 2437 2438 def call(self): 2439 if self.w3 is None: 2440 self.w3 = variables.Variable( 2441 constant_op.constant(3., shape=[self.size])) 2442 for w in (self.w1, self.w2, self.w3): 2443 w.assign_add(constant_op.constant(1., shape=[self.size])) 2444 return self.w1, self.w2, self.w3 2445 2446 def export_initializer(initial_value, export_dir): 2447 2448 class Initializer(module.Module): 2449 2450 @def_function.function(input_signature=[]) 2451 def call(self): 2452 if callable(initial_value): 2453 return initial_value() 2454 return initial_value 2455 2456 save.save(Initializer(), export_dir) 2457 2458 def create_and_save_module(weight_size): 2459 2460 initial_values = {} # For storing initial_value of created variables 2461 2462 def variable_creator(next_creator, **kwargs): 2463 variable = next_creator(**kwargs) 2464 variable_name = variable.name 2465 if ":" in variable_name: 2466 variable_name = variable_name[:variable_name.index(":")] 2467 initial_values[variable_name] = kwargs["initial_value"] 2468 return variable 2469 2470 export_dir = self.create_tempdir().full_path 2471 2472 with ops.Graph().as_default(): 2473 with variable_scope.variable_creator_scope(variable_creator): 2474 exported = MyModule(weight_size) 2475 exported.call = def_function.function(input_signature=[])( 2476 exported.call) 2477 2478 module_dir = f"{export_dir}/module" 2479 file_io.recursive_create_dir(module_dir) 2480 save.save_and_return_nodes( 2481 exported, module_dir, experimental_skip_checkpoint=True) 2482 2483 # Save the initializer of the created variables. 2484 for variable_name, initial_value in initial_values.items(): 2485 export_initializer(initial_value, 2486 f"{export_dir}/variables/{variable_name}") 2487 2488 return export_dir 2489 2490 def load_and_run_module(export_dir, weight_size): 2491 2492 # pylint: disable=unused-argument 2493 def layer_variable_creator(next_creator, **kwargs): 2494 variable_dir = f"{export_dir}/variables/{kwargs['name']}" 2495 initializer = load.load(variable_dir) 2496 kwargs["initial_value"] = initializer.call 2497 variable = resource_variable_ops.ResourceVariable(**kwargs) 2498 return variable 2499 2500 with ops.Graph().as_default(): 2501 with variable_scope.variable_creator_scope(layer_variable_creator): 2502 imported = load.load( 2503 f"{export_dir}/module", 2504 options=load_options.LoadOptions( 2505 experimental_skip_checkpoint=True)) 2506 outputs = imported.call() 2507 2508 with self.cached_session() as sess: 2509 variables.global_variables_initializer().run() 2510 # Check if variables work as expected across multiple iterations. 2511 for i in range(3): 2512 np_outputs = sess.run(outputs) 2513 for j, np_output in enumerate(np_outputs): 2514 self.assertAllClose(np_output, np.full(weight_size, i + j + 2)) 2515 2516 # The size of the serialized content (both module and variables) stays 2517 # small even with a large weight_size as the initial values are not stored 2518 # in checkpoints. 2519 weight_size = 1024 2520 export_dir = create_and_save_module(weight_size) 2521 load_and_run_module(export_dir, weight_size) 2522 2523 def _make_asset(self, contents): 2524 fd, filename = tempfile.mkstemp(prefix=self.get_temp_dir()) 2525 with os.fdopen(fd, "w") as f: 2526 f.write(contents) 2527 return filename 2528 2529 def test_assets(self): 2530 2531 class MyLookupModel(autotrackable.AutoTrackable): 2532 2533 def __init__(self, vocab_file): 2534 2535 vocab_initializer = lookup_ops.TextFileInitializer( 2536 vocab_file, 2537 key_dtype=dtypes.string, 2538 key_index=lookup_ops.TextFileIndex.WHOLE_LINE, 2539 value_dtype=dtypes.int64, 2540 value_index=lookup_ops.TextFileIndex.LINE_NUMBER) 2541 self._vocab_table = lookup_ops.StaticHashTable(vocab_initializer, 2542 default_value=-1) 2543 2544 @def_function.function(input_signature=[ 2545 tensor_spec.TensorSpec((None,), dtypes.string)]) 2546 def __call__(self, inputs): 2547 return self._vocab_table.lookup(inputs) 2548 2549 vocab_file = self._make_asset("\n".join(["a", "b", "c", "d"])) 2550 root = MyLookupModel(vocab_file) 2551 2552 save_dir = os.path.join(self.get_temp_dir(), "save_dir") 2553 save.save_and_return_nodes( 2554 root, save_dir, experimental_skip_checkpoint=True) 2555 file_io.delete_file(vocab_file) 2556 load_dir = os.path.join(self.get_temp_dir(), "load_dir") 2557 file_io.rename(save_dir, load_dir) 2558 2559 imported = load.load( 2560 load_dir, 2561 options=load_options.LoadOptions(experimental_skip_checkpoint=True)) 2562 self.assertAllEqual(imported(constant_op.constant(["d", "b"])), 2563 [3, 1]) 2564 2565 2566class _TestModel(module.Module): 2567 2568 def __init__(self, rows, cols): 2569 super().__init__() 2570 self.rows = rows 2571 self.cols = cols 2572 self.table = None 2573 2574 def __call__(self, x): 2575 with ops.device("/cpu:0"): 2576 self.table = variables.Variable( 2577 constant_op.constant(1., shape=[self.rows, self.cols])) 2578 x = math_ops.matmul(self.table, x) 2579 x = math_ops.reduce_sum(x, axis=0) 2580 return x 2581 2582 2583class SavedModelLoadMemoryTests(test.TestCase): 2584 2585 @test_util.run_gpu_only 2586 def test_no_oom_loading_large_tenor(self): 2587 if not config.get_soft_device_placement(): 2588 self.skipTest("This test only works for soft device placement is on") 2589 save_dir = os.path.join(self.get_temp_dir(), "saved_model") 2590 ncols = 16 2591 nrows = 32 2592 model = _TestModel(rows=nrows, cols=ncols) 2593 x = array_ops.zeros(shape=(ncols, 2), dtype=dtypes.float32) 2594 y = model(x) 2595 save.save( 2596 model, 2597 save_dir, 2598 options=save_options.SaveOptions( 2599 experimental_variable_policy=save_options.VariablePolicy 2600 .SAVE_VARIABLE_DEVICES), 2601 ) 2602 loaded_on_cpu = load.load( 2603 export_dir=save_dir, 2604 options=load_options.LoadOptions( 2605 experimental_variable_policy=save_options.VariablePolicy 2606 .SAVE_VARIABLE_DEVICES), 2607 ) 2608 loaded_on_gpu = load.load(export_dir=save_dir) 2609 self.assertTrue("CPU" in loaded_on_cpu.table.device) 2610 self.assertTrue("GPU" in loaded_on_gpu.table.device) 2611 2612 2613if __name__ == "__main__": 2614 test.main() 2615