xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/cross_device_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CrossDeviceOps."""
16
17import collections
18import os
19import threading
20import time
21
22from absl.testing import parameterized
23
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.core.protobuf import tensorflow_server_pb2
26from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
27from tensorflow.python.distribute import collective_util
28from tensorflow.python.distribute import combinations
29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
30from tensorflow.python.distribute import cross_device_utils
31from tensorflow.python.distribute import device_util
32from tensorflow.python.distribute import multi_process_runner
33from tensorflow.python.distribute import multi_worker_test_base
34from tensorflow.python.distribute import reduce_util
35from tensorflow.python.distribute import test_util
36from tensorflow.python.distribute import values as value_lib
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.eager import test
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import indexed_slices
44from tensorflow.python.framework import ops
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import collective_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.util import nest
50
51CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher
52CommunicationImplementation = collective_util.CommunicationImplementation
53ReduceOp = reduce_util.ReduceOp
54IndexedSlicesValue = indexed_slices.IndexedSlicesValue
55IndexedSlices = indexed_slices.IndexedSlices
56
57
58def make_per_replica_value(value, devices):
59  """Creates a `PerReplica` object whose values reside in `devices`.
60
61  Args:
62    value: a tensor-convertible value or a `IndexedSlicesValue`, or a callable
63      that takes one argument (`device_idx`) and should return the value that is
64      going to be created on devices[device_idx].
65    devices: a list of device strings to create `PerReplica` values on.
66
67  Returns:
68    A `PerReplica` object.
69  """
70  values = []
71  for device_idx, device in enumerate(devices):
72    if callable(value):
73      v = value(device_idx)
74    elif isinstance(value, list):
75      v = value[device_idx]
76    else:
77      v = value
78    if isinstance(v, IndexedSlicesValue):
79      with ops.device(device):
80        values.append(
81            IndexedSlices(
82                values=array_ops.identity(v.values),
83                indices=array_ops.identity(v.indices),
84                dense_shape=array_ops.identity(v.dense_shape)))
85    else:
86      with ops.device(device):
87        values.append(array_ops.identity(v))
88  return value_lib.PerReplica(values)
89
90
91def enable_collective_ops():
92  """Enable collectives in the current process."""
93  cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
94  context.context().configure_collective_ops(
95      collective_leader="'/job:worker/replica:0/task:0'")
96  config_proto = config_pb2.ConfigProto()
97  config_proto.experimental.collective_group_leader = (
98      "/job:worker/replica:0/task:0")
99  server_def = tensorflow_server_pb2.ServerDef(
100      cluster=cluster_resolver.cluster_spec().as_cluster_def(),
101      default_session_config=config_proto,
102      job_name=cluster_resolver.task_type,
103      task_index=cluster_resolver.task_id,
104      protocol=cluster_resolver.rpc_layer)
105  context.context().enable_collective_ops(server_def)
106  # Recover default flag values.
107  CollectiveReplicaLauncher._prefer_unique_instance_key = True
108  CollectiveReplicaLauncher._prefer_ordering_token = False
109
110
111class MultiProcessPoolRunner():
112
113  def __init__(self, num_processes):
114    cluster_spec_dict = multi_worker_test_base.create_cluster_spec(
115        num_workers=num_processes)
116    self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict)
117
118
119# Global MultiProcessPoolRunners that can be shared by test cases to avoid
120# expensive initialization cost of TensorFlow in new processes.
121#
122# Note that they have to be globals and can't be owned by test classes because
123# usually fn usually captures the test class instance, and test class
124# instance can't be pickled if it has mpr as a member (it is not allowed to
125# pickle Process objects).
126# TODO(crccw): Use `num_workers` combination once it is ready.
127global_mpr_2p = MultiProcessPoolRunner(num_processes=2)
128global_mpr_1p = MultiProcessPoolRunner(num_processes=1)
129
130
131def get_global_mpr(num_processes):
132  if num_processes == 1:
133    return global_mpr_1p.runner
134  elif num_processes == 2:
135    return global_mpr_2p.runner
136  else:
137    raise ValueError("get_global_mpr: num_processes must be 1 or 2, got %d" %
138                     num_processes)
139
140
141class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
142
143  def setUp(self):
144    super().setUp()
145    # Enabling collectives can be done in "setUpClass", but requires using
146    # different collective_keys in different tests as collectives are reused
147    # across tests. Always resetting collective ops before each test offers
148    # better test isolation.
149    global_mpr_1p.runner.run(enable_collective_ops)
150    global_mpr_2p.runner.run(enable_collective_ops)
151
152  def make_collective(self, num_processes, gpu_per_process):
153    """Returns collectives and other info to be used in tests.
154
155    Args:
156      num_processes: an integer indicating the number of processes that
157        participate in the collective.
158      gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
159
160    Returns:
161     A tuple of (collective, devices, pid) where collective is a instance
162     of `CollectiveAllReduce`, devices are a list of local devices (str)
163     attached to the current process, and pid is the id of this process among
164     all participant processes.
165    """
166
167    cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
168    devices = [
169        "/job:worker/replica:0/task:%d/device:CPU:0" % cluster_resolver.task_id
170    ]
171    if gpu_per_process > 0:
172      devices = [
173          "/job:worker/replica:0/task:%d/device:GPU:%d" %
174          (cluster_resolver.task_id, i) for i in range(gpu_per_process)
175      ]
176    group_size = num_processes * len(devices)
177    collective = cross_device_ops_lib.CollectiveAllReduce(
178        devices=devices,
179        group_size=group_size,
180        options=collective_util.Options())
181    return collective, devices, cluster_resolver.task_id
182
183  def as_list(self, value):
184    """An utility to convert a `Mirrored`, `Tensor` or `IndexedSlices` to a list.
185
186    The reason it exists is to provide a uniformed view of returned value of
187    "reduce" calls, especially across tf.function boundaries. Returning
188    `Mirrored` from a tf.function will only evaluate the primary value, which
189    makes collective ops of non-primary device being pruned, and will eventually
190    cause hanging.
191
192    Args:
193      value: the value to convert, can be one of `Mirrored`, `Tensor` and
194        `IndexedSlices`.
195
196    Returns:
197      A list of `Tensor` or `IndexedSlices`.
198    """
199    if isinstance(value, ops.Tensor):
200      return [value]
201    elif isinstance(value, IndexedSlices):
202      return [value]
203    elif isinstance(value, value_lib.Mirrored):
204      return value.values
205    else:
206      raise ValueError("unwrap: unsupported input type: %s" % type(value))
207
208  RunOptions = collections.namedtuple(  # pylint: disable=invalid-name
209      "RunOptions",
210      [
211          "mode",  # A list of str from ["eager", "func_graph"]
212          "num_processes",
213          "gpus_per_process",
214          "reduce_op",
215          "communication_options",
216          "prefer_unique_instance_key",
217      ])
218  RunOptions.__new__.__defaults__ = (["eager",
219                                      "func_graph"], 2, 0, ReduceOp.SUM,
220                                     collective_util.Options(), True)
221
222  def reduce_and_verify(self, inputs, expect, options):
223    """Reduce the given `inputs` and verify the output matches `expect`.
224
225    Args:
226      inputs: a list of `Tensor` or `IndexedSlices`, where i-th value will be
227        fed to i-th replica.
228      expect: a `Tensor` or `IndexedSlices`. This should be the expected value
229        for one replica.
230      options: a `RunOpotions` instance.
231    """
232
233    def replica_fn():
234      CollectiveReplicaLauncher._prefer_unique_instance_key = (
235          options.prefer_unique_instance_key)
236      collective, devices, pid = self.make_collective(options.num_processes,
237                                                      options.gpus_per_process)
238
239      def reduce_fn():
240        value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx]
241        per_replica_value = make_per_replica_value(value_fn, devices)
242        reduced_values = collective.reduce(options.reduce_op, per_replica_value,
243                                           per_replica_value,
244                                           options.communication_options)
245        if options.gpus_per_process > 1:
246          self.assertIsInstance(reduced_values, value_lib.Mirrored)
247        reduced_values = self.as_list(reduced_values)
248        self.assertAllEqual(devices, [v.device for v in reduced_values])
249        return [ops.convert_to_tensor(v) for v in reduced_values]
250
251      per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)
252
253      if "eager" in options.mode:
254        got = reduce_fn()
255        self.assertAllClose(got, per_replica_expect)
256
257      if "func_graph" in options.mode:
258        got = def_function.function(reduce_fn)()
259        self.assertAllClose(got, per_replica_expect)
260
261    get_global_mpr(options.num_processes).run(replica_fn)
262
263  def batch_reduce_and_verify(self, inputs, expect, options):
264    """Batch reduce the given `inputs` and verify the output matches `expect`.
265
266    Args:
267      inputs: a 2-level nested list of `Tensor` or `IndexedSlices`, where i-th
268        value will be fed to i-th replica.
269      expect: a list of `Tensor` or `IndexedSlices`. This should be the expected
270        value for one replica.
271      options: a `RunOpotions` instance.
272    """
273
274    def replica_fn():
275      CollectiveReplicaLauncher._prefer_unique_instance_key = (
276          options.prefer_unique_instance_key)
277      collective, devices, pid = self.make_collective(options.num_processes,
278                                                      options.gpus_per_process)
279
280      def batch_reduce_fn():
281        batch_size = len(inputs[0])
282        value_dst_pairs = []
283        for i in range(batch_size):
284
285          def value_fn(device_idx, idx=i):
286            return inputs[pid * len(devices) + device_idx][idx]
287
288          per_replica_value = make_per_replica_value(value_fn, devices)
289          value_dst_pairs.append((per_replica_value, per_replica_value))
290        reduced_values = collective.batch_reduce(options.reduce_op,
291                                                 value_dst_pairs,
292                                                 options.communication_options)
293        if options.gpus_per_process > 1:
294          for v in reduced_values:
295            self.assertIsInstance(v, value_lib.Mirrored)
296        reduced_values = [self.as_list(v) for v in reduced_values]
297        for v in reduced_values:
298          self.assertAllEqual(devices, [t.device for t in v])
299        return nest.map_structure(ops.convert_to_tensor, reduced_values)
300
301      per_replica_expect = nest.map_structure(
302          lambda x: [ops.convert_to_tensor(x)] * len(devices), expect)
303
304      if "eager" in options.mode:
305        got = batch_reduce_fn()
306        self.assertAllClose(got, per_replica_expect)
307
308      if "func_graph" in options.mode:
309        got = def_function.function(batch_reduce_fn)()
310        self.assertAllClose(got, per_replica_expect)
311
312    get_global_mpr(options.num_processes).run(replica_fn)
313
314  @combinations.generate(
315      combinations.combine(
316          num_processes=[1, 2],
317          required_gpus=[0, 1, 2],
318          implementation=[
319              CommunicationImplementation.AUTO,
320              CommunicationImplementation.RING,
321              CommunicationImplementation.NCCL,
322          ],
323          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
324          prefer_unique_instance_key=[True, False]))
325  def testReduceDense(self, num_processes, required_gpus, implementation,
326                      reduce_op, prefer_unique_instance_key):
327    if (required_gpus == 0 and
328        implementation == CommunicationImplementation.NCCL):
329      self.skipTest("Skip CPU + NCCL combination")
330    if (num_processes != required_gpus and
331        implementation == CommunicationImplementation.NCCL):
332      self.skipTest("Skip NCCL combination with mismatched process and GPU "
333                    "count. NCCL requires physical GPUs for every process.")
334    if (num_processes != required_gpus and
335        implementation == CommunicationImplementation.AUTO):
336      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
337                    "process and GPU count. NCCL requires physical GPUs for "
338                    "every process.")
339    options = self.RunOptions(
340        num_processes=num_processes,
341        gpus_per_process=required_gpus,
342        reduce_op=reduce_op,
343        communication_options=collective_util.Options(
344            implementation=implementation),
345        prefer_unique_instance_key=prefer_unique_instance_key)
346    group_size = options.num_processes * (options.gpus_per_process or 1)
347
348    inputs_data = [1.0, 2.0, 3.0, 4.0]
349    inputs = inputs_data[0:group_size]
350
351    if group_size == 1:
352      expect = 1.0
353    if group_size == 2:
354      expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5
355    elif group_size == 4:
356      expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5
357
358    self.reduce_and_verify(inputs, expect, options)
359
360  @combinations.generate(
361      combinations.combine(
362          num_processes=[1, 2],
363          required_gpus=[0, 1, 2],
364          implementation=[
365              CommunicationImplementation.AUTO,
366              CommunicationImplementation.RING,
367              CommunicationImplementation.NCCL,
368          ],
369          # TODO(b/166682130): add MEAN reduce once the bug is fixed.
370          reduce_op=ReduceOp.SUM,
371          prefer_unique_instance_key=[True, False]))
372  def testReduceSparse(self, num_processes, required_gpus, implementation,
373                       reduce_op, prefer_unique_instance_key):
374    if (required_gpus == 0 and
375        implementation == CommunicationImplementation.NCCL):
376      self.skipTest("Skip CPU + NCCL combination")
377    if (num_processes != required_gpus and
378        implementation == CommunicationImplementation.NCCL):
379      self.skipTest("Skip NCCL combination with mismatched process and GPU "
380                    "count. NCCL requires physical GPUs for every process.")
381    if (num_processes != required_gpus and
382        implementation == CommunicationImplementation.AUTO):
383      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
384                    "process and GPU count. NCCL requires physical GPUs for "
385                    "every process.")
386    options = self.RunOptions(
387        mode=["func_graph"],  # Sparse reduce is not supported in eager.
388        num_processes=num_processes,
389        gpus_per_process=required_gpus,
390        reduce_op=reduce_op,
391        communication_options=collective_util.Options(
392            implementation=implementation),
393        prefer_unique_instance_key=prefer_unique_instance_key)
394    group_size = options.num_processes * (options.gpus_per_process or 1)
395
396    inputs_data = [
397        IndexedSlicesValue(
398            values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
399        IndexedSlicesValue(
400            values=[[3.], [4.]], indices=[1, 2], dense_shape=[10, 1]),
401        IndexedSlicesValue(
402            values=[[5.], [6.]], indices=[7, 8], dense_shape=[10, 1]),
403        IndexedSlicesValue(
404            values=[[7.], [8.]], indices=[3, 2], dense_shape=[10, 1]),
405    ]
406    inputs = inputs_data[0:group_size]
407
408    if group_size == 1:
409      expect = IndexedSlices(
410          values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1])
411    elif group_size == 2:
412      expect = IndexedSlices(
413          values=[[1.], [2.], [3.], [4.]],
414          indices=[0, 1, 1, 2],
415          dense_shape=[10, 1])
416    elif group_size == 4:
417      expect = IndexedSlices(
418          values=[[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]],
419          indices=[0, 1, 1, 2, 7, 8, 3, 2],
420          dense_shape=[10, 1])
421
422    self.reduce_and_verify(inputs, expect, options)
423
424  @combinations.generate(
425      combinations.combine(prefer_unique_instance_key=[True, False]))
426  def testReduceSparseVariableLength(self, prefer_unique_instance_key):
427    # One device per process, 2 processes, 2 replicas in total.
428    inputs = [
429        IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
430        IndexedSlicesValue(
431            values=[[2.], [3.], [4.]], indices=[0, 1, 2], dense_shape=[10, 1]),
432    ]
433    expect = IndexedSlices(
434        values=[[1.], [2.], [3.], [4.]],
435        indices=[0, 0, 1, 2],
436        dense_shape=[10, 1])
437    self.reduce_and_verify(
438        inputs,
439        expect,
440        self.RunOptions(
441            mode=["func_graph"],  # Sparse reduce is not supported in eager.
442            num_processes=2,
443            reduce_op=ReduceOp.SUM,
444            prefer_unique_instance_key=prefer_unique_instance_key))
445
446  @combinations.generate(
447      combinations.combine(
448          num_processes=[1, 2],
449          required_gpus=[0, 1, 2],
450          implementation=[
451              CommunicationImplementation.AUTO,
452              CommunicationImplementation.RING,
453              CommunicationImplementation.NCCL,
454          ],
455          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
456          prefer_unique_instance_key=[True, False]))
457  def testBatchReduceDense(self, num_processes, required_gpus, implementation,
458                           reduce_op, prefer_unique_instance_key):
459    if (required_gpus == 0 and
460        implementation == CommunicationImplementation.NCCL):
461      self.skipTest("Skip CPU + NCCL combination")
462    if (num_processes != required_gpus and
463        implementation == CommunicationImplementation.NCCL):
464      self.skipTest("Skip NCCL combination with mismatched process and GPU "
465                    "count. NCCL requires physical GPUs for every process.")
466    if (num_processes != required_gpus and
467        implementation == CommunicationImplementation.AUTO):
468      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
469                    "process and GPU count. NCCL requires physical GPUs for "
470                    "every process.")
471
472    options = self.RunOptions(
473        num_processes=num_processes,
474        gpus_per_process=required_gpus,
475        reduce_op=reduce_op,
476        communication_options=collective_util.Options(
477            implementation=implementation),
478        prefer_unique_instance_key=prefer_unique_instance_key)
479    group_size = options.num_processes * (options.gpus_per_process or 1)
480
481    inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
482    inputs = inputs_data[0:group_size]
483
484    if group_size == 1:
485      expect = [1.0, 2.0]
486    if group_size == 2:
487      expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0]
488    elif group_size == 4:
489      expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0]
490
491    self.batch_reduce_and_verify(inputs, expect, options)
492
493  @combinations.generate(
494      combinations.combine(
495          num_processes=[1, 2],
496          required_gpus=[0, 1, 2],
497          implementation=[
498              CommunicationImplementation.AUTO,
499              CommunicationImplementation.RING,
500              CommunicationImplementation.NCCL,
501          ],
502          # TODO(b/166682130): add MEAN reduce once the bug is fixed.
503          reduce_op=ReduceOp.SUM,
504          prefer_unique_instance_key=[True, False]))
505  def testBatchReduceSparse(self, num_processes, required_gpus, implementation,
506                            reduce_op, prefer_unique_instance_key):
507    if (required_gpus == 0 and
508        implementation == CommunicationImplementation.NCCL):
509      self.skipTest("Skip CPU + NCCL combination")
510    if (num_processes != required_gpus and
511        implementation == CommunicationImplementation.NCCL):
512      self.skipTest("Skip NCCL combination with mismatched process and GPU "
513                    "count. NCCL requires physical GPUs for every process.")
514    if (num_processes != required_gpus and
515        implementation == CommunicationImplementation.AUTO):
516      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
517                    "process and GPU count. NCCL requires physical GPUs for "
518                    "every process.")
519
520    options = self.RunOptions(
521        mode=["func_graph"],  # Sparse reduce is not supported in eager.
522        num_processes=num_processes,
523        gpus_per_process=required_gpus,
524        reduce_op=reduce_op,
525        communication_options=collective_util.Options(
526            implementation=implementation),
527        prefer_unique_instance_key=prefer_unique_instance_key)
528    group_size = options.num_processes * (options.gpus_per_process or 1)
529
530    inputs_data = ([
531        IndexedSlicesValue(
532            values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
533        IndexedSlicesValue(
534            values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1])
535    ], [
536        IndexedSlicesValue(
537            values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]),
538        IndexedSlicesValue(
539            values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1])
540    ], [
541        IndexedSlicesValue(
542            values=[[9.], [10.]], indices=[3, 4], dense_shape=[10, 1]),
543        IndexedSlicesValue(
544            values=[[11.], [12.]], indices=[3, 4], dense_shape=[5, 1])
545    ], [
546        IndexedSlicesValue(
547            values=[[13.], [14.]], indices=[8, 9], dense_shape=[10, 1]),
548        IndexedSlicesValue(
549            values=[[15.], [16.]], indices=[3, 4], dense_shape=[5, 1])
550    ])
551    inputs = inputs_data[0:group_size]
552
553    if group_size == 1:
554      expect = [
555          IndexedSlices(
556              values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
557          IndexedSlices(
558              values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1])
559      ]
560    if group_size == 2:
561      expect = [
562          IndexedSlices(
563              values=[[1.], [2.], [5.], [6.]],
564              indices=[0, 1, 1, 2],
565              dense_shape=[10, 1]),
566          IndexedSlices(
567              values=[[3.], [4.], [7.], [8.]],
568              indices=[1, 2, 0, 1],
569              dense_shape=[5, 1])
570      ]
571    elif group_size == 4:
572      expect = [
573          IndexedSlices(
574              values=[[1.], [2.], [5.], [6.], [9.], [10.], [13.], [14.]],
575              indices=[0, 1, 1, 2, 3, 4, 8, 9],
576              dense_shape=[10, 1]),
577          IndexedSlices(
578              values=[[3.], [4.], [7.], [8.], [11.], [12.], [15.], [16.]],
579              indices=[1, 2, 0, 1, 3, 4, 3, 4],
580              dense_shape=[5, 2])
581      ]
582    self.batch_reduce_and_verify(inputs, expect, options)
583
584  def testBatchReduceMixedDenseAndSparse(self):
585
586    options = self.RunOptions(
587        num_processes=2,
588        gpus_per_process=0,
589        reduce_op=ReduceOp.SUM,
590        mode=["func_graph"])
591
592    inputs_data = [
593        [
594            1.0, 2.0,
595            IndexedSlicesValue(
596                values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
597            IndexedSlicesValue(
598                values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1])
599        ],
600        [
601            3.0, 4.0,
602            IndexedSlicesValue(
603                values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]),
604            IndexedSlicesValue(
605                values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1])
606        ],
607    ]
608
609    expect = [
610        4.0, 6.0,
611        IndexedSlices(
612            values=[[1.], [2.], [5.], [6.]],
613            indices=[0, 1, 1, 2],
614            dense_shape=[10, 1]),
615        IndexedSlices(
616            values=[[3.], [4.], [7.], [8.]],
617            indices=[1, 2, 0, 1],
618            dense_shape=[5, 1])
619    ]
620
621    self.batch_reduce_and_verify(inputs_data, expect, options)
622
623  @combinations.generate(
624      combinations.combine(
625          num_processes=[1, 2],
626          required_gpus=[0, 1, 2],
627          implementation=[
628              CommunicationImplementation.AUTO,
629              CommunicationImplementation.RING,
630              CommunicationImplementation.NCCL,
631          ],
632          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
633      ))
634  def testAllReduceDense(self, num_processes, required_gpus, implementation,
635                         reduce_op):
636    if (required_gpus == 0 and
637        implementation == CommunicationImplementation.NCCL):
638      self.skipTest("Skip CPU + NCCL combination")
639    if (num_processes != required_gpus and
640        implementation == CommunicationImplementation.NCCL):
641      self.skipTest("Skip NCCL combination with mismatched process and GPU "
642                    "count. NCCL requires physical GPUs for every process.")
643    if (num_processes != required_gpus and
644        implementation == CommunicationImplementation.AUTO):
645      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
646                    "process and GPU count. NCCL requires physical GPUs for "
647                    "every process.")
648
649    def replica_fn():
650      collective, devices, _ = self.make_collective(num_processes,
651                                                    required_gpus)
652      options = collective_util.Options(implementation=implementation)
653      group_size = num_processes * (required_gpus or 1)
654
655      @def_function.function
656      def collective_all_reduce():
657        results = []
658        for replica_id, device in enumerate(devices):
659          with ops.device(device):
660            value = constant_op.constant(1.0)
661            results.append(
662                collective._all_reduce(reduce_op, value, replica_id, options))
663        return results
664
665      got = collective_all_reduce()
666      if reduce_op == ReduceOp.SUM:
667        expect = [1.0 * group_size] * len(devices)
668      elif reduce_op == ReduceOp.MEAN:
669        expect = [1.0] * len(devices)
670      self.assertAllClose(got, expect)
671
672      @def_function.function
673      def collective_batch_all_reduce():
674        results = []
675        for replica_id, device in enumerate(devices):
676          with ops.device(device):
677            value = (constant_op.constant(1.0), constant_op.constant(2.0))
678            results.append(
679                collective._all_reduce(reduce_op, value, replica_id, options))
680        return results
681
682      got = collective_batch_all_reduce()
683      if reduce_op == ReduceOp.SUM:
684        expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices)
685      elif reduce_op == ReduceOp.MEAN:
686        expect = [(1.0, 2.0)] * len(devices)
687      self.assertAllClose(got, expect)
688
689    get_global_mpr(num_processes).run(replica_fn)
690
691  @combinations.generate(
692      combinations.combine(
693          num_processes=[1, 2],
694          required_gpus=[0, 1, 2],
695          implementation=[
696              CommunicationImplementation.AUTO,
697              CommunicationImplementation.RING,
698              CommunicationImplementation.NCCL,
699          ],
700          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
701      ))
702  def testAllReduceSparse(self, num_processes, required_gpus, implementation,
703                          reduce_op):
704    if (required_gpus == 0 and
705        implementation == CommunicationImplementation.NCCL):
706      self.skipTest("Skip CPU + NCCL combination")
707    if (num_processes != required_gpus and
708        implementation == CommunicationImplementation.NCCL):
709      self.skipTest("Skip NCCL combination with mismatched process and GPU "
710                    "count. NCCL requires physical GPUs for every process.")
711    if (num_processes != required_gpus and
712        implementation == CommunicationImplementation.AUTO):
713      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
714                    "process and GPU count. NCCL requires physical GPUs for "
715                    "every process.")
716
717    def replica_fn():
718      collective, devices, _ = self.make_collective(num_processes,
719                                                    required_gpus)
720      options = collective_util.Options(implementation=implementation)
721      group_size = num_processes * (required_gpus or 1)
722
723      @def_function.function
724      def collective_all_reduce():
725        results = []
726        for replica_id, device in enumerate(devices):
727          with ops.device(device):
728            value = IndexedSlices(
729                values=array_ops.identity([[1.]]),
730                indices=array_ops.identity([0]),
731                dense_shape=array_ops.identity([5, 1]))
732            results.append(
733                collective._all_reduce(reduce_op, value, replica_id, options))
734        return results
735
736      got = collective_all_reduce()
737      if reduce_op == ReduceOp.SUM:
738        expect = [IndexedSlices([[1. * group_size]], [0], [5, 1])
739                 ] * len(devices)
740      elif reduce_op == ReduceOp.MEAN:
741        expect = [IndexedSlices([[1.]], [0], [5, 1])] * len(devices)
742      self.assertAllClose(
743          nest.map_structure(ops.convert_to_tensor, got),
744          nest.map_structure(ops.convert_to_tensor, expect))
745
746      @def_function.function
747      def collective_batch_all_reduce():
748        results = []
749        for replica_id, device in enumerate(devices):
750          with ops.device(device):
751            value = (IndexedSlices(
752                array_ops.identity([[1.]]), array_ops.identity([0]),
753                array_ops.identity([5, 1])),
754                     IndexedSlices(
755                         array_ops.identity([[3.]]), array_ops.identity([2]),
756                         array_ops.identity([5, 1])))
757            results.append(
758                collective._all_reduce(reduce_op, value, replica_id, options))
759        return results
760
761      got = collective_batch_all_reduce()
762      if reduce_op == ReduceOp.SUM:
763        expect = [(IndexedSlices([[1. * group_size]], [0], [5, 1]),
764                   IndexedSlices([[3. * group_size]], [2], [5, 1]))
765                 ] * len(devices)
766      elif reduce_op == ReduceOp.MEAN:
767        expect = [(IndexedSlices([[1.]], [0], [5, 1]),
768                   IndexedSlices([[3.]], [2], [5, 1]))] * len(devices)
769      self.assertAllClose(
770          nest.map_structure(ops.convert_to_tensor, got),
771          nest.map_structure(ops.convert_to_tensor, expect))
772
773    get_global_mpr(num_processes).run(replica_fn)
774
775  @combinations.generate(
776      combinations.combine(
777          num_processes=2,
778          required_gpus=0,
779          implementation=CommunicationImplementation.AUTO,
780          reduce_op=ReduceOp.SUM))
781  def testAllReduceMixedDenseAndSparse(self, num_processes, required_gpus,
782                                       implementation, reduce_op):
783
784    if (num_processes != required_gpus and
785        implementation == CommunicationImplementation.AUTO):
786      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
787                    "process and GPU count. NCCL requires physical GPUs for "
788                    "every process.")
789
790    def replica_fn():
791      collective, devices, _ = self.make_collective(num_processes,
792                                                    required_gpus)
793      options = collective_util.Options(implementation=implementation)
794      group_size = num_processes * (required_gpus or 1)
795
796      @def_function.function
797      def collective_batch_all_reduce():
798        results = []
799        for replica_id, device in enumerate(devices):
800          with ops.device(device):
801            value = (IndexedSlices(
802                array_ops.identity([[1.]]), array_ops.identity([0]),
803                array_ops.identity([5, 1])), array_ops.identity(1.0),
804                     IndexedSlices(
805                         array_ops.identity([[3.]]), array_ops.identity([2]),
806                         array_ops.identity([5, 1])), array_ops.identity(2.0))
807            results.append(
808                collective._all_reduce(reduce_op, value, replica_id, options))
809        return results
810
811      got = collective_batch_all_reduce()
812      expect = [
813          (IndexedSlices([[1. * group_size]], [0], [5, 1]), 1.0 * group_size,
814           IndexedSlices([[3. * group_size]], [2], [5, 1]), 2.0 * group_size)
815      ] * len(devices)
816      self.assertAllClose(
817          nest.map_structure(ops.convert_to_tensor, got),
818          nest.map_structure(ops.convert_to_tensor, expect))
819
820    get_global_mpr(num_processes).run(replica_fn)
821
822  @combinations.generate(
823      combinations.combine(
824          num_processes=[1, 2],
825          required_gpus=[0, 1, 2],
826          axis=[0, 1, 2],
827          func_mode=["eager", "func_graph"],
828          implementation=[
829              CommunicationImplementation.AUTO,
830              CommunicationImplementation.RING,
831              CommunicationImplementation.NCCL,
832          ],
833          prefer_unique_instance_key=[True, False]))
834  def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
835                             func_mode, axis, prefer_unique_instance_key):
836
837    if (required_gpus == 0 and
838        implementation == CommunicationImplementation.NCCL):
839      self.skipTest("Skip CPU + NCCL combination")
840    if (num_processes != required_gpus and
841        implementation == CommunicationImplementation.NCCL):
842      self.skipTest("Skip NCCL combination with mismatched process and GPU "
843                    "count. NCCL requires physical GPUs for every process.")
844    if (num_processes != required_gpus and
845        implementation == CommunicationImplementation.AUTO):
846      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
847                    "process and GPU count. NCCL requires physical GPUs for "
848                    "every process.")
849
850    def replica_fn():
851      CollectiveReplicaLauncher._prefer_unique_instance_key = (
852          prefer_unique_instance_key)
853      collective, devices, _ = self.make_collective(num_processes,
854                                                    required_gpus)
855      options = collective_util.Options(implementation=implementation)
856      value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32)
857
858      def gather_fn():
859        per_replica_value = make_per_replica_value(value, devices)
860        gathered_values = collective._gather(
861            per_replica_value, per_replica_value, axis=axis, options=options)
862        gathered_values = self.as_list(gathered_values)
863        # Skip checking devices in eager. In eager the device attribute doesn't
864        # reflect the actual device of the tensor.
865        if not context.executing_eagerly():
866          self.assertAllEqual(devices, [v.device for v in gathered_values])
867        return [ops.convert_to_tensor(v) for v in gathered_values]
868
869      group_size = num_processes * (required_gpus or 1)
870      expect = array_ops.concat([value] * group_size, axis=axis)
871      per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)
872
873      if func_mode == "eager":
874        result = gather_fn()
875        self.assertAllClose(result, per_replica_expect)
876
877      if func_mode == "func_graph":
878        result = def_function.function(gather_fn)()
879        self.assertAllClose(result, per_replica_expect)
880
881    get_global_mpr(num_processes).run(replica_fn)
882
883  @combinations.generate(
884      combinations.combine(
885          num_processes=[1, 2],
886          required_gpus=[0, 1, 2],
887          implementation=[CommunicationImplementation.RING]))
888  def testCollectiveV2ControlFlow(self, num_processes, required_gpus,
889                                  implementation):
890
891    def replica_fn():
892      CollectiveReplicaLauncher._prefer_unique_instance_key = True
893      collective, devices, _ = self.make_collective(num_processes,
894                                                    required_gpus)
895      options = collective_util.Options(implementation=implementation)
896      value = make_per_replica_value(constant_op.constant([1.]), devices)
897
898      @def_function.function
899      def reduce_fn():
900
901        def cond_body():
902          reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value,
903                                      options)
904          return math_ops.add_n(self.as_list(reduced)) / len(devices)
905
906        return control_flow_ops.cond(
907            array_ops.identity(False), cond_body, cond_body)
908
909      num_replicas = num_processes * len(devices)
910      self.assertAllEqual(reduce_fn(), [1. * num_replicas])
911
912    get_global_mpr(num_processes).run(replica_fn)
913
914  @combinations.generate(
915      combinations.combine(
916          num_processes=1,
917          required_gpus=2,
918          implementation=[
919              CommunicationImplementation.RING,
920              CommunicationImplementation.NCCL,
921          ],
922          prefer_unique_instance_key=[True, False]))
923  def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
924                                                    required_gpus,
925                                                    implementation,
926                                                    prefer_unique_instance_key):
927
928    if (num_processes != required_gpus and
929        implementation == CommunicationImplementation.NCCL):
930      self.skipTest("Skip NCCL combination with mismatched process and GPU "
931                    "count. NCCL requires physical GPUs for every process.")
932    if (num_processes != required_gpus and
933        implementation == CommunicationImplementation.AUTO):
934      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
935                    "process and GPU count. NCCL requires physical GPUs for "
936                    "every process.")
937
938    def replica_fn():
939      CollectiveReplicaLauncher._prefer_unique_instance_key = (
940          prefer_unique_instance_key)
941      collective, devices, _ = self.make_collective(num_processes,
942                                                    required_gpus)
943      options = collective_util.Options(implementation=implementation)
944
945      # We would like to simulate the following sequence:
946      #   thread-0  device0                 device1
947      #   thread-1          device0 device1
948      # If the kernel launch sequence is as-is the program will deadlock since
949      # NCCL requires the launch order to be same on each device.
950      v0 = make_per_replica_value(1.0, devices)
951      v1 = make_per_replica_value(2.0, devices)
952
953      # Add a delay to collective_ops.all_reduce according to the input tensors
954      # index in `sequence.`
955      sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
956      all_reduce = collective_ops.all_reduce
957
958      def delayed_all_reduce(input_tensor, *args, **kwargs):
959        for idx, v in enumerate(sequence):
960          if input_tensor is v:
961            time.sleep(idx)
962            break
963        return all_reduce(input_tensor, *args, **kwargs)
964
965      with test.mock.patch.object(collective_ops, "all_reduce",
966                                  delayed_all_reduce):
967        # We only use NCCL for batch reduce with two or more values, so we use
968        # two values here.
969
970        def thread_fn():
971          reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
972                                            [(v0, v0), (v0, v0)], options)
973          self.assertAllEqual(reduced[0].values, [2.0, 2.0])
974          self.assertAllEqual(reduced[1].values, [2.0, 2.0])
975
976        t = threading.Thread(target=thread_fn)
977        t.start()
978        reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
979                                                                     (v1, v1)],
980                                          options)
981        self.assertAllEqual(reduced[0].values, [4.0, 4.0])
982        self.assertAllEqual(reduced[1].values, [4.0, 4.0])
983        t.join()
984
985    get_global_mpr(num_processes).run(replica_fn)
986
987  @combinations.generate(
988      combinations.combine(
989          num_processes=1,
990          required_gpus=2,
991          implementation=[
992              CommunicationImplementation.RING,
993              CommunicationImplementation.NCCL,
994          ],
995          prefer_unique_instance_key=[True, False]))
996  def testInputsAreFunctionArgs(self, num_processes, required_gpus,
997                                implementation, prefer_unique_instance_key):
998
999    if (num_processes != required_gpus and
1000        implementation == CommunicationImplementation.NCCL):
1001      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1002                    "count. NCCL requires physical GPUs for every process.")
1003    if (num_processes != required_gpus and
1004        implementation == CommunicationImplementation.AUTO):
1005      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
1006                    "process and GPU count. NCCL requires physical GPUs for "
1007                    "every process.")
1008
1009    def replica_fn():
1010      CollectiveReplicaLauncher._prefer_unique_instance_key = (
1011          prefer_unique_instance_key)
1012      collective, devices, _ = self.make_collective(num_processes,
1013                                                    required_gpus)
1014      options = collective_util.Options(implementation=implementation)
1015
1016      @def_function.function
1017      def reduce_fn(v):
1018        # Function inputs don't have device placement.
1019        self.assertEqual(v.values[0].device, "")
1020        self.assertEqual(v.values[1].device, "")
1021        # We only use NCCL for batch reduce with two or more values, so we use
1022        # two values here.
1023        reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v),
1024                                                                     (v, v)],
1025                                          options)
1026        self.assertEqual(reduced[0].values[0].device, devices[0])
1027        self.assertEqual(reduced[0].values[1].device, devices[1])
1028        self.assertEqual(reduced[1].values[0].device, devices[0])
1029        self.assertEqual(reduced[1].values[1].device, devices[1])
1030        # Returning Mirrored only evaluates the primary value, which causes
1031        # hanging,
1032        return [reduced[0].values, reduced[1].values]
1033
1034      v = make_per_replica_value(1.0, devices)
1035      reduced = reduce_fn(v)
1036      self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]])
1037
1038    get_global_mpr(num_processes).run(replica_fn)
1039
1040  @combinations.generate(
1041      combinations.combine(
1042          num_processes=2,
1043          required_gpus=[0, 1],
1044          implementation=[
1045              CommunicationImplementation.RING,
1046              CommunicationImplementation.NCCL,
1047          ],
1048          prefer_unique_instance_key=[True, False]))
1049  def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
1050                             prefer_unique_instance_key):
1051
1052    if (required_gpus == 0 and
1053        implementation == CommunicationImplementation.NCCL):
1054      self.skipTest("Skip CPU + NCCL combination")
1055    if (num_processes != required_gpus and
1056        implementation == CommunicationImplementation.NCCL):
1057      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1058                    "count. NCCL requires physical GPUs for every process.")
1059    if (num_processes != required_gpus and
1060        implementation == CommunicationImplementation.AUTO):
1061      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
1062                    "process and GPU count. NCCL requires physical GPUs for "
1063                    "every process.")
1064
1065    def replica_fn():
1066      CollectiveReplicaLauncher._prefer_unique_instance_key = (
1067          prefer_unique_instance_key)
1068      collective, devices, task_id = self.make_collective(
1069          num_processes, required_gpus)
1070      if task_id != 0:
1071        return
1072
1073      v = make_per_replica_value(1.0, devices)
1074      options = collective_util.Options(
1075          timeout_seconds=1., implementation=implementation)
1076
1077      @def_function.function
1078      def reduce_dense():
1079        return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
1080
1081      # The collective should time out because we only launch it on worker-0,
1082      # while there're three workers in total.
1083      with self.assertRaises(errors.DeadlineExceededError):
1084        reduce_dense()
1085
1086    get_global_mpr(num_processes).run(replica_fn)
1087
1088  @combinations.generate(
1089      combinations.combine(
1090          num_processes=2,
1091          required_gpus=[0, 1],
1092          implementation=[
1093              CommunicationImplementation.RING,
1094              CommunicationImplementation.NCCL,
1095          ],
1096          prefer_unique_instance_key=[True, False]))
1097  def testTimeoutBatchReduceDense(self, num_processes, implementation,
1098                                  required_gpus, prefer_unique_instance_key):
1099    if (required_gpus == 0 and
1100        implementation == CommunicationImplementation.NCCL):
1101      self.skipTest("Skip CPU + NCCL combination")
1102    if (num_processes != required_gpus and
1103        implementation == CommunicationImplementation.NCCL):
1104      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1105                    "count. NCCL requires physical GPUs for every process.")
1106    if (num_processes != required_gpus and
1107        implementation == CommunicationImplementation.AUTO):
1108      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
1109                    "process and GPU count. NCCL requires physical GPUs for "
1110                    "every process.")
1111
1112    def replica_fn():
1113      CollectiveReplicaLauncher._prefer_unique_instance_key = (
1114          prefer_unique_instance_key)
1115      collective, devices, task_id = self.make_collective(
1116          num_processes, required_gpus)
1117      if task_id != 0:
1118        return
1119
1120      v = make_per_replica_value(1.0, devices)
1121      options = collective_util.Options(
1122          timeout_seconds=1., implementation=implementation)
1123
1124      @def_function.function
1125      def batch_reduce_dense():
1126        return collective.batch_reduce(reduce_util.ReduceOp.SUM,
1127                                       [(v, v), (v, v)], options)
1128
1129      # The collective should time out because we only launch it on worker-0,
1130      # while there're two workers in total.
1131      with self.assertRaises(errors.DeadlineExceededError):
1132        batch_reduce_dense()
1133
1134    get_global_mpr(num_processes).run(replica_fn)
1135
1136  @combinations.generate(
1137      combinations.combine(
1138          num_processes=2,
1139          required_gpus=[0, 1],
1140          implementation=[
1141              CommunicationImplementation.RING,
1142              CommunicationImplementation.NCCL,
1143          ],
1144          prefer_unique_instance_key=[True, False]))
1145  def testTimeoutReduceSparse(self, num_processes, implementation,
1146                              required_gpus, prefer_unique_instance_key):
1147    if (required_gpus == 0 and
1148        implementation == CommunicationImplementation.NCCL):
1149      self.skipTest("Skip CPU + NCCL combination")
1150    if (num_processes != required_gpus and
1151        implementation == CommunicationImplementation.NCCL):
1152      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1153                    "count. NCCL requires physical GPUs for every process.")
1154    if (num_processes != required_gpus and
1155        implementation == CommunicationImplementation.AUTO):
1156      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
1157                    "process and GPU count. NCCL requires physical GPUs for "
1158                    "every process.")
1159
1160    def replica_fn():
1161      CollectiveReplicaLauncher._prefer_unique_instance_key = (
1162          prefer_unique_instance_key)
1163      collective, devices, task_id = self.make_collective(
1164          num_processes, required_gpus)
1165      if task_id != 0:
1166        return
1167
1168      v = make_per_replica_value(
1169          IndexedSlicesValue(
1170              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
1171      options = collective_util.Options(
1172          timeout_seconds=1., implementation=implementation)
1173
1174      @def_function.function
1175      def reduce_sparse():
1176        return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
1177
1178      # The collective should time out because we only launch it on worker-0,
1179      # while there're two workers in total.
1180      with self.assertRaises(errors.DeadlineExceededError):
1181        reduce_sparse()
1182
1183    get_global_mpr(num_processes).run(replica_fn)
1184
1185  @combinations.generate(
1186      combinations.combine(
1187          num_processes=2,
1188          required_gpus=[0, 1],
1189          implementation=[
1190              CommunicationImplementation.RING,
1191              CommunicationImplementation.NCCL,
1192          ],
1193          prefer_unique_instance_key=[True, False]))
1194  def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
1195                                   implementation, prefer_unique_instance_key):
1196    if (required_gpus == 0 and
1197        implementation == CommunicationImplementation.NCCL):
1198      self.skipTest("Skip CPU + NCCL combination")
1199    if (num_processes != required_gpus and
1200        implementation == CommunicationImplementation.NCCL):
1201      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1202                    "count. NCCL requires physical GPUs for every process.")
1203    if (num_processes != required_gpus and
1204        implementation == CommunicationImplementation.AUTO):
1205      self.skipTest("Skip potential NCCL combination (AUTO) with mismatched "
1206                    "process and GPU count. NCCL requires physical GPUs for "
1207                    "every process.")
1208
1209    def replica_fn():
1210      CollectiveReplicaLauncher._prefer_unique_instance_key = (
1211          prefer_unique_instance_key)
1212      collective, devices, task_id = self.make_collective(
1213          num_processes, required_gpus)
1214      if task_id != 0:
1215        return
1216
1217      v = make_per_replica_value(
1218          IndexedSlicesValue(
1219              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
1220      options = collective_util.Options(
1221          timeout_seconds=1., implementation=implementation)
1222
1223      @def_function.function
1224      def batch_reduce_sparse():
1225        return collective.batch_reduce(reduce_util.ReduceOp.SUM,
1226                                       [(v, v), (v, v)], options)
1227
1228      # The collective should time out because we only launch it on worker-0,
1229      # while there're two workers in total.
1230      with self.assertRaises(errors.DeadlineExceededError):
1231        batch_reduce_sparse()
1232
1233    get_global_mpr(num_processes).run(replica_fn)
1234
1235  @combinations.generate(combinations.combine(num_processes=1, required_gpus=2))
1236  def testNcclOrdering(self, num_processes, required_gpus):
1237
1238    if num_processes != required_gpus:
1239      self.skipTest("Skip NCCL combination with mismatched process and GPU "
1240                    "count. NCCL requires physical GPUs for every process.")
1241
1242    def replica_fn():
1243      CollectiveReplicaLauncher._prefer_unique_instance_key = True
1244      CollectiveReplicaLauncher._prefer_ordering_token = True
1245      collective, devices, _ = self.make_collective(num_processes,
1246                                                    required_gpus)
1247      options = collective_util.Options(
1248          implementation=CommunicationImplementation.NCCL)
1249
1250      v_dense = make_per_replica_value([1.0, 1.0], devices)
1251      v_sparse = make_per_replica_value([
1252          IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
1253          IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
1254      ], devices)
1255
1256      @def_function.function
1257      def nested_dense():
1258        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1259
1260      @def_function.function
1261      def nested_sparse():
1262        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1263
1264      # All collectives, function calls, if clause and while loops should be
1265      # chained by control dependencies, so that the execution order is
1266      # deterministic.
1267      @def_function.function
1268      def f():
1269        # pylint: disable=pointless-statement
1270        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1271        # reducing dense value.
1272        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1273        # reducing sparse value.
1274        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1275        # reduce dense value in nested tf.function.
1276        nested_dense()
1277        # reduce sparse value in nested tf.function.
1278        nested_sparse()
1279        # reduce dense value in tf.cond.
1280        if array_ops.identity(1.0) > array_ops.identity(2.0):
1281          collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1282        else:
1283          v_dense
1284        # reduce sparse value in tf.cond.
1285        if array_ops.identity(1.0) > array_ops.identity(2.0):
1286          v_sparse
1287        else:
1288          collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
1289                            options)
1290        # reduce dense value in tf.while_loop.
1291        i = array_ops.identity(1)
1292        while i < 3:
1293          collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1294          i += 1
1295        # reduce sparse value in tf.while_loop.
1296        i = array_ops.identity(1)
1297        while i < 3:
1298          collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
1299                            options)
1300          i += 1
1301        # reducing dense and sparse value again.
1302        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1303        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1304        # pylint: enable=pointless-statement
1305
1306      graph = f.get_concrete_function().graph
1307      should_be_ordered = set([
1308          "CollectiveReduceV2", "CollectiveGatherV2", "If", "While",
1309          "StatefulPartitionedCall"
1310      ])
1311      nodes_by_device = {}
1312      for op in graph.get_operations():
1313        if op.type in should_be_ordered:
1314          if op.device not in nodes_by_device:
1315            nodes_by_device[op.device] = []
1316          nodes_by_device[op.device].append(op)
1317      order = test_util.topological_sort_operations(graph.get_operations())
1318      for device in devices:
1319        device = device_util.canonicalize(device)
1320        # Those function ops don't have device annotations, but they contain
1321        # collectives for both devices so we always include them.
1322        operations = nodes_by_device[device] + nodes_by_device[""]
1323        # Verify that we get all types of nodes we want.
1324        self.assertEqual(set(op.type for op in operations), should_be_ordered)
1325        test_util.assert_sequential_execution(order, operations)
1326
1327    get_global_mpr(num_processes).run(replica_fn)
1328
1329
1330if __name__ == "__main__":
1331  # Set default inter op thread pool size to one to ensure we don't exhaust the
1332  # thread pool with the additional executors to run collectives in eager.
1333  os.environ["TF_NUM_INTEROP_THREADS"] = "1"
1334  # TODO(b/172304955): figure why logical devices doesn't work.
1335  test_util.main(config_logical_devices=False)
1336