1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17import copy 18import os 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python import tf2 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.distribute import combinations 27from tensorflow.python.distribute import strategy_combinations 28from tensorflow.python.distribute import test_util as ds_test_util 29from tensorflow.python.distribute import tpu_strategy 30from tensorflow.python.distribute import tpu_values 31from tensorflow.python.distribute import values as values_lib 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import test 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import sparse_tensor 39from tensorflow.python.framework import test_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import sparse_ops 43from tensorflow.python.ops import variable_scope 44from tensorflow.python.ops import variables as variables_lib 45from tensorflow.python.training import saver as saver_lib 46 47 48def _device_str(d): 49 return "/device:GPU:" + str(d) 50 51 52def _nested_value(d): 53 return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) 54 55 56def mirrored_and_tpu_strategy_combinations(): 57 return combinations.combine( 58 distribution=[ 59 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 60 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 61 strategy_combinations.tpu_strategy, 62 strategy_combinations.tpu_strategy_packed_var, 63 strategy_combinations.tpu_strategy_spmd, 64 ], 65 mode=["graph", "eager"]) 66 67 68class DistributedValuesTest(test.TestCase, parameterized.TestCase): 69 70 @combinations.generate( 71 combinations.combine( 72 distribution=(strategy_combinations.all_strategies_minus_default + 73 strategy_combinations.multiworker_strategies), 74 mode=["eager"] 75 )) 76 def testMakeDistributedValueFromTensor(self, distribution): 77 if not tf2.enabled(): 78 self.skipTest("Only V2 is supported.") 79 single_value = constant_op.constant(1) 80 def value_fn(ctx): 81 del ctx 82 return single_value 83 84 distributed_values = ( 85 distribution.experimental_distribute_values_from_function(value_fn)) 86 self.assertAllEqual( 87 ds_test_util.gather(distribution, distributed_values), 88 constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) 89 90 @combinations.generate( 91 combinations.combine( 92 distribution=(strategy_combinations.all_strategies_minus_default + 93 strategy_combinations.multiworker_strategies), 94 mode=["eager"] 95 )) 96 def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): 97 if not tf2.enabled(): 98 self.skipTest("Only V2 is supported.") 99 array_value = np.array([1., 2., 3.]) 100 def value_fn(ctx): 101 del ctx 102 return array_value 103 104 distributed_values = ( 105 distribution.experimental_distribute_values_from_function(value_fn)) 106 self.assertAllEqual( 107 ds_test_util.gather(distribution, distributed_values).numpy(), 108 [[1., 2., 3.]] * distribution.num_replicas_in_sync) 109 110 @combinations.generate( 111 combinations.combine( 112 distribution=(strategy_combinations.all_strategies_minus_default + 113 strategy_combinations.multiworker_strategies), 114 mode=["eager"] 115 )) 116 def testMakeDistributedValueTupleConstant(self, distribution): 117 if not tf2.enabled(): 118 self.skipTest("Only V2 is supported.") 119 tuple_value = (1., 2., 3.) 120 def value_fn(ctx): 121 del ctx 122 return tuple_value 123 distributed_values = ( 124 distribution.experimental_distribute_values_from_function(value_fn)) 125 distributed_values = ds_test_util.gather(distribution, distributed_values) 126 127 # Expected output for 2 replicas: 128 # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) 129 expected = tuple([v for i in range(distribution.num_replicas_in_sync)] 130 for v in tuple_value) 131 self.assertAllEqual(distributed_values, expected) 132 133 @combinations.generate( 134 combinations.combine( 135 distribution=(strategy_combinations.all_strategies_minus_default + 136 strategy_combinations.multiworker_strategies), 137 mode=["eager"] 138 )) 139 def testMakeDistributedValueNestedStructurePerReplica(self, distribution): 140 if not tf2.enabled(): 141 self.skipTest("Only V2 is supported.") 142 tuple_value = (1., 2., 3.) 143 def value_fn(ctx): 144 per_replica = [] 145 for val in tuple_value: 146 per_replica.append(val * ctx.replica_id_in_sync_group) 147 return tuple(per_replica) 148 distributed_values = ( 149 distribution.experimental_distribute_values_from_function(value_fn)) 150 distributed_values = ds_test_util.gather(distribution, distributed_values) 151 152 # Expected output for 2 replicas: 153 # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) 154 expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] 155 for v in tuple_value) 156 self.assertAllEqual(distributed_values, expected) 157 158 # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because 159 # collective ops do not support SparseTensors. 160 @combinations.generate( 161 combinations.combine( 162 distribution=strategy_combinations.all_strategies_minus_default, 163 mode=["eager"] 164 )) 165 def testMakeDistributedValueSpareTensor(self, distribution): 166 if not tf2.enabled(): 167 self.skipTest("Only V2 is supported.") 168 def value_fn(ctx): 169 del ctx 170 return sparse_tensor.SparseTensor( 171 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 172 173 distributed_values = ( 174 distribution.experimental_distribute_values_from_function(value_fn)) 175 local_results = distribution.experimental_local_results(distributed_values) 176 for i in range(distribution.num_replicas_in_sync): 177 self.assertAllEqual( 178 sparse_ops.sparse_tensor_to_dense(local_results[i]), 179 [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) 180 181 @combinations.generate( 182 combinations.combine( 183 distribution=(strategy_combinations.all_strategies_minus_default + 184 strategy_combinations.multiworker_strategies), 185 mode=["eager"] 186 )) 187 def testMakeDistributedValueExtractFromArray(self, distribution): 188 if not tf2.enabled(): 189 self.skipTest("Only V2 is supported.") 190 multiple_values = range(distribution.num_replicas_in_sync) 191 def value_fn(ctx): 192 return multiple_values[ctx.replica_id_in_sync_group] 193 distributed_values = ( 194 distribution.experimental_distribute_values_from_function(value_fn)) 195 distributed_values = ds_test_util.gather(distribution, distributed_values) 196 expected = range(distribution.num_replicas_in_sync) 197 self.assertAllEqual(distributed_values, expected) 198 199 @combinations.generate( 200 combinations.combine( 201 distribution=(strategy_combinations.all_strategies_minus_default + 202 strategy_combinations.multiworker_strategies), 203 mode=["eager"] 204 )) 205 def testMakeDistributedValueAndRun(self, distribution): 206 if not tf2.enabled(): 207 self.skipTest("Only V2 is supported.") 208 209 @def_function.function 210 def run(): 211 multiple_values = range(distribution.num_replicas_in_sync) 212 def value_fn(ctx): 213 return multiple_values[ctx.replica_id_in_sync_group] 214 distributed_values = ( 215 distribution.experimental_distribute_values_from_function(value_fn)) 216 217 def computation(x): 218 return math_ops.square(x) 219 220 outputs = ds_test_util.gather( 221 distribution, 222 distribution.run(computation, args=(distributed_values,))) 223 return outputs 224 225 results = run() 226 227 expected = [i**2 for i in range(distribution.num_replicas_in_sync)] 228 self.assertAllEqual(results, expected) 229 230 @combinations.generate( 231 combinations.combine( 232 distribution=[ 233 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 234 strategy_combinations 235 .mirrored_strategy_with_two_gpus_no_merge_call, 236 strategy_combinations.tpu_strategy, 237 strategy_combinations.tpu_strategy_packed_var, 238 strategy_combinations.central_storage_strategy_with_two_gpus, 239 ] + strategy_combinations.multiworker_strategies, 240 mode=["eager"])) 241 def testMakeDistributedValueDefaultDevicePlacement(self, distribution): 242 if not tf2.enabled(): 243 self.skipTest("Only V2 is supported.") 244 def value_fn(ctx): 245 del ctx 246 return constant_op.constant(1.0) 247 distributed_values = ( 248 distribution.experimental_distribute_values_from_function(value_fn)) 249 default_device = array_ops.identity(constant_op.constant(1.0)).device 250 for i in range(len(distribution.extended.worker_devices)): 251 self.assertAllEqual(distributed_values._values[i].device, default_device) 252 253 @combinations.generate( 254 combinations.combine( 255 distribution=[ 256 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 257 strategy_combinations 258 .mirrored_strategy_with_two_gpus_no_merge_call, 259 strategy_combinations.tpu_strategy, 260 strategy_combinations.tpu_strategy_packed_var, 261 strategy_combinations.central_storage_strategy_with_two_gpus, 262 ] + strategy_combinations.multiworker_strategies, 263 mode=["eager"], 264 op_type=[constant_op.constant, array_ops.identity])) 265 def testMakeDistributedValueExplicitDevicePlacement(self, distribution, 266 op_type): 267 if not tf2.enabled(): 268 self.skipTest("Only V2 is supported.") 269 worker_devices = distribution.extended.worker_devices 270 def value_fn(ctx): 271 # In multi client setup, worker_devices is just the devices on that 272 # worker. 273 worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices) 274 with ops.device(worker_devices[worker_device_id]): 275 return op_type(1.0) 276 277 distributed_values = ( 278 distribution.experimental_distribute_values_from_function(value_fn)) 279 for i in range(len(distribution.extended.worker_devices)): 280 self.assertAllEqual(distributed_values._values[i].device, 281 worker_devices[i]) 282 283 284class PerReplicaTest(test.TestCase, parameterized.TestCase): 285 286 @combinations.generate( 287 combinations.combine( 288 distribution=[ 289 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 290 strategy_combinations 291 .mirrored_strategy_with_two_gpus_no_merge_call, 292 strategy_combinations.tpu_strategy, 293 strategy_combinations.tpu_strategy_packed_var, 294 strategy_combinations.central_storage_strategy_with_two_gpus, 295 ] + strategy_combinations.multiworker_strategies, 296 mode=["eager"])) 297 def testUsePerReplicaInvalidContextGivesError(self, distribution): 298 if not tf2.enabled(): 299 self.skipTest("Only V2 is supported.") 300 multiple_values = range(distribution.num_replicas_in_sync) 301 def value_fn(ctx): 302 return multiple_values[ctx.replica_id_in_sync_group] 303 distributed_values = ( 304 distribution.experimental_distribute_values_from_function(value_fn)) 305 with self.assertRaisesRegex(ValueError, "not inside a replica context"): 306 math_ops.cast(distributed_values, dtypes.float32) 307 308 309class PerWorkerResourceTest(test.TestCase, parameterized.TestCase): 310 311 @combinations.generate( 312 combinations.combine(dataset_fn_as_tf_function=[True, False])) 313 def testMapFnTracing(self, dataset_fn_as_tf_function): 314 # For a PerWorkerResource to correctly behave when used in dataset.map, 315 # it has to be that the map_fn is not traced only once such that 316 # PerWorkerResource.local_table can return the correct resource. This test 317 # can detect the potential breakage of this behavior on TAP. 318 self._traced_once = 0 319 320 def map_fn(x): 321 self._traced_once += 1 322 return x 323 324 def dataset_fn(): 325 dataset = dataset_ops.DatasetV2.from_tensors([0, 1, 2]).repeat().batch( 326 2, drop_remainder=True) 327 dataset = dataset.map(map_fn) 328 return dataset 329 330 datasets = [] 331 number_of_input_pipelines = 5 332 333 if dataset_fn_as_tf_function: 334 dataset_fn = def_function.function(dataset_fn) 335 expected_tracing_times = 1 336 else: 337 expected_tracing_times = number_of_input_pipelines 338 339 for _ in range(number_of_input_pipelines): 340 datasets.append(dataset_fn()) 341 342 self.assertEqual(self._traced_once, expected_tracing_times) 343 344 345class DistributedDelegateTest(test.TestCase): 346 347 @test_util.run_in_graph_and_eager_modes 348 def testGetAttr(self): 349 class Foo(object): 350 351 def __init__(self, x): 352 self.x = x 353 354 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 355 self.assertEqual(7, v.x) 356 with self.assertRaises(AttributeError): 357 _ = v.y 358 359 @test_util.run_in_graph_and_eager_modes 360 def testOperatorOverride(self): 361 v = values_lib.DistributedDelegate((7, 8)) 362 # v should act like int(7). 363 self.assertEqual(8, v + 1) 364 self.assertEqual(10, 3 + v) 365 self.assertEqual(14, v + v) 366 self.assertEqual(5, v - 2) 367 self.assertEqual(6, 13 - v) 368 self.assertEqual(0, v - v) 369 self.assertEqual(14, v * 2) 370 self.assertEqual(21, 3 * v) 371 self.assertEqual(49, v * v) 372 self.assertEqual(3.5, v / 2) 373 self.assertEqual(1.5, 10.5 / v) 374 self.assertEqual(3, v // 2) 375 self.assertEqual(2, 15 // v) 376 self.assertEqual(1, v % 2) 377 self.assertEqual(2, 16 % v) 378 # pylint: disable=g-generic-assert 379 self.assertTrue(v < 12) 380 self.assertTrue(v <= 12) 381 self.assertFalse(v > 12) 382 self.assertFalse(v >= 12) 383 self.assertFalse(12 < v) 384 self.assertFalse(12 <= v) 385 self.assertTrue(12 > v) 386 self.assertTrue(12 >= v) 387 # pylint: enable=g-generic-assert 388 self.assertEqual(3, v & 3) 389 self.assertEqual(3, 11 & v) 390 self.assertEqual(15, v | 8) 391 self.assertEqual(23, 16 | v) 392 self.assertEqual(4, v ^ 3) 393 self.assertEqual(12, 11 ^ v) 394 self.assertEqual(343, pow(v, 3)) 395 self.assertEqual(3, pow(v, 3, 10)) 396 self.assertEqual(128, pow(2, v)) 397 self.assertEqual(-7, -v) 398 self.assertEqual(~7, ~v) 399 self.assertEqual(7, abs(v)) 400 with self.assertRaises(TypeError): 401 _ = v[2] 402 403 @test_util.run_in_graph_and_eager_modes 404 def testCopy(self): 405 406 class Foo(object): 407 408 def __init__(self, x): 409 self.x = x 410 411 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 412 v_shallow_copy = copy.copy(v) 413 self.assertEqual(v.x, v_shallow_copy.x) 414 v_deep_copy = copy.deepcopy(v) 415 self.assertEqual(v.x, v_deep_copy.x) 416 417 418_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1) 419 420 421def _make_replica_local(method, strategy=None): 422 if strategy is None: 423 devices = ("/device:GPU:0", "/device:CPU:0") 424 else: 425 devices = strategy.extended.worker_devices 426 427 v = [] 428 for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): 429 with ops.device(d): 430 v.append(variable_scope.get_variable( 431 name=n, initializer=init, use_resource=True)) 432 433 if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): 434 var_cls = tpu_values.TPUSyncOnReadVariable 435 else: 436 var_cls = values_lib.SyncOnReadVariable 437 replica_local = var_cls(strategy, v, method) 438 return v, replica_local 439 440 441class DistributedVariableTest(test.TestCase, parameterized.TestCase): 442 443 def _assign_replica_local(self, v, new): 444 for var, n in zip(v, new): 445 with ops.device(var.device): 446 self.evaluate(var.assign(n)) 447 448 def _save_return_saver(self, sess, var): 449 saver = saver_lib.Saver(var_list=[var]) 450 test_dir = self.get_temp_dir() 451 prefix = os.path.join(test_dir, "ckpt") 452 return saver.save(sess, prefix), saver 453 454 def _save(self, sess, var): 455 save_path, _ = self._save_return_saver(sess, var) 456 return save_path 457 458 config = config_pb2.ConfigProto() 459 config.allow_soft_placement = True 460 461 @test_util.run_in_graph_and_eager_modes(config=config) 462 def testProperties(self): 463 if context.num_gpus() < 1 and context.executing_eagerly(): 464 self.skipTest("A GPU is not available for this test in eager mode.") 465 v, replica_local = _make_replica_local( 466 variable_scope.VariableAggregation.SUM) 467 468 self.assertEqual(v[0].constraint, replica_local.constraint) 469 self.assertEqual(v[0].name, replica_local.name) 470 self.assertEqual(v[0].dtype, replica_local.dtype) 471 self.assertEqual(v[0].shape, replica_local.shape) 472 self.assertEqual(variable_scope.VariableAggregation.SUM, 473 replica_local.aggregation) 474 475 @combinations.generate( 476 combinations.combine( 477 distribution=[ 478 strategy_combinations.mirrored_strategy_with_gpu_and_cpu 479 ], 480 mode=["eager"])) 481 def testCanPassToDefFun(self, distribution): 482 483 @def_function.function 484 def add1(x): 485 return x + 1. 486 487 with distribution.scope(): 488 v = variables_lib.Variable( 489 1., 490 aggregation=variables_lib.VariableAggregation.MEAN, 491 synchronization=variables_lib.VariableSynchronization.ON_READ) 492 493 self.assertEqual(2., self.evaluate(add1(v))) 494 495 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 496 def testTensorConversion(self, distribution): 497 with context.graph_mode(): 498 _, replica_local = _make_replica_local( 499 variable_scope.VariableAggregation.SUM, distribution) 500 converted = ops.convert_to_tensor(replica_local, as_ref=False) 501 self.assertIsInstance(converted, ops.Tensor) 502 self.assertEqual(converted.dtype, replica_local.dtype) 503 504 converted = ops.convert_to_tensor(replica_local, as_ref=True) 505 # Resources variable are converted to tensors as well when as_ref is True. 506 self.assertIsInstance(converted, ops.Tensor) 507 self.assertEqual(converted.dtype, replica_local.dtype) 508 509 @combinations.generate(combinations.combine( 510 distribution=[ 511 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 512 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 513 strategy_combinations.tpu_strategy, 514 strategy_combinations.tpu_strategy_packed_var, 515 ], mode=["eager"])) 516 def testValueInCrossReplicaContext(self, distribution): 517 value_list, replica_local = _make_replica_local( 518 variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) 519 520 self.assertIsInstance(replica_local.value(), ops.Tensor) 521 self.assertEqual(self.evaluate(replica_local.value()), 522 self.evaluate(value_list[0].value())) 523 524 @combinations.generate( 525 combinations.combine( 526 distribution=[ 527 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 528 strategy_combinations.tpu_strategy_packed_var, 529 ], 530 mode=["eager"])) 531 def testValueInDefaultReplicaContext(self, distribution): 532 with distribution.scope(): 533 v1 = variables_lib.Variable( 534 0.0, 535 aggregation=variables_lib.VariableAggregation.SUM, 536 synchronization=variables_lib.VariableSynchronization.ON_READ) 537 v2 = variables_lib.Variable( 538 0.0, 539 aggregation=variables_lib.VariableAggregation.SUM, 540 synchronization=variables_lib.VariableSynchronization.ON_READ) 541 542 @def_function.function 543 def replica_fn(): 544 v1.assign_add(1.0) 545 v2.assign_add(2.0) 546 547 distribution.run(replica_fn) 548 sum_v = v1 + v2 549 self.assertEqual(sum_v, 6.0) 550 551 @combinations.generate( 552 combinations.combine( 553 distribution=[ 554 strategy_combinations.tpu_strategy_packed_var, 555 ], 556 mode=["eager"])) 557 def testValueInFunctionCrossReplicaContext(self, distribution): 558 with distribution.scope(): 559 v1 = variables_lib.Variable( 560 0.0, 561 aggregation=variables_lib.VariableAggregation.NONE, 562 synchronization=variables_lib.VariableSynchronization.ON_WRITE) 563 564 @def_function.function 565 def assign_fn(): 566 v1.assign(1.0) 567 568 assign_fn() 569 self.assertEqual(v1, 1.0) 570 571 # Make sure the function graph has composite variable as inputs. 572 graph_def = assign_fn.get_concrete_function().graph.as_graph_def() 573 self.assertRegex(str(graph_def), "device:COMPOSITE:0") 574 575 @combinations.generate( 576 combinations.combine( 577 distribution=[ 578 strategy_combinations.tpu_strategy_packed_var, 579 ], 580 mode=["eager"])) 581 def testReplicatedValueNameDeterministic(self, distribution): 582 with distribution.scope(): 583 v1 = variables_lib.Variable(0.0, name="test_var_1") 584 v2 = variables_lib.Variable(0.0, name="test_var_2") 585 586 def fn(): 587 v1.assign_add(1.0) 588 v2.assign_add(2.0) 589 return v1 + v2 590 591 @def_function.function 592 def dist_run_fn(): 593 a = distribution.run(fn) 594 return a 595 596 concrete_fn = dist_run_fn.get_concrete_function() 597 inputs = concrete_fn.graph.inputs 598 self.assertLen(inputs, 2) 599 # Before cl/433948982, input name will include a non-deterministic uid, 600 # e.g. "test_var_1_139726389910864/handle/inputs_0:0" 601 self.assertEqual(inputs[0].name, "test_var_1/handle/inputs_0:0") 602 self.assertEqual(inputs[1].name, "test_var_2/handle/inputs_0:0") 603 604 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 605 def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): 606 with self.cached_session() as sess: 607 v, replica_local = _make_replica_local( 608 variable_scope.VariableAggregation.SUM, distribution) 609 610 # Overwrite the initial values. 611 self._assign_replica_local(v, [3., 4.]) 612 613 with distribution.scope(): 614 # Saves the current value of v[0] + v[1], 7. 615 save_path, saver = self._save_return_saver(sess, replica_local) 616 617 # Change the values between save and restore. 618 self._assign_replica_local(v, [5., 6.]) 619 620 # Restores the saved value of 7. which gets divided equally 621 # between the variables. 622 saver.restore(sess, save_path) 623 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 624 625 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 626 def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): 627 if context.num_gpus() < 1 and context.executing_eagerly(): 628 self.skipTest("A GPU is not available for this test in eager mode.") 629 630 with self.cached_session() as sess: 631 v, replica_local = _make_replica_local( 632 variable_scope.VariableAggregation.MEAN, distribution) 633 634 # Overwrite the initial values. 635 self._assign_replica_local(v, [3., 4.]) 636 637 with distribution.scope(): 638 # Saves the current value of (v[0] + v[1])/2, 3.5. 639 save_path, saver = self._save_return_saver(sess, replica_local) 640 641 # Change the values between save and restore. 642 self._assign_replica_local(v, [5., 6.]) 643 644 # Restores the saved value of 3.5 to both variables. 645 saver.restore(sess, save_path) 646 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 647 648 def _save_replica_local_mean(self, distribution): 649 """Save variables with mirroring, returns save_path.""" 650 with self.session(graph=ops.Graph()) as sess: 651 v, replica_local = _make_replica_local( 652 variable_scope.VariableAggregation.MEAN, distribution) 653 654 # Overwrite the initial values. 655 self._assign_replica_local(v, [3., 4.]) 656 657 with distribution.scope(): 658 # Saves the current value of (v[0] + v[1])/2, 3.5 659 save_path = self._save(sess, replica_local) 660 661 # Change the values between save and restore. 662 self._assign_replica_local(v, [5., 6.]) 663 return save_path 664 665 def _save_replica_local_sum(self, distribution): 666 """Save variables with mirroring, returns save_path.""" 667 with self.session(graph=ops.Graph()) as sess: 668 v, replica_local = _make_replica_local( 669 variable_scope.VariableAggregation.SUM, distribution) 670 671 # Overwrite the initial values. 672 self._assign_replica_local(v, [1.5, 2.]) 673 674 with distribution.scope(): 675 # Saves the current value of v[0] + v[1], 3.5 676 save_path = self._save(sess, replica_local) 677 678 # Change the values between save and restore. 679 self._assign_replica_local(v, [5., 6.]) 680 return save_path 681 682 def _save_normal(self): 683 """Save variables without mirroring, returns save_path.""" 684 with self.session(graph=ops.Graph()) as sess: 685 var = variable_scope.get_variable( 686 name="v", initializer=1., use_resource=True) 687 688 # Overwrite the initial value. 689 self.evaluate(var.assign(3.5)) 690 691 # Saves the current value of var, 3.5. 692 save_path = self._save(sess, var) 693 694 # Change the values between save and restore. 695 self.evaluate(var.assign(5.)) 696 return save_path 697 698 def _restore_normal(self, save_path): 699 """Restore to variables without mirroring in a fresh graph.""" 700 with self.session(graph=ops.Graph()) as sess: 701 var = variable_scope.get_variable( 702 name="v", initializer=7., use_resource=True) 703 704 # Overwrite the initial value. 705 self.evaluate(var.assign(8.)) 706 707 # Restores the saved value of 3.5 to `var`. 708 saver = saver_lib.Saver(var_list=[var]) 709 saver.restore(sess, save_path) 710 self.assertEqual(3.5, self.evaluate(var)) 711 712 def _restore_replica_local_mean(self, save_path, distribution): 713 """Restore to variables with mirroring in a fresh graph.""" 714 with self.session(graph=ops.Graph()) as sess: 715 v, replica_local = _make_replica_local( 716 variable_scope.VariableAggregation.MEAN, distribution) 717 718 # Overwrite the initial values. 719 self._assign_replica_local(v, [7., 8.]) 720 721 with distribution.scope(): 722 # Restores the saved value of 3.5 to both variables. 723 saver = saver_lib.Saver(var_list=[replica_local]) 724 saver.restore(sess, save_path) 725 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 726 727 def _restore_replica_local_sum(self, save_path, distribution): 728 """Restore to variables with mirroring in a fresh graph.""" 729 with self.session(graph=ops.Graph()) as sess: 730 v, replica_local = _make_replica_local( 731 variable_scope.VariableAggregation.SUM, distribution) 732 733 # Overwrite the initial values. 734 self._assign_replica_local(v, [7., 8.]) 735 736 with distribution.scope(): 737 # Restores the saved value of 3.5 to both variables. 738 saver = saver_lib.Saver(var_list=[replica_local]) 739 saver.restore(sess, save_path) 740 self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) 741 742 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 743 def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): 744 save_path = self._save_replica_local_mean(distribution) 745 self._restore_replica_local_mean(save_path, distribution) 746 747 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 748 def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): 749 save_path = self._save_replica_local_sum(distribution) 750 self._restore_replica_local_sum(save_path, distribution) 751 752 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 753 def testSaveReplicaLocalMeanRestoreNormal(self, distribution): 754 save_path = self._save_replica_local_mean(distribution) 755 self._restore_normal(save_path) 756 757 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 758 def testSaveReplicaLocalSumRestoreNormal(self, distribution): 759 save_path = self._save_replica_local_sum(distribution) 760 self._restore_normal(save_path) 761 762 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 763 def testSaveNormalRestoreReplicaLocalMean(self, distribution): 764 save_path = self._save_normal() 765 self._restore_replica_local_mean(save_path, distribution) 766 767 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 768 def testSaveNormalRestoreReplicaLocalSum(self, distribution): 769 save_path = self._save_normal() 770 self._restore_replica_local_sum(save_path, distribution) 771 772 773if __name__ == "__main__": 774 ds_test_util.main() 775