1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17import itertools 18 19import uuid 20from absl.testing import parameterized 21from tensorflow.python.checkpoint import checkpoint as trackable_utils 22from tensorflow.python.checkpoint import checkpoint_management as ckpt_manager 23from tensorflow.python.distribute import collective_all_reduce_strategy 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import distribution_strategy_context as ds_context 26from tensorflow.python.distribute import strategy_combinations 27from tensorflow.python.distribute import strategy_test_lib 28from tensorflow.python.distribute import test_util 29from tensorflow.python.distribute import values 30from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import test 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import indexed_slices 37from tensorflow.python.framework import ops 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables as variables_lib 43from tensorflow.python.tpu import tpu_strategy_util 44from tensorflow.python.util import variable_utils 45 46 47def strategy_and_run_tf_function_combinations(): 48 # Test the combination of different strategies and whether a tf.function 49 # is passed into strategy.run.""" 50 # TODO(b/197981388): re-enable MWMS test 51 # return combinations.combine( 52 # distribution=[ 53 # strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 54 # ], 55 # mode=["graph", "eager"], 56 # experimental_run_tf_function=[True, False], 57 # use_var_policy=[True, False]) + 58 return combinations.combine( 59 distribution=[ 60 strategy_combinations.tpu_strategy, 61 strategy_combinations.tpu_strategy_packed_var, 62 ], 63 mode=["graph", "eager"], 64 experimental_run_tf_function=[True], 65 use_var_policy=[True, False]) 66 67 68def strategy_with_var_policy(): 69 return combinations.combine( 70 distribution=[ 71 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 72 # TODO(b/197981388): re-enable MWMS test 73 # strategy_combinations.multi_worker_mirrored_2x1_cpu, 74 # strategy_combinations.multi_worker_mirrored_2x1_gpu, 75 strategy_combinations.tpu_strategy, 76 strategy_combinations.tpu_strategy_packed_var, 77 ], 78 mode=["graph", "eager"], 79 use_var_policy=[True, False]) 80 81 82class OnWriteVariableSync(test.TestCase, parameterized.TestCase): 83 84 @combinations.generate(strategy_and_run_tf_function_combinations()) 85 def testAssign(self, distribution, experimental_run_tf_function): 86 87 def assign(fn, v, update_value, cross_replica): 88 update_fn = lambda: getattr(v, fn)(update_value) 89 if cross_replica: 90 return update_fn() 91 else: 92 if experimental_run_tf_function: 93 update_fn = def_function.function(update_fn) 94 return test_util.gather(distribution, distribution.run(update_fn)) 95 96 updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)] 97 aggregations = [ 98 variables_lib.VariableAggregation.NONE, 99 variables_lib.VariableAggregation.SUM, 100 variables_lib.VariableAggregation.MEAN, 101 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 102 ] 103 options = list( 104 x for x in itertools.product(updates, aggregations, [True, False])) 105 for update, aggregation, cross_replica in options: 106 # assign in replica context with SUM does not make sense cause you can 107 # just do value * num replicas error is 1. is not a distributed value and 108 # is unsupported for aggregation SUM 109 if (not cross_replica and aggregation == 110 variables_lib.VariableAggregation.SUM): 111 continue 112 with distribution.scope(): 113 v = variable_scope.variable( 114 0., 115 aggregation=aggregation) 116 self.evaluate(variables_lib.global_variables_initializer()) 117 fn, update_value = update 118 self.evaluate(assign(fn, v, update_value, cross_replica)) 119 for component in v._values: 120 self.assertAllEqual(self.evaluate(component.read_value()), 121 self.evaluate(array_ops.ones_like(component))) 122 123 @combinations.generate(strategy_and_run_tf_function_combinations()) 124 def testAssignOnWriteVar(self, distribution, experimental_run_tf_function): 125 126 with distribution.scope(): 127 v_to_assign = variable_scope.variable( 128 2., aggregation=variables_lib.VariableAggregation.MEAN) 129 v_to_assign_sub = variable_scope.variable( 130 -2., aggregation=variables_lib.VariableAggregation.MEAN) 131 132 def assign(fn, v, update_value, cross_replica): 133 update_fn = lambda: getattr(v, fn)(update_value) 134 if cross_replica: 135 return update_fn() 136 else: 137 if experimental_run_tf_function: 138 update_fn = def_function.function(update_fn) 139 return test_util.gather(distribution, distribution.run(update_fn)) 140 141 updates = [("assign", v_to_assign), ("assign_add", v_to_assign), 142 ("assign_sub", v_to_assign_sub)] 143 aggregations = [ 144 variables_lib.VariableAggregation.NONE, 145 variables_lib.VariableAggregation.SUM, 146 variables_lib.VariableAggregation.MEAN, 147 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 148 ] 149 options = list( 150 x for x in itertools.product(updates, aggregations, [True, False])) 151 for update, aggregation, cross_replica in options: 152 # assign in replica context with SUM does not make sense cause you can 153 # just do value * num replicas error is 1. is not a distributed value and 154 # is unsupported for aggregation SUM 155 if aggregation == variables_lib.VariableAggregation.SUM: 156 continue 157 with distribution.scope(): 158 v = variable_scope.variable( 159 0., 160 aggregation=aggregation) 161 self.evaluate(variables_lib.global_variables_initializer()) 162 fn, update_value = update 163 self.evaluate(assign(fn, v, update_value, cross_replica)) 164 for component in v._values: 165 self.assertAllEqual(2.0, self.evaluate(component.read_value())) 166 167 @combinations.generate(strategy_and_run_tf_function_combinations()) 168 def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function): 169 170 if strategy_test_lib.is_tpu_strategy(distribution): 171 self.skipTest("Assigning PerReplica values is not supported. See" 172 " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.") 173 174 with distribution.scope(): 175 per_replica_value = values.PerReplica( 176 [constant_op.constant(2.0), 177 constant_op.constant(2.0)]) 178 per_replica_sub_value = values.PerReplica( 179 [constant_op.constant(-2.0), 180 constant_op.constant(-2.0)]) 181 182 def assign(fn, v, update_value, cross_replica): 183 update_fn = lambda: getattr(v, fn)(update_value) 184 if cross_replica: 185 return update_fn() 186 else: 187 if experimental_run_tf_function: 188 update_fn = def_function.function(update_fn) 189 return test_util.gather(distribution, distribution.run(update_fn)) 190 191 updates = [("assign", per_replica_value), ("assign_add", per_replica_value), 192 ("assign_sub", per_replica_sub_value)] 193 # We don't support assigning PerReplica valus to vars in replica context 194 # with aggregation=NONE. 195 aggregations = [ 196 variables_lib.VariableAggregation.SUM, 197 variables_lib.VariableAggregation.MEAN, 198 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 199 ] 200 options = list( 201 x for x in itertools.product(updates, aggregations, [True, False])) 202 for update, aggregation, cross_replica in options: 203 # assign in replica context with SUM does not make sense cause you can 204 # just do value * num replicas error is 1. is not a distributed value and 205 # is unsupported for aggregation SUM 206 if cross_replica: 207 # We don't support assigning PerReplica values to MirroredVariables in 208 # cross replica context 209 continue 210 with distribution.scope(): 211 v = variable_scope.variable( 212 0., 213 aggregation=aggregation) 214 self.evaluate(variables_lib.global_variables_initializer()) 215 fn, update_value = update 216 self.evaluate(assign(fn, v, update_value, cross_replica)) 217 if aggregation == variables_lib.VariableAggregation.SUM: 218 expected = 4.0 219 else: 220 expected = 2.0 221 for component in v._values: 222 self.assertAllEqual(expected, self.evaluate(component.read_value())) 223 224 @combinations.generate(strategy_with_var_policy()) 225 def testValueInReplicaContext(self, distribution): 226 with distribution.scope(): 227 v = variables_lib.Variable( 228 1., aggregation=variables_lib.VariableAggregation.MEAN) 229 self.evaluate(variables_lib.global_variables_initializer()) 230 231 @def_function.function 232 def f(): 233 with ops.control_dependencies([v.assign_add(1.)]): 234 return v.value() 235 236 results = self.evaluate( 237 test_util.gather(distribution, distribution.run(f))) 238 for value in results: 239 self.assertEqual(2., value) 240 241 @combinations.generate(strategy_with_var_policy()) 242 def testValueInReplicaContextAssignDirectValue(self, distribution, 243 use_var_policy): 244 with distribution.scope(): 245 v = variables_lib.Variable( 246 1., aggregation=variables_lib.VariableAggregation.MEAN) 247 self.evaluate(variables_lib.global_variables_initializer()) 248 249 @def_function.function 250 def f(): 251 with ops.control_dependencies([v.assign_add(1.)]): 252 return v.value() 253 254 results = self.evaluate( 255 test_util.gather(distribution, distribution.run(f))) 256 for value in results: 257 self.assertEqual(2., value) 258 259 @combinations.generate(strategy_and_run_tf_function_combinations()) 260 def testReadValueInReplicaContext(self, distribution, 261 experimental_run_tf_function): 262 aggregations = [ 263 variables_lib.VariableAggregation.NONE, 264 variables_lib.VariableAggregation.SUM, 265 variables_lib.VariableAggregation.MEAN, 266 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 267 ] 268 for aggregation in aggregations: 269 with distribution.scope(): 270 v = variable_scope.variable( 271 0., 272 aggregation=aggregation) 273 self.evaluate(variables_lib.global_variables_initializer()) 274 if experimental_run_tf_function: 275 read_var_fn = def_function.function(v.read_value) 276 else: 277 read_var_fn = v.read_value 278 results = self.evaluate( 279 test_util.gather(distribution, distribution.run(read_var_fn))) 280 for component, value in zip(v._values, results): 281 self.assertAllEqual(self.evaluate(component.read_value()), value) 282 283 @combinations.generate(strategy_and_run_tf_function_combinations()) 284 def testReadValueInCrossReplicaContext(self, distribution, 285 experimental_run_tf_function): 286 aggregations = [ 287 variables_lib.VariableAggregation.NONE, 288 variables_lib.VariableAggregation.SUM, 289 variables_lib.VariableAggregation.MEAN, 290 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 291 ] 292 for aggregation in aggregations: 293 with distribution.scope(): 294 v = variable_scope.variable( 295 2., 296 aggregation=aggregation) 297 self.evaluate(variables_lib.global_variables_initializer()) 298 299 if experimental_run_tf_function: 300 read_var_fn = def_function.function(v.read_value) 301 else: 302 read_var_fn = v.read_value 303 304 results = read_var_fn() 305 for component in v._values: 306 self.assertEqual(self.evaluate(component.read_value()), 307 self.evaluate(results)) 308 309 @combinations.generate(strategy_with_var_policy()) 310 def testAssignOutOfScope(self, distribution): 311 with distribution.scope(): 312 mirrored = variables_lib.Variable(1.) 313 self.evaluate(mirrored.assign(3.)) 314 self.assertEqual(self.evaluate(mirrored.read_value()), 3.) 315 for component in mirrored.values: 316 self.assertEqual(self.evaluate(component.read_value()), 3.) 317 318 @combinations.generate(strategy_with_var_policy()) 319 def testInitializedToSameValueInsideEagerRun(self, distribution): 320 if not context.executing_eagerly(): self.skipTest("eager only test") 321 if isinstance(distribution.extended, 322 collective_all_reduce_strategy.CollectiveAllReduceExtended): 323 self.skipTest("Test for more than 1 device per worker only.") 324 v = [None] 325 326 @def_function.function 327 def step(): 328 329 def f(): 330 if v[0] is None: 331 v[0] = variables_lib.Variable(random_ops.random_normal([])) 332 333 distribution.run(f) 334 335 context.set_global_seed(None) 336 step() 337 vals = self.evaluate(v[0].values) 338 self.assertAllEqual(vals[0], vals[1]) 339 340 @combinations.generate(strategy_with_var_policy()) 341 def testAggregationOnlyFirstReplica(self, distribution): 342 if isinstance(distribution.extended, 343 collective_all_reduce_strategy.CollectiveAllReduceExtended): 344 self.skipTest("b/212945803") 345 with distribution.scope(): 346 v = variable_scope.variable( 347 15., 348 synchronization=variables_lib.VariableSynchronization.ON_WRITE, 349 aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA) 350 self.evaluate(variables_lib.global_variables_initializer()) 351 352 @def_function.function 353 def assign(): 354 ctx = ds_context.get_replica_context() 355 replica_id = ctx.replica_id_in_sync_group 356 return v.assign(math_ops.cast(replica_id, dtypes.float32)) 357 358 per_replica_results = self.evaluate( 359 test_util.gather(distribution, distribution.run(assign))) 360 # The per-replica values should always match the first replicas value. 361 self.assertAllEqual( 362 array_ops.zeros(distribution.num_replicas_in_sync, dtypes.float32), 363 per_replica_results) 364 365 @combinations.generate(strategy_with_var_policy()) 366 def testInitScope(self, distribution): 367 if not context.executing_eagerly(): self.skipTest("eager only") 368 369 class C(object): 370 pass 371 372 obj = C() 373 obj.w = None 374 obj.v = None 375 376 @def_function.function 377 def assign(): 378 with ops.init_scope(): 379 if obj.w is None: 380 obj.w = variables_lib.Variable( 381 0., aggregation=variables_lib.VariableAggregation.MEAN) 382 obj.v = variables_lib.Variable( 383 obj.w.read_value(), 384 aggregation=variables_lib.VariableAggregation.MEAN) 385 self.evaluate(variables_lib.global_variables_initializer()) 386 387 return obj.v.assign_add(2.) 388 389 per_replica_results = self.evaluate( 390 test_util.gather(distribution, distribution.run(assign))) 391 self.assertAllEqual([2., 2.], per_replica_results) 392 393 @combinations.generate(strategy_with_var_policy()) 394 def testOperatorOverride(self, distribution): 395 396 if not context.executing_eagerly() and isinstance( 397 distribution.extended, 398 collective_all_reduce_strategy.CollectiveAllReduceExtended): 399 self.skipTest("b/212954197") 400 401 with distribution.scope(): 402 v = variable_scope.variable( 403 1, aggregation=variables_lib.VariableAggregation.SUM) 404 self.evaluate(variables_lib.global_variables_initializer()) 405 406 self.assertEqual(2, self.evaluate(v + 1)) 407 408 @def_function.function 409 def add(): 410 return v + 1 411 412 per_replica_results = self.evaluate( 413 test_util.gather(distribution, distribution.run(add))) 414 self.assertAllEqual([2, 2], per_replica_results) 415 416 @combinations.generate( 417 combinations.combine( 418 strategy=[ 419 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 420 strategy_combinations.tpu_strategy, 421 strategy_combinations.tpu_strategy_packed_var, 422 strategy_combinations.multi_worker_mirrored_2x1_cpu, 423 strategy_combinations.multi_worker_mirrored_2x1_gpu, 424 ], 425 mode=["eager"], 426 use_var_policy=[True, False])) 427 def testSaveAndRestoreOnWrite(self, strategy): 428 aggregation = [ 429 variable_scope.VariableAggregation.NONE, 430 variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, 431 variable_scope.VariableAggregation.SUM, 432 variable_scope.VariableAggregation.MEAN 433 ] 434 for agg in aggregation: 435 v_normal_restore = variables_lib.Variable(1.0) 436 v_normal_save = variables_lib.Variable(3.0) 437 with strategy.scope(): 438 v_on_write = variables_lib.Variable(2.0, aggregation=agg) 439 440 # Save ONWRITE Restore ONWRITE 441 # Save 442 ckpt = trackable_utils.Checkpoint(var=v_on_write) 443 manager = ckpt_manager.CheckpointManager( 444 ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None) 445 manager.save() 446 # Restore 447 ckpt.restore(manager.latest_checkpoint) 448 self.assertEqual(2.0, self.evaluate(v_on_write._values[0])) 449 self.assertEqual(2.0, self.evaluate(v_on_write.read_value())) 450 451 # Save Mirrored Restore Normal 452 # We've already saved Mirrored, so we only need to restore normal 453 ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore) 454 ckpt_normal.restore(manager.latest_checkpoint) 455 self.assertEqual(2.0, self.evaluate(v_on_write._values[0])) 456 self.assertEqual(2.0, self.evaluate(v_normal_restore.read_value())) 457 458 # Save Normal Restore Mirrored 459 # Save 460 ckpt = trackable_utils.Checkpoint(var=v_normal_save) 461 manager_2 = ckpt_manager.CheckpointManager( 462 ckpt, "/tmp/ckptckpt_" + str(uuid.uuid4()), max_to_keep=None) 463 manager_2.save() 464 # Restore 465 ckpt_on_write = trackable_utils.Checkpoint(var=v_on_write) 466 ckpt_on_write.restore(manager_2.latest_checkpoint) 467 self.assertEqual(3.0, self.evaluate(v_on_write._values[0])) 468 self.assertEqual(3.0, self.evaluate(v_on_write.read_value())) 469 470 471ms_combination = combinations.combine( 472 distribution=[strategy_combinations.mirrored_strategy_with_gpu_and_cpu], 473 mode=["graph", "eager"]) 474tpu_combination = combinations.combine( 475 distribution=[strategy_combinations.tpu_strategy_packed_var], 476 mode=["graph", "eager"]) 477 478 479class OnWriteVariableSyncScatterTests(test.TestCase, parameterized.TestCase): 480 481 @combinations.generate(ms_combination) 482 def testScatterSub(self, distribution): 483 with distribution.scope(): 484 v = variables_lib.Variable( 485 [0., 0., 0.], aggregation=variables_lib.VariableAggregation.MEAN) 486 self.evaluate(v.initializer) 487 488 @def_function.function 489 def scatter_sub(): 490 ctx = ds_context.get_replica_context() 491 replica_id = ctx.replica_id_in_sync_group 492 value = indexed_slices.IndexedSlices( 493 values=array_ops.stack([ 494 math_ops.cast(replica_id, dtypes.float32), 495 math_ops.cast(replica_id + 1, dtypes.float32) 496 ]), 497 indices=array_ops.stack([replica_id, replica_id + 1]), 498 dense_shape=(3,)) 499 return v.scatter_sub(value) 500 501 per_replica_results = self.evaluate( 502 distribution.experimental_local_results( 503 distribution.run(scatter_sub))) 504 self.assertAllEqual([[0., -1., -1.], [0., -1., -1.]], per_replica_results) 505 506 @combinations.generate(ms_combination) 507 def testScatterAdd(self, distribution): 508 with distribution.scope(): 509 v = variables_lib.Variable( 510 [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM) 511 self.evaluate(v.initializer) 512 513 @def_function.function 514 def scatter_add(): 515 ctx = ds_context.get_replica_context() 516 replica_id = ctx.replica_id_in_sync_group 517 value = indexed_slices.IndexedSlices( 518 values=array_ops.stack([replica_id, replica_id + 1]), 519 indices=array_ops.stack([replica_id, replica_id + 1]), 520 dense_shape=(3,)) 521 return v.scatter_add(value) 522 523 per_replica_results = self.evaluate( 524 test_util.gather(distribution, distribution.run(scatter_add))) 525 self.assertAllEqual([[0, 2, 2], [0, 2, 2]], per_replica_results) 526 527 @combinations.generate(ms_combination) 528 def testScatterDiv(self, distribution): 529 with distribution.scope(): 530 v = variables_lib.Variable( 531 [1, 6, 1], aggregation=variables_lib.VariableAggregation.SUM) 532 self.evaluate(v.initializer) 533 534 @def_function.function 535 def scatter_div(): 536 ctx = ds_context.get_replica_context() 537 replica_id = ctx.replica_id_in_sync_group 538 value = indexed_slices.IndexedSlices( 539 values=array_ops.reshape(replica_id + 2, [1]), 540 indices=array_ops.reshape(replica_id, [1]), 541 dense_shape=(3,)) 542 return v.scatter_div(value) 543 544 per_replica_results = self.evaluate( 545 test_util.gather(distribution, distribution.run(scatter_div))) 546 self.assertAllEqual([[0, 2, 1], [0, 2, 1]], per_replica_results) 547 548 @combinations.generate(ms_combination) 549 def testScatterMul(self, distribution): 550 with distribution.scope(): 551 v = variables_lib.Variable( 552 [2., 1., 1.], aggregation=variables_lib.VariableAggregation.MEAN) 553 self.evaluate(v.initializer) 554 555 @def_function.function 556 def scatter_mul(): 557 ctx = ds_context.get_replica_context() 558 replica_id = ctx.replica_id_in_sync_group 559 value = indexed_slices.IndexedSlices( 560 values=array_ops.reshape( 561 math_ops.cast(replica_id + 2, dtypes.float32), [1]), 562 indices=array_ops.reshape(replica_id, [1]), 563 dense_shape=(3,)) 564 return v.scatter_mul(value) 565 566 per_replica_results = self.evaluate( 567 test_util.gather(distribution, distribution.run(scatter_mul))) 568 self.assertAllClose([[2., 1.5, 1.], [2., 1.5, 1.]], per_replica_results) 569 570 @combinations.generate(ms_combination) 571 def testScatterMin(self, distribution): 572 with distribution.scope(): 573 v1 = variables_lib.Variable( 574 [0, 2, 0], aggregation=variables_lib.VariableAggregation.SUM) 575 v2 = variables_lib.Variable( 576 [0, 2, 0], 577 aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA) 578 self.evaluate(variables_lib.global_variables_initializer()) 579 580 @def_function.function 581 def scatter_min(v): 582 value = indexed_slices.IndexedSlices( 583 values=array_ops.identity([1]), 584 indices=array_ops.identity([1]), 585 dense_shape=(3,)) 586 return v.scatter_min(value) 587 588 with self.assertRaisesRegex(NotImplementedError, "scatter_min.*"): 589 self.evaluate( 590 test_util.gather(distribution, 591 distribution.run(scatter_min, args=(v1,)))) 592 593 per_replica_results = self.evaluate( 594 test_util.gather(distribution, 595 distribution.run(scatter_min, args=(v2,)))) 596 self.assertAllClose([[0, 1, 0], [0, 1, 0]], per_replica_results) 597 598 @combinations.generate(ms_combination) 599 def testScatterMax(self, distribution): 600 with distribution.scope(): 601 v1 = variables_lib.Variable( 602 [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM) 603 v2 = variables_lib.Variable( 604 [0, 0, 0], 605 aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA) 606 self.evaluate(variables_lib.global_variables_initializer()) 607 608 @def_function.function 609 def scatter_max(v): 610 value = indexed_slices.IndexedSlices( 611 values=array_ops.identity([1]), 612 indices=array_ops.identity([0]), 613 dense_shape=(3,)) 614 return v.scatter_max(value) 615 616 with self.assertRaisesRegex(NotImplementedError, "scatter_max.*"): 617 self.evaluate( 618 test_util.gather(distribution, 619 distribution.run(scatter_max, args=(v1,)))) 620 621 per_replica_results = self.evaluate( 622 test_util.gather(distribution, 623 distribution.run(scatter_max, args=(v2,)))) 624 self.assertAllClose([[1, 0, 0], [1, 0, 0]], per_replica_results) 625 626 @combinations.generate(ms_combination) 627 def testScatterUpdate(self, distribution): 628 with distribution.scope(): 629 v1 = variables_lib.Variable( 630 [0, 0, 0], aggregation=variables_lib.VariableAggregation.SUM) 631 v2 = variables_lib.Variable( 632 [0, 0, 0], 633 aggregation=variables_lib.VariableAggregation.ONLY_FIRST_REPLICA) 634 self.evaluate(variables_lib.global_variables_initializer()) 635 636 @def_function.function 637 def scatter_update(v): 638 value = indexed_slices.IndexedSlices( 639 values=array_ops.identity([3]), 640 indices=array_ops.identity([1]), 641 dense_shape=(3,)) 642 return v.scatter_update(value) 643 644 with self.assertRaisesRegex(NotImplementedError, "scatter_update.*"): 645 self.evaluate( 646 test_util.gather(distribution, 647 distribution.run(scatter_update, args=(v1,)))) 648 649 per_replica_results = self.evaluate( 650 test_util.gather(distribution, 651 distribution.run(scatter_update, args=(v2,)))) 652 self.assertAllClose([[0, 3, 0], [0, 3, 0]], per_replica_results) 653 654 @combinations.generate(ms_combination + tpu_combination) 655 def testScatterOpsWithNoneAggregation(self, distribution): 656 657 def assert_close(v, op, delta, expect): 658 scatter_op = getattr(v, op) 659 660 @def_function.function 661 def scatter_xxx(): 662 return scatter_op(delta) 663 664 per_replica_results = self.evaluate( 665 variable_utils.convert_variables_to_tensors( 666 distribution.experimental_local_results( 667 distribution.run(scatter_xxx)))) 668 self.assertAllClose([expect, expect], per_replica_results) 669 670 with distribution.scope(): 671 v = variables_lib.Variable( 672 [4.], aggregation=variables_lib.VariableAggregation.NONE) 673 self.evaluate(variables_lib.global_variables_initializer()) 674 675 delta = indexed_slices.IndexedSlices( 676 values=array_ops.identity([2.]), 677 indices=array_ops.identity([0]), 678 dense_shape=(1,)) 679 680 assert_close(v, "scatter_sub", delta, [2.]) 681 assert_close(v, "scatter_add", delta, [4.]) 682 assert_close(v, "scatter_max", delta, [4.]) 683 assert_close(v, "scatter_min", delta, [2.]) 684 assert_close(v, "scatter_mul", delta, [4.]) 685 assert_close(v, "scatter_div", delta, [2.]) 686 assert_close(v, "scatter_update", delta, [2.]) 687 688 @combinations.generate(ms_combination + tpu_combination) 689 def testScatterOpsInCrossReplicaContext(self, distribution): 690 with distribution.scope(): 691 v1 = variables_lib.Variable( 692 [1, 1, 1], aggregation=variables_lib.VariableAggregation.SUM) 693 v2 = variables_lib.Variable([1, 1, 1]) 694 self.evaluate(variables_lib.global_variables_initializer()) 695 696 value = indexed_slices.IndexedSlices( 697 values=array_ops.identity([2]), 698 indices=array_ops.identity([0]), 699 dense_shape=(3,)) 700 with distribution.scope(): 701 self.evaluate(v1.scatter_add(value)) 702 self.assertAllEqual([3, 1, 1], self.evaluate(v1.read_value())) 703 704 self.evaluate(v2.scatter_min(value)) 705 self.assertAllEqual([1, 1, 1], self.evaluate(v2.read_value())) 706 707 708class OnReadVariableSyncTest(test.TestCase, parameterized.TestCase): 709 710 @combinations.generate(strategy_and_run_tf_function_combinations()) 711 def testAssign(self, distribution, experimental_run_tf_function): 712 713 def assign(fn, v, update_value, cross_replica): 714 update_fn = lambda: getattr(v, fn)(update_value) 715 if cross_replica: 716 return update_fn() 717 else: 718 if experimental_run_tf_function: 719 update_fn = def_function.function(update_fn) 720 return test_util.gather(distribution, distribution.run(update_fn)) 721 722 updates = [("assign", 1.), ("assign_add", 1.), ("assign_sub", -1.)] 723 aggregations = [ 724 variables_lib.VariableAggregation.NONE, 725 variables_lib.VariableAggregation.SUM, 726 variables_lib.VariableAggregation.MEAN, 727 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 728 ] 729 options = list( 730 x for x in itertools.product(updates, aggregations, [True, False])) 731 for update, aggregation, cross_replica in options: 732 # VariableAggregation.SUM in cross-replica mode is tested below, 733 # VariableAggregation.NONE in cross-replica mode is not supported. 734 if cross_replica and aggregation in [ 735 variables_lib.VariableAggregation.SUM, 736 variables_lib.VariableAggregation.NONE, 737 ]: 738 continue 739 with distribution.scope(): 740 v = variable_scope.variable( 741 0., 742 synchronization=variables_lib.VariableSynchronization.ON_READ, 743 aggregation=aggregation) 744 self.evaluate(variables_lib.global_variables_initializer()) 745 fn, update_value = update 746 self.evaluate(assign(fn, v, update_value, cross_replica)) 747 for component in v._values: 748 self.assertAllEqual(self.evaluate(component.read_value()), 749 self.evaluate(array_ops.ones_like(component))) 750 751 @combinations.generate(strategy_and_run_tf_function_combinations()) 752 def testAssignOnReadVar(self, distribution, experimental_run_tf_function): 753 754 with distribution.scope(): 755 v_to_assign = variable_scope.variable( 756 2., aggregation=variables_lib.VariableAggregation.MEAN) 757 v_to_assign_sub = variable_scope.variable( 758 -2., aggregation=variables_lib.VariableAggregation.MEAN) 759 760 def assign(fn, v, update_value, cross_replica): 761 update_fn = lambda: getattr(v, fn)(update_value) 762 if cross_replica: 763 return update_fn() 764 else: 765 if experimental_run_tf_function: 766 update_fn = def_function.function(update_fn) 767 return test_util.gather(distribution, distribution.run(update_fn)) 768 769 updates = [("assign", v_to_assign), ("assign_add", v_to_assign), 770 ("assign_sub", v_to_assign_sub)] 771 expected_cross_replica = { 772 variables_lib.VariableAggregation.SUM: 1.0, 773 variables_lib.VariableAggregation.MEAN: 2.0, 774 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0 775 } 776 expected_replica = { 777 variables_lib.VariableAggregation.SUM: 2.0, 778 variables_lib.VariableAggregation.MEAN: 2.0, 779 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA: 2.0 780 } 781 # aggregation=NONE is not supported for OnReadVariables. 782 aggregations = [ 783 variables_lib.VariableAggregation.SUM, 784 variables_lib.VariableAggregation.MEAN, 785 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 786 ] 787 options = list( 788 x for x in itertools.product(updates, aggregations, [True, False])) 789 for update, aggregation, cross_replica in options: 790 # assign in replica context with SUM does not make sense cause you can 791 # just do value * num replicas error is 1. is not a distributed value and 792 # is unsupported for aggregation SUM 793 if aggregation == variables_lib.VariableAggregation.SUM: 794 continue 795 with distribution.scope(): 796 v = variable_scope.variable( 797 0., 798 aggregation=aggregation) 799 self.evaluate(variables_lib.global_variables_initializer()) 800 fn, update_value = update 801 self.evaluate(assign(fn, v, update_value, cross_replica)) 802 if cross_replica: 803 for component in v._values: 804 self.assertAllEqual(expected_cross_replica.get(aggregation), 805 self.evaluate(component.read_value())) 806 else: 807 for component in v._values: 808 self.assertAllEqual(expected_replica.get(aggregation), 809 self.evaluate(component.read_value())) 810 811 @combinations.generate(strategy_and_run_tf_function_combinations()) 812 def testAssignPerReplicaVal(self, distribution, experimental_run_tf_function): 813 814 if strategy_test_lib.is_tpu_strategy(distribution): 815 self.skipTest("Assigning PerReplica values is not supported. See" 816 " sponge/80ba41f8-4220-4516-98ce-bbad48f9f11a.") 817 818 self.skipTest("We don't support assiging PerReplica values in cross " 819 "replica context or replica context. see error in " 820 "sponge/2b2e54c1-eda6-4534-82e1-c73b1dcd517f.") 821 822 with distribution.scope(): 823 per_replica_value = values.PerReplica( 824 [constant_op.constant(2.0), 825 constant_op.constant(2.0)]) 826 827 def assign(fn, v, update_value, cross_replica): 828 update_fn = lambda: getattr(v, fn)(update_value) 829 if cross_replica: 830 return update_fn() 831 else: 832 if experimental_run_tf_function: 833 update_fn = def_function.function(update_fn) 834 return test_util.gather(distribution, distribution.run(update_fn)) 835 836 updates = [("assign", per_replica_value)] 837 # We don't support assigning PerReplica valus to vars in replica context 838 # with aggregation=NONE. 839 aggregations = [ 840 variables_lib.VariableAggregation.SUM, 841 variables_lib.VariableAggregation.MEAN, 842 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 843 ] 844 options = list( 845 x for x in itertools.product(updates, aggregations, [True, False])) 846 for update, aggregation, cross_replica in options: 847 # assign in replica context with SUM does not make sense cause you can 848 # just do value * num replicas error is 1. is not a distributed value and 849 # is unsupported for aggregation SUM 850 with distribution.scope(): 851 v = variable_scope.variable( 852 0., 853 synchronization=variables_lib.VariableSynchronization.ON_READ, 854 aggregation=aggregation) 855 self.evaluate(variables_lib.global_variables_initializer()) 856 fn, update_value = update 857 # with self.assertRaisesRegex(ValueError, "Attempt to convert a value "): 858 self.evaluate(assign(fn, v, update_value, cross_replica)) 859 if aggregation == variables_lib.VariableAggregation.SUM: 860 expected = 4.0 861 else: 862 expected = 2.0 863 for component in v._values: 864 self.assertAllEqual(expected, self.evaluate(component.read_value())) 865 866 @combinations.generate(strategy_and_run_tf_function_combinations()) 867 def testAssignDtypeConversion(self, distribution, 868 experimental_run_tf_function): 869 870 def assign(fn, v, update_value, cross_replica): 871 update_fn = lambda: getattr(v, fn)(update_value) 872 if cross_replica: 873 return update_fn() 874 else: 875 if experimental_run_tf_function: 876 update_fn = def_function.function(update_fn) 877 return test_util.gather(distribution, distribution.run(update_fn)) 878 879 updates = [("assign", 1), ("assign_add", 1), ("assign_sub", -1)] 880 aggregations = [ 881 variables_lib.VariableAggregation.NONE, 882 variables_lib.VariableAggregation.SUM, 883 variables_lib.VariableAggregation.MEAN, 884 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 885 ] 886 options = list( 887 x for x in itertools.product(updates, aggregations, [True, False])) 888 for update, aggregation, cross_replica in options: 889 # VariableAggregation.SUM in cross-replica mode is tested below, 890 # VariableAggregation.NONE in cross-replica mode is not supported. 891 if cross_replica and aggregation in [ 892 variables_lib.VariableAggregation.SUM, 893 variables_lib.VariableAggregation.NONE, 894 ]: 895 continue 896 with distribution.scope(): 897 v = variable_scope.variable( 898 0., 899 synchronization=variables_lib.VariableSynchronization.ON_READ, 900 aggregation=aggregation) 901 self.evaluate(variables_lib.global_variables_initializer()) 902 fn, update_value = update 903 self.evaluate(assign(fn, v, update_value, cross_replica)) 904 for component in v._values: 905 self.assertAllEqual(self.evaluate(component.read_value()), 906 self.evaluate(array_ops.ones_like(component))) 907 908 @combinations.generate(strategy_with_var_policy()) 909 def testAssignWithAggregationSum(self, distribution): 910 with distribution.scope(): 911 v = variable_scope.variable( 912 0., 913 synchronization=variables_lib.VariableSynchronization.ON_READ, 914 aggregation=variables_lib.VariableAggregation.SUM) 915 self.evaluate(variables_lib.global_variables_initializer()) 916 self.evaluate(v.assign(1. * distribution.num_replicas_in_sync)) 917 for component in v._values: 918 self.assertAllEqual(self.evaluate(component.read_value()), 919 self.evaluate(array_ops.ones_like(component))) 920 921 @combinations.generate(strategy_with_var_policy()) 922 def testAssignAddSubWithAggregationSum(self, distribution): 923 with distribution.scope(): 924 v = variable_scope.variable( 925 0., 926 synchronization=variables_lib.VariableSynchronization.ON_READ, 927 aggregation=variables_lib.VariableAggregation.SUM) 928 self.evaluate(variables_lib.global_variables_initializer()) 929 with self.assertRaisesRegex( 930 ValueError, "SyncOnReadVariable does not support "): 931 self.evaluate(v.assign_add(1.)) 932 with self.assertRaisesRegex( 933 ValueError, "SyncOnReadVariable does not support "): 934 self.evaluate(v.assign_sub(1.)) 935 936 @combinations.generate(strategy_and_run_tf_function_combinations()) 937 def testReadValueInReplicaContext(self, distribution, 938 experimental_run_tf_function): 939 aggregations = [ 940 variables_lib.VariableAggregation.NONE, 941 variables_lib.VariableAggregation.SUM, 942 variables_lib.VariableAggregation.MEAN, 943 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 944 ] 945 for aggregation in aggregations: 946 with distribution.scope(): 947 v = variable_scope.variable( 948 0., 949 synchronization=variables_lib.VariableSynchronization.ON_READ, 950 aggregation=aggregation) 951 self.evaluate(variables_lib.global_variables_initializer()) 952 if experimental_run_tf_function: 953 read_var_fn = def_function.function(v.read_value) 954 else: 955 read_var_fn = v.read_value 956 results = self.evaluate( 957 test_util.gather(distribution, distribution.run(read_var_fn))) 958 for component, value in zip(v._values, results): 959 self.assertAllEqual(self.evaluate(component.read_value()), value) 960 961 @combinations.generate(strategy_and_run_tf_function_combinations()) 962 def testReadValueInCrossReplicaContext(self, distribution, 963 experimental_run_tf_function): 964 aggregations = [ 965 variables_lib.VariableAggregation.SUM, 966 variables_lib.VariableAggregation.MEAN, 967 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 968 ] 969 for aggregation in aggregations: 970 if strategy_test_lib.is_tpu_strategy(distribution): 971 resolver = tpu_cluster_resolver.TPUClusterResolver("") 972 tpu_strategy_util.initialize_tpu_system(resolver) 973 with distribution.scope(): 974 v = variable_scope.variable( 975 0., 976 synchronization=variables_lib.VariableSynchronization.ON_READ, 977 aggregation=aggregation) 978 self.evaluate(variables_lib.global_variables_initializer()) 979 980 def assign(v=v): 981 ctx = ds_context.get_replica_context() 982 replica_id = ctx.replica_id_in_sync_group 983 return v.assign(math_ops.cast(replica_id, dtypes.float32)) 984 985 if experimental_run_tf_function: 986 assign = def_function.function(assign) 987 988 self.evaluate(test_util.gather(distribution, distribution.run(assign))) 989 num_replicas = distribution.num_replicas_in_sync 990 sum_of_replica_values = num_replicas * (num_replicas - 1) / 2. 991 if aggregation == variables_lib.VariableAggregation.SUM: 992 expected = sum_of_replica_values 993 elif aggregation == variables_lib.VariableAggregation.MEAN: 994 expected = sum_of_replica_values / num_replicas 995 else: 996 expected = 0 997 self.assertEqual(expected, self.evaluate(v.read_value()), aggregation) 998 self.assertEqual(expected, self.evaluate(v.value()), aggregation) 999 self.assertEqual(expected, self.evaluate(v), aggregation) 1000 self.assertEqual(expected, self.evaluate(array_ops.identity(v)), 1001 aggregation) 1002 1003 @combinations.generate(strategy_and_run_tf_function_combinations()) 1004 def testAllReduce(self, distribution, experimental_run_tf_function): 1005 with distribution.scope(): 1006 v = variable_scope.variable( 1007 2., 1008 synchronization=variables_lib.VariableSynchronization.ON_WRITE, 1009 aggregation=variables_lib.VariableAggregation.MEAN) 1010 self.evaluate(variables_lib.global_variables_initializer()) 1011 1012 def all_reduce(): 1013 ctx = ds_context.get_replica_context() 1014 replica_id = ctx.replica_id_in_sync_group 1015 return ctx.all_reduce("SUM", v) + math_ops.cast(replica_id, 1016 dtypes.float32) 1017 1018 if experimental_run_tf_function: 1019 all_reduce = def_function.function(all_reduce) 1020 1021 per_replica_results = self.evaluate( 1022 test_util.gather(distribution, distribution.run(all_reduce))) 1023 expected_result = [] 1024 for i in range(distribution.num_replicas_in_sync): 1025 expected_result.append(2.0 * distribution.num_replicas_in_sync + 1026 1.0 * i) 1027 self.assertAllEqual(per_replica_results, tuple(expected_result)) 1028 1029 @combinations.generate(strategy_and_run_tf_function_combinations()) 1030 def testAssignPerReplicaBeforeRead(self, distribution, 1031 experimental_run_tf_function): 1032 aggregations = [ 1033 variables_lib.VariableAggregation.SUM, 1034 variables_lib.VariableAggregation.MEAN, 1035 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 1036 ] 1037 for aggregation in aggregations: 1038 with distribution.scope(): 1039 v = variable_scope.variable( 1040 0., 1041 synchronization=variables_lib.VariableSynchronization.ON_READ, 1042 aggregation=aggregation) 1043 self.evaluate(variables_lib.global_variables_initializer()) 1044 1045 def assign(var=v): 1046 ctx = ds_context.get_replica_context() 1047 replica_id = ctx.replica_id_in_sync_group 1048 return var.assign(math_ops.cast(replica_id, dtypes.float32)) 1049 1050 if experimental_run_tf_function: 1051 assign = def_function.function(assign) 1052 1053 per_replica_results = self.evaluate( 1054 test_util.gather(distribution, distribution.run(assign))) 1055 expected_result = [] 1056 for i in range(distribution.num_replicas_in_sync): 1057 expected_result.append(1.0 * i) 1058 self.assertAllEqual(per_replica_results, tuple(expected_result)) 1059 1060 @combinations.generate(strategy_with_var_policy()) 1061 def testReadValueWithAggregationNoneInCrossReplicaContext(self, distribution): 1062 with distribution.scope(): 1063 v = variable_scope.variable( 1064 0., 1065 synchronization=variables_lib.VariableSynchronization.ON_READ, 1066 aggregation=variables_lib.VariableAggregation.NONE) 1067 self.evaluate(variables_lib.global_variables_initializer()) 1068 with self.assertRaisesRegex( 1069 ValueError, "Could not convert from .* VariableAggregation\\.NONE"): 1070 self.evaluate(v.read_value()) 1071 1072 @combinations.generate(strategy_with_var_policy()) 1073 def testInitializedToSameValueInsideEagerRun(self, distribution): 1074 if not context.executing_eagerly(): self.skipTest("eager only") 1075 if isinstance(distribution.extended, 1076 collective_all_reduce_strategy.CollectiveAllReduceExtended): 1077 self.skipTest("Test for more than 1 device per worker only.") 1078 1079 v = [None] 1080 @def_function.function 1081 def step(): 1082 def f(): 1083 if v[0] is None: 1084 v[0] = variables_lib.Variable( 1085 random_ops.random_normal([]), 1086 synchronization=variables_lib.VariableSynchronization.ON_READ) 1087 1088 distribution.run(f) 1089 1090 context.set_global_seed(None) 1091 step() 1092 vals = self.evaluate(v[0].values) 1093 self.assertAllEqual(vals[0], vals[1]) 1094 1095 @combinations.generate(strategy_with_var_policy()) 1096 def testOperatorOverride(self, distribution): 1097 1098 with distribution.scope(): 1099 v = variable_scope.variable( 1100 0.0, 1101 synchronization=variables_lib.VariableSynchronization.ON_READ, 1102 aggregation=variables_lib.VariableAggregation.MEAN) 1103 self.evaluate(variables_lib.global_variables_initializer()) 1104 1105 @def_function.function 1106 def assign(): 1107 ctx = ds_context.get_replica_context() 1108 replica_id = ctx.replica_id_in_sync_group 1109 return v.assign(math_ops.cast(replica_id, dtypes.float32)) 1110 1111 # Assign different replicas with different values. 1112 self.evaluate(test_util.gather(distribution, distribution.run(assign))) 1113 self.assertEqual(1.5, self.evaluate(v + 1)) 1114 1115 @def_function.function 1116 def add(): 1117 return v + 1 1118 1119 per_replica_results = self.evaluate( 1120 test_util.gather(distribution, distribution.run(add))) 1121 self.assertAllEqual([1, 2], per_replica_results) 1122 1123 @combinations.generate( 1124 combinations.combine( 1125 strategy=[ 1126 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1127 strategy_combinations.tpu_strategy, 1128 strategy_combinations.tpu_strategy_packed_var, 1129 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1130 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1131 ], 1132 mode=["eager"], 1133 use_var_policy=[True, False])) 1134 def testSaveAndRestoreOnRead(self, strategy): 1135 aggregation = [variable_scope.VariableAggregation.SUM, 1136 variable_scope.VariableAggregation.MEAN] 1137 for agg in aggregation: 1138 v_normal_restore = variables_lib.Variable(1.0) 1139 v_normal_save = variables_lib.Variable(2.0) 1140 1141 with strategy.scope(): 1142 v_on_read = variables_lib.Variable( 1143 1.0, synchronization=variable_scope.VariableSynchronization.ON_READ, 1144 aggregation=agg) 1145 1146 @def_function.function 1147 def assign_fn(): 1148 cluster_resolver = strategy.cluster_resolver 1149 replica_ctx = ds_context.get_replica_context() 1150 if ((cluster_resolver and cluster_resolver.task_type == "worker") or 1151 math_ops.equal(replica_ctx.replica_id_in_sync_group, 1152 constant_op.constant(1))): 1153 v_on_read.assign(3.) # pylint:disable=cell-var-from-loop 1154 else: 1155 v_on_read.assign(4.) # pylint:disable=cell-var-from-loop 1156 1157 strategy.run(assign_fn) 1158 1159 # Save ONREAD, restore ONREAD 1160 # Saves v[0] + v[1] = 7 for SUM and 3.5 for MEAN. 1161 ckpt = trackable_utils.Checkpoint(var=v_on_read) 1162 manager = ckpt_manager.CheckpointManager( 1163 ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None) 1164 manager.save() 1165 # Restores a value of 7/2 = 3.5 for SUM and 3.5 for MEAN. 1166 ckpt.restore(manager.latest_checkpoint) 1167 self.assertEqual(3.5, self.evaluate(v_on_read._values[0])) 1168 1169 # Save ONREAD, restore normal 1170 ckpt_normal = trackable_utils.Checkpoint(var=v_normal_restore) 1171 ckpt_normal.restore(manager.latest_checkpoint) 1172 if agg == variable_scope.VariableAggregation.SUM: 1173 self.assertEqual(7.0, self.evaluate(v_normal_restore.read_value())) 1174 else: 1175 self.assertEqual(3.5, self.evaluate(v_normal_restore.read_value())) 1176 1177 # Save normal, restore ONREAD 1178 ckpt = trackable_utils.Checkpoint(var=v_normal_save) 1179 manager = ckpt_manager.CheckpointManager( 1180 ckpt, "/tmp/ckpt_" + str(uuid.uuid4()), max_to_keep=None) 1181 manager.save() 1182 # Restores a value of 2/2 = 1.0 for SUM and 2.0 for MEAN. 1183 ckpt_on_read = trackable_utils.Checkpoint(var=v_on_read) 1184 ckpt_on_read.restore(manager.latest_checkpoint) 1185 if agg == variable_scope.VariableAggregation.SUM: 1186 self.assertEqual(1.0, self.evaluate(v_on_read._values[0])) 1187 else: 1188 self.assertEqual(2.0, self.evaluate(v_on_read._values[0])) 1189 1190 1191@combinations.generate( 1192 combinations.combine( 1193 distribution=[ 1194 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1195 strategy_combinations.multi_worker_mirrored_2x1_cpu, 1196 strategy_combinations.multi_worker_mirrored_2x1_gpu, 1197 ], 1198 aggregation=[ 1199 variables_lib.VariableAggregation.MEAN, 1200 variables_lib.VariableAggregation.SUM, 1201 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 1202 ], 1203 mode=["graph", "eager"], 1204 use_var_policy=[True, False])) 1205class SyncOnReadScatterReplicaTest(test.TestCase, parameterized.TestCase): 1206 1207 def testScatterSub(self, distribution, aggregation): 1208 with distribution.scope(): 1209 v = variables_lib.Variable( 1210 [1., 1., 1.], 1211 synchronization=variables_lib.VariableSynchronization.ON_READ, 1212 aggregation=aggregation) 1213 self.evaluate(v.initializer) 1214 1215 delta = values.PerReplica([ 1216 indexed_slices.IndexedSlices( 1217 values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)), 1218 indexed_slices.IndexedSlices( 1219 values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)), 1220 ]) 1221 1222 with self.assertRaises(NotImplementedError): 1223 self.evaluate(distribution.run(v.scatter_sub, args=(delta,))) 1224 1225 def testScatterAdd(self, distribution, aggregation): 1226 with distribution.scope(): 1227 v = variables_lib.Variable( 1228 [1., 1., 1.], 1229 synchronization=variables_lib.VariableSynchronization.ON_READ, 1230 aggregation=aggregation) 1231 self.evaluate(v.initializer) 1232 1233 delta = values.PerReplica([ 1234 indexed_slices.IndexedSlices( 1235 values=[[0.], [1.]], indices=[0, 1], dense_shape=(3,)), 1236 indexed_slices.IndexedSlices( 1237 values=[[1.], [2.]], indices=[1, 2], dense_shape=(3,)), 1238 ]) 1239 1240 with self.assertRaises(NotImplementedError): 1241 self.evaluate(distribution.run(v.scatter_add, args=(delta,))) 1242 1243 def testScatterDiv(self, distribution, aggregation): 1244 with distribution.scope(): 1245 v = variables_lib.Variable( 1246 [2., 6., 1.], 1247 synchronization=variables_lib.VariableSynchronization.ON_READ, 1248 aggregation=aggregation) 1249 self.evaluate(v.initializer) 1250 1251 delta = values.PerReplica([ 1252 indexed_slices.IndexedSlices( 1253 values=[[2.], [2.]], indices=[0, 1], dense_shape=(3,)), 1254 indexed_slices.IndexedSlices( 1255 values=[[3.], [3.]], indices=[1, 2], dense_shape=(3,)), 1256 ]) 1257 1258 with self.assertRaises(NotImplementedError): 1259 self.evaluate(distribution.run(v.scatter_div, args=(delta,))) 1260 1261 def testScatterMul(self, distribution, aggregation): 1262 with distribution.scope(): 1263 v = variables_lib.Variable( 1264 [2., 1., 1.], 1265 synchronization=variables_lib.VariableSynchronization.ON_READ, 1266 aggregation=aggregation) 1267 self.evaluate(v.initializer) 1268 1269 delta = values.PerReplica([ 1270 indexed_slices.IndexedSlices( 1271 values=[[2.], [3.]], indices=[0, 1], dense_shape=(3,)), 1272 indexed_slices.IndexedSlices( 1273 values=[[4.], [5.]], indices=[1, 2], dense_shape=(3,)), 1274 ]) 1275 1276 with self.assertRaises(NotImplementedError): 1277 self.evaluate(distribution.run(v.scatter_mul, args=(delta,))) 1278 1279 def testScatterMin(self, distribution, aggregation): 1280 with distribution.scope(): 1281 v = variables_lib.Variable( 1282 [3., 4., 5.], 1283 synchronization=variables_lib.VariableSynchronization.ON_READ, 1284 aggregation=aggregation) 1285 self.evaluate(v.initializer) 1286 1287 delta = values.PerReplica([ 1288 indexed_slices.IndexedSlices( 1289 values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)), 1290 indexed_slices.IndexedSlices( 1291 values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)), 1292 ]) 1293 1294 with self.assertRaises(NotImplementedError): 1295 self.evaluate(distribution.run(v.scatter_min, args=(delta,))) 1296 1297 def testScatterMax(self, distribution, aggregation): 1298 with distribution.scope(): 1299 v = variables_lib.Variable( 1300 [3., 4., 5.], 1301 synchronization=variables_lib.VariableSynchronization.ON_READ, 1302 aggregation=aggregation) 1303 self.evaluate(v.initializer) 1304 1305 delta = values.PerReplica([ 1306 indexed_slices.IndexedSlices( 1307 values=[[1.], [8.]], indices=[0, 1], dense_shape=(3,)), 1308 indexed_slices.IndexedSlices( 1309 values=[[9.], [2.]], indices=[1, 2], dense_shape=(3,)), 1310 ]) 1311 1312 with self.assertRaises(NotImplementedError): 1313 self.evaluate(distribution.run(v.scatter_max, args=(delta,))) 1314 1315 def testScatterUpdate(self, distribution, aggregation): 1316 with distribution.scope(): 1317 v = variables_lib.Variable( 1318 [0., 0., 0.], 1319 synchronization=variables_lib.VariableSynchronization.ON_READ, 1320 aggregation=aggregation) 1321 self.evaluate(v.initializer) 1322 1323 delta = values.PerReplica([ 1324 indexed_slices.IndexedSlices( 1325 values=[[1.], [2.]], indices=[0, 1], dense_shape=(3,)), 1326 indexed_slices.IndexedSlices( 1327 values=[[3.], [4.]], indices=[1, 2], dense_shape=(3,)), 1328 ]) 1329 1330 with self.assertRaises(NotImplementedError): 1331 self.evaluate(distribution.run(v.scatter_min, args=(delta,))) 1332 1333 1334if __name__ == "__main__": 1335 test_util.main() 1336