1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for MirroredStrategy.""" 16 17import json 18import sys 19 20from absl.testing import parameterized 21 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.python import tf2 24from tensorflow.python.autograph.core import converter_testing 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.distribute import combinations 27from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 28from tensorflow.python.distribute import device_util 29from tensorflow.python.distribute import distribute_lib 30from tensorflow.python.distribute import distribute_utils 31from tensorflow.python.distribute import distribution_strategy_context as ds_context 32from tensorflow.python.distribute import mirrored_strategy 33from tensorflow.python.distribute import multi_worker_test_base 34from tensorflow.python.distribute import reduce_util 35from tensorflow.python.distribute import strategy_combinations 36from tensorflow.python.distribute import strategy_test_lib 37from tensorflow.python.distribute import test_util 38from tensorflow.python.distribute import values 39from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 40from tensorflow.python.eager import backprop 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import function 44from tensorflow.python.eager import test 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import device as tf_device 47from tensorflow.python.framework import dtypes 48from tensorflow.python.framework import func_graph 49from tensorflow.python.framework import ops 50from tensorflow.python.framework import tensor_shape 51from tensorflow.python.framework import tensor_util 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import control_flow_ops 54from tensorflow.python.ops import gradients 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops import variable_scope 57from tensorflow.python.ops import variables 58from tensorflow.python.training import server_lib 59from tensorflow.python.util import traceback_utils 60 61 62GPU_TEST = "test_gpu" in sys.argv[0] 63 64 65@combinations.generate( 66 combinations.combine( 67 distribution=[ 68 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 69 strategy_combinations.mirrored_strategy_with_two_gpus, 70 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 71 ], 72 mode=["graph", "eager"])) 73class MirroredTwoDeviceDistributionTest( 74 strategy_test_lib.DistributionTestBase, 75 strategy_test_lib.TwoDeviceDistributionTestBase, 76 parameterized.TestCase): 77 78 def testMinimizeLoss(self, distribution): 79 if context.executing_eagerly(): 80 self._test_minimize_loss_eager(distribution) 81 else: 82 self._test_minimize_loss_graph(distribution) 83 84 def testReplicaId(self, distribution): 85 self._test_replica_id(distribution) 86 87 def testNumReplicasInSync(self, distribution): 88 self.assertEqual(2, distribution.num_replicas_in_sync) 89 90 def testCallAndMergeExceptions(self, distribution): 91 self._test_call_and_merge_exceptions(distribution) 92 93 def testRunRegroupError(self, distribution): 94 if not distribution.extended._use_merge_call(): 95 self.skipTest("Collective all-reduce does not support int32 on GPU.") 96 def run_fn(): 97 replica_id = int(self.evaluate(_replica_id())) 98 # Generates a list with different lengths on different devices. 99 # Will fail in _regroup() (if more than one device). 100 return list(range(replica_id)) 101 102 with distribution.scope(), self.assertRaises(AssertionError): 103 distribution.extended.call_for_each_replica(run_fn) 104 105 def testReduceToCpu(self, distribution): 106 if not distribution.extended._use_merge_call(): 107 self.skipTest("Collective all-reduce does not support int32 on GPU.") 108 109 with distribution.scope(): 110 result = distribution.extended.call_for_each_replica(_replica_id) 111 reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None) 112 expected = sum(range(distribution.num_replicas_in_sync)) 113 self.assertEqual(expected, self.evaluate(reduced)) 114 115 def testReduceToCpuNested(self, distribution): 116 if not distribution.extended._use_merge_call(): 117 self.skipTest("Collective all-reduce does not support int32 on GPU.") 118 119 with distribution.scope(): 120 def replica_fn(input_tensor): 121 return input_tensor + constant_op.constant( 122 1.0), input_tensor - constant_op.constant(1.0) 123 124 input_tensor = constant_op.constant(3.0) 125 run_result = distribution.run(replica_fn, args=(input_tensor,)) 126 reduced_result = distribution.reduce("SUM", run_result, axis=None) 127 expected_result = (4 * distribution.num_replicas_in_sync, 128 2 * distribution.num_replicas_in_sync) 129 130 self.assertEqual(expected_result, self.evaluate(reduced_result)) 131 132 def reduce_axis_helper(self, distribution, replica_squared_fn): 133 with distribution.scope(): 134 num_replicas = distribution.num_replicas_in_sync 135 result = distribution.extended.call_for_each_replica(replica_squared_fn) 136 # sum 137 reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0) 138 expected = sum(x * (x + 1) for x in range(num_replicas)) 139 self.assertNear(expected, self.evaluate(reduced), 0.00001) 140 141 # mean 142 reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0) 143 expected /= sum(x + 1 for x in range(num_replicas)) 144 self.assertNear(expected, self.evaluate(reduced), 0.00001) 145 146 def testReduceAxisToCpu(self, distribution): 147 if not distribution.extended._use_merge_call(): 148 self.skipTest("Collective all-reduce does not support int32 on GPU.") 149 for dtype in (dtypes.float32, dtypes.int32): 150 def replica_squared_fn(dtype=dtype): 151 # Lists with different lengths on different replicas. 152 replica_id = _replica_id_as_int() 153 return array_ops.identity( 154 math_ops.cast([replica_id] * (replica_id + 1), dtype)) 155 156 self.reduce_axis_helper(distribution, replica_squared_fn) 157 158 def set_v2_tensorshape(self, v2): 159 if v2: 160 tensor_shape.enable_v2_tensorshape() 161 else: 162 tensor_shape.disable_v2_tensorshape() 163 164 def testReduceAxisToCpuUnknownShape(self, distribution): 165 if not distribution.extended._use_merge_call(): 166 self.skipTest("Collective all-reduce does not support int32 on GPU.") 167 original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE # pylint: disable=protected-access 168 try: 169 for v2 in (False, True): 170 self.set_v2_tensorshape(v2) 171 for dtype in (dtypes.float32, dtypes.int32): 172 for shape in ((None,), None): # Test both unknown size and rank. 173 def replica_squared_fn(dtype=dtype, shape=shape): 174 # Lists with different lengths on different replicas. 175 replica_id = _replica_id_as_int() 176 tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype) 177 # Erase shape information 178 return array_ops.placeholder_with_default(tensor, shape=shape) 179 180 self.reduce_axis_helper(distribution, replica_squared_fn) 181 finally: 182 self.set_v2_tensorshape(original_v2) 183 184 def testReplicateDataset(self, distribution): 185 if tf2.enabled() and not context.executing_eagerly(): 186 self.skipTest("Skipping test since we do not support graph mode in TF 2") 187 188 dataset_fn = lambda: dataset_ops.Dataset.range(10) 189 expected_values = [[i, i+1] for i in range(0, 10, 2)] 190 input_fn = self._input_fn_to_test_input_context( 191 dataset_fn, 192 expected_num_replicas_in_sync=2, 193 expected_num_input_pipelines=1, 194 expected_input_pipeline_id=0) 195 self._test_input_fn_iterable(distribution, input_fn, expected_values) 196 197 def testMakeInputFnIteratorWithDataset(self, distribution): 198 dataset_fn = lambda: dataset_ops.Dataset.range(10) 199 expected_values = [[i, i+1] for i in range(0, 10, 2)] 200 201 input_fn = self._input_fn_to_test_input_context( 202 dataset_fn, 203 expected_num_replicas_in_sync=2, 204 expected_num_input_pipelines=1, 205 expected_input_pipeline_id=0) 206 iterator = distribution.make_input_fn_iterator(input_fn) 207 self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, 208 expected_values) 209 210 def testMakeInputFnIteratorWithCallable(self, distribution): 211 def fn(): 212 dataset = dataset_ops.Dataset.range(2).interleave( 213 (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2) 214 it = dataset_ops.make_one_shot_iterator(dataset) 215 return it.get_next 216 expected_values = [[i, i] for i in range(0, 10)] 217 218 input_fn = self._input_fn_to_test_input_context( 219 fn, 220 expected_num_replicas_in_sync=2, 221 expected_num_input_pipelines=1, 222 expected_input_pipeline_id=0) 223 iterator = distribution.make_input_fn_iterator(input_fn) 224 self._test_input_fn_iterator(iterator, distribution.extended.worker_devices, 225 expected_values, test_reinitialize=False, 226 ignore_order=True) 227 228 def testNumpyDataset(self, distribution): 229 self._test_numpy_dataset(distribution) 230 231 def testGlobalStepUpdate(self, distribution): 232 self._test_global_step_update(distribution) 233 234 def testRun(self, distribution): 235 self._test_run(distribution) 236 237 def testAllReduceSum(self, distribution): 238 self._test_all_reduce_sum(distribution) 239 240 def testAllReduceSumGradients(self, distribution): 241 self._test_all_reduce_sum_gradients(distribution) 242 243 def testAllReduceSumGradientTape(self, distribution): 244 self._test_all_reduce_sum_gradient_tape(distribution) 245 246 def testAllReduceMean(self, distribution): 247 self._test_all_reduce_mean(distribution) 248 249 def testAllReduceMeanGradients(self, distribution): 250 self._test_all_reduce_mean_gradients(distribution) 251 252 def testAllReduceMeanGradientTape(self, distribution): 253 self._test_all_reduce_mean_gradient_tape(distribution) 254 255 def testSummaryForReplicaZeroOnly(self, distribution): 256 self._test_summary_for_replica_zero_only(distribution) 257 258 def testTrainableVariables(self, distribution): 259 self._test_trainable_variable(distribution) 260 261 def test_prefetch_to_device_dataset(self, distribution): 262 input_options = distribute_lib.InputOptions( 263 experimental_fetch_to_device=True) 264 dataset = dataset_ops.Dataset.range(100) 265 dataset = dataset.batch(distribution.num_replicas_in_sync) 266 dataset = distribution.experimental_distribute_dataset( 267 dataset, options=input_options) 268 if context.executing_eagerly(): 269 item = next(iter(dataset)) 270 else: 271 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 272 item = dataset.make_initializable_iterator().get_next() 273 else: 274 self.skipTest("unsupported test combination") 275 device_types = [ 276 tf_device.DeviceSpec.from_string(tensor.device).device_type for 277 tensor in item.values] 278 expected_device_types = [ 279 tf_device.DeviceSpec.from_string(device).device_type for 280 device in distribution.extended.worker_devices] 281 self.assertAllEqual(device_types, expected_device_types) 282 283 def test_prefetch_to_host_dataset(self, distribution): 284 input_options = distribute_lib.InputOptions( 285 experimental_fetch_to_device=False) 286 dataset = dataset_ops.Dataset.range(100) 287 dataset = dataset.batch(distribution.num_replicas_in_sync) 288 dataset = distribution.experimental_distribute_dataset( 289 dataset, options=input_options) 290 if context.executing_eagerly(): 291 item = next(iter(dataset)) 292 else: 293 if isinstance(dataset, input_lib_v1.DistributedDatasetV1): 294 item = dataset.make_initializable_iterator().get_next() 295 else: 296 self.skipTest("unsupported test combination") 297 device_types = { 298 tf_device.DeviceSpec.from_string(tensor.device).device_type for 299 tensor in item.values} 300 self.assertAllEqual(list(device_types), ["CPU"]) 301 302 303def one_device_combinations(): 304 return combinations.combine( 305 distribution=[ 306 strategy_combinations.mirrored_strategy_with_one_cpu, 307 strategy_combinations.mirrored_strategy_with_one_gpu, 308 ], 309 mode=["graph", "eager"]) 310 311 312@combinations.generate(one_device_combinations()) 313class MirroredOneDeviceDistributionTest( 314 strategy_test_lib.DistributionTestBase, 315 strategy_test_lib.OneDeviceDistributionTestBase, 316 parameterized.TestCase): 317 318 def testMinimizeLoss(self, distribution): 319 if context.executing_eagerly(): 320 self._test_minimize_loss_eager(distribution) 321 else: 322 self._test_minimize_loss_graph(distribution) 323 324 def testReplicaId(self, distribution): 325 self._test_replica_id(distribution) 326 327 def testCallAndMergeExceptions(self, distribution): 328 self._test_call_and_merge_exceptions(distribution) 329 330 def testRun(self, distribution): 331 self._test_run(distribution) 332 333 def testAllReduceSum(self, distribution): 334 self._test_all_reduce_sum(distribution) 335 336 def testAllReduceSumGradients(self, distribution): 337 self._test_all_reduce_sum_gradients(distribution) 338 339 def testAllReduceSumGradientTape(self, distribution): 340 self._test_all_reduce_sum_gradient_tape(distribution) 341 342 def testAllReduceMean(self, distribution): 343 self._test_all_reduce_mean(distribution) 344 345 def testAllReduceMeanGradients(self, distribution): 346 self._test_all_reduce_mean_gradients(distribution) 347 348 def testAllReduceMeanGradientTape(self, distribution): 349 self._test_all_reduce_mean_gradient_tape(distribution) 350 351 352class MirroredStrategyVariableCreatorStackTest( 353 test.TestCase, parameterized.TestCase): 354 355 @combinations.generate( 356 combinations.combine( 357 distribution=[ 358 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 359 ], 360 mode=["graph"])) 361 def testCreatorStacksAreThreadLocal(self, distribution): 362 def model_fn(): 363 replica_id_str = str(self.evaluate(_replica_id())) 364 365 def thread_creator_fn(next_creator, **kwargs): 366 return next_creator(**kwargs) + ":thread_" + replica_id_str 367 368 with variable_scope.variable_creator_scope(thread_creator_fn): 369 # Create a variable in this scope. 370 v = variable_scope.variable(1.0) 371 372 # This will pause the current thread, and execute the other thread. 373 ds_context.get_replica_context().merge_call(lambda _: _) 374 return v 375 376 def main_thread_creator(next_creator, **kwargs): 377 # We are not using the underlying next_creator for test purposes. 378 del next_creator, kwargs 379 return "main_thread" 380 381 with context.graph_mode(), \ 382 distribution.scope(), \ 383 variable_scope.variable_creator_scope(main_thread_creator): 384 result = distribution.extended.call_for_each_replica(model_fn) 385 result = distribution.experimental_local_results(result) 386 expected = ("main_thread:thread_0", "main_thread:thread_1") 387 self.assertEqual(expected, result) 388 389 390@combinations.generate( 391 combinations.combine( 392 distribution=[ 393 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 394 ], 395 mode=["graph", "eager"])) 396class MirroredStrategyCallForEachReplicaTest(test.TestCase): 397 398 def testExecutingEagerlyOutsideFunction(self, distribution): 399 """Verify we preserve the value of executing_eagerly_outside_functions().""" 400 def model_fn(): 401 return ops.executing_eagerly_outside_functions() 402 403 originally = ops.executing_eagerly_outside_functions() 404 with distribution.scope(): 405 in_scope = ops.executing_eagerly_outside_functions() 406 in_model_fn = distribution.extended.call_for_each_replica(model_fn) 407 unwrapped = distribution.experimental_local_results(in_model_fn) 408 self.assertEqual(in_scope, unwrapped[0]) 409 self.assertEqual(in_scope, originally) 410 411 # Verify this all again, but this time in a FuncGraph. 412 with func_graph.FuncGraph("fg").as_default(), distribution.scope(): 413 in_scope = ops.executing_eagerly_outside_functions() 414 in_model_fn = distribution.extended.call_for_each_replica(model_fn) 415 unwrapped = distribution.experimental_local_results(in_model_fn) 416 self.assertEqual(in_scope, unwrapped[0]) 417 self.assertEqual(in_scope, originally) 418 419 def testFunctionInCallForEachReplica(self, distribution): 420 traces = [] 421 @def_function.function 422 def model_fn(): 423 traces.append(1) 424 return ds_context.get_replica_context().replica_id_in_sync_group 425 426 with distribution.scope(): 427 result = distribution.extended.call_for_each_replica(model_fn) 428 self.assertEqual( 429 (0, 1), 430 self.evaluate(distribution.experimental_local_results(result))) 431 self.assertLen(traces, distribution.num_replicas_in_sync) 432 433 def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution): 434 traces = [] 435 @def_function.function 436 def model_fn(): 437 traces.append(1) 438 return ds_context.get_replica_context().replica_id_in_sync_group 439 440 @def_function.function 441 def step(): 442 return distribution.extended.call_for_each_replica(model_fn) 443 444 with distribution.scope(): 445 result = step() 446 self.assertEqual( 447 (0, 1), 448 self.evaluate(distribution.experimental_local_results(result))) 449 self.assertLen(traces, distribution.num_replicas_in_sync) 450 451 def testControlFlowFunctionInCallForEachReplicaWithMergeCall( 452 self, distribution): 453 454 def merge_fn(strategy, value): 455 return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) 456 457 @def_function.function 458 def model_fn(): 459 460 def body_fn(i): 461 return ds_context.get_replica_context().merge_call(merge_fn, args=(i,)) 462 463 return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0]) 464 465 with distribution.scope(): 466 with self.assertRaisesRegex( 467 RuntimeError, "`merge_call` called while defining a new graph."): 468 distribution.extended.call_for_each_replica(model_fn) 469 470 def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution): 471 472 def merge_fn(strategy, value): 473 return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None) 474 475 def model_fn(): 476 477 @def_function.function 478 def model_fn_nested(): 479 t = constant_op.constant(1) 480 return ds_context.get_replica_context().merge_call(merge_fn, args=(t,)) 481 482 return model_fn_nested() 483 484 with distribution.scope(): 485 with self.assertRaisesRegex( 486 RuntimeError, "`merge_call` called while defining a new graph."): 487 distribution.extended.call_for_each_replica(model_fn) 488 489 def testFunctionInCallForEachReplicaWithMergeCall(self, distribution): 490 def merge_fn(_): 491 pass 492 493 @def_function.function 494 def model_fn(): 495 ds_context.get_replica_context().merge_call(merge_fn) 496 return 0. 497 498 with distribution.scope(): 499 self.assertEqual( 500 self.evaluate(distribution.extended.call_for_each_replica(model_fn)), 501 0.) 502 503 def testFunctionInCallForEachReplicaCached(self, distribution): 504 traces = [] 505 506 @def_function.function 507 def model_fn(): 508 traces.append(None) 509 510 self.assertEmpty(traces) 511 512 for i in range(10): 513 distribution.extended.call_for_each_replica(model_fn) 514 515 if i == 0: 516 num_devices = len(traces) 517 self.assertGreater(num_devices, 0) 518 else: 519 # model_fn should not have been re-evaluated so the length should remain 520 # the same. 521 self.assertLen(traces, num_devices) 522 523 524@combinations.generate( 525 combinations.combine( 526 distribution=[ 527 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 528 ], 529 mode=["graph"])) 530class MirroredStrategyNameScopeTest(test.TestCase): 531 # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not 532 # testing this in eager mode. 533 534 def testNameScope(self, distribution): 535 def model_fn(): 536 with ops.name_scope("foo"): 537 a = constant_op.constant(1.0, name="a") 538 ds_context.get_replica_context().merge_call(lambda _: _) 539 b = constant_op.constant(1.0, name="b") 540 return a, b 541 542 with context.graph_mode(), distribution.scope(): 543 with ops.name_scope("main"): 544 result = distribution.extended.call_for_each_replica(model_fn) 545 self.assertEqual(2, len(result)) 546 for v, name in zip(result, ["a", "b"]): 547 self.assertIsInstance(v, values.DistributedValues) 548 v0, v1 = distribution.experimental_local_results(v) 549 self.assertEqual("main/foo/" + name + ":0", v0.name) 550 self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name) 551 552 def testWithDefaultName(self, distribution): 553 def model_fn(): 554 with ops.name_scope(None, "foo"): 555 a = constant_op.constant(1.0, name="a") 556 ds_context.get_replica_context().merge_call(lambda _: _) 557 b = constant_op.constant(2.0, name="b") 558 return a, b 559 560 with context.graph_mode(), distribution.scope(): 561 result = distribution.extended.call_for_each_replica(model_fn) 562 self.assertEqual(2, len(result)) 563 for v, name in zip(result, ["a", "b"]): 564 self.assertIsInstance(v, values.DistributedValues) 565 v0, v1 = distribution.experimental_local_results(v) 566 self.assertEqual("foo/" + name + ":0", v0.name) 567 self.assertEqual("replica_1/foo/" + name + ":0", v1.name) 568 569 # variable_scope.variable() respects name scopes when creating 570 # variables. On the other hand variable_scope.get_variable() ignores name 571 # scopes but respects variable scope when creating variables. We test both 572 # methods of creating variables to make sure that we have the same 573 # variable names in both cases. 574 def testNameScopeWithVariable(self, distribution): 575 def in_cross_replica(_): 576 c = variable_scope.variable(1.0, name="c") 577 return c 578 579 def model_fn(): 580 b = variable_scope.variable(1.0, name="b") 581 with ops.name_scope("foo"): 582 c = ds_context.get_replica_context().merge_call(in_cross_replica) 583 return b, c 584 585 with context.graph_mode(), distribution.scope(): 586 with ops.name_scope("main"): 587 a = variable_scope.variable(1.0, name="a") 588 result = distribution.extended.call_for_each_replica(model_fn) 589 result_b = result[0] 590 result_c = result[1] 591 self.assertIsInstance(result_b, values.DistributedValues) 592 self.assertIsInstance(result_c, values.DistributedValues) 593 a0, a1 = distribution.experimental_local_results(a) 594 b0, b1 = distribution.experimental_local_results(result_b) 595 c0, c1 = distribution.experimental_local_results(result_c) 596 self.assertEqual("main/a:0", a0.name) 597 self.assertEqual("main/a/replica_1:0", a1.name) 598 self.assertEqual("main/b:0", b0.name) 599 self.assertEqual("main/b/replica_1:0", b1.name) 600 self.assertEqual("main/foo/c:0", c0.name) 601 self.assertEqual("main/foo/c/replica_1:0", c1.name) 602 603 def testNameScopeWithGetVariable(self, distribution): 604 def in_cross_replica(_): 605 c = variable_scope.get_variable("c", [1]) 606 return c 607 608 def model_fn(): 609 b = variable_scope.get_variable("b", [1]) 610 with ops.name_scope("foo"): 611 c = ds_context.get_replica_context().merge_call(in_cross_replica) 612 return b, c 613 614 with context.graph_mode(), distribution.scope(): 615 with ops.name_scope("main"): 616 a = variable_scope.get_variable("a", [1]) 617 result = distribution.extended.call_for_each_replica(model_fn) 618 result_b = result[0] 619 result_c = result[1] 620 self.assertIsInstance(result_b, values.DistributedValues) 621 self.assertIsInstance(result_c, values.DistributedValues) 622 a0, a1 = distribution.experimental_local_results(a) 623 b0, b1 = distribution.experimental_local_results(result_b) 624 c0, c1 = distribution.experimental_local_results(result_c) 625 self.assertEqual("a:0", a0.name) 626 self.assertEqual("a/replica_1:0", a1.name) 627 self.assertEqual("b:0", b0.name) 628 self.assertEqual("b/replica_1:0", b1.name) 629 self.assertEqual("c:0", c0.name) 630 self.assertEqual("c/replica_1:0", c1.name) 631 632 def testVariableScopeWithGetVariable(self, distribution): 633 634 def in_cross_replica(_): 635 c = variable_scope.get_variable("c", [1]) 636 return c 637 638 def model_fn(): 639 b = variable_scope.get_variable("b", [1]) 640 with variable_scope.variable_scope("foo"): 641 c = ds_context.get_replica_context().merge_call(in_cross_replica) 642 return b, c 643 644 with context.graph_mode(), distribution.scope(): 645 with variable_scope.variable_scope("main"): 646 a = variable_scope.get_variable("a", [1]) 647 result = distribution.extended.call_for_each_replica(model_fn) 648 result_b = result[0] 649 result_c = result[1] 650 self.assertIsInstance(result_b, values.DistributedValues) 651 self.assertIsInstance(result_c, values.DistributedValues) 652 a0, a1 = distribution.experimental_local_results(a) 653 b0, b1 = distribution.experimental_local_results(result_b) 654 c0, c1 = distribution.experimental_local_results(result_c) 655 self.assertEqual("main/a:0", a0.name) 656 self.assertEqual("main/a/replica_1:0", a1.name) 657 self.assertEqual("main/b:0", b0.name) 658 self.assertEqual("main/b/replica_1:0", b1.name) 659 self.assertEqual("main/foo/c:0", c0.name) 660 self.assertEqual("main/foo/c/replica_1:0", c1.name) 661 662 663@combinations.generate( 664 combinations.combine( 665 distribution=[ 666 combinations.NamedDistribution( 667 "Mirrored3Devices", 668 # pylint: disable=g-long-lambda 669 lambda: mirrored_strategy.MirroredStrategy( 670 ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]), 671 required_gpus=2) 672 ], 673 mode=["graph", "eager"])) 674class MirroredThreeDeviceDistributionTest( 675 strategy_test_lib.DistributionTestBase, 676 parameterized.TestCase): 677 678 def testThreeDevices(self, distribution): 679 def model_fn(): 680 v = variable_scope.variable(1.0, name="foo") 681 ds_context.get_replica_context().merge_call(lambda _: _) 682 return v 683 684 with distribution.scope(): 685 result = distribution.extended.call_for_each_replica(model_fn) 686 self.assertTrue(distribute_utils.is_mirrored(result)) 687 self.assertEqual("foo:0", result.name) 688 689 690@combinations.generate( 691 combinations.combine( 692 distribution=[ 693 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 694 ], 695 mode=["graph", "eager"])) 696class MirroredVariableUpdateTest(test.TestCase): 697 # The following tests check assign, assign_add and assign_sub on Mirrored 698 # variables in replica and cross replica context. 699 700 def testAssignMirroredVarReplicaContextWithoutAggregationType(self, 701 distribution): 702 def var_fn(): 703 v = variable_scope.variable(1.0, name="foo") 704 return v 705 706 with distribution.scope(): 707 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 708 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 709 self.evaluate(variables.global_variables_initializer()) 710 711 def model_fn(): 712 return mirrored_var.assign(5.0) 713 714 self.evaluate(distribution.experimental_local_results( 715 distribution.extended.call_for_each_replica(model_fn))) 716 self.assertEqual(5.0, self.evaluate(mirrored_var)) 717 718 def testAssignMirroredVarReplicaContextWithSum(self, distribution): 719 # Test that we don't reduce a non-per-replica value with the "sum" 720 # aggregation type. 721 def var_fn(): 722 v = variable_scope.variable( 723 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM) 724 return v 725 726 with distribution.scope(): 727 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 728 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 729 self.evaluate(variables.global_variables_initializer()) 730 731 def model_fn(): 732 return mirrored_var.assign(5.0) 733 734 if distribution.extended._use_merge_call(): 735 with self.assertRaisesRegex( 736 ValueError, "A non-DistributedValues value 5.0 cannot be reduced " 737 "with the given reduce op ReduceOp.SUM."): 738 self.evaluate(distribution.experimental_local_results( 739 distribution.extended.call_for_each_replica(model_fn))) 740 else: 741 result = self.evaluate( 742 distribution.experimental_local_results( 743 distribution.extended.call_for_each_replica(model_fn))) 744 self.assertAllEqual(result[0], 5.0) 745 746 def testAssignMirroredVarCrossDeviceContext(self, distribution): 747 def var_fn(): 748 return variable_scope.variable(1.0, name="foo") 749 750 with distribution.scope(): 751 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 752 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 753 self.evaluate(variables.global_variables_initializer()) 754 self.assertEqual(1.0, self.evaluate(mirrored_var)) 755 mirrored_var_result = self.evaluate(mirrored_var.assign(6.0)) 756 self.assertEqual(6.0, mirrored_var_result) 757 758 def testAssignMirroredVarReplicaContext(self, distribution): 759 def var_fn(): 760 return variable_scope.variable( 761 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 762 763 with distribution.scope(): 764 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 765 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 766 self.evaluate(variables.global_variables_initializer()) 767 self.assertEqual(1.0, self.evaluate(mirrored_var)) 768 769 def model_fn(): 770 value = math_ops.cast( 771 ds_context.get_replica_context().replica_id_in_sync_group, 772 mirrored_var.dtype) 773 return mirrored_var.assign(value) 774 775 self.evaluate(distribution.experimental_local_results( 776 distribution.extended.call_for_each_replica(model_fn))) 777 self.assertEqual(0.5, self.evaluate(mirrored_var)) 778 779 def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution): 780 def var_fn(): 781 return variable_scope.variable( 782 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 783 784 with distribution.scope(): 785 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 786 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 787 self.evaluate(variables.global_variables_initializer()) 788 self.assertEqual(1.0, self.evaluate(mirrored_var)) 789 790 def model_fn(): 791 return mirrored_var.assign(5.0) 792 793 self.evaluate(distribution.experimental_local_results( 794 distribution.extended.call_for_each_replica(model_fn))) 795 self.assertEqual(5.0, self.evaluate(mirrored_var)) 796 797 def testAssignAddMirroredVarCrossDeviceContext(self, distribution): 798 def var_fn(): 799 return variable_scope.variable(1.0, name="foo") 800 801 with distribution.scope(): 802 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 803 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 804 self.evaluate(variables.global_variables_initializer()) 805 self.assertEqual(1.0, self.evaluate(mirrored_var)) 806 807 # read_value == True 808 mirrored_var_result = self.evaluate( 809 mirrored_var.assign_add(6.0, read_value=True)) 810 self.assertEqual(7.0, mirrored_var_result) 811 self.assertEqual( 812 7.0, 813 self.evaluate( 814 distribution.experimental_local_results(mirrored_var)[0])) 815 self.assertEqual( 816 7.0, 817 self.evaluate( 818 distribution.experimental_local_results(mirrored_var)[1])) 819 self.assertEqual( 820 distribution.extended.worker_devices[0], mirrored_var._devices[0]) 821 self.assertEqual( 822 distribution.extended.worker_devices[1], mirrored_var._devices[1]) 823 824 # read_value == False 825 self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) 826 self.assertEqual( 827 9.0, 828 self.evaluate( 829 distribution.experimental_local_results(mirrored_var)[0])) 830 self.assertEqual( 831 9.0, 832 self.evaluate( 833 distribution.experimental_local_results(mirrored_var)[1])) 834 self.assertEqual( 835 distribution.extended.worker_devices[0], mirrored_var._devices[0]) 836 self.assertEqual( 837 distribution.extended.worker_devices[1], mirrored_var._devices[1]) 838 839 def testAssignAddMirroredVarReplicaContext(self, distribution): 840 def var_fn(): 841 return variable_scope.variable( 842 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 843 844 with distribution.scope(): 845 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 846 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 847 self.evaluate(variables.global_variables_initializer()) 848 self.assertEqual(1.0, self.evaluate(mirrored_var)) 849 850 def model_fn(): 851 value = math_ops.cast( 852 ds_context.get_replica_context().replica_id_in_sync_group, 853 mirrored_var.dtype) 854 return mirrored_var.assign_add(value) 855 856 self.evaluate(distribution.experimental_local_results( 857 distribution.extended.call_for_each_replica(model_fn))) 858 self.assertEqual(1.5, self.evaluate(mirrored_var)) 859 860 def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution): 861 def var_fn(): 862 return variable_scope.variable( 863 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 864 865 with distribution.scope(): 866 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 867 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 868 self.evaluate(variables.global_variables_initializer()) 869 self.assertEqual(1.0, self.evaluate(mirrored_var)) 870 871 def model_fn(): 872 return mirrored_var.assign_add(5.0) 873 874 self.evaluate(distribution.experimental_local_results( 875 distribution.extended.call_for_each_replica(model_fn))) 876 self.assertEqual(6.0, self.evaluate(mirrored_var)) 877 878 def testAssignSubMirroredVarCrossDeviceContext(self, distribution): 879 def var_fn(): 880 return variable_scope.variable(5.0, name="foo") 881 882 with distribution.scope(): 883 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 884 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 885 self.evaluate(variables.global_variables_initializer()) 886 self.assertEqual(5.0, self.evaluate(mirrored_var)) 887 mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) 888 self.assertEqual(3.0, mirrored_var_result) 889 self.assertEqual( 890 3.0, 891 self.evaluate( 892 distribution.experimental_local_results(mirrored_var)[0])) 893 self.assertEqual( 894 3.0, 895 self.evaluate( 896 distribution.experimental_local_results(mirrored_var)[1])) 897 self.assertEqual( 898 distribution.extended.worker_devices[0], mirrored_var._devices[0]) 899 self.assertEqual( 900 distribution.extended.worker_devices[1], mirrored_var._devices[1]) 901 902 def testAssignSubMirroredVarReplicaContext(self, distribution): 903 def var_fn(): 904 return variable_scope.variable( 905 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 906 907 with distribution.scope(): 908 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 909 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 910 self.evaluate(variables.global_variables_initializer()) 911 self.assertEqual(5.0, self.evaluate(mirrored_var)) 912 913 def model_fn(): 914 value = math_ops.cast( 915 ds_context.get_replica_context().replica_id_in_sync_group, 916 mirrored_var.dtype) 917 return mirrored_var.assign_sub(value) 918 919 self.evaluate(distribution.experimental_local_results( 920 distribution.extended.call_for_each_replica(model_fn))) 921 self.assertEqual(4.5, self.evaluate(mirrored_var)) 922 923 def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution): 924 def var_fn(): 925 return variable_scope.variable( 926 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN) 927 928 with distribution.scope(): 929 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 930 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 931 self.evaluate(variables.global_variables_initializer()) 932 self.assertEqual(5.0, self.evaluate(mirrored_var)) 933 934 def model_fn(): 935 return mirrored_var.assign_sub(1.0) 936 937 self.evaluate(distribution.experimental_local_results( 938 distribution.extended.call_for_each_replica(model_fn))) 939 self.assertEqual(4.0, self.evaluate(mirrored_var)) 940 941 942@combinations.generate( 943 combinations.combine( 944 distribution=[ 945 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 946 ], 947 mode=["graph", "eager"])) 948class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase): 949 950 def testAssignMirroredVarInitializer(self, distribution): 951 # This test is not eager compatible since in eager variables are initialized 952 # upon construction instead of once the initialization op is run. 953 with context.graph_mode(): 954 def var_fn(): 955 v = variable_scope.variable(1.0, name="foo") 956 return v 957 958 with distribution.scope(): 959 mirrored_var = distribution.extended.call_for_each_replica(var_fn) 960 self.assertTrue(distribute_utils.is_mirrored(mirrored_var)) 961 self.assertFalse(self.evaluate(mirrored_var.is_initialized())) 962 self.evaluate(mirrored_var.initializer) 963 self.assertTrue(self.evaluate(mirrored_var.is_initialized())) 964 965 def testAssignReplicaLocalVarInitializer(self, distribution): 966 # This test is not eager compatible since in eager variables are initialized 967 # upon construction instead of once the initialization op is run. 968 with context.graph_mode(): 969 def model_fn(): 970 v_sum = variable_scope.variable( 971 1.0, 972 synchronization=variable_scope.VariableSynchronization.ON_READ, 973 aggregation=variable_scope.VariableAggregation.SUM) 974 self.assertTrue(distribute_utils.is_sync_on_read(v_sum)) 975 return v_sum 976 977 with distribution.scope(): 978 sync_on_read_var = distribution.extended.call_for_each_replica( 979 model_fn) 980 self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) 981 self.assertFalse(self.evaluate(sync_on_read_var.is_initialized())) 982 self.evaluate(sync_on_read_var.initializer) 983 self.assertTrue(self.evaluate(sync_on_read_var.is_initialized())) 984 985 986@combinations.generate( 987 combinations.combine( 988 distribution=[ 989 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 990 ], 991 mode=["graph", "eager"])) 992class SyncOnReadVariableAssignTest(test.TestCase): 993 994 def testAssignReplicaLocalVarSumAggregation(self, distribution): 995 def model_fn(): 996 v_sum = variable_scope.variable( 997 1.0, 998 synchronization=variable_scope.VariableSynchronization.ON_READ, 999 aggregation=variable_scope.VariableAggregation.SUM) 1000 return v_sum 1001 1002 with distribution.scope(): 1003 sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) 1004 self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) 1005 self.evaluate(variables.global_variables_initializer()) 1006 # Each replica has a value of 1.0 assigned to it in replica context. 1007 # When we read the value using `read_var` we should see the SUM of each of 1008 # values on each of the replicas. 1009 self.assertEqual(2.0, self.evaluate( 1010 distribution.extended.read_var(sync_on_read_var))) 1011 # Assigning 6.0 in cross replica context will assign a value of 1012 # 6.0/num_replicas to each replica. 1013 tlv_ops = sync_on_read_var.assign(6.0) 1014 self.evaluate(tlv_ops) 1015 # On reading the sync on read var we should get the assigned value back. 1016 # The value on all the replicas are added before being returned by 1017 # `read_var`. 1018 self.assertEqual(6.0, self.evaluate( 1019 distribution.extended.read_var(sync_on_read_var))) 1020 1021 def testAssignReplicaLocalVarMeanAggregation(self, distribution): 1022 def model_fn(): 1023 v_sum = variable_scope.variable( 1024 1.0, 1025 synchronization=variable_scope.VariableSynchronization.ON_READ, 1026 aggregation=variable_scope.VariableAggregation.MEAN) 1027 return v_sum 1028 1029 with distribution.scope(): 1030 sync_on_read_var = distribution.extended.call_for_each_replica(model_fn) 1031 self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var)) 1032 self.evaluate(variables.global_variables_initializer()) 1033 # Each replica has a value of 1.0 assigned to it in replica context. 1034 # When we read the value using `read_var` we should see the MEAN of values 1035 # on all replicas which is the value assigned in replica context. 1036 self.assertEqual(1.0, self.evaluate( 1037 distribution.extended.read_var(sync_on_read_var))) 1038 tlv_ops = sync_on_read_var.assign(6.0) 1039 self.evaluate(tlv_ops) 1040 # On reading the sync on read var we should get the MEAN of all values 1041 # which is equal to the value assigned. 1042 self.assertEqual(6.0, self.evaluate( 1043 distribution.extended.read_var(sync_on_read_var))) 1044 1045 1046class MockModel(object): 1047 1048 def __init__(self, two_variables=False): 1049 self.variables = [] 1050 self.variables.append(variable_scope.variable(1.25, name="dummy_var1")) 1051 if two_variables: 1052 self.variables.append(variable_scope.variable(2.0, name="dummy_var2")) 1053 1054 def __call__(self, factor=2): 1055 x = factor * self.variables[0] 1056 if len(self.variables) > 1: 1057 x += self.variables[1] 1058 return x 1059 1060 1061@combinations.generate( 1062 combinations.combine( 1063 distribution=[ 1064 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1065 ], 1066 mode=["graph", "eager"])) 1067class MirroredStrategyDefunTest(test.TestCase): 1068 1069 def _call_and_check(self, distribution, model_fn, inputs, expected_result, 1070 defuns, two_variables=False): 1071 cpu_dev = device_util.canonicalize("CPU:0") 1072 gpu_dev = device_util.canonicalize("GPU:0") 1073 devices = [cpu_dev, gpu_dev] 1074 1075 with distribution.scope(): 1076 mock_model = MockModel(two_variables) 1077 self.evaluate(variables.global_variables_initializer()) 1078 1079 result = distribution.extended.call_for_each_replica( 1080 model_fn, args=[mock_model] + inputs) 1081 for r in range(len(devices)): 1082 device_result = distribute_utils.select_replica(r, result) 1083 device_expected_result = distribute_utils.select_replica( 1084 r, expected_result) 1085 self.assertAllClose(device_expected_result, 1086 self.evaluate(device_result)) 1087 1088 for defun in defuns: 1089 # `Function`s are specialized to the current device stack, so 1090 # call_for_each has one trace per device. To check that the expected set 1091 # of variables was accessed on each trace, we first retrieve each 1092 # device-specific graph function. 1093 per_replica_graph_functions = ( 1094 distribution.extended.call_for_each_replica( 1095 defun.get_concrete_function, args=[mock_model] + inputs)) 1096 for i in range(len(devices)): 1097 graph_function = distribution.experimental_local_results( 1098 per_replica_graph_functions)[i] 1099 # TODO(b/129555712): re-enable an assertion here that the two sets of 1100 # variables are the same. 1101 # self.assertEqual(set(graph_function.graph.variables), 1102 # set(mock_model.variables)) 1103 del graph_function 1104 1105 def testVariableInDefun(self, distribution): 1106 @function.defun 1107 def times_two(mock_model): 1108 return mock_model() 1109 1110 def model_fn(mock_model): 1111 return times_two(mock_model) 1112 1113 self._call_and_check(distribution, model_fn, [], 2.5, [times_two]) 1114 1115 def testVariableInNestedDefun(self, distribution): 1116 @function.defun 1117 def times_two(mock_model): 1118 return mock_model() 1119 1120 @function.defun 1121 def two_x_plus_one(mock_model): 1122 return times_two(mock_model) + 1 1123 1124 def model_fn(mock_model): 1125 return two_x_plus_one(mock_model) 1126 1127 self._call_and_check(distribution, model_fn, [], 3.5, 1128 [times_two, two_x_plus_one]) 1129 1130 def testTwoVariablesInNestedDefun(self, distribution): 1131 @function.defun 1132 def fn1(mock_model): 1133 return mock_model() 1134 1135 @function.defun 1136 def fn2(mock_model): 1137 return fn1(mock_model) + 1 1138 1139 def model_fn(mock_model): 1140 return fn2(mock_model) 1141 1142 self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2], 1143 two_variables=True) 1144 1145 def testGradientTapeOverNestedDefuns(self, distribution): 1146 @function.defun 1147 def fn1(mock_model): 1148 return mock_model() 1149 1150 @function.defun 1151 def fn2(mock_model): 1152 return fn1(mock_model) + 1 1153 1154 def model_fn(mock_model): 1155 with backprop.GradientTape(persistent=True) as gtape: 1156 result = fn2(mock_model) 1157 grads = gtape.gradient(result, 1158 [v._get() for v in mock_model.variables]) 1159 return grads 1160 1161 self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2], 1162 two_variables=True) 1163 1164 def testPassPerReplica(self, distribution): 1165 @function.defun 1166 def fn1(mock_model, factor): 1167 return mock_model(factor) 1168 1169 factors = values.PerReplica((5.0, 3.0)) 1170 expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25)) 1171 self._call_and_check(distribution, fn1, [factors], expected_result, [fn1]) 1172 1173 1174@combinations.generate( 1175 combinations.combine( 1176 distribution=[ 1177 combinations.NamedDistribution( 1178 "Mirrored", 1179 # pylint: disable=g-long-lambda 1180 lambda: mirrored_strategy.MirroredStrategy( 1181 devices=mirrored_strategy.all_local_devices(), 1182 cross_device_ops=cross_device_ops_lib.ReductionToOneDevice( 1183 ), 1184 ), 1185 required_gpus=1) 1186 ], 1187 mode=["graph"])) 1188class MultiWorkerMirroredStrategyTest( 1189 multi_worker_test_base.MultiWorkerTestBase, 1190 strategy_test_lib.DistributionTestBase): 1191 1192 def _configure_distribution_strategy(self, distribution): 1193 cluster_spec = server_lib.ClusterSpec({ 1194 "worker": ["/job:worker/task:0", "/job:worker/task:1"] 1195 }) 1196 distribution.configure(cluster_spec=cluster_spec) 1197 1198 def test_num_replicas_in_sync(self, distribution): 1199 self._configure_distribution_strategy(distribution) 1200 # We calculate the total number of gpus across the workers(2) specified in 1201 # the cluster spec. 1202 self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync) 1203 1204 def testMinimizeLossGraph(self, distribution): 1205 self._configure_distribution_strategy(distribution) 1206 self._test_minimize_loss_graph(distribution, learning_rate=0.05) 1207 1208 def testDeviceScope(self, distribution): 1209 """Test the device scope of multi-worker MirroredStrategy.""" 1210 self._configure_distribution_strategy(distribution) 1211 with distribution.scope(): 1212 a = constant_op.constant(1.) 1213 with ops.device("/cpu:0"): 1214 b = constant_op.constant(1.) 1215 self.assertEqual(a.device, "/job:worker/task:0") 1216 self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") 1217 1218 def testMakeInputFnIteratorWithDataset(self, distribution): 1219 self._configure_distribution_strategy(distribution) 1220 dataset_fn = lambda: dataset_ops.Dataset.range(100) 1221 num_gpus = context.num_gpus() 1222 num_workers = 2 1223 1224 expected_values = [[i+j for j in range(num_gpus)] * num_workers 1225 for i in range(0, 100, num_gpus)] 1226 1227 with context.graph_mode(), self.cached_session() as sess: 1228 # `expected_input_pipeline_id` is None because the input_fn will be called 1229 # multiple times, each with a different input_pipeline_id. 1230 input_fn = self._input_fn_to_test_input_context( 1231 dataset_fn, 1232 expected_num_replicas_in_sync=num_workers*num_gpus, 1233 expected_num_input_pipelines=num_workers, 1234 expected_input_pipeline_id=None) 1235 iterator = distribution.make_input_fn_iterator(input_fn) 1236 self._test_input_fn_iterator( 1237 iterator, distribution.extended.worker_devices, expected_values, sess) 1238 1239 def testMakeInputFnIteratorWithCallable(self, distribution): 1240 self._configure_distribution_strategy(distribution) 1241 def fn(): 1242 dataset = dataset_ops.Dataset.range(100) 1243 it = dataset_ops.make_one_shot_iterator(dataset) 1244 return it.get_next 1245 num_gpus = context.num_gpus() 1246 num_workers = 2 1247 1248 expected_values = [] 1249 for i in range(0, 100, num_gpus): 1250 expected_values.append([i+j for j in range(num_gpus)] * num_workers) 1251 1252 with context.graph_mode(), self.cached_session() as sess: 1253 # `expected_input_pipeline_id` is None because the input_fn will be called 1254 # multiple times, each with a different input_pipeline_id. 1255 input_fn = self._input_fn_to_test_input_context( 1256 fn, 1257 expected_num_replicas_in_sync=num_workers*num_gpus, 1258 expected_num_input_pipelines=num_workers, 1259 expected_input_pipeline_id=None) 1260 iterator = distribution.make_input_fn_iterator(input_fn) 1261 self._test_input_fn_iterator( 1262 iterator, distribution.extended.worker_devices, expected_values, sess, 1263 test_reinitialize=False, ignore_order=True) 1264 1265 def testUpdateConfigProto(self, distribution): 1266 distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]}) 1267 1268 config_proto = config_pb2.ConfigProto() 1269 new_config = distribution.update_config_proto(config_proto) 1270 1271 # Verify isolate_session_state 1272 self.assertTrue(new_config.isolate_session_state) 1273 1274 1275@combinations.generate( 1276 combinations.combine( 1277 distribution=[ 1278 combinations.NamedDistribution( 1279 "Mirrored", 1280 # pylint: disable=g-long-lambda 1281 lambda: mirrored_strategy.MirroredStrategy( 1282 devices=["/job:worker/task:0/gpu:{}".format( 1283 i) for i in range(context.num_gpus())]), 1284 required_gpus=1) 1285 ], 1286 mode=["graph"])) 1287class RemoteSingleWorkerMirroredStrategyGraph( 1288 multi_worker_test_base.SingleWorkerTestBaseGraph, 1289 strategy_test_lib.RemoteSingleWorkerMirroredStrategyBase): 1290 1291 def _get_num_gpus(self): 1292 return context.num_gpus() 1293 1294 def testNumReplicasInSync(self, distribution): 1295 self._testNumReplicasInSync(distribution) 1296 1297 def testMinimizeLoss(self, distribution): 1298 self._testMinimizeLoss(distribution) 1299 1300 def testDeviceScope(self, distribution): 1301 self._testDeviceScope(distribution) 1302 1303 def testMakeInputFnIteratorWithDataset(self, distribution): 1304 self._testMakeInputFnIteratorWithDataset(distribution) 1305 1306 def testMakeInputFnIteratorWithCallable(self, distribution): 1307 self._testMakeInputFnIteratorWithCallable(distribution) 1308 1309 1310class MultiWorkerMirroredStrategyTestWithChief( 1311 multi_worker_test_base.MultiWorkerTestBase, 1312 strategy_test_lib.DistributionTestBase): 1313 1314 @classmethod 1315 def setUpClass(cls): 1316 """Create a local cluster with 2 workers and 1 chief.""" 1317 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 1318 num_workers=2, num_ps=0, has_chief=True) 1319 cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] 1320 1321 def _make_cross_device_ops(self): 1322 return cross_device_ops_lib.ReductionToOneDevice() 1323 1324 def testMinimizeLossGraph(self): 1325 with context.graph_mode(): 1326 strategy = mirrored_strategy.MirroredStrategy( 1327 cross_device_ops=self._make_cross_device_ops()) 1328 strategy.configure(cluster_spec=self._cluster_spec) 1329 self._test_minimize_loss_graph(strategy, learning_rate=0.05) 1330 1331 def testMinimizeLossGraphMirroredStrategy(self): 1332 with context.graph_mode(): 1333 strategy = mirrored_strategy.MirroredStrategy( 1334 mirrored_strategy.all_local_devices(), 1335 cross_device_ops=self._make_cross_device_ops()) 1336 strategy.configure(cluster_spec=self._cluster_spec) 1337 self._test_minimize_loss_graph(strategy, learning_rate=0.05) 1338 1339 def testMinimizeLossGraphMirroredStrategyWithOneNode(self): 1340 with context.graph_mode(): 1341 cluster_spec = {} 1342 cluster_spec["chief"] = self._cluster_spec["chief"] 1343 tf_config = {"cluster": cluster_spec} 1344 with test.mock.patch.dict("os.environ", 1345 {"TF_CONFIG": json.dumps(tf_config)}): 1346 strategy = mirrored_strategy.MirroredStrategy() 1347 if context.num_gpus() == 0: 1348 self.assertIsInstance(strategy.extended._inferred_cross_device_ops, 1349 cross_device_ops_lib.ReductionToOneDevice) 1350 self.skipTest("b/130551176, run the following once fixed.") 1351 self._test_minimize_loss_graph(strategy, learning_rate=0.05) 1352 1353 def testInitializeFromTFConfig(self): 1354 with context.graph_mode(): 1355 tf_config = {"cluster": self._cluster_spec} 1356 with test.mock.patch.dict("os.environ", 1357 {"TF_CONFIG": json.dumps(tf_config)}): 1358 strategy = mirrored_strategy.MirroredStrategy( 1359 cross_device_ops=self._make_cross_device_ops()) 1360 self.assertEqual( 1361 max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync) 1362 1363 def testSummaryForReplicaZeroOnly(self): 1364 with context.graph_mode(): 1365 strategy = mirrored_strategy.MirroredStrategy( 1366 mirrored_strategy.all_local_devices(), 1367 cross_device_ops=self._make_cross_device_ops()) 1368 strategy.configure(cluster_spec=self._cluster_spec) 1369 self._test_summary_for_replica_zero_only(strategy) 1370 1371 1372class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase): 1373 1374 @combinations.generate( 1375 combinations.combine( 1376 distribution=[ 1377 strategy_combinations.mirrored_strategy_with_one_cpu, 1378 strategy_combinations.mirrored_strategy_with_one_gpu, 1379 ], 1380 mode=["graph"])) 1381 def testMirroredVariableAsStopGradient(self, distribution): 1382 with distribution.scope(): 1383 inp = constant_op.constant(1.0) 1384 x = variables.Variable(1.0) 1385 y = inp*x 1386 grads = gradients.gradients(x, y, stop_gradients=x) 1387 self.assertIsNone(grads[0]) 1388 1389 1390@combinations.generate( 1391 combinations.combine( 1392 distribution=[ 1393 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1394 ], 1395 mode=["eager"])) 1396class FunctionTest(test.TestCase, parameterized.TestCase): 1397 1398 def testBackwardFunctionDevicePlacement(self, distribution): 1399 with distribution.scope(): 1400 w = variable_scope.variable([1.5], name="w") 1401 b = variable_scope.variable([0.5], name="b") 1402 1403 @def_function.function 1404 def forward(x, w, b): 1405 return x * w + b 1406 1407 x = array_ops.identity([1.0], name="x_useless") 1408 concrete_forward = forward.get_concrete_function(x, w._primary, b._primary) 1409 1410 with distribution.scope(): 1411 1412 def replica_fn(): 1413 with backprop.GradientTape() as t: 1414 x = array_ops.identity([1.0], name="x") 1415 loss = concrete_forward(x, w._get(), b._get()) - [1.0] 1416 return t.gradient(loss, [w, b]) 1417 1418 def step_fn(): 1419 return distribution.run(replica_fn) 1420 1421 context.enable_run_metadata() 1422 g1, g2 = step_fn() 1423 run_metadata = context.export_run_metadata() 1424 context.disable_run_metadata() 1425 self.assertEqual(self.evaluate(g1._primary), 1.0) 1426 self.assertEqual(self.evaluate(g2._primary), 1.0) 1427 1428 # Verify that this node runs on both devices. 1429 node_name = "gradients_mul_grad_mul_1_x" 1430 devices_for_this_node = set() 1431 for partition_graph in run_metadata.partition_graphs: 1432 for node in partition_graph.node: 1433 if node.name == node_name: 1434 devices_for_this_node.add(node.device) 1435 devices = [device_util.resolve("/device:GPU:0"), 1436 device_util.resolve("/device:CPU:0")] 1437 self.assertSetEqual(devices_for_this_node, set(devices)) 1438 1439 def testFuctionPreservesAutoGraph(self, distribution): 1440 def f(): 1441 self.assertTrue(converter_testing.is_inside_generated_code()) 1442 return 1 1443 1444 with distribution.scope(): 1445 1446 @def_function.function 1447 def replica_fn(): 1448 return f() 1449 1450 distribution.run(replica_fn) 1451 1452 def testPreserveTracebackFiltering(self, distribution): 1453 traceback_utils.disable_traceback_filtering() 1454 self.assertFalse(traceback_utils.is_traceback_filtering_enabled()) 1455 1456 def f(): 1457 self.assertFalse(traceback_utils.is_traceback_filtering_enabled()) 1458 1459 distribution.run(f) 1460 1461 1462def _replica_id(): 1463 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 1464 if not isinstance(replica_id, ops.Tensor): 1465 replica_id = constant_op.constant(replica_id) 1466 return array_ops.identity(replica_id) 1467 1468 1469def _replica_id_as_int(): 1470 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 1471 if isinstance(replica_id, ops.Tensor): 1472 replica_id = tensor_util.constant_value(replica_id) 1473 return replica_id 1474 1475 1476if __name__ == "__main__": 1477 # TODO(b/172304955) 1478 test_util.main(config_logical_devices=False) 1479