1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for common methods in strategy classes.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.data.ops import dataset_ops 20from tensorflow.python.distribute import combinations 21from tensorflow.python.distribute import distribution_strategy_context as ds_context 22from tensorflow.python.distribute import multi_worker_test_base 23from tensorflow.python.distribute import reduce_util 24from tensorflow.python.distribute import strategy_combinations 25from tensorflow.python.distribute import strategy_test_lib 26from tensorflow.python.distribute import test_util 27from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import indexed_slices 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import test 37from tensorflow.python.util import nest 38 39 40@combinations.generate( 41 combinations.combine( 42 strategy=[ 43 strategy_combinations.multi_worker_mirrored_2x1_cpu, 44 strategy_combinations.multi_worker_mirrored_2x1_gpu, 45 ] + strategy_combinations.all_strategies, 46 mode=['eager'])) 47class StrategyTest(test.TestCase, parameterized.TestCase): 48 49 def testCaptureReplicaId(self, strategy): 50 m = {} 51 52 @def_function.function 53 def f(): 54 return ds_context.get_replica_context().replica_id_in_sync_group 55 56 @def_function.function 57 def g(): 58 # Make g() a stateful function so it's traced twice. 59 if m.get('v', None) is None: 60 m['v'] = variables.Variable(0.) 61 return strategy.run(f) 62 63 g() 64 65 def testMergeCallInitScope(self, strategy): 66 with strategy.scope(): 67 68 @def_function.function 69 def fn(): 70 71 def merge_fn(unused_strat): 72 73 y = constant_op.constant(11) 74 return y 75 76 def replica_fn(): 77 78 with ops.init_scope(): 79 y = ds_context.get_replica_context().merge_call(merge_fn) 80 z = y + 1 81 return z 82 83 return strategy.run(replica_fn) 84 85 result = strategy.experimental_local_results(fn()) 86 self.assertAllClose(result, [12] * _get_num_replicas_per_client(strategy)) 87 88 89@combinations.generate( 90 combinations.combine( 91 distribution=[ 92 strategy_combinations.mirrored_strategy_with_cpu_1_and_2, 93 strategy_combinations.multi_worker_mirrored_2x2_gpu, 94 strategy_combinations.tpu_strategy 95 ], 96 mode=['graph', 'eager'])) 97class StrategyLocalResultTest(test.TestCase): 98 99 def testLocalResultForDictionary(self, distribution): 100 101 @def_function.function 102 def model_fn(): 103 return {'a': constant_op.constant(1.), 'b': constant_op.constant(2.)} 104 105 with distribution.scope(): 106 result = distribution.run(model_fn) 107 got = self.evaluate(distribution.experimental_local_results(result)) 108 self.assertEqual(got, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.})) 109 110 def testLocalResultForList(self, distribution): 111 112 @def_function.function 113 def model_fn(): 114 return [constant_op.constant(1.), constant_op.constant(2.)] 115 116 with distribution.scope(): 117 result = distribution.run(model_fn) 118 got = self.evaluate(distribution.experimental_local_results(result)) 119 self.assertEqual(got, ([1., 2.], [1., 2.])) 120 121 def testLocalResultForTuple(self, distribution): 122 123 @def_function.function 124 def model_fn(): 125 return (constant_op.constant(1.), constant_op.constant(2.), 126 constant_op.constant(3.)) 127 128 with distribution.scope(): 129 result = distribution.run(model_fn) 130 got = self.evaluate(distribution.experimental_local_results(result)) 131 self.assertEqual(got, ((1., 2., 3.), (1., 2., 3.))) 132 133 def testLocalResultForNestedStruct(self, distribution): 134 135 @def_function.function 136 def model_fn(): 137 return ({ 138 'a': constant_op.constant(1.), 139 'b': constant_op.constant(2.) 140 }, { 141 'a': constant_op.constant(4.), 142 'b': constant_op.constant(6.) 143 }) 144 145 with distribution.scope(): 146 result = distribution.run(model_fn) 147 got = self.evaluate(distribution.experimental_local_results(result)) 148 self.assertEqual(got, (({ 149 'a': 1., 150 'b': 2. 151 }, { 152 'a': 4., 153 'b': 6. 154 }), ({ 155 'a': 1., 156 'b': 2. 157 }, { 158 'a': 4., 159 'b': 6. 160 }))) 161 162 def testLocalResultForNestedStructWithoutTensor(self, distribution): 163 164 @def_function.function 165 def model_fn(): 166 return {'a': 1., 'b': 2.} 167 168 with distribution.scope(): 169 result = distribution.run(model_fn) 170 v = self.evaluate(distribution.experimental_local_results(result)) 171 self.assertIsInstance(v, tuple) 172 self.assertAllEqual(v, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.})) 173 174 def testLocalResultForScalarValue(self, distribution): 175 176 @def_function.function 177 def model_fn(): 178 return distribution.extended._get_local_replica_id( 179 ds_context.get_replica_context().replica_id_in_sync_group) 180 181 with distribution.scope(): 182 result = distribution.run(model_fn) 183 v = self.evaluate(distribution.experimental_local_results(result)) 184 self.assertIsInstance(v, tuple) 185 self.assertEqual(v, (0, 1)) 186 187 def testLocalResultForDictionaryDifferentReplicas(self, distribution): 188 189 @def_function.function 190 def model_fn(): 191 replica_id = distribution.extended._get_local_replica_id( 192 ds_context.get_replica_context().replica_id_in_sync_group) 193 return { 194 'a': math_ops.cast(replica_id + 1, dtype=float), 195 'b': math_ops.cast(replica_id + 2, dtype=float) 196 } 197 198 with distribution.scope(): 199 result = distribution.run(model_fn) 200 got = self.evaluate(distribution.experimental_local_results(result)) 201 self.assertAllEqual(got, ({'a': 1., 'b': 2.}, {'a': 2., 'b': 3.})) 202 203 def testLocalResultForTensor(self, distribution): 204 205 @def_function.function 206 def model_fn(): 207 return constant_op.constant([2., 3.]) 208 209 with distribution.scope(): 210 result = distribution.run(model_fn) 211 v = self.evaluate(distribution.experimental_local_results(result)) 212 self.assertAllEqual(v, ([2., 3.], [2., 3.])) 213 214 215@combinations.generate( 216 combinations.combine( 217 strategy=[ 218 strategy_combinations.multi_worker_mirrored_2x1_cpu, 219 strategy_combinations.multi_worker_mirrored_2x1_gpu, 220 ] + strategy_combinations.all_strategies, 221 mode=['eager'])) 222class ReduceTest(test.TestCase, parameterized.TestCase): 223 224 def testBasic(self, strategy): 225 per_replica_value = strategy.experimental_distribute_values_from_function( 226 lambda _: array_ops.ones((), dtypes.float32)) 227 228 def fn_eager(): 229 230 return strategy.reduce( 231 reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None) 232 233 fn_graph = def_function.function(fn_eager) 234 # Run reduce under the strategy scope to explicitly enter 235 # strategy default_device scope. 236 with strategy.scope(): 237 self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) 238 self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) 239 240 # Run reduce without a strategy scope to implicitly enter 241 # strategy default_device scope. 242 self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync) 243 self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync) 244 245 def testAxis(self, strategy): 246 247 @def_function.function 248 def fn(): 249 return constant_op.constant([1., 2.]) 250 251 x = strategy.run(fn) 252 253 x_m = strategy.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) 254 self.assertEqual(1.5, x_m) 255 x_s = strategy.reduce(reduce_util.ReduceOp.SUM, x, axis=0) 256 self.assertEqual(3 * strategy.num_replicas_in_sync, x_s) 257 258 259@combinations.generate( 260 combinations.combine( 261 strategy=[ 262 strategy_combinations.default_strategy, 263 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 264 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 265 strategy_combinations.tpu_strategy, 266 strategy_combinations.tpu_strategy_packed_var, 267 strategy_combinations.multi_worker_mirrored_2x1_cpu, 268 strategy_combinations.multi_worker_mirrored_2x2_gpu, 269 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 270 ], 271 update_fn=['assign', 'assign_add', 'assign_sub'], 272 tf_function=[True, False], 273 mode=['eager'])) 274class ReplicaCtxUpdateTest(test.TestCase, parameterized.TestCase): 275 276 def testDenseUpdate(self, strategy, tf_function, update_fn): 277 if strategy_test_lib.is_tpu_strategy(strategy) and (not tf_function): 278 self.skipTest('Skip TPUStrategy + eager combination.') 279 with strategy.scope(): 280 distributed_variable1 = variables.Variable(5.0) 281 282 def replica_fn(): 283 value = array_ops.constant(2.) 284 python_literal = 1. 285 replica_context = ds_context.get_replica_context() 286 fn_sets = { 287 'assign': lambda var, value: var.assign(value), 288 'assign_add': lambda var, value: var.assign_add(value), 289 'assign_sub': lambda var, value: var.assign_sub(value), 290 } 291 replica_context._update( 292 distributed_variable1, fn_sets[update_fn], args=(value,)) 293 replica_context._update( 294 distributed_variable1, fn_sets[update_fn], args=(python_literal,)) 295 296 if tf_function: 297 replica_fn = def_function.function(replica_fn) 298 strategy.run(replica_fn) 299 300 expected_result = {'assign': 1., 'assign_add': 8., 'assign_sub': 2.} 301 self.assertAllEqual( 302 strategy.experimental_local_results(distributed_variable1), 303 [expected_result[update_fn]] * _get_num_replicas_per_client(strategy)) 304 305 306@combinations.generate( 307 combinations.combine( 308 strategy=[ 309 strategy_combinations.multi_worker_mirrored_2x1_cpu, 310 strategy_combinations.multi_worker_mirrored_2x1_gpu, 311 strategy_combinations.multi_worker_mirrored_2x2_gpu, 312 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 313 strategy_combinations.tpu_strategy, 314 ] + strategy_combinations.strategies_minus_tpu, 315 tf_function=[combinations.tf_function, combinations.no_tf_function], 316 mode=['eager'])) 317class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase): 318 319 def testDense(self, strategy, tf_function): 320 if (strategy_test_lib.is_tpu_strategy(strategy) and 321 tf_function is combinations.no_tf_function): 322 self.skipTest('Skip TPUStrategy + eager combination.') 323 324 @tf_function 325 def fn(): 326 327 def replica_fn(): 328 value = array_ops.identity(1.0) 329 reduced = strategy.extended._replica_ctx_all_reduce( 330 reduce_util.ReduceOp.SUM, value) 331 return reduced 332 333 return strategy.experimental_local_results(strategy.run(replica_fn)) 334 335 got = fn()[0] 336 self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync) 337 338 def testSparse(self, strategy, tf_function): 339 if tf_function is combinations.no_tf_function: 340 self.skipTest('Skip IndexedSlices + eager combination.') 341 342 @tf_function 343 def fn(): 344 345 def replica_fn(): 346 value = indexed_slices.IndexedSlices( 347 values=array_ops.identity([[1.0]]), 348 indices=array_ops.identity([0]), 349 dense_shape=array_ops.identity([5, 1])) 350 reduced = strategy.extended._replica_ctx_all_reduce( 351 reduce_util.ReduceOp.SUM, value) 352 return reduced 353 354 return strategy.experimental_local_results(strategy.run(replica_fn)) 355 356 got = fn()[0] 357 expect = indexed_slices.IndexedSlices( 358 values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]), 359 indices=array_ops.identity([0]), 360 dense_shape=array_ops.identity([5, 1])) 361 self.assertAllEqual( 362 ops.convert_to_tensor(got), ops.convert_to_tensor(expect)) 363 364 def testNestedInput(self, strategy, tf_function): 365 if tf_function is combinations.no_tf_function: 366 self.skipTest('Skip IndexedSlices + eager combination.') 367 368 @tf_function 369 def fn(): 370 371 def replica_fn(): 372 value = (array_ops.identity(1.0), 373 indexed_slices.IndexedSlices( 374 values=array_ops.identity([[1.0]]), 375 indices=array_ops.identity([0]), 376 dense_shape=array_ops.identity([5, 1])), 377 array_ops.identity(2.0), 378 indexed_slices.IndexedSlices( 379 values=array_ops.identity([[2.0]]), 380 indices=array_ops.identity([1]), 381 dense_shape=array_ops.identity([5, 1]))) 382 reduced = strategy.extended._replica_ctx_all_reduce( 383 reduce_util.ReduceOp.SUM, value) 384 return reduced 385 386 return strategy.experimental_local_results(strategy.run(replica_fn)) 387 388 got = fn()[0] 389 expect = (1.0 * strategy.num_replicas_in_sync, 390 indexed_slices.IndexedSlices( 391 values=array_ops.identity( 392 [[1.0 * strategy.num_replicas_in_sync]]), 393 indices=array_ops.identity([0]), 394 dense_shape=array_ops.identity([5, 1])), 395 2.0 * strategy.num_replicas_in_sync, 396 indexed_slices.IndexedSlices( 397 values=array_ops.identity( 398 [[2.0 * strategy.num_replicas_in_sync]]), 399 indices=array_ops.identity([1]), 400 dense_shape=array_ops.identity([5, 1]))) 401 402 self.assertAllClose( 403 nest.map_structure(ops.convert_to_tensor, got), 404 nest.map_structure(ops.convert_to_tensor, expect)) 405 406 def testSyncOnReadVariableInput(self, strategy, tf_function): 407 if (not strategy_test_lib.is_mirrored_strategy(strategy) and 408 not strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and 409 not strategy_test_lib.is_tpu_strategy(strategy)): 410 self.skipTest('Skip strategies not using SyncOnReadVariables.') 411 if (strategy_test_lib.is_tpu_strategy(strategy) and 412 tf_function is combinations.no_tf_function): 413 self.skipTest('Skip TPUStrategy + eager combination.') 414 if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and 415 tf_function is combinations.tf_function): 416 self.skipTest('Skip MWMS + graph combination until b/228512201 is fixed.') 417 418 with strategy.scope(): 419 var = variables.Variable( 420 0.0, 421 synchronization=variables.VariableSynchronization.ON_READ, 422 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 423 424 @tf_function 425 def replica_fn(): 426 replica_context = ds_context.get_replica_context() 427 replica_id = replica_context.replica_id_in_sync_group 428 var.assign(math_ops.cast(replica_id, dtype=float) * 3.0) 429 430 return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var) 431 432 if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy): 433 client_local_replica_num = strategy.extended._num_devices_per_worker 434 else: 435 client_local_replica_num = strategy.num_replicas_in_sync 436 437 workers_num = strategy.num_replicas_in_sync 438 expected_sum = sum(range(workers_num)) * 3.0 439 440 # Expand the values on each replica if multiple devices are used; otherwise 441 # simple read the value of the Tensor. 442 result = strategy.run(replica_fn) 443 if hasattr(result, 'values'): 444 result = result.values 445 result = nest.flatten(result) 446 447 # Iterate through all replicas and verify the reduce sum result. 448 for i in range(client_local_replica_num): 449 self.assertEqual(result[i].numpy(), expected_sum) 450 451 452@combinations.generate( 453 combinations.combine( 454 strategy=[ 455 strategy_combinations.multi_worker_mirrored_2x1_cpu, 456 strategy_combinations.multi_worker_mirrored_2x1_gpu, 457 strategy_combinations.multi_worker_mirrored_2x2_gpu, 458 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 459 strategy_combinations.tpu_strategy, 460 ] + strategy_combinations.strategies_minus_tpu, 461 tf_function=[combinations.tf_function, combinations.no_tf_function], 462 mode=['eager'])) 463class AllReduceTest(test.TestCase, parameterized.TestCase): 464 465 def testDense(self, strategy, tf_function): 466 if (strategy_test_lib.is_tpu_strategy(strategy) and 467 tf_function is combinations.no_tf_function): 468 self.skipTest('Skip TPUStrategy + eager combination.') 469 470 @tf_function 471 def fn(): 472 473 def replica_fn(): 474 value = array_ops.identity(1.0) 475 rep_ctx = ds_context.get_replica_context() 476 reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value) 477 return reduced 478 479 return strategy.experimental_local_results(strategy.run(replica_fn)) 480 481 got = fn()[0] 482 self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync) 483 484 def testSparse(self, strategy, tf_function): 485 if tf_function is combinations.no_tf_function: 486 self.skipTest('Skip IndexedSlices + eager combination.') 487 488 @tf_function 489 def fn(): 490 491 def replica_fn(): 492 value = indexed_slices.IndexedSlices( 493 values=array_ops.identity([[1.0]]), 494 indices=array_ops.identity([0]), 495 dense_shape=array_ops.identity([5, 1])) 496 rep_ctx = ds_context.get_replica_context() 497 reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) 498 return reduced 499 500 return strategy.experimental_local_results(strategy.run(replica_fn)) 501 502 got = fn()[0] 503 504 if not strategy_test_lib.is_tpu_strategy(strategy): 505 self.assertIsInstance(got, indexed_slices.IndexedSlices) 506 expect = indexed_slices.IndexedSlices( 507 values=array_ops.identity([[1.0]]), 508 indices=array_ops.identity([0]), 509 dense_shape=array_ops.identity([5, 1])) 510 self.assertAllEqual( 511 ops.convert_to_tensor(got), ops.convert_to_tensor(expect)) 512 513 def testSparseTuple(self, strategy, tf_function): 514 if tf_function is combinations.no_tf_function: 515 self.skipTest('Skip IndexedSlices + eager combination.') 516 517 @tf_function 518 def fn(): 519 520 def replica_fn(): 521 value1 = indexed_slices.IndexedSlices( 522 values=array_ops.identity([[1.0]]), 523 indices=array_ops.identity([0]), 524 dense_shape=array_ops.identity([5, 1])) 525 value2 = indexed_slices.IndexedSlices( 526 values=array_ops.identity([[2.0]]), 527 indices=array_ops.identity([0]), 528 dense_shape=array_ops.identity([5, 1])) 529 rep_ctx = ds_context.get_replica_context() 530 reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, [value1, value2]) 531 return reduced 532 533 return strategy.experimental_local_results(strategy.run(replica_fn)) 534 535 got = fn()[0] 536 537 if not strategy_test_lib.is_tpu_strategy(strategy): 538 for g in got: 539 self.assertIsInstance(g, indexed_slices.IndexedSlices) 540 expect = [ 541 indexed_slices.IndexedSlices( 542 values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]), 543 indices=array_ops.identity([0]), 544 dense_shape=array_ops.identity([5, 1])), 545 indexed_slices.IndexedSlices( 546 values=array_ops.identity([[2.0 * strategy.num_replicas_in_sync]]), 547 indices=array_ops.identity([0]), 548 dense_shape=array_ops.identity([5, 1])) 549 ] 550 self.assertAllEqual( 551 nest.map_structure(ops.convert_to_tensor, got), 552 nest.map_structure(ops.convert_to_tensor, expect)) 553 554 def testNestedInput(self, strategy, tf_function): 555 if tf_function is combinations.no_tf_function: 556 self.skipTest('Skip IndexedSlices + eager combination.') 557 558 @tf_function 559 def fn(): 560 561 def replica_fn(): 562 value = (array_ops.identity(1.0), 563 indexed_slices.IndexedSlices( 564 values=array_ops.identity([[1.0]]), 565 indices=array_ops.identity([0]), 566 dense_shape=array_ops.identity([5, 1])), 567 array_ops.identity(2.0), 568 indexed_slices.IndexedSlices( 569 values=array_ops.identity([[2.0]]), 570 indices=array_ops.identity([1]), 571 dense_shape=array_ops.identity([5, 1]))) 572 rep_ctx = ds_context.get_replica_context() 573 reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value) 574 return reduced 575 576 return strategy.experimental_local_results(strategy.run(replica_fn)) 577 578 got = fn()[0] 579 expect = (1.0 * strategy.num_replicas_in_sync, 580 indexed_slices.IndexedSlices( 581 values=array_ops.identity( 582 [[1.0 * strategy.num_replicas_in_sync]]), 583 indices=array_ops.identity([0]), 584 dense_shape=array_ops.identity([5, 1])), 585 2.0 * strategy.num_replicas_in_sync, 586 indexed_slices.IndexedSlices( 587 values=array_ops.identity( 588 [[2.0 * strategy.num_replicas_in_sync]]), 589 indices=array_ops.identity([1]), 590 dense_shape=array_ops.identity([5, 1]))) 591 592 self.assertAllClose( 593 nest.map_structure(ops.convert_to_tensor, got), 594 nest.map_structure(ops.convert_to_tensor, expect)) 595 596 597def _make_indexed_slices(values, indices, dense_shape): 598 tensor = indexed_slices.IndexedSlices( 599 values=constant_op.constant(values), 600 indices=constant_op.constant(indices), 601 dense_shape=constant_op.constant(dense_shape)) 602 return tensor 603 604 605def _get_num_replicas_per_client(strategy): 606 if isinstance(strategy, CollectiveAllReduceStrategy): 607 resolver = strategy.cluster_resolver 608 return max(nest.flatten(resolver.num_accelerators())[0], 1) 609 else: 610 return strategy.num_replicas_in_sync 611 612 613@combinations.generate( 614 combinations.combine( 615 strategy=[ 616 strategy_combinations.multi_worker_mirrored_2x1_cpu, 617 strategy_combinations.multi_worker_mirrored_2x1_gpu, 618 ], 619 mode=['eager'])) 620class DistributedCollectiveAllReduceStrategyTest( 621 strategy_test_lib.DistributionTestBase, 622 parameterized.TestCase): 623 624 def testDatasetFromFunction(self, strategy): 625 def dataset_fn(input_context): 626 global_batch_size = 10 627 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 628 d = dataset_ops.DatasetV2.range(100).repeat().batch(batch_size) 629 return d.shard(input_context.num_input_pipelines, 630 input_context.input_pipeline_id) 631 632 expected_sum_on_workers = {'chief': 10, 'worker': 35} 633 input_iterator = iter( 634 strategy.distribute_datasets_from_function(dataset_fn)) 635 636 @def_function.function 637 def run(iterator): 638 return strategy.experimental_local_results(iterator.get_next()) 639 640 result = run(input_iterator) 641 sum_value = math_ops.reduce_sum(result) 642 self.assertEqual( 643 sum_value.numpy(), 644 expected_sum_on_workers[multi_worker_test_base.get_task_type()]) 645 646 def testSimpleInputFromDatasetLastPartialBatch(self, strategy): 647 global_batch_size = 8 648 dataset = dataset_ops.DatasetV2.range(14).batch( 649 global_batch_size, drop_remainder=False) 650 input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 651 652 @def_function.function 653 def run(input_iterator): 654 return strategy.run(lambda x: x, args=(next(input_iterator),)) 655 656 # Let the complete batch go. 657 run(input_iterator) 658 659 # `result` is an incomplete batch 660 result = run(input_iterator) 661 expected_data_on_workers = {'chief': [8, 9, 10], 'worker': [11, 12, 13]} 662 self.assertAllEqual( 663 expected_data_on_workers[multi_worker_test_base.get_task_type()], 664 result.numpy(), 665 ) 666 667 def testSimpleInputFromFnLastPartialBatch(self, strategy): 668 669 def dataset_fn(input_context): 670 global_batch_size = 8 671 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 672 dataset = dataset_ops.DatasetV2.range(14).batch( 673 batch_size, drop_remainder=False) 674 return dataset.shard(input_context.num_input_pipelines, 675 input_context.input_pipeline_id) 676 677 input_iterator = iter( 678 strategy.distribute_datasets_from_function(dataset_fn)) 679 680 @def_function.function 681 def run(input_iterator): 682 return strategy.run(lambda x: x, args=(next(input_iterator),)) 683 684 # Let the complete batch go. 685 run(input_iterator) 686 # `result` is an incomplete batch 687 result = run(input_iterator) 688 689 expected_data_on_worker = {'chief': [8, 9, 10, 11], 'worker': [12, 13]} 690 self.assertAllEqual( 691 expected_data_on_worker[multi_worker_test_base.get_task_type()], 692 result.numpy()) 693 694 def testReduceHostTensor(self, strategy): 695 reduced = strategy.reduce( 696 reduce_util.ReduceOp.SUM, array_ops.identity(1.), axis=None) 697 self.assertEqual(reduced.numpy(), 2.) 698 699 def testReduceToHostTensor(self, strategy): 700 value = array_ops.identity(1.) 701 reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value, 702 value) 703 self.assertEqual(reduced.numpy(), 2.) 704 705 def testBatchReduceToHostTensor(self, strategy): 706 value = array_ops.identity(1.) 707 reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, 708 [(value, value), 709 (value, value)]) 710 self.assertAllEqual([2., 2.], reduced) 711 712 def testReduceDeviceTensors(self, strategy): 713 value = strategy.run(lambda: array_ops.identity(1.)) 714 reduced = strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) 715 self.assertEqual(reduced.numpy(), 2.) 716 717 def testReduceToDeviceTensors(self, strategy): 718 value = strategy.run(lambda: array_ops.identity(1.)) 719 reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value, 720 value) 721 self.assertEqual(reduced.numpy(), 2.) 722 723 def testBatchReduceToDeviceTensors(self, strategy): 724 value = strategy.run(lambda: array_ops.identity(1.)) 725 reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM, 726 [(value, value), 727 (value, value)]) 728 self.assertAllEqual([2., 2.], reduced) 729 730 # TODO(crccw): add a test that mixes device and host tensors after multi 731 # worker strategy combinations can run on a fixed number of GPUs. 732 733 734class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase): 735 736 @combinations.generate( 737 combinations.combine( 738 strategy=[strategy_combinations.multi_worker_mirrored_2x1_cpu] + 739 strategy_combinations.all_strategies, 740 mode=['eager'])) 741 def testClusterResolverProperty(self, strategy): 742 # CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver. 743 # `None` otherwise. 744 resolver = strategy.cluster_resolver 745 if (not isinstance(strategy, CollectiveAllReduceStrategy) and 746 not strategy_test_lib.is_tpu_strategy(strategy)): 747 self.assertIsNone(resolver) 748 return 749 750 with strategy.scope(): 751 self.assertIs(strategy.cluster_resolver, resolver) 752 753 self.assertTrue(hasattr(resolver, 'cluster_spec')) 754 self.assertTrue(hasattr(resolver, 'master')) 755 self.assertTrue(hasattr(resolver, 'num_accelerators')) 756 self.assertTrue(hasattr(resolver, 'task_id')) 757 self.assertTrue(hasattr(resolver, 'task_type')) 758 if isinstance(strategy, CollectiveAllReduceStrategy): 759 self.assertEqual(resolver.task_id, 0) 760 self.assertAllInSet(resolver.task_type, ['chief', 'worker']) 761 762 763if __name__ == '__main__': 764 test_util.main() 765