xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/strategy_common_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for common methods in strategy classes."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.distribute import combinations
21from tensorflow.python.distribute import distribution_strategy_context as ds_context
22from tensorflow.python.distribute import multi_worker_test_base
23from tensorflow.python.distribute import reduce_util
24from tensorflow.python.distribute import strategy_combinations
25from tensorflow.python.distribute import strategy_test_lib
26from tensorflow.python.distribute import test_util
27from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceStrategy
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import indexed_slices
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import test
37from tensorflow.python.util import nest
38
39
40@combinations.generate(
41    combinations.combine(
42        strategy=[
43            strategy_combinations.multi_worker_mirrored_2x1_cpu,
44            strategy_combinations.multi_worker_mirrored_2x1_gpu,
45        ] + strategy_combinations.all_strategies,
46        mode=['eager']))
47class StrategyTest(test.TestCase, parameterized.TestCase):
48
49  def testCaptureReplicaId(self, strategy):
50    m = {}
51
52    @def_function.function
53    def f():
54      return ds_context.get_replica_context().replica_id_in_sync_group
55
56    @def_function.function
57    def g():
58      # Make g() a stateful function so it's traced twice.
59      if m.get('v', None) is None:
60        m['v'] = variables.Variable(0.)
61      return strategy.run(f)
62
63    g()
64
65  def testMergeCallInitScope(self, strategy):
66    with strategy.scope():
67
68      @def_function.function
69      def fn():
70
71        def merge_fn(unused_strat):
72
73          y = constant_op.constant(11)
74          return y
75
76        def replica_fn():
77
78          with ops.init_scope():
79            y = ds_context.get_replica_context().merge_call(merge_fn)
80            z = y + 1
81            return z
82
83        return strategy.run(replica_fn)
84
85      result = strategy.experimental_local_results(fn())
86      self.assertAllClose(result, [12] * _get_num_replicas_per_client(strategy))
87
88
89@combinations.generate(
90    combinations.combine(
91        distribution=[
92            strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
93            strategy_combinations.multi_worker_mirrored_2x2_gpu,
94            strategy_combinations.tpu_strategy
95        ],
96        mode=['graph', 'eager']))
97class StrategyLocalResultTest(test.TestCase):
98
99  def testLocalResultForDictionary(self, distribution):
100
101    @def_function.function
102    def model_fn():
103      return {'a': constant_op.constant(1.), 'b': constant_op.constant(2.)}
104
105    with distribution.scope():
106      result = distribution.run(model_fn)
107      got = self.evaluate(distribution.experimental_local_results(result))
108      self.assertEqual(got, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.}))
109
110  def testLocalResultForList(self, distribution):
111
112    @def_function.function
113    def model_fn():
114      return [constant_op.constant(1.), constant_op.constant(2.)]
115
116    with distribution.scope():
117      result = distribution.run(model_fn)
118      got = self.evaluate(distribution.experimental_local_results(result))
119      self.assertEqual(got, ([1., 2.], [1., 2.]))
120
121  def testLocalResultForTuple(self, distribution):
122
123    @def_function.function
124    def model_fn():
125      return (constant_op.constant(1.), constant_op.constant(2.),
126              constant_op.constant(3.))
127
128    with distribution.scope():
129      result = distribution.run(model_fn)
130      got = self.evaluate(distribution.experimental_local_results(result))
131      self.assertEqual(got, ((1., 2., 3.), (1., 2., 3.)))
132
133  def testLocalResultForNestedStruct(self, distribution):
134
135    @def_function.function
136    def model_fn():
137      return ({
138          'a': constant_op.constant(1.),
139          'b': constant_op.constant(2.)
140      }, {
141          'a': constant_op.constant(4.),
142          'b': constant_op.constant(6.)
143      })
144
145    with distribution.scope():
146      result = distribution.run(model_fn)
147      got = self.evaluate(distribution.experimental_local_results(result))
148      self.assertEqual(got, (({
149          'a': 1.,
150          'b': 2.
151      }, {
152          'a': 4.,
153          'b': 6.
154      }), ({
155          'a': 1.,
156          'b': 2.
157      }, {
158          'a': 4.,
159          'b': 6.
160      })))
161
162  def testLocalResultForNestedStructWithoutTensor(self, distribution):
163
164    @def_function.function
165    def model_fn():
166      return {'a': 1., 'b': 2.}
167
168    with distribution.scope():
169      result = distribution.run(model_fn)
170      v = self.evaluate(distribution.experimental_local_results(result))
171      self.assertIsInstance(v, tuple)
172      self.assertAllEqual(v, ({'a': 1., 'b': 2.}, {'a': 1., 'b': 2.}))
173
174  def testLocalResultForScalarValue(self, distribution):
175
176    @def_function.function
177    def model_fn():
178      return distribution.extended._get_local_replica_id(
179          ds_context.get_replica_context().replica_id_in_sync_group)
180
181    with distribution.scope():
182      result = distribution.run(model_fn)
183      v = self.evaluate(distribution.experimental_local_results(result))
184      self.assertIsInstance(v, tuple)
185      self.assertEqual(v, (0, 1))
186
187  def testLocalResultForDictionaryDifferentReplicas(self, distribution):
188
189    @def_function.function
190    def model_fn():
191      replica_id = distribution.extended._get_local_replica_id(
192          ds_context.get_replica_context().replica_id_in_sync_group)
193      return {
194          'a': math_ops.cast(replica_id + 1, dtype=float),
195          'b': math_ops.cast(replica_id + 2, dtype=float)
196      }
197
198    with distribution.scope():
199      result = distribution.run(model_fn)
200      got = self.evaluate(distribution.experimental_local_results(result))
201      self.assertAllEqual(got, ({'a': 1., 'b': 2.}, {'a': 2., 'b': 3.}))
202
203  def testLocalResultForTensor(self, distribution):
204
205    @def_function.function
206    def model_fn():
207      return constant_op.constant([2., 3.])
208
209    with distribution.scope():
210      result = distribution.run(model_fn)
211      v = self.evaluate(distribution.experimental_local_results(result))
212      self.assertAllEqual(v, ([2., 3.], [2., 3.]))
213
214
215@combinations.generate(
216    combinations.combine(
217        strategy=[
218            strategy_combinations.multi_worker_mirrored_2x1_cpu,
219            strategy_combinations.multi_worker_mirrored_2x1_gpu,
220        ] + strategy_combinations.all_strategies,
221        mode=['eager']))
222class ReduceTest(test.TestCase, parameterized.TestCase):
223
224  def testBasic(self, strategy):
225    per_replica_value = strategy.experimental_distribute_values_from_function(
226        lambda _: array_ops.ones((), dtypes.float32))
227
228    def fn_eager():
229
230      return strategy.reduce(
231          reduce_util.ReduceOp.SUM, value=per_replica_value, axis=None)
232
233    fn_graph = def_function.function(fn_eager)
234    # Run reduce under the strategy scope to explicitly enter
235    # strategy default_device scope.
236    with strategy.scope():
237      self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
238      self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
239
240    # Run reduce without a strategy scope to implicitly enter
241    # strategy default_device scope.
242    self.assertEqual(fn_eager().numpy(), 1.0 * strategy.num_replicas_in_sync)
243    self.assertEqual(fn_graph().numpy(), 1.0 * strategy.num_replicas_in_sync)
244
245  def testAxis(self, strategy):
246
247    @def_function.function
248    def fn():
249      return constant_op.constant([1., 2.])
250
251    x = strategy.run(fn)
252
253    x_m = strategy.reduce(reduce_util.ReduceOp.MEAN, x, axis=0)
254    self.assertEqual(1.5, x_m)
255    x_s = strategy.reduce(reduce_util.ReduceOp.SUM, x, axis=0)
256    self.assertEqual(3 * strategy.num_replicas_in_sync, x_s)
257
258
259@combinations.generate(
260    combinations.combine(
261        strategy=[
262            strategy_combinations.default_strategy,
263            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
264            strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
265            strategy_combinations.tpu_strategy,
266            strategy_combinations.tpu_strategy_packed_var,
267            strategy_combinations.multi_worker_mirrored_2x1_cpu,
268            strategy_combinations.multi_worker_mirrored_2x2_gpu,
269            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
270        ],
271        update_fn=['assign', 'assign_add', 'assign_sub'],
272        tf_function=[True, False],
273        mode=['eager']))
274class ReplicaCtxUpdateTest(test.TestCase, parameterized.TestCase):
275
276  def testDenseUpdate(self, strategy, tf_function, update_fn):
277    if strategy_test_lib.is_tpu_strategy(strategy) and (not tf_function):
278      self.skipTest('Skip TPUStrategy + eager combination.')
279    with strategy.scope():
280      distributed_variable1 = variables.Variable(5.0)
281
282    def replica_fn():
283      value = array_ops.constant(2.)
284      python_literal = 1.
285      replica_context = ds_context.get_replica_context()
286      fn_sets = {
287          'assign': lambda var, value: var.assign(value),
288          'assign_add': lambda var, value: var.assign_add(value),
289          'assign_sub': lambda var, value: var.assign_sub(value),
290      }
291      replica_context._update(
292          distributed_variable1, fn_sets[update_fn], args=(value,))
293      replica_context._update(
294          distributed_variable1, fn_sets[update_fn], args=(python_literal,))
295
296    if tf_function:
297      replica_fn = def_function.function(replica_fn)
298    strategy.run(replica_fn)
299
300    expected_result = {'assign': 1., 'assign_add': 8., 'assign_sub': 2.}
301    self.assertAllEqual(
302        strategy.experimental_local_results(distributed_variable1),
303        [expected_result[update_fn]] * _get_num_replicas_per_client(strategy))
304
305
306@combinations.generate(
307    combinations.combine(
308        strategy=[
309            strategy_combinations.multi_worker_mirrored_2x1_cpu,
310            strategy_combinations.multi_worker_mirrored_2x1_gpu,
311            strategy_combinations.multi_worker_mirrored_2x2_gpu,
312            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
313            strategy_combinations.tpu_strategy,
314        ] + strategy_combinations.strategies_minus_tpu,
315        tf_function=[combinations.tf_function, combinations.no_tf_function],
316        mode=['eager']))
317class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
318
319  def testDense(self, strategy, tf_function):
320    if (strategy_test_lib.is_tpu_strategy(strategy) and
321        tf_function is combinations.no_tf_function):
322      self.skipTest('Skip TPUStrategy + eager combination.')
323
324    @tf_function
325    def fn():
326
327      def replica_fn():
328        value = array_ops.identity(1.0)
329        reduced = strategy.extended._replica_ctx_all_reduce(
330            reduce_util.ReduceOp.SUM, value)
331        return reduced
332
333      return strategy.experimental_local_results(strategy.run(replica_fn))
334
335    got = fn()[0]
336    self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
337
338  def testSparse(self, strategy, tf_function):
339    if tf_function is combinations.no_tf_function:
340      self.skipTest('Skip IndexedSlices + eager combination.')
341
342    @tf_function
343    def fn():
344
345      def replica_fn():
346        value = indexed_slices.IndexedSlices(
347            values=array_ops.identity([[1.0]]),
348            indices=array_ops.identity([0]),
349            dense_shape=array_ops.identity([5, 1]))
350        reduced = strategy.extended._replica_ctx_all_reduce(
351            reduce_util.ReduceOp.SUM, value)
352        return reduced
353
354      return strategy.experimental_local_results(strategy.run(replica_fn))
355
356    got = fn()[0]
357    expect = indexed_slices.IndexedSlices(
358        values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]),
359        indices=array_ops.identity([0]),
360        dense_shape=array_ops.identity([5, 1]))
361    self.assertAllEqual(
362        ops.convert_to_tensor(got), ops.convert_to_tensor(expect))
363
364  def testNestedInput(self, strategy, tf_function):
365    if tf_function is combinations.no_tf_function:
366      self.skipTest('Skip IndexedSlices + eager combination.')
367
368    @tf_function
369    def fn():
370
371      def replica_fn():
372        value = (array_ops.identity(1.0),
373                 indexed_slices.IndexedSlices(
374                     values=array_ops.identity([[1.0]]),
375                     indices=array_ops.identity([0]),
376                     dense_shape=array_ops.identity([5, 1])),
377                 array_ops.identity(2.0),
378                 indexed_slices.IndexedSlices(
379                     values=array_ops.identity([[2.0]]),
380                     indices=array_ops.identity([1]),
381                     dense_shape=array_ops.identity([5, 1])))
382        reduced = strategy.extended._replica_ctx_all_reduce(
383            reduce_util.ReduceOp.SUM, value)
384        return reduced
385
386      return strategy.experimental_local_results(strategy.run(replica_fn))
387
388    got = fn()[0]
389    expect = (1.0 * strategy.num_replicas_in_sync,
390              indexed_slices.IndexedSlices(
391                  values=array_ops.identity(
392                      [[1.0 * strategy.num_replicas_in_sync]]),
393                  indices=array_ops.identity([0]),
394                  dense_shape=array_ops.identity([5, 1])),
395              2.0 * strategy.num_replicas_in_sync,
396              indexed_slices.IndexedSlices(
397                  values=array_ops.identity(
398                      [[2.0 * strategy.num_replicas_in_sync]]),
399                  indices=array_ops.identity([1]),
400                  dense_shape=array_ops.identity([5, 1])))
401
402    self.assertAllClose(
403        nest.map_structure(ops.convert_to_tensor, got),
404        nest.map_structure(ops.convert_to_tensor, expect))
405
406  def testSyncOnReadVariableInput(self, strategy, tf_function):
407    if (not strategy_test_lib.is_mirrored_strategy(strategy) and
408        not strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
409        not strategy_test_lib.is_tpu_strategy(strategy)):
410      self.skipTest('Skip strategies not using SyncOnReadVariables.')
411    if (strategy_test_lib.is_tpu_strategy(strategy) and
412        tf_function is combinations.no_tf_function):
413      self.skipTest('Skip TPUStrategy + eager combination.')
414    if (strategy_test_lib.is_multi_worker_mirrored_strategy(strategy) and
415        tf_function is combinations.tf_function):
416      self.skipTest('Skip MWMS + graph combination until b/228512201 is fixed.')
417
418    with strategy.scope():
419      var = variables.Variable(
420          0.0,
421          synchronization=variables.VariableSynchronization.ON_READ,
422          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
423
424    @tf_function
425    def replica_fn():
426      replica_context = ds_context.get_replica_context()
427      replica_id = replica_context.replica_id_in_sync_group
428      var.assign(math_ops.cast(replica_id, dtype=float) * 3.0)
429
430      return replica_context.all_reduce(reduce_util.ReduceOp.SUM, var)
431
432    if strategy_test_lib.is_multi_worker_mirrored_strategy(strategy):
433      client_local_replica_num = strategy.extended._num_devices_per_worker
434    else:
435      client_local_replica_num = strategy.num_replicas_in_sync
436
437    workers_num = strategy.num_replicas_in_sync
438    expected_sum = sum(range(workers_num)) * 3.0
439
440    # Expand the values on each replica if multiple devices are used; otherwise
441    # simple read the value of the Tensor.
442    result = strategy.run(replica_fn)
443    if hasattr(result, 'values'):
444      result = result.values
445    result = nest.flatten(result)
446
447    # Iterate through all replicas and verify the reduce sum result.
448    for i in range(client_local_replica_num):
449      self.assertEqual(result[i].numpy(), expected_sum)
450
451
452@combinations.generate(
453    combinations.combine(
454        strategy=[
455            strategy_combinations.multi_worker_mirrored_2x1_cpu,
456            strategy_combinations.multi_worker_mirrored_2x1_gpu,
457            strategy_combinations.multi_worker_mirrored_2x2_gpu,
458            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
459            strategy_combinations.tpu_strategy,
460        ] + strategy_combinations.strategies_minus_tpu,
461        tf_function=[combinations.tf_function, combinations.no_tf_function],
462        mode=['eager']))
463class AllReduceTest(test.TestCase, parameterized.TestCase):
464
465  def testDense(self, strategy, tf_function):
466    if (strategy_test_lib.is_tpu_strategy(strategy) and
467        tf_function is combinations.no_tf_function):
468      self.skipTest('Skip TPUStrategy + eager combination.')
469
470    @tf_function
471    def fn():
472
473      def replica_fn():
474        value = array_ops.identity(1.0)
475        rep_ctx = ds_context.get_replica_context()
476        reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
477        return reduced
478
479      return strategy.experimental_local_results(strategy.run(replica_fn))
480
481    got = fn()[0]
482    self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
483
484  def testSparse(self, strategy, tf_function):
485    if tf_function is combinations.no_tf_function:
486      self.skipTest('Skip IndexedSlices + eager combination.')
487
488    @tf_function
489    def fn():
490
491      def replica_fn():
492        value = indexed_slices.IndexedSlices(
493            values=array_ops.identity([[1.0]]),
494            indices=array_ops.identity([0]),
495            dense_shape=array_ops.identity([5, 1]))
496        rep_ctx = ds_context.get_replica_context()
497        reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.MEAN, value)
498        return reduced
499
500      return strategy.experimental_local_results(strategy.run(replica_fn))
501
502    got = fn()[0]
503
504    if not strategy_test_lib.is_tpu_strategy(strategy):
505      self.assertIsInstance(got, indexed_slices.IndexedSlices)
506    expect = indexed_slices.IndexedSlices(
507        values=array_ops.identity([[1.0]]),
508        indices=array_ops.identity([0]),
509        dense_shape=array_ops.identity([5, 1]))
510    self.assertAllEqual(
511        ops.convert_to_tensor(got), ops.convert_to_tensor(expect))
512
513  def testSparseTuple(self, strategy, tf_function):
514    if tf_function is combinations.no_tf_function:
515      self.skipTest('Skip IndexedSlices + eager combination.')
516
517    @tf_function
518    def fn():
519
520      def replica_fn():
521        value1 = indexed_slices.IndexedSlices(
522            values=array_ops.identity([[1.0]]),
523            indices=array_ops.identity([0]),
524            dense_shape=array_ops.identity([5, 1]))
525        value2 = indexed_slices.IndexedSlices(
526            values=array_ops.identity([[2.0]]),
527            indices=array_ops.identity([0]),
528            dense_shape=array_ops.identity([5, 1]))
529        rep_ctx = ds_context.get_replica_context()
530        reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, [value1, value2])
531        return reduced
532
533      return strategy.experimental_local_results(strategy.run(replica_fn))
534
535    got = fn()[0]
536
537    if not strategy_test_lib.is_tpu_strategy(strategy):
538      for g in got:
539        self.assertIsInstance(g, indexed_slices.IndexedSlices)
540    expect = [
541        indexed_slices.IndexedSlices(
542            values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]),
543            indices=array_ops.identity([0]),
544            dense_shape=array_ops.identity([5, 1])),
545        indexed_slices.IndexedSlices(
546            values=array_ops.identity([[2.0 * strategy.num_replicas_in_sync]]),
547            indices=array_ops.identity([0]),
548            dense_shape=array_ops.identity([5, 1]))
549    ]
550    self.assertAllEqual(
551        nest.map_structure(ops.convert_to_tensor, got),
552        nest.map_structure(ops.convert_to_tensor, expect))
553
554  def testNestedInput(self, strategy, tf_function):
555    if tf_function is combinations.no_tf_function:
556      self.skipTest('Skip IndexedSlices + eager combination.')
557
558    @tf_function
559    def fn():
560
561      def replica_fn():
562        value = (array_ops.identity(1.0),
563                 indexed_slices.IndexedSlices(
564                     values=array_ops.identity([[1.0]]),
565                     indices=array_ops.identity([0]),
566                     dense_shape=array_ops.identity([5, 1])),
567                 array_ops.identity(2.0),
568                 indexed_slices.IndexedSlices(
569                     values=array_ops.identity([[2.0]]),
570                     indices=array_ops.identity([1]),
571                     dense_shape=array_ops.identity([5, 1])))
572        rep_ctx = ds_context.get_replica_context()
573        reduced = rep_ctx.all_reduce(reduce_util.ReduceOp.SUM, value)
574        return reduced
575
576      return strategy.experimental_local_results(strategy.run(replica_fn))
577
578    got = fn()[0]
579    expect = (1.0 * strategy.num_replicas_in_sync,
580              indexed_slices.IndexedSlices(
581                  values=array_ops.identity(
582                      [[1.0 * strategy.num_replicas_in_sync]]),
583                  indices=array_ops.identity([0]),
584                  dense_shape=array_ops.identity([5, 1])),
585              2.0 * strategy.num_replicas_in_sync,
586              indexed_slices.IndexedSlices(
587                  values=array_ops.identity(
588                      [[2.0 * strategy.num_replicas_in_sync]]),
589                  indices=array_ops.identity([1]),
590                  dense_shape=array_ops.identity([5, 1])))
591
592    self.assertAllClose(
593        nest.map_structure(ops.convert_to_tensor, got),
594        nest.map_structure(ops.convert_to_tensor, expect))
595
596
597def _make_indexed_slices(values, indices, dense_shape):
598  tensor = indexed_slices.IndexedSlices(
599      values=constant_op.constant(values),
600      indices=constant_op.constant(indices),
601      dense_shape=constant_op.constant(dense_shape))
602  return tensor
603
604
605def _get_num_replicas_per_client(strategy):
606  if isinstance(strategy, CollectiveAllReduceStrategy):
607    resolver = strategy.cluster_resolver
608    return max(nest.flatten(resolver.num_accelerators())[0], 1)
609  else:
610    return strategy.num_replicas_in_sync
611
612
613@combinations.generate(
614    combinations.combine(
615        strategy=[
616            strategy_combinations.multi_worker_mirrored_2x1_cpu,
617            strategy_combinations.multi_worker_mirrored_2x1_gpu,
618        ],
619        mode=['eager']))
620class DistributedCollectiveAllReduceStrategyTest(
621    strategy_test_lib.DistributionTestBase,
622    parameterized.TestCase):
623
624  def testDatasetFromFunction(self, strategy):
625    def dataset_fn(input_context):
626      global_batch_size = 10
627      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
628      d = dataset_ops.DatasetV2.range(100).repeat().batch(batch_size)
629      return d.shard(input_context.num_input_pipelines,
630                     input_context.input_pipeline_id)
631
632    expected_sum_on_workers = {'chief': 10, 'worker': 35}
633    input_iterator = iter(
634        strategy.distribute_datasets_from_function(dataset_fn))
635
636    @def_function.function
637    def run(iterator):
638      return strategy.experimental_local_results(iterator.get_next())
639
640    result = run(input_iterator)
641    sum_value = math_ops.reduce_sum(result)
642    self.assertEqual(
643        sum_value.numpy(),
644        expected_sum_on_workers[multi_worker_test_base.get_task_type()])
645
646  def testSimpleInputFromDatasetLastPartialBatch(self, strategy):
647    global_batch_size = 8
648    dataset = dataset_ops.DatasetV2.range(14).batch(
649        global_batch_size, drop_remainder=False)
650    input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
651
652    @def_function.function
653    def run(input_iterator):
654      return strategy.run(lambda x: x, args=(next(input_iterator),))
655
656    # Let the complete batch go.
657    run(input_iterator)
658
659    # `result` is an incomplete batch
660    result = run(input_iterator)
661    expected_data_on_workers = {'chief': [8, 9, 10], 'worker': [11, 12, 13]}
662    self.assertAllEqual(
663        expected_data_on_workers[multi_worker_test_base.get_task_type()],
664        result.numpy(),
665    )
666
667  def testSimpleInputFromFnLastPartialBatch(self, strategy):
668
669    def dataset_fn(input_context):
670      global_batch_size = 8
671      batch_size = input_context.get_per_replica_batch_size(global_batch_size)
672      dataset = dataset_ops.DatasetV2.range(14).batch(
673          batch_size, drop_remainder=False)
674      return dataset.shard(input_context.num_input_pipelines,
675                           input_context.input_pipeline_id)
676
677    input_iterator = iter(
678        strategy.distribute_datasets_from_function(dataset_fn))
679
680    @def_function.function
681    def run(input_iterator):
682      return strategy.run(lambda x: x, args=(next(input_iterator),))
683
684    # Let the complete batch go.
685    run(input_iterator)
686    # `result` is an incomplete batch
687    result = run(input_iterator)
688
689    expected_data_on_worker = {'chief': [8, 9, 10, 11], 'worker': [12, 13]}
690    self.assertAllEqual(
691        expected_data_on_worker[multi_worker_test_base.get_task_type()],
692        result.numpy())
693
694  def testReduceHostTensor(self, strategy):
695    reduced = strategy.reduce(
696        reduce_util.ReduceOp.SUM, array_ops.identity(1.), axis=None)
697    self.assertEqual(reduced.numpy(), 2.)
698
699  def testReduceToHostTensor(self, strategy):
700    value = array_ops.identity(1.)
701    reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
702                                          value)
703    self.assertEqual(reduced.numpy(), 2.)
704
705  def testBatchReduceToHostTensor(self, strategy):
706    value = array_ops.identity(1.)
707    reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
708                                                [(value, value),
709                                                 (value, value)])
710    self.assertAllEqual([2., 2.], reduced)
711
712  def testReduceDeviceTensors(self, strategy):
713    value = strategy.run(lambda: array_ops.identity(1.))
714    reduced = strategy.reduce(reduce_util.ReduceOp.SUM, value, axis=None)
715    self.assertEqual(reduced.numpy(), 2.)
716
717  def testReduceToDeviceTensors(self, strategy):
718    value = strategy.run(lambda: array_ops.identity(1.))
719    reduced = strategy.extended.reduce_to(reduce_util.ReduceOp.SUM, value,
720                                          value)
721    self.assertEqual(reduced.numpy(), 2.)
722
723  def testBatchReduceToDeviceTensors(self, strategy):
724    value = strategy.run(lambda: array_ops.identity(1.))
725    reduced = strategy.extended.batch_reduce_to(reduce_util.ReduceOp.SUM,
726                                                [(value, value),
727                                                 (value, value)])
728    self.assertAllEqual([2., 2.], reduced)
729
730  # TODO(crccw): add a test that mixes device and host tensors after multi
731  # worker strategy combinations can run on a fixed number of GPUs.
732
733
734class StrategyClusterResolverTest(test.TestCase, parameterized.TestCase):
735
736  @combinations.generate(
737      combinations.combine(
738          strategy=[strategy_combinations.multi_worker_mirrored_2x1_cpu] +
739          strategy_combinations.all_strategies,
740          mode=['eager']))
741  def testClusterResolverProperty(self, strategy):
742    # CollectiveAllReduceStrategy and TPUStrategy must have a cluster resolver.
743    # `None` otherwise.
744    resolver = strategy.cluster_resolver
745    if (not isinstance(strategy, CollectiveAllReduceStrategy) and
746        not strategy_test_lib.is_tpu_strategy(strategy)):
747      self.assertIsNone(resolver)
748      return
749
750    with strategy.scope():
751      self.assertIs(strategy.cluster_resolver, resolver)
752
753    self.assertTrue(hasattr(resolver, 'cluster_spec'))
754    self.assertTrue(hasattr(resolver, 'master'))
755    self.assertTrue(hasattr(resolver, 'num_accelerators'))
756    self.assertTrue(hasattr(resolver, 'task_id'))
757    self.assertTrue(hasattr(resolver, 'task_type'))
758    if isinstance(strategy, CollectiveAllReduceStrategy):
759      self.assertEqual(resolver.task_id, 0)
760      self.assertAllInSet(resolver.task_type, ['chief', 'worker'])
761
762
763if __name__ == '__main__':
764  test_util.main()
765