xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/strategy_test_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for testing DistributionStrategy descendants."""
16
17import functools
18import os
19import tempfile
20
21import numpy as np
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.core.util import event_pb2
25from tensorflow.python.client import session as session_lib
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
28from tensorflow.python.distribute import distribute_lib
29from tensorflow.python.distribute import distribute_utils
30from tensorflow.python.distribute import distribution_strategy_context as ds_context
31from tensorflow.python.distribute import mirrored_strategy as mirrored_lib
32from tensorflow.python.distribute import reduce_util
33from tensorflow.python.distribute import tpu_strategy
34from tensorflow.python.eager import backprop
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.eager import test
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import test_util
42from tensorflow.python.lib.io import tf_record
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import gradients_impl
45from tensorflow.python.ops import init_ops
46from tensorflow.python.ops import init_ops_v2
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import summary_ops_v2 as summary_ops
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops import variables
51from tensorflow.python.platform import gfile
52from tensorflow.python.training import optimizer
53from tensorflow.python.training import training_util
54from tensorflow.python.util import nest
55from tensorflow.python.util import tf_inspect
56
57
58class _TestException(Exception):
59  pass
60
61
62# Conditionally wrap the fn in a def_function.function (so it runs in graph
63# mode).
64def _maybe_run_in_function(fn, run_in_function=False):
65  if not run_in_function or not context.executing_eagerly():
66    return fn
67  else:
68    return def_function.function()(fn)
69
70
71# May be the argument to either distribution.extended.call_for_each_replica() or
72# get_replica_context().merge_call()
73def _raise_exception_fn(_=None):
74  raise _TestException()
75
76
77# Must be the argument to a distribution.extended.call_for_each_replica() call,
78# calls a get_replica_context().merge_call() that raises an exception.
79def _merge_raises_fn():
80  ds_context.get_replica_context().merge_call(_raise_exception_fn)
81
82
83# Must be the argument to a get_replica_context().merge_call() call, calls
84# dist.extended.call_for_each_replica() with a function that raises an
85# exception.
86def _call_raises_fn(dist):
87  dist.extended.call_for_each_replica(_raise_exception_fn)
88
89
90# Must be the argument to a distribution.extended.call_for_each_replica() call,
91# calls a get_replica_context().merge_call() that calls a
92# call_for_each_replica() that raises an exception.
93def _merge_call_raises_fn():
94  ds_context.get_replica_context().merge_call(_call_raises_fn)
95
96
97# Must be the argument to a get_replica_context().merge_call() call, calls
98# dist.extended.call_for_each_replica() with a function that calls a
99# get_replica_context().merge_call() that raises an exception.
100def _call_merge_raises_fn(dist):
101  dist.extended.call_for_each_replica(_merge_raises_fn)
102
103
104# Must be the argument to a distribution.extended.call_for_each_replica() call,
105# calls a get_replica_context().merge_call() that calls a
106# call_for_each_replica() that calls a get_replica_context().merge_call() that
107# raises an exception.
108def _merge_call_merge_raises_fn():
109  ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
110
111
112def _events_from_logdir(test_case, logdir):
113  """Reads summary events from log directory."""
114  test_case.assertTrue(gfile.Exists(logdir))
115  files = gfile.ListDirectory(logdir)
116  test_case.assertLen(files, 1)
117  records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
118  result = []
119  for r in records:
120    event = event_pb2.Event()
121    event.ParseFromString(r)
122    result.append(event)
123  return result
124
125
126def create_variable_like_keras_layer(name, shape, dtype):
127  """Utitlity for create variables that works like variable in keras layer."""
128  initializer = functools.partial(
129      init_ops_v2.GlorotUniform(), shape, dtype=dtype)
130  return variables.Variable(
131      initial_value=initializer, name=name, trainable=True)
132
133
134def is_optimizer_v2_instance(optimizer_obj):
135  # For a optimizer instance, the v2 implementation has var_list as a required
136  # argument.
137  arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize)
138  return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
139
140
141def is_mirrored_strategy(strategy: distribute_lib.Strategy) -> bool:
142  return isinstance(
143      strategy,
144      (mirrored_lib.MirroredStrategy, mirrored_lib.MirroredStrategyV1))
145
146
147def is_multi_worker_mirrored_strategy(
148    strategy: distribute_lib.Strategy) -> bool:
149  return isinstance(strategy, (mwms_lib.CollectiveAllReduceStrategy,
150                               mwms_lib.CollectiveAllReduceStrategyV1))
151
152
153def is_tpu_strategy(strategy: distribute_lib.Strategy) -> bool:
154  return isinstance(strategy,
155                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
156                     tpu_strategy.TPUStrategyV2))
157
158
159class DistributionTestBase(test.TestCase):
160  """Some tests that should work with any DistributionStrategy."""
161
162  def _test_minimize_loss_eager(self, d):
163    with d.scope():
164      kernel = create_variable_like_keras_layer(
165          name="kernel", shape=(1, 1), dtype=dtypes.float32)
166      def loss(x):
167        y = array_ops.reshape(
168            math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
169        return y * y
170      # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
171      # common `implicit_grad` function and put it in DistributionStrategy.
172      grad_fn = backprop.implicit_grad(loss)
173      grad_fn = optimizer.get_filtered_grad_fn(grad_fn)
174
175      def update(v, g):
176        return v.assign_sub(0.2 * g)
177
178      one = array_ops.identity([[1.]])
179
180      def step():
181        """Perform one optimization step."""
182        # Run forward & backward to get gradients, variables list.
183        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
184
185        # Update the variables using the gradients and the update() function.
186        before_list = []
187        after_list = []
188        for g, v in g_v:
189          fetched = d.extended.read_var(v)
190          before_list.append(fetched)
191          # control_dependencies irrelevant but harmless in eager execution
192          with ops.control_dependencies([fetched]):
193            g = d.extended.reduce_to(
194                reduce_util.ReduceOp.SUM, g, destinations=v)
195            with ops.control_dependencies(
196                d.extended.update(v, update, args=(g,), group=False)):
197              after_list.append(d.extended.read_var(v))
198        return before_list, after_list
199
200      for i in range(10):
201        b, a = step()
202        if i == 0:
203          before, = b  # pylint: disable=unbalanced-tuple-unpacking
204        after, = a  # pylint: disable=unbalanced-tuple-unpacking
205
206      error_before = abs(before.numpy() - 1)
207      error_after = abs(after.numpy() - 1)
208      # Error should go down
209      self.assertLess(error_after, error_before)
210
211  def _test_minimize_loss_graph(self,
212                                d,
213                                soft_placement=False,
214                                learning_rate=0.2):
215    config = config_pb2.ConfigProto()
216    config.allow_soft_placement = soft_placement
217    config.gpu_options.per_process_gpu_memory_fraction = 0.3
218    with context.graph_mode(), \
219         ops.Graph().as_default(), \
220         self.cached_session(config=config) as sess, \
221         d.scope():
222      kernel = create_variable_like_keras_layer(
223          name="kernel", shape=(1, 1), dtype=dtypes.float32)
224
225      def loss(x):
226        y = array_ops.reshape(
227            math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
228        return y * y
229
230      grad_fn = backprop.implicit_grad(loss)
231
232      def update(v, g):
233        return v.assign_sub(learning_rate * g)
234
235      one = array_ops.identity([[1.]])
236
237      def step():
238        """Perform one optimization step."""
239        # Run forward & backward to get gradients, variables list.
240        g_v = d.extended.call_for_each_replica(grad_fn, args=(one,))
241
242        # Update the variables using the gradients and the update() function.
243        before_list = []
244        after_list = []
245        for g, v in g_v:
246          fetched = d.extended.read_var(v)
247          before_list.append(fetched)
248          with ops.control_dependencies([fetched]):
249            g = d.extended.reduce_to(
250                reduce_util.ReduceOp.SUM, g, destinations=v)
251            with ops.control_dependencies(
252                d.extended.update(v, update, args=(g,), group=False)):
253              after_list.append(d.extended.read_var(v))
254        return before_list, after_list
255
256      before_out, after_out = step()
257      variables.global_variables_initializer().run()
258      for i in range(10):
259        b, a = sess.run((before_out, after_out))
260        if i == 0:
261          before, = b
262        after, = a
263
264      error_before = abs(before - 1)
265      error_after = abs(after - 1)
266      # Error should go down
267      self.assertLess(error_after, error_before)
268
269  def _test_summary_for_replica_zero_only(self, d):
270    logdir = tempfile.mkdtemp()
271
272    def run_fn():
273      """Function executed for each replica."""
274      with summary_writer.as_default():
275        replica_id = ds_context.get_replica_context().replica_id_in_sync_group
276        return summary_ops.write("a", replica_id)
277
278    with self.cached_session() as sess, d.scope(), \
279        summary_ops.always_record_summaries():
280      # We need global_step because summary writing op *always* has global_step
281      # as input, even when we always record summary or never record summary.
282      global_step = training_util.get_or_create_global_step()
283      if not context.executing_eagerly():
284        # When executing eagerly, variables are initialized immediately after
285        # creation, and its initializer will be None.
286        global_step.initializer.run()
287      summary_ops.set_step(0)
288      summary_writer = summary_ops.create_file_writer(logdir)
289      output = d.extended.call_for_each_replica(run_fn)
290      unwrapped = d.unwrap(output)
291      if not context.executing_eagerly():
292        sess.run(summary_writer.init())
293        sess.run(unwrapped)
294        sess.run(summary_writer.close())
295
296      events = _events_from_logdir(self, logdir)
297      # There will be 2 entries: 1 summary file header entry, and 1 entry
298      # written by replica 0.
299      self.assertLen(events, 2)
300      self.assertEqual(events[1].summary.value[0].tag, "a")
301      self.assertEqual(events[1].summary.value[0].simple_value, 0.0)
302
303  def _test_replica_id(self, d):
304    with d.scope():
305      expected_devices = [False] * len(d.extended.worker_devices)
306
307      def mark_devices_fn():
308        replica_id = self.evaluate(
309            ds_context.get_replica_context().replica_id_in_sync_group)
310        self.assertLess(replica_id, len(d.extended.worker_devices))
311        self.assertFalse(expected_devices[replica_id])
312        expected_devices[replica_id] = True
313
314      d.extended.call_for_each_replica(mark_devices_fn)
315      self.assertAllEqual(expected_devices,
316                          [True] * len(d.extended.worker_devices))
317
318  def _test_call_and_merge_exceptions(self, dist):
319    with dist.scope():
320      with self.assertRaises(_TestException):
321        dist.extended.call_for_each_replica(_raise_exception_fn)
322      with self.assertRaises(_TestException):
323        dist.extended.call_for_each_replica(_merge_raises_fn)
324      with self.assertRaises(_TestException):
325        dist.extended.call_for_each_replica(_merge_call_raises_fn)
326      with self.assertRaises(_TestException):
327        dist.extended.call_for_each_replica(_merge_call_merge_raises_fn)
328
329  def _input_fn_to_test_input_context(self, dataset_or_callable_fn,
330                                      expected_num_replicas_in_sync,
331                                      expected_num_input_pipelines,
332                                      expected_input_pipeline_id):
333    # Use a list of one element as counter so that it can be captured by the
334    # `_input_fn`. This counter is incremented by 1 each time an input_fn is
335    # called. We use this counter to check whether the `input_pipeline_id`
336    # matches the counter in the in-graph replication.
337    worker_id_counter = [0]
338
339    def _input_fn(input_context):
340      """Input fn for testing."""
341      self.assertIsNotNone(input_context)
342      self.assertEqual(expected_num_replicas_in_sync,
343                       input_context.num_replicas_in_sync)
344      self.assertEqual(expected_num_input_pipelines,
345                       input_context.num_input_pipelines)
346      if expected_input_pipeline_id is not None:
347        self.assertEqual(expected_input_pipeline_id,
348                         input_context.input_pipeline_id)
349      else:
350        self.assertEqual(worker_id_counter[0], input_context.input_pipeline_id)
351        worker_id_counter[0] += 1
352
353      return dataset_or_callable_fn()
354
355    return _input_fn
356
357  def _test_input_fn_iterable(
358      self, strategy, input_fn, expected_values, ignore_order=False):
359    assert_same = self.assertCountEqual if ignore_order else self.assertEqual
360
361    iterable = strategy.distribute_datasets_from_function(input_fn)
362    if context.executing_eagerly():
363      iterator = iter(iterable)
364
365      for expected_value in expected_values:
366        computed_value = self.evaluate(
367            list(strategy.experimental_local_results(next(iterator))))
368        assert_same(expected_value, computed_value)
369
370      with self.assertRaises(StopIteration):
371        self.evaluate(strategy.experimental_local_results(next(iterator)))
372
373      # After re-initializing the iterator, should be able to iterate again.
374      iterator = iter(iterable)
375
376      for expected_value in expected_values:
377        computed_value = self.evaluate(
378            list(strategy.experimental_local_results(next(iterator))))
379        assert_same(expected_value, computed_value)
380    else:
381      iterator = dataset_ops.make_initializable_iterator(iterable)
382      self._test_input_fn_iterator(iterator, strategy.extended.worker_devices,
383                                   expected_values, test_reinitialize=True,
384                                   ignore_order=ignore_order)
385
386  def _test_input_fn_iterator(self,
387                              iterator,
388                              devices,
389                              expected_values,
390                              sess=None,
391                              test_reinitialize=True,
392                              ignore_order=False):
393    evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
394    evaluate(iterator.initializer)
395
396    for expected_value in expected_values:
397      next_element = iterator.get_next()
398      computed_value = evaluate(
399          [distribute_utils.select_replica(r, next_element) for r in
400           range(len(devices))])
401      if ignore_order:
402        self.assertCountEqual(expected_value, computed_value)
403      else:
404        self.assertEqual(expected_value, computed_value)
405
406    with self.assertRaises(errors.OutOfRangeError):
407      next_element = iterator.get_next()
408      evaluate(
409          [distribute_utils.select_replica(r, next_element) for r in
410           range(len(devices))])
411
412    # After re-initializing the iterator, should be able to iterate again.
413    if test_reinitialize:
414      evaluate(iterator.initializer)
415
416      for expected_value in expected_values:
417        next_element = iterator.get_next()
418        computed_value = evaluate([
419            distribute_utils.select_replica(r, next_element) for r in
420            range(len(devices))
421        ])
422        if ignore_order:
423          self.assertCountEqual(expected_value, computed_value)
424        else:
425          self.assertEqual(expected_value, computed_value)
426
427  def _test_global_step_update(self, strategy):
428    with strategy.scope():
429      global_step = variable_scope.get_variable(
430          "global_step",
431          shape=[],
432          dtype=dtypes.int64,
433          initializer=init_ops.zeros_initializer(),
434          trainable=False,
435          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
436      self.evaluate(variables.global_variables_initializer())
437
438      def model_fn():
439        train_op = global_step.assign_add(1)
440        value = global_step.read_value()
441        return train_op, value
442
443      train_ops, value = strategy.extended.call_for_each_replica(model_fn)
444      self.evaluate(strategy.group(train_ops))
445      global_step_tensors = strategy.experimental_local_results(value)
446      global_step_values = self.evaluate(global_step_tensors)
447      self.assertEqual((1,) * len(global_step_tensors), global_step_values)
448
449  def _test_numpy_dataset(self, strategy, session=None, run_in_function=False):
450    if not isinstance(strategy, distribute_lib.StrategyV1):
451      self.skipTest("n/a: V1 only")
452    cached_session = session or self.cached_session()
453    with strategy.scope(), cached_session as sess:
454      x = np.asarray([[1, 2], [6, 12], [2, 4], [5, 10], [3, 6], [4, 8]])
455      y = np.asarray([5, 4, 3, 2, 1, 0])
456      batch_size = 6
457      if not strategy.extended._global_batch_size:  # pylint: disable=protected-access
458        batch_size = batch_size // strategy.num_replicas_in_sync
459
460      ds = strategy.extended.experimental_make_numpy_dataset(
461          (x, y), session=sess or self.cached_session())
462      ds = ds.repeat(2)  # 2 epochs
463      # We need to use the drop_remainder argument to get a known static
464      # input shape which is required for TPUs.
465      drop_remainder = strategy.extended.experimental_require_static_shapes
466      ds = ds.batch(batch_size, drop_remainder=drop_remainder)
467      i = strategy.make_dataset_iterator(ds)
468
469      self.evaluate(i.initializer)
470
471      def run_and_concatenate(strategy, i):
472        x, y = strategy.experimental_run(
473            _maybe_run_in_function(lambda z: z, run_in_function), i)
474        x, y = self.evaluate((strategy.experimental_local_results(x),
475                              strategy.experimental_local_results(y)))
476        return np.concatenate(x), np.concatenate(y)
477
478      x_1, y_1 = run_and_concatenate(strategy, i)
479      self.assertAllEqual(x, x_1)
480      self.assertAllEqual(y, y_1)
481      x_2, y_2 = run_and_concatenate(strategy, i)
482      self.assertAllEqual(x, x_2)
483      self.assertAllEqual(y, y_2)
484      with self.assertRaises(errors.OutOfRangeError):
485        run_and_concatenate(strategy, i)
486
487  def _test_trainable_variable(self, strategy):
488    for cls in [variables.VariableV1, variables.Variable]:
489      with strategy.scope():
490        v1 = cls(1.0)
491        self.assertEqual(True, v1.trainable)
492
493        v2 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ)
494        self.assertEqual(False, v2.trainable)
495
496        v3 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
497                 trainable=True)
498        self.assertEqual(True, v3.trainable)
499
500        v4 = cls(1.0, synchronization=variables.VariableSynchronization.ON_READ,
501                 trainable=False)
502        self.assertEqual(False, v4.trainable)
503
504
505class OneDeviceDistributionTestBase(test.TestCase):
506  """Some tests that should work with any one-device DistributionStrategy."""
507
508  def _test_run(self, strategy):
509    out1 = strategy.run(lambda: array_ops.identity(4.))
510    self.assertAllEqual([4.], self.evaluate(strategy.unwrap(out1)))
511
512    out2 = strategy.run(lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
513    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
514    self.assertAllEqual([8.], out2_vals["a"])
515    self.assertAllEqual([16.], out2_vals["b"])
516
517    out3 = strategy.run(lambda b, a: a + 2 * b + 2, kwargs=out2)
518    self.assertAllEqual([42.], self.evaluate(strategy.unwrap(out3)))
519
520  def _test_all_reduce_sum(self, strategy):
521    self._test_collective_comms(
522        strategy, _all_sum, inputs=(4., [42., 43.]), expected=(4., [42., 43.]))
523
524  def _test_all_reduce_sum_gradients(self, strategy):
525    self._test_collective_comms_gradients(
526        strategy, _all_sum, inputs=[4.], expected_grads=[4.])
527
528  def _test_all_reduce_sum_gradient_tape(self, strategy):
529    self._test_collective_comms_gradient_tape(
530        strategy, _all_sum, inputs=[4.], expected_grads=[4.])
531
532  def _test_all_reduce_mean(self, strategy):
533    self._test_collective_comms(
534        strategy, _all_mean, inputs=(2., [21., 22.]), expected=(2., [21., 22.]))
535
536  def _test_all_reduce_mean_gradients(self, strategy):
537    self._test_collective_comms_gradients(
538        strategy, _all_mean, inputs=[5.], expected_grads=[5.])
539
540  def _test_all_reduce_mean_gradient_tape(self, strategy):
541    self._test_collective_comms_gradient_tape(
542        strategy, _all_mean, inputs=[5.], expected_grads=[5.])
543
544  def _test_collective_comms(self, strategy, comm_fn, inputs, expected):
545    inputs = strategy.make_input_fn_iterator(
546        lambda _: dataset_ops.Dataset.from_tensors(inputs))
547
548    self.evaluate(inputs.initialize())
549    outputs = self.evaluate(
550        list(
551            map(strategy.experimental_local_results,
552                strategy.experimental_run(comm_fn, inputs))))
553    self.assertAllEqual([expected[0]], outputs[0])
554    self.assertAllEqual([expected[1]], outputs[1])
555
556  def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
557                                       expected_grads):
558    if context.executing_eagerly():
559      self.skipTest("`tf.gradients` is not supported with eager execution.")
560
561    def step(c):
562      x = array_ops.identity(42.)
563      y = comm_fn(x) * c
564      return gradients_impl.gradients(y, [x])[0]
565
566    inputs = strategy.make_input_fn_iterator(
567        lambda _: dataset_ops.Dataset.from_tensors(inputs))
568
569    self.evaluate(inputs.initialize())
570    self.assertAllEqual(
571        expected_grads,
572        self.evaluate(
573            strategy.experimental_local_results(
574                strategy.experimental_run(step, inputs))))
575
576  def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
577                                           expected_grads):
578
579    def step(c):
580      x = array_ops.identity(42.)
581      with backprop.GradientTape() as tape:
582        tape.watch(x)
583        y = comm_fn(x) * c
584      return tape.gradient(y, x)
585
586    inputs = strategy.make_input_fn_iterator(
587        lambda _: dataset_ops.Dataset.from_tensors(inputs))
588
589    self.evaluate(inputs.initialize())
590    self.assertAllEqual(
591        expected_grads,
592        self.evaluate(
593            strategy.experimental_local_results(
594                strategy.experimental_run(step, inputs))))
595
596  def _test_device_and_input_device_are_colocated(self, strategy):
597    if context.executing_eagerly():
598      self.skipTest(
599          "cross-device tests are not supported with eager execution.")
600    workers, _ = test_util.create_local_cluster(2, 0)
601    inputs = strategy.make_input_fn_iterator(
602        lambda _: dataset_ops.Dataset.range(5))
603    comm_fn = lambda x: x + 1
604    run_op = strategy.experimental_run(comm_fn, inputs)
605    with session_lib.Session(target=workers[1].target) as sess:
606      sess.run(inputs.initialize())
607      sess.run(run_op)
608
609  def _test_device_and_input_device_are_colocated_with_function(self, strategy):
610    if context.executing_eagerly():
611      self.skipTest(
612          "cross-device tests are not supported with eager execution.")
613    workers, _ = test_util.create_local_cluster(2, 0)
614    inputs = strategy.make_input_fn_iterator(
615        lambda _: dataset_ops.Dataset.range(5))
616    comm_fn = lambda x: x + 1
617    experimental_run = def_function.function()(strategy.experimental_run)
618    with ops.device("/job:worker/replica:0/task:1/device:CPU:0"):
619      # The tf.function must be defined on the right device as well.
620      run_op = experimental_run(comm_fn, inputs)
621    with session_lib.Session(target=workers[1].target) as sess:
622      sess.run(inputs.initialize())
623      sess.run(run_op)
624
625
626class TwoDeviceDistributionTestBase(test.TestCase):
627  """Some tests that should work with any two-device DistributionStrategy."""
628
629  def _test_run(self, strategy, run_in_function=False):
630    out1 = strategy.run(_maybe_run_in_function(
631        lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1,
632        run_in_function))
633    self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
634
635    out2 = strategy.run(_maybe_run_in_function(
636        lambda x: {"a": x * 2, "b": x * x}, run_in_function), args=(out1,))
637    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
638    self.assertAllEqual([2, 4], out2_vals["a"])
639    self.assertAllEqual([1, 4], out2_vals["b"])
640
641    out3 = strategy.run(_maybe_run_in_function(
642        lambda b, a: a + 2 * b + 2, run_in_function), kwargs=out2)
643    self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
644
645  def _test_all_reduce_sum(self, strategy, run_in_function=False):
646    self._test_collective_comms(
647        strategy,
648        _all_sum,
649        inputs=([1., 3.], [[39., 2.], [3., 41.]]),
650        expected=(4., [42., 43.]),
651        run_in_function=run_in_function)
652
653  def _test_all_reduce_sum_gradients(self, strategy, run_in_function=False):
654    self._test_collective_comms_gradients(
655        strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
656        run_in_function=run_in_function)
657
658  def _test_all_reduce_sum_gradient_tape(self, strategy, run_in_function=False):
659    self._test_collective_comms_gradient_tape(
660        strategy, _all_sum, inputs=[1., 3.], expected_grads=[4., 4.],
661        run_in_function=run_in_function)
662
663  def _test_all_reduce_mean(self, strategy, run_in_function=False):
664    self._test_collective_comms(
665        strategy,
666        _all_mean,
667        inputs=([1., 3.], [[39., 2.], [3., 41.]]),
668        expected=(2., [21., 21.5]),
669        run_in_function=run_in_function)
670
671  def _test_all_reduce_mean_gradients(self, strategy, run_in_function=False):
672    self._test_collective_comms_gradients(
673        strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
674        run_in_function=run_in_function)
675
676  def _test_all_reduce_mean_gradient_tape(self, strategy,
677                                          run_in_function=False):
678    self._test_collective_comms_gradient_tape(
679        strategy, _all_mean, inputs=[1., 3.], expected_grads=[2., 2.],
680        run_in_function=run_in_function)
681
682  def _test_collective_comms(self, strategy, comm_fn, inputs, expected,
683                             run_in_function=False):
684    inputs = strategy.make_input_fn_iterator(
685        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
686
687    self.evaluate(inputs.initialize())
688    outputs = self.evaluate(
689        list(
690            map(strategy.experimental_local_results,
691                strategy.experimental_run(
692                    _maybe_run_in_function(comm_fn, run_in_function), inputs))))
693    self.assertAllEqual([expected[0], expected[0]], outputs[0])
694    self.assertAllEqual([expected[1], expected[1]], outputs[1])
695
696  def _test_collective_comms_gradients(self, strategy, comm_fn, inputs,
697                                       expected_grads, run_in_function=False):
698    if context.executing_eagerly() and not run_in_function:
699      self.skipTest("`tf.gradients` is not supported with eager execution "
700                    "without using tf.functions.")
701
702    def step(c):
703      x = array_ops.identity(42.)
704      y = comm_fn(x) * c
705      return gradients_impl.gradients(y, [x])[0]
706
707    inputs = strategy.make_input_fn_iterator(
708        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
709
710    self.evaluate(inputs.initialize())
711    self.assertAllEqual(
712        expected_grads,
713        self.evaluate(
714            strategy.experimental_local_results(
715                strategy.experimental_run(
716                    _maybe_run_in_function(step, run_in_function), inputs))))
717
718  def _test_collective_comms_gradient_tape(self, strategy, comm_fn, inputs,
719                                           expected_grads,
720                                           run_in_function=False):
721
722    def step(c):
723      x = array_ops.identity(42.)
724      with backprop.GradientTape() as tape:
725        tape.watch(x)
726        y = comm_fn(x) * c
727      return tape.gradient(y, x)
728
729    inputs = strategy.make_input_fn_iterator(
730        lambda _: dataset_ops.Dataset.from_tensor_slices(inputs))
731
732    self.evaluate(inputs.initialize())
733    self.assertAllEqual(
734        expected_grads,
735        self.evaluate(
736            strategy.experimental_local_results(
737                strategy.experimental_run(
738                    _maybe_run_in_function(step, run_in_function),
739                    inputs))))
740
741
742class RemoteSingleWorkerMirroredStrategyBase(DistributionTestBase):
743  """Tests for a Remote single worker."""
744
745  def _get_num_gpus(self):
746    pass
747
748  def _testNumReplicasInSync(self, distribution):
749    self.assertEqual(self._get_num_gpus(), distribution.num_replicas_in_sync)
750
751  def _testMinimizeLoss(self, distribution):
752    if context.executing_eagerly():
753      self._test_minimize_loss_eager(distribution)
754    else:
755      self._test_minimize_loss_graph(distribution, learning_rate=0.05)
756
757  def _testDeviceScope(self, distribution):
758    with distribution.scope():
759      a = array_ops.identity(1.)
760      with ops.device("/cpu:0"):
761        b = array_ops.identity(1.)
762      if context.executing_eagerly():
763        device = "/job:worker/replica:0/task:0/device:CPU:0"
764      else:
765        device = "/job:worker/replica:0/task:0"
766      self.assertEqual(a.device, device)
767      self.assertEqual(b.device, "/job:worker/replica:0/task:0/device:CPU:0")
768
769  def _testMakeInputFnIteratorWithDataset(self, distribution):
770    dataset_fn = lambda: dataset_ops.Dataset.range(100)
771    num_gpus = self._get_num_gpus()  # pylint: disable=assignment-from-no-return
772    num_workers = 1
773
774    expected_values = [[i+j for j in range(num_gpus)] * num_workers
775                       for i in range(0, 100, num_gpus)]
776
777    # Dummy cached_session is used in Eager
778    with self.cached_session() as sess:
779      # `expected_input_pipeline_id` is None because the input_fn will be called
780      # multiple times, each with a different input_pipeline_id.
781      input_fn = self._input_fn_to_test_input_context(
782          dataset_fn,
783          expected_num_replicas_in_sync=num_workers*num_gpus,
784          expected_num_input_pipelines=num_workers,
785          expected_input_pipeline_id=None)
786      iterator = distribution.make_input_fn_iterator(input_fn)
787      self._test_input_fn_iterator(
788          iterator, distribution.extended.worker_devices, expected_values, sess)
789
790  def _testMakeInputFnIteratorWithCallable(self, distribution):
791    def fn():
792      dataset = dataset_ops.Dataset.range(100)
793      it = dataset_ops.make_one_shot_iterator(dataset)
794      return it.get_next
795
796    num_gpus = self._get_num_gpus()  # pylint: disable=assignment-from-no-return
797    num_workers = 1
798
799    expected_values = []
800    for i in range(0, 100, num_gpus):
801      expected_values.append([i+j for j in range(num_gpus)] * num_workers)
802
803    # Dummy cached_session is used in Eager
804    with self.cached_session() as sess:
805      # `expected_input_pipeline_id` is None because the input_fn will be called
806      # multiple times, each with a different input_pipeline_id.
807      input_fn = self._input_fn_to_test_input_context(
808          fn,
809          expected_num_replicas_in_sync=num_workers*num_gpus,
810          expected_num_input_pipelines=num_workers,
811          expected_input_pipeline_id=None)
812      iterator = distribution.make_input_fn_iterator(input_fn)
813      self._test_input_fn_iterator(
814          iterator, distribution.extended.worker_devices, expected_values, sess,
815          test_reinitialize=False, ignore_order=True)
816
817
818def _all_sum(value):
819  ctx = ds_context.get_replica_context()
820  return ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
821
822
823def _all_mean(value):
824  ctx = ds_context.get_replica_context()
825  return ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
826