1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for testing DistributionStrategy descendants.""" 16 17import functools 18import os 19import tempfile 20 21import numpy as np 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.core.util import event_pb2 25from tensorflow.python.client import session as session_lib 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib 28from tensorflow.python.distribute import distribute_lib 29from tensorflow.python.distribute import distribute_utils 30from tensorflow.python.distribute import distribution_strategy_context as ds_context 31from tensorflow.python.distribute import mirrored_strategy as mirrored_lib 32from tensorflow.python.distribute import reduce_util 33from tensorflow.python.distribute import tpu_strategy 34from tensorflow.python.eager import backprop 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.eager import test 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import test_util 42from tensorflow.python.lib.io import tf_record 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import gradients_impl 45from tensorflow.python.ops import init_ops 46from tensorflow.python.ops import init_ops_v2 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import summary_ops_v2 as summary_ops 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.ops import variables 51from tensorflow.python.platform import gfile 52from tensorflow.python.training import optimizer 53from tensorflow.python.training import training_util 54from tensorflow.python.util import nest 55from tensorflow.python.util import tf_inspect 56 57 58class _TestException(Exception): 59 pass 60 61 62# Conditionally wrap the fn in a def_function.function (so it runs in graph 63# mode). 64def _maybe_run_in_function(fn, run_in_function=False): 65 if not run_in_function or not context.executing_eagerly(): 66 return fn 67 else: 68 return def_function.function()(fn) 69 70 71# May be the argument to either distribution.extended.call_for_each_replica() or 72# get_replica_context().merge_call() 73def _raise_exception_fn(_=None): 74 raise _TestException() 75 76 77# Must be the argument to a distribution.extended.call_for_each_replica() call, 78# calls a get_replica_context().merge_call() that raises an exception. 79def _merge_raises_fn(): 80 ds_context.get_replica_context().merge_call(_raise_exception_fn) 81 82 83# Must be the argument to a get_replica_context().merge_call() call, calls 84# dist.extended.call_for_each_replica() with a function that raises an 85# exception. 86def _call_raises_fn(dist): 87 dist.extended.call_for_each_replica(_raise_exception_fn) 88 89 90# Must be the argument to a distribution.extended.call_for_each_replica() call, 91# calls a get_replica_context().merge_call() that calls a 92# call_for_each_replica() that raises an exception. 93def _merge_call_raises_fn(): 94 ds_context.get_replica_context().merge_call(_call_raises_fn) 95 96 97# Must be the argument to a get_replica_context().merge_call() call, calls 98# dist.extended.call_for_each_replica() with a function that calls a 99# get_replica_context().merge_call() that raises an exception. 100def _call_merge_raises_fn(dist): 101 dist.extended.call_for_each_replica(_merge_raises_fn) 102 103 104# Must be the argument to a distribution.extended.call_for_each_replica() call, 105# calls a get_replica_context().merge_call() that calls a 106# call_for_each_replica() that calls a get_replica_context().merge_call() that 107# raises an exception. 108def _merge_call_merge_raises_fn(): 109 ds_context.get_replica_context().merge_call(_call_merge_raises_fn) 110 111 112def _events_from_logdir(test_case, logdir): 113 """Reads summary events from log directory.""" 114 test_case.assertTrue(gfile.Exists(logdir)) 115 files = gfile.ListDirectory(logdir) 116 test_case.assertLen(files, 1) 117 records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) 118 result = [] 119 for r in records: 120 event = event_pb2.Event() 121 event.ParseFromString(r) 122 result.append(event) 123 return result 124 125 126def create_variable_like_keras_layer(name, shape, dtype): 127 """Utitlity for create variables that works like variable in keras layer.""" 128 initializer = functools.partial( 129 init_ops_v2.GlorotUniform(), shape, dtype=dtype) 130 return variables.Variable( 131 initial_value=initializer, name=name, trainable=True) 132 133 134def is_optimizer_v2_instance(optimizer_obj): 135 # For a optimizer instance, the v2 implementation has var_list as a required 136 # argument. 137 arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize) 138 return "var_list" in arg_spec.args[:-len(arg_spec.defaults)] 139 140 141def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool: 142 return isinstance( 143 strategy, 144 (mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1)) 145 146 147def is_multi_worker_mirrored_strategy( 148 strategy: distribute_lib.Strategy) -> bool: 149 return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy, 150 mwms_lib.CollectiveAllReduceStrategyV1)) 151 152 153def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool: 154 return isinstance(strategy, 155 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 156 tpu_strategy.TPUStrategyV2)) 157 158 159class DistributionTestBase(test.TestCase): 160 """Some tests that should work with any DistributionStrategy.""" 161 162 def _test_minimize_loss_eager(self, d): 163 with d.scope(): 164 kernel = create_variable_like_keras_layer( 165 name="kernel", shape=(1, 1), dtype=dtypes.float32) 166 def loss(x): 167 y = array_ops.reshape( 168 math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) 169 return y * y 170 # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a 171 # common `implicit_grad` function and put it in DistributionStrategy. 172 grad_fn = backprop.implicit_grad(loss) 173 grad_fn = optimizer.get_filtered_grad_fn(grad_fn) 174 175 def update(v, g): 176 return v.assign_sub(0.2 * g) 177 178 one = array_ops.identity([[1.]]) 179 180 def step(): 181 """Perform one optimization step.""" 182 # Run forward & backward to get gradients, variables list. 183 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 184 185 # Update the variables using the gradients and the update() function. 186 before_list = [] 187 after_list = [] 188 for g, v in g_v: 189 fetched = d.extended.read_var(v) 190 before_list.append(fetched) 191 # control_dependencies irrelevant but harmless in eager execution 192 with ops.control_dependencies([fetched]): 193 g = d.extended.reduce_to( 194 reduce_util.ReduceOp.SUM, g, destinations=v) 195 with ops.control_dependencies( 196 d.extended.update(v, update, args=(g,), group=False)): 197 after_list.append(d.extended.read_var(v)) 198 return before_list, after_list 199 200 for i in range(10): 201 b, a = step() 202 if i == 0: 203 before, = b # pylint: disable=unbalanced-tuple-unpacking 204 after, = a # pylint: disable=unbalanced-tuple-unpacking 205 206 error_before = abs(before.numpy() - 1) 207 error_after = abs(after.numpy() - 1) 208 # Error should go down 209 self.assertLess(error_after, error_before) 210 211 def _test_minimize_loss_graph(self, 212 d, 213 soft_placement=False, 214 learning_rate=0.2): 215 config = config_pb2.ConfigProto() 216 config.allow_soft_placement = soft_placement 217 config.gpu_options.per_process_gpu_memory_fraction = 0.3 218 with context.graph_mode(), \ 219 ops.Graph().as_default(), \ 220 self.cached_session(config=config) as sess, \ 221 d.scope(): 222 kernel = create_variable_like_keras_layer( 223 name="kernel", shape=(1, 1), dtype=dtypes.float32) 224 225 def loss(x): 226 y = array_ops.reshape( 227 math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.) 228 return y * y 229 230 grad_fn = backprop.implicit_grad(loss) 231 232 def update(v, g): 233 return v.assign_sub(learning_rate * g) 234 235 one = array_ops.identity([[1.]]) 236 237 def step(): 238 """Perform one optimization step.""" 239 # Run forward & backward to get gradients, variables list. 240 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 241 242 # Update the variables using the gradients and the update() function. 243 before_list = [] 244 after_list = [] 245 for g, v in g_v: 246 fetched = d.extended.read_var(v) 247 before_list.append(fetched) 248 with ops.control_dependencies([fetched]): 249 g = d.extended.reduce_to( 250 reduce_util.ReduceOp.SUM, g, destinations=v) 251 with ops.control_dependencies( 252 d.extended.update(v, update, args=(g,), group=False)): 253 after_list.append(d.extended.read_var(v)) 254 return before_list, after_list 255 256 before_out, after_out = step() 257 variables.global_variables_initializer().run() 258 for i in range(10): 259 b, a = sess.run((before_out, after_out)) 260 if i == 0: 261 before, = b 262 after, = a 263 264 error_before = abs(before - 1) 265 error_after = abs(after - 1) 266 # Error should go down 267 self.assertLess(error_after, error_before) 268 269 def _test_summary_for_replica_zero_only(self, d): 270 logdir = tempfile.mkdtemp() 271 272 def run_fn(): 273 """Function executed for each replica.""" 274 with summary_writer.as_default(): 275 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 276 return summary_ops.write("a", replica_id) 277 278 with self.cached_session() as sess, d.scope(), \ 279 summary_ops.always_record_summaries(): 280 # We need global_step because summary writing op *always* has global_step 281 # as input, even when we always record summary or never record summary. 282 global_step = training_util.get_or_create_global_step() 283 if not context.executing_eagerly(): 284 # When executing eagerly, variables are initialized immediately after 285 # creation, and its initializer will be None. 286 global_step.initializer.run() 287 summary_ops.set_step(0) 288 summary_writer = summary_ops.create_file_writer(logdir) 289 output = d.extended.call_for_each_replica(run_fn) 290 unwrapped = d.unwrap(output) 291 if not context.executing_eagerly(): 292 sess.run(summary_writer.init()) 293 sess.run(unwrapped) 294 sess.run(summary_writer.close()) 295 296 events = _events_from_logdir(self, logdir) 297 # There will be 2 entries: 1 summary file header entry, and 1 entry 298 # written by replica 0. 299 self.assertLen(events, 2) 300 self.assertEqual(events[1].summary.value[0].tag, "a") 301 self.assertEqual(events[1].summary.value[0].simple_value, 0.0) 302 303 def _test_replica_id(self, d): 304 with d.scope(): 305 expected_devices = [False] * len(d.extended.worker_devices) 306 307 def mark_devices_fn(): 308 replica_id = self.evaluate( 309 ds_context.get_replica_context().replica_id_in_sync_group) 310 self.assertLess(replica_id, len(d.extended.worker_devices)) 311 self.assertFalse(expected_devices[replica_id]) 312 expected_devices[replica_id] = True 313 314 d.extended.call_for_each_replica(mark_devices_fn) 315 self.assertAllEqual(expected_devices, 316 [True] * len(d.extended.worker_devices)) 317 318 def _test_call_and_merge_exceptions(self, dist): 319 with dist.scope(): 320 with self.assertRaises(_TestException): 321 dist.extended.call_for_each_replica(_raise_exception_fn) 322 with self.assertRaises(_TestException): 323 dist.extended.call_for_each_replica(_merge_raises_fn) 324 with self.assertRaises(_TestException): 325 dist.extended.call_for_each_replica(_merge_call_raises_fn) 326 with self.assertRaises(_TestException): 327 dist.extended.call_for_each_replica(_merge_call_merge_raises_fn) 328 329 def _input_fn_to_test_input_context(self, dataset_or_callable_fn, 330 expected_num_replicas_in_sync, 331 expected_num_input_pipelines, 332 expected_input_pipeline_id): 333 # Use a list of one element as counter so that it can be captured by the 334 # `_input_fn`. This counter is incremented by 1 each time an input_fn is 335 # called. We use this counter to check whether the `input_pipeline_id` 336 # matches the counter in the in-graph replication. 337 worker_id_counter = [0] 338 339 def _input_fn(input_context): 340 """Input fn for testing.""" 341 self.assertIsNotNone(input_context) 342 self.assertEqual(expected_num_replicas_in_sync, 343 input_context.num_replicas_in_sync) 344 self.assertEqual(expected_num_input_pipelines, 345 input_context.num_input_pipelines) 346 if expected_input_pipeline_id is not None: 347 self.assertEqual(expected_input_pipeline_id, 348 input_context.input_pipeline_id) 349 else: 350 self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id) 351 worker_id_counter[0] += 1 352 353 return dataset_or_callable_fn() 354 355 return _input_fn 356 357 def _test_input_fn_iterable( 358 self, strategy, input_fn, expected_values, ignore_order=False): 359 assert_same = self.assertCountEqual if ignore_order else self.assertEqual 360 361 iterable = strategy.distribute_datasets_from_function(input_fn) 362 if context.executing_eagerly(): 363 iterator = iter(iterable) 364 365 for expected_value in expected_values: 366 computed_value = self.evaluate( 367 list(strategy.experimental_local_results(next(iterator)))) 368 assert_same(expected_value, computed_value) 369 370 with self.assertRaises(StopIteration): 371 self.evaluate(strategy.experimental_local_results(next(iterator))) 372 373 # After re-initializing the iterator, should be able to iterate again. 374 iterator = iter(iterable) 375 376 for expected_value in expected_values: 377 computed_value = self.evaluate( 378 list(strategy.experimental_local_results(next(iterator)))) 379 assert_same(expected_value, computed_value) 380 else: 381 iterator = dataset_ops.make_initializable_iterator(iterable) 382 self._test_input_fn_iterator(iterator, strategy.extended.worker_devices, 383 expected_values, test_reinitialize=True, 384 ignore_order=ignore_order) 385 386 def _test_input_fn_iterator(self, 387 iterator, 388 devices, 389 expected_values, 390 sess=None, 391 test_reinitialize=True, 392 ignore_order=False): 393 evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) 394 evaluate(iterator.initializer) 395 396 for expected_value in expected_values: 397 next_element = iterator.get_next() 398 computed_value = evaluate( 399 [distribute_utils.select_replica(r, next_element) for r in 400 range(len(devices))]) 401 if ignore_order: 402 self.assertCountEqual(expected_value, computed_value) 403 else: 404 self.assertEqual(expected_value, computed_value) 405 406 with self.assertRaises(errors.OutOfRangeError): 407 next_element = iterator.get_next() 408 evaluate( 409 [distribute_utils.select_replica(r, next_element) for r in 410 range(len(devices))]) 411 412 # After re-initializing the iterator, should be able to iterate again. 413 if test_reinitialize: 414 evaluate(iterator.initializer) 415 416 for expected_value in expected_values: 417 next_element = iterator.get_next() 418 computed_value = evaluate([ 419 distribute_utils.select_replica(r, next_element) for r in 420 range(len(devices)) 421 ]) 422 if ignore_order: 423 self.assertCountEqual(expected_value, computed_value) 424 else: 425 self.assertEqual(expected_value, computed_value) 426 427 def _test_global_step_update(self, strategy): 428 with strategy.scope(): 429 global_step = variable_scope.get_variable( 430 "global_step", 431 shape=[], 432 dtype=dtypes.int64, 433 initializer=init_ops.zeros_initializer(), 434 trainable=False, 435 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 436 self.evaluate(variables.global_variables_initializer()) 437 438 def model_fn(): 439 train_op = global_step.assign_add(1) 440 value = global_step.read_value() 441 return train_op, value 442 443 train_ops, value = strategy.extended.call_for_each_replica(model_fn) 444 self.evaluate(strategy.group(train_ops)) 445 global_step_tensors = strategy.experimental_local_results(value) 446 global_step_values = self.evaluate(global_step_tensors) 447 self.assertEqual((1,) * len(global_step_tensors), global_step_values) 448 449 def _test_numpy_dataset(self, strategy, session=None, run_in_function=False): 450 if not isinstance(strategy, distribute_lib.StrategyV1): 451 self.skipTest("n/a: V1 only") 452 cached_session = session or self.cached_session() 453 with strategy.scope(), cached_session as sess: 454 x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]]) 455 y = np.asarray([5, 4, 3, 2, 1, 0]) 456 batch_size = 6 457 if not strategy.extended._global_batch_size: # pylint: disable=protected-access 458 batch_size = batch_size // strategy.num_replicas_in_sync 459 460 ds = strategy.extended.experimental_make_numpy_dataset( 461 (x, y), session=sess or self.cached_session()) 462 ds = ds.repeat(2) # 2 epochs 463 # We need to use the drop_remainder argument to get a known static 464 # input shape which is required for TPUs. 465 drop_remainder = strategy.extended.experimental_require_static_shapes 466 ds = ds.batch(batch_size, drop_remainder=drop_remainder) 467 i = strategy.make_dataset_iterator(ds) 468 469 self.evaluate(i.initializer) 470 471 def run_and_concatenate(strategy, i): 472 x, y = strategy.experimental_run( 473 _maybe_run_in_function(lambda z: z, run_in_function), i) 474 x, y = self.evaluate((strategy.experimental_local_results(x), 475 strategy.experimental_local_results(y))) 476 return np.concatenate(x), np.concatenate(y) 477 478 x_1, y_1 = run_and_concatenate(strategy, i) 479 self.assertAllEqual(x, x_1) 480 self.assertAllEqual(y, y_1) 481 x_2, y_2 = run_and_concatenate(strategy, i) 482 self.assertAllEqual(x, x_2) 483 self.assertAllEqual(y, y_2) 484 with self.assertRaises(errors.OutOfRangeError): 485 run_and_concatenate(strategy, i) 486 487 def _test_trainable_variable(self, strategy): 488 for cls in [variables.VariableV1, variables.Variable]: 489 with strategy.scope(): 490 v1 = cls(1.0) 491 self.assertEqual(True, v1.trainable) 492 493 v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ) 494 self.assertEqual(False, v2.trainable) 495 496 v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, 497 trainable=True) 498 self.assertEqual(True, v3.trainable) 499 500 v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ, 501 trainable=False) 502 self.assertEqual(False, v4.trainable) 503 504 505class OneDeviceDistributionTestBase(test.TestCase): 506 """Some tests that should work with any one-device DistributionStrategy.""" 507 508 def _test_run(self, strategy): 509 out1 = strategy.run(lambda: array_ops.identity(4.)) 510 self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1))) 511 512 out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,)) 513 out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) 514 self.assertAllEqual([8.], out2_vals["a"]) 515 self.assertAllEqual([16.], out2_vals["b"]) 516 517 out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2) 518 self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3))) 519 520 def _test_all_reduce_sum(self, strategy): 521 self._test_collective_comms( 522 strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.])) 523 524 def _test_all_reduce_sum_gradients(self, strategy): 525 self._test_collective_comms_gradients( 526 strategy, _all_sum, inputs=[4.], expected_grads=[4.]) 527 528 def _test_all_reduce_sum_gradient_tape(self, strategy): 529 self._test_collective_comms_gradient_tape( 530 strategy, _all_sum, inputs=[4.], expected_grads=[4.]) 531 532 def _test_all_reduce_mean(self, strategy): 533 self._test_collective_comms( 534 strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.])) 535 536 def _test_all_reduce_mean_gradients(self, strategy): 537 self._test_collective_comms_gradients( 538 strategy, _all_mean, inputs=[5.], expected_grads=[5.]) 539 540 def _test_all_reduce_mean_gradient_tape(self, strategy): 541 self._test_collective_comms_gradient_tape( 542 strategy, _all_mean, inputs=[5.], expected_grads=[5.]) 543 544 def _test_collective_comms(self, strategy, comm_fn, inputs, expected): 545 inputs = strategy.make_input_fn_iterator( 546 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 547 548 self.evaluate(inputs.initialize()) 549 outputs = self.evaluate( 550 list( 551 map(strategy.experimental_local_results, 552 strategy.experimental_run(comm_fn, inputs)))) 553 self.assertAllEqual([expected[0]], outputs[0]) 554 self.assertAllEqual([expected[1]], outputs[1]) 555 556 def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, 557 expected_grads): 558 if context.executing_eagerly(): 559 self.skipTest("`tf.gradients` is not supported with eager execution.") 560 561 def step(c): 562 x = array_ops.identity(42.) 563 y = comm_fn(x) * c 564 return gradients_impl.gradients(y, [x])[0] 565 566 inputs = strategy.make_input_fn_iterator( 567 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 568 569 self.evaluate(inputs.initialize()) 570 self.assertAllEqual( 571 expected_grads, 572 self.evaluate( 573 strategy.experimental_local_results( 574 strategy.experimental_run(step, inputs)))) 575 576 def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, 577 expected_grads): 578 579 def step(c): 580 x = array_ops.identity(42.) 581 with backprop.GradientTape() as tape: 582 tape.watch(x) 583 y = comm_fn(x) * c 584 return tape.gradient(y, x) 585 586 inputs = strategy.make_input_fn_iterator( 587 lambda _: dataset_ops.Dataset.from_tensors(inputs)) 588 589 self.evaluate(inputs.initialize()) 590 self.assertAllEqual( 591 expected_grads, 592 self.evaluate( 593 strategy.experimental_local_results( 594 strategy.experimental_run(step, inputs)))) 595 596 def _test_device_and_input_device_are_colocated(self, strategy): 597 if context.executing_eagerly(): 598 self.skipTest( 599 "cross-device tests are not supported with eager execution.") 600 workers, _ = test_util.create_local_cluster(2, 0) 601 inputs = strategy.make_input_fn_iterator( 602 lambda _: dataset_ops.Dataset.range(5)) 603 comm_fn = lambda x: x + 1 604 run_op = strategy.experimental_run(comm_fn, inputs) 605 with session_lib.Session(target=workers[1].target) as sess: 606 sess.run(inputs.initialize()) 607 sess.run(run_op) 608 609 def _test_device_and_input_device_are_colocated_with_function(self, strategy): 610 if context.executing_eagerly(): 611 self.skipTest( 612 "cross-device tests are not supported with eager execution.") 613 workers, _ = test_util.create_local_cluster(2, 0) 614 inputs = strategy.make_input_fn_iterator( 615 lambda _: dataset_ops.Dataset.range(5)) 616 comm_fn = lambda x: x + 1 617 experimental_run = def_function.function()(strategy.experimental_run) 618 with ops.device("/job:worker/replica:0/task:1/device:CPU:0"): 619 # The tf.function must be defined on the right device as well. 620 run_op = experimental_run(comm_fn, inputs) 621 with session_lib.Session(target=workers[1].target) as sess: 622 sess.run(inputs.initialize()) 623 sess.run(run_op) 624 625 626class TwoDeviceDistributionTestBase(test.TestCase): 627 """Some tests that should work with any two-device DistributionStrategy.""" 628 629 def _test_run(self, strategy, run_in_function=False): 630 out1 = strategy.run(_maybe_run_in_function( 631 lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1, 632 run_in_function)) 633 self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1))) 634 635 out2 = strategy.run(_maybe_run_in_function( 636 lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,)) 637 out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2)) 638 self.assertAllEqual([2, 4], out2_vals["a"]) 639 self.assertAllEqual([1, 4], out2_vals["b"]) 640 641 out3 = strategy.run(_maybe_run_in_function( 642 lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2) 643 self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3))) 644 645 def _test_all_reduce_sum(self, strategy, run_in_function=False): 646 self._test_collective_comms( 647 strategy, 648 _all_sum, 649 inputs=([1., 3.], [[39., 2.], [3., 41.]]), 650 expected=(4., [42., 43.]), 651 run_in_function=run_in_function) 652 653 def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False): 654 self._test_collective_comms_gradients( 655 strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], 656 run_in_function=run_in_function) 657 658 def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False): 659 self._test_collective_comms_gradient_tape( 660 strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.], 661 run_in_function=run_in_function) 662 663 def _test_all_reduce_mean(self, strategy, run_in_function=False): 664 self._test_collective_comms( 665 strategy, 666 _all_mean, 667 inputs=([1., 3.], [[39., 2.], [3., 41.]]), 668 expected=(2., [21., 21.5]), 669 run_in_function=run_in_function) 670 671 def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False): 672 self._test_collective_comms_gradients( 673 strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], 674 run_in_function=run_in_function) 675 676 def _test_all_reduce_mean_gradient_tape(self, strategy, 677 run_in_function=False): 678 self._test_collective_comms_gradient_tape( 679 strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.], 680 run_in_function=run_in_function) 681 682 def _test_collective_comms(self, strategy, comm_fn, inputs, expected, 683 run_in_function=False): 684 inputs = strategy.make_input_fn_iterator( 685 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 686 687 self.evaluate(inputs.initialize()) 688 outputs = self.evaluate( 689 list( 690 map(strategy.experimental_local_results, 691 strategy.experimental_run( 692 _maybe_run_in_function(comm_fn, run_in_function), inputs)))) 693 self.assertAllEqual([expected[0], expected[0]], outputs[0]) 694 self.assertAllEqual([expected[1], expected[1]], outputs[1]) 695 696 def _test_collective_comms_gradients(self, strategy, comm_fn, inputs, 697 expected_grads, run_in_function=False): 698 if context.executing_eagerly() and not run_in_function: 699 self.skipTest("`tf.gradients` is not supported with eager execution " 700 "without using tf.functions.") 701 702 def step(c): 703 x = array_ops.identity(42.) 704 y = comm_fn(x) * c 705 return gradients_impl.gradients(y, [x])[0] 706 707 inputs = strategy.make_input_fn_iterator( 708 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 709 710 self.evaluate(inputs.initialize()) 711 self.assertAllEqual( 712 expected_grads, 713 self.evaluate( 714 strategy.experimental_local_results( 715 strategy.experimental_run( 716 _maybe_run_in_function(step, run_in_function), inputs)))) 717 718 def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs, 719 expected_grads, 720 run_in_function=False): 721 722 def step(c): 723 x = array_ops.identity(42.) 724 with backprop.GradientTape() as tape: 725 tape.watch(x) 726 y = comm_fn(x) * c 727 return tape.gradient(y, x) 728 729 inputs = strategy.make_input_fn_iterator( 730 lambda _: dataset_ops.Dataset.from_tensor_slices(inputs)) 731 732 self.evaluate(inputs.initialize()) 733 self.assertAllEqual( 734 expected_grads, 735 self.evaluate( 736 strategy.experimental_local_results( 737 strategy.experimental_run( 738 _maybe_run_in_function(step, run_in_function), 739 inputs)))) 740 741 742class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase): 743 """Tests for a Remote single worker.""" 744 745 def _get_num_gpus(self): 746 pass 747 748 def _testNumReplicasInSync(self, distribution): 749 self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync) 750 751 def _testMinimizeLoss(self, distribution): 752 if context.executing_eagerly(): 753 self._test_minimize_loss_eager(distribution) 754 else: 755 self._test_minimize_loss_graph(distribution, learning_rate=0.05) 756 757 def _testDeviceScope(self, distribution): 758 with distribution.scope(): 759 a = array_ops.identity(1.) 760 with ops.device("/cpu:0"): 761 b = array_ops.identity(1.) 762 if context.executing_eagerly(): 763 device = "/job:worker/replica:0/task:0/device:CPU:0" 764 else: 765 device = "/job:worker/replica:0/task:0" 766 self.assertEqual(a.device, device) 767 self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0") 768 769 def _testMakeInputFnIteratorWithDataset(self, distribution): 770 dataset_fn = lambda: dataset_ops.Dataset.range(100) 771 num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return 772 num_workers = 1 773 774 expected_values = [[i+j for j in range(num_gpus)] * num_workers 775 for i in range(0, 100, num_gpus)] 776 777 # Dummy cached_session is used in Eager 778 with self.cached_session() as sess: 779 # `expected_input_pipeline_id` is None because the input_fn will be called 780 # multiple times, each with a different input_pipeline_id. 781 input_fn = self._input_fn_to_test_input_context( 782 dataset_fn, 783 expected_num_replicas_in_sync=num_workers*num_gpus, 784 expected_num_input_pipelines=num_workers, 785 expected_input_pipeline_id=None) 786 iterator = distribution.make_input_fn_iterator(input_fn) 787 self._test_input_fn_iterator( 788 iterator, distribution.extended.worker_devices, expected_values, sess) 789 790 def _testMakeInputFnIteratorWithCallable(self, distribution): 791 def fn(): 792 dataset = dataset_ops.Dataset.range(100) 793 it = dataset_ops.make_one_shot_iterator(dataset) 794 return it.get_next 795 796 num_gpus = self._get_num_gpus() # pylint: disable=assignment-from-no-return 797 num_workers = 1 798 799 expected_values = [] 800 for i in range(0, 100, num_gpus): 801 expected_values.append([i+j for j in range(num_gpus)] * num_workers) 802 803 # Dummy cached_session is used in Eager 804 with self.cached_session() as sess: 805 # `expected_input_pipeline_id` is None because the input_fn will be called 806 # multiple times, each with a different input_pipeline_id. 807 input_fn = self._input_fn_to_test_input_context( 808 fn, 809 expected_num_replicas_in_sync=num_workers*num_gpus, 810 expected_num_input_pipelines=num_workers, 811 expected_input_pipeline_id=None) 812 iterator = distribution.make_input_fn_iterator(input_fn) 813 self._test_input_fn_iterator( 814 iterator, distribution.extended.worker_devices, expected_values, sess, 815 test_reinitialize=False, ignore_order=True) 816 817 818def _all_sum(value): 819 ctx = ds_context.get_replica_context() 820 return ctx.all_reduce(reduce_util.ReduceOp.SUM, value) 821 822 823def _all_mean(value): 824 ctx = ds_context.get_replica_context() 825 return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value) 826