xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/v1/all_reduce.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
16
17import collections
18import math
19
20from tensorflow.python.framework import device as device_lib
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import nccl_ops
25
26
27def _flatten_tensors(tensors):
28  """Check tensors for isomorphism and flatten.
29
30  Args:
31    tensors: list of `tf.Tensor` which must all have the same shape.
32
33  Returns:
34    tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
35    shape: the original shape of each element of input tensors
36
37  Raises:
38    ValueError: tensors are empty or non-isomorphic or have unknown shape.
39  """
40  if not tensors:
41    raise ValueError("tensors cannot be empty")
42  shape = tensors[0].shape
43  for tensor in tensors:
44    shape = shape.merge_with(tensor.shape)
45  if not shape.is_fully_defined():
46    raise ValueError("Tensors must have statically known shape.")
47  if len(shape) != 1:
48    reshaped = []
49    for t in tensors:
50      with ops.colocate_with(t):
51        reshaped.append(array_ops.reshape(t, [-1]))
52    tensors = reshaped
53  return tensors, shape
54
55
56def _reshape_tensors(tensors, shape):
57  """Reshape tensors flattened by _flatten_tensors.
58
59  Args:
60    tensors: list of `tf.Tensor` of identical length 1D tensors.
61    shape: list of integers describing the desired shape.  Product of
62      the elements must equal the length of each tensor.
63
64  Returns:
65    list of `tf.Tensor` which are the reshaped inputs.
66  """
67  reshaped = []
68  for t in tensors:
69    with ops.colocate_with(t):
70      reshaped.append(array_ops.reshape(t, shape))
71  return reshaped
72
73
74def _padded_split(tensor, pieces):
75  """Like split for 1D tensors but pads-out case where len % pieces != 0.
76
77  Args:
78    tensor: `tf.Tensor` that must be 1D.
79    pieces: a positive integer specifying the number of pieces into which
80      tensor should be split.
81
82  Returns:
83    list of `tf.Tensor` of length pieces, which hold the values of
84      thin input tensor, in order. The final tensor may
85      be zero-padded on the end to make its size equal to those of all
86      of the other tensors.
87
88  Raises:
89    ValueError: The input tensor is not 1D.
90  """
91  shape = tensor.shape
92  if 1 != len(shape):
93    raise ValueError("input tensor must be 1D")
94  tensor_len = shape.dims[0].value
95  with ops.colocate_with(tensor):
96    if tensor_len % pieces != 0:
97      # pad to an even length
98      chunk_size = 1 + tensor_len // pieces
99      if pieces > tensor_len:
100        # This is an edge case that should not come up in practice,
101        # i.e. a different reduction algorithm would be better,
102        # but we'll make it work just for completeness.
103        pad_len = pieces - tensor_len
104        extended_whole = array_ops.concat(
105            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
106        parts = array_ops.split(extended_whole, pieces)
107        return parts, pad_len
108      elif (pieces - 1) * chunk_size >= tensor_len:
109        # Another edge case of limited real interest.
110        pad_len = (pieces * chunk_size) % tensor_len
111        extended_whole = array_ops.concat(
112            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
113        parts = array_ops.split(extended_whole, pieces)
114        return parts, pad_len
115      else:
116        last_chunk_size = tensor_len - (pieces - 1) * chunk_size
117        pad_len = chunk_size - last_chunk_size
118        piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
119        parts = array_ops.split(tensor, piece_lens)
120        parts[-1] = array_ops.concat(
121            [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
122        return parts, pad_len
123    else:
124      return array_ops.split(tensor, pieces), 0
125
126
127def _strip_padding(tensors, pad_len):
128  """Strip the suffix padding added by _padded_split.
129
130  Args:
131    tensors: list of `tf.Tensor` of identical length 1D tensors.
132    pad_len: number of elements to be stripped from the end of each tensor.
133
134  Returns:
135    list of `tf.Tensor` which are the stripped inputs.
136
137  Raises:
138    ValueError: tensors must be a non-empty list of 1D tensors, and
139      each must be longer than pad_len.
140  """
141  if not tensors:
142    raise ValueError("tensors cannot be empty")
143  shape = tensors[0].shape
144  if len(shape) > 1:
145    raise ValueError("tensors must be 1D")
146  prefix_len = int(shape[0] - pad_len)
147  if prefix_len < 0:
148    raise ValueError("pad_len longer than tensor")
149  stripped = []
150  for t in tensors:
151    with ops.colocate_with(t):
152      stripped.append(array_ops.slice(t, [0], [prefix_len]))
153  return stripped
154
155
156def _ragged_split(tensor, pieces):
157  """Like split for 1D tensors but allows case where len % pieces != 0.
158
159  Args:
160    tensor: `tf.Tensor` that must be 1D.
161    pieces: a positive integer specifying the number of pieces into which
162      tensor should be split.
163
164  Returns:
165    list of `tf.Tensor` of length pieces, which hold the values of
166      the input tensor, in order. The final tensor may be shorter
167      than the others, which will all be of equal length.
168
169  Raises:
170    ValueError: input tensor must be 1D.
171  """
172  shape = tensor.shape
173  if 1 != len(shape):
174    raise ValueError("input tensor must be 1D")
175  tensor_len = shape.dims[0].value
176  chunk_size = tensor_len // pieces
177  with ops.colocate_with(tensor):
178    if tensor_len != (pieces * chunk_size):
179      # last piece will be short
180      assert pieces > 1
181      last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
182      assert last_chunk_size > 0
183      piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
184      return array_ops.split(tensor, piece_lens)
185    else:
186      return array_ops.split(tensor, pieces)
187
188
189def _ring_permutations(num_workers, num_subchunks, gpu_perm):
190  """"Generate an array of device index arrays, one for each subchunk.
191
192  In the basic ring reduction algorithm there are size(T)/num_devices
193  data chunks and each device process one chunk per tick, i.e. sending
194  one chunk and receiving one chunk.  The idea of subchunking is that
195  each device processes num_subchunks smaller data regions per tick,
196  and the ring rank permutation is different for each subchunk index
197  so that a device is potentially sending to and receiving from
198  num_subchunks different other devices at each tick.  Where multiple
199  independent data channels exist between devices, this strategy
200  supplies a method of using them in parallel.
201
202  Args:
203    num_workers: number of worker tasks
204    num_subchunks: number of subchunks into which to divide each per-GPU chunk.
205    gpu_perm: an array of integers in [0, num_gpus-1] giving the default
206      ring order of GPUs at each worker.  Other permutations will be generated
207      by rotating this array and splicing together per-worker instances.
208
209  Raises:
210    ValueError: the number of subchunks may not exceed the number of GPUs.
211
212  Returns:
213    pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
214        preceding device in the permutation for that subchunk.  The
215        device index of GPU i at worker j is i + (j * num_gpus).
216    rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
217       local rank of device d in the permutation for that subchunk.
218  """
219  num_gpus = len(gpu_perm)
220  devices = num_workers * num_gpus
221  if devices == 0:
222    return [], []
223  if num_subchunks > num_gpus:
224    raise ValueError(
225        "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
226  rotation_interval = max(1, int(num_gpus / num_subchunks))
227  perms_by_s = []
228  for s in range(0, num_subchunks):
229    full_order = []
230    offset = s * rotation_interval
231    for w in range(0, num_workers):
232      default_order = [(w * num_gpus) + i for i in gpu_perm]
233      dev_order = default_order[offset:] + default_order[:offset]
234      full_order += dev_order
235    perms_by_s.append(full_order)
236  pred_by_s_d = [[-1 for d in range(0, devices)]
237                 for s in range(0, num_subchunks)]
238  rank_by_s_d = [[-1 for d in range(0, devices)]
239                 for s in range(0, num_subchunks)]
240  for s in range(0, num_subchunks):
241    for d in range(0, devices):
242      for t in range(0, devices):
243        if d == perms_by_s[s][t]:
244          rank_by_s_d[s][d] = t
245          pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
246          break
247  return (pred_by_s_d, rank_by_s_d)
248
249
250def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
251                          gpu_perm, red_op, un_op=None):
252  """Construct a subgraph performing a ring-style all-reduce of input_tensors.
253
254  Args:
255    input_tensors: a list of `tf.Tensor` objects, which must all
256      have the same shape and type.
257    num_workers: number of worker tasks spanned by input_tensors.
258    num_subchunks: number of subchunks each device should process in one tick.
259    gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
260      each worker.  All workers must have the same number of
261      GPUs with the same rank ordering.  If NVLINK is available, this should
262      be a ring order supported by NVLINK edges.
263    red_op: a binary operator for elementwise reduction.
264    un_op: an optional unary operator to apply to fully reduced values.
265
266  Raises:
267    ValueError: empty input_tensors or they don't all have same
268    size.
269
270  Returns:
271    a list of `tf.Tensor` identical sum-reductions of input_tensors.
272  """
273  if len(input_tensors) < 2:
274    raise ValueError("input_tensors must be length 2 or longer")
275  input_tensors, shape = _flatten_tensors(input_tensors)
276  devices = [t.device for t in input_tensors]
277  (pred_by_s_d, rank_by_s_d) = _ring_permutations(
278      num_workers, num_subchunks, gpu_perm)
279  chunks_by_dev, pad_len = _build_ring_gather(
280      input_tensors, devices,
281      num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
282  if un_op:
283    chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
284  output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
285                                       chunks_by_dev)
286  if pad_len > 0:
287    output_tensors = _strip_padding(output_tensors, pad_len)
288  if len(shape) != 1:
289    output_tensors = _reshape_tensors(output_tensors, shape)
290  return output_tensors
291
292
293def _build_ring_gather(input_tensors, devices, num_subchunks,
294                       pred_by_s_d, rank_by_s_d, red_op):
295  """Construct a subgraph for the first (reduction) pass of ring all-reduce.
296
297  Args:
298    input_tensors: a list of `tf.Tensor` 1D input tensors of same
299      shape and type.
300    devices: array of device name strings
301    num_subchunks: number of subchunks each device should process in one tick.
302    pred_by_s_d: as produced by _ring_permutations
303    rank_by_s_d: as produced by _ring_permutations
304    red_op: a binary operator for elementwise reduction
305
306  Raises:
307    ValueError: tensors must all be one dimensional.
308
309  Returns:
310    list of list of `tf.Tensor` of (partially) reduced values where
311    exactly num_subchunks chunks at each device are fully reduced.
312  """
313  num_devices = len(input_tensors)
314  if num_devices == 0:
315    return []
316  if num_devices == 1:
317    return input_tensors
318  shape = input_tensors[0].shape
319  if 1 != len(shape):
320    raise ValueError("input tensors must be 1D")
321  num_chunks = num_devices * num_subchunks
322  num_ticks = num_devices - 1
323  # Initialize chunks_by_dev with splits of the input tensors.
324  chunks_by_dev = []
325  split_pad_len = 0
326  for d in range(0, num_devices):
327    with ops.device(devices[d]):
328      splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
329      chunks_by_dev.append(splits)
330  # Reduction phase
331  for tick in range(0, num_ticks):
332    # One new partial reduction for every chunk
333    new_partial_reductions = [None for _ in range(0, num_chunks)]
334    # Compute reductions with respect to last tick's values
335    for d in range(0, num_devices):
336      with ops.device(devices[d]):
337        for s in range(0, num_subchunks):
338          rank = rank_by_s_d[s][d]
339          seg_index = (rank + num_devices - (2 + tick)) % num_devices
340          pred_dev = pred_by_s_d[s][d]
341          chunk_index = (seg_index * num_subchunks) + s
342          new_partial_reductions[chunk_index] = red_op(
343              chunks_by_dev[pred_dev][chunk_index],
344              chunks_by_dev[d][chunk_index])
345    # Update chunks_by_dev with the new values at the end of the tick.
346    for d in range(0, num_devices):
347      for s in range(0, num_subchunks):
348        rank = rank_by_s_d[s][d]
349        seg_index = (rank + num_devices - (2 + tick)) % num_devices
350        chunk_index = (seg_index * num_subchunks) + s
351        chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
352  return chunks_by_dev, split_pad_len
353
354
355def _apply_unary_to_chunks(f, chunks_by_dev):
356  """Apply a unary op to each tensor in chunks_by_dev, on same device.
357
358  Args:
359    f: a unary function over `tf.Tensor`.
360    chunks_by_dev: list of lists of `tf.Tensor`.
361
362  Returns:
363    new list of lists of `tf.Tensor` with the same structure as
364    chunks_by_dev containing the derived tensors.
365  """
366  output = []
367  for x in chunks_by_dev:
368    with ops.colocate_with(x[0]):
369      output.append([f(t) for t in x])
370  return output
371
372
373def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
374                        chunks_by_dev):
375  """Construct subgraph for second (scatter) pass of ring all-reduce.
376
377  Args:
378    pred_by_s_d: as produced by _ring_permutations
379    rank_by_s_d: as produced by _ring_permutations
380    chunks_by_dev: list of list of `tf.Tensor` indexed by ints
381      (device, chunk)
382
383  Raises:
384    ValueError: chunks_by_dev is not well-formed
385
386  Returns:
387    list of `tf.Tensor` which are the fully reduced tensors, one
388    at each device corresponding to the outer dimension of chunks_by_dev.
389  """
390  num_devices = len(chunks_by_dev)
391  num_chunks = len(chunks_by_dev[0])
392  if 0 != num_chunks % num_devices:
393    raise ValueError(
394        "Expect number of chunks per device to be divisible by num_devices")
395  num_subchunks = int(num_chunks / num_devices)
396  num_ticks = num_devices - 1
397  for tick in range(0, num_ticks):
398    passed_values = [None for _ in range(0, num_chunks)]
399    for d in range(0, num_devices):
400      with ops.colocate_with(chunks_by_dev[d][0]):
401        for s in range(0, num_subchunks):
402          rank = rank_by_s_d[s][d]
403          seg_index = (rank + num_devices - (1 + tick)) % num_devices
404          pred_dev = pred_by_s_d[s][d]
405          chunk_index = (seg_index * num_subchunks) + s
406          passed_values[chunk_index] = array_ops.identity(
407              chunks_by_dev[pred_dev][chunk_index])
408    for d in range(0, num_devices):
409      for s in range(0, num_subchunks):
410        rank = rank_by_s_d[s][d]
411        seg_index = (rank + num_devices - (1 + tick)) % num_devices
412        chunk_index = (seg_index * num_subchunks) + s
413        chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
414  # Join chunks at each device.
415  output = []
416  for x in chunks_by_dev:
417    with ops.colocate_with(x[0]):
418      output.append(array_ops.concat(x, 0))
419  return output
420
421
422def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
423  """Construct a subgraph for recursive halving-doubling all-reduce.
424
425  The recursive halving-doubling algorithm is described in
426  (Thakur et al., 2015).
427
428  The concept is to arrange the participating n devices in
429  a linear sequence where devices exchange data pairwise
430  with one other device in each round.  During the gather
431  phase there are lg(n) rounds where devices exchange
432  increasingly smaller sub-tensors with another device
433  at increasingly greater distances, until at the top
434  each device has 1/n of the fully reduced values.  During the
435  scatter phase each device exchanges its fully reduced
436  sub-tensor (which doubles in length at each round)
437  with one other device at increasingly smaller distances
438  until each device has all of the fully reduced values.
439
440  Note: this preliminary version requires that len(input_tensors) be a
441    power of 2.  TODO(tucker): relax this restriction.  Also, the
442    number of elements in each tensor must be divisible by 2^h where h
443    is the number of hops in each phase.  This will also be relaxed in
444    the future with edge-case specific logic.
445
446  Args:
447    input_tensors: list of `tf.Tensor` to be elementwise reduced.
448    red_op: a binary elementwise reduction Op.
449    un_op: an optional unary elementwise Op to apply to reduced values.
450
451  Returns:
452    list of `tf.Tensor` which are the fully reduced tensors, one
453    at each device of input_tensors.
454
455  Raises:
456    ValueError: num_devices not a power of 2, or tensor len not divisible
457    by 2 the proper number of times.
458
459  References:
460    Optimization of Collective Communication Operations in MPICH:
461      [Thakur et al., 2005]
462      (https://journals.sagepub.com/doi/abs/10.1177/1094342005051521)
463      ([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf))
464  """
465  devices = [t.device for t in input_tensors]
466  input_tensors, shape = _flatten_tensors(input_tensors)
467  reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
468  if un_op:
469    reduced_shards = [un_op(t) for t in reduced_shards]
470  output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
471  if len(shape) != 1:
472    output_tensors = _reshape_tensors(output_tensors, shape)
473  return output_tensors
474
475
476def _build_recursive_hd_gather(input_tensors, devices, red_op):
477  """Construct the gather phase of recursive halving-doubling all-reduce.
478
479  Args:
480    input_tensors: list of `tf.Tensor` to be elementwise reduced.
481    devices: a list of strings naming the devices hosting input_tensors,
482      which will also be used to host the (partial) reduction values.
483    red_op: a binary elementwise reduction Op.
484
485  Returns:
486    list of `tf.Tensor` which are the fully reduced tensor shards.
487
488  Raises:
489    ValueError: num_devices not a power of 2, or tensor len not divisible
490    by 2 the proper number of times.
491  """
492  num_devices = len(devices)
493  num_hops = int(math.log(num_devices, 2))
494  if num_devices != (2 ** num_hops):
495    raise ValueError("num_devices must be a power of 2")
496  chunks = input_tensors
497  for h in range(0, num_hops):
498    span = 2 ** h
499    group_size = span * 2
500    new_chunks = [[] for _ in devices]
501    for d in range(0, num_devices):
502      if (d % group_size) >= (group_size / 2):
503        # skip right half of a pair
504        continue
505      left_dev = devices[d]
506      right_dev = devices[d + span]
507      left_split = array_ops.split(chunks[d], 2)
508      right_split = array_ops.split(chunks[d+span], 2)
509      with ops.device(left_dev):
510        new_chunks[d] = red_op(left_split[0], right_split[0])
511      with ops.device(right_dev):
512        new_chunks[d + span] = red_op(left_split[1], right_split[1])
513    chunks = new_chunks
514  return chunks
515
516
517def _build_recursive_hd_scatter(input_tensors, devices):
518  """Construct the scatter phase of recursive halving-doubling all-reduce.
519
520  Args:
521    input_tensors: list of `tf.Tensor` that are fully-reduced shards.
522    devices: a list of strings naming the devices on which the reconstituted
523      full tensors should be placed.
524
525  Returns:
526    list of `tf.Tensor` which are the fully reduced tensors.
527  """
528  num_devices = len(devices)
529  num_hops = int(math.log(num_devices, 2))
530  assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
531  chunks = input_tensors
532  for h in reversed(range(0, num_hops)):
533    span = 2 ** h
534    group_size = span * 2
535    new_chunks = [[] for _ in devices]
536    for d in range(0, num_devices):
537      if (d % group_size) >= (group_size / 2):
538        # skip right half of a pair
539        continue
540      left_idx = d
541      right_idx = d + span
542      left_dev = devices[left_idx]
543      right_dev = devices[right_idx]
544      with ops.device(left_dev):
545        new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
546                                                 chunks[right_idx]], 0)
547      with ops.device(right_dev):
548        new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
549                                                  chunks[right_idx]], 0)
550    chunks = new_chunks
551  return chunks
552
553
554def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
555  """Construct a subgraph for shuffle all-reduce.
556
557  Shuffle reduce is essentially the algorithm implemented when using
558  parameter servers.  Suppose tensor length is n, there are d devices
559  and g gather shards.  Each device sends a n/g length sub-tensor to
560  each gather shard.  The gather shards perform a reduction across d
561  fragments, then broadcast the result back to each device.  The
562  devices then join the g fully reduced fragments they receive from
563  the shards.  The gather shards could perform d-1 pairwise
564  reductions, or one d-way reduction.  The first is better where
565  reduction Op time is low compared to transmission time, the second
566  better in the other case.
567
568  Args:
569    input_tensors: list of `tf.Tensor` values to be reduced.
570    gather_devices: list of names of devices on which reduction shards
571      should be placed.
572    red_op: an n-array elementwise reduction Op
573    un_op: optional elementwise unary Op to be applied to fully-reduced values.
574
575  Returns:
576    list of `tf.Tensor` which are the fully reduced tensors.
577  """
578  input_tensors, shape = _flatten_tensors(input_tensors)
579  dst_devices = [t.device for t in input_tensors]
580  reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
581                                         red_op, un_op)
582  output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
583  if len(shape) != 1:
584    output_tensors = _reshape_tensors(output_tensors, shape)
585  return output_tensors
586
587
588def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
589  """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
590
591  Args:
592    input_tensors: list of `tf.Tensor` values to be reduced.
593    gather_devices: list of names of devices on which reduction shards
594      should be placed.
595    red_op: the binary reduction Op
596    un_op: optional elementwise unary Op to be applied to fully-reduced values.
597
598  Returns:
599    list of `tf.Tensor` which are the fully reduced shards.
600
601  Raises:
602    ValueError: inputs not well-formed.
603  """
604  num_source_devices = len(input_tensors)
605  num_gather_devices = len(gather_devices)
606  shape = input_tensors[0].shape
607  if len(shape) != 1:
608    raise ValueError("input_tensors must be 1D")
609  shards_by_source = []
610  for d in range(0, num_source_devices):
611    with ops.colocate_with(input_tensors[d]):
612      shards_by_source.append(
613          _ragged_split(input_tensors[d], num_gather_devices))
614  reduced_shards = []
615  for d in range(0, num_gather_devices):
616    with ops.device(gather_devices[d]):
617      values = [s[d] for s in shards_by_source]
618      red_shard = red_op(values)
619      if un_op:
620        red_shard = un_op(red_shard)
621      reduced_shards.append(red_shard)
622  return reduced_shards
623
624
625def _build_shuffle_scatter(reduced_shards, dst_devices):
626  """Build the scatter phase of shuffle all-reduce.
627
628  Args:
629    reduced_shards:  list of `tf.Tensor` fully reduced shards
630    dst_devices: list of names of devices at which the fully-reduced value
631      should be reconstituted.
632
633  Returns:
634    list of `tf.Tensor` scattered tensors.
635  """
636  num_devices = len(dst_devices)
637  out_tensors = []
638  for d in range(0, num_devices):
639    with ops.device(dst_devices[d]):
640      out_tensors.append(array_ops.concat(reduced_shards, 0))
641  return out_tensors
642
643
644def _split_by_task(devices, values):
645  """Partition devices and values by common task.
646
647  Args:
648    devices: list of device name strings
649    values: list of `tf.Tensor` of same length as devices.
650
651  Returns:
652    (per_task_devices, per_task_values) where both values are
653    lists of lists with isomorphic structure: the outer list is
654    indexed by task, and the inner list has length of the number
655    of values belonging to that task.  per_task_devices contains
656    the specific devices to which the values are local, and
657    per_task_values contains the corresponding values.
658
659  Raises:
660    ValueError: devices must be same length as values.
661  """
662  num_devices = len(devices)
663  if num_devices != len(values):
664    raise ValueError("len(devices) must equal len(values)")
665  per_task_devices = collections.OrderedDict()
666  per_task_values = collections.OrderedDict()
667  for d in range(num_devices):
668    d_spec = device_lib.DeviceSpec.from_string(devices[d])
669    if not hasattr(d_spec, "task") or d_spec.task is None:
670      assert False, "failed to parse device %s" % devices[d]
671    index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
672    if index not in per_task_devices:
673      per_task_devices[index] = []
674      per_task_values[index] = []
675    per_task_devices[index].append(devices[d])
676    per_task_values[index].append(values[d])
677
678  return (list(per_task_devices.values()), list(per_task_values.values()))
679
680
681def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
682  """Build a subgraph that does one full all-reduce, using NCCL.
683
684  Args:
685    input_tensors: list of `tf.Tensor` of same-shape and type values to
686      be reduced.
687    red_op: binary elementwise reduction operator. Must be one of
688      {tf.add}
689    un_op: optional unary elementwise Op to apply to fully-reduce values.
690
691  Returns:
692    list of `tf.Tensor` of reduced values.
693
694  Raises:
695    ValueError: red_op not supported.
696  """
697  if red_op == math_ops.add:
698    output_tensors = nccl_ops.all_sum(input_tensors)
699  else:
700    raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
701  if un_op:
702    un_op_wrapped = []
703    for t in output_tensors:
704      with ops.colocate_with(t):
705        un_op_wrapped.append(un_op(t))
706    output_tensors = un_op_wrapped
707  return output_tensors
708
709
710def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
711  """Construct a subgraph for NCCL hybrid all-reduce.
712
713  Args:
714    input_tensors: list of `tf.Tensor` of same-shape and type values to
715      be reduced.
716    red_op: binary elementwise reduction operator.
717    upper_level_f: function for reducing one value per worker, across
718      workers.
719
720  Returns:
721    list of `tf.Tensor` of reduced values.
722
723  Raises:
724    ValueError: inputs not well-formed.
725  """
726  input_tensors, shape = _flatten_tensors(input_tensors)
727  devices = [t.device for t in input_tensors]
728  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
729  num_workers = len(per_worker_devices)
730  up_values = [None for w in range(0, num_workers)]
731  up_devices = up_values[:]
732  down_values = up_values[:]
733  # First stage: reduce within each worker using NCCL
734  for w in range(0, num_workers):
735    worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
736    # NOTE: these reductions will not run to completion unless
737    # every output value is used.  Since we only need one, we
738    # need to put control dependencies on the rest.
739    with ops.control_dependencies(worker_values):
740      with ops.device(worker_values[0].device):
741        up_values[w] = array_ops.identity(worker_values[0])
742      up_devices[w] = per_worker_devices[w][0]
743  # Second stage: Apply upper_level_f to reduce across first device at
744  # each worker
745  level_2_output = upper_level_f(up_values)
746  # Third stage: propagate within each worker using NCCL Broadcast
747  for w in range(0, num_workers):
748    dst_tensors = []
749    with ops.device(per_worker_devices[w][0]):
750      broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
751    for d in per_worker_devices[w]:
752      with ops.device(d):
753        dst_tensors.append(array_ops.identity(broadcast_src))
754    down_values[w] = dst_tensors
755  output_tensors = [v for sublist in down_values for v in sublist]
756  if len(shape) != 1:
757    output_tensors = _reshape_tensors(output_tensors, shape)
758  return output_tensors
759
760
761def _reduce_non_singleton(input_tensors, red_f, un_op):
762  """If len(input_tensors) > 1, apply red_f, else apply un_op."""
763  if len(input_tensors) > 1:
764    return red_f(input_tensors)
765  else:
766    if not un_op:
767      return input_tensors
768    output_tensors = []
769    for t in input_tensors:
770      with ops.colocate_with(t):
771        output_tensors.append(un_op(t))
772    return output_tensors
773
774
775def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
776  """Construct hybrid of NCCL within workers, Ring across workers."""
777  def upper_builder(y):
778    return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
779  def upper_level_f(x):
780    return _reduce_non_singleton(x, upper_builder, un_op)
781  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
782
783
784def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
785  """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
786  upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
787  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
788
789
790def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
791                            shuffle_red_op, un_op=None):
792  """Construct hybrid of NCCL within workers, Shuffle across workers."""
793  def upper_level_f(x):
794    return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op)
795
796  return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)
797
798
799def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
800  """Construct a subgraph for Shuffle hybrid all-reduce.
801
802  Args:
803    input_tensors: list of `tf.Tensor` of same-shape and type values to
804      be reduced.
805    gather_devices: list of device names on which to host gather shards.
806    red_op: binary elementwise reduction operator.
807    upper_level_f: function for reducing one value per worker, across
808      workers.
809
810  Returns:
811    list of `tf.Tensor` of reduced values.
812
813  Raises:
814    ValueError: inputs not well-formed.
815  """
816  input_tensors, shape = _flatten_tensors(input_tensors)
817  # First stage, reduce across each worker using gather_devices.
818  devices = [t.device for t in input_tensors]
819  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
820  num_workers = len(per_worker_devices)
821  up_values = []
822  if len(gather_devices) != num_workers:
823    raise ValueError("For shuffle hybrid, gather_devices must contain one "
824                     "device per worker. ")
825  for w in range(0, num_workers):
826    reduced_shards = _build_shuffle_gather(
827        per_worker_values[w], [gather_devices[w]], red_op)
828    up_values.append(reduced_shards[0])
829  # Second stage, apply upper_level_f.
830  level_2_output = upper_level_f(up_values)
831  # Third stage, apply shuffle scatter at each worker.
832  output_tensors = []
833  for w in range(0, num_workers):
834    output_tensors += _build_shuffle_scatter(
835        [level_2_output[w]], per_worker_devices[w])
836  if len(shape) != 1:
837    output_tensors = _reshape_tensors(output_tensors, shape)
838  return output_tensors
839
840
841def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
842                            red_n_op, red_op, un_op=None):
843  """Construct hybrid of Shuffle within workers, Ring across workers."""
844  def upper_builder(tensors):
845    return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
846                                 red_op, un_op)
847  def upper_level_f(tensors):
848    return _reduce_non_singleton(tensors, upper_builder, un_op)
849  return _build_shuffle_hybrid(
850      input_tensors, gather_devices, red_n_op, upper_level_f)
851
852
853def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
854                               second_gather_devices, red_op, un_op=None):
855  """Construct hybrid of Shuffle within workers, Shuffle across workers."""
856  def upper_builder(tensors):
857    return build_shuffle_all_reduce(tensors, second_gather_devices,
858                                    red_op, un_op)
859  def upper_level_f(tensors):
860    return _reduce_non_singleton(tensors, upper_builder, un_op)
861  return _build_shuffle_hybrid(
862      input_tensors, first_gather_devices, red_op, upper_level_f)
863