1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to construct a TF subgraph implementing distributed All-Reduce.""" 16 17import collections 18import math 19 20from tensorflow.python.framework import device as device_lib 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import nccl_ops 25 26 27def _flatten_tensors(tensors): 28 """Check tensors for isomorphism and flatten. 29 30 Args: 31 tensors: list of `tf.Tensor` which must all have the same shape. 32 33 Returns: 34 tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors 35 shape: the original shape of each element of input tensors 36 37 Raises: 38 ValueError: tensors are empty or non-isomorphic or have unknown shape. 39 """ 40 if not tensors: 41 raise ValueError("tensors cannot be empty") 42 shape = tensors[0].shape 43 for tensor in tensors: 44 shape = shape.merge_with(tensor.shape) 45 if not shape.is_fully_defined(): 46 raise ValueError("Tensors must have statically known shape.") 47 if len(shape) != 1: 48 reshaped = [] 49 for t in tensors: 50 with ops.colocate_with(t): 51 reshaped.append(array_ops.reshape(t, [-1])) 52 tensors = reshaped 53 return tensors, shape 54 55 56def _reshape_tensors(tensors, shape): 57 """Reshape tensors flattened by _flatten_tensors. 58 59 Args: 60 tensors: list of `tf.Tensor` of identical length 1D tensors. 61 shape: list of integers describing the desired shape. Product of 62 the elements must equal the length of each tensor. 63 64 Returns: 65 list of `tf.Tensor` which are the reshaped inputs. 66 """ 67 reshaped = [] 68 for t in tensors: 69 with ops.colocate_with(t): 70 reshaped.append(array_ops.reshape(t, shape)) 71 return reshaped 72 73 74def _padded_split(tensor, pieces): 75 """Like split for 1D tensors but pads-out case where len % pieces != 0. 76 77 Args: 78 tensor: `tf.Tensor` that must be 1D. 79 pieces: a positive integer specifying the number of pieces into which 80 tensor should be split. 81 82 Returns: 83 list of `tf.Tensor` of length pieces, which hold the values of 84 thin input tensor, in order. The final tensor may 85 be zero-padded on the end to make its size equal to those of all 86 of the other tensors. 87 88 Raises: 89 ValueError: The input tensor is not 1D. 90 """ 91 shape = tensor.shape 92 if 1 != len(shape): 93 raise ValueError("input tensor must be 1D") 94 tensor_len = shape.dims[0].value 95 with ops.colocate_with(tensor): 96 if tensor_len % pieces != 0: 97 # pad to an even length 98 chunk_size = 1 + tensor_len // pieces 99 if pieces > tensor_len: 100 # This is an edge case that should not come up in practice, 101 # i.e. a different reduction algorithm would be better, 102 # but we'll make it work just for completeness. 103 pad_len = pieces - tensor_len 104 extended_whole = array_ops.concat( 105 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 106 parts = array_ops.split(extended_whole, pieces) 107 return parts, pad_len 108 elif (pieces - 1) * chunk_size >= tensor_len: 109 # Another edge case of limited real interest. 110 pad_len = (pieces * chunk_size) % tensor_len 111 extended_whole = array_ops.concat( 112 [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 113 parts = array_ops.split(extended_whole, pieces) 114 return parts, pad_len 115 else: 116 last_chunk_size = tensor_len - (pieces - 1) * chunk_size 117 pad_len = chunk_size - last_chunk_size 118 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 119 parts = array_ops.split(tensor, piece_lens) 120 parts[-1] = array_ops.concat( 121 [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0) 122 return parts, pad_len 123 else: 124 return array_ops.split(tensor, pieces), 0 125 126 127def _strip_padding(tensors, pad_len): 128 """Strip the suffix padding added by _padded_split. 129 130 Args: 131 tensors: list of `tf.Tensor` of identical length 1D tensors. 132 pad_len: number of elements to be stripped from the end of each tensor. 133 134 Returns: 135 list of `tf.Tensor` which are the stripped inputs. 136 137 Raises: 138 ValueError: tensors must be a non-empty list of 1D tensors, and 139 each must be longer than pad_len. 140 """ 141 if not tensors: 142 raise ValueError("tensors cannot be empty") 143 shape = tensors[0].shape 144 if len(shape) > 1: 145 raise ValueError("tensors must be 1D") 146 prefix_len = int(shape[0] - pad_len) 147 if prefix_len < 0: 148 raise ValueError("pad_len longer than tensor") 149 stripped = [] 150 for t in tensors: 151 with ops.colocate_with(t): 152 stripped.append(array_ops.slice(t, [0], [prefix_len])) 153 return stripped 154 155 156def _ragged_split(tensor, pieces): 157 """Like split for 1D tensors but allows case where len % pieces != 0. 158 159 Args: 160 tensor: `tf.Tensor` that must be 1D. 161 pieces: a positive integer specifying the number of pieces into which 162 tensor should be split. 163 164 Returns: 165 list of `tf.Tensor` of length pieces, which hold the values of 166 the input tensor, in order. The final tensor may be shorter 167 than the others, which will all be of equal length. 168 169 Raises: 170 ValueError: input tensor must be 1D. 171 """ 172 shape = tensor.shape 173 if 1 != len(shape): 174 raise ValueError("input tensor must be 1D") 175 tensor_len = shape.dims[0].value 176 chunk_size = tensor_len // pieces 177 with ops.colocate_with(tensor): 178 if tensor_len != (pieces * chunk_size): 179 # last piece will be short 180 assert pieces > 1 181 last_chunk_size = tensor_len - ((pieces - 1) * chunk_size) 182 assert last_chunk_size > 0 183 piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size] 184 return array_ops.split(tensor, piece_lens) 185 else: 186 return array_ops.split(tensor, pieces) 187 188 189def _ring_permutations(num_workers, num_subchunks, gpu_perm): 190 """"Generate an array of device index arrays, one for each subchunk. 191 192 In the basic ring reduction algorithm there are size(T)/num_devices 193 data chunks and each device process one chunk per tick, i.e. sending 194 one chunk and receiving one chunk. The idea of subchunking is that 195 each device processes num_subchunks smaller data regions per tick, 196 and the ring rank permutation is different for each subchunk index 197 so that a device is potentially sending to and receiving from 198 num_subchunks different other devices at each tick. Where multiple 199 independent data channels exist between devices, this strategy 200 supplies a method of using them in parallel. 201 202 Args: 203 num_workers: number of worker tasks 204 num_subchunks: number of subchunks into which to divide each per-GPU chunk. 205 gpu_perm: an array of integers in [0, num_gpus-1] giving the default 206 ring order of GPUs at each worker. Other permutations will be generated 207 by rotating this array and splicing together per-worker instances. 208 209 Raises: 210 ValueError: the number of subchunks may not exceed the number of GPUs. 211 212 Returns: 213 pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 214 preceding device in the permutation for that subchunk. The 215 device index of GPU i at worker j is i + (j * num_gpus). 216 rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to 217 local rank of device d in the permutation for that subchunk. 218 """ 219 num_gpus = len(gpu_perm) 220 devices = num_workers * num_gpus 221 if devices == 0: 222 return [], [] 223 if num_subchunks > num_gpus: 224 raise ValueError( 225 "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus)) 226 rotation_interval = max(1, int(num_gpus / num_subchunks)) 227 perms_by_s = [] 228 for s in range(0, num_subchunks): 229 full_order = [] 230 offset = s * rotation_interval 231 for w in range(0, num_workers): 232 default_order = [(w * num_gpus) + i for i in gpu_perm] 233 dev_order = default_order[offset:] + default_order[:offset] 234 full_order += dev_order 235 perms_by_s.append(full_order) 236 pred_by_s_d = [[-1 for d in range(0, devices)] 237 for s in range(0, num_subchunks)] 238 rank_by_s_d = [[-1 for d in range(0, devices)] 239 for s in range(0, num_subchunks)] 240 for s in range(0, num_subchunks): 241 for d in range(0, devices): 242 for t in range(0, devices): 243 if d == perms_by_s[s][t]: 244 rank_by_s_d[s][d] = t 245 pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices] 246 break 247 return (pred_by_s_d, rank_by_s_d) 248 249 250def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, 251 gpu_perm, red_op, un_op=None): 252 """Construct a subgraph performing a ring-style all-reduce of input_tensors. 253 254 Args: 255 input_tensors: a list of `tf.Tensor` objects, which must all 256 have the same shape and type. 257 num_workers: number of worker tasks spanned by input_tensors. 258 num_subchunks: number of subchunks each device should process in one tick. 259 gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at 260 each worker. All workers must have the same number of 261 GPUs with the same rank ordering. If NVLINK is available, this should 262 be a ring order supported by NVLINK edges. 263 red_op: a binary operator for elementwise reduction. 264 un_op: an optional unary operator to apply to fully reduced values. 265 266 Raises: 267 ValueError: empty input_tensors or they don't all have same 268 size. 269 270 Returns: 271 a list of `tf.Tensor` identical sum-reductions of input_tensors. 272 """ 273 if len(input_tensors) < 2: 274 raise ValueError("input_tensors must be length 2 or longer") 275 input_tensors, shape = _flatten_tensors(input_tensors) 276 devices = [t.device for t in input_tensors] 277 (pred_by_s_d, rank_by_s_d) = _ring_permutations( 278 num_workers, num_subchunks, gpu_perm) 279 chunks_by_dev, pad_len = _build_ring_gather( 280 input_tensors, devices, 281 num_subchunks, pred_by_s_d, rank_by_s_d, red_op) 282 if un_op: 283 chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev) 284 output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d, 285 chunks_by_dev) 286 if pad_len > 0: 287 output_tensors = _strip_padding(output_tensors, pad_len) 288 if len(shape) != 1: 289 output_tensors = _reshape_tensors(output_tensors, shape) 290 return output_tensors 291 292 293def _build_ring_gather(input_tensors, devices, num_subchunks, 294 pred_by_s_d, rank_by_s_d, red_op): 295 """Construct a subgraph for the first (reduction) pass of ring all-reduce. 296 297 Args: 298 input_tensors: a list of `tf.Tensor` 1D input tensors of same 299 shape and type. 300 devices: array of device name strings 301 num_subchunks: number of subchunks each device should process in one tick. 302 pred_by_s_d: as produced by _ring_permutations 303 rank_by_s_d: as produced by _ring_permutations 304 red_op: a binary operator for elementwise reduction 305 306 Raises: 307 ValueError: tensors must all be one dimensional. 308 309 Returns: 310 list of list of `tf.Tensor` of (partially) reduced values where 311 exactly num_subchunks chunks at each device are fully reduced. 312 """ 313 num_devices = len(input_tensors) 314 if num_devices == 0: 315 return [] 316 if num_devices == 1: 317 return input_tensors 318 shape = input_tensors[0].shape 319 if 1 != len(shape): 320 raise ValueError("input tensors must be 1D") 321 num_chunks = num_devices * num_subchunks 322 num_ticks = num_devices - 1 323 # Initialize chunks_by_dev with splits of the input tensors. 324 chunks_by_dev = [] 325 split_pad_len = 0 326 for d in range(0, num_devices): 327 with ops.device(devices[d]): 328 splits, split_pad_len = _padded_split(input_tensors[d], num_chunks) 329 chunks_by_dev.append(splits) 330 # Reduction phase 331 for tick in range(0, num_ticks): 332 # One new partial reduction for every chunk 333 new_partial_reductions = [None for _ in range(0, num_chunks)] 334 # Compute reductions with respect to last tick's values 335 for d in range(0, num_devices): 336 with ops.device(devices[d]): 337 for s in range(0, num_subchunks): 338 rank = rank_by_s_d[s][d] 339 seg_index = (rank + num_devices - (2 + tick)) % num_devices 340 pred_dev = pred_by_s_d[s][d] 341 chunk_index = (seg_index * num_subchunks) + s 342 new_partial_reductions[chunk_index] = red_op( 343 chunks_by_dev[pred_dev][chunk_index], 344 chunks_by_dev[d][chunk_index]) 345 # Update chunks_by_dev with the new values at the end of the tick. 346 for d in range(0, num_devices): 347 for s in range(0, num_subchunks): 348 rank = rank_by_s_d[s][d] 349 seg_index = (rank + num_devices - (2 + tick)) % num_devices 350 chunk_index = (seg_index * num_subchunks) + s 351 chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index] 352 return chunks_by_dev, split_pad_len 353 354 355def _apply_unary_to_chunks(f, chunks_by_dev): 356 """Apply a unary op to each tensor in chunks_by_dev, on same device. 357 358 Args: 359 f: a unary function over `tf.Tensor`. 360 chunks_by_dev: list of lists of `tf.Tensor`. 361 362 Returns: 363 new list of lists of `tf.Tensor` with the same structure as 364 chunks_by_dev containing the derived tensors. 365 """ 366 output = [] 367 for x in chunks_by_dev: 368 with ops.colocate_with(x[0]): 369 output.append([f(t) for t in x]) 370 return output 371 372 373def _build_ring_scatter(pred_by_s_d, rank_by_s_d, 374 chunks_by_dev): 375 """Construct subgraph for second (scatter) pass of ring all-reduce. 376 377 Args: 378 pred_by_s_d: as produced by _ring_permutations 379 rank_by_s_d: as produced by _ring_permutations 380 chunks_by_dev: list of list of `tf.Tensor` indexed by ints 381 (device, chunk) 382 383 Raises: 384 ValueError: chunks_by_dev is not well-formed 385 386 Returns: 387 list of `tf.Tensor` which are the fully reduced tensors, one 388 at each device corresponding to the outer dimension of chunks_by_dev. 389 """ 390 num_devices = len(chunks_by_dev) 391 num_chunks = len(chunks_by_dev[0]) 392 if 0 != num_chunks % num_devices: 393 raise ValueError( 394 "Expect number of chunks per device to be divisible by num_devices") 395 num_subchunks = int(num_chunks / num_devices) 396 num_ticks = num_devices - 1 397 for tick in range(0, num_ticks): 398 passed_values = [None for _ in range(0, num_chunks)] 399 for d in range(0, num_devices): 400 with ops.colocate_with(chunks_by_dev[d][0]): 401 for s in range(0, num_subchunks): 402 rank = rank_by_s_d[s][d] 403 seg_index = (rank + num_devices - (1 + tick)) % num_devices 404 pred_dev = pred_by_s_d[s][d] 405 chunk_index = (seg_index * num_subchunks) + s 406 passed_values[chunk_index] = array_ops.identity( 407 chunks_by_dev[pred_dev][chunk_index]) 408 for d in range(0, num_devices): 409 for s in range(0, num_subchunks): 410 rank = rank_by_s_d[s][d] 411 seg_index = (rank + num_devices - (1 + tick)) % num_devices 412 chunk_index = (seg_index * num_subchunks) + s 413 chunks_by_dev[d][chunk_index] = passed_values[chunk_index] 414 # Join chunks at each device. 415 output = [] 416 for x in chunks_by_dev: 417 with ops.colocate_with(x[0]): 418 output.append(array_ops.concat(x, 0)) 419 return output 420 421 422def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): 423 """Construct a subgraph for recursive halving-doubling all-reduce. 424 425 The recursive halving-doubling algorithm is described in 426 (Thakur et al., 2015). 427 428 The concept is to arrange the participating n devices in 429 a linear sequence where devices exchange data pairwise 430 with one other device in each round. During the gather 431 phase there are lg(n) rounds where devices exchange 432 increasingly smaller sub-tensors with another device 433 at increasingly greater distances, until at the top 434 each device has 1/n of the fully reduced values. During the 435 scatter phase each device exchanges its fully reduced 436 sub-tensor (which doubles in length at each round) 437 with one other device at increasingly smaller distances 438 until each device has all of the fully reduced values. 439 440 Note: this preliminary version requires that len(input_tensors) be a 441 power of 2. TODO(tucker): relax this restriction. Also, the 442 number of elements in each tensor must be divisible by 2^h where h 443 is the number of hops in each phase. This will also be relaxed in 444 the future with edge-case specific logic. 445 446 Args: 447 input_tensors: list of `tf.Tensor` to be elementwise reduced. 448 red_op: a binary elementwise reduction Op. 449 un_op: an optional unary elementwise Op to apply to reduced values. 450 451 Returns: 452 list of `tf.Tensor` which are the fully reduced tensors, one 453 at each device of input_tensors. 454 455 Raises: 456 ValueError: num_devices not a power of 2, or tensor len not divisible 457 by 2 the proper number of times. 458 459 References: 460 Optimization of Collective Communication Operations in MPICH: 461 [Thakur et al., 2005] 462 (https://journals.sagepub.com/doi/abs/10.1177/1094342005051521) 463 ([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf)) 464 """ 465 devices = [t.device for t in input_tensors] 466 input_tensors, shape = _flatten_tensors(input_tensors) 467 reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op) 468 if un_op: 469 reduced_shards = [un_op(t) for t in reduced_shards] 470 output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) 471 if len(shape) != 1: 472 output_tensors = _reshape_tensors(output_tensors, shape) 473 return output_tensors 474 475 476def _build_recursive_hd_gather(input_tensors, devices, red_op): 477 """Construct the gather phase of recursive halving-doubling all-reduce. 478 479 Args: 480 input_tensors: list of `tf.Tensor` to be elementwise reduced. 481 devices: a list of strings naming the devices hosting input_tensors, 482 which will also be used to host the (partial) reduction values. 483 red_op: a binary elementwise reduction Op. 484 485 Returns: 486 list of `tf.Tensor` which are the fully reduced tensor shards. 487 488 Raises: 489 ValueError: num_devices not a power of 2, or tensor len not divisible 490 by 2 the proper number of times. 491 """ 492 num_devices = len(devices) 493 num_hops = int(math.log(num_devices, 2)) 494 if num_devices != (2 ** num_hops): 495 raise ValueError("num_devices must be a power of 2") 496 chunks = input_tensors 497 for h in range(0, num_hops): 498 span = 2 ** h 499 group_size = span * 2 500 new_chunks = [[] for _ in devices] 501 for d in range(0, num_devices): 502 if (d % group_size) >= (group_size / 2): 503 # skip right half of a pair 504 continue 505 left_dev = devices[d] 506 right_dev = devices[d + span] 507 left_split = array_ops.split(chunks[d], 2) 508 right_split = array_ops.split(chunks[d+span], 2) 509 with ops.device(left_dev): 510 new_chunks[d] = red_op(left_split[0], right_split[0]) 511 with ops.device(right_dev): 512 new_chunks[d + span] = red_op(left_split[1], right_split[1]) 513 chunks = new_chunks 514 return chunks 515 516 517def _build_recursive_hd_scatter(input_tensors, devices): 518 """Construct the scatter phase of recursive halving-doubling all-reduce. 519 520 Args: 521 input_tensors: list of `tf.Tensor` that are fully-reduced shards. 522 devices: a list of strings naming the devices on which the reconstituted 523 full tensors should be placed. 524 525 Returns: 526 list of `tf.Tensor` which are the fully reduced tensors. 527 """ 528 num_devices = len(devices) 529 num_hops = int(math.log(num_devices, 2)) 530 assert num_devices == (2 ** num_hops), "num_devices must be a power of 2" 531 chunks = input_tensors 532 for h in reversed(range(0, num_hops)): 533 span = 2 ** h 534 group_size = span * 2 535 new_chunks = [[] for _ in devices] 536 for d in range(0, num_devices): 537 if (d % group_size) >= (group_size / 2): 538 # skip right half of a pair 539 continue 540 left_idx = d 541 right_idx = d + span 542 left_dev = devices[left_idx] 543 right_dev = devices[right_idx] 544 with ops.device(left_dev): 545 new_chunks[left_idx] = array_ops.concat([chunks[left_idx], 546 chunks[right_idx]], 0) 547 with ops.device(right_dev): 548 new_chunks[right_idx] = array_ops.concat([chunks[left_idx], 549 chunks[right_idx]], 0) 550 chunks = new_chunks 551 return chunks 552 553 554def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): 555 """Construct a subgraph for shuffle all-reduce. 556 557 Shuffle reduce is essentially the algorithm implemented when using 558 parameter servers. Suppose tensor length is n, there are d devices 559 and g gather shards. Each device sends a n/g length sub-tensor to 560 each gather shard. The gather shards perform a reduction across d 561 fragments, then broadcast the result back to each device. The 562 devices then join the g fully reduced fragments they receive from 563 the shards. The gather shards could perform d-1 pairwise 564 reductions, or one d-way reduction. The first is better where 565 reduction Op time is low compared to transmission time, the second 566 better in the other case. 567 568 Args: 569 input_tensors: list of `tf.Tensor` values to be reduced. 570 gather_devices: list of names of devices on which reduction shards 571 should be placed. 572 red_op: an n-array elementwise reduction Op 573 un_op: optional elementwise unary Op to be applied to fully-reduced values. 574 575 Returns: 576 list of `tf.Tensor` which are the fully reduced tensors. 577 """ 578 input_tensors, shape = _flatten_tensors(input_tensors) 579 dst_devices = [t.device for t in input_tensors] 580 reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, 581 red_op, un_op) 582 output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) 583 if len(shape) != 1: 584 output_tensors = _reshape_tensors(output_tensors, shape) 585 return output_tensors 586 587 588def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None): 589 """Construct the gather (concentrate and reduce) phase of shuffle all-reduce. 590 591 Args: 592 input_tensors: list of `tf.Tensor` values to be reduced. 593 gather_devices: list of names of devices on which reduction shards 594 should be placed. 595 red_op: the binary reduction Op 596 un_op: optional elementwise unary Op to be applied to fully-reduced values. 597 598 Returns: 599 list of `tf.Tensor` which are the fully reduced shards. 600 601 Raises: 602 ValueError: inputs not well-formed. 603 """ 604 num_source_devices = len(input_tensors) 605 num_gather_devices = len(gather_devices) 606 shape = input_tensors[0].shape 607 if len(shape) != 1: 608 raise ValueError("input_tensors must be 1D") 609 shards_by_source = [] 610 for d in range(0, num_source_devices): 611 with ops.colocate_with(input_tensors[d]): 612 shards_by_source.append( 613 _ragged_split(input_tensors[d], num_gather_devices)) 614 reduced_shards = [] 615 for d in range(0, num_gather_devices): 616 with ops.device(gather_devices[d]): 617 values = [s[d] for s in shards_by_source] 618 red_shard = red_op(values) 619 if un_op: 620 red_shard = un_op(red_shard) 621 reduced_shards.append(red_shard) 622 return reduced_shards 623 624 625def _build_shuffle_scatter(reduced_shards, dst_devices): 626 """Build the scatter phase of shuffle all-reduce. 627 628 Args: 629 reduced_shards: list of `tf.Tensor` fully reduced shards 630 dst_devices: list of names of devices at which the fully-reduced value 631 should be reconstituted. 632 633 Returns: 634 list of `tf.Tensor` scattered tensors. 635 """ 636 num_devices = len(dst_devices) 637 out_tensors = [] 638 for d in range(0, num_devices): 639 with ops.device(dst_devices[d]): 640 out_tensors.append(array_ops.concat(reduced_shards, 0)) 641 return out_tensors 642 643 644def _split_by_task(devices, values): 645 """Partition devices and values by common task. 646 647 Args: 648 devices: list of device name strings 649 values: list of `tf.Tensor` of same length as devices. 650 651 Returns: 652 (per_task_devices, per_task_values) where both values are 653 lists of lists with isomorphic structure: the outer list is 654 indexed by task, and the inner list has length of the number 655 of values belonging to that task. per_task_devices contains 656 the specific devices to which the values are local, and 657 per_task_values contains the corresponding values. 658 659 Raises: 660 ValueError: devices must be same length as values. 661 """ 662 num_devices = len(devices) 663 if num_devices != len(values): 664 raise ValueError("len(devices) must equal len(values)") 665 per_task_devices = collections.OrderedDict() 666 per_task_values = collections.OrderedDict() 667 for d in range(num_devices): 668 d_spec = device_lib.DeviceSpec.from_string(devices[d]) 669 if not hasattr(d_spec, "task") or d_spec.task is None: 670 assert False, "failed to parse device %s" % devices[d] 671 index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) 672 if index not in per_task_devices: 673 per_task_devices[index] = [] 674 per_task_values[index] = [] 675 per_task_devices[index].append(devices[d]) 676 per_task_values[index].append(values[d]) 677 678 return (list(per_task_devices.values()), list(per_task_values.values())) 679 680 681def build_nccl_all_reduce(input_tensors, red_op, un_op=None): 682 """Build a subgraph that does one full all-reduce, using NCCL. 683 684 Args: 685 input_tensors: list of `tf.Tensor` of same-shape and type values to 686 be reduced. 687 red_op: binary elementwise reduction operator. Must be one of 688 {tf.add} 689 un_op: optional unary elementwise Op to apply to fully-reduce values. 690 691 Returns: 692 list of `tf.Tensor` of reduced values. 693 694 Raises: 695 ValueError: red_op not supported. 696 """ 697 if red_op == math_ops.add: 698 output_tensors = nccl_ops.all_sum(input_tensors) 699 else: 700 raise ValueError("red_op not supported by NCCL all-reduce: ", red_op) 701 if un_op: 702 un_op_wrapped = [] 703 for t in output_tensors: 704 with ops.colocate_with(t): 705 un_op_wrapped.append(un_op(t)) 706 output_tensors = un_op_wrapped 707 return output_tensors 708 709 710def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): 711 """Construct a subgraph for NCCL hybrid all-reduce. 712 713 Args: 714 input_tensors: list of `tf.Tensor` of same-shape and type values to 715 be reduced. 716 red_op: binary elementwise reduction operator. 717 upper_level_f: function for reducing one value per worker, across 718 workers. 719 720 Returns: 721 list of `tf.Tensor` of reduced values. 722 723 Raises: 724 ValueError: inputs not well-formed. 725 """ 726 input_tensors, shape = _flatten_tensors(input_tensors) 727 devices = [t.device for t in input_tensors] 728 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 729 num_workers = len(per_worker_devices) 730 up_values = [None for w in range(0, num_workers)] 731 up_devices = up_values[:] 732 down_values = up_values[:] 733 # First stage: reduce within each worker using NCCL 734 for w in range(0, num_workers): 735 worker_values = build_nccl_all_reduce(per_worker_values[w], red_op) 736 # NOTE: these reductions will not run to completion unless 737 # every output value is used. Since we only need one, we 738 # need to put control dependencies on the rest. 739 with ops.control_dependencies(worker_values): 740 with ops.device(worker_values[0].device): 741 up_values[w] = array_ops.identity(worker_values[0]) 742 up_devices[w] = per_worker_devices[w][0] 743 # Second stage: Apply upper_level_f to reduce across first device at 744 # each worker 745 level_2_output = upper_level_f(up_values) 746 # Third stage: propagate within each worker using NCCL Broadcast 747 for w in range(0, num_workers): 748 dst_tensors = [] 749 with ops.device(per_worker_devices[w][0]): 750 broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w])) 751 for d in per_worker_devices[w]: 752 with ops.device(d): 753 dst_tensors.append(array_ops.identity(broadcast_src)) 754 down_values[w] = dst_tensors 755 output_tensors = [v for sublist in down_values for v in sublist] 756 if len(shape) != 1: 757 output_tensors = _reshape_tensors(output_tensors, shape) 758 return output_tensors 759 760 761def _reduce_non_singleton(input_tensors, red_f, un_op): 762 """If len(input_tensors) > 1, apply red_f, else apply un_op.""" 763 if len(input_tensors) > 1: 764 return red_f(input_tensors) 765 else: 766 if not un_op: 767 return input_tensors 768 output_tensors = [] 769 for t in input_tensors: 770 with ops.colocate_with(t): 771 output_tensors.append(un_op(t)) 772 return output_tensors 773 774 775def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None): 776 """Construct hybrid of NCCL within workers, Ring across workers.""" 777 def upper_builder(y): 778 return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op) 779 def upper_level_f(x): 780 return _reduce_non_singleton(x, upper_builder, un_op) 781 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 782 783 784def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None): 785 """Construct hybrid of NCCL within workers, Recursive-HD across workers.""" 786 upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op) 787 return _build_nccl_hybrid(input_tensors, red_op, upper_level_f) 788 789 790def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op, 791 shuffle_red_op, un_op=None): 792 """Construct hybrid of NCCL within workers, Shuffle across workers.""" 793 def upper_level_f(x): 794 return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op) 795 796 return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f) 797 798 799def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): 800 """Construct a subgraph for Shuffle hybrid all-reduce. 801 802 Args: 803 input_tensors: list of `tf.Tensor` of same-shape and type values to 804 be reduced. 805 gather_devices: list of device names on which to host gather shards. 806 red_op: binary elementwise reduction operator. 807 upper_level_f: function for reducing one value per worker, across 808 workers. 809 810 Returns: 811 list of `tf.Tensor` of reduced values. 812 813 Raises: 814 ValueError: inputs not well-formed. 815 """ 816 input_tensors, shape = _flatten_tensors(input_tensors) 817 # First stage, reduce across each worker using gather_devices. 818 devices = [t.device for t in input_tensors] 819 per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors) 820 num_workers = len(per_worker_devices) 821 up_values = [] 822 if len(gather_devices) != num_workers: 823 raise ValueError("For shuffle hybrid, gather_devices must contain one " 824 "device per worker. ") 825 for w in range(0, num_workers): 826 reduced_shards = _build_shuffle_gather( 827 per_worker_values[w], [gather_devices[w]], red_op) 828 up_values.append(reduced_shards[0]) 829 # Second stage, apply upper_level_f. 830 level_2_output = upper_level_f(up_values) 831 # Third stage, apply shuffle scatter at each worker. 832 output_tensors = [] 833 for w in range(0, num_workers): 834 output_tensors += _build_shuffle_scatter( 835 [level_2_output[w]], per_worker_devices[w]) 836 if len(shape) != 1: 837 output_tensors = _reshape_tensors(output_tensors, shape) 838 return output_tensors 839 840 841def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, 842 red_n_op, red_op, un_op=None): 843 """Construct hybrid of Shuffle within workers, Ring across workers.""" 844 def upper_builder(tensors): 845 return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], 846 red_op, un_op) 847 def upper_level_f(tensors): 848 return _reduce_non_singleton(tensors, upper_builder, un_op) 849 return _build_shuffle_hybrid( 850 input_tensors, gather_devices, red_n_op, upper_level_f) 851 852 853def build_shuffle_then_shuffle(input_tensors, first_gather_devices, 854 second_gather_devices, red_op, un_op=None): 855 """Construct hybrid of Shuffle within workers, Shuffle across workers.""" 856 def upper_builder(tensors): 857 return build_shuffle_all_reduce(tensors, second_gather_devices, 858 red_op, un_op) 859 def upper_level_f(tensors): 860 return _reduce_non_singleton(tensors, upper_builder, un_op) 861 return _build_shuffle_hybrid( 862 input_tensors, first_gather_devices, red_op, upper_level_f) 863