xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/mirrored_strategy_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for MirroredStrategy."""
16
17import json
18import sys
19
20from absl.testing import parameterized
21
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.python import tf2
24from tensorflow.python.autograph.core import converter_testing
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.distribute import combinations
27from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
28from tensorflow.python.distribute import device_util
29from tensorflow.python.distribute import distribute_lib
30from tensorflow.python.distribute import distribute_utils
31from tensorflow.python.distribute import distribution_strategy_context as ds_context
32from tensorflow.python.distribute import mirrored_strategy
33from tensorflow.python.distribute import multi_worker_test_base
34from tensorflow.python.distribute import reduce_util
35from tensorflow.python.distribute import strategy_combinations
36from tensorflow.python.distribute import strategy_test_lib
37from tensorflow.python.distribute import test_util
38from tensorflow.python.distribute import values
39from tensorflow.python.distribute.v1 import input_lib as input_lib_v1
40from tensorflow.python.eager import backprop
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import function
44from tensorflow.python.eager import test
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import device as tf_device
47from tensorflow.python.framework import dtypes
48from tensorflow.python.framework import func_graph
49from tensorflow.python.framework import ops
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import tensor_util
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import control_flow_ops
54from tensorflow.python.ops import gradients
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import variable_scope
57from tensorflow.python.ops import variables
58from tensorflow.python.training import server_lib
59from tensorflow.python.util import traceback_utils
60
61
62GPU_TEST = "test_gpu" in sys.argv[0]
63
64
65@combinations.generate(
66    combinations.combine(
67        distribution=[
68            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
69            strategy_combinations.mirrored_strategy_with_two_gpus,
70            strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
71        ],
72        mode=["graph", "eager"]))
73class MirroredTwoDeviceDistributionTest(
74    strategy_test_lib.DistributionTestBase,
75    strategy_test_lib.TwoDeviceDistributionTestBase,
76    parameterized.TestCase):
77
78  def testMinimizeLoss(self, distribution):
79    if context.executing_eagerly():
80      self._test_minimize_loss_eager(distribution)
81    else:
82      self._test_minimize_loss_graph(distribution)
83
84  def testReplicaId(self, distribution):
85    self._test_replica_id(distribution)
86
87  def testNumReplicasInSync(self, distribution):
88    self.assertEqual(2, distribution.num_replicas_in_sync)
89
90  def testCallAndMergeExceptions(self, distribution):
91    self._test_call_and_merge_exceptions(distribution)
92
93  def testRunRegroupError(self, distribution):
94    if not distribution.extended._use_merge_call():
95      self.skipTest("Collective all-reduce does not support int32 on GPU.")
96    def run_fn():
97      replica_id = int(self.evaluate(_replica_id()))
98      # Generates a list with different lengths on different devices.
99      # Will fail in _regroup() (if more than one device).
100      return list(range(replica_id))
101
102    with distribution.scope(), self.assertRaises(AssertionError):
103      distribution.extended.call_for_each_replica(run_fn)
104
105  def testReduceToCpu(self, distribution):
106    if not distribution.extended._use_merge_call():
107      self.skipTest("Collective all-reduce does not support int32 on GPU.")
108
109    with distribution.scope():
110      result = distribution.extended.call_for_each_replica(_replica_id)
111      reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=None)
112      expected = sum(range(distribution.num_replicas_in_sync))
113      self.assertEqual(expected, self.evaluate(reduced))
114
115  def testReduceToCpuNested(self, distribution):
116    if not distribution.extended._use_merge_call():
117      self.skipTest("Collective all-reduce does not support int32 on GPU.")
118
119    with distribution.scope():
120      def replica_fn(input_tensor):
121        return input_tensor + constant_op.constant(
122            1.0), input_tensor - constant_op.constant(1.0)
123
124      input_tensor = constant_op.constant(3.0)
125      run_result = distribution.run(replica_fn, args=(input_tensor,))
126      reduced_result = distribution.reduce("SUM", run_result, axis=None)
127      expected_result = (4 * distribution.num_replicas_in_sync,
128                         2 * distribution.num_replicas_in_sync)
129
130      self.assertEqual(expected_result, self.evaluate(reduced_result))
131
132  def reduce_axis_helper(self, distribution, replica_squared_fn):
133    with distribution.scope():
134      num_replicas = distribution.num_replicas_in_sync
135      result = distribution.extended.call_for_each_replica(replica_squared_fn)
136      # sum
137      reduced = distribution.reduce(reduce_util.ReduceOp.SUM, result, axis=0)
138      expected = sum(x * (x + 1) for x in range(num_replicas))
139      self.assertNear(expected, self.evaluate(reduced), 0.00001)
140
141      # mean
142      reduced = distribution.reduce(reduce_util.ReduceOp.MEAN, result, axis=0)
143      expected /= sum(x + 1 for x in range(num_replicas))
144      self.assertNear(expected, self.evaluate(reduced), 0.00001)
145
146  def testReduceAxisToCpu(self, distribution):
147    if not distribution.extended._use_merge_call():
148      self.skipTest("Collective all-reduce does not support int32 on GPU.")
149    for dtype in (dtypes.float32, dtypes.int32):
150      def replica_squared_fn(dtype=dtype):
151        # Lists with different lengths on different replicas.
152        replica_id = _replica_id_as_int()
153        return array_ops.identity(
154            math_ops.cast([replica_id] * (replica_id + 1), dtype))
155
156      self.reduce_axis_helper(distribution, replica_squared_fn)
157
158  def set_v2_tensorshape(self, v2):
159    if v2:
160      tensor_shape.enable_v2_tensorshape()
161    else:
162      tensor_shape.disable_v2_tensorshape()
163
164  def testReduceAxisToCpuUnknownShape(self, distribution):
165    if not distribution.extended._use_merge_call():
166      self.skipTest("Collective all-reduce does not support int32 on GPU.")
167    original_v2 = tensor_shape._TENSORSHAPE_V2_OVERRIDE  # pylint: disable=protected-access
168    try:
169      for v2 in (False, True):
170        self.set_v2_tensorshape(v2)
171        for dtype in (dtypes.float32, dtypes.int32):
172          for shape in ((None,), None):  # Test both unknown size and rank.
173            def replica_squared_fn(dtype=dtype, shape=shape):
174              # Lists with different lengths on different replicas.
175              replica_id = _replica_id_as_int()
176              tensor = math_ops.cast([replica_id] * (replica_id + 1), dtype)
177              # Erase shape information
178              return array_ops.placeholder_with_default(tensor, shape=shape)
179
180            self.reduce_axis_helper(distribution, replica_squared_fn)
181    finally:
182      self.set_v2_tensorshape(original_v2)
183
184  def testReplicateDataset(self, distribution):
185    if tf2.enabled() and not context.executing_eagerly():
186      self.skipTest("Skipping test since we do not support graph mode in TF 2")
187
188    dataset_fn = lambda: dataset_ops.Dataset.range(10)
189    expected_values = [[i, i+1] for i in range(0, 10, 2)]
190    input_fn = self._input_fn_to_test_input_context(
191        dataset_fn,
192        expected_num_replicas_in_sync=2,
193        expected_num_input_pipelines=1,
194        expected_input_pipeline_id=0)
195    self._test_input_fn_iterable(distribution, input_fn, expected_values)
196
197  def testMakeInputFnIteratorWithDataset(self, distribution):
198    dataset_fn = lambda: dataset_ops.Dataset.range(10)
199    expected_values = [[i, i+1] for i in range(0, 10, 2)]
200
201    input_fn = self._input_fn_to_test_input_context(
202        dataset_fn,
203        expected_num_replicas_in_sync=2,
204        expected_num_input_pipelines=1,
205        expected_input_pipeline_id=0)
206    iterator = distribution.make_input_fn_iterator(input_fn)
207    self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
208                                 expected_values)
209
210  def testMakeInputFnIteratorWithCallable(self, distribution):
211    def fn():
212      dataset = dataset_ops.Dataset.range(2).interleave(
213          (lambda _: dataset_ops.Dataset.range(10)), cycle_length=2)
214      it = dataset_ops.make_one_shot_iterator(dataset)
215      return it.get_next
216    expected_values = [[i, i] for i in range(0, 10)]
217
218    input_fn = self._input_fn_to_test_input_context(
219        fn,
220        expected_num_replicas_in_sync=2,
221        expected_num_input_pipelines=1,
222        expected_input_pipeline_id=0)
223    iterator = distribution.make_input_fn_iterator(input_fn)
224    self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
225                                 expected_values, test_reinitialize=False,
226                                 ignore_order=True)
227
228  def testNumpyDataset(self, distribution):
229    self._test_numpy_dataset(distribution)
230
231  def testGlobalStepUpdate(self, distribution):
232    self._test_global_step_update(distribution)
233
234  def testRun(self, distribution):
235    self._test_run(distribution)
236
237  def testAllReduceSum(self, distribution):
238    self._test_all_reduce_sum(distribution)
239
240  def testAllReduceSumGradients(self, distribution):
241    self._test_all_reduce_sum_gradients(distribution)
242
243  def testAllReduceSumGradientTape(self, distribution):
244    self._test_all_reduce_sum_gradient_tape(distribution)
245
246  def testAllReduceMean(self, distribution):
247    self._test_all_reduce_mean(distribution)
248
249  def testAllReduceMeanGradients(self, distribution):
250    self._test_all_reduce_mean_gradients(distribution)
251
252  def testAllReduceMeanGradientTape(self, distribution):
253    self._test_all_reduce_mean_gradient_tape(distribution)
254
255  def testSummaryForReplicaZeroOnly(self, distribution):
256    self._test_summary_for_replica_zero_only(distribution)
257
258  def testTrainableVariables(self, distribution):
259    self._test_trainable_variable(distribution)
260
261  def test_prefetch_to_device_dataset(self, distribution):
262    input_options = distribute_lib.InputOptions(
263        experimental_fetch_to_device=True)
264    dataset = dataset_ops.Dataset.range(100)
265    dataset = dataset.batch(distribution.num_replicas_in_sync)
266    dataset = distribution.experimental_distribute_dataset(
267        dataset, options=input_options)
268    if context.executing_eagerly():
269      item = next(iter(dataset))
270    else:
271      if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
272        item = dataset.make_initializable_iterator().get_next()
273      else:
274        self.skipTest("unsupported test combination")
275    device_types = [
276        tf_device.DeviceSpec.from_string(tensor.device).device_type for
277        tensor in item.values]
278    expected_device_types = [
279        tf_device.DeviceSpec.from_string(device).device_type for
280        device in distribution.extended.worker_devices]
281    self.assertAllEqual(device_types, expected_device_types)
282
283  def test_prefetch_to_host_dataset(self, distribution):
284    input_options = distribute_lib.InputOptions(
285        experimental_fetch_to_device=False)
286    dataset = dataset_ops.Dataset.range(100)
287    dataset = dataset.batch(distribution.num_replicas_in_sync)
288    dataset = distribution.experimental_distribute_dataset(
289        dataset, options=input_options)
290    if context.executing_eagerly():
291      item = next(iter(dataset))
292    else:
293      if isinstance(dataset, input_lib_v1.DistributedDatasetV1):
294        item = dataset.make_initializable_iterator().get_next()
295      else:
296        self.skipTest("unsupported test combination")
297    device_types = {
298        tf_device.DeviceSpec.from_string(tensor.device).device_type for
299        tensor in item.values}
300    self.assertAllEqual(list(device_types), ["CPU"])
301
302
303def one_device_combinations():
304  return combinations.combine(
305      distribution=[
306          strategy_combinations.mirrored_strategy_with_one_cpu,
307          strategy_combinations.mirrored_strategy_with_one_gpu,
308      ],
309      mode=["graph", "eager"])
310
311
312@combinations.generate(one_device_combinations())
313class MirroredOneDeviceDistributionTest(
314    strategy_test_lib.DistributionTestBase,
315    strategy_test_lib.OneDeviceDistributionTestBase,
316    parameterized.TestCase):
317
318  def testMinimizeLoss(self, distribution):
319    if context.executing_eagerly():
320      self._test_minimize_loss_eager(distribution)
321    else:
322      self._test_minimize_loss_graph(distribution)
323
324  def testReplicaId(self, distribution):
325    self._test_replica_id(distribution)
326
327  def testCallAndMergeExceptions(self, distribution):
328    self._test_call_and_merge_exceptions(distribution)
329
330  def testRun(self, distribution):
331    self._test_run(distribution)
332
333  def testAllReduceSum(self, distribution):
334    self._test_all_reduce_sum(distribution)
335
336  def testAllReduceSumGradients(self, distribution):
337    self._test_all_reduce_sum_gradients(distribution)
338
339  def testAllReduceSumGradientTape(self, distribution):
340    self._test_all_reduce_sum_gradient_tape(distribution)
341
342  def testAllReduceMean(self, distribution):
343    self._test_all_reduce_mean(distribution)
344
345  def testAllReduceMeanGradients(self, distribution):
346    self._test_all_reduce_mean_gradients(distribution)
347
348  def testAllReduceMeanGradientTape(self, distribution):
349    self._test_all_reduce_mean_gradient_tape(distribution)
350
351
352class MirroredStrategyVariableCreatorStackTest(
353    test.TestCase, parameterized.TestCase):
354
355  @combinations.generate(
356      combinations.combine(
357          distribution=[
358              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
359          ],
360          mode=["graph"]))
361  def testCreatorStacksAreThreadLocal(self, distribution):
362    def model_fn():
363      replica_id_str = str(self.evaluate(_replica_id()))
364
365      def thread_creator_fn(next_creator, **kwargs):
366        return next_creator(**kwargs) + ":thread_" + replica_id_str
367
368      with variable_scope.variable_creator_scope(thread_creator_fn):
369        # Create a variable in this scope.
370        v = variable_scope.variable(1.0)
371
372        # This will pause the current thread, and execute the other thread.
373        ds_context.get_replica_context().merge_call(lambda _: _)
374      return v
375
376    def main_thread_creator(next_creator, **kwargs):
377      # We are not using the underlying next_creator for test purposes.
378      del next_creator, kwargs
379      return "main_thread"
380
381    with context.graph_mode(), \
382        distribution.scope(), \
383        variable_scope.variable_creator_scope(main_thread_creator):
384      result = distribution.extended.call_for_each_replica(model_fn)
385      result = distribution.experimental_local_results(result)
386      expected = ("main_thread:thread_0", "main_thread:thread_1")
387      self.assertEqual(expected, result)
388
389
390@combinations.generate(
391    combinations.combine(
392        distribution=[
393            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
394        ],
395        mode=["graph", "eager"]))
396class MirroredStrategyCallForEachReplicaTest(test.TestCase):
397
398  def testExecutingEagerlyOutsideFunction(self, distribution):
399    """Verify we preserve the value of executing_eagerly_outside_functions()."""
400    def model_fn():
401      return ops.executing_eagerly_outside_functions()
402
403    originally = ops.executing_eagerly_outside_functions()
404    with distribution.scope():
405      in_scope = ops.executing_eagerly_outside_functions()
406      in_model_fn = distribution.extended.call_for_each_replica(model_fn)
407      unwrapped = distribution.experimental_local_results(in_model_fn)
408      self.assertEqual(in_scope, unwrapped[0])
409      self.assertEqual(in_scope, originally)
410
411    # Verify this all again, but this time in a FuncGraph.
412    with func_graph.FuncGraph("fg").as_default(), distribution.scope():
413      in_scope = ops.executing_eagerly_outside_functions()
414      in_model_fn = distribution.extended.call_for_each_replica(model_fn)
415      unwrapped = distribution.experimental_local_results(in_model_fn)
416      self.assertEqual(in_scope, unwrapped[0])
417      self.assertEqual(in_scope, originally)
418
419  def testFunctionInCallForEachReplica(self, distribution):
420    traces = []
421    @def_function.function
422    def model_fn():
423      traces.append(1)
424      return ds_context.get_replica_context().replica_id_in_sync_group
425
426    with distribution.scope():
427      result = distribution.extended.call_for_each_replica(model_fn)
428      self.assertEqual(
429          (0, 1),
430          self.evaluate(distribution.experimental_local_results(result)))
431      self.assertLen(traces, distribution.num_replicas_in_sync)
432
433  def testFunctionInCallForEachReplicaInsideAnotherFunction(self, distribution):
434    traces = []
435    @def_function.function
436    def model_fn():
437      traces.append(1)
438      return ds_context.get_replica_context().replica_id_in_sync_group
439
440    @def_function.function
441    def step():
442      return distribution.extended.call_for_each_replica(model_fn)
443
444    with distribution.scope():
445      result = step()
446      self.assertEqual(
447          (0, 1),
448          self.evaluate(distribution.experimental_local_results(result)))
449      self.assertLen(traces, distribution.num_replicas_in_sync)
450
451  def testControlFlowFunctionInCallForEachReplicaWithMergeCall(
452      self, distribution):
453
454    def merge_fn(strategy, value):
455      return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
456
457    @def_function.function
458    def model_fn():
459
460      def body_fn(i):
461        return ds_context.get_replica_context().merge_call(merge_fn, args=(i,))
462
463      return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0])
464
465    with distribution.scope():
466      with self.assertRaisesRegex(
467          RuntimeError, "`merge_call` called while defining a new graph."):
468        distribution.extended.call_for_each_replica(model_fn)
469
470  def testNestedFunctionInCallForEachReplicaWithMergeCall(self, distribution):
471
472    def merge_fn(strategy, value):
473      return strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
474
475    def model_fn():
476
477      @def_function.function
478      def model_fn_nested():
479        t = constant_op.constant(1)
480        return ds_context.get_replica_context().merge_call(merge_fn, args=(t,))
481
482      return model_fn_nested()
483
484    with distribution.scope():
485      with self.assertRaisesRegex(
486          RuntimeError, "`merge_call` called while defining a new graph."):
487        distribution.extended.call_for_each_replica(model_fn)
488
489  def testFunctionInCallForEachReplicaWithMergeCall(self, distribution):
490    def merge_fn(_):
491      pass
492
493    @def_function.function
494    def model_fn():
495      ds_context.get_replica_context().merge_call(merge_fn)
496      return 0.
497
498    with distribution.scope():
499      self.assertEqual(
500          self.evaluate(distribution.extended.call_for_each_replica(model_fn)),
501          0.)
502
503  def testFunctionInCallForEachReplicaCached(self, distribution):
504    traces = []
505
506    @def_function.function
507    def model_fn():
508      traces.append(None)
509
510    self.assertEmpty(traces)
511
512    for i in range(10):
513      distribution.extended.call_for_each_replica(model_fn)
514
515      if i == 0:
516        num_devices = len(traces)
517        self.assertGreater(num_devices, 0)
518      else:
519        # model_fn should not have been re-evaluated so the length should remain
520        # the same.
521        self.assertLen(traces, num_devices)
522
523
524@combinations.generate(
525    combinations.combine(
526        distribution=[
527            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
528        ],
529        mode=["graph"]))
530class MirroredStrategyNameScopeTest(test.TestCase):
531  # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
532  # testing this in eager mode.
533
534  def testNameScope(self, distribution):
535    def model_fn():
536      with ops.name_scope("foo"):
537        a = constant_op.constant(1.0, name="a")
538        ds_context.get_replica_context().merge_call(lambda _: _)
539        b = constant_op.constant(1.0, name="b")
540      return a, b
541
542    with context.graph_mode(), distribution.scope():
543      with ops.name_scope("main"):
544        result = distribution.extended.call_for_each_replica(model_fn)
545        self.assertEqual(2, len(result))
546        for v, name in zip(result, ["a", "b"]):
547          self.assertIsInstance(v, values.DistributedValues)
548          v0, v1 = distribution.experimental_local_results(v)
549          self.assertEqual("main/foo/" + name + ":0", v0.name)
550          self.assertEqual("main/replica_1/foo/" + name + ":0", v1.name)
551
552  def testWithDefaultName(self, distribution):
553    def model_fn():
554      with ops.name_scope(None, "foo"):
555        a = constant_op.constant(1.0, name="a")
556        ds_context.get_replica_context().merge_call(lambda _: _)
557        b = constant_op.constant(2.0, name="b")
558      return a, b
559
560    with context.graph_mode(), distribution.scope():
561      result = distribution.extended.call_for_each_replica(model_fn)
562      self.assertEqual(2, len(result))
563      for v, name in zip(result, ["a", "b"]):
564        self.assertIsInstance(v, values.DistributedValues)
565        v0, v1 = distribution.experimental_local_results(v)
566        self.assertEqual("foo/" + name + ":0", v0.name)
567        self.assertEqual("replica_1/foo/" + name + ":0", v1.name)
568
569  # variable_scope.variable() respects name scopes when creating
570  # variables. On the other hand variable_scope.get_variable() ignores name
571  # scopes but respects variable scope when creating variables. We test both
572  # methods of creating variables to make sure that we have the same
573  # variable names in both cases.
574  def testNameScopeWithVariable(self, distribution):
575    def in_cross_replica(_):
576      c = variable_scope.variable(1.0, name="c")
577      return c
578
579    def model_fn():
580      b = variable_scope.variable(1.0, name="b")
581      with ops.name_scope("foo"):
582        c = ds_context.get_replica_context().merge_call(in_cross_replica)
583      return b, c
584
585    with context.graph_mode(), distribution.scope():
586      with ops.name_scope("main"):
587        a = variable_scope.variable(1.0, name="a")
588        result = distribution.extended.call_for_each_replica(model_fn)
589      result_b = result[0]
590      result_c = result[1]
591      self.assertIsInstance(result_b, values.DistributedValues)
592      self.assertIsInstance(result_c, values.DistributedValues)
593      a0, a1 = distribution.experimental_local_results(a)
594      b0, b1 = distribution.experimental_local_results(result_b)
595      c0, c1 = distribution.experimental_local_results(result_c)
596      self.assertEqual("main/a:0", a0.name)
597      self.assertEqual("main/a/replica_1:0", a1.name)
598      self.assertEqual("main/b:0", b0.name)
599      self.assertEqual("main/b/replica_1:0", b1.name)
600      self.assertEqual("main/foo/c:0", c0.name)
601      self.assertEqual("main/foo/c/replica_1:0", c1.name)
602
603  def testNameScopeWithGetVariable(self, distribution):
604    def in_cross_replica(_):
605      c = variable_scope.get_variable("c", [1])
606      return c
607
608    def model_fn():
609      b = variable_scope.get_variable("b", [1])
610      with ops.name_scope("foo"):
611        c = ds_context.get_replica_context().merge_call(in_cross_replica)
612      return b, c
613
614    with context.graph_mode(), distribution.scope():
615      with ops.name_scope("main"):
616        a = variable_scope.get_variable("a", [1])
617        result = distribution.extended.call_for_each_replica(model_fn)
618      result_b = result[0]
619      result_c = result[1]
620      self.assertIsInstance(result_b, values.DistributedValues)
621      self.assertIsInstance(result_c, values.DistributedValues)
622      a0, a1 = distribution.experimental_local_results(a)
623      b0, b1 = distribution.experimental_local_results(result_b)
624      c0, c1 = distribution.experimental_local_results(result_c)
625      self.assertEqual("a:0", a0.name)
626      self.assertEqual("a/replica_1:0", a1.name)
627      self.assertEqual("b:0", b0.name)
628      self.assertEqual("b/replica_1:0", b1.name)
629      self.assertEqual("c:0", c0.name)
630      self.assertEqual("c/replica_1:0", c1.name)
631
632  def testVariableScopeWithGetVariable(self, distribution):
633
634    def in_cross_replica(_):
635      c = variable_scope.get_variable("c", [1])
636      return c
637
638    def model_fn():
639      b = variable_scope.get_variable("b", [1])
640      with variable_scope.variable_scope("foo"):
641        c = ds_context.get_replica_context().merge_call(in_cross_replica)
642      return b, c
643
644    with context.graph_mode(), distribution.scope():
645      with variable_scope.variable_scope("main"):
646        a = variable_scope.get_variable("a", [1])
647        result = distribution.extended.call_for_each_replica(model_fn)
648      result_b = result[0]
649      result_c = result[1]
650      self.assertIsInstance(result_b, values.DistributedValues)
651      self.assertIsInstance(result_c, values.DistributedValues)
652      a0, a1 = distribution.experimental_local_results(a)
653      b0, b1 = distribution.experimental_local_results(result_b)
654      c0, c1 = distribution.experimental_local_results(result_c)
655      self.assertEqual("main/a:0", a0.name)
656      self.assertEqual("main/a/replica_1:0", a1.name)
657      self.assertEqual("main/b:0", b0.name)
658      self.assertEqual("main/b/replica_1:0", b1.name)
659      self.assertEqual("main/foo/c:0", c0.name)
660      self.assertEqual("main/foo/c/replica_1:0", c1.name)
661
662
663@combinations.generate(
664    combinations.combine(
665        distribution=[
666            combinations.NamedDistribution(
667                "Mirrored3Devices",
668                # pylint: disable=g-long-lambda
669                lambda: mirrored_strategy.MirroredStrategy(
670                    ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]),
671                required_gpus=2)
672        ],
673        mode=["graph", "eager"]))
674class MirroredThreeDeviceDistributionTest(
675    strategy_test_lib.DistributionTestBase,
676    parameterized.TestCase):
677
678  def testThreeDevices(self, distribution):
679    def model_fn():
680      v = variable_scope.variable(1.0, name="foo")
681      ds_context.get_replica_context().merge_call(lambda _: _)
682      return v
683
684    with distribution.scope():
685      result = distribution.extended.call_for_each_replica(model_fn)
686      self.assertTrue(distribute_utils.is_mirrored(result))
687      self.assertEqual("foo:0", result.name)
688
689
690@combinations.generate(
691    combinations.combine(
692        distribution=[
693            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
694        ],
695        mode=["graph", "eager"]))
696class MirroredVariableUpdateTest(test.TestCase):
697  # The following tests check assign, assign_add and assign_sub on Mirrored
698  # variables in replica and cross replica context.
699
700  def testAssignMirroredVarReplicaContextWithoutAggregationType(self,
701                                                                distribution):
702    def var_fn():
703      v = variable_scope.variable(1.0, name="foo")
704      return v
705
706    with distribution.scope():
707      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
708      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
709      self.evaluate(variables.global_variables_initializer())
710
711      def model_fn():
712        return mirrored_var.assign(5.0)
713
714      self.evaluate(distribution.experimental_local_results(
715          distribution.extended.call_for_each_replica(model_fn)))
716      self.assertEqual(5.0, self.evaluate(mirrored_var))
717
718  def testAssignMirroredVarReplicaContextWithSum(self, distribution):
719    # Test that we don't reduce a non-per-replica value with the "sum"
720    # aggregation type.
721    def var_fn():
722      v = variable_scope.variable(
723          1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
724      return v
725
726    with distribution.scope():
727      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
728      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
729      self.evaluate(variables.global_variables_initializer())
730
731      def model_fn():
732        return mirrored_var.assign(5.0)
733
734      if distribution.extended._use_merge_call():
735        with self.assertRaisesRegex(
736            ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
737            "with the given reduce op ReduceOp.SUM."):
738          self.evaluate(distribution.experimental_local_results(
739              distribution.extended.call_for_each_replica(model_fn)))
740      else:
741        result = self.evaluate(
742            distribution.experimental_local_results(
743                distribution.extended.call_for_each_replica(model_fn)))
744        self.assertAllEqual(result[0], 5.0)
745
746  def testAssignMirroredVarCrossDeviceContext(self, distribution):
747    def var_fn():
748      return variable_scope.variable(1.0, name="foo")
749
750    with distribution.scope():
751      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
752      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
753      self.evaluate(variables.global_variables_initializer())
754      self.assertEqual(1.0, self.evaluate(mirrored_var))
755      mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
756      self.assertEqual(6.0, mirrored_var_result)
757
758  def testAssignMirroredVarReplicaContext(self, distribution):
759    def var_fn():
760      return variable_scope.variable(
761          1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
762
763    with distribution.scope():
764      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
765      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
766      self.evaluate(variables.global_variables_initializer())
767      self.assertEqual(1.0, self.evaluate(mirrored_var))
768
769      def model_fn():
770        value = math_ops.cast(
771            ds_context.get_replica_context().replica_id_in_sync_group,
772            mirrored_var.dtype)
773        return mirrored_var.assign(value)
774
775      self.evaluate(distribution.experimental_local_results(
776          distribution.extended.call_for_each_replica(model_fn)))
777      self.assertEqual(0.5, self.evaluate(mirrored_var))
778
779  def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution):
780    def var_fn():
781      return variable_scope.variable(
782          1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
783
784    with distribution.scope():
785      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
786      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
787      self.evaluate(variables.global_variables_initializer())
788      self.assertEqual(1.0, self.evaluate(mirrored_var))
789
790      def model_fn():
791        return mirrored_var.assign(5.0)
792
793      self.evaluate(distribution.experimental_local_results(
794          distribution.extended.call_for_each_replica(model_fn)))
795      self.assertEqual(5.0, self.evaluate(mirrored_var))
796
797  def testAssignAddMirroredVarCrossDeviceContext(self, distribution):
798    def var_fn():
799      return variable_scope.variable(1.0, name="foo")
800
801    with distribution.scope():
802      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
803      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
804      self.evaluate(variables.global_variables_initializer())
805      self.assertEqual(1.0, self.evaluate(mirrored_var))
806
807      # read_value == True
808      mirrored_var_result = self.evaluate(
809          mirrored_var.assign_add(6.0, read_value=True))
810      self.assertEqual(7.0, mirrored_var_result)
811      self.assertEqual(
812          7.0,
813          self.evaluate(
814              distribution.experimental_local_results(mirrored_var)[0]))
815      self.assertEqual(
816          7.0,
817          self.evaluate(
818              distribution.experimental_local_results(mirrored_var)[1]))
819      self.assertEqual(
820          distribution.extended.worker_devices[0], mirrored_var._devices[0])
821      self.assertEqual(
822          distribution.extended.worker_devices[1], mirrored_var._devices[1])
823
824      # read_value == False
825      self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
826      self.assertEqual(
827          9.0,
828          self.evaluate(
829              distribution.experimental_local_results(mirrored_var)[0]))
830      self.assertEqual(
831          9.0,
832          self.evaluate(
833              distribution.experimental_local_results(mirrored_var)[1]))
834      self.assertEqual(
835          distribution.extended.worker_devices[0], mirrored_var._devices[0])
836      self.assertEqual(
837          distribution.extended.worker_devices[1], mirrored_var._devices[1])
838
839  def testAssignAddMirroredVarReplicaContext(self, distribution):
840    def var_fn():
841      return variable_scope.variable(
842          1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
843
844    with distribution.scope():
845      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
846      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
847      self.evaluate(variables.global_variables_initializer())
848      self.assertEqual(1.0, self.evaluate(mirrored_var))
849
850      def model_fn():
851        value = math_ops.cast(
852            ds_context.get_replica_context().replica_id_in_sync_group,
853            mirrored_var.dtype)
854        return mirrored_var.assign_add(value)
855
856      self.evaluate(distribution.experimental_local_results(
857          distribution.extended.call_for_each_replica(model_fn)))
858      self.assertEqual(1.5, self.evaluate(mirrored_var))
859
860  def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution):
861    def var_fn():
862      return variable_scope.variable(
863          1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
864
865    with distribution.scope():
866      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
867      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
868      self.evaluate(variables.global_variables_initializer())
869      self.assertEqual(1.0, self.evaluate(mirrored_var))
870
871      def model_fn():
872        return mirrored_var.assign_add(5.0)
873
874      self.evaluate(distribution.experimental_local_results(
875          distribution.extended.call_for_each_replica(model_fn)))
876      self.assertEqual(6.0, self.evaluate(mirrored_var))
877
878  def testAssignSubMirroredVarCrossDeviceContext(self, distribution):
879    def var_fn():
880      return variable_scope.variable(5.0, name="foo")
881
882    with distribution.scope():
883      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
884      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
885      self.evaluate(variables.global_variables_initializer())
886      self.assertEqual(5.0, self.evaluate(mirrored_var))
887      mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
888      self.assertEqual(3.0, mirrored_var_result)
889      self.assertEqual(
890          3.0,
891          self.evaluate(
892              distribution.experimental_local_results(mirrored_var)[0]))
893      self.assertEqual(
894          3.0,
895          self.evaluate(
896              distribution.experimental_local_results(mirrored_var)[1]))
897      self.assertEqual(
898          distribution.extended.worker_devices[0], mirrored_var._devices[0])
899      self.assertEqual(
900          distribution.extended.worker_devices[1], mirrored_var._devices[1])
901
902  def testAssignSubMirroredVarReplicaContext(self, distribution):
903    def var_fn():
904      return variable_scope.variable(
905          5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
906
907    with distribution.scope():
908      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
909      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
910      self.evaluate(variables.global_variables_initializer())
911      self.assertEqual(5.0, self.evaluate(mirrored_var))
912
913      def model_fn():
914        value = math_ops.cast(
915            ds_context.get_replica_context().replica_id_in_sync_group,
916            mirrored_var.dtype)
917        return mirrored_var.assign_sub(value)
918
919      self.evaluate(distribution.experimental_local_results(
920          distribution.extended.call_for_each_replica(model_fn)))
921      self.assertEqual(4.5, self.evaluate(mirrored_var))
922
923  def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution):
924    def var_fn():
925      return variable_scope.variable(
926          5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
927
928    with distribution.scope():
929      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
930      self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
931      self.evaluate(variables.global_variables_initializer())
932      self.assertEqual(5.0, self.evaluate(mirrored_var))
933
934      def model_fn():
935        return mirrored_var.assign_sub(1.0)
936
937      self.evaluate(distribution.experimental_local_results(
938          distribution.extended.call_for_each_replica(model_fn)))
939      self.assertEqual(4.0, self.evaluate(mirrored_var))
940
941
942@combinations.generate(
943    combinations.combine(
944        distribution=[
945            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
946        ],
947        mode=["graph", "eager"]))
948class MirroredAndSyncOnReadVariableInitializerTest(test.TestCase):
949
950  def testAssignMirroredVarInitializer(self, distribution):
951    # This test is not eager compatible since in eager variables are initialized
952    # upon construction instead of once the initialization op is run.
953    with context.graph_mode():
954      def var_fn():
955        v = variable_scope.variable(1.0, name="foo")
956        return v
957
958      with distribution.scope():
959        mirrored_var = distribution.extended.call_for_each_replica(var_fn)
960        self.assertTrue(distribute_utils.is_mirrored(mirrored_var))
961        self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
962        self.evaluate(mirrored_var.initializer)
963        self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
964
965  def testAssignReplicaLocalVarInitializer(self, distribution):
966    # This test is not eager compatible since in eager variables are initialized
967    # upon construction instead of once the initialization op is run.
968    with context.graph_mode():
969      def model_fn():
970        v_sum = variable_scope.variable(
971            1.0,
972            synchronization=variable_scope.VariableSynchronization.ON_READ,
973            aggregation=variable_scope.VariableAggregation.SUM)
974        self.assertTrue(distribute_utils.is_sync_on_read(v_sum))
975        return v_sum
976
977      with distribution.scope():
978        sync_on_read_var = distribution.extended.call_for_each_replica(
979            model_fn)
980        self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
981        self.assertFalse(self.evaluate(sync_on_read_var.is_initialized()))
982        self.evaluate(sync_on_read_var.initializer)
983        self.assertTrue(self.evaluate(sync_on_read_var.is_initialized()))
984
985
986@combinations.generate(
987    combinations.combine(
988        distribution=[
989            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
990        ],
991        mode=["graph", "eager"]))
992class SyncOnReadVariableAssignTest(test.TestCase):
993
994  def testAssignReplicaLocalVarSumAggregation(self, distribution):
995    def model_fn():
996      v_sum = variable_scope.variable(
997          1.0,
998          synchronization=variable_scope.VariableSynchronization.ON_READ,
999          aggregation=variable_scope.VariableAggregation.SUM)
1000      return v_sum
1001
1002    with distribution.scope():
1003      sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
1004      self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
1005      self.evaluate(variables.global_variables_initializer())
1006      # Each replica has a value of 1.0 assigned to it in replica context.
1007      # When we read the value using `read_var` we should see the SUM of each of
1008      # values on each of the replicas.
1009      self.assertEqual(2.0, self.evaluate(
1010          distribution.extended.read_var(sync_on_read_var)))
1011      # Assigning 6.0 in cross replica context will assign a value of
1012      # 6.0/num_replicas to each replica.
1013      tlv_ops = sync_on_read_var.assign(6.0)
1014      self.evaluate(tlv_ops)
1015      # On reading the sync on read var we should get the assigned value back.
1016      # The value on all the replicas are added before being returned by
1017      # `read_var`.
1018      self.assertEqual(6.0, self.evaluate(
1019          distribution.extended.read_var(sync_on_read_var)))
1020
1021  def testAssignReplicaLocalVarMeanAggregation(self, distribution):
1022    def model_fn():
1023      v_sum = variable_scope.variable(
1024          1.0,
1025          synchronization=variable_scope.VariableSynchronization.ON_READ,
1026          aggregation=variable_scope.VariableAggregation.MEAN)
1027      return v_sum
1028
1029    with distribution.scope():
1030      sync_on_read_var = distribution.extended.call_for_each_replica(model_fn)
1031      self.assertTrue(distribute_utils.is_sync_on_read(sync_on_read_var))
1032      self.evaluate(variables.global_variables_initializer())
1033      # Each replica has a value of 1.0 assigned to it in replica context.
1034      # When we read the value using `read_var` we should see the MEAN of values
1035      # on all replicas which is the value assigned in replica context.
1036      self.assertEqual(1.0, self.evaluate(
1037          distribution.extended.read_var(sync_on_read_var)))
1038      tlv_ops = sync_on_read_var.assign(6.0)
1039      self.evaluate(tlv_ops)
1040      # On reading the sync on read var we should get the MEAN of all values
1041      # which is equal to the value assigned.
1042      self.assertEqual(6.0, self.evaluate(
1043          distribution.extended.read_var(sync_on_read_var)))
1044
1045
1046class MockModel(object):
1047
1048  def __init__(self, two_variables=False):
1049    self.variables = []
1050    self.variables.append(variable_scope.variable(1.25, name="dummy_var1"))
1051    if two_variables:
1052      self.variables.append(variable_scope.variable(2.0, name="dummy_var2"))
1053
1054  def __call__(self, factor=2):
1055    x = factor * self.variables[0]
1056    if len(self.variables) > 1:
1057      x += self.variables[1]
1058    return x
1059
1060
1061@combinations.generate(
1062    combinations.combine(
1063        distribution=[
1064            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1065        ],
1066        mode=["graph", "eager"]))
1067class MirroredStrategyDefunTest(test.TestCase):
1068
1069  def _call_and_check(self, distribution, model_fn, inputs, expected_result,
1070                      defuns, two_variables=False):
1071    cpu_dev = device_util.canonicalize("CPU:0")
1072    gpu_dev = device_util.canonicalize("GPU:0")
1073    devices = [cpu_dev, gpu_dev]
1074
1075    with distribution.scope():
1076      mock_model = MockModel(two_variables)
1077      self.evaluate(variables.global_variables_initializer())
1078
1079      result = distribution.extended.call_for_each_replica(
1080          model_fn, args=[mock_model] + inputs)
1081      for r in range(len(devices)):
1082        device_result = distribute_utils.select_replica(r, result)
1083        device_expected_result = distribute_utils.select_replica(
1084            r, expected_result)
1085        self.assertAllClose(device_expected_result,
1086                            self.evaluate(device_result))
1087
1088      for defun in defuns:
1089        # `Function`s are specialized to the current device stack, so
1090        # call_for_each has one trace per device. To check that the expected set
1091        # of variables was accessed on each trace, we first retrieve each
1092        # device-specific graph function.
1093        per_replica_graph_functions = (
1094            distribution.extended.call_for_each_replica(
1095                defun.get_concrete_function, args=[mock_model] + inputs))
1096        for i in range(len(devices)):
1097          graph_function = distribution.experimental_local_results(
1098              per_replica_graph_functions)[i]
1099          # TODO(b/129555712): re-enable an assertion here that the two sets of
1100          # variables are the same.
1101          # self.assertEqual(set(graph_function.graph.variables),
1102          #  set(mock_model.variables))
1103          del graph_function
1104
1105  def testVariableInDefun(self, distribution):
1106    @function.defun
1107    def times_two(mock_model):
1108      return mock_model()
1109
1110    def model_fn(mock_model):
1111      return times_two(mock_model)
1112
1113    self._call_and_check(distribution, model_fn, [], 2.5, [times_two])
1114
1115  def testVariableInNestedDefun(self, distribution):
1116    @function.defun
1117    def times_two(mock_model):
1118      return mock_model()
1119
1120    @function.defun
1121    def two_x_plus_one(mock_model):
1122      return times_two(mock_model) + 1
1123
1124    def model_fn(mock_model):
1125      return two_x_plus_one(mock_model)
1126
1127    self._call_and_check(distribution, model_fn, [], 3.5,
1128                         [times_two, two_x_plus_one])
1129
1130  def testTwoVariablesInNestedDefun(self, distribution):
1131    @function.defun
1132    def fn1(mock_model):
1133      return mock_model()
1134
1135    @function.defun
1136    def fn2(mock_model):
1137      return fn1(mock_model) + 1
1138
1139    def model_fn(mock_model):
1140      return fn2(mock_model)
1141
1142    self._call_and_check(distribution, model_fn, [], 5.5, [fn1, fn2],
1143                         two_variables=True)
1144
1145  def testGradientTapeOverNestedDefuns(self, distribution):
1146    @function.defun
1147    def fn1(mock_model):
1148      return mock_model()
1149
1150    @function.defun
1151    def fn2(mock_model):
1152      return fn1(mock_model) + 1
1153
1154    def model_fn(mock_model):
1155      with backprop.GradientTape(persistent=True) as gtape:
1156        result = fn2(mock_model)
1157      grads = gtape.gradient(result,
1158                             [v._get() for v in mock_model.variables])
1159      return grads
1160
1161    self._call_and_check(distribution, model_fn, [], [2.0, 1.0], [fn1, fn2],
1162                         two_variables=True)
1163
1164  def testPassPerReplica(self, distribution):
1165    @function.defun
1166    def fn1(mock_model, factor):
1167      return mock_model(factor)
1168
1169    factors = values.PerReplica((5.0, 3.0))
1170    expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25))
1171    self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
1172
1173
1174@combinations.generate(
1175    combinations.combine(
1176        distribution=[
1177            combinations.NamedDistribution(
1178                "Mirrored",
1179                # pylint: disable=g-long-lambda
1180                lambda: mirrored_strategy.MirroredStrategy(
1181                    devices=mirrored_strategy.all_local_devices(),
1182                    cross_device_ops=cross_device_ops_lib.ReductionToOneDevice(
1183                    ),
1184                ),
1185                required_gpus=1)
1186        ],
1187        mode=["graph"]))
1188class MultiWorkerMirroredStrategyTest(
1189    multi_worker_test_base.MultiWorkerTestBase,
1190    strategy_test_lib.DistributionTestBase):
1191
1192  def _configure_distribution_strategy(self, distribution):
1193    cluster_spec = server_lib.ClusterSpec({
1194        "worker": ["/job:worker/task:0", "/job:worker/task:1"]
1195    })
1196    distribution.configure(cluster_spec=cluster_spec)
1197
1198  def test_num_replicas_in_sync(self, distribution):
1199    self._configure_distribution_strategy(distribution)
1200    # We calculate the total number of gpus across the workers(2) specified in
1201    # the cluster spec.
1202    self.assertEqual(context.num_gpus() * 2, distribution.num_replicas_in_sync)
1203
1204  def testMinimizeLossGraph(self, distribution):
1205    self._configure_distribution_strategy(distribution)
1206    self._test_minimize_loss_graph(distribution, learning_rate=0.05)
1207
1208  def testDeviceScope(self, distribution):
1209    """Test the device scope of multi-worker MirroredStrategy."""
1210    self._configure_distribution_strategy(distribution)
1211    with distribution.scope():
1212      a = constant_op.constant(1.)
1213      with ops.device("/cpu:0"):
1214        b = constant_op.constant(1.)
1215      self.assertEqual(a.device, "/job:worker/task:0")
1216      self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
1217
1218  def testMakeInputFnIteratorWithDataset(self, distribution):
1219    self._configure_distribution_strategy(distribution)
1220    dataset_fn = lambda: dataset_ops.Dataset.range(100)
1221    num_gpus = context.num_gpus()
1222    num_workers = 2
1223
1224    expected_values = [[i+j for j in range(num_gpus)] * num_workers
1225                       for i in range(0, 100, num_gpus)]
1226
1227    with context.graph_mode(), self.cached_session() as sess:
1228      # `expected_input_pipeline_id` is None because the input_fn will be called
1229      # multiple times, each with a different input_pipeline_id.
1230      input_fn = self._input_fn_to_test_input_context(
1231          dataset_fn,
1232          expected_num_replicas_in_sync=num_workers*num_gpus,
1233          expected_num_input_pipelines=num_workers,
1234          expected_input_pipeline_id=None)
1235      iterator = distribution.make_input_fn_iterator(input_fn)
1236      self._test_input_fn_iterator(
1237          iterator, distribution.extended.worker_devices, expected_values, sess)
1238
1239  def testMakeInputFnIteratorWithCallable(self, distribution):
1240    self._configure_distribution_strategy(distribution)
1241    def fn():
1242      dataset = dataset_ops.Dataset.range(100)
1243      it = dataset_ops.make_one_shot_iterator(dataset)
1244      return it.get_next
1245    num_gpus = context.num_gpus()
1246    num_workers = 2
1247
1248    expected_values = []
1249    for i in range(0, 100, num_gpus):
1250      expected_values.append([i+j for j in range(num_gpus)] * num_workers)
1251
1252    with context.graph_mode(), self.cached_session() as sess:
1253      # `expected_input_pipeline_id` is None because the input_fn will be called
1254      # multiple times, each with a different input_pipeline_id.
1255      input_fn = self._input_fn_to_test_input_context(
1256          fn,
1257          expected_num_replicas_in_sync=num_workers*num_gpus,
1258          expected_num_input_pipelines=num_workers,
1259          expected_input_pipeline_id=None)
1260      iterator = distribution.make_input_fn_iterator(input_fn)
1261      self._test_input_fn_iterator(
1262          iterator, distribution.extended.worker_devices, expected_values, sess,
1263          test_reinitialize=False, ignore_order=True)
1264
1265  def testUpdateConfigProto(self, distribution):
1266    distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]})
1267
1268    config_proto = config_pb2.ConfigProto()
1269    new_config = distribution.update_config_proto(config_proto)
1270
1271    # Verify isolate_session_state
1272    self.assertTrue(new_config.isolate_session_state)
1273
1274
1275@combinations.generate(
1276    combinations.combine(
1277        distribution=[
1278            combinations.NamedDistribution(
1279                "Mirrored",
1280                # pylint: disable=g-long-lambda
1281                lambda: mirrored_strategy.MirroredStrategy(
1282                    devices=["/job:worker/task:0/gpu:{}".format(
1283                        i) for i in range(context.num_gpus())]),
1284                required_gpus=1)
1285        ],
1286        mode=["graph"]))
1287class RemoteSingleWorkerMirroredStrategyGraph(
1288    multi_worker_test_base.SingleWorkerTestBaseGraph,
1289    strategy_test_lib.RemoteSingleWorkerMirroredStrategyBase):
1290
1291  def _get_num_gpus(self):
1292    return context.num_gpus()
1293
1294  def testNumReplicasInSync(self, distribution):
1295    self._testNumReplicasInSync(distribution)
1296
1297  def testMinimizeLoss(self, distribution):
1298    self._testMinimizeLoss(distribution)
1299
1300  def testDeviceScope(self, distribution):
1301    self._testDeviceScope(distribution)
1302
1303  def testMakeInputFnIteratorWithDataset(self, distribution):
1304    self._testMakeInputFnIteratorWithDataset(distribution)
1305
1306  def testMakeInputFnIteratorWithCallable(self, distribution):
1307    self._testMakeInputFnIteratorWithCallable(distribution)
1308
1309
1310class MultiWorkerMirroredStrategyTestWithChief(
1311    multi_worker_test_base.MultiWorkerTestBase,
1312    strategy_test_lib.DistributionTestBase):
1313
1314  @classmethod
1315  def setUpClass(cls):
1316    """Create a local cluster with 2 workers and 1 chief."""
1317    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
1318        num_workers=2, num_ps=0, has_chief=True)
1319    cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
1320
1321  def _make_cross_device_ops(self):
1322    return cross_device_ops_lib.ReductionToOneDevice()
1323
1324  def testMinimizeLossGraph(self):
1325    with context.graph_mode():
1326      strategy = mirrored_strategy.MirroredStrategy(
1327          cross_device_ops=self._make_cross_device_ops())
1328      strategy.configure(cluster_spec=self._cluster_spec)
1329      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
1330
1331  def testMinimizeLossGraphMirroredStrategy(self):
1332    with context.graph_mode():
1333      strategy = mirrored_strategy.MirroredStrategy(
1334          mirrored_strategy.all_local_devices(),
1335          cross_device_ops=self._make_cross_device_ops())
1336      strategy.configure(cluster_spec=self._cluster_spec)
1337      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
1338
1339  def testMinimizeLossGraphMirroredStrategyWithOneNode(self):
1340    with context.graph_mode():
1341      cluster_spec = {}
1342      cluster_spec["chief"] = self._cluster_spec["chief"]
1343      tf_config = {"cluster": cluster_spec}
1344      with test.mock.patch.dict("os.environ",
1345                                {"TF_CONFIG": json.dumps(tf_config)}):
1346        strategy = mirrored_strategy.MirroredStrategy()
1347        if context.num_gpus() == 0:
1348          self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
1349                                cross_device_ops_lib.ReductionToOneDevice)
1350      self.skipTest("b/130551176, run the following once fixed.")
1351      self._test_minimize_loss_graph(strategy, learning_rate=0.05)
1352
1353  def testInitializeFromTFConfig(self):
1354    with context.graph_mode():
1355      tf_config = {"cluster": self._cluster_spec}
1356      with test.mock.patch.dict("os.environ",
1357                                {"TF_CONFIG": json.dumps(tf_config)}):
1358        strategy = mirrored_strategy.MirroredStrategy(
1359            cross_device_ops=self._make_cross_device_ops())
1360        self.assertEqual(
1361            max(context.num_gpus(), 1) * 3, strategy.num_replicas_in_sync)
1362
1363  def testSummaryForReplicaZeroOnly(self):
1364    with context.graph_mode():
1365      strategy = mirrored_strategy.MirroredStrategy(
1366          mirrored_strategy.all_local_devices(),
1367          cross_device_ops=self._make_cross_device_ops())
1368      strategy.configure(cluster_spec=self._cluster_spec)
1369      self._test_summary_for_replica_zero_only(strategy)
1370
1371
1372class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
1373
1374  @combinations.generate(
1375      combinations.combine(
1376          distribution=[
1377              strategy_combinations.mirrored_strategy_with_one_cpu,
1378              strategy_combinations.mirrored_strategy_with_one_gpu,
1379          ],
1380          mode=["graph"]))
1381  def testMirroredVariableAsStopGradient(self, distribution):
1382    with distribution.scope():
1383      inp = constant_op.constant(1.0)
1384      x = variables.Variable(1.0)
1385      y = inp*x
1386      grads = gradients.gradients(x, y, stop_gradients=x)
1387      self.assertIsNone(grads[0])
1388
1389
1390@combinations.generate(
1391    combinations.combine(
1392        distribution=[
1393            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1394        ],
1395        mode=["eager"]))
1396class FunctionTest(test.TestCase, parameterized.TestCase):
1397
1398  def testBackwardFunctionDevicePlacement(self, distribution):
1399    with distribution.scope():
1400      w = variable_scope.variable([1.5], name="w")
1401      b = variable_scope.variable([0.5], name="b")
1402
1403    @def_function.function
1404    def forward(x, w, b):
1405      return x * w + b
1406
1407    x = array_ops.identity([1.0], name="x_useless")
1408    concrete_forward = forward.get_concrete_function(x, w._primary, b._primary)
1409
1410    with distribution.scope():
1411
1412      def replica_fn():
1413        with backprop.GradientTape() as t:
1414          x = array_ops.identity([1.0], name="x")
1415          loss = concrete_forward(x, w._get(), b._get()) - [1.0]
1416          return t.gradient(loss, [w, b])
1417
1418      def step_fn():
1419        return distribution.run(replica_fn)
1420
1421      context.enable_run_metadata()
1422      g1, g2 = step_fn()
1423      run_metadata = context.export_run_metadata()
1424      context.disable_run_metadata()
1425      self.assertEqual(self.evaluate(g1._primary), 1.0)
1426      self.assertEqual(self.evaluate(g2._primary), 1.0)
1427
1428      # Verify that this node runs on both devices.
1429      node_name = "gradients_mul_grad_mul_1_x"
1430      devices_for_this_node = set()
1431      for partition_graph in run_metadata.partition_graphs:
1432        for node in partition_graph.node:
1433          if node.name == node_name:
1434            devices_for_this_node.add(node.device)
1435      devices = [device_util.resolve("/device:GPU:0"),
1436                 device_util.resolve("/device:CPU:0")]
1437      self.assertSetEqual(devices_for_this_node, set(devices))
1438
1439  def testFuctionPreservesAutoGraph(self, distribution):
1440    def f():
1441      self.assertTrue(converter_testing.is_inside_generated_code())
1442      return 1
1443
1444    with distribution.scope():
1445
1446      @def_function.function
1447      def replica_fn():
1448        return f()
1449
1450      distribution.run(replica_fn)
1451
1452  def testPreserveTracebackFiltering(self, distribution):
1453    traceback_utils.disable_traceback_filtering()
1454    self.assertFalse(traceback_utils.is_traceback_filtering_enabled())
1455
1456    def f():
1457      self.assertFalse(traceback_utils.is_traceback_filtering_enabled())
1458
1459    distribution.run(f)
1460
1461
1462def _replica_id():
1463  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
1464  if not isinstance(replica_id, ops.Tensor):
1465    replica_id = constant_op.constant(replica_id)
1466  return array_ops.identity(replica_id)
1467
1468
1469def _replica_id_as_int():
1470  replica_id = ds_context.get_replica_context().replica_id_in_sync_group
1471  if isinstance(replica_id, ops.Tensor):
1472    replica_id = tensor_util.constant_value(replica_id)
1473  return replica_id
1474
1475
1476if __name__ == "__main__":
1477  # TODO(b/172304955)
1478  test_util.main(config_logical_devices=False)
1479