xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/cross_device_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different algorithms of reduction and broadcasting."""
16
17import collections
18import copy
19import multiprocessing.dummy
20import multiprocessing.pool
21import threading
22
23import six
24
25from tensorflow.python.client import device_lib
26from tensorflow.python.distribute import collective_util
27from tensorflow.python.distribute import cross_device_utils
28from tensorflow.python.distribute import device_util
29from tensorflow.python.distribute import distribute_utils
30from tensorflow.python.distribute import ps_values
31from tensorflow.python.distribute import reduce_util
32from tensorflow.python.distribute import tpu_values
33from tensorflow.python.distribute import values as value_lib
34from tensorflow.python.distribute import values_util
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.framework import indexed_slices
38from tensorflow.python.framework import kernels
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import tf_export
47from tensorflow.tools.docs import doc_controls
48
49
50def check_destinations(destinations):
51  """Checks whether `destinations` is not empty.
52
53  Args:
54    destinations: a `DistributedValues`, variable, or string object.
55
56  Returns:
57    Boolean which is True if `destinations` is not empty.
58  """
59  # Calling bool() on a ResourceVariable is not allowed.
60  if isinstance(destinations,
61                (resource_variable_ops.BaseResourceVariable, ops.Tensor)):
62    return bool(destinations.device)
63  return bool(destinations)
64
65
66def validate_destinations(destinations):
67  """Validates the `destination` is one of expected types."""
68  if not isinstance(
69      destinations,
70      (value_lib.DistributedValues, ops.Tensor, indexed_slices.IndexedSlices,
71       ps_values.AggregatingVariable, six.string_types,
72       tpu_values.TPUMirroredVariable
73      )) and not resource_variable_ops.is_resource_variable(destinations):
74    raise ValueError("destinations must be one of a `DistributedValues` object,"
75                     " a tf.Variable object, or a device string.")
76
77  if not check_destinations(destinations):
78    raise ValueError("destinations can not be empty")
79
80
81def reduce_non_distributed_value(reduce_op,
82                                 value,
83                                 destinations,
84                                 num_replicas_in_graph,
85                                 canonicalize_devices=True):
86  """Reduce a non-DistributedValue `value` to `destinations`."""
87  if isinstance(value, value_lib.DistributedValues):
88    raise ValueError("You are passing a `DistributedValues` to "
89                     "`reduce_non_distributed_value`, which is not allowed.")
90
91  # If the same value is present on all replicas then the PerReplica value will
92  # be a single value. We also handle the case when `value` is a single value
93  # and equal to 0.
94  # TODO:(b/138823479): handle the tensor value properly.
95  if not tensor_util.is_tf_type(value) and value == 0:
96    return 0
97  # If there is only a single value and the reduce op is MEAN,
98  # that value should be on all destinations.
99  if reduce_op == reduce_util.ReduceOp.MEAN:
100    return value
101  elif num_replicas_in_graph != 1:
102    # We do not support a reduce op of SUM if the value is the same across
103    # all replicas. We call this as part of assign functions for
104    # MirroredVariables and summing up identical values across replicas is not
105    # clearly defined.
106    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
107                     "the given reduce op %s." % (value, reduce_op))
108  else:
109    validate_destinations(destinations)
110    return simple_broadcast(
111        value, destinations, canonicalize_devices=canonicalize_devices)
112
113
114def _make_tensor_into_per_replica(input_tensor):
115  """Converts a single tensor into a PerReplica object."""
116  if isinstance(input_tensor, value_lib.DistributedValues):
117    return input_tensor
118
119  # If input is not a Tensor, convert it to a Tensor first.
120  if not tensor_util.is_tensor(input_tensor):
121    input_tensor = ops.convert_to_tensor(input_tensor)
122
123  if hasattr(input_tensor, "device"):
124    return value_lib.PerReplica((input_tensor,))
125
126  raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
127                   "because it doesn't have device set.")
128
129
130def _normalize_value_destination_pairs(value_destination_pairs):
131  """Converts each tensor into a PerReplica object in the input list."""
132  result = []
133
134  value_destination_pairs = list(value_destination_pairs)
135
136  if not isinstance(value_destination_pairs, (list, tuple)):
137    raise ValueError("`value_destination_pairs` should be a list or tuple")
138  for pair in value_destination_pairs:
139    if not isinstance(pair, tuple):
140      raise ValueError(
141          "Each element of `value_destination_pairs` should be a tuple.")
142    if len(pair) != 2:
143      raise ValueError("Each element of `value_destination_pairs` should be a "
144                       "tuple of size 2.")
145
146    per_replica = _make_tensor_into_per_replica(pair[0])
147    result.append((per_replica, pair[1]))
148  return result
149
150
151def _validate_value_destination_pairs(value_destination_pairs):
152  """Validates value_destination_pairs are valid."""
153  # TODO(yuefengz): raise exceptions instead of returning False.
154  if not value_destination_pairs: return False
155  if not isinstance(value_destination_pairs, (list, tuple)): return False
156  if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
157    return False
158  if not all(isinstance(v[0], value_lib.PerReplica)
159             for v in value_destination_pairs):
160    return False
161  return True
162
163
164# TODO(yuefengz): consider calling this function in the caller of
165# CrossDeviceOps.
166def get_devices_from(destinations, canonicalize_devices=True):
167  if isinstance(destinations, value_lib.DistributedValues):
168    return destinations._devices  # pylint: disable=protected-access
169  if canonicalize_devices:
170    if isinstance(destinations, six.string_types):
171      return (device_util.resolve(destinations),)
172    return (device_util.resolve(destinations.device),)
173
174  # Let placer canonicalize and resolve destination devices.
175  if isinstance(destinations, six.string_types):
176    return (device_util.canonicalize_without_job_and_task(destinations),)
177  return (device_util.canonicalize_without_job_and_task(destinations.device),)
178
179
180def _devices_match(left, right, canonicalize_devices=True):
181  return left is right or set(get_devices_from(
182      left, canonicalize_devices)) == set(
183          get_devices_from(right, canonicalize_devices))
184
185
186def _all_devices_match(value_destination_pairs, canonicalize_devices=True):
187  if not all(
188      _devices_match(v, d, canonicalize_devices)
189      for v, d in value_destination_pairs):
190    return False
191  if not all(
192      _devices_match(v, value_destination_pairs[0][0], canonicalize_devices)
193      for v, _ in value_destination_pairs[1:]):
194    return False
195  return True
196
197
198def simple_broadcast(value,
199                     destinations,
200                     always_mirrored=False,
201                     canonicalize_devices=True):
202  """Broadcast `value` to `destinations` using simple copies."""
203  devices = get_devices_from(destinations, canonicalize_devices)
204  if len(devices) == 1 and not always_mirrored:
205    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
206        value, devices[0])
207  else:
208    value_updates = []
209    for d in devices:
210      value_updates.append(
211          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
212    return distribute_utils.regroup(value_updates,
213                                    wrap_class=value_lib.Mirrored)
214
215
216def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
217                   reduce_op):
218  """Reduces the value by accumulation_fn and reduce_op."""
219  all_values = per_replica_value.values
220  if not all_values:
221    raise ValueError("`per_replica_value` must be non-empty")
222  count = len(all_values)
223
224  with ops.device(reduce_to_device):
225    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
226      reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
227          all_values, accumulation_fn)
228      if reduce_op == reduce_util.ReduceOp.MEAN:
229        reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
230            reduced, count)
231      elif reduce_op != reduce_util.ReduceOp.SUM:
232        raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
233  return reduced
234
235
236def _simple_gather(per_replica_value, reduce_to_device, axis):
237  """Concatenate all values in the DistributedValues input and return."""
238  all_values = per_replica_value.values
239  if not all_values:
240    raise ValueError("`per_replica_value` must be non-empty")
241
242  with ops.device(reduce_to_device):
243    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
244      gathered = array_ops.concat(all_values, axis)
245  return gathered
246
247
248@tf_export("distribute.CrossDeviceOps")
249class CrossDeviceOps(object):
250  """Base class for cross-device reduction and broadcasting algorithms.
251
252  The main purpose of this class is to be passed to
253  `tf.distribute.MirroredStrategy` in order to choose among different cross
254  device communication implementations. Prefer using the methods of
255  `tf.distribute.Strategy` instead of the ones of this class.
256
257  Implementations:
258  * `tf.distribute.ReductionToOneDevice`
259  * `tf.distribute.NcclAllReduce`
260  * `tf.distribute.HierarchicalCopyAllReduce`
261  """
262
263  def __init__(self):
264    self._canonicalize_devices = True
265    pass
266
267  @property
268  def _num_between_graph_workers(self):
269    # Returns 1 by default, the value may be overridden by sub classes.
270    return 1
271
272  def reduce(self, reduce_op, per_replica_value, destinations, options=None):
273    """Reduce `per_replica_value` to `destinations`.
274
275    See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
276    the cross-replica context.
277
278    Args:
279      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
280        combined.
281      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
282        like object.
283      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
284        `tf.Tensor` alike object, or a device string. It specifies the devices
285        to reduce to. To perform an all-reduce, pass the same to `value` and
286        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
287        to the devices of that variable, and this method doesn't update the
288        variable.
289      options: a `tf.distribute.experimental.CommunicationOptions`. See
290        `tf.distribute.experimental.CommunicationOptions` for details.
291
292    Returns:
293      A `tf.Tensor` or `tf.distribute.DistributedValues`.
294
295    Raises:
296      ValueError: if per_replica_value can't be converted to a
297        `tf.distribute.DistributedValues` or if destinations is not a string,
298        `tf.Variable` or `tf.distribute.DistributedValues`.
299    """
300    if options is None:
301      options = collective_util.Options()
302
303    per_replica_value = _make_tensor_into_per_replica(per_replica_value)
304
305    validate_destinations(destinations)
306
307    # Shortcut if `per_replica_value` only contains one value.
308    if self._num_between_graph_workers == 1 and len(
309        per_replica_value.values) == 1 and _devices_match(
310            per_replica_value, destinations, self._canonicalize_devices):
311      with ops.device(per_replica_value.values[0].device):
312        v = array_ops.identity(per_replica_value.values[0])
313      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
314
315    if options is None:
316      options = collective_util.Options()
317    return self.reduce_implementation(reduce_op, per_replica_value,
318                                      destinations, options)
319
320  def _gather(self, per_replica_value, destinations, axis, options=None):
321    """Gather `per_replica_value` to `destinations`.
322
323    Args:
324      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
325        like object.
326      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
327        `tf.Tensor` alike object, or a device string. It specifies the devices
328        to gather to. To perform an all-gather, pass the same to `value` and
329        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
330        to the devices of that variable, and this method doesn't update the
331        variable.
332      axis: specifies the dimension to gather along within each replica's
333        tensor.
334      options: a `tf.distribute.experimental.CommunicationOptions`. See
335        `tf.distribute.experimental.CommunicationOptions` for details.
336
337    Returns:
338      A `tf.Tensor` or `tf.distribute.DistributedValues`
339
340    Raises:
341      ValueError: if per_replica_value can't be converted to a
342        `tf.distribute.DistributedValues` or if destinations is not a string,
343        `tf.Variable` or `tf.distribute.DistributedValues`.
344    """
345    if isinstance(per_replica_value, indexed_slices.IndexedSlices):
346      raise NotImplementedError("gather/all_gather does not support "
347                                "IndexedSlices")
348    if options is None:
349      options = collective_util.Options()
350
351    per_replica_value = _make_tensor_into_per_replica(per_replica_value)
352
353    validate_destinations(destinations)
354
355    # Shortcut if `per_replica_value` only contains one value.
356    if self._num_between_graph_workers == 1 and len(
357        per_replica_value.values) == 1 and _devices_match(
358            per_replica_value, destinations, self._canonicalize_devices):
359      with ops.device(per_replica_value.values[0].device):
360        v = array_ops.identity(per_replica_value.values[0])
361      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
362
363    return self._gather_implementation(per_replica_value, destinations, axis,
364                                       options)
365
366  def _gather_implementation(self, per_replica_value, destinations, axis,
367                             options):
368    """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
369
370    Overriding this method is useful for subclass implementers.
371
372    Args:
373      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
374        like object.
375      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
376        `tf.Tensor` alike object, or a device string. It specifies the devices
377        to gather to. To perform an all-gather, pass the same to `value` and
378        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
379        to the devices of that variable, this method doesn't update the
380        variable.
381      axis: specifies the dimension to gather along within each replica's
382        tensor.
383      options: a `tf.distribute.experimental.CommunicationOptions`. See
384        `tf.distribute.experimental.CommunicationOptions` for details.
385
386    Returns:
387      A `tf.Tensor` or `tf.distribute.DistributedValues`.
388
389    Raises:
390      ValueError: if per_replica_value can't be converted to a
391        `tf.distribute.DistributedValues` or if destinations is not a string,
392        `tf.Variable` or `tf.distribute.DistributedValues`.
393    """
394    raise NotImplementedError(
395        "_gather method must be implemented in descendants.")
396
397  def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
398    """Reduce values to destinations in batches.
399
400    See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
401    called in the cross-replica context.
402
403    Args:
404      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
405        combined.
406      value_destination_pairs: a sequence of (value, destinations) pairs. See
407        `tf.distribute.CrossDeviceOps.reduce` for descriptions.
408      options: a `tf.distribute.experimental.CommunicationOptions`. See
409        `tf.distribute.experimental.CommunicationOptions` for details.
410
411    Returns:
412      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
413      in `value_destination_pairs`.
414
415    Raises:
416      ValueError: if `value_destination_pairs` is not an iterable of
417        tuples of `tf.distribute.DistributedValues` and destinations.
418    """
419    if options is None:
420      options = collective_util.Options()
421    # TODO(yuefengz): if destinations are different, split into several
422    # `_batch_reduce` invocations.
423    if not _validate_value_destination_pairs(value_destination_pairs):
424      # If the first element of each pair is a tensor, we try to turn it into a
425      # PerReplica object.
426      value_destination_pairs = _normalize_value_destination_pairs(
427          value_destination_pairs)
428
429    for _, d in value_destination_pairs:
430      validate_destinations(d)
431
432    # Shortcut all PerReplica objects only contain one value.
433    if self._num_between_graph_workers == 1 and _all_devices_match(
434        value_destination_pairs, self._canonicalize_devices) and len(
435            value_destination_pairs[0][0].values) == 1:
436      return [
437          distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
438          for v, _ in value_destination_pairs
439      ]
440
441    if options is None:
442      options = collective_util.Options()
443    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
444                                            options)
445
446  def broadcast(self, tensor, destinations):
447    """Broadcast `tensor` to `destinations`.
448
449    This can only be called in the cross-replica context.
450
451    Args:
452      tensor: a `tf.Tensor` like object. The value to broadcast.
453      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
454        `tf.Tensor` alike object, or a device string. It specifies the devices
455        to broadcast to. Note that if it's a `tf.Variable`, the value is
456        broadcasted to the devices of that variable, this method doesn't update
457        the variable.
458
459    Returns:
460      A `tf.Tensor` or `tf.distribute.DistributedValues`.
461    """
462    validate_destinations(destinations)
463    return self.broadcast_implementation(tensor, destinations)
464
465  @doc_controls.for_subclass_implementers
466  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
467                            options):
468    """Implementation of `reduce`.
469
470    Overriding this method is useful for subclass implementers.
471
472    Args:
473      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
474        combined.
475      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
476        like object.
477      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
478        `tf.Tensor` alike object, or a device string. It specifies the devices
479        to reduce to. To perform an all-reduce, pass the same to `value` and
480        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
481        to the devices of that variable, this method doesn't update the
482        variable.
483      options: a `tf.distribute.experimental.CommunicationOptions`. See
484        `tf.distribute.experimental.CommunicationOptions` for details.
485
486    Returns:
487      A `tf.Tensor` or `tf.distribute.DistributedValues`.
488
489    Raises:
490      ValueError: if per_replica_value can't be converted to a
491        `tf.distribute.DistributedValues` or if destinations is not a string,
492        `tf.Variable` or `tf.distribute.DistributedValues`.
493    """
494    raise NotImplementedError(
495        "_reduce method must be implemented in descendants.")
496
497  @doc_controls.for_subclass_implementers
498  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
499                                  options):
500    """Implementation of `batch_reduce`.
501
502    Overriding this method is useful for subclass implementers.
503
504    Args:
505      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
506        combined.
507      value_destination_pairs: a sequence of (value, destinations) pairs. See
508        `reduce` for descriptions.
509      options: a `tf.distribute.experimental.CommunicationOptions`. See
510        `tf.distribute.experimental.CommunicationOptions` for details.
511
512    Returns:
513      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
514      in `value_destination_pairs`.
515
516    Raises:
517      ValueError: if `value_destination_pairs` is not an iterable of
518        tuples of `tf.distribute.DistributedValues` and destinations.
519    """
520    raise NotImplementedError(
521        "batch_reduce_implementation method must be implemented in descendants."
522    )
523
524  @doc_controls.for_subclass_implementers
525  def broadcast_implementation(self, tensor, destinations):
526    """Implementation of `broadcast`.
527
528    Args:
529      tensor: a `tf.Tensor` like object. The value to broadcast.
530      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
531        `tf.Tensor` alike object, or a device string. It specifies the devices
532        to broadcast to.
533        `destinations`. Note that if it's a `tf.Variable`, the value is
534        broadcasted to the devices of that variable, this method doesn't update
535        the variable.
536
537    Returns:
538      A `tf.Tensor` or `tf.distribute.DistributedValues`.
539    """
540    return simple_broadcast(
541        tensor,
542        destinations,
543        always_mirrored=True,
544        canonicalize_devices=self._canonicalize_devices)
545
546  # ========================== Collective APIs ================================
547  #
548  # Different than `reduce`, `batch_reduce` and `broadcast` which must be called
549  # in cross-replcia context, collective APIs are to be called in replica
550  # context.
551
552  def _all_reduce(self, reduce_op, value, replica_id, options):
553    """All-reduce the `value` across all replicas so that all get the result.
554
555    `value` can be a nested structure of tensors or `IndexedSlices`. The
556    implementation should generally batch the all-reduces when possible.
557    `options` can be set to hint the batching behavior.
558
559    This API must be called in a replica context.
560
561    Args:
562      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
563        be combined.
564      value: Value to be reduced. A tensor or a nested structure of tensors or
565        `IndexedSlices`.
566      replica_id: An interger indicating the id of the replica where this
567        all_reduce is called under. This is the local replica id that ranges
568        from 0 to len(local_devices) - 1.
569      options: A `tf.distribute.experimental.CommunicationOptions`.
570
571    Returns:
572      A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
573      the reduced values. The structure is the same as `value`.
574    """
575    raise NotImplementedError("_all_reduce must be implemented in descendants.")
576
577
578@tf_export("distribute.ReductionToOneDevice")
579class ReductionToOneDevice(CrossDeviceOps):
580  """A CrossDeviceOps implementation that copies values to one device to reduce.
581
582  This implementation always copies values to one device to reduce them, then
583  broadcast reduced values to the destinations. It doesn't support efficient
584  batching.
585
586  Here is how you can use `ReductionToOneDevice` in
587  `tf.distribute.MirroredStrategy`:
588
589  ```
590    strategy = tf.distribute.MirroredStrategy(
591      cross_device_ops=tf.distribute.ReductionToOneDevice())
592  ```
593  """
594
595  def __init__(self, reduce_to_device=None, accumulation_fn=None):
596    """Initializes with a device to reduce to and a way to accumulate.
597
598    Args:
599      reduce_to_device: the intermediate device to reduce to. If None, reduce
600        to the first device in `destinations` of the `reduce` method.
601      accumulation_fn: a function that does accumulation.  If None,
602        `tf.math.add_n` is used.
603    """
604    self.reduce_to_device = reduce_to_device
605    self.accumulation_fn = accumulation_fn or math_ops.add_n
606    super(ReductionToOneDevice, self).__init__()
607
608  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
609                            options):
610    del options  # Unused.
611    if check_destinations(destinations):
612      devices = get_devices_from(destinations, self._canonicalize_devices)
613    else:
614      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
615    reduce_to_device = self.reduce_to_device or devices[0]
616    logging.log_first_n(
617        logging.INFO,
618        "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
619    reduced = _simple_reduce(per_replica_value, reduce_to_device,
620                             self.accumulation_fn, reduce_op)
621    return self.broadcast(reduced, destinations)
622
623  def _gather_implementation(self, per_replica_value, destinations, axis,
624                             options):
625    del options  # Unused.
626    if check_destinations(destinations):
627      devices = get_devices_from(destinations, self._canonicalize_devices)
628    else:
629      devices = get_devices_from(per_replica_value, self._canonicalize_devices)
630    reduce_to_device = self.reduce_to_device or devices[0]
631    logging.log_first_n(
632        logging.INFO,
633        "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
634    gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
635    return self.broadcast(gathered, destinations)
636
637  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
638                                  options):
639    return [
640        self.reduce_implementation(
641            reduce_op, t, destinations=v, options=options)
642        for t, v in value_destination_pairs
643    ]
644
645
646def _group_value_by_device(per_replica_values):
647  """Group values into sublists by their devices.
648
649  This grouping is needed to call the all-reduce library because it expects a
650  list of the following form:
651    [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
652     [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
653     [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
654     ...
655    ]
656
657  Args:
658    per_replica_values: a list of PerReplica objects.
659
660  Returns:
661    a list of lists, each sublist has components for its corresponding device of
662      PerReplica objects, paired with a None.
663  """
664  destinations = per_replica_values[0]._devices  # pylint: disable=protected-access
665  grouped = [[] for _ in range(len(destinations))]
666  for per_replica_value in per_replica_values:
667    # pylint: disable=protected-access
668    for i, v in enumerate(per_replica_value.values):
669      assert per_replica_value._devices == destinations
670      grouped[i].append((v, None))
671  return grouped
672
673
674def _ungroup_and_make_mirrored(grouped_reduced,
675                               destinations,
676                               reduce_op,
677                               num_between_graph_workers=1):
678  """Ungroup results from all-reduce and make Mirrored objects.
679
680  Each all-reduce result will be divided by the number of destinations before
681  Mirrored objects are created if reduce_op is "mean".
682
683  Args:
684    grouped_reduced: a list of lists, each sublist has components for each
685      device, paired with a None. It is the result from
686      cross_device_utils.aggregate_gradients_using*.
687    destinations: a value to colocate the result with.
688    reduce_op: Indicates how values will be aggregated. Accepted values
689      are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
690    num_between_graph_workers: number of workers in the between-graph
691      replication.
692
693  Returns:
694    a list of Mirrored objects.
695  """
696  num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
697  index = [[] for _ in range(len(grouped_reduced[0]))]
698  for per_replica_reduced in grouped_reduced:
699    for i, (v, _) in enumerate(per_replica_reduced):
700      if reduce_op == reduce_util.ReduceOp.MEAN:
701        with ops.device(v.device):
702          index[i].append(v / num_replicas)
703      else:
704        index[i].append(v)
705  return [distribute_utils.regroup(
706      v, wrap_class=value_lib.Mirrored) for v in index]
707
708
709class _ConcatAndSplitPacker(object):
710  """Concatenate and split tensors for reduction."""
711
712  def __init__(self, num_packs=1):
713    """Initialize the _ConcatAndSplitPacker object.
714
715    Args:
716      num_packs: specifies the number of split packs that will be
717        formed.
718
719    Raises:
720      ValueError: if num_packs is not greater than 0.
721    """
722    if num_packs <= 0:
723      raise ValueError("num_packs must be greater than zero.")
724    self.num_packs = num_packs
725
726  def pack(self, grouped_grads_and_vars):
727    """Pack tensors."""
728    self.grouped_grads_and_vars = grouped_grads_and_vars
729    self.all_device_shapes = []
730    self.all_device_sizes = []
731
732    device_grad_packs = []
733    for device_grads_and_vars in grouped_grads_and_vars:
734      with ops.colocate_with(device_grads_and_vars[0][0]):
735        # Flatten all the grads.
736        flat_grads = [
737            array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
738        ]
739        # Remember the original shape of all the grads.
740        device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
741        # Remember the original sizes of all the grads.
742        device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
743        # Concat all the flat grads into a big flat tensor.
744        concat_grads = array_ops.concat(flat_grads, 0)
745
746        # Split the big tensor into num_splits packs. In cases where the
747        # total size is not divisible num_splits, the last pack gets
748        # more elements.
749        # TODO(zhengxq): it is also possible to optimize away all the concat
750        # as well.
751        num_splits = self.num_packs
752
753        # The array_ops.size function will sometimes remove static shapes. So if
754        # all gradient shapes are defined, we use another method to get the
755        # total size.
756        # TODO(yuefengz): move this logic to array_ops.size.
757        if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
758          total_grad_size = sum(
759              [g.shape.num_elements() for g, _ in device_grads_and_vars])
760        else:
761          total_grad_size = array_ops.size(concat_grads)
762
763        split_size = total_grad_size // num_splits
764        split_size_last = total_grad_size - split_size * (num_splits - 1)
765        split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
766        grad_packs = array_ops.split(concat_grads, split_sizes)
767
768        # Ready to aggregate the repacked gradients, with fake variables.
769        # TODO(zhengxq): It is hacky to have to use fake variables.
770        # We should remove the need for variables in
771        # aggregate_gradients_using*.
772        device_grad_packs.append(zip(grad_packs, [None] * num_splits))
773        self.all_device_shapes.append(device_shapes)
774        self.all_device_sizes.append(device_sizes)
775
776    return device_grad_packs
777
778  def unpack(self, summed_device_grad_packs):
779    """Reverse the pack."""
780    aggregated_device_grads = []
781    for (summed_device_grad_packs,
782         device_grads_and_vars, device_shapes, device_sizes) in zip(
783             summed_device_grad_packs, self.grouped_grads_and_vars,
784             self.all_device_shapes, self.all_device_sizes):
785      # pylint: enable=line-too-long
786      # Reverse the packing operations in the previous steps. Form the
787      # summed gradients back into their original shapes.
788      with ops.colocate_with(summed_device_grad_packs[0][0]):
789        # Form a list of the summed grad packs.
790        device_grad_packs = [g for g, _ in summed_device_grad_packs]
791
792        # Concat them back into a big flat tensor.
793        device_grads_concat = array_ops.concat(device_grad_packs, 0)
794
795        # Split the tensors back into their original sizes.
796        grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
797
798        # Reshape the tensors back into their original shapes.
799        grads_with_shapes = [
800            array_ops.reshape(grad, shape)
801            for shape, grad in zip(device_shapes, grads_with_sizes)
802        ]
803
804        # Form the list with the original list of variables.
805        summed_device_grads = [
806            (g, v) for g, (_, v) in zip(grads_with_shapes,
807                                        device_grads_and_vars)
808        ]
809        aggregated_device_grads.append(summed_device_grads)
810    return aggregated_device_grads
811
812
813def _pack_tensors(device_grads, num_packs=0):
814  """Pack tensors if specified."""
815  if num_packs > 0:
816    tensor_packer = _ConcatAndSplitPacker(num_packs)
817    device_grad_packs = tensor_packer.pack(device_grads)
818  else:
819    tensor_packer = None
820    device_grad_packs = device_grads
821  return device_grad_packs, tensor_packer
822
823
824def _unpack_tensors(reduced, tensor_packer=None):
825  """Unpack tensors if they are packed before all-reduce."""
826  if tensor_packer:
827    return tensor_packer.unpack(reduced)
828  return reduced
829
830
831class AllReduceCrossDeviceOps(CrossDeviceOps):
832  """All-reduce implementation of CrossDeviceOps.
833
834  It performs all-reduce when applicable using NCCL or hierarchical copy. For
835  the batch API, tensors will be repacked or aggregated for more efficient
836  cross-device transportation.
837
838  For reduces that are not all-reduce, it falls back to
839  `tf.distribute.ReductionToOneDevice`.
840  """
841
842  def __init__(self, all_reduce_alg="nccl", num_packs=1):
843    """Initializes the object.
844
845    Args:
846      all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
847        "hierarchical_copy" are supported.
848      num_packs: a non-negative integer. The number of packs to split values
849        into. If zero, no packing will be done.
850    """
851    self._all_reduce_alg = all_reduce_alg
852    self._num_packs = num_packs
853    self._simple_cross_replica_ops = ReductionToOneDevice()
854    super(AllReduceCrossDeviceOps, self).__init__()
855
856  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
857                            options):
858    del options  # Unused.
859    # To use NCCL or all-reduce, source and destination devices should match,
860    # and none of the devices should be CPU.
861    if (_devices_match(per_replica_value, destinations) and
862        not any("cpu" in d.lower() for d in get_devices_from(destinations))):
863      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
864    else:
865      return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
866                                                   destinations)
867
868  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
869                                  options):
870    if _all_devices_match(value_destination_pairs):
871      return self._batch_all_reduce(reduce_op,
872                                    [v[0] for v in value_destination_pairs])
873    else:
874      return [
875          self.reduce_implementation(reduce_op, value, dest, options)
876          for value, dest in value_destination_pairs
877      ]
878
879  def _batch_all_reduce(self, reduce_op, per_replica_values):
880    """All-reduce algorithm in a batch."""
881    dense_values, dense_indices, sparse_values, sparse_indices = (
882        cross_device_utils.split_by_sparsity(per_replica_values))
883    if dense_values:
884      dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
885    else:
886      dense_results = []
887    if sparse_values:
888      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
889                                                        sparse_values)
890    else:
891      sparse_results = []
892    return cross_device_utils.stitch_values(((dense_results, dense_indices),
893                                             (sparse_results, sparse_indices)))
894
895  def _do_batch_all_reduce(self, reduce_op, dense_values):
896    """Run batch all-reduces."""
897    logging.log_first_n(
898        logging.INFO,
899        "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
900        (len(dense_values), self._all_reduce_alg, self._num_packs), 10)
901
902    destinations = dense_values[0]._devices  # pylint: disable=protected-access
903    grouped = _group_value_by_device(dense_values)
904
905    # device_grad_packs:
906    # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
907    device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
908
909    # The actual aggregation of the repacked gradients. Note that they are
910    # sharded among different aggregation trees. So it is important to strike
911    # the balance on num_splits.
912    if self._all_reduce_alg == "nccl":
913      # TODO(yuefengz): merge this into the all-reduce library.
914      reduced = cross_device_utils.aggregate_gradients_using_nccl(
915          device_grad_packs)
916    else:
917      # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
918      # order.
919      reduced = (
920          cross_device_utils.aggregate_gradients_using_hierarchical_copy(
921              destinations, device_grad_packs))
922
923    reduced = _unpack_tensors(reduced, tensor_packer)
924    return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
925
926  def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
927    """Run batch all-reduce for sparse values."""
928    logging.log_first_n(
929        logging.WARN,
930        "Efficient allreduce is not supported for %d IndexedSlices" %
931        len(sparse_values), 10)
932    # Use `sparse_values` as destinations to do all-reduces. It is effectively
933    # an allgather under the hood but not an efficient one.
934    return self._simple_cross_replica_ops.batch_reduce(
935        reduce_op, zip(sparse_values, sparse_values))
936
937  def _gather_implementation(self, per_replica_value, destinations, axis,
938                             options):
939    logging.log_first_n(
940        logging.WARN,
941        "gather/all_gather with NCCL or HierarchicalCopy is not supported. "
942        "Falling back to gather on one device and then broadcast. We're working"
943        " on a more efficient implementation.", 3)
944    return ReductionToOneDevice()._gather(per_replica_value, destinations, axis,  # pylint: disable=protected-access
945                                          options)
946
947
948# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
949AllReduceCrossTowerOps = AllReduceCrossDeviceOps
950
951
952AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
953                                            "alg shards limit")
954
955
956@tf_export("distribute.NcclAllReduce")
957class NcclAllReduce(AllReduceCrossDeviceOps):
958  """NCCL all-reduce implementation of CrossDeviceOps.
959
960  It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
961  repacked or aggregated for more efficient cross-device transportation.
962
963  For reduces that are not all-reduce, it falls back to
964  `tf.distribute.ReductionToOneDevice`.
965
966  Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
967
968
969  ```
970    strategy = tf.distribute.MirroredStrategy(
971      cross_device_ops=tf.distribute.NcclAllReduce())
972  ```
973  """
974
975  def __init__(self, num_packs=1):
976    """Initializes the object.
977
978    Args:
979      num_packs: a non-negative integer. The number of packs to split values
980        into. If zero, no packing will be done.
981
982    Raises:
983      ValueError: if `num_packs` is negative.
984    """
985    if num_packs < 0:
986      raise ValueError(
987          "NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
988              num_packs))
989    super(NcclAllReduce, self).__init__(
990        all_reduce_alg="nccl", num_packs=num_packs)
991
992
993@tf_export("distribute.HierarchicalCopyAllReduce")
994class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
995  """Hierarchical copy all-reduce implementation of CrossDeviceOps.
996
997  It reduces to one GPU along edges in some hierarchy and broadcasts back to
998  each GPU along the same path. For the batch API, tensors will be repacked or
999  aggregated for more efficient cross-device transportation.
1000
1001  This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
1002  that on DGX-1 machine. If you have different GPU inter-connections, it is
1003  likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
1004
1005  For reduces that are not all-reduce, it falls back to
1006  `tf.distribute.ReductionToOneDevice`.
1007
1008  Here is how you can use `HierarchicalCopyAllReduce` in
1009  `tf.distribute.MirroredStrategy`:
1010
1011  ```
1012    strategy = tf.distribute.MirroredStrategy(
1013      cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
1014  ```
1015  """
1016
1017  def __init__(self, num_packs=1):
1018    """Initializes the object.
1019
1020    Args:
1021      num_packs: a non-negative integer. The number of packs to split values
1022        into. If zero, no packing will be done.
1023
1024    Raises:
1025      ValueError if `num_packs` is negative.
1026    """
1027    if num_packs < 0:
1028      raise ValueError(
1029          "HierarchicalCopy requires num_packs >= 0, but {} is specified"
1030          .format(num_packs))
1031    super(HierarchicalCopyAllReduce, self).__init__(
1032        all_reduce_alg="hierarchical_copy",
1033        num_packs=num_packs)
1034
1035
1036# TODO(crccw): remove after migrating all callers.
1037CollectiveCommunication = collective_util.CommunicationImplementation
1038CommunicationImplementation = collective_util.CommunicationImplementation
1039
1040
1041# TODO(yuefengz): support in-graph collective all-reduce.
1042class CollectiveAllReduce(CrossDeviceOps):
1043  """All-reduce cross device ops using collective ops.
1044
1045  In the between-graph replicated training, it will still do all-reduces across
1046  all workers and then put results on the right destinations.
1047  """
1048
1049  def __init__(self,
1050               devices,
1051               group_size,
1052               options,
1053               collective_keys=None,
1054               canonicalize_devices=True):
1055    """Initializes the object.
1056
1057    Args:
1058      devices: a list of device strings to run collectives on.
1059      group_size: the global group size. For between-graph replicated training
1060        it's the total number of devices across all workers.
1061      options: a `tf.distribute.experimental.CommunicationOptions`.
1062      collective_keys: an optional CollectiveKey object.
1063      canonicalize_devices: Whether to canonicalize devices for workers or not.
1064    """
1065    if group_size % len(devices) > 0:
1066      raise ValueError("group_size must be divisible by the number of devices.")
1067
1068    self._group_size = group_size
1069    self._options = options
1070    self._collective_keys = (collective_keys or
1071                             cross_device_utils.CollectiveKeys())
1072    # This lock guards all collective launches, i.e. calls to
1073    # cross_device_utils.build_collectve_*.
1074    #
1075    # In a multi threaded eager program we need to ensure different groups of
1076    # collectives don't interleave each other, otherwise there could be
1077    # deadlocks. E.g. if two user threads both are launching collectives:
1078    #   user-thread-0  device0                 device1
1079    #   user-thread-1          device0 device1
1080    # In eager mode, we use one thread per device to launch collective ops, so
1081    # the above launch sequences end up with the following queues:
1082    #   device-0  collective-0  collective-1
1083    #   device-1  collective-1  collective-0
1084    # This deadlocks since neither collective is able to finish.
1085    self._lock = threading.Lock()
1086
1087    if canonicalize_devices:
1088      self._devices = tuple(device_util.canonicalize(d) for d in devices)
1089    else:
1090      self._devices = tuple(
1091          device_util.canonicalize_without_job_and_task(d) for d in devices)
1092    group_key = self._collective_keys.get_group_key(self._devices)
1093    self._launchers = []
1094    # Whether to only use NCCL for batched all-reduce when NCCL is requested.
1095    # This is because of the lack of mechanism to order NCCL operations
1096    # deterministically.
1097    self._limited_nccl = False
1098    for device in self._devices:
1099      launcher = cross_device_utils.CollectiveReplicaLauncher(
1100          group_key, group_size, self._collective_keys, device, options)
1101      self._launchers.append(launcher)
1102      if not launcher.can_order_nccl():
1103        self._limited_nccl = True
1104
1105    super(CollectiveAllReduce, self).__init__()
1106    self._canonicalize_devices = canonicalize_devices
1107
1108  @property
1109  def _num_between_graph_workers(self):
1110    # Currently we only support equal number of devices on each worker.
1111    return self._group_size / len(self._devices)
1112
1113  def _all_reduce(self, reduce_op, value, replica_id, options):
1114    """Implements CrossDeviceOps.all_reduce."""
1115    # TODO(b/122840926): reuse this method in _batch_all_reduce.
1116    flat_values = nest.flatten(value)
1117
1118    # If NCCL launches can't be ordered (self._limited_nccl == True), we only
1119    # use NCCL when batch_size > 1, hoping that there's only one batched
1120    # all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
1121    # NCCL launches are always ordered.
1122    if (self._limited_nccl and options.implementation
1123        == collective_util.CommunicationImplementation.NCCL and
1124        len(flat_values) == 1):
1125      options = options.merge(
1126          collective_util.Options(
1127              implementation=collective_util.CommunicationImplementation.RING))
1128
1129    launcher = self._launchers[replica_id]
1130    dense_values, dense_indices, sparse_values, sparse_indices = (
1131        cross_device_utils.split_by_sparsity(flat_values))
1132    dense_results = []
1133    sparse_results = []
1134
1135    if dense_values:
1136      # Reverse the lists so that there's better chance that values follows
1137      # the order in which they are calculated (e.g. when they're gradients), so
1138      # as to overlap calculation with communication. However, this may not be
1139      # optimal for cases like gradients of complicated non-sequential models.
1140      #
1141      # Note that we reverse the list before packing so that the first pack
1142      # won't be too small, since it's more likely for first few packs to have
1143      # long queuing time due to concurrent intense computation.
1144      #
1145      # TODO(b/147393503): explore solutions for optimal ordering.
1146      dense_values.reverse()
1147      packs = cross_device_utils.group_by_size(dense_values,
1148                                               options.bytes_per_pack)
1149
1150      if not context.executing_eagerly() and replica_id == 0:
1151        logging.info(
1152            "Collective all_reduce tensors: %d all_reduces, num_devices = %d, "
1153            "group_size = %d, implementation = %s, num_packs = %d",
1154            len(dense_values), len(self._launchers), self._group_size,
1155            options.implementation, len(packs))
1156
1157      dense_results = launcher.batch_all_reduce(packs, options)
1158      if reduce_op == reduce_util.ReduceOp.MEAN:
1159        for i, v in enumerate(dense_results):
1160          with ops.device(self._devices[replica_id]):
1161            dense_results[i] = v / self._group_size
1162      dense_results.reverse()
1163
1164    if sparse_values:
1165      if not context.executing_eagerly() and replica_id == 0:
1166        logging.info(
1167            "Collective all_reduce IndexedSlices: %d all_reduces, num_devices ="
1168            "%d, group_size = %d, implementation = %s", len(sparse_values),
1169            len(self._launchers), self._group_size, options.implementation)
1170
1171      for indexed_slice in sparse_values:
1172        sparse_results.append(
1173            launcher.all_reduce_indexed_slices(indexed_slice, options))
1174
1175      if reduce_op == reduce_util.ReduceOp.MEAN:
1176        for i, v in enumerate(sparse_results):
1177          with ops.device(self._devices[replica_id]):
1178            sparse_results[i] = indexed_slices.IndexedSlices(
1179                values=sparse_results[i].values / self._group_size,
1180                indices=sparse_results[i].indices,
1181                dense_shape=sparse_results[i].dense_shape)
1182
1183    flat_results = cross_device_utils.stitch_values(
1184        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
1185    return nest.pack_sequence_as(value, flat_results)
1186
1187  def _all_reduce_per_replica_values(self, reduce_op, per_replica_values,
1188                                     options):
1189    """All reduce a list of per_replica_value."""
1190    values_by_device = [[] for _ in self._devices]
1191    num_devices = len(self._devices)
1192    for per_replica in per_replica_values:
1193      for i in range(num_devices):
1194        values_by_device[i].append(per_replica.values[i])
1195
1196    if context.executing_eagerly():
1197
1198      def thread_fn(device_id):
1199        with context.eager_mode():
1200          return self._all_reduce(reduce_op, values_by_device[device_id],
1201                                  device_id, options)
1202
1203      with self._lock:
1204        pool = multiprocessing.pool.ThreadPool(len(self._devices))
1205        outputs_by_device = pool.map(thread_fn, list(range(num_devices)))
1206        pool.close()
1207    else:
1208      outputs_by_device = []
1209      with self._lock:
1210        for i in range(num_devices):
1211          outputs_by_device.append(
1212              self._all_reduce(reduce_op, values_by_device[i], i, options))
1213
1214    result = []
1215    for values in zip(*outputs_by_device):
1216      result.append(
1217          distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
1218    return result
1219
1220  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
1221                            options):
1222    values_util.mark_as_unsaveable()
1223    all_reduced = self._all_reduce_per_replica_values(reduce_op,
1224                                                      [per_replica_value],
1225                                                      options)[0]
1226    devices = get_devices_from(destinations, self._canonicalize_devices)
1227
1228    if _devices_match(per_replica_value, destinations,
1229                      self._canonicalize_devices):
1230      return all_reduced
1231
1232    # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
1233    # utility to access component for a particular device.
1234    if not isinstance(all_reduced, value_lib.Mirrored):
1235      all_reduced = value_lib.Mirrored([all_reduced])
1236
1237    # If we got this far, the destination devices do not match the all-reduce
1238    # devices, so we must map from one to the other.
1239    index = []
1240    # We must add these control dependencies, otherwise we can get deadlock.
1241    with ops.control_dependencies(all_reduced.values):
1242      for d in devices:
1243        with ops.device(d):
1244          for v in all_reduced.values:
1245            if v.device == d:
1246              index.append(array_ops.identity(v))
1247              break
1248          else:
1249            # TODO(josh11b): Once we add support for model parallelism, get the
1250            # copy from the corresponding replica instead of the primary.
1251            index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
1252    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1253
1254  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
1255                                  options):
1256    values_util.mark_as_unsaveable()
1257    all_devices_match = _all_devices_match(value_destination_pairs,
1258                                           self._canonicalize_devices)
1259    if all_devices_match:
1260      return self._all_reduce_per_replica_values(
1261          reduce_op, [v[0] for v in value_destination_pairs], options)
1262    else:
1263      if not all_devices_match:
1264        logging.log_first_n(
1265            logging.WARN, "Efficient batch_reduce is not supported if "
1266            "destinations are different.", 10)
1267
1268      return [
1269          self.reduce_implementation(reduce_op, value, dest, options)
1270          for value, dest in value_destination_pairs
1271      ]
1272
1273  def _gather_implementation(self, per_replica_value, destinations, axis,
1274                             options):
1275    all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
1276    values_util.mark_as_unsaveable()
1277    devices = get_devices_from(destinations, self._canonicalize_devices)
1278
1279    if _devices_match(per_replica_value, destinations,
1280                      self._canonicalize_devices):
1281      return all_gathered
1282
1283    # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform
1284    # utility to access component for a particular device.
1285    if not isinstance(all_gathered, value_lib.Mirrored):
1286      all_gathered = value_lib.Mirrored([all_gathered])
1287
1288    # If we got this far, the destination devices do not match the all-gather
1289    # devices, so we must map from one to the other.
1290    index = []
1291    # We must add these control dependencies, otherwise we can get deadlock.
1292    with ops.control_dependencies(all_gathered.values):
1293      for d in devices:
1294        with ops.device(d):
1295          for v in all_gathered.values:
1296            if v.device == d:
1297              index.append(array_ops.identity(v))
1298              break
1299            else:
1300              index.append(array_ops.identity(all_gathered._primary))  # pylint: disable=protected-access
1301    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1302
1303  def _batch_all_gather(self, per_replica_values, axis, options):
1304    """all gather multiple per-replica-values."""
1305    batch_size = len(per_replica_values)
1306    # For now, we use NCCL only when batch_size > 1.
1307    # TODO(b/132575814): switch to NCCL for all collectives when implementation
1308    # is NCCL.
1309    if (self._limited_nccl and options.implementation
1310        == collective_util.CommunicationImplementation.NCCL and
1311        batch_size == 1):
1312      options = options.merge(
1313          collective_util.Options(
1314              implementation=collective_util.CommunicationImplementation.RING))
1315
1316    logging.log_first_n(
1317        logging.INFO, "Collective batch_all_gather: %d all-gathers, "
1318        "num_devices = %d, group_size = %d, implementation = %s, " %
1319        (batch_size, len(
1320            self._devices), self._group_size, options.implementation), 10)
1321
1322    def compute_gathered_values():
1323      gathered_values = []
1324      with self._lock, ops.name_scope("allgather"):
1325        for per_replica in per_replica_values:
1326          outputs = []
1327          for i in range(len(self._devices)):
1328            outputs.append(self._launchers[i].all_gather(
1329                per_replica.values[i], axis, options))
1330          gathered_values.append(outputs)
1331      return gathered_values
1332
1333    if context.executing_eagerly():
1334      gathered_values = def_function.function(compute_gathered_values)()
1335    else:
1336      gathered_values = compute_gathered_values()
1337
1338    mirrored = []
1339    for value in gathered_values:
1340      mirrored.append(
1341          distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
1342    return mirrored
1343
1344  def __deepcopy__(self, memo):
1345    # distribute_coordinator deep-copies the strategy object, so
1346    # CollectiveAllReduce needs to support deep copy as well.
1347    collective_keys = copy.deepcopy(self._collective_keys, memo)
1348    return CollectiveAllReduce(self._devices, self._group_size, self._options,
1349                               collective_keys, self._canonicalize_devices)
1350
1351
1352def select_cross_device_ops(devices, session_config=None):
1353  """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`.
1354
1355  Args:
1356    devices: a list of devices passed to `tf.distribute.Strategy`.
1357    session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
1358      make decision based on all logical devices.
1359
1360  Returns:
1361    A subclass of `CrossDeviceOps`.
1362  """
1363  requested_devices = set(device_util.canonicalize(d) for d in devices)
1364  if ops.executing_eagerly_outside_functions():
1365    logical_gpus = context.context().list_logical_devices(device_type="GPU")
1366    physical_gpus = context.context().list_physical_devices(device_type="GPU")
1367    if len(logical_gpus) != len(physical_gpus):
1368      logging.warning("NCCL is not supported when using virtual GPUs, falling"
1369                      "back to reduction to one device")
1370      return ReductionToOneDevice()
1371
1372    machine_devices = context.context().list_logical_devices()
1373  else:
1374    machine_devices = device_lib.list_local_devices(
1375        session_config=session_config)
1376  using_devices = set()
1377  for d in machine_devices:
1378    if device_util.canonicalize(d.name) in requested_devices:
1379      using_devices.add(d.name)
1380
1381  if len(using_devices) != len(requested_devices):
1382    logging.warning(
1383        "Some requested devices in `tf.distribute.Strategy` are not visible "
1384        "to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
1385
1386  if any("gpu" not in d.lower() for d in requested_devices):
1387    logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
1388                    "not using nccl allreduce.")
1389    return ReductionToOneDevice()
1390
1391  if kernels.get_registered_kernels_for_op("NcclAllReduce"):
1392    return NcclAllReduce(num_packs=1)
1393  else:
1394    logging.warning("Nccl kernel is not found, not using nccl allreduce.")
1395    return ReductionToOneDevice()
1396